From f087d2c716754c4b1c914ff3a0d7fc7de2df5135 Mon Sep 17 00:00:00 2001 From: Ke Li <118020231@link.cuhk.edu.cn> Date: Fri, 3 Dec 2021 18:44:33 +0800 Subject: [PATCH] feature(lk): implement multi pass DQN (#131) * feature(lk): add initial version of MP-PDQN * fix(lk): fix expand function bug * refactor(nyz): refactor mpdqn continuous args inputs module * fix(nyz): fix pdqn scatter index generation * fix(lk): fix pdqn scatter assignment bug * feature(lk): polish mpdqn code and style format * feature(lk): add mpdqn config and test file * feature(lk): polish mpdqn code and style format * fix(lk): fix import bug * polish(lk): add test for mpdqn * polish(lk): polish code style and format * polish(lk): rm print debug info * polish(lk): rm print debug info * polish(lk): polish code style and format * polish(lk): add MPDQN in readme.md Co-authored-by: niuyazhe --- README.md | 37 ++++---- ding/entry/tests/test_serial_entry.py | 11 +++ ding/model/template/pdqn.py | 62 +++++++++++-- ding/model/template/tests/test_pdqn.py | 27 +++++- ding/utils/type_helper.py | 4 +- .../config/gym_hybrid_mpdqn_config.py | 87 +++++++++++++++++++ .../config/gym_hybrid_pdqn_config.py | 11 ++- 7 files changed, 205 insertions(+), 34 deletions(-) create mode 100644 dizoo/gym_hybrid/config/gym_hybrid_mpdqn_config.py diff --git a/README.md b/README.md index 1331bca..1aa075d 100644 --- a/README.md +++ b/README.md @@ -125,24 +125,25 @@ ding -m serial -e cartpole -p dqn -s 0 | 15 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py | | 16 | [SAC](https://arxiv.org/abs/1801.01290) | ![continuous](https://img.shields.io/badge/-continous-green) | [policy/sac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/sac.py) | ding -m serial -c pendulum_sac_config.py -s 0 | | 17 | [PDQN](https://arxiv.org/pdf/1810.06394.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_pdqn_config.py -s 0 | -| 18 | [QMIX](https://arxiv.org/pdf/1803.11485.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/qmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qmix.py) | ding -m serial -c smac_3s5z_qmix_config.py -s 0 | -| 19 | [COMA](https://arxiv.org/pdf/1705.08926.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/coma](https://github.com/opendilab/DI-engine/blob/main/ding/policy/coma.py) | ding -m serial -c smac_3s5z_coma_config.py -s 0 | -| 20 | [QTran](https://arxiv.org/abs/1905.05408) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/qtran](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qtran.py) | ding -m serial -c smac_3s5z_qtran_config.py -s 0 | -| 21 | [WQMIX](https://arxiv.org/abs/2006.10800) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/wqmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/wqmix.py) | ding -m serial -c smac_3s5z_wqmix_config.py -s 0 | -| 22 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 | -| 23 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_gail -c cartpole_dqn_gail_config.py -s 0 | -| 24 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 | -| 25 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 | -| 26 | [R2D3](https://arxiv.org/pdf/1909.01387.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/r2d3](https://github.com/opendilab/DI-engine/blob/main/ding/policy/r2d3.py) | python3 -u pong_r2d3_r2d2expert_config.py | -| 27 | [GCL](https://arxiv.org/pdf/1603.00448.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/guided_cost](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/guided_cost_reward_model.py) | python3 lunarlander_gcl_config.py -| 28 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py | -| 29 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py | -| 30 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py | -| 31 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py | -| 32 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u mujoco_td3_bc_main.py | -| 33 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [model/template/model_based/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/model_based/mbpo.py) | python3 -u sac_halfcheetah_mopo_default_config.py | -| 34 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` | -| 35 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` | +| 18 | [MPDQN](https://arxiv.org/pdf/1905.04388.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_mpdqn_config.py -s 0 | +| 19 | [QMIX](https://arxiv.org/pdf/1803.11485.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/qmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qmix.py) | ding -m serial -c smac_3s5z_qmix_config.py -s 0 | +| 20 | [COMA](https://arxiv.org/pdf/1705.08926.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/coma](https://github.com/opendilab/DI-engine/blob/main/ding/policy/coma.py) | ding -m serial -c smac_3s5z_coma_config.py -s 0 | +| 21 | [QTran](https://arxiv.org/abs/1905.05408) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/qtran](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qtran.py) | ding -m serial -c smac_3s5z_qtran_config.py -s 0 | +| 22 | [WQMIX](https://arxiv.org/abs/2006.10800) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/wqmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/wqmix.py) | ding -m serial -c smac_3s5z_wqmix_config.py -s 0 | +| 23 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 | +| 24 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_gail -c cartpole_dqn_gail_config.py -s 0 | +| 25 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 | +| 26 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 | +| 27 | [R2D3](https://arxiv.org/pdf/1909.01387.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/r2d3](https://github.com/opendilab/DI-engine/blob/main/ding/policy/r2d3.py) | python3 -u pong_r2d3_r2d2expert_config.py | +| 28 | [GCL](https://arxiv.org/pdf/1603.00448.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/guided_cost](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/guided_cost_reward_model.py) | python3 lunarlander_gcl_config.py +| 29 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py | +| 30 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py | +| 31 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py | +| 32 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py | +| 33 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u mujoco_td3_bc_main.py | +| 34 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [model/template/model_based/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/model_based/mbpo.py) | python3 -u sac_halfcheetah_mopo_default_config.py | +| 35 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` | +| 36 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space, which is only label in normal DRL algorithms (1-16) diff --git a/ding/entry/tests/test_serial_entry.py b/ding/entry/tests/test_serial_entry.py index 1c04f7b..77ed417 100644 --- a/ding/entry/tests/test_serial_entry.py +++ b/ding/entry/tests/test_serial_entry.py @@ -43,6 +43,7 @@ from dizoo.classic_control.pendulum.config.pendulum_td3_data_generation_config i from dizoo.classic_control.pendulum.config.pendulum_td3_bc_config import pendulum_td3_bc_config, pendulum_td3_bc_create_config # noqa from dizoo.gym_hybrid.config.gym_hybrid_ddpg_config import gym_hybrid_ddpg_config, gym_hybrid_ddpg_create_config from dizoo.gym_hybrid.config.gym_hybrid_pdqn_config import gym_hybrid_pdqn_config, gym_hybrid_pdqn_create_config +from dizoo.gym_hybrid.config.gym_hybrid_mpdqn_config import gym_hybrid_mpdqn_config, gym_hybrid_mpdqn_create_config @pytest.mark.unittest @@ -88,6 +89,16 @@ def test_hybrid_pdqn(): assert False, "pipeline fail" +# @pytest.mark.unittest +def test_hybrid_mpdqn(): + config = [deepcopy(gym_hybrid_mpdqn_config), deepcopy(gym_hybrid_mpdqn_create_config)] + config[0].policy.learn.update_per_collect = 1 + try: + serial_pipeline(config, seed=0, max_iterations=1) + except Exception: + assert False, "pipeline fail" + + @pytest.mark.unittest def test_td3(): config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)] diff --git a/ding/model/template/pdqn.py b/ding/model/template/pdqn.py index bd44559..532d579 100644 --- a/ding/model/template/pdqn.py +++ b/ding/model/template/pdqn.py @@ -22,7 +22,9 @@ class PDQN(nn.Module): head_hidden_size: Optional[int] = None, head_layer_num: int = 1, activation: Optional[nn.Module] = nn.ReLU(), - norm_type: Optional[str] = None + norm_type: Optional[str] = None, + multi_pass: Optional[bool] = False, + action_mask: Optional[list] = None ) -> None: r""" Overview: @@ -40,18 +42,39 @@ class PDQN(nn.Module): if ``None`` then default set it to ``nn.ReLU()`` - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ ``ding.torch_utils.fc_block`` for more details. + - multi_pass (:obj:`Optional[bool]`): Whether to use multi pass version. + - action_mask: (:obj:`Optional[list]`): An action mask indicating how action args are + associated to each discrete action. For example, if there are 3 discrete action, + 4 continous action args, and the first discrete action associates with the first + continuous action args, the second discrete action associates with the second continuous + action args, and the third discrete action associates with the remaining 2 action args, + the action mask will be like: [[1,0,0,0],[0,1,0,0],[0,0,1,1]] with shape 3*4. """ super(PDQN, self).__init__() + self.multi_pass = multi_pass + if self.multi_pass: + assert isinstance( + action_mask, list + ), 'Please indicate action mask in list form if you set multi_pass to True' + self.action_mask = torch.LongTensor(action_mask) + nonzero = torch.nonzero(self.action_mask) + index = torch.zeros(action_shape.action_args_shape).long() + index.scatter_(dim=0, index=nonzero[:, 1], src=nonzero[:, 0]) + self.action_scatter_index = index # (self.action_args_shape, ) - # squeeze obs input for compatibility: 1, (1, ), [4, 32, 32] - obs_shape = squeeze(obs_shape) # squeeze action shape input like (3,) to 3 action_shape.action_args_shape = squeeze(action_shape.action_args_shape) action_shape.action_type_shape = squeeze(action_shape.action_type_shape) + self.action_args_shape = action_shape.action_args_shape + self.action_type_shape = action_shape.action_type_shape + # init head hidden size if head_hidden_size is None: head_hidden_size = encoder_hidden_size_list[-1] + # squeeze obs input for compatibility: 1, (1, ), [4, 32, 32] + obs_shape = squeeze(obs_shape) + # Obs Encoder Type if isinstance(obs_shape, int) or len(obs_shape) == 1: # FC Encoder self.dis_encoder = FCEncoder( @@ -69,7 +92,7 @@ class PDQN(nn.Module): ) else: raise RuntimeError( - "not support obs_shape for pre-defined encoder: {}, please customize your own PDQN".format(obs_shape) + "Pre-defined encoder not support obs_shape {}, please customize your own PDQN.".format(obs_shape) ) # Continuous Action Head Type @@ -142,8 +165,33 @@ class PDQN(nn.Module): - 'action_args': the continuous action args(same as the inputs['action_args']) for later usage """ dis_x = self.encoder[0](inputs['state']) # size (B, encoded_state_shape) - action_args = inputs['action_args'] # size (B, action_args_shape - state_action_cat = torch.cat((dis_x, action_args), dim=-1) # size (B, encoded_state_shape + action_args_shape) - logit = self.actor_head[0](state_action_cat)['logit'] # size (B, action_type_shape) + action_args = inputs['action_args'] # size (B, action_args_shape) + + if self.multi_pass: # mpdqn + # fill_value=-2 is a mask value, which is not in normal acton range + # (B, action_args_shape, K) where K is the action_type_shape + mp_action = torch.full( + (dis_x.shape[0], self.action_args_shape, self.action_type_shape), + fill_value=-2, + device=dis_x.device, + dtype=dis_x.dtype + ) + index = self.action_scatter_index.view(1, -1, 1).repeat(dis_x.shape[0], 1, 1) + + # index: (B, action_args_shape, 1) src: (B, action_args_shape, 1) + mp_action.scatter_(dim=-1, index=index, src=action_args.unsqueeze(-1)) + mp_action = mp_action.permute(0, 2, 1) # (B, K, action_args_shape) + + mp_state = dis_x.unsqueeze(1).repeat(1, self.action_type_shape, 1) # (B, K, obs_shape) + mp_state_action_cat = torch.cat([mp_state, mp_action], dim=-1) + + logit = self.actor_head[0](mp_state_action_cat)['logit'] # (B, K, K) + + logit = torch.diagonal(logit, dim1=-2, dim2=-1) # (B, K) + else: # pdqn + # size (B, encoded_state_shape + action_args_shape) + state_action_cat = torch.cat((dis_x, action_args), dim=-1) + logit = self.actor_head[0](state_action_cat)['logit'] # size (B, K) where K is action_type_shape + outputs = {'logit': logit, 'action_args': action_args} return outputs diff --git a/ding/model/template/tests/test_pdqn.py b/ding/model/template/tests/test_pdqn.py index e00167a..10dfa25 100644 --- a/ding/model/template/tests/test_pdqn.py +++ b/ding/model/template/tests/test_pdqn.py @@ -2,7 +2,6 @@ import pytest from easydict import EasyDict import torch from ding.model.template import PDQN -from ding.torch_utils import is_differentiable @pytest.mark.unittest @@ -29,3 +28,29 @@ class TestPQQN: else: for i, s in enumerate(act_shape): assert dis_outputs['logit'][i].shape == (B, s) + + def test_mdqn(self): + T, B = 3, 4 + obs_shape = (4, ) + act_shape = EasyDict({'action_type_shape': 3, 'action_args_shape': 5}) + if isinstance(obs_shape, int): + cont_inputs = torch.randn(B, obs_shape) + else: + cont_inputs = torch.randn(B, *obs_shape) + model = PDQN( + obs_shape, act_shape, multi_pass=True, action_mask=[[1, 1, 0, 0, 0], [0, 0, 1, 1, 1], [0, 0, 0, 0, 0]] + ) + cont_outputs = model.forward(cont_inputs, mode='compute_continuous') + assert isinstance(cont_outputs, dict) + dis_inputs = {'state': cont_inputs, 'action_args': cont_outputs['action_args']} + + dis_outputs = model.forward(dis_inputs, mode='compute_discrete') + + assert isinstance(dis_outputs, dict) + if isinstance(act_shape['action_type_shape'], int): + assert dis_outputs['logit'].shape == (B, act_shape.action_type_shape) + elif len(act_shape['action_type_shape']) == 1: + assert dis_outputs['logit'].shape == (B, *act_shape.action_type_shape) + else: + for i, s in enumerate(act_shape): + assert dis_outputs['logit'][i].shape == (B, s) diff --git a/ding/utils/type_helper.py b/ding/utils/type_helper.py index fab1ad3..855b453 100644 --- a/ding/utils/type_helper.py +++ b/ding/utils/type_helper.py @@ -1,4 +1,4 @@ from collections import namedtuple -from typing import List, Dict, TypeVar +from typing import List, Tuple, TypeVar -SequenceType = TypeVar('SequenceType', List, Dict, namedtuple) +SequenceType = TypeVar('SequenceType', List, Tuple, namedtuple) diff --git a/dizoo/gym_hybrid/config/gym_hybrid_mpdqn_config.py b/dizoo/gym_hybrid/config/gym_hybrid_mpdqn_config.py new file mode 100644 index 0000000..de207c8 --- /dev/null +++ b/dizoo/gym_hybrid/config/gym_hybrid_mpdqn_config.py @@ -0,0 +1,87 @@ +from easydict import EasyDict +from ding.entry import serial_pipeline + +gym_hybrid_mpdqn_config = dict( + exp_name='gym_hybrid_mpdqn_seed1', + env=dict( + collector_env_num=8, + evaluator_env_num=3, + # (bool) Scale output action into legal range [-1, 1]. + act_scale=True, + env_id='Moving-v0', # ['Sliding-v0', 'Moving-v0'] + n_evaluator_episode=5, + stop_value=1.5, # 1.85 for hybrid_ddpg + ), + policy=dict( + cuda=True, + priority=False, + # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. + priority_IS_weight=False, + discount_factor=0.99, + nstep=1, + model=dict( + obs_shape=10, + action_shape=dict( + action_type_shape=3, + action_args_shape=2, + ), + multi_pass=True, + action_mask=[[1, 0], [0, 1], [0, 0]], + ), + learn=dict( + # (bool) Whether to use multi gpu + multi_gpu=False, + # How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + update_per_collect=500, # 100, 10, + batch_size=320, # 32, + learning_rate_dis=3e-4, # 1e-5, 3e-4, alpha + learning_rate_cont=3e-4, # beta + target_theta=0.001, # 0.005, + # cont_update_freq=10, + # disc_update_freq=10, + update_circle=10, + ), + # collect_mode config + collect=dict( + # (int) Only one of [n_sample, n_episode] shoule be set + n_sample=3200, # 128, + # (int) Cut trajectories into pieces with length "unroll_len". + unroll_len=1, + noise_sigma=0.1, # 0.05, + collector=dict(collect_print_freq=1000, ), + ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + # other config + other=dict( + # Epsilon greedy with decay. + eps=dict( + # (str) Decay type. Support ['exp', 'linear']. + type='exp', + start=1, # 0.95, + end=0.1, # 0.05, + # (int) Decay length(env step) + decay=int(1e5), + ), + replay_buffer=dict(replay_buffer_size=int(1e6), ), + ), + ) +) + +gym_hybrid_mpdqn_config = EasyDict(gym_hybrid_mpdqn_config) +main_config = gym_hybrid_mpdqn_config + +gym_hybrid_mpdqn_create_config = dict( + env=dict( + type='gym_hybrid', + import_names=['dizoo.gym_hybrid.envs.gym_hybrid_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='pdqn'), +) +gym_hybrid_mpdqn_create_config = EasyDict(gym_hybrid_mpdqn_create_config) +create_config = gym_hybrid_mpdqn_create_config + +if __name__ == "__main__": + serial_pipeline([main_config, create_config], seed=1) diff --git a/dizoo/gym_hybrid/config/gym_hybrid_pdqn_config.py b/dizoo/gym_hybrid/config/gym_hybrid_pdqn_config.py index 6d8ff49..138980e 100644 --- a/dizoo/gym_hybrid/config/gym_hybrid_pdqn_config.py +++ b/dizoo/gym_hybrid/config/gym_hybrid_pdqn_config.py @@ -2,15 +2,12 @@ from easydict import EasyDict from ding.entry import serial_pipeline gym_hybrid_pdqn_config = dict( - # exp_name='gym_hybrid_pdqn_dataaction_1encoder_lrd3e-4_lrc1e-3_upc10_auf100_seed0', - # exp_name='gym_hybrid_pdqn_dataaction_1encoder_lrd3e-4_lrc3e-4_upc100_uc10v2_seed0', - # exp_name='gym_hybrid_pdqn_dataaction_1encoder_lrd3e-4_lrc3e-4_upc500_uc10v2_seed0', - exp_name='gym_hybrid_pdqn_dataaction_1encoder_lrd3e-4_lrc3e-4_upc500_ed1e5_rbs1e6_uc10v2_seed1', + exp_name='gym_hybrid_pdqn_seed1', # exp_name='gym_hybrid_pdqn_dataaction_1encoder_lrd1e-5_lrc1e-3_upc100_seed0', env=dict( collector_env_num=8, - evaluator_env_num=5, + evaluator_env_num=3, # (bool) Scale output action into legal range [-1, 1]. act_scale=True, env_id='Moving-v0', # ['Sliding-v0', 'Moving-v0'] @@ -30,6 +27,8 @@ gym_hybrid_pdqn_config = dict( action_type_shape=3, action_args_shape=2, ), + # multi_pass=True, + # action_mask=[[1,0],[0,1],[0,0]], ), learn=dict( # (bool) Whether to use multi gpu @@ -39,7 +38,7 @@ gym_hybrid_pdqn_config = dict( # collect data -> update policy-> collect data -> ... update_per_collect=500, # 100, 10, batch_size=320, # 32, - learning_rate_dis=3e-4, # 1e-5,#3e-4, alpha + learning_rate_dis=3e-4, # 1e-5, 3e-4, alpha learning_rate_cont=3e-4, # beta target_theta=0.001, # 0.005, # cont_update_freq=10, -- GitLab