那棵树看起来生气了
TensorFlow模型加载,用于恢复保存的模型
09/17
前言
在Tensorflow的使用中,常常需要保存训练好的模型或者中间模型,下次使用的时候直接加载恢复就可以从上次训练的状态恢复。
TensorFlow 保存的模型有很多种,下面就来介绍生产过程中常用的几种模型加载方法。
正文
加载 .pb 模型
# -*- coding: utf-8 -*-
import tensorflow as tf
import argparse
def load_frozenpb(file_name):
with tf.gfile.GFile(file_name, "rb") as g:
graph_def = tf.GraphDef()
graph_def.ParseFromString(g.read())
_ = tf.import_graph_def(graph_def, name="")
sess = tf.Session()
gd = sess.graph.as_graph_def()
for node in gd.node:
print(node.op,node.name)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--name", type=str, default="resnet", help="model folder name")
args = parser.parse_args()
HOME = "/home/dyb/save_model/"
name_path = HOME + "{}/{}.pb".format(args.name, args.name)
load_frozenpb(name_path)
结果如下:
(u'Conv2D', u'resnet_v1_50/block4/unit_2/bottleneck_v1/conv1/Conv2D')
(u'Relu', u'resnet_v1_50/block4/unit_2/bottleneck_v1/conv1/Relu')
(u'Conv2D', u'resnet_v1_50/block4/unit_2/bottleneck_v1/conv2/Conv2D')
(u'Relu', u'resnet_v1_50/block4/unit_2/bottleneck_v1/conv2/Relu')
(u'Conv2D', u'resnet_v1_50/block4/unit_2/bottleneck_v1/conv3/Conv2D')
(u'Add', u'resnet_v1_50/block4/unit_2/bottleneck_v1/add')
(u'Relu', u'resnet_v1_50/block4/unit_2/bottleneck_v1/Relu')
(u'Conv2D', u'resnet_v1_50/block4/unit_3/bottleneck_v1/conv1/Conv2D')
(u'Relu', u'resnet_v1_50/block4/unit_3/bottleneck_v1/conv1/Relu')
(u'Conv2D', u'resnet_v1_50/block4/unit_3/bottleneck_v1/conv2/Conv2D')
(u'Relu', u'resnet_v1_50/block4/unit_3/bottleneck_v1/conv2/Relu')
(u'Conv2D', u'resnet_v1_50/block4/unit_3/bottleneck_v1/conv3/Conv2D')
(u'Add', u'resnet_v1_50/block4/unit_3/bottleneck_v1/add')
(u'Relu', u'resnet_v1_50/block4/unit_3/bottleneck_v1/Relu')
(u'Mean', u'resnet_v1_50/pool5')
(u'Conv2D', u'resnet_v1_50/logits/Conv2D')
(u'BiasAdd', u'resnet_v1_50/logits/BiasAdd')
(u'Squeeze', u'resnet_v1_50/SpatialSqueeze')
(u'Shape', u'resnet_v1_50/predictions/Shape')
(u'Reshape', u'resnet_v1_50/predictions/Reshape')
(u'Softmax', u'resnet_v1_50/predictions/Softmax')
(u'Reshape', u'resnet_v1_50/predictions/Reshape_1')
保存pb模型
gd = sess.graph.as_graph_def()
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, gd, output_node_names=["final_result"])
with tf.gfile.FastGFile( "test_new.pb", mode='wb') as f:
f.write(output_graph_def.SerializeToString())
print("save model success.")
# tf.train.write_graph( graph_def , './' , 'protobuf.pb' , as_text = False )
加载pbtxt模型
# -*- coding: utf-8 -*-
# 注意和加载pb文件的代码区别,虽然大致相同,但关键点不同哦
import tensorflow as tf
from google.protobuf import text_format
tf.reset_default_graph()
def load_frozenpb(file_name):
with tf.gfile.GFile(file_name, "rb") as g:
graph_def = tf.GraphDef()
text_format.Merge(g.read(), graph_def)
_ = tf.import_graph_def(graph_def, name="")
sess = tf.Session()
return sess
sess = load_frozenpb("./test.pbtxt")
保存pbtxt模型
tf.train.write_graph( graph_def , './' , 'protobuf.pb' , as_text = True )
加载ckpt
import tensorflow as tf
from google.protobuf import text_format
with tf.gfile.GFile("graph.pbtxt", "r") as f:
graph_def = tf.GraphDef()
text_format.Merge(f.read(), graph_def)
tf.import_graph_def(graph_def, name="")
graph = tf.get_default_graph()
config = tf.ConfigProto(allow_soft_placement=True)
session = tf.Session(config=config)
variables_list = []
for node in graph_def.node:
if node.op == "VariableV2":
variables_list.append(graph.get_tensor_by_name(node.name + ":0"))
ckpt = tf.train.get_checkpoint_state("./")
#saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
saver = tf.train.Saver(variables_list)
#session.run(tf.global_variables_initializer())
tf.logging.info("Reading model parameters from %s", ckpt.model_checkpoint_path)
saver.restore(session, ckpt.model_checkpoint_path)
保存ckpt
saver = tf.train.Saver()
saver.save(sess, model_path + model_name)
加载saved_model
saved_model主要是tf setvering 使用的模型存储格式,方便快速搭建tf服务
builder = tf.saved_model.builder.SavedModelBuilder("./sa_model")
builder.add_meta_graph_and_variables(session, [tf.saved_model.tag_constants.SERVING])
builder.save()
保存saved_model
import tensorflow as tf
export_dir = "tmp2/MetaGraphDir"
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.Session(graph=tf.Graph()) as sess:
builder.add_meta_graph_and_variables(sess,
["TRAINING"],
signature_def_map=foo_signatures,
assets_collection=foo_assets
)
# Add a sencod MetaGraphDef for inference
with tf.Session(graph=tf.Graph()) as sess:
builder.add_meta_graph(["SERVING"])
builder.save()
保存Tensorboard图
tf.summary.FileWriter("./", session.graph)
三合一收款
下面三种方式都支持哦