Tutorial on Image Augmentation Using Keras Preprocessing Layers

When working on an image-related machine learning problem, we must not only collect some images as training data, but we must also use augmentation to create variations in the image. This is particularly true for more difficult object recognition problems.

Image augmentation can be accomplished in a variety of ways. You can use external libraries or write your own functions to accomplish this. There are also some augmentation modules in TensorFlow and Keras. This post will show you how to use the Keras preprocessing layer and the tf.image module in TensorFlow for image augmentation.

You will understand the following after reading this post:

  1. What are Keras preprocessing layers and how do they work?
  2. What image augmentation functions does the tf.image module provide?
  3. How to use augmentation in conjunction with the tf.data dataset

Overview

This article is divided into five sections, which are as follows:

  1. Obtaining Images
  2. Visualizing Images
  3. Keras Preprocessing Layers
  4. Using tf.image API for Augmentation
  5. Using Preprocessing Layers in Neural Networks

 

  1. Obtaining Images

We need the images before we can see how we can do augmentation. Finally, the images must be represented as arrays, such as HxWx3 in 8-bit integers for the RGB pixel value. There are numerous methods for obtaining the images. Some are available as ZIP files. You can get some image datasets from the tensorflow_datasets library if you’re using TensorFlow.

We will use the citrus leaves images in this tutorial, which is a small dataset of less than 100MB. It is available for download from tensorflow_datasets as follows:

import tensorflow_datasets as tfds
ds, meta = tfds.load('citrus_leaves', with_info=True, split='train', shuffle_files=True)

When you run this code for the first time, it will download the image dataset into your computer and produce the following output:

Downloading and preparing dataset 63.87 MiB (download: 63.87 MiB, generated: 37.89 MiB, total: 101.76 MiB) to ~/tensorflow_datasets/citrus_leaves/0.1.2...
Extraction completed...: 100%|██████████████████████████████| 1/1 [00:06<00:00, 6.54s/ file]
Dl Size...: 100%|██████████████████████████████████████████| 63/63 [00:06<00:00, 9.63 MiB/s]
Dl Completed...: 100%|███████████████████████████████████████| 1/1 [00:06<00:00, 6.54s/ url]
Dataset citrus_leaves downloaded and prepared to ~/tensorflow_datasets/citrus_leaves/0.1.2. Subsequent calls will reuse this data.

The function above returns the images as a tf.data dataset object and the metadata. This is a dataset for classification. We can print the following on the training labels:

...
for i in range(meta.features['label'].num_classes):
print(meta.features['label'].int2str(i))

and this prints:

Black spot
canker
greening
healthy

However, the other way to load the downloaded images into a tf.data dataset is to the image_dataset_from_directory() function.

The dataset is downloaded into the directory ~/tensorflow_datasets as shown in the screen output above. When you look at the directory, you will notice the following directory structure:

.../Citrus/Leaves
├── Black spot
├── Melanose
├── canker
├── greening
└── healthy

The directories are the labels, and the images are files stored in their respective directories. We can instruct the function to read the directory into a dataset recursively:

import tensorflow as tf
from tensorflow.keras.utils import image_dataset_from_directory

# set to fixed image size 256x256
PATH = ".../Citrus/Leaves"
ds = image_dataset_from_directory(PATH,
validation_split=0.2, subset="training",
image_size=(256,256), interpolation="bilinear",
crop_to_aspect_ratio=True,
seed=42, shuffle=True, batch_size=32)

If you do not want the dataset to be batched, set batch_size=None. Typically, we want the dataset to be batched before training a neural network model.

  1. Visualizing Images

It is critical to visualize the augmentation result so that we can confirm that it is what we want. Matplotlib can be used for this.

The imshow() function in matplotlib is used to display an image. However, in order for the image to be displayed correctly, it must be presented as an array of 8-bit unsigned integers (uint8).

Given that we’ve created a dataset with image_dataset_from_directory(), we can get the first batch (of 32 images) and display a few of them with imshow(), as shown below:

...
import matplotlib.pyplot as plt

fig, ax = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(5,5))

for images, labels in ds.take(1):
for i in range(3):
for j in range(3):
ax[i][j].imshow(images[i*3+j].numpy().astype("uint8"))
ax[i][j].set_title(ds.class_names[labels[i*3+j]])
plt.show()

Using ds.class_names, we display 9 images in a grid and label each image with its corresponding classification label. For display, the images should be converted to a NumPy array in uint8. This code generates the following image:

