热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

使用Pytorch简单实现混合密度网络(MixtureDensityNetwork,MDN)

本文主要参考自:https:github.comsksq96pytorch-mdnblobmastermdn.ipynbhttps:blog.otoro.net201

本文主要参考自:
https://github.com/sksq96/pytorch-mdn/blob/master/mdn.ipynb
https://blog.otoro.net/2015/11/24/mixture-density-networks-with-tensorflow/?tdsourcetag=s_pctim_aiomsg


引言

我们知道,神经网络具有很强的拟合能力。比方说,假设我们要拟合如下一个带噪声的函数:y=7.0sin⁡(0.75x)+0.5x+ϵy=7.0 \sin (0.75 x)+0.5 x+\epsilon y=7.0sin(0.75x)+0.5x+ϵ

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchsummary import summaryimport numpy as np
import matplotlib.pyplot as plt
n_samples = 1000
epsilon = torch.randn(n_samples)
x_data = torch.linspace(-10, 10, n_samples)
y_data = 7*np.sin(0.75*x_data) + 0.5*x_data + epsilon
y_data, x_data = y_data.view(-1, 1), x_data.view(-1, 1)
plt.figure(figsize=(8, 8))
plt.scatter(x_data, y_data, alpha=0.4)
plt.show()

在这里插入图片描述

拟合这个显然是很容易的,随便构建一个由两个全连接层组成的网络就具有非线性的拟合能力:

n_input = 1
n_hidden = 20
n_output = 1
model = nn.Sequential(nn.Linear(n_input, n_hidden),nn.Tanh(),nn.Linear(n_hidden, n_output))
loss_fn = nn.MSELoss()
optimizer = torch.optim.RMSprop(model.parameters())

然后编写代码进行训练:

for epoch in range(3000):y_pred = model(x_data)loss = loss_fn(y_pred, y_data)optimizer.zero_grad()loss.backward()optimizer.step()if epoch % 1000 == 0:print(loss.data.tolist())x_test = torch.linspace(-15, 15, n_samples).view(-1, 1)
y_pred = model(x_test).data
plt.figure(figsize=(8, 8))
plt.scatter(x_data, y_data, alpha=0.4)
plt.scatter(x_test, y_pred, alpha=0.4, color='red')
plt.show()

结果如下:
在这里插入图片描述
可以看到,给定一个输入x,要求预测一个y的话,这种一对一的建模能力神经网络是很擅长的。但是实际工程中我们也很容易遇到这种情况,在这里,我们将x与y调换,模拟一对多的情形:x=7.0sin⁡(0.75y)+0.5y+εx=7.0 \sin (0.75 y)+0.5 y+\varepsilon x=7.0sin(0.75y)+0.5y+ε

y_data, x_data = x_data.view(-1, 1), y_data.view(-1, 1)

重新训练进行预测,结果如下:
在这里插入图片描述
可以看到网络在试图拟合同一x下各y的平均值。那么,有没有一种网络能够去拟合这种一对多的情况呢?


MDN

传统的神经网络:对于单一输入x,给出一个单一预测y
MDN:对于单一输入x,预测y的概率分布

具体来说,对于输入x,MDN的输出为服从混合高斯分布(Mixture Gaussian distributions),具体的输出值被建模为多个高斯随机值的和,这几个高斯分布的均值和标准差是不同的。形式化地,有:P(Y=y∣X=x)=∑k=0K−1Πk(x)ϕ(y,μk(x),σk(x)),∑k=0K−1Πk(x)=1P(Y=y \mid X=x)=\sum_{k=0}^{K-1} \Pi_{k}(x) \phi\left(y, \mu_{k}(x), \sigma_{k}(x)\right), \sum_{k=0}^{K-1} \Pi_{k}(x)=1 P(Y=yX=x)=k=0K1Πk(x)ϕ(y,μk(x),σk(x)),k=0K1Πk(x)=1 需要注意的是,这里各个高斯分布的参数Πk(x),μk(x),σk(x)\Pi_{k}(x), \mu_{k}(x), \sigma_{k}(x)Πk(x),μk(x),σk(x)是取决于输入xxx的,也就是要通过网络训练去预测得到。实际上,这依然可以通过全连接层来搞定,接下来我们去介绍怎么实现MDN。


