热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

float32精度_模型压缩系列方法——混合精度计算与量化压缩(3)

摘要移动端应用以及服务端节约空间都需要对当前的大模型进行适当压缩。本文继续介绍一种模型压缩方法。实际除了各种形式的distilling方式,混合精度计算与量化压缩方法
f40112ccd6b8240062757c4c3778aa66.png

摘要

移动端应用以及服务端节约空间都需要对当前的大模型进行适当压缩。本文继续介绍一种模型压缩方法。实际除了各种形式的distilling方式,混合精度计算与量化压缩方法也是非常常用的。

一、methodology

1.1 混合精度

实际在TensorFlow矩阵计算中,大多数是使用float32进行计算和存储的,但实际在可接受小幅精度损失的情况下,其中一部分变量可以采用float16进行变量申明和存储,仅仅在计算时候cast成为float32,也就形成了float32和float16混合的情景。

这样能压缩一部分空间;同时由于直接进行训练的缘故,效果偏差可控。

1.2 量化压缩

google 在官方网页中https://tensorflow.google.cn/api_docs/python/tf/lite/ 开源了量化压缩方法实现 8bit压缩。经过转换后,输入输出依旧是float,只不过中间的计算是用过8 bit来计算存储的。

对量化的实现是通过把常见操作转换为等价的八位版本达到的。涉及的操作包括卷积,矩阵乘法,激活函数,池化操作,以及拼接。转换脚本先把每个已知的操作替换为等价的量化版本。然后在操作的前后加上含有转换函数的子图,将input从浮点数转换成8 bit,再把output从8 bit转回浮点数。下面是 ReLu 的例子,input(float)==>relu==>output(float)

经过转换后,如下图所示:

ac06f77084e11d7cc82707d86b8ea8f5.png

quantize取input中的min和max,分别对应被量化的input中的最小值(0)和最大值(255),把[min, max]这个区间均匀分成255个小区间,把input中的值对应到对应的区间中。反量化操作则是把上述操作反向执行。

经过量化操作,可以有效提高点乘的计算效率。但当前google开源的tflite只对部分基础AIP有效,新出的很多高阶API尚不支持,期待后续开发。

二、data&实现

注意自行标记输入输出点:

from __future__ import print_functionimport os,sys
import time
from datetime import timedelta
import numpy as np
import tensorflow as tf
#from create_tf_record import *
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import tensorflow.contrib.slim as slim
from tensorflow.python.framework import graph_utildef freeze_graph(input_checkpoint,output_graph):''':param input_checkpoint::param output_graph: PBmodel path:return:'''# checkpoint = tf.train.get_checkpoint_state(model_folder) ## input_checkpoint = checkpoint.model_checkpoint_path ##output_node_names = "score_teacher/output_teacher"output_node_names = "score_student/output_student"#saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)graph = tf.get_default_graph()#input_graph_def = graph.as_graph_def()#with tf.Session() as sess:saver.restore(sess, input_checkpoint) #output_graph_def = graph_util.convert_variables_to_constants( # sess=sess,input_graph_def=input_graph_def,# :sess.graph_defoutput_node_names=output_node_names.split(","),variable_names_whitelist=None,variable_names_blacklist=None)#with tf.gfile.GFile(output_graph, "wb") as f: #f.write(output_graph_def.SerializeToString()) #print("%d ops in the final graph." % len(output_graph_def.node))
#
input_checkpoint='/data/liuyuanlin/push_project/push_model/push_student_model_topk_v2.0_20190910_1/best_validation'
out_pb_path='/data/liuyuanlin/push_project/push_model/push_student_model_topk_v2.0_20190910_1/pbmodel/IASv2.0.pb'
freeze_graph(input_checkpoint, out_pb_path)#=====================简单转换为 tensorflow lite格式 不压缩==================#
import tensorflow as tf
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
input_arrays = ["input_x"]
output_arrays = ["cnn_student_1/output_student"]
#converter = tf.lite.TFLiteConverter.from_frozen_graph("/data/liuyuanlin/push_project/push_model/push_student_model_topk_20190819_1/pbmodel/frozen_model_for_best_validation.pb", input_arrays, output_arrays)
converter = tf.contrib.lite.TocoConverter.from_frozen_graph("/data/liuyuanlin/push_project/push_model/push_student_model_topk_20190819_1/pbmodel/frozen_model_for_best_validation.pb",input_arrays, output_arrays)print("start convert..")
tflite_model = converter.convert()
print("convert ok and write the tflite model...")
open("/data/liuyuanlin/push_project/push_model/push_student_model_topk_20190819_1/pbmodel/converted_model.tflite", "wb").write(tflite_model)
#============================================================================##======================================================================================================#
#需要tf 1.14进行量化压缩
# default 默认压缩
import tensorflow as tf
in_tensors = ["input_x"]
out_tensors = ["score_student/output_student"]
graph_def_file = './push_student_model_topk_20190813_1/frozen_model_for_best_validation.pb'
converter = tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file, in_tensors, out_tensors)
#converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
#converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_LATENCY]#OPTIMIZE_FOR_SIZE
converter.optimizations = [tf.lite.Optimize.DEFAULT]#tf.lite.Optimize下有DEFAULT,OPTIMIZE_FOR_LATENCY,OPTIMIZE_FOR_SIZE
tflite_model = converter.convert()
open("quantify_default_model.tflite", "wb").write(tflite_model)