The entire code from image loading to display is shown below.

from tensorflow.keras.utils import image_dataset_from_directory
import matplotlib.pyplot as plt

# use image_dataset_from_directory() to load images, with image size scaled to 256x256
PATH='.../Citrus/Leaves' # modify to your path
ds = image_dataset_from_directory(PATH,
validation_split=0.2, subset="training",
image_size=(256,256), interpolation="mitchellcubic",
crop_to_aspect_ratio=True,
seed=42, shuffle=True, batch_size=32)

# Take one batch from dataset and display the images
fig, ax = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(5,5))

for images, labels in ds.take(1):
for i in range(3):
for j in range(3):
ax[i][j].imshow(images[i*3+j].numpy().astype("uint8"))
ax[i][j].set_title(ds.class_names[labels[i*3+j]])
plt.show()

If you use tensorflow_datasets to get the image, the samples are presented as a dictionary rather than a tuple of values (image, label). You should modify your code slightly to look like this:

import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

# use tfds.load() or image_dataset_from_directory() to load images
ds, meta = tfds.load('citrus_leaves', with_info=True, split='train', shuffle_files=True)
ds = ds.batch(32)

# Take one batch from dataset and display the images
fig, ax = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(5,5))

for sample in ds.take(1):
images, labels = sample["image"], sample["label"]
for i in range(3):
for j in range(3):
ax[i][j].imshow(images[i*3+j].numpy().astype("uint8"))
ax[i][j].set_title(meta.features['label'].int2str(labels[i*3+j]))
plt.show()

In the following sections, we will assume that the dataset was created with image_dataset_from_directory(). If your dataset was created differently, you may need to modify the code slightly.

  1. Keras Preprocessing Layers

Keras includes many neural network layers, such as convolution layers, for training. There are also no-parameter-to-train layers, such as flatten layers, which convert an array, such as an image, into a vector.

Keras’ preprocessing layers are specifically designed for use in the early stages of a neural network. They can be used for image preprocessing, such as resizing or rotating images, or adjusting brightness and contrast. While preprocessing layers are intended to be part of a larger neural network, they can also be used as functions. The following is an example of how we can use the resizing layer as a function to transform some images and display them alongside the original:

...

# create a resizing layer
out_height, out_width = 128,256
resize = tf.keras.layers.Resizing(out_height, out_width)

# show original vs resized
fig, ax = plt.subplots(2, 3, figsize=(6,4))

for images, labels in ds.take(1):
for i in range(3):
ax[0][i].imshow(images[i].numpy().astype("uint8"))
ax[0][i].set_title("original")
# resize
ax[1][i].imshow(resize(images[i]).numpy().astype("uint8"))
ax[1][i].set_title("resize")
plt.show()

Our images are 256×256 pixels in size, and the resizing layer will reduce them to 256×128 pixels. The following is the output of the above code:

Because the resizing layer is a function, we can chain it to the dataset. As an example,

...
def augment(image, label):
return resize(image), label

resized_ds = ds.map(augment)

for image, label in resized_ds:
...

The dataset ds contains samples in the form of (image, label). As a result, we wrote a function that accepts such a tuple and preprocesses the image with the resizing layer. We assigned this function as an argument for map() in the dataset. The image will be transformed when we draw a sample from the new dataset created with the map() function.

More preprocessing layers are available. Some examples are provided below.

We can resize the image, as we saw above. We can also enlarge or decrease the height or width of an image at random. On an image, we can also zoom in and out. The following is an example of how to manipulate the image size in various ways for a maximum of 30% increase or decrease:

...

# Create preprocessing layers
out_height, out_width = 128,256
resize = tf.keras.layers.Resizing(out_height, out_width)
height = tf.keras.layers.RandomHeight(0.3)
width = tf.keras.layers.RandomWidth(0.3)
zoom = tf.keras.layers.RandomZoom(0.3)

# Visualize images and augmentations
fig, ax = plt.subplots(5, 3, figsize=(6,14))

for images, labels in ds.take(1):
for i in range(3):
ax[0][i].imshow(images[i].numpy().astype("uint8"))
ax[0][i].set_title("original")
# resize
ax[1][i].imshow(resize(images[i]).numpy().astype("uint8"))
ax[1][i].set_title("resize")
# height
ax[2][i].imshow(height(images[i]).numpy().astype("uint8"))
ax[2][i].set_title("height")
# width
ax[3][i].imshow(width(images[i]).numpy().astype("uint8"))
ax[3][i].set_title("width")
# zoom
ax[4][i].imshow(zoom(images[i]).numpy().astype("uint8"))
ax[4][i].set_title("zoom")
plt.show()

