未验证 提交 a2edf6a2 编写于 作者: W Will-Nie 提交者: GitHub

polish(nyp):add R2d2 comments (#149)

* add comments for r2d2

* sort style

* revise according to the comments

* fix style
上级 490691fb
......@@ -678,11 +678,14 @@ class DRQN(nn.Module):
"""
x, prev_state = inputs['obs'], inputs['prev_state']
# for both inference and other cases, the network structure is encoder -> rnn network -> head
# the difference is inference take the data with seq_len=1 (or T = 1)
if inference:
x = self.encoder(x)
x = x.unsqueeze(0)
x = x.unsqueeze(0) # for rnn input, put the seq_len of x as 1 instead of none.
# prev_state: DataType: List[Tuple[torch.Tensor]]; Initially, it is a list of None
x, next_state = self.rnn(x, prev_state)
x = x.squeeze(0)
x = x.squeeze(0) # to delete the seq_len dim to match head network input
x = self.head(x)
x['next_state'] = next_state
return x
......@@ -700,11 +703,14 @@ class DRQN(nn.Module):
saved_hidden_state.append(prev_state)
lstm_embedding.append(output)
hidden_state = list(zip(*prev_state)) # {list: 2{tuple: B{Tensor:(1, 1, head_hidden_size}}}
# only keep ht, {list: x.shape[0]{Tensor:(1, batch_size, head_hidden_size)}}
hidden_state_list.append(torch.cat(hidden_state[0], dim=1))
x = torch.cat(lstm_embedding, 0) # (T, B, head_hidden_size)
x = parallel_wrapper(self.head)(x) # (T, B, action_shape)
x['next_state'] = prev_state # the last timestep state including h and c
x['hidden_state'] = torch.cat(hidden_state_list, dim=-3) # the all hidden state h
# the last timestep state including h and c for lstm, {list: B{tuple: 2{Tensor:(1, 1, head_hidden_size}}}
x['next_state'] = prev_state
# all hidden state h, this returns a tensor of the dim: seq_len*batch_size*head_hidden_size
x['hidden_state'] = torch.cat(hidden_state_list, dim=-3)
if saved_hidden_state_timesteps is not None:
x['saved_hidden_state'] = saved_hidden_state # the selected saved hidden states, including h and c
return x
......
......@@ -89,18 +89,20 @@ class HiddenStateWrapper(IModelWrapper):
"""
super().__init__(model)
self._state_num = state_num
# This is to maintain hidden states (when it comes to this wrapper, \
# map self._state into data['prev_value] and update next_state, store in self._state)
self._state = {i: init_fn() for i in range(state_num)}
self._save_prev_state = save_prev_state
self._init_fn = init_fn
def forward(self, data, **kwargs):
state_id = kwargs.pop('data_id', None)
valid_id = kwargs.pop('valid_id', None)
data, state_info = self.before_forward(data, state_id)
valid_id = kwargs.pop('valid_id', None) # None, not used in any code in DI-engine
data, state_info = self.before_forward(data, state_id) # update data['prev_state'] with self._state
output = self._model.forward(data, **kwargs)
h = output.pop('next_state', None)
if h is not None:
self.after_forward(h, state_info, valid_id)
self.after_forward(h, state_info, valid_id) # this is to store the 'next hidden state' for each time step
if self._save_prev_state:
prev_state = get_tensor_data(data['prev_state'])
output['prev_state'] = prev_state
......
......@@ -229,11 +229,13 @@ class R2D2Policy(Policy):
data['weight'] = data['weight'] * torch.ones_like(data['done'])
# every timestep in sequence has same weight, which is the _priority_IS_weight in PER
data['action'] = data['action'][bs:-self._nstep]
data['reward'] = data['reward'][bs:-self._nstep]
data['action'] = data['action'][bs:-self._nstep] # cut the seq_len from burn_in step to (seq_len - nstep) step
data['reward'] = data['reward'][bs:-self._nstep] # cut the seq_len from burn_in step to (seq_len - nstep) step
# the burnin_nstep_obs is used to calculate the init hidden state of rnn for the calculation of the q_value,
# target_q_value, and target_q_action
# these slicing are all done in the outermost layer, which is the seq_len dim
data['burnin_nstep_obs'] = data['obs'][:bs + self._nstep]
# the main_obs is used to calculate the q_value, the [bs:-self._nstep] means using the data from
# [bs] timestep to [self._unroll_len_add_burnin_step-self._nstep] timestep
......@@ -259,10 +261,11 @@ class R2D2Policy(Policy):
- total_loss (:obj:`float`): The calculated loss
"""
# forward
data = self._data_preprocess_learn(data)
data = self._data_preprocess_learn(data) # output datatype: Dict
self._learn_model.train()
self._target_model.train()
# use the hidden state in timestep=0
# note the reset method is performed at the hidden state wrapper, to reset self._state.
self._learn_model.reset(data_id=None, state=data['prev_state'][0])
self._target_model.reset(data_id=None, state=data['prev_state'][0])
......@@ -271,7 +274,8 @@ class R2D2Policy(Policy):
inputs = {'obs': data['burnin_nstep_obs'], 'enable_fast_timestep': True}
burnin_output = self._learn_model.forward(
inputs, saved_hidden_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep]
)
) # keys include 'logit', 'hidden_state' 'saved_hidden_state', \
# 'action', for their specific dim, please refer to DRQN model
burnin_output_target = self._target_model.forward(
inputs, saved_hidden_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep]
)
......@@ -307,6 +311,8 @@ class R2D2Policy(Policy):
else:
l, e = q_nstep_td_error(td_data, self._gamma, self._nstep, value_gamma=value_gamma[t])
loss.append(l)
# td will be a list of the length (self._unroll_len_add_burnin_step - self._burnin_step - self._nstep)
# and each value is a tensor of the size batch_size
td_error.append(e.abs())
loss = sum(loss) / (len(loss) + 1e-8)
......@@ -314,6 +320,7 @@ class R2D2Policy(Policy):
td_error_per_sample = 0.9 * torch.max(
torch.stack(td_error), dim=0
)[0] + (1 - 0.9) * (torch.sum(torch.stack(td_error), dim=0) / (len(td_error) + 1e-8))
# torch.max(torch.stack(td_error), dim=0) will return tuple like thing, please refer to torch.max
# td_error shape list(<self._unroll_len_add_burnin_step-self._burnin_step-self._nstep>, B), for example, (75,64)
# torch.sum(torch.stack(td_error), dim=0) can also be replaced with sum(td_error)
......@@ -332,7 +339,7 @@ class R2D2Policy(Policy):
return {
'cur_lr': self._optimizer.defaults['lr'],
'total_loss': loss.item(),
'priority': td_error_per_sample.abs().tolist(),
'priority': td_error_per_sample.tolist(), # note abs operation has been performed above
# the first timestep in the sequence, may not be the start of episode
'q_s_taken-a_t0': q_s_a_t0.mean().item(),
'target_q_s_max-a_t0': target_q_s_a_t0.mean().item(),
......@@ -365,6 +372,8 @@ class R2D2Policy(Policy):
self._unroll_len_add_burnin_step = self._cfg.unroll_len + self._cfg.burnin_step
self._unroll_len = self._unroll_len_add_burnin_step # for compatibility
# for r2d2, this hidden_state wrapper is to add the 'prev hidden state' for each transition.
# Note that collect env forms a batch and the key is added for the batch simultaneously.
self._collect_model = model_wrap(
self._model, wrapper_name='hidden_state', state_num=self._cfg.collect.env_num, save_prev_state=True
)
......
......@@ -125,7 +125,7 @@ def timestep_collate(batch: List[Dict[str, Any]]) -> Dict[str, Union[torch.Tenso
prev_state = [b.pop('prev_state') for b in batch]
batch_data = default_collate(batch) # -> {some_key: T lists}, each list is [B, some_dim]
batch_data = stack(batch_data) # -> {some_key: [T, B, some_dim]}
batch_data['prev_state'] = list(zip(*prev_state))
batch_data['prev_state'] = list(zip(*prev_state)) # permute batch size dim with sequence len dim
# append back prev_state, avoiding multi batch share the same data bug
for i in range(len(batch)):
batch[i]['prev_state'] = prev_state[i]
......
......@@ -257,6 +257,14 @@ class SampleSerialCollector(ISerialCollector):
self._total_envstep_count += 1
# prepare data
if timestep.done or len(self._traj_buffer[env_id]) == self._traj_len:
# for r2d2:
# 1. for each collect_env, we want to collect data of the length self._traj_len
# except when it comes to a done.
# 2. however, even if timestep is done and assume we only collected 9 transitions,
# by going through self._policy.get_train_sample, it will be padded automatically.
# 3. so, a unit of train transition for r2d2 will have seq len
# (burnin + nstep) (collected_sample=1), and we need to collect n_sample.
# Episode is done or traj_buffer(maxlen=traj_len) is full.
transitions = to_tensor_transitions(self._traj_buffer[env_id])
train_sample = self._policy.get_train_sample(transitions)
......
......@@ -30,7 +30,7 @@ cartpole_r2d2_config = dict(
learn=dict(
# according to the R2D2 paper, actor parameter update interval is 400
# environment timesteps, and in per collect phase, we collect 32 sequence
# samples, the length of each samlpe sequence is <burnin_step> + <unroll_len>,
# samples, the length of each sample sequence is <burnin_step> + <unroll_len>,
# which is 100 in our seeting, 32*100/400=8, so we set update_per_collect=8
# in most environments
update_per_collect=8,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册