实现

class MDN(nn.Module):def __init__(self, n_hidden, n_gaussians):super(MDN, self).__init__()self.z_h = nn.Sequential(nn.Linear(1, n_hidden),nn.Tanh())self.z_pi = nn.Linear(n_hidden, n_gaussians)self.z_mu = nn.Linear(n_hidden, n_gaussians)self.z_sigma = nn.Linear(n_hidden, n_gaussians)def forward(self, x):z_h = self.z_h(x)pi = F.softmax(self.z_pi(z_h), -1)mu = self.z_mu(z_h)sigma = torch.exp(self.z_sigma(z_h))return pi, mu, sigma

混个高斯分布由多少个子高斯分布构成属于超参数,这里我们设计为5个:

model = MDN(n_hidden=20, n_gaussians=5)

然后是损失函数的设计。由于输出本质上是概率分布,因此不能采用诸如L1损失、L2损失的硬损失函数。这里我们采用了对数似然损失(和交叉熵类似):CostFunction⁡(y∣x)=−log⁡[∑kKΠk(x)ϕ(y,μ(x),σ(x))]\operatorname{CostFunction}(y \mid x)=-\log \left[\sum_{k}^{K} \Pi_{k}(x) \phi(y, \mu(x), \sigma(x))\right]CostFunction(yx)=log[kKΠk(x)ϕ(y,μ(x),σ(x))]

def mdn_loss_fn(y, mu, sigma, pi):m = torch.distributions.Normal(loc=mu, scale=sigma)loss = torch.exp(m.log_prob(y))loss = torch.sum(loss * pi, dim=1)loss = -torch.log(loss)return torch.mean(loss)

接下来则是训练的过程:

for epoch in range(10000):pi, mu, sigma = model(x_data)loss = mdn_loss_fn(y_data, mu, sigma, pi)optimizer.zero_grad()loss.backward()optimizer.step()if epoch % 1000 == 0:print(loss.data.tolist())

最后是推理的过程。需要注意的是,MDN学到的只是若干个高斯分布:

pi, mu, sigma = model(x_test)

因此我们还要手动去从高斯分布中采样获得具体的值:

k = torch.multinomial(pi, 1).view(-1)
y_pred = torch.normal(mu, sigma)[np.arange(n_samples), k].data
plt.figure(figsize=(8, 8))
plt.scatter(x_data, y_data, alpha=0.4)
plt.scatter(x_test, y_pred, alpha=0.4, color='red')
plt.show()

最后结果如下:
在这里插入图片描述


完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as pltn_samples = 1000epsilon = torch.randn(n_samples)
x_data = torch.linspace(-10, 10, n_samples)
y_data = 7*np.sin(0.75*x_data) + 0.5*x_data + epsilon# y_data, x_data = y_data.view(-1, 1), x_data.view(-1, 1),
y_data, x_data = x_data.view(-1, 1), y_data.view(-1, 1)plt.figure(figsize=(8, 8))
plt.scatter(x_data, y_data, alpha=0.4)
plt.show()n_input = 1
n_hidden = 20
n_output = 1
model = nn.Sequential(nn.Linear(n_input, n_hidden),nn.Tanh(),nn.Linear(n_hidden, n_output))
loss_fn = nn.MSELoss()
optimizer = torch.optim.RMSprop(model.parameters())for epoch in range(3000):y_pred = model(x_data)loss = loss_fn(y_pred, y_data)optimizer.zero_grad()loss.backward()optimizer.step()if epoch % 1000 == 0:print(loss.data.tolist())x_test = torch.linspace(-15, 15, n_samples).view(-1, 1)
y_pred = model(x_test).data
plt.figure(figsize=(8, 8))
plt.scatter(x_data, y_data, alpha=0.4)
plt.scatter(x_test, y_pred, alpha=0.4, color='red')
plt.show()class MDN(nn.Module):def __init__(self, n_hidden, n_gaussians):super(MDN, self).__init__()self.z_h = nn.Sequential(nn.Linear(1, n_hidden),nn.Tanh())self.z_pi = nn.Linear(n_hidden, n_gaussians)self.z_mu = nn.Linear(n_hidden, n_gaussians)self.z_sigma = nn.Linear(n_hidden, n_gaussians)def forward(self, x):z_h = self.z_h(x)pi = F.softmax(self.z_pi(z_h), -1)mu = self.z_mu(z_h)sigma = torch.exp(self.z_sigma(z_h))return pi, mu, sigmamodel = MDN(n_hidden=20, n_gaussians=5)
optimizer = torch.optim.Adam(model.parameters())
def mdn_loss_fn(y, mu, sigma, pi):m = torch.distributions.Normal(loc=mu, scale=sigma)loss = torch.exp(m.log_prob(y))loss = torch.sum(loss * pi, dim=1)loss = -torch.log(loss)return torch.mean(loss)for epoch in range(10000):pi, mu, sigma = model(x_data)loss = mdn_loss_fn(y_data, mu, sigma, pi)optimizer.zero_grad()loss.backward()optimizer.step()if epoch % 1000 == 0:print(loss.data.tolist())pi, mu, sigma = model(x_test)k = torch.multinomial(pi, 1).view(-1)
y_pred = torch.normal(mu, sigma)[np.arange(n_samples), k].dataplt.figure(figsize=(8, 8))
plt.scatter(x_data, y_data, alpha=0.4)
plt.scatter(x_test, y_pred, alpha=0.4, color='red')
plt.show()

