传感器采集轴承震动数据,结合卷积神经网络进行轴承故障诊断。在知乎上看到一篇文章,根据文章提供的数据以及代码制作了一个模型
以下是模型的训练以及量化代码。(这里有更好的量化方式,但是不是这方面的还请提出模型量化改进方法?)
import tensorflow
import glob
import numpy as np
import pandas as pd
import math
import os
import sys
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.datasets import *
from pathlib import Path
MANIFEST_DIR = sys.path[0] + "\\train.csv"
Batch_size = 20
Long = 792
Lens = 640
#把标签转成oneHot
def convert2oneHot(index,Lens):
hot = np.zeros((Lens,))
hot[int(index)] = 1
return(hot)
def xs_gen(path=MANIFEST_DIR,batch_size = Batch_size,train=True,Lens=Lens):
img_list = pd.read_csv(path)
if train:
img_list = np.array(img_list)[:Lens]
print(img_list.shape)
print("Found %s train items."%len(img_list))
print("list 1 is",img_list[0,-1])
steps = math.ceil(len(img_list) / batch_size) # 确定每轮有多少个batch
else:
img_list = np.array(img_list)[Lens:]
print("Found %s test items."%len(img_list))
print("list 1 is",img_list[0,-1])
steps = math.ceil(len(img_list) / batch_size) # 确定每轮有多少个batch
while True:
for i in range(steps):
batch_list = img_list[i * batch_size : i * batch_size + batch_size]
np.random.shuffle(batch_list)
batch_x = np.array([file for file in batch_list[:,1:-1]])
batch_y = np.array([convert2oneHot(label,10) for label in batch_list[:,-1]])
yield batch_x, batch_y
TIME_PERIODS = 6000
def build_model(input_shape=(TIME_PERIODS,),num_classes=10):
model = Sequential()
model.add(Reshape((TIME_PERIODS, 1), input_shape=input_shape))
model.add(Conv1D(16, 8,strides=2, activation='relu',input_shape=(TIME_PERIODS,1)))
model.add(Conv1D(16, 8,strides=2, activation='relu',padding="same"))
model.add(MaxPooling1D(2))
model.add(Conv1D(64, 4,strides=2, activation='relu',padding="same"))
model.add(Conv1D(64, 4,strides=2, activation='relu',padding="same"))
model.add(MaxPooling1D(2))
model.add(Conv1D(256, 4,strides=2, activation='relu',padding="same"))
model.add(Conv1D(256, 4,strides=2, activation='relu',padding="same"))
model.add(MaxPooling1D(2))
model.add(Conv1D(512, 2,strides=1, activation='relu',padding="same"))
model.add(Conv1D(512, 2,strides=1, activation='relu',padding="same"))
model.add(MaxPooling1D(2))
model.add(GlobalAveragePooling1D())
model.add(Dropout(0.3))
model.add(Dense(num_classes, activation='softmax'))
return(model)
if __name__ == "__main__":
train_iter = xs_gen()
val_iter = xs_gen(train=False)
ckpt = tensorflow.keras.callbacks.ModelCheckpoint(
filepath= r"C:\Users\years\Desktop\learn_some\model\'best_model.{epoch:02d}-{val_loss:.4f}.h5",
monitor='val_loss', save_best_only=True,verbose=1)
model = build_model()
opt = Adam(0.0002)
model.compile(loss='categorical_crossentropy',
optimizer=opt, metrics=['accuracy'])
print(model.summary())
model.fit(
x=train_iter,
steps_per_epoch=Lens//Batch_size,
epochs=30,
initial_epoch=0,
validation_data = val_iter,
validation_steps = (Long - Lens)//Batch_size,
callbacks=[ckpt],
)
keras_file = r"C:\Users\years\Desktop\learn_some\model\finishModel.h5"
model.save(keras_file, save_format="h5")
model = tensorflow.keras.models.load_model(keras_file)
model.input.set_shape(1 + model.input.shape[1:])
converter = tensorflow.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
tflite_file = Path("C:/Users/years/Desktop/learn_some/model/finishModel.tflite")
tflite_file.write_bytes(tflite_model)
print("tflite is finish")
得到训练后的模型,使用RT_AK将模型部署到嵌入式因不懂其具体的应用采用官方提供的mnist的源码如下:
/*
* Copyright (c) 2006-2021, RT-Thread Development Team
*
* SPDX-License-Identifier: Apache-2.0
*
* Change Logs:
* Date Author Notes
* 2021-07-02 Lebhoryi first version
*/
#include <rt_ai_network_model.h>
#include <rt_ai.h>
#include <rt_ai_log.h>
static rt_ai_t model = NULL;
#define MNIST_0_7 {}
const static float input_data0[] = MNIST_0_7;
void ai_run_complete(void *arg){
*(int*)arg = 1;
}
int mnist_app(void){
rt_err_t result = RT_EOK;
int ai_run_complete_flag = 0;
rt_ai_buffer_t *work_buffer = rt_malloc(RT_AI_NETWORK_WORK_BUFFER_BYTES +
RT_AI_NETWORK_IN_TOTAL_SIZE_BYTES +
RT_AI_NETWORK_OUT_TOTAL_SIZE_BYTES);
// find a registered model handle
model = rt_ai_find(RT_AI_NETWORK_MODEL_NAME);
if(!model) {
rt_kprintf("ai model find err\r\n");
return -1;
}
// init the model and allocate memory
result = rt_ai_init(model, work_buffer);
if (result != 0) {
rt_kprintf("ai init err\r\n");
return -1;
}
// prepare input data
rt_memcpy(model->input[0], input_data0, RT_AI_NETWORK_IN_1_SIZE_BYTES);
result = rt_ai_run(model , ai_run_complete, &ai_run_complete_flag);
if (result != 0) {
rt_kprintf("ai model run err\r\n");
return -1;
}
// get output and post-process the output
int pred_num = 0;
if(ai_run_complete_flag){
float *out = (float *)rt_ai_output(model, 0);
for(int i = 1 ; i < 10 ; i++){
if(out[i] > out[pred_num]){
pred_num = i;
}
}
AI_LOG("The Mnist prediction is : %d\n", pred_num);
}
rt_free(work_buffer);
return 0;
}
MSH_CMD_EXPORT(mnist_app, mnist classification demo);
报错: