TypechoJoeTheme

半醉残影

统计

TensorFlow模型Graph修改插入节点

2018-09-29
/
0 评论
/
2,325 阅读
/
正在检测是否收录...
09/29

背景

在优化一个深度学习模型的时候,客户给的模型为已经冻结过得pb文件,而这个pb文件没有input(Placeholder)节点,把特定的数据输入这个模型无法做到,而且对于测试性能和数据的准确性非常不便。然后我就想,用什么办法可以修改已经训练好模型的结构呢?最后找到一种方法,修改计算图结构。

正文

假如原topo.pb结构如下形式,而我们想要在 Conv2D_1前面加上某个节点,这里就假如需要加上Placeholder

Conv2D_1 -> Relu -> Conv2D_2 -> Softmax

步骤

生成pbtxt文件

对于pb文件,肉眼几乎是不可读的,用vim编辑器打开如下:

^Mconv_1/Conv2D^R^FConv2D^Z^Dconv^Z^Tconv_1/conv2d_params*^R
^Gpadding^R^G^R^EVALID*^S
^Gstrides^R^H
^F^Z^D^A^A^A^A*^V
^Puse_cudnn_on_gpu^R^B(^A*^U
^Kdata_format^R^F^R^DNHWC*^G
^AT^R^B0^A
Æ^A
^Uconv_1/batchnorm/beta^R^EConst*^K
^Edtype^R^B0^A*<98>^A
^Evalue^R<8e>^AB<8b>^A^H^A^R^D^R^B^H "<80>^AÜÆé¿^Mé=¿^W<80><94>¾<87><8f>4?w¨<8c>?^TM<9b>¾x\<8a>>¸ñ<8b>>R<9c><8f>>Z^[ë>'±<9a>>%ÓA@^]ã ¾÷çB?Jjâ?ò^^j?!&¶¿Í<9e>^M@¢|ñ¿!»ê¿´ü^W@ÎI]¿³<8a>G@^Ft­¿{æB@^X<94>8?8<8b><8f>¾Ä<90>d¿h=<8c>¾Êy^B@YxÛ?/<95><9c>>
Ç^A
^Vconv_1/batchnorm/gamma^R^EConst*^K
^Edtype^R^B0^A*<98>^A

这里我把它转换为goole probuf 文本文件,代码如下:

# make_pbtxt.py

# -*- coding: utf-8 -*-
import tensorflow as tf

## 加载模型
def load_frozenpb(file_name):
    with tf.gfile.GFile(file_name, "rb") as g:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(g.read())
        _ = tf.import_graph_def(graph_def, name="")
    
    sess = tf.Session()
    return sess

tf.reset_default_graph()
sess = load_frozenpb("./test.pb")
tf.train.write_graph(sess.graph, './','test.pbtxt')
$ python make_pbtxt.py

修改pbtxt

经过上边的命令,会在当前目录生成test.pbtxt文件,使用vim打开查看格式,这个时候因该是个可读性很强的内容,基本看一眼就知道是干什么的了。

node {
  name: "moments/Squeeze"
  op: "Squeeze"
  input: "moments/mean"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "squeeze_dims"
    value {
      list {
        i: 0
        i: 1
        i: 2
      }
    }
  }
}
node {
  name: "moments/Squeeze_1"
  op: "Squeeze"
  input: "moments/variance"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "squeeze_dims"
    value {
      list {
        i: 0
        i: 1
        i: 2
      }
    }
  }
}

这个时候就可以非常方便的修改该文件,但需要注意的是按照原有格式规范修改。

把需要插入的节点放在pbtxt文件的最开头,比如我插入一个Placeholder。如下内容,可以修改他的shape等属性适应业务需求

# 下面的结构生成的方法还是用上一步生成pbtxt的方法生成

node {
  name: "Placeholder"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: 1
        }
        dim {
          size: 299
        }
        dim {
          size: 299
        }
        dim {
          size: 3
        }
      }
    }
  }
}

经过这一步后,新的节点已经放进去了,但是还没有连接到原有计算图上,所以还需要连接起来

连接

找到需要插入的位置(即新的节点作为哪个节点的输入),这里我找到了我要插入的位置 name: "conv/Conv2D",然后修改该节点的input为新节点的name,保存退出即可

