运行以下代码,进入~/anaconda3/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py和~/anaconda3/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell.py单步追踪调试
调试中import tensorflow as tf,利用tf.Session().run(variable)打印变量
查看BasicRNNCell和dynamic_rnn的实现方式
#-*-coding:utf8-*- __author = "buyizhiyou"
__date = "2017-11-20" '''
单步调试,学习rnn的tf实现
'''
import tensorflow as tf
import numpy as np
import pdb X = tf.random_normal(shape=[2,3,4], dtype=tf.float32)#(2,3,4)==>(Batch_size,Time_steps(序列长度),Data_Vector)
pdb.set_trace()
cell = tf.nn.rnn_cell.BasicRNNCell(10)#output_size:10,也可以换成GRUCell,LSTMAACell,BasicRNNCell
state = cell.zero_state(2, tf.float32)#batch_size:2
output, state = tf.nn.dynamic_rnn(cell, X, initial_state=state, time_major=False)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print (output.get_shape())
print (sess.run(state))