未验证 提交 b532b4cd 编写于 作者: 蒲源 提交者: GitHub

polish(pu): polish r2d3 for abs priority (#158)

* polish(pu): polish r2d3

* polish(pu): first abs then sum each item in td-error
上级 b7cd6751
......@@ -352,14 +352,15 @@ class R2D3Policy(Policy):
value_gamma=value_gamma[t],
)
loss.append(l)
td_error.append(e.abs())
# td_error.append(e.abs()) # first sum then abs
td_error.append(e) # first abs then sum
# loss statistics for debugging
loss_nstep.append(loss_statistics[0])
loss_1step.append(loss_statistics[1])
loss_sl.append(loss_statistics[2])
else:
l, e = dqfd_nstep_td_error(
l, e, loss_statistics = dqfd_nstep_td_error(
td_data,
self._gamma,
self.lambda1,
......@@ -371,7 +372,12 @@ class R2D3Policy(Policy):
value_gamma=value_gamma[t],
)
loss.append(l)
td_error.append(e.abs())
# td_error.append(e.abs()) # first sum then abs
td_error.append(e) # first abs then sum
# loss statistics for debugging
loss_nstep.append(loss_statistics[0])
loss_1step.append(loss_statistics[1])
loss_sl.append(loss_statistics[2])
loss = sum(loss) / (len(loss) + 1e-8)
# loss statistics for debugging
......
......@@ -669,8 +669,9 @@ def dqfd_nstep_td_error(
lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample +
lambda_supervised_loss * JE
) * weight
).mean(), lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample +
lambda_supervised_loss * JE
).mean(), lambda_n_step_td * td_error_per_sample.abs() +
lambda_one_step_td * td_error_one_step_per_sample.abs() + lambda_supervised_loss * JE.abs(),
(td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean())
)
......@@ -775,8 +776,9 @@ def dqfd_nstep_td_error_with_rescale(
lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample +
lambda_supervised_loss * JE
) * weight
).mean(), lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample +
lambda_supervised_loss * JE, (td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean())
).mean(), lambda_n_step_td * td_error_per_sample.abs() +
lambda_one_step_td * td_error_one_step_per_sample.abs() + lambda_supervised_loss * JE.abs(),
(td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean())
)
......
......@@ -4,7 +4,7 @@ from ding.entry import serial_pipeline
collector_env_num = 8
evaluator_env_num = 5
pong_r2d2_config = dict(
exp_name='debug_pong_r2d2_n5_bs2_ul40_rbs1e4_seed0',
exp_name='pong_r2d2_n5_bs2_ul40_rbs1e4_seed0',
env=dict(
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
......
......@@ -6,10 +6,11 @@ module_path = os.path.dirname(__file__)
collector_env_num = 8
evaluator_env_num = 5
expert_replay_buffer_size=1000 #TODO 1000
expert_replay_buffer_size = int(5e3) # TODO(pu)
"""agent config"""
pong_r2d3_config = dict(
exp_name='debug_pong_r2d3_offppoexpert_k0_pho1-256_rbs2e4',
exp_name='pong_r2d3_offppoexpert_k0_pho1-4_rbs2e4_ds5e3',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True),
......@@ -62,7 +63,7 @@ pong_r2d3_config = dict(
env_num=collector_env_num,
# The hyperparameter pho, the demo ratio, control the propotion of data coming\
# from expert demonstrations versus from the agent's own experience.
pho=1/256, # 1/256, #TODO(pu), 0.25,
pho=1/4, # TODO(pu)
),
eval=dict(env_num=evaluator_env_num, ),
other=dict(
......@@ -98,7 +99,7 @@ create_config = pong_r2d3_create_config
"""export config"""
expert_pong_r2d3_config = dict(
exp_name='expert_pong_r2d3_ppoexpert_k0_pho1-256_rbs2e4',
exp_name='expert_pong_r2d3_ppoexpert_k0_pho1-4_rbs2e4_ds5e3',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True),
......
......@@ -5,11 +5,11 @@ module_path = os.path.dirname(__file__)
collector_env_num = 8
evaluator_env_num = 5
expert_replay_buffer_size=int(5e3)
expert_replay_buffer_size = int(5e3) # TODO(pu)
"""agent config"""
pong_r2d3_config = dict(
exp_name='debug_pong_r2d3_r2d2expert_k0_pho1-4_rbs2e4_ds5e3',
exp_name='pong_r2d3_r2d2expert_k0_pho1-4_rbs2e4_ds5e3',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True),
......@@ -45,7 +45,7 @@ pong_r2d3_config = dict(
# in most environments
value_rescale=True,
update_per_collect=8,
batch_size=64, # TODO(pu)
batch_size=64,
learning_rate=0.0005,
target_update_theta=0.001,
# DQFD related parameters
......@@ -97,7 +97,7 @@ create_config = pong_r2d3_create_config
"""export config"""
expert_pong_r2d3_config = dict(
exp_name='expert_pong_r2d3_r2d2expert_k0_pho1-4_rbs1e4_ds5e3',
exp_name='expert_pong_r2d3_r2d2expert_k0_pho1-4_rbs2e4_ds5e3',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册