热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

LibTorch之图像分类

LibTorch之图像分类LibTorch之图像分类LibTorch之图像分类数据集地址:https:download.pytorch.orgtutorialhym

LibTorch之图像分类LibTorch之图像分类LibTorch之图像分类


数据集地址:https://download.pytorch.org/tutorial/hymenoptera_data.zip


LibTorch之全连接层(torch::nn::Linear)使用


卷积层


LibTorch实现MLP(多层感知机)


LibTorch实现LeNet


训练

#include
#include
#include
#include
using namespace std;
namespace fs &#61; std::filesystem;vector<pair<string, int>> get_imgs_labels(const std::string& data_dir, map<string, int> dict_label)
{// 1.定义标签//map dict_label;//dict_label.insert(pair("ants", 0));//dict_label.insert(pair("bees", 1));// 2.定义存储图像路径和标签的vectorvector<pair<string, int>> data_info;// 3.读取图像和对应label放入data_info// 遍历字典,读取图像路径和对应labelfor (map<string, int>::iterator it &#61; dict_label.begin(); it !&#61; dict_label.end(); it&#43;&#43;){// 遍历目录查找for (const auto& file_path : fs::directory_iterator(data_dir)){if (file_path.path().filename() &#61;&#61; it->first) {// 遍历所有图像路径for (const auto& img_path : fs::directory_iterator(data_dir &#43; "\\" &#43; it->first)){//std::cout <data_info.push_back(pair<string, int>(img_path.path().string(), it->second));}}//std::cout <}//printVector(data_info);}return data_info;
}///


/// 数据集处理模块类
///

class MyDataset :public torch::data::Dataset<MyDataset> {
private:vector<pair<string, int>> data_info;torch::Tensor imgs, labels;
public:// 构造器&#xff1a;一般用于确定数据集和预处理形式MyDataset(const std::string& data_dir,std::map<string,int> dict_label);// get_item数据处理&#xff1a;对读取的数据进行预处理torch::data::Example<> get(size_t index) override;// 返回数据数量torch::optional<size_t> size() const override {return data_info.size();};
};///
/// 根据数据集路径和对应的标签列表&#xff0c;配对训练数据
///

