那棵树看起来生气了
TensorFlow模型Graph修改插入节点
背景
在优化一个深度学习模型的时候,客户给的模型为已经冻结过得pb文件,而这个pb文件没有input(Placeholder)节点,把特定的数据输入这个模型无法做到,而且对于测试性能和数据的准确性非常不便。然后我就想,用什么办法可以修改已经训练好模型的结构呢?最后找到一种方法,修改计算图结构。
正文
假如原topo.pb结构如下形式,而我们想要在 Conv2D_1前面加上某个节点,这里就假如需要加上Placeholder
Conv2D_1 -> Relu -> Conv2D_2 -> Softmax
步骤
生成pbtxt文件
对于pb文件,肉眼几乎是不可读的,用vim编辑器打开如下:
^Mconv_1/Conv2D^R^FConv2D^Z^Dconv^Z^Tconv_1/conv2d_params*^R
^Gpadding^R^G^R^EVALID*^S
^Gstrides^R^H
^F^Z^D^A^A^A^A*^V
^Puse_cudnn_on_gpu^R^B(^A*^U
^Kdata_format^R^F^R^DNHWC*^G
^AT^R^B0^A
Æ^A
^Uconv_1/batchnorm/beta^R^EConst*^K
^Edtype^R^B0^A*<98>^A
^Evalue^R<8e>^AB<8b>^A^H^A^R^D^R^B^H "<80>^AÜÆé¿^Mé=¿^W<80><94>¾<87><8f>4?w¨<8c>?^TM<9b>¾x\<8a>>¸ñ<8b>>R<9c><8f>>Z^[ë>'±<9a>>%ÓA@^]ã ¾÷çB?Jjâ?ò^^j?!&¶¿Í<9e>^M@¢|ñ¿!»ê¿´ü^W@ÎI]¿³<8a>G@^Ft¿{æB@^X<94>8?8<8b><8f>¾Ä<90>d¿h=<8c>¾Êy^B@YxÛ?/<95><9c>>
Ç^A
^Vconv_1/batchnorm/gamma^R^EConst*^K
^Edtype^R^B0^A*<98>^A
这里我把它转换为goole probuf 文本文件,代码如下:
# make_pbtxt.py
# -*- coding: utf-8 -*-
import tensorflow as tf
## 加载模型
def load_frozenpb(file_name):
with tf.gfile.GFile(file_name, "rb") as g:
graph_def = tf.GraphDef()
graph_def.ParseFromString(g.read())
_ = tf.import_graph_def(graph_def, name="")
sess = tf.Session()
return sess
tf.reset_default_graph()
sess = load_frozenpb("./test.pb")
tf.train.write_graph(sess.graph, './','test.pbtxt')
$ python make_pbtxt.py
修改pbtxt
经过上边的命令,会在当前目录生成test.pbtxt文件,使用vim打开查看格式,这个时候因该是个可读性很强的内容,基本看一眼就知道是干什么的了。
node {
name: "moments/Squeeze"
op: "Squeeze"
input: "moments/mean"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "squeeze_dims"
value {
list {
i: 0
i: 1
i: 2
}
}
}
}
node {
name: "moments/Squeeze_1"
op: "Squeeze"
input: "moments/variance"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "squeeze_dims"
value {
list {
i: 0
i: 1
i: 2
}
}
}
}
这个时候就可以非常方便的修改该文件,但需要注意的是按照原有格式规范修改。
把需要插入的节点放在pbtxt文件的最开头,比如我插入一个Placeholder。如下内容,可以修改他的shape等属性适应业务需求
# 下面的结构生成的方法还是用上一步生成pbtxt的方法生成
node {
name: "Placeholder"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
dim {
size: 299
}
dim {
size: 299
}
dim {
size: 3
}
}
}
}
}
经过这一步后,新的节点已经放进去了,但是还没有连接到原有计算图上,所以还需要连接起来
连接
找到需要插入的位置(即新的节点作为哪个节点的输入),这里我找到了我要插入的位置 name: "conv/Conv2D",然后修改该节点的input为新节点的name,保存退出即可
# 这里我找到了我要插入的位置 name: "conv/Conv2D",然后修改该节点的input为新节点的name,保存退出即可
node {
name: "conv/Conv2D"
op: "Conv2D"
input: "Placeholder"
input: "conv/conv2d_params"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "data_format"
value {
s: "NHWC"
}
}
attr {
key: "dilations"
value {
list {
i: 1
i: 1
i: 1
i: 1
}
}
}
attr {
key: "padding"
value {
s: "VALID"
}
}
attr {
key: "strides"
value {
list {
i: 1
i: 2
i: 2
i: 1
}
}
}
attr {
key: "use_cudnn_on_gpu"
value {
b: true
}
}
}
重新生成pb文件
首先加载新的pbtxt文件到tensorflow,然后冻结模型就可以了
加载
# -*- coding: utf-8 -*-
# 注意和加载pb文件的代码区别,虽然大致相同,但关键点不同哦
import tensorflow as tf
from google.protobuf import text_format
tf.reset_default_graph()
def load_frozenpb(file_name):
with tf.gfile.GFile(file_name, "rb") as g:
graph_def = tf.GraphDef()
text_format.Merge(g.read(), graph_def)
_ = tf.import_graph_def(graph_def, name="")
sess = tf.Session()
return sess
sess = load_frozenpb("./test.pbtxt")
然后冻结模型
gd = sess.graph.as_graph_def()
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, gd, output_node_names=["final_result"])
with tf.gfile.FastGFile( "test_new.pb", mode='wb') as f:
f.write(output_graph_def.SerializeToString())
print("save model success.")
# tf.train.write_graph( graph_def , './' , 'protobuf.pb' , as_text = False )
INFO:tensorflow:Froze 0 variables.
INFO:tensorflow:Converted 0 variables to const ops.
save model success.
到此新的pb文件生成。最后用TensorFlow加载模型测试一下哦
结尾
这个方法还是要了解TensorFlow节点的工程师才好操作,所以要是发现什么更加简单的方法,欢迎留言告知!
三合一收款
下面三种方式都支持哦
《青春援助交际国语》剧情片高清在线免费观看:https://www.jgz518.com/xingkong/133521.html
你的文章总是能给我带来欢乐,谢谢你! http://www.55baobei.com/jjrxdpa8Xs.html
《麒麟剧社陶阳夏昀帆《宋江杀惜》全本2023.10.21三庆园(独家版)》大陆综艺高清在线免费观看:https://www.jgz518.com/xingkong/142746.html
你的文章总是能给我带来欢乐,谢谢你! http://www.55baobei.com/SK4nTZpQzJ.html