RL 实践(7)—— CartPole【TPRO & PPO】
  VFpNeRYlMszB 2023年11月02日 68 0


  • 本文介绍 PPO 这个 online RL 的经典算法,并在 CartPole-V0 上进行测试。由于 PPO 是源自 TPRO 的,因此也会在原理部分介绍 TPRO
  • 参考:张伟楠《动手学强化学习》、王树森《深度强化学习》
  • 完整代码下载:8_[Gym] CartPole-V0 (PPO)


文章目录

  • 1. TPRO(置信域策略优化)方法
  • 1.1 朴素策略梯度方法的问题
  • 1.2 置信域优化法
  • 1.3 TPRO 公式推导
  • 1.3.1 做近似
  • 1.3.2 最大化
  • 1.4 小结
  • 2. PPO(近端策略优化)方法
  • 2.1 PPO 公式推导
  • 2.1.1 做近似
  • 2.1.2 最大化
  • 2.2 伪代码
  • 2.3 用 PPO 方法解决 CartPole 问题
  • 3. 总结


1. TPRO(置信域策略优化)方法

  • 置信域策略优化 (Trust Region Policy Optimization, TRPO) 是一种策略学习方法,跟朴素的策略梯度方法相比有两个优势:
  1. TRPO表现更稳定,收敛曲线不会剧烈波动,而且对学习率不敏感
  2. TRPO 用更少的经验数据(transition 四元组)就能达到与策略梯度方法相同的表现

1.1 朴素策略梯度方法的问题

  • 前文已经介绍了 policy gradient 方法 REINFORCE & Actor-Critic 以及其带 baseline 的改进版本 REINFORCE with baseline & A2C。这些方法的核心思想都是:参数化 agent 策略 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch,设计衡量策略好坏的目标函数 RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_02通过梯度上升的方法找出最大化这个目标函数的策略参数 RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_03,从而得到最优策略 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_04
  • 但是这种算法有一个明显的缺点:注意到在环境中 rollout 时,策略 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_05 会被重复使用,即使策略只有微小的改变,也可能导致最终收益的巨大变化。当策略网络是深度模型时这种特性尤其明显,因此在沿着策略梯度方向更新参数时
    RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_06 很有可能由于步长 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_07
  • 针对以上问题,TPRO 的思想是在更新时找到一块信任区域(trust region),认为在这个区域上更新策略时能够得到某种策略性能的安全性保证,从而避免策略崩溃

    为了实现这种安全性保证,我们必须舍弃掉随机梯度上升而改用其他的优化算法,TPRO 选择了置信域方法 (Trust Region Methods)

1.2 置信域优化法

  • 置信域优化法是数值最优化领域中一类经典的算法,历史至少可以追溯到 1970 年。其出发点是:如果对目标函数 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_08 进行优化过于困难,不妨构造一个替代函数 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_09,要求替代函在 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_10 的当前值 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_11 的邻域 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_12 内和 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_08 十分相似的,通过在这个局部范围内最优化 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_09 来更新一次 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_10

    其中 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_12 就被称作置信域,顾名思义,在 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_11 的邻域上我们可以信任 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_09,可以拿它来替代目标函数 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_08

    具体而言每轮迭代可以分成两步

    1. 做近似:给定 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_11,构造函数 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_21,使得对于所有的 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_22(置信域内取值),函数值 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_21 与原优化目标 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_08
    2. 最大化: 在置信域 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_12 中寻找变量 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_10 的值,使得替代函数 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_27 的值最大化。即求 RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_28
  • 注意每一轮迭代中,我们都在构造并求解一个小的约束优化问题,可以如下图示
  • RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_29

  • 注意到置信域半径控制着每一轮迭代中 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_30 变化的上限,我们通常会让这个半径随优化过程不断减小来避免 overstep
  • RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_31

  • 置信域方法是一种算法框架而非一个具体的算法。有很多种方式实现实现置信域方法:

    1. 第一步做近似的方法有多种多样,比如蒙特卡洛、二阶泰勒展开等
    2. 第二步解一个约束最大化问题的方法也很多,包括梯度投影算法、拉格朗日法等
    3. 置信域 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_12

