tf.Variable

时间:2023-03-09 17:55:03
tf.Variable

tf.Variable

__init__(
initial_value=None,
trainable=True,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
expected_shape=None,
import_scope=None
)

功能说明:

维护图在执行过程中的状态信息,例如神经网络权重值的变化。

参数列表:

参数名 类型 说明
initial_value 张量 Variable 类的初始值,这个变量必须指定 shape 信息,否则后面 validate_shape 需设为 False
trainable Boolean 是否把变量添加到 collection GraphKeys.TRAINABLE_VARIABLES 中(collection 是一种全局存储,不受变量名生存空间影响,一处保存,到处可取)
collections Graph collections 全局存储,默认是 GraphKeys.GLOBAL_VARIABLES
validate_shape Boolean 是否允许被未知维度的 initial_value 初始化
caching_device string 指明哪个 device 用来缓存变量
name string 变量名
dtype dtype 如果被设置,初始化的值就会按照这个类型初始化
expected_shape TensorShape 要是设置了,那么初始的值会是这种维度

示例代码:

import tensorflow as tf
initial= tf.truncated_normal(shape=[10,10],mean=0,stddev=1)
W=tf.Variable(initial)
list=[[1.,1.],[2.,2.]]
X=tf.Variable(list,dtype=tf.float32)
ini_op=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(ini_op)
print(sess.run(W[:2,:2])) op=W[:2,:2].assign(22.*tf.ones((2,2)))
print(sess.run(op))
print (W.eval()) #Usage with the default session
print ("#####################(6)#############")
print (W.dtype)
print (sess.run(W.initial_value))
print (sess.run(W.op))
print (W.shape)
print ("###################(7)###############")
print (sess.run(X))