This code displays the following images:

While resize has a fixed dimension, other augmentations have a random amount of manipulation.

Using preprocessing layers, we can also perform flipping, rotation, cropping, and geometric translation:

...
# Create preprocessing layers
flip = tf.keras.layers.RandomFlip("horizontal_and_vertical") # or "horizontal", "vertical"
rotate = tf.keras.layers.RandomRotation(0.2)
crop = tf.keras.layers.RandomCrop(out_height, out_width)
translation = tf.keras.layers.RandomTranslation(height_factor=0.2, width_factor=0.2)

# Visualize augmentations
fig, ax = plt.subplots(5, 3, figsize=(6,14))

for images, labels in ds.take(1):
for i in range(3):
ax[0][i].imshow(images[i].numpy().astype("uint8"))
ax[0][i].set_title("original")
# flip
ax[1][i].imshow(flip(images[i]).numpy().astype("uint8"))
ax[1][i].set_title("flip")
# crop
ax[2][i].imshow(crop(images[i]).numpy().astype("uint8"))
ax[2][i].set_title("crop")
# translation
ax[3][i].imshow(translation(images[i]).numpy().astype("uint8"))
ax[3][i].set_title("translation")
# rotate
ax[4][i].imshow(rotate(images[i]).numpy().astype("uint8"))
ax[4][i].set_title("rotate")
plt.show()

The following images are displayed by this code:

Finally, we can do augmentations on color adjustments as well:

...
brightness = tf.keras.layers.RandomBrightness([-0.8,0.8])
contrast = tf.keras.layers.RandomContrast(0.2)

# Visualize augmentation
fig, ax = plt.subplots(3, 3, figsize=(6,7))

for images, labels in ds.take(1):
for i in range(3):
ax[0][i].imshow(images[i].numpy().astype("uint8"))
ax[0][i].set_title("original")
# brightness
ax[1][i].imshow(brightness(images[i]).numpy().astype("uint8"))
ax[1][i].set_title("brightness")
# contrast
ax[2][i].imshow(contrast(images[i]).numpy().astype("uint8"))
ax[2][i].set_title("contrast")
plt.show()

This displays the following images:

To be complete, here is the code for displaying the results of various augmentations:

from tensorflow.keras.utils import image_dataset_from_directory
import tensorflow as tf
import matplotlib.pyplot as plt

# use image_dataset_from_directory() to load images, with image size scaled to 256x256
PATH='.../Citrus/Leaves' # modify to your path
ds = image_dataset_from_directory(PATH,
validation_split=0.2, subset="training",
image_size=(256,256), interpolation="mitchellcubic",
crop_to_aspect_ratio=True,
seed=42, shuffle=True, batch_size=32)

# Create preprocessing layers
out_height, out_width = 128,256
resize = tf.keras.layers.Resizing(out_height, out_width)
height = tf.keras.layers.RandomHeight(0.3)
width = tf.keras.layers.RandomWidth(0.3)
zoom = tf.keras.layers.RandomZoom(0.3)

flip = tf.keras.layers.RandomFlip("horizontal_and_vertical")
rotate = tf.keras.layers.RandomRotation(0.2)
crop = tf.keras.layers.RandomCrop(out_height, out_width)
translation = tf.keras.layers.RandomTranslation(height_factor=0.2, width_factor=0.2)

brightness = tf.keras.layers.RandomBrightness([-0.8,0.8])
contrast = tf.keras.layers.RandomContrast(0.2)

# Visualize images and augmentations
fig, ax = plt.subplots(5, 3, figsize=(6,14))
for images, labels in ds.take(1):
for i in range(3):
ax[0][i].imshow(images[i].numpy().astype("uint8"))
ax[0][i].set_title("original")
# resize
ax[1][i].imshow(resize(images[i]).numpy().astype("uint8"))
ax[1][i].set_title("resize")
# height
ax[2][i].imshow(height(images[i]).numpy().astype("uint8"))
ax[2][i].set_title("height")
# width
ax[3][i].imshow(width(images[i]).numpy().astype("uint8"))
ax[3][i].set_title("width")
# zoom
ax[4][i].imshow(zoom(images[i]).numpy().astype("uint8"))
ax[4][i].set_title("zoom")
plt.show()