# 这里我找到了我要插入的位置 name: "conv/Conv2D",然后修改该节点的input为新节点的name,保存退出即可
node {
  name: "conv/Conv2D"
  op: "Conv2D"
  input: "Placeholder"
  input: "conv/conv2d_params"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "data_format"
    value {
      s: "NHWC"
    }
  }
  attr {
    key: "dilations"
    value {
      list {
        i: 1
        i: 1
        i: 1
        i: 1
      }
    }
  }
  attr {
    key: "padding"
    value {
      s: "VALID"
    }
  }
  attr {
    key: "strides"
    value {
      list {
        i: 1
        i: 2
        i: 2
        i: 1
      }
    }
  }
  attr {
    key: "use_cudnn_on_gpu"
    value {
      b: true
    }
  }
}

重新生成pb文件

首先加载新的pbtxt文件到tensorflow,然后冻结模型就可以了

加载

# -*- coding: utf-8 -*-
# 注意和加载pb文件的代码区别,虽然大致相同,但关键点不同哦
import tensorflow as tf
from google.protobuf import text_format
tf.reset_default_graph()

def load_frozenpb(file_name):
    with tf.gfile.GFile(file_name, "rb") as g:
        graph_def = tf.GraphDef()
        text_format.Merge(g.read(), graph_def)
        _ = tf.import_graph_def(graph_def, name="")
    
    sess = tf.Session()
    return sess



sess = load_frozenpb("./test.pbtxt")

然后冻结模型

gd = sess.graph.as_graph_def()
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, gd, output_node_names=["final_result"])
with tf.gfile.FastGFile( "test_new.pb", mode='wb') as f:
    f.write(output_graph_def.SerializeToString())
    print("save model success.")
# tf.train.write_graph( graph_def , './' , 'protobuf.pb' , as_text = False )

INFO:tensorflow:Froze 0 variables.
INFO:tensorflow:Converted 0 variables to const ops.
save model success.

到此新的pb文件生成。最后用TensorFlow加载模型测试一下哦

结尾

这个方法还是要了解TensorFlow节点的工程师才好操作,所以要是发现什么更加简单的方法,欢迎留言告知!

TensorFlow深度学习
朗读
赞(0)
赞赏
感谢您的支持,我会继续努力哒!

三合一收款

下面三种方式都支持哦

微信
QQ
支付宝
打开支付宝/微信/QQ扫一扫,即可进行扫码打赏哦
版权属于:

半醉残影

本文链接:

https://blog.dengyb.com/archives/30/(转载时请注明本文出处及文章链接)

评论 (0)

人生倒计时

今日已经过去小时
这周已经过去
本月已经过去
今年已经过去个月