1.3 TPRO 公式推导

  • TPRO 是一种将置信域优化方法应用到策略学习中的 Online RL 方法。回顾 policy gradient 算法,优化目标为最大化
    RL 实践(7)—— CartPole【TPRO & PPO】_PPO_33 其中 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_34 是策略 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_05 诱导的状态分布。考虑置信域优化法的迭代过程,每一步我们要构造优化问题:基于当前的参数 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_36 优化 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_30,故在式1中引入 RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_38
    RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_39 注意 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_40 是关于 RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_41 的函数,含有 RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_38 的成分都可以看做常数,故以上是一个恒等变换。下面开始推导每轮迭代的两个关键步骤

1.3.1 做近似

  • 原始优化目标 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_40RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_34RL 实践(7)—— CartPole【TPRO & PPO】_最优化_45 都不知道,无法直接优化,需要进行三步近似来构造替代函数
    1. 用当前策略 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_46 诱导的状态分布 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_47 近似 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_48,原始优化目标近似为
      RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_49
    2. 用 MC 近似消去上式中的两个期望。具体而言,先用当前策略 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_46 和环境交互收集一条轨迹
      RL 实践(7)—— CartPole【TPRO & PPO】_PPO_51 此轨迹满足 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_52,故每个 RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_53 二元组都能构造一个无偏 MC 估计
      RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_54 用这些无偏估计的期望(均值)来近似原始优化目标,得到
      RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_55
    3. 用真实 return 对 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_56 进行 MC 近似,具体而言
      RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_57
  • 综上得到对优化目标 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_40 的近似
    RL 实践(7)—— CartPole【TPRO & PPO】_最优化_59 注意近似过程中假设了RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_60RL 实践(7)—— CartPole【TPRO & PPO】_最优化_61 极其接近,以至于可以认为二者诱导的状态分布一致,这样就能完全避免策略优化后进入坏状态引发 1.1 节的 overstep 问题。因此需要强调置信域:只有 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_30 靠近 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_36

1.3.2 最大化

  • 每轮迭代中,求解以下约束优化问题
    RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_64 我们认为在置信域 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_65RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_66 近似 RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_67,这个约束越紧,就越能避免 1.1 节的 overstep 问题

  • 邻域(置信域)RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_65

    1. 简单地设置一个关于参数的欧式距离的阈值 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_69,即 RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_70 这时置信域是以 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_11 为球心,RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_69
    2. 另一种方式是设置一个关于策略的 KL 散度的阈值 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_69,即
      RL 实践(7)—— CartPole【TPRO & PPO】_最优化_74 此 KL 散度同样用 1.3.1 节中 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_46 交互得到的轨迹来做 MC 近似计算,即
      RL 实践(7)—— CartPole【TPRO & PPO】_最优化_76 这种做法可以直接约束策略的变化程度。实践表明这种置信域设定表现较好,对于 RL 来说,约束 “行为上的距离” 可能比约束 “参数上的距离” 更加合适
  • 综上得到每轮迭代的约束优化问题为
    RL 实践(7)—— CartPole【TPRO & PPO】_PPO_77

    1. 对优化目标在 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_11
    2. 对约束函数在 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_11
    3. 用拉格朗日乘子法转换为无约束优化问题,通过 KKT 条件得到 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_10

    其中二阶泰勒展开带来的黑塞矩阵尺寸很大,编程时要使用共轭梯度法进行处理;另外由于泰勒展开近似得不到精确解,还要用线性搜索来确保约束条件满足,这些问题导致 TPRO 实现复杂,没有大规模流行

