keras下基于mnist数据集的cnn

时间:2022-02-03 13:07:13

keras是一个支持theano和thsorflow为后端的深度学习框架,本实例以theano为后端,实现一个简单的cnn网络,通过这个我们也可以体会到cnn的强大之处,
首先要安装keras1.02,python2.7,下载mnist数据集于本地(由于在线下载一直失败)。
主程序如下:

import numpy as np

np.random.seed(1337) # for reproducibility

import os
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.utils import np_utils

batch_size = 128
nb_classes = 10
nb_epoch = 12

# input image dimensions
img_rows, img_cols = 28, 28
# number of convolutional filters to use
nb_filters = 32
# size of pooling area for max pooling
nb_pool = 2
# convolution kernel size
nb_conv = 3

# the data, shuffled and split between train and test sets
(X_train, y_train), (X_val, y_val), (X_test, y_test) = mnist.load_data()

# Add the depth in the input. Only grayscale so depth is only one
# see http://cs231n.github.io/convolutional-networks/#overview
X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)

# Make the value floats in [0;1] instead of int in [0;255]
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255

# Display the shapes to check if everything's ok
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

# convert class vectors to binary class matrices (ie one-hot vectors)
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)

##############################################################################################
model = Sequential()
# For an explanation on conv layers see http://cs231n.github.io/convolutional-networks/#conv
# By default the stride/subsample is 1
# border_mode "valid" means no zero-padding.
# If you want zero-padding add a ZeroPadding layer or, if stride is 1 use border_mode="same"
model.add(Convolution2D(nb_filters, nb_conv, nb_conv,border_mode = 'valid',input_shape = (1,img_rows, img_cols),dim_ordering='th'))
model.add(Activation('relu'))

model.add(Convolution2D(nb_filters, nb_conv, nb_conv))
model.add(Activation('relu'))

# For an explanation on pooling layers see http://cs231n.github.io/convolutional-networks/#pool
model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool)))
model.add(Dropout(0.25))

# Flatten the 3D output to 1D tensor for a fully connected layer to accept the input
model.add(Flatten())
model.add(Dense(128))
model.add(Activation('relu'))

model.add(Dropout(0.5))
model.add(Dense(nb_classes)) # Last layer with one output per class
model.add(Activation('softmax')) # We want a score simlar to a probability for each class
###
############################################################################################

# The function to optimize is the cross entropy between the true label and the output (softmax) of the model
# We will use adadelta to do the gradient descent see http://cs231n.github.io/neural-networks-3/#ada
model.compile(loss='categorical_crossentropy', optimizer='adadelta', metrics=["accuracy"])

# Make the model learn
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
verbose=1, validation_data=(X_test, Y_test))

# Evaluate how the model does on the test set
score = model.evaluate(X_test, Y_test, verbose=0)

print('Test score:', score[0])
print('Test accuracy:', score[1])

keras下基于mnist数据集的cnn
还要改一个地方,就是修改mnist.load_data()函数,改变数据集的打开方式,同时设置mnist数据集路径

import gzip
from ..utils.data_utils import get_file
from six.moves import cPickle
import sys
def load_data(path='C:/Users/123/Desktop/mnist.pkl'):
# path = get_file(path, origin='https://s3.amazonaws.com/img-datasets/mnist.pkl.gz')
path = r'C:/Users/123/Desktop/mnist.pkl'

if path.endswith('.gz'):
f = gzip.open(path, 'rb')
else:
f = open(path, 'rb')
f = open(path, 'rb')
data = cPickle.load(f)
f.close()
return data

keras下基于mnist数据集的cnn
第三次时正确率已经达到80%多,设置的12次。由于程序比较占内存,我只运行了7次。结果正确率到92%左右。
keras下基于mnist数据集的cnn