未验证 提交 1040c9fc 编写于 作者: 蒲源 提交者: GitHub

fix(pu): fix dqfd compatibility (#161)

* polish(pu): polish r2d3

* polish(pu): first abs then sum each item in td-error

* fix(pu): fix dqfd compatibility
上级 b532b4cd
......@@ -209,7 +209,7 @@ class DQFDPolicy(DQNPolicy):
data['is_expert'] # set is_expert flag(expert 1, agent 0)
)
value_gamma = data.get('value_gamma')
loss, td_error_per_sample = dqfd_nstep_td_error(
loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error(
data_n,
self._gamma,
self.lambda1,
......
......@@ -253,7 +253,7 @@ def test_dqfd_nstep_td():
data = dqfd_nstep_td_data(
q, next_q, action, next_action, reward, done, done_1, None, next_q_one_step, next_action_one_step, is_expert
)
loss, td_error_per_sample = dqfd_nstep_td_error(
loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error(
data, 0.95, lambda_n_step_td=1, lambda_supervised_loss=1, margin_function=0.8, nstep=nstep
)
assert td_error_per_sample.shape == (batch_size, )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册