热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

【Seaborn绘图】深度强化学习实验中的paper绘图方法

来源:知乎(zhuanlan.zhihu.comp75477750)编辑:DeepRL强化学习实验中的绘图技巧-使用seaborn绘制paper中的图片,

来源:知乎(zhuanlan.zhihu.com/p/75477750)

编辑: DeepRL

强化学习实验中的绘图技巧-使用seaborn绘制paper中的图片,使用seaborn绘制折线图时参数数据可以传递ndarray或者pandas,不同的源数据对应的其他参数也略有不同.

1. ndarray

先看一个小例子

def getdata(): basecond = [[18, 20, 19, 18, 13, 4, 1], [20, 17, 12, 9, 3, 0, 0], [20, 20, 20, 12, 5, 3, 0]]cond1 = [[18, 19, 18, 19, 20, 15, 14], [19, 20, 18, 16, 20, 15, 9], [19, 20, 20, 20, 17, 10, 0], [20, 20, 20, 20, 7, 9, 1]]cond2 = [[20, 20, 20, 20, 19, 17, 4], [20, 20, 20, 20, 20, 19, 7], [19, 20, 20, 19, 19, 15, 2]]cond3 = [[20, 20, 20, 20, 19, 17, 12], [18, 20, 19, 18, 13, 4, 1], [20, 19, 18, 17, 13, 2, 0], [19, 18, 20, 20, 15, 6, 0]]return basecond, cond1, cond2, cond3

数据维度都为(3,7)或(4, 7) 

第一个维度表示每个时间点采样不同数目的数据(可认为是每个x对应多个不同y值) 第二个维度表示不同的时间点(可认为是x轴对应的x值)

data = getdata()
fig = plt.figure()
xdata = np.array([0, 1, 2, 3, 4, 5, 6])/5
linestyle = ['-', '--', ':', '-.']
color = ['r', 'g', 'b', 'k']
label = ['algo1', 'algo2', 'algo3', 'algo4']for i in range(4): sns.tsplot(time=xdata, data=data[i], color=color[i], linestyle=linestyle[i], condition=label[i])

sns.tsplot 用来画时间序列图

time参数表示对应的时间轴(ndarray),即x轴,data即要求绘制的数据,上述例子为(3, 7)或(4, 7),color为每条线的颜色,linestyle为每条线的样式,condition为每条线的标记.

plt.ylabel("Success Rate", fontsize=25)
plt.xlabel("Iteration Number", fontsize=25)
plt.title("Awesome Robot Performance", fontsize=30)
plt.show()

1.2 绘图建议

  • 你的程序代码需要使用一个额外的文件记录结果,例如csv或pkl文件,而不是直接产生最终的绘图结果.这种方式下,你能运行程序代码一次,然后以不同的方式去绘制结果,记录超出您认为严格必要的内容可能是一个好主意,因为您永远不知道哪些信息对于了解发生的事情最有用.注意文件的大小,但通常最好记录以下内容:每次迭代的平均reward或loss,一些采样的轨迹,有用的辅助指标(如贝尔曼误差和梯度)

  • 你需要有一个单独的脚本去加载一个或多个记录文件来绘制图像,如果你使用不同的超参数或随机种子运行算法多次,一起加载所有的数据(也许来自不同的文件)并画在一起是个好主意,使用自动生成的图例和颜色模式使分辨不同的方法变得容易.

  • 深度强化学习方法,往往在不同的运行中有巨大的变化,因此使用不同的随机种子运行多次是一个好主意,在绘制多次运行的结果时,在一张图上绘制不同运行次的结果,通过使用不同粗细和颜色的线来分辨.在绘制不同的方法时,你将发现将他们总结为均值和方差图是容易的,然而分布并不总是遵循正态曲线,所以至少在初始时有明显的感觉对比不同随机种子的性能.

1.3 实验绘图流程

下面以模仿学习的基础实验为例

