If you are viewing this file in preview mode, some links won't work. Find the fully featured Jupyter Notebook file on the website of Prof. Jens Flemming at Zwickau University of Applied Sciences. This work is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License.
So far we only considered the basics of CNNs. Now we discuss techniques for improving prediction quality and for decreasing training times. First we introduce the ideas, then we implement all techniques for better cats and dogs classification.
Prediction accuracy heavily depends on amount and variety of data available for training. Collecting more data is expensive. Thus, we could generate synthetic data from existing data. In case of image data we may rotate, scale, translate or distort the images to get new images showing identical objects in slightly different ways. This idea is known as data augmentation and increases the amount os data as well as the variety.
Kera's ImageDataGenerator
class provides several types of data augmentation (rotation, zoom, pan, brightness, flip, and some more). Activating this feature yields a stream of augmented images.
CNNs have to major components: the feature extraction stack (convolutional and pooling layers) and the decision stack (dense layers for classification or regression). The task of the feature extraction stack is to automatically preprocess images resulting in a set of feature maps containing higher level information than just colored pixels. Based on this higher level information the decision stack predicts the targets.
With this two-step approach in mind we may use more powerful feature extraction. The feature extraction part is more or less the same for all object classification problem in image processing. Thus, we might use a feature extraction stack trained on much larger training data and with much more computational resources. Such pre-trained CNNs are available in the internet and Keras ships with some, too. See Keras Applications for a list of pre-trained CNNs in Keras.
In Keras' documentation the feature extraction stack is called convolutional base and the decision stack is the head of the CNN. When loading a pre-trained model we have to decide wether to load the full model or only the convolutional base. If we do not use the pre-trained head, we have to specify the input shape for the network. This sounds a bit strange, but the convolutional base works for arbitrary input shapes and specifing a concrete shape fixes the output shape of the convolutional base. If use the pre-trained head, then the output shape of the convolutional base has to fit the input shape of the head. Thus, the head determines the input shape of the CNN.
Up to now we only considered simple gradient descent. But there are much better algorithms for minimizing loss functions. Keras implements some of them and we should use them although at the moment we do not know what those algorithms do in detail. We will have a look at advanced minimization techniques next semester in the lecture series on numerical methods.
Loading images from the disk and preprocessing them during training might slow down training. One solution is to load all images (including augmentation) to memory before training, but large memory is required. Another solution is to asynchronously load and preprocess data. That is, while the GPU does some calculations the CPU loads and preprocesses images. Keras and TensorFlow support such advanced techniques, but we will not cover them here.
We consider object detection with cats and dogs again.
import numpy as np
import matplotlib.pyplot as plt
import tensorflow.keras as keras
#import tensorflow as tf
#physical_devices = tf.config.list_physical_devices('GPU')
#tf.config.experimental.set_memory_growth(physical_devices[0], True)
#gpu_idx = 2
#gpu_list = tf.config.experimental.list_physical_devices(device_type='GPU')
#tf.config.experimental.set_visible_devices(gpu_list[gpu_idx], 'GPU')
data_path = '/home/ZW.FH-ZWICKAU.DE/all_users/cats_and_dogs/'
We load a pre-trained convolutional base.
img_size = 128 # width and height of images
conv_base = keras.applications.Xception(
include_top=False,
input_shape=(img_size, img_size, 3)
)
conv_base.summary()
Model: "xception" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 128, 128, 3) 0 __________________________________________________________________________________________________ block1_conv1 (Conv2D) (None, 63, 63, 32) 864 input_1[0][0] __________________________________________________________________________________________________ block1_conv1_bn (BatchNormaliza (None, 63, 63, 32) 128 block1_conv1[0][0] __________________________________________________________________________________________________ block1_conv1_act (Activation) (None, 63, 63, 32) 0 block1_conv1_bn[0][0] __________________________________________________________________________________________________ block1_conv2 (Conv2D) (None, 61, 61, 64) 18432 block1_conv1_act[0][0] __________________________________________________________________________________________________ block1_conv2_bn (BatchNormaliza (None, 61, 61, 64) 256 block1_conv2[0][0] __________________________________________________________________________________________________ block1_conv2_act (Activation) (None, 61, 61, 64) 0 block1_conv2_bn[0][0] __________________________________________________________________________________________________ block2_sepconv1 (SeparableConv2 (None, 61, 61, 128) 8768 block1_conv2_act[0][0] __________________________________________________________________________________________________ block2_sepconv1_bn (BatchNormal (None, 61, 61, 128) 512 block2_sepconv1[0][0] __________________________________________________________________________________________________ block2_sepconv2_act (Activation (None, 61, 61, 128) 0 block2_sepconv1_bn[0][0] __________________________________________________________________________________________________ block2_sepconv2 (SeparableConv2 (None, 61, 61, 128) 17536 block2_sepconv2_act[0][0] __________________________________________________________________________________________________ block2_sepconv2_bn (BatchNormal (None, 61, 61, 128) 512 block2_sepconv2[0][0] __________________________________________________________________________________________________ conv2d (Conv2D) (None, 31, 31, 128) 8192 block1_conv2_act[0][0] __________________________________________________________________________________________________ block2_pool (MaxPooling2D) (None, 31, 31, 128) 0 block2_sepconv2_bn[0][0] __________________________________________________________________________________________________ batch_normalization (BatchNorma (None, 31, 31, 128) 512 conv2d[0][0] __________________________________________________________________________________________________ add (Add) (None, 31, 31, 128) 0 block2_pool[0][0] batch_normalization[0][0] __________________________________________________________________________________________________ block3_sepconv1_act (Activation (None, 31, 31, 128) 0 add[0][0] __________________________________________________________________________________________________ block3_sepconv1 (SeparableConv2 (None, 31, 31, 256) 33920 block3_sepconv1_act[0][0] __________________________________________________________________________________________________ block3_sepconv1_bn (BatchNormal (None, 31, 31, 256) 1024 block3_sepconv1[0][0] __________________________________________________________________________________________________ block3_sepconv2_act (Activation (None, 31, 31, 256) 0 block3_sepconv1_bn[0][0] __________________________________________________________________________________________________ block3_sepconv2 (SeparableConv2 (None, 31, 31, 256) 67840 block3_sepconv2_act[0][0] __________________________________________________________________________________________________ block3_sepconv2_bn (BatchNormal (None, 31, 31, 256) 1024 block3_sepconv2[0][0] __________________________________________________________________________________________________ conv2d_1 (Conv2D) (None, 16, 16, 256) 32768 add[0][0] __________________________________________________________________________________________________ block3_pool (MaxPooling2D) (None, 16, 16, 256) 0 block3_sepconv2_bn[0][0] __________________________________________________________________________________________________ batch_normalization_1 (BatchNor (None, 16, 16, 256) 1024 conv2d_1[0][0] __________________________________________________________________________________________________ add_1 (Add) (None, 16, 16, 256) 0 block3_pool[0][0] batch_normalization_1[0][0] __________________________________________________________________________________________________ block4_sepconv1_act (Activation (None, 16, 16, 256) 0 add_1[0][0] __________________________________________________________________________________________________ block4_sepconv1 (SeparableConv2 (None, 16, 16, 728) 188672 block4_sepconv1_act[0][0] __________________________________________________________________________________________________ block4_sepconv1_bn (BatchNormal (None, 16, 16, 728) 2912 block4_sepconv1[0][0] __________________________________________________________________________________________________ block4_sepconv2_act (Activation (None, 16, 16, 728) 0 block4_sepconv1_bn[0][0] __________________________________________________________________________________________________ block4_sepconv2 (SeparableConv2 (None, 16, 16, 728) 536536 block4_sepconv2_act[0][0] __________________________________________________________________________________________________ block4_sepconv2_bn (BatchNormal (None, 16, 16, 728) 2912 block4_sepconv2[0][0] __________________________________________________________________________________________________ conv2d_2 (Conv2D) (None, 8, 8, 728) 186368 add_1[0][0] __________________________________________________________________________________________________ block4_pool (MaxPooling2D) (None, 8, 8, 728) 0 block4_sepconv2_bn[0][0] __________________________________________________________________________________________________ batch_normalization_2 (BatchNor (None, 8, 8, 728) 2912 conv2d_2[0][0] __________________________________________________________________________________________________ add_2 (Add) (None, 8, 8, 728) 0 block4_pool[0][0] batch_normalization_2[0][0] __________________________________________________________________________________________________ block5_sepconv1_act (Activation (None, 8, 8, 728) 0 add_2[0][0] __________________________________________________________________________________________________ block5_sepconv1 (SeparableConv2 (None, 8, 8, 728) 536536 block5_sepconv1_act[0][0] __________________________________________________________________________________________________ block5_sepconv1_bn (BatchNormal (None, 8, 8, 728) 2912 block5_sepconv1[0][0] __________________________________________________________________________________________________ block5_sepconv2_act (Activation (None, 8, 8, 728) 0 block5_sepconv1_bn[0][0] __________________________________________________________________________________________________ block5_sepconv2 (SeparableConv2 (None, 8, 8, 728) 536536 block5_sepconv2_act[0][0] __________________________________________________________________________________________________ block5_sepconv2_bn (BatchNormal (None, 8, 8, 728) 2912 block5_sepconv2[0][0] __________________________________________________________________________________________________ block5_sepconv3_act (Activation (None, 8, 8, 728) 0 block5_sepconv2_bn[0][0] __________________________________________________________________________________________________ block5_sepconv3 (SeparableConv2 (None, 8, 8, 728) 536536 block5_sepconv3_act[0][0] __________________________________________________________________________________________________ block5_sepconv3_bn (BatchNormal (None, 8, 8, 728) 2912 block5_sepconv3[0][0] __________________________________________________________________________________________________ add_3 (Add) (None, 8, 8, 728) 0 block5_sepconv3_bn[0][0] add_2[0][0] __________________________________________________________________________________________________ block6_sepconv1_act (Activation (None, 8, 8, 728) 0 add_3[0][0] __________________________________________________________________________________________________ block6_sepconv1 (SeparableConv2 (None, 8, 8, 728) 536536 block6_sepconv1_act[0][0] __________________________________________________________________________________________________ block6_sepconv1_bn (BatchNormal (None, 8, 8, 728) 2912 block6_sepconv1[0][0] __________________________________________________________________________________________________ block6_sepconv2_act (Activation (None, 8, 8, 728) 0 block6_sepconv1_bn[0][0] __________________________________________________________________________________________________ block6_sepconv2 (SeparableConv2 (None, 8, 8, 728) 536536 block6_sepconv2_act[0][0] __________________________________________________________________________________________________ block6_sepconv2_bn (BatchNormal (None, 8, 8, 728) 2912 block6_sepconv2[0][0] __________________________________________________________________________________________________ block6_sepconv3_act (Activation (None, 8, 8, 728) 0 block6_sepconv2_bn[0][0] __________________________________________________________________________________________________ block6_sepconv3 (SeparableConv2 (None, 8, 8, 728) 536536 block6_sepconv3_act[0][0] __________________________________________________________________________________________________ block6_sepconv3_bn (BatchNormal (None, 8, 8, 728) 2912 block6_sepconv3[0][0] __________________________________________________________________________________________________ add_4 (Add) (None, 8, 8, 728) 0 block6_sepconv3_bn[0][0] add_3[0][0] __________________________________________________________________________________________________ block7_sepconv1_act (Activation (None, 8, 8, 728) 0 add_4[0][0] __________________________________________________________________________________________________ block7_sepconv1 (SeparableConv2 (None, 8, 8, 728) 536536 block7_sepconv1_act[0][0] __________________________________________________________________________________________________ block7_sepconv1_bn (BatchNormal (None, 8, 8, 728) 2912 block7_sepconv1[0][0] __________________________________________________________________________________________________ block7_sepconv2_act (Activation (None, 8, 8, 728) 0 block7_sepconv1_bn[0][0] __________________________________________________________________________________________________ block7_sepconv2 (SeparableConv2 (None, 8, 8, 728) 536536 block7_sepconv2_act[0][0] __________________________________________________________________________________________________ block7_sepconv2_bn (BatchNormal (None, 8, 8, 728) 2912 block7_sepconv2[0][0] __________________________________________________________________________________________________ block7_sepconv3_act (Activation (None, 8, 8, 728) 0 block7_sepconv2_bn[0][0] __________________________________________________________________________________________________ block7_sepconv3 (SeparableConv2 (None, 8, 8, 728) 536536 block7_sepconv3_act[0][0] __________________________________________________________________________________________________ block7_sepconv3_bn (BatchNormal (None, 8, 8, 728) 2912 block7_sepconv3[0][0] __________________________________________________________________________________________________ add_5 (Add) (None, 8, 8, 728) 0 block7_sepconv3_bn[0][0] add_4[0][0] __________________________________________________________________________________________________ block8_sepconv1_act (Activation (None, 8, 8, 728) 0 add_5[0][0] __________________________________________________________________________________________________ block8_sepconv1 (SeparableConv2 (None, 8, 8, 728) 536536 block8_sepconv1_act[0][0] __________________________________________________________________________________________________ block8_sepconv1_bn (BatchNormal (None, 8, 8, 728) 2912 block8_sepconv1[0][0] __________________________________________________________________________________________________ block8_sepconv2_act (Activation (None, 8, 8, 728) 0 block8_sepconv1_bn[0][0] __________________________________________________________________________________________________ block8_sepconv2 (SeparableConv2 (None, 8, 8, 728) 536536 block8_sepconv2_act[0][0] __________________________________________________________________________________________________ block8_sepconv2_bn (BatchNormal (None, 8, 8, 728) 2912 block8_sepconv2[0][0] __________________________________________________________________________________________________ block8_sepconv3_act (Activation (None, 8, 8, 728) 0 block8_sepconv2_bn[0][0] __________________________________________________________________________________________________ block8_sepconv3 (SeparableConv2 (None, 8, 8, 728) 536536 block8_sepconv3_act[0][0] __________________________________________________________________________________________________ block8_sepconv3_bn (BatchNormal (None, 8, 8, 728) 2912 block8_sepconv3[0][0] __________________________________________________________________________________________________ add_6 (Add) (None, 8, 8, 728) 0 block8_sepconv3_bn[0][0] add_5[0][0] __________________________________________________________________________________________________ block9_sepconv1_act (Activation (None, 8, 8, 728) 0 add_6[0][0] __________________________________________________________________________________________________ block9_sepconv1 (SeparableConv2 (None, 8, 8, 728) 536536 block9_sepconv1_act[0][0] __________________________________________________________________________________________________ block9_sepconv1_bn (BatchNormal (None, 8, 8, 728) 2912 block9_sepconv1[0][0] __________________________________________________________________________________________________ block9_sepconv2_act (Activation (None, 8, 8, 728) 0 block9_sepconv1_bn[0][0] __________________________________________________________________________________________________ block9_sepconv2 (SeparableConv2 (None, 8, 8, 728) 536536 block9_sepconv2_act[0][0] __________________________________________________________________________________________________ block9_sepconv2_bn (BatchNormal (None, 8, 8, 728) 2912 block9_sepconv2[0][0] __________________________________________________________________________________________________ block9_sepconv3_act (Activation (None, 8, 8, 728) 0 block9_sepconv2_bn[0][0] __________________________________________________________________________________________________ block9_sepconv3 (SeparableConv2 (None, 8, 8, 728) 536536 block9_sepconv3_act[0][0] __________________________________________________________________________________________________ block9_sepconv3_bn (BatchNormal (None, 8, 8, 728) 2912 block9_sepconv3[0][0] __________________________________________________________________________________________________ add_7 (Add) (None, 8, 8, 728) 0 block9_sepconv3_bn[0][0] add_6[0][0] __________________________________________________________________________________________________ block10_sepconv1_act (Activatio (None, 8, 8, 728) 0 add_7[0][0] __________________________________________________________________________________________________ block10_sepconv1 (SeparableConv (None, 8, 8, 728) 536536 block10_sepconv1_act[0][0] __________________________________________________________________________________________________ block10_sepconv1_bn (BatchNorma (None, 8, 8, 728) 2912 block10_sepconv1[0][0] __________________________________________________________________________________________________ block10_sepconv2_act (Activatio (None, 8, 8, 728) 0 block10_sepconv1_bn[0][0] __________________________________________________________________________________________________ block10_sepconv2 (SeparableConv (None, 8, 8, 728) 536536 block10_sepconv2_act[0][0] __________________________________________________________________________________________________ block10_sepconv2_bn (BatchNorma (None, 8, 8, 728) 2912 block10_sepconv2[0][0] __________________________________________________________________________________________________ block10_sepconv3_act (Activatio (None, 8, 8, 728) 0 block10_sepconv2_bn[0][0] __________________________________________________________________________________________________ block10_sepconv3 (SeparableConv (None, 8, 8, 728) 536536 block10_sepconv3_act[0][0] __________________________________________________________________________________________________ block10_sepconv3_bn (BatchNorma (None, 8, 8, 728) 2912 block10_sepconv3[0][0] __________________________________________________________________________________________________ add_8 (Add) (None, 8, 8, 728) 0 block10_sepconv3_bn[0][0] add_7[0][0] __________________________________________________________________________________________________ block11_sepconv1_act (Activatio (None, 8, 8, 728) 0 add_8[0][0] __________________________________________________________________________________________________ block11_sepconv1 (SeparableConv (None, 8, 8, 728) 536536 block11_sepconv1_act[0][0] __________________________________________________________________________________________________ block11_sepconv1_bn (BatchNorma (None, 8, 8, 728) 2912 block11_sepconv1[0][0] __________________________________________________________________________________________________ block11_sepconv2_act (Activatio (None, 8, 8, 728) 0 block11_sepconv1_bn[0][0] __________________________________________________________________________________________________ block11_sepconv2 (SeparableConv (None, 8, 8, 728) 536536 block11_sepconv2_act[0][0] __________________________________________________________________________________________________ block11_sepconv2_bn (BatchNorma (None, 8, 8, 728) 2912 block11_sepconv2[0][0] __________________________________________________________________________________________________ block11_sepconv3_act (Activatio (None, 8, 8, 728) 0 block11_sepconv2_bn[0][0] __________________________________________________________________________________________________ block11_sepconv3 (SeparableConv (None, 8, 8, 728) 536536 block11_sepconv3_act[0][0] __________________________________________________________________________________________________ block11_sepconv3_bn (BatchNorma (None, 8, 8, 728) 2912 block11_sepconv3[0][0] __________________________________________________________________________________________________ add_9 (Add) (None, 8, 8, 728) 0 block11_sepconv3_bn[0][0] add_8[0][0] __________________________________________________________________________________________________ block12_sepconv1_act (Activatio (None, 8, 8, 728) 0 add_9[0][0] __________________________________________________________________________________________________ block12_sepconv1 (SeparableConv (None, 8, 8, 728) 536536 block12_sepconv1_act[0][0] __________________________________________________________________________________________________ block12_sepconv1_bn (BatchNorma (None, 8, 8, 728) 2912 block12_sepconv1[0][0] __________________________________________________________________________________________________ block12_sepconv2_act (Activatio (None, 8, 8, 728) 0 block12_sepconv1_bn[0][0] __________________________________________________________________________________________________ block12_sepconv2 (SeparableConv (None, 8, 8, 728) 536536 block12_sepconv2_act[0][0] __________________________________________________________________________________________________ block12_sepconv2_bn (BatchNorma (None, 8, 8, 728) 2912 block12_sepconv2[0][0] __________________________________________________________________________________________________ block12_sepconv3_act (Activatio (None, 8, 8, 728) 0 block12_sepconv2_bn[0][0] __________________________________________________________________________________________________ block12_sepconv3 (SeparableConv (None, 8, 8, 728) 536536 block12_sepconv3_act[0][0] __________________________________________________________________________________________________ block12_sepconv3_bn (BatchNorma (None, 8, 8, 728) 2912 block12_sepconv3[0][0] __________________________________________________________________________________________________ add_10 (Add) (None, 8, 8, 728) 0 block12_sepconv3_bn[0][0] add_9[0][0] __________________________________________________________________________________________________ block13_sepconv1_act (Activatio (None, 8, 8, 728) 0 add_10[0][0] __________________________________________________________________________________________________ block13_sepconv1 (SeparableConv (None, 8, 8, 728) 536536 block13_sepconv1_act[0][0] __________________________________________________________________________________________________ block13_sepconv1_bn (BatchNorma (None, 8, 8, 728) 2912 block13_sepconv1[0][0] __________________________________________________________________________________________________ block13_sepconv2_act (Activatio (None, 8, 8, 728) 0 block13_sepconv1_bn[0][0] __________________________________________________________________________________________________ block13_sepconv2 (SeparableConv (None, 8, 8, 1024) 752024 block13_sepconv2_act[0][0] __________________________________________________________________________________________________ block13_sepconv2_bn (BatchNorma (None, 8, 8, 1024) 4096 block13_sepconv2[0][0] __________________________________________________________________________________________________ conv2d_3 (Conv2D) (None, 4, 4, 1024) 745472 add_10[0][0] __________________________________________________________________________________________________ block13_pool (MaxPooling2D) (None, 4, 4, 1024) 0 block13_sepconv2_bn[0][0] __________________________________________________________________________________________________ batch_normalization_3 (BatchNor (None, 4, 4, 1024) 4096 conv2d_3[0][0] __________________________________________________________________________________________________ add_11 (Add) (None, 4, 4, 1024) 0 block13_pool[0][0] batch_normalization_3[0][0] __________________________________________________________________________________________________ block14_sepconv1 (SeparableConv (None, 4, 4, 1536) 1582080 add_11[0][0] __________________________________________________________________________________________________ block14_sepconv1_bn (BatchNorma (None, 4, 4, 1536) 6144 block14_sepconv1[0][0] __________________________________________________________________________________________________ block14_sepconv1_act (Activatio (None, 4, 4, 1536) 0 block14_sepconv1_bn[0][0] __________________________________________________________________________________________________ block14_sepconv2 (SeparableConv (None, 4, 4, 2048) 3159552 block14_sepconv1_act[0][0] __________________________________________________________________________________________________ block14_sepconv2_bn (BatchNorma (None, 4, 4, 2048) 8192 block14_sepconv2[0][0] __________________________________________________________________________________________________ block14_sepconv2_act (Activatio (None, 4, 4, 2048) 0 block14_sepconv2_bn[0][0] ================================================================================================== Total params: 20,861,480 Trainable params: 20,806,952 Non-trainable params: 54,528 __________________________________________________________________________________________________
We see that there are new layer types: separable convolutions and batch normalization. Separable convolutions are a special case of usual convolution allowing for more efficient computation by restricting to specially structured filters. Batch normalization is a kind of rescaling layer outputs. The more important observation is the output shape: 4x4x2048. That is, we obtain 2048 feature maps each of size 4x4. This is where we connect our decision stack.
Models in Keras behave like layers (the Model
class inherits from Layer
). Thus, we may create a new model with the pre-trained convolutional base as one layer.
model = keras.models.Sequential()
model.add(conv_base)
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(10, activation='relu', name='dense1'))
model.add(keras.layers.Dense(10, activation='relu', name='dense2'))
model.add(keras.layers.Dense(2, activation='sigmoid', name='out'))
model.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= xception (Model) (None, 4, 4, 2048) 20861480 _________________________________________________________________ flatten_1 (Flatten) (None, 32768) 0 _________________________________________________________________ dense1 (Dense) (None, 10) 327690 _________________________________________________________________ dense2 (Dense) (None, 10) 110 _________________________________________________________________ out (Dense) (None, 2) 22 ================================================================= Total params: 21,189,302 Trainable params: 327,822 Non-trainable params: 20,861,480 _________________________________________________________________
Before we start training we have to tell Keras to keep the weights of the convolutional base constant. We simply have to set the layer's trainable
attribute to False
:
model.get_layer('xception').trainable = False
model.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= xception (Model) (None, 4, 4, 2048) 20861480 _________________________________________________________________ flatten_1 (Flatten) (None, 32768) 0 _________________________________________________________________ dense1 (Dense) (None, 10) 327690 _________________________________________________________________ dense2 (Dense) (None, 10) 110 _________________________________________________________________ out (Dense) (None, 2) 22 ================================================================= Total params: 21,189,302 Trainable params: 327,822 Non-trainable params: 20,861,480 _________________________________________________________________
For training we use Keras' default optimizer RMSProp.
model.compile(loss='mean_squared_error', metrics=['categorical_accuracy'])
To speed up training we would like to have all data in memory. Images have $128^2=16384$ pixels, each taking 3 bytes for the colors (one byte per channel) if color values are integers. For colors scaled to $[0,1]$ we need 4 bytes per channel with np.float32
as data type. Thus, we need 196608 bytes per image, say 200 kB. These are 5 images per MB or 5000 images per GB. Our data set has 25000 images and we could increase it to arbitrary size by data augmentation. Note that data augmentation is only useful for training data. Validation and test data should not be augmented. To save memory we do augmentation in real-time, that is, we only keep original training images in memory and generate batches of augmented images as needed.
To implement data augmentation we simply have to pass corresponding arguments to ImageDataGenerator
.
Since we do not want to augment validation and test images, we use two-step approach. We first load all images to memory. Then we use a second ImageDataGenerator
object to create an iterator yielding augmented training images.
img_generator = keras.preprocessing.image.ImageDataGenerator(
validation_split=0.25,
dtype=np.float32
)
orig_train_iterator = img_generator.flow_from_directory(
data_path + 'labeled/train',
subset='training',
target_size=(img_size, img_size),
batch_size=15000
)
val_iterator = img_generator.flow_from_directory(
data_path + 'labeled/train',
subset='validation',
target_size=(img_size, img_size),
batch_size=5000
)
test_iterator = img_generator.flow_from_directory(
data_path + 'labeled/test',
target_size=(img_size, img_size),
batch_size=5000
)
orig_train_images, orig_train_labels = next(orig_train_iterator)
val_images, val_labels = next(val_iterator)
test_images, test_labels = next(test_iterator)
Found 15000 images belonging to 2 classes. Found 5000 images belonging to 2 classes. Found 5000 images belonging to 2 classes.
When using pre-trained models, data preprocessing has to be done in exactly the same way as has been done in training. For each pre-trained model in Keras there is a preprocess_input
function doing necessary preprocessing. If images are provided to Model.fit
by an iterator, we have to tell the iterator that the preprocessing function has to be applied before yielding an image. For this purpose ImageDataGenerator
accepts the preprocessing_function
argument.
The ImageDataGenerator.flow
function streams images from memory while augmenting them.
aug_img_generator = keras.preprocessing.image.ImageDataGenerator(
dtype=np.float32,
rotation_range=10,
horizontal_flip=True,
preprocessing_function=keras.applications.xception.preprocess_input
)
train_iterator = aug_img_generator.flow(
x=orig_train_images,
y=orig_train_labels,
batch_size=32
)
val_images = keras.applications.xception.preprocess_input(val_images)
test_images = keras.applications.xception.preprocess_input(test_images)
Now training can be started. Since augmentation yields an infinite set of training data we have to tell fit
the length of an epoch by providing the number of batches per epoch via steps_per_epoch
argument.
loss = []
val_loss = []
acc = []
val_acc = []
history = model.fit(
train_iterator,
epochs=20,
steps_per_epoch=100,
validation_data=(val_images, val_labels)
)
loss.extend(history.history['loss'])
val_loss.extend(history.history['val_loss'])
acc.extend(history.history['categorical_accuracy'])
val_acc.extend(history.history['val_categorical_accuracy'])
Epoch 1/20 100/100 [==============================] - 17s 168ms/step - loss: 0.0321 - categorical_accuracy: 0.9588 - val_loss: 0.0316 - val_categorical_accuracy: 0.9564 Epoch 2/20 100/100 [==============================] - 17s 172ms/step - loss: 0.0305 - categorical_accuracy: 0.9603 - val_loss: 0.0364 - val_categorical_accuracy: 0.9518 Epoch 3/20 100/100 [==============================] - 16s 165ms/step - loss: 0.0294 - categorical_accuracy: 0.9647 - val_loss: 0.0302 - val_categorical_accuracy: 0.9600 Epoch 4/20 100/100 [==============================] - 17s 172ms/step - loss: 0.0293 - categorical_accuracy: 0.9597 - val_loss: 0.0303 - val_categorical_accuracy: 0.9592 Epoch 5/20 100/100 [==============================] - 16s 164ms/step - loss: 0.0305 - categorical_accuracy: 0.9613 - val_loss: 0.0371 - val_categorical_accuracy: 0.9526 Epoch 6/20 100/100 [==============================] - 17s 168ms/step - loss: 0.0329 - categorical_accuracy: 0.9581 - val_loss: 0.0324 - val_categorical_accuracy: 0.9560 Epoch 7/20 100/100 [==============================] - 17s 168ms/step - loss: 0.0331 - categorical_accuracy: 0.9568 - val_loss: 0.0314 - val_categorical_accuracy: 0.9590 Epoch 8/20 100/100 [==============================] - 17s 170ms/step - loss: 0.0280 - categorical_accuracy: 0.9641 - val_loss: 0.0330 - val_categorical_accuracy: 0.9554 Epoch 9/20 100/100 [==============================] - 17s 166ms/step - loss: 0.0304 - categorical_accuracy: 0.9594 - val_loss: 0.0317 - val_categorical_accuracy: 0.9600 Epoch 10/20 100/100 [==============================] - 17s 168ms/step - loss: 0.0335 - categorical_accuracy: 0.9541 - val_loss: 0.0304 - val_categorical_accuracy: 0.9562 Epoch 11/20 100/100 [==============================] - 16s 165ms/step - loss: 0.0264 - categorical_accuracy: 0.9644 - val_loss: 0.0312 - val_categorical_accuracy: 0.9572 Epoch 12/20 100/100 [==============================] - 16s 164ms/step - loss: 0.0294 - categorical_accuracy: 0.9616 - val_loss: 0.0320 - val_categorical_accuracy: 0.9574 Epoch 13/20 100/100 [==============================] - 17s 169ms/step - loss: 0.0306 - categorical_accuracy: 0.9606 - val_loss: 0.0324 - val_categorical_accuracy: 0.9554 Epoch 14/20 100/100 [==============================] - 17s 168ms/step - loss: 0.0295 - categorical_accuracy: 0.9575 - val_loss: 0.0302 - val_categorical_accuracy: 0.9608 Epoch 15/20 100/100 [==============================] - 18s 176ms/step - loss: 0.0294 - categorical_accuracy: 0.9615 - val_loss: 0.0311 - val_categorical_accuracy: 0.9580 Epoch 16/20 100/100 [==============================] - 17s 170ms/step - loss: 0.0315 - categorical_accuracy: 0.9569 - val_loss: 0.0305 - val_categorical_accuracy: 0.9604 Epoch 17/20 100/100 [==============================] - 17s 165ms/step - loss: 0.0282 - categorical_accuracy: 0.9624 - val_loss: 0.0335 - val_categorical_accuracy: 0.9552 Epoch 18/20 100/100 [==============================] - 17s 166ms/step - loss: 0.0275 - categorical_accuracy: 0.9650 - val_loss: 0.0316 - val_categorical_accuracy: 0.9576 Epoch 19/20 100/100 [==============================] - 16s 165ms/step - loss: 0.0264 - categorical_accuracy: 0.9681 - val_loss: 0.0330 - val_categorical_accuracy: 0.9576 Epoch 20/20 100/100 [==============================] - 17s 166ms/step - loss: 0.0298 - categorical_accuracy: 0.9574 - val_loss: 0.0334 - val_categorical_accuracy: 0.9568
fig, ax = plt.subplots()
ax.plot(loss, '-b', label='training loss')
ax.plot(val_loss, '-r', label='validation loss')
ax.legend()
plt.show()
fig, ax = plt.subplots()
ax.plot(acc, '-b', label='training accuracy')
ax.plot(val_acc, '-r', label='validation accuracy')
ax.legend()
plt.show()
model.save('anncnnimprove/model')
INFO:tensorflow:Assets written to: anncnnimprove/model/assets
test_loss, test_metric = model.evaluate(x=test_images, y=test_labels)
157/157 [==============================] - 3s 21ms/step - loss: 0.0390 - categorical_accuracy: 0.9512
print(test_metric)
0.951200008392334