Hello Mat

 找回密码
 立即注册
查看: 6442|回复: 9

Pytorch负反馈多层神经网络

[复制链接]

84

主题

115

帖子

731

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
1467
发表于 2020-4-19 12:27:16 | 显示全部楼层 |阅读模式
Pytorch负反馈多层神经网络:
模型如下:
  1. class module_net(nn.Module):
  2.     def __init__(self, num_input, num_hidden, num_output):
  3.         super(module_net, self).__init__()
  4.         self.layer1 = nn.Linear(num_input, num_hidden)
  5.         self.layer2 = nn.Tanh()
  6.         self.layer3 = nn.Linear(num_hidden, num_hidden)
  7.         self.layer4 = nn.Tanh()
  8.         self.layer5 = nn.Linear(num_hidden, num_hidden)
  9.         self.layer6 = nn.Tanh()
  10.         self.layer7 = nn.Linear(num_hidden, num_output)

  11.     def forward(self, x):
  12.         x = self.layer1(x)
  13.         x = self.layer2(x)
  14.         x = self.layer3(x)
  15.         x = self.layer4(x)
  16.         x = self.layer5(x)
  17.         x = self.layer6(x)
  18.         x = self.layer7(x)
  19.         return x
复制代码
简单的模型如下:
  1. class module_net(nn.Module):
  2.     def __init__(self, num_input, num_hidden, num_output):
  3.         super(module_net, self).__init__()
  4.         self.layer1 = nn.Linear(num_input, num_hidden)
  5.         self.layer2 = nn.ReLU()
  6.         self.layer3 = nn.Linear(num_hidden, num_output)
  7.         self.layer4 = nn.ReLU()

  8.     def forward(self, x):
  9.         x = self.layer1(x)
  10.         x = self.layer2(x)
  11.         x = self.layer3(x)
  12.         x = self.layer4(x)
  13.         return x
复制代码
当然也可以加阈值:
  1. class module_net(nn.Module):
  2.     def __init__(self, num_input, num_hidden, num_output):
  3.         super(module_net, self).__init__()
  4.         self.layer1 = nn.Linear(num_input, num_hidden)
  5.         self.bn1 = nn.BatchNorm1d(num_input)
  6.         self.layer2 = nn.Tanh()
  7.         self.layer3 = nn.Linear(num_hidden, num_hidden)
  8.         self.bn2 = nn.BatchNorm1d(num_hidden)
  9.         self.layer4 = nn.Tanh()
  10.         self.layer5 = nn.Linear(num_hidden, num_hidden)
  11.         self.bn3 = nn.BatchNorm1d(num_hidden)
  12.         self.layer6 = nn.Tanh()
  13.         self.layer7 = nn.Linear(num_hidden, num_output)
  14.         self.bn4 = nn.BatchNorm1d(num_output)

  15.     def forward(self, x):
  16.         x = self.layer1(x)
  17.         x = self.bn1(x)
  18.         x = self.layer2(x)
  19.         x = self.layer3(x)
  20.         x = self.bn2(x)
  21.         x = self.layer4(x)
  22.         x = self.layer5(x)
  23.         x = self.bn3(x)
  24.         x = self.layer6(x)
  25.         x = self.layer7(x)
  26.         x = self.bn4(x)
  27.         return x
复制代码

参考链接:
【1】【PyTorch 深度学习】4.用PyTorch实现多层网络
【2】基于多隐藏层的BP神经网络








回复

使用道具 举报

11

主题

13

帖子

40

金钱

版主

Rank: 7Rank: 7Rank: 7

积分
84
发表于 2025-7-23 14:53:41 | 显示全部楼层
  1. // 2. 转换为float*
  2.         float* floatArray = convertToFloatArray(numbers);
  3.         //
  4.         *outputObjs = new float[ObjLen];
  5.         //
  6.         torch::NoGradGuard nograd;
  7.         std::vector<torch::jit::IValue> input_vars;

  8.         // transform_images_normalized
  9.         auto image_tensor = torch::from_blob(floatArray, { 1, inputLen }, torch::kFloat32).clone();  // wav byte2字节,转short
  10.         delete[] floatArray;  // 不能在debug下运行

  11.         auto image_var = torch::autograd::make_variable(image_tensor, false);
  12.         //image_var = image_var.unsqueeze(0);  // 增加第1维维度
  13.         image_var = image_var.squeeze(0);  // 移除第1维维度
  14.         // call AI
  15.         input_vars.push_back(image_var);

  16.         torch::Tensor output = this->CNN_module.forward(input_vars).toTensor();
  17.         input_vars.pop_back();
  18.         // core calculation
  19.         //auto cpuTensor = outputs[2].to(torch::kInt64).to(at::kCPU).cpu();
  20.         auto accr = output.accessor<float, 2>();  // 输出为二位数组,维度 = 1 x 1000
  21.         for (int i = 0; i < ObjLen; i++)
  22.         {
  23.                 (*outputObjs)[i] = static_cast<float>(accr[0][i]);
  24.         }
复制代码
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Python|Opencv|MATLAB|Halcom.cn ( 蜀ICP备16027072号 )

GMT+8, 2025-8-30 17:51 , Processed in 0.186209 second(s), 24 queries .

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表