首页 > 技术知识 > 正文

原文:https://blog.csdn.net/brightming/article/details/50895356 5 完整示例 5.1 二维数据的训练与预测 5.1.1 训练二维数据 以y=kx直线进行划分,在直线以下的为类别0,其他为类别1。 在训练的时候,指定k,产生训练数据,同时将一部分作为测试数据。

5.1.1.1 训练入口

/**

以斜率=slope 来做分界,训练一个mlp模型 */ extern “C” void train_2_class_slope(float slope){//(int useExistModel,float from_x,float end_x,float from_y,float end_y,float x_step,float y_step){

CvANN_MLP annMlp; int outputClassCnt=2; bool loadModelFromFile=false;

Mat training_datas; Mat trainClasses; Mat oriTrainDatas;

generateFix2ClassSlopeTrainData(slope,training_datas,trainClasses); training_datas.convertTo(training_datas,CV_32FC1); trainClasses.convertTo(trainClasses,CV_32FC1);

cout<<“training_datas=\n”<<training_datas<<“,oridata=”<<oriTrainDatas<<endl; cout<<“trainClasses=\n”<<trainClasses<<endl;

//创建mlp Mat layers(1, 3, CV_32SC1); layers.at(0) = training_datas.row(0).cols; cout<<“————————trainAnnModel.input sample cnt:”<<training_datas.rows<<“,input layer features:”<<layers.at(0,0)<<endl; layers.at(1)=3; layers.at(2) = outputClassCnt;//输出

cout<<“outputClassCnt=”<<outputClassCnt<<endl; annMlp.create(layers, CvANN_MLP::SIGMOID_SYM, 0.6667f, 1.7159f);

//——–训练mlp———–// // Set up BPNetwork‘s parameters CvANN_MLP_TrainParams params; params.train_method = CvANN_MLP_TrainParams::BACKPROP; params.bp_dw_scale = 0.001; params.bp_moment_scale = 0.0;

CvTermCriteria criteria; criteria.max_iter = 300; criteria.epsilon = 9.999999e-06; criteria.type = CV_TERMCRIT_ITER | CV_TERMCRIT_EPS; params.term_crit = criteria;

annMlp.train(training_datas, trainClasses, Mat(), Mat(), params); cout<<“train finished”<<endl;

char _dstPath[256]; sprintf(_dstPath,”data/my/my_simple_2_class_20160307slope%.2f.xml”,slope); string dstPath(_dstPath);//=”data/my/my_simple_2_class_20160307_slope_1.xml”; annMlp.save(dstPath.c_str()); cout<<“save model finished.model file=”<<dstPath<<“\n”;

//预测 Mat test_datas; Mat testClasses; int testCount=1;//每个象限的测试图片数量 Mat oriTestData; generate2ClassSlopeTestData(slope,test_datas,testClasses,oriTestData);//,from_x,end_x,from_y,end_y); test_datas.convertTo(test_datas,CV_32FC1); testClasses.convertTo(testClasses,CV_32FC1); cout<<“test_datas=\n”<<test_datas<<“,oridata=”<<oriTestData<<endl; // cout<<“testClasses=\n”<<testClasses<<endl;

int correctCount=0; int errorCount=0; cout<<“test_datas size=”<<test_datas.rows<<endl; int totalTestSize=test_datas.rows; bool right=false;

// TestData_2Feature* cur=testDataHead; int expected_idx=0; for(int i=0;i<totalTestSize;i++){ Mat predict_result(1, outputClassCnt, CV_32FC1); annMlp.predict(test_datas.row(i), predict_result); Point maxLoc; double maxVal; minMaxLoc(predict_result, 0, &maxVal, 0, &maxLoc);

right=false; if(test_datas.row(i).at(0,0)*slope > test_datas.row(i).at(0,1)){ expected_idx=0; }else{ expected_idx=1; } if(expected_idx==maxLoc.x){ ++correctCount; right=true; }else { ++errorCount; } cout<<“data:”<<test_datas.row(i)<<“(“<<oriTestData.row(i)<<“),predict_result=”<<predict_result<<“,maxVal=”<<maxVal<<“,maxLoc.x=”<<maxLoc.x<<“,right?”<<right<<endl;

// cur=cur->next; }

cout<<“total test data count=”<<totalTestSize<<“,correct count=”<<correctCount<<“,error count=”<<errorCount<<“,accurate=”<<(correctCount)*1.0f/(totalTestSize)<<endl;

}