1.4 小结

  • 置信域方法指的是一大类数值优化算法,通常用于求解非凸问题。对于一个最大化问题,算法重复两个步骤——做近似、最大化——直到算法收敛
  • 置信域策略优化(TRPO)是一种利用置信域算法优化策略的 On-policy Online RL 方法,它的优化目标和策略梯度方法相同,每次策略训练仅使用上一轮策略采样的数据,是 policy-based 类算法中十分有代表性的工作之一。直觉性地理解,TRPO 给出的观点是:由于策略的改变导致数据分布的改变,这大大影响深度模型实现的策略网络的学习效果,所以通过划定一个可信任的策略学习区域,保证策略学习的稳定性和有效性
  • TRPO中有两个需要调的超参数:一个是置信域的半径 RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_81,另一个是求解最大化问题的数值算法的学习率。通常来说, RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_81 在算法的运行过程中要逐渐缩小。虽然TRPO需要调参,但是TRPO对超参数的设置并不敏感,即使超参数设置不够好,TRPO的表现也不会太差。相比之下,策略梯度算法对超参数更敏感
  • TPRO 的优势在于更好的稳定性和更高的样本效率;缺点在于每步迭代求解约束优化问题的过程繁琐,算法实现复杂 ,其后续工作 PPO 很好地解决了此问题,成为了非常流行的 Online RL 方法

2. PPO(近端策略优化)方法
  • PPO 基于 TRPO 的思想,但是其算法实现更加简单。大量的实验结果表明,PPO 能和 TRPO 学习得一样好且收敛更快,这使得 PPO 和 SAC、TD3 一起成为三大最流行的强化学习算法。如果我们想要尝试在一个新的环境中使用强化学习,可以首先尝试这三个算法
  • PPO 算法框架和 TPRO 无异,其核心思想在于将 “最大化” 操作中的约束优化问题转换为无约束优化来简化问题

2.1 PPO 公式推导

  • 前文 1.3 节推 TPRO 优化目标时是从 policy gradient 法的原始优化目标开始推导的,那样推比较简单,得到优化目标为
    RL 实践(7)—— CartPole【TPRO & PPO】_最优化_83 但 TPRO 和 PPO 的原始论文中使用了另一种推导方法,最后得到的优化目标略有不同,为
    RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_84 其中 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_85 函数是前文 RL 实践(6)—— CartPole【REINFORCE with baseline & A2C】 中介绍的优势函数 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_86 这两种优化目标都是可行的,由于 TPRO 和 PPO 的论文都用了后者,这里也推导一下这个目标
  • 推导的出发点是希望借助当前参数 RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_38 推导出新的 RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_41 可以使得 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_89。这里优化目标设定为在初始状态分布 RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_90 下的状态价值期望 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_91,有
    RL 实践(7)—— CartPole【TPRO & PPO】_PPO_92 考虑到策略诱导的状态分布和初始分布 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_93,这样我们可以推导新旧策略的目标函数之间的差距
    RL 实践(7)—— CartPole【TPRO & PPO】_最优化_94 故只要能找到一个新策略 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_05 使得 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_96,就能保证策略性能单调递增 RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_97。去掉其中常数部分再用重要度采样改为用 RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_98 采样动作,就得到了 TPRO/PPO 的优化目标函数
    RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_84

2.1.1 做近似

  • 得到替代函数的方法完全类似 1.3.1 节,进行三次近似即可。具体而言
    1. 用当前策略 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_100 诱导的状态分布 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_101 近似 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_102,原始优化目标近似为
      RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_103

    2. 用 MC 近似消去上式中的两个期望。具体而言,先用当前策略 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_100 和环境交互收集一条轨迹
      RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_105 此轨迹满足 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_106,故每个 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_107 二元组都能构造一个无偏 MC 估计
      RL 实践(7)—— CartPole【TPRO & PPO】_最优化_108 用这些无偏估计的期望(均值)来近似原始优化目标,得到
      RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_109

    3. 最后我们考虑如何估计优势函数 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_110。目前比较常用的方法是 广义优势估计(Generalized Advantage Estimation,GAE),先简介一下 GAE

      首先将 TD Error 表示为 RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_111,其中 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_112 是一个已经学习的状态价值函数,根据多步 TD 思想有
      RL 实践(7)—— CartPole【TPRO & PPO】_最优化_113 GAE 将这些不同步数的优势估计进行指数加权平均:
      RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_114 其中,RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_115

      1. RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_116 时,RL 实践(7)—— CartPole【TPRO & PPO】_最优化_117
      2. RL 实践(7)—— CartPole【TPRO & PPO】_PPO_118 时,RL 实践(7)—— CartPole【TPRO & PPO】_最优化_119

      利用 GAE 估计优势函数 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_110 时,需要计算 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_121 交互得到的轨迹每个 timestep的 TD error RL 实践(7)—— CartPole【TPRO & PPO】_最优化_122,为此需要引入价值网络(critic)来估计 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_123,得到所有 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_124 后直接代入 GAE 公式 RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_125

