diff --git a/examples/NeurIPS2018-AI-for-Prosthetics-Challenge/README.md b/examples/NeurIPS2018-AI-for-Prosthetics-Challenge/README.md index f79a58c8a0b5e4d68ff26e72c601eddabc0ea1a6..be7766d8b66f1d159d7e0006ec099aa7b1421087 100644 --- a/examples/NeurIPS2018-AI-for-Prosthetics-Challenge/README.md +++ b/examples/NeurIPS2018-AI-for-Prosthetics-Challenge/README.md @@ -19,7 +19,7 @@ For more technical details about our solution, we provide: ## Dependencies - python3.6 -- [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle) +- [paddlepaddle==1.5.1](https://github.com/PaddlePaddle/Paddle) - [osim-rl](https://github.com/stanfordnmbl/osim-rl) - [grpcio==1.12.1](https://grpc.io/docs/quickstart/python.html) - tqdm diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/README.md b/examples/NeurIPS2019-Learn-to-Move-Challenge/README.md index 8468e942fca77e48890fabd344c022cd621123fd..9bc97422a7235577e48f753f261d98416ffba800 100644 --- a/examples/NeurIPS2019-Learn-to-Move-Challenge/README.md +++ b/examples/NeurIPS2019-Learn-to-Move-Challenge/README.md @@ -1,8 +1,13 @@ - +The **PARL** team gets the first place in NeurIPS reinforcement learning competition, again! This folder contains our final submitted model and the code relative to the training process. + +

+PARL +

