提交 86ec4895 编写于 作者: W wizardforcel

2021-01-19 11:32:28

上级 e9c1289b
......@@ -6,7 +6,7 @@
此实现使用 PyTorch 张量上的运算来计算正向传播,并使用 PyTorch autograd 来计算梯度。
在此实现中,我们实现了自己的自定义 autograd 函数来执行\(P_3'(x)\)。 通过数学,\(P_3'(x)= rac {3} {2} \ left(5x ^ 2-1 ight)\)
在此实现中,我们实现了自己的自定义 autograd 函数来执行`P'[3](x)`。 通过数学,\(P_3'(x)= rac {3} {2} \ left(5x ^ 2-1 ight)\)
```py
import torch
......
......@@ -107,13 +107,13 @@ Q 学习的主要思想是,如果我们有一个函数\(Q ^ *:State \ time
\[\pi^*(s) = \arg\!\max_a \ Q^*(s, a)\]
但是,我们对世界一无所知,因此无法访问\(Q ^ * \)。 但是,由于神经网络是通用函数逼近器,因此我们可以简单地创建一个并将其训练为类似于\(Q ^ * \)的函数。
但是,我们对世界一无所知,因此无法访问`Q*`。 但是,由于神经网络是通用函数逼近器,因此我们可以简单地创建一个并将其训练为类似于`Q*`的函数。
对于我们的训练更新规则,我们将使用一个事实,即某些策略的每个`Q`函数都遵循 Bellman 方程:
\[Q^{\pi}(s, a) = r + \gamma Q^{\pi}(s', \pi(s'))\]
等式两侧之间的差异称为时间差异误差\(\ delta \)
等式两侧之间的差异称为时间差异误差`delta`
\[\delta = Q(s, a) - (r + \gamma \max_a Q(s', a))\]
......
......@@ -311,9 +311,9 @@ class Mario(Mario): # subclassing for continuity
### 学习
Mario 在后台使用 [DDQN 算法](https://arxiv.org/pdf/1509.06461)。 DDQN 使用两个 ConvNet-\(Q_ {online} \)\(Q_ {target} \)-独立地逼近最佳作用值函数。
Mario 在后台使用 [DDQN 算法](https://arxiv.org/pdf/1509.06461)。 DDQN 使用两个 ConvNet-`Q_online``Q_target`-独立地逼近最佳作用值函数。
在我们的实现中,我们在\(Q_ {online} \)\(Q_ {target} \)之间共享特征生成器`features`,但是为每个特征维护单独的 FC 分类器。 `θ_target`\(Q_ {target} \)的参数)被冻结,以防止反向传播进行更新。 而是定期与`θ_online`同步(稍后会对此进行详细介绍)。
在我们的实现中,我们在`Q_online``Q_target`之间共享特征生成器`features`,但是为每个特征维护单独的 FC 分类器。 `θ_target``Q_target`的参数)被冻结,以防止反向传播进行更新。 而是定期与`θ_online`同步(稍后会对此进行详细介绍)。
#### 神经网络
......@@ -363,15 +363,15 @@ class MarioNet(nn.Module):
学习涉及两个值:
**TD 估计**-给定状态`s`的预测最佳\(Q ^ * \)
**TD 估计**-给定状态`s`的预测最佳`Q*`
\[{TD}_e = Q_{online}^*(s,a)\]
**TD 目标**-当前奖励和下一状态`s'`中的估计\(Q ^ * \)的汇总
**TD 目标**-当前奖励和下一状态`s'`中的估计`Q*`的汇总
\[a' = argmax_{a} Q_{online}(s', a)\] \[{TD}_t = r + \gamma Q_{target}^*(s',a')\]
由于我们不知道下一个动作`a'`是什么,因此我们在下一个状态`s'`中使用动作`a'`最大化\(Q_ {online} \)
由于我们不知道下一个动作`a'`是什么,因此我们在下一个状态`s'`中使用动作`a'`最大化`Q_online`
请注意,我们在`td_target()`上使用了 [@ torch.no_grad()](https://pytorch.org/docs/stable/generated/torch.no_grad.html#no-grad)装饰器来禁用梯度计算(因为我们无需在`θ_target`上进行反向传播。)
......@@ -400,7 +400,7 @@ class Mario(Mario):
#### 更新模型
当 Mario 从其重播缓冲区中采样输入时,我们计算`TD_t``TD_e`并反向传播该损耗\(Q_ {online} \)以更新其参数`θ_online`\ \ alpha \)是传递给`optimizer`的学习率`lr`
当 Mario 从其重播缓冲区中采样输入时,我们计算`TD_t``TD_e`并反向传播该损耗`Q_online`以更新其参数`θ_online`\ \ alpha \)是传递给`optimizer`的学习率`lr`
\[\theta_{online} \leftarrow \theta_{online} + \alpha \nabla(TD_e - TD_t)\]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册