means = []
stds = []
#使用不同的随机种子表示运行多次实验
for seed in range(SEED_NUM): tf.set_random_seed(seed*10) np.random.seed(seed*10) mean = [] std = []#构建神经网络模型model = tf.keras.Sequential() model.add(layers.Dense(64, activation="relu")) model.add(layers.Dense(64, activation="relu")) model.add(layers.Dense(act_dim, activation="tanh")) model.compile(optimizer=tf.train.AdamOptimizer(0.0001), loss="mse", metrics=['mae']) #迭代次数for iter in range(ITERATION): print("iter:", iter) #训练模型model.fit(train, label, batch_size=BATCH_SIZE, epochs=EPOCHS)#测试,通过与环境交互n次而成,即n趟轨迹roll_reward = [] for roll in range(NUM_ROLLOUTS): s = env.reset() done = False reward = 0 step = 0 #以下循环表示一趟轨迹while not done: a = model.predict(s[np.newaxis, :]) s, r, done, _ = env.step(a) reward += r step += 1 if step >= max_steps: break#记录每一趟的总回报值roll_reward.append(reward) #n趟回报的平均值和方差作为这次迭代的结果记录mean.append(np.mean(roll_reward)) std.append(np.std(roll_reward)) #记录每一次实验,矩阵的一行表示一次实验每次迭代结果means.append(mean) stds.append(std)

接着需要保存数据为pkl文件

d = {"mean": means, "std": stds}
with open(os.path.join("test_data", "behavior_cloning_" + ENV_NAME+".pkl"), "wb") as f:pickle.dump(d, f, pickle.HIGHEST_PROTOCOL)

绘图的程序代码比较简单

file = "behavior_cloning_" + ENV_NAME+".pkl"with open(os.path.join("test_data", file), "rb") as f:data = pickle.load(f)x1 = data["mean"]file = "dagger_" + ENV_NAME+".pkl"with open(os.path.join("test_data", file), "rb") as f:data = pickle.load(f)x2 = data["mean"]time = range(10)sns.set(style="darkgrid", font_scale=1.5)sns.tsplot(time=time, data=x1, color="r", condition="behavior_cloning")sns.tsplot(time=time, data=x2, color="b", condition="dagger")plt.ylabel("Reward")plt.xlabel("Iteration Number")plt.title("Imitation Learning")plt.show()

有时我们需要对曲线进行平滑

def smooth(data, sm=1):if sm > 1:smooth_data = []for d in data:y = np.ones(sm)*1.0/smd = np.convolve(y, d, "same")smooth_data.append(d)return smooth_data

sm表示滑动窗口大小,为2*k+1,

smoothed_y[t] = average(y[t-k], y[t-k+1], ..., y[t+k-1], y[t+k])

2.pandas

sns.tsplot可以使用pandas源数据作为数据输入,当使用pandas作为数据时,time,value,condition,unit选项将为pandas数据的列名.

其中time选项给出使用该列Series作为x轴数据,value选项表示使用该Series作为y轴数据,用unit来分辨这些数据是哪一次采样(每个x对应多个y),用condition选项表示这些数据来自哪一条曲线.

在openai 的spinning up中,将每次迭代的数据保存到了txt文件中,类似如下:

可以使用pd.read_table读取这个以"\t"分割的文件形成pandas

algo = ["ddpg_" + ENV, "td3_" + ENV, "ppo_" + ENV, "trpo_" + ENV, "vpg_" + ENV, "sac_" + ENV]data = []for i in range(len(algo)):for seed in range(SEED_NUM):file = os.path.join(os.path.join(algo[i], algo[i] + "_s" + str(seed*10)), "progress.txt")pd_data = pd.read_table(file)pd_data.insert(len(pd_data.columns), "Unit", seed)pd_data.insert(len(pd_data.columns), "Condition", algo[i])data.append(pd_data)data = pd.concat(data, ignore_index=True)sns.set(style="darkgrid", font_scale=1.5)sns.tsplot(data=data, time="TotalEnvInteracts", value="AverageEpRet", condition="Condition", unit="Unit")#数据大时使用科学计数法xscale = np.max(data["TotalEnvInteracts"]) > 5e3if xscale:plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))plt.legend(loc='best').set_draggable(True)plt.tight_layout(pad=0.5)plt.show()

程序参考了spinning up 的代码逻辑github.com/openai/spinn

绘制效果如下:

完整代码:https://github.com/feidieufo/homework/tree/master/hw1


