提交 28930a86 编写于 作者: N niuyazhe

fix(nyz): fix r2d2 and dqtd error unittest bug

上级 286ea243
......@@ -188,7 +188,7 @@ def test_r2d2():
config = [deepcopy(cartpole_r2d2_config), deepcopy(cartpole_r2d2_create_config)]
config[0].policy.learn.update_per_collect = 1
try:
serial_pipeline(config, seed=0, max_iterations=1)
serial_pipeline(config, seed=0, max_iterations=5)
except Exception:
assert False, "pipeline fail"
......
......@@ -223,7 +223,7 @@ class R2D2Policy(Policy):
else:
data['value_gamma'] = data['value_gamma'][bs:]
if 'weight' not in data:
if 'weight' not in data or data['weight'] is None:
data['weight'] = [None for _ in range(self._unroll_len_add_burnin_step - bs)]
else:
data['weight'] = data['weight'] * torch.ones_like(data['done'])
......
......@@ -13,6 +13,7 @@ class TestAdder:
return {
'value': torch.randn(1),
'reward': torch.rand(1),
'action': torch.rand(3),
'other': np.random.randint(0, 10, size=(4, )),
'obs': torch.randn(3),
'done': False
......@@ -22,6 +23,7 @@ class TestAdder:
return {
'value': torch.randn(1, 8),
'reward': torch.rand(1, 1),
'action': torch.rand(3),
'other': np.random.randint(0, 10, size=(4, )),
'obs': torch.randn(3),
'done': False
......
......@@ -234,7 +234,7 @@ def test_dqfd_nstep_td():
q, next_q, action, next_action, reward, done, done_1, None, next_q_one_step, next_action_one_step, is_expert
)
loss, td_error_per_sample = dqfd_nstep_td_error(
data, 0.95, lambda1=1, lambda2=1, margin_function=0.8, nstep=nstep
data, 0.95, lambda_n_step_td=1, lambda_supervised_loss=1, margin_function=0.8, nstep=nstep
)
assert td_error_per_sample.shape == (batch_size, )
assert loss.shape == ()
......
......@@ -14,6 +14,7 @@ cartpole_r2d2_config = dict(
policy=dict(
cuda=False,
priority=False,
priority_IS_weight=False,
model=dict(
obs_shape=4,
action_shape=2,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册