2.1.2 最大化

  • 最大化这一步是 PPO 和 TPRO 唯一的区别,首先二者的置信域约束优化问题均可表示为
    RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_126
    1. TRPO 使用泰勒展开近似、共轭梯度、线性搜索等方法直接求解约束优化问题
    2. PPO 使用拉格朗日乘子法、限制目标函数等方式去除约束,然后就可以直接用梯度下降简单地求解无约束最优化问题
  • 具体来说,PPO 有两种形式,一是 PPO-惩罚,二是 PPO-截断,我们接下来对这两种形式进行介绍:
    1. PPO-惩罚:用拉格朗日乘数法直接将 KL 散度的限制放进了目标函数中,将原问题转换为无约束优化问题。迭代过程中根据真实的 KL 散度值(约束效果)不断更新 KL 散度前的拉格朗日乘数(调节约束强度)。第 RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_127 轮优化函数为:
      RL 实践(7)—— CartPole【TPRO & PPO】_PPO_128RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_129RL 实践(7)—— CartPole【TPRO & PPO】_最优化_130 的更新规则如下
      RL 实践(7)—— CartPole【TPRO & PPO】_PPO_131 其中 RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_132
    2. PPO-截断:直接在目标函数中进行限制,以保证新的参数和旧的参数的差距不会太大。第 RL 实践(7)—— CartPole【TPRO & PPO】_强化学习_127 轮优化函数为:
      RL 实践(7)—— CartPole【TPRO & PPO】_PPO_134 其中 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_135,即把 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_136 限制在 RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_137 内,上式中 RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_138 是一个超参数,表示进行截断的范围。注意 min 操作中的两个选择,后者就是把前者 clip 到 RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_139 而已。直接将两个系数的曲线如下画出来
    3. RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_140

    4. 其中绿色虚线是 RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_141,蓝色虚线是 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_142,红色实线是优势函数 RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_143 不同取值时 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_144 操作选出的系数。以左图 RL 实践(7)—— CartPole【TPRO & PPO】_PPO_145 的情况为例分析
      1. RL 实践(7)—— CartPole【TPRO & PPO】_PPO_146意味着状态 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_147 处动作 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_148 带来了好处,所以为了鼓励 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_148 出现系数应尽量大,但是不要超过 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_150 (就是说 RL 实践(7)—— CartPole【TPRO & PPO】_TPRO_151 处选择 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_152 的概率不要比现在高超过 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_150
      2. 系数小于 1 时说明网络还处于欠拟合状态,并没有学到此时应在 RL 实践(7)—— CartPole【TPRO & PPO】_最优化_147 位置鼓励选择动作 RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_148,这时就不用限制了。所以注意到红色线有上限而无下限
  • 大量实验表明,PPO-截断总是比 PPO-惩罚表现得更好

2.2 伪代码

  • PPO-截断的伪代码如下
  • RL 实践(7)—— CartPole【TPRO & PPO】_最优化_156


2.3 用 PPO 方法解决 CartPole 问题

  • 本节实验使用 gym 自带的 CartPole-V0 环境。这是一个经典的一阶倒立摆控制问题,agent 的任务是通过左右移动保持车上的杆竖直,若杆的倾斜度数过大,或者车子离初始位置左右的偏离程度过大,或者坚持时间到达 200 帧,则游戏结束
  • RL 实践(7)—— CartPole【TPRO & PPO】_pytorch_157

  • 关于此环境动作状态空间、奖励函数及初始状态分布等的详细说明请参考 CartPole-V0
  • 下面给出完整代码
  • import gym
    import torch
    import random
    import torch.nn.functional as F
    import numpy as np
    import matplotlib.pyplot as plt
    from tqdm import tqdm
    from gym.utils.env_checker import check_env
    from gym.wrappers import TimeLimit 
    
    class PolicyNet(torch.nn.Module):
        ''' 策略网络是一个两层 MLP '''
        def __init__(self, input_dim, hidden_dim, output_dim):
            super(PolicyNet, self).__init__()
            self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
            self.fc2 = torch.nn.Linear(hidden_dim, output_dim)
    
        def forward(self, x):
            x = F.relu(self.fc1(x))             # (1, hidden_dim)
            x = F.softmax(self.fc2(x), dim=1)   # (1, output_dim)
            return x
    
    class VNet(torch.nn.Module):
        ''' 价值网络是一个两层 MLP '''
        def __init__(self, input_dim, hidden_dim):
            super(VNet, self).__init__()
            self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
            self.fc2 = torch.nn.Linear(hidden_dim, 1)
    
        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return x
    
    class PPO(torch.nn.Module):
        def __init__(self, state_dim, hidden_dim, action_range, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device):
            super().__init__()
            self.actor = PolicyNet(state_dim, hidden_dim, action_range).to(device)
            self.critic = VNet(state_dim, hidden_dim).to(device) 
            self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
            self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
            
            self.device = device
            self.gamma = gamma
            self.lmbda = lmbda      # GAE 参数
            self.epochs = epochs    # 一条轨迹数据用来训练的轮数
            self.eps = eps          # PPO 中截断范围的参数
            self.device = device        
    
        def take_action(self, state):
            state = torch.tensor(state, dtype=torch.float).to(self.device)
            state = state.unsqueeze(0)
            probs = self.actor(state)
            action_dist = torch.distributions.Categorical(probs)
            action = action_dist.sample()
            return action.item()
    
        def compute_advantage(self, gamma, lmbda, td_delta):
            ''' 广义优势估计 GAE '''
            td_delta = td_delta.detach().numpy()
            advantage_list = []
            advantage = 0.0
            for delta in td_delta[::-1]:
                advantage = gamma * lmbda * advantage + delta
                advantage_list.append(advantage)
            advantage_list.reverse()
            return torch.tensor(np.array(advantage_list), dtype=torch.float)
    
        def update(self, transition_dict):
            states = torch.tensor(np.array(transition_dict['states']), dtype=torch.float).to(self.device)
            actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
            rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)
            next_states = torch.tensor(np.array(transition_dict['next_states']), dtype=torch.float).to(self.device)
            dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device)
    
            td_target = rewards + self.gamma * self.critic(next_states) * (1-dones)
            td_delta = td_target - self.critic(states)
            advantage = self.compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)
            old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()
    
            # 用刚采集的一条轨迹数据训练 epochs 轮
            for _ in range(self.epochs):
                log_probs = torch.log(self.actor(states).gather(1, actions))
                ratio = torch.exp(log_probs - old_log_probs)
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage  # 截断
                actor_loss = torch.mean(-torch.min(surr1, surr2))                   # PPO损失函数
                critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
                
                # 更新网络参数
                self.actor_optimizer.zero_grad()
                self.critic_optimizer.zero_grad()
                actor_loss.backward()
                critic_loss.backward()
                self.actor_optimizer.step()
                self.critic_optimizer.step()
    
    if __name__ == "__main__":
        def moving_average(a, window_size):
            ''' 生成序列 a 的滑动平均序列 '''
            cumulative_sum = np.cumsum(np.insert(a, 0, 0)) 
            middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_size
            r = np.arange(1, window_size-1, 2)
            begin = np.cumsum(a[:window_size-1])[::2] / r
            end = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]
            return np.concatenate((begin, middle, end))
    
        def set_seed(env, seed=42):
            ''' 设置随机种子 '''
            env.action_space.seed(seed)
            env.reset(seed=seed)
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
    
        state_dim = 4               # 环境观测维度
        action_range = 2            # 环境动作空间大小
        actor_lr = 1e-3
        critic_lr = 1e-2
        num_episodes = 500
        hidden_dim = 128
        gamma = 0.98
        lmbda = 0.95
        epochs = 10
        eps = 0.2
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    
        # build environment
        env_name = 'CartPole-v0'
        env = gym.make(env_name, render_mode='rgb_array')
        check_env(env.unwrapped)    # 检查环境是否符合 gym 规范
        set_seed(env, 0)
    
        # build agent
        agent = PPO(state_dim, hidden_dim, action_range, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device)
    
        # start training
        return_list = []
        for i in range(10):
            with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
                for i_episode in range(int(num_episodes / 10)):
                    episode_return = 0
                    transition_dict = {
                        'states': [],
                        'actions': [],
                        'next_states': [],
                        'next_actions': [],
                        'rewards': [],
                        'dones': []
                    }
                    state, _ = env.reset()
    
                    # 以当前策略交互得到一条轨迹
                    while True:
                        action = agent.take_action(state)
                        next_state, reward, terminated, truncated, _ = env.step(action)
                        next_action = agent.take_action(next_state)
                        transition_dict['states'].append(state)
                        transition_dict['actions'].append(action)
                        transition_dict['next_states'].append(next_state)
                        transition_dict['next_actions'].append(next_action)
                        transition_dict['rewards'].append(reward)
                        transition_dict['dones'].append(terminated or truncated)
                        state = next_state
                        episode_return += reward
                                            
                        if terminated or truncated:
                            break
                        #env.render()
    
                    # 用当前策略收集的数据进行 on-policy 更新
                    agent.update(transition_dict)
    
                    # 更新进度条
                    return_list.append(episode_return)
                    pbar.set_postfix({
                        'episode':
                        '%d' % (num_episodes / 10 * i + i_episode + 1),
                        'return':
                        '%.3f' % episode_return,
                        'ave return':
                        '%.3f' % np.mean(return_list[-10:])
                    })
                    pbar.update(1)
    
        # show policy performence
        mv_return_list = moving_average(return_list, 29)
        episodes_list = list(range(len(return_list)))
        plt.figure(figsize=(12,8))
        plt.plot(episodes_list, return_list, label='raw', alpha=0.5)
        plt.plot(episodes_list, mv_return_list, label='moving ave')
        plt.xlabel('Episodes')
        plt.ylabel('Returns')
        plt.title(f'{agent._get_name()} on CartPole-V0')
        plt.legend()
        plt.savefig(f'./result/{agent._get_name()}.png')
        plt.show()
  • 收敛曲线如下所示
  • RL 实践(7)—— CartPole【TPRO & PPO】_最优化_158

  • 可见 PPO 的收敛速度和稳定性都比 前文 介绍的 REINFORCE with baseline 和 A2C 方法好得多
