未验证 提交 c8dac674 编写于 作者: 蒲源 提交者: GitHub

fix(pu): fix r2d2 bug (#36)

* test rnd

* fix mz config

* fix config

* fix(pu): fix r2d2

* feature(puyuan): add minigrid r2d2 config

* polish minigrid config

* modified as review

* fix(pu): fix bugffor compatibility

* polish(pu): add annotations and polish slice operation

* style(pu): run format.sh

* style(pu): correct yapf format
上级 fae79267
......@@ -78,10 +78,13 @@ class R2D2Policy(Policy):
# (float) Reward's future discount factor, aka. gamma.
discount_factor=0.997,
# (int) N-step reward for target q_value estimation
nstep=3,
nstep=5,
# (int) the timestep of burnin operation, which is designed to RNN hidden state difference
# caused by off-policy
burnin_step=2,
# (int) the trajectory length to unroll the RNN network minus
# the timestep of burnin operation
unroll_len=80,
learn=dict(
# (bool) Whether to use multi gpu
multi_gpu=False,
......@@ -99,7 +102,7 @@ class R2D2Policy(Policy):
),
collect=dict(
# (int) Only one of [n_sample, n_episode] shoule be set
# n_sample=64,
n_sample=64,
# `env_num` is used in hidden state, should equal to that one in env config.
# User should specify this value in user config.
env_num=None,
......@@ -176,24 +179,36 @@ class R2D2Policy(Policy):
data = timestep_collate(data)
if self._cuda:
data = to_device(data, self._device)
assert len(data['obs']) == 2 * self._nstep + self._burnin_step, data['obs'].shape # todo: why 2*a+b
bs = self._burnin_step
data['weight'] = data.get('weight', [None for _ in range(self._nstep)])
# data['done'], data['weight'], data['value_gamma'] is used in def _forward_learn() to calculate
# the q_nstep_td_error, should be length of [self._unroll_len_add_burnin_step-self._burnin_step-self._nstep]
ignore_done = self._cfg.learn.ignore_done
if ignore_done:
data['done'] = [None for _ in range(self._nstep)]
data['done'] = [None for _ in range(self._unroll_len_add_burnin_step - bs - self._nstep)]
else:
data['done'] = data['done'][bs:bs + self._nstep].float()
data['action'] = data['action'][bs:bs + self._nstep]
data['reward'] = data['reward'][bs:]
# split obs into three parts ['burnin_obs'(0~bs), 'main_obs'(bs~bs+nstep), 'target_obs'(bs+nstep~bss+2nstep)]
data['burnin_obs'] = data['obs'][:bs]
data['main_obs'] = data['obs'][bs:bs + self._nstep]
data['target_obs'] = data['obs'][bs + self._nstep:]
data['done'] = data['done'][bs:-self._nstep].float()
# if the data don't include 'weight' or 'value_gamma' then fill in None in a list
# with length of [self._unroll_len_add_burnin_step-self._burnin_step-self._nstep],
# below is two different implementation ways
if 'value_gamma' not in data:
data['value_gamma'] = [None for _ in range(self._nstep)]
data['value_gamma'] = [None for _ in range(self._unroll_len_add_burnin_step - bs - self._nstep)]
else:
data['value_gamma'] = data['value_gamma'][bs:bs + self._nstep]
data['value_gamma'] = data['value_gamma'][bs:-self._nstep]
data['weight'] = data.get('weight', [None for _ in range(self._unroll_len_add_burnin_step - bs - self._nstep)])
data['action'] = data['action'][bs:-self._nstep]
data['reward'] = data['reward'][bs:-self._nstep]
# the burnin_obs is used to calculate the init hidden state for the calculation of the q_value
data['burnin_obs'] = data['obs'][:bs]
# the main_obs is used to calculate the q_value, the [bs:-self._nstep] means using the data from
# [bs] timestep to [self._unroll_len_add_burnin_step-self._nstep] timestep
data['main_obs'] = data['obs'][bs:-self._nstep]
# the target_obs is used to calculate the target_q_value
data['target_obs'] = data['obs'][bs + self._nstep:]
return data
def _forward_learn(self, data: dict) -> Dict[str, Any]:
......@@ -235,7 +250,7 @@ class R2D2Policy(Policy):
loss = []
td_error = []
value_gamma = data['value_gamma']
for t in range(self._nstep):
for t in range(self._unroll_len_add_burnin_step - self._burnin_step - self._nstep):
td_data = q_nstep_td_data(
q_value[t], target_q_value[t], action[t], target_q_action[t], reward[t], done[t], weight[t]
)
......@@ -284,7 +299,9 @@ class R2D2Policy(Policy):
self._nstep = self._cfg.nstep
self._burnin_step = self._cfg.burnin_step
self._gamma = self._cfg.discount_factor
self._unroll_len = self._burnin_step + 2 * self._nstep
self._unroll_len_add_burnin_step = self._cfg.unroll_len + self._cfg.burnin_step
self._unroll_len = self._unroll_len_add_burnin_step # for compatibility
self._collect_model = model_wrap(
self._model, wrapper_name='hidden_state', state_num=self._cfg.collect.env_num, save_prev_state=True
)
......@@ -351,7 +368,7 @@ class R2D2Policy(Policy):
- samples (:obj:`dict`): The training samples generated
"""
data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)
return get_train_sample(data, self._unroll_len)
return get_train_sample(data, self._unroll_len_add_burnin_step)
def _init_eval(self) -> None:
r"""
......
from easydict import EasyDict
from ding.entry import serial_pipeline
collector_env_num = 8
evaluator_env_num = 5
lunarlander_r2d2_config = dict(
exp_name='lunarlander_r2d2_bs2_n5_ul40',
env=dict(
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
......@@ -16,13 +17,14 @@ lunarlander_r2d2_config = dict(
model=dict(
obs_shape=8,
action_shape=4,
hidden_size_list=[128, 128, 64],
encoder_hidden_size_list=[128, 128, 64],
),
discount_factor=0.999,
burnin_step=20,
nstep=2,
burnin_step=2,
nstep=5,
unroll_len=40, #80,
learn=dict(
update_per_collect=4,
update_per_collect=20,
batch_size=64,
learning_rate=0.0005,
target_update_freq=100,
......@@ -54,3 +56,6 @@ lunarlander_r2d2_create_config = dict(
)
lunarlander_r2d2_create_config = EasyDict(lunarlander_r2d2_create_config)
create_config = lunarlander_r2d2_create_config
if __name__ == "__main__":
serial_pipeline([main_config, create_config], seed=0)
\ No newline at end of file
......@@ -21,6 +21,9 @@ cartpole_r2d2_config = dict(
discount_factor=0.997,
burnin_step=5,
nstep=5,
# (int) the trajectory length to unroll the RNN network minus
# the timestep of burnin operation
unroll_len=40,
learn=dict(
update_per_collect=4,
batch_size=64,
......
......@@ -56,4 +56,4 @@ minigrid_ppo_rnd_create_config = EasyDict(minigrid_ppo_rnd_create_config)
create_config = minigrid_ppo_rnd_create_config
if __name__ == "__main__":
serial_pipeline_reward_model([main_config, create_config], seed=0)
serial_pipeline_reward_model([main_config, create_config], seed=0)
\ No newline at end of file
from easydict import EasyDict
from ding.entry import serial_pipeline
collector_env_num = 8
evaluator_env_num = 5
minigrid_r2d2_config = dict(
exp_name='minigrid_empty8_r2d2_bs2_n5',
env=dict(
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
env_id='MiniGrid-Empty-8x8-v0',
n_evaluator_episode=5,
stop_value=0.96,
),
policy=dict(
cuda=False,
on_policy=False,
priority=False,
model=dict(
obs_shape=2739,
action_shape=7,
encoder_hidden_size_list=[256, 128, 64, 64],
),
discount_factor=0.999,
burnin_step=2,
nstep=5,
unroll_len=80,
learn=dict(
update_per_collect=20,
batch_size=64,
learning_rate=0.0005,
target_update_freq=100,
),
collect=dict(
n_sample=32,
env_num=collector_env_num,
),
eval=dict(env_num=evaluator_env_num, ),
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.05,
decay=10000,
), replay_buffer=dict(replay_buffer_size=10000, )
),
),
)
minigrid_r2d2_config = EasyDict(minigrid_r2d2_config)
main_config = minigrid_r2d2_config
minigrid_r2d2_create_config = dict(
env=dict(
type='minigrid',
import_names=['dizoo.minigrid.envs.minigrid_env'],
),
env_manager=dict(type='base'),
policy=dict(type='r2d2'),
)
minigrid_r2d2_create_config = EasyDict(minigrid_r2d2_create_config)
create_config = minigrid_r2d2_create_config
if __name__ == "__main__":
serial_pipeline([main_config, create_config], seed=0)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册