图像分类训练设计图像分类训练设计图像分类训练设计
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim
norm_mean = [0.33424968,0.33424437, 0.33428448]
norm_std = [0.24796878, 0.24796101, 0.24801227]train_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std),
])train_data_path = r"D:\PycharmProjects\AI_Easy_Demo\MyData\split_data\train"
from MyDataset.Cifar10_Dataset import LoadDataset
train_dataset = LoadDataset(data_dir=r"D:\PycharmProjects\AI_Easy_Demo\MyData\split_data\train",transform=train_transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=True)
from MyNet.ResNet import ResNet34
net = ResNet34(num_classes=10, num_linear=512)
criterion = nn.NLLLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
MAX_EPOCH=200
best = [0] for i in range(MAX_EPOCH):print("当前轮转次数:",i+1)for idx, data_info in enumerate(train_loader):inputs, labels = data_infooutputs = net(inputs)outputs = F.log_softmax(outputs, dim=1)optimizer.zero_grad() loss = criterion(outputs,labels)loss.backward() print(loss)optimizer.step()