建立与读取.pb文件

时间:2021-04-13 13:51:18
#coding=utf-8
import tensorflow as tf
from tensorflow.python.framework import graph_util

x = tf.placeholder(shape=[1], dtype=tf.float32, name='x')

varibale_1 = tf.get_variable('v1', [1], tf.float32, initializer=tf.random_normal_initializer(mean=1))

output = tf.multiply(x, varibale_1, name='mul')

initial_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(initial_op)
graph_def = tf.get_default_graph().as_graph_def()#将图定义取出
# print(graph_def)
out_graph = graph_util.convert_variables_to_constants(sess, graph_def, ['mul'])#将图中的变量转化为constant
print(sess.run(output,{x:[5]}))
print(sess.run(varibale_1))
with tf.gfile.GFile('./model.pb','wb') as f:
f.write(out_graph.SerializeToString())#将图定义转化为字符串形式并且写入.pb文件中
结果:
建立与读取.pb文件


读取.pb文件:

#coding=utf-8
import tensorflow as tf
from tensorflow.python.platform import gfile

k = tf.constant([1, 2, 3], dtype=tf.float32)


with tf.Session() as sess:
model_filename = 'model.pb'
with gfile.FastGFile(model_filename, 'rb') as f:#打开.pb文件
graph_def = tf.GraphDef()#建立一个图定义类
print(graph_def)
graph_def.ParseFromString(f.read())#将.pb文件中的信息写入该图定义类

v1= tf.import_graph_def(graph_def, return_elements=[ 'v1:0'])#载入图定义,并返回感兴趣的值
print(tf.get_default_graph().as_graph_def())
print(tf.get_default_graph().get_tensor_by_name('import/x:0'))
print(v1.name)