热门标签 | HotTags
当前位置:  开发笔记 > 人工智能 > 正文

pytorch实现mnist数据集的图像可视化及保存

今天小编就为大家分享一篇pytorch实现mnist数据集的图像可视化及保存,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

如何将pytorch中mnist数据集的图像可视化及保存

导出一些库

import torch
import torchvision 
import torch.utils.data as Data 
import scipy.misc
import os
import matplotlib.pyplot as plt   
BATCH_SIZE = 50  
DOWNLOAD_MNIST = True 

数据集的准备

#训练集测试集的准备

train_data = torchvision.datasets.MNIST(root='./mnist/', train=True,transform=torchvision.transforms.ToTensor(),              
  download=DOWNLOAD_MNIST, )
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)

将训练及测试集利用dataloader进行迭代

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), requires_grad=True).type(torch.FloatTensor)[:20]/255 
test_y = test_data.test_labels[:20]#前两千张
 #具体查看图像形式为:
 
a_data, a_label = train_data[0]
print(type(a_data))#tensor 类型
#print(a_data)
print(a_label)

#把原始图片保存至MNIST_data/raw/下
save_dir="mnist/raw/"
if os.path.exists(save_dir)is False:
 os.makedirs(save_dir)
 
for i in range(20):
 image_array,_=train_data[i]#打印第i个
 image_array=image_array.resize(28,28)
 filename=save_dir + 'mnist_train_%d.jpg' % i#保存文件的格式
 print(filename)
 print(train_data.train_labels[i])#打印出标签
 scipy.misc.toimage(image_array,cmin=0.0,cmax=1.0).save(filename)#保存图像

以上这篇pytorch实现mnist数据集的图像可视化及保存就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。


推荐阅读
  • 在 PyTorch 的 `CrossEntropyLoss` 函数中,当目标标签 `target` 为类别 ID 时,实际上会进行 one-hot 编码处理。例如,假设总共有三个类别,其中一个类别的 ID 为 2,则该标签会被转换为 `[0, 0, 1]`。这一过程简化了多分类任务中的损失计算,使得模型能够更高效地进行训练和评估。此外,`CrossEntropyLoss` 还结合了 softmax 激活函数和负对数似然损失,进一步提高了模型的性能和稳定性。 ... [详细]
  • 本文探讨了BERT模型在自然语言处理领域的应用与实践。详细介绍了Transformers库(曾用名pytorch-transformers和pytorch-pretrained-bert)的使用方法,涵盖了从模型加载到微调的各个环节。此外,还分析了BERT在文本分类、情感分析和命名实体识别等任务中的性能表现,并讨论了其在实际项目中的优势和局限性。 ... [详细]
  • 利用 PyTorch 实现 Python 中的高效矩阵运算 ... [详细]
  • 本文介绍了一款高效的开源OCR文本识别模型,结合了TextBoxes++和RetinaNet的优势。该模型在文本检测方面表现出色,适用于多种场景。项目代码已托管至GitHub,方便研究人员和开发者使用和改进。 ... [详细]
  • 在上一节中,我们完成了网络的前向传播实现。本节将重点探讨如何为检测输出设定目标置信度阈值,并应用非极大值抑制技术以提高检测精度。为了更好地理解和实践这些内容,建议读者已经完成本系列教程的前三部分,并具备一定的PyTorch基础知识。此外,我们将详细介绍这些技术的原理及其在实际应用中的重要性,帮助读者深入理解目标检测算法的核心机制。 ... [详细]
  • 在 PyTorch 中,`pin_memory` 技术用于锁定页面内存。当在创建 `DataLoader` 时将 `pin_memory` 参数设置为 `True`,这意味着生成的 Tensor 数据最初会被存储在锁定的内存中。这一技术能够显著提高数据从 CPU 到 GPU 的传输效率,从而加快训练速度。通过合理利用 `pin_memory`,可以有效减少数据加载的瓶颈,提升整体性能。 ... [详细]
  • 谷歌工程师:TensorFlow已重获新生;网友:我还是用PyTorch
    乾明发自凹非寺量子位报道|公众号QbitAI道友留步!TensorFlow已重获新生。在“PyTorch真香”的潮流中,有人站出来为TensorFlow说话了。这次来自谷歌的工程师 ... [详细]
  • 1.如何进行迁移 使用Pytorch写的模型: 对模型和相应的数据使用.cuda()处理。通过这种方式,我们就可以将内存中的数据复制到GPU的显存中去。 ... [详细]
  • 5.Numpy 索引(一维索引/二维索引)
    本文内容是根据莫烦Python网站的视频整理的笔记,笔记中对代码的注释更加清晰明了,同时根据所有笔记还整理了精简版的思维导图,可在此专栏查看,想观看视频可直接去他的网 ... [详细]
  • python教程分享Pytorchmlu 实现添加逐层算子方法详解
    目录1、注册算子2、算子分发3、修改opmethods基类4、下发算子5、添加wrapper6、添加wrapper7、算子测试本教程分享了在寒武纪设备上pytorch-mlu中添加 ... [详细]
  • [TensorFlow系列3]:初学者是选择Tensorflow2.x还是1.x? 2.x与1.x的主要区别?
    作者主页(文火冰糖的硅基工坊):https:blog.csdn.netHiWangWenBing本文网址:https:blog.csdn.netHiW ... [详细]
  • pytorch(网络模型训练)
    上一篇目录标题网络模型训练小插曲训练模型数据训练GPU训练第一种方式方式二:查看GPU信息完整模型验证网络模型训练小插曲区别importtorchatorch ... [详细]
  • 一、Transorboard使用(可视化工具)(观察模型不同阶段的数据状况)fromtorch.utils.tensorboardimportSummaryWriterfromPI ... [详细]
  • CBAM:卷积块注意模块
    CBAM:ConvolutionalBlockAttentionModule论文地址:https:arxiv.orgabs1807.06521简介:我们提出了 ... [详细]
  • 1.首先在终端中输入python进入python交互式环境2.接着输入 importtorch print(torch.__version__)#注意是 ... [详细]
author-avatar
拍友2502935047
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有