Hello Mat

 找回密码
 立即注册
查看: 4394|回复: 2

Pytorch深度学习模型写入供C++调用

[复制链接]

84

主题

115

帖子

731

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
1467
发表于 2019-5-8 08:23:20 | 显示全部楼层 |阅读模式
Pytorch深度学习模型写入供C++调用
  1. import torch
  2. from Models.MobileNetv2 import mobilenetv2

  3. model = mobildnetv2(pretrained)
  4. example = torch.rand(1, 3, 224, 224).cuda() # 注意,我这里导出的是CUDA版的模型,因为我的模型是在GPU中进行训练的
  5. model = model.eval()

  6. traced_script_module = torch.jit.trace(model, example)
  7. output = traced_script_module(torch.ones(1,3,224,224).cuda())
  8. traced_script_module.save('mobilenetv2-trace.pt')
  9. print(output)
复制代码


参考:
【1】https://zhuanlan.zhihu.com/p/52154049
【2】https://blog.csdn.net/IAMoldpan/article/details/85057238
【3】https://blog.csdn.net/IAMoldpan/article/details/86604302


回复

使用道具 举报

84

主题

115

帖子

731

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
1467
 楼主| 发表于 2019-5-8 23:11:17 | 显示全部楼层
  1. import torch
  2. import XXnet_path
  3. net = XXnet()
  4. net.cuda(0)        # GPU
  5. net.load_state_dict(torch.load('xx.pth')['state_dict'])
  6. net.eval()
  7. example = torch.rand(1,3,256,256).type(torch.FloatTensor).cuda(0)
  8. model = torch.jit.trace(net, example)
  9. model.save( 'xx.pt' )
复制代码
回复 支持 反对

使用道具 举报

84

主题

115

帖子

731

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
1467
 楼主| 发表于 2019-5-8 23:11:50 | 显示全部楼层
  1. import torch
  2. import XXnet_path
  3. net = XXnet()
  4. net.load_state_dict(torch.load('xx.pth')['state_dict'])
  5. net.eval()
  6. example = torch.rand(1,3,256,256)
  7. model = torch.jit.trace(net, example)      # CPU
  8. model.save( 'xx.pt' )
复制代码
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Python|Opencv|MATLAB|Halcom.cn ( 蜀ICP备16027072号 )

GMT+8, 2024-11-22 23:08 , Processed in 0.223047 second(s), 21 queries .

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表