fig, ax = plt.subplots(5, 3, figsize=(6,14))
for images, labels in ds.take(1):
for i in range(3):
ax[0][i].imshow(images[i].numpy().astype("uint8"))
ax[0][i].set_title("original")
# flip
ax[1][i].imshow(flip(images[i]).numpy().astype("uint8"))
ax[1][i].set_title("flip")
# crop
ax[2][i].imshow(crop(images[i]).numpy().astype("uint8"))
ax[2][i].set_title("crop")
# translation
ax[3][i].imshow(translation(images[i]).numpy().astype("uint8"))
ax[3][i].set_title("translation")
# rotate
ax[4][i].imshow(rotate(images[i]).numpy().astype("uint8"))
ax[4][i].set_title("rotate")
plt.show()

fig, ax = plt.subplots(3, 3, figsize=(6,7))
for images, labels in ds.take(1):
for i in range(3):
ax[0][i].imshow(images[i].numpy().astype("uint8"))
ax[0][i].set_title("original")
# brightness
ax[1][i].imshow(brightness(images[i]).numpy().astype("uint8"))
ax[1][i].set_title("brightness")
# contrast
ax[2][i].imshow(contrast(images[i]).numpy().astype("uint8"))
ax[2][i].set_title("contrast")
plt.show()

Finally, it is worth noting that most neural network models perform better when the input images are scaled. While we typically use an 8-bit unsigned integer for the pixel values in an image (for example, for display using imshow()), neural networks prefer the pixel values to be between 0 and 1, or between -1 and +1. This is also possible with preprocessing layers. Here’s how we can update one of our previous examples to include the scaling layer in the augmentation:

...
out_height, out_width = 128,256
resize = tf.keras.layers.Resizing(out_height, out_width)
rescale = tf.keras.layers.Rescaling(1/127.5, offset=-1) # rescale pixel values to [-1,1]

def augment(image, label):
return rescale(resize(image)), label

rescaled_resized_ds = ds.map(augment)

for image, label in rescaled_resized_ds:
...
  1. Using tf.image API for Augmentation

Aside from the preprocessing layer, the tf.image module also included some augmentation functions. These functions, unlike the preprocessing layer, are intended to be used in a user-defined function and assigned to a dataset using map(), as we saw above.

The functions provided by tf.image are not duplicates of the preprocessing layers, although there is some overlap. The following is an example of how to use the tf.image functions to resize and crop images:

...

fig, ax = plt.subplots(5, 3, figsize=(6,14))

for images, labels in ds.take(1):
for i in range(3):
# original
ax[0][i].imshow(images[i].numpy().astype("uint8"))
ax[0][i].set_title("original")
# resize
h = int(256 * tf.random.uniform([], minval=0.8, maxval=1.2))
w = int(256 * tf.random.uniform([], minval=0.8, maxval=1.2))
ax[1][i].imshow(tf.image.resize(images[i], [h,w]).numpy().astype("uint8"))
ax[1][i].set_title("resize")
# crop
y, x, h, w = (128 * tf.random.uniform((4,))).numpy().astype("uint8")
ax[2][i].imshow(tf.image.crop_to_bounding_box(images[i], y, x, h, w).numpy().astype("uint8"))
ax[2][i].set_title("crop")
# central crop
x = tf.random.uniform([], minval=0.4, maxval=1.0)
ax[3][i].imshow(tf.image.central_crop(images[i], x).numpy().astype("uint8"))
ax[3][i].set_title("central crop")
# crop to (h,w) at random offset
h, w = (256 * tf.random.uniform((2,))).numpy().astype("uint8")
seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32")
ax[4][i].imshow(tf.image.stateless_random_crop(images[i], [h,w,3], seed).numpy().astype("uint8"))
ax[4][i].set_title("random crop")
plt.show()

The following is the output of the preceding code:

While the image display matches what we would expect from the code, the use of tf.image functions differs significantly from that of the preprocessing layers. Every tf.image function is unique. As a result, we can see that the crop_to_bounding_box() function requires pixel coordinates, whereas the central crop() function requires a fraction ratio as an argument.

These functions also differ in how they handle randomness. Some of these functions do not follow a random pattern. As a result, before calling the resize function, the exact output size should be generated using a random number generator separately. Other functions, such as stateless_random_crop(), can perform augmentation at random, but a pair of random seeds in int32 must be explicitly specified.

