在Google的TPU上训练Fashion MNIST图像识别模型

时间:2023-02-25 13:55:32


在Google的TPU上训练Fashion MNIST图像识别模型


作者 | 张强

今天我们要训练的模型是基于Keras框架,来训练FashionMNIST图像识别模型,该模型和MNIST是一样的分类数量。

​MNIST​​​的分类是0到9的十个数字
​​​FashionMNIST​​​的分类是这十个分类:​​'t_shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle_boots'​

平时我们在训练AI模型时,都是在CPU或者GPU服务器上,今天在此示例中,我们尝试使用tf.keras在Google Cloud TPU上训练基于FashionMNIST数据集的模型。该模型在Cloud TPU上训练1个Epoch,大约需要2分钟运行完毕。

学习目标

在本Jupyter Notebook中,我们将学习:

构建标准的卷积网络,在Keras的每一层之间有3层,具有Dropout和批量标准化操作。

  • 使用yield来创建数据集的生成器和fit_generator来训练模型。
  • 运行模型预测以查看模型如何预测fashion类别并输出结果。

TPU位于Google Cloud中,为获得最佳性能,可以直接从Google云端存储(GCS)读取数据,让我们一起看看代码是如何编写的。

获取数据
首先使用tf.keras.datasets下载Fashion MNIST数据集,如下所示:

import tensorflow as tf
import numpy as np

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

# add empty color dimension
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

定义模型

以下示例使用标准conv-net,每层都有Dropout和批量标准化:

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.BatchNormalization(input_shape=x_train.shape[1:]))
model.add(tf.keras.layers.Conv2D(64, (5, 5), padding='same', activation='elu'))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2,2)))
model.add(tf.keras.layers.Dropout(0.25))

model.add(tf.keras.layers.BatchNormalization(input_shape=x_train.shape[1:]))
model.add(tf.keras.layers.Conv2D(128, (5, 5), padding='same', activation='elu'))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
model.add(tf.keras.layers.Dropout(0.25))

model.add(tf.keras.layers.BatchNormalization(input_shape=x_train.shape[1:]))
model.add(tf.keras.layers.Conv2D(256, (5, 5), padding='same', activation='elu'))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2,2)))
model.add(tf.keras.layers.Dropout(0.25))

model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(256))
model.add(tf.keras.layers.Activation('elu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(10))
model.add(tf.keras.layers.Activation('softmax'))
model.summary()

输出的模型概要如下:

Layer (type)                 Output Shape              Param #   
=================================================================
batch_normalization_v1 (Batc (None, 28, 28, 1) 4
_________________________________________________________________
conv2d (Conv2D) (None, 28, 28, 64) 1664
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 64) 0
_________________________________________________________________
dropout (Dropout) (None, 14, 14, 64) 0
_________________________________________________________________
batch_normalization_v1_1 (Ba (None, 14, 14, 64) 256
_________________________________________________________________
conv2d_1 (Conv2D) (None, 14, 14, 128) 204928
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 128) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 7, 7, 128) 0
_________________________________________________________________
batch_normalization_v1_2 (Ba (None, 7, 7, 128) 512
_________________________________________________________________
conv2d_2 (Conv2D) (None, 7, 7, 256) 819456
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 3, 3, 256) 0
_________________________________________________________________
dropout_2 (Dropout) (None, 3, 3, 256) 0
_________________________________________________________________
flatten (Flatten) (None, 2304) 0
_________________________________________________________________
dense (Dense) (None, 256) 590080
_________________________________________________________________
activation (Activation) (None, 256) 0
_________________________________________________________________
dropout_3 (Dropout) (None, 256) 0
_________________________________________________________________
dense_1 (Dense) (None, 10) 2570
_________________________________________________________________
activation_1 (Activation) (None, 10) 0
=================================================================
Total params: 1,619,470
Trainable params: 1,619,084
Non-trainable params: 386
_________________________________________________________________

在TPU上训练

要开始训练,请在TPU上构建模型,然后进行编译。

以下代码演示了如何使用生成器函数和fit_generator来训练模型。 或者,您可以将x_train和y_train传递给tpu_model.fit(),代码如下:

import os
tpu_model = tf.contrib.tpu.keras_to_tpu_model(
model,
strategy=tf.contrib.tpu.TPUDistributionStrategy(
tf.contrib.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
)
)
tpu_model.compile(
optimizer=tf.train.AdamOptimizer(learning_rate=1e-3, ),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['sparse_categorical_accuracy']
)

def train_gen(batch_size):
while True:
offset = np.random.randint(0, x_train.shape[0] - batch_size)
yield x_train[offset:offset+batch_size], y_train[offset:offset + batch_size]


tpu_model.fit_generator(
train_gen(1024),
epochs=1,
steps_per_epoch=1000,
validation_data=(x_test, y_test),
)

测试结果(Inference)

现在您已经完成了训练,看看模型如何预测fashion类别,如下代码:

LABEL_NAMES = ['t_shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle_boots']

cpu_model = tpu_model.sync_to_cpu()

from matplotlib import pyplot
%matplotlib inline

def plot_predictions(images, predictions):
n = images.shape[0]
nc = int(np.ceil(n / 4))
f, axes = pyplot.subplots(nc, 4)
for i in range(nc * 4):
y = i // 4
x = i % 4
axes[x, y].axis('off')

label = LABEL_NAMES[np.argmax(predictions[i])]
confidence = np.max(predictions[i])
if i > n:
continue
axes[x, y].imshow(images[i])
axes[x, y].text(0.5, 0.5, label + '\n%.3f' % confidence, fontsize=14)

pyplot.gcf().set_size_inches(8, 8)

plot_predictions(np.squeeze(x_test[:16]),
tpu_model.predict(x_test[:16]))

输出的预测效果如图所示:

在Google的TPU上训练Fashion MNIST图像识别模型

课外阅读

在Google云端平台上,除了预先配置的深度学习虚拟机上提供的GPU和TPU之外,还可以找到用于训练自定义模型的AutoML(测试版),无需编写代码在Cloud ML Engine上,就可以运行并行训练模型,超参数调整,以及强大的分布式硬件上的自定义模型。

代码已经上传到了Github:
​​​https://github.com/VictorZhang2014/paddle_vs_tensorflow_vs_keras/tree/master/fashion_mnist​