热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

DRL实战:用PyTorch150行代码实现AdvantageActorCritic玩CartPole

1前言今天我们来用Pytorch实现一下用AdvantageActor-Critic也就是A3C的非异步版本A2C玩CartPole。https:www.zhihu.comvide

1 前言

今天我们来用Pytorch实现一下用Advantage Actor-Critic 也就是A3C的非异步版本A2C玩CartPole。

《DRL实战:用PyTorch 150行代码实现Advantage Actor-Critic玩CartPole》https://www.zhihu.com/video/868815239284682752

2 前提条件

要理解今天的这个DRL实战,需要具备以下条件:

  • 理解Advantage Actor-Critic算法
  • 熟悉Python
  • 一定程度了解PyTorch
  • 安装了OpenAI Gym的环境

3 Advantage Actor-Critic 算法简介

《DRL实战:用PyTorch 150行代码实现Advantage Actor-Critic玩CartPole》
《DRL实战:用PyTorch 150行代码实现Advantage Actor-Critic玩CartPole》

这里直接引用David Silver的Talk课件。

我们要构造两个网络:Actor Network和Value Network

其中Actor Network的更新使用Policy Gradient,而Value Network的更新使用MSELoss。

关于Policy Gradient方法不了解的童鞋可以参考一下专栏之前的Blog。

4 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
import math
import random
import os
import gym
# Hyper Parameters
STATE_DIM = 4
ACTION_DIM = 2
STEP = 2000
SAMPLE_NUMS = 30
class ActorNetwork(nn.Module):
def __init__(self,input_size,hidden_size,action_size):
super(ActorNetwork, self).__init__()
self.fc1 = nn.Linear(input_size,hidden_size)
self.fc2 = nn.Linear(hidden_size,hidden_size)
self.fc3 = nn.Linear(hidden_size,action_size)
def forward(self,x):
out = F.relu(self.fc1(x))
out = F.relu(self.fc2(out))
out = F.log_softmax(self.fc3(out))
return out
class ValueNetwork(nn.Module):
def __init__(self,input_size,hidden_size,output_size):
super(ValueNetwork, self).__init__()
self.fc1 = nn.Linear(input_size,hidden_size)
self.fc2 = nn.Linear(hidden_size,hidden_size)
self.fc3 = nn.Linear(hidden_size,output_size)
def forward(self,x):
out = F.relu(self.fc1(x))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
def roll_out(actor_network,task,sample_nums,value_network,init_state):
#task.reset()
states = []
actiOns= []
rewards = []
is_dOne= False
final_r = 0
state = init_state
for j in range(sample_nums):
states.append(state)
log_softmax_action = actor_network(Variable(torch.Tensor([state])))
softmax_action = torch.exp(log_softmax_action)
action = np.random.choice(ACTION_DIM,p=softmax_action.cpu().data.numpy()[0])
one_hot_action = [int(k == action) for k in range(ACTION_DIM)]
next_state,reward,done,_ = task.step(action)
#fix_reward = -10 if done else 1
actions.append(one_hot_action)
rewards.append(reward)
final_state = next_state
state = next_state
if done:
is_dOne= True
state = task.reset()
break
if not is_done:
final_r = value_network(Variable(torch.Tensor([final_state]))).cpu().data.numpy()
return states,actions,rewards,final_r,state
def discount_reward(r, gamma,final_r):
discounted_r = np.zeros_like(r)
running_add = final_r
for t in reversed(range(0, len(r))):
running_add = running_add * gamma + r[t]
discounted_r[t] = running_add
return discounted_r
def main():
# init a task generator for data fetching
task = gym.make("CartPole-v0")
init_state = task.reset()
# init value network
value_network = ValueNetwork(input_size = STATE_DIM,hidden_size = 40,output_size = 1)
value_network_optim = torch.optim.Adam(value_network.parameters(),lr=0.01)
# init actor network
actor_network = ActorNetwork(STATE_DIM,40,ACTION_DIM)
actor_network_optim = torch.optim.Adam(actor_network.parameters(),lr = 0.01)
steps =[]
task_episodes =[]
test_results =[]
for step in range(STEP):
states,actions,rewards,final_r,current_state = roll_out(actor_network,task,SAMPLE_NUMS,value_network,init_state)
init_state = current_state
actions_var = Variable(torch.Tensor(actions).view(-1,ACTION_DIM))
states_var = Variable(torch.Tensor(states).view(-1,STATE_DIM))
# train actor network
actor_network_optim.zero_grad()
log_softmax_actiOns= actor_network(states_var)
vs = value_network(states_var).detach()
# calculate qs
qs = Variable(torch.Tensor(discount_reward(rewards,0.99,final_r)))
advantages = qs - vs
actor_network_loss = - torch.mean(torch.sum(log_softmax_actions*actions_var,1)* advantages)
actor_network_loss.backward()
torch.nn.utils.clip_grad_norm(actor_network.parameters(),0.5)
actor_network_optim.step()
# train value network
value_network_optim.zero_grad()
target_values = qs
values = value_network(states_var)
criterion = nn.MSELoss()
value_network_loss = criterion(values,target_values)
value_network_loss.backward()
torch.nn.utils.clip_grad_norm(value_network.parameters(),0.5)
value_network_optim.step()
# Testing
if (step + 1) % 50== 0:
result = 0
test_task = gym.make("CartPole-v0")
for test_epi in range(10):
state = test_task.reset()
for test_step in range(200):
softmax_action = torch.exp(actor_network(Variable(torch.Tensor([state]))))
#print(softmax_action.data)
action = np.argmax(softmax_action.data.numpy()[0])
next_state,reward,done,_ = test_task.step(action)
result += reward
state = next_state
if done:
break
print("step:",step+1,"test result:",result/10.0)
steps.append(step+1)
test_results.append(result/10)
if __name__ == '__main__':
main()

