Sample Classification Code of CIFAR-10 in Torch

时间:2021-11-22 15:06:57

Sample Classification Code of CIFAR-10 in Torch

from: http://torch.ch/blog/2015/07/30/cifar.html

require 'xlua'
require 'optim'
require 'nn'
require 'image'
local c = require 'trepl.colorize' opt = lapp[[
-s,--save (default "logs") subdirectory to save logs
-b,--batchSize (default 128) batch size
-r,--learningRate (default 1) learning rate
--learningRateDecay (default 1e-7) learning rate decay
--weightDecay (default 0.0005) weightDecay
-m,--momentum (default 0.9) momentum
--epoch_step (default 25) epoch step
--model (default vgg_bn_drop) model name
--max_epoch (default 300) maximum number of iterations
--backend (default nn) backend
--type (default cuda) cuda/float/cl
]] print(opt) do -- data augmentation module
local BatchFlip,parent = torch.class('nn.BatchFlip', 'nn.Module') function BatchFlip:__init()
parent.__init(self)
self.train = true
end function BatchFlip:updateOutput(input)
if self.train then
local bs = input:size()
local flip_mask = torch.randperm(bs):le(bs/)
for i=,input:size() do
if flip_mask[i] == then image.hflip(input[i], input[i]) end
end
end
self.output:set(input)
return self.output
end
end local function cast(t)
if opt.type == 'cuda' then
require 'cunn'
return t:cuda()
elseif opt.type == 'float' then
return t:float()
elseif opt.type == 'cl' then
require 'clnn'
return t:cl()
else
error('Unknown type '..opt.type)
end
end print(c.blue '==>' ..' configuring model')
local model = nn.Sequential()
model:add(nn.BatchFlip():float())
model:add(cast(nn.Copy('torch.FloatTensor', torch.type(cast(torch.Tensor())))))
model:add(cast(dofile('models/'..opt.model..'.lua')))
model:get().updateGradInput = function(input) return end if opt.backend == 'cudnn' then
require 'cudnn'
cudnn.benchmark=true
cudnn.convert(model:get(), cudnn)
end print(model) print(c.blue '==>' ..' loading data') -------------------------------------------------------------------------------------------
---------------------------- Load the Train and Test data -------------------------------
------------------------------------------------------------------------------------------- local trsize =
local tesize =
-- load dataset
trainData = {
data = torch.Tensor(, ),
labels = torch.Tensor(),
size = function() return trsize end
}
local trainData = trainData
for i = , do
local subset = torch.load('cifar-10-batches-t7/data_batch_' .. (i+) .. '.t7', 'ascii')
trainData.data[{ {i*+, (i+)*} }] = subset.data:t()
trainData.labels[{ {i*+, (i+)*} }] = subset.labels
end
trainData.labels = trainData.labels + local subset = torch.load('cifar-10-batches-t7/test_batch.t7', 'ascii')
testData = {
data = subset.data:t():double(),
labels = subset.labels[]:double(),
size = function() return tesize end
}
local testData = testData
testData.labels = testData.labels + -- resize dataset (if using small version)
trainData.data = trainData.data[{ {,trsize} }]
trainData.labels = trainData.labels[{ {,trsize} }] testData.data = testData.data[{ {,tesize} }]
testData.labels = testData.labels[{ {,tesize} }] -- reshape data
trainData.data = trainData.data:reshape(trsize,,,)
testData.data = testData.data:reshape(tesize,,,) ----------------------------------------------------------------------------------
----------------------------------------------------------------------------------
-- preprocessing data (color space + normalization)
----------------------------------------------------------------------------------
----------------------------------------------------------------------------------
print '<trainer> preprocessing data (color space + normalization)'
collectgarbage() -- preprocess trainSet
local normalization = nn.SpatialContrastiveNormalization(, image.gaussian1D())
for i = ,trainData:size() do
xlua.progress(i, trainData:size())
-- rgb -> yuv
local rgb = trainData.data[i]
local yuv = image.rgb2yuv(rgb)
-- normalize y locally:
yuv[] = normalization(yuv[{{}}])
trainData.data[i] = yuv
end
-- normalize u globally:
local mean_u = trainData.data:select(,):mean()
local std_u = trainData.data:select(,):std()
trainData.data:select(,):add(-mean_u)
trainData.data:select(,):div(std_u)
-- normalize v globally:
local mean_v = trainData.data:select(,):mean()
local std_v = trainData.data:select(,):std()
trainData.data:select(,):add(-mean_v)
trainData.data:select(,):div(std_v) trainData.mean_u = mean_u
trainData.std_u = std_u
trainData.mean_v = mean_v
trainData.std_v = std_v -- preprocess testSet
for i = ,testData:size() do
xlua.progress(i, testData:size())
-- rgb -> yuv
local rgb = testData.data[i]
local yuv = image.rgb2yuv(rgb)
-- normalize y locally:
yuv[{}] = normalization(yuv[{{}}])
testData.data[i] = yuv
end
-- normalize u globally:
testData.data:select(,):add(-mean_u)
testData.data:select(,):div(std_u)
-- normalize v globally:
testData.data:select(,):add(-mean_v)
testData.data:select(,):div(std_v) ----------------------------------------------------------------------------------
----------------------------- END --------------------------------------------- trainData.data = trainData.data:float()
testData.data = testData.data:float() confusion = optim.ConfusionMatrix() print('Will save at '..opt.save)
paths.mkdir(opt.save)
testLogger = optim.Logger(paths.concat(opt.save, 'test.log'))
testLogger:setNames{'% mean class accuracy (train set)', '% mean class accuracy (test set)'}
testLogger.showPlot = false parameters,gradParameters = model:getParameters() print(c.blue'==>' ..' setting criterion')
criterion = cast(nn.CrossEntropyCriterion()) print(c.blue'==>' ..' configuring optimizer')
optimState = {
learningRate = opt.learningRate,
weightDecay = opt.weightDecay,
momentum = opt.momentum,
learningRateDecay = opt.learningRateDecay,
} function train()
model:training()
epoch = epoch or -- drop learning rate every "epoch_step" epochs
if epoch % opt.epoch_step == then optimState.learningRate = optimState.learningRate/ end print(c.blue '==>'.." online epoch # " .. epoch .. ' [batchSize = ' .. opt.batchSize .. ']') local targets = cast(torch.FloatTensor(opt.batchSize))
local indices = torch.randperm(trainData.data:size()):long():split(opt.batchSize)
-- remove last element so that all the batches have equal size
indices[#indices] = nil local tic = torch.tic()
for t,v in ipairs(indices) do
xlua.progress(t, #indices) local inputs = trainData.data:index(,v)
targets:copy(trainData.labels:index(,v)) local feval = function(x)
if x ~= parameters then parameters:copy(x) end
gradParameters:zero() local outputs = model:forward(inputs)
local f = criterion:forward(outputs, targets)
local df_do = criterion:backward(outputs, targets)
model:backward(inputs, df_do) confusion:batchAdd(outputs, targets) return f,gradParameters
end
optim.sgd(feval, parameters, optimState)
end confusion:updateValids()
print(('Train accuracy: '..c.cyan'%.2f'..' %%\t time: %.2f s'):format(
confusion.totalValid * , torch.toc(tic))) train_acc = confusion.totalValid * confusion:zero()
epoch = epoch +
end function test()
-- disable flips, dropouts and batch normalization
model:evaluate()
print(c.blue '==>'.." testing")
local bs =
for i=,testData.data:size(),bs do
local outputs = model:forward(testData.data:narrow(,i,bs))
confusion:batchAdd(outputs, testData.labels:narrow(,i,bs))
end confusion:updateValids()
print('Test accuracy:', confusion.totalValid * ) if testLogger then
paths.mkdir(opt.save)
testLogger:add{train_acc, confusion.totalValid * }
testLogger:style{'-','-'}
testLogger:plot() if paths.filep(opt.save..'/test.log.eps') then
local base64im
do
os.execute(('convert -density 200 %s/test.log.eps %s/test.png'):format(opt.save,opt.save))
os.execute(('openssl base64 -in %s/test.png -out %s/test.base64'):format(opt.save,opt.save))
local f = io.open(opt.save..'/test.base64')
if f then base64im = f:read'*all' end
end local file = io.open(opt.save..'/report.html','w')
file:write(([[
<!DOCTYPE html>
<html>
<body>
<title>%s - %s</title>
<img src="data:image/png;base64,%s">
<h4>optimState:</h4>
<table>
]]):format(opt.save,epoch,base64im))
for k,v in pairs(optimState) do
if torch.type(v) == 'number' then
file:write('<tr><td>'..k..'</td><td>'..v..'</td></tr>\n')
end
end
file:write'</table><pre>\n'
file:write(tostring(confusion)..'\n')
file:write(tostring(model)..'\n')
file:write'</pre></body></html>'
file:close()
end
end -- save model every 50 epochs
if epoch % == then
local filename = paths.concat(opt.save, 'model.net')
print('==> saving model to '..filename)
torch.save(filename, model:get():clearState())
end confusion:zero()
end for i=,opt.max_epoch do
train()
test()
end

the original version code:

Sample Classification Code of CIFAR-10 in Torch

why they written like this ?

It can not run ...

Sample Classification Code of CIFAR-10 in Torch

Sample Classification Code of CIFAR-10 in Torch的更多相关文章

  1. 【翻译】TensorFlow卷积神经网络识别CIFAR 10Convolutional Neural Network &lpar;CNN&rpar;&vert; CIFAR 10 TensorFlow

    原网址:https://data-flair.training/blogs/cnn-tensorflow-cifar-10/ by DataFlair Team · Published May 21, ...

  2. code&colon;&colon;blocks&lpar;版本10&period;05&rpar; 配置opencv2&period;4&period;3

    (1)首先下载opencv2.4.3, 解压缩到D:下: (2)配置code::blocks, 具体操作如下: 第一步, 配置compiler, 操作步骤为Settings  -> Compil ...

  3. code&colon;&colon;blocks&lpar;版本号10&period;05&rpar; 配置opencv2&period;4&period;3

    (1)首先下载opencv2.4.3, 解压缩到D:下: (2)配置code::blocks, 详细操作例如以下: 第一步, 配置compiler, 操作步骤为Settings  -> Comp ...

  4. DL Practice:Cifar 10分类

    Step 1:数据加载和处理 一般使用深度学习框架会经过下面几个流程: 模型定义(包括损失函数的选择)——>数据处理和加载——>训练(可能包括训练过程可视化)——>测试 所以自己写代 ...

  5. 【神经网络与深度学习】基于Windows&plus;Caffe的Minst和CIFAR—10训练过程说明

    Minst训练 我的路径:G:\Caffe\Caffe For Windows\examples\mnist  对于新手来说,初步完成环境的配置后,一脸茫然.不知如何跑Demo,有么有!那么接下来的教 ...

  6. Oracle Applications Multiple Organizations Access Control for Custom Code

    档 ID 420787.1 White Paper Oracle Applications Multiple Organizations Access Control for Custom Code ...

  7. UWP深入学习六:Build better apps&colon; Windows 10 by 10 development series

    Promotion in the Windows Store  In this article, I walk through how to Give your Store listing a mak ...

  8. Removing Columns 分类: 贪心 CF 2015-08-08 16&colon;10 10人阅读 评论&lpar;0&rpar; 收藏

    Removing Columns time limit per test 2 seconds memory limit per test 256 megabytes input standard in ...

  9. CV code references

    转:http://www.sigvc.org/bbs/thread-72-1-1.html 一.特征提取Feature Extraction:   SIFT [1] [Demo program][SI ...

随机推荐

  1. Node&period;js Express 框架 POST方法

    POST 方法 以下实例演示了在表单中通过 POST 方法提交两个参数,我们可以使用 server.js 文件内的 process_post 路由器来处理输入: index.htm 文件代码修改如下: ...

  2. SpringMVC 邮件发送

    <!--邮件发送实现类--> <bean id="javaMailSender" class="org.springframework.mail.jav ...

  3. BZOJ1701 &colon; &lbrack;Usaco2007 Jan&rsqb;Cow School牛学校

    枚举剩下的分数个数$k$,设最高的$k$个分数和的分子分母分别为$U$和$D$. 那么在选了的里面找到$A=\min(Dt[x]-Up[x])$,没选的里面找到$B=\max(Dt[x]-Up[x]) ...

  4. &lpar;转&rpar;Android中的Shape使用总结

    http://blog.csdn.net/bear_huangzhen/article/details/24488337 在Android程序开发中,我们经常会去用到Shape这个东西去定义各种各样的 ...

  5. 理解Socket编程【转载】

    “一切皆Socket!” 话虽些许夸张,但是事实也是,现在的网络编程几乎都是用的socket. ——有感于实际编程和开源项目研究. 我们深谙信息交流的价值,那网络中进程之间如何通信,如我们每天打开浏览 ...

  6. 曾经的10道JAVA面试题

    1.HashMap和Hashtable的区别. 都属于Map接口的类,实现了将惟一键映射到特定的值上.HashMap 类没有分类或者排序.它允许一个null 键和多个null 值.Hashtable ...

  7. RabbitMQ 保证消息不丢失

    参考:https://www.imooc.com/article/49814 发送消息的时候,加上messageId字段,数据库记录消息日志表 ,插入的时候为发送中 当收到消息的时候,更改为已发送 , ...

  8. JavaScript中8个常见的陷阱

    译者按: 漫漫编程路,总有一些坑让你泪流满面. 原文: Who said javascript was easy ? 译者: Fundebug 为了保证可读性,本文采用意译而非直译.另外,本文版权归原 ...

  9. Linux发布WebApi

    一:WebApi 使用Owin来做  http://www.cnblogs.com/xiaoyaodijun/category/666029.html 二:安装最新版的Jexus服务 https:// ...

  10. Linux中rsync备份服务部署

    rsync介绍 rsync是一款开源的.快速的.多功能的.可实现全量及增量的本地或远程数据同步备份工具 在常驻模式(daemon mode)下,rsync默认监听TCP端口873,以原生rsync传输 ...