热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

tensorflow自定义op_TensorRT加速tensorflow模型

一.动机目前最新版本的TensorRT已经支持tensorflow1.x和tensorflow2.x版本的模型,由于个人之前的模型是基于tf1.14训练的࿰

一. 动机

目前最新版本的TensorRT已经支持tensorflow1.x和tensorflow2.x版本的模型,由于个人之前的模型是基于tf1.14训练的,为了进一步对模型进行加速,因此本文主要对基于1.14的模型使用TensorRT进行加速。

二. Tensorflow的优势

目前tensorflow中已经继承了tensorrt模块,因此只要有savedmodel就可以用TensorRT进行加速,而不需要像Pytorch需要转格式之后再用TensorRT加速,同时基于tensorflow内部的tensorrt模块,可以避免写plugin来支持对应的算子,当一个op无法被TensorRT进行转换时,模型会仍旧使用tensorflow原有的算子,需要进一步加速时才需要添加自定义算子,当对应的op不是十分耗时的情况时,可以使用tensorflow的算子进行计算,转换后依旧可以被保存为saved_model格式,从而使用tfserving进行部署。

三.转换

首先需要有一个转换好的saved_model(可以见上一篇讲述tensorflow-serving的文章https://zhuanlan.zhihu.com/p/104960285), 如以下格式

0cf8c554e92b7d37bd8c86278ae9ed44.png

import tensorflow as tf
import cv2
import base64
from tensorflow.python.compiler.tensorrt import trt_convert as trtsaved_model_dir = "./export_model_0126/1581080318"
output_saved_model_dir = "./convert_INT8_export_model"
fetch_names = ["strided_slice_256:0", "cond/Merge:0", "strided_slice_1:0", "ExpandDims:0","Const_39:0", "strided_slice_258:0", "Shape:0", "Cast:0","strided_slice_260:0", "combined_non_max_suppression/CombinedNonMaxSuppression:3","cond/Merge_1:0", "strided_slice_2:0"]class feed_dict_input_fn():def __init__(self, filename):self.filename = filenameself.content = []with open(self.filename) as f:for line in f:self.content.append(line.strip())self.index = 0def __call__(self, *args, **kwargs):data = open(self.content[self.index], 'rb').read()encode = base64.urlsafe_b64encode(data)encode = str(encode, encoding='utf-8')image = {"input:0": encode}# value = {"inputs": image}self.index += 1return imageconverter = trt.TrtGraphConverter(input_saved_model_dir=saved_model_dir,precision_mode=trt.TrtPrecisionMode.INT8,use_calibration=True, is_dynamic_op=True, maximum_cached_engines=3)
feet_dict_input = feed_dict_input_fn("/home/admin-seu/TempData/sss/Master_work/data/test.list")
converter.convert()
converter.calibrate(fetch_names=fetch_names, num_runs=100, feed_dict_fn=feet_dict_input)
converter.save(output_saved_model_dir)

上述代码块使用了tensorflow内部的tensorrt模块来对模型加速,其中的fetch_names同样可以参考上一篇文章中的saved_model_cli工具获取模型的输出tensor的名字,calibrate函数是用一串输入数据集对模型进行校准,这是由于TensorRT的INT8需要对数据进行归一化,因此校准是必要的。转换完成后,会得到下图中的INT8模型,同样可以对应生成FP16和FP32的模型。

7e946d73407507168df738489444674e.png

四.测试

使用如下代码块进行简单测试:

port tensorflow as tf
import numpy as np
import base64
import time# output_saved_model_dir = "./convert_export_model"
# output_saved_model_dir = "./export_model_0126/1581080318"
output_saved_model_dir = "./convert_INT8_export_model"
# output_saved_model_dir = "./convert_FP32_export_model"data = open("/home/admin-seu/TempData/test2017/000000258074.jpg", 'rb').read()
encode = base64.urlsafe_b64encode(data)
encode = str(encode, encoding='utf-8')with tf.Session() as sess:tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING],output_saved_model_dir)cur_graph = sess.graphnode_names = [tensor.name for tensor in sess.graph_def.node]output_tensors = []input_tensor = cur_graph.get_tensor_by_name("input:0")output_tensors.append(cur_graph.get_tensor_by_name("strided_slice_256:0"))output_tensors.append(cur_graph.get_tensor_by_name("strided_slice_260:0"))# for node_name in node_names:# if "input" in node_name:# print(node_name)# input_tensor = cur_graph.get_tensor_by_name(node_name)# if "scores_1" in node_name:# print(node_name)# output_tensors.append(cur_graph.get_tensor_by_name(node_name))# if "labels" in node_name:# print(node_name)# output_tensors.append(cur_graph.get_tensor_by_name(node_name))# if "all_ids" in node_name:# print(node_name)# output_tensors.append(cur_graph.get_tensor_by_name(node_name))# if "boxes_1" in node_name:# print(node_name)# output_tensors.append(cur_graph.get_tensor_by_name(node_name))output = sess.run(output_tensors, feed_dict={input_tensor: encode})print(np.shape(output[0]))for i in range(10):output = sess.run(output_tensors, feed_dict={input_tensor: encode})start = time.time()print(start)for i in range(100):output = sess.run(output_tensors, feed_dict={input_tensor: encode})end = time.time()print(end)print(end - start)

最终可以得到加速后的模型的速度,由于本文使用的检测模型使用了大量TensorRT不支持的算子,因此加速效果比较有限,大概能比原先模型提升10%的速度。在一些更为简单的任务上相信模型能得到更大的加速比。

5.总结

可见使用TensorRT对tensorflow的模型加速是十分简单的,基本开箱即用,代码已放置在https://github.com/smallsunsun1/Cascade-RCNN,最后总结一下就是Tensorflow Yes, 后续可能记录一下鸽了很久的对TensorFlow源码部分的阅读笔记了只剩,立一个Flag,希望有时间有空自己补上0.0



推荐阅读
author-avatar
手机用户2502857335
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有