热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pytorch实现LSTMAutoEncoder与案例

文章目录LSTMAutoEncoder简介LSTMAutoEncoder实现基本LSTMAutoEncoder网络结构LSTMFcAutoEncoder网络结构案例代码LSTMAu


文章目录

  • LSTM AutoEncoder简介
  • LSTM AutoEncoder 实现
    • 基本LSTM AutoEncoder网络结构
    • LSTM+Fc AutoEncoder网络结构
  • 案例代码


LSTM AutoEncoder简介

基础的AutoEncoder可以参考:https://blog.csdn.net/weixin_35757704/article/details/118457110

LSTM AutoEncoder是将原始的全连接变成了LSTM,然后构造出来的AutoEncoder模型,输入与输出是一样的数据为最佳


LSTM AutoEncoder 实现

博主发现网上对于LSTM AutoEncoder的版本都不一样,通常来讲有:


  1. encoder与decoder都是:lstm
  2. encoder是 lstm + fc ; decoder是 fc + lstm

以下是两种网络架构:


基本LSTM AutoEncoder网络结构

这个结构比较简单,就是encoder的时候过一个lstm,decoder的时候再过一个lstm

class LstmAutoEncoder(nn.Module):def __init__(self, input_layer=300, hidden_layer=100, batch_size=20):super(LstmAutoEncoder, self).__init__()self.input_layer = input_layerself.hidden_layer = hidden_layerself.batch_size = batch_sizeself.encoder_lstm = nn.LSTM(self.input_layer, self.hidden_layer, batch_first=True)self.decoder_lstm = nn.LSTM(self.hidden_layer, self.input_layer, batch_first=True)def forward(self, input_x):input_x = input_x.view(len(input_x), 1, -1)# encoderencoder_lstm, (n, c) = self.encoder_lstm(input_x,(torch.zeros(1, self.batch_size, self.hidden_layer),torch.zeros(1, self.batch_size, self.hidden_layer)))# decoderdecoder_lstm, (n, c) = self.decoder_lstm(encoder_lstm,(torch.zeros(1, self.batch_size, self.input_layer),torch.zeros(1, self.batch_size, self.input_layer)))return decoder_lstm.squeeze()

LSTM+Fc AutoEncoder网络结构

这个网络结构就是:


  • 在encoder的时候过一个lstm,然后接一个全连接,最后用relu激活函数;
  • 在decoder的时候先过全连接,然后用relu的激活函数,最后接lstm

class LstmFcAutoEncoder(nn.Module):def __init__(self, input_layer=300, hidden_layer=100, batch_size=20):super(LstmFcAutoEncoder, self).__init__()self.input_layer = input_layerself.hidden_layer = hidden_layerself.batch_size = batch_sizeself.encoder_lstm = nn.LSTM(self.input_layer, self.hidden_layer, batch_first=True)self.encoder_fc = nn.Linear(self.hidden_layer, self.hidden_layer)self.decoder_lstm = nn.LSTM(self.hidden_layer, self.input_layer, batch_first=True)self.decoder_fc = nn.Linear(self.hidden_layer, self.hidden_layer)self.relu = nn.ReLU()def forward(self, input_x):input_x = input_x.view(len(input_x), 1, -1)# encoderencoder_lstm, (n, c) = self.encoder_lstm(input_x,# shape: (n_layers, batch, hidden_size)(torch.zeros(1, self.batch_size, self.hidden_layer),torch.zeros(1, self.batch_size, self.hidden_layer)))encoder_fc = self.encoder_fc(encoder_lstm)encoder_out = self.relu(encoder_fc)# decoderdecoder_fc = self.relu(self.decoder_fc(encoder_out))decoder_lstm, (n, c) = self.decoder_lstm(decoder_fc,(torch.zeros(1, 20, self.input_layer),torch.zeros(1, 20, self.input_layer)))return decoder_lstm.squeeze()

案例代码

