pytorch之 RNN 参数解释

摘要:
上次通过pytorch实现了RNN模型,简易的完成了使用RNN完成mnist的手写数字识别,但是里面的参数有点不了解,所以对问题进行总结归纳来解决。当然这是是对于RNN某一个节点而言的,那么如何规定RNN的节点个数呢?output的size,那么就是[batch_size,seq_len,n_hidden],对于分类任务如果要取得最后一个output,只需添加下标[:,-1,:]看图找答案:hn就是RNN的最后一个隐含状态,output就是RNN最终得到的结果。

上次通过pytorch实现了RNN模型,简易的完成了使用RNN完成mnist的手写数字识别,但是里面的参数有点不了解,所以对问题进行总结归纳来解决。

总述:
第一次看到这个函数时,脑袋有点懵,总结了下总共有五个问题:

1.这个input_size是啥?要输入啥?feature num又是啥?

2.这个hidden_size是啥?要输入啥?feature num又是啥?

3.不是说RNN会有很多个节点连在一起的吗?这怎么定义连接的节点数呢?

4.num_layer中说的stack是怎么stack的?

5.怎么输出会有两个东西呀output,hn

pytorch中RNN的一些参数,并且解决以上五个问题

1.Pytorch中的RNN

pytorch之 RNN 参数解释第1张

2.input_size是啥?
说白了input_size无非就是你输入RNN的维度,比如说NLP中你需要把一个单词输入到RNN中,这个单词的编码是300维的,那么这个input_size就是300.这里的input_size其实就是规定了你的输入变量的维度。用f(wX+b)来类比的话,这里输入的就是X的维度。

3.hidden_size是啥?
和最简单的BP网络一样的,每个RNN的节点实际上就是一个BP嘛,包含输入层,隐含层,输出层。这里的hidden_size呢,你可以看做是隐含层中,隐含节点的个数。

pytorch之 RNN 参数解释第2张

那个输入层的三个节点代表输入维度为3,也就是input_size=3,然后这个hidden_size就是5了。当然这是是对于RNN某一个节点而言的,那么如何规定RNN的节点个数呢?

4.如何规定节点个数?

事实上,节点个数并不需要规定,你的输入序列是这样子的,[x1,x2,x3,x4,x5],那么input_size呢就是你的xi的维度,而你的RNN的节点数呢,就是由你的序列长度决定的,在这里我们的序列长度是5,所以会有5个节点。那么问题来了,我咋知道你的序列长度呢?pytorch里面不是只有input_size的参数吗?实际上,你声明RNN是这样声明的

self.encoder = nn.RNN(input_size=300,hidden_size=128,dropout=0.5)
但是你用的时候;

output,hn = self.encoder(encoder_input,encoder_hidden)
你会把你的数据丢进去吧,也就是你把encoder_input这一整个序列丢进去了,那么序列长度他不就知道了?

5.num_layers是啥?
一开始你是不是以为这个就是RNN的节点数呀,hhh,然而并不是:),如果num_layer=2的话,表示两个RNN堆叠在一起。那么怎么堆叠的呢?

如果是num_layer==1的话:

pytorch之 RNN 参数解释第3张

如果num_layer==2的话:

pytorch之 RNN 参数解释第4张

ok了~最后再来看看最后一个问题

6.hn,output分别是啥?

hidden的输出size为[ num_layers* num_directions, batch_size, n_hidden].

说白了,hidden就是每个方向,每个层的 隐藏单元的输出,所以是n_hidden个。

output的size(如果RNN设定的batch_first=True),那么就是[batch_size,seq_len,n_hidden],对于分类任务如果要取得最后一个output,只需添加下标 [ :,-1,:]

看图找答案:

pytorch之 RNN 参数解释第5张

hn就是RNN的最后一个隐含状态,output就是RNN最终得到的结果。

免责声明:文章转载自《pytorch之 RNN 参数解释》仅用于学习参考。如对内容有疑问,请及时联系本站处理。

上篇Oracle Dedicated server 和 Shared server(专用模式 和 共享模式) 说明03.pandas数据DataFrame下篇

宿迁高防,2C2G15M,22元/月;香港BGP,2C5G5M,25元/月 雨云优惠码:MjYwNzM=

相关文章

JS实现刷新iframe的方法

<iframe src="http://t.zoukankan.com/1.htm" name="ifrmname" id="ifrmid"></iframe> 方案一:用iframe的name属性定位 <input type="button" name="Button" value="Button"onclick="docu...

layer弹窗在IOS上,被软键盘挤到上边的解决方法

就像这种情况,经过多番请教跟尝试,找到一个能解决这个问题的方法,但可能有点笨重。就是在当前弹框里,设置offset的值,里边的值可以随意写,然后再下边给弹框追加一个样式即可。 <!DOCTYPE html> <html> <head> <meta charset="UTF-8"&g...

【转载】Iptables详解

参考链接:http://blog.csdn.net/reyleon/article/details/12976341 Iptabels是与Linux内核集成的包过滤防火墙系统,几乎所有的linux发行版本都会包含Iptables的功能。如果 Linux 系统连接到因特网或 LAN、服务器或连接 LAN 和因特网的代理服务器, 则Iptables有利于在 L...

input 只能输入数字、字母、汉字等

1.文本框只能输入数字代码(小数点也不能输入) <input onkeyup="this.value=this.value.replace(/D/g,'')"onafterpaste="this.value=this.value.replace(/D/g,'')" /> 2.只能输入数字,能输小数点. <input onkeyup="if...

(二)vue数据处理

1:计算属性和监视   计算属性 1) 在 computed 属性对象中定义计算属性的方法 2) 在页面中使用{{方法名}}来显示计算的结果  2:监视属性 1) 通过通过 vm 对象的$watch()或 watch 配置来监视指定的属性 2) 当属性变化时, 回调函数自动调用, 在函数内部进行计算 3: 计算属性高级 1) 通过 getter/sette...

vue 文件上传

  学习参考地址: http://www.cnblogs.com/zhengweijie/p/6922808.html#3920491 依赖js文件: http://files.cnblogs.com/files/zhengweijie/jquery.form.rar HTML 文本内容: <template>   <div id="ac...