网上关于tensorflow模型文件ckpt格式转pb文件的帖子很多,本人几乎尝试了所有方法,最后终于成功了,现总结如下。方法无外乎下面两种:
- 使用tensorflow.python.tools.freeze_graph.freeze_graph
- 使用graph_util.convert_variables_to_constants
1、tensorflow模型的文件解读
使用tensorflow训练好的模型会自动保存为四个文件,如下
checkpoint:记录近几次训练好的模型结果(名称)。
xxx.data-00000-of-00001: 模型的所有变量的值(weights, biases, placeholders,gradients, hyper-parameters etc),也就是模型训练好参数和其他值。
xxx.index :模型的元数据,二进制或者其他格式,不可直接查看 。是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和一些辅助数据等。
xxx.meta:模型的meta数据 ,二进制或者其他格式,不可直接查看,保存了TensorFlow计算图的结构信息,通俗地讲就是神经网络的网络结构。
2、最常见的ckpt转pb文件的方法
2、ckpt转pb文件(freeze_graph.freeze_graph)
此种方法尝试成功,虽然不知道输出节点名,但是只要模型代码还在就可以操作,直接上代码。
import tensorflow as tf import os from tensorflow.python.tools import freeze_graph from model import network # network是你们自己定义的模型结构(代码结构) # egs: # def network(input): # return tf.layers.softmax(input) model_path = "model.ckpt-0000" #设置model的路径,因新版tensorflow会生成三个文件,只需写到数字前 def main(): tf.reset_default_graph() # 设置输入网络的数据维度,根据训练时的模型输入数据的维度自行修改 input_node = tf.placeholder(tf.float32, shape=(None, None, 200)) output_node = network(input_node) # 神经网络的输出 # 设置输出数据类型(特别注意,这里必须要跟输出网络参数的数据格式保持一致,不然会导致模型预测 精度或者预测能力的丢失)以及重新定义输出节点的名字(这样在后面保存pb文件以及之后使用pb文件时直接使用重新定义的节点名字即可) flow = tf.cast(output_node , tf.float16, 'the_outputs') saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, model_path) #保存模型图(结构),为一个json文件 tf.train.write_graph(sess.graph_def, 'output_model/pb_model', 'model.pb') #将模型参数与模型图结合,并保存为pb文件 freeze_graph.freeze_graph('output_model/pb_model/model.pb', '', False, model_path, 'the_outputs','save/restore_all', 'save/Const:0', 'output_model/pb_model/frozen_model.pb', False, "") print("done") if __name__ == '__main__': main()
2、ckpt转pb文件(graph_util.convert_variables_to_constants)
没有成功,因为不知道输出节点的名字,使用该方法保存后的pb文件只有几十k,无法使用,写在这里主要是为了总结。直接上代码,代码里面没有的库(函数),按提示自行import。
def freeze_graph(input_checkpoint,output_graph): ''' :param input_checkpoint: :param output_graph: PB模型保存路径 :return: ''' # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点 output_node_names = "InceptionV3/Logits/SpatialSqueeze" saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) graph = tf.get_default_graph() # 获得默认的图 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图 with tf.Session() as sess: saver.restore(sess, input_checkpoint) #恢复图并得到数据 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定 sess=sess, input_graph_def=input_graph_def,# 等于:sess.graph_def output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型 f.write(output_graph_def.SerializeToString()) #序列化输出 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点 # for op in graph.get_operations(): # print(op.name, op.values()) if __name__ == '__main__': # 输入ckpt模型路径 input_checkpoint='models/model.ckpt-10000' # 输出pb模型的路径 out_pb_path="models/pb/frozen_model.pb" # 调用freeze_graph将ckpt转为pb freeze_graph(input_checkpoint,out_pb_path)
参考链接:
https://www.jb51.net/article/185209.htm
https://www.jb51.net/article/185206.htm
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件! 如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
更新日志
- 小骆驼-《草原狼2(蓝光CD)》[原抓WAV+CUE]
- 群星《欢迎来到我身边 电影原声专辑》[320K/MP3][105.02MB]
- 群星《欢迎来到我身边 电影原声专辑》[FLAC/分轨][480.9MB]
- 雷婷《梦里蓝天HQⅡ》 2023头版限量编号低速原抓[WAV+CUE][463M]
- 群星《2024好听新歌42》AI调整音效【WAV分轨】
- 王思雨-《思念陪着鸿雁飞》WAV
- 王思雨《喜马拉雅HQ》头版限量编号[WAV+CUE]
- 李健《无时无刻》[WAV+CUE][590M]
- 陈奕迅《酝酿》[WAV分轨][502M]
- 卓依婷《化蝶》2CD[WAV+CUE][1.1G]
- 群星《吉他王(黑胶CD)》[WAV+CUE]
- 齐秦《穿乐(穿越)》[WAV+CUE]
- 发烧珍品《数位CD音响测试-动向效果(九)》【WAV+CUE】
- 邝美云《邝美云精装歌集》[DSF][1.6G]
- 吕方《爱一回伤一回》[WAV+CUE][454M]