import torch
import torch.nn as nn
import torch.utils.data as Datadef get_train_data():"""得到训练数据,这里使用随机数生成训练数据,由此导致最终结果并不好"""def get_tensor_from_pd(dataframe_series) -> torch.Tensor:return torch.tensor(data=dataframe_series.values)import numpy as npimport pandas as pdfrom sklearn import preprocessing# 生成训练数据x并做归一化后,构造成dataframe格式,再转换为tensor格式df = pd.DataFrame(data=preprocessing.MinMaxScaler().fit_transform(np.random.randint(0, 10, size=(2000, 300))))y = pd.Series(np.random.randint(0, 2, 2000))return get_tensor_from_pd(df).float(), get_tensor_from_pd(y).float()class LstmAutoEncoder(nn.Module):def __init__(self, input_layer=300, hidden_layer=100, batch_size=20):super(LstmAutoEncoder, self).__init__()self.input_layer = input_layerself.hidden_layer = hidden_layerself.batch_size = batch_sizeself.encoder_lstm = nn.LSTM(self.input_layer, self.hidden_layer, batch_first=True)self.decoder_lstm = nn.LSTM(self.hidden_layer, self.input_layer, batch_first=True)def forward(self, input_x):input_x = input_x.view(len(input_x), 1, -1)# encoderencoder_lstm, (n, c) = self.encoder_lstm(input_x,(torch.zeros(1, self.batch_size, self.hidden_layer),torch.zeros(1, self.batch_size, self.hidden_layer)))# decoderdecoder_lstm, (n, c) = self.decoder_lstm(encoder_lstm,(torch.zeros(1, self.batch_size, self.input_layer),torch.zeros(1, self.batch_size, self.input_layer)))return decoder_lstm.squeeze()class LstmFcAutoEncoder(nn.Module):def __init__(self, input_layer=300, hidden_layer=100, batch_size=20):super(LstmFcAutoEncoder, self).__init__()self.input_layer = input_layerself.hidden_layer = hidden_layerself.batch_size = batch_sizeself.encoder_lstm = nn.LSTM(self.input_layer, self.hidden_layer, batch_first=True)self.encoder_fc = nn.Linear(self.hidden_layer, self.hidden_layer)self.decoder_lstm = nn.LSTM(self.hidden_layer, self.input_layer, batch_first=True)self.decoder_fc = nn.Linear(self.hidden_layer, self.hidden_layer)self.relu = nn.ReLU()def forward(self, input_x):input_x = input_x.view(len(input_x), 1, -1)# encoderencoder_lstm, (n, c) = self.encoder_lstm(input_x,# shape: (n_layers, batch, hidden_size)(torch.zeros(1, self.batch_size, self.hidden_layer),torch.zeros(1, self.batch_size, self.hidden_layer)))encoder_fc = self.encoder_fc(encoder_lstm)encoder_out = self.relu(encoder_fc)# decoderdecoder_fc = self.relu(self.decoder_fc(encoder_out))decoder_lstm, (n, c) = self.decoder_lstm(decoder_fc,(torch.zeros(1, 20, self.input_layer),torch.zeros(1, 20, self.input_layer)))return decoder_lstm.squeeze()if __name__ == '__main__':# 得到数据x, y = get_train_data()train_loader = Data.DataLoader(dataset=Data.TensorDataset(x, y), # 封装进Data.TensorDataset()类的数据,可以为任意维度batch_size=20, # 每块的大小shuffle=True, # 要不要打乱数据 (打乱比较好)num_workers=2, # 多进程(multiprocess)来读数据)# 建模三件套:loss,优化,epochsmodel = LstmAutoEncoder() # lstm# model = LstmFcAutoEncoder() # lstm+fc模型loss_function = nn.MSELoss() # lossoptimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 优化器epochs = 150# 开始训练model.train()for i in range(epochs):for seq, labels in train_loader:optimizer.zero_grad()y_pred = model(seq).squeeze() # 压缩维度:得到输出,并将维度为1的去除single_loss = loss_function(y_pred, seq)# 若想要获得类别,二分类问题使用四舍五入的方法即可:print(torch.round(y_pred))single_loss.backward()optimizer.step()print("Train Step:", i, " loss: ", single_loss)# 每20次,输出一次前20个的结果,对比一下效果if i % 20 == 0:test_data = x[:20]y_pred = model(test_data).squeeze() # 压缩维度:得到输出,并将维度为1的去除print("TEST: ", test_data)print("PRED: ", y_pred)print("LOSS: ", loss_function(y_pred, test_data))

推荐阅读
  • 语义分割系列3SegNet(pytorch实现)
    SegNet手稿最早是在2015年12月投出,和FCN属于同时期作品。稍晚于FCN,既然属于后来者,又是与FCN同属于语义分割网络 ... [详细]
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • 开源Keras Faster RCNN模型介绍及代码结构解析
    本文介绍了开源Keras Faster RCNN模型的环境需求和代码结构,包括FasterRCNN源码解析、RPN与classifier定义、data_generators.py文件的功能以及损失计算。同时提供了该模型的开源地址和安装所需的库。 ... [详细]
  • 很多时候在注册一些比较重要的帐号,或者使用一些比较重要的接口的时候,需要使用到随机字符串,为了方便,我们设计这个脚本需要注意 ... [详细]
  • 本文介绍了九度OnlineJudge中的1002题目“Grading”的解决方法。该题目要求设计一个公平的评分过程,将每个考题分配给3个独立的专家,如果他们的评分不一致,则需要请一位裁判做出最终决定。文章详细描述了评分规则,并给出了解决该问题的程序。 ... [详细]
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • 逻辑回归_训练二元分类器#训练一个二元分类器fromsklearn.linear_modelimportLogisticRegressionfromsklearnimport ... [详细]
  • baresip android编译、运行教程1语音通话
    本文介绍了如何在安卓平台上编译和运行baresip android,包括下载相关的sdk和ndk,修改ndk路径和输出目录,以及创建一个c++的安卓工程并将目录考到cpp下。详细步骤可参考给出的链接和文档。 ... [详细]
  • 本文介绍了计算机网络的定义和通信流程,包括客户端编译文件、二进制转换、三层路由设备等。同时,还介绍了计算机网络中常用的关键词,如MAC地址和IP地址。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 【shell】网络处理:判断IP是否在网段、两个ip是否同网段、IP地址范围、网段包含关系
    本文介绍了使用shell脚本判断IP是否在同一网段、判断IP地址是否在某个范围内、计算IP地址范围、判断网段之间的包含关系的方法和原理。通过对IP和掩码进行与计算,可以判断两个IP是否在同一网段。同时,还提供了一段用于验证IP地址的正则表达式和判断特殊IP地址的方法。 ... [详细]
  • 网址:https:vue.docschina.orgv2guideforms.html表单input绑定基础用法可以通过使用v-model指令,在 ... [详细]
  • Python教学练习二Python1-12练习二一、判断季节用户输入月份,判断这个月是哪个季节?3,4,5月----春 ... [详细]
  • Python 可视化 | Seaborn5 分钟入门 (六)——heatmap 热力图
    微信公众号:「Python读财」如有问题或建议,请公众号留言Seaborn是基于matplotlib的Python可视化库。它提供了一个高级界面来绘制有吸引力的统计图形。Seabo ... [详细]
  • LwebandStringTimeLimit:20001000MS(JavaOthers)MemoryLimit:6553665536K(JavaO ... [详细]
author-avatar
轰炸籹厕所744
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有