未验证 提交 a0435286 编写于 作者: R Robin Chen 提交者: GitHub

polish(nyz): update multi-discrete policies (#167)

上级 2699aa5e
......@@ -31,7 +31,10 @@ class PolicyFactory:
def forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]:
def discrete_random_action(min_val, max_val, shape):
return np.random.randint(min_val, max_val, shape)
action = np.random.randint(min_val, max_val, shape)
if len(action) > 1:
action = list(np.expand_dims(action, axis=1))
return action
def continuous_random_action(min_val, max_val, shape):
bounded_below = min_val != float("inf")
......
......@@ -56,9 +56,10 @@ class MultiDiscreteDQNPolicy(DQNPolicy):
value_gamma = data.get('value_gamma')
if isinstance(q_value, list):
tl_num = len(q_value)
act_num = len(q_value)
loss, td_error_per_sample = [], []
for i in range(tl_num):
q_value_list = []
for i in range(act_num):
td_data = q_nstep_td_data(
q_value[i], target_q_value[i], data['action'][i], target_q_action[i], data['reward'], data['done'],
data['weight']
......@@ -68,8 +69,10 @@ class MultiDiscreteDQNPolicy(DQNPolicy):
)
loss.append(loss_)
td_error_per_sample.append(td_error_per_sample_.abs())
q_value_list.append(q_value[i].mean().item())
loss = sum(loss) / (len(loss) + 1e-8)
td_error_per_sample = sum(td_error_per_sample) / (len(td_error_per_sample) + 1e-8)
q_value_mean = sum(q_value_list) / act_num
else:
data_n = q_nstep_td_data(
q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight']
......@@ -77,6 +80,7 @@ class MultiDiscreteDQNPolicy(DQNPolicy):
loss, td_error_per_sample = q_nstep_td_error(
data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma
)
q_value_mean = q_value.mean().item()
# ====================
# Q-learning update
......@@ -94,5 +98,6 @@ class MultiDiscreteDQNPolicy(DQNPolicy):
return {
'cur_lr': self._optimizer.defaults['lr'],
'total_loss': loss.item(),
'q_value_mean': q_value_mean,
'priority': td_error_per_sample.abs().tolist(),
}
......@@ -34,26 +34,19 @@ class MultiDiscretePPOPolicy(PPOPolicy):
# ====================
return_infos = []
self._learn_model.train()
if self._value_norm:
unnormalized_return = data['adv'] + data['value'] * self._running_mean_std.std
data['return'] = unnormalized_return / self._running_mean_std.std
self._running_mean_std.update(unnormalized_return.cpu().numpy())
else:
data['return'] = data['adv'] + data['value']
for epoch in range(self._cfg.learn.epoch_per_collect):
if self._recompute_adv:
with torch.no_grad():
# obs = torch.cat([data['obs'], data['next_obs'][-1:]])
value = self._learn_model.forward(data['obs'], mode='compute_critic')['value']
next_value = self._learn_model.forward(data['next_obs'], mode='compute_critic')['value']
if self._value_norm:
value *= self._running_mean_std.std
next_value *= self._running_mean_std.std
gae_data_ = gae_data(value, next_value, data['reward'], data['done'])
compute_adv_data = gae_data(value, next_value, data['reward'], data['done'], data['traj_flag'])
# GAE need (T, B) shape input and return (T, B) output
data['adv'] = gae(gae_data_, self._gamma, self._gae_lambda)
data['adv'] = gae(compute_adv_data, self._gamma, self._gae_lambda)
# value = value[:-1]
unnormalized_returns = value + data['adv']
......@@ -65,6 +58,14 @@ class MultiDiscretePPOPolicy(PPOPolicy):
data['value'] = value
data['return'] = unnormalized_returns
else: # don't recompute adv
if self._value_norm:
unnormalized_return = data['adv'] + data['value'] * self._running_mean_std.std
data['return'] = unnormalized_return / self._running_mean_std.std
self._running_mean_std.update(unnormalized_return.cpu().numpy())
else:
data['return'] = data['adv'] + data['value']
for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True):
output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic')
adv = batch['adv']
......
......@@ -67,10 +67,10 @@ class MultiDiscreteRainbowDQNPolicy(RainbowDQNPolicy):
value_gamma=value_gamma
)
else:
tl_num = len(q_dist)
act_num = len(q_dist)
losses = []
td_error_per_samples = []
for i in range(tl_num):
for i in range(act_num):
td_data = dist_nstep_td_data(
q_dist[i], target_q_dist[i], data['action'][i], target_q_action[i], data['reward'], data['done'],
data['weight']
......@@ -87,7 +87,7 @@ class MultiDiscreteRainbowDQNPolicy(RainbowDQNPolicy):
losses.append(td_loss)
td_error_per_samples.append(td_error_per_sample)
loss = sum(losses) / (len(losses) + 1e-8)
td_error_per_sample_mean = sum(td_error_per_samples)
td_error_per_sample_mean = sum(td_error_per_samples) / (len(td_error_per_samples) + 1e-8)
# ====================
# Rainbow update
# ====================
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册