提交 c070db83 编写于 作者: L LI Yunxiang 提交者: Hongsheng Zeng

add sac (#188)

* add sac
上级 5054efed
# requirements for unittest
paddlepaddle-gpu==1.5.1.post97
paddlepaddle-gpu==1.6.1.post97
gym
details
parameterized
......
......@@ -76,8 +76,10 @@ pip install parl
- [PPO](examples/PPO/)
- [IMPALA](examples/IMPALA/)
- [A2C](examples/A2C/)
- [GA3C](examples/GA3C/)
- [TD3](examples/TD3/)
- [SAC](examples/SAC/)
- [冠军解决方案:NIPS2018强化学习假肢挑战赛](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/)
- [冠军解决方案:NIPS2019强化学习仿生人控制赛事](examples/NeurIPS2019-Learn-to-Move-Challenge/)
<img src=".github/NeurlIPS2018.gif" width = "300" height ="200" alt="NeurlIPS2018"/> <img src=".github/Half-Cheetah.gif" width = "300" height ="200" alt="Half-Cheetah"/> <img src=".github/Breakout.gif" width = "200" height ="200" alt="Breakout"/>
<br>
......
......@@ -79,8 +79,10 @@ pip install parl
- [PPO](examples/PPO/)
- [IMPALA](examples/IMPALA/)
- [A2C](examples/A2C/)
- [GA3C](examples/GA3C/)
- [TD3](examples/TD3/)
- [SAC](examples/SAC/)
- [Winning Solution for NIPS2018: AI for Prosthetics Challenge](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/)
- [Winning Solution for NIPS2019: Learn to Move Challenge](examples/NeurIPS2019-Learn-to-Move-Challenge/)
<img src=".github/NeurlIPS2018.gif" width = "300" height ="200" alt="NeurlIPS2018"/> <img src=".github/Half-Cheetah.gif" width = "300" height ="200" alt="Half-Cheetah"/> <img src=".github/Breakout.gif" width = "200" height ="200" alt="Breakout"/>
<br>
......
## Reproduce SAC with PARL
Based on PARL, the SAC algorithm of deep reinforcement learning has been reproduced, reaching the same level of indicators as the paper in Mujoco benchmarks.
Include following approaches:
+ DDPG Style with Stochastic Policy
+ Maximum Entropy
> SAC in
[Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor](https://arxiv.org/abs/1801.01290)
### Mujoco games introduction
Please see [here](https://github.com/openai/mujoco-py) to know more about Mujoco games.
### Benchmark result
<img src=".benchmark/merge.png" width = "1500" height ="260" alt="Performance" />
## How to use
### Dependencies:
+ python3.5+
+ [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle)
+ [parl](https://github.com/PaddlePaddle/PARL)
+ gym
+ mujoco-py>=1.50.1.0
### Start Training:
```
# To train an agent for HalfCheetah-v2 game
python train.py
# To train for different games
# python train.py --env [ENV_NAME]
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import parl
from parl import layers
from paddle import fluid
class MujocoAgent(parl.Agent):
def __init__(self, algorithm, obs_dim, act_dim):
assert isinstance(obs_dim, int)
assert isinstance(act_dim, int)
self.obs_dim = obs_dim
self.act_dim = act_dim
super(MujocoAgent, self).__init__(algorithm)
# Attention: In the beginning, sync target model totally.
self.alg.sync_target(decay=0)
def build_program(self):
self.pred_program = fluid.Program()
self.sample_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.pred_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
self.pred_act = self.alg.predict(obs)
with fluid.program_guard(self.sample_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
self.sample_act, _ = self.alg.sample(obs)
with fluid.program_guard(self.learn_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
act = layers.data(
name='act', shape=[self.act_dim], dtype='float32')
reward = layers.data(name='reward', shape=[], dtype='float32')
next_obs = layers.data(
name='next_obs', shape=[self.obs_dim], dtype='float32')
terminal = layers.data(name='terminal', shape=[], dtype='bool')
self.critic_cost, self.actor_cost = self.alg.learn(
obs, act, reward, next_obs, terminal)
def predict(self, obs):
obs = np.expand_dims(obs, axis=0)
act = self.fluid_executor.run(
self.pred_program, feed={'obs': obs},
fetch_list=[self.pred_act])[0]
return act
def sample(self, obs):
obs = np.expand_dims(obs, axis=0)
act = self.fluid_executor.run(
self.sample_program,
feed={'obs': obs},
fetch_list=[self.sample_act])[0]
return act
def learn(self, obs, act, reward, next_obs, terminal):
feed = {
'obs': obs,
'act': act,
'reward': reward,
'next_obs': next_obs,
'terminal': terminal
}
[critic_cost, actor_cost] = self.fluid_executor.run(
self.learn_program,
feed=feed,
fetch_list=[self.critic_cost, self.actor_cost])
self.alg.sync_target()
return critic_cost[0], actor_cost[0]
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl
from parl import layers
LOG_SIG_MAX = 2.0
LOG_SIG_MIN = -20.0
class ActorModel(parl.Model):
def __init__(self, act_dim):
hid1_size = 400
hid2_size = 300
self.fc1 = layers.fc(size=hid1_size, act='relu')
self.fc2 = layers.fc(size=hid2_size, act='relu')
self.mean_linear = layers.fc(size=act_dim)
self.log_std_linear = layers.fc(size=act_dim)
def policy(self, obs):
hid1 = self.fc1(obs)
hid2 = self.fc2(hid1)
means = self.mean_linear(hid2)
log_std = self.log_std_linear(hid2)
log_std = layers.clip(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
return means, log_std
class CriticModel(parl.Model):
def __init__(self):
hid1_size = 400
hid2_size = 300
self.fc1 = layers.fc(size=hid1_size, act='relu')
self.fc2 = layers.fc(size=hid2_size, act='relu')
self.fc3 = layers.fc(size=1, act=None)
self.fc4 = layers.fc(size=hid1_size, act='relu')
self.fc5 = layers.fc(size=hid2_size, act='relu')
self.fc6 = layers.fc(size=1, act=None)
def value(self, obs, act):
hid1 = self.fc1(obs)
concat1 = layers.concat([hid1, act], axis=1)
Q1 = self.fc2(concat1)
Q1 = self.fc3(Q1)
Q1 = layers.squeeze(Q1, axes=[1])
hid2 = self.fc4(obs)
concat2 = layers.concat([hid2, act], axis=1)
Q2 = self.fc5(concat2)
Q2 = self.fc6(Q2)
Q2 = layers.squeeze(Q2, axes=[1])
return Q1, Q2
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Refer to https://github.com/pranz24/pytorch-soft-actor-critic
import argparse
import gym
import numpy as np
import time
import parl
from mujoco_agent import MujocoAgent
from mujoco_model import ActorModel, CriticModel
from parl.utils import logger, tensorboard, action_mapping, ReplayMemory
ACTOR_LR = 1e-3
CRITIC_LR = 1e-3
GAMMA = 0.99
TAU = 0.005
MEMORY_SIZE = int(1e6)
WARMUP_SIZE = 1e4
BATCH_SIZE = 256
ENV_SEED = 1
def run_train_episode(env, agent, rpm):
obs = env.reset()
total_reward = 0
steps = 0
while True:
steps += 1
batch_obs = np.expand_dims(obs, axis=0)
if rpm.size() < WARMUP_SIZE:
action = env.action_space.sample()
else:
action = agent.sample(batch_obs.astype('float32'))
action = np.squeeze(action)
next_obs, reward, done, info = env.step(action)
rpm.append(obs, action, reward, next_obs, done)
if rpm.size() > WARMUP_SIZE:
batch_obs, batch_action, batch_reward, batch_next_obs, batch_terminal = rpm.sample_batch(
BATCH_SIZE)
agent.learn(batch_obs, batch_action, batch_reward, batch_next_obs,
batch_terminal)
obs = next_obs
total_reward += reward
if done:
break
return total_reward, steps
def run_evaluate_episode(env, agent):
obs = env.reset()
total_reward = 0
while True:
batch_obs = np.expand_dims(obs, axis=0)
action = agent.predict(batch_obs.astype('float32'))
action = np.squeeze(action)
next_obs, reward, done, info = env.step(action)
obs = next_obs
total_reward += reward
if done:
break
return total_reward
def main():
env = gym.make(args.env)
env.seed(ENV_SEED)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
actor = ActorModel(act_dim)
critic = CriticModel()
algorithm = parl.algorithms.SAC(
actor,
critic,
max_action=max_action,
gamma=GAMMA,
tau=TAU,
actor_lr=ACTOR_LR,
critic_lr=CRITIC_LR)
agent = MujocoAgent(algorithm, obs_dim, act_dim)
rpm = ReplayMemory(MEMORY_SIZE, obs_dim, act_dim)
test_flag = 0
total_steps = 0
while total_steps < args.train_total_steps:
train_reward, steps = run_train_episode(env, agent, rpm)
total_steps += steps
logger.info('Steps: {} Reward: {}'.format(total_steps, train_reward))
tensorboard.add_scalar('train/episode_reward', train_reward,
total_steps)
if total_steps // args.test_every_steps >= test_flag:
while total_steps // args.test_every_steps >= test_flag:
test_flag += 1
evaluate_reward = run_evaluate_episode(env, agent)
logger.info('Steps {}, Evaluate reward: {}'.format(
total_steps, evaluate_reward))
tensorboard.add_scalar('eval/episode_reward', evaluate_reward,
total_steps)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--env', help='Mujoco environment name', default='HalfCheetah-v2')
parser.add_argument(
'--train_total_steps',
type=int,
default=int(1e6),
help='maximum training steps')
parser.add_argument(
'--test_every_steps',
type=int,
default=int(1e4),
help='the step interval between two consecutive evaluations')
parser.add_argument(
'--alpha',
type=float,
default=0.2,
help='Temperature parameter α determines the relative importance of the \
entropy term against the reward (default: 0.2)')
args = parser.parse_args()
main()
......@@ -19,4 +19,5 @@ from parl.algorithms.fluid.ddqn import *
from parl.algorithms.fluid.policy_gradient import *
from parl.algorithms.fluid.ppo import *
from parl.algorithms.fluid.td3 import *
from parl.algorithms.fluid.sac import *
from parl.algorithms.fluid.impala.impala import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.core.fluid import layers
from copy import deepcopy
import numpy as np
from paddle import fluid
from paddle.fluid.layers import Normal
from parl.core.fluid.algorithm import Algorithm
epsilon = 1e-6
__all__ = ['SAC']
class SAC(Algorithm):
def __init__(self,
actor,
critic,
max_action,
alpha=0.2,
gamma=None,
tau=None,
actor_lr=None,
critic_lr=None):
""" SAC algorithm
Args:
actor (parl.Model): forward network of actor.
critic (patl.Model): forward network of the critic.
max_action (float): the largest value that an action can be, env.action_space.high[0]
alpha (float): Temperature parameter determines the relative importance of the entropy against the reward
gamma (float): discounted factor for reward computation.
tau (float): decay coefficient when updating the weights of self.target_model with self.model
actor_lr (float): learning rate of the actor model
critic_lr (float): learning rate of the critic model
"""
assert isinstance(gamma, float)
assert isinstance(tau, float)
assert isinstance(actor_lr, float)
assert isinstance(critic_lr, float)
assert isinstance(alpha, float)
self.max_action = max_action
self.gamma = gamma
self.tau = tau
self.actor_lr = actor_lr
self.critic_lr = critic_lr
self.alpha = alpha
self.actor = actor
self.critic = critic
self.target_critic = deepcopy(critic)
def predict(self, obs):
""" use actor model of self.policy to predict the action
"""
mean, _ = self.actor.policy(obs)
mean = layers.tanh(mean) * self.max_action
return mean
def sample(self, obs):
mean, log_std = self.actor.policy(obs)
std = layers.exp(log_std)
normal = Normal(mean, std)
x_t = normal.sample([1])[0]
y_t = layers.tanh(x_t)
action = y_t * self.max_action
log_prob = normal.log_prob(x_t)
log_prob -= layers.log(self.max_action * (1 - layers.pow(y_t, 2)) +
epsilon)
log_prob = layers.reduce_sum(log_prob, dim=1, keep_dim=True)
log_prob = layers.squeeze(log_prob, axes=[1])
return action, log_prob
def learn(self, obs, action, reward, next_obs, terminal):
actor_cost = self.actor_learn(obs)
critic_cost = self.critic_learn(obs, action, reward, next_obs,
terminal)
return critic_cost, actor_cost
def actor_learn(self, obs):
action, log_pi = self.sample(obs)
qf1_pi, qf2_pi = self.critic.value(obs, action)
min_qf_pi = layers.elementwise_min(qf1_pi, qf2_pi)
cost = log_pi * self.alpha - min_qf_pi
cost = layers.reduce_mean(cost)
optimizer = fluid.optimizer.AdamOptimizer(self.actor_lr)
optimizer.minimize(cost, parameter_list=self.actor.parameters())
return cost
def critic_learn(self, obs, action, reward, next_obs, terminal):
next_state_action, next_state_log_pi = self.sample(next_obs)
qf1_next_target, qf2_next_target = self.target_critic.value(
next_obs, next_state_action)
min_qf_next_target = layers.elementwise_min(
qf1_next_target, qf2_next_target) - next_state_log_pi * self.alpha
terminal = layers.cast(terminal, dtype='float32')
target_Q = reward + (1.0 - terminal) * self.gamma * min_qf_next_target
target_Q.stop_gradient = True
current_Q1, current_Q2 = self.critic.value(obs, action)
cost = layers.square_error_cost(current_Q1,
target_Q) + layers.square_error_cost(
current_Q2, target_Q)
cost = layers.reduce_mean(cost)
optimizer = fluid.optimizer.AdamOptimizer(self.critic_lr)
optimizer.minimize(cost)
return cost
def sync_target(self, decay=None):
if decay is None:
decay = 1.0 - self.tau
self.critic.sync_weights_to(self.target_critic, decay=decay)
......@@ -219,11 +219,14 @@ def start_worker(address, cpu_num):
@click.command("stop", help="Exit the cluster.")
def stop():
command = ("pkill -f remote/start.py")
command = (
"ps aux | grep remote/start.py | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True)
command = ("pkill -f remote/job.py")
command = (
"ps aux | grep remote/job.py | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True)
command = ("pkill -f remote/monitor.py")
command = (
"ps aux | grep remote/monitor.py | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True)
......
......@@ -70,7 +70,8 @@ class TestJobAlone(unittest.TestCase):
time.sleep(1)
self.assertEqual(master.cpu_num, 4)
print("We are going to kill all the jobs.")
command = ("pkill -f remote/job.py")
command = (
"ps aux | grep remote/job.py | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True)
parl.connect('localhost:1334')
actor = Actor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册