热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

LibTorch之优化器

LibTorch之优化器SGDtorch::optim::SGDoptimizer(net-parameters(),*lr*0.01);官方案例使用#include

LibTorch之优化器


SGD

torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01);

官方案例使用

#include
// Use one of many "standard library" modules.
torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};// Define a new Module.
struct Net : torch::nn::Module {Net() {// Construct and register two Linear submodules.fc1 &#61; register_module("fc1", torch::nn::Linear(784, 64));fc2 &#61; register_module("fc2", torch::nn::Linear(64, 32));fc3 &#61; register_module("fc3", torch::nn::Linear(32, 10));}// Implement the Net&#39;s algorithm.torch::Tensor forward(torch::Tensor x) {// Use one of many tensor manipulation functions.x &#61; torch::relu(fc1->forward(x.reshape({x.size(0), 784})));x &#61; torch::dropout(x, /*p&#61;*/0.5, /*train&#61;*/is_training());x &#61; torch::relu(fc2->forward(x));x &#61; torch::log_softmax(fc3->forward(x), /*dim&#61;*/1);return x;}};int main() {// Create a new Net.auto net &#61; std::make_shared<Net>();// Create a multi-threaded data loader for the MNIST dataset.auto data_loader &#61; torch::data::make_data_loader(torch::data::datasets::MNIST("./data").map(torch::data::transforms::Stack<>()),/*batch_size&#61;*/64);// Instantiate an SGD optimization algorithm to update our Net&#39;s parameters.torch::optim::SGD optimizer(net->parameters(), /*lr&#61;*/0.01);for (size_t epoch &#61; 1; epoch <&#61; 10; &#43;&#43;epoch) {size_t batch_index &#61; 0;// Iterate the data loader to yield batches from the dataset.for (auto& batch : *data_loader) {// Reset gradients.optimizer.zero_grad();// Execute the model on the input data.torch::Tensor prediction &#61; net->forward(batch.data);// Compute a loss value to judge the prediction of our model.torch::Tensor loss &#61; torch::nll_loss(prediction, batch.target);// Compute gradients of the loss w.r.t. the parameters of our model.loss.backward();// Update the parameters based on the calculated gradients.optimizer.step();// Output the loss and checkpoint every 100 batches.if (&#43;&#43;batch_index % 100 &#61;&#61; 0) {std::cout << "Epoch: " << epoch << " | Batch: " << batch_index<< " | Loss: " << loss.item<float>() << std::endl;// Serialize your model periodically as a checkpoint.torch::save(net, "net.pt");}}}
}


推荐阅读
  • 本题探讨了在大数据结构背景下,如何通过整体二分和CDQ分治等高级算法优化处理复杂的时间序列问题。题目设定包括节点数量、查询次数和权重限制,并详细分析了解决方案中的关键步骤。 ... [详细]
  • 在编译BSP包过程中,遇到了一个与 'gets' 函数相关的编译错误。该问题通常发生在较新的编译环境中,由于 'gets' 函数已被弃用并视为安全漏洞。本文将详细介绍如何通过修改源代码和配置文件来解决这一问题。 ... [详细]
  • 丽江客栈选择问题
    本文介绍了一道经典的算法题,题目涉及在丽江河边的n家特色客栈中选择住宿方案。两位游客希望住在色调相同的两家客栈,并在晚上选择一家最低消费不超过p元的咖啡店小聚。我们将详细探讨如何计算满足条件的住宿方案总数。 ... [详细]
  • 在进行QT交叉编译时,可能会遇到与目标架构不匹配的宏定义问题。例如,当为ARM或MIPS架构编译时,需要确保使用正确的宏(如QT_ARCH_ARM或QT_ARCH_MIPS),而不是默认的QT_ARCH_I386。本文将详细介绍如何正确配置编译环境以避免此类错误。 ... [详细]
  • 本文介绍了如何在多线程环境中实现异步任务的事务控制,确保任务执行的一致性和可靠性。通过使用计数器和异常标记字段,系统能够准确判断所有异步线程的执行结果,并根据结果决定是否回滚或提交事务。 ... [详细]
  • 深入解析Java枚举及其高级特性
    本文详细介绍了Java枚举的概念、语法、使用规则和应用场景,并探讨了其在实际编程中的高级应用。所有相关内容已收录于GitHub仓库[JavaLearningmanual](https://github.com/Ziphtracks/JavaLearningmanual),欢迎Star并持续关注。 ... [详细]
  • 本题来自WC2014,题目编号为BZOJ3435、洛谷P3920和UOJ55。该问题描述了一棵不断生长的带权树及其节点上小精灵之间的友谊关系,要求实时计算每次新增节点后树上所有可能的朋友对数。 ... [详细]
  • 本文探讨了如何在 F# Interactive (FSI) 中通过 AddPrinter 和 AddPrintTransformer 方法自定义类型(尤其是集合类型)的输出格式,提供了详细的指南和示例代码。 ... [详细]
  • 本文介绍如何从字符串中移除大写、小写、特殊、数字和非数字字符,并提供了多种编程语言的实现示例。 ... [详细]
  • Linux环境下C语言实现定时向文件写入当前时间
    本文介绍如何在Linux系统中使用C语言编程,实现在每秒钟向指定文件中写入当前时间戳。通过此示例,读者可以了解基本的文件操作、时间处理以及循环控制。 ... [详细]
  • 在高并发需求的C++项目中,我们最初选择了JsonCpp进行JSON解析和序列化。然而,在处理大数据量时,JsonCpp频繁抛出异常,尤其是在多线程环境下问题更为突出。通过分析发现,旧版本的JsonCpp存在多线程安全性和性能瓶颈。经过评估,我们最终选择了RapidJSON作为替代方案,并实现了显著的性能提升。 ... [详细]
  • 深入解析Spring启动过程
    本文详细介绍了Spring框架的启动流程,帮助开发者理解其内部机制。通过具体示例和代码片段,解释了Bean定义、工厂类、读取器以及条件评估等关键概念,使读者能够更全面地掌握Spring的初始化过程。 ... [详细]
  • 本文介绍了如何在 C# 和 XNA 框架中实现一个自定义的 3x3 矩阵类(MMatrix33),旨在深入理解矩阵运算及其应用场景。该类参考了 AS3 Starling 和其他相关资源,以确保算法的准确性和高效性。 ... [详细]
  • 在尝试使用C# Windows Forms客户端通过SignalR连接到ASP.NET服务器时,遇到了内部服务器错误(500)。本文将详细探讨问题的原因及解决方案。 ... [详细]
  • Linux环境下进程间通信:深入解析信号机制
    本文详细探讨了Linux系统中信号的生命周期,从信号生成到处理函数执行完毕的全过程,并介绍了信号编程中的注意事项和常见应用实例。通过分析信号在进程中的注册、注销及处理过程,帮助读者理解如何高效利用信号进行进程间通信。 ... [详细]
author-avatar
明睿崇
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有