参考文献

[1]

TensorFlow Lite | 适用于移动设备和边缘设备的机器学习技术​tensorflow.google.cn
1b985bc78b87d861d9fa2538ab110513.png

[2] https://www.tensorflow.org/lite/performance/post_training_quantization



推荐阅读
  • 在尝试加载支持推送通知的iOS应用程序的Ad Hoc构建时,遇到了‘no valid aps-environment entitlement found for application’的错误提示。本文将探讨此错误的原因及多种可能的解决方案。 ... [详细]
  • 长期从事ABAP开发工作的专业人士,在面对行业新趋势时,往往需要重新审视自己的发展方向。本文探讨了几位资深专家对ABAP未来走向的看法,以及开发者应如何调整技能以适应新的技术环境。 ... [详细]
  • 本文探讨了如何将个人经历,特别是非传统的职业路径,转化为职业生涯中的优势。通过作者的亲身经历,展示了舞蹈生涯对商业思维的影响。 ... [详细]
  • 本题要求计算一组正整数的最小公倍数(LCM)。输入包括多组测试数据,每组数据首先给出一个正整数n,随后是n个正整数。 ... [详细]
  • 在1995年,Simon Plouffe 发现了一种特殊的求和方法来表示某些常数。两年后,Bailey 和 Borwein 在他们的论文中发表了这一发现,这种方法被命名为 Bailey-Borwein-Plouffe (BBP) 公式。该问题要求计算圆周率 π 的第 n 个十六进制数字。 ... [详细]
  • 洛谷 P4009 汽车加油行驶问题 解析
    探讨了经典算法题目——汽车加油行驶问题,通过网络流和费用流的视角,深入解析了该问题的解决方案。本文将详细阐述如何利用最短路径算法解决这一问题,并提供详细的代码实现。 ... [详细]
  • Irish budget airline Ryanair announced plans to significantly increase its route network from Frankfurt Airport, marking a direct challenge to Lufthansa, Germany's leading carrier. ... [详细]
  • 从理想主义者的内心深处萌发的技术信仰,推动了云原生技术在全球范围内的快速发展。本文将带你深入了解阿里巴巴在开源领域的贡献与成就。 ... [详细]
  • 本文详细介绍了如何正确设置Shadowsocks公共代理,包括调整超时设置、检查系统限制、防止滥用及遵守DMCA法规等关键步骤。 ... [详细]
  • 理解浏览器历史记录(2)hashchange、pushState
    阅读目录1.hashchange2.pushState本文也是一篇基础文章。继上文之后,本打算去研究pushState,偶然在一些信息中发现了锚点变 ... [详细]
  • Jupyter Notebook多语言环境搭建指南
    本文详细介绍了如何在Linux环境下为Jupyter Notebook配置Python、Python3、R及Go四种编程语言的环境,包括必要的软件安装和配置步骤。 ... [详细]
  • 本文详细介绍了如何搭建一个高可用的MongoDB集群,包括环境准备、用户配置、目录创建、MongoDB安装、配置文件设置、集群组件部署等步骤。特别关注分片、读写分离及负载均衡的实现。 ... [详细]
  • C# 中创建和执行存储过程的方法
    本文详细介绍了如何使用 C# 创建和调用 SQL Server 存储过程,包括连接数据库、定义命令类型、设置参数等步骤。 ... [详细]
  • H5技术实现经典游戏《贪吃蛇》
    本文将分享一个使用HTML5技术实现的经典小游戏——《贪吃蛇》。通过H5技术,我们将探讨如何构建这款游戏的两种主要玩法:积分闯关和无尽模式。 ... [详细]
  • 如何从BAM文件绘制ATAC-seq插入片段长度分布图?
    在ATAC-seq数据处理中,插入片段长度的分布图是一个重要的质量控制指标,它能反映出核小体的周期性排列。本文将详细介绍如何从BAM文件中提取并绘制这些数据。 ... [详细]
author-avatar
粪青12_601
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有