pytorch 中的 torch.nn.RNN 的参数

时间:2024-03-14 08:09:38

1、定义RNN的网络结构的参数(类似于CNN中定义 in_channel,out_channel,kernel_size等等)

         input_size   输入x的特征大小(以mnist图像为例,特征大小为28*28 = 784)
         hidden_size   隐藏层h的特征大小
         num_layers    循环层的数量(RNN中重复的部分)
         nonlinearity   **函数 默认为tanh,可以设置为relu
         bias   是否设置偏置,默认为True
         batch_first   默认为false, 设置为True之后,输入输出为(batch_size, seq_len, input_size)
         dropout   默认为0
         bidirectional   默认为False,True设置为RNN为双向

【注】下图红色的部分为num_layers的个数(num_layers = 2 )

pytorch 中的 torch.nn.RNN 的参数

 

 

2、输入RNN网络与输出的参数

(1)输入:input:(seq_len,batch_size,input_size)    #(序列长度,batch_size,特征大小(数量))

                    h0:(num_layers*directions,batch_size,hidden_size)

(2)输出:hn:(num_layers*directions,batch_size,hidden_size)

                    output:(seq_len,batch_size,hidden_size*directions)

【注】bidirectional为Ture,则 directions=2,否则 directions=1 。

 

RNN的一个解析:https://www.jianshu.com/p/298116084ec7