|
libtorch 常用api函数示例
- torch::Tensor b = torch::argmax(output_1, 2).cpu();
- // std::cout<<b<<std::endl;
- b.print();
- cv::Mat mask(T_height, T_width, CV_8UC1, (uchar*)b.data_ptr());
- imshow("mask",mask*255);
- waitKey(0);
复制代码- torch::Tensor a = torch::rand({2,3});
- torch::Tensor aa = a.clone();
- aa.masked_fill_(aa>0.5,-2);
- std::cout<<a<<std::endl;
- std::cout<<aa<<std::endl;
复制代码- 0.8803 0.2387 0.8577
- 0.8166 0.0730 0.4682
- [ Variable[CPUFloatType]{2,3} ]
- -2.0000 0.2387 -2.0000
- -2.0000 0.0730 0.4682
- [ Variable[CPUFloatType]{2,3} ]
复制代码
参考:
【1】libtorch 常用api函数示例https://blog.csdn.net/yang332233/article/details/106199180
|
|