MXNet学习——初识NiN

时间:2023-02-23 12:57:51


相关理论&故事背景书上都有,不打算复述一遍,这里仅作代码记录&分享

此处非直接可用代码,由于学习过程中多次使用相同函数,故而将一些常用函数整理成了工具包,​​MxNet学习——自定义工具包​​

两者结合,方可运行代码

# -------------------------------------------------------------------------------
# Description: NiN 网络中的网络
# Description: NiN 思路:串联多个由卷积层&全连接层构成的小网络来构建一个深层网络
# Description: LeNet、AlexNet、VGG 在设计上的共同点:
# Description: 1.先用卷积层构成的模块充分抽取空间特征 2.再以由全连接层构成的模块来输出分类结果
# Description: NiN 采用 1x1 卷积层来替代全连接层,从而使空间信息能够自然传递到后面的层中去
# Reference:
# Author: Sophia
# Date: 2021/3/11
# -------------------------------------------------------------------------------
from IPython import display
from mxnet import autograd, nd, init, gluon
from mxnet.gluon import data as gdata, loss as gloss, nn
import random, sys, time, matplotlib.pyplot as plt, mxnet as mx, os
from plt_so import *

'''
NiN 块
'''
def nin_block(num_channels, kernel_size, strides, padding):
blk = nn.Sequential()
blk.add(nn.Conv2D(num_channels, kernel_size, strides, padding, activation='relu'),
nn.Conv2D(num_channels, kernel_size=1, activation='relu'),
nn.Conv2D(num_channels, kernel_size=1, activation='relu'))
return blk

'''
NiN 模型
'''
net = nn.Sequential()
net.add(nin_block(96, kernel_size=11, strides=4, padding=0),
nn.MaxPool2D(pool_size=3, strides=2),
nin_block(256, kernel_size=5, strides=1, padding=2),
nn.MaxPool2D(pool_size=3, strides=2),
nin_block(384, kernel_size=3, strides=1, padding=1),
nn.MaxPool2D(pool_size=3, strides=2), nn.Dropout(0.5),
# 标签类别数是10
nin_block(10, kernel_size=3, strides=1, padding=1),
# 全局平均池化层将窗口形状自动设置成输入的高和宽
nn.GlobalAvgPool2D(),
# 将四维的输出转成二维的输出,其形状为(批量大小,10)
nn.Flatten())

# X = nd.random.uniform(shape=(1, 1, 224, 224))
# net.initialize()
# print(net)

# 输出:
# Sequential(
# (0): Sequential(
# (0): Conv2D(None -> 96, kernel_size=(11, 11), stride=(4, 4), Activation(relu))
# (1): Conv2D(None -> 96, kernel_size=(1, 1), stride=(1, 1), Activation(relu))
# (2): Conv2D(None -> 96, kernel_size=(1, 1), stride=(1, 1), Activation(relu))
# )
# (1): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)
# (2): Sequential(
# (0): Conv2D(None -> 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), Activation(relu))
# (1): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), Activation(relu))
# (2): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1), Activation(relu))
# )
# (3): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)
# (4): Sequential(
# (0): Conv2D(None -> 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))
# (1): Conv2D(None -> 384, kernel_size=(1, 1), stride=(1, 1), Activation(relu))
# (2): Conv2D(None -> 384, kernel_size=(1, 1), stride=(1, 1), Activation(relu))
# )
# (5): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)
# (6): Dropout(p = 0.5, axes=())
# (7): Sequential(
# (0): Conv2D(None -> 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))
# (1): Conv2D(None -> 10, kernel_size=(1, 1), stride=(1, 1), Activation(relu))
# (2): Conv2D(None -> 10, kernel_size=(1, 1), stride=(1, 1), Activation(relu))
# )
# (8): GlobalAvgPool2D(size=(1, 1), stride=(1, 1), padding=(0, 0), ceil_mode=True, global_pool=True, pool_type=avg, layout=NCHW)
# (9): Flatten
# )


# for layer in net:
# X = layer(X)
# print(layer.name, 'output shape:\t', X.shape)


# 输出:
# sequential1 output shape: (1, 96, 54, 54)
# pool0 output shape: (1, 96, 26, 26)
# sequential2 output shape: (1, 256, 26, 26)
# pool1 output shape: (1, 256, 12, 12)
# sequential3 output shape: (1, 384, 12, 12)
# pool2 output shape: (1, 384, 5, 5)
# dropout0 output shape: (1, 384, 5, 5)
# sequential4 output shape: (1, 10, 5, 5)
# pool3 output shape: (1, 10, 1, 1)
# flatten0 output shape: (1, 10)

'''
训练模型
'''
lr, num_epochs, batch_size, ctx = 0.1, 5, 128, try_gpu()
net.initialize(force_reinit=True, ctx=ctx, init=init.Xavier())
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
train_iter, test_iter = load_data_fashion_mnist_ch5(batch_size, resize=224)
train_ch5(net, train_iter, test_iter, batch_size, trainer, ctx, num_epochs)