LibTorch之优化器
SGD
torch::optim::SGD optimizer(net->parameters(), 0.01);
官方案例使用
#include
torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
struct Net : torch::nn::Module {Net() {fc1 &#61; register_module("fc1", torch::nn::Linear(784, 64));fc2 &#61; register_module("fc2", torch::nn::Linear(64, 32));fc3 &#61; register_module("fc3", torch::nn::Linear(32, 10));}torch::Tensor forward(torch::Tensor x) {x &#61; torch::relu(fc1->forward(x.reshape({x.size(0), 784})));x &#61; torch::dropout(x, 0.5, is_training());x &#61; torch::relu(fc2->forward(x));x &#61; torch::log_softmax(fc3->forward(x), 1);return x;}};int main() {auto net &#61; std::make_shared<Net>();auto data_loader &#61; torch::data::make_data_loader(torch::data::datasets::MNIST("./data").map(torch::data::transforms::Stack<>()),64);torch::optim::SGD optimizer(net->parameters(), 0.01);for (size_t epoch &#61; 1; epoch <&#61; 10; &#43;&#43;epoch) {size_t batch_index &#61; 0;for (auto& batch : *data_loader) {optimizer.zero_grad();torch::Tensor prediction &#61; net->forward(batch.data);torch::Tensor loss &#61; torch::nll_loss(prediction, batch.target);loss.backward();optimizer.step();if (&#43;&#43;batch_index % 100 &#61;&#61; 0) {std::cout << "Epoch: " << epoch << " | Batch: " << batch_index<< " | Loss: " << loss.item<float>() << std::endl;torch::save(net, "net.pt");}}}
}