diff --git a/.scripts/update_readme_paddle_version.py b/.scripts/update_readme_paddle_version.py index 56d56914c65956a2bb753bc58269d59034766b1c..901d2d672d9f3eff1021241ac80b6e9f75d0886a 100644 --- a/.scripts/update_readme_paddle_version.py +++ b/.scripts/update_readme_paddle_version.py @@ -37,7 +37,8 @@ if __name__ == '__main__': exclude_examples = [ 'NeurIPS2019-Learn-to-Move-Challenge', - 'NeurIPS2018-AI-for-Prosthetics-Challenge', 'EagerMode' + 'NeurIPS2018-AI-for-Prosthetics-Challenge', 'LiftSim_baseline', + 'EagerMode' ] for example in os.listdir('../examples/'): if example not in exclude_examples: diff --git a/examples/A2C/README.md b/examples/A2C/README.md index d38a5d153b3ab39c59b851775567d90edcdda4fb..2328a3ee350851370c5c44bbed6b4e2daae27512 100755 --- a/examples/A2C/README.md +++ b/examples/A2C/README.md @@ -20,7 +20,7 @@ Performance of A2C on various envrionments ## How to use ### Dependencies + [paddlepaddle>=1.6.1](https://github.com/PaddlePaddle/Paddle) -+ [parl](https://github.com/PaddlePaddle/PARL) ++ [parl>=1.2.1](https://github.com/PaddlePaddle/PARL) + gym==0.12.1 + atari-py==0.1.7 diff --git a/examples/A2C/train.py b/examples/A2C/train.py index 777a22849afcb3ad5b1e237d7e3d0ae9b39fa871..1daba4736f8e19e9f50d425c12868a43ee0933db 100755 --- a/examples/A2C/train.py +++ b/examples/A2C/train.py @@ -55,11 +55,6 @@ class Learner(object): assert get_gpu_count() == 1, 'Only support training in single GPU,\ Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_TO_USE]` .' - else: - cpu_num = os.environ.get('CPU_NUM') - assert cpu_num is not None and cpu_num == '1', 'Only support training in single CPU,\ - Please set environment variable: `export CPU_NUM=1`.' - #========== Learner ========== self.total_loss_stat = WindowStat(100) diff --git a/examples/LiftSim_baseline/A2C/README.md b/examples/LiftSim_baseline/A2C/README.md new file mode 100644 index 0000000000000000000000000000000000000000..69019aaa22393194e110b72072d16e97d6d6dd3b --- /dev/null +++ b/examples/LiftSim_baseline/A2C/README.md @@ -0,0 +1,47 @@ +# LiftSim基线 + +## 简介 + +基于PARL库实现A2C算法,应用于[RLSchool][rlschool]库中的电梯调度模拟环境[LiftSim][liftsim]。 + +## 依赖库 + ++ [paddlepaddle>=1.6.1](https://github.com/PaddlePaddle/Paddle) ++ [parl>=1.2.1](https://github.com/PaddlePaddle/PARL) ++ [rlschool>=0.1.1][rlschool] + + +## 分布式训练 + +首先,启动一个具有5个CPU资源的本地集群: + +```bash +xparl start --port 8010 --cpu_num 5 +``` + +> 注意,如果你已经启动了一个集群,则不需要重复运行上面命令。关于PARL集群更多信息,可以参考[文档](https://parl.readthedocs.io/en/latest/parallel_training/setup.html)。 + +然后我们就可以通过运行下面命令进行分布式训练: + +```bash +python train.py +``` + + +## 评估 +可以通过下面命令来评估保存的模型 +```bash +python evaluate.py --model_path saved_models/[FILENAME] +``` + +tensorboard和log文件会保存在`./train_log/train/`;可以通过运行命令`tensorboard --logdir .`查看tensorboard可视化界面。 + +## 收敛指标 +训练30h左右,评估指标能达到-120分左右(LiftSim环境运行1天reward) + + +## 可视化效果 + + +[rlschool]: https://github.com/PaddlePaddle/RLSchool +[liftsim]: https://github.com/PaddlePaddle/RLSchool/tree/master/rlschool/liftsim diff --git a/examples/LiftSim_baseline/A2C/a2c_config.py b/examples/LiftSim_baseline/A2C/a2c_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e3e776b2f1033bc6bd911ed929437f91c44bf938 --- /dev/null +++ b/examples/LiftSim_baseline/A2C/a2c_config.py @@ -0,0 +1,37 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +config = { + #========== remote config ========== + 'master_address': 'localhost:8010', + + #========== actor config ========== + 'actor_num': 5, + 'env_num': 5, + 'sample_batch_steps': 5, + + #========== learner config ========== + 'max_sample_steps': int(1e10), + 'gamma': 0.998, + 'lambda': 1.0, # GAE + + # start learning rate + 'start_lr': 1.0e-4, + + # coefficient of policy entropy adjustment schedule: (train_step, coefficient) + 'entropy_coeff_scheduler': [(0, -2.0e-4)], + 'vf_loss_coeff': 0.5, + 'get_remote_metrics_interval': 100, + 'log_metrics_interval_s': 60, +} diff --git a/examples/LiftSim_baseline/A2C/actor.py b/examples/LiftSim_baseline/A2C/actor.py new file mode 100644 index 0000000000000000000000000000000000000000..4286e7f9c4dd64b444c2ca34d5ebffe870b2769f --- /dev/null +++ b/examples/LiftSim_baseline/A2C/actor.py @@ -0,0 +1,137 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import parl +from collections import defaultdict +from env_wrapper import ObsProcessWrapper, ActionProcessWrapper, RewardWrapper, MetricsWrapper +from parl.utils.rl_utils import calc_gae +from parl.env.vector_env import VectorEnv +from rlschool import LiftSim +from copy import deepcopy +from lift_model import LiftModel +from lift_agent import LiftAgent + + +@parl.remote_class +class Actor(object): + def __init__(self, config): + self.config = config + self.env_num = config['env_num'] + + self.envs = [] + for _ in range(self.env_num): + env = LiftSim() + env = RewardWrapper(env) + env = ActionProcessWrapper(env) + env = ObsProcessWrapper(env) + env = MetricsWrapper(env) + self.envs.append(env) + self.vector_env = VectorEnv(self.envs) + + # number of elevators + self.ele_num = self.envs[0].mansion_attr.ElevatorNumber + + act_dim = self.envs[0].act_dim + self.obs_dim = self.envs[0].obs_dim + self.config['obs_dim'] = self.obs_dim + + # nested list of shape (env_num, ele_num, obs_dim) + self.obs_batch = self.vector_env.reset() + # (env_num * ele_num, obs_dim) + self.obs_batch = np.array(self.obs_batch).reshape( + [self.env_num * self.ele_num, self.obs_dim]) + + model = LiftModel(act_dim) + algorithm = parl.algorithms.A3C( + model, vf_loss_coeff=config['vf_loss_coeff']) + self.agent = LiftAgent(algorithm, config) + + def sample(self): + sample_data = defaultdict(list) + + env_sample_data = {} + # treat each elevator in Liftsim as an independent env + for env_id in range(self.env_num * self.ele_num): + env_sample_data[env_id] = defaultdict(list) + + for i in range(self.config['sample_batch_steps']): + actions_batch, values_batch = self.agent.sample(self.obs_batch) + + vector_actions = np.array_split(actions_batch, self.env_num) + assert len(vector_actions[-1]) == self.ele_num + next_obs_batch, reward_batch, done_batch, info_batch = \ + self.vector_env.step(vector_actions) + + # (env_num, ele_num, obs_dim) -> (env_num * ele_num, obs_dim) + next_obs_batch = np.array(next_obs_batch).reshape( + [self.env_num * self.ele_num, self.obs_dim]) + # repeat reward and done to ele_num times + # (env_num) -> (env_num, ele_num) -> (env_num * ele_num) + reward_batch = np.repeat(reward_batch, self.ele_num) + done_batch = np.repeat(done_batch, self.ele_num) + + for env_id in range(self.env_num * self.ele_num): + env_sample_data[env_id]['obs'].append(self.obs_batch[env_id]) + env_sample_data[env_id]['actions'].append( + actions_batch[env_id]) + env_sample_data[env_id]['rewards'].append(reward_batch[env_id]) + env_sample_data[env_id]['dones'].append(done_batch[env_id]) + env_sample_data[env_id]['values'].append(values_batch[env_id]) + + # Calculate advantages when the episode is done or reaches max sample steps. + if done_batch[env_id] or i + 1 == self.config[ + 'sample_batch_steps']: # reach max sample steps + next_value = 0 + if not done_batch[env_id]: + next_obs = np.expand_dims(next_obs_batch[env_id], 0) + next_value = self.agent.value(next_obs) + + values = env_sample_data[env_id]['values'] + rewards = env_sample_data[env_id]['rewards'] + advantages = calc_gae(rewards, values, next_value, + self.config['gamma'], + self.config['lambda']) + target_values = advantages + values + + sample_data['obs'].extend(env_sample_data[env_id]['obs']) + sample_data['actions'].extend( + env_sample_data[env_id]['actions']) + sample_data['advantages'].extend(advantages) + sample_data['target_values'].extend(target_values) + + env_sample_data[env_id] = defaultdict(list) + + self.obs_batch = deepcopy(next_obs_batch) + + # size of sample_data[key]: env_num * ele_num * sample_batch_steps + for key in sample_data: + sample_data[key] = np.stack(sample_data[key]) + + return sample_data + + def get_metrics(self): + metrics = defaultdict(list) + for metrics_env in self.envs: + assert isinstance( + metrics_env, + MetricsWrapper), "Put the MetricsWrapper in the last wrapper" + for env_reward_1h, env_reward_24h in metrics_env.next_episode_results( + ): + metrics['env_reward_1h'].append(env_reward_1h) + metrics['env_reward_24h'].append(env_reward_24h) + return metrics + + def set_weights(self, params): + self.agent.set_weights(params) diff --git a/examples/LiftSim_baseline/A2C/effect.gif b/examples/LiftSim_baseline/A2C/effect.gif new file mode 100644 index 0000000000000000000000000000000000000000..b796acbf3c872f0e960e731fb99848c774446165 Binary files /dev/null and b/examples/LiftSim_baseline/A2C/effect.gif differ diff --git a/examples/LiftSim_baseline/A2C/env_wrapper.py b/examples/LiftSim_baseline/A2C/env_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..c8d08fac40d1cc6fff6f46e97e2bac649a79bc1b --- /dev/null +++ b/examples/LiftSim_baseline/A2C/env_wrapper.py @@ -0,0 +1,301 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +import numpy as np +from utils import discretize, linear_discretize +from rlschool import LiftSim + + +class BaseWrapper(object): + def __init__(self, env): + self.env = env + self._mansion = env._mansion + self.mansion_attr = self._mansion.attribute + + @property + def obs_dim(self): + if hasattr(self.env, 'obs_dim'): + return self.env.obs_dim + else: + return None + + @property + def act_dim(self): + if hasattr(self.env, 'act_dim'): + return self.env.act_dim + else: + return None + + def seed(self, seed=None): + return self.env.seed(seed) + + def step(self, action): + return self.env.step(action) + + def reset(self): + return self.env.reset() + + def render(self): + return self.env.render() + + def close(self): + return self.env.close() + + +class ObsProcessWrapper(BaseWrapper): + """Extract features of each elevator in LiftSim env + """ + + def __init__(self, env, hour_distize_num=6): + super(ObsProcessWrapper, self).__init__(env) + self.hour_distize_num = hour_distize_num + self.total_steps = 0 + + @property + def obs_dim(self): + """ + NOTE: + Keep obs_dim to the return size of function `_mansion_state_process` + """ + ele_dim = self.mansion_attr.NumberOfFloor * 3 + 34 + obs_dim = (ele_dim + 1) * self.mansion_attr.ElevatorNumber + \ + self.mansion_attr.NumberOfFloor * 2 + obs_dim += self.hour_distize_num + return obs_dim + + def reset(self): + """ + + Returns: + obs(list): [[self.obs_dim]] * mansion_attr.ElevatorNumber, features array of all elevators + """ + obs = self.env.reset() + self.total_steps = 0 + obs = self._mansion_state_process(obs) + return obs + + def step(self, action): + """ + Returns: + obs(list): nested list, shape of [mansion_attr.ElevatorNumber, self.obs_dim], + features array of all elevators + reward(int): returned by self.env + done(bool): returned by self.env + info(dict): returned by self.env + """ + obs, reward, done, info = self.env.step(action) + self.total_steps += 1 + obs = self._mansion_state_process(obs) + return obs, reward, done, info + + def _mansion_state_process(self, mansion_state): + """Extract features of env + """ + ele_features = list() + for ele_state in mansion_state.ElevatorStates: + ele_features.append(self._ele_state_process(ele_state)) + max_floor = ele_state.MaximumFloor + + target_floor_binaries_up = [0.0 for i in range(max_floor)] + target_floor_binaries_down = [0.0 for i in range(max_floor)] + for floor in mansion_state.RequiringUpwardFloors: + target_floor_binaries_up[floor - 1] = 1.0 + for floor in mansion_state.RequiringDownwardFloors: + target_floor_binaries_down[floor - 1] = 1.0 + target_floor_binaries = target_floor_binaries_up + target_floor_binaries_down + + raw_time = self.total_steps * 0.5 # timestep seconds + time_id = int(raw_time % 86400) + time_id = time_id // (24 / self.hour_distize_num * 3600) + time_id_vec = discretize(time_id + 1, self.hour_distize_num, 1, + self.hour_distize_num) + + man_features = list() + for idx in range(len(mansion_state.ElevatorStates)): + elevator_id_vec = discretize(idx + 1, + len(mansion_state.ElevatorStates), 1, + len(mansion_state.ElevatorStates)) + idx_array = list(range(len(mansion_state.ElevatorStates))) + idx_array.remove(idx) + man_features.append(ele_features[idx]) + for left_idx in idx_array: + man_features[idx] = man_features[idx] + ele_features[left_idx] + man_features[idx] = man_features[idx] + \ + elevator_id_vec + target_floor_binaries + man_features[idx] = man_features[idx] + time_id_vec + return np.asarray(man_features, dtype='float32') + + def _ele_state_process(self, ele_state): + """Extract features of elevator + """ + ele_feature = [] + + # add floor information + ele_feature.extend( + linear_discretize(ele_state.Floor, ele_state.MaximumFloor, 1.0, + ele_state.MaximumFloor)) + + # add velocity information + ele_feature.extend( + linear_discretize(ele_state.Velocity, 21, -ele_state.MaximumSpeed, + ele_state.MaximumSpeed)) + + # add door information + ele_feature.append(ele_state.DoorState) + ele_feature.append(float(ele_state.DoorIsOpening)) + ele_feature.append(float(ele_state.DoorIsClosing)) + + # add direction information + ele_feature.extend(discretize(ele_state.Direction, 3, -1, 1)) + + # add load weight information + ele_feature.extend( + linear_discretize(ele_state.LoadWeight / ele_state.MaximumLoad, 5, + 0.0, 1.0)) + + # add other information + target_floor_binaries = [0.0 for i in range(ele_state.MaximumFloor)] + for target_floor in ele_state.ReservedTargetFloors: + target_floor_binaries[target_floor - 1] = 1.0 + ele_feature.extend(target_floor_binaries) + + dispatch_floor_binaries = [ + 0.0 for i in range(ele_state.MaximumFloor + 1) + ] + dispatch_floor_binaries[ele_state.CurrentDispatchTarget] = 1.0 + ele_feature.extend(dispatch_floor_binaries) + ele_feature.append(ele_state.DispatchTargetDirection) + + return ele_feature + + +class ActionProcessWrapper(BaseWrapper): + def __init__(self, env): + """Map action id predicted by model to action of LiftSim + + """ + super(ActionProcessWrapper, self).__init__(env) + + @property + def act_dim(self): + """ + NOTE: + keep act_dim in line with function `_action_idx_to_action` + + Returns: + int: NumberOfFloor * (2 directions) + (-1 DispatchTarget) + (0 DispatchTarget) + """ + return self.mansion_attr.NumberOfFloor * 2 + 2 + + def step(self, action): + """ + Args: + action(list): action_id of all elevators (length = mansion_attr.ElevatorNumber) + """ + ele_actions = [] + for action_id in action: + ele_actions.extend(self._action_idx_to_action(action_id)) + + # ele_action: list, formatted action for LiftSim env (length = 2 * mansion_attr.ElevatorNumber) + return self.env.step(ele_actions) + + def _action_idx_to_action(self, action_idx): + action_idx = int(action_idx) + realdim = self.act_dim - 2 + if (action_idx == realdim): + return (0, 1) # mapped to DispatchTarget=0 + elif (action_idx == realdim + 1): + return (-1, 1) # mapped to DispatchTarget=-1 + action = action_idx + if (action_idx < realdim / 2): + direction = 1 # up direction + action += 1 + else: + direction = -1 # down direction + action -= int(realdim / 2) + action += 1 + return (action, direction) + + +class RewardWrapper(BaseWrapper): + def __init__(self, env): + """Design reward of LiftSim env. + """ + super(RewardWrapper, self).__init__(env) + self.ele_num = self.mansion_attr.ElevatorNumber + + def step(self, action): + """Here we return same reward for each elevator, + you alos can design different rewards of each elevator. + + Returns: + obs: returned by self.env + reward: shaping reward + done: returned by self.env + info: returned by self.env + """ + obs, origin_reward, done, info = self.env.step(action) + + reward = -(30 * info['time_consume'] + 0.01 * info['energy_consume'] + + 100 * info['given_up_persons']) * 1.0e-3 / self.ele_num + + info['origin_reward'] = origin_reward + + return obs, reward, done, info + + +class MetricsWrapper(BaseWrapper): + def __init__(self, env): + super(MetricsWrapper, self).__init__(env) + + self._total_steps = 0 + self._env_reward_1h = 0 + self._env_reward_24h = 0 + + self._num_returned = 0 + self._episode_result = [] + + def reset(self): + self._total_steps = 0 + self._env_reward_1h = 0 + self._env_reward_24h = 0 + return self.env.reset() + + def step(self, action): + obs, reward, done, info = self.env.step(action) + self._total_steps += 1 + + self._env_reward_1h += info['origin_reward'] + self._env_reward_24h += info['origin_reward'] + + # Treat 1h in LiftSim env as an episode (1step = 0.5s) + if self._total_steps % (3600 * 2) == 0: # 1h + episode_env_reward_1h = self._env_reward_1h + self._env_reward_1h = 0 + + episode_env_reward_24h = None + if self._total_steps % (24 * 3600 * 2) == 0: # 24h + episode_env_reward_24h = self._env_reward_24h + self._env_reward_24h = 0 + + self._episode_result.append( + [episode_env_reward_1h, episode_env_reward_24h]) + + return obs, reward, done, info + + def next_episode_results(self): + for i in range(self._num_returned, len(self._episode_result)): + yield self._episode_result[i] + self._num_returned = len(self._episode_result) diff --git a/examples/LiftSim_baseline/A2C/evaluate.py b/examples/LiftSim_baseline/A2C/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..d004aa4255e7f54d123e2f11de80158e60319e1c --- /dev/null +++ b/examples/LiftSim_baseline/A2C/evaluate.py @@ -0,0 +1,61 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import parl +from parl.utils import logger +from env_wrapper import ObsProcessWrapper, ActionProcessWrapper, RewardWrapper +from rlschool import LiftSim +from lift_model import LiftModel +from lift_agent import LiftAgent +from a2c_config import config + + +def evaluate_one_day(model_path): + env = LiftSim() + env = ActionProcessWrapper(env) + env = ObsProcessWrapper(env) + act_dim = env.act_dim + obs_dim = env.obs_dim + config['obs_dim'] = obs_dim + + model = LiftModel(act_dim) + algorithm = parl.algorithms.A3C( + model, vf_loss_coeff=config['vf_loss_coeff']) + agent = LiftAgent(algorithm, config) + agent.restore(model_path) + + reward_24h = 0 + obs = env.reset() + for i in range(24 * 3600 * 2): # 24h, 1step = 0.5s + action, _ = agent.sample(obs) + #print(action) + obs, reward, done, info = env.step(action) + reward_24h += reward + if (i + 1) % (3600 * 2) == 0: + logger.info('hour {}, total_reward: {}'.format( + (i + 1) // (3600 * 2), reward_24h)) + + logger.info('model_path: {}, 24h reward: {}'.format( + model_path, reward_24h)) + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument( + '--model_path', type=str, help='path of the model to evaluate.') + args = parser.parse_args() + + evaluate_one_day(args.model_path) diff --git a/examples/LiftSim_baseline/A2C/lift_agent.py b/examples/LiftSim_baseline/A2C/lift_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..9c2d64d64c0f89dc939a2458d1b2e9fedbb5242a --- /dev/null +++ b/examples/LiftSim_baseline/A2C/lift_agent.py @@ -0,0 +1,153 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import paddle.fluid as fluid +import numpy as np +from parl import layers +from parl.utils.scheduler import PiecewiseScheduler, LinearDecayScheduler + + +class LiftAgent(parl.Agent): + def __init__(self, algorithm, config): + """ + Args: + algorithm (`parl.Algorithm`): algorithm to be used in this agent. + config (dict): config describing the training hyper-parameters(see a2c_config.py) + """ + self.obs_dim = config['obs_dim'] + super(LiftAgent, self).__init__(algorithm) + + self.lr_scheduler = LinearDecayScheduler(config['start_lr'], + config['max_sample_steps']) + self.entropy_coeff_scheduler = PiecewiseScheduler( + config['entropy_coeff_scheduler']) + + def build_program(self): + self.sample_program = fluid.Program() + self.predict_program = fluid.Program() + self.value_program = fluid.Program() + self.learn_program = fluid.Program() + + with fluid.program_guard(self.sample_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + sample_actions, values = self.alg.sample(obs) + self.sample_outputs = [sample_actions, values] + + with fluid.program_guard(self.predict_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + self.predict_actions = self.alg.predict(obs) + + with fluid.program_guard(self.value_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + self.values = self.alg.value(obs) + + with fluid.program_guard(self.learn_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + actions = layers.data(name='actions', shape=[], dtype='int32') + advantages = layers.data( + name='advantages', shape=[], dtype='float32') + target_values = layers.data( + name='target_values', shape=[], dtype='float32') + lr = layers.data( + name='lr', shape=[1], dtype='float32', append_batch_size=False) + entropy_coeff = layers.data( + name='entropy_coeff', shape=[], dtype='float32') + + total_loss, pi_loss, vf_loss, entropy = self.alg.learn( + obs, actions, advantages, target_values, lr, entropy_coeff) + self.learn_outputs = [total_loss, pi_loss, vf_loss, entropy] + self.learn_program = parl.compile(self.learn_program, total_loss) + + def sample(self, obs_np): + """ + Args: + obs_np: a numpy float32 array of shape (B, obs_dim). + + Returns: + sample_ids: a numpy int64 array of shape [B] + values: a numpy float32 array of shape [B] + """ + obs_np = obs_np.astype('float32') + + sample_actions, values = self.fluid_executor.run( + self.sample_program, + feed={'obs': obs_np}, + fetch_list=self.sample_outputs) + return sample_actions, values + + def predict(self, obs_np): + """ + Args: + obs_np: a numpy float32 array of shape (B, obs_dim). + + Returns: + predict_actions: a numpy int64 array of shape [B] + """ + obs_np = obs_np.astype('float32') + + predict_actions = self.fluid_executor.run( + self.predict_program, + feed={'obs': obs_np}, + fetch_list=[self.predict_actions])[0] + return predict_actions + + def value(self, obs_np): + """ + Args: + obs_np: a numpy float32 array of shape (B, obs_dim). + + Returns: + values: a numpy float32 array of shape [B] + """ + obs_np = obs_np.astype('float32') + + values = self.fluid_executor.run( + self.value_program, feed={'obs': obs_np}, + fetch_list=[self.values])[0] + return values + + def learn(self, obs_np, actions_np, advantages_np, target_values_np): + """ + Args: + obs_np: a numpy float32 array of shape (B, obs_dim). + actions_np: a numpy int64 array of shape [B] + advantages_np: a numpy float32 array of shape [B] + target_values_np: a numpy float32 array of shape [B] + """ + + obs_np = obs_np.astype('float32') + actions_np = actions_np.astype('int64') + advantages_np = advantages_np.astype('float32') + target_values_np = target_values_np.astype('float32') + + lr = self.lr_scheduler.step(step_num=obs_np.shape[0]) + entropy_coeff = self.entropy_coeff_scheduler.step() + + total_loss, pi_loss, vf_loss, entropy = self.fluid_executor.run( + self.learn_program, + feed={ + 'obs': obs_np, + 'actions': actions_np, + 'advantages': advantages_np, + 'target_values': target_values_np, + 'lr': np.array([lr], dtype='float32'), + 'entropy_coeff': np.array([entropy_coeff], dtype='float32') + }, + fetch_list=self.learn_outputs) + return total_loss, pi_loss, vf_loss, entropy, lr, entropy_coeff diff --git a/examples/LiftSim_baseline/A2C/lift_model.py b/examples/LiftSim_baseline/A2C/lift_model.py new file mode 100644 index 0000000000000000000000000000000000000000..6da63723b5ab058ef7333c590725f90268ce52b6 --- /dev/null +++ b/examples/LiftSim_baseline/A2C/lift_model.py @@ -0,0 +1,75 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import paddle.fluid as fluid +from parl import layers + + +class LiftModel(parl.Model): + def __init__(self, act_dim): + self.act_dim = act_dim + self.fc_1 = layers.fc(size=512, act='relu') + self.fc_2 = layers.fc(size=256, act='relu') + self.fc_3 = layers.fc(size=128, act='tanh') + + self.value_fc = layers.fc(size=1) + self.policy_fc = layers.fc(size=act_dim) + + def policy(self, obs): + """ + Args: + obs(float32 tensor): shape of (B * obs_dim) + + Returns: + policy_logits(float32 tensor): shape of (B * act_dim) + """ + h_1 = self.fc_1(obs) + h_2 = self.fc_2(h_1) + h_3 = self.fc_3(h_2) + policy_logits = self.policy_fc(h_3) + return policy_logits + + def value(self, obs): + """ + Args: + obs(float32 tensor): shape of (B * obs_dim) + + Returns: + values(float32 tensor): shape of (B,) + """ + h_1 = self.fc_1(obs) + h_2 = self.fc_2(h_1) + h_3 = self.fc_3(h_2) + values = self.value_fc(h_3) + values = layers.squeeze(values, axes=[1]) + return values + + def policy_and_value(self, obs): + """ + Args: + obs(float32 tensor): shape (B * obs_dim) + + Returns: + policy_logits(float32 tensor): shape of (B * act_dim) + values(float32 tensor): shape of (B,) + """ + h_1 = self.fc_1(obs) + h_2 = self.fc_2(h_1) + h_3 = self.fc_3(h_2) + policy_logits = self.policy_fc(h_3) + values = self.value_fc(h_3) + values = layers.squeeze(values, axes=[1]) + + return policy_logits, values diff --git a/examples/LiftSim_baseline/A2C/performance.png b/examples/LiftSim_baseline/A2C/performance.png new file mode 100644 index 0000000000000000000000000000000000000000..153da4eb12bd3219ed5030516bcc001188c980b2 Binary files /dev/null and b/examples/LiftSim_baseline/A2C/performance.png differ diff --git a/examples/LiftSim_baseline/A2C/train.py b/examples/LiftSim_baseline/A2C/train.py new file mode 100644 index 0000000000000000000000000000000000000000..3663cb841f676efcc157fccd15fa816cdaae662c --- /dev/null +++ b/examples/LiftSim_baseline/A2C/train.py @@ -0,0 +1,220 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import os +import parl +import queue +import six +import time +import threading + +from actor import Actor +from collections import defaultdict +from env_wrapper import ObsProcessWrapper, ActionProcessWrapper +from parl.utils import logger, get_gpu_count, tensorboard, machine_info +from parl.utils.scheduler import PiecewiseScheduler +from parl.utils.time_stat import TimeStat +from parl.utils.window_stat import WindowStat +from rlschool import LiftSim +from lift_model import LiftModel +from lift_agent import LiftAgent + + +class Learner(object): + def __init__(self, config): + self.config = config + + #=========== Create Agent ========== + env = LiftSim() + env = ActionProcessWrapper(env) + env = ObsProcessWrapper(env) + + obs_dim = env.obs_dim + act_dim = env.act_dim + self.config['obs_dim'] = obs_dim + + model = LiftModel(act_dim) + algorithm = parl.algorithms.A3C( + model, vf_loss_coeff=config['vf_loss_coeff']) + self.agent = LiftAgent(algorithm, config) + + if machine_info.is_gpu_available(): + assert get_gpu_count() == 1, 'Only support training in single GPU,\ + Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_TO_USE]` .' + + #========== Learner ========== + + self.entropy_stat = WindowStat(100) + self.target_values = None + + self.learn_time_stat = TimeStat(100) + self.start_time = None + + #========== Remote Actor =========== + self.remote_count = 0 + self.sample_data_queue = queue.Queue() + + self.remote_metrics_queue = queue.Queue() + self.sample_total_steps = 0 + + self.params_queues = [] + self.create_actors() + + self.log_steps = 0 + + def create_actors(self): + """ Connect to the cluster and start sampling of the remote actor. + """ + parl.connect(self.config['master_address']) + + logger.info('Waiting for {} remote actors to connect.'.format( + self.config['actor_num'])) + + for i in six.moves.range(self.config['actor_num']): + params_queue = queue.Queue() + self.params_queues.append(params_queue) + + self.remote_count += 1 + logger.info('Remote actor count: {}'.format(self.remote_count)) + + remote_thread = threading.Thread( + target=self.run_remote_sample, args=(params_queue, )) + remote_thread.setDaemon(True) + remote_thread.start() + + self.start_time = time.time() + + def run_remote_sample(self, params_queue): + """ Sample data from remote actor and update parameters of remote actor. + """ + remote_actor = Actor(self.config) + + cnt = 0 + while True: + latest_params = params_queue.get() + remote_actor.set_weights(latest_params) + batch = remote_actor.sample() + + self.sample_data_queue.put(batch) + + cnt += 1 + if cnt % self.config['get_remote_metrics_interval'] == 0: + metrics = remote_actor.get_metrics() + if metrics: + self.remote_metrics_queue.put(metrics) + + def step(self): + """ + 1. kick off all actors to synchronize parameters and sample data; + 2. collect sample data of all actors; + 3. update parameters. + """ + latest_params = self.agent.get_weights() + for params_queue in self.params_queues: + params_queue.put(latest_params) + + train_batch = defaultdict(list) + for i in range(self.config['actor_num']): + sample_data = self.sample_data_queue.get() + for key, value in sample_data.items(): + train_batch[key].append(value) + + self.sample_total_steps += sample_data['obs'].shape[0] + + for key, value in train_batch.items(): + train_batch[key] = np.concatenate(value) + + with self.learn_time_stat: + total_loss, pi_loss, vf_loss, entropy, lr, entropy_coeff = self.agent.learn( + obs_np=train_batch['obs'], + actions_np=train_batch['actions'], + advantages_np=train_batch['advantages'], + target_values_np=train_batch['target_values']) + + self.entropy_stat.add(entropy) + self.target_values = np.mean(train_batch['target_values']) + + tensorboard.add_scalar('model/entropy', entropy, + self.sample_total_steps) + tensorboard.add_scalar('model/q_value', self.target_values, + self.sample_total_steps) + + def log_metrics(self): + """ Log metrics of learner and actors + """ + if self.start_time is None: + return + + metrics = [] + while True: + try: + metric = self.remote_metrics_queue.get_nowait() + metrics.append(metric) + except queue.Empty: + break + + env_reward_1h, env_reward_24h = [], [] + for x in metrics: + env_reward_1h.extend(x['env_reward_1h']) + env_reward_24h.extend(x['env_reward_24h']) + env_reward_1h = [x for x in env_reward_1h if x is not None] + env_reward_24h = [x for x in env_reward_24h if x is not None] + + mean_reward_1h, mean_reward_24h = None, None + if env_reward_1h: + mean_reward_1h = np.mean(np.array(env_reward_1h).flatten()) + tensorboard.add_scalar('performance/env_rewards_1h', + mean_reward_1h, self.sample_total_steps) + if env_reward_24h: + mean_reward_24h = np.mean(np.array(env_reward_24h).flatten()) + tensorboard.add_scalar('performance/env_rewards_24h', + mean_reward_24h, self.sample_total_steps) + + metric = { + 'Sample steps': self.sample_total_steps, + 'env_reward_1h': mean_reward_1h, + 'env_reward_24h': mean_reward_24h, + 'target_values': self.target_values, + 'entropy': self.entropy_stat.mean, + 'learn_time_s': self.learn_time_stat.mean, + 'elapsed_time_s': int(time.time() - self.start_time), + } + logger.info(metric) + + self.log_steps += 1 + save_interval_step = 7200 // max(1, + self.config['log_metrics_interval_s']) + if self.log_steps % save_interval_step == 0: + self.save_model() # save model every 2h + + def should_stop(self): + return self.sample_total_steps >= self.config['max_sample_steps'] + + def save_model(self): + time_str = time.strftime(".%Y%m%d_%H%M%S", time.localtime()) + self.agent.save(os.path.join('saved_models', 'model.ckpt' + time_str)) + + +if __name__ == '__main__': + from a2c_config import config + + learner = Learner(config) + assert config['log_metrics_interval_s'] > 0 + + while not learner.should_stop(): + start = time.time() + while time.time() - start < config['log_metrics_interval_s']: + learner.step() + learner.log_metrics() diff --git a/examples/LiftSim_baseline/A2C/utils.py b/examples/LiftSim_baseline/A2C/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..05c081eaa39e4a2366ebf986092c4a9ec11a2c2f --- /dev/null +++ b/examples/LiftSim_baseline/A2C/utils.py @@ -0,0 +1,58 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def discretize(value, n_dim, min_val, max_val): + ''' + discretize a value into a vector of n_dim dimension 1-hot representation + with the value below min_val being [1, 0, 0, ..., 0] + and the value above max_val being [0, 0, ..., 0, 1] + ''' + assert n_dim > 0 + if (n_dim == 1): + return [1] + delta = (max_val - min_val) / float(n_dim - 1) + active_pos = int((value - min_val) / delta + 0.5) + active_pos = min(n_dim - 1, active_pos) + active_pos = max(0, active_pos) + ret_array = [0 for i in range(n_dim)] + ret_array[active_pos] = 1.0 + return ret_array + + +def linear_discretize(value, n_dim, min_val, max_val): + ''' + discretize a value into a vector of n_dim dimensional representation + with the value below min_val being [1, 0, 0, ..., 0] + and the value above max_val being [0, 0, ..., 0, 1] + e.g. if n_dim = 2, min_val = 1.0, max_val = 2.0 + if value = 1.5 returns [0.5, 0.5], if value = 1.8 returns [0.2, 0.8] + ''' + assert n_dim > 0 + if (n_dim == 1): + return [1] + delta = (max_val - min_val) / float(n_dim - 1) + active_pos = int((value - min_val) / delta + 0.5) + active_pos = min(n_dim - 2, active_pos) + active_pos = max(0, active_pos) + anchor_pt = active_pos * delta + min_val + if (anchor_pt > value and anchor_pt > min_val + 0.5 * delta): + anchor_pt -= delta + active_pos -= 1 + weight = (value - anchor_pt) / delta + weight = min(1.0, max(0.0, weight)) + ret_array = [0 for i in range(n_dim)] + ret_array[active_pos] = 1.0 - weight + ret_array[active_pos + 1] = weight + return ret_array diff --git a/examples/LiftSim_baseline/README.md b/examples/LiftSim_baseline/DQN/README.md similarity index 78% rename from examples/LiftSim_baseline/README.md rename to examples/LiftSim_baseline/DQN/README.md index bfc903402d2665fb00e518ae1df77a1b8c88dae5..d75b549f8f3b1c433a915f025ab676115749ddd3 100644 --- a/examples/LiftSim_baseline/README.md +++ b/examples/LiftSim_baseline/DQN/README.md @@ -6,11 +6,9 @@ ## 依赖库 -- paddlepaddle >= 1.5.1 -- parl >= 1.1.2 -- rlschool >= 0.0.1 - -Windows版本仅支持Python3.5及以上版本。 ++ [paddlepaddle==1.5.1](https://github.com/PaddlePaddle/Paddle) ++ [parl==1.1.2](https://github.com/PaddlePaddle/PARL) ++ [rlschool>=0.0.1](rlschool) ## 运行 diff --git a/examples/LiftSim_baseline/__init__.py b/examples/LiftSim_baseline/DQN/__init__.py similarity index 100% rename from examples/LiftSim_baseline/__init__.py rename to examples/LiftSim_baseline/DQN/__init__.py diff --git a/examples/LiftSim_baseline/demo.py b/examples/LiftSim_baseline/DQN/demo.py similarity index 100% rename from examples/LiftSim_baseline/demo.py rename to examples/LiftSim_baseline/DQN/demo.py diff --git a/examples/LiftSim_baseline/rl_10.png b/examples/LiftSim_baseline/DQN/rl_10.png similarity index 100% rename from examples/LiftSim_baseline/rl_10.png rename to examples/LiftSim_baseline/DQN/rl_10.png diff --git a/examples/LiftSim_baseline/rl_benchmark/__init__.py b/examples/LiftSim_baseline/DQN/rl_benchmark/__init__.py similarity index 100% rename from examples/LiftSim_baseline/rl_benchmark/__init__.py rename to examples/LiftSim_baseline/DQN/rl_benchmark/__init__.py diff --git a/examples/LiftSim_baseline/rl_benchmark/agent.py b/examples/LiftSim_baseline/DQN/rl_benchmark/agent.py similarity index 100% rename from examples/LiftSim_baseline/rl_benchmark/agent.py rename to examples/LiftSim_baseline/DQN/rl_benchmark/agent.py diff --git a/examples/LiftSim_baseline/rl_benchmark/dispatcher.py b/examples/LiftSim_baseline/DQN/rl_benchmark/dispatcher.py similarity index 100% rename from examples/LiftSim_baseline/rl_benchmark/dispatcher.py rename to examples/LiftSim_baseline/DQN/rl_benchmark/dispatcher.py diff --git a/examples/LiftSim_baseline/rl_benchmark/model.py b/examples/LiftSim_baseline/DQN/rl_benchmark/model.py similarity index 100% rename from examples/LiftSim_baseline/rl_benchmark/model.py rename to examples/LiftSim_baseline/DQN/rl_benchmark/model.py diff --git a/examples/LiftSim_baseline/wrapper.py b/examples/LiftSim_baseline/DQN/wrapper.py similarity index 100% rename from examples/LiftSim_baseline/wrapper.py rename to examples/LiftSim_baseline/DQN/wrapper.py diff --git a/examples/LiftSim_baseline/wrapper_utils.py b/examples/LiftSim_baseline/DQN/wrapper_utils.py similarity index 100% rename from examples/LiftSim_baseline/wrapper_utils.py rename to examples/LiftSim_baseline/DQN/wrapper_utils.py