TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例—Jason niu

时间:2021-12-05 01:58:11
import tensorflow as tf
# 22 scope (name_scope/variable_scope)
from __future__ import print_function class TrainConfig:
batch_size = 20
time_steps = 20
input_size = 10
output_size = 2
cell_size = 11
learning_rate = 0.01 class TestConfig(TrainConfig):
time_steps = 1 class RNN(object): def __init__(self, config):
self._batch_size = config.batch_size
self._time_steps = config.time_steps
self._input_size = config.input_size
self._output_size = config.output_size
self._cell_size = config.cell_size
self._lr = config.learning_rate
self._built_RNN() def _built_RNN(self):
with tf.variable_scope('inputs'):
self._xs = tf.placeholder(tf.float32, [self._batch_size, self._time_steps, self._input_size], name='xs')
self._ys = tf.placeholder(tf.float32, [self._batch_size, self._time_steps, self._output_size], name='ys')
with tf.name_scope('RNN'):
with tf.variable_scope('input_layer'):
l_in_x = tf.reshape(self._xs, [-1, self._input_size], name='2_2D') # (batch*n_step, in_size)
# Ws (in_size, cell_size)
Wi = self._weight_variable([self._input_size, self._cell_size])
print(Wi.name)
# bs (cell_size, )
bi = self._bias_variable([self._cell_size, ])
# l_in_y = (batch * n_steps, cell_size)
with tf.name_scope('Wx_plus_b'):
l_in_y = tf.matmul(l_in_x, Wi) + bi
l_in_y = tf.reshape(l_in_y, [-1, self._time_steps, self._cell_size], name='2_3D') with tf.variable_scope('cell'):
cell = tf.contrib.rnn.BasicLSTMCell(self._cell_size)
with tf.name_scope('initial_state'):
self._cell_initial_state = cell.zero_state(self._batch_size, dtype=tf.float32) self.cell_outputs = []
cell_state = self._cell_initial_state
for t in range(self._time_steps):
if t > 0: tf.get_variable_scope().reuse_variables()
cell_output, cell_state = cell(l_in_y[:, t, :], cell_state)
self.cell_outputs.append(cell_output)
self._cell_final_state = cell_state with tf.variable_scope('output_layer'):
# cell_outputs_reshaped (BATCH*TIME_STEP, CELL_SIZE)
cell_outputs_reshaped = tf.reshape(tf.concat(self.cell_outputs, 1), [-1, self._cell_size])
Wo = self._weight_variable((self._cell_size, self._output_size))
bo = self._bias_variable((self._output_size,))
product = tf.matmul(cell_outputs_reshaped, Wo) + bo
# _pred shape (batch*time_step, output_size)
self._pred = tf.nn.relu(product) # for displacement with tf.name_scope('cost'):
_pred = tf.reshape(self._pred, [self._batch_size, self._time_steps, self._output_size])
mse = self.ms_error(_pred, self._ys)
mse_ave_across_batch = tf.reduce_mean(mse, 0)
mse_sum_across_time = tf.reduce_sum(mse_ave_across_batch, 0)
self._cost = mse_sum_across_time
self._cost_ave_time = self._cost / self._time_steps with tf.variable_scope('trian'):
self._lr = tf.convert_to_tensor(self._lr)
self.train_op = tf.train.AdamOptimizer(self._lr).minimize(self._cost) @staticmethod
def ms_error(y_target, y_pre):
return tf.square(tf.subtract(y_target, y_pre)) @staticmethod
def _weight_variable(shape, name='weights'):
initializer = tf.random_normal_initializer(mean=0., stddev=0.5, )
return tf.get_variable(shape=shape, initializer=initializer, name=name) @staticmethod
def _bias_variable(shape, name='biases'):
initializer = tf.constant_initializer(0.1)
return tf.get_variable(name=name, shape=shape, initializer=initializer) if __name__ == '__main__':
train_config = TrainConfig() #定义train_config
test_config = TestConfig() # # the wrong method to reuse parameters in train rnn
# with tf.variable_scope('train_rnn'):
# train_rnn1 = RNN(train_config)
# with tf.variable_scope('test_rnn'):
# test_rnn1 = RNN(test_config) # the right method to reuse parameters in train rnn
#目的使train的RNN调用参数,然后利用variable_scope方法共享RNN,让test的RNN再次调用一样的参数,
with tf.variable_scope('rnn') as scope:
sess = tf.Session()
train_rnn2 = RNN(train_config)
scope.reuse_variables() #告诉TF想重复利用RNN的参数
test_rnn2 = RNN(test_config)
# tf.initialize_all_variables() no long valid from
# 2017-03-02 if using tensorflow >= 0.12
if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
init = tf.initialize_all_variables()
else:
init = tf.global_variables_initializer()
sess.run(init)

  

TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例—Jason niu的更多相关文章

  1. TF之RNN:TF的RNN中的常用的两种定义scope的方式get&lowbar;variable和Variable—Jason niu

    # tensorflow中的两种定义scope(命名变量)的方式tf.get_variable和tf.Variable.Tensorflow当中有两种途径生成变量 variable import te ...

  2. 深度学习原理与框架-递归神经网络-RNN&lowbar;exmaple&lpar;代码&rpar; 1&period;rnn&period;BasicLSTMCell&lpar;构造基本网络&rpar; 2&period;tf&period;nn&period;dynamic&lowbar;rnn&lpar;执行rnn网络&rpar; 3&period;tf&period;expand&lowbar;dim&lpar;增加输入数据的维度&rpar; 4&period;tf&period;tile&lpar;在某个维度上按照倍数进行平铺迭代&rpar; 5&period;tf&period;squeeze&lpar;去除维度上为1的维度&rpar;

    1. rnn.BasicLSTMCell(num_hidden) #  构造单层的lstm网络结构 参数说明:num_hidden表示隐藏层的个数 2.tf.nn.dynamic_rnn(cell, ...

  3. TF之RNN:matplotlib动态演示之基于顺序的RNN回归案例实现高效学习逐步逼近余弦曲线—Jason niu

    import tensorflow as tf import numpy as np import matplotlib.pyplot as plt BATCH_START = 0 TIME_STEP ...

  4. TF之RNN:基于顺序的RNN分类案例对手写数字图片mnist数据集实现高精度预测—Jason niu

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...

  5. TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架

    TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架 http://blog.sina.com.cn/s/blog_4b0020f30102wv4l.html

  6. TF之RNN:TensorBoard可视化之基于顺序的RNN回归案例实现蓝色正弦虚线预测红色余弦实线—Jason niu

    import tensorflow as tf import numpy as np import matplotlib.pyplot as plt BATCH_START = 0 TIME_STEP ...

  7. TF:利用sklearn自带数据集使用dropout解决学习中overfitting的问题&plus;Tensorboard显示变化曲线—Jason niu

    import tensorflow as tf from sklearn.datasets import load_digits #from sklearn.cross_validation impo ...

  8. TF:Tensorflow结构简单应用,随机生成100个数,利用Tensorflow训练使其逼近已知线性直线的效率和截距—Jason niu

    import os os.environ[' import tensorflow as tf import numpy as np x_data = np.random.rand(100).astyp ...

  9. 深度学习原理与框架-图像补全&lpar;原理与代码&rpar; 1&period;tf&period;nn&period;moments&lpar;求平均值和标准差&rpar; 2&period;tf&period;control&lowbar;dependencies&lpar;先执行内部操作&rpar; 3&period;tf&period;cond&lpar;判别执行前或后函数&rpar; 4&period;tf&period;nn&period;atrous&lowbar;conv2d 5&period;tf&period;nn&period;conv2d&lowbar;transpose&lpar;反卷积&rpar; 7&period;tf&period;train&period;get&lowbar;checkpoint&lowbar;state&lpar;判断sess是否存在

    1. tf.nn.moments(x, axes=[0, 1, 2])  # 对前三个维度求平均值和标准差,结果为最后一个维度,即对每个feature_map求平均值和标准差 参数说明:x为输入的fe ...

随机推荐

  1. python2&period;7使用ansible

    升级python到2.7后 $ ansible 报错 Traceback (most recent call last): File "/usr/bin/ansible", lin ...

  2. SQL Server中查询结果拼接遇到的小问题

    前天的项目,刚接手,对于模块还不是很熟悉,其中有一个模块,涉及到4个表,其中主要的表就有两个,只要把这个弄清楚了就一切回归于“太平”了. 模块要求:把两个表的内容查询出来,结果连接在一起.大师说完,感 ...

  3. 【自动化测试】Selenium常用的键盘事件

    send_keys(Keys.BACK_SPACE) 删除键(BackSpace)send_keys(Keys.SPACE) 空格键(Space)send_keys(Keys.TAB) 制表键(Tab ...

  4. Delphi-CompareStr 函数

    函数名称 CompareStr 所在单元 System.SysUtils 函数原型 function CompareStr(const S1, S2: string): Integer; 函数功能 比 ...

  5. VS2005&plus;WINDDK&plus;Driver Studio 3&period;2个人总结

    通过在网上搜索大量的资料,终于把环境搭建起来.对于我这样的驱动新手来说,理应把高手们的东西整理并总结下,方便以后的初学者. 这三个软件的安装顺序没有具体规定,也有高手推荐的顺序,我自己也是重复安装卸载 ...

  6. Intent的几种Flag的不同

    冬天有点冷,不想写博客. 研究下Intent的几种Flag的不同: 1,FLAG_ACTIVITY_CLEAR_TOP:会清理掉目标activity栈上面所有的activity Intent inte ...

  7. 各大型网站架构分析收集-原网址http&colon;&sol;&sol;blog&period;csdn&period;net&sol;lovingprince&sol;article&sol;details&sol;3379710

    1. PlentyOfFish 网站架构学习http://www.dbanotes.net/arch/plentyoffish_arch.html 采取 Windows 技术路线的 Web 2.0 站 ...

  8. SWIFT中使用AFNetwroking访问网络数据

    AFNetworking 是 iOS 一个使用很方便的第三方网络开发框架,它可以很轻松的从一个URL地址内获取JSON数据. 在使用它时我用到包管理器Cocoapods 不懂的请移步: Cocoapo ...

  9. spring MVC 后台token防重复提交解决方案

    看到公司有个部门提出了这个问题,补个粗略的解决方案... 1.编写拦截器 /** * Description: 防止重复提交 * * @Author liam * @Create Date: 2018 ...

  10. 配置主从Mysql

    怎么安装mysql数据库,这里不说了,只说它的主从复制,步骤如下: 1.主从服务器分别作以下操作:  1.1.版本一致  1.2.初始化表,并在后台启动mysql  1.3.修改root的密码 2.修 ...