未验证 提交 81602ce9 编写于 作者: 蒲源 提交者: GitHub

polish(pu): add loss statistics and polish r2d3 pong config (#126)

* fix(pu): fix adam weight decay bug

* feature(pu): add pitfall offppo config

* feature(pu): add qbert spaceinvaders pitfall r2d3 config

* fix(pu): fix expert offfppo config in r2d3

* fix(pu): fix pong connfig

* polish(pu): add loss statistics

* fix(pu): fix loss statistics bug

* polish(pu): polish pong r2d3 config

* polish(pu): polish r2d3 pong and lunarlander config

* polish(pu): delete unused files
上级 f88bc0e0
......@@ -123,15 +123,16 @@ ding -m serial -e cartpole -p dqn -s 0
| 23 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_gail -c cartpole_dqn_gail_config.py -s 0 |
| 24 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 |
| 25 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 |
| 26 | [GCL](https://arxiv.org/pdf/1603.00448.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/guided_cost](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/guided_cost_reward_model.py) | python3 lunarlander_gcl_config.py
| 27 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
| 28 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py |
| 29 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py |
| 30 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
| 31 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u mujoco_td3_bc_main.py |
| 32 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [model/template/model_based/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/model_based/mbpo.py) | python3 -u sac_halfcheetah_mopo_default_config.py |
| 33 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 34 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
| 26 | [R2D3](https://arxiv.org/pdf/1909.01387.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/r2d3](https://github.com/opendilab/DI-engine/blob/main/ding/policy/r2d3.py) | python3 -u pong_r2d3_r2d2expert_config.py |
| 27 | [GCL](https://arxiv.org/pdf/1603.00448.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/guided_cost](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/guided_cost_reward_model.py) | python3 lunarlander_gcl_config.py
| 28 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
| 29 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py |
| 30 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py |
| 31 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
| 32 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u mujoco_td3_bc_main.py |
| 33 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [model/template/model_based/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/model_based/mbpo.py) | python3 -u sac_halfcheetah_mopo_default_config.py |
| 34 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 35 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space, which is only label in normal DRL algorithms (1-16)
......
......@@ -148,7 +148,9 @@ class R2D3Policy(Policy):
self._priority = self._cfg.priority
self._priority_IS_weight = self._cfg.priority_IS_weight
self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate, weight_decay=self.lambda3)
self._optimizer = Adam(
self._model.parameters(), lr=self._cfg.learn.learning_rate, weight_decay=self.lambda3, optim_type='adamw'
)
self._gamma = self._cfg.discount_factor
self._nstep = self._cfg.nstep
self._burnin_step = self._cfg.burnin_step
......@@ -316,6 +318,9 @@ class R2D3Policy(Policy):
# T, B, nstep -> T, nstep, B
reward = reward.permute(0, 2, 1).contiguous()
loss = []
loss_nstep = []
loss_1step = []
loss_sl = []
td_error = []
for t in range(self._unroll_len_add_burnin_step - self._burnin_step - self._nstep):
# here t=0 means timestep <self._burnin_step> in the original sample sequence, we minus self._nstep
......@@ -335,7 +340,7 @@ class R2D3Policy(Policy):
)
if self._value_rescale:
l, e = dqfd_nstep_td_error_with_rescale(
l, e, loss_statistics = dqfd_nstep_td_error_with_rescale(
td_data,
self._gamma,
self.lambda1,
......@@ -348,6 +353,10 @@ class R2D3Policy(Policy):
)
loss.append(l)
td_error.append(e.abs())
# loss statistics for debugging
loss_nstep.append(loss_statistics[0])
loss_1step.append(loss_statistics[1])
loss_sl.append(loss_statistics[2])
else:
l, e = dqfd_nstep_td_error(
......@@ -365,6 +374,10 @@ class R2D3Policy(Policy):
td_error.append(e.abs())
loss = sum(loss) / (len(loss) + 1e-8)
# loss statistics for debugging
loss_nstep = sum(loss_nstep) / (len(loss_nstep) + 1e-8)
loss_1step = sum(loss_1step) / (len(loss_1step) + 1e-8)
loss_sl = sum(loss_sl) / (len(loss_sl) + 1e-8)
# using the mixture of max and mean absolute n-step TD-errors as the priority of the sequence
td_error_per_sample = 0.9 * torch.max(
......@@ -388,6 +401,10 @@ class R2D3Policy(Policy):
return {
'cur_lr': self._optimizer.defaults['lr'],
'total_loss': loss.item(),
# loss statistics for debugging
'nstep_loss': loss_nstep.item(),
'1step_loss': loss_1step.item(),
'sl_loss': loss_sl.item(),
'priority': td_error_per_sample.abs().tolist(),
# the first timestep in the sequence, may not be the start of episode
'q_s_taken-a_t0': q_s_a_t0.mean().item(),
......@@ -541,5 +558,6 @@ class R2D3Policy(Policy):
def _monitor_vars_learn(self) -> List[str]:
return super()._monitor_vars_learn() + [
'total_loss', 'priority', 'q_s_taken-a_t0', 'target_q_s_max-a_t0', 'q_s_a-mean_t0'
'total_loss', 'nstep_loss', '1step_loss', 'sl_loss', 'priority', 'q_s_taken-a_t0', 'target_q_s_max-a_t0',
'q_s_a-mean_t0'
]
......@@ -739,7 +739,7 @@ def dqfd_nstep_td_error_with_rescale(
lambda_supervised_loss * JE
) * weight
).mean(), lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample +
lambda_supervised_loss * JE
lambda_supervised_loss * JE, (td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean())
)
......
......@@ -4,7 +4,7 @@ from ding.entry import serial_pipeline
collector_env_num = 8
evaluator_env_num = 5
pong_r2d2_config = dict(
exp_name='debug_pong_r2d2_n5_bs2_ul40',
exp_name='debug_pong_r2d2_n5_bs2_ul40_rbs1e4_seed0',
env=dict(
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
......@@ -55,7 +55,7 @@ pong_r2d2_config = dict(
decay=1e5,
),
replay_buffer=dict(
replay_buffer_size=20000, # TODO(pu)
replay_buffer_size=10000, # TODO(pu)
# (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
alpha=0.6,
# (Float type) How much correction is used: 0 means no correction while 1 means full correction
......
......@@ -6,10 +6,10 @@ module_path = os.path.dirname(__file__)
collector_env_num = 8
evaluator_env_num = 5
expert_replay_buffer_size=1 #TODO 1000
expert_replay_buffer_size=1000 #TODO 1000
"""agent config"""
pong_r2d3_config = dict(
exp_name='debug_pong_r2d3_k0_pho0',
exp_name='debug_pong_r2d3_offppoexpert_k0_pho1-256_rbs2e4',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True),
......@@ -50,10 +50,9 @@ pong_r2d3_config = dict(
target_update_theta=0.001,
# DQFD related parameters
lambda1=1.0, # n-step return
lambda2=0, # 1.0, # supervised loss
lambda3=1e-5, # L2
lambda_one_step_td=0, # 1-step return
lambda2=1, # 1.0, # supervised loss
lambda3=1e-5, # 1e-5, # L2 it's very important to set Adam optimizer optim_type='adamw'.
lambda_one_step_td=1, # 1-step return
margin_function=0.8, # margin function in JE, here we implement this as a constant
per_train_iter_k=0, # TODO(pu)
),
......@@ -63,15 +62,15 @@ pong_r2d3_config = dict(
env_num=collector_env_num,
# The hyperparameter pho, the demo ratio, control the propotion of data coming\
# from expert demonstrations versus from the agent's own experience.
pho=0, # 1/256, #TODO(pu), 0.25,
pho=1/256, # 1/256, #TODO(pu), 0.25,
),
eval=dict(env_num=evaluator_env_num, ),
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
decay=100000,
end=0.05,
decay=1e5,
),
replay_buffer=dict(
replay_buffer_size=20000, # TODO(pu) sequence_length 42 10000 obs need 11GB memory, if rbs=20000, at least 140gb
......@@ -99,7 +98,7 @@ create_config = pong_r2d3_create_config
"""export config"""
expert_pong_r2d3_config = dict(
exp_name='debug_pong_r2d3',
exp_name='expert_pong_r2d3_ppoexpert_k0_pho1-256_rbs2e4',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True),
......
from easydict import EasyDict
from ding.entry import serial_pipeline_r2d3
import os
module_path = os.path.dirname(__file__)
collector_env_num = 8
evaluator_env_num = 5
expert_replay_buffer_size=1 #TODO 1000
expert_replay_buffer_size=int(5e3)
"""agent config"""
pong_r2d3_config = dict(
exp_name='debug_pong_r2d3_r2d2expert_k0_pho0_no1td_nosl',
exp_name='debug_pong_r2d3_r2d2expert_k0_pho1-4_rbs2e4_ds5e3',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True),
......@@ -46,14 +45,14 @@ pong_r2d3_config = dict(
# in most environments
value_rescale=True,
update_per_collect=8,
batch_size=64,
batch_size=64, # TODO(pu)
learning_rate=0.0005,
target_update_theta=0.001,
# DQFD related parameters
lambda1=1.0, # n-step return
lambda2=0, # supervised loss
lambda3=1e-5, # L2
lambda_one_step_td=0, # 1-step return
lambda2=1.0, # supervised loss
lambda3=1e-5, # L2 it's very important to set Adam optimizer optim_type='adamw'.
lambda_one_step_td=1.0, # 1-step return
margin_function=0.8, # margin function in JE, here we implement this as a constant
per_train_iter_k=0, # TODO(pu)
),
......@@ -63,18 +62,18 @@ pong_r2d3_config = dict(
env_num=collector_env_num,
# The hyperparameter pho, the demo ratio, control the propotion of data coming\
# from expert demonstrations versus from the agent's own experience.
pho=0, #TODO(pu), 0.25,
pho=1/4, # TODO(pu)
),
eval=dict(env_num=evaluator_env_num, ),
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
end=0.05,
decay=100000,
),
replay_buffer=dict(
replay_buffer_size=20000, # TODO(pu) sequence_length 42 10000 obs need 11GB memory, if rbs=20000, at least 140gb
replay_buffer_size=int(2e4), # TODO(pu)
# (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
alpha=0.6,
# (Float type) How much correction is used: 0 means no correction while 1 means full correction
......@@ -96,10 +95,9 @@ pong_r2d3_create_config = dict(
pong_r2d3_create_config = EasyDict(pong_r2d3_create_config)
create_config = pong_r2d3_create_config
"""export config"""
expert_pong_r2d3_config = dict(
# exp_name='debug_pong_r2d3',
exp_name='expert_pong_r2d3_r2d2expert_k0_pho1-4_rbs1e4_ds5e3',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True),
......@@ -119,22 +117,18 @@ expert_pong_r2d3_config = dict(
obs_shape=[4, 84, 84],
action_shape=6,
# encoder_hidden_size_list=[64, 64, 128], # ppo expert policy
encoder_hidden_size_list=[128, 128, 512], # r2d2
# actor_head_hidden_size=128,
# critic_head_hidden_size=128,
encoder_hidden_size_list=[128, 128, 512], # r2d2 expert policy
),
discount_factor=0.997,
burnin_step=20,
burnin_step=2,
nstep=5,
learn=dict(
expert_replay_buffer_size=expert_replay_buffer_size, # TODO(pu)
expert_replay_buffer_size=expert_replay_buffer_size,
),
collect=dict(
# NOTE it is important that don't include key n_sample here, to make sure self._traj_len=INF
each_iter_n_sample=32,
# Users should add their own path here (path should lead to a well-trained model)
# demonstration_info_path='dizoo/atari/config/serial/pong/demo_path/ppo-off_iteration_16127.pth.tar',
# demonstration_info_path=module_path + '/demo_path/ppo-off_iteration_16127.pth.tar',
# demonstration_info_path=module_path + '/demo_path/ppo-off_ckpt_best.pth.tar',
demonstration_info_path=module_path + '/demo_path/r2d2_iteration_15000.pth.tar',
# Cut trajectories into pieces with length "unroll_len". should set as self._unroll_len_add_burnin_step of r2d2
......@@ -144,7 +138,7 @@ expert_pong_r2d3_config = dict(
eval=dict(env_num=evaluator_env_num, ),
other=dict(
replay_buffer=dict(
replay_buffer_size=expert_replay_buffer_size, # TODO(pu)
replay_buffer_size=expert_replay_buffer_size,
# (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
alpha=0.6,
# (Float type) How much correction is used: 0 means no correction while 1 means full correction
......@@ -161,7 +155,7 @@ expert_pong_r2d3_create_config = dict(
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='base'),
policy=dict(type='r2d2_collect_traj'),
policy=dict(type='r2d2_collect_traj'), # this policy is designed to collect r2d2 expert traj for r2d3
)
expert_pong_r2d3_create_config = EasyDict(expert_pong_r2d3_create_config)
expert_create_config = expert_pong_r2d3_create_config
......
......@@ -8,10 +8,11 @@ module_path = os.path.dirname(__file__)
collector_env_num = 8
evaluator_env_num = 5
expert_replay_buffer_size=int(5e3)
"""agent config"""
lunarlander_r2d3_config = dict(
exp_name='debug_lunarlander_r2d3_k0_pho0',
exp_name='debug_lunarlander_r2d3_ppoexpert_k100_pho1-4_rbs1e5_ds5e3',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True),
......@@ -45,13 +46,14 @@ lunarlander_r2d3_config = dict(
# in most environments
value_rescale=True,
update_per_collect=8,
batch_size=64, #32, # TODO(pu)
batch_size=64, # TODO(pu)
learning_rate=0.0005,
target_update_theta=0.001,
# DQFD related parameters
lambda1=1.0, # n-step return
lambda2=1.0, # supervised loss
lambda3=1e-5, # L2
lambda3=1e-5, # L2 it's very important to set Adam optimizer optim_type='adamw'.
lambda_one_step_td=1, # 1-step return
margin_function=0.8, # margin function in JE, here we implement this as a constant
per_train_iter_k=0, # TODO(pu)
),
......@@ -59,19 +61,19 @@ lunarlander_r2d3_config = dict(
# NOTE it is important that don't include key n_sample here, to make sure self._traj_len=INF
each_iter_n_sample=32,
env_num=collector_env_num,
# The hyperparameter pho, the demo ratio, control the propotion of data coming\
# The hyperparameter pho, the demo ratio, control the propotion of data coming
# from expert demonstrations versus from the agent's own experience.
pho=0, # TODO(pu) 0.25
pho=1/4., # TODO(pu)
),
eval=dict(env_num=evaluator_env_num, ),
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
end=0.05,
decay=100000,
),
replay_buffer=dict(replay_buffer_size=10000,
replay_buffer=dict(replay_buffer_size=int(1e5),
# (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
alpha=0.6, # priority exponent default=0.6
# (Float type) How much correction is used: 0 means no correction while 1 means full correction
......@@ -96,7 +98,7 @@ create_config = lunarlander_r2d3_create_config
"""export config"""
expert_lunarlander_r2d3_config = dict(
exp_name='debug_lunarlander_r2d3',
exp_name='expert_lunarlander_r2d3_ppoexpert_k0_pho1-4_rbs1e5_ds5e3',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True),
......@@ -119,12 +121,11 @@ expert_lunarlander_r2d3_config = dict(
burnin_step=2,
nstep=5,
learn=dict(
expert_replay_buffer_size=1000, # 10000, TODO(pu)
expert_replay_buffer_size=expert_replay_buffer_size,
),
collect=dict(
# n_sample=32, # NOTE it is important that don't include key n_sample here, to make sure self._traj_len=INF
# n_sample=32, NOTE it is important that don't include key n_sample here, to make sure self._traj_len=INF
# Users should add their own path here (path should lead to a well-trained model)
# demonstration_info_path='dizoo/box2d/lunarlander/config/demo_path/ppo-off_iteration_12948.pth.tar',
demonstration_info_path=module_path + '/demo_path/ppo-off_iteration_12948.pth.tar',
# Cut trajectories into pieces with length "unroll_len". should set as self._unroll_len_add_burnin_step of r2d2
unroll_len=42, # TODO(pu) should equals self._unroll_len_add_burnin_step in r2d2 policy
......@@ -132,7 +133,7 @@ expert_lunarlander_r2d3_config = dict(
),
eval=dict(env_num=evaluator_env_num, ),
other=dict(
replay_buffer=dict(replay_buffer_size=1000, # 10000,8 TODO(pu)
replay_buffer=dict(replay_buffer_size=expert_replay_buffer_size,
# (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
alpha=0.9, # priority exponent default=0.6
# (Float type) How much correction is used: 0 means no correction while 1 means full correction
......@@ -149,7 +150,7 @@ expert_lunarlander_r2d3_create_config = dict(
import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo_offpolicy_collect_traj'), # NOTE
policy=dict(type='ppo_offpolicy_collect_traj'), # this policy is designed to collect off-ppo expert traj for r2d3
)
expert_lunarlander_r2d3_create_config = EasyDict(expert_lunarlander_r2d3_create_config)
expert_create_config = expert_lunarlander_r2d3_create_config
......
......@@ -8,10 +8,11 @@ module_path = os.path.dirname(__file__)
collector_env_num = 8
evaluator_env_num = 5
expert_replay_buffer_size=int(5e3)
"""agent config"""
lunarlander_r2d3_config = dict(
exp_name='debug_lunarlander_r2d3_r2d2expert_k0_pho0',
exp_name='debug_lunarlander_r2d3_r2d2expert_k0_pho1-4_rbs1e4_ds5e3',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True),
......@@ -51,7 +52,8 @@ lunarlander_r2d3_config = dict(
# DQFD related parameters
lambda1=1.0, # n-step return
lambda2=1.0, # supervised loss
lambda3=1e-5, # L2
lambda3=1e-5, # L2 it's very important to set Adam optimizer optim_type='adamw'.
lambda_one_step_td=1, # 1-step return
margin_function=0.8, # margin function in JE, here we implement this as a constant
per_train_iter_k=0, # TODO(pu)
),
......@@ -61,7 +63,7 @@ lunarlander_r2d3_config = dict(
env_num=collector_env_num,
# The hyperparameter pho, the demo ratio, control the propotion of data coming\
# from expert demonstrations versus from the agent's own experience.
pho=0, # TODO(pu) 0.25
pho=1/4, # TODO(pu)
),
eval=dict(env_num=evaluator_env_num, ),
other=dict(
......@@ -71,7 +73,7 @@ lunarlander_r2d3_config = dict(
end=0.1,
decay=100000,
),
replay_buffer=dict(replay_buffer_size=10000,
replay_buffer=dict(replay_buffer_size=int(1e4),
# (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
alpha=0.6, # priority exponent default=0.6
# (Float type) How much correction is used: 0 means no correction while 1 means full correction
......@@ -93,10 +95,10 @@ lunarlander_r2d3_create_config = dict(
lunarlander_r2d3_create_config = EasyDict(lunarlander_r2d3_create_config)
create_config = lunarlander_r2d3_create_config
"""export config"""
"""export config"""
expert_lunarlander_r2d3_config = dict(
# exp_name='debug_lunarlander_r2d3',
exp_name='expert_lunarlander_r2d3_r2d2expert_k0_pho1-4_ds5e3',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True),
......@@ -120,22 +122,21 @@ expert_lunarlander_r2d3_config = dict(
burnin_step=2,
nstep=5,
learn=dict(
expert_replay_buffer_size=1000,
expert_replay_buffer_size=expert_replay_buffer_size,
),
collect=dict(
# n_sample=32, NOTE it is important that don't include key n_sample here, to make sure self._traj_len=INF
# NOTE it is important that don't include key n_sample here, to make sure self._traj_len=INF
each_iter_n_sample=32,
# Users should add their own path here (path should lead to a well-trained model)
# demonstration_info_path='dizoo/box2d/lunarlander/config/demo_path/ppo-off_iteration_12948.pth.tar',
# demonstration_info_path=module_path + '/demo_path/ppo-off_iteration_12948.pth.tar',
demonstration_info_path=module_path + '/demo_path/r2d2_iteration_13000.pth.tar',
# Cut trajectories into pieces with length "unroll_len". should set as self._unroll_len_add_burnin_step of r2d2
unroll_len=40, # TODO(pu): if in ppo_offpolicy, this key should equals self._unroll_len_add_burnin_step in r2d2 policy
env_num=collector_env_num,
),
eval=dict(env_num=evaluator_env_num, ),
other=dict(
replay_buffer=dict(replay_buffer_size=1000,
replay_buffer=dict(replay_buffer_size=expert_replay_buffer_size,
# (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
alpha=0.9, # priority exponent default=0.6
# (Float type) How much correction is used: 0 means no correction while 1 means full correction
......@@ -152,7 +153,7 @@ expert_lunarlander_r2d3_create_config = dict(
import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
),
env_manager=dict(type='base'),
policy=dict(type='r2d2_collect_traj'), # NOTE
policy=dict(type='r2d2_collect_traj'), # this policy is designed to collect r2d2 expert traj for r2d3
)
expert_lunarlander_r2d3_create_config = EasyDict(expert_lunarlander_r2d3_create_config)
expert_create_config = expert_lunarlander_r2d3_create_config
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册