diff --git a/ding/policy/dqfd.py b/ding/policy/dqfd.py index 22108b6ebced0754ffe5a1116e18d086a640f564..d64e75f25e37f53727362b0564af5646f4d659d1 100644 --- a/ding/policy/dqfd.py +++ b/ding/policy/dqfd.py @@ -209,7 +209,7 @@ class DQFDPolicy(DQNPolicy): data['is_expert'] # set is_expert flag(expert 1, agent 0) ) value_gamma = data.get('value_gamma') - loss, td_error_per_sample = dqfd_nstep_td_error( + loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error( data_n, self._gamma, self.lambda1, diff --git a/ding/rl_utils/tests/test_td.py b/ding/rl_utils/tests/test_td.py index 06fbe60d2d8e974dc6a7dc4cb6e8a6f2ce8f6392..3fd32cf0af4efe5630c403a6d87e326d339eb9af 100644 --- a/ding/rl_utils/tests/test_td.py +++ b/ding/rl_utils/tests/test_td.py @@ -253,7 +253,7 @@ def test_dqfd_nstep_td(): data = dqfd_nstep_td_data( q, next_q, action, next_action, reward, done, done_1, None, next_q_one_step, next_action_one_step, is_expert ) - loss, td_error_per_sample = dqfd_nstep_td_error( + loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error( data, 0.95, lambda_n_step_td=1, lambda_supervised_loss=1, margin_function=0.8, nstep=nstep ) assert td_error_per_sample.shape == (batch_size, )