热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

【Pytorch神经网络实战案例】19神经网络实现估计互信息的功能

1案例说明(实现MINE正方法的功能)定义两组具有不同分布的模拟数据,使用神经网络的MINE的方法计算两个数据分布之间的互信息2代码编


1 案例说明(实现MINE正方法的功能)

定义两组具有不同分布的模拟数据,使用神经网络的MINE的方法计算两个数据分布之间的互信息


2 代码编写


2.1 代码实战:准备样本数据

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # 可能是由于是MacOS系统的原因### 本例实现了用神经网络计算互信息的功能。这是一个简单的例子,目的在于帮助读者更好地理MNE方法。# 1.1 准备样本数据:定义两个数据生成函数gen_x()、gen_y()。函数gen_x()用于生成1或-1,函数gen_y()在此基础上为其再加上一个符合高斯分布的随机值。
# 生成模拟数据
data_size = 1000
def gen_x():return np.sign(np.random.normal(0.0,1.0,[data_size,1]))def gen_y(x):return x + np.random.normal(0.0,0.5,[data_size,1])def show_data():x_sample = gen_x()y_sample = gen_y(x_sample)plt.scatter(np.arange(len(x_sample)), x_sample, s=10,c='b',marker='o')plt.scatter(np.arange(len(y_sample)), y_sample, s=10,c='y',marker='o')plt.show() # 两条横线部分是样本数据x中的点,其他部分是样本数据y。


2.2 代码实战:定义神经网络模型

# 1.2 定义神经网络模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(1,10)self.fc2 = nn.Linear(1,10)self.fc3 = nn.Linear(10,1)def forward(self,x,y):h1 = F.relu(self.fc1(x) + self.fc2(y))h2 = self.fc3(h1)return h2

2.3 代码实战:利用MINE方法训练模型并输出结果

# 1.3 利用MINE方法训练模型并输出结果
if __name__ == '__main__':show_data()# 显示数据model = Net() # 实例化模型optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # 使用Adam优化器并设置学习率为0.01n_epoch = 500plot_loss = []# MiNE方法主要用于模型的训练阶段for epoch in tqdm(range(n_epoch)):x_sample = gen_x() # 调用gen_x()函数生成样本x_Sample。X_sample代表X的边缘分布P(X)y_sample = gen_y(x_sample) # 将生成的×_sample样本放到gen_x()函数中,生成样本y_sample。y_sample代表条件分布P(Y|X)。y_shuffle = np.random.permutation(y_sample) # )将 y_sample按照批次维度打乱顺序得到y_shuffle,y_shuffle是Y的经验分布,近似于Y的边缘分布P(Y)。# 转化为张量x_sample = torch.from_numpy(x_sample).type(torch.FloatTensor)y_sample = torch.from_numpy(y_sample).type(torch.FloatTensor)y_shuffle = torch.from_numpy(y_shuffle).type(torch.FloatTensor)model.zero_grad()pred_xy = model(x_sample, y_sample) # 式(8-49)中的第一项联合分布的期望:将x_sample和y_sample放到模型中,得到联合概率(P(X,Y)=P(Y|X)P(X))关于神经网络的期望值pred_xy。pred_x_y = model(x_sample, y_shuffle) # 式(8-49)中的第二项边缘分布的期望:将x_sample和y_shuffle放到模型中,得到边缘概率关于神经网络的期望值pred_x_y 。ret = torch.mean(pred_xy) - torch.log(torch.mean(torch.exp(pred_x_y))) # 将pred_xy和pred_x_y代入式(8-49)中,得到互信息ret。loss = - ret # 最大化互信息:在训练过程中,因为需要将模型权重向着互信息最大的方向优化,所以对互信息取反,得到最终的loss值。plot_loss.append(loss.data) # 收集损失值loss.backward() # 反向传播:在得到loss值之后,便可以进行反向传播并调用优化器进行模型优化。optimizer.step() # 调用优化器plot_y = np.array(plot_loss).reshape(-1, ) # 可视化plt.plot(np.arange(len(plot_loss)), -plot_y, 'r') # 直接将|oss值取反,得到最大化互信息的值。plt.show()


