Toggle navigation
首页
问答
文章
积分商城
专家
专区
更多专区...
文档中心
返回主站
搜索
提问
会员
中心
登录
注册
AI_人工智能
TFLite
模型量化 - tflite量化(详细 | 附源码)
发布于 2021-07-13 09:47:52 浏览:2343
订阅该版
[tocm] # 模型转 tflite July 1, 2021 - lebhoryi@gmail.com 火灾烟雾检测二分类模型 [github](https://github.com/Lebhoryi/FireNet-LightWeight-Network-for-Fire-Detection/tree/master/Codes) ![image.png](https://oss-club.rt-thread.org/uploads/20210713/75554d07872d69697b5566a162bb124f.png) 由于模型过小,它的推理时间一直在变,1ms~7ms之间的值都有。 总的来说,数据还是和官网提供的表格数据保持一致。 为什么要量化? 因为嵌入式端对 Int 友好,有硬件资源限制... Tensorflow Lite 官网提供了三种量化方式: ![image.png](https://oss-club.rt-thread.org/uploads/20210713/d07c22986d484a422c4cca183db2bf26.png) ![image.png](https://oss-club.rt-thread.org/uploads/20210713/89109686985a88c4fb475794af8de1f0.png.webp) # 0x00 检查模型是否正常 数据准备: ```python import cv2 import numpy as np # 单个测试样本数据 test_path = "test.jpg" image = cv2.imread(test_path) image = cv2.resize(image, (64, 64)) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image_bn = image.astype("float32") / 255.0 image = np.expand_dims(image, axis=0) image_bn = np.expand_dims(image_bn, axis=0) ``` ![image.png](https://oss-club.rt-thread.org/uploads/20210713/e362f8f943f17c121de6da65b0b1a8e9.png.webp) ```python import os import time import tensorflow as tf # 恢复 keras 模型,并预测 keras_file = '../Models/20210701.h5' model = tf.keras.models.load_model(keras_file) # model.summary() # tf.autograph.set_verbosity(0) start_time = time.time() pred = model.predict(image_bn) stop_time = time.time() print(f"prediction: {pred}") print('time: {:.3f}ms'.format((stop_time - start_time) * 1000)) print("model size: {:.2f} MB".format(os.path.getsize(keras_file)/1024/1024)) # prediction: [[0.9730365 0.02696355]] # time: 126.018ms # model size: 7.46 MB ``` # 0x01 (未量化) 模型转 tflite 格式 ```python # 恢复 keras 模型 keras_file = '../Models/20210701.h5' model = tf.keras.models.load_model(keras_file) # 最直接的保存为tflite,没有任何量化 converter = tf.lite.TFLiteConverter.from_keras_model(model) tflite_model = converter.convert() tflite_file = Path("../Models/Tflites/fire_raw.tflite") tflite_file.write_bytes(tflite_model) # 2591588 ``` tflite 模型推理代码: ```python tflite_file = Path("../Models/Tflites/fire_raw.tflite") # tflite 模型推理 interpreter = tf.lite.Interpreter(model_path=str(tflite_file)) interpreter.allocate_tensors() # Get input and output tensors. input_details = interpreter.get_input_details()[0] output_details = interpreter.get_output_details()[0] interpreter.set_tensor(input_details['index'], image_bn) start_time = time.time() interpreter.invoke() stop_time = time.time() output_data = interpreter.get_tensor(output_details['index']) print(f"prediction: {output_data}") print('time: {:.3f}ms'.format((stop_time - start_time) * 1000)) print("model size: {:.2f} MB".format(os.path.getsize(tflite_file)/1024/1024)) ``` 结果如下 ``` prediction: [[0.9730364 0.02696356]] time: 1.637ms model size: 2.47 MB ``` # 0x02 动态范围量化 只能用于CPU加速, “动态范围”:根据激活函数的范围动态的将其转换为8bit整数 仅量化权重,从float32量化为int8,激活保持不变,模型减小了3/4 。 在推理的时候,把int8转回fp32,**输入和输出都是浮点数** ```python # 动态量化 dynamic range quantization converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_model = converter.convert() tflite_file = Path("../Models/Tflites/fire_dynamic.tflite") tflite_file.write_bytes(tflite_model) ``` 模型推理用的还是上面的推理代码,结果如下 ``` prediction: [[0.9738594 0.02614057]] time: 17.465ms model size: 641.66 KB ``` # 0x03 Float16 量化 将权重与激活函数均转换为16位浮点数。 - 模型减小1/2。 - 量化中精度损失最少 缺点:float16 量化模型在 CPU 上运行时会将权重值“反量化”为 float32。 ```python # float16 range quantization converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.target_spec.supported_types = [tf.float16] tflite_model = converter.convert() tflite_file = Path("../Models/Tflites/fire_float16.tflite") tflite_file.write_bytes(tflite_model) ``` 模型推理用的还是上面的推理代码,结果如下 ``` prediction: [[0.97292536 0.02707472]] time: 5.738ms model size: 1.24 MB ``` # 0x04 Int 量化 需要校准或估计模型中所有浮点张量的范围,即 (min, max),所以需要一部分的数据集。 查看模型输入和输出的类型的函数 ```python interpreter = tf.lite.Interpreter(model_path=tflite_model_path) interpreter = tf.lite.Interpreter(model_content=tflite_model) input_type = interpreter.get_input_details()[0]['dtype'] print('input: ', input_type) output_type = interpreter.get_output_details()[0]['dtype'] print('output: ', output_type) ``` 1. 输入和输出还是浮点,将权重和偏置进行整型量化 ```python # quantize int def representative_data_gen(): for input_value in X[:100]: input_value = np.expand_dims(input_value, axis=0) input_value = input_value.astype(np.float32) yield [input_value] converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_data_gen # Ensure that if any ops can't be quantized, the converter throws an error tflite_model = converter.convert() tflite_file = Path("../Models/Tflites/fire_int_half.tflite") tflite_file.write_bytes(tflite_model) ``` 模型推理用的还是上面的推理代码,结果如下 ``` prediction: [[0.97265625 0.02734375]] time: 3.416ms model size: 0.63 MB ``` 2. 输入和输出是整型 ```python # quantize int def representative_data_gen(): for input_value in X[:100]: input_value = np.expand_dims(input_value, axis=0) input_value = input_value.astype(np.float32) yield [input_value] converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_data_gen # Ensure that if any ops can't be quantized, the converter throws an error converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] # Set the input and output tensors to uint8 (APIs added in r2.3) converter.inference_input_type = tf.uint8 converter.inference_output_type = tf.uint8 tflite_model = converter.convert() tflite_file = Path("../Models/Tflites/fire_int.tflite") tflite_file.write_bytes(tflite_model) ``` - 模型推理(只要将输入改为整形输入即可) ```python # tflite 模型推理 interpreter = tf.lite.Interpreter(model_path=str(tflite_file)) interpreter.allocate_tensors() # Get input and output tensors. input_details = interpreter.get_input_details()[0] output_details = interpreter.get_output_details()[0] # 更改输入即可 interpreter.set_tensor(input_details['index'], image) start_time = time.time() interpreter.invoke() stop_time = time.time() output_data = interpreter.get_tensor(output_details['index']) print(f"prediction: {output_data}") print('time: {:.3f}ms'.format((stop_time - start_time) * 1000)) print("model size: {:.2f} MB".format(os.path.getsize(tflite_file)/1024/1024)) ``` 结果如下: ``` prediction: [[249 7]] time: 1.019ms mod size: 0.63 MB ```
0
条评论
默认排序
按发布时间排序
登录
注册新账号
关于作者
lebhoryi
这家伙很懒,什么也没写!
文章
30
回答
6
被采纳
1
关注TA
发私信
相关文章
1
RT-Thread AI Kit 相关资料和教程在哪里?
2
20号的开发者大会上,人脸识别AI
3
2020 开发者大会演示的 AI 套件什么时候开源?
4
请问RT-AK有没有部署MobileNet的例子呢
5
请问RT-AK会支持paddlepaddle的模型吗
6
嵌入式比赛中要求的 RT-Thread ai toolkit 能介绍一下吗?
7
使用rt_ai_tools转换模型时报错
8
求一个识别人有没有带口罩的模型?
9
RT-AK的人物检测例子下载后编译失败
10
cube ai部署后报错
推荐文章
1
RT-Thread应用项目汇总
2
玩转RT-Thread系列教程
3
国产MCU移植系列教程汇总,欢迎查看!
4
机器人操作系统 (ROS2) 和 RT-Thread 通信
5
五分钟玩转RT-Thread新社区
6
【技术三千问】之《玩转ART-Pi》,看这篇就够了!干货汇总
7
关于STM32H7开发板上使用SDIO接口驱动SD卡挂载文件系统的问题总结
8
STM32的“GPU”——DMA2D实例详解
9
RT-Thread隐藏的宝藏之completion
10
【ART-PI】RT-Thread 开启RTC 与 Alarm组件
热门标签
RT-Thread Studio
串口
Env
LWIP
SPI
AT
Bootloader
Hardfault
CAN总线
FinSH
ART-Pi
USB
DMA
文件系统
RT-Thread
SCons
RT-Thread Nano
线程
MQTT
STM32
RTC
FAL
rt-smart
ESP8266
I2C_IIC
UART
WIZnet_W5500
ota在线升级
freemodbus
PWM
flash
cubemx
packages_软件包
BSP
潘多拉开发板_Pandora
定时器
ADC
GD32
flashDB
socket
中断
编译报错
Debug
rt_mq_消息队列_msg_queue
SFUD
msh
keil_MDK
ulog
C++_cpp
MicroPython
本月问答贡献
出出啊
1517
个答案
342
次被采纳
小小李sunny
1444
个答案
289
次被采纳
张世争
809
个答案
175
次被采纳
crystal266
547
个答案
161
次被采纳
whj467467222
1222
个答案
148
次被采纳
本月文章贡献
catcatbing
3
篇文章
5
次点赞
qq1078249029
2
篇文章
2
次点赞
xnosky
2
篇文章
1
次点赞
Woshizhapuren
1
篇文章
5
次点赞
YZRD
1
篇文章
2
次点赞
回到
顶部
发布
问题
投诉
建议
回到
底部