diff --git a/.scripts/update_readme_paddle_version.py b/.scripts/update_readme_paddle_version.py
index 56d56914c65956a2bb753bc58269d59034766b1c..901d2d672d9f3eff1021241ac80b6e9f75d0886a 100644
--- a/.scripts/update_readme_paddle_version.py
+++ b/.scripts/update_readme_paddle_version.py
@@ -37,7 +37,8 @@ if __name__ == '__main__':
exclude_examples = [
'NeurIPS2019-Learn-to-Move-Challenge',
- 'NeurIPS2018-AI-for-Prosthetics-Challenge', 'EagerMode'
+ 'NeurIPS2018-AI-for-Prosthetics-Challenge', 'LiftSim_baseline',
+ 'EagerMode'
]
for example in os.listdir('../examples/'):
if example not in exclude_examples:
diff --git a/examples/A2C/README.md b/examples/A2C/README.md
index d38a5d153b3ab39c59b851775567d90edcdda4fb..2328a3ee350851370c5c44bbed6b4e2daae27512 100755
--- a/examples/A2C/README.md
+++ b/examples/A2C/README.md
@@ -20,7 +20,7 @@ Performance of A2C on various envrionments
## How to use
### Dependencies
+ [paddlepaddle>=1.6.1](https://github.com/PaddlePaddle/Paddle)
-+ [parl](https://github.com/PaddlePaddle/PARL)
++ [parl>=1.2.1](https://github.com/PaddlePaddle/PARL)
+ gym==0.12.1
+ atari-py==0.1.7
diff --git a/examples/A2C/train.py b/examples/A2C/train.py
index 777a22849afcb3ad5b1e237d7e3d0ae9b39fa871..1daba4736f8e19e9f50d425c12868a43ee0933db 100755
--- a/examples/A2C/train.py
+++ b/examples/A2C/train.py
@@ -55,11 +55,6 @@ class Learner(object):
assert get_gpu_count() == 1, 'Only support training in single GPU,\
Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_TO_USE]` .'
- else:
- cpu_num = os.environ.get('CPU_NUM')
- assert cpu_num is not None and cpu_num == '1', 'Only support training in single CPU,\
- Please set environment variable: `export CPU_NUM=1`.'
-
#========== Learner ==========
self.total_loss_stat = WindowStat(100)
diff --git a/examples/LiftSim_baseline/A2C/README.md b/examples/LiftSim_baseline/A2C/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..69019aaa22393194e110b72072d16e97d6d6dd3b
--- /dev/null
+++ b/examples/LiftSim_baseline/A2C/README.md
@@ -0,0 +1,47 @@
+# LiftSim基线
+
+## 简介
+
+基于PARL库实现A2C算法,应用于[RLSchool][rlschool]库中的电梯调度模拟环境[LiftSim][liftsim]。
+
+## 依赖库
+
++ [paddlepaddle>=1.6.1](https://github.com/PaddlePaddle/Paddle)
++ [parl>=1.2.1](https://github.com/PaddlePaddle/PARL)
++ [rlschool>=0.1.1][rlschool]
+
+
+## 分布式训练
+
+首先,启动一个具有5个CPU资源的本地集群:
+
+```bash
+xparl start --port 8010 --cpu_num 5
+```
+
+> 注意,如果你已经启动了一个集群,则不需要重复运行上面命令。关于PARL集群更多信息,可以参考[文档](https://parl.readthedocs.io/en/latest/parallel_training/setup.html)。
+
+然后我们就可以通过运行下面命令进行分布式训练:
+
+```bash
+python train.py
+```
+
+
+## 评估
+可以通过下面命令来评估保存的模型
+```bash
+python evaluate.py --model_path saved_models/[FILENAME]
+```
+
+tensorboard和log文件会保存在`./train_log/train/`;可以通过运行命令`tensorboard --logdir .`查看tensorboard可视化界面。
+
+## 收敛指标
+训练30h左右,评估指标能达到-120分左右(LiftSim环境运行1天reward)
+
+
+## 可视化效果
+
+
+[rlschool]: https://github.com/PaddlePaddle/RLSchool
+[liftsim]: https://github.com/PaddlePaddle/RLSchool/tree/master/rlschool/liftsim
diff --git a/examples/LiftSim_baseline/A2C/a2c_config.py b/examples/LiftSim_baseline/A2C/a2c_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3e776b2f1033bc6bd911ed929437f91c44bf938
--- /dev/null
+++ b/examples/LiftSim_baseline/A2C/a2c_config.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+config = {
+ #========== remote config ==========
+ 'master_address': 'localhost:8010',
+
+ #========== actor config ==========
+ 'actor_num': 5,
+ 'env_num': 5,
+ 'sample_batch_steps': 5,
+
+ #========== learner config ==========
+ 'max_sample_steps': int(1e10),
+ 'gamma': 0.998,
+ 'lambda': 1.0, # GAE
+
+ # start learning rate
+ 'start_lr': 1.0e-4,
+
+ # coefficient of policy entropy adjustment schedule: (train_step, coefficient)
+ 'entropy_coeff_scheduler': [(0, -2.0e-4)],
+ 'vf_loss_coeff': 0.5,
+ 'get_remote_metrics_interval': 100,
+ 'log_metrics_interval_s': 60,
+}
diff --git a/examples/LiftSim_baseline/A2C/actor.py b/examples/LiftSim_baseline/A2C/actor.py
new file mode 100644
index 0000000000000000000000000000000000000000..4286e7f9c4dd64b444c2ca34d5ebffe870b2769f
--- /dev/null
+++ b/examples/LiftSim_baseline/A2C/actor.py
@@ -0,0 +1,137 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import parl
+from collections import defaultdict
+from env_wrapper import ObsProcessWrapper, ActionProcessWrapper, RewardWrapper, MetricsWrapper
+from parl.utils.rl_utils import calc_gae
+from parl.env.vector_env import VectorEnv
+from rlschool import LiftSim
+from copy import deepcopy
+from lift_model import LiftModel
+from lift_agent import LiftAgent
+
+
+@parl.remote_class
+class Actor(object):
+ def __init__(self, config):
+ self.config = config
+ self.env_num = config['env_num']
+
+ self.envs = []
+ for _ in range(self.env_num):
+ env = LiftSim()
+ env = RewardWrapper(env)
+ env = ActionProcessWrapper(env)
+ env = ObsProcessWrapper(env)
+ env = MetricsWrapper(env)
+ self.envs.append(env)
+ self.vector_env = VectorEnv(self.envs)
+
+ # number of elevators
+ self.ele_num = self.envs[0].mansion_attr.ElevatorNumber
+
+ act_dim = self.envs[0].act_dim
+ self.obs_dim = self.envs[0].obs_dim
+ self.config['obs_dim'] = self.obs_dim
+
+ # nested list of shape (env_num, ele_num, obs_dim)
+ self.obs_batch = self.vector_env.reset()
+ # (env_num * ele_num, obs_dim)
+ self.obs_batch = np.array(self.obs_batch).reshape(
+ [self.env_num * self.ele_num, self.obs_dim])
+
+ model = LiftModel(act_dim)
+ algorithm = parl.algorithms.A3C(
+ model, vf_loss_coeff=config['vf_loss_coeff'])
+ self.agent = LiftAgent(algorithm, config)
+
+ def sample(self):
+ sample_data = defaultdict(list)
+
+ env_sample_data = {}
+ # treat each elevator in Liftsim as an independent env
+ for env_id in range(self.env_num * self.ele_num):
+ env_sample_data[env_id] = defaultdict(list)
+
+ for i in range(self.config['sample_batch_steps']):
+ actions_batch, values_batch = self.agent.sample(self.obs_batch)
+
+ vector_actions = np.array_split(actions_batch, self.env_num)
+ assert len(vector_actions[-1]) == self.ele_num
+ next_obs_batch, reward_batch, done_batch, info_batch = \
+ self.vector_env.step(vector_actions)
+
+ # (env_num, ele_num, obs_dim) -> (env_num * ele_num, obs_dim)
+ next_obs_batch = np.array(next_obs_batch).reshape(
+ [self.env_num * self.ele_num, self.obs_dim])
+ # repeat reward and done to ele_num times
+ # (env_num) -> (env_num, ele_num) -> (env_num * ele_num)
+ reward_batch = np.repeat(reward_batch, self.ele_num)
+ done_batch = np.repeat(done_batch, self.ele_num)
+
+ for env_id in range(self.env_num * self.ele_num):
+ env_sample_data[env_id]['obs'].append(self.obs_batch[env_id])
+ env_sample_data[env_id]['actions'].append(
+ actions_batch[env_id])
+ env_sample_data[env_id]['rewards'].append(reward_batch[env_id])
+ env_sample_data[env_id]['dones'].append(done_batch[env_id])
+ env_sample_data[env_id]['values'].append(values_batch[env_id])
+
+ # Calculate advantages when the episode is done or reaches max sample steps.
+ if done_batch[env_id] or i + 1 == self.config[
+ 'sample_batch_steps']: # reach max sample steps
+ next_value = 0
+ if not done_batch[env_id]:
+ next_obs = np.expand_dims(next_obs_batch[env_id], 0)
+ next_value = self.agent.value(next_obs)
+
+ values = env_sample_data[env_id]['values']
+ rewards = env_sample_data[env_id]['rewards']
+ advantages = calc_gae(rewards, values, next_value,
+ self.config['gamma'],
+ self.config['lambda'])
+ target_values = advantages + values
+
+ sample_data['obs'].extend(env_sample_data[env_id]['obs'])
+ sample_data['actions'].extend(
+ env_sample_data[env_id]['actions'])
+ sample_data['advantages'].extend(advantages)
+ sample_data['target_values'].extend(target_values)
+
+ env_sample_data[env_id] = defaultdict(list)
+
+ self.obs_batch = deepcopy(next_obs_batch)
+
+ # size of sample_data[key]: env_num * ele_num * sample_batch_steps
+ for key in sample_data:
+ sample_data[key] = np.stack(sample_data[key])
+
+ return sample_data
+
+ def get_metrics(self):
+ metrics = defaultdict(list)
+ for metrics_env in self.envs:
+ assert isinstance(
+ metrics_env,
+ MetricsWrapper), "Put the MetricsWrapper in the last wrapper"
+ for env_reward_1h, env_reward_24h in metrics_env.next_episode_results(
+ ):
+ metrics['env_reward_1h'].append(env_reward_1h)
+ metrics['env_reward_24h'].append(env_reward_24h)
+ return metrics
+
+ def set_weights(self, params):
+ self.agent.set_weights(params)
diff --git a/examples/LiftSim_baseline/A2C/effect.gif b/examples/LiftSim_baseline/A2C/effect.gif
new file mode 100644
index 0000000000000000000000000000000000000000..b796acbf3c872f0e960e731fb99848c774446165
Binary files /dev/null and b/examples/LiftSim_baseline/A2C/effect.gif differ
diff --git a/examples/LiftSim_baseline/A2C/env_wrapper.py b/examples/LiftSim_baseline/A2C/env_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8d08fac40d1cc6fff6f46e97e2bac649a79bc1b
--- /dev/null
+++ b/examples/LiftSim_baseline/A2C/env_wrapper.py
@@ -0,0 +1,301 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from copy import deepcopy
+import numpy as np
+from utils import discretize, linear_discretize
+from rlschool import LiftSim
+
+
+class BaseWrapper(object):
+ def __init__(self, env):
+ self.env = env
+ self._mansion = env._mansion
+ self.mansion_attr = self._mansion.attribute
+
+ @property
+ def obs_dim(self):
+ if hasattr(self.env, 'obs_dim'):
+ return self.env.obs_dim
+ else:
+ return None
+
+ @property
+ def act_dim(self):
+ if hasattr(self.env, 'act_dim'):
+ return self.env.act_dim
+ else:
+ return None
+
+ def seed(self, seed=None):
+ return self.env.seed(seed)
+
+ def step(self, action):
+ return self.env.step(action)
+
+ def reset(self):
+ return self.env.reset()
+
+ def render(self):
+ return self.env.render()
+
+ def close(self):
+ return self.env.close()
+
+
+class ObsProcessWrapper(BaseWrapper):
+ """Extract features of each elevator in LiftSim env
+ """
+
+ def __init__(self, env, hour_distize_num=6):
+ super(ObsProcessWrapper, self).__init__(env)
+ self.hour_distize_num = hour_distize_num
+ self.total_steps = 0
+
+ @property
+ def obs_dim(self):
+ """
+ NOTE:
+ Keep obs_dim to the return size of function `_mansion_state_process`
+ """
+ ele_dim = self.mansion_attr.NumberOfFloor * 3 + 34
+ obs_dim = (ele_dim + 1) * self.mansion_attr.ElevatorNumber + \
+ self.mansion_attr.NumberOfFloor * 2
+ obs_dim += self.hour_distize_num
+ return obs_dim
+
+ def reset(self):
+ """
+
+ Returns:
+ obs(list): [[self.obs_dim]] * mansion_attr.ElevatorNumber, features array of all elevators
+ """
+ obs = self.env.reset()
+ self.total_steps = 0
+ obs = self._mansion_state_process(obs)
+ return obs
+
+ def step(self, action):
+ """
+ Returns:
+ obs(list): nested list, shape of [mansion_attr.ElevatorNumber, self.obs_dim],
+ features array of all elevators
+ reward(int): returned by self.env
+ done(bool): returned by self.env
+ info(dict): returned by self.env
+ """
+ obs, reward, done, info = self.env.step(action)
+ self.total_steps += 1
+ obs = self._mansion_state_process(obs)
+ return obs, reward, done, info
+
+ def _mansion_state_process(self, mansion_state):
+ """Extract features of env
+ """
+ ele_features = list()
+ for ele_state in mansion_state.ElevatorStates:
+ ele_features.append(self._ele_state_process(ele_state))
+ max_floor = ele_state.MaximumFloor
+
+ target_floor_binaries_up = [0.0 for i in range(max_floor)]
+ target_floor_binaries_down = [0.0 for i in range(max_floor)]
+ for floor in mansion_state.RequiringUpwardFloors:
+ target_floor_binaries_up[floor - 1] = 1.0
+ for floor in mansion_state.RequiringDownwardFloors:
+ target_floor_binaries_down[floor - 1] = 1.0
+ target_floor_binaries = target_floor_binaries_up + target_floor_binaries_down
+
+ raw_time = self.total_steps * 0.5 # timestep seconds
+ time_id = int(raw_time % 86400)
+ time_id = time_id // (24 / self.hour_distize_num * 3600)
+ time_id_vec = discretize(time_id + 1, self.hour_distize_num, 1,
+ self.hour_distize_num)
+
+ man_features = list()
+ for idx in range(len(mansion_state.ElevatorStates)):
+ elevator_id_vec = discretize(idx + 1,
+ len(mansion_state.ElevatorStates), 1,
+ len(mansion_state.ElevatorStates))
+ idx_array = list(range(len(mansion_state.ElevatorStates)))
+ idx_array.remove(idx)
+ man_features.append(ele_features[idx])
+ for left_idx in idx_array:
+ man_features[idx] = man_features[idx] + ele_features[left_idx]
+ man_features[idx] = man_features[idx] + \
+ elevator_id_vec + target_floor_binaries
+ man_features[idx] = man_features[idx] + time_id_vec
+ return np.asarray(man_features, dtype='float32')
+
+ def _ele_state_process(self, ele_state):
+ """Extract features of elevator
+ """
+ ele_feature = []
+
+ # add floor information
+ ele_feature.extend(
+ linear_discretize(ele_state.Floor, ele_state.MaximumFloor, 1.0,
+ ele_state.MaximumFloor))
+
+ # add velocity information
+ ele_feature.extend(
+ linear_discretize(ele_state.Velocity, 21, -ele_state.MaximumSpeed,
+ ele_state.MaximumSpeed))
+
+ # add door information
+ ele_feature.append(ele_state.DoorState)
+ ele_feature.append(float(ele_state.DoorIsOpening))
+ ele_feature.append(float(ele_state.DoorIsClosing))
+
+ # add direction information
+ ele_feature.extend(discretize(ele_state.Direction, 3, -1, 1))
+
+ # add load weight information
+ ele_feature.extend(
+ linear_discretize(ele_state.LoadWeight / ele_state.MaximumLoad, 5,
+ 0.0, 1.0))
+
+ # add other information
+ target_floor_binaries = [0.0 for i in range(ele_state.MaximumFloor)]
+ for target_floor in ele_state.ReservedTargetFloors:
+ target_floor_binaries[target_floor - 1] = 1.0
+ ele_feature.extend(target_floor_binaries)
+
+ dispatch_floor_binaries = [
+ 0.0 for i in range(ele_state.MaximumFloor + 1)
+ ]
+ dispatch_floor_binaries[ele_state.CurrentDispatchTarget] = 1.0
+ ele_feature.extend(dispatch_floor_binaries)
+ ele_feature.append(ele_state.DispatchTargetDirection)
+
+ return ele_feature
+
+
+class ActionProcessWrapper(BaseWrapper):
+ def __init__(self, env):
+ """Map action id predicted by model to action of LiftSim
+
+ """
+ super(ActionProcessWrapper, self).__init__(env)
+
+ @property
+ def act_dim(self):
+ """
+ NOTE:
+ keep act_dim in line with function `_action_idx_to_action`
+
+ Returns:
+ int: NumberOfFloor * (2 directions) + (-1 DispatchTarget) + (0 DispatchTarget)
+ """
+ return self.mansion_attr.NumberOfFloor * 2 + 2
+
+ def step(self, action):
+ """
+ Args:
+ action(list): action_id of all elevators (length = mansion_attr.ElevatorNumber)
+ """
+ ele_actions = []
+ for action_id in action:
+ ele_actions.extend(self._action_idx_to_action(action_id))
+
+ # ele_action: list, formatted action for LiftSim env (length = 2 * mansion_attr.ElevatorNumber)
+ return self.env.step(ele_actions)
+
+ def _action_idx_to_action(self, action_idx):
+ action_idx = int(action_idx)
+ realdim = self.act_dim - 2
+ if (action_idx == realdim):
+ return (0, 1) # mapped to DispatchTarget=0
+ elif (action_idx == realdim + 1):
+ return (-1, 1) # mapped to DispatchTarget=-1
+ action = action_idx
+ if (action_idx < realdim / 2):
+ direction = 1 # up direction
+ action += 1
+ else:
+ direction = -1 # down direction
+ action -= int(realdim / 2)
+ action += 1
+ return (action, direction)
+
+
+class RewardWrapper(BaseWrapper):
+ def __init__(self, env):
+ """Design reward of LiftSim env.
+ """
+ super(RewardWrapper, self).__init__(env)
+ self.ele_num = self.mansion_attr.ElevatorNumber
+
+ def step(self, action):
+ """Here we return same reward for each elevator,
+ you alos can design different rewards of each elevator.
+
+ Returns:
+ obs: returned by self.env
+ reward: shaping reward
+ done: returned by self.env
+ info: returned by self.env
+ """
+ obs, origin_reward, done, info = self.env.step(action)
+
+ reward = -(30 * info['time_consume'] + 0.01 * info['energy_consume'] +
+ 100 * info['given_up_persons']) * 1.0e-3 / self.ele_num
+
+ info['origin_reward'] = origin_reward
+
+ return obs, reward, done, info
+
+
+class MetricsWrapper(BaseWrapper):
+ def __init__(self, env):
+ super(MetricsWrapper, self).__init__(env)
+
+ self._total_steps = 0
+ self._env_reward_1h = 0
+ self._env_reward_24h = 0
+
+ self._num_returned = 0
+ self._episode_result = []
+
+ def reset(self):
+ self._total_steps = 0
+ self._env_reward_1h = 0
+ self._env_reward_24h = 0
+ return self.env.reset()
+
+ def step(self, action):
+ obs, reward, done, info = self.env.step(action)
+ self._total_steps += 1
+
+ self._env_reward_1h += info['origin_reward']
+ self._env_reward_24h += info['origin_reward']
+
+ # Treat 1h in LiftSim env as an episode (1step = 0.5s)
+ if self._total_steps % (3600 * 2) == 0: # 1h
+ episode_env_reward_1h = self._env_reward_1h
+ self._env_reward_1h = 0
+
+ episode_env_reward_24h = None
+ if self._total_steps % (24 * 3600 * 2) == 0: # 24h
+ episode_env_reward_24h = self._env_reward_24h
+ self._env_reward_24h = 0
+
+ self._episode_result.append(
+ [episode_env_reward_1h, episode_env_reward_24h])
+
+ return obs, reward, done, info
+
+ def next_episode_results(self):
+ for i in range(self._num_returned, len(self._episode_result)):
+ yield self._episode_result[i]
+ self._num_returned = len(self._episode_result)
diff --git a/examples/LiftSim_baseline/A2C/evaluate.py b/examples/LiftSim_baseline/A2C/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..d004aa4255e7f54d123e2f11de80158e60319e1c
--- /dev/null
+++ b/examples/LiftSim_baseline/A2C/evaluate.py
@@ -0,0 +1,61 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import parl
+from parl.utils import logger
+from env_wrapper import ObsProcessWrapper, ActionProcessWrapper, RewardWrapper
+from rlschool import LiftSim
+from lift_model import LiftModel
+from lift_agent import LiftAgent
+from a2c_config import config
+
+
+def evaluate_one_day(model_path):
+ env = LiftSim()
+ env = ActionProcessWrapper(env)
+ env = ObsProcessWrapper(env)
+ act_dim = env.act_dim
+ obs_dim = env.obs_dim
+ config['obs_dim'] = obs_dim
+
+ model = LiftModel(act_dim)
+ algorithm = parl.algorithms.A3C(
+ model, vf_loss_coeff=config['vf_loss_coeff'])
+ agent = LiftAgent(algorithm, config)
+ agent.restore(model_path)
+
+ reward_24h = 0
+ obs = env.reset()
+ for i in range(24 * 3600 * 2): # 24h, 1step = 0.5s
+ action, _ = agent.sample(obs)
+ #print(action)
+ obs, reward, done, info = env.step(action)
+ reward_24h += reward
+ if (i + 1) % (3600 * 2) == 0:
+ logger.info('hour {}, total_reward: {}'.format(
+ (i + 1) // (3600 * 2), reward_24h))
+
+ logger.info('model_path: {}, 24h reward: {}'.format(
+ model_path, reward_24h))
+
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--model_path', type=str, help='path of the model to evaluate.')
+ args = parser.parse_args()
+
+ evaluate_one_day(args.model_path)
diff --git a/examples/LiftSim_baseline/A2C/lift_agent.py b/examples/LiftSim_baseline/A2C/lift_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c2d64d64c0f89dc939a2458d1b2e9fedbb5242a
--- /dev/null
+++ b/examples/LiftSim_baseline/A2C/lift_agent.py
@@ -0,0 +1,153 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import parl
+import paddle.fluid as fluid
+import numpy as np
+from parl import layers
+from parl.utils.scheduler import PiecewiseScheduler, LinearDecayScheduler
+
+
+class LiftAgent(parl.Agent):
+ def __init__(self, algorithm, config):
+ """
+ Args:
+ algorithm (`parl.Algorithm`): algorithm to be used in this agent.
+ config (dict): config describing the training hyper-parameters(see a2c_config.py)
+ """
+ self.obs_dim = config['obs_dim']
+ super(LiftAgent, self).__init__(algorithm)
+
+ self.lr_scheduler = LinearDecayScheduler(config['start_lr'],
+ config['max_sample_steps'])
+ self.entropy_coeff_scheduler = PiecewiseScheduler(
+ config['entropy_coeff_scheduler'])
+
+ def build_program(self):
+ self.sample_program = fluid.Program()
+ self.predict_program = fluid.Program()
+ self.value_program = fluid.Program()
+ self.learn_program = fluid.Program()
+
+ with fluid.program_guard(self.sample_program):
+ obs = layers.data(
+ name='obs', shape=[self.obs_dim], dtype='float32')
+ sample_actions, values = self.alg.sample(obs)
+ self.sample_outputs = [sample_actions, values]
+
+ with fluid.program_guard(self.predict_program):
+ obs = layers.data(
+ name='obs', shape=[self.obs_dim], dtype='float32')
+ self.predict_actions = self.alg.predict(obs)
+
+ with fluid.program_guard(self.value_program):
+ obs = layers.data(
+ name='obs', shape=[self.obs_dim], dtype='float32')
+ self.values = self.alg.value(obs)
+
+ with fluid.program_guard(self.learn_program):
+ obs = layers.data(
+ name='obs', shape=[self.obs_dim], dtype='float32')
+ actions = layers.data(name='actions', shape=[], dtype='int32')
+ advantages = layers.data(
+ name='advantages', shape=[], dtype='float32')
+ target_values = layers.data(
+ name='target_values', shape=[], dtype='float32')
+ lr = layers.data(
+ name='lr', shape=[1], dtype='float32', append_batch_size=False)
+ entropy_coeff = layers.data(
+ name='entropy_coeff', shape=[], dtype='float32')
+
+ total_loss, pi_loss, vf_loss, entropy = self.alg.learn(
+ obs, actions, advantages, target_values, lr, entropy_coeff)
+ self.learn_outputs = [total_loss, pi_loss, vf_loss, entropy]
+ self.learn_program = parl.compile(self.learn_program, total_loss)
+
+ def sample(self, obs_np):
+ """
+ Args:
+ obs_np: a numpy float32 array of shape (B, obs_dim).
+
+ Returns:
+ sample_ids: a numpy int64 array of shape [B]
+ values: a numpy float32 array of shape [B]
+ """
+ obs_np = obs_np.astype('float32')
+
+ sample_actions, values = self.fluid_executor.run(
+ self.sample_program,
+ feed={'obs': obs_np},
+ fetch_list=self.sample_outputs)
+ return sample_actions, values
+
+ def predict(self, obs_np):
+ """
+ Args:
+ obs_np: a numpy float32 array of shape (B, obs_dim).
+
+ Returns:
+ predict_actions: a numpy int64 array of shape [B]
+ """
+ obs_np = obs_np.astype('float32')
+
+ predict_actions = self.fluid_executor.run(
+ self.predict_program,
+ feed={'obs': obs_np},
+ fetch_list=[self.predict_actions])[0]
+ return predict_actions
+
+ def value(self, obs_np):
+ """
+ Args:
+ obs_np: a numpy float32 array of shape (B, obs_dim).
+
+ Returns:
+ values: a numpy float32 array of shape [B]
+ """
+ obs_np = obs_np.astype('float32')
+
+ values = self.fluid_executor.run(
+ self.value_program, feed={'obs': obs_np},
+ fetch_list=[self.values])[0]
+ return values
+
+ def learn(self, obs_np, actions_np, advantages_np, target_values_np):
+ """
+ Args:
+ obs_np: a numpy float32 array of shape (B, obs_dim).
+ actions_np: a numpy int64 array of shape [B]
+ advantages_np: a numpy float32 array of shape [B]
+ target_values_np: a numpy float32 array of shape [B]
+ """
+
+ obs_np = obs_np.astype('float32')
+ actions_np = actions_np.astype('int64')
+ advantages_np = advantages_np.astype('float32')
+ target_values_np = target_values_np.astype('float32')
+
+ lr = self.lr_scheduler.step(step_num=obs_np.shape[0])
+ entropy_coeff = self.entropy_coeff_scheduler.step()
+
+ total_loss, pi_loss, vf_loss, entropy = self.fluid_executor.run(
+ self.learn_program,
+ feed={
+ 'obs': obs_np,
+ 'actions': actions_np,
+ 'advantages': advantages_np,
+ 'target_values': target_values_np,
+ 'lr': np.array([lr], dtype='float32'),
+ 'entropy_coeff': np.array([entropy_coeff], dtype='float32')
+ },
+ fetch_list=self.learn_outputs)
+ return total_loss, pi_loss, vf_loss, entropy, lr, entropy_coeff
diff --git a/examples/LiftSim_baseline/A2C/lift_model.py b/examples/LiftSim_baseline/A2C/lift_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..6da63723b5ab058ef7333c590725f90268ce52b6
--- /dev/null
+++ b/examples/LiftSim_baseline/A2C/lift_model.py
@@ -0,0 +1,75 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import parl
+import paddle.fluid as fluid
+from parl import layers
+
+
+class LiftModel(parl.Model):
+ def __init__(self, act_dim):
+ self.act_dim = act_dim
+ self.fc_1 = layers.fc(size=512, act='relu')
+ self.fc_2 = layers.fc(size=256, act='relu')
+ self.fc_3 = layers.fc(size=128, act='tanh')
+
+ self.value_fc = layers.fc(size=1)
+ self.policy_fc = layers.fc(size=act_dim)
+
+ def policy(self, obs):
+ """
+ Args:
+ obs(float32 tensor): shape of (B * obs_dim)
+
+ Returns:
+ policy_logits(float32 tensor): shape of (B * act_dim)
+ """
+ h_1 = self.fc_1(obs)
+ h_2 = self.fc_2(h_1)
+ h_3 = self.fc_3(h_2)
+ policy_logits = self.policy_fc(h_3)
+ return policy_logits
+
+ def value(self, obs):
+ """
+ Args:
+ obs(float32 tensor): shape of (B * obs_dim)
+
+ Returns:
+ values(float32 tensor): shape of (B,)
+ """
+ h_1 = self.fc_1(obs)
+ h_2 = self.fc_2(h_1)
+ h_3 = self.fc_3(h_2)
+ values = self.value_fc(h_3)
+ values = layers.squeeze(values, axes=[1])
+ return values
+
+ def policy_and_value(self, obs):
+ """
+ Args:
+ obs(float32 tensor): shape (B * obs_dim)
+
+ Returns:
+ policy_logits(float32 tensor): shape of (B * act_dim)
+ values(float32 tensor): shape of (B,)
+ """
+ h_1 = self.fc_1(obs)
+ h_2 = self.fc_2(h_1)
+ h_3 = self.fc_3(h_2)
+ policy_logits = self.policy_fc(h_3)
+ values = self.value_fc(h_3)
+ values = layers.squeeze(values, axes=[1])
+
+ return policy_logits, values
diff --git a/examples/LiftSim_baseline/A2C/performance.png b/examples/LiftSim_baseline/A2C/performance.png
new file mode 100644
index 0000000000000000000000000000000000000000..153da4eb12bd3219ed5030516bcc001188c980b2
Binary files /dev/null and b/examples/LiftSim_baseline/A2C/performance.png differ
diff --git a/examples/LiftSim_baseline/A2C/train.py b/examples/LiftSim_baseline/A2C/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..3663cb841f676efcc157fccd15fa816cdaae662c
--- /dev/null
+++ b/examples/LiftSim_baseline/A2C/train.py
@@ -0,0 +1,220 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import os
+import parl
+import queue
+import six
+import time
+import threading
+
+from actor import Actor
+from collections import defaultdict
+from env_wrapper import ObsProcessWrapper, ActionProcessWrapper
+from parl.utils import logger, get_gpu_count, tensorboard, machine_info
+from parl.utils.scheduler import PiecewiseScheduler
+from parl.utils.time_stat import TimeStat
+from parl.utils.window_stat import WindowStat
+from rlschool import LiftSim
+from lift_model import LiftModel
+from lift_agent import LiftAgent
+
+
+class Learner(object):
+ def __init__(self, config):
+ self.config = config
+
+ #=========== Create Agent ==========
+ env = LiftSim()
+ env = ActionProcessWrapper(env)
+ env = ObsProcessWrapper(env)
+
+ obs_dim = env.obs_dim
+ act_dim = env.act_dim
+ self.config['obs_dim'] = obs_dim
+
+ model = LiftModel(act_dim)
+ algorithm = parl.algorithms.A3C(
+ model, vf_loss_coeff=config['vf_loss_coeff'])
+ self.agent = LiftAgent(algorithm, config)
+
+ if machine_info.is_gpu_available():
+ assert get_gpu_count() == 1, 'Only support training in single GPU,\
+ Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_TO_USE]` .'
+
+ #========== Learner ==========
+
+ self.entropy_stat = WindowStat(100)
+ self.target_values = None
+
+ self.learn_time_stat = TimeStat(100)
+ self.start_time = None
+
+ #========== Remote Actor ===========
+ self.remote_count = 0
+ self.sample_data_queue = queue.Queue()
+
+ self.remote_metrics_queue = queue.Queue()
+ self.sample_total_steps = 0
+
+ self.params_queues = []
+ self.create_actors()
+
+ self.log_steps = 0
+
+ def create_actors(self):
+ """ Connect to the cluster and start sampling of the remote actor.
+ """
+ parl.connect(self.config['master_address'])
+
+ logger.info('Waiting for {} remote actors to connect.'.format(
+ self.config['actor_num']))
+
+ for i in six.moves.range(self.config['actor_num']):
+ params_queue = queue.Queue()
+ self.params_queues.append(params_queue)
+
+ self.remote_count += 1
+ logger.info('Remote actor count: {}'.format(self.remote_count))
+
+ remote_thread = threading.Thread(
+ target=self.run_remote_sample, args=(params_queue, ))
+ remote_thread.setDaemon(True)
+ remote_thread.start()
+
+ self.start_time = time.time()
+
+ def run_remote_sample(self, params_queue):
+ """ Sample data from remote actor and update parameters of remote actor.
+ """
+ remote_actor = Actor(self.config)
+
+ cnt = 0
+ while True:
+ latest_params = params_queue.get()
+ remote_actor.set_weights(latest_params)
+ batch = remote_actor.sample()
+
+ self.sample_data_queue.put(batch)
+
+ cnt += 1
+ if cnt % self.config['get_remote_metrics_interval'] == 0:
+ metrics = remote_actor.get_metrics()
+ if metrics:
+ self.remote_metrics_queue.put(metrics)
+
+ def step(self):
+ """
+ 1. kick off all actors to synchronize parameters and sample data;
+ 2. collect sample data of all actors;
+ 3. update parameters.
+ """
+ latest_params = self.agent.get_weights()
+ for params_queue in self.params_queues:
+ params_queue.put(latest_params)
+
+ train_batch = defaultdict(list)
+ for i in range(self.config['actor_num']):
+ sample_data = self.sample_data_queue.get()
+ for key, value in sample_data.items():
+ train_batch[key].append(value)
+
+ self.sample_total_steps += sample_data['obs'].shape[0]
+
+ for key, value in train_batch.items():
+ train_batch[key] = np.concatenate(value)
+
+ with self.learn_time_stat:
+ total_loss, pi_loss, vf_loss, entropy, lr, entropy_coeff = self.agent.learn(
+ obs_np=train_batch['obs'],
+ actions_np=train_batch['actions'],
+ advantages_np=train_batch['advantages'],
+ target_values_np=train_batch['target_values'])
+
+ self.entropy_stat.add(entropy)
+ self.target_values = np.mean(train_batch['target_values'])
+
+ tensorboard.add_scalar('model/entropy', entropy,
+ self.sample_total_steps)
+ tensorboard.add_scalar('model/q_value', self.target_values,
+ self.sample_total_steps)
+
+ def log_metrics(self):
+ """ Log metrics of learner and actors
+ """
+ if self.start_time is None:
+ return
+
+ metrics = []
+ while True:
+ try:
+ metric = self.remote_metrics_queue.get_nowait()
+ metrics.append(metric)
+ except queue.Empty:
+ break
+
+ env_reward_1h, env_reward_24h = [], []
+ for x in metrics:
+ env_reward_1h.extend(x['env_reward_1h'])
+ env_reward_24h.extend(x['env_reward_24h'])
+ env_reward_1h = [x for x in env_reward_1h if x is not None]
+ env_reward_24h = [x for x in env_reward_24h if x is not None]
+
+ mean_reward_1h, mean_reward_24h = None, None
+ if env_reward_1h:
+ mean_reward_1h = np.mean(np.array(env_reward_1h).flatten())
+ tensorboard.add_scalar('performance/env_rewards_1h',
+ mean_reward_1h, self.sample_total_steps)
+ if env_reward_24h:
+ mean_reward_24h = np.mean(np.array(env_reward_24h).flatten())
+ tensorboard.add_scalar('performance/env_rewards_24h',
+ mean_reward_24h, self.sample_total_steps)
+
+ metric = {
+ 'Sample steps': self.sample_total_steps,
+ 'env_reward_1h': mean_reward_1h,
+ 'env_reward_24h': mean_reward_24h,
+ 'target_values': self.target_values,
+ 'entropy': self.entropy_stat.mean,
+ 'learn_time_s': self.learn_time_stat.mean,
+ 'elapsed_time_s': int(time.time() - self.start_time),
+ }
+ logger.info(metric)
+
+ self.log_steps += 1
+ save_interval_step = 7200 // max(1,
+ self.config['log_metrics_interval_s'])
+ if self.log_steps % save_interval_step == 0:
+ self.save_model() # save model every 2h
+
+ def should_stop(self):
+ return self.sample_total_steps >= self.config['max_sample_steps']
+
+ def save_model(self):
+ time_str = time.strftime(".%Y%m%d_%H%M%S", time.localtime())
+ self.agent.save(os.path.join('saved_models', 'model.ckpt' + time_str))
+
+
+if __name__ == '__main__':
+ from a2c_config import config
+
+ learner = Learner(config)
+ assert config['log_metrics_interval_s'] > 0
+
+ while not learner.should_stop():
+ start = time.time()
+ while time.time() - start < config['log_metrics_interval_s']:
+ learner.step()
+ learner.log_metrics()
diff --git a/examples/LiftSim_baseline/A2C/utils.py b/examples/LiftSim_baseline/A2C/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..05c081eaa39e4a2366ebf986092c4a9ec11a2c2f
--- /dev/null
+++ b/examples/LiftSim_baseline/A2C/utils.py
@@ -0,0 +1,58 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+def discretize(value, n_dim, min_val, max_val):
+ '''
+ discretize a value into a vector of n_dim dimension 1-hot representation
+ with the value below min_val being [1, 0, 0, ..., 0]
+ and the value above max_val being [0, 0, ..., 0, 1]
+ '''
+ assert n_dim > 0
+ if (n_dim == 1):
+ return [1]
+ delta = (max_val - min_val) / float(n_dim - 1)
+ active_pos = int((value - min_val) / delta + 0.5)
+ active_pos = min(n_dim - 1, active_pos)
+ active_pos = max(0, active_pos)
+ ret_array = [0 for i in range(n_dim)]
+ ret_array[active_pos] = 1.0
+ return ret_array
+
+
+def linear_discretize(value, n_dim, min_val, max_val):
+ '''
+ discretize a value into a vector of n_dim dimensional representation
+ with the value below min_val being [1, 0, 0, ..., 0]
+ and the value above max_val being [0, 0, ..., 0, 1]
+ e.g. if n_dim = 2, min_val = 1.0, max_val = 2.0
+ if value = 1.5 returns [0.5, 0.5], if value = 1.8 returns [0.2, 0.8]
+ '''
+ assert n_dim > 0
+ if (n_dim == 1):
+ return [1]
+ delta = (max_val - min_val) / float(n_dim - 1)
+ active_pos = int((value - min_val) / delta + 0.5)
+ active_pos = min(n_dim - 2, active_pos)
+ active_pos = max(0, active_pos)
+ anchor_pt = active_pos * delta + min_val
+ if (anchor_pt > value and anchor_pt > min_val + 0.5 * delta):
+ anchor_pt -= delta
+ active_pos -= 1
+ weight = (value - anchor_pt) / delta
+ weight = min(1.0, max(0.0, weight))
+ ret_array = [0 for i in range(n_dim)]
+ ret_array[active_pos] = 1.0 - weight
+ ret_array[active_pos + 1] = weight
+ return ret_array
diff --git a/examples/LiftSim_baseline/README.md b/examples/LiftSim_baseline/DQN/README.md
similarity index 78%
rename from examples/LiftSim_baseline/README.md
rename to examples/LiftSim_baseline/DQN/README.md
index bfc903402d2665fb00e518ae1df77a1b8c88dae5..d75b549f8f3b1c433a915f025ab676115749ddd3 100644
--- a/examples/LiftSim_baseline/README.md
+++ b/examples/LiftSim_baseline/DQN/README.md
@@ -6,11 +6,9 @@
## 依赖库
-- paddlepaddle >= 1.5.1
-- parl >= 1.1.2
-- rlschool >= 0.0.1
-
-Windows版本仅支持Python3.5及以上版本。
++ [paddlepaddle==1.5.1](https://github.com/PaddlePaddle/Paddle)
++ [parl==1.1.2](https://github.com/PaddlePaddle/PARL)
++ [rlschool>=0.0.1](rlschool)
## 运行
diff --git a/examples/LiftSim_baseline/__init__.py b/examples/LiftSim_baseline/DQN/__init__.py
similarity index 100%
rename from examples/LiftSim_baseline/__init__.py
rename to examples/LiftSim_baseline/DQN/__init__.py
diff --git a/examples/LiftSim_baseline/demo.py b/examples/LiftSim_baseline/DQN/demo.py
similarity index 100%
rename from examples/LiftSim_baseline/demo.py
rename to examples/LiftSim_baseline/DQN/demo.py
diff --git a/examples/LiftSim_baseline/rl_10.png b/examples/LiftSim_baseline/DQN/rl_10.png
similarity index 100%
rename from examples/LiftSim_baseline/rl_10.png
rename to examples/LiftSim_baseline/DQN/rl_10.png
diff --git a/examples/LiftSim_baseline/rl_benchmark/__init__.py b/examples/LiftSim_baseline/DQN/rl_benchmark/__init__.py
similarity index 100%
rename from examples/LiftSim_baseline/rl_benchmark/__init__.py
rename to examples/LiftSim_baseline/DQN/rl_benchmark/__init__.py
diff --git a/examples/LiftSim_baseline/rl_benchmark/agent.py b/examples/LiftSim_baseline/DQN/rl_benchmark/agent.py
similarity index 100%
rename from examples/LiftSim_baseline/rl_benchmark/agent.py
rename to examples/LiftSim_baseline/DQN/rl_benchmark/agent.py
diff --git a/examples/LiftSim_baseline/rl_benchmark/dispatcher.py b/examples/LiftSim_baseline/DQN/rl_benchmark/dispatcher.py
similarity index 100%
rename from examples/LiftSim_baseline/rl_benchmark/dispatcher.py
rename to examples/LiftSim_baseline/DQN/rl_benchmark/dispatcher.py
diff --git a/examples/LiftSim_baseline/rl_benchmark/model.py b/examples/LiftSim_baseline/DQN/rl_benchmark/model.py
similarity index 100%
rename from examples/LiftSim_baseline/rl_benchmark/model.py
rename to examples/LiftSim_baseline/DQN/rl_benchmark/model.py
diff --git a/examples/LiftSim_baseline/wrapper.py b/examples/LiftSim_baseline/DQN/wrapper.py
similarity index 100%
rename from examples/LiftSim_baseline/wrapper.py
rename to examples/LiftSim_baseline/DQN/wrapper.py
diff --git a/examples/LiftSim_baseline/wrapper_utils.py b/examples/LiftSim_baseline/DQN/wrapper_utils.py
similarity index 100%
rename from examples/LiftSim_baseline/wrapper_utils.py
rename to examples/LiftSim_baseline/DQN/wrapper_utils.py