TypechoJoeTheme

半醉残影

统计

TensorFlow模型加载,用于恢复保存的模型

2018-09-17
/
0 评论
/
568 阅读
/
正在检测是否收录...
09/17

前言

在Tensorflow的使用中,常常需要保存训练好的模型或者中间模型,下次使用的时候直接加载恢复就可以从上次训练的状态恢复。
TensorFlow 保存的模型有很多种,下面就来介绍生产过程中常用的几种模型加载方法。

正文

加载 .pb 模型

# -*- coding: utf-8 -*-
import tensorflow as tf
import argparse

def load_frozenpb(file_name):
    with tf.gfile.GFile(file_name, "rb") as g:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(g.read())
        _ = tf.import_graph_def(graph_def, name="")
    
    sess = tf.Session()
    gd = sess.graph.as_graph_def()
    for node in gd.node:
        print(node.op,node.name)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("-n", "--name", type=str, default="resnet", help="model folder name")
    args = parser.parse_args()
    HOME = "/home/dyb/save_model/"
    name_path = HOME + "{}/{}.pb".format(args.name, args.name)
    load_frozenpb(name_path)

结果如下:

(u'Conv2D', u'resnet_v1_50/block4/unit_2/bottleneck_v1/conv1/Conv2D')
(u'Relu', u'resnet_v1_50/block4/unit_2/bottleneck_v1/conv1/Relu')
(u'Conv2D', u'resnet_v1_50/block4/unit_2/bottleneck_v1/conv2/Conv2D')
(u'Relu', u'resnet_v1_50/block4/unit_2/bottleneck_v1/conv2/Relu')
(u'Conv2D', u'resnet_v1_50/block4/unit_2/bottleneck_v1/conv3/Conv2D')
(u'Add', u'resnet_v1_50/block4/unit_2/bottleneck_v1/add')
(u'Relu', u'resnet_v1_50/block4/unit_2/bottleneck_v1/Relu')
(u'Conv2D', u'resnet_v1_50/block4/unit_3/bottleneck_v1/conv1/Conv2D')
(u'Relu', u'resnet_v1_50/block4/unit_3/bottleneck_v1/conv1/Relu')
(u'Conv2D', u'resnet_v1_50/block4/unit_3/bottleneck_v1/conv2/Conv2D')
(u'Relu', u'resnet_v1_50/block4/unit_3/bottleneck_v1/conv2/Relu')
(u'Conv2D', u'resnet_v1_50/block4/unit_3/bottleneck_v1/conv3/Conv2D')
(u'Add', u'resnet_v1_50/block4/unit_3/bottleneck_v1/add')
(u'Relu', u'resnet_v1_50/block4/unit_3/bottleneck_v1/Relu')
(u'Mean', u'resnet_v1_50/pool5')
(u'Conv2D', u'resnet_v1_50/logits/Conv2D')
(u'BiasAdd', u'resnet_v1_50/logits/BiasAdd')
(u'Squeeze', u'resnet_v1_50/SpatialSqueeze')
(u'Shape', u'resnet_v1_50/predictions/Shape')
(u'Reshape', u'resnet_v1_50/predictions/Reshape')
(u'Softmax', u'resnet_v1_50/predictions/Softmax')
(u'Reshape', u'resnet_v1_50/predictions/Reshape_1')

保存pb模型

gd = sess.graph.as_graph_def()
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, gd, output_node_names=["final_result"])
with tf.gfile.FastGFile( "test_new.pb", mode='wb') as f:
    f.write(output_graph_def.SerializeToString())
    print("save model success.")
# tf.train.write_graph( graph_def , './' , 'protobuf.pb' , as_text = False )

加载pbtxt模型

# -*- coding: utf-8 -*-
# 注意和加载pb文件的代码区别,虽然大致相同,但关键点不同哦
import tensorflow as tf
from google.protobuf import text_format
tf.reset_default_graph()

def load_frozenpb(file_name):
    with tf.gfile.GFile(file_name, "rb") as g:
        graph_def = tf.GraphDef()
        text_format.Merge(g.read(), graph_def)
        _ = tf.import_graph_def(graph_def, name="")
    
    sess = tf.Session()
    return sess



sess = load_frozenpb("./test.pbtxt")

保存pbtxt模型

tf.train.write_graph( graph_def , './' , 'protobuf.pb' , as_text = True )

加载ckpt

import tensorflow as tf
from google.protobuf import text_format
with tf.gfile.GFile("graph.pbtxt", "r") as f:
    graph_def = tf.GraphDef()
    text_format.Merge(f.read(), graph_def)

tf.import_graph_def(graph_def, name="")
graph = tf.get_default_graph()

config = tf.ConfigProto(allow_soft_placement=True)

session = tf.Session(config=config)
variables_list = []
for node in graph_def.node:
    if node.op == "VariableV2":
        variables_list.append(graph.get_tensor_by_name(node.name + ":0"))
ckpt = tf.train.get_checkpoint_state("./")
#saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
saver = tf.train.Saver(variables_list)
#session.run(tf.global_variables_initializer())
tf.logging.info("Reading model parameters from %s", ckpt.model_checkpoint_path)
saver.restore(session, ckpt.model_checkpoint_path)

保存ckpt

saver = tf.train.Saver()
saver.save(sess, model_path + model_name)

加载saved_model

saved_model主要是tf setvering 使用的模型存储格式,方便快速搭建tf服务

builder = tf.saved_model.builder.SavedModelBuilder("./sa_model")
builder.add_meta_graph_and_variables(session, [tf.saved_model.tag_constants.SERVING])
builder.save()

保存saved_model

import tensorflow as tf
export_dir = "tmp2/MetaGraphDir"
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.Session(graph=tf.Graph()) as sess:
    builder.add_meta_graph_and_variables(sess,
                                         ["TRAINING"],
                                         signature_def_map=foo_signatures,
                                         assets_collection=foo_assets
                                         )
# Add a sencod MetaGraphDef for inference
with tf.Session(graph=tf.Graph()) as sess:
    builder.add_meta_graph(["SERVING"])
builder.save()

保存Tensorboard图

tf.summary.FileWriter("./", session.graph)
TensorFlow深度学习机器学习
朗读
赞(0)
赞赏
感谢您的支持,我会继续努力哒!

三合一收款

下面三种方式都支持哦

微信
QQ
支付宝
打开支付宝/微信/QQ扫一扫,即可进行扫码打赏哦
版权属于:

半醉残影

本文链接:

https://blog.dengyb.com/archives/26/(转载时请注明本文出处及文章链接)

评论 (0)