3. 总结
  • 置信域策略优化(TRPO)是一种利用置信域算法优化策略的 On-policy Online RL 方法,它的优化目标和策略梯度方法相同,每次策略训练仅使用上一轮策略采样的数据,是 policy-based 类算法中十分有代表性的工作之一。直觉性地理解,TRPO 给出的观点是:由于策略的改变导致数据分布的改变,这大大影响深度模型实现的策略网络的学习效果,所以通过划定一个可信任的策略学习区域,保证策略学习的稳定性和有效性
  • 近端策略优化 (PPO) 是 TRPO 的一种改进算法,它在实现上简化了 TRPO 中的复杂计算,并且它在实验中的性能大多数情况下会比 TRPO 更好,因此目前常被用作一种常用的基准算法。需要注意的是,TRPO 和 PPO 都属于在线策略学习算法,即使优化目标中包含重要性采样的过程,但其只是用到了上一轮策略的数据,而不是过去所有策略的数据




【版权声明】本文内容来自摩杜云社区用户原创、第三方投稿、转载,内容版权归原作者所有。本网站的目的在于传递更多信息,不拥有版权,亦不承担相应法律责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@moduyun.com

上一篇: 5134. 简单判断 下一篇: docker 安装mysql
  1. 分享:
最后一次编辑于 2023年11月08日 0

暂无评论

VFpNeRYlMszB