传感器采集轴承震动数据,结合卷积神经网络进行轴承故障诊断。在知乎上看到一篇文章,根据文章提供的数据以及代码制作了一个模型
以下是模型的训练以及量化代码。(这里有更好的量化方式,但是不是这方面的还请提出模型量化改进方法?)
import tensorflow
import glob
import numpy as np
import pandas as pd
import math
import os
import sys
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.datasets import *
from pathlib import Path
MANIFEST_DIR = sys.path[0] + "\\train.csv"
Batch_size = 20
Long = 792
Lens = 640
#把标签转成oneHot
def convert2oneHot(index,Lens):
hot = np.zeros((Lens,))
hot[int(index)] = 1
return(hot)
def xs_gen(path=MANIFEST_DIR,batch_size = Batch_size,train=True,Lens=Lens):
img_list = pd.read_csv(path)
if train:
img_list = np.array(img_list)[:Lens]
print(img_list.shape)
print("Found %s train items."%len(img_list))
print("list 1 is",img_list[0,-1])
steps = math.ceil(len(img_list) / batch_size) # 确定每轮有多少个batch
else:
img_list = np.array(img_list)[Lens:]
print("Found %s test items."%len(img_list))
print("list 1 is",img_list[0,-1])
steps = math.ceil(len(img_list) / batch_size) # 确定每轮有多少个batch
while True:
for i in range(steps):
batch_list = img_list[i * batch_size : i * batch_size + batch_size]
np.random.shuffle(batch_list)
batch_x = np.array([file for file in batch_list[:,1:-1]])
batch_y = np.array([convert2oneHot(label,10) for label in batch_list[:,-1]])
yield batch_x, batch_y
TIME_PERIODS = 6000
def build_model(input_shape=(TIME_PERIODS,),num_classes=10):
model = Sequential()
model.add(Reshape((TIME_PERIODS, 1), input_shape=input_shape))
model.add(Conv1D(16, 8,strides=2, activation='relu',input_shape=(TIME_PERIODS,1)))
model.add(Conv1D(16, 8,strides=2, activation='relu',padding="same"))
model.add(MaxPooling1D(2))
model.add(Conv1D(64, 4,strides=2, activation='relu',padding="same"))
model.add(Conv1D(64, 4,strides=2, activation='relu',padding="same"))
model.add(MaxPooling1D(2))
model.add(Conv1D(256, 4,strides=2, activation='relu',padding="same"))
model.add(Conv1D(256, 4,strides=2, activation='relu',padding="same"))
model.add(MaxPooling1D(2))
model.add(Conv1D(512, 2,strides=1, activation='relu',padding="same"))
model.add(Conv1D(512, 2,strides=1, activation='relu',padding="same"))
model.add(MaxPooling1D(2))
model.add(GlobalAveragePooling1D())
model.add(Dropout(0.3))
model.add(Dense(num_classes, activation='softmax'))
return(model)
if __name__ == "__main__":
train_iter = xs_gen()
val_iter = xs_gen(train=False)
ckpt = tensorflow.keras.callbacks.ModelCheckpoint(
filepath= r"C:\Users\years\Desktop\learn_some\model\'best_model.{epoch:02d}-{val_loss:.4f}.h5",
monitor='val_loss', save_best_only=True,verbose=1)
model = build_model()
opt = Adam(0.0002)
model.compile(loss='categorical_crossentropy',
optimizer=opt, metrics=['accuracy'])
print(model.summary())
model.fit(
x=train_iter,
steps_per_epoch=Lens//Batch_size,
epochs=30,
initial_epoch=0,
validation_data = val_iter,
validation_steps = (Long - Lens)//Batch_size,
callbacks=[ckpt],
)
keras_file = r"C:\Users\years\Desktop\learn_some\model\finishModel.h5"
model.save(keras_file, save_format="h5")
model = tensorflow.keras.models.load_model(keras_file)
model.input.set_shape(1 + model.input.shape[1:])
converter = tensorflow.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
tflite_file = Path("C:/Users/years/Desktop/learn_some/model/finishModel.tflite")
tflite_file.write_bytes(tflite_model)
print("tflite is finish")
得到训练后的模型,使用RT_AK将模型部署到嵌入式因不懂其具体的应用采用官方提供的mnist的源码如下:
/*
* Copyright (c) 2006-2021, RT-Thread Development Team
*
* SPDX-License-Identifier: Apache-2.0
*
* Change Logs:
* Date Author Notes
* 2021-07-02 Lebhoryi first version
*/
#include <rt_ai_network_model.h>
#include <rt_ai.h>
#include <rt_ai_log.h>
static rt_ai_t model = NULL;
#define MNIST_0_7 {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.34058376931019385, 0.5534434156016326, 0.5159157133469962, 0.47675838274876575, 0.16790986331319932, 0.06389560571085397, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9001142474626551, 0.7598628516908901, 0.8241672401895409, 0.8019644319085198, 0.7108184213592105, 0.4277455826754391, 0.3146021401174672, 0.2991960804831432, 0.35451095188003034, 0.35818466685703043, 0.34876617933839493, 0.3362681728806376, 0.3496743610673831, 0.3351779954733201, 0.37058414837972664, 0.2825753120721564, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.27165610171170224, 0.3410408074518168, 0.23362220981750767, 0.35993679227390263, 0.45615512866752483, 0.40289729156566256, 0.40358052318099324, 0.33999554600357185, 0.4547766756440793, 0.45948942111962493, 0.44740711894925406, 0.4245810263644414, 0.4044213569920744, 0.4299758123748652, 0.553696315814415, 0.7607796863481134, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.030172924919014375, 0.10486738003915572, 0.021155278418000027, 0.11996077664627289, 0.12120390242131839, 0.11801683846299221, 0.10020112222200817, 0.03708667465866184, 0.39950508551365427, 0.553696315814415, 0.5760189053778574, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1465806665080444, 0.42828299421590904, 0.45560051183154626, 0.0978145311019003, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03736313032007084, 0.4114854854984861, 0.4316686305338214, 0.1809322606795136, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.21908380960405174, 0.44857216015714796, 0.4028907218315666, 0.09591589722769395, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10392527566144091, 0.4228827022589836, 0.44857216015714796, 0.10495472585528207, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.23427223157579052, 0.43137432278627247, 0.33024800767475065, 0.008464090794780811, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01628112122077411, 0.3610962967897523, 0.42118437815352583, 0.1024298633429708, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.22793569709083755, 0.44740711894925406, 0.30909498719331335, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.13428445146970847, 0.4540623807127002, 0.4227468840465393, 0.09680447401109264, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.02871073499585718, 0.3956915169974076, 0.45948942111962493, 0.2923999281321897, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.004766699092688897, 0.3067515370610004, 0.4547766756440793, 0.39617394970550335, 0.06165058725678698, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.06037818850739269, 0.38381719415514337, 0.4547766756440793, 0.1392940371110674, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.05502121602879092, 0.35591353225410427, 0.38381719415514337, 0.2059028255868863, 0.0018090134689749012, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2360587655428772, 0.40358052318099324, 0.38381719415514337, 0.0931038863523312, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.17070836103508596, 0.42952046061185173, 0.40358052318099324, 0.38381719415514337, 0.0931038863523312, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3386182243482853, 0.45081899584880303, 0.40358052318099324, 0.33092899811014326, 0.07161837411717785, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3386182243482853, 0.45081899584880303, 0.32890223739553387, 0.02719964368028575, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}
const static float input_data0[] = MNIST_0_7;
void ai_run_complete(void *arg){
*(int*)arg = 1;
}
int mnist_app(void){
rt_err_t result = RT_EOK;
int ai_run_complete_flag = 0;
rt_ai_buffer_t *work_buffer = rt_malloc(RT_AI_NETWORK_WORK_BUFFER_BYTES +
RT_AI_NETWORK_IN_TOTAL_SIZE_BYTES +
RT_AI_NETWORK_OUT_TOTAL_SIZE_BYTES);
// find a registered model handle
model = rt_ai_find(RT_AI_NETWORK_MODEL_NAME);
if(!model) {
rt_kprintf("ai model find err\r\n");
return -1;
}
// init the model and allocate memory
result = rt_ai_init(model, work_buffer);
if (result != 0) {
rt_kprintf("ai init err\r\n");
return -1;
}
// prepare input data
rt_memcpy(model->input[0], input_data0, RT_AI_NETWORK_IN_1_SIZE_BYTES);
result = rt_ai_run(model , ai_run_complete, &ai_run_complete_flag);
if (result != 0) {
rt_kprintf("ai model run err\r\n");
return -1;
}
// get output and post-process the output
int pred_num = 0;
if(ai_run_complete_flag){
float *out = (float *)rt_ai_output(model, 0);
for(int i = 1 ; i < 10 ; i++){
if(out[i] > out[pred_num]){
pred_num = i;
}
}
AI_LOG("The Mnist prediction is : %d\n", pred_num);
}
rt_free(work_buffer);
return 0;
}
MSH_CMD_EXPORT(mnist_app, mnist classification demo);
报错: