Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PARL
提交
87e68119
P
PARL
项目概览
PaddlePaddle
/
PARL
通知
68
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
18
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PARL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
18
Issue
18
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
87e68119
编写于
5月 07, 2020
作者:
L
LI Yunxiang
提交者:
GitHub
5月 07, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update ddpg (#260)
* update ddpg * Update train.py
上级
e5e1685a
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
18 addition
and
29 deletion
+18
-29
examples/DDPG/mujoco_agent.py
examples/DDPG/mujoco_agent.py
+1
-0
examples/DDPG/mujoco_model.py
examples/DDPG/mujoco_model.py
+0
-1
examples/DDPG/train.py
examples/DDPG/train.py
+17
-28
未找到文件。
examples/DDPG/mujoco_agent.py
浏览文件 @
87e68119
...
@@ -55,6 +55,7 @@ class MujocoAgent(parl.Agent):
...
@@ -55,6 +55,7 @@ class MujocoAgent(parl.Agent):
act
=
self
.
fluid_executor
.
run
(
act
=
self
.
fluid_executor
.
run
(
self
.
pred_program
,
feed
=
{
'obs'
:
obs
},
self
.
pred_program
,
feed
=
{
'obs'
:
obs
},
fetch_list
=
[
self
.
pred_act
])[
0
]
fetch_list
=
[
self
.
pred_act
])[
0
]
act
=
np
.
squeeze
(
act
)
return
act
return
act
def
learn
(
self
,
obs
,
act
,
reward
,
next_obs
,
terminal
):
def
learn
(
self
,
obs
,
act
,
reward
,
next_obs
,
terminal
):
...
...
examples/DDPG/mujoco_model.py
浏览文件 @
87e68119
...
@@ -45,7 +45,6 @@ class ActorModel(parl.Model):
...
@@ -45,7 +45,6 @@ class ActorModel(parl.Model):
hid1
=
self
.
fc1
(
obs
)
hid1
=
self
.
fc1
(
obs
)
hid2
=
self
.
fc2
(
hid1
)
hid2
=
self
.
fc2
(
hid1
)
means
=
self
.
fc3
(
hid2
)
means
=
self
.
fc3
(
hid2
)
means
=
means
return
means
return
means
...
...
examples/DDPG/train.py
浏览文件 @
87e68119
...
@@ -21,14 +21,12 @@ from mujoco_agent import MujocoAgent
...
@@ -21,14 +21,12 @@ from mujoco_agent import MujocoAgent
from
mujoco_model
import
MujocoModel
from
mujoco_model
import
MujocoModel
from
parl.utils
import
logger
,
action_mapping
,
ReplayMemory
from
parl.utils
import
logger
,
action_mapping
,
ReplayMemory
MAX_EPISODES
=
5000
TEST_EVERY_EPISODES
=
20
ACTOR_LR
=
1e-4
ACTOR_LR
=
1e-4
CRITIC_LR
=
1e-3
CRITIC_LR
=
1e-3
GAMMA
=
0.99
GAMMA
=
0.99
TAU
=
0.001
TAU
=
0.001
MEMORY_SIZE
=
int
(
1e6
)
MEMORY_SIZE
=
int
(
1e6
)
M
IN_LEARN
_SIZE
=
1e4
M
EMORY_WARMUP
_SIZE
=
1e4
BATCH_SIZE
=
128
BATCH_SIZE
=
128
REWARD_SCALE
=
0.1
REWARD_SCALE
=
0.1
ENV_SEED
=
1
ENV_SEED
=
1
...
@@ -37,12 +35,9 @@ ENV_SEED = 1
...
@@ -37,12 +35,9 @@ ENV_SEED = 1
def
run_train_episode
(
env
,
agent
,
rpm
):
def
run_train_episode
(
env
,
agent
,
rpm
):
obs
=
env
.
reset
()
obs
=
env
.
reset
()
total_reward
=
0
total_reward
=
0
steps
=
0
while
True
:
while
True
:
steps
+=
1
batch_obs
=
np
.
expand_dims
(
obs
,
axis
=
0
)
batch_obs
=
np
.
expand_dims
(
obs
,
axis
=
0
)
action
=
agent
.
predict
(
batch_obs
.
astype
(
'float32'
))
action
=
agent
.
predict
(
batch_obs
.
astype
(
'float32'
))
action
=
np
.
squeeze
(
action
)
# Add exploration noise, and clip to [-1.0, 1.0]
# Add exploration noise, and clip to [-1.0, 1.0]
action
=
np
.
clip
(
np
.
random
.
normal
(
action
,
1.0
),
-
1.0
,
1.0
)
action
=
np
.
clip
(
np
.
random
.
normal
(
action
,
1.0
),
-
1.0
,
1.0
)
...
@@ -53,7 +48,7 @@ def run_train_episode(env, agent, rpm):
...
@@ -53,7 +48,7 @@ def run_train_episode(env, agent, rpm):
rpm
.
append
(
obs
,
action
,
REWARD_SCALE
*
reward
,
next_obs
,
done
)
rpm
.
append
(
obs
,
action
,
REWARD_SCALE
*
reward
,
next_obs
,
done
)
if
rpm
.
size
()
>
M
IN_LEARN
_SIZE
:
if
rpm
.
size
()
>
M
EMORY_WARMUP
_SIZE
:
batch_obs
,
batch_action
,
batch_reward
,
batch_next_obs
,
batch_terminal
=
rpm
.
sample_batch
(
batch_obs
,
batch_action
,
batch_reward
,
batch_next_obs
,
batch_terminal
=
rpm
.
sample_batch
(
BATCH_SIZE
)
BATCH_SIZE
)
agent
.
learn
(
batch_obs
,
batch_action
,
batch_reward
,
batch_next_obs
,
agent
.
learn
(
batch_obs
,
batch_action
,
batch_reward
,
batch_next_obs
,
...
@@ -64,7 +59,7 @@ def run_train_episode(env, agent, rpm):
...
@@ -64,7 +59,7 @@ def run_train_episode(env, agent, rpm):
if
done
:
if
done
:
break
break
return
total_reward
,
steps
return
total_reward
def
run_evaluate_episode
(
env
,
agent
):
def
run_evaluate_episode
(
env
,
agent
):
...
@@ -73,7 +68,6 @@ def run_evaluate_episode(env, agent):
...
@@ -73,7 +68,6 @@ def run_evaluate_episode(env, agent):
while
True
:
while
True
:
batch_obs
=
np
.
expand_dims
(
obs
,
axis
=
0
)
batch_obs
=
np
.
expand_dims
(
obs
,
axis
=
0
)
action
=
agent
.
predict
(
batch_obs
.
astype
(
'float32'
))
action
=
agent
.
predict
(
batch_obs
.
astype
(
'float32'
))
action
=
np
.
squeeze
(
action
)
action
=
action_mapping
(
action
,
env
.
action_space
.
low
[
0
],
action
=
action_mapping
(
action
,
env
.
action_space
.
low
[
0
],
env
.
action_space
.
high
[
0
])
env
.
action_space
.
high
[
0
])
...
@@ -101,19 +95,19 @@ def main():
...
@@ -101,19 +95,19 @@ def main():
rpm
=
ReplayMemory
(
MEMORY_SIZE
,
obs_dim
,
act_dim
)
rpm
=
ReplayMemory
(
MEMORY_SIZE
,
obs_dim
,
act_dim
)
test_flag
=
0
while
rpm
.
size
()
<
MEMORY_WARMUP_SIZE
:
total_steps
=
0
run_train_episode
(
env
,
agent
,
rpm
)
while
total_steps
<
args
.
train_total_steps
:
train_reward
,
steps
=
run_train_episode
(
env
,
agent
,
rpm
)
episode
=
0
total_steps
+=
steps
while
episode
<
args
.
train_total_episode
:
logger
.
info
(
'Steps: {} Reward: {}'
.
format
(
total_steps
,
train_reward
))
for
i
in
range
(
50
):
train_reward
=
run_train_episode
(
env
,
agent
,
rpm
)
episode
+=
1
logger
.
info
(
'Episode: {} Reward: {}'
.
format
(
episode
,
train_reward
))
if
total_steps
//
args
.
test_every_steps
>=
test_flag
:
while
total_steps
//
args
.
test_every_steps
>=
test_flag
:
test_flag
+=
1
evaluate_reward
=
run_evaluate_episode
(
env
,
agent
)
evaluate_reward
=
run_evaluate_episode
(
env
,
agent
)
logger
.
info
(
'Steps
{}, Evaluate reward: {}'
.
format
(
logger
.
info
(
'Episode
{}, Evaluate reward: {}'
.
format
(
total_steps
,
evaluate_reward
))
episode
,
evaluate_reward
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -121,15 +115,10 @@ if __name__ == '__main__':
...
@@ -121,15 +115,10 @@ if __name__ == '__main__':
parser
.
add_argument
(
parser
.
add_argument
(
'--env'
,
help
=
'Mujoco environment name'
,
default
=
'HalfCheetah-v2'
)
'--env'
,
help
=
'Mujoco environment name'
,
default
=
'HalfCheetah-v2'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--train_total_steps'
,
'--train_total_episode'
,
type
=
int
,
default
=
int
(
1e7
),
help
=
'maximum training steps'
)
parser
.
add_argument
(
'--test_every_steps'
,
type
=
int
,
type
=
int
,
default
=
int
(
1e4
),
default
=
int
(
1e4
),
help
=
'
the step interval between two consecutive evaluation
s'
)
help
=
'
maximum training episode
s'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录