Neural Networks: General-purpose learning algorithm for modeling non-linearity
... if you train it with "enough" data
Transform $x$ into $\phi(x)$ to become linearly separable
$\phi(x)$ is the basis for a "neuron"
$$y = W\phi(x) + b$$
$$\phi(x) = g(W'x + b')$$
Trainable: $W', b', W, b$
$g(x)$ is a non-linear function, e.g. Sigmoid
$$y = sigmoid(W(x) + b)$$
$$y = relu(W(x) + b)$$
Multiple "hidden" layers of neurons make up a "Deep Neural Network"
(image: Goldberg, 2017)
Term | Description | Examples |
---|---|---|
Input dimension | How many inputs | 4 |
Output dimension | How many outputs | 3 |
Number of hidden layers | Number of layers, excluding input and output | 2 |
Activation type | Type of non-linear function | sigmoid, ReLU, tanh |
Hidden layer type | How the neurons are connected together | Fully-connected, Convolutional |
How the neurons are connected together, and what operations are performed with x, W, and b:
More detail to come...
In this walkthrough, we will use Keras to examine the architecture of some well-known neural networks.
mldds03
a. Launch an Anaconda Python
command window
b. conda create -n mldds03 python=3
conda activate mldds03
conda install jupyter numpy pandas matplotlib keras pydot python-graphviz
cd mldds-courseware
jupyter notebook
and open this notebookInstall: conda install keras
Install: conda install pydot python-graphviz
"Pre-trained" neural networks are available under keras.applications
https://keras.io/applications/
These are trained on the ImageNet dataset (http://www.image-net.org/), which contains millions of images.
The neural network architectures from keras are previous years submissions to the ImageNet annual challenge.
import keras
print(keras.__version__)
Using TensorFlow backend.
2.2.0
MobileNet is a pre-trained ImageNet DNN optimized to run on smaller devices.
Documentation: https://keras.io/applications/#mobilenet
Implementation: https://github.com/keras-team/keras-applications/blob/master/keras_applications/mobilenet.py
from keras.applications import mobilenet
mobilenet_model = mobilenet.MobileNet(weights='imagenet')
mobilenet_model.summary()
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) (None, 224, 224, 3) 0 _________________________________________________________________ conv1_pad (ZeroPadding2D) (None, 226, 226, 3) 0 _________________________________________________________________ conv1 (Conv2D) (None, 112, 112, 32) 864 _________________________________________________________________ conv1_bn (BatchNormalization (None, 112, 112, 32) 128 _________________________________________________________________ conv1_relu (Activation) (None, 112, 112, 32) 0 _________________________________________________________________ conv_pad_1 (ZeroPadding2D) (None, 114, 114, 32) 0 _________________________________________________________________ conv_dw_1 (DepthwiseConv2D) (None, 112, 112, 32) 288 _________________________________________________________________ conv_dw_1_bn (BatchNormaliza (None, 112, 112, 32) 128 _________________________________________________________________ conv_dw_1_relu (Activation) (None, 112, 112, 32) 0 _________________________________________________________________ conv_pw_1 (Conv2D) (None, 112, 112, 64) 2048 _________________________________________________________________ conv_pw_1_bn (BatchNormaliza (None, 112, 112, 64) 256 _________________________________________________________________ conv_pw_1_relu (Activation) (None, 112, 112, 64) 0 _________________________________________________________________ conv_pad_2 (ZeroPadding2D) (None, 114, 114, 64) 0 _________________________________________________________________ conv_dw_2 (DepthwiseConv2D) (None, 56, 56, 64) 576 _________________________________________________________________ conv_dw_2_bn (BatchNormaliza (None, 56, 56, 64) 256 _________________________________________________________________ conv_dw_2_relu (Activation) (None, 56, 56, 64) 0 _________________________________________________________________ conv_pw_2 (Conv2D) (None, 56, 56, 128) 8192 _________________________________________________________________ conv_pw_2_bn (BatchNormaliza (None, 56, 56, 128) 512 _________________________________________________________________ conv_pw_2_relu (Activation) (None, 56, 56, 128) 0 _________________________________________________________________ conv_pad_3 (ZeroPadding2D) (None, 58, 58, 128) 0 _________________________________________________________________ conv_dw_3 (DepthwiseConv2D) (None, 56, 56, 128) 1152 _________________________________________________________________ conv_dw_3_bn (BatchNormaliza (None, 56, 56, 128) 512 _________________________________________________________________ conv_dw_3_relu (Activation) (None, 56, 56, 128) 0 _________________________________________________________________ conv_pw_3 (Conv2D) (None, 56, 56, 128) 16384 _________________________________________________________________ conv_pw_3_bn (BatchNormaliza (None, 56, 56, 128) 512 _________________________________________________________________ conv_pw_3_relu (Activation) (None, 56, 56, 128) 0 _________________________________________________________________ conv_pad_4 (ZeroPadding2D) (None, 58, 58, 128) 0 _________________________________________________________________ conv_dw_4 (DepthwiseConv2D) (None, 28, 28, 128) 1152 _________________________________________________________________ conv_dw_4_bn (BatchNormaliza (None, 28, 28, 128) 512 _________________________________________________________________ conv_dw_4_relu (Activation) (None, 28, 28, 128) 0 _________________________________________________________________ conv_pw_4 (Conv2D) (None, 28, 28, 256) 32768 _________________________________________________________________ conv_pw_4_bn (BatchNormaliza (None, 28, 28, 256) 1024 _________________________________________________________________ conv_pw_4_relu (Activation) (None, 28, 28, 256) 0 _________________________________________________________________ conv_pad_5 (ZeroPadding2D) (None, 30, 30, 256) 0 _________________________________________________________________ conv_dw_5 (DepthwiseConv2D) (None, 28, 28, 256) 2304 _________________________________________________________________ conv_dw_5_bn (BatchNormaliza (None, 28, 28, 256) 1024 _________________________________________________________________ conv_dw_5_relu (Activation) (None, 28, 28, 256) 0 _________________________________________________________________ conv_pw_5 (Conv2D) (None, 28, 28, 256) 65536 _________________________________________________________________ conv_pw_5_bn (BatchNormaliza (None, 28, 28, 256) 1024 _________________________________________________________________ conv_pw_5_relu (Activation) (None, 28, 28, 256) 0 _________________________________________________________________ conv_pad_6 (ZeroPadding2D) (None, 30, 30, 256) 0 _________________________________________________________________ conv_dw_6 (DepthwiseConv2D) (None, 14, 14, 256) 2304 _________________________________________________________________ conv_dw_6_bn (BatchNormaliza (None, 14, 14, 256) 1024 _________________________________________________________________ conv_dw_6_relu (Activation) (None, 14, 14, 256) 0 _________________________________________________________________ conv_pw_6 (Conv2D) (None, 14, 14, 512) 131072 _________________________________________________________________ conv_pw_6_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_6_relu (Activation) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pad_7 (ZeroPadding2D) (None, 16, 16, 512) 0 _________________________________________________________________ conv_dw_7 (DepthwiseConv2D) (None, 14, 14, 512) 4608 _________________________________________________________________ conv_dw_7_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_dw_7_relu (Activation) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pw_7 (Conv2D) (None, 14, 14, 512) 262144 _________________________________________________________________ conv_pw_7_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_7_relu (Activation) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pad_8 (ZeroPadding2D) (None, 16, 16, 512) 0 _________________________________________________________________ conv_dw_8 (DepthwiseConv2D) (None, 14, 14, 512) 4608 _________________________________________________________________ conv_dw_8_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_dw_8_relu (Activation) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pw_8 (Conv2D) (None, 14, 14, 512) 262144 _________________________________________________________________ conv_pw_8_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_8_relu (Activation) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pad_9 (ZeroPadding2D) (None, 16, 16, 512) 0 _________________________________________________________________ conv_dw_9 (DepthwiseConv2D) (None, 14, 14, 512) 4608 _________________________________________________________________ conv_dw_9_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_dw_9_relu (Activation) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pw_9 (Conv2D) (None, 14, 14, 512) 262144 _________________________________________________________________ conv_pw_9_bn (BatchNormaliza (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_9_relu (Activation) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pad_10 (ZeroPadding2D) (None, 16, 16, 512) 0 _________________________________________________________________ conv_dw_10 (DepthwiseConv2D) (None, 14, 14, 512) 4608 _________________________________________________________________ conv_dw_10_bn (BatchNormaliz (None, 14, 14, 512) 2048 _________________________________________________________________ conv_dw_10_relu (Activation) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pw_10 (Conv2D) (None, 14, 14, 512) 262144 _________________________________________________________________ conv_pw_10_bn (BatchNormaliz (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_10_relu (Activation) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pad_11 (ZeroPadding2D) (None, 16, 16, 512) 0 _________________________________________________________________ conv_dw_11 (DepthwiseConv2D) (None, 14, 14, 512) 4608 _________________________________________________________________ conv_dw_11_bn (BatchNormaliz (None, 14, 14, 512) 2048 _________________________________________________________________ conv_dw_11_relu (Activation) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pw_11 (Conv2D) (None, 14, 14, 512) 262144 _________________________________________________________________ conv_pw_11_bn (BatchNormaliz (None, 14, 14, 512) 2048 _________________________________________________________________ conv_pw_11_relu (Activation) (None, 14, 14, 512) 0 _________________________________________________________________ conv_pad_12 (ZeroPadding2D) (None, 16, 16, 512) 0 _________________________________________________________________ conv_dw_12 (DepthwiseConv2D) (None, 7, 7, 512) 4608 _________________________________________________________________ conv_dw_12_bn (BatchNormaliz (None, 7, 7, 512) 2048 _________________________________________________________________ conv_dw_12_relu (Activation) (None, 7, 7, 512) 0 _________________________________________________________________ conv_pw_12 (Conv2D) (None, 7, 7, 1024) 524288 _________________________________________________________________ conv_pw_12_bn (BatchNormaliz (None, 7, 7, 1024) 4096 _________________________________________________________________ conv_pw_12_relu (Activation) (None, 7, 7, 1024) 0 _________________________________________________________________ conv_pad_13 (ZeroPadding2D) (None, 9, 9, 1024) 0 _________________________________________________________________ conv_dw_13 (DepthwiseConv2D) (None, 7, 7, 1024) 9216 _________________________________________________________________ conv_dw_13_bn (BatchNormaliz (None, 7, 7, 1024) 4096 _________________________________________________________________ conv_dw_13_relu (Activation) (None, 7, 7, 1024) 0 _________________________________________________________________ conv_pw_13 (Conv2D) (None, 7, 7, 1024) 1048576 _________________________________________________________________ conv_pw_13_bn (BatchNormaliz (None, 7, 7, 1024) 4096 _________________________________________________________________ conv_pw_13_relu (Activation) (None, 7, 7, 1024) 0 _________________________________________________________________ global_average_pooling2d_1 ( (None, 1024) 0 _________________________________________________________________ reshape_1 (Reshape) (None, 1, 1, 1024) 0 _________________________________________________________________ dropout (Dropout) (None, 1, 1, 1024) 0 _________________________________________________________________ conv_preds (Conv2D) (None, 1, 1, 1000) 1025000 _________________________________________________________________ act_softmax (Activation) (None, 1, 1, 1000) 0 _________________________________________________________________ reshape_2 (Reshape) (None, 1000) 0 ================================================================= Total params: 4,253,864 Trainable params: 4,231,976 Non-trainable params: 21,888 _________________________________________________________________
ResNet50 is another pre-trained ImageNet DNN. This is a larger network than MobileNet (almost 26 million parameters). It improves accuracy by introducing residual connections, which are connections that skip layers.
Documentation: https://keras.io/applications/#resnet50
Implementation: https://github.com/keras-team/keras-applications/blob/master/keras_applications/resnet50.py
from keras.applications import resnet50
resnet_model = resnet50.ResNet50(weights='imagenet')
resnet_model.summary()
__________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_2 (InputLayer) (None, 224, 224, 3) 0 __________________________________________________________________________________________________ conv1_pad (ZeroPadding2D) (None, 230, 230, 3) 0 input_2[0][0] __________________________________________________________________________________________________ conv1 (Conv2D) (None, 112, 112, 64) 9472 conv1_pad[0][0] __________________________________________________________________________________________________ bn_conv1 (BatchNormalization) (None, 112, 112, 64) 256 conv1[0][0] __________________________________________________________________________________________________ activation_1 (Activation) (None, 112, 112, 64) 0 bn_conv1[0][0] __________________________________________________________________________________________________ max_pooling2d_1 (MaxPooling2D) (None, 55, 55, 64) 0 activation_1[0][0] __________________________________________________________________________________________________ res2a_branch2a (Conv2D) (None, 55, 55, 64) 4160 max_pooling2d_1[0][0] __________________________________________________________________________________________________ bn2a_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2a_branch2a[0][0] __________________________________________________________________________________________________ activation_2 (Activation) (None, 55, 55, 64) 0 bn2a_branch2a[0][0] __________________________________________________________________________________________________ res2a_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_2[0][0] __________________________________________________________________________________________________ bn2a_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2a_branch2b[0][0] __________________________________________________________________________________________________ activation_3 (Activation) (None, 55, 55, 64) 0 bn2a_branch2b[0][0] __________________________________________________________________________________________________ res2a_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_3[0][0] __________________________________________________________________________________________________ res2a_branch1 (Conv2D) (None, 55, 55, 256) 16640 max_pooling2d_1[0][0] __________________________________________________________________________________________________ bn2a_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2a_branch2c[0][0] __________________________________________________________________________________________________ bn2a_branch1 (BatchNormalizatio (None, 55, 55, 256) 1024 res2a_branch1[0][0] __________________________________________________________________________________________________ add_1 (Add) (None, 55, 55, 256) 0 bn2a_branch2c[0][0] bn2a_branch1[0][0] __________________________________________________________________________________________________ activation_4 (Activation) (None, 55, 55, 256) 0 add_1[0][0] __________________________________________________________________________________________________ res2b_branch2a (Conv2D) (None, 55, 55, 64) 16448 activation_4[0][0] __________________________________________________________________________________________________ bn2b_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2b_branch2a[0][0] __________________________________________________________________________________________________ activation_5 (Activation) (None, 55, 55, 64) 0 bn2b_branch2a[0][0] __________________________________________________________________________________________________ res2b_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_5[0][0] __________________________________________________________________________________________________ bn2b_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2b_branch2b[0][0] __________________________________________________________________________________________________ activation_6 (Activation) (None, 55, 55, 64) 0 bn2b_branch2b[0][0] __________________________________________________________________________________________________ res2b_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_6[0][0] __________________________________________________________________________________________________ bn2b_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2b_branch2c[0][0] __________________________________________________________________________________________________ add_2 (Add) (None, 55, 55, 256) 0 bn2b_branch2c[0][0] activation_4[0][0] __________________________________________________________________________________________________ activation_7 (Activation) (None, 55, 55, 256) 0 add_2[0][0] __________________________________________________________________________________________________ res2c_branch2a (Conv2D) (None, 55, 55, 64) 16448 activation_7[0][0] __________________________________________________________________________________________________ bn2c_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2c_branch2a[0][0] __________________________________________________________________________________________________ activation_8 (Activation) (None, 55, 55, 64) 0 bn2c_branch2a[0][0] __________________________________________________________________________________________________ res2c_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_8[0][0] __________________________________________________________________________________________________ bn2c_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2c_branch2b[0][0] __________________________________________________________________________________________________ activation_9 (Activation) (None, 55, 55, 64) 0 bn2c_branch2b[0][0] __________________________________________________________________________________________________ res2c_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_9[0][0] __________________________________________________________________________________________________ bn2c_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2c_branch2c[0][0] __________________________________________________________________________________________________ add_3 (Add) (None, 55, 55, 256) 0 bn2c_branch2c[0][0] activation_7[0][0] __________________________________________________________________________________________________ activation_10 (Activation) (None, 55, 55, 256) 0 add_3[0][0] __________________________________________________________________________________________________ res3a_branch2a (Conv2D) (None, 28, 28, 128) 32896 activation_10[0][0] __________________________________________________________________________________________________ bn3a_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2a[0][0] __________________________________________________________________________________________________ activation_11 (Activation) (None, 28, 28, 128) 0 bn3a_branch2a[0][0] __________________________________________________________________________________________________ res3a_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_11[0][0] __________________________________________________________________________________________________ bn3a_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2b[0][0] __________________________________________________________________________________________________ activation_12 (Activation) (None, 28, 28, 128) 0 bn3a_branch2b[0][0] __________________________________________________________________________________________________ res3a_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_12[0][0] __________________________________________________________________________________________________ res3a_branch1 (Conv2D) (None, 28, 28, 512) 131584 activation_10[0][0] __________________________________________________________________________________________________ bn3a_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3a_branch2c[0][0] __________________________________________________________________________________________________ bn3a_branch1 (BatchNormalizatio (None, 28, 28, 512) 2048 res3a_branch1[0][0] __________________________________________________________________________________________________ add_4 (Add) (None, 28, 28, 512) 0 bn3a_branch2c[0][0] bn3a_branch1[0][0] __________________________________________________________________________________________________ activation_13 (Activation) (None, 28, 28, 512) 0 add_4[0][0] __________________________________________________________________________________________________ res3b_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_13[0][0] __________________________________________________________________________________________________ bn3b_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2a[0][0] __________________________________________________________________________________________________ activation_14 (Activation) (None, 28, 28, 128) 0 bn3b_branch2a[0][0] __________________________________________________________________________________________________ res3b_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_14[0][0] __________________________________________________________________________________________________ bn3b_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2b[0][0] __________________________________________________________________________________________________ activation_15 (Activation) (None, 28, 28, 128) 0 bn3b_branch2b[0][0] __________________________________________________________________________________________________ res3b_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_15[0][0] __________________________________________________________________________________________________ bn3b_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3b_branch2c[0][0] __________________________________________________________________________________________________ add_5 (Add) (None, 28, 28, 512) 0 bn3b_branch2c[0][0] activation_13[0][0] __________________________________________________________________________________________________ activation_16 (Activation) (None, 28, 28, 512) 0 add_5[0][0] __________________________________________________________________________________________________ res3c_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_16[0][0] __________________________________________________________________________________________________ bn3c_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3c_branch2a[0][0] __________________________________________________________________________________________________ activation_17 (Activation) (None, 28, 28, 128) 0 bn3c_branch2a[0][0] __________________________________________________________________________________________________ res3c_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_17[0][0] __________________________________________________________________________________________________ bn3c_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3c_branch2b[0][0] __________________________________________________________________________________________________ activation_18 (Activation) (None, 28, 28, 128) 0 bn3c_branch2b[0][0] __________________________________________________________________________________________________ res3c_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_18[0][0] __________________________________________________________________________________________________ bn3c_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3c_branch2c[0][0] __________________________________________________________________________________________________ add_6 (Add) (None, 28, 28, 512) 0 bn3c_branch2c[0][0] activation_16[0][0] __________________________________________________________________________________________________ activation_19 (Activation) (None, 28, 28, 512) 0 add_6[0][0] __________________________________________________________________________________________________ res3d_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_19[0][0] __________________________________________________________________________________________________ bn3d_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3d_branch2a[0][0] __________________________________________________________________________________________________ activation_20 (Activation) (None, 28, 28, 128) 0 bn3d_branch2a[0][0] __________________________________________________________________________________________________ res3d_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_20[0][0] __________________________________________________________________________________________________ bn3d_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3d_branch2b[0][0] __________________________________________________________________________________________________ activation_21 (Activation) (None, 28, 28, 128) 0 bn3d_branch2b[0][0] __________________________________________________________________________________________________ res3d_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_21[0][0] __________________________________________________________________________________________________ bn3d_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3d_branch2c[0][0] __________________________________________________________________________________________________ add_7 (Add) (None, 28, 28, 512) 0 bn3d_branch2c[0][0] activation_19[0][0] __________________________________________________________________________________________________ activation_22 (Activation) (None, 28, 28, 512) 0 add_7[0][0] __________________________________________________________________________________________________ res4a_branch2a (Conv2D) (None, 14, 14, 256) 131328 activation_22[0][0] __________________________________________________________________________________________________ bn4a_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4a_branch2a[0][0] __________________________________________________________________________________________________ activation_23 (Activation) (None, 14, 14, 256) 0 bn4a_branch2a[0][0] __________________________________________________________________________________________________ res4a_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_23[0][0] __________________________________________________________________________________________________ bn4a_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4a_branch2b[0][0] __________________________________________________________________________________________________ activation_24 (Activation) (None, 14, 14, 256) 0 bn4a_branch2b[0][0] __________________________________________________________________________________________________ res4a_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_24[0][0] __________________________________________________________________________________________________ res4a_branch1 (Conv2D) (None, 14, 14, 1024) 525312 activation_22[0][0] __________________________________________________________________________________________________ bn4a_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4a_branch2c[0][0] __________________________________________________________________________________________________ bn4a_branch1 (BatchNormalizatio (None, 14, 14, 1024) 4096 res4a_branch1[0][0] __________________________________________________________________________________________________ add_8 (Add) (None, 14, 14, 1024) 0 bn4a_branch2c[0][0] bn4a_branch1[0][0] __________________________________________________________________________________________________ activation_25 (Activation) (None, 14, 14, 1024) 0 add_8[0][0] __________________________________________________________________________________________________ res4b_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_25[0][0] __________________________________________________________________________________________________ bn4b_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4b_branch2a[0][0] __________________________________________________________________________________________________ activation_26 (Activation) (None, 14, 14, 256) 0 bn4b_branch2a[0][0] __________________________________________________________________________________________________ res4b_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_26[0][0] __________________________________________________________________________________________________ bn4b_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4b_branch2b[0][0] __________________________________________________________________________________________________ activation_27 (Activation) (None, 14, 14, 256) 0 bn4b_branch2b[0][0] __________________________________________________________________________________________________ res4b_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_27[0][0] __________________________________________________________________________________________________ bn4b_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4b_branch2c[0][0] __________________________________________________________________________________________________ add_9 (Add) (None, 14, 14, 1024) 0 bn4b_branch2c[0][0] activation_25[0][0] __________________________________________________________________________________________________ activation_28 (Activation) (None, 14, 14, 1024) 0 add_9[0][0] __________________________________________________________________________________________________ res4c_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_28[0][0] __________________________________________________________________________________________________ bn4c_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4c_branch2a[0][0] __________________________________________________________________________________________________ activation_29 (Activation) (None, 14, 14, 256) 0 bn4c_branch2a[0][0] __________________________________________________________________________________________________ res4c_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_29[0][0] __________________________________________________________________________________________________ bn4c_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4c_branch2b[0][0] __________________________________________________________________________________________________ activation_30 (Activation) (None, 14, 14, 256) 0 bn4c_branch2b[0][0] __________________________________________________________________________________________________ res4c_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_30[0][0] __________________________________________________________________________________________________ bn4c_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4c_branch2c[0][0] __________________________________________________________________________________________________ add_10 (Add) (None, 14, 14, 1024) 0 bn4c_branch2c[0][0] activation_28[0][0] __________________________________________________________________________________________________ activation_31 (Activation) (None, 14, 14, 1024) 0 add_10[0][0] __________________________________________________________________________________________________ res4d_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_31[0][0] __________________________________________________________________________________________________ bn4d_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4d_branch2a[0][0] __________________________________________________________________________________________________ activation_32 (Activation) (None, 14, 14, 256) 0 bn4d_branch2a[0][0] __________________________________________________________________________________________________ res4d_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_32[0][0] __________________________________________________________________________________________________ bn4d_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4d_branch2b[0][0] __________________________________________________________________________________________________ activation_33 (Activation) (None, 14, 14, 256) 0 bn4d_branch2b[0][0] __________________________________________________________________________________________________ res4d_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_33[0][0] __________________________________________________________________________________________________ bn4d_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4d_branch2c[0][0] __________________________________________________________________________________________________ add_11 (Add) (None, 14, 14, 1024) 0 bn4d_branch2c[0][0] activation_31[0][0] __________________________________________________________________________________________________ activation_34 (Activation) (None, 14, 14, 1024) 0 add_11[0][0] __________________________________________________________________________________________________ res4e_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_34[0][0] __________________________________________________________________________________________________ bn4e_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4e_branch2a[0][0] __________________________________________________________________________________________________ activation_35 (Activation) (None, 14, 14, 256) 0 bn4e_branch2a[0][0] __________________________________________________________________________________________________ res4e_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_35[0][0] __________________________________________________________________________________________________ bn4e_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4e_branch2b[0][0] __________________________________________________________________________________________________ activation_36 (Activation) (None, 14, 14, 256) 0 bn4e_branch2b[0][0] __________________________________________________________________________________________________ res4e_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_36[0][0] __________________________________________________________________________________________________ bn4e_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4e_branch2c[0][0] __________________________________________________________________________________________________ add_12 (Add) (None, 14, 14, 1024) 0 bn4e_branch2c[0][0] activation_34[0][0] __________________________________________________________________________________________________ activation_37 (Activation) (None, 14, 14, 1024) 0 add_12[0][0] __________________________________________________________________________________________________ res4f_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_37[0][0] __________________________________________________________________________________________________ bn4f_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4f_branch2a[0][0] __________________________________________________________________________________________________ activation_38 (Activation) (None, 14, 14, 256) 0 bn4f_branch2a[0][0] __________________________________________________________________________________________________ res4f_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_38[0][0] __________________________________________________________________________________________________ bn4f_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4f_branch2b[0][0] __________________________________________________________________________________________________ activation_39 (Activation) (None, 14, 14, 256) 0 bn4f_branch2b[0][0] __________________________________________________________________________________________________ res4f_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_39[0][0] __________________________________________________________________________________________________ bn4f_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4f_branch2c[0][0] __________________________________________________________________________________________________ add_13 (Add) (None, 14, 14, 1024) 0 bn4f_branch2c[0][0] activation_37[0][0] __________________________________________________________________________________________________ activation_40 (Activation) (None, 14, 14, 1024) 0 add_13[0][0] __________________________________________________________________________________________________ res5a_branch2a (Conv2D) (None, 7, 7, 512) 524800 activation_40[0][0] __________________________________________________________________________________________________ bn5a_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5a_branch2a[0][0] __________________________________________________________________________________________________ activation_41 (Activation) (None, 7, 7, 512) 0 bn5a_branch2a[0][0] __________________________________________________________________________________________________ res5a_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_41[0][0] __________________________________________________________________________________________________ bn5a_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5a_branch2b[0][0] __________________________________________________________________________________________________ activation_42 (Activation) (None, 7, 7, 512) 0 bn5a_branch2b[0][0] __________________________________________________________________________________________________ res5a_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_42[0][0] __________________________________________________________________________________________________ res5a_branch1 (Conv2D) (None, 7, 7, 2048) 2099200 activation_40[0][0] __________________________________________________________________________________________________ bn5a_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5a_branch2c[0][0] __________________________________________________________________________________________________ bn5a_branch1 (BatchNormalizatio (None, 7, 7, 2048) 8192 res5a_branch1[0][0] __________________________________________________________________________________________________ add_14 (Add) (None, 7, 7, 2048) 0 bn5a_branch2c[0][0] bn5a_branch1[0][0] __________________________________________________________________________________________________ activation_43 (Activation) (None, 7, 7, 2048) 0 add_14[0][0] __________________________________________________________________________________________________ res5b_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_43[0][0] __________________________________________________________________________________________________ bn5b_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5b_branch2a[0][0] __________________________________________________________________________________________________ activation_44 (Activation) (None, 7, 7, 512) 0 bn5b_branch2a[0][0] __________________________________________________________________________________________________ res5b_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_44[0][0] __________________________________________________________________________________________________ bn5b_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5b_branch2b[0][0] __________________________________________________________________________________________________ activation_45 (Activation) (None, 7, 7, 512) 0 bn5b_branch2b[0][0] __________________________________________________________________________________________________ res5b_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_45[0][0] __________________________________________________________________________________________________ bn5b_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5b_branch2c[0][0] __________________________________________________________________________________________________ add_15 (Add) (None, 7, 7, 2048) 0 bn5b_branch2c[0][0] activation_43[0][0] __________________________________________________________________________________________________ activation_46 (Activation) (None, 7, 7, 2048) 0 add_15[0][0] __________________________________________________________________________________________________ res5c_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_46[0][0] __________________________________________________________________________________________________ bn5c_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5c_branch2a[0][0] __________________________________________________________________________________________________ activation_47 (Activation) (None, 7, 7, 512) 0 bn5c_branch2a[0][0] __________________________________________________________________________________________________ res5c_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_47[0][0] __________________________________________________________________________________________________ bn5c_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5c_branch2b[0][0] __________________________________________________________________________________________________ activation_48 (Activation) (None, 7, 7, 512) 0 bn5c_branch2b[0][0] __________________________________________________________________________________________________ res5c_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_48[0][0] __________________________________________________________________________________________________ bn5c_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5c_branch2c[0][0] __________________________________________________________________________________________________ add_16 (Add) (None, 7, 7, 2048) 0 bn5c_branch2c[0][0] activation_46[0][0] __________________________________________________________________________________________________ activation_49 (Activation) (None, 7, 7, 2048) 0 add_16[0][0] __________________________________________________________________________________________________ avg_pool (AveragePooling2D) (None, 1, 1, 2048) 0 activation_49[0][0] __________________________________________________________________________________________________ flatten_1 (Flatten) (None, 2048) 0 avg_pool[0][0] __________________________________________________________________________________________________ fc1000 (Dense) (None, 1000) 2049000 flatten_1[0][0] ================================================================================================== Total params: 25,636,712 Trainable params: 25,583,592 Non-trainable params: 53,120 __________________________________________________________________________________________________
Finally, let's try something simpler.
Let's create a 1-layer network that can do linear regression.
# Reference: https://gist.github.com/fchollet/b7507f373a3446097f26840330c1c378
from keras.models import Sequential
from keras.layers import Dense
simple_model = Sequential()
simple_model.add(Dense(1, input_dim=4, activation='sigmoid')) # 4 inputs, 1 output
simple_model.summary()
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_1 (Dense) (None, 1) 5 ================================================================= Total params: 5 Trainable params: 5 Non-trainable params: 0 _________________________________________________________________
keras.models.Sequential?
keras.layers.Dense?
keras.Model.compile?
How about a 2-layer network to make it a deep neural network?
deeper_model = Sequential()
deeper_model.add(Dense(256, input_dim=16, activation='relu'))
deeper_model.add(Dense(1, activation='sigmoid'))
deeper_model.summary()
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_2 (Dense) (None, 256) 4352 _________________________________________________________________ dense_3 (Dense) (None, 1) 257 ================================================================= Total params: 4,609 Trainable params: 4,609 Non-trainable params: 0 _________________________________________________________________
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
model_to_dot?
SVG(model_to_dot(simple_model, show_shapes=True).create(prog='dot', format='svg'))
SVG(model_to_dot(deeper_model, show_shapes=True).create(prog='dot', format='svg'))
SVG(model_to_dot(mobilenet_model, show_shapes=True).create(prog='dot', format='svg'))
SVG(model_to_dot(resnet_model, show_shapes=True).create(prog='dot', format='svg'))
If pydot is not able to find graphviz
, you can try installing graphviz manually.
C:/Program Files (x86)/Graphviz2.38/bin
Anaconda Prompt
and re-run the Jupyter notebook.A neural network is trained using Stochastic Gradient Descent
2 layers of neurons:
$$x_1 = W_1'g(W_1x + b_1) + b_1'$$
$$y = x_2 = W_2'g(W_2x_1 + b_2) + b_2'$$
For layer $l$, single layer operation:
$$x_l = \sigma_l(W_lx_{l-1} + b_1)$$
where $\sigma_l(z) = W_l'g(z) + b_l'$
for $l = 1$ to $\,L$:
$\,\,\,\,x_l = \sigma_l(W_lx_{l-1} + b_l)$
Where:
Objective
$$W_l^j := W_l^j + \epsilon \frac{\partial J}{W_l^j}$$
$$b_l^j := b_l^j + \epsilon \frac{\partial J}{b_l^j}$$
$\epsilon$ = learning rate
In this workshop, you'll train a neural network to perform logistic regression on the MNIST dataset.
Credits: https://medium.com/@the1ju/simple-logistic-regression-using-keras-249e0cc9a970
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Activation
from keras import backend as K
# Training settings
BATCH_SIZE = 128
NUM_CLASSES = 10
EPOCHS = 30
# Input size settings
IMG_ROWS = 28 # 28 pixels wide
IMG_COLS = 28 # 28 pixels high
# Import the dataset, split between train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# Input processing
input_dim = IMG_ROWS * IMG_COLS
X_train = X_train.reshape(60000, input_dim)
X_test = X_test.reshape(10000, input_dim)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255 # scale to between 0 and 1 (pixel: 0-255)
X_test /= 255
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')
X_train shape: (60000, 784) 60000 train samples 10000 test samples
import matplotlib.pyplot as plt
(X_train1, y_train1), (X_test1, y_test1) = mnist.load_data()
plt.imshow(X_train1[3, :, :], cmap=plt.cm.gray)
plt.show()
print(X_train1[3, :, :]/255)
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.48627451 0.99215686 1. 0.24705882 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.37647059 0.95686275 0.98431373 0.99215686 0.24313725 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.49803922 0.98431373 0.98431373 0.99215686 0.24313725 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.26666667 0.9254902 0.98431373 0.82745098 0.12156863 0.03137255 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.23529412 0.89411765 0.98431373 0.98431373 0.36862745 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.60784314 0.99215686 0.99215686 0.74117647 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.07843137 0.99215686 0.98431373 0.92156863 0.25882353 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.1254902 0.80392157 0.99215686 0.98431373 0.49411765 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.40784314 0.98431373 0.99215686 0.72156863 0.05882353 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.31372549 0.94117647 0.98431373 0.75686275 0.09019608 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.1254902 0.99215686 0.99215686 0.99215686 0.62352941 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.59215686 0.98431373 0.98431373 0.98431373 0.15294118 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.18823529 0.86666667 0.98431373 0.98431373 0.6745098 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.91764706 0.98431373 0.98431373 0.76862745 0.04705882 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.99215686 0.98431373 0.98431373 0.34901961 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.62352941 1. 0.99215686 0.99215686 0.12156863 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0.18823529 0.89411765 0.99215686 0.96862745 0.54901961 0.03137255 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0.25098039 0.98431373 0.99215686 0.8627451 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0.25098039 0.98431373 0.99215686 0.8627451 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0.09411765 0.75686275 0.99215686 0.8627451 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ]]
# We are doing multi-class classification
# Convert class vectors to binary class matrices
y_train_cat = keras.utils.to_categorical(y_train, NUM_CLASSES)
y_test_cat = keras.utils.to_categorical(y_test, NUM_CLASSES)
# Show how the classes look like
print('y_train shape:', y_train_cat.shape)
print('First y_train sample:', y_train_cat[1])
y_train shape: (60000, 10) First y_train sample: [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
# Create the model
from keras.layers import Activation
model = Sequential()
model.add(Dense(NUM_CLASSES, input_dim=input_dim))
model.add(Activation('softmax'))
model.summary()
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_6 (Dense) (None, 10) 7850 _________________________________________________________________ activation_52 (Activation) (None, 10) 0 ================================================================= Total params: 7,850 Trainable params: 7,850 Non-trainable params: 0 _________________________________________________________________
sgd
, minibatch size 128 X_train
, y_train_cat
X_test
, y_test_cat
How to get the accuracy metrics:
history = model.fit(..., metrics=['accuracy'])
...
loss = history.history['loss']
val_loss = history.history['val_loss']
You may reference this example for steps 1 and 2: https://medium.com/@the1ju/simple-logistic-regression-using-keras-249e0cc9a970
# Compile and train model
# Your code here
model.compile(optimizer='sgd',
loss='categorical_crossentropy', metrics=['accuracy'])
history = model.fit(X_train, y_train_cat, batch_size=BATCH_SIZE,
epochs=EPOCHS,
verbose=1,
validation_data=(X_test, y_test_cat))
score = model.evaluate(X_test, y_test_cat, verbose=0)
print('Test score:', score[0])
print('Test accuracy:', score[1])
pred = model.predict(X_test)
Train on 60000 samples, validate on 10000 samples Epoch 1/30 60000/60000 [==============================] - 3s 43us/step - loss: 1.2983 - acc: 0.6889 - val_loss: 0.8124 - val_acc: 0.8318 Epoch 2/30 60000/60000 [==============================] - 2s 27us/step - loss: 0.7178 - acc: 0.8397 - val_loss: 0.6065 - val_acc: 0.8625 Epoch 3/30 60000/60000 [==============================] - 2s 28us/step - loss: 0.5879 - acc: 0.8587 - val_loss: 0.5249 - val_acc: 0.8746 Epoch 4/30 60000/60000 [==============================] - 2s 30us/step - loss: 0.5260 - acc: 0.8691 - val_loss: 0.4793 - val_acc: 0.8813 Epoch 5/30 60000/60000 [==============================] - 3s 46us/step - loss: 0.4882 - acc: 0.8752 - val_loss: 0.4497 - val_acc: 0.8860 Epoch 6/30 60000/60000 [==============================] - 2s 41us/step - loss: 0.4623 - acc: 0.8793 - val_loss: 0.4283 - val_acc: 0.8915 Epoch 7/30 60000/60000 [==============================] - 2s 28us/step - loss: 0.4430 - acc: 0.8829 - val_loss: 0.4121 - val_acc: 0.8941 Epoch 8/30 60000/60000 [==============================] - 1s 24us/step - loss: 0.4281 - acc: 0.8862 - val_loss: 0.3991 - val_acc: 0.8959 Epoch 9/30 60000/60000 [==============================] - 2s 32us/step - loss: 0.4160 - acc: 0.8883 - val_loss: 0.3893 - val_acc: 0.8978 Epoch 10/30 60000/60000 [==============================] - 2s 37us/step - loss: 0.4060 - acc: 0.8903 - val_loss: 0.3802 - val_acc: 0.8994 Epoch 11/30 60000/60000 [==============================] - 2s 32us/step - loss: 0.3975 - acc: 0.8920 - val_loss: 0.3727 - val_acc: 0.9011 Epoch 12/30 60000/60000 [==============================] - 1s 22us/step - loss: 0.3902 - acc: 0.8933 - val_loss: 0.3665 - val_acc: 0.9016 Epoch 13/30 60000/60000 [==============================] - 1s 24us/step - loss: 0.3838 - acc: 0.8949 - val_loss: 0.3611 - val_acc: 0.9026 Epoch 14/30 60000/60000 [==============================] - 1s 18us/step - loss: 0.3782 - acc: 0.8964 - val_loss: 0.3558 - val_acc: 0.9038 Epoch 15/30 60000/60000 [==============================] - 1s 22us/step - loss: 0.3731 - acc: 0.8973 - val_loss: 0.3516 - val_acc: 0.9051 Epoch 16/30 60000/60000 [==============================] - 1s 23us/step - loss: 0.3686 - acc: 0.8982 - val_loss: 0.3476 - val_acc: 0.9064 Epoch 17/30 60000/60000 [==============================] - 1s 21us/step - loss: 0.3644 - acc: 0.8992 - val_loss: 0.3439 - val_acc: 0.9065 Epoch 18/30 60000/60000 [==============================] - 1s 22us/step - loss: 0.3607 - acc: 0.9000 - val_loss: 0.3408 - val_acc: 0.9073 Epoch 19/30 60000/60000 [==============================] - 1s 21us/step - loss: 0.3572 - acc: 0.9008 - val_loss: 0.3378 - val_acc: 0.9074 Epoch 20/30 60000/60000 [==============================] - 1s 18us/step - loss: 0.3540 - acc: 0.9015 - val_loss: 0.3350 - val_acc: 0.9080 Epoch 21/30 60000/60000 [==============================] - 1s 25us/step - loss: 0.3511 - acc: 0.9024 - val_loss: 0.3324 - val_acc: 0.9085 Epoch 22/30 60000/60000 [==============================] - 1s 18us/step - loss: 0.3483 - acc: 0.9030 - val_loss: 0.3302 - val_acc: 0.9091 Epoch 23/30 60000/60000 [==============================] - 1s 18us/step - loss: 0.3458 - acc: 0.9035 - val_loss: 0.3281 - val_acc: 0.9099 Epoch 24/30 60000/60000 [==============================] - 1s 20us/step - loss: 0.3435 - acc: 0.9040 - val_loss: 0.3261 - val_acc: 0.9104 Epoch 25/30 60000/60000 [==============================] - 1s 20us/step - loss: 0.3412 - acc: 0.9046 - val_loss: 0.3243 - val_acc: 0.9109 Epoch 26/30 60000/60000 [==============================] - 2s 25us/step - loss: 0.3391 - acc: 0.9052 - val_loss: 0.3224 - val_acc: 0.9109 Epoch 27/30 60000/60000 [==============================] - 1s 23us/step - loss: 0.3371 - acc: 0.9055 - val_loss: 0.3206 - val_acc: 0.9119 Epoch 28/30 60000/60000 [==============================] - 1s 22us/step - loss: 0.3353 - acc: 0.9063 - val_loss: 0.3191 - val_acc: 0.9117 Epoch 29/30 60000/60000 [==============================] - 1s 25us/step - loss: 0.3335 - acc: 0.9066 - val_loss: 0.3180 - val_acc: 0.9126 Epoch 30/30 60000/60000 [==============================] - 1s 19us/step - loss: 0.3318 - acc: 0.9073 - val_loss: 0.3162 - val_acc: 0.9127 Test score: 0.31617396712303164 Test accuracy: 0.9127
pred = model.predict(X_test)
print(pred)
pred_classes = model.predict_classes(X_test)
print(pred_classes)
[[1.92300795e-04 3.05666418e-07 2.73727986e-04 ... 9.94899571e-01 1.48842766e-04 2.02709646e-03] [7.04828463e-03 1.23849735e-04 9.05364752e-01 ... 6.24142373e-08 5.18862577e-03 1.11505744e-06] [1.44006510e-04 9.55273390e-01 1.45111205e-02 ... 5.02551626e-03 9.41167492e-03 1.76334998e-03] ... [1.89898765e-06 8.51190543e-06 7.51070256e-05 ... 4.82808892e-03 1.33694988e-02 4.35379557e-02] [3.01771378e-03 4.32022521e-03 1.24727038e-03 ... 7.10607506e-04 3.68826538e-01 1.30194030e-03] [2.76002334e-04 8.18013834e-09 8.10557511e-04 ... 1.97109102e-08 4.32123943e-06 7.32404715e-07]] [7 2 1 ... 4 5 6]
(X_train3, y_train3), (X_test3, y_test3) = mnist.load_data()
test = X_test3[:20]
plt.imshow(test[8], cmap=plt.cm.gray)
plt.show()
# preprocessing
test = test.reshape(test.shape[0], input_dim)
test = test.astype('float32')
test /= 255
predicted_numbers = model.predict_classes(test)
predicted_prob = model.predict(test)
print('prediction', predicted_numbers)
print('truth', y_test3[:20])
print(predicted_prob[8])
print(predicted_numbers[8])
print(y_test3[8])
prediction [7 2 1 0 4 1 4 9 6 9 0 6 9 0 1 5 9 7 3 4] truth [7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4] [8.9296000e-03 1.3875663e-04 3.4160007e-02 2.3498736e-05 1.7972823e-02 8.6964490e-03 9.2203307e-01 2.3101331e-05 6.4728083e-03 1.5499224e-03] 6 5
https://keras.io/models/model/
Workflow:
Cheatsheet: https://s3.amazonaws.com/assets.datacamp.com/blog_assets/Keras_Cheat_Sheet_Python.pdf
# Plot learning curve
# How to get the accuracy metrics:
#
# history = model.fit(..., metrics=['accuracy'])
# ...
# loss = history.history['loss']
# val_loss = history.history['val_loss']
print(history.history.keys())
print(history.history['val_loss'])
# Use matplotlib to plot 'val_loss' and 'loss' vs. number of epochs
# Your code here
Material | Read it for | URL |
---|---|---|
Lecture 1: Deep Learning Challenge. Is There Theory? | Intro to Deep Learning | https://stats385.github.io/lecture_slides (lecture 1) |
Lecture 2: Overview of Deep Learning from a Practical Point of View | More background on Neural Nets | https://stats385.github.io/lecture_slides (lecture 2) |
Neural Networks and Deep Learning, Chapter 2 | Understanding Back Propagation | http://neuralnetworksanddeeplearning.com/chap2.html |
Guide to the Sequential Model | Basic usage of Keras for neural net training | https://keras.io/getting-started/sequential-model-guide/ |