热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pytorch实战3

1、tensornumpyimporttorchimportnumpyasnpxnp.array([[1,2],[3,4]])#将numpy换成torchtensor

1、tensor<--->numpy

import torch
import numpy as npx &#61; np.array([[1,2],[3,4]])# 将numpy换成torch tensor
y &#61; torch.from_numpy(x)# 将torch tensor换成numpy
z &#61; y.numpy()


2、数据下载

import torch
import torchvision
from torchvision import transforms# 下载CIFAR-10数据集
train_dataset &#61; torchvision.datasets.CIFAR10(root&#61;&#39;../../data/&#39;,train&#61;True,transform&#61;transforms.ToTensor(),download&#61;True
)# 取其中一组数据
image,label &#61; train_dataset[0]
print(image.size())
print(label)train_loader &#61; torch.utils.data.DataLoader(dataset &#61; train_dataset,batch_size &#61; 64,shuffle &#61; True
)data_iter &#61; iter(train_loader)images,labels &#61; data_iter.next()for images,labels in train_loader:pass


3、梯度理解案例1

import torch
from torch.autograd import Variable# 创建tensor
x &#61; torch.tensor(1.,requires_grad&#61;True)
w &#61; torch.tensor(2.,requires_grad&#61;True)
b &#61; torch.tensor(3.,requires_grad&#61;True)# 创建一个计算图
y &#61; w * x &#43; b# 计算梯度
y.backward()# 输出梯度
print(w.grad) # --> x&#61;1
print(x.grad) # --> w&#61;2
print(b.grad) # --> 1&#39;&#39;&#39;
计算谁的梯度就是将谁看作变量&#xff0c;其他全部看成常量
&#39;&#39;&#39;


4、梯度理解案例2

import torch
from torch.autograd import Variable# 输入
x &#61; torch.randn(10,3)
# 输出
y &#61; torch.randn(10,2)# 线性回归
linear &#61; torch.nn.Linear(3,2)
print(linear.weight)
print(linear.bias)# 损失
loss_fn &#61; torch.nn.MSELoss()# 学习率
learning_rate &#61; 1e-3# 训练次数
epoch_n &#61; 100# 优化器
optimiazer &#61; torch.optim.Adam(linear.parameters(),learning_rate)for epoch in range(epoch_n):y_pred &#61; linear(x)loss &#61; loss_fn(y,y_pred)print("Eopch:{},Loss:{:.4f}".format(epoch,loss))linear.zero_grad()loss.backward()optimiazer.step()


5、预训练模型

import torchvision
import torch# 下载并加载预训练的ResNet-18
resnet &#61; torchvision.models.resnet18(pretrained&#61;True)# 如果你只想对模型的顶层进行微调&#xff0c;请设置如下
for param in resnet.parameters():param.requires_grad &#61; False# 替换顶层进行微调
resnet.fc &#61; torch.nn.Linear(resnet.fc.in_features,100) # 100是个例子images &#61; torch.randn(64,3,224,224)
outputs &#61; resnet(images)print(outputs.size())# 保存以及载入模型
torch.save(resnet,&#39;model.ckpt&#39;)
model &#61; torch.load(&#39;model.ckpt&#39;)# 仅仅保存和载入模型参数
torch.save(resnet.state_dict(),&#39;params.ckpt&#39;)
resnet.load_state_dict(torch.load(&#39;params.ckpt&#39;))


 


推荐阅读
author-avatar
手机用户2502901575_836
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有