torch 深度学习(5)

摘要:
现在看代码1_数据预处理主要包括数据加载、集中化和标准化。require'torch'mnist=require('mnist')--列表大小28*28--初始化数据设置={data=mnist.traindataset().data:
torch 深度学习(5)
mnist
torch
siamese
deep-learning

这篇文章主要是想使用torch学习并理解如何构建siamese network。

siamese network的结构如下:

blob:http://markdown.xiaoshujiang.com/aa0f13b0-909e-4c87-8706-5d686c425b87

1486455020988.jpg

使用的数据集:mnist 手写数据集
实验目的:通过孪生网络使得同一类的尽可能的靠近,不同类的尽可能不同。

命令行:

sudo luarocks install mnist

主要涉及的torch/nn中Container包括Sequential和ParallelTable,具体参见Docs » Modules » Containers

OK,现在来看代码

1_data 数据预处理

主要在于数据的加载和中心化以及归一化处理

require 'torch'
mnist = require('mnist')
-- the size of mnist is 28*28

-- initialize the dataset

train={
data = mnist.traindataset().data:type('torch.FloatTensor'),  -- traindata
label = mnist.traindataset().label, -- train label
size=function() 
return mnist.traindataset().data:size(1) end
}
test={
data = mnist.testdataset().data:type('torch.FloatTensor'),
label = mnist.testdataset().label,
size=function() 
return mnist.testdataset().data:size(1) end
}

local meanV = train.data:mean()
local stdV = train.data:std()

train.data = train.data:csub(meanV)
train.data = train.data:div(stdV)

test.data = test.data:add(-meanV)
test.data = test.data:mul(1.0/stdV)

mnist数据集中图像的大小是$28 imes 28$的,训练样本有60000张,测试样本有10000张

2_model 构建模型

首先孪生网络包括两个子网络,这两个子网络包含在ParallelTable中,而每一个单独的子网络又是在一个Sequential容器内,所以

require 'nn'

cnn=nn.Sequential()
-- stage 1
cnn:add(nn.SpatialConvolution(1,8,3,3,1,1,1)) -- 28*28
-- nn.SpationConvolution(nInputPlane,nOutputPlane,kW,kH,dW,dH,padW,padH)
cnn:add(nn.ReLU())
cnn:add(nn.SpatialMaxPooling(2,2)) -- 14*14
-- stage 2
cnn:add(nn.SpatialConvolution(8,16,3,3,1,1,1)) -- 14*14
cnn:add(nn.ReLU())
cnn:add(nn.SpatialMaxPooling(2,2)) -- 7*7
-- stage 3
cnn:add(nn.SpatialConvolution(16,32,3,3,1,1,1))
cnn:add(nn.ReLU())
cnn:add(nn.SpatialMaxPooling(2,2)) -- 3*3
-- stage 4
cnn:add(nn.Reshape(32*3*3))
cnn:add(nn.Linear(32*3*3,256))
cnn:add(nn.ReLU())
-- stage 5
cnn:add(nn.Linear(256,2))

parallel_model = nn.ParallelTable()
parallel_model:add(cnn)
parallel_model:add(cnn:clone('weight','bias','gradWeight','gradBias'))
--这里,孪生网络要求两个子网络共享参数,所以要分享权重和梯度变化 

model = nn.Sequential()
model:add(nn.SplitTable(1))
model:add(parallel_model)
model:add(nn.PairwiseDistance(2)) -- L2距离
--print(model)

构造的模型如下:

enter description here

1486455042581.jpg

为什么最终每一个子网络输出维度为2?这是因为我们希望之后能够在二维上显示的观察结果

nn.SplitTable(ndim): 将该层输入在第ndim上划分成table,在代码中就是将model的输入样本沿着第1维保存成table,table每一个元素对应这ParallelTable中的一个子网络,
所以model的输入应该是$2 imes 1 imes 28 imes 28$的torch.Tensor

3_loss 损失函数

这里使用的损失函数为 HingeEmbeddingCriterion,具体定义参见HingeEmbeddingCriterion
其形式:loss(x,y) = forward(x,y) = x, if y=1 = max(0,margin - x), if y=-1

$$
loss(x,y)=egin{cases}
x,	ext{ if}quad  y=1\
max(0,margin-x), if y=-1
end{cases}
$$
criterion=nn.HingeEmbeddingCriterion()

4_train 模型训练

在所有的步骤中,我觉得训练这一步相对来说是比较复杂的。
首先要定义数据的batch处理方式,然后定义优化方法调用的函数feval,这个函数使用BP算法更新了模型的参数,所以在整个文件之前要通过model.getPatameters()获得模型参数的引用。
最后就是调用optim中的优化方法对模型进行不断的优化了。

require 'nn'
require 'optim'
require 'xlua'

if model then 
	parameters,gradParameters=model:getParameters()
end
batchSize = 100
learningRate = 0.01
function training()
	epoch=epoch or 1
	time = sys.clock()
	shuffer = torch.randperm(train:size())
	print ">>>>>>>>>>>>>>>>>>>>>> doing epoch on training data: >>>>>>>>>>>>>>>>>>>>>"
	print("=======> online epoch # " .. epoch .. '[batchSize = ' .. batchSize .. ']')
	for t=1,train:size(),batchSize do
		xlua.progress(t,train:size())
		
		batchData = {}
		batchLabel = {}
		
		for i=t,math.min(t+batchSize-1,train:size()) do
			local input=torch.Tensor(2,1,28,28) --注意这里,每个样本是28*28的tensor,但是模型中cnn的输入要求是1*28*28的所以应该存成2*1*28*28的tensor
			input[1]=train.data[i]
			input[2]=train.data[shuffer[i]]
			if train.label[i] == train.label[shuffer[i]] then
				target = 1
			else
				target = -1
			end
			table.insert(batchData,input)
			table.insert(batchLabel,target)
		end
		local feval = function(x)
			if x~= parameters then
				parameters:copy(x)
			end
			
			model:zeroGradParameters()
			
			local f=0
			for i=1,#batchData do
			--print(#batchData[i])
				local output = model:forward(batchData[i])
				local err = criterion:forward(output,batchLabel[i])
				f=f+err
				
				local df_do = criterion:backward(output,batchLabel[i])
				model:backward(batchData[i],df_do)
			end
			
			gradParameters:div(#batchData)
			f=f/#batchData
			return f, gradParameters
		end
		optimState = {leraningRate=learningRate}
		optim.adam(feval,parameters,optimState)
	end
		
	time = sys.clock()-time
	time=time/train:size()
	
	print('=================> time to learn one smaple = ' .. (time*1000) .. 'ms')
	epoch =epoch+1
end			

5_Test 模型测试

这里我只是测试了模型了输出误差,其实评价该模型可以通过confusion矩阵实现,偷了个懒,后面可视化的时候也可以看到分类结果

require 'xlua'
function testing()
	print '======> testing:' 
	local time=sys.clock()
	local shuffer = torch.randperm(test:size())
	err=0
	for t=1,test:size() do
		xlua.progress(t,test:size())
		local input=torch.Tensor(2,1,28,28)
		input[1]=test.data[t]
		input[2]=test.data[shuffer[t]]
		if test.label[t]==test.label[shuffer[t]] then
			target = 1
		else
			target = -1
		end
		
		output=model:forward(input)
		f=criterion(output,target)
		
		err=err+f
	end
	
	time=sys.clock()-time
	time = time/test:size()
	print('=======> time to test each sample = ' .. (time*1000) .. 'ms')
	print('=======> average error is ' .. err/test:size())
end

6_visualization 结果可视化

这里我使用了itorch:Plot()的功能,折腾了很久ipython-notebook还是没装好,只是装好的itorch,参见官网

results={}
for i=1,10 do 
	table.insert(results,{x={},y={}})
end

for t=1,5000 do   -- 这里我们验证了5000个样本,如果绘制10000个样本的话实在太密集了
	local idx=test.label[t]
	local data=torch.Tensor(1,28,28)
	data[1]=test.data[t]
	local pos = cnn:forward(data)
	if idx==0 then 
		idx=10
	end
	
	table.insert(results[idx].x,pos[1])
	table.insert(results[idx].y,pos[2]) 
end

Plot =require'itorch.Plot'
plot=Plot():circle(results[1].x,results[1].y,'red','1'):draw()
plot:circle(results[2].x,results[2].y,'green','2'):redraw()
plot:circle(results[3].x,results[3].y,'blue','3'):redraw()
plot:circle(results[4].x,results[4].y,'black','4'):redraw()
plot:circle(results[5].x,results[5].y,'orange','5'):redraw()
plot:triangle(results[6].x,results[6].y,'red','6'):redraw()
plot:triangle(results[7].x,results[7].y,'green','7'):redraw()
plot:triangle(results[8].x,results[8].y,'blue','8'):redraw()
plot:triangle(results[9].x,results[9].y,'black','9'):redraw()
plot:triangle(results[10].x,results[10].y,'orange','10'):redraw()
plot:title('样本降维到2维时的分布'):redraw()
plot:xaxis('x1'):yaxis('x2'):redraw()
plot:legend(true)
plot:redraw()
plot:save('out.html') --只能保存成html之后再人工保存成png图像

这个模型有点类似于使用FDA找到两个主方向

7_doall 统一执行文件

dofile '1_data.lua'
dofile '2_model.lua'
dofile '3_loss.lua'
dofile '4_train.lua'
dofile '5_test.lua'

k=1
while k<30 do
	training()
	k=k+1
end
testing()
dofile '6_visualization.lua'

结果

enter description here

idx.png

参考资料:
Teaonly/easylearning.io/siamese_network
深度学习实验: Siamese network
facebook/iTorch

免责声明:文章转载自《torch 深度学习(5)》仅用于学习参考。如对内容有疑问,请及时联系本站处理。

上篇常用统计图的对比kafka 消息队列下篇

宿迁高防,2C2G15M,22元/月;香港BGP,2C5G5M,25元/月 雨云优惠码:MjYwNzM=

相关文章

Vue2.0进阶组件 短信倒计时组件

原本我想隔个几天再发文章,刚好今天项目上线,环境有问题,导致只有干等,刚好要为公司打造一套属于公司自己的一系列功能组件,这个使命就交给我了,大家也一直叫我来点干货,说实话我只是一个湿货,肚子里干一点就给你们出点货,那从今天开始不看岛国片系列教程视频,不但自撸,还教你撸............你懂的!!最强vue组件 写之前我只想说如果看到错别字,就别给我点...

python 识别登录验证码图片功能的实现代码(完整代码)

在编写自动化测试用例的时候,每次登录都需要输入验证码,后来想把让python自己识别图片里的验证码,不需要自己手动登陆,所以查了一下识别功能怎么实现,做一下笔记。 首选导入一些用到的库,re、Image、pytesseract、selenium、time import re # 用于正则 from PIL import Image # 用于打开图片和对图片...

数据可视化基础专题(44):NUMPY基础(9)数组操作(1)修改数组形状/翻转数组

1 修改数组形状 函数 描述 reshape 不改变数据的条件下修改形状 flat 数组元素迭代器 flatten 返回一份数组拷贝,对拷贝所做的修改不会影响原始数组 ravel 返回展开数组 numpy.reshape numpy.reshape 函数可以在不改变数据的条件下修改形状,格式如下: numpy.reshape(a...

Python源码.py文件打包为.whl文件

1 python源码.py文件打包  1.1 安装工具包 python源文件打包需要用到setuptools和wheel工具包:  1.2建立python项目源文件   建立一个名称为hello的项目包和setup.py文件    其中hello项目包中有一个hello_world.py文件和一个__init__.py文件 hello_world.py...

提高iOS开发效率的第三方框架等--不断更新中。。。

1. Mantle Mantle 让我们能简化 Cocoa 和 Cocoa Touch 应用的 model 层。简单点说,程序中经常要进行网络请求,请求到得一般是 json 字符串,我们一般会建一个 Model 类来存放这些数据。这就要求我们编写一系列的序列化代码,来把 json 转换为 Model 。这很费时间,容易错,不容易修改。 Mantle 很好...

snprintf()函数使用方法

众所周知,sprintf不能检查目标字符串的长度,可能造成众多安全问题,所以都会推荐使用snprintf. 自从snprintf代替了sprintf,相信大家对snprintf的使用都不会少,函数定义如下: int snprintf(char*str, size_t size,constchar*format, ...); 函数说明: 最多从源串中拷贝s...