5.1.1.2 训练与测试数据产生方法

/**

y=x,划分, */ void generateFix2ClassSlopeTrainData(float slope,Mat& mat,Mat& labels){ vector dataVec; vector labVec;

float tmp1=0,tmp2=0; printf(“generateFix2ClassSlopeTrainData begin\n”);

int multi=1; float x_step=16; float y_step=16; int needTestSize=10; int nowTestSize=0;

int loopcnt=0; ostringstream os;

Int end_x=255; Int end_y=255; int getDataInterval=((end_x-0)/x_step (end_y-0)/y_step)/needTestSize; printf(“getDataInterval=%d,totalTrainSize=%d\n”,(deltaX/x_step deltaY/y_step));

for(int x=0;x<end_x;x+=x_step){ for(int y=0;y<end_y;y+=x_step){ ++loopcnt; dataVec.clear(); multi=-1; tmp1=multi(float)x;///255; dataVec.push_back(tmp1); tmp2=multi*(float)y;///255; dataVec.push_back(tmp2);

// printf(“tmp1=%f\n”,tmp1); // Mat tpmat=Mat(dataVec).reshape(1,1).clone(); mat.push_back(Mat(dataVec).reshape(1,1).clone());

labVec.clear(); if(tmp1*slope>tmp2){// x> 为类0 labVec.push_back(1.0f); labVec.push_back(0.0f); labels.push_back(Mat(labVec).reshape(1,1).clone());

if(loopcnt%getDataInterval==0){ os<<“0:”; } }else{ labVec.push_back(0.0f); labVec.push_back(1.0f); labels.push_back(Mat(labVec).reshape(1,1).clone());

if(loopcnt%getDataInterval==0){ os<<“1:”; } } if(loopcnt%getDataInterval==0){ os<<x<<” “<<y<<endl; }

} }

//输出一部分作为测试文件 system(“rm data/my/test2classdata_slope.list”); fstream ftxt; string testfile=”data/my/test2classdata_slope.list”; ftxt.open(testfile.c_str(), ios::out | ios::app); if (ftxt.fail()) { cout << “创建文件:”<<testfile<<” 失败!” <<endl; getchar(); } ftxt << os.str(); ftxt.close(); }

5.1.2 海思预测二维数据样本的所属类别 5.1.2.1 预测入口 /**

测a试?y=kx的?分?类え?情é况? / HI_VOID SAMPLE_IVE_Ann_predict_2class_slope(float slope){ // HI_CHAR pchBinFileName; int height,width,image_type; char pchBinFileName[256]; sprintf(pchBinFileName,”./data/my/my_simple_2_class_20160307slope%.2f.bin”,slope); // pchBinFileName = “./data/my/my_simple_2_class_20160307_slope_3.00.bin”; height=1; width=2; image_type=IVE_IMAGE_TYPE_S32C1;

HI_S32 s32Ret; SAMPLE_IVE_ANN_INFO_S stAnnInfo;

printf(“use model bin file:%s\n”,pchBinFileName); SAMPLE_COMM_IVE_CheckIveMpiInit();

s32Ret=SAMPLE_IVE_Ann_Mlp_2Class_Slope_Init(&stAnnInfo, pchBinFileName,image_type,height,width); if (HI_SUCCESS != s32Ret) { SAMPLE_PRT(“SAMPLE_IVE_Ann_Mlp__2Class_Init fail\n”); goto ANN_FAIL; } // predict2ClassData(&stAnnInfo,slope); predict2ClassSlopeData(&stAnnInfo,slope);

//uninit SAMPLE_IVE_Ann_Mlp_Uninit(&stAnnInfo);

ANN_FAIL: SAMPLE_COMM_IVE_IveMpiExit(); }

5.1.2.2 初始化

/**

function : Ann mlp init **/ static HI_S32 SAMPLE_IVE_Ann_Mlp_2Class_Slope_Init(SAMPLE_IVE_ANN_INFO_S pstAnnInfo, HI_CHAR pchBinFileName,int image_type,int height,int width ) { SAMPLE_PRT(“SAMPLE_IVE_Ann_Mlp_Init…..\n”); HI_S32 s32Ret = HI_SUCCESS; HI_U32 u32Size;

memset(pstAnnInfo, 0, sizeof(SAMPLE_IVE_ANN_INFO_S));

/**

查é找ò表括?里?的?数簓值μ范?围§是?[0,1],?精?度è是?8位?,?即′1<<8=256,? 表括?示?要癮被?分?成é256段?。£

/ pstAnnInfo->stTable.s32TabInLower = 0; pstAnnInfo->stTable.s32TabInUpper = 1;//1; pstAnnInfo->stTable.u8TabInPreci = 8;//12; pstAnnInfo->stTable.u8TabOutNorm = 2;//2 pstAnnInfo->stTable.u16ElemNum = (pstAnnInfo->stTable.s32TabInUpper-pstAnnInfo->stTable.s32TabInLower) << pstAnnInfo->stTable.u8TabInPreci; u32Size = pstAnnInfo->stTable.u16ElemNum sizeof(HI_U16); // SAMPLE_PRT(“stTable.s32TabInLower=%d,s32TabInUpper=%d,u8TabInPreci=%d,u8TabOutNorm=%d,u16ElemNum=%d\n”,pstAnnInfo->stTable.s32TabInLower,pstAnnInfo->stTable.s32TabInUpper,pstAnnInfo->stTable.u8TabInPreci,pstAnnInfo->stTable.u8TabOutNorm,pstAnnInfo->stTable.u16ElemNum); s32Ret = SAMPLE_COMM_IVE_CreateMemInfo(&(pstAnnInfo->stTable.stTable), u32Size); if (s32Ret != HI_SUCCESS) { SAMPLE_PRT(“SAMPLE_COMM_IVE_CreateMemInfo fail\n”); goto ANN_INIT_FAIL; }

s32Ret = SAMPLE_IVE_Ann_Mlp_CreateTable(&(pstAnnInfo->stTable), 0.6667f, 1.7159f);

// s32Ret = SAMPLE_IVE_Ann_Mlp_CreateTable(&(pstAnnInfo->stTable), 1.0f, 1.0f); if (s32Ret != HI_SUCCESS) { SAMPLE_PRT(“SAMPLE_IVE_Ann_Mlp_CreateTable fail\n”); goto ANN_INIT_FAIL; } SAMPLE_PRT(“begin to load model:%s\n”,pchBinFileName); s32Ret = HI_MPI_IVE_ANN_MLP_LoadModel(pchBinFileName, &(pstAnnInfo->stAnnModel)); if (s32Ret != HI_SUCCESS) { SAMPLE_PRT(“HI_MPI_IVE_ANN_MLP_LoadModel fail,Error(%#x)\n”, s32Ret); goto ANN_INIT_FAIL; } printf(“finish load model:%s\n”,pchBinFileName);

u32Size = pstAnnInfo->stAnnModel.au16LayerCount[0] * sizeof(HI_S16Q16);//输?入?层?需è要癮的?空?间?大洙?小?:阰输?入?层?的?元a素?个?数簓*每?个?元a素?的?大洙?小? printf(“allocate memory for input,size=%d\n”,u32Size); s32Ret = SAMPLE_COMM_IVE_CreateMemInfo(&(pstAnnInfo->stSrc), u32Size); if (s32Ret != HI_SUCCESS) { SAMPLE_PRT(“SAMPLE_COMM_IVE_CreateMemInfo fail\n”); goto ANN_INIT_FAIL; } u32Size = pstAnnInfo->stAnnModel.au16LayerCount[pstAnnInfo->stAnnModel.u8LayerNum – 1] * sizeof(HI_S16Q16);//输?出?类え?别纄信?息¢所ù需è空?间?的?大洙?小?:阰输?出?层?类え?别纄数簓*每?个?类え?别纄数簓值μ的?占?的?空?间?

// SAMPLE_PRT(“annModel output class cnt=%d\n”,pstAnnInfo->stAnnModel.au16LayerCount[pstAnnInfo->stAnnModel.u8LayerNum – 1]); printf(“allocate memory for output,size=%d\n”,u32Size); s32Ret = SAMPLE_COMM_IVE_CreateMemInfo(&(pstAnnInfo->stDst), u32Size); if (s32Ret != HI_SUCCESS) { SAMPLE_PRT(“SAMPLE_COMM_IVE_CreateMemInfo fail\n”); goto ANN_INIT_FAIL; }

ANN_INIT_FAIL:

// printf(“s32Ret=%d,HI_SUCCESS=%d\n”,s32Ret,HI_SUCCESS); if (HI_SUCCESS != s32Ret) { SAMPLE_IVE_Ann_Mlp_Uninit(pstAnnInfo); }

return s32Ret;

}

5.1.2.3 预测

/**

预¤测ay=kx的?分?类え? / void predict2ClassSlopeData(SAMPLE_IVE_ANN_INFO_S pstAnnInfo,float slope){ char contFile=”data/my/test2classdata_slope_eq_1.list”; printf(“try to get file info:%s\n”,contFile); TestData_2Feature head=get2FeatureData(contFile); printf(“after read file:%s,head=%p\n”,contFile,head); if(!head){ printf(“fail to read contFile:%s\n”,contFile); return; }

// printStringNode(head,”1″); // printStringNode(head,”2″);

HI_S32 i, k; HI_S32 s32Ret; HI_S32 s32ResponseCls; HI_U16 u16LayerCount; HI_S16Q16 ps16q16Dst; HI_S16Q16 s16q16Response; HI_BOOL bInstant = HI_TRUE; HI_BOOL bFinish; HI_BOOL bBlock = HI_TRUE; // HI_CHAR achFileName[IVE_FILE_NAME_LEN]; FILE pFp = HI_NULL; IVE_HANDLE iveHandle;

int xs[3]={-5,-4,3}; int ys[3]={99,-10,10};

srand(time(NULL));

int totalCount=0; int correctCount=0; TestData_2Feature* cur=head;

int cnt=0; int expected_idx=0; while(cur!=NULL){ // printf(“flag=%d,filePath=%s,filenName=%s –>\n “,cur->flag,cur->fileFullPath,cur->fileName); ps16q16Dst = (HI_S16Q16*)pstAnnInfo->stDst.pu8VirAddr; s16q16Response = 0; s32ResponseCls = -1;

HI_S16Q16 stSrc=(HI_S16Q16)pstAnnInfo->stSrc.pu8VirAddr; stSrc[0]=changeFloatToS16Q16(cur->x1);//转换为以s16q16表示的数据 stSrc[1]=changeFloatToS16Q16(cur->x2);

s32Ret = HI_MPI_IVE_ANN_MLP_Predict(&iveHandle, &(pstAnnInfo->stSrc), \ & (pstAnnInfo->stTable), &(pstAnnInfo->stAnnModel), &(pstAnnInfo->stDst), bInstant); if (s32Ret != HI_SUCCESS) { SAMPLE_PRT(“HI_MPI_IVE_ANN_MLP_Predict fail,Error(%#x)\n”, s32Ret); break; } s32Ret = HI_MPI_IVE_Query(iveHandle, &bFinish, bBlock); while (HI_ERR_IVE_QUERY_TIMEOUT == s32Ret) { usleep(100); s32Ret = HI_MPI_IVE_Query(iveHandle, &bFinish, bBlock); } if (HI_SUCCESS != s32Ret) { SAMPLE_PRT(“HI_MPI_IVE_Query fail,Error(%#x)\n”, s32Ret); break; } u16LayerCount = pstAnnInfo->stAnnModel.au16LayerCount[pstAnnInfo->stAnnModel.u8LayerNum – 1]; // SAMPLE_PRT(“pstAnnInfo->CstAnnModel.u8LayerNum=%d,pstAnnInfo->stAnnModel.au16LayerCount[pstAnnInfo->stAnnModel.u8LayerNum – 1]=%d\n”,pstAnnInfo->stAnnModel.u8LayerNum,pstAnnInfo->stAnnModel.au16LayerCount[pstAnnInfo->stAnnModel.u8LayerNum – 1]); SAMPLE_PRT(” \n–predict2ClassSlopeData–Begin– x1=%f(s16q16=%d),x2=%f(s16q16=%d)\n”,cur->x1,changeFloatToS16Q16(cur->x1),cur->x2,changeFloatToS16Q16(cur->x2)); ++totalCount; for (k = 0; k < u16LayerCount; k++) { printf(” ps16q16Dst[%d]=%d,H16Q16=%f\n”, k,ps16q16Dst[k],calculateS16Q16_c(ps16q16Dst[k])); if (s16q16Response < ps16q16Dst[k]) { s16q16Response = ps16q16Dst[k]; s32ResponseCls = k; } }

if(cur->x1*slope>cur->x2){ expected_idx=0; }else{ expected_idx=1; } SAMPLE_PRT(” –predict2ClassSlopeData–End– result:%s,flag:%d,class:%d ——\n\n”,(expected_idx==s32ResponseCls?”right”:”wrong”),expected_idx,s32ResponseCls);

cur=cur->next; }

freeTestData_2FeatureNode(head); } 附上斜率为的预测结果输出: [predict2ClassSlopeData]-830:

–predict2ClassSlopeData–Begin– x1=0.220000(s16q16=14417),x2=0.100000(s16q16=6553) ps16q16Dst[0]=46174,H16Q16=0.704559 ps16q16Dst[1]=20098,H16Q16=0.306671 [predict2ClassSlopeData]-847: –predict2ClassSlopeData–End– result:right,flag:0,class:0 ——

[predict2ClassSlopeData]-830:

–predict2ClassSlopeData–Begin– x1=-1.000000(s16q16=-65536),x2=-3.000000(s16q16=-196608) ps16q16Dst[0]=48919,H16Q16=0.746445 ps16q16Dst[1]=15412,H16Q16=0.235168 [predict2ClassSlopeData]-847: –predict2ClassSlopeData–End– result:right,flag:0,class:0 ——

[predict2ClassSlopeData]-830:

–predict2ClassSlopeData–Begin– x1=1.000000(s16q16=65536),x2=0.200000(s16q16=13107) ps16q16Dst[0]=48919,H16Q16=0.746445 ps16q16Dst[1]=15412,H16Q16=0.235168 [predict2ClassSlopeData]-847: –predict2ClassSlopeData–End– result:right,flag:0,class:0 ——

[predict2ClassSlopeData]-830:

–predict2ClassSlopeData–Begin– x1=0.200000(s16q16=13107),x2=0.700000(s16q16=45875) ps16q16Dst[0]=16687,H16Q16=0.254623 ps16q16Dst[1]=51450,H16Q16=0.785065 [predict2ClassSlopeData]-847: –predict2ClassSlopeData–End– result:right,flag:1,class:1 ——

[predict2ClassSlopeData]-830:

–predict2ClassSlopeData–Begin– x1=0.400000(s16q16=26214),x2=0.900000(s16q16=58982) ps16q16Dst[0]=16830,H16Q16=0.256805 ps16q16Dst[1]=51033,H16Q16=0.778702 [predict2ClassSlopeData]-847: –predict2ClassSlopeData–End– result:right,flag:1,class:1 ——

[predict2ClassSlopeData]-830:

–predict2ClassSlopeData–Begin– x1=0.690196(s16q16=45232),x2=0.062745(s16q16=4112) ps16q16Dst[0]=48919,H16Q16=0.746445 ps16q16Dst[1]=15412,H16Q16=0.235168 [predict2ClassSlopeData]-847: –predict2ClassSlopeData–End– result:right,flag:0,class:0 ——

[predict2ClassSlopeData]-830:

–predict2ClassSlopeData–Begin– x1=224.000000(s16q16=14680064),x2=80.000000(s16q16=5242880) ps16q16Dst[0]=17622,H16Q16=0.268890 ps16q16Dst[1]=45294,H16Q16=0.691132 [predict2ClassSlopeData]-847: –predict2ClassSlopeData–End– result:wrong,flag:0,class:1 ——

[predict2ClassSlopeData]-830:

–predict2ClassSlopeData–Begin– x1=-224.000000(s16q16=-14680064),x2=80.000000(s16q16=5242880) ps16q16Dst[0]=17117,H16Q16=0.261185 ps16q16Dst[1]=51728,H16Q16=0.789307 [predict2ClassSlopeData]-847: –predict2ClassSlopeData–End– result:right,flag:1,class:1 ——

猜你喜欢