Toggle navigation
首页
问答
文章
积分商城
专家
专区
更多专区...
文档中心
返回主站
搜索
提问
会员
中心
登录
注册
AI_人工智能
RT-AK_AI-Kit_人工智能
CIFAR10分类CNN模型训练加保存
发布于 2021-06-16 15:11:23 浏览:732
订阅该版
[tocm] CIFAR10分类CNN模型搭建&数据集处理&训练&保存. 模型简单可以尝试将其移植在嵌入式端实验. 附件提供jupyter notebook文件供实验. ``` #!/usr/bin/env python # coding: utf-8 # In[1]: #cifar import tensorflow as tf from tensorflow import keras from tensorflow.keras import backend from tensorflow.keras import layers import os import numpy as np from matplotlib import pyplot as plt # os.environ["CUDA_VISIBLE_DEVICES"] = "1" get_ipython().run_line_magic('matplotlib', 'inline') # In[2]: # create CNN def CNNmodel(input_shape,filters=64, kernel=(3,3),size=4,dropout=0.2,**kwargs): _inputs = layers.Input(shape=input_shape) x = layers.Conv2D(8,(3,3),padding='same',use_bias=False,strides=(2,2), name='conv_0')(_inputs) x = layers.BatchNormalization(axis=-1, name='conv_0_bn')(x) x = layers.ReLU(6., name='conv_0_relu')(x) x = layers.Conv2D(16,(3,3),padding='same',use_bias=False,strides=(2,2), name='conv_1')(_inputs) x = layers.BatchNormalization(axis=-1, name='conv_1_bn')(x) x = layers.ReLU(6., name='conv_1_relu')(x) for block_id in range(2,size+2): x = layers.Conv2D(filters,kernel,padding='same',use_bias=False,strides=(1,1), name='conv_%d'%block_id)(x) x = layers.BatchNormalization(axis=-1, name='conv_%d_bn'%block_id)(x) x = layers.ReLU(6., name='conv_%d_relu'%block_id)(x) x = layers.GlobalAveragePooling2D()(x) x = layers.Dropout(dropout, name='dropout')(x) x = layers.Dense(10)(x) x = layers.Softmax()(x) return keras.Model(inputs=_inputs,outputs=x) # In[3]: # preprocess input def preprocess_input(inputs, std=255. ,mean=0., expand_dims=None): inputs = tf.cast(inputs,tf.float32) inputs = (inputs - mean) / std if expand_dims is not None: np.expand_dims(inputs,expand_dims) return inputs # dataset aug def img_aug_fun(elem): elem = tf.image.random_flip_left_right(elem)#左右翻转 elem = tf.image.random_brightness(elem, max_delta=0.5)#调亮度 elem = tf.image.random_contrast(elem, lower=0.5, upper=1.5)#调对比度 elem = preprocess_input(elem) return elem # load CIFAR10 dataset, size(32,32,3) (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data() assert x_train.shape == (50000, 32, 32, 3) assert x_test.shape == (10000, 32, 32, 3) assert y_train.shape == (50000, 1) assert y_test.shape == (10000, 1) x_test = preprocess_input(x_test) x_train_ds = tf.data.Dataset.from_tensor_slices(x_train).map(img_aug_fun) y_train_ds = tf.data.Dataset.from_tensor_slices(y_train) x_y_train_ds = tf.data.Dataset.zip((x_train_ds,y_train_ds)) x_y_train_ds = x_y_train_ds.batch(128) # In[4]: reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='accuracy', factor=0.5, patience=4, min_lr=0.0001,verbose=1) earlystop = keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=8,verbose=1) model = CNNmodel(input_shape=(32,32,3),filters=64, kernel=(3,3),size=9) model.compile(optimizer='SGD',loss='sparse_categorical_crossentropy',metrics=['accuracy']) history = model.fit(x_y_train_ds,validation_data=(x_test,y_test),callbacks=[reduce_lr,earlystop],verbose=2,epochs=500) # In[5]: plt.plot(history.history['val_accuracy'],label='val_acc') plt.legend() plt.xlabel('Epochs') plt.ylabel('Acc') plt.show() # In[6]: model.save('checkpoint/Cifar10_CNN_%.3f'%history.history['val_accuracy'][-1]+'.h5') [cifar_cnn.ipynb](https://oss-club.rt-thread.org/uploads/20210616/8ed570671e76b0075cd78b42a36f8545.ipynb) ```
0
条评论
默认排序
按发布时间排序
登录
注册新账号
关于作者
霍格沃茨的小学生
这家伙很懒,什么也没写!
文章
3
回答
6
被采纳
3
关注TA
发私信
相关文章
1
RT-Thread AI Kit 相关资料和教程在哪里?
2
20号的开发者大会上,人脸识别AI
3
2020 开发者大会演示的 AI 套件什么时候开源?
4
请问RT-AK有没有部署MobileNet的例子呢
5
请问RT-AK会支持paddlepaddle的模型吗
6
嵌入式比赛中要求的 RT-Thread ai toolkit 能介绍一下吗?
7
使用rt_ai_tools转换模型时报错
8
求一个识别人有没有带口罩的模型?
9
RT-AK的人物检测例子下载后编译失败
10
cube ai部署后报错
推荐文章
1
RT-Thread应用项目汇总
2
玩转RT-Thread系列教程
3
国产MCU移植系列教程汇总,欢迎查看!
4
机器人操作系统 (ROS2) 和 RT-Thread 通信
5
五分钟玩转RT-Thread新社区
6
【技术三千问】之《玩转ART-Pi》,看这篇就够了!干货汇总
7
关于STM32H7开发板上使用SDIO接口驱动SD卡挂载文件系统的问题总结
8
STM32的“GPU”——DMA2D实例详解
9
RT-Thread隐藏的宝藏之completion
10
【ART-PI】RT-Thread 开启RTC 与 Alarm组件
热门标签
RT-Thread Studio
串口
Env
LWIP
SPI
AT
Bootloader
Hardfault
CAN总线
FinSH
ART-Pi
DMA
USB
文件系统
RT-Thread
SCons
RT-Thread Nano
线程
MQTT
STM32
RTC
rt-smart
FAL
I2C_IIC
UART
ESP8266
cubemx
WIZnet_W5500
ota在线升级
PWM
BSP
flash
freemodbus
packages_软件包
潘多拉开发板_Pandora
定时器
ADC
GD32
flashDB
socket
编译报错
中断
Debug
rt_mq_消息队列_msg_queue
keil_MDK
ulog
SFUD
msh
C++_cpp
MicroPython
本月问答贡献
RTT_逍遥
10
个答案
3
次被采纳
xiaorui
3
个答案
2
次被采纳
winfeng
2
个答案
2
次被采纳
三世执戟
8
个答案
1
次被采纳
KunYi
8
个答案
1
次被采纳
本月文章贡献
catcatbing
3
篇文章
5
次点赞
lizimu
2
篇文章
9
次点赞
swet123
1
篇文章
4
次点赞
Days
1
篇文章
4
次点赞
YZRD
1
篇文章
2
次点赞
回到
顶部
发布
问题
投诉
建议
回到
底部