推荐阅读
  • 深入解析Android自定义View面试题
    本文探讨了Android Launcher开发中自定义View的重要性,并通过一道经典的面试题,帮助开发者更好地理解自定义View的实现细节。文章不仅涵盖了基础知识,还提供了实际操作建议。 ... [详细]
  • 2023年京东Android面试真题解析与经验分享
    本文由一位拥有6年Android开发经验的工程师撰写,详细解析了京东面试中常见的技术问题。涵盖引用传递、Handler机制、ListView优化、多线程控制及ANR处理等核心知识点。 ... [详细]
  • 毕业设计:基于机器学习与深度学习的垃圾邮件(短信)分类算法实现
    本文详细介绍了如何使用机器学习和深度学习技术对垃圾邮件和短信进行分类。内容涵盖从数据集介绍、预处理、特征提取到模型训练与评估的完整流程,并提供了具体的代码示例和实验结果。 ... [详细]
  • 探索电路与系统的起源与发展
    本文回顾了电路与系统的发展历程,从电的早期发现到现代电子器件的应用。文章不仅涵盖了基础理论和关键发明,还探讨了这一学科对计算机、人工智能及物联网等领域的深远影响。 ... [详细]
  • 优化ListView性能
    本文深入探讨了如何通过多种技术手段优化ListView的性能,包括视图复用、ViewHolder模式、分批加载数据、图片优化及内存管理等。这些方法能够显著提升应用的响应速度和用户体验。 ... [详细]
  • 本文探讨了卷积神经网络(CNN)中感受野的概念及其与锚框(anchor box)的关系。感受野定义了特征图上每个像素点对应的输入图像区域大小,而锚框则是在每个像素中心生成的多个不同尺寸和宽高比的边界框。两者在目标检测任务中起到关键作用。 ... [详细]
  • 卷积神经网络(CNN)基础理论与架构解析
    本文介绍了卷积神经网络(CNN)的基本概念、常见结构及其各层的功能。重点讨论了LeNet-5、AlexNet、ZFNet、VGGNet和ResNet等经典模型,并详细解释了输入层、卷积层、激活层、池化层和全连接层的工作原理及优化方法。 ... [详细]
  • 尽管深度学习带来了广泛的应用前景,其训练通常需要强大的计算资源。然而,并非所有开发者都能负担得起高性能服务器或专用硬件。本文探讨了如何在有限的硬件条件下(如ARM CPU)高效运行深度神经网络,特别是通过选择合适的工具和框架来加速模型推理。 ... [详细]
  • 本文作者分享了在阿里巴巴获得实习offer的经历,包括五轮面试的详细内容和经验总结。其中四轮为技术面试,一轮为HR面试,涵盖了大量的Java技术和项目实践经验。 ... [详细]
  • 本文详细探讨了KMP算法中next数组的构建及其应用,重点分析了未改良和改良后的next数组在字符串匹配中的作用。通过具体实例和代码实现,帮助读者更好地理解KMP算法的核心原理。 ... [详细]
  • Python自动化处理:从Word文档提取内容并生成带水印的PDF
    本文介绍如何利用Python实现从特定网站下载Word文档,去除水印并添加自定义水印,最终将文档转换为PDF格式。该方法适用于批量处理和自动化需求。 ... [详细]
  • PHP 5.5.0rc1 发布:深入解析 Zend OPcache
    2013年5月9日,PHP官方发布了PHP 5.5.0rc1和PHP 5.4.15正式版,这两个版本均支持64位环境。本文将详细介绍Zend OPcache的功能及其在Windows环境下的配置与测试。 ... [详细]
  • 网易严选Java开发面试:MySQL索引深度解析
    本文详细记录了网易严选Java开发岗位的面试经验,特别针对MySQL索引相关的技术问题进行了深入探讨。通过本文,读者可以了解面试官常问的索引问题及其背后的原理。 ... [详细]
  • 探讨如何从数据库中按分组获取最大N条记录的方法,并分享新年祝福。本文提供多种解决方案,适用于不同数据库系统,如MySQL、Oracle等。 ... [详细]
  • 本文探讨了如何在iOS开发环境中,特别是在Xcode 6.1中,设置和应用自定义文本样式。我们将详细介绍实现方法,并提供一些实用的技巧。 ... [详细]
author-avatar
圣换少爷
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有