提交 b249dee3 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

add testing module of NeurIPS2018-AI-for-Prosthetics-Challenge (#32)

* add testing module of NeurIPS2018-AI-for-Prosthetics-Challenge, add dependencies of setup

* add copyright

* add google drive link

* fix depedencie

* refine setup
上级 ec005b50
## The Winning Solution for the NeurIPS 2018: AI for Prosthetics Challenge
This folder will contains the code used to train the winning models for the [NeurIPS 2018: AI for Prosthetics Challenge](https://www.crowdai.org/challenges/neurips-2018-ai-for-prosthetics-challenge) along with the resulting models. (Codes of training part is organizing, but the resulting models is available now.)
### Dependencies
- python3.6
- [paddlepaddle>=1.2.0](https://github.com/PaddlePaddle/Paddle)
- [PARL](https://github.com/PaddlePaddle/PARL)
- [osim-rl](https://github.com/stanfordnmbl/osim-rl)
### Start Testing best models
- How to Run
```bash
# cd current directory
# install best models file (saved_model.tar.gz)
tar zxvf saved_model.tar.gz
python test.py
```
> You can install models file from [Baidu Pan](https://pan.baidu.com/s/1NN1auY2eDblGzUiqR8Bfqw) or [Google Drive](https://drive.google.com/open?id=1DQHrwtXzgFbl9dE7jGOe9ZbY0G9-qfq3)
- More arguments
```bash
# Run with GPU
python test.py --use_cuda
# Visulize the game
python test.py --vis
# Set the random seed
python test.py --seed 1024
# Set the episode number to run
python test.py --episode_num 2
```
### Start Training
- [ ] To be Done
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import copy
import gym
import math
import numpy as np
from collections import OrderedDict
from osim.env import ProstheticsEnv
from parl.utils import logger
MAXTIME_LIMIT = 1000
ProstheticsEnv.time_limit = MAXTIME_LIMIT
FRAME_SKIP = None
FALL_PENALTY = 0
class RemoteEnv(gym.Wrapper):
def __init__(self, env):
env.metadata = {}
env.action_space = None
env.observation_space = None
env.reward_range = None
gym.Wrapper.__init__(self, env)
self.remote_env = env
self.first_time = True
def step(self, act):
return self.remote_env.env_step(act.tolist())
def reset(self):
if self.first_time:
self.first_time = False
return self.remote_env.env_create()
obs = self.remote_env.env_reset()
if not obs:
return None
return obs
def calc_vel_diff(state_desc):
cur_vel_x = state_desc['body_vel']['pelvis'][0]
cur_vel_z = state_desc['body_vel']['pelvis'][2]
target_vel_x = state_desc['target_vel'][0]
target_vel_z = state_desc['target_vel'][2]
diff_vel_x = cur_vel_x - target_vel_x
diff_vel_z = cur_vel_z - target_vel_z
cur_vel = (cur_vel_x**2 + cur_vel_z**2)**0.5
target_vel = (target_vel_x**2 + target_vel_z**2)**0.5
diff_vel = cur_vel - target_vel
target_theta = math.atan(-1.0 * target_vel_z / target_vel_x)
# alone y axis
cur_theta = state_desc['body_pos_rot']['pelvis'][1]
diff_theta = cur_theta - target_theta
return cur_vel_x, cur_vel_z, diff_vel_x, diff_vel_z, diff_vel, diff_theta
class ActionScale(gym.Wrapper):
def __init__(self, env):
gym.Wrapper.__init__(self, env)
def step(self, action, **kwargs):
action = (np.copy(action) + 1.0) * 0.5
action = np.clip(action, 0.0, 1.0)
return self.env.step(action, **kwargs)
def reset(self, **kwargs):
return self.env.reset(**kwargs)
class FrameSkip(gym.Wrapper):
def __init__(self, env, k):
gym.Wrapper.__init__(self, env)
self.frame_skip = k
global FRAME_SKIP
FRAME_SKIP = k
self.frame_count = 0
def step(self, action, **kwargs):
r = 0.0
merge_info = {}
for k in range(self.frame_skip):
self.frame_count += 1
obs, reward, done, info = self.env.step(action, **kwargs)
r += reward
for key in info.keys():
if 'reward' in key:
# to assure that we don't igonre other reward
# if new reward was added, consider its logic here
assert (key == 'shaping_reward') or (key == 'r2_reward')
merge_info[key] = merge_info.get(key, 0.0) + info[key]
else:
merge_info[key] = info[key]
if info['target_changed']:
#merge_info['shaping_reward'] += info['shaping_reward'] * (self.frame_skip - k - 1)
logger.warn("[FrameSkip] early break since target was changed")
break
if done:
break
merge_info['frame_count'] = self.frame_count
return obs, r, done, merge_info
def reset(self, **kwargs):
self.frame_count = 0
return self.env.reset(**kwargs)
class RewardShaping(gym.Wrapper):
""" A wrapper for reward shaping, note this wrapper must be the first wrapper """
def __init__(self, env):
logger.info("[RewardShaping]type:{}".format(type(env)))
self.step_count = 0
self.pre_state_desc = None
self.last_target_vel = None
self.last_target_change_step = 0
self.target_change_times = 0
gym.Wrapper.__init__(self, env)
@abc.abstractmethod
def reward_shaping(self, state_desc, reward, done, action):
"""define your own reward computation function
Args:
state_desc(dict): state description for current model
reward(scalar): generic reward generated by env
done(bool): generic done flag generated by env
"""
pass
def step(self, action, **kwargs):
self.step_count += 1
obs, r, done, info = self.env.step(action, **kwargs)
info = self.reward_shaping(obs, r, done, action)
if info['target_vel'] > 2.75:
rate = math.sqrt((2.75**2) / (info['target_vel']**2))
logger.warn('Changing targets, origin targets: {}'.format(
obs['target_vel']))
obs['target_vel'][0] = obs['target_vel'][0] * rate
obs['target_vel'][2] = obs['target_vel'][2] * rate
logger.warn('Changing targets, new targets: {}'.format(
obs['target_vel']))
info['target_vel'] = 2.75
if info['target_vel'] < -0.25:
rate = math.sqrt(((-0.25)**2) / (info['target_vel']**2))
logger.warn('Changing targets, origin targets: {}'.format(
obs['target_vel']))
obs['target_vel'][0] = obs['target_vel'][0] * rate
obs['target_vel'][2] = obs['target_vel'][2] * rate
logger.warn('Changing targets, new targets: {}'.format(
obs['target_vel']))
info['target_vel'] = -0.25
delta = 0
if self.last_target_vel is not None:
delta = np.absolute(
np.array(self.last_target_vel) - np.array(obs['target_vel']))
if (self.last_target_vel is None) or np.all(delta < 1e-5):
info['target_changed'] = False
else:
info['target_changed'] = True
logger.info("[env_wrapper] target_changed, vx:{} vz:{}".format(
obs['target_vel'][0], obs['target_vel'][2]))
self.last_target_change_step = self.step_count
self.target_change_times += 1
info['target_change_times'] = self.target_change_times
self.last_target_vel = obs['target_vel']
assert 'shaping_reward' in info
timeout = False
if self.step_count >= MAXTIME_LIMIT:
timeout = True
if done and not timeout:
# penalty for falling down
info['shaping_reward'] += FALL_PENALTY
info['timeout'] = timeout
self.pre_state_desc = obs
return obs, r, done, info
def reset(self, **kwargs):
self.step_count = 0
self.last_target_vel = None
self.last_target_change_step = 0
self.target_change_times = 0
obs = self.env.reset(**kwargs)
self.pre_state_desc = obs
return obs
class ForwardReward(RewardShaping):
""" A reward shaping wraper"""
def __init__(self, env):
RewardShaping.__init__(self, env)
def reward_shaping(self, state_desc, r2_reward, done, action):
target_vel = math.sqrt(state_desc["target_vel"][0]**2 +
state_desc["target_vel"][2]**2)
if state_desc["target_vel"][0] < 0:
target_vel = -target_vel
info = {
'shaping_reward': r2_reward,
'target_vel': target_vel,
'r2_reward': r2_reward,
}
return info
class ObsTranformerBase(gym.Wrapper):
def __init__(self, env):
gym.Wrapper.__init__(self, env)
self.step_fea = MAXTIME_LIMIT
self.raw_obs = None
global FRAME_SKIP
self.frame_skip = int(FRAME_SKIP)
def get_observation(self, state_desc):
obs = self._get_observation(state_desc)
if not isinstance(self, PelvisBasedObs):
cur_vel_x, cur_vel_z, diff_vel_x, diff_vel_z, diff_vel, diff_theta = calc_vel_diff(
state_desc)
obs = np.append(obs, [
cur_vel_x, cur_vel_z, diff_vel_x, diff_vel_z, diff_vel,
diff_theta
])
else:
pass
return obs
@abc.abstractmethod
def _get_observation(self, state_desc):
pass
def feature_normalize(self, obs, mean, std, duplicate_id):
scaler_len = mean.shape[0]
assert obs.shape[0] >= scaler_len
obs[:scaler_len] = (obs[:scaler_len] - mean) / std
final_obs = []
for i in range(obs.shape[0]):
if i not in duplicate_id:
final_obs.append(obs[i])
return np.array(final_obs)
def step(self, action, **kwargs):
obs, r, done, info = self.env.step(action, **kwargs)
if info['target_changed']:
# reset step_fea when change target
self.step_fea = MAXTIME_LIMIT
self.step_fea -= FRAME_SKIP
self.raw_obs = copy.deepcopy(obs)
obs = self.get_observation(obs)
self.raw_obs['step_count'] = MAXTIME_LIMIT - self.step_fea
return obs, r, done, info
def reset(self, **kwargs):
obs = self.env.reset(**kwargs)
if obs is None:
return None
self.step_fea = MAXTIME_LIMIT
self.raw_obs = copy.deepcopy(obs)
obs = self.get_observation(obs)
self.raw_obs['step_count'] = MAXTIME_LIMIT - self.step_fea
return obs
class PelvisBasedObs(ObsTranformerBase):
def __init__(self, env):
ObsTranformerBase.__init__(self, env)
data = np.load('./pelvisBasedObs_scaler.npz')
self.mean, self.std, self.duplicate_id = data['mean'], data[
'std'], data['duplicate_id']
self.duplicate_id = self.duplicate_id.astype(np.int32).tolist()
def get_core_matrix(self, yaw):
core_matrix = np.zeros(shape=(3, 3))
core_matrix[0][0] = math.cos(yaw)
core_matrix[0][2] = -1.0 * math.sin(yaw)
core_matrix[1][1] = 1
core_matrix[2][0] = math.sin(yaw)
core_matrix[2][2] = math.cos(yaw)
return core_matrix
def _get_observation(self, state_desc):
o = OrderedDict()
for body_part in [
'pelvis', 'femur_r', 'pros_tibia_r', 'pros_foot_r', 'femur_l',
'tibia_l', 'talus_l', 'calcn_l', 'toes_l', 'torso', 'head'
]:
# position
o[body_part + '_x'] = state_desc['body_pos'][body_part][0]
o[body_part + '_y'] = state_desc['body_pos'][body_part][1]
o[body_part + '_z'] = state_desc['body_pos'][body_part][2]
# velocity
o[body_part + '_v_x'] = state_desc["body_vel"][body_part][0]
o[body_part + '_v_y'] = state_desc["body_vel"][body_part][1]
o[body_part + '_v_z'] = state_desc["body_vel"][body_part][2]
o[body_part + '_x_r'] = state_desc["body_pos_rot"][body_part][0]
o[body_part + '_y_r'] = state_desc["body_pos_rot"][body_part][1]
o[body_part + '_z_r'] = state_desc["body_pos_rot"][body_part][2]
o[body_part + '_v_x_r'] = state_desc["body_vel_rot"][body_part][0]
o[body_part + '_v_y_r'] = state_desc["body_vel_rot"][body_part][1]
o[body_part + '_v_z_r'] = state_desc["body_vel_rot"][body_part][2]
for joint in [
'hip_r', 'knee_r', 'ankle_r', 'hip_l', 'knee_l', 'ankle_l',
'back'
]:
if 'hip' not in joint:
o[joint + '_joint_pos'] = state_desc['joint_pos'][joint][0]
o[joint + '_joint_vel'] = state_desc['joint_vel'][joint][0]
else:
for i in range(3):
o[joint + '_joint_pos_' +
str(i)] = state_desc['joint_pos'][joint][i]
o[joint + '_joint_vel_' +
str(i)] = state_desc['joint_vel'][joint][i]
# In NIPS2017, only use activation
for muscle in sorted(state_desc["muscles"].keys()):
activation = state_desc["muscles"][muscle]["activation"]
if isinstance(activation, float):
activation = [activation]
for i, val in enumerate(activation):
o[muscle + '_activation_' + str(i)] = activation[i]
fiber_length = state_desc["muscles"][muscle]["fiber_length"]
if isinstance(fiber_length, float):
fiber_length = [fiber_length]
for i, val in enumerate(fiber_length):
o[muscle + '_fiber_length_' + str(i)] = fiber_length[i]
fiber_velocity = state_desc["muscles"][muscle]["fiber_velocity"]
if isinstance(fiber_velocity, float):
fiber_velocity = [fiber_velocity]
for i, val in enumerate(fiber_velocity):
o[muscle + '_fiber_velocity_' + str(i)] = fiber_velocity[i]
# z axis of mass have some problem now, delete it later
o['mass_x'] = state_desc["misc"]["mass_center_pos"][0]
o['mass_y'] = state_desc["misc"]["mass_center_pos"][1]
o['mass_z'] = state_desc["misc"]["mass_center_pos"][2]
o['mass_v_x'] = state_desc["misc"]["mass_center_vel"][0]
o['mass_v_y'] = state_desc["misc"]["mass_center_vel"][1]
o['mass_v_z'] = state_desc["misc"]["mass_center_vel"][2]
for key in ['talus_l_y', 'toes_l_y']:
o['touch_indicator_' + key] = np.clip(0.05 - o[key] * 10 + 0.5, 0.,
1.)
o['touch_indicator_2_' + key] = np.clip(0.1 - o[key] * 10 + 0.5,
0., 1.)
# Tranformer
core_matrix = self.get_core_matrix(o['pelvis_y_r'])
pelvis_pos = np.array([o['pelvis_x'], o['pelvis_y'],
o['pelvis_z']]).reshape((3, 1))
pelvis_vel = np.array(
[o['pelvis_v_x'], o['pelvis_v_y'], o['pelvis_v_z']]).reshape((3,
1))
for body_part in [
'mass', 'femur_r', 'pros_tibia_r', 'pros_foot_r', 'femur_l',
'tibia_l', 'talus_l', 'calcn_l', 'toes_l', 'torso', 'head'
]:
# rotation
if body_part != 'mass':
o[body_part + '_y_r'] -= o['pelvis_y_r']
o[body_part + '_v_y_r'] -= o['pelvis_v_y_r']
# position/velocity
global_pos = []
global_vel = []
for each in ['_x', '_y', '_z']:
global_pos.append(o[body_part + each])
global_vel.append(o[body_part + '_v' + each])
global_pos = np.array(global_pos).reshape((3, 1))
global_vel = np.array(global_vel).reshape((3, 1))
pelvis_rel_pos = core_matrix.dot(global_pos - pelvis_pos)
w = o['pelvis_v_y_r']
offset = np.array(
[-w * pelvis_rel_pos[2], 0, w * pelvis_rel_pos[0]])
pelvis_rel_vel = core_matrix.dot(global_vel - pelvis_vel) + offset
for i, each in enumerate(['_x', '_y', '_z']):
o[body_part + each] = pelvis_rel_pos[i][0]
o[body_part + '_v' + each] = pelvis_rel_vel[i][0]
for key in ['pelvis_x', 'pelvis_z', 'pelvis_y_r']:
del o[key]
current_v = np.array(state_desc['body_vel']['pelvis']).reshape((3, 1))
pelvis_current_v = core_matrix.dot(current_v)
o['pelvis_v_x'] = pelvis_current_v[0]
o['pelvis_v_z'] = pelvis_current_v[2]
res = np.array(list(o.values()))
res = self.feature_normalize(
res, mean=self.mean, std=self.std, duplicate_id=self.duplicate_id)
feet_dis = ((o['tibia_l_x'] - o['pros_tibia_r_x'])**2 +
(o['tibia_l_z'] - o['pros_tibia_r_z'])**2)**0.5
res = np.append(res, feet_dis)
remaining_time = (self.step_fea -
(MAXTIME_LIMIT / 2.0)) / (MAXTIME_LIMIT / 2.0) * -1.0
#logger.info('remaining_time fea: {}'.format(remaining_time))
res = np.append(res, remaining_time)
# target driven
target_v = np.array(state_desc['target_vel']).reshape((3, 1))
pelvis_target_v = core_matrix.dot(target_v)
diff_vel_x = pelvis_target_v[0] - pelvis_current_v[0]
diff_vel_z = pelvis_target_v[2] - pelvis_current_v[2]
diff_vel = np.sqrt(pelvis_target_v[0] ** 2 + pelvis_target_v[2] ** 2) - \
np.sqrt(pelvis_current_v[0] ** 2 + pelvis_current_v[2] ** 2)
target_vel_x = target_v[0]
target_vel_z = target_v[2]
target_theta = math.atan(-1.0 * target_vel_z / target_vel_x)
current_theta = state_desc['body_pos_rot']['pelvis'][1]
diff_theta = target_theta - current_theta
res = np.append(res, [
diff_vel_x[0] / 3.0, diff_vel_z[0] / 3.0, diff_vel[0] / 3.0,
diff_theta / (np.pi * 3 / 8)
])
return res
if __name__ == '__main__':
from osim.env import ProstheticsEnv
env = ProstheticsEnv(visualize=False)
env.change_model(model='3D', difficulty=1, prosthetic=True)
env = ForwardReward(env)
env = FrameSkip(env, 4)
env = ActionScale(env)
env = PelvisBasedObs(env)
for i in range(64):
observation = env.reset(project=False)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import parl.layers as layers
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from parl.framework.model_base import Model
class ActorModel(Model):
def __init__(self,
obs_dim,
vel_obs_dim,
act_dim,
stage_name=None,
model_id=0,
shared=False):
super(ActorModel, self).__init__()
hid0_size = 800
hid1_size = 400
hid2_size = 200
vel_hid0_size = 200
vel_hid1_size = 400
self.obs_dim = obs_dim
self.vel_obs_dim = vel_obs_dim
# buttom layers
if shared:
scope_name = 'policy_shared'
else:
scope_name = 'policy_identity_{}'.format(model_id)
if stage_name is not None:
scope_name = '{}_{}'.format(stage_name, scope_name)
self.fc0 = layers.fc(
size=hid0_size,
act='tanh',
param_attr=ParamAttr(name='{}/h0/W'.format(scope_name)),
bias_attr=ParamAttr(name='{}/h0/b'.format(scope_name)))
self.fc1 = layers.fc(
size=hid1_size,
act='tanh',
param_attr=ParamAttr(name='{}/h1/W'.format(scope_name)),
bias_attr=ParamAttr(name='{}/h1/b'.format(scope_name)))
self.vel_fc0 = layers.fc(
size=vel_hid0_size,
act='tanh',
param_attr=ParamAttr(name='{}/vel_h0/W'.format(scope_name)),
bias_attr=ParamAttr(name='{}/vel_h0/b'.format(scope_name)))
self.vel_fc1 = layers.fc(
size=vel_hid1_size,
act='tanh',
param_attr=ParamAttr(name='{}/vel_h1/W'.format(scope_name)),
bias_attr=ParamAttr(name='{}/vel_h1/b'.format(scope_name)))
# top layers
scope_name = 'policy_identity_{}'.format(model_id)
if stage_name is not None:
scope_name = '{}_{}'.format(stage_name, scope_name)
self.fc2 = layers.fc(
size=hid2_size,
act='tanh',
param_attr=ParamAttr(name='{}/h2/W'.format(scope_name)),
bias_attr=ParamAttr(name='{}/h2/b'.format(scope_name)))
self.fc3 = layers.fc(
size=act_dim,
act='tanh',
param_attr=ParamAttr(name='{}/means/W'.format(scope_name)),
bias_attr=ParamAttr(name='{}/means/b'.format(scope_name)))
def predict(self, obs):
real_obs = layers.slice(
obs, axes=[1], starts=[0], ends=[self.obs_dim - self.vel_obs_dim])
vel_obs = layers.slice(
obs, axes=[1], starts=[-self.vel_obs_dim], ends=[self.obs_dim])
hid0 = self.fc0(real_obs)
hid1 = self.fc1(hid0)
vel_hid0 = self.vel_fc0(vel_obs)
vel_hid1 = self.vel_fc1(vel_hid0)
concat = layers.concat([hid1, vel_hid1], axis=1)
hid2 = self.fc2(concat)
means = self.fc3(hid2)
return means
class CriticModel(Model):
def __init__(self,
obs_dim,
vel_obs_dim,
act_dim,
stage_name=None,
model_id=0,
shared=False):
super(CriticModel, self).__init__()
hid0_size = 800
hid1_size = 400
vel_hid0_size = 200
vel_hid1_size = 400
self.obs_dim = obs_dim
self.vel_obs_dim = vel_obs_dim
# buttom layers
if shared:
scope_name = 'critic_shared'
else:
scope_name = 'critic_identity_{}'.format(model_id)
if stage_name is not None:
scope_name = '{}_{}'.format(stage_name, scope_name)
self.fc0 = layers.fc(
size=hid0_size,
act='selu',
param_attr=ParamAttr(name='{}/w1/W'.format(scope_name)),
bias_attr=ParamAttr(name='{}/w1/b'.format(scope_name)))
self.fc1 = layers.fc(
size=hid1_size,
act='selu',
param_attr=ParamAttr(name='{}/h1/W'.format(scope_name)),
bias_attr=ParamAttr(name='{}/h1/b'.format(scope_name)))
self.vel_fc0 = layers.fc(
size=vel_hid0_size,
act='selu',
param_attr=ParamAttr(name='{}/vel_h0/W'.format(scope_name)),
bias_attr=ParamAttr(name='{}/vel_h0/b'.format(scope_name)))
self.vel_fc1 = layers.fc(
size=vel_hid1_size,
act='selu',
param_attr=ParamAttr(name='{}/vel_h1/W'.format(scope_name)),
bias_attr=ParamAttr(name='{}/vel_h1/b'.format(scope_name)))
self.act_fc0 = layers.fc(
size=hid1_size,
act='selu',
param_attr=ParamAttr(name='{}/a1/W'.format(scope_name)),
bias_attr=ParamAttr(name='{}/a1/b'.format(scope_name)))
# top layers
scope_name = 'critic_identity_{}'.format(model_id)
if stage_name is not None:
scope_name = '{}_{}'.format(stage_name, scope_name)
self.fc2 = layers.fc(
size=hid1_size,
act='selu',
param_attr=ParamAttr(name='{}/h3/W'.format(scope_name)),
bias_attr=ParamAttr(name='{}/h3/b'.format(scope_name)))
self.fc3 = layers.fc(
size=1,
act='selu',
param_attr=ParamAttr(name='{}/value/W'.format(scope_name)),
bias_attr=ParamAttr(name='{}/value/b'.format(scope_name)))
def predict(self, obs, action):
real_obs = layers.slice(
obs, axes=[1], starts=[0], ends=[self.obs_dim - self.vel_obs_dim])
vel_obs = layers.slice(
obs, axes=[1], starts=[-self.vel_obs_dim], ends=[self.obs_dim])
hid0 = self.fc0(real_obs)
hid1 = self.fc1(hid0)
vel_hid0 = self.vel_fc0(vel_obs)
vel_hid1 = self.vel_fc1(vel_hid0)
a1 = self.act_fc0(action)
concat = layers.concat([hid1, a1, vel_hid1], axis=1)
hid2 = self.fc2(concat)
V = self.fc3(hid2)
V = layers.squeeze(V, axes=[1])
return V
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import parl.layers as layers
from mlp_model import ActorModel, CriticModel
from paddle import fluid
from parl.utils import logger
VEL_OBS_DIM = 4
OBS_DIM = 185 + VEL_OBS_DIM
ACT_DIM = 19
class EnsembleBaseModel(object):
def __init__(self,
model_dirname=None,
stage_name=None,
ensemble_num=12,
use_cuda=False):
self.stage_name = stage_name
self.ensemble_num = ensemble_num
self.actors = []
self.critics = []
for i in range(ensemble_num):
self.actors.append(
ActorModel(
OBS_DIM,
VEL_OBS_DIM,
ACT_DIM,
stage_name=stage_name,
model_id=i))
self.critics.append(
CriticModel(
OBS_DIM,
VEL_OBS_DIM,
ACT_DIM,
stage_name=stage_name,
model_id=i))
self._define_program()
self.place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
self.fluid_executor = fluid.Executor(self.place)
self.fluid_executor.run(self.startup_program)
if model_dirname is not None:
self._load_params(model_dirname)
def _load_params(self, dirname):
logger.info('[{}]: Loading model from {}'.format(
self.stage_name, dirname))
fluid.io.load_params(
executor=self.fluid_executor,
dirname=dirname,
main_program=self.ensemble_predict_program)
def _define_program(self):
self.ensemble_predict_program = fluid.Program()
self.startup_program = fluid.Program()
with fluid.program_guard(self.ensemble_predict_program,
self.startup_program):
obs = layers.data(name='obs', shape=[OBS_DIM], dtype='float32')
action = self._ensemble_predict(obs)
self.ensemble_predict_output = [action]
def _ensemble_predict(self, obs):
actor_outputs = []
for i in range(self.ensemble_num):
actor_outputs.append(self.actors[i].predict(obs))
batch_actions = layers.concat(actor_outputs, axis=0)
batch_obs = layers.expand(obs, expand_times=[self.ensemble_num, 1])
critic_outputs = []
for i in range(self.ensemble_num):
critic_output = self.critics[i].predict(batch_obs, batch_actions)
critic_output = layers.unsqueeze(critic_output, axes=[1])
critic_outputs.append(critic_output)
score_matrix = layers.concat(critic_outputs, axis=1)
# Normalize scores given by each critic
sum_critic_score = layers.reduce_sum(
score_matrix, dim=0, keep_dim=True)
sum_critic_score = layers.expand(
sum_critic_score, expand_times=[self.ensemble_num, 1])
norm_score_matrix = score_matrix / sum_critic_score
actions_mean_score = layers.reduce_mean(
norm_score_matrix, dim=1, keep_dim=True)
best_score_id = layers.argmax(actions_mean_score, axis=0)
best_score_id = layers.cast(best_score_id, dtype='int32')
ensemble_predict_action = layers.gather(batch_actions, best_score_id)
ensemble_predict_action = layers.squeeze(
ensemble_predict_action, axes=[0])
return ensemble_predict_action
def pred_batch(self, obs):
feed = {'obs': obs}
action = self.fluid_executor.run(
self.ensemble_predict_program,
feed=feed,
fetch_list=self.ensemble_predict_output)[0]
return action
class StartModel(EnsembleBaseModel):
def __init__(self, use_cuda):
super(StartModel, self).__init__(
model_dirname='saved_model',
stage_name='stage0',
use_cuda=use_cuda)
class Stage123Model(EnsembleBaseModel):
def __init__(self, use_cuda):
super(Stage123Model, self).__init__(
model_dirname='saved_model',
stage_name='stage123',
use_cuda=use_cuda)
class SubmitModel(object):
def __init__(self, use_cuda=False):
self.start_model = StartModel(use_cuda=use_cuda)
self.stage123_model = Stage123Model(use_cuda=use_cuda)
def pred_batch(self, obs, stage_idx):
batch_obs = np.expand_dims(obs, axis=0).astype('float32')
if stage_idx == 0:
action = self.start_model.pred_batch(batch_obs)
else:
action = self.stage123_model.pred_batch(batch_obs)
return action
if __name__ == '__main__':
submit_model = SubmitModel()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import numpy as np
import time
from env_wrapper import FrameSkip, ActionScale, PelvisBasedObs, ForwardReward
from osim.env import ProstheticsEnv
from parl.utils import logger
from submit_model import SubmitModel
def play_multi_episode(submit_model, episode_num=2, vis=False, seed=0):
np.random.seed(seed)
env = ProstheticsEnv(visualize=vis)
env.change_model(model='3D', difficulty=1, prosthetic=True, seed=seed)
env = ForwardReward(env)
env = FrameSkip(env, 4)
env = ActionScale(env)
env = PelvisBasedObs(env)
all_reward = []
all_shaping_reward = 0
last_frames_count = 0
for e in range(episode_num):
t = time.time()
episode_reward = 0.0
episode_shaping_reward = 0.0
observation = env.reset(project=False)
target_change_times = 0
step = 0
loss = []
while True:
step += 1
action = submit_model.pred_batch(observation, target_change_times)
observation, reward, done, info = env.step(action, project=False)
step_frames = info['frame_count'] - last_frames_count
last_frames_count = info['frame_count']
episode_reward += reward
# we pacle it here to drop the first step after changing
if target_change_times >= 1:
loss.append(10 * step_frames - reward)
if info['target_changed']:
target_change_times = min(target_change_times + 1, 3)
logger.info("[step/{}]reward:{} info:{}".format(
step, reward, info))
episode_shaping_reward += info['shaping_reward']
if done:
break
all_reward.append(episode_reward)
all_shaping_reward += episode_shaping_reward
t = time.time() - t
logger.info(
"[episode/{}] time: {} episode_reward:{} change_loss:{} after_change_loss:{} mean_reward:{}"
.format(e, t, episode_reward, np.sum(loss[:15]), np.sum(loss[15:]),
np.mean(all_reward)))
logger.info("Mean reward:{}".format(np.mean(all_reward)))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--use_cuda', action="store_true", help='If set, will run in gpu 0')
parser.add_argument(
'--vis', action="store_true", help='If set, will visualize.')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument(
'--episode_num', type=int, default=1, help='Episode number to run.')
args = parser.parse_args()
submit_model = SubmitModel(use_cuda=args.use_cuda)
play_multi_episode(
submit_model,
episode_num=args.episode_num,
vis=args.vis,
seed=args.seed)
......@@ -32,4 +32,8 @@ setup(
name='parl',
version=0.1,
packages=_find_packages(),
package_data={'': ['*.so']})
package_data={'': ['*.so']},
install_requires=[
"termcolor>=1.1.0",
],
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册