推荐阅读
  • 【论文】ICLR 2020 九篇满分论文!!!
    点击上方,选择星标或置顶,每天给你送干货!阅读大概需要11分钟跟随小博主,每天进步一丢丢来自:深度学习技术前沿 ... [详细]
  • 浏览器中的异常检测算法及其在深度学习中的应用
    本文介绍了在浏览器中进行异常检测的算法,包括统计学方法和机器学习方法,并探讨了异常检测在深度学习中的应用。异常检测在金融领域的信用卡欺诈、企业安全领域的非法入侵、IT运维中的设备维护时间点预测等方面具有广泛的应用。通过使用TensorFlow.js进行异常检测,可以实现对单变量和多变量异常的检测。统计学方法通过估计数据的分布概率来计算数据点的异常概率,而机器学习方法则通过训练数据来建立异常检测模型。 ... [详细]
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 本文介绍了在Python张量流中使用make_merged_spec()方法合并设备规格对象的方法和语法,以及参数和返回值的说明,并提供了一个示例代码。 ... [详细]
  • 开源Keras Faster RCNN模型介绍及代码结构解析
    本文介绍了开源Keras Faster RCNN模型的环境需求和代码结构,包括FasterRCNN源码解析、RPN与classifier定义、data_generators.py文件的功能以及损失计算。同时提供了该模型的开源地址和安装所需的库。 ... [详细]
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • 语义分割系列3SegNet(pytorch实现)
    SegNet手稿最早是在2015年12月投出,和FCN属于同时期作品。稍晚于FCN,既然属于后来者,又是与FCN同属于语义分割网络 ... [详细]
  • 上一章讲了如何制作数据集,接下来我们使用mmcls来实现多标签分类。 ... [详细]
  • 都会|可能会_###haohaohao###图神经网络之神器——PyTorch Geometric 上手 & 实战
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了###haohaohao###图神经网络之神器——PyTorchGeometric上手&实战相关的知识,希望对你有一定的参考价值。 ... [详细]
  • 「爆干7天7夜」入门AI人工智能学习路线一条龙,真的不能再透彻了
    前言应广大粉丝要求,今天迪迦来和大家讲解一下如何去入门人工智能,也算是迪迦对自己学习人工智能这么多年的一个总结吧,本条学习路线并不会那么 ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • 本文介绍了在开发Android新闻App时,搭建本地服务器的步骤。通过使用XAMPP软件,可以一键式搭建起开发环境,包括Apache、MySQL、PHP、PERL。在本地服务器上新建数据库和表,并设置相应的属性。最后,给出了创建new表的SQL语句。这个教程适合初学者参考。 ... [详细]
  • 本文分享了一个关于在C#中使用异步代码的问题,作者在控制台中运行时代码正常工作,但在Windows窗体中却无法正常工作。作者尝试搜索局域网上的主机,但在窗体中计数器没有减少。文章提供了相关的代码和解决思路。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
author-avatar
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有