提交 7e51de4f 编写于 作者: N niuyazhe

fix(nyz): simplify onppo with traj_flag

上级 0b46dd24
......@@ -69,9 +69,8 @@ def serial_pipeline_onpolicy(
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
commander = BaseSerialCommander(
cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
cfg.policy.other.commander, learner, collector, evaluator, None, policy.command_mode
)
# ==========
......@@ -80,15 +79,6 @@ def serial_pipeline_onpolicy(
# Learner's before_run hook.
learner.call_hook('before_run')
# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode)
for _ in range(max_iterations):
collect_kwargs = commander.step()
# Evaluate policy performance
......@@ -100,23 +90,7 @@ def serial_pipeline_onpolicy(
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
# Learn policy from collected data
for i in range(cfg.policy.learn.update_per_collect): # update_per_collect=1, for onppo
# Learner will train ``update_per_collect`` times in one iteration.
train_data = new_data
if train_data is None:
# It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
logging.warning(
"Replay buffer's data can only train for {} steps. ".format(i) +
"You can modify data collect config, e.g. increasing n_sample, n_episode."
)
break
learner.train(train_data, collector.envstep)
if learner.policy.get_attribute('priority'):
replay_buffer.update(learner.priority_info)
if cfg.policy.on_policy:
# On-policy algorithm must clear the replay buffer.
replay_buffer.clear()
learner.train(new_data, collector.envstep)
# Learner's after_run hook.
learner.call_hook('after_run')
......
......@@ -17,91 +17,6 @@ from .common_utils import default_preprocess_learn
from ding.utils import dicts_to_lists, lists_to_dicts
def compute_adv(data, last_value, cfg):
# last_value could be the real last value of the last timestep in the whole traj,
# or the next_value sequence for each timesteps.
data = get_gae(data, last_value, gamma=cfg.collect.discount_factor, gae_lambda=cfg.collect.gae_lambda, cuda=False)
# data: list (T timestep, 1 batch) [['value':,'reward':,'adv':], ...,]
return get_nstep_return_data(data,
cfg.nstep) if cfg.nstep_return else get_train_sample(data, cfg.collect.unroll_len)
def dict_data_split_traj_and_compute_adv(data, next_value, cfg):
# because the get_gae function need input the traj data in the same episode not different episodes,
# so we should split the data into traj according to the key 'done' and 'traj_flag' if have, and
# the max_traj_length <cfg.collect.n_sample // cfg.collect.collector_env_num>
# data shape: dict of torch.FloatTensor of thansitions
# {'obs':[torch.FloatTensor], ...,'reward':[torch.FloatTensor],...}
# traj means consequent transitions in one episode,it may be the whole episode or truncated episode,
# or consequent part of one episode, because the restrict of max_traj_len.
processed_data = []
start_index = 0
timesteps = 0
for i in range(data['reward'].shape[0]):
timesteps += 1
traj_data = []
if 'traj_flag' in data.keys():
# for compatibility in mujoco, when ignore done, we should split the data according to the traj_flag
traj_flag = data['traj_flag'][i]
else:
traj_flag = data['done'][i]
if traj_flag: # data['done'][i]: torch.tensor(1.) or True
for k in range(start_index, i + 1):
# transform to shape like this:
# traj_data.append( {'value':data['value'][k] ,'reward':data['reward'][k] ,'adv':data['adv'][k] } )
# if discrete action: traj_data.append({key: data[key][k] for key in data.keys()})
# if continuous action: data['logit'] list(torch.tensor(3200,6)); data['weight'] list
traj_data.append(
{
key: [data[key][logit_index][k] for logit_index in range(len(data[key]))]
if isinstance(data[key], list) and key == 'logit' else data[key][k]
for key in data.keys()
}
)
if data['done'][i]: # if done
next_value[i] = torch.zeros(1)[0].to(data['obs'][0].device)
processed_data.extend(traj_data)
start_index = i + 1
timesteps = 0
continue
if timesteps == cfg.collect.n_sample // cfg.collect.collector_env_num: # equals self._traj_len, e.g. 64
for k in range(start_index, i + 1):
traj_data.append(
{
key: [data[key][logit_index][k] for logit_index in range(len(data[key]))]
if isinstance(data[key], list) and key == 'logit' else data[key][k]
for key in data.keys()
}
)
# traj_data = compute_adv(traj_data, next_value[i], cfg)
if data['done'][i]: # if done
next_value[i] = torch.zeros(1)[0].to(data['obs'][0].device)
processed_data.extend(traj_data)
start_index = i + 1
timesteps = 0
continue
remaining_traj_data = []
for k in range(start_index, i + 1):
remaining_traj_data.append(
{
key: [data[key][logit_index][k] for logit_index in range(len(data[key]))]
if isinstance(data[key], list) and key == 'logit' else data[key][k]
for key in data.keys()
}
)
if data['done'][i]: # if done
next_value[i] = torch.zeros(1)[0].to(data['obs'][0].device)
# add the remaining data, return shape list of dict
data = processed_data + remaining_traj_data
return compute_adv(data, next_value, cfg)
@POLICY_REGISTRY.register('ppo')
class PPOPolicy(Policy):
r"""
......@@ -232,11 +147,6 @@ class PPOPolicy(Policy):
Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \
adv_abs_max, approx_kl, clipfrac
"""
# for transition in data:
# # for compatibility in mujoco, when ignore done, we should split the data according to the traj_flag
# if 'traj_flag' not in transition.keys():
# transition['traj_flag'] = copy.deepcopy(transition['done'])
data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False)
if self._cuda:
data = to_device(data, self._device)
......@@ -257,39 +167,26 @@ class PPOPolicy(Policy):
for epoch in range(self._cfg.learn.epoch_per_collect):
if self._recompute_adv: # new v network compute new value
with torch.no_grad():
# obs = torch.cat([data['obs'], data['next_obs'][-1:]])
value = self._learn_model.forward(data['obs'], mode='compute_critic')['value']
next_value = self._learn_model.forward(data['next_obs'], mode='compute_critic')['value']
if self._value_norm:
value *= self._running_mean_std.std
next_value *= self._running_mean_std.std
data['value'] = value
data['weight'] = [None for i in range(data['reward'].shape[0])]
processed_data = dict_data_split_traj_and_compute_adv(
data, next_value.to(self._device), self._cfg
)
compute_adv_data = gae_data(value, next_value, data['reward'], data['done'], data['traj_flag'])
data['adv'] = gae(compute_adv_data, self._gamma, self._gae_lambda)
processed_data = lists_to_dicts(processed_data)
for k, v in processed_data.items():
if isinstance(v[0], torch.Tensor):
processed_data[k] = torch.stack(v, dim=0)
processed_data['weight'] = None
unnormalized_returns = processed_data['value'] + processed_data['adv']
unnormalized_returns = value + data['adv']
if self._value_norm:
processed_data['value'] = processed_data['value'] / self._running_mean_std.std
processed_data['return'] = unnormalized_returns / self._running_mean_std.std
data['value'] = data['value'] / self._running_mean_std.std
data['return'] = unnormalized_returns / self._running_mean_std.std
self._running_mean_std.update(unnormalized_returns.cpu().numpy())
else:
processed_data['value'] = processed_data['value']
processed_data['return'] = unnormalized_returns
else:
processed_data = data
data['value'] = data['value']
data['return'] = unnormalized_returns
for batch in split_data_generator(processed_data, self._cfg.learn.batch_size, shuffle=True):
for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True):
output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic')
adv = batch['adv']
if self._adv_norm:
......@@ -429,11 +326,9 @@ class PPOPolicy(Policy):
"""
data = to_device(data, self._device)
for transition in data:
# for compatibility in mujoco, when ignore done, we should split the data according to the traj_flag
if 'traj_flag' not in transition.keys():
transition['traj_flag'] = copy.deepcopy(transition['done'])
transition['traj_flag'] = copy.deepcopy(transition['done'])
data[-1]['traj_flag'] = True
# adder is defined in _init_collect
if self._cfg.learn.ignore_done:
data[-1]['done'] = False
......
......@@ -34,10 +34,7 @@ cartpole_ppo_config = dict(
),
eval=dict(
evaluator=dict(
eval_freq=1000,
cfg_type='InteractionSerialEvaluatorDict',
stop_value=195,
n_episode=5,
eval_freq=100,
),
),
),
......
......@@ -11,7 +11,7 @@ pendulum_ppo_config = dict(
policy=dict(
cuda=False,
continuous=True,
recompute_adv=False,
recompute_adv=True,
model=dict(
obs_shape=3,
action_shape=1,
......@@ -31,7 +31,7 @@ pendulum_ppo_config = dict(
clip_ratio=0.2,
adv_norm=False,
value_norm=True,
ignore_done=False,
ignore_done=True,
),
collect=dict(
n_sample=200,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册