///
///
MyDataset::MyDataset(const std::string& data_dir, std::map<string, int> dict_label) {// 获取训练数据data_info &#61; get_imgs_labels(data_dir, dict_label);}///
/// 对数据进行预处理&#xff0c;并返回成对的实例Example{data,label}
///

///
///
torch::data::Example<> MyDataset::get(size_t index)
{// 获取图像路径auto img_path &#61; data_info[index].first;// 确定labelauto label &#61; data_info[index].second;// opencv根据图像路径读取图像auto image &#61; cv::imread(img_path);cout << image.size() << endl;//获取通道数int channels &#61; image.channels();cout<<"channels:" <<channels << endl;// resize图像大小cv::resize(image, image, cv::Size(224, 224));// mat转tensorauto input_tensor &#61; torch::from_blob(image.data, { image.rows, image.cols, 3 }, torch::kByte).permute({ 2, 0, 1 }).to(torch::kFloat32) / 225.0;cout << input_tensor.sizes() << endl;// int转tensortorch::Tensor label_tensor &#61; torch::tensor(label);return {input_tensor,label_tensor };}///
/// LeNet实现类
///

class LeNet :public torch::nn::Module {
public:// 构造器LeNet(int num_classes, int num_linear);// 前向传播torch::Tensor forward(torch::Tensor x);
private:// 具体实现放到构造器实现中torch::nn::Conv2d conv1{ nullptr };torch::nn::Conv2d conv2{ nullptr };torch::nn::Linear fc1{ nullptr };torch::nn::Linear fc2{ nullptr };torch::nn::Linear fc3{ nullptr };
};LeNet::LeNet(int num_classes, int num_linear)
{conv1 &#61; register_module("conv1", torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 6, 5)));conv2 &#61; register_module("conv2", torch::nn::Conv2d(torch::nn::Conv2dOptions(6, 16, 5)));fc1 &#61; register_module("fc1", torch::nn::Linear(torch::nn::LinearOptions(num_linear, 128)));fc2 &#61; register_module("fc2", torch::nn::Linear(torch::nn::LinearOptions(128, 32)));fc3 &#61; register_module("fc3", torch::nn::Linear(torch::nn::LinearOptions(32, num_classes)));
}torch::Tensor LeNet::forward(torch::Tensor x)
{auto out &#61; torch::relu(conv1->forward(x));out &#61; torch::max_pool2d(out, 2);out &#61; torch::relu(conv2(out));out &#61; torch::max_pool2d(out, 2);out &#61; out.view({ 1, -1 });out &#61; torch::relu(fc1(out));out &#61; torch::relu(fc2(out));out &#61; fc3(out);return out;
}int main()
{try{map<string, int> dict_label;dict_label.insert(pair<string, int>("ants", 0));dict_label.insert(pair<string, int>("bees", 1));// 设置datasetauto dataset_train &#61; MyDataset("D:\\dataset\\hymenoptera_data\\train", dict_label).map(torch::data::transforms::Stack<>());// batchszieint batchSize &#61; 1;// 设置dataloaderauto dataLoader &#61; torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(std::move(dataset_train), batchSize);// 打印//for (auto& batch : * dataLoader) {// auto data &#61; batch.data;// auto target &#61; batch.target;// std::cout <// //std::cout <// //std::cout <// std::cout <// int ssss;// cin >> ssss;//}//auto net &#61; LeNet(5, 44944);std::shared_ptr<LeNet> net &#61; std::make_shared<LeNet>(2, 44944);// 优化器torch::optim::SGD optimizer(net->parameters(), /*lr&#61;*/0.01);for (size_t epoch &#61; 1; epoch <&#61; 10; &#43;&#43;epoch) {size_t batch_index &#61; 0;// 遍历数据集for (auto& batch : *dataLoader) {// 梯度清零.optimizer.zero_grad();// 前向传播torch::Tensor prediction &#61; net->forward(batch.data);cout << "prediction:" << prediction << endl;cout << "target:" << batch.target << endl;// 计算损失torch::Tensor loss &#61; torch::nll_loss(prediction, batch.target);cout <<"loss:" << loss << endl;// 反向传播loss.backward();// 更新梯度optimizer.step();// 间隔 x batch 进行loss打印和模型保存if (&#43;&#43;batch_index % 20 &#61;&#61; 0) {std::cout << "Epoch: " << epoch << " | Batch: " << batch_index<< " | Loss: " << loss << std::endl;// 保存模型torch::save(net, "net.pt");cout << net->parameters() << endl;}}}}catch (const std::exception& e){// step5:打印报错cout << e.what() << endl;}return 0;
}

在这里插入图片描述


推荐阅读
  • 开发笔记:加密&json&StringIO模块&BytesIO模块
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了加密&json&StringIO模块&BytesIO模块相关的知识,希望对你有一定的参考价值。一、加密加密 ... [详细]
  • 本文介绍了Redis的基础数据结构string的应用场景,并以面试的形式进行问答讲解,帮助读者更好地理解和应用Redis。同时,描述了一位面试者的心理状态和面试官的行为。 ... [详细]
  • 本文主要解析了Open judge C16H问题中涉及到的Magical Balls的快速幂和逆元算法,并给出了问题的解析和解决方法。详细介绍了问题的背景和规则,并给出了相应的算法解析和实现步骤。通过本文的解析,读者可以更好地理解和解决Open judge C16H问题中的Magical Balls部分。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 个人学习使用:谨慎参考1Client类importcom.thoughtworks.gauge.Step;importcom.thoughtworks.gauge.T ... [详细]
  • CF:3D City Model(小思维)问题解析和代码实现
    本文通过解析CF:3D City Model问题,介绍了问题的背景和要求,并给出了相应的代码实现。该问题涉及到在一个矩形的网格上建造城市的情景,每个网格单元可以作为建筑的基础,建筑由多个立方体叠加而成。文章详细讲解了问题的解决思路,并给出了相应的代码实现供读者参考。 ... [详细]
  • springmvc学习笔记(十):控制器业务方法中通过注解实现封装Javabean接收表单提交的数据
    本文介绍了在springmvc学习笔记系列的第十篇中,控制器的业务方法中如何通过注解实现封装Javabean来接收表单提交的数据。同时还讨论了当有多个注册表单且字段完全相同时,如何将其交给同一个控制器处理。 ... [详细]
  • 本文介绍了南邮ctf-web的writeup,包括签到题和md5 collision。在CTF比赛和渗透测试中,可以通过查看源代码、代码注释、页面隐藏元素、超链接和HTTP响应头部来寻找flag或提示信息。利用PHP弱类型,可以发现md5('QNKCDZO')='0e830400451993494058024219903391'和md5('240610708')='0e462097431906509019562988736854'。 ... [详细]
  • 前景:当UI一个查询条件为多项选择,或录入多个条件的时候,比如查询所有名称里面包含以下动态条件,需要模糊查询里面每一项时比如是这样一个数组条件:newstring[]{兴业银行, ... [详细]
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • 本文详细介绍了如何使用MySQL来显示SQL语句的执行时间,并通过MySQL Query Profiler获取CPU和内存使用量以及系统锁和表锁的时间。同时介绍了效能分析的三种方法:瓶颈分析、工作负载分析和基于比率的分析。 ... [详细]
  • Ihavethefollowingonhtml我在html上有以下内容<html><head><scriptsrc..3003_Tes ... [详细]
  • Imtryingtofigureoutawaytogeneratetorrentfilesfromabucket,usingtheAWSSDKforGo.我正 ... [详细]
  • 本文介绍了Swing组件的用法,重点讲解了图标接口的定义和创建方法。图标接口用来将图标与各种组件相关联,可以是简单的绘画或使用磁盘上的GIF格式图像。文章详细介绍了图标接口的属性和绘制方法,并给出了一个菱形图标的实现示例。该示例可以配置图标的尺寸、颜色和填充状态。 ... [详细]
  • 网址:https:vue.docschina.orgv2guideforms.html表单input绑定基础用法可以通过使用v-model指令,在 ... [详细]
author-avatar
jackystorm岁月_657
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有