最新回复

  1. 搭建自用导航网站 R11; JKblog
    2022-03-10
  2. JK
    2022-01-13

    {!{data:image/webp;base64,UklGRhwNAABXRUJQVlA4WAoAAAAwAAAAAQMAKwEASUNDUBgCAAAAAAIYAAAAAAQwAABtbnRyUkdCIFhZWiAAAAAAAAAAAAAAAABhY3NwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQAA9tYAAQAAAADTLQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAlkZXNjAAAA8AAAAHRyWFlaAAABZAAAABRnWFlaAAABeAAAABRiWFlaAAABjAAAABRyVFJDAAABoAAAAChnVFJDAAABoAAAAChiVFJDAAABoAAAACh3dHB0AAAByAAAABRjcHJ0AAAB3AAAADxtbHVjAAAAAAAAAAEAAAAMZW5VUwAAAFgAAAAcAHMAUgBHAEIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAFhZWiAAAAAAAABvogAAOPUAAAOQWFlaIAAAAAAAAGKZAAC3hQAAGNpYWVogAAAAAAAAJKAAAA+EAAC2z3BhcmEAAAAAAAQAAAACZmYAAPKnAAANWQAAE9AAAApbAAAAAAAAAABYWVogAAAAAAAA9tYAAQAAAADTLW1sdWMAAAAAAAAAAQAAAAxlblVTAAAAIAAAABwARwBvAG8AZwBsAGUAIABJAG4AYwAuACAAMgAwADEANkFMUEgDCQAADbBG///6tNUbKWwk9S3VXZn76LUaHey4L5y71NsURplkbkGqa08JVFOXlBW6YjlbJXMNlTl31VyvnVWAtjPoSpA8+P0j5MH/wffKPyImAMP/hv8N/xv+N/xv+N/wv+F/w/+G/w3/Gy3s/+Gwd6efkH3r4N5Jk0Xf5cMQ/7MAGkWfLU2plXw2G0DomsF+sWe/HzWhZHjWR5MbZZ6bCO/YMK9cljy316+LXibii1c2FwiSMgfp5sU1egucfQvEiAvgDzdU6J/LUb+4vE8iiYkAzsNeIeJC8/cN+meWcnzc4R+6GOMzAdSZcxolyIRfaXnRvbY05TnU+rUeE8C9KxaUC5DFaD+ZPblR39jvB+j6RANfldsCYHZT1Cg9Lif8HWsXlOsaN+pOwvqW1FpQc52jzsqOocsjwOx+z6tjPKhNc8MRcMzOSAbovf/XoiP3ZSKvO/qKbrH1UlofINKAI2m7CcA0Za3gsN2DGmrulZSUCDB9g16ZkI3qIMr6l8vGJAFrTswXG3OuR7N+7E+dGTUWoI/lHzplMerBhmgIODo+vAiY9/ucRplxtQ21Y/dcoH7OJwnAyqf0yeWoZxzE8POtZQD3rp5XLjJWojZN7EQNBNOBO28v0CWlyo8PEdMV91wFcPHK5gKB4UI9+whh5z92KeA8v0J/mD2os4ixdXNPVOdhr7hYjtr0OOHf3LoZ4NVZumPEh6ihTbEK5LoykgHqzk8TFnPQnEqk73/qAvosytMbfjT3EfOAI2m7CcBeMPqcpLjaplFA5KU5FuDR89X6ogjNE0/FDurXui0A1+4pWicoVqIeeCcKxtUCLH5RX8zSKF1Bt/qW1FpQ1/5QICZcqC15ROu/5EZg0MWHdcSiJ1BfoLsDc5YMVXAe9gqJ5ahNOUQ/rjIVmPWqjpiG2rSl26ifPLlEoc7rFRG2gRpTiWFgUxlgv65AP/TXGEc8uv66Jg2g7nCBgHBdjlpATF+zJwLO/q/ohRTUf30XF1S+s9wC4OxbIB5sl6MeeCc2P397K4C9Qi8Ef+4PbU8Rp74lK28AcB72CocHpqKeeJEYO8pMgKVJLzArZ+DpFcRtYM5KC0Cd1ysbalGbJhPrwLxqAFeBXli96DTxHFjiNgHUHS6QDI+jNhURe999SYCzb4FOiHtfldsC4OxbIBcevkc5+wjdWZ8J4Oz/ii7Dt2TlDQDOMeNbpEI5wIU8urXykTSA/PX6jMCclRaA+2u8uTLBg/oK3Vvx7kYT0LuwQJ8RWOI2oU6sEgm2Xsq5Ld2Er6oawNm3QJ/hq/KYlNRQUCC4Lkf9M93uswwHcA54QZ/hq/KYgDMIRBeaB890H2vsaQDPX/2EPsP3Xv4NnNwoEFxonnEQhxXvbjQB5Lym06hYegiJaB+o0WQjLn1VHhOQGgrqNKnoRm19ljj1VW1OhlOI8o2orXnErc87e/DpLaJsUUgJ2onj6pknkeXTUPOR/P2VUKXkS0H9Gskf/Lk/tL0p+piVM/D0CmT/6kWnMfxv+N/wv+F/0W8ecqJF+DlKGFy4SvbNgITi+aIvpR8w8FxQ8gUPDYNvkf1PWi8/vE34fTH7MP/N2jzkRIvwc5QwuHCV7JsBCcXzRV9KP2DguaDkCx4aBt8i+5+0Xn54m/D7YvZhDP8b/jf8rxuf2+sXfh+PIN28uEbwVWcD/OGGCrFndaH9+waxl0lYL2LfqtW+6Eq558kDQrsrEfx7jmX0+rzu38j2rHRIPbm1JVb+Dbt/Qqybswe3ceK5UaiDs725MRLt6yZCKkMIP7FKyfJrmLMHty2Tb5OINjUUZPy4Pj/U1jDgrWshNafyGchK35u+1y/UUnpEdQZYBNZrXbuvRO09Ya21FNJJH/2+v3MvkL5MlgXPDgoT6iKBrpMbKXsa4JZ9o9Du0Z6JZo+7s5LSAXI+2tdW0SjGmP+oJbGrIxj8qnwrIQAXmp1mwpvv0Iqw9x23pNqPjBRjq+c3dRFpVj7a9aT20Aq11uRFozn8SzEW5aOLCHuhDveTPQgl0PEme45l9Gs3JSdGkWBqEWkL0O7aswVedrdDCMC/Yfc5Wk0vT+6X9F1L+y0pGh0XEOdZw3vfgXbrpHPE8sfS1891AulT05Muom0t0txaSoS/5ALmol3vRhO+wrO38wICPZPwLbt2AB8NZ8wNzprYiHVruIMjAYZeDfzqtooostKXiTZPntYv0wCmP4v6+4aIrGWQ4w/8mNJW0SjT9hy7t09CqOMvm1AXoukl4kygd9b1vUm1Hxkp0vwb3v0hRFgbmse/jMxKhMO/FGmR2+/X2L2ByD15ESSYWqSbG3V3IVHu+evUlDAdFxDuHtQz24jWvyVh6u9TEkkIta1FuG9Cbb2PGFZ4GoJdhJDuZU+j6UDqZ81A+2CD0LOWEvaMA6GfSdhz2Uh9q1bXnq2IfU+e8suEc8j9Pcfu6BXavwnJ79/w0U8Y/jf8b/jf8L/h///amIVf+FlnQNZHO3fhjzdz9tYWKZYJcMetY8ja8mJjPHnGke3NFWJWRXPSvM2z48cFMLFKiHnywpH20oRXfHHiQk0NBWXYnl3jhyZrwbVld3/gXxwHLjRPIcT979cML74i8XzvJNSUu7JyitZ11yo0g4VSDHhj7r+7+O22oQrQb0Fa/oXuMHvQbLUjzfc+OLlEC6YsLdn4Q8xGfIhmax4C3XVqyg2piQqmuQ9Wlp+KkR/NCzmI9OWbD1302JxUVMszue8VNsbCg+aJJxDrP66pWzIyFc2c4gXl0dl6aZSuQLIHHEft+YMVzO73vNG4Lkd9Ael+ar3vgRd7otZ5vZG50Dz4oXiDU+ur1mUkA9SlzW6IwHa5xhkHIv7UxKQqC8D9c3/Y6zqjUfYb1CYbUr5+zsobUG9Pf9q9VLGjtj6LnA/MWX2JAqZ5n3sBi9L0LJI+8MSqGzWgzrxhK5q3IesD43pNmNBf4d7pGTs0AsIOPquuzi82oV76kNKJxC/Z+pZF0a4XeQQcszOSw5x5R+YRcAwovkkrE7H/xtzXb9Y4Lvfg2bl9gK7zks9fnz+Ijq2I/mVlpzD8b/jf8L/hf8P/hv8N/xv+/6/7AFZQOCDSAQAAUDUAnQEqAgMsAT9xuNlltK8rpyAIApAuCWlu4XdhG0AJ7APfbJyHvtk5D32ych77ZOQ99snIe+2TkPfbJyHvtk5D32ych77ZOQ99snIe+2TkPfbJyHvtk5D32ych77ZOQ99snIe+2TkPfbJyHvtk5D32ych77ZOQ99snIe+2TkPfbJyHvtk5D32ych77ZOQ99snIe+2TkPfbJyHvtk5D32ych77ZOQ99snIe+2TkPfbJyHvtk5D32ych77ZOQ99snIe+2TkPfbJyHvtk5D32ych77ZOQ99snIe+2TkPfbJyHvtk5D32ych77ZOQ99snIe+2TkPfbJyHvtk5D32ych77ZOQ99snIe+2TkPfbJyHvtk5D32ych77ZOQ99snIe+2TkPfbJyHvtk5D32ych77ZOQ99snIe+2TkPfbJyHvtk5D32ych77ZOQ99snIe+2TkPfbJyHvtk5D32ych77ZOQ99snIe+2TkPfbJyHvtk5D32ych77ZOQ99snIe+2TkPfbJyHvtk5D32ych77ZOQ99snIe+2TkPfbJyHvtk5D32ych77ZOQ99snIe+2TkPfbJwoAAP7/rR4AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==}!}

  3. 搭建自用导航网站 R11; JKblog
    2022-01-12
  4. 搭建自用导航网站 R11; JKblog
    2022-01-12
  5. MrGao
    2019-09-03

标签云