请选择 进入手机版 | 继续访问电脑版

Hello Mat

 找回密码
 立即注册
查看: 6360|回复: 0

基于多隐藏层的BP神经网络

[复制链接]

1278

主题

1504

帖子

90

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
22549
发表于 2017-3-26 13:06:45 | 显示全部楼层 |阅读模式
基于多隐藏层的BP神经网络:主要结构如下:
  1. % 多隐藏层
  2. net=newff(inputn,output_train,[10,15],{'logsig','tansig'},'trainlm');
复制代码

主要代码如下:
  1. %% 该代码为基于多隐藏层的BP网络的语言识别
  2. %% 清空环境变量
  3. clc,clear,close all
  4. warning off
  5. format shortG
  6. % 训练数据预测数据提取及归一化
  7. %下载四类语音信号
  8. load data1 c1
  9. load data2 c2
  10. load data3 c3
  11. load data4 c4
  12. %四个特征信号矩阵合成一个矩阵
  13. data(1:500,:)=c1(1:500,:);
  14. data(501:1000,:)=c2(1:500,:);
  15. data(1001:1500,:)=c3(1:500,:);
  16. data(1501:2000,:)=c4(1:500,:);

  17. %从1到2000间随机排序
  18. k=rand(1,2000);
  19. [m,n]=sort(k);

  20. %输入输出数据
  21. input=data(:,2:25);
  22. output1 =data(:,1);

  23. %把输出从1维变成4维
  24. output=zeros(2000,4);
  25. for i=1:2000
  26.     switch output1(i)
  27.         case 1
  28.             output(i,:)=[1 0 0 0];
  29.         case 2
  30.             output(i,:)=[0 1 0 0];
  31.         case 3
  32.             output(i,:)=[0 0 1 0];
  33.         case 4
  34.             output(i,:)=[0 0 0 1];
  35.     end
  36. end

  37. %随机提取1500个样本为训练样本,500个样本为预测样本
  38. input_train=input(n(1:1500),:)';
  39. output_train=output(n(1:1500),:)';
  40. input_test=input(n(1501:2000),:)';
  41. output_test=output(n(1501:2000),:)';

  42. %输入数据归一化
  43. [inputn,inputps]=mapminmax(input_train);
  44. %% 网络结构初始化
  45. innum=24;
  46. midnum=25;
  47. outnum=4;
  48. %初始化网络结构
  49. nntwarn off
  50. % 多隐藏层
  51. net=newff(inputn,output_train,[10,15],{'logsig','tansig'},'trainlm');

  52. net.trainParam.epochs=100;
  53. net.trainParam.lr=0.1;
  54. net.trainParam.goal=0.00004;
  55. % 网络训练
  56. net=train(net,inputn,output_train);

  57. % BP网络预测
  58. % 训练样本--预测输出
  59. yc_train = sim(net,inputn);

  60. % 测试样本--预测输出
  61. inputn_test=mapminmax('apply',input_test,inputps);
  62. yc_test = sim(net,inputn_test);

  63. %% 结果分析
  64. % 根据网络输出找出数据属于哪类
  65. % 训练样本
  66. output_train_yc=zeros(1,1500);
  67. for i=1:1500
  68.     output_train_yc(i)=find(yc_train(:,i)==max(yc_train(:,i)));
  69. end
  70. % 训练样本预测误差
  71. error_train = output_train_yc - output1(n(1:1500))';
  72. [eTa,eTb] = find(error_train==0);
  73. disp(['训练样本预测正确率为:', num2str( length(eTb)/length(error_train) )])

  74. % 测试样本
  75. output_test_yc=zeros(1,500);
  76. for i=1:500
  77.     output_test_yc(i)=find(yc_test(:,i)==max(yc_test(:,i)));
  78. end
  79. % 测试样本预测误差
  80. error_test = output_test_yc - output1(n(1501:2000))';

  81. % 画出预测语音种类和实际语音种类的分类图
  82. figure(1)
  83. plot(output_test_yc,'r')
  84. hold on
  85. plot(output1(n(1501:2000))','b')
  86. legend('预测语音类别','实际语音类别')
  87. hold off

  88. % 画出误差图
  89. figure(2)
  90. plot(error_test)
  91. title('BP网络分类误差','fontsize',12)
  92. xlabel('语音信号','fontsize',12)
  93. ylabel('分类误差','fontsize',12)

  94. % 找出判断错误的分类属于哪一类
  95. k=zeros(1,4);  
  96. for i=1:500
  97.     if error_test(i)~=0
  98.         [b,c]=max(output_test(:,i));
  99.         switch c
  100.             case 1
  101.                 k(1)=k(1)+1;
  102.             case 2
  103.                 k(2)=k(2)+1;
  104.             case 3
  105.                 k(3)=k(3)+1;
  106.             case 4
  107.                 k(4)=k(4)+1;
  108.         end
  109.     end
  110. end

  111. % 找出每类的个体和
  112. kk=zeros(1,4);
  113. for i=1:500
  114.     [b,c]=max(output_test(:,i));
  115.     switch c
  116.         case 1
  117.             kk(1)=kk(1)+1;
  118.         case 2
  119.             kk(2)=kk(2)+1;
  120.         case 3
  121.             kk(3)=kk(3)+1;
  122.         case 4
  123.             kk(4)=kk(4)+1;
  124.     end
  125. end

  126. % 正确率
  127. rightridio=(kk-k)./kk;
  128. disp('测试样本4类预测正确率分别为:')
  129. disp(rightridio);
复制代码







算法QQ  3283892722
群智能算法链接http://halcom.cn/forum.php?mod=forumdisplay&fid=73
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Python|Opencv|MATLAB|Halcom.cn ( 蜀ICP备16027072号 )

GMT+8, 2024-3-29 03:14 , Processed in 0.242441 second(s), 25 queries .

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表