Keras实现预训练网络VGG16迁移学习——猫狗大战分类【65行代码训练&预测】

时间:2024-05-23 11:19:44





一、简介

  1. 迁移学习:一种非常强大的深度学习技术,它的动机很简单——“站在巨人肩膀上”。假设你想学习一门西班牙语,如果从已经掌握的英语学习便更加简单高效。
  2. VGG16:2014年牛津大学提出的模型,简洁实用,擅长图像分类和目标检测
  3. 猫狗大战数据集:Kaggle五年前提供的数据集,含25000张狗和猫的图像(共50000张),另有用于测试的12500张。https://www.kaggle.com/c/dogs-vs-cats/data
    Keras实现预训练网络VGG16迁移学习——猫狗大战分类【65行代码训练&预测】

二、训练代码

备注:

  1. 建议分步用IPython运行。

  2. 训练数据集用的是自定义目录,请自行更改。
    我的目录结构如下:
    Keras实现预训练网络VGG16迁移学习——猫狗大战分类【65行代码训练&预测】

  3. 训练样本数和验证样本数可根据电脑配置和时间调整。下载官网数据集后把train.zip解压后根据数目随意分配到对应文件夹下即可。

from keras import optimizers
from keras import applications
from keras.models import Sequential, Model
from keras.callbacks import ModelCheckpoint
from keras.layers import Dropout, Flatten, Dense
from keras.preprocessing.image import ImageDataGenerator

# 数据集
img_height, img_width = 256, 256  # 图片高宽
batch_size = 2  # 批量大小
epochs = 50  # 迭代次数
train_data_dir = 'data/dogs_and_cats/train'  # 训练集目录
validation_data_dir = 'data/dogs_and_cats/validation'  # 测试集目录
OUT_CATEGORIES = 1  # 分类数
nb_train_samples = 2000  # 训练样本数
nb_validation_samples = 200  # 验证样本数

# 定义模型
base_model = applications.VGG16(weights="imagenet", include_top=False,
                                input_shape=(img_width, img_height, 3))  # 预训练的VGG16网络,替换掉顶部网络
print(base_model.summary())

for layer in base_model.layers[:15]: layer.trainable = False  # 冻结预训练网络前15层

top_model = Sequential()  # 自定义顶层网络
top_model.add(Flatten(input_shape=base_model.output_shape[1:]))  # 将预训练网络展平
top_model.add(Dense(256, activation='relu'))  # 全连接层,输入像素256
top_model.add(Dropout(0.5))  # Dropout概率0.5
top_model.add(Dense(OUT_CATEGORIES, activation='sigmoid'))  # 输出层,二分类

# top_model.load_weights("")  # 单独训练的自定义网络

model = Model(inputs=base_model.input, outputs=top_model(base_model.output))  # 新网络=预训练网络+自定义网络

model.compile(loss='binary_crossentropy', optimizer=optimizers.SGD(lr=0.0001, momentum=0.9),
              metrics=['accuracy'])  # 损失函数为二进制交叉熵,优化器为SGD

train_datagen = ImageDataGenerator(rescale=1. / 255, horizontal_flip=True)  # 训练数据预处理器,随机水平翻转
test_datagen = ImageDataGenerator(rescale=1. / 255)  # 测试数据预处理器
train_generator = train_datagen.flow_from_directory(train_data_dir, target_size=(img_height, img_width),
                                                    batch_size=batch_size, class_mode='binary')  # 训练数据生成器
validation_generator = test_datagen.flow_from_directory(validation_data_dir, target_size=(img_height, img_width),
                                                        batch_size=batch_size, class_mode='binary',
                                                        shuffle=False)  # 验证数据生成器
checkpointer = ModelCheckpoint(filepath='dogcatmodel.h5', verbose=1, save_best_only=True)  # 保存最优模型

# 训练&评估
model.fit_generator(train_generator, steps_per_epoch=nb_train_samples // batch_size, epochs=epochs,
                    validation_data=validation_generator, validation_steps=nb_validation_samples // batch_size,
                    verbose=2, workers=12, callbacks=[checkpointer])  # 每轮一行输出结果,最大进程12

三、训练结果

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 256, 256, 3)       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 256, 256, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 256, 256, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 128, 128, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 128, 128, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 128, 128, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 64, 64, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 64, 64, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 64, 64, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 64, 64, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 32, 32, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 32, 32, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 32, 32, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 32, 32, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 16, 16, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 16, 16, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 16, 16, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 16, 16, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 8, 8, 512)         0         
=================================================================
Total params: 14,714,688
Trainable params: 14,714,688
Non-trainable params: 0
_________________________________________________________________
None
Found 24800 images belonging to 2 classes.
Found 200 images belonging to 2 classes.

省略一堆

Epoch 49/50
 - 47s - loss: 0.0412 - acc: 0.9840 - val_loss: 0.0477 - val_acc: 0.9850

Epoch 00049: val_loss improved from 0.05592 to 0.04766, saving model to dogcatmodel.h5

得到模型
Keras实现预训练网络VGG16迁移学习——猫狗大战分类【65行代码训练&预测】

四、预测代码

from keras.models import load_model
from keras.preprocessing import image

# 图片预处理
path = 'data/dogs_and_cats/test1/1.jpg'  # 图片路径
img_height, img_width = 256, 256  # 图片宽高
x = image.load_img(path=path, target_size=(img_height, img_width))  # 加载图片
x = image.img_to_array(x)  # 图片转ndarray
x = x[None]  # 升维

# 预测
model = load_model('dogcatmodel.h5')  # 加载模型
y = model.predict(x)
y = 'cat' if y == 0 else 'dog'  # 0猫1狗
print(y)

五、预测结果

备注:该图为print输出后的人工标注,非机器标注。
Keras实现预训练网络VGG16迁移学习——猫狗大战分类【65行代码训练&预测】
那么问题来了,我家俩肥猫是猫吗?
Keras实现预训练网络VGG16迁移学习——猫狗大战分类【65行代码训练&预测】

Keras实现预训练网络VGG16迁移学习——猫狗大战分类【65行代码训练&预测】