To continue the example, the following functions are available for flipping an image and extracting the Sobel edges:

...
fig, ax = plt.subplots(5, 3, figsize=(6,14))

for images, labels in ds.take(1):
for i in range(3):
ax[0][i].imshow(images[i].numpy().astype("uint8"))
ax[0][i].set_title("original")
# flip
seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32")
ax[1][i].imshow(tf.image.stateless_random_flip_left_right(images[i], seed).numpy().astype("uint8"))
ax[1][i].set_title("flip left-right")
# flip
seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32")
ax[2][i].imshow(tf.image.stateless_random_flip_up_down(images[i], seed).numpy().astype("uint8"))
ax[2][i].set_title("flip up-down")
# sobel edge
sobel = tf.image.sobel_edges(images[i:i+1])
ax[3][i].imshow(sobel[0, ..., 0].numpy().astype("uint8"))
ax[3][i].set_title("sobel y")
# sobel edge
ax[4][i].imshow(sobel[0, ..., 1].numpy().astype("uint8"))
ax[4][i].set_title("sobel x")
plt.show()

which demonstrates the following:

And here are the functions for adjusting the brightness, contrast, and colors:

...
fig, ax = plt.subplots(5, 3, figsize=(6,14))

for images, labels in ds.take(1):
for i in range(3):
ax[0][i].imshow(images[i].numpy().astype("uint8"))
ax[0][i].set_title("original")
# brightness
seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32")
ax[1][i].imshow(tf.image.stateless_random_brightness(images[i], 0.3, seed).numpy().astype("uint8"))
ax[1][i].set_title("brightness")
# contrast
ax[2][i].imshow(tf.image.stateless_random_contrast(images[i], 0.7, 1.3, seed).numpy().astype("uint8"))
ax[2][i].set_title("contrast")
# saturation
ax[3][i].imshow(tf.image.stateless_random_saturation(images[i], 0.7, 1.3, seed).numpy().astype("uint8"))
ax[3][i].set_title("saturation")
# hue
ax[4][i].imshow(tf.image.stateless_random_hue(images[i], 0.3, seed).numpy().astype("uint8"))
ax[4][i].set_title("hue")
plt.show()

This code demonstrates the following:

The complete code for displaying all of the above is provided below:

from tensorflow.keras.utils import image_dataset_from_directory
import tensorflow as tf
import matplotlib.pyplot as plt

# use image_dataset_from_directory() to load images, with image size scaled to 256x256
PATH='.../Citrus/Leaves' # modify to your path
ds = image_dataset_from_directory(PATH,
validation_split=0.2, subset="training",
image_size=(256,256), interpolation="mitchellcubic",
crop_to_aspect_ratio=True,
seed=42, shuffle=True, batch_size=32)

# Visualize tf.image augmentations

fig, ax = plt.subplots(5, 3, figsize=(6,14))
for images, labels in ds.take(1):
for i in range(3):
# original
ax[0][i].imshow(images[i].numpy().astype("uint8"))
ax[0][i].set_title("original")
# resize
h = int(256 * tf.random.uniform([], minval=0.8, maxval=1.2))
w = int(256 * tf.random.uniform([], minval=0.8, maxval=1.2))
ax[1][i].imshow(tf.image.resize(images[i], [h,w]).numpy().astype("uint8"))
ax[1][i].set_title("resize")
# crop
y, x, h, w = (128 * tf.random.uniform((4,))).numpy().astype("uint8")
ax[2][i].imshow(tf.image.crop_to_bounding_box(images[i], y, x, h, w).numpy().astype("uint8"))
ax[2][i].set_title("crop")
# central crop
x = tf.random.uniform([], minval=0.4, maxval=1.0)
ax[3][i].imshow(tf.image.central_crop(images[i], x).numpy().astype("uint8"))
ax[3][i].set_title("central crop")
# crop to (h,w) at random offset
h, w = (256 * tf.random.uniform((2,))).numpy().astype("uint8")
seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32")
ax[4][i].imshow(tf.image.stateless_random_crop(images[i], [h,w,3], seed).numpy().astype("uint8"))
ax[4][i].set_title("random crop")
plt.show()