3 代码总览

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # 可能是由于是MacOS系统的原因### 本例实现了用神经网络计算互信息的功能。这是一个简单的例子,目的在于帮助读者更好地理MNE方法。# 1.1 准备样本数据:定义两个数据生成函数gen_x()、gen_y()。函数gen_x()用于生成1或-1,函数gen_y()在此基础上为其再加上一个符合高斯分布的随机值。
# 生成模拟数据
data_size = 1000
def gen_x():return np.sign(np.random.normal(0.0,1.0,[data_size,1]))def gen_y(x):return x + np.random.normal(0.0,0.5,[data_size,1])def show_data():x_sample = gen_x()y_sample = gen_y(x_sample)plt.scatter(np.arange(len(x_sample)), x_sample, s=10,c='b',marker='o')plt.scatter(np.arange(len(y_sample)), y_sample, s=10,c='y',marker='o')plt.show() # 两条横线部分是样本数据x中的点,其他部分是样本数据y。# 1.2 定义神经网络模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(1,10)self.fc2 = nn.Linear(1,10)self.fc3 = nn.Linear(10,1)def forward(self,x,y):h1 = F.relu(self.fc1(x) + self.fc2(y))h2 = self.fc3(h1)return h2# 1.3 利用MINE方法训练模型并输出结果
if __name__ == '__main__':show_data()# 显示数据model = Net() # 实例化模型optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # 使用Adam优化器并设置学习率为0.01n_epoch = 500plot_loss = []# MiNE方法主要用于模型的训练阶段for epoch in tqdm(range(n_epoch)):x_sample = gen_x() # 调用gen_x()函数生成样本x_Sample。X_sample代表X的边缘分布P(X)y_sample = gen_y(x_sample) # 将生成的×_sample样本放到gen_x()函数中,生成样本y_sample。y_sample代表条件分布P(Y|X)。y_shuffle = np.random.permutation(y_sample) # )将 y_sample按照批次维度打乱顺序得到y_shuffle,y_shuffle是Y的经验分布,近似于Y的边缘分布P(Y)。# 转化为张量x_sample = torch.from_numpy(x_sample).type(torch.FloatTensor)y_sample = torch.from_numpy(y_sample).type(torch.FloatTensor)y_shuffle = torch.from_numpy(y_shuffle).type(torch.FloatTensor)model.zero_grad()pred_xy = model(x_sample, y_sample) # 式(8-49)中的第一项联合分布的期望:将x_sample和y_sample放到模型中,得到联合概率(P(X,Y)=P(Y|X)P(X))关于神经网络的期望值pred_xy。pred_x_y = model(x_sample, y_shuffle) # 式(8-49)中的第二项边缘分布的期望:将x_sample和y_shuffle放到模型中,得到边缘概率关于神经网络的期望值pred_x_y 。ret = torch.mean(pred_xy) - torch.log(torch.mean(torch.exp(pred_x_y))) # 将pred_xy和pred_x_y代入式(8-49)中,得到互信息ret。loss = - ret # 最大化互信息:在训练过程中,因为需要将模型权重向着互信息最大的方向优化,所以对互信息取反,得到最终的loss值。plot_loss.append(loss.data) # 收集损失值loss.backward() # 反向传播:在得到loss值之后,便可以进行反向传播并调用优化器进行模型优化。optimizer.step() # 调用优化器plot_y = np.array(plot_loss).reshape(-1, ) # 可视化plt.plot(np.arange(len(plot_loss)), -plot_y, 'r') # 直接将|oss值取反,得到最大化互信息的值。plt.show()


推荐阅读
  • 颜色迁移(reinhard VS welsh)
    不要谈什么天分,运气,你需要的是一个截稿日,以及一个不交稿就能打爆你狗头的人,然后你就会被自己的才华吓到。------ ... [详细]
  • 本文介绍了[从头学数学]中第101节关于比例的相关问题的研究和修炼过程。主要内容包括[机器小伟]和[工程师阿伟]一起研究比例的相关问题,并给出了一个求比例的函数scale的实现。 ... [详细]
  • 本文介绍了闭包的定义和运转机制,重点解释了闭包如何能够接触外部函数的作用域中的变量。通过词法作用域的查找规则,闭包可以访问外部函数的作用域。同时还提到了闭包的作用和影响。 ... [详细]
  • 本文介绍了一个Java猜拳小游戏的代码,通过使用Scanner类获取用户输入的拳的数字,并随机生成计算机的拳,然后判断胜负。该游戏可以选择剪刀、石头、布三种拳,通过比较两者的拳来决定胜负。 ... [详细]
  • Commit1ced2a7433ea8937a1b260ea65d708f32ca7c95eintroduceda+Clonetraitboundtom ... [详细]
  • Java容器中的compareto方法排序原理解析
    本文从源码解析Java容器中的compareto方法的排序原理,讲解了在使用数组存储数据时的限制以及存储效率的问题。同时提到了Redis的五大数据结构和list、set等知识点,回忆了作者大学时代的Java学习经历。文章以作者做的思维导图作为目录,展示了整个讲解过程。 ... [详细]
  • 展开全部下面的代码是创建一个立方体Thisexamplescreatesanddisplaysasimplebox.#Thefirstlineloadstheinit_disp ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 本文介绍了PE文件结构中的导出表的解析方法,包括获取区段头表、遍历查找所在的区段等步骤。通过该方法可以准确地解析PE文件中的导出表信息。 ... [详细]
  • C++中的三角函数计算及其应用
    本文介绍了C++中的三角函数的计算方法和应用,包括计算余弦、正弦、正切值以及反三角函数求对应的弧度制角度的示例代码。代码中使用了C++的数学库和命名空间,通过赋值和输出语句实现了三角函数的计算和结果显示。通过学习本文,读者可以了解到C++中三角函数的基本用法和应用场景。 ... [详细]
  • 本文讨论了在openwrt-17.01版本中,mt7628设备上初始化启动时eth0的mac地址总是随机生成的问题。每次随机生成的eth0的mac地址都会写到/sys/class/net/eth0/address目录下,而openwrt-17.01原版的SDK会根据随机生成的eth0的mac地址再生成eth0.1、eth0.2等,生成后的mac地址会保存在/etc/config/network下。 ... [详细]
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • 怎么在PHP项目中实现一个HTTP断点续传功能发布时间:2021-01-1916:26:06来源:亿速云阅读:96作者:Le ... [详细]
  • STL迭代器的种类及其功能介绍
    本文介绍了标准模板库(STL)定义的五种迭代器的种类和功能。通过图表展示了这几种迭代器之间的关系,并详细描述了各个迭代器的功能和使用方法。其中,输入迭代器用于从容器中读取元素,输出迭代器用于向容器中写入元素,正向迭代器是输入迭代器和输出迭代器的组合。本文的目的是帮助读者更好地理解STL迭代器的使用方法和特点。 ... [详细]
  • Python教学练习二Python1-12练习二一、判断季节用户输入月份,判断这个月是哪个季节?3,4,5月----春 ... [详细]
author-avatar
茨冈人686
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有