+ ## Dependencies - python3.6 -- [paddlepaddle>=1.5.2](https://github.com/PaddlePaddle/Paddle) -- [parl>=1.2](https://github.com/PaddlePaddle/PARL) +- [paddlepaddle==1.5.2](https://github.com/PaddlePaddle/Paddle) +- [parl>=1.2.1](https://github.com/PaddlePaddle/PARL) - [osim-rl==3.0.11](https://github.com/stanfordnmbl/osim-rl) @@ -11,8 +16,63 @@ - How to Run 1. Enter the sub-folder `final_submit` - 2. Download the model file from online stroage service: [Baidu Pan](https://pan.baidu.com/s/12LIPspckCT8-Q5U1QX69Fg) (password: `b5ck`) or [Google Drive](https://drive.google.com/file/d/1jJtOcOVJ6auz3s-TyWgUJvofPXI94yxy/view?usp=sharing) - 3. Unpack the file: + 2. Download the model file from online stroage service: [Baidu Pan](https://pan.baidu.com/s/12LIPspckCT8-Q5U1QX69Fg) (password: `b5ck`) or [Google Drive](https://drive.google.com/file/d/1jJtOcOVJ6auz3s-TyWgUJvofPXI94yxy/view?usp=sharing) + 3. Unpack the file: `tar zxvf saved_models.tar.gz` - 4. Launch the test script: + 4. Launch the test script: `python test.py` + + +## Part2: Curriculum learning + +#### 1. Run as fast as possible -> run at 3.0 m/s -> walk at 2.0 m/s -> walk slowly at 1.3 m/s +The curriculum learning pipeline to get a walking slowly model is the same pipeline in [our winning solution in NeurIPS 2018: AI for Prosthetics Challenge](https://github.com/PaddlePaddle/PARL/tree/develop/examples/NeurIPS2018-AI-for-Prosthetics-Challenge). You can get a walking slowly model by following the [guide](https://github.com/PaddlePaddle/PARL/tree/develop/examples/NeurIPS2018-AI-for-Prosthetics-Challenge#part2-curriculum-learning). + +We also provide a pre-trained model that walk naturally at ~1.3m/s. You can download the model file (naming `low_speed_model`) from online stroage service: [Baidu Pan](https://pan.baidu.com/s/1Mi_6bD4QxLWLdyLYe2GRFw) (password: `q9vj`) or [Google Drive](https://drive.google.com/file/d/1_cz6Cg3DAT4u2a5mxk2vP9u8nDWOE7rW/view?usp=sharing). + +#### 2. difficulty=1 +> We built our distributed training agent based on PARL cluster. To start a PARL cluster, we can execute the following two xparl commands: +> +> +>```bash +># starts a master node to manage computation resources and adds the local CPUs to the cluster. +>xparl start --port 8010 +>``` +> +>```bash +># if necessary, adds more CPUs (computation resources) in other machine to the cluster. +>xparl connect --address [CLUSTER_IP]:8010 +>``` +> +> For more information of xparl, please visit the [documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html). + +In this example, we can start a local cluster with 300 CPUs by running: + +```bash +xparl start --port 8010 --cpu_num 300 +``` + +Then, we can start the distributed training by running: +```bash +sh scripts/train_difficulty1.sh ./low_speed_model +``` + +Optionally, you can start the distributed evaluating by running: +```bash +sh scripts/eval_difficulty1.sh +``` + +#### 3. difficulty=2 +```bash +sh scripts/train_difficulty2.sh [DIFFICULTY=1 MODEL] +``` + +#### 4. difficulty=3, first target +```bash +sh scripts/train_difficulty3_first_target.sh [DIFFICULTY=2 MODEL] +``` + +#### 5. difficulty=3 +```bash +sh scripts/train_difficulty3.sh [DIFFICULTY=3 FIRST TARGET MODEL] +``` diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/actor.py b/examples/NeurIPS2019-Learn-to-Move-Challenge/actor.py new file mode 100644 index 0000000000000000000000000000000000000000..211fb4dfcf6fa14c99cff6e94479007c99abfd45 --- /dev/null +++ b/examples/NeurIPS2019-Learn-to-Move-Challenge/actor.py @@ -0,0 +1,56 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import numpy as np +from osim.env import L2M2019Env +from env_wrapper import FrameSkip, ActionScale, OfficialObs, FinalReward, FirstTarget + + +@parl.remote_class +class Actor(object): + def __init__(self, + difficulty, + vel_penalty_coeff, + muscle_penalty_coeff, + penalty_coeff, + only_first_target=False): + + random_seed = np.random.randint(int(1e9)) + + env = L2M2019Env( + difficulty=difficulty, visualize=False, seed=random_seed) + max_timelimit = env.time_limit + + env = FinalReward( + env, + max_timelimit=max_timelimit, + vel_penalty_coeff=vel_penalty_coeff, + muscle_penalty_coeff=muscle_penalty_coeff, + penalty_coeff=penalty_coeff) + + if only_first_target: + assert difficulty == 3, "argument `only_first_target` is available only in `difficulty=3`." + env = FirstTarget(env) + + env = FrameSkip(env) + env = ActionScale(env) + self.env = OfficialObs(env, max_timelimit=max_timelimit) + + def reset(self): + observation = self.env.reset(project=False) + return observation + + def step(self, action): + return self.env.step(action, project=False) diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/env_wrapper.py b/examples/NeurIPS2019-Learn-to-Move-Challenge/env_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..2290fc60f7f91857404338da0859229da681c939 --- /dev/null +++ b/examples/NeurIPS2019-Learn-to-Move-Challenge/env_wrapper.py @@ -0,0 +1,482 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import numpy as np +import gym +import abc +import math +from osim.env import L2M2019Env +from parl.utils import logger + + +class FirstTarget(gym.Wrapper): + def __init__(self, env): + assert (isinstance(env, RewardShaping)), type(env) + gym.Wrapper.__init__(self, env) + + def step(self, action, **kwargs): + obs, r, done, info = self.env.step(action, **kwargs) + # early stop condition + if info['target_changed']: + info['timeout'] = True + done = True + logger.warning( + '[FirstTarget Wrapper] early stop since first target is finished.' + ) + return obs, r, done, info + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + + +class ActionScale(gym.Wrapper): + def __init__(self, env): + gym.Wrapper.__init__(self, env) + + def step(self, action, **kwargs): + action = (np.copy(action) + 1.0) * 0.5 + action = np.clip(action, 0.0, 1.0) + return self.env.step(action, **kwargs) + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + + +class FrameSkip(gym.Wrapper): + def __init__(self, env, skip_num=4): + gym.Wrapper.__init__(self, env) + self.skip_num = skip_num + self.frame_count = 0 + + def step(self, action, **kwargs): + r = 0.0 + merge_info = {} + for k in range(self.skip_num): + self.frame_count += 1 + obs, reward, done, info = self.env.step(action, **kwargs) + r += reward + + for key in info.keys(): + if 'reward' in key: + merge_info[key] = merge_info.get(key, 0.0) + info[key] + else: + merge_info[key] = info[key] + + if info['target_changed']: + logger.warning( + "[FrameSkip Wrapper] early break since target is changed") + break + + if done: + break + merge_info['frame_count'] = self.frame_count + return obs, r, done, merge_info + + def reset(self, **kwargs): + self.frame_count = 0 + return self.env.reset(**kwargs) + + +class RewardShaping(gym.Wrapper): + """ A wrapper for reward shaping, note this wrapper must be the first wrapper """ + + def __init__(self, env, max_timelimit): + logger.info("[RewardShaping]type:{}, max_timelimit: {}".format( + type(env), max_timelimit)) + + self.max_timelimit = max_timelimit + + self.step_count = 0 + self.pre_state_desc = None + self.last_target_vel = None + gym.Wrapper.__init__(self, env) + + @abc.abstractmethod + def reward_shaping(self, state_desc, reward, done, action, info): + """define your own reward computation function + Args: + state_desc(dict): state description for current model + reward(scalar): generic reward generated by env + done(bool): generic done flag generated by env + info(dict): generic info generated by env + """ + pass + + def step(self, action, **kwargs): + self.step_count += 1 + obs, r, done, info = self.env.step(action, **kwargs) + + info = self.reward_shaping(obs, r, done, action, info) + + target_vel = np.linalg.norm( + [obs['v_tgt_field'][0][5][5], obs['v_tgt_field'][1][5][5]]) + info['target_changed'] = False + if self.last_target_vel is not None: + if np.abs(target_vel - self.last_target_vel) > 0.2: + info['target_changed'] = True + self.last_target_vel = target_vel + + assert 'shaping_reward' in info + timeout = False + if self.step_count >= self.max_timelimit: + timeout = True + + info['timeout'] = timeout + self.pre_state_desc = obs + return obs, r, done, info + + def reset(self, **kwargs): + self.step_count = 0 + self.last_target_vel = None + obs = self.env.reset(**kwargs) + self.pre_state_desc = obs + return obs + + +class FinalReward(RewardShaping): + """ A reward shaping wrapper""" + + def __init__(self, env, max_timelimit, vel_penalty_coeff, + muscle_penalty_coeff, penalty_coeff): + RewardShaping.__init__(self, env, max_timelimit=max_timelimit) + + self.vel_penalty_coeff = vel_penalty_coeff + self.muscle_penalty_coeff = muscle_penalty_coeff + self.penalty_coeff = penalty_coeff + + def reward_shaping(self, state_desc, env_reward, done, action, info): + # Reward for not falling down + reward = 10.0 + + yaw = state_desc['joint_pos']['ground_pelvis'][2] + current_v_x, current_v_z = rotate_frame( + state_desc['body_vel']['pelvis'][0], + state_desc['body_vel']['pelvis'][2], yaw) + # leftward (Attention!!!) + current_v_z = -current_v_z + + # current relative target theta + target_v_x, target_v_z = state_desc['v_tgt_field'][0][5][ + 5], state_desc['v_tgt_field'][1][5][5] + + vel_penalty = np.linalg.norm( + [target_v_x - current_v_x, target_v_z - current_v_z]) + + muscle_penalty = 0 + for muscle in sorted(state_desc['muscles'].keys()): + muscle_penalty += np.square( + state_desc['muscles'][muscle]['activation']) + + ret_r = reward - (vel_penalty * self.vel_penalty_coeff + muscle_penalty + * self.muscle_penalty_coeff) * self.penalty_coeff + + info = { + 'shaping_reward': ret_r, + 'env_reward': env_reward, + } + return info + + +class ObsTranformerBase(gym.Wrapper): + def __init__(self, env, max_timelimit, skip_num=4): + gym.Wrapper.__init__(self, env) + self.max_timelimit = max_timelimit + self.skip_num = skip_num + + self.step_fea = self.max_timelimit + + def get_observation(self, state_desc): + obs = self._get_observation(state_desc) + return obs + + @abc.abstractmethod + def _get_observation(self, state_desc): + pass + + def feature_normalize(self, obs, mean, std, duplicate_id): + scaler_len = mean.shape[0] + assert obs.shape[0] >= scaler_len + obs[:scaler_len] = (obs[:scaler_len] - mean) / std + final_obs = [] + for i in range(obs.shape[0]): + if i not in duplicate_id: + final_obs.append(obs[i]) + return np.array(final_obs) + + def step(self, action, **kwargs): + obs, r, done, info = self.env.step(action, **kwargs) + + if info['target_changed']: + # reset step_fea when change target + self.step_fea = self.max_timelimit + + self.step_fea -= self.skip_num + + obs = self.get_observation(obs) + return obs, r, done, info + + def reset(self, **kwargs): + obs = self.env.reset(**kwargs) + self.step_fea = self.max_timelimit + obs = self.get_observation(obs) + return obs + + +class OfficialObs(ObsTranformerBase): + """Basically same feature processing as official. + + Reference: https://github.com/stanfordnmbl/osim-rl/blob/master/osim/env/osim.py + """ + + MASS = 75.16460000000001 # 11.777 + 2*(9.3014 + 3.7075 + 0.1 + 1.25 + 0.2166) + 34.2366 + G = 9.80665 # from gait1dof22muscle + + LENGTH0 = 1 # leg lengt + + Fmax = { + 'r_leg': { + 'HAB': 4460.290481, + 'HAD': 3931.8, + 'HFL': 2697.344262, + 'GLU': 3337.583607, + 'HAM': 4105.465574, + 'RF': 2191.74098360656, + 'VAS': 9593.95082, + 'BFSH': 557.11475409836, + 'GAS': 4690.57377, + 'SOL': 7924.996721, + 'TA': 2116.818162 + }, + 'l_leg': { + 'HAB': 4460.290481, + 'HAD': 3931.8, + 'HFL': 2697.344262, + 'GLU': 3337.583607, + 'HAM': 4105.465574, + 'RF': 2191.74098360656, + 'VAS': 9593.95082, + 'BFSH': 557.11475409836, + 'GAS': 4690.57377, + 'SOL': 7924.996721, + 'TA': 2116.818162 + } + } + lopt = { + 'r_leg': { + 'HAB': 0.0845, + 'HAD': 0.087, + 'HFL': 0.117, + 'GLU': 0.157, + 'HAM': 0.069, + 'RF': 0.076, + 'VAS': 0.099, + 'BFSH': 0.11, + 'GAS': 0.051, + 'SOL': 0.044, + 'TA': 0.068 + }, + 'l_leg': { + 'HAB': 0.0845, + 'HAD': 0.087, + 'HFL': 0.117, + 'GLU': 0.157, + 'HAM': 0.069, + 'RF': 0.076, + 'VAS': 0.099, + 'BFSH': 0.11, + 'GAS': 0.051, + 'SOL': 0.044, + 'TA': 0.068 + } + } + + def __init__(self, env, max_timelimit, skip_num=4): + ObsTranformerBase.__init__(self, env, max_timelimit, skip_num) + data = np.load('./official_obs_scaler.npz') + self.mean, self.std, self.duplicate_id = data['mean'], data[ + 'std'], data['duplicate_id'] + self.duplicate_id = self.duplicate_id.astype(np.int32).tolist() + + def _get_observation_dict(self, state_desc): + obs_dict = {} + + obs_dict['v_tgt_field'] = state_desc['v_tgt_field'] + + # pelvis state (in local frame) + obs_dict['pelvis'] = {} + obs_dict['pelvis']['height'] = state_desc['body_pos']['pelvis'][1] + obs_dict['pelvis']['pitch'] = -state_desc['joint_pos'][ + 'ground_pelvis'][0] # (+) pitching forward + obs_dict['pelvis']['roll'] = state_desc['joint_pos']['ground_pelvis'][ + 1] # (+) rolling around the forward axis (to the right) + yaw = state_desc['joint_pos']['ground_pelvis'][2] + dx_local, dy_local = rotate_frame(state_desc['body_vel']['pelvis'][0], + state_desc['body_vel']['pelvis'][2], + yaw) + dz_local = state_desc['body_vel']['pelvis'][1] + obs_dict['pelvis']['vel'] = [ + dx_local, # (+) forward + -dy_local, # (+) leftward + dz_local, # (+) upward + -state_desc['joint_vel']['ground_pelvis'] + [0], # (+) pitch angular velocity + state_desc['joint_vel']['ground_pelvis'] + [1], # (+) roll angular velocity + state_desc['joint_vel']['ground_pelvis'][2] + ] # (+) yaw angular velocity + + # leg state + for leg, side in zip(['r_leg', 'l_leg'], ['r', 'l']): + obs_dict[leg] = {} + grf = [ + f / (self.MASS * self.G) + for f in state_desc['forces']['foot_{}'.format(side)][0:3] + ] # forces normalized by bodyweight + grm = [ + m / (self.MASS * self.G) + for m in state_desc['forces']['foot_{}'.format(side)][3:6] + ] # forces normalized by bodyweight + grfx_local, grfy_local = rotate_frame(-grf[0], -grf[2], yaw) + if leg == 'r_leg': + obs_dict[leg]['ground_reaction_forces'] = [ + grfx_local, # (+) forward + grfy_local, # (+) lateral (rightward) + -grf[1] + ] # (+) upward + if leg == 'l_leg': + obs_dict[leg]['ground_reaction_forces'] = [ + grfx_local, # (+) forward + -grfy_local, # (+) lateral (leftward) + -grf[1] + ] # (+) upward + + # joint angles + obs_dict[leg]['joint'] = {} + obs_dict[leg]['joint']['hip_abd'] = -state_desc['joint_pos'][ + 'hip_{}'.format(side)][1] # (+) hip abduction + obs_dict[leg]['joint']['hip'] = -state_desc['joint_pos'][ + 'hip_{}'.format(side)][0] # (+) extension + obs_dict[leg]['joint']['knee'] = state_desc['joint_pos'][ + 'knee_{}'.format(side)][0] # (+) extension + obs_dict[leg]['joint']['ankle'] = -state_desc['joint_pos'][ + 'ankle_{}'.format(side)][0] # (+) extension + # joint angular velocities + obs_dict[leg]['d_joint'] = {} + obs_dict[leg]['d_joint']['hip_abd'] = -state_desc['joint_vel'][ + 'hip_{}'.format(side)][1] # (+) hip abduction + obs_dict[leg]['d_joint']['hip'] = -state_desc['joint_vel'][ + 'hip_{}'.format(side)][0] # (+) extension + obs_dict[leg]['d_joint']['knee'] = state_desc['joint_vel'][ + 'knee_{}'.format(side)][0] # (+) extension + obs_dict[leg]['d_joint']['ankle'] = -state_desc['joint_vel'][ + 'ankle_{}'.format(side)][0] # (+) extension + + # muscles + for MUS, mus in zip([ + 'HAB', 'HAD', 'HFL', 'GLU', 'HAM', 'RF', 'VAS', 'BFSH', + 'GAS', 'SOL', 'TA' + ], [ + 'abd', 'add', 'iliopsoas', 'glut_max', 'hamstrings', + 'rect_fem', 'vasti', 'bifemsh', 'gastroc', 'soleus', + 'tib_ant' + ]): + obs_dict[leg][MUS] = {} + obs_dict[leg][MUS]['f'] = state_desc['muscles']['{}_{}'.format( + mus, side)]['fiber_force'] / self.Fmax[leg][MUS] + obs_dict[leg][MUS]['l'] = state_desc['muscles']['{}_{}'.format( + mus, side)]['fiber_length'] / self.lopt[leg][MUS] + obs_dict[leg][MUS]['v'] = state_desc['muscles']['{}_{}'.format( + mus, side)]['fiber_velocity'] / self.lopt[leg][MUS] + + return obs_dict + + def _get_observation(self, state_desc): + + obs_dict = self._get_observation_dict(state_desc) + res = [] + + # target velocity field (in body frame) + #v_tgt = np.ndarray.flatten(obs_dict['v_tgt_field']) + #res += v_tgt.tolist() + + res.append(obs_dict['pelvis']['height']) + res.append(obs_dict['pelvis']['pitch']) + res.append(obs_dict['pelvis']['roll']) + res.append(obs_dict['pelvis']['vel'][0]) + res.append(obs_dict['pelvis']['vel'][1]) + res.append(obs_dict['pelvis']['vel'][2]) + res.append(obs_dict['pelvis']['vel'][3]) + res.append(obs_dict['pelvis']['vel'][4]) + res.append(obs_dict['pelvis']['vel'][5]) + + for leg in ['r_leg', 'l_leg']: + res += obs_dict[leg]['ground_reaction_forces'] + res.append(obs_dict[leg]['joint']['hip_abd']) + res.append(obs_dict[leg]['joint']['hip']) + res.append(obs_dict[leg]['joint']['knee']) + res.append(obs_dict[leg]['joint']['ankle']) + res.append(obs_dict[leg]['d_joint']['hip_abd']) + res.append(obs_dict[leg]['d_joint']['hip']) + res.append(obs_dict[leg]['d_joint']['knee']) + res.append(obs_dict[leg]['d_joint']['ankle']) + for MUS in [ + 'HAB', 'HAD', 'HFL', 'GLU', 'HAM', 'RF', 'VAS', 'BFSH', + 'GAS', 'SOL', 'TA' + ]: + res.append(obs_dict[leg][MUS]['f']) + res.append(obs_dict[leg][MUS]['l']) + res.append(obs_dict[leg][MUS]['v']) + + res = np.array(res) + + res = self.feature_normalize( + res, mean=self.mean, std=self.std, duplicate_id=self.duplicate_id) + + remaining_time = (self.step_fea - (self.max_timelimit / 2.0)) / ( + self.max_timelimit / 2.0) * -1.0 + res = np.append(res, remaining_time) + + # target driven (Relative coordinate system) + current_v_x = obs_dict['pelvis']['vel'][0] # (+) forward + current_v_z = obs_dict['pelvis']['vel'][1] # (+) leftward + + # future vels (0m, 1m, ..., 5m) + for index in range(5, 11): + target_v_x, target_v_z = state_desc['v_tgt_field'][0][index][ + 5], state_desc['v_tgt_field'][1][index][5] + + diff_vel_x = target_v_x - current_v_x + diff_vel_z = target_v_z - current_v_z + diff_vel = np.sqrt(target_v_x ** 2 + target_v_z ** 2) - \ + np.sqrt(current_v_x ** 2 + current_v_z ** 2) + res = np.append( + res, [diff_vel_x / 5.0, diff_vel_z / 5.0, diff_vel / 5.0]) + + # current relative target theta + target_v_x, target_v_z = state_desc['v_tgt_field'][0][5][ + 5], state_desc['v_tgt_field'][1][5][5] + target_theta = math.atan2(target_v_z, target_v_x) + + diff_theta = target_theta + + res = np.append(res, [diff_theta / np.pi]) + + return res + + +def rotate_frame(x, y, theta): + x_rot = np.cos(theta) * x - np.sin(theta) * y + y_rot = np.sin(theta) * x + np.cos(theta) * y + return x_rot, y_rot diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/evaluate.py b/examples/NeurIPS2019-Learn-to-Move-Challenge/evaluate.py new file mode 100755 index 0000000000000000000000000000000000000000..9e6230efe547c0dae53e15dad2488fce357b8ac1 --- /dev/null +++ b/examples/NeurIPS2019-Learn-to-Move-Challenge/evaluate.py @@ -0,0 +1,298 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import parl +import queue +import six +import threading +import time +import numpy as np +from actor import Actor +from opensim_model import OpenSimModel +from opensim_agent import OpenSimAgent +from parl.utils import logger, ReplayMemory, tensorboard +from parl.utils.window_stat import WindowStat +from parl.remote.client import get_global_client +from shutil import copy2 + +ACT_DIM = 22 +VEL_DIM = 19 +OBS_DIM = 98 + VEL_DIM +GAMMA = 0.96 +TAU = 0.001 +ACTOR_LR = 3e-5 +CRITIC_LR = 3e-5 + + +class TransitionExperience(object): + """ A transition of state, or experience""" + + def __init__(self, obs, action, reward, info, **kwargs): + """ kwargs: whatever other attribute you want to save""" + self.obs = obs + self.action = action + self.reward = reward + self.info = info + for k, v in six.iteritems(kwargs): + setattr(self, k, v) + + +class ActorState(object): + """Maintain incomplete trajectories data of actor.""" + + def __init__(self): + self.memory = [] # list of Experience + self.model_name = None + + def reset(self): + self.memory = [] + + +class Evaluator(object): + def __init__(self, args): + model = OpenSimModel(OBS_DIM, VEL_DIM, ACT_DIM) + algorithm = parl.algorithms.DDPG( + model, + gamma=GAMMA, + tau=TAU, + actor_lr=ACTOR_LR, + critic_lr=CRITIC_LR) + self.agent = OpenSimAgent(algorithm, OBS_DIM, ACT_DIM) + + self.evaluate_result = [] + + self.lock = threading.Lock() + self.model_lock = threading.Lock() + self.model_queue = queue.Queue() + + self.best_shaping_reward = 0 + self.best_env_reward = 0 + + if args.offline_evaluate: + self.offline_evaluate() + else: + t = threading.Thread(target=self.online_evaluate) + t.start() + + with self.lock: + while True: + model_path = self.model_queue.get() + if not args.offline_evaluate: + # online evaluate + while not self.model_queue.empty(): + model_path = self.model_queue.get() + try: + self.agent.restore(model_path) + break + except Exception as e: + logger.warn("Agent restore Exception: {} ".format(e)) + + self.cur_model = model_path + + self.create_actors() + + def create_actors(self): + """Connect to the cluster and start sampling of the remote actor. + """ + parl.connect(args.cluster_address, ['official_obs_scaler.npz']) + + for i in range(args.actor_num): + logger.info('Remote actor count: {}'.format(i + 1)) + + remote_thread = threading.Thread(target=self.run_remote_sample) + remote_thread.setDaemon(True) + remote_thread.start() + + # There is a memory-leak problem in osim-rl package. + # So we will dynamically add actors when remote actors killed due to excessive memory usage. + time.sleep(10 * 60) + parl_client = get_global_client() + while True: + if parl_client.actor_num < args.actor_num: + logger.info( + 'Dynamic adding acotr, current actor num:{}'.format( + parl_client.actor_num)) + remote_thread = threading.Thread(target=self.run_remote_sample) + remote_thread.setDaemon(True) + remote_thread.start() + time.sleep(5) + + def offline_evaluate(self): + ckpt_paths = set([]) + for x in os.listdir(args.saved_models_dir): + path = os.path.join(args.saved_models_dir, x) + ckpt_paths.add(path) + ckpt_paths = list(ckpt_paths) + steps = [int(x.split('-')[-1]) for x in ckpt_paths] + sorted_idx = sorted(range(len(steps)), key=lambda k: steps[k]) + ckpt_paths = [ckpt_paths[i] for i in sorted_idx] + ckpt_paths.reverse() + logger.info("All checkpoints: {}".format(ckpt_paths)) + for ckpt_path in ckpt_paths: + self.model_queue.put(ckpt_path) + + def online_evaluate(self): + last_model_step = None + while True: + ckpt_paths = set([]) + for x in os.listdir(args.saved_models_dir): + path = os.path.join(args.saved_models_dir, x) + ckpt_paths.add(path) + if len(ckpt_paths) == 0: + time.sleep(60) + continue + ckpt_paths = list(ckpt_paths) + steps = [int(x.split('-')[-1]) for x in ckpt_paths] + sorted_idx = sorted(range(len(steps)), key=lambda k: steps[k]) + ckpt_paths = [ckpt_paths[i] for i in sorted_idx] + model_step = ckpt_paths[-1].split('-')[-1] + if model_step != last_model_step: + logger.info("Adding new checkpoint: :{}".format( + ckpt_paths[-1])) + self.model_queue.put(ckpt_paths[-1]) + last_model_step = model_step + time.sleep(60) + + def run_remote_sample(self): + remote_actor = Actor( + difficulty=args.difficulty, + vel_penalty_coeff=args.vel_penalty_coeff, + muscle_penalty_coeff=args.muscle_penalty_coeff, + penalty_coeff=args.penalty_coeff, + only_first_target=args.only_first_target) + + actor_state = ActorState() + + while True: + actor_state.model_name = self.cur_model + actor_state.reset() + + obs = remote_actor.reset() + + while True: + if actor_state.model_name != self.cur_model: + break + + actor_state.memory.append( + TransitionExperience( + obs=obs, + action=None, + reward=None, + info=None, + timestamp=time.time())) + + action = self.pred_batch(obs) + + obs, reward, done, info = remote_actor.step(action) + + actor_state.memory[-1].reward = reward + actor_state.memory[-1].info = info + actor_state.memory[-1].action = action + if done: + self._parse_memory(actor_state) + break + + def _parse_memory(self, actor_state): + mem = actor_state.memory + n = len(mem) + episode_shaping_reward = np.sum( + [exp.info['shaping_reward'] for exp in mem]) + episode_env_reward = np.sum([exp.info['env_reward'] for exp in mem]) + + with self.lock: + if actor_state.model_name == self.cur_model: + self.evaluate_result.append({ + 'shaping_reward': + episode_shaping_reward, + 'env_reward': + episode_env_reward, + 'episode_length': + mem[-1].info['frame_count'], + 'falldown': + not mem[-1].info['timeout'], + }) + logger.info('{}, finish_cnt: {}'.format( + self.cur_model, len(self.evaluate_result))) + logger.info('{}'.format(self.evaluate_result[-1])) + if len(self.evaluate_result) >= args.evaluate_times: + mean_value = {} + for key in self.evaluate_result[0].keys(): + mean_value[key] = np.mean( + [x[key] for x in self.evaluate_result]) + logger.info('Model: {}, mean_value: {}'.format( + self.cur_model, mean_value)) + + eval_num = len(self.evaluate_result) + falldown_num = len( + [x for x in self.evaluate_result if x['falldown']]) + falldown_rate = falldown_num / eval_num + logger.info('Falldown rate: {}'.format(falldown_rate)) + for key in self.evaluate_result[0].keys(): + mean_value[key] = np.mean([ + x[key] for x in self.evaluate_result + if not x['falldown'] + ]) + logger.info( + 'Model: {}, Exclude falldown, mean_value: {}'.format( + self.cur_model, mean_value)) + if mean_value['shaping_reward'] > self.best_shaping_reward: + self.best_shaping_reward = mean_value['shaping_reward'] + copy2(self.cur_model, './model_zoo') + logger.info( + "[best shaping reward updated:{}] path:{}".format( + self.best_shaping_reward, self.cur_model)) + if mean_value[ + 'env_reward'] > self.best_env_reward and falldown_rate < 0.3: + self.best_env_reward = mean_value['env_reward'] + copy2(self.cur_model, './model_zoo') + logger.info( + "[best env reward updated:{}] path:{}, falldown rate: {}" + .format(self.best_env_reward, self.cur_model, + falldown_num / eval_num)) + + self.evaluate_result = [] + while True: + model_path = self.model_queue.get() + if not args.offline_evaluate: + # online evaluate + while not self.model_queue.empty(): + model_path = self.model_queue.get() + try: + self.agent.restore(model_path) + break + except Exception as e: + logger.warn( + "Agent restore Exception: {} ".format(e)) + self.cur_model = model_path + else: + actor_state.model_name = self.cur_model + actor_state.reset() + + def pred_batch(self, obs): + batch_obs = np.expand_dims(obs, axis=0) + with self.model_lock: + action = self.agent.predict(batch_obs.astype('float32')) + + action = np.squeeze(action, axis=0) + return action + + +if __name__ == '__main__': + from evaluate_args import get_args + args = get_args() + if args.logdir is not None: + logger.set_dir(args.logdir) + + evaluate = Evaluator(args) diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/evaluate_args.py b/examples/NeurIPS2019-Learn-to-Move-Challenge/evaluate_args.py new file mode 100644 index 0000000000000000000000000000000000000000..d501f893dfc1cabebcf5e21786164d8d7a8d4284 --- /dev/null +++ b/examples/NeurIPS2019-Learn-to-Move-Challenge/evaluate_args.py @@ -0,0 +1,80 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + '--cluster_address', + default='localhost:8081', + type=str, + help='cluster address of xparl.') + parser.add_argument( + '--actor_num', type=int, required=True, help='number of actors.') + parser.add_argument( + '--logdir', + type=str, + default='logdir', + help='directory to save model/tensorboard data') + + parser.add_argument( + '--difficulty', + type=int, + required=True, + help= + 'difficulty of L2M2019Env. difficulty=1 means Round 1 environment but target theta is always 0; difficulty=2 menas Round 1 environment; difficulty=3 means Round 2 environment.' + ) + parser.add_argument( + '--vel_penalty_coeff', + type=float, + default=1.0, + help='coefficient of velocity penalty in reward shaping.') + parser.add_argument( + '--muscle_penalty_coeff', + type=float, + default=1.0, + help='coefficient of muscle penalty in reward shaping.') + parser.add_argument( + '--penalty_coeff', + type=float, + default=1.0, + help='coefficient of all penalty in reward shaping.') + parser.add_argument( + '--only_first_target', + action="store_true", + help= + 'if set, will terminate the environment run after the first target finished.' + ) + + parser.add_argument( + '--saved_models_dir', + type=str, + required=True, + help='directory of saved models.') + parser.add_argument( + '--offline_evaluate', + action="store_true", + help='if set, will evaluate models offline.') + parser.add_argument( + '--evaluate_times', + default=300, + type=int, + help='evaluate episodes per model.') + + args = parser.parse_args() + + return args diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/image/performance.gif b/examples/NeurIPS2019-Learn-to-Move-Challenge/image/performance.gif new file mode 100644 index 0000000000000000000000000000000000000000..1a38bae66974f2f7058ca4cc1401da2b123cf998 Binary files /dev/null and b/examples/NeurIPS2019-Learn-to-Move-Challenge/image/performance.gif differ diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/official_obs_scaler.npz b/examples/NeurIPS2019-Learn-to-Move-Challenge/official_obs_scaler.npz new file mode 100644 index 0000000000000000000000000000000000000000..1099b5f2b036833d9149987b7c10e5a1ec78dbef Binary files /dev/null and b/examples/NeurIPS2019-Learn-to-Move-Challenge/official_obs_scaler.npz differ diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/opensim_agent.py b/examples/NeurIPS2019-Learn-to-Move-Challenge/opensim_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..7f380237a34ef54e56b3b5e044572cce31a64cb1 --- /dev/null +++ b/examples/NeurIPS2019-Learn-to-Move-Challenge/opensim_agent.py @@ -0,0 +1,90 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import numpy as np +from parl import layers +from parl.utils import machine_info +from paddle import fluid + + +class OpenSimAgent(parl.Agent): + def __init__(self, algorithm, obs_dim, act_dim): + self.obs_dim = obs_dim + self.act_dim = act_dim + super(OpenSimAgent, self).__init__(algorithm) + + # Use ParallelExecutor to make program running faster + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.use_experimental_executor = True + exec_strategy.num_threads = 4 + build_strategy = fluid.BuildStrategy() + build_strategy.remove_unnecessary_lock = True + + with fluid.scope_guard(fluid.global_scope().new_scope()): + self.learn_pe = fluid.ParallelExecutor( + use_cuda=machine_info.is_gpu_available(), + main_program=self.learn_program, + exec_strategy=exec_strategy, + build_strategy=build_strategy) + + with fluid.scope_guard(fluid.global_scope().new_scope()): + self.pred_pe = fluid.ParallelExecutor( + use_cuda=machine_info.is_gpu_available(), + main_program=self.pred_program, + exec_strategy=exec_strategy, + build_strategy=build_strategy) + + # Attention: In the beginning, sync target model totally. + self.alg.sync_target( + decay=0, share_vars_parallel_executor=self.learn_pe) + + def build_program(self): + self.pred_program = fluid.Program() + self.learn_program = fluid.Program() + + with fluid.program_guard(self.pred_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + self.pred_act = self.alg.predict(obs) + + with fluid.program_guard(self.learn_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + act = layers.data( + name='act', shape=[self.act_dim], dtype='float32') + reward = layers.data(name='reward', shape=[], dtype='float32') + next_obs = layers.data( + name='next_obs', shape=[self.obs_dim], dtype='float32') + terminal = layers.data(name='terminal', shape=[], dtype='bool') + _, self.critic_cost = self.alg.learn(obs, act, reward, next_obs, + terminal) + + def predict(self, obs): + feed = {'obs': obs} + act = self.pred_pe.run(feed=[feed], fetch_list=[self.pred_act.name])[0] + return act + + def learn(self, obs, act, reward, next_obs, terminal): + feed = { + 'obs': obs, + 'act': act, + 'reward': reward, + 'next_obs': next_obs, + 'terminal': terminal + } + critic_cost = self.learn_pe.run( + feed=[feed], fetch_list=[self.critic_cost.name])[0] + self.alg.sync_target(share_vars_parallel_executor=self.learn_pe) + return critic_cost diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/opensim_model.py b/examples/NeurIPS2019-Learn-to-Move-Challenge/opensim_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8196c58dac482da01f5282b8350d1551cc422133 --- /dev/null +++ b/examples/NeurIPS2019-Learn-to-Move-Challenge/opensim_model.py @@ -0,0 +1,162 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +from parl import layers +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr + + +class OpenSimModel(parl.Model): + def __init__(self, obs_dim, vel_obs_dim, act_dim): + self.actor_model = ActorModel(obs_dim, vel_obs_dim, act_dim) + self.critic_model = CriticModel(obs_dim, vel_obs_dim, act_dim) + + def policy(self, obs): + return self.actor_model.policy(obs) + + def value(self, obs, action): + return self.critic_model.value(obs, action) + + def get_actor_params(self): + return self.actor_model.parameters() + + +class ActorModel(parl.Model): + def __init__(self, obs_dim, vel_obs_dim, act_dim): + hid0_size = 800 + hid1_size = 400 + hid2_size = 200 + vel_hid0_size = 200 + vel_hid1_size = 400 + + self.obs_dim = obs_dim + self.vel_obs_dim = vel_obs_dim + + scope_name = 'policy' + + self.fc0 = layers.fc( + size=hid0_size, + act='tanh', + param_attr=ParamAttr(name='{}/h0/W'.format(scope_name)), + bias_attr=ParamAttr(name='{}/h0/b'.format(scope_name))) + self.fc1 = layers.fc( + size=hid1_size, + act='tanh', + param_attr=ParamAttr(name='{}/h1/W'.format(scope_name)), + bias_attr=ParamAttr(name='{}/h1/b'.format(scope_name))) + self.vel_fc0 = layers.fc( + size=vel_hid0_size, + act='tanh', + param_attr=ParamAttr(name='{}/vel_h0/W'.format(scope_name)), + bias_attr=ParamAttr(name='{}/vel_h0/b'.format(scope_name))) + self.vel_fc1 = layers.fc( + size=vel_hid1_size, + act='tanh', + param_attr=ParamAttr(name='{}/vel_h1/W'.format(scope_name)), + bias_attr=ParamAttr(name='{}/vel_h1/b'.format(scope_name))) + self.fc2 = layers.fc( + size=hid2_size, + act='tanh', + param_attr=ParamAttr(name='{}/h2/W'.format(scope_name)), + bias_attr=ParamAttr(name='{}/h2/b'.format(scope_name))) + self.fc3 = layers.fc( + size=act_dim, + act='tanh', + param_attr=ParamAttr(name='{}/means/W'.format(scope_name)), + bias_attr=ParamAttr(name='{}/means/b'.format(scope_name))) + + def policy(self, obs): + real_obs = layers.slice( + obs, axes=[1], starts=[0], ends=[-self.vel_obs_dim]) + # target related fetures + vel_obs = layers.slice( + obs, axes=[1], starts=[-self.vel_obs_dim], ends=[self.obs_dim]) + + hid0 = self.fc0(real_obs) + hid1 = self.fc1(hid0) + vel_hid0 = self.vel_fc0(vel_obs) + vel_hid1 = self.vel_fc1(vel_hid0) + concat = layers.concat([hid1, vel_hid1], axis=1) + hid2 = self.fc2(concat) + means = self.fc3(hid2) + return means + + +class CriticModel(parl.Model): + def __init__(self, obs_dim, vel_obs_dim, act_dim): + super(CriticModel, self).__init__() + hid0_size = 800 + hid1_size = 400 + vel_hid0_size = 200 + vel_hid1_size = 400 + + self.obs_dim = obs_dim + self.vel_obs_dim = vel_obs_dim + + scope_name = 'critic' + + self.fc0 = layers.fc( + size=hid0_size, + act='selu', + param_attr=ParamAttr(name='{}/w1/W'.format(scope_name)), + bias_attr=ParamAttr(name='{}/w1/b'.format(scope_name))) + self.fc1 = layers.fc( + size=hid1_size, + act='selu', + param_attr=ParamAttr(name='{}/h1/W'.format(scope_name)), + bias_attr=ParamAttr(name='{}/h1/b'.format(scope_name))) + self.vel_fc0 = layers.fc( + size=vel_hid0_size, + act='selu', + param_attr=ParamAttr(name='{}/vel_h0/W'.format(scope_name)), + bias_attr=ParamAttr(name='{}/vel_h0/b'.format(scope_name))) + self.vel_fc1 = layers.fc( + size=vel_hid1_size, + act='selu', + param_attr=ParamAttr(name='{}/vel_h1/W'.format(scope_name)), + bias_attr=ParamAttr(name='{}/vel_h1/b'.format(scope_name))) + self.act_fc0 = layers.fc( + size=hid1_size, + act='selu', + param_attr=ParamAttr(name='{}/a1/W'.format(scope_name)), + bias_attr=ParamAttr(name='{}/a1/b'.format(scope_name))) + self.fc2 = layers.fc( + size=hid1_size, + act='selu', + param_attr=ParamAttr(name='{}/h3/W'.format(scope_name)), + bias_attr=ParamAttr(name='{}/h3/b'.format(scope_name))) + self.fc3 = layers.fc( + size=1, + act='selu', + param_attr=ParamAttr(name='{}/value/W'.format(scope_name)), + bias_attr=ParamAttr(name='{}/value/b'.format(scope_name))) + + def value(self, obs, action): + real_obs = layers.slice( + obs, axes=[1], starts=[0], ends=[-self.vel_obs_dim]) + # target related fetures + vel_obs = layers.slice( + obs, axes=[1], starts=[-self.vel_obs_dim], ends=[self.obs_dim]) + + hid0 = self.fc0(real_obs) + hid1 = self.fc1(hid0) + vel_hid0 = self.vel_fc0(vel_obs) + vel_hid1 = self.vel_fc1(vel_hid0) + a1 = self.act_fc0(action) + concat = layers.concat([hid1, a1, vel_hid1], axis=1) + hid2 = self.fc2(concat) + Q = self.fc3(hid2) + Q = layers.squeeze(Q, axes=[1]) + return Q diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/eval_difficulty1.sh b/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/eval_difficulty1.sh new file mode 100644 index 0000000000000000000000000000000000000000..b9dcc895d3f66e15908d56db9615305ab38680a0 --- /dev/null +++ b/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/eval_difficulty1.sh @@ -0,0 +1,5 @@ +python evaluate.py --actor_num 160 \ + --difficulty 1 \ + --penalty_coeff 3.0 \ + --saved_models_dir ./output/difficulty1/model_every_100_episodes \ + --evaluate_times 300 diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/eval_difficulty2.sh b/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/eval_difficulty2.sh new file mode 100644 index 0000000000000000000000000000000000000000..705eacae2ffcabc738c145e2289f3cf8a7b846ba --- /dev/null +++ b/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/eval_difficulty2.sh @@ -0,0 +1,5 @@ +python evaluate.py --actor_num 160 \ + --difficulty 2 \ + --penalty_coeff 5.0 \ + --saved_models_dir ./output/difficulty2/model_every_100_episodes \ + --evaluate_times 300 diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/eval_difficulty3.sh b/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/eval_difficulty3.sh new file mode 100644 index 0000000000000000000000000000000000000000..7bc2d81973235edcb73bc8c3240f3deddfd759c1 --- /dev/null +++ b/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/eval_difficulty3.sh @@ -0,0 +1,6 @@ +python evaluate.py --actor_num 160 \ + --difficulty 3 \ + --vel_penalty_coeff 3.0 \ + --penalty_coeff 2.0 \ + --saved_models_dir ./output/difficulty3/model_every_100_episodes \ + --evaluate_times 300 diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/eval_difficulty3_first_target.sh b/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/eval_difficulty3_first_target.sh new file mode 100644 index 0000000000000000000000000000000000000000..840fec0143529fd602ea25334146d141ba245577 --- /dev/null +++ b/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/eval_difficulty3_first_target.sh @@ -0,0 +1,7 @@ +python evaluate.py --actor_num 160 \ + --difficulty 3 \ + --vel_penalty_coeff 3.0 \ + --penalty_coeff 3.0 \ + --only_first_target \ + --saved_models_dir ./output/difficulty3_first_target/model_every_100_episodes \ + --evaluate_times 300 diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/train_difficulty1.sh b/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/train_difficulty1.sh new file mode 100644 index 0000000000000000000000000000000000000000..b23d284a7fa16fe0a468fc059fa659686788d5cb --- /dev/null +++ b/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/train_difficulty1.sh @@ -0,0 +1,11 @@ +echo `which python` +if [ $# != 1 ]; then + echo "Usage: sh train_difficulty1.sh [RESTORE_MODEL_PATH]" + exit 0 +fi + +python train.py --actor_num 300 \ + --difficulty 1 \ + --penalty_coeff 3.0 \ + --logdir ./output/difficulty1 \ + --restore_model_path $1 diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/train_difficulty2.sh b/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/train_difficulty2.sh new file mode 100644 index 0000000000000000000000000000000000000000..0e305adc2ef3c2046eff69ae280654fa35279eb5 --- /dev/null +++ b/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/train_difficulty2.sh @@ -0,0 +1,10 @@ +if [ $# != 1 ]; then + echo "Usage: sh train_difficulty2.sh [RESTORE_MODEL_PATH]" + exit 0 +fi + +python train.py --actor_num 300 \ + --difficulty 2 \ + --penalty_coeff 5.0 \ + --logdir ./output/difficulty2 \ + --restore_model_path $1 diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/train_difficulty3.sh b/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/train_difficulty3.sh new file mode 100644 index 0000000000000000000000000000000000000000..9800fb1d7a87422305dcc4b12a6ec084e142ad4a --- /dev/null +++ b/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/train_difficulty3.sh @@ -0,0 +1,13 @@ +if [ $# != 1 ]; then + echo "Usage: sh train_difficulty3.sh [RESTORE_MODEL_PATH]" + exit 0 +fi + +python train.py --actor_num 300 \ + --difficulty 3 \ + --vel_penalty_coeff 3.0 \ + --penalty_coeff 2.0 \ + --rpm_size 6e6 \ + --train_times 250 \ + --logdir ./output/difficulty3 \ + --restore_model_path $1 diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/train_difficulty3_first_target.sh b/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/train_difficulty3_first_target.sh new file mode 100644 index 0000000000000000000000000000000000000000..8631ee4de5bbd2d8737b964d0f4f9cd4b0bdebfc --- /dev/null +++ b/examples/NeurIPS2019-Learn-to-Move-Challenge/scripts/train_difficulty3_first_target.sh @@ -0,0 +1,12 @@ +if [ $# != 1 ]; then + echo "Usage: sh train_difficulty3_first_target.sh [RESTORE_MODEL_PATH]" + exit 0 +fi + +python train.py --actor_num 300 \ + --difficulty 3 \ + --vel_penalty_coeff 3.0 \ + --penalty_coeff 3.0 \ + --only_first_target \ + --logdir ./output/difficulty3_first_target \ + --restore_model_path $1 diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/train.py b/examples/NeurIPS2019-Learn-to-Move-Challenge/train.py new file mode 100755 index 0000000000000000000000000000000000000000..cfd0e657c8350ccbb3357cf723e6db96cf5f0822 --- /dev/null +++ b/examples/NeurIPS2019-Learn-to-Move-Challenge/train.py @@ -0,0 +1,327 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import parl +import queue +import six +import threading +import time +import numpy as np +from actor import Actor +from opensim_model import OpenSimModel +from opensim_agent import OpenSimAgent +from parl.utils import logger, ReplayMemory, tensorboard +from parl.utils.window_stat import WindowStat +from parl.remote.client import get_global_client + +ACT_DIM = 22 +VEL_DIM = 19 +OBS_DIM = 98 + VEL_DIM +GAMMA = 0.96 +TAU = 0.001 +ACTOR_LR = 3e-5 +CRITIC_LR = 3e-5 +BATCH_SIZE = 128 +NOISE_DECAY = 0.999998 + + +class TransitionExperience(object): + """ A transition of state, or experience""" + + def __init__(self, obs, action, reward, info, **kwargs): + """ kwargs: whatever other attribute you want to save""" + self.obs = obs + self.action = action + self.reward = reward + self.info = info + for k, v in six.iteritems(kwargs): + setattr(self, k, v) + + +class ActorState(object): + """Maintain incomplete trajectories data of actor.""" + + def __init__(self): + self.memory = [] # list of Experience + self.ident = np.random.randint(int(1e18)) + self.last_target_changed_steps = 0 + + def reset(self): + self.memory = [] + self.last_target_changed_steps = 0 + + def update_last_target_changed(self): + self.last_target_changed_steps = len(self.memory) + + +class Learner(object): + def __init__(self, args): + model = OpenSimModel(OBS_DIM, VEL_DIM, ACT_DIM) + algorithm = parl.algorithms.DDPG( + model, + gamma=GAMMA, + tau=TAU, + actor_lr=ACTOR_LR, + critic_lr=CRITIC_LR) + self.agent = OpenSimAgent(algorithm, OBS_DIM, ACT_DIM) + + self.rpm = ReplayMemory(args.rpm_size, OBS_DIM, ACT_DIM) + + if args.restore_rpm_path is not None: + self.rpm.load(args.restore_rpm_path) + if args.restore_model_path is not None: + self.restore(args.restore_model_path) + + # add lock between training and predicting + self.model_lock = threading.Lock() + + # add lock when appending data to rpm or writing scalars to tensorboard + self.memory_lock = threading.Lock() + + self.ready_actor_queue = queue.Queue() + + self.total_steps = 0 + self.noiselevel = 0.5 + + self.critic_loss_stat = WindowStat(500) + self.env_reward_stat = WindowStat(500) + self.shaping_reward_stat = WindowStat(500) + self.max_env_reward = 0 + + # thread to keep training + learn_thread = threading.Thread(target=self.keep_training) + learn_thread.setDaemon(True) + learn_thread.start() + + self.create_actors() + + def create_actors(self): + """Connect to the cluster and start sampling of the remote actor. + """ + parl.connect(args.cluster_address, ['official_obs_scaler.npz']) + + for i in range(args.actor_num): + logger.info('Remote actor count: {}'.format(i + 1)) + + remote_thread = threading.Thread(target=self.run_remote_sample) + remote_thread.setDaemon(True) + remote_thread.start() + + # There is a memory-leak problem in osim-rl package. + # So we will dynamically add actors when remote actors killed due to excessive memory usage. + time.sleep(10 * 60) + parl_client = get_global_client() + while True: + if parl_client.actor_num < args.actor_num: + logger.info( + 'Dynamic adding acotr, current actor num:{}'.format( + parl_client.actor_num)) + remote_thread = threading.Thread(target=self.run_remote_sample) + remote_thread.setDaemon(True) + remote_thread.start() + time.sleep(5) + + def _new_ready_actor(self): + """ + + The actor is ready to start new episode, + but blocking until training thread call actor_ready_event.set() + """ + actor_ready_event = threading.Event() + self.ready_actor_queue.put(actor_ready_event) + logger.info( + "[new_avaliabe_actor] approximate size of ready actors:{}".format( + self.ready_actor_queue.qsize())) + actor_ready_event.wait() + + def run_remote_sample(self): + remote_actor = Actor( + difficulty=args.difficulty, + vel_penalty_coeff=args.vel_penalty_coeff, + muscle_penalty_coeff=args.muscle_penalty_coeff, + penalty_coeff=args.penalty_coeff, + only_first_target=args.only_first_target) + + actor_state = ActorState() + + while True: + obs = remote_actor.reset() + actor_state.reset() + + while True: + actor_state.memory.append( + TransitionExperience( + obs=obs, + action=None, + reward=None, + info=None, + timestamp=time.time())) + + action = self.pred_batch(obs) + + # For each target, decay noise as the steps increase. + step = len( + actor_state.memory) - actor_state.last_target_changed_steps + current_noise = self.noiselevel * (0.98**(step - 1)) + + noise = np.zeros((ACT_DIM, ), dtype=np.float32) + if actor_state.ident % 3 == 0: + if step % 5 == 0: + noise = np.random.randn(ACT_DIM) * current_noise + elif actor_state.ident % 3 == 1: + if step % 5 == 0: + noise = np.random.randn(ACT_DIM) * current_noise * 2 + action += noise + + action = np.clip(action, -1, 1) + + obs, reward, done, info = remote_actor.step(action) + + reward_scale = (1 - GAMMA) + info['shaping_reward'] *= reward_scale + + actor_state.memory[-1].reward = reward + actor_state.memory[-1].info = info + actor_state.memory[-1].action = action + + if 'target_changed' in info and info['target_changed']: + actor_state.update_last_target_changed() + + if done: + self._parse_memory(actor_state, last_obs=obs) + break + + self._new_ready_actor() + + def _parse_memory(self, actor_state, last_obs): + mem = actor_state.memory + n = len(mem) + + episode_shaping_reward = np.sum( + [exp.info['shaping_reward'] for exp in mem]) + episode_env_reward = np.sum([exp.info['env_reward'] for exp in mem]) + episode_time = time.time() - mem[0].timestamp + + episode_rpm = [] + for i in range(n - 1): + episode_rpm.append([ + mem[i].obs, mem[i].action, mem[i].info['shaping_reward'], + mem[i + 1].obs, False + ]) + episode_rpm.append([ + mem[-1].obs, mem[-1].action, mem[-1].info['shaping_reward'], + last_obs, not mem[-1].info['timeout'] + ]) + + with self.memory_lock: + self.total_steps += n + self.add_episode_rpm(episode_rpm) + + if actor_state.ident % 3 == 2: # trajectory without noise + self.env_reward_stat.add(episode_env_reward) + self.shaping_reward_stat.add(episode_shaping_reward) + self.max_env_reward = max(self.max_env_reward, + episode_env_reward) + + if self.env_reward_stat.count > 500: + tensorboard.add_scalar('recent_env_reward', + self.env_reward_stat.mean, + self.total_steps) + tensorboard.add_scalar('recent_shaping_reward', + self.shaping_reward_stat.mean, + self.total_steps) + if self.critic_loss_stat.count > 500: + tensorboard.add_scalar('recent_critic_loss', + self.critic_loss_stat.mean, + self.total_steps) + tensorboard.add_scalar('episode_length', n, self.total_steps) + tensorboard.add_scalar('max_env_reward', self.max_env_reward, + self.total_steps) + tensorboard.add_scalar('ready_actor_num', + self.ready_actor_queue.qsize(), + self.total_steps) + tensorboard.add_scalar('episode_time', episode_time, + self.total_steps) + + self.noiselevel = self.noiselevel * NOISE_DECAY + + def learn(self): + start_time = time.time() + + for T in range(args.train_times): + [states, actions, rewards, new_states, + dones] = self.rpm.sample_batch(BATCH_SIZE) + with self.model_lock: + critic_loss = self.agent.learn(states, actions, rewards, + new_states, dones) + self.critic_loss_stat.add(critic_loss) + logger.info( + "[learn] time consuming:{}".format(time.time() - start_time)) + + def keep_training(self): + episode_count = 1000000 + for T in range(episode_count): + if self.rpm.size() > BATCH_SIZE * args.warm_start_batchs: + self.learn() + logger.info( + "[keep_training/{}] trying to acq a new env".format(T)) + + # Keep training and predicting balance + # After training, wait for a ready actor, and make the actor start new episode + ready_actor_event = self.ready_actor_queue.get() + ready_actor_event.set() + + if np.mod(T, 100) == 0: + logger.info("saving models") + self.save(T) + if np.mod(T, 10000) == 0: + logger.info("saving rpm") + self.save_rpm() + + def save_rpm(self): + save_path = os.path.join(logger.get_dir(), "rpm.npz") + self.rpm.save(save_path) + + def save(self, T): + save_path = os.path.join( + logger.get_dir(), 'model_every_100_episodes/episodes-{}'.format(T)) + self.agent.save(save_path) + + def restore(self, model_path): + logger.info('restore model from {}'.format(model_path)) + self.agent.restore(model_path) + + def add_episode_rpm(self, episode_rpm): + for x in episode_rpm: + self.rpm.append( + obs=x[0], act=x[1], reward=x[2], next_obs=x[3], terminal=x[4]) + + def pred_batch(self, obs): + batch_obs = np.expand_dims(obs, axis=0) + + with self.model_lock: + action = self.agent.predict(batch_obs.astype('float32')) + + action = np.squeeze(action, axis=0) + return action + + +if __name__ == '__main__': + from train_args import get_args + args = get_args() + if args.logdir is not None: + logger.set_dir(args.logdir) + + learner = Learner(args) diff --git a/examples/NeurIPS2019-Learn-to-Move-Challenge/train_args.py b/examples/NeurIPS2019-Learn-to-Move-Challenge/train_args.py new file mode 100644 index 0000000000000000000000000000000000000000..501abb4ee3f5ca599dd865e04a9a2d2a1170ddd0 --- /dev/null +++ b/examples/NeurIPS2019-Learn-to-Move-Challenge/train_args.py @@ -0,0 +1,87 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + '--cluster_address', + default='localhost:8081', + type=str, + help='cluster address of xparl.') + parser.add_argument( + '--actor_num', type=int, required=True, help='number of actors.') + parser.add_argument( + '--logdir', + type=str, + default='logdir', + help='directory to save model/tensorboard data') + + parser.add_argument( + '--difficulty', + type=int, + required=True, + help= + 'difficulty of L2M2019Env. difficulty=1 means Round 1 environment but target theta is always 0; difficulty=2 menas Round 1 environment; difficulty=3 means Round 2 environment.' + ) + parser.add_argument( + '--vel_penalty_coeff', + type=float, + default=1.0, + help='coefficient of velocity penalty in reward shaping.') + parser.add_argument( + '--muscle_penalty_coeff', + type=float, + default=1.0, + help='coefficient of muscle penalty in reward shaping.') + parser.add_argument( + '--penalty_coeff', + type=float, + default=1.0, + help='coefficient of all penalty in reward shaping.') + parser.add_argument( + '--only_first_target', + action="store_true", + help= + 'if set, will terminate the environment run after the first target finished.' + ) + + parser.add_argument( + '--rpm_size', + type=lambda x: int(float(x)), + default=int(2e6), + help='size of replay memory.') + parser.add_argument( + '--train_times', + type=int, + default=100, + help='training times (batches) when finishing an episode.') + parser.add_argument( + '--restore_model_path', + type=str, + help='restore model path for warm start') + parser.add_argument( + '--restore_rpm_path', type=str, help='restore rpm path for warm start') + parser.add_argument( + '--warm_start_batchs', + type=int, + default=2000, + help='collect how many batch data to warm start') + + args = parser.parse_args() + + return args diff --git a/parl/__init__.py b/parl/__init__.py index 618ebe05dd92b706bc0effad806df6a76738f6d5..7d3c26a00c4671f6aef2810a78e1f92bccaf35ed 100644 --- a/parl/__init__.py +++ b/parl/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.2" +__version__ = "1.2.1" """ generates new PARL python API """ diff --git a/parl/algorithms/fluid/ddpg.py b/parl/algorithms/fluid/ddpg.py index b0319d0655e40879cc3b64af6d4eee833deec97f..c127109c7d92f3f5b6e42d4eac25a796ae0c89ae 100644 --- a/parl/algorithms/fluid/ddpg.py +++ b/parl/algorithms/fluid/ddpg.py @@ -115,7 +115,10 @@ class DDPG(Algorithm): optimizer.minimize(cost) return cost - def sync_target(self, gpu_id=None, decay=None): + def sync_target(self, + gpu_id=None, + decay=None, + share_vars_parallel_executor=None): if gpu_id is not None: warnings.warn( "the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DDPG` is deprecated since version 1.2 and will be removed in version 1.3.", @@ -123,4 +126,7 @@ class DDPG(Algorithm): stacklevel=2) if decay is None: decay = 1.0 - self.tau - self.model.sync_weights_to(self.target_model, decay=decay) + self.model.sync_weights_to( + self.target_model, + decay=decay, + share_vars_parallel_executor=share_vars_parallel_executor)