Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PARL
提交
bbcb707b
P
PARL
项目概览
PaddlePaddle
/
PARL
通知
67
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
18
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PARL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
18
Issue
18
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bbcb707b
编写于
3月 03, 2020
作者:
H
Hongsheng Zeng
提交者:
GitHub
3月 03, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
torch benchmark policy gradient (#203)
* torch benchmark policy gradient * refine comments and use native api
上级
9216d941
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
294 addition
and
0 deletion
+294
-0
benchmark/torch/QuickStart/README.md
benchmark/torch/QuickStart/README.md
+28
-0
benchmark/torch/QuickStart/cartpole_agent.py
benchmark/torch/QuickStart/cartpole_agent.py
+79
-0
benchmark/torch/QuickStart/cartpole_model.py
benchmark/torch/QuickStart/cartpole_model.py
+38
-0
benchmark/torch/QuickStart/train.py
benchmark/torch/QuickStart/train.py
+76
-0
parl/algorithms/torch/__init__.py
parl/algorithms/torch/__init__.py
+1
-0
parl/algorithms/torch/policy_gradient.py
parl/algorithms/torch/policy_gradient.py
+72
-0
未找到文件。
benchmark/torch/QuickStart/README.md
0 → 100644
浏览文件 @
bbcb707b
## PyTorch benchmark Quick Start
Train an agent with PARL to solve the CartPole problem, a classical benchmark in RL.
## How to use
### Dependencies:
+
[
parl
](
https://github.com/PaddlePaddle/PARL
)
+
torch
+
gym
### Start Training:
```
# Install dependencies
pip install torch torchvision gym
git clone https://github.com/PaddlePaddle/PARL.git
cd PARL
pip install .
# Train model
cd benchmark/torch/QuickStart
python train.py
```
### Expected Result
<img
src=
"https://github.com/PaddlePaddle/PARL/blob/develop/examples/QuickStart/performance.gif"
width =
"300"
height =
"200"
alt=
"result"
/>
The agent can get around 200 points in a few minutes.
benchmark/torch/QuickStart/cartpole_agent.py
0 → 100644
浏览文件 @
bbcb707b
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
parl
import
torch
import
numpy
as
np
class
CartpoleAgent
(
parl
.
Agent
):
"""Agent of Cartpole env.
Args:
algorithm(parl.Algorithm): algorithm used to solve the problem.
"""
def
__init__
(
self
,
algorithm
):
self
.
algorithm
=
algorithm
self
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
def
sample
(
self
,
obs
):
"""Sample an action when given an observation
Args:
obs(np.float32): shape of (obs_dim,)
Returns:
action(int)
"""
obs
=
torch
.
tensor
(
obs
,
device
=
self
.
device
,
dtype
=
torch
.
float
)
prob
=
self
.
algorithm
.
predict
(
obs
)
prob
=
prob
.
data
.
numpy
()
action
=
np
.
random
.
choice
(
len
(
prob
),
1
,
p
=
prob
)[
0
]
return
action
def
predict
(
self
,
obs
):
"""Predict an action when given an observation
Args:
obs(np.float32): shape of (obs_dim,)
Returns:
action(int)
"""
obs
=
torch
.
tensor
(
obs
,
device
=
self
.
device
,
dtype
=
torch
.
float
)
prob
=
self
.
algorithm
.
predict
(
obs
)
_
,
action
=
prob
.
max
(
-
1
)
return
action
.
item
()
def
learn
(
self
,
obs
,
action
,
reward
):
"""Update model with an episode data
Args:
obs(np.float32): shape of (batch_size, obs_dim)
action(np.int64): shape of (batch_size)
reward(np.float32): shape of (batch_size)
Returns:
loss(float)
"""
obs
=
torch
.
tensor
(
obs
,
device
=
self
.
device
,
dtype
=
torch
.
float
)
action
=
torch
.
tensor
(
action
,
device
=
self
.
device
,
dtype
=
torch
.
long
)
reward
=
torch
.
tensor
(
reward
,
device
=
self
.
device
,
dtype
=
torch
.
float
)
loss
=
self
.
algorithm
.
learn
(
obs
,
action
,
reward
)
return
loss
.
item
()
benchmark/torch/QuickStart/cartpole_model.py
0 → 100644
浏览文件 @
bbcb707b
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
parl
class
CartpoleModel
(
parl
.
Model
):
""" Linear network to solve Cartpole problem.
Args:
obs_dim (int): Dimension of observation space.
act_dim (int): Dimension of action space.
"""
def
__init__
(
self
,
obs_dim
,
act_dim
):
super
(
CartpoleModel
,
self
).
__init__
()
hid1_size
=
act_dim
*
10
self
.
fc1
=
nn
.
Linear
(
obs_dim
,
hid1_size
)
self
.
fc2
=
nn
.
Linear
(
hid1_size
,
act_dim
)
def
forward
(
self
,
x
):
out
=
torch
.
tanh
(
self
.
fc1
(
x
))
prob
=
F
.
softmax
(
self
.
fc2
(
out
),
dim
=-
1
)
return
prob
benchmark/torch/QuickStart/train.py
0 → 100644
浏览文件 @
bbcb707b
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
gym
import
numpy
as
np
import
parl
from
parl.utils
import
logger
from
cartpole_model
import
CartpoleModel
from
cartpole_agent
import
CartpoleAgent
OBS_DIM
=
4
ACT_DIM
=
2
LEARNING_RATE
=
1e-3
def
run_episode
(
env
,
agent
,
train_or_test
=
'train'
):
obs_list
,
action_list
,
reward_list
=
[],
[],
[]
obs
=
env
.
reset
()
while
True
:
obs_list
.
append
(
obs
)
if
train_or_test
==
'train'
:
action
=
agent
.
sample
(
obs
)
else
:
action
=
agent
.
predict
(
obs
)
action_list
.
append
(
action
)
obs
,
reward
,
done
,
_
=
env
.
step
(
action
)
reward_list
.
append
(
reward
)
if
done
:
break
return
obs_list
,
action_list
,
reward_list
def
calc_reward_to_go
(
reward_list
):
for
i
in
range
(
len
(
reward_list
)
-
2
,
-
1
,
-
1
):
reward_list
[
i
]
+=
reward_list
[
i
+
1
]
return
np
.
array
(
reward_list
)
def
main
():
env
=
gym
.
make
(
'CartPole-v0'
)
model
=
CartpoleModel
(
obs_dim
=
OBS_DIM
,
act_dim
=
ACT_DIM
)
alg
=
parl
.
algorithms
.
PolicyGradient
(
model
,
LEARNING_RATE
)
agent
=
CartpoleAgent
(
alg
)
for
i
in
range
(
1000
):
# 1000 episodes
obs_list
,
action_list
,
reward_list
=
run_episode
(
env
,
agent
)
if
i
%
10
==
0
:
logger
.
info
(
"Episode {}, Reward Sum {}."
.
format
(
i
,
sum
(
reward_list
)))
batch_obs
=
np
.
array
(
obs_list
)
batch_action
=
np
.
array
(
action_list
)
batch_reward
=
calc_reward_to_go
(
reward_list
)
agent
.
learn
(
batch_obs
,
batch_action
,
batch_reward
)
if
(
i
+
1
)
%
100
==
0
:
_
,
_
,
reward_list
=
run_episode
(
env
,
agent
,
train_or_test
=
'test'
)
total_reward
=
np
.
sum
(
reward_list
)
logger
.
info
(
'Test reward: {}'
.
format
(
total_reward
))
if
__name__
==
'__main__'
:
main
()
parl/algorithms/torch/__init__.py
浏览文件 @
bbcb707b
...
...
@@ -16,3 +16,4 @@ from parl.algorithms.torch.ddqn import *
from
parl.algorithms.torch.dqn
import
*
from
parl.algorithms.torch.a2c
import
*
from
parl.algorithms.torch.td3
import
*
from
parl.algorithms.torch.policy_gradient
import
*
parl/algorithms/torch/policy_gradient.py
0 → 100644
浏览文件 @
bbcb707b
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.optim
as
optim
import
parl
from
torch.distributions
import
Categorical
__all__
=
[
'PolicyGradient'
]
class
PolicyGradient
(
parl
.
Algorithm
):
def
__init__
(
self
,
model
,
lr
):
"""Policy gradient algorithm
Args:
model (parl.Model): model defining forward network of policy.
lr (float): learning rate.
"""
assert
isinstance
(
lr
,
float
)
self
.
model
=
model
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
self
.
model
.
to
(
device
)
self
.
optimizer
=
optim
.
Adam
(
self
.
model
.
parameters
(),
lr
=
lr
)
def
predict
(
self
,
obs
):
"""Predict the probability of actions
Args:
obs (torch.tensor): shape of (obs_dim,)
Returns:
prob (torch.tensor): shape of (action_dim,)
"""
prob
=
self
.
model
(
obs
)
return
prob
def
learn
(
self
,
obs
,
action
,
reward
):
"""Update model with policy gradient algorithm
Args:
obs (torch.tensor): shape of (batch_size, obs_dim)
action (torch.tensor): shape of (batch_size, 1)
reward (torch.tensor): shape of (batch_size, 1)
Returns:
loss (torch.tensor): shape of (1)
"""
prob
=
self
.
model
(
obs
)
log_prob
=
Categorical
(
prob
).
log_prob
(
action
)
loss
=
torch
.
mean
(
-
1
*
log_prob
*
reward
)
self
.
optimizer
.
zero_grad
()
loss
.
backward
()
self
.
optimizer
.
step
()
return
loss
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录