diff --git a/.teamcity/requirements.txt b/.teamcity/requirements.txt
index 0b2d8e3ebe2cfd15cb5b2bc63cda5c9add38dd1a..354e3632e02ce8e678df2024a6d16657281c1a0e 100644
--- a/.teamcity/requirements.txt
+++ b/.teamcity/requirements.txt
@@ -1,5 +1,5 @@
# requirements for unittest
-paddlepaddle-gpu==1.5.1.post97
+paddlepaddle-gpu==1.6.1.post97
gym
details
parameterized
diff --git a/README.cn.md b/README.cn.md
index 8a6c9fb4fe423ed5f12bd58a264a66f776ca30f1..09f1df56a90bcc36dd0971038dfd15de501034ec 100644
--- a/README.cn.md
+++ b/README.cn.md
@@ -76,8 +76,10 @@ pip install parl
- [PPO](examples/PPO/)
- [IMPALA](examples/IMPALA/)
- [A2C](examples/A2C/)
-- [GA3C](examples/GA3C/)
+- [TD3](examples/TD3/)
+- [SAC](examples/SAC/)
- [冠军解决方案:NIPS2018强化学习假肢挑战赛](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/)
+- [冠军解决方案:NIPS2019强化学习仿生人控制赛事](examples/NeurIPS2019-Learn-to-Move-Challenge/)
diff --git a/README.md b/README.md
index 29c87fa1fb228446e6145aea40da32b1c34efd6b..5245c349951b28a1a6c74e0a15cfc89d22edfaf3 100644
--- a/README.md
+++ b/README.md
@@ -79,8 +79,10 @@ pip install parl
- [PPO](examples/PPO/)
- [IMPALA](examples/IMPALA/)
- [A2C](examples/A2C/)
-- [GA3C](examples/GA3C/)
+- [TD3](examples/TD3/)
+- [SAC](examples/SAC/)
- [Winning Solution for NIPS2018: AI for Prosthetics Challenge](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/)
+- [Winning Solution for NIPS2019: Learn to Move Challenge](examples/NeurIPS2019-Learn-to-Move-Challenge/)
diff --git a/examples/SAC/.benchmark/merge.png b/examples/SAC/.benchmark/merge.png
new file mode 100644
index 0000000000000000000000000000000000000000..95a74a5fd047b11fecf4d176ef3a0d688eb73850
Binary files /dev/null and b/examples/SAC/.benchmark/merge.png differ
diff --git a/examples/SAC/README.md b/examples/SAC/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c9a02209d663556900842317aa5f1ab987e14af3
--- /dev/null
+++ b/examples/SAC/README.md
@@ -0,0 +1,32 @@
+## Reproduce SAC with PARL
+Based on PARL, the SAC algorithm of deep reinforcement learning has been reproduced, reaching the same level of indicators as the paper in Mujoco benchmarks.
+
+Include following approaches:
++ DDPG Style with Stochastic Policy
++ Maximum Entropy
+
+> SAC in
+[Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor](https://arxiv.org/abs/1801.01290)
+
+### Mujoco games introduction
+Please see [here](https://github.com/openai/mujoco-py) to know more about Mujoco games.
+
+### Benchmark result
+
+
+
+## How to use
+### Dependencies:
++ python3.5+
++ [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle)
++ [parl](https://github.com/PaddlePaddle/PARL)
++ gym
++ mujoco-py>=1.50.1.0
+
+### Start Training:
+```
+# To train an agent for HalfCheetah-v2 game
+python train.py
+
+# To train for different games
+# python train.py --env [ENV_NAME]
diff --git a/examples/SAC/mujoco_agent.py b/examples/SAC/mujoco_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4c1ccec4bc68e47b51bb41c63b120ec79d3ffd9
--- /dev/null
+++ b/examples/SAC/mujoco_agent.py
@@ -0,0 +1,87 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import parl
+from parl import layers
+from paddle import fluid
+
+
+class MujocoAgent(parl.Agent):
+ def __init__(self, algorithm, obs_dim, act_dim):
+ assert isinstance(obs_dim, int)
+ assert isinstance(act_dim, int)
+ self.obs_dim = obs_dim
+ self.act_dim = act_dim
+ super(MujocoAgent, self).__init__(algorithm)
+
+ # Attention: In the beginning, sync target model totally.
+ self.alg.sync_target(decay=0)
+
+ def build_program(self):
+ self.pred_program = fluid.Program()
+ self.sample_program = fluid.Program()
+ self.learn_program = fluid.Program()
+
+ with fluid.program_guard(self.pred_program):
+ obs = layers.data(
+ name='obs', shape=[self.obs_dim], dtype='float32')
+ self.pred_act = self.alg.predict(obs)
+
+ with fluid.program_guard(self.sample_program):
+ obs = layers.data(
+ name='obs', shape=[self.obs_dim], dtype='float32')
+ self.sample_act, _ = self.alg.sample(obs)
+
+ with fluid.program_guard(self.learn_program):
+ obs = layers.data(
+ name='obs', shape=[self.obs_dim], dtype='float32')
+ act = layers.data(
+ name='act', shape=[self.act_dim], dtype='float32')
+ reward = layers.data(name='reward', shape=[], dtype='float32')
+ next_obs = layers.data(
+ name='next_obs', shape=[self.obs_dim], dtype='float32')
+ terminal = layers.data(name='terminal', shape=[], dtype='bool')
+ self.critic_cost, self.actor_cost = self.alg.learn(
+ obs, act, reward, next_obs, terminal)
+
+ def predict(self, obs):
+ obs = np.expand_dims(obs, axis=0)
+ act = self.fluid_executor.run(
+ self.pred_program, feed={'obs': obs},
+ fetch_list=[self.pred_act])[0]
+ return act
+
+ def sample(self, obs):
+ obs = np.expand_dims(obs, axis=0)
+ act = self.fluid_executor.run(
+ self.sample_program,
+ feed={'obs': obs},
+ fetch_list=[self.sample_act])[0]
+ return act
+
+ def learn(self, obs, act, reward, next_obs, terminal):
+ feed = {
+ 'obs': obs,
+ 'act': act,
+ 'reward': reward,
+ 'next_obs': next_obs,
+ 'terminal': terminal
+ }
+ [critic_cost, actor_cost] = self.fluid_executor.run(
+ self.learn_program,
+ feed=feed,
+ fetch_list=[self.critic_cost, self.actor_cost])
+ self.alg.sync_target()
+ return critic_cost[0], actor_cost[0]
diff --git a/examples/SAC/mujoco_model.py b/examples/SAC/mujoco_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4d9ee39df3e3ce2e2718ab436638836a49fe7b2
--- /dev/null
+++ b/examples/SAC/mujoco_model.py
@@ -0,0 +1,69 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle.fluid as fluid
+import parl
+from parl import layers
+
+LOG_SIG_MAX = 2.0
+LOG_SIG_MIN = -20.0
+
+
+class ActorModel(parl.Model):
+ def __init__(self, act_dim):
+ hid1_size = 400
+ hid2_size = 300
+
+ self.fc1 = layers.fc(size=hid1_size, act='relu')
+ self.fc2 = layers.fc(size=hid2_size, act='relu')
+ self.mean_linear = layers.fc(size=act_dim)
+ self.log_std_linear = layers.fc(size=act_dim)
+
+ def policy(self, obs):
+ hid1 = self.fc1(obs)
+ hid2 = self.fc2(hid1)
+ means = self.mean_linear(hid2)
+ log_std = self.log_std_linear(hid2)
+ log_std = layers.clip(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
+
+ return means, log_std
+
+
+class CriticModel(parl.Model):
+ def __init__(self):
+ hid1_size = 400
+ hid2_size = 300
+
+ self.fc1 = layers.fc(size=hid1_size, act='relu')
+ self.fc2 = layers.fc(size=hid2_size, act='relu')
+ self.fc3 = layers.fc(size=1, act=None)
+
+ self.fc4 = layers.fc(size=hid1_size, act='relu')
+ self.fc5 = layers.fc(size=hid2_size, act='relu')
+ self.fc6 = layers.fc(size=1, act=None)
+
+ def value(self, obs, act):
+ hid1 = self.fc1(obs)
+ concat1 = layers.concat([hid1, act], axis=1)
+ Q1 = self.fc2(concat1)
+ Q1 = self.fc3(Q1)
+ Q1 = layers.squeeze(Q1, axes=[1])
+
+ hid2 = self.fc4(obs)
+ concat2 = layers.concat([hid2, act], axis=1)
+ Q2 = self.fc5(concat2)
+ Q2 = self.fc6(Q2)
+ Q2 = layers.squeeze(Q2, axes=[1])
+
+ return Q1, Q2
diff --git a/examples/SAC/train.py b/examples/SAC/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..a88260245880a39738f931573dd0b183487722df
--- /dev/null
+++ b/examples/SAC/train.py
@@ -0,0 +1,150 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Refer to https://github.com/pranz24/pytorch-soft-actor-critic
+
+import argparse
+import gym
+import numpy as np
+import time
+import parl
+from mujoco_agent import MujocoAgent
+from mujoco_model import ActorModel, CriticModel
+from parl.utils import logger, tensorboard, action_mapping, ReplayMemory
+
+ACTOR_LR = 1e-3
+CRITIC_LR = 1e-3
+GAMMA = 0.99
+TAU = 0.005
+MEMORY_SIZE = int(1e6)
+WARMUP_SIZE = 1e4
+BATCH_SIZE = 256
+ENV_SEED = 1
+
+
+def run_train_episode(env, agent, rpm):
+ obs = env.reset()
+ total_reward = 0
+ steps = 0
+ while True:
+ steps += 1
+ batch_obs = np.expand_dims(obs, axis=0)
+
+ if rpm.size() < WARMUP_SIZE:
+ action = env.action_space.sample()
+ else:
+ action = agent.sample(batch_obs.astype('float32'))
+ action = np.squeeze(action)
+
+ next_obs, reward, done, info = env.step(action)
+
+ rpm.append(obs, action, reward, next_obs, done)
+
+ if rpm.size() > WARMUP_SIZE:
+ batch_obs, batch_action, batch_reward, batch_next_obs, batch_terminal = rpm.sample_batch(
+ BATCH_SIZE)
+ agent.learn(batch_obs, batch_action, batch_reward, batch_next_obs,
+ batch_terminal)
+
+ obs = next_obs
+ total_reward += reward
+
+ if done:
+ break
+ return total_reward, steps
+
+
+def run_evaluate_episode(env, agent):
+ obs = env.reset()
+ total_reward = 0
+ while True:
+ batch_obs = np.expand_dims(obs, axis=0)
+ action = agent.predict(batch_obs.astype('float32'))
+ action = np.squeeze(action)
+
+ next_obs, reward, done, info = env.step(action)
+
+ obs = next_obs
+ total_reward += reward
+
+ if done:
+ break
+ return total_reward
+
+
+def main():
+ env = gym.make(args.env)
+ env.seed(ENV_SEED)
+
+ obs_dim = env.observation_space.shape[0]
+ act_dim = env.action_space.shape[0]
+ max_action = float(env.action_space.high[0])
+
+ actor = ActorModel(act_dim)
+ critic = CriticModel()
+ algorithm = parl.algorithms.SAC(
+ actor,
+ critic,
+ max_action=max_action,
+ gamma=GAMMA,
+ tau=TAU,
+ actor_lr=ACTOR_LR,
+ critic_lr=CRITIC_LR)
+ agent = MujocoAgent(algorithm, obs_dim, act_dim)
+
+ rpm = ReplayMemory(MEMORY_SIZE, obs_dim, act_dim)
+
+ test_flag = 0
+ total_steps = 0
+ while total_steps < args.train_total_steps:
+ train_reward, steps = run_train_episode(env, agent, rpm)
+ total_steps += steps
+ logger.info('Steps: {} Reward: {}'.format(total_steps, train_reward))
+ tensorboard.add_scalar('train/episode_reward', train_reward,
+ total_steps)
+
+ if total_steps // args.test_every_steps >= test_flag:
+ while total_steps // args.test_every_steps >= test_flag:
+ test_flag += 1
+ evaluate_reward = run_evaluate_episode(env, agent)
+ logger.info('Steps {}, Evaluate reward: {}'.format(
+ total_steps, evaluate_reward))
+ tensorboard.add_scalar('eval/episode_reward', evaluate_reward,
+ total_steps)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--env', help='Mujoco environment name', default='HalfCheetah-v2')
+ parser.add_argument(
+ '--train_total_steps',
+ type=int,
+ default=int(1e6),
+ help='maximum training steps')
+ parser.add_argument(
+ '--test_every_steps',
+ type=int,
+ default=int(1e4),
+ help='the step interval between two consecutive evaluations')
+ parser.add_argument(
+ '--alpha',
+ type=float,
+ default=0.2,
+ help='Temperature parameter α determines the relative importance of the \
+ entropy term against the reward (default: 0.2)')
+
+ args = parser.parse_args()
+
+ main()
diff --git a/parl/algorithms/fluid/__init__.py b/parl/algorithms/fluid/__init__.py
index 468be6f249fe82a52760d7872517b493a6ed5f99..3f005ac623ed5cf7fc95b2f9242c962dbfbf3302 100644
--- a/parl/algorithms/fluid/__init__.py
+++ b/parl/algorithms/fluid/__init__.py
@@ -19,4 +19,5 @@ from parl.algorithms.fluid.ddqn import *
from parl.algorithms.fluid.policy_gradient import *
from parl.algorithms.fluid.ppo import *
from parl.algorithms.fluid.td3 import *
+from parl.algorithms.fluid.sac import *
from parl.algorithms.fluid.impala.impala import *
diff --git a/parl/algorithms/fluid/sac.py b/parl/algorithms/fluid/sac.py
new file mode 100644
index 0000000000000000000000000000000000000000..cec92c98568905af7bce64252e9f3ff0531da039
--- /dev/null
+++ b/parl/algorithms/fluid/sac.py
@@ -0,0 +1,127 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from parl.core.fluid import layers
+from copy import deepcopy
+import numpy as np
+from paddle import fluid
+from paddle.fluid.layers import Normal
+from parl.core.fluid.algorithm import Algorithm
+
+epsilon = 1e-6
+
+__all__ = ['SAC']
+
+
+class SAC(Algorithm):
+ def __init__(self,
+ actor,
+ critic,
+ max_action,
+ alpha=0.2,
+ gamma=None,
+ tau=None,
+ actor_lr=None,
+ critic_lr=None):
+ """ SAC algorithm
+
+ Args:
+ actor (parl.Model): forward network of actor.
+ critic (patl.Model): forward network of the critic.
+ max_action (float): the largest value that an action can be, env.action_space.high[0]
+ alpha (float): Temperature parameter determines the relative importance of the entropy against the reward
+ gamma (float): discounted factor for reward computation.
+ tau (float): decay coefficient when updating the weights of self.target_model with self.model
+ actor_lr (float): learning rate of the actor model
+ critic_lr (float): learning rate of the critic model
+ """
+ assert isinstance(gamma, float)
+ assert isinstance(tau, float)
+ assert isinstance(actor_lr, float)
+ assert isinstance(critic_lr, float)
+ assert isinstance(alpha, float)
+ self.max_action = max_action
+ self.gamma = gamma
+ self.tau = tau
+ self.actor_lr = actor_lr
+ self.critic_lr = critic_lr
+
+ self.alpha = alpha
+
+ self.actor = actor
+ self.critic = critic
+ self.target_critic = deepcopy(critic)
+
+ def predict(self, obs):
+ """ use actor model of self.policy to predict the action
+ """
+ mean, _ = self.actor.policy(obs)
+ mean = layers.tanh(mean) * self.max_action
+ return mean
+
+ def sample(self, obs):
+ mean, log_std = self.actor.policy(obs)
+ std = layers.exp(log_std)
+ normal = Normal(mean, std)
+ x_t = normal.sample([1])[0]
+ y_t = layers.tanh(x_t)
+ action = y_t * self.max_action
+ log_prob = normal.log_prob(x_t)
+ log_prob -= layers.log(self.max_action * (1 - layers.pow(y_t, 2)) +
+ epsilon)
+ log_prob = layers.reduce_sum(log_prob, dim=1, keep_dim=True)
+ log_prob = layers.squeeze(log_prob, axes=[1])
+ return action, log_prob
+
+ def learn(self, obs, action, reward, next_obs, terminal):
+ actor_cost = self.actor_learn(obs)
+ critic_cost = self.critic_learn(obs, action, reward, next_obs,
+ terminal)
+ return critic_cost, actor_cost
+
+ def actor_learn(self, obs):
+ action, log_pi = self.sample(obs)
+ qf1_pi, qf2_pi = self.critic.value(obs, action)
+ min_qf_pi = layers.elementwise_min(qf1_pi, qf2_pi)
+ cost = log_pi * self.alpha - min_qf_pi
+ cost = layers.reduce_mean(cost)
+ optimizer = fluid.optimizer.AdamOptimizer(self.actor_lr)
+ optimizer.minimize(cost, parameter_list=self.actor.parameters())
+
+ return cost
+
+ def critic_learn(self, obs, action, reward, next_obs, terminal):
+ next_state_action, next_state_log_pi = self.sample(next_obs)
+ qf1_next_target, qf2_next_target = self.target_critic.value(
+ next_obs, next_state_action)
+ min_qf_next_target = layers.elementwise_min(
+ qf1_next_target, qf2_next_target) - next_state_log_pi * self.alpha
+
+ terminal = layers.cast(terminal, dtype='float32')
+ target_Q = reward + (1.0 - terminal) * self.gamma * min_qf_next_target
+ target_Q.stop_gradient = True
+
+ current_Q1, current_Q2 = self.critic.value(obs, action)
+ cost = layers.square_error_cost(current_Q1,
+ target_Q) + layers.square_error_cost(
+ current_Q2, target_Q)
+ cost = layers.reduce_mean(cost)
+ optimizer = fluid.optimizer.AdamOptimizer(self.critic_lr)
+ optimizer.minimize(cost)
+ return cost
+
+ def sync_target(self, decay=None):
+ if decay is None:
+ decay = 1.0 - self.tau
+ self.critic.sync_weights_to(self.target_critic, decay=decay)
diff --git a/parl/remote/scripts.py b/parl/remote/scripts.py
index fd76419684bce3c8fad63a8099ea86dca8c7f88b..71677d692878eef63f65b0ff1054cb6233b0d7a5 100644
--- a/parl/remote/scripts.py
+++ b/parl/remote/scripts.py
@@ -219,11 +219,14 @@ def start_worker(address, cpu_num):
@click.command("stop", help="Exit the cluster.")
def stop():
- command = ("pkill -f remote/start.py")
+ command = (
+ "ps aux | grep remote/start.py | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True)
- command = ("pkill -f remote/job.py")
+ command = (
+ "ps aux | grep remote/job.py | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True)
- command = ("pkill -f remote/monitor.py")
+ command = (
+ "ps aux | grep remote/monitor.py | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True)
diff --git a/parl/remote/tests/reset_job_test_alone.py b/parl/remote/tests/reset_job_test_alone.py
index 7ca5969658548dfc977a42ce0f5350f90b7a4ea5..81cc2fe77a102521c0dc0633d215821a2a5d991c 100644
--- a/parl/remote/tests/reset_job_test_alone.py
+++ b/parl/remote/tests/reset_job_test_alone.py
@@ -70,7 +70,8 @@ class TestJobAlone(unittest.TestCase):
time.sleep(1)
self.assertEqual(master.cpu_num, 4)
print("We are going to kill all the jobs.")
- command = ("pkill -f remote/job.py")
+ command = (
+ "ps aux | grep remote/job.py | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True)
parl.connect('localhost:1334')
actor = Actor()