那棵树看起来生气了
TensorFlow模型量化 压缩 Int8推理
前言
在工业生产过程中,对于深度学习模型,企业比较关注的是成本和性能,怎样让我们的深度学习模型消耗最少的能量,最少的时间,获得最大的准确率(或者说获得不影响业务使用的最佳性能),工业中会用到模型量化,减少模型储存空间,减少模型计算量从而达到性能优化的目的,下面就介绍两种优化的使用方法。
正文
第一种优化方案transform_graph
1.编译transform_graph
$ bazel build tensorflow/tools/graph_transforms:transform_graph
2.生成量化模型
$ bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=./model.pb \
--out_graph=model_quanlized.pb \
--inputs='Placeholder' \
--outputs='Softmax' \
--transforms=' \
strip_unused_nodes(type=float, shape="1,224,224,3") \
fold_constants(ignore_errors=true) \
fold_batch_norms \
fold_old_batch_norms \
quantize_nodes \
quantize_weights(minimum_size=0)'
得到如下所示的模型,model_quanlized.pb便是优化后的模型,可以看到优化后模型大小只有原有的1/4
-rw-r--r--. 1 dyb dyb 93M 10月 30 09:39 model.pb
-rw-rw-r--. 1 dyb dyb 25M 11月 23 03:45 model_quanlized.pb
3.得到的pb文件可以使用如下量化方法进行量化。
import tensorflow as tf
import numpy as np
import time
import argparse
import cv2
import os
np.set_printoptions(threshold='nan')
output_graph_path = "./model_quanlized.pb"
with tf.gfile.GFile(output_graph_path, "rb") as f:
g_def = tf.GraphDef()
g_def.ParseFromString(f.read())
_ = tf.import_graph_def(g_def, name="")
with tf.Session() as sess:
input_node = sess.graph.get_tensor_by_name("Placeholder:0")
ipt_shape = input_node.shape.as_list()
ipt_shape = [1,224,224,3]
img = cv2.imread('test.jpg', cv2.IMREAD_COLOR)
img.resize(1, ipt_shape[1], ipt_shape[2], ipt_shape[3])
img = img.astype('float32')
loop = 100
print 'loop:', loop
start_time = time.time()
for _ in range(loop):
predictions = sess.run('Sotfmax:0', feed_dict = {input_node: img})
end_time = time.time()
esp_time = (end_time - start_time)/float(loop)
esp_ms_time = round(esp_time * 1000, 2)
print "[TF] per loop is : %s ms" % (esp_ms_time)
第二种优化方案tflite
Tenserflow lite是谷歌在2017年11月推出的轻量级移动端预测框架,目前已经支持Android和iOS,并且支持iOS的CoreML。TensorFlow是针对手机和嵌入式设备提供的轻量级解决方案。它提供了低延迟和小体积的端侧机器学习预测能力。TensorFlow Lite同时支持通过Android Neural Networks API硬件加速。
1.生成tflite
转换pb文件为更加简洁高效的*.tflite,直接上代码:
import argparse
import sys
if sys.version_info.major >= 3:
import pathlib
else:
import pathlib2 as pathlib
import tensorflow as tf
import numpy as np
import cv2
import time
Save_Home = "/home/save_model"
save_model = "model"
input_name = "Placeholder"
output_name = "Softmax"
isQuantize = True #根据设置的不同isQuantize,决定生成的文件是否需要量化
graph_def_file = pathlib.Path(Save_Home)/save_model/("model.pb")
print(graph_def_file)
input_arrays = [input_name]
output_arrays = [output_name]
converter = tf.contrib.lite.TocoConverter.from_frozen_graph(
str(graph_def_file), input_arrays, output_arrays, input_shapes={input_name:[1,224,224,3]})
converter.post_training_quantize = isQuantize
if isQuantize:
resnet_tflite_file = graph_def_file.parent/"model_quantized.tflite"
else:
resnet_tflite_file = graph_def_file.parent/"model.tflite"
resnet_tflite_file.write_bytes(converter.convert())
print("general_site success {}".format(resnet_tflite_file))
根据设置的不同isQuantize,决定生成的文件是否需要量化
2.运行测试,代码如下
import argparse
import sys
if sys.version_info.major >= 3:
import pathlib
else:
import pathlib2 as pathlib
import tensorflow as tf
import numpy as np
import cv2
import time
Save_Home = "/home/save_model"
isQuantize = True
img_path="./test.jpg"
batch_size=1
loop=100
if isQuantize:
tflite_file = pathlib.Path(Save_Home)/save_model/"model_quantized.tflite"
else:
tflite_file = pathlib.Path(Save_Home)/save_model/"model.tflite"
tf.logging.set_verbosity(tf.logging.DEBUG)
interpreter_quant = tf.contrib.lite.Interpreter(model_path=str(tflite_file))
interpreter_quant.allocate_tensors()
input_index = interpreter_quant.get_input_details()[0]["index"]
output_index = interpreter_quant.get_output_details()[0]["index"]
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
img = img.astype('float32')
img = np.tile(img, batch_size)
img.resize(1, 224, 224, 3)
interpreter_quant.set_tensor(input_index, img)
start_time = time.time()
for _ in range(loop):
interpreter_quant.invoke()
predictions = interpreter_quant.get_tensor(output_index)
end_time = time.time()
esp_time = (end_time - start_time) / float(loop)
esp_ms_time = round(esp_time * 1000, 2)
print("TF time used per loop is: %s ms" % esp_ms_time)
结尾


三合一收款
下面三种方式都支持哦