【caffe学习笔记之4】利用MATLAB接口运行cifar数据集

时间:2022-12-28 12:11:49

【前期准备工作】

参考上篇帖子:http://write.blog.csdn.net/postedit/53964874

1. 确保模型训练成功,生成模型文件:cifar10_quick_iter_4000.caffemodel及均值文件:mean.binaryproto。注意,此处一定是生成caffemodel格式的模型文件,而非.h5模型文件,否则会导致Matlab运行崩溃。如何生成caffemodel文件请参考上篇帖子。

也可以利用Matlab生成cifar10_quick_iter_4000.caffemodel,方法是进入caffe根目录,例如我的电脑上为D:\caffe-master\caffe-master,然后在matlab中执行以下命令,即可对模型进行训练:

solver = caffe.Solver('./examples/cifar10/cifar10_quick_solver.prototxt');
solver.solve()

2. 在caffe-master\matlab路径下新建cifar文件夹用于案例调试

3. 拷贝classification_demo.m文件到cifar文件夹下,并更名为classification_cifar.m

【基于mean.binaryproto文件生成.mat 文件】

在matlab command line中输入以下命令,对mean.binaryproto文件进行转换:

mean_file = 'D:\caffe-master\caffe-master\examples\cifar10\test\mean.binaryproto';
image_mean = caffe_('read_mean', mean_file);
save 'D:\caffe-master\caffe-master\matlab\cifar\image_mean.mat' image_mean
于是在matlab/cifar文件夹下生成了image_mean.mat文件

【对classification_cifar.m文件进行修改】

1. 修改dir路径、model路径和weight路径:

【caffe学习笔记之4】利用MATLAB接口运行cifar数据集

2. 修改prepare.image()函数

【caffe学习笔记之4】利用MATLAB接口运行cifar数据集

修改后的classification_cifar.m文件代码:

function [scores, maxlabel] = classification_cifar(im, use_gpu)

% Add caffe/matlab to you Matlab search PATH to use matcaffe
if exist('../+caffe', 'dir')
addpath('..');
else
error('Please run this demo from caffe/matlab/demo');
end

% Set caffe mode
if exist('use_gpu', 'var') && use_gpu
caffe.set_mode_gpu();
gpu_id = 0; % we will use the first gpu in this demo
caffe.set_device(gpu_id);
else
caffe.set_mode_cpu();
end

% Initialize the network using BVLC CaffeNet for image classification
% Weights (parameter) file needs to be downloaded from Model Zoo.
model_dir = '../../examples/cifar10/';
net_model = [model_dir 'cifar10_quick.prototxt'];
net_weights = [model_dir 'cifar10_quick_iter_4000.caffemodel'];
phase = 'test'; % run with phase test (so that dropout isn't applied)
if ~exist(net_weights, 'file')
error('Please download CaffeNet from Model Zoo before you run this demo');
end

% Initialize a network
net = caffe.Net(net_model, net_weights, phase);

if nargin < 1
% For demo purposes we will use the cat image
fprintf('using caffe/examples/images/cat.jpg as input image\n');
im = imread('../../examples/images/cat.jpg');
end

% prepare oversampled input
% input_data is Height x Width x Channel x Num
tic;
input_data = {prepare_image(im)};
toc;

% do forward pass to get scores
% scores are now Channels x Num, where Channels == 1000
tic;
% The net forward function. It takes in a cell array of N-D arrays
% (where N == 4 here) containing data of input blob(s) and outputs a cell
% array containing data from output blob(s)
scores = net.forward(input_data);
toc;

scores = scores{1};
scores = mean(scores, 2); % take average scores over 10 crops

[~, maxlabel] = max(scores);

% call caffe.reset_all() to reset caffe
caffe.reset_all();

% ------------------------------------------------------------------------
function im_data = prepare_image(im)
% ------------------------------------------------------------------------
% caffe/matlab/+caffe/imagenet/ilsvrc_2012_mean.mat contains mean_data that
% is already in W x H x C with BGR channels
d = load('D:\caffe-master\caffe-master\matlab\cifar\image_mean.mat');
mean_data = d.mean_data;
IMAGE_DIM = 32;

% Convert an image returned by Matlab's imread to im_data in caffe's data
% format: W x H x C with BGR channels
im_data = im(:, :, [3, 2, 1]); % permute channels from RGB to BGR
im_data = permute(im_data, [2, 1, 3]); % flip width and height
im_data = single(im_data); % convert from uint8 to single
im_data = imresize(im_data, [IMAGE_DIM IMAGE_DIM], 'bilinear'); % resize im_data
im_data = im_data - mean_data; % subtract mean_data (already in W x H x C, BGR)

【模型测试】

编写test.m文件,用于模型测试,test.m文件代码:

clear;clc  

im = imread('D:\caffe-master\caffe-master\examples\images\cat.jpg');
[scores, maxlabel] = classification_cifar(im,0)
index = importdata('synset_words.txt');
name = index(maxlabel);

figure;imshow(im);
str=strcat('分类结果:',name,'   得分:',num2str(max(scores)));
title(str);

使用上述命令完成模型测试,并对猫做出了正确分类:

【caffe学习笔记之4】利用MATLAB接口运行cifar数据集

【文件下载】

上述文件夹中的4个文件:classification.m、test.m、image_mean.mat、synset_words.txt打包下载地址:

点击打开链接

训练的cifar10_quick_iter_4000.caffemodel文件下载地址:

点击打开链接