tensorflow实战google深度学习框架阅读笔记——保存,读取model(ckpt文件)

时间:2024-04-10 19:50:56

最近在阅读《tensorflow实战google深度学习框架》,对里面讲到的内容,重点部分做下摘抄和笔记,以备后面查阅。部分内容为本人个人理解,如果错误,请指正,如果侵权,请联系删除,谢谢。转载请注明出处,谢谢。


将模型保存为ckpt文件

    首先,创建一个saver对象:saver=tf.train.Saver(max_to_keep = 5) 注意,这句话要写在创建graph的代码中,在图创建完成并且初始化variable后,再调用。max_to_keep代表需要保存的模型的个数,默认为5,如果只需要保存最新的模型,设置为1即可

    然后,saver.save(sess,'ckpt/mnist.ckpt',global_step=step,write_meta_graph=False)  ,其中,第一个参数为sess,第二个参数设定保存的路径和名字,这里可以用os.path.join(opts.save_path,"model.ckpt")来组合,第一个为路径,第二个为名字。第三个参数将训练的次数作为后缀加入到模型名字中去(改参数可以不加),第四个参数为是否保存.meta,meta中保存的是模型的图,不需要每次都保存,所以可以设置为false,默认为true,这里可以这么写:

saver.save(sess, 'model/model.ckpt', global_step=step, write_meta_graph=False)

if not os.path.exists(' model/model.meta'):

   saver.export_meta_graph(metagraph_filename)


从ckpt文件中读取参数,恢复模型

有两种方法:

方法一:不恢复模型,直接从ckpt文件中读取参数:

tensorflow实战google深度学习框架阅读笔记——保存,读取model(ckpt文件)

方法二:恢复模型,然后调用sess.run获取参数:

tensorflow实战google深度学习框架阅读笔记——保存,读取model(ckpt文件)

注意:这里面,tf.train.import_meta_graph是从meta文件中读取图,如果图中存在自定义的op,是行不通的,所以【4】这个操作是加载自定义op

另外,这里参数的名字可以通过方式一获取,然后这里的要用  参数名:0 的方式获取参数。