提交 2a8e7f2c 编写于 作者: W wanghaoshuang

Add README

上级 bbeb1e98
# Policy Gradient RL by PaddlePaddle
本文介绍了如何使用PaddlePaddle通过policy-based的强化学习方法来训练一个player(actor model), 我们希望这个player可以完成简单的走阶梯任务。
内容分为:
- 任务描述
- 模型
- 策略(目标函数)
- 算法(Gradient ascent)
- PaddlePaddle实现
## 1. 任务描述
假设有一个阶梯,连接A、B点,player从A点出发,每一步只能向前走一步或向后走一步,到达B点即为完成任务。我们希望训练一个聪明的player,它知道怎么最快的从A点到达B点。
我们在命令行以下边的形式模拟任务:
```
A - O - - - - - B
```
一个‘-'代表一个阶梯,A点在行头,B点在行末,O代表player当前在的位置。
## 2. Policy Gradient
### 2.1 模型
#### inputyer
模型的输入是player观察到的当前阶梯的状态$S$, 要包含阶梯的长度和player当前的位置信息。
在命令行模拟的情况下,player的位置和阶梯长度连个变量足以表示当前的状态,但是我们为了便于将这个demo推广到更复杂的任务场景,我们这里用一个向量来表示游戏状态$S$.
向量$S$的长度为阶梯的长度,每一维代表一个阶梯,player所在的位置为1,其它位置为0.
下边是一个例子:
```
S = [0, 1, 0, 0] // 阶梯长度为4,player在第二个阶梯上。
```
#### hidden layer
隐藏层采用两个全连接layer `FC_1``FC_2`, 其中`FC_1` 的size为10, `FC_2`的size为2.
#### output layer
我们使用softmax将`FC_2`的output映射为所有可能的动作(前进或后退)的概率分布(Probability of taking the action),即为一个二维向量`act_probs`, 其中,`act_probs[0]` 为后退的概率,`act_probs[1]`为前进的概率。
#### 模型表示
我将我们的player模型(actor)形式化表示如下:
$$a = \pi_\theta(s)$$
其中$\theta$表示模型的参数,$s$是输入状态。
### 2.2 策略(目标函数)
我们怎么评估一个player(模型)的好坏呢?首先我们定义几个术语:
我们让$\pi_\theta(s)$来玩一局游戏,$s_t$表示第$t$时刻的状态,$a_t$表示在状态$s_t$做出的动作,$r_t$表示做过动作$a_t$后得到的奖赏。
一局游戏的过程可以表示如下:
$$\tau = [s_1, a_1, r_1, s_2, a_2, r_2 ... s_T, a_T, r_T] \tag{1}$$
一局游戏的奖励表示如下:
$$R(\tau) = \sum_{t=1}^Tr_t$$
player玩一局游戏,可能会出现多种操作序列$\tau$ ,某个$\tau$出现的概率是依赖于player model的$\theta$, 记做:
$$P(\tau | \theta)$$
那么,给定一个$\theta$(player model), 玩一局游戏,期望得到的奖励是:
$$\overline {R}_\theta = \sum_\tau R(\tau)\sum_\tau R(\tau) P(\tau|\theta)$$
大多数情况,我们无法穷举出所有的$\tau$,所以我们就抽取N个$\tau$来计算近似的期望:
$$\overline {R}_\theta = \sum_\tau R(\tau) P(\tau|\theta) \approx \frac{1}{N} \sum_{n=1}^N R(\tau^n)$$
$\overline {R}_\theta$就是我们需要的目标函数,它表示了一个参数为$\theta$的player玩一局游戏得分的期望,这个期望越大,代表这个player能力越强。
### 2.3 算法(Gradient ascent)
我们的目标函数是$\overline {R}_\theta$, 我们训练的任务就是, 我们训练的任务就是:
$$\theta^* = \arg\max_\theta \overline {R}_\theta$$
为了找到理想的$\theta$,我们使用Gradient ascent方法不断在$\overline {R}_\theta$的梯度方向更新$\theta$,可表示如下:
$$\theta' = \theta + \eta * \bigtriangledown \overline {R}_\theta$$
$$ \bigtriangledown \overline {R}_\theta = \sum_\tau R(\tau) \bigtriangledown P(\tau|\theta)\\
= \sum_\tau R(\tau) P(\tau|\theta) \frac{\bigtriangledown P(\tau|\theta)}{P(\tau|\theta)} \\
=\sum_\tau R(\tau) P(\tau|\theta) {\bigtriangledown \log P(\tau|\theta)} $$
$$P(\tau|\theta) = P(s_1)P(a_1|s_1,\theta)P(s_2, r_1|s_1,a_1)P(a_2|s_2,\theta)P(s_3,r_2|s_2,a_2)...P(a_t|s_t,\theta)P(s_{t+1}, r_t|s_t,a_t)\\
=P(s_1) \sum_{t=1}^T P(a_t|s_t,\theta)P(s_{t+1}, r_t|s_t,a_t)$$
$$\log P(\tau|\theta) = \log P(s_1) + \sum_{t=1}^T [\log P(a_t|s_t,\theta) + \log P(s_{t+1}, r_t|s_t,a_t)]$$
$$ \bigtriangledown \log P(\tau|\theta) = \sum_{t=1}^T \bigtriangledown \log P(a_t|s_t,\theta)$$
$$ \bigtriangledown \overline {R}_\theta = \sum_\tau R(\tau) P(\tau|\theta) {\bigtriangledown \log P(\tau|\theta)} \\
\approx \frac{1}{N} \sum_{n=1}^N R(\tau^n) {\bigtriangledown \log P(\tau|\theta)} \\
= \frac{1}{N} \sum_{n=1}^N R(\tau^n) {\sum_{t=1}^T \bigtriangledown \log P(a_t|s_t,\theta)} \\
= \frac{1}{N} \sum_{n=1}^N \sum_{t=1}^T R(\tau^n) { \bigtriangledown \log P(a_t|s_t,\theta)} \tag{11}$$
#### 2.3.2 导数解释
在使用深度学习框架进行训练求解时,一般用梯度下降方法,所以我们把Gradient ascent转为Gradient
descent, 重写等式$(5)(6)$为:
$$\theta^* = \arg\min_\theta (-\overline {R}_\theta \tag{13}$$
$$\theta' = \theta - \eta * \bigtriangledown (-\overline {R}_\theta)) \tag{14}$$
根据上一节的推导,$ (-\bigtriangledown \overline {R}_\theta) $结果如下:
$$ -\bigtriangledown \overline {R}_\theta
= \frac{1}{N} \sum_{n=1}^N \sum_{t=1}^T R(\tau^n) { \bigtriangledown -\log P(a_t|s_t,\theta)} \tag{15}$$
根据等式(14), 我们的player的模型可以设计为:
![图片](http://bos.nj.bpc.baidu.com/v1/agroup/5f762f001d4a421bc06964d39cc78859e1a1e331)
图 1
用户的在一局游戏中的一次操作可以用元组$(s_t, a_t)$, 就是在状态$s_t$状态下做了动作$a_t$, 我们通过图(1)中的前向网络计算出来cross entropy cost为$−\log P(a_t|s_t,\theta)$, 恰好是等式(15)中我们需要微分的一项。
图1是我们需要的player模型,我用这个网络的前向计算可以预测任何状态下该做什么动作。但是怎么去训练学习这个网络呢?在等式(15)中还有一项$R(\tau^n)$, 我做反向梯度传播的时候要加上这一项,所以我们需要在图1基础上再加上$R(\tau^n)$, 如 图2 所示:
![图片](http://bos.nj.bpc.baidu.com/v1/agroup/b639162977cc9c1f612be8fdf31ec99d73630f97)
图2
图2就是我们最终的网络结构。
#### 2.3.3 直观理解
对于等式(15),我只看游戏中的一步操作,也就是这一项: $R(\tau^n) { \bigtriangledown -\log P(a_t|s_t,\theta)}$, 我们可以简单的认为我们训练的目的是让 $R(\tau^n) {[ -\log P(a_t|s_t,\theta)]}$尽可能的小,也就是$R(\tau^n) \log P(a_t|s_t,\theta)$尽可能的大。
- 如果我们当前游戏局的奖励$R(\tau^n)$为正,那么我们希望当前操作的出现的概率$P(a_t|s_t,\theta)$尽可能大。
- 如果我们当前游戏局的奖励$R(\tau^n)$为负,那么我们希望当前操作的出现的概率$P(a_t|s_t,\theta)$尽可能小。
#### 2.3.4 一个问题
一人犯错,诛连九族。一人得道,鸡犬升天。如果一局游戏得到奖励,我们希望帮助获得奖励的每一次操作都被重视;否则,导致惩罚的操作都要被冷落一次。
是不是很有道理的样子?但是,如果有些游戏场景只有奖励,没有惩罚,怎么办?也就是所有的$R(\tau^n)$都为正。
针对不同的游戏场景,我们有不同的解决方案:
1. 每局游戏得分不一样:将每局的得分减去一个bias,结果就有正有负了。
2. 每局游戏得分一样:把完成一局的时间作为计分因素,并减去一个bias.
我们在第一章描述的游戏场景,需要用第二种 ,player每次到达终点都会收到1分的奖励,我们可以按完成任务所用的步数来定义奖励R.
更进一步,我们认为一局游戏中每步动作对结局的贡献是不同的,有聪明的动作,也有愚蠢的操作。直观的理解,一般是靠前的动作是愚蠢的,靠后的动作是聪明的。既然有了这个价值观,那么我们拿到1分的奖励,就不能平均分给每个动作了。
如图3所示,让所有动作按先后排队,从后往前衰减地给每个动作奖励,然后再每个动作的奖励再减去所有动作奖励的平均值:
![图片](https://github.com/PaddlePaddle/models/blob/develop/policy_gradient/images/PG_3.svg)
## 3. 训练效果
demo运行训练效果如下,经过1000轮尝试,我们的player就学会了如何有效的完成任务了:
```
---------O epoch: 0; steps: 42
---------O epoch: 1; steps: 77
---------O epoch: 2; steps: 82
---------O epoch: 3; steps: 64
---------O epoch: 4; steps: 79
---------O epoch: 501; steps: 19
---------O epoch: 1001; steps: 9
---------O epoch: 1501; steps: 9
---------O epoch: 2001; steps: 11
---------O epoch: 2501; steps: 9
---------O epoch: 3001; steps: 9
---------O epoch: 3002; steps: 9
---------O epoch: 3003; steps: 9
---------O epoch: 3004; steps: 9
---------O epoch: 3005; steps: 9
---------O epoch: 3006; steps: 9
---------O epoch: 3007; steps: 9
---------O epoch: 3008; steps: 9
---------O epoch: 3009; steps: 9
---------O epoch: 3010; steps: 11
---------O epoch: 3011; steps: 9
---------O epoch: 3012; steps: 9
---------O epoch: 3013; steps: 9
---------O epoch: 3014; steps: 9
```
<?xml version="1.0" encoding="utf-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xl="http://www.w3.org/1999/xlink" version="1.1" viewBox="162 59 594 567" width="594pt" height="567pt" xmlns:dc="http://purl.org/dc/elements/1.1/"><metadata> Produced by OmniGraffle 6.0.5 <dc:date>2017-12-01 08:39Z</dc:date></metadata><defs><font-face font-family="STIXGeneral" font-size="20" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-816.5001" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><marker orient="auto" overflow="visible" markerUnits="strokeWidth" id="FilledArrow_Marker" viewBox="-1 -3 6 6" markerWidth="6" markerHeight="6" color="#9a9a9a"><g><path d="M 3.7333333 0 L 0 -1.4 L 0 1.4 Z" fill="currentColor" stroke="currentColor" stroke-width="1"/></g></marker><font-face font-family="Helvetica Neue" font-size="20" panose-1="2 0 8 3 0 0 0 9 0 4" units-per-em="1000" underline-position="-100" underline-thickness="50" slope="0" x-height="517" cap-height="714" ascent="975.0061" descent="-216.99524" font-weight="bold"><font-face-src><font-face-name name="HelveticaNeue-Bold"/></font-face-src></font-face><font-face font-family="Helvetica Neue" font-size="21" panose-1="2 0 5 3 0 0 0 2 0 4" units-per-em="1000" underline-position="-100" underline-thickness="50" slope="0" x-height="517" cap-height="714" ascent="951.99585" descent="-212.99744" font-weight="500"><font-face-src><font-face-name name="HelveticaNeue"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="21" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-777.61913" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="19" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-859.4738" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="19" units-per-em="1000" underline-position="-227" underline-thickness="66" slope="0" x-height="450" cap-height="662" ascent="1055.00214" descent="-455.00092" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Regular"/></font-face-src></font-face></defs><g stroke="none" stroke-opacity="1" stroke-dasharray="none" fill="none" fill-opacity="1"><title>神经网络</title><g><title>Layer 1</title><circle cx="312.32677" cy="437.85433" r="32.500052" fill="#d5c0ff"/><circle cx="312.32677" cy="437.85433" r="32.500052" stroke="#695f7e" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><circle cx="312.32677" cy="581.58662" r="32.500052" fill="#bfeaff"/><circle cx="312.32677" cy="581.58662" r="32.500052" stroke="#5f747e" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(288.32677 566.58662)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="9.56" y="21" textLength="28.88">x_2</tspan></text><circle cx="312.32677" cy="232.15354" r="32.500052" fill="#ffd6d8"/><circle cx="312.32677" cy="232.15354" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(288.32677 217.15354)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="9.56" y="21" textLength="28.88">y_2</tspan></text><circle cx="207.32677" cy="232.15354" r="32.500052" fill="#ffd6d8"/><circle cx="207.32677" cy="232.15354" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(183.32677 217.15354)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="9.56" y="21" textLength="28.88">y_0</tspan></text><circle cx="207.32677" cy="581.58662" r="32.500052" fill="#bfeaff" fill-opacity=".8"/><circle cx="207.32677" cy="581.58662" r="32.500052" stroke="#5f747e" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(183.32677 566.58662)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="9.56" y="21" textLength="28.88">x_1</tspan></text><circle cx="207.32677" cy="437.85433" r="32.500052" fill="#d5c0ff" fill-opacity=".8"/><circle cx="207.32677" cy="437.85433" r="32.500052" stroke="#695f7e" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><circle cx="473.78347" cy="581.58662" r="32.500052" fill="#bfeaff"/><circle cx="473.78347" cy="581.58662" r="32.500052" stroke="#5f747e" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(449.78347 566.58662)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="9.56" y="21" textLength="28.88">x_n</tspan></text><rect x="180.33071" y="313.22835" width="322.95276" height="44.811023" fill="#ffec8a"/><rect x="180.33071" y="313.22835" width="322.95276" height="44.811023" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(182.33071 320.63386)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="127.24638" y="21" textLength="64.46">Softmax</tspan></text><circle cx="467.02756" cy="232.06693" r="32.500052" fill="#ffd6d8"/><circle cx="467.02756" cy="232.06693" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(443.02756 217.06693)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="7.34" y="21" textLength="33.32">y_m</tspan></text><circle cx="473.32284" cy="437.85433" r="32.500052" fill="#d5c0ff"/><circle cx="473.32284" cy="437.85433" r="32.500052" stroke="#695f7e" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><line x1="207.53024" y1="404.85494" x2="207.72086" y2="373.93907" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="207.32676" y1="548.5866" x2="207.32676" y2="486.75435" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="312.32676" y1="548.5866" x2="312.32676" y2="486.75435" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="473.68325" y1="548.58675" x2="473.49546" y2="486.75404" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="473.11448" y1="404.85497" x2="472.91929" y2="373.93906" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="312.1168" y1="404.85498" x2="311.92007" y2="373.93905" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="205.8189" y1="312.03937" x2="206.40393" y2="281.04489" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="311.82677" y1="312.57482" x2="312.02275" y2="281.05262" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="468.52756" y1="312.57482" x2="467.9385" y2="280.9585" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="292.85776" y1="554.9359" x2="236.17501" y2="477.34407" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="444.7321" y1="565.9156" x2="250.37242" y2="461.07324" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="336.94445" y1="559.60872" x2="436.84423" y2="470.4213" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="226.79579" y1="554.9359" x2="283.47854" y2="477.34407" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="236.36684" y1="565.89467" x2="430.29435" y2="461.105" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="449.13498" y1="559.64322" x2="348.85352" y2="470.36736" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><circle cx="310.83071" cy="103.452757" r="32.500052" fill="#c2ffc4"/><circle cx="310.83071" cy="103.452757" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(286.83071 88.452757)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="8.99" y="21" textLength="30.02">a_2</tspan></text><circle cx="205.83071" cy="103.452757" r="32.500052" fill="#c2ffc4"/><circle cx="205.83071" cy="103.452757" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(181.83071 88.452757)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="8.99" y="21" textLength="30.02">a_0</tspan></text><circle cx="465.5315" cy="103.36614" r="32.500052" fill="#c2ffc4"/><circle cx="465.5315" cy="103.36614" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(441.5315 88.36614)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="6.77" y="21" textLength="34.46">a_m</tspan></text><text transform="translate(371.41733 91.03937)" fill="black"><tspan font-family="Helvetica Neue" font-size="20" font-weight="bold" x="3.2771656" y="20" textLength="27.8">. . .</tspan></text><text transform="translate(372.5 219.59843)" fill="black"><tspan font-family="Helvetica Neue" font-size="20" font-weight="bold" x="3.2771656" y="20" textLength="27.8">. . .</tspan></text><text transform="translate(376.08662 421.10237)" fill="black"><tspan font-family="Helvetica Neue" font-size="20" font-weight="bold" x="3.2771656" y="20" textLength="27.8">. . .</tspan></text><text transform="translate(375.87796 569.08662)" fill="black"><tspan font-family="Helvetica Neue" font-size="20" font-weight="bold" x="3.2771656" y="20" textLength="27.8">. . .</tspan></text><text transform="translate(589.35434 572.08662)" fill="black"><tspan font-family="Helvetica Neue" font-size="21" font-weight="500" x=".1925" y="20" textLength="27.615">s_t</tspan></text><text transform="translate(597.35434 499.73623)" fill="#262626"><tspan font-family="STIXGeneral" font-size="21" font-style="italic" font-weight="500" fill="#262626" x="0" y="22" textLength="10.08">θ</tspan></text><text transform="translate(542 222.3504)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="57.152">y_t = P</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="57.152" y="20" textLength="6.327">(</tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="63.479" y="20" textLength="61.199">a_t | s_t</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="124.678" y="20" textLength="9.5">, </tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="134.178" y="20" textLength="9.12">θ</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="143.298" y="20" textLength="6.327">)</tspan></text><text transform="translate(501.97638 152.90158)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="131.024">-log(y_t) = -logP</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="131.024" y="20" textLength="6.327">(</tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="137.351" y="20" textLength="61.199">a_t | s_t</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="198.55" y="20" textLength="9.5">, </tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="208.05" y="20" textLength="9.12">θ</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="217.17" y="20" textLength="6.327">)</tspan></text><text transform="translate(271.4567 154.73622)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="171.171">CROSS ENTROPY =</tspan></text><path d="M 510.65355 515.65355 L 528.9232 507.65355 L 528.9232 512.30024 L 566.08468 512.30024 L 566.08468 507.65355 L 584.35434 515.65355 L 566.08468 523.65355 L 566.08468 519.00685 L 528.9232 519.00685 L 528.9232 523.65355 Z" fill="white"/><path d="M 510.65355 515.65355 L 528.9232 507.65355 L 528.9232 512.30024 L 566.08468 512.30024 L 566.08468 507.65355 L 584.35434 515.65355 L 566.08468 523.65355 L 566.08468 519.00685 L 528.9232 519.00685 L 528.9232 523.65355 Z" stroke="black" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><path d="M 510.4882 586.26772 L 528.75785 578.26772 L 528.75785 582.9144 L 565.91932 582.9144 L 565.91932 578.26772 L 584.18898 586.26772 L 565.91932 594.26772 L 565.91932 589.62103 L 528.75785 589.62103 L 528.75785 594.26772 Z" fill="white"/><path d="M 510.4882 586.26772 L 528.75785 578.26772 L 528.75785 582.9144 L 565.91932 582.9144 L 565.91932 578.26772 L 584.18898 586.26772 L 565.91932 594.26772 L 565.91932 589.62103 L 528.75785 589.62103 L 528.75785 594.26772 Z" stroke="black" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/></g></g></svg>
<?xml version="1.0" encoding="utf-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xl="http://www.w3.org/1999/xlink" version="1.1" viewBox="80 136 614 415" width="614pt" height="415pt" xmlns:dc="http://purl.org/dc/elements/1.1/"><metadata> Produced by OmniGraffle 6.0.5 <dc:date>2017-12-01 08:39Z</dc:date></metadata><defs><font-face font-family="Helvetica Neue" font-size="20" panose-1="2 11 6 4 2 2 2 2 2 4" units-per-em="1000" underline-position="-75" underline-thickness="50" slope="0" x-height="517" cap-height="714" ascent="975.0061" descent="-216.99524" font-weight="600"><font-face-src><font-face-name name="HelveticaNeue-Medium"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="20" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-816.5001" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><marker orient="auto" overflow="visible" markerUnits="strokeWidth" id="FilledArrow_Marker" viewBox="-1 -3 6 6" markerWidth="6" markerHeight="6" color="#9a9a9a"><g><path d="M 3.7333333 0 L 0 -1.4 L 0 1.4 Z" fill="currentColor" stroke="currentColor" stroke-width="1"/></g></marker><font-face font-family="STIXGeneral" font-size="19" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-859.4738" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="19" units-per-em="1000" underline-position="-227" underline-thickness="66" slope="0" x-height="450" cap-height="662" ascent="1055.00214" descent="-455.00092" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Regular"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="21" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-777.61913" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><font-face font-family="Helvetica Neue" font-size="18" panose-1="2 11 6 4 2 2 2 2 2 4" units-per-em="1000" underline-position="-75" underline-thickness="50" slope="0" x-height="517" cap-height="714" ascent="975.0061" descent="-216.99524" font-weight="600"><font-face-src><font-face-name name="HelveticaNeue-Medium"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="18" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-907.2223" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="12" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-1360.8335" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="14" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-1166.4287" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face></defs><g stroke="none" stroke-opacity="1" stroke-dasharray="none" fill="none" fill-opacity="1"><title>神经网络 2</title><g><title>Layer 1</title><circle cx="170.05906" cy="507.30315" r="32.500052" fill="#bfeaff"/><circle cx="170.05906" cy="507.30315" r="32.500052" stroke="#5f747e" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(146.05906 494.80315)" fill="black"><tspan font-family="Helvetica Neue" font-size="20" font-weight="600" x="10.48" y="20" textLength="27.04">s_t</tspan></text><circle cx="169.59842" cy="238.01181" r="32.500052" fill="#ffd6d8"/><circle cx="169.59842" cy="238.01181" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(145.598425 225.51181)" fill="black"><tspan font-family="Helvetica Neue" font-size="20" font-weight="600" x="17.52" y="20" textLength="12.96">Y</tspan></text><circle cx="169.59842" cy="367.57087" r="32.500052" fill="#d5c0ff"/><circle cx="169.59842" cy="367.57087" r="32.500052" stroke="#695f7e" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(145.598425 352.57087)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" fill="black" x="11.22" y="21" textLength="25.56">FC</tspan></text><line x1="169.95027" y1="474.30332" x2="169.75962" y2="416.47062" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="169.59843" y1="334.57085" x2="169.59843" y2="286.91183" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><circle cx="304.46063" cy="507.30315" r="32.500052" fill="#c2ffc4"/><circle cx="304.46063" cy="507.30315" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(280.46063 494.80315)" fill="black"><tspan font-family="Helvetica Neue" font-size="20" font-weight="600" x="10.11" y="20" textLength="27.78">a_t</tspan></text><text transform="translate(226.54725 151.48425)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="42.218">-logP</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="42.218" y="20" textLength="6.327">(</tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="48.545" y="20" textLength="61.199">a_t | s_t</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="109.744" y="20" textLength="9.5">, </tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="119.244" y="20" textLength="9.12">θ</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="128.364" y="20" textLength="6.327">)</tspan></text><text transform="translate(94.125985 292.63386)" fill="#262626"><tspan font-family="STIXGeneral" font-size="21" font-style="italic" font-weight="500" fill="#262626" x="0" y="22" textLength="67.683">Softmax</tspan></text><circle cx="430.8189" cy="507.30315" r="32.500052" fill="#c2ffc4"/><circle cx="430.8189" cy="507.30315" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(406.8189 493.80315)" fill="black"><tspan font-family="Helvetica Neue" font-size="18" font-weight="600" x=".672" y="19" textLength="17.676">R(</tspan><tspan font-family="STIXGeneral" font-size="18" font-style="italic" font-weight="500" fill="#262626" x="18.348" y="19" textLength="14.976">τ^</tspan><tspan font-family="STIXGeneral" font-size="18" font-style="italic" font-weight="500" fill="#262626" x="33.324" y="19" textLength="9">n</tspan><tspan font-family="Helvetica Neue" font-size="18" font-weight="600" x="42.324" y="19" textLength="5.004">)</tspan></text><circle cx="300.20866" cy="238.01181" r="32.500052" fill="#c2ffc4"/><circle cx="300.20866" cy="238.01181" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(276.20866 220.01181)" fill="black"><tspan font-family="STIXGeneral" font-size="12" font-style="italic" font-weight="500" x="9.996" y="13" textLength="31.008">Cross </tspan><tspan font-family="STIXGeneral" font-size="12" font-style="italic" font-weight="500" x="4.644" y="31" textLength="38.712">Entropy</tspan></text><line x1="303.93964" y1="474.30723" x2="300.98067" y2="286.90576" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="202.59844" y1="238.01183" x2="251.30865" y2="238.01183" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><circle cx="430.8189" cy="238.01181" r="32.500052" fill="#c2ffc4"/><circle cx="430.8189" cy="238.01181" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(406.8189 223.01181)" fill="black"><tspan font-family="STIXGeneral" font-size="20" font-style="italic" font-weight="500" x="7.89" y="21" textLength="32.22">Mul</tspan></text><line x1="333.20868" y1="238.01182" x2="381.91889" y2="238.01182" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="430.81892" y1="474.30314" x2="430.81892" y2="286.91183" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><text transform="translate(488.3937 223.51181)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="6.327">-</tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="6.327" y="20" textLength="11.609">R</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="17.936" y="20" textLength="6.327">(</tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="24.263" y="20" textLength="15.808">τ^</tspan><tspan font-family="STIXGeneral" font-size="14" font-style="italic" font-weight="500" fill="#262626" x="40.071" y="20" textLength="7">n</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="47.071" y="20" textLength="6.327">)</tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="53.398" y="20" textLength="35.891">logP</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="89.289" y="20" textLength="6.327">(</tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="95.616" y="20" textLength="61.199">a_t | s_t</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="156.815" y="20" textLength="9.5">, </tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="166.315" y="20" textLength="9.12">θ</tspan><tspan font-family="STIXGeneral" font-size="19" font-weight="500" fill="#262626" x="175.435" y="20" textLength="6.327">)</tspan></text><text transform="translate(131.72441 424.19292)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="9.12">θ</tspan></text></g></g></svg>
<?xml version="1.0" encoding="utf-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xl="http://www.w3.org/1999/xlink" version="1.1" viewBox="51 454 689 160" width="689pt" height="160pt" xmlns:dc="http://purl.org/dc/elements/1.1/"><metadata> Produced by OmniGraffle 6.0.5 <dc:date>2017-12-01 09:42Z</dc:date></metadata><defs><font-face font-family="Helvetica Neue" font-size="20" panose-1="2 11 6 4 2 2 2 2 2 4" units-per-em="1000" underline-position="-75" underline-thickness="50" slope="0" x-height="517" cap-height="714" ascent="975.0061" descent="-216.99524" font-weight="600"><font-face-src><font-face-name name="HelveticaNeue-Medium"/></font-face-src></font-face><font-face font-family="Helvetica Neue" font-size="16" panose-1="2 11 6 4 2 2 2 2 2 4" units-per-em="1000" underline-position="-75" underline-thickness="50" slope="0" x-height="517" cap-height="714" ascent="975.0061" descent="-216.99524" font-weight="600"><font-face-src><font-face-name name="HelveticaNeue-Medium"/></font-face-src></font-face><font-face font-family="STIXGeneral" font-size="19" units-per-em="1000" underline-position="-227" underline-thickness="31" slope="-859.4738" x-height="428" cap-height="653" ascent="1055.00214" descent="-455.00092" font-style="italic" font-weight="500"><font-face-src><font-face-name name="STIXGeneral-Italic"/></font-face-src></font-face><marker orient="auto" overflow="visible" markerUnits="strokeWidth" id="FilledArrow_Marker" viewBox="-1 -3 6 6" markerWidth="6" markerHeight="6" color="#9a9a9a"><g><path d="M 3.7333333 0 L 0 -1.4 L 0 1.4 Z" fill="currentColor" stroke="currentColor" stroke-width="1"/></g></marker><font-face font-family="Helvetica" font-size="19" units-per-em="1000" underline-position="-75.683594" underline-thickness="49.316406" slope="0" x-height="522.94922" cap-height="717.28516" ascent="770.01953" descent="-229.98047" font-weight="500"><font-face-src><font-face-name name="Helvetica"/></font-face-src></font-face></defs><g stroke="none" stroke-opacity="1" stroke-dasharray="none" fill="none" fill-opacity="1"><title>神经网络 3</title><g><title>Layer 1</title><circle cx="695.8071" cy="498.79922" r="32.500052" fill="#ffd6d8"/><circle cx="695.8071" cy="498.79922" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(671.8071 486.29922)" fill="black"><tspan font-family="Helvetica Neue" font-size="20" font-weight="600" x="16.96" y="20" textLength="14.08">R</tspan></text><circle cx="232.22048" cy="498.79922" r="32.500052" fill="#c2ffc4"/><circle cx="232.22048" cy="498.79922" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(208.22048 489.29922)" fill="black"><tspan font-family="Helvetica Neue" font-size="16" font-weight="600" x="11.104" y="16" textLength="25.792">a_2</tspan></text><text transform="translate(581.49607 471.46063)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="60.325">= 0.9 * </tspan></text><line x1="652" y1="498" x2="569.0881" y2="498.50272" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><line x1="487.18897" y1="498.79923" x2="420.7819" y2="498.79923" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><text transform="translate(280.70866 487.29922)" fill="#262626"><tspan font-family="Helvetica" font-size="19" font-weight="500" fill="#262626" x="0" y="19" textLength="34.836426"></tspan></text><circle cx="94.862205" cy="498.79922" r="32.500052" fill="#c2ffc4"/><circle cx="94.862205" cy="498.79922" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(70.862205 489.29922)" fill="black"><tspan font-family="Helvetica Neue" font-size="16" font-weight="600" x="11.104" y="16" textLength="25.792">a_1</tspan></text><circle cx="371.8819" cy="498.79922" r="32.500052" fill="#c2ffc4"/><circle cx="371.8819" cy="498.79922" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(347.8819 489.29922)" fill="black"><tspan font-family="Helvetica Neue" font-size="16" font-weight="600" x=".88" y="16" textLength="46.24">a_(t-1)</tspan></text><circle cx="520.18898" cy="498.79922" r="32.500052" fill="#c2ffc4"/><circle cx="520.18898" cy="498.79922" r="32.500052" stroke="#5f7e69" stroke-linecap="round" stroke-linejoin="round" stroke-width="1"/><text transform="translate(496.18898 489.29922)" fill="black"><tspan font-family="Helvetica Neue" font-size="16" font-weight="600" x="12.888" y="16" textLength="22.224">a_t</tspan></text><line x1="199.22046" y1="498.79923" x2="143.76222" y2="498.79923" marker-end="url(#FilledArrow_Marker)" stroke="#9a9a9a" stroke-linecap="round" stroke-linejoin="round" stroke-width="3"/><text transform="translate(414.6063 469.12993)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="63.593">= 0.9^2 </tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="63.593" y="20" textLength="9.5">*</tspan></text><text transform="translate(129.30709 468.79134)" fill="#262626"><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="0" y="20" textLength="59.375">= 0.9^t </tspan><tspan font-family="STIXGeneral" font-size="19" font-style="italic" font-weight="500" fill="#262626" x="59.375" y="20" textLength="9.5">*</tspan></text><text transform="translate(194.84252 576.43308)" fill="#262626"><tspan font-family="Helvetica" font-size="19" font-weight="500" fill="#262626" x="0" y="19" textLength="218.0918"> -= mean(a_1, a_2 … a_t)</tspan></text></g></g></svg>
...@@ -15,7 +15,7 @@ if __name__ == "__main__": ...@@ -15,7 +15,7 @@ if __name__ == "__main__":
done = False done = False
for epoch in range(epoches): for epoch in range(epoches):
if epoch % 100 == 1: if (epoch % 500 == 1) or epoch < 5 or epoch > 3000:
e.render = True e.render = True
else: else:
e.render = False e.render = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册