TypechoJoeTheme

半醉残影

统计

TensorFlow模型Graph修改插入节点

2018-09-29
/
0 评论
/
2,419 阅读
/
正在检测是否收录...
09/29

背景

在优化一个深度学习模型的时候,客户给的模型为已经冻结过得pb文件,而这个pb文件没有input(Placeholder)节点,把特定的数据输入这个模型无法做到,而且对于测试性能和数据的准确性非常不便。然后我就想,用什么办法可以修改已经训练好模型的结构呢?最后找到一种方法,修改计算图结构。

正文

假如原topo.pb结构如下形式,而我们想要在 Conv2D_1前面加上某个节点,这里就假如需要加上Placeholder

Conv2D_1 -> Relu -> Conv2D_2 -> Softmax

步骤

生成pbtxt文件

对于pb文件,肉眼几乎是不可读的,用vim编辑器打开如下:

^Mconv_1/Conv2D^R^FConv2D^Z^Dconv^Z^Tconv_1/conv2d_params*^R
^Gpadding^R^G^R^EVALID*^S
^Gstrides^R^H
^F^Z^D^A^A^A^A*^V
^Puse_cudnn_on_gpu^R^B(^A*^U
^Kdata_format^R^F^R^DNHWC*^G
^AT^R^B0^A
Æ^A
^Uconv_1/batchnorm/beta^R^EConst*^K
^Edtype^R^B0^A*<98>^A
^Evalue^R<8e>^AB<8b>^A^H^A^R^D^R^B^H "<80>^AÜÆé¿^Mé=¿^W<80><94>¾<87><8f>4?w¨<8c>?^TM<9b>¾x\<8a>>¸ñ<8b>>R<9c><8f>>Z^[ë>'±<9a>>%ÓA@^]ã ¾÷çB?Jjâ?ò^^j?!&¶¿Í<9e>^M@¢|ñ¿!»ê¿´ü^W@ÎI]¿³<8a>G@^Ft­¿{æB@^X<94>8?8<8b><8f>¾Ä<90>d¿h=<8c>¾Êy^B@YxÛ?/<95><9c>>
Ç^A
^Vconv_1/batchnorm/gamma^R^EConst*^K
^Edtype^R^B0^A*<98>^A

这里我把它转换为goole probuf 文本文件,代码如下:

# make_pbtxt.py

# -*- coding: utf-8 -*-
import tensorflow as tf

## 加载模型
def load_frozenpb(file_name):
    with tf.gfile.GFile(file_name, "rb") as g:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(g.read())
        _ = tf.import_graph_def(graph_def, name="")
    
    sess = tf.Session()
    return sess

tf.reset_default_graph()
sess = load_frozenpb("./test.pb")
tf.train.write_graph(sess.graph, './','test.pbtxt')
$ python make_pbtxt.py

修改pbtxt

经过上边的命令,会在当前目录生成test.pbtxt文件,使用vim打开查看格式,这个时候因该是个可读性很强的内容,基本看一眼就知道是干什么的了。

node {
  name: "moments/Squeeze"
  op: "Squeeze"
  input: "moments/mean"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "squeeze_dims"
    value {
      list {
        i: 0
        i: 1
        i: 2
      }
    }
  }
}
node {
  name: "moments/Squeeze_1"
  op: "Squeeze"
  input: "moments/variance"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "squeeze_dims"
    value {
      list {
        i: 0
        i: 1
        i: 2
      }
    }
  }
}

这个时候就可以非常方便的修改该文件,但需要注意的是按照原有格式规范修改。

把需要插入的节点放在pbtxt文件的最开头,比如我插入一个Placeholder。如下内容,可以修改他的shape等属性适应业务需求

# 下面的结构生成的方法还是用上一步生成pbtxt的方法生成

node {
  name: "Placeholder"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: 1
        }
        dim {
          size: 299
        }
        dim {
          size: 299
        }
        dim {
          size: 3
        }
      }
    }
  }
}

经过这一步后,新的节点已经放进去了,但是还没有连接到原有计算图上,所以还需要连接起来

连接

找到需要插入的位置(即新的节点作为哪个节点的输入),这里我找到了我要插入的位置 name: "conv/Conv2D",然后修改该节点的input为新节点的name,保存退出即可

# 这里我找到了我要插入的位置 name: "conv/Conv2D",然后修改该节点的input为新节点的name,保存退出即可
node {
  name: "conv/Conv2D"
  op: "Conv2D"
  input: "Placeholder"
  input: "conv/conv2d_params"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "data_format"
    value {
      s: "NHWC"
    }
  }
  attr {
    key: "dilations"
    value {
      list {
        i: 1
        i: 1
        i: 1
        i: 1
      }
    }
  }
  attr {
    key: "padding"
    value {
      s: "VALID"
    }
  }
  attr {
    key: "strides"
    value {
      list {
        i: 1
        i: 2
        i: 2
        i: 1
      }
    }
  }
  attr {
    key: "use_cudnn_on_gpu"
    value {
      b: true
    }
  }
}

重新生成pb文件

首先加载新的pbtxt文件到tensorflow,然后冻结模型就可以了

加载

# -*- coding: utf-8 -*-
# 注意和加载pb文件的代码区别,虽然大致相同,但关键点不同哦
import tensorflow as tf
from google.protobuf import text_format
tf.reset_default_graph()

def load_frozenpb(file_name):
    with tf.gfile.GFile(file_name, "rb") as g:
        graph_def = tf.GraphDef()
        text_format.Merge(g.read(), graph_def)
        _ = tf.import_graph_def(graph_def, name="")
    
    sess = tf.Session()
    return sess



sess = load_frozenpb("./test.pbtxt")

然后冻结模型

gd = sess.graph.as_graph_def()
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, gd, output_node_names=["final_result"])
with tf.gfile.FastGFile( "test_new.pb", mode='wb') as f:
    f.write(output_graph_def.SerializeToString())
    print("save model success.")
# tf.train.write_graph( graph_def , './' , 'protobuf.pb' , as_text = False )

INFO:tensorflow:Froze 0 variables.
INFO:tensorflow:Converted 0 variables to const ops.
save model success.

到此新的pb文件生成。最后用TensorFlow加载模型测试一下哦

结尾

这个方法还是要了解TensorFlow节点的工程师才好操作,所以要是发现什么更加简单的方法,欢迎留言告知!

TensorFlow深度学习
朗读
赞(0)
赞赏
感谢您的支持,我会继续努力哒!

三合一收款

下面三种方式都支持哦

微信
QQ
支付宝
打开支付宝/微信/QQ扫一扫,即可进行扫码打赏哦
版权属于:

半醉残影

本文链接:

https://blog.dengyb.com/archives/30/(转载时请注明本文出处及文章链接)

评论 (0)