提交 50bd0556 编写于 作者: V Varuna Jayasiri

ppo colab

上级 ac40d0a7
此差异已折叠。
......@@ -75,9 +75,10 @@
<h1>Generalized Advantage Estimation (GAE)</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of paper
<a href="https://arxiv.org/abs/1506.02438">Generalized Advantage Estimation</a>.</p>
<p>You can find an experiment that uses it <a href="experiment.html">here</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">13</span><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span></pre></div>
<div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
......@@ -88,7 +89,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">16</span><span class="k">class</span> <span class="nc">GAE</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">18</span><span class="k">class</span> <span class="nc">GAE</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
......@@ -99,11 +100,11 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">17</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_workers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">worker_steps</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">gamma</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">lambda_</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
<span class="lineno">18</span> <span class="bp">self</span><span class="o">.</span><span class="n">lambda_</span> <span class="o">=</span> <span class="n">lambda_</span>
<span class="lineno">19</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma</span> <span class="o">=</span> <span class="n">gamma</span>
<span class="lineno">20</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span> <span class="o">=</span> <span class="n">worker_steps</span>
<span class="lineno">21</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span> <span class="o">=</span> <span class="n">n_workers</span></pre></div>
<div class="highlight"><pre><span class="lineno">19</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_workers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">worker_steps</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">gamma</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">lambda_</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
<span class="lineno">20</span> <span class="bp">self</span><span class="o">.</span><span class="n">lambda_</span> <span class="o">=</span> <span class="n">lambda_</span>
<span class="lineno">21</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma</span> <span class="o">=</span> <span class="n">gamma</span>
<span class="lineno">22</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span> <span class="o">=</span> <span class="n">worker_steps</span>
<span class="lineno">23</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span> <span class="o">=</span> <span class="n">n_workers</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
......@@ -142,7 +143,7 @@ $\hat{A_t}$</p>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">23</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">done</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">rewards</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">values</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">25</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">done</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">rewards</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">values</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
......@@ -153,8 +154,8 @@ $\hat{A_t}$</p>
<p>advantages table</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">advantages</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="lineno">57</span> <span class="n">last_advantage</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
<div class="highlight"><pre><span class="lineno">58</span> <span class="n">advantages</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="lineno">59</span> <span class="n">last_advantage</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
......@@ -165,9 +166,9 @@ $\hat{A_t}$</p>
<p>$V(s_{t+1})$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="n">last_value</span> <span class="o">=</span> <span class="n">values</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="lineno">61</span>
<span class="lineno">62</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">)):</span></pre></div>
<div class="highlight"><pre><span class="lineno">62</span> <span class="n">last_value</span> <span class="o">=</span> <span class="n">values</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="lineno">63</span>
<span class="lineno">64</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">worker_steps</span><span class="p">)):</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
......@@ -178,9 +179,9 @@ $\hat{A_t}$</p>
<p>mask if episode completed after step $t$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">64</span> <span class="n">mask</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">done</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span>
<span class="lineno">65</span> <span class="n">last_value</span> <span class="o">=</span> <span class="n">last_value</span> <span class="o">*</span> <span class="n">mask</span>
<span class="lineno">66</span> <span class="n">last_advantage</span> <span class="o">=</span> <span class="n">last_advantage</span> <span class="o">*</span> <span class="n">mask</span></pre></div>
<div class="highlight"><pre><span class="lineno">66</span> <span class="n">mask</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">done</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span>
<span class="lineno">67</span> <span class="n">last_value</span> <span class="o">=</span> <span class="n">last_value</span> <span class="o">*</span> <span class="n">mask</span>
<span class="lineno">68</span> <span class="n">last_advantage</span> <span class="o">=</span> <span class="n">last_advantage</span> <span class="o">*</span> <span class="n">mask</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
......@@ -191,7 +192,7 @@ $\hat{A_t}$</p>
<p>$\delta_t$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">delta</span> <span class="o">=</span> <span class="n">rewards</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma</span> <span class="o">*</span> <span class="n">last_value</span> <span class="o">-</span> <span class="n">values</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">delta</span> <span class="o">=</span> <span class="n">rewards</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma</span> <span class="o">*</span> <span class="n">last_value</span> <span class="o">-</span> <span class="n">values</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
......@@ -202,7 +203,7 @@ $\hat{A_t}$</p>
<p>$\hat{A_t} = \delta_t + \gamma \lambda \hat{A_{t+1}}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="n">last_advantage</span> <span class="o">=</span> <span class="n">delta</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">lambda_</span> <span class="o">*</span> <span class="n">last_advantage</span></pre></div>
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">last_advantage</span> <span class="o">=</span> <span class="n">delta</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">lambda_</span> <span class="o">*</span> <span class="n">last_advantage</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
......@@ -219,11 +220,11 @@ The performance of the model was improving
probably because the samples are similar.</em></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="n">advantages</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="n">last_advantage</span>
<span class="lineno">81</span>
<span class="lineno">82</span> <span class="n">last_value</span> <span class="o">=</span> <span class="n">values</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span>
<div class="highlight"><pre><span class="lineno">82</span> <span class="n">advantages</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="n">last_advantage</span>
<span class="lineno">83</span>
<span class="lineno">84</span> <span class="k">return</span> <span class="n">advantages</span></pre></div>
<span class="lineno">84</span> <span class="n">last_value</span> <span class="o">=</span> <span class="n">values</span><span class="p">[:,</span> <span class="n">t</span><span class="p">]</span>
<span class="lineno">85</span>
<span class="lineno">86</span> <span class="k">return</span> <span class="n">advantages</span></pre></div>
</div>
</div>
</div>
......
......@@ -85,12 +85,14 @@ It does so by clipping gradient flow if the updated policy
is not close to the policy used to sample the data.</p>
<p>You can find an experiment that uses it <a href="experiment.html">here</a>.
The experiment uses <a href="gae.html">Generalized Advantage Estimation</a>.</p>
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/rl/ppo/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
<a href="https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">26</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">27</span>
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">29</span><span class="kn">from</span> <span class="nn">labml_nn.rl.ppo.gae</span> <span class="kn">import</span> <span class="n">GAE</span></pre></div>
<div class="highlight"><pre><span class="lineno">29</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">30</span>
<span class="lineno">31</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">32</span><span class="kn">from</span> <span class="nn">labml_nn.rl.ppo.gae</span> <span class="kn">import</span> <span class="n">GAE</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
......@@ -195,7 +197,7 @@ J(\pi_\theta) - J(\pi_{\theta_{OLD}})
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span><span class="k">class</span> <span class="nc">ClippedPPOLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">35</span><span class="k">class</span> <span class="nc">ClippedPPOLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
......@@ -206,8 +208,8 @@ J(\pi_\theta) - J(\pi_{\theta_{OLD}})
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">134</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">136</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">137</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
......@@ -218,8 +220,8 @@ J(\pi_\theta) - J(\pi_{\theta_{OLD}})
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">137</span> <span class="n">advantage</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">139</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_log_pi</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">140</span> <span class="n">advantage</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
......@@ -231,7 +233,7 @@ J(\pi_\theta) - J(\pi_{\theta_{OLD}})
<em>this is different from rewards</em> $r_t$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">ratio</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_pi</span> <span class="o">-</span> <span class="n">sampled_log_pi</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">143</span> <span class="n">ratio</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_pi</span> <span class="o">-</span> <span class="n">sampled_log_pi</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
......@@ -267,14 +269,14 @@ Large deviation can cause performance collapse;
but it reduces variance a lot.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">169</span> <span class="n">clipped_ratio</span> <span class="o">=</span> <span class="n">ratio</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">clip</span><span class="p">,</span>
<span class="lineno">170</span> <span class="nb">max</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">clip</span><span class="p">)</span>
<span class="lineno">171</span> <span class="n">policy_reward</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">,</span>
<span class="lineno">172</span> <span class="n">clipped_ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">)</span>
<span class="lineno">173</span>
<span class="lineno">174</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_fraction</span> <span class="o">=</span> <span class="p">(</span><span class="nb">abs</span><span class="p">((</span><span class="n">ratio</span> <span class="o">-</span> <span class="mf">1.0</span><span class="p">))</span> <span class="o">&gt;</span> <span class="n">clip</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
<span class="lineno">175</span>
<span class="lineno">176</span> <span class="k">return</span> <span class="o">-</span><span class="n">policy_reward</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">172</span> <span class="n">clipped_ratio</span> <span class="o">=</span> <span class="n">ratio</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">clip</span><span class="p">,</span>
<span class="lineno">173</span> <span class="nb">max</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">clip</span><span class="p">)</span>
<span class="lineno">174</span> <span class="n">policy_reward</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">,</span>
<span class="lineno">175</span> <span class="n">clipped_ratio</span> <span class="o">*</span> <span class="n">advantage</span><span class="p">)</span>
<span class="lineno">176</span>
<span class="lineno">177</span> <span class="bp">self</span><span class="o">.</span><span class="n">clip_fraction</span> <span class="o">=</span> <span class="p">(</span><span class="nb">abs</span><span class="p">((</span><span class="n">ratio</span> <span class="o">-</span> <span class="mf">1.0</span><span class="p">))</span> <span class="o">&gt;</span> <span class="n">clip</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
<span class="lineno">178</span>
<span class="lineno">179</span> <span class="k">return</span> <span class="o">-</span><span class="n">policy_reward</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
......@@ -300,7 +302,7 @@ V^{\pi_\theta}_{CLIP}(s_t)
significantly from $V_{\theta_{OLD}}$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">179</span><span class="k">class</span> <span class="nc">ClippedValueFunctionLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">182</span><span class="k">class</span> <span class="nc">ClippedValueFunctionLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
......@@ -311,10 +313,10 @@ V^{\pi_\theta}_{CLIP}(s_t)
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">200</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_return</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
<span class="lineno">201</span> <span class="n">clipped_value</span> <span class="o">=</span> <span class="n">sampled_value</span> <span class="o">+</span> <span class="p">(</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_value</span><span class="p">)</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=-</span><span class="n">clip</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">clip</span><span class="p">)</span>
<span class="lineno">202</span> <span class="n">vf_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="p">(</span><span class="n">clipped_value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
<span class="lineno">203</span> <span class="k">return</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">vf_loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">203</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">sampled_return</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">clip</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
<span class="lineno">204</span> <span class="n">clipped_value</span> <span class="o">=</span> <span class="n">sampled_value</span> <span class="o">+</span> <span class="p">(</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_value</span><span class="p">)</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=-</span><span class="n">clip</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">clip</span><span class="p">)</span>
<span class="lineno">205</span> <span class="n">vf_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="n">value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="p">(</span><span class="n">clipped_value</span> <span class="o">-</span> <span class="n">sampled_return</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
<span class="lineno">206</span> <span class="k">return</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">vf_loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div>
</div>
</div>
......
......@@ -85,6 +85,8 @@ It does so by clipping gradient flow if the updated policy
is not close to the policy used to sample the data.</p>
<p>You can find an experiment that uses it <a href="https://nn.labml.ai/rl/ppo/experiment.html">here</a>.
The experiment uses <a href="https://nn.labml.ai/rl/ppo/gae.html">Generalized Advantage Estimation</a>.</p>
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/rl/ppo/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
<a href="https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
......
......@@ -699,6 +699,13 @@
</url>
<url>
<loc>https://nn.labml.ai/rl/ppo/experiment.html</loc>
<lastmod>2021-03-30T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/rl/ppo/index.html</loc>
<lastmod>2021-03-27T16:30:00+00:00</lastmod>
......@@ -722,7 +729,7 @@
<url>
<loc>https://nn.labml.ai/rl/ppo/experiment.html</loc>
<lastmod>2021-03-27T16:30:00+00:00</lastmod>
<lastmod>2021-03-30T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
......
......@@ -21,6 +21,9 @@ is not close to the policy used to sample the data.
You can find an experiment that uses it [here](experiment.html).
The experiment uses [Generalized Advantage Estimation](gae.html).
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/rl/ppo/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f)
"""
import torch
......
......@@ -86,7 +86,10 @@
"id": "-OnHLi626tJt"
},
"source": [
"Configurations"
"### Configurations\n",
"\n",
"`IntDynamicHyperParam` and `FloatDynamicHyperParam` are dynamic hyper parameters\n",
"that you can change while the experiment is running."
]
},
{
......@@ -223,4 +226,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
\ No newline at end of file
......@@ -8,6 +8,9 @@ summary: Annotated implementation to train a PPO agent on Atari Breakout game.
This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym.
It runs the [game environments on multiple processes](../game.html) to sample efficiently.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/rl/ppo/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f)
"""
from typing import Dict
......@@ -354,23 +357,31 @@ def main():
experiment.create(name='ppo')
# Configurations
configs = {
# number of updates
# Number of updates
'updates': 10000,
# number of epochs to train the model with sampled data
# ⚙️ Number of epochs to train the model with sampled data.
# You can change this while the experiment is running.
# [![Example](https://img.shields.io/badge/example-hyperparams-brightgreen)](https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f/hyper_params)
'epochs': IntDynamicHyperParam(8),
# number of worker processes
# Number of worker processes
'n_workers': 8,
# number of steps to run on each process for a single update
# Number of steps to run on each process for a single update
'worker_steps': 128,
# number of mini batches
# Number of mini batches
'batches': 4,
# Value loss coefficient
# ⚙️ Value loss coefficient.
# You can change this while the experiment is running.
# [![Example](https://img.shields.io/badge/example-hyperparams-brightgreen)](https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f/hyper_params)
'value_loss_coef': FloatDynamicHyperParam(0.5),
# Entropy bonus coefficient
# ⚙️ Entropy bonus coefficient.
# You can change this while the experiment is running.
# [![Example](https://img.shields.io/badge/example-hyperparams-brightgreen)](https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f/hyper_params)
'entropy_bonus_coef': FloatDynamicHyperParam(0.01),
# Clip range
# ⚙️ Clip range.
'clip_range': FloatDynamicHyperParam(0.1),
# Learning rate
# You can change this while the experiment is running.
# [![Example](https://img.shields.io/badge/example-hyperparams-brightgreen)](https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f/hyper_params)
# ⚙️ Learning rate.
'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)),
}
......
......@@ -8,6 +8,8 @@ summary: A PyTorch implementation/tutorial of Generalized Advantage Estimation (
This is a [PyTorch](https://pytorch.org) implementation of paper
[Generalized Advantage Estimation](https://arxiv.org/abs/1506.02438).
You can find an experiment that uses it [here](experiment.html).
"""
import numpy as np
......
......@@ -13,4 +13,7 @@ It does so by clipping gradient flow if the updated policy
is not close to the policy used to sample the data.
You can find an experiment that uses it [here](https://nn.labml.ai/rl/ppo/experiment.html).
The experiment uses [Generalized Advantage Estimation](https://nn.labml.ai/rl/ppo/gae.html).
\ No newline at end of file
The experiment uses [Generalized Advantage Estimation](https://nn.labml.ai/rl/ppo/gae.html).
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/rl/ppo/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/6eff28a0910e11eb9b008db315936e2f)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册