fig, ax = plt.subplots(5, 3, figsize=(6,14))
for images, labels in ds.take(1):
for i in range(3):
ax[0][i].imshow(images[i].numpy().astype("uint8"))
ax[0][i].set_title("original")
# flip
seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32")
ax[1][i].imshow(tf.image.stateless_random_flip_left_right(images[i], seed).numpy().astype("uint8"))
ax[1][i].set_title("flip left-right")
# flip
seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32")
ax[2][i].imshow(tf.image.stateless_random_flip_up_down(images[i], seed).numpy().astype("uint8"))
ax[2][i].set_title("flip up-down")
# sobel edge
sobel = tf.image.sobel_edges(images[i:i+1])
ax[3][i].imshow(sobel[0, ..., 0].numpy().astype("uint8"))
ax[3][i].set_title("sobel y")
# sobel edge
ax[4][i].imshow(sobel[0, ..., 1].numpy().astype("uint8"))
ax[4][i].set_title("sobel x")
plt.show()

fig, ax = plt.subplots(5, 3, figsize=(6,14))
for images, labels in ds.take(1):
for i in range(3):
ax[0][i].imshow(images[i].numpy().astype("uint8"))
ax[0][i].set_title("original")
# brightness
seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32")
ax[1][i].imshow(tf.image.stateless_random_brightness(images[i], 0.3, seed).numpy().astype("uint8"))
ax[1][i].set_title("brightness")
# contrast
ax[2][i].imshow(tf.image.stateless_random_contrast(images[i], 0.7, 1.3, seed).numpy().astype("uint8"))
ax[2][i].set_title("contrast")
# saturation
ax[3][i].imshow(tf.image.stateless_random_saturation(images[i], 0.7, 1.3, seed).numpy().astype("uint8"))
ax[3][i].set_title("saturation")
# hue
ax[4][i].imshow(tf.image.stateless_random_hue(images[i], 0.3, seed).numpy().astype("uint8"))
ax[4][i].set_title("hue")
plt.show()

These augmentation functions should be sufficient for the majority of users. However, if you have a specific idea for augmentation, you will most likely require a better image processing library. OpenCV and Pillow are two well-known but powerful image-transformation libraries.

  1. Using Preprocessing Layers in Neural Networks

In the preceding examples, we used Keras preprocessing layers as functions. They can, however, be used as layers in a neural network. It is simple to use. Here’s an example of incorporating a preprocessing layer into a classification network and training it with a dataset:

from tensorflow.keras.utils import image_dataset_from_directory
import tensorflow as tf
import matplotlib.pyplot as plt

# use image_dataset_from_directory() to load images, with image size scaled to 256x256
PATH='.../Citrus/Leaves' # modify to your path
ds = image_dataset_from_directory(PATH,
validation_split=0.2, subset="training",
image_size=(256,256), interpolation="mitchellcubic",
crop_to_aspect_ratio=True,
seed=42, shuffle=True, batch_size=32)

AUTOTUNE = tf.data.AUTOTUNE
ds = ds.cache().prefetch(buffer_size=AUTOTUNE)

num_classes = 5
model = tf.keras.Sequential([
tf.keras.layers.RandomFlip("horizontal_and_vertical"),
tf.keras.layers.RandomRotation(0.2),
tf.keras.layers.Rescaling(1/127.0, offset=-1),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(num_classes)
])

model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])

model.fit(ds, epochs=3)

Running this code produces the following results:

Found 609 files belonging to 5 classes.
Using 488 files for training.
Epoch 1/3
16/16 [==============================] - 5s 253ms/step - loss: 1.4114 - accuracy: 0.4283
Epoch 2/3
16/16 [==============================] - 4s 259ms/step - loss: 0.8101 - accuracy: 0.6475
Epoch 3/3
16/16 [==============================] - 4s 267ms/step - loss: 0.7015 - accuracy: 0.7111

We created the dataset with cache() and prefetch() in the preceding code. This is a performance optimization technique that allows the dataset to prepare data asynchronously while the neural network is being trained. This is significant if the dataset has other augmentations applied via the map() function.

Because you make the problem easier by removing the RandomFlip and RandomRotation layers, you will see an improvement in accuracy. However, because we want the network to predict well across a wide range of image quality and properties, we can use augmentation to make the resulting network more powerful.

Summary

You’ve seen how to use the tf.data dataset with image augmentation functions from Keras and TensorFlow in this post.

You specifically learned:

  1. How to use Keras’ preprocessing layers as a function and as part of a neural network
  2. How to write your own image augmentation function and apply it to a dataset with the map() function
  3. How to use the image augmentation functions provided by the tf.image module

Source link