Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PARL
提交
a9159021
P
PARL
项目概览
PaddlePaddle
/
PARL
通知
67
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
18
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PARL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
18
Issue
18
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a9159021
编写于
6月 10, 2020
作者:
R
rical730
提交者:
GitHub
6月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
upgrade DQN's lr interface compatibility (#291)
* upgrade DQN's lr interface compatibility * yapf * update example DQN
上级
533d4b2c
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
47 addition
and
21 deletion
+47
-21
examples/DQN/cartpole_agent.py
examples/DQN/cartpole_agent.py
+2
-6
examples/DQN/train.py
examples/DQN/train.py
+3
-3
parl/algorithms/fluid/ddqn.py
parl/algorithms/fluid/ddqn.py
+25
-9
parl/algorithms/fluid/dqn.py
parl/algorithms/fluid/dqn.py
+16
-2
parl/utils/utils.py
parl/utils/utils.py
+1
-1
未找到文件。
examples/DQN/cartpole_agent.py
浏览文件 @
a9159021
...
...
@@ -54,10 +54,7 @@ class CartpoleAgent(parl.Agent):
next_obs
=
layers
.
data
(
name
=
'next_obs'
,
shape
=
[
self
.
obs_dim
],
dtype
=
'float32'
)
terminal
=
layers
.
data
(
name
=
'terminal'
,
shape
=
[],
dtype
=
'bool'
)
lr
=
layers
.
data
(
name
=
'lr'
,
shape
=
[
1
],
dtype
=
'float32'
,
append_batch_size
=
False
)
self
.
cost
=
self
.
alg
.
learn
(
obs
,
action
,
reward
,
next_obs
,
terminal
,
lr
)
self
.
cost
=
self
.
alg
.
learn
(
obs
,
action
,
reward
,
next_obs
,
terminal
)
def
sample
(
self
,
obs
):
sample
=
np
.
random
.
rand
()
...
...
@@ -78,7 +75,7 @@ class CartpoleAgent(parl.Agent):
act
=
np
.
argmax
(
pred_Q
)
return
act
def
learn
(
self
,
obs
,
act
,
reward
,
next_obs
,
terminal
,
lr
):
def
learn
(
self
,
obs
,
act
,
reward
,
next_obs
,
terminal
):
if
self
.
global_step
%
self
.
update_target_steps
==
0
:
self
.
alg
.
sync_target
()
self
.
global_step
+=
1
...
...
@@ -90,7 +87,6 @@ class CartpoleAgent(parl.Agent):
'reward'
:
reward
,
'next_obs'
:
next_obs
.
astype
(
'float32'
),
'terminal'
:
terminal
,
'lr'
:
np
.
float32
([
lr
]),
}
cost
=
self
.
fluid_executor
.
run
(
self
.
learn_program
,
feed
=
feed
,
fetch_list
=
[
self
.
cost
])[
0
]
...
...
examples/DQN/train.py
浏览文件 @
a9159021
...
...
@@ -45,8 +45,7 @@ def run_episode(agent, env, rpm):
(
batch_obs
,
batch_action
,
batch_reward
,
batch_next_obs
,
batch_isOver
)
=
rpm
.
sample
(
BATCH_SIZE
)
train_loss
=
agent
.
learn
(
batch_obs
,
batch_action
,
batch_reward
,
batch_next_obs
,
batch_isOver
,
LEARNING_RATE
)
batch_next_obs
,
batch_isOver
)
total_reward
+=
reward
obs
=
next_obs
...
...
@@ -80,7 +79,8 @@ def main():
rpm
=
ReplayMemory
(
MEMORY_SIZE
)
model
=
CartpoleModel
(
act_dim
=
action_dim
)
algorithm
=
parl
.
algorithms
.
DQN
(
model
,
act_dim
=
action_dim
,
gamma
=
GAMMA
)
algorithm
=
parl
.
algorithms
.
DQN
(
model
,
act_dim
=
action_dim
,
gamma
=
GAMMA
,
lr
=
LEARNING_RATE
)
agent
=
CartpoleAgent
(
algorithm
,
obs_dim
=
obs_shape
[
0
],
...
...
parl/algorithms/fluid/ddqn.py
浏览文件 @
a9159021
...
...
@@ -21,19 +21,17 @@ import paddle.fluid as fluid
from
parl.core.fluid.algorithm
import
Algorithm
from
parl.core.fluid
import
layers
__all__
=
[
'DDQN'
]
class
DDQN
(
Algorithm
):
def
__init__
(
self
,
model
,
act_dim
=
None
,
gamma
=
None
,
):
def
__init__
(
self
,
model
,
act_dim
=
None
,
gamma
=
None
,
lr
=
None
):
""" Double DQN algorithm
Args:
model (parl.Model): model defining forward network of Q function.
model (parl.Model): model defining forward network of Q function
act_dim (int): dimension of the action space
gamma (float): discounted factor for reward computation.
lr (float): learning rate.
"""
self
.
model
=
model
self
.
target_model
=
copy
.
deepcopy
(
model
)
...
...
@@ -43,11 +41,29 @@ class DDQN(Algorithm):
self
.
act_dim
=
act_dim
self
.
gamma
=
gamma
self
.
lr
=
lr
def
predict
(
self
,
obs
):
""" use value model self.model to predict the action value
"""
return
self
.
model
.
value
(
obs
)
def
learn
(
self
,
obs
,
action
,
reward
,
next_obs
,
terminal
,
learning_rate
):
def
learn
(
self
,
obs
,
action
,
reward
,
next_obs
,
terminal
,
learning_rate
=
None
):
""" update value model self.model with DQN algorithm
"""
# Support the modification of learning_rate
if
learning_rate
is
None
:
assert
isinstance
(
self
.
lr
,
float
),
"Please set the learning rate of DQN in initializaion."
learning_rate
=
self
.
lr
pred_value
=
self
.
model
.
value
(
obs
)
action_onehot
=
layers
.
one_hot
(
action
,
self
.
act_dim
)
action_onehot
=
layers
.
cast
(
action_onehot
,
dtype
=
'float32'
)
...
...
parl/algorithms/fluid/dqn.py
浏览文件 @
a9159021
...
...
@@ -24,7 +24,7 @@ __all__ = ['DQN']
class
DQN
(
Algorithm
):
def
__init__
(
self
,
model
,
act_dim
=
None
,
gamma
=
None
):
def
__init__
(
self
,
model
,
act_dim
=
None
,
gamma
=
None
,
lr
=
None
):
""" DQN algorithm
Args:
...
...
@@ -38,17 +38,31 @@ class DQN(Algorithm):
assert
isinstance
(
act_dim
,
int
)
assert
isinstance
(
gamma
,
float
)
self
.
act_dim
=
act_dim
self
.
gamma
=
gamma
self
.
lr
=
lr
def
predict
(
self
,
obs
):
""" use value model self.model to predict the action value
"""
return
self
.
model
.
value
(
obs
)
def
learn
(
self
,
obs
,
action
,
reward
,
next_obs
,
terminal
,
learning_rate
):
def
learn
(
self
,
obs
,
action
,
reward
,
next_obs
,
terminal
,
learning_rate
=
None
):
""" update value model self.model with DQN algorithm
"""
# Support the modification of learning_rate
if
learning_rate
is
None
:
assert
isinstance
(
self
.
lr
,
float
),
"Please set the learning rate of DQN in initializaion."
learning_rate
=
self
.
lr
pred_value
=
self
.
model
.
value
(
obs
)
next_pred_value
=
self
.
target_model
.
value
(
next_obs
)
...
...
parl/utils/utils.py
浏览文件 @
a9159021
...
...
@@ -89,7 +89,7 @@ MAX_INT32 = 0x7fffffff
try
:
from
paddle
import
fluid
fluid_version
=
get_fluid_version
()
assert
fluid_version
>=
161
,
"PARL requires paddle>=1.6.1"
assert
fluid_version
>=
161
or
fluid_version
==
0
,
"PARL requires paddle>=1.6.1"
_HAS_FLUID
=
True
except
ImportError
:
_HAS_FLUID
=
False
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录