直接贴了代码,来自songrotek/a2c_cartpole_pytorch

代码可以直接运行。一共150行。代码应该比较清晰,这里也就不过多介绍。一些小技巧:

  • 使用了clip gradient防止梯度过大
  • 通过roll out和discount reward两个函数来获取样本及计算n step 的Q值
  • roll out部分并不每次重启游戏,而是done了才重启游戏

5 小结

相信研究DRL的童鞋理解了以上代码也就理解了Advantage Actor-Critic的算法了。


推荐阅读
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • 本文介绍了贝叶斯垃圾邮件分类的机器学习代码,代码来源于https://www.cnblogs.com/huangyc/p/10327209.html,并对代码进行了简介。朴素贝叶斯分类器训练函数包括求p(Ci)和基于词汇表的p(w|Ci)。 ... [详细]
  • 十大经典排序算法动图演示+Python实现
    本文介绍了十大经典排序算法的原理、演示和Python实现。排序算法分为内部排序和外部排序,常见的内部排序算法有插入排序、希尔排序、选择排序、冒泡排序、归并排序、快速排序、堆排序、基数排序等。文章还解释了时间复杂度和稳定性的概念,并提供了相关的名词解释。 ... [详细]
  • 本文主要解析了Open judge C16H问题中涉及到的Magical Balls的快速幂和逆元算法,并给出了问题的解析和解决方法。详细介绍了问题的背景和规则,并给出了相应的算法解析和实现步骤。通过本文的解析,读者可以更好地理解和解决Open judge C16H问题中的Magical Balls部分。 ... [详细]
  • 使用Ubuntu中的Python获取浏览器历史记录原文: ... [详细]
  • 本文介绍了P1651题目的描述和要求,以及计算能搭建的塔的最大高度的方法。通过动态规划和状压技术,将问题转化为求解差值的问题,并定义了相应的状态。最终得出了计算最大高度的解法。 ... [详细]
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
  • FeatureRequestIsyourfeaturerequestrelatedtoaproblem?Please ... [详细]
  • [大整数乘法] java代码实现
    本文介绍了使用java代码实现大整数乘法的过程,同时也涉及到大整数加法和大整数减法的计算方法。通过分治算法来提高计算效率,并对算法的时间复杂度进行了研究。详细代码实现请参考文章链接。 ... [详细]
  • 本文介绍了Python爬虫技术基础篇面向对象高级编程(中)中的多重继承概念。通过继承,子类可以扩展父类的功能。文章以动物类层次的设计为例,讨论了按照不同分类方式设计类层次的复杂性和多重继承的优势。最后给出了哺乳动物和鸟类的设计示例,以及能跑、能飞、宠物类和非宠物类的增加对类数量的影响。 ... [详细]
  • 第四章高阶函数(参数传递、高阶函数、lambda表达式)(python进阶)的讲解和应用
    本文主要讲解了第四章高阶函数(参数传递、高阶函数、lambda表达式)的相关知识,包括函数参数传递机制和赋值机制、引用传递的概念和应用、默认参数的定义和使用等内容。同时介绍了高阶函数和lambda表达式的概念,并给出了一些实例代码进行演示。对于想要进一步提升python编程能力的读者来说,本文将是一个不错的学习资料。 ... [详细]
  • 基于dlib的人脸68特征点提取(眨眼张嘴检测)python版本
    文章目录引言开发环境和库流程设计张嘴和闭眼的检测引言(1)利用Dlib官方训练好的模型“shape_predictor_68_face_landmarks.dat”进行68个点标定 ... [详细]
  • 欢乐的票圈重构之旅——RecyclerView的头尾布局增加
    项目重构的Git地址:https:github.comrazerdpFriendCircletreemain-dev项目同步更新的文集:http:www.jianshu.comno ... [详细]
  • EzPP 0.2发布,新增YAML布局渲染功能
    EzPP发布了0.2.1版本,新增了YAML布局渲染功能,可以将YAML文件渲染为图片,并且可以复用YAML作为模版,通过传递不同参数生成不同的图片。这个功能可以用于绘制Logo、封面或其他图片,让用户不需要安装或卸载Photoshop。文章还提供了一个入门例子,介绍了使用ezpp的基本渲染方法,以及如何使用canvas、text类元素、自定义字体等。 ... [详细]
  • 超级简单加解密工具的方案和功能
    本文介绍了一个超级简单的加解密工具的方案和功能。该工具可以读取文件头,并根据特定长度进行加密,加密后将加密部分写入源文件。同时,该工具也支持解密操作。加密和解密过程是可逆的。本文还提到了一些相关的功能和使用方法,并给出了Python代码示例。 ... [详细]
author-avatar
手机用户2502863643
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有