热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

LibTorch之优化器

LibTorch之优化器SGDtorch::optim::SGDoptimizer(net-parameters(),*lr*0.01);官方案例使用#include

LibTorch之优化器


SGD

torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01);

官方案例使用

#include
// Use one of many "standard library" modules.
torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};// Define a new Module.
struct Net : torch::nn::Module {Net() {// Construct and register two Linear submodules.fc1 &#61; register_module("fc1", torch::nn::Linear(784, 64));fc2 &#61; register_module("fc2", torch::nn::Linear(64, 32));fc3 &#61; register_module("fc3", torch::nn::Linear(32, 10));}// Implement the Net&#39;s algorithm.torch::Tensor forward(torch::Tensor x) {// Use one of many tensor manipulation functions.x &#61; torch::relu(fc1->forward(x.reshape({x.size(0), 784})));x &#61; torch::dropout(x, /*p&#61;*/0.5, /*train&#61;*/is_training());x &#61; torch::relu(fc2->forward(x));x &#61; torch::log_softmax(fc3->forward(x), /*dim&#61;*/1);return x;}};int main() {// Create a new Net.auto net &#61; std::make_shared<Net>();// Create a multi-threaded data loader for the MNIST dataset.auto data_loader &#61; torch::data::make_data_loader(torch::data::datasets::MNIST("./data").map(torch::data::transforms::Stack<>()),/*batch_size&#61;*/64);// Instantiate an SGD optimization algorithm to update our Net&#39;s parameters.torch::optim::SGD optimizer(net->parameters(), /*lr&#61;*/0.01);for (size_t epoch &#61; 1; epoch <&#61; 10; &#43;&#43;epoch) {size_t batch_index &#61; 0;// Iterate the data loader to yield batches from the dataset.for (auto& batch : *data_loader) {// Reset gradients.optimizer.zero_grad();// Execute the model on the input data.torch::Tensor prediction &#61; net->forward(batch.data);// Compute a loss value to judge the prediction of our model.torch::Tensor loss &#61; torch::nll_loss(prediction, batch.target);// Compute gradients of the loss w.r.t. the parameters of our model.loss.backward();// Update the parameters based on the calculated gradients.optimizer.step();// Output the loss and checkpoint every 100 batches.if (&#43;&#43;batch_index % 100 &#61;&#61; 0) {std::cout << "Epoch: " << epoch << " | Batch: " << batch_index<< " | Loss: " << loss.item<float>() << std::endl;// Serialize your model periodically as a checkpoint.torch::save(net, "net.pt");}}}
}


推荐阅读
author-avatar
明睿崇
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有