未验证 提交 f087d2c7 编写于 作者: K Ke Li 提交者: GitHub

feature(lk): implement multi pass DQN (#131)

* feature(lk): add initial version of MP-PDQN

* fix(lk): fix expand function bug

* refactor(nyz): refactor mpdqn continuous args inputs module

* fix(nyz): fix pdqn scatter index generation

* fix(lk): fix pdqn scatter assignment bug

* feature(lk): polish mpdqn code and style format

* feature(lk): add mpdqn config and test file

* feature(lk): polish mpdqn code and style format

* fix(lk): fix import bug

* polish(lk): add test for mpdqn

* polish(lk): polish code style and format

* polish(lk): rm print debug info

* polish(lk): rm print debug info

* polish(lk): polish code style and format

* polish(lk): add MPDQN in readme.md
Co-authored-by: Nniuyazhe <niuyazhe@sensetime.com>
上级 5ee17ad1
......@@ -125,24 +125,25 @@ ding -m serial -e cartpole -p dqn -s 0
| 15 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py |
| 16 | [SAC](https://arxiv.org/abs/1801.01290) | ![continuous](https://img.shields.io/badge/-continous-green) | [policy/sac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/sac.py) | ding -m serial -c pendulum_sac_config.py -s 0 |
| 17 | [PDQN](https://arxiv.org/pdf/1810.06394.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_pdqn_config.py -s 0 |
| 18 | [QMIX](https://arxiv.org/pdf/1803.11485.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/qmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qmix.py) | ding -m serial -c smac_3s5z_qmix_config.py -s 0 |
| 19 | [COMA](https://arxiv.org/pdf/1705.08926.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/coma](https://github.com/opendilab/DI-engine/blob/main/ding/policy/coma.py) | ding -m serial -c smac_3s5z_coma_config.py -s 0 |
| 20 | [QTran](https://arxiv.org/abs/1905.05408) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/qtran](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qtran.py) | ding -m serial -c smac_3s5z_qtran_config.py -s 0 |
| 21 | [WQMIX](https://arxiv.org/abs/2006.10800) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/wqmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/wqmix.py) | ding -m serial -c smac_3s5z_wqmix_config.py -s 0 |
| 22 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 |
| 23 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_gail -c cartpole_dqn_gail_config.py -s 0 |
| 24 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 |
| 25 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 |
| 26 | [R2D3](https://arxiv.org/pdf/1909.01387.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/r2d3](https://github.com/opendilab/DI-engine/blob/main/ding/policy/r2d3.py) | python3 -u pong_r2d3_r2d2expert_config.py |
| 27 | [GCL](https://arxiv.org/pdf/1603.00448.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/guided_cost](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/guided_cost_reward_model.py) | python3 lunarlander_gcl_config.py
| 28 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
| 29 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py |
| 30 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py |
| 31 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
| 32 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u mujoco_td3_bc_main.py |
| 33 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [model/template/model_based/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/model_based/mbpo.py) | python3 -u sac_halfcheetah_mopo_default_config.py |
| 34 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 35 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
| 18 | [MPDQN](https://arxiv.org/pdf/1905.04388.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_mpdqn_config.py -s 0 |
| 19 | [QMIX](https://arxiv.org/pdf/1803.11485.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/qmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qmix.py) | ding -m serial -c smac_3s5z_qmix_config.py -s 0 |
| 20 | [COMA](https://arxiv.org/pdf/1705.08926.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/coma](https://github.com/opendilab/DI-engine/blob/main/ding/policy/coma.py) | ding -m serial -c smac_3s5z_coma_config.py -s 0 |
| 21 | [QTran](https://arxiv.org/abs/1905.05408) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/qtran](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qtran.py) | ding -m serial -c smac_3s5z_qtran_config.py -s 0 |
| 22 | [WQMIX](https://arxiv.org/abs/2006.10800) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/wqmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/wqmix.py) | ding -m serial -c smac_3s5z_wqmix_config.py -s 0 |
| 23 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 |
| 24 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_gail -c cartpole_dqn_gail_config.py -s 0 |
| 25 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 |
| 26 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 |
| 27 | [R2D3](https://arxiv.org/pdf/1909.01387.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/r2d3](https://github.com/opendilab/DI-engine/blob/main/ding/policy/r2d3.py) | python3 -u pong_r2d3_r2d2expert_config.py |
| 28 | [GCL](https://arxiv.org/pdf/1603.00448.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/guided_cost](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/guided_cost_reward_model.py) | python3 lunarlander_gcl_config.py
| 29 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
| 30 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py |
| 31 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py |
| 32 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
| 33 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u mujoco_td3_bc_main.py |
| 34 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [model/template/model_based/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/model_based/mbpo.py) | python3 -u sac_halfcheetah_mopo_default_config.py |
| 35 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 36 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space, which is only label in normal DRL algorithms (1-16)
......
......@@ -43,6 +43,7 @@ from dizoo.classic_control.pendulum.config.pendulum_td3_data_generation_config i
from dizoo.classic_control.pendulum.config.pendulum_td3_bc_config import pendulum_td3_bc_config, pendulum_td3_bc_create_config # noqa
from dizoo.gym_hybrid.config.gym_hybrid_ddpg_config import gym_hybrid_ddpg_config, gym_hybrid_ddpg_create_config
from dizoo.gym_hybrid.config.gym_hybrid_pdqn_config import gym_hybrid_pdqn_config, gym_hybrid_pdqn_create_config
from dizoo.gym_hybrid.config.gym_hybrid_mpdqn_config import gym_hybrid_mpdqn_config, gym_hybrid_mpdqn_create_config
@pytest.mark.unittest
......@@ -88,6 +89,16 @@ def test_hybrid_pdqn():
assert False, "pipeline fail"
# @pytest.mark.unittest
def test_hybrid_mpdqn():
config = [deepcopy(gym_hybrid_mpdqn_config), deepcopy(gym_hybrid_mpdqn_create_config)]
config[0].policy.learn.update_per_collect = 1
try:
serial_pipeline(config, seed=0, max_iterations=1)
except Exception:
assert False, "pipeline fail"
@pytest.mark.unittest
def test_td3():
config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)]
......
......@@ -22,7 +22,9 @@ class PDQN(nn.Module):
head_hidden_size: Optional[int] = None,
head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None
norm_type: Optional[str] = None,
multi_pass: Optional[bool] = False,
action_mask: Optional[list] = None
) -> None:
r"""
Overview:
......@@ -40,18 +42,39 @@ class PDQN(nn.Module):
if ``None`` then default set it to ``nn.ReLU()``
- norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
``ding.torch_utils.fc_block`` for more details.
- multi_pass (:obj:`Optional[bool]`): Whether to use multi pass version.
- action_mask: (:obj:`Optional[list]`): An action mask indicating how action args are
associated to each discrete action. For example, if there are 3 discrete action,
4 continous action args, and the first discrete action associates with the first
continuous action args, the second discrete action associates with the second continuous
action args, and the third discrete action associates with the remaining 2 action args,
the action mask will be like: [[1,0,0,0],[0,1,0,0],[0,0,1,1]] with shape 3*4.
"""
super(PDQN, self).__init__()
self.multi_pass = multi_pass
if self.multi_pass:
assert isinstance(
action_mask, list
), 'Please indicate action mask in list form if you set multi_pass to True'
self.action_mask = torch.LongTensor(action_mask)
nonzero = torch.nonzero(self.action_mask)
index = torch.zeros(action_shape.action_args_shape).long()
index.scatter_(dim=0, index=nonzero[:, 1], src=nonzero[:, 0])
self.action_scatter_index = index # (self.action_args_shape, )
# squeeze obs input for compatibility: 1, (1, ), [4, 32, 32]
obs_shape = squeeze(obs_shape)
# squeeze action shape input like (3,) to 3
action_shape.action_args_shape = squeeze(action_shape.action_args_shape)
action_shape.action_type_shape = squeeze(action_shape.action_type_shape)
self.action_args_shape = action_shape.action_args_shape
self.action_type_shape = action_shape.action_type_shape
# init head hidden size
if head_hidden_size is None:
head_hidden_size = encoder_hidden_size_list[-1]
# squeeze obs input for compatibility: 1, (1, ), [4, 32, 32]
obs_shape = squeeze(obs_shape)
# Obs Encoder Type
if isinstance(obs_shape, int) or len(obs_shape) == 1: # FC Encoder
self.dis_encoder = FCEncoder(
......@@ -69,7 +92,7 @@ class PDQN(nn.Module):
)
else:
raise RuntimeError(
"not support obs_shape for pre-defined encoder: {}, please customize your own PDQN".format(obs_shape)
"Pre-defined encoder not support obs_shape {}, please customize your own PDQN.".format(obs_shape)
)
# Continuous Action Head Type
......@@ -142,8 +165,33 @@ class PDQN(nn.Module):
- 'action_args': the continuous action args(same as the inputs['action_args']) for later usage
"""
dis_x = self.encoder[0](inputs['state']) # size (B, encoded_state_shape)
action_args = inputs['action_args'] # size (B, action_args_shape
state_action_cat = torch.cat((dis_x, action_args), dim=-1) # size (B, encoded_state_shape + action_args_shape)
logit = self.actor_head[0](state_action_cat)['logit'] # size (B, action_type_shape)
action_args = inputs['action_args'] # size (B, action_args_shape)
if self.multi_pass: # mpdqn
# fill_value=-2 is a mask value, which is not in normal acton range
# (B, action_args_shape, K) where K is the action_type_shape
mp_action = torch.full(
(dis_x.shape[0], self.action_args_shape, self.action_type_shape),
fill_value=-2,
device=dis_x.device,
dtype=dis_x.dtype
)
index = self.action_scatter_index.view(1, -1, 1).repeat(dis_x.shape[0], 1, 1)
# index: (B, action_args_shape, 1) src: (B, action_args_shape, 1)
mp_action.scatter_(dim=-1, index=index, src=action_args.unsqueeze(-1))
mp_action = mp_action.permute(0, 2, 1) # (B, K, action_args_shape)
mp_state = dis_x.unsqueeze(1).repeat(1, self.action_type_shape, 1) # (B, K, obs_shape)
mp_state_action_cat = torch.cat([mp_state, mp_action], dim=-1)
logit = self.actor_head[0](mp_state_action_cat)['logit'] # (B, K, K)
logit = torch.diagonal(logit, dim1=-2, dim2=-1) # (B, K)
else: # pdqn
# size (B, encoded_state_shape + action_args_shape)
state_action_cat = torch.cat((dis_x, action_args), dim=-1)
logit = self.actor_head[0](state_action_cat)['logit'] # size (B, K) where K is action_type_shape
outputs = {'logit': logit, 'action_args': action_args}
return outputs
......@@ -2,7 +2,6 @@ import pytest
from easydict import EasyDict
import torch
from ding.model.template import PDQN
from ding.torch_utils import is_differentiable
@pytest.mark.unittest
......@@ -29,3 +28,29 @@ class TestPQQN:
else:
for i, s in enumerate(act_shape):
assert dis_outputs['logit'][i].shape == (B, s)
def test_mdqn(self):
T, B = 3, 4
obs_shape = (4, )
act_shape = EasyDict({'action_type_shape': 3, 'action_args_shape': 5})
if isinstance(obs_shape, int):
cont_inputs = torch.randn(B, obs_shape)
else:
cont_inputs = torch.randn(B, *obs_shape)
model = PDQN(
obs_shape, act_shape, multi_pass=True, action_mask=[[1, 1, 0, 0, 0], [0, 0, 1, 1, 1], [0, 0, 0, 0, 0]]
)
cont_outputs = model.forward(cont_inputs, mode='compute_continuous')
assert isinstance(cont_outputs, dict)
dis_inputs = {'state': cont_inputs, 'action_args': cont_outputs['action_args']}
dis_outputs = model.forward(dis_inputs, mode='compute_discrete')
assert isinstance(dis_outputs, dict)
if isinstance(act_shape['action_type_shape'], int):
assert dis_outputs['logit'].shape == (B, act_shape.action_type_shape)
elif len(act_shape['action_type_shape']) == 1:
assert dis_outputs['logit'].shape == (B, *act_shape.action_type_shape)
else:
for i, s in enumerate(act_shape):
assert dis_outputs['logit'][i].shape == (B, s)
from collections import namedtuple
from typing import List, Dict, TypeVar
from typing import List, Tuple, TypeVar
SequenceType = TypeVar('SequenceType', List, Dict, namedtuple)
SequenceType = TypeVar('SequenceType', List, Tuple, namedtuple)
from easydict import EasyDict
from ding.entry import serial_pipeline
gym_hybrid_mpdqn_config = dict(
exp_name='gym_hybrid_mpdqn_seed1',
env=dict(
collector_env_num=8,
evaluator_env_num=3,
# (bool) Scale output action into legal range [-1, 1].
act_scale=True,
env_id='Moving-v0', # ['Sliding-v0', 'Moving-v0']
n_evaluator_episode=5,
stop_value=1.5, # 1.85 for hybrid_ddpg
),
policy=dict(
cuda=True,
priority=False,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=False,
discount_factor=0.99,
nstep=1,
model=dict(
obs_shape=10,
action_shape=dict(
action_type_shape=3,
action_args_shape=2,
),
multi_pass=True,
action_mask=[[1, 0], [0, 1], [0, 0]],
),
learn=dict(
# (bool) Whether to use multi gpu
multi_gpu=False,
# How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=500, # 100, 10,
batch_size=320, # 32,
learning_rate_dis=3e-4, # 1e-5, 3e-4, alpha
learning_rate_cont=3e-4, # beta
target_theta=0.001, # 0.005,
# cont_update_freq=10,
# disc_update_freq=10,
update_circle=10,
),
# collect_mode config
collect=dict(
# (int) Only one of [n_sample, n_episode] shoule be set
n_sample=3200, # 128,
# (int) Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
noise_sigma=0.1, # 0.05,
collector=dict(collect_print_freq=1000, ),
),
eval=dict(evaluator=dict(eval_freq=1000, ), ),
# other config
other=dict(
# Epsilon greedy with decay.
eps=dict(
# (str) Decay type. Support ['exp', 'linear'].
type='exp',
start=1, # 0.95,
end=0.1, # 0.05,
# (int) Decay length(env step)
decay=int(1e5),
),
replay_buffer=dict(replay_buffer_size=int(1e6), ),
),
)
)
gym_hybrid_mpdqn_config = EasyDict(gym_hybrid_mpdqn_config)
main_config = gym_hybrid_mpdqn_config
gym_hybrid_mpdqn_create_config = dict(
env=dict(
type='gym_hybrid',
import_names=['dizoo.gym_hybrid.envs.gym_hybrid_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='pdqn'),
)
gym_hybrid_mpdqn_create_config = EasyDict(gym_hybrid_mpdqn_create_config)
create_config = gym_hybrid_mpdqn_create_config
if __name__ == "__main__":
serial_pipeline([main_config, create_config], seed=1)
......@@ -2,15 +2,12 @@ from easydict import EasyDict
from ding.entry import serial_pipeline
gym_hybrid_pdqn_config = dict(
# exp_name='gym_hybrid_pdqn_dataaction_1encoder_lrd3e-4_lrc1e-3_upc10_auf100_seed0',
# exp_name='gym_hybrid_pdqn_dataaction_1encoder_lrd3e-4_lrc3e-4_upc100_uc10v2_seed0',
# exp_name='gym_hybrid_pdqn_dataaction_1encoder_lrd3e-4_lrc3e-4_upc500_uc10v2_seed0',
exp_name='gym_hybrid_pdqn_dataaction_1encoder_lrd3e-4_lrc3e-4_upc500_ed1e5_rbs1e6_uc10v2_seed1',
exp_name='gym_hybrid_pdqn_seed1',
# exp_name='gym_hybrid_pdqn_dataaction_1encoder_lrd1e-5_lrc1e-3_upc100_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
evaluator_env_num=3,
# (bool) Scale output action into legal range [-1, 1].
act_scale=True,
env_id='Moving-v0', # ['Sliding-v0', 'Moving-v0']
......@@ -30,6 +27,8 @@ gym_hybrid_pdqn_config = dict(
action_type_shape=3,
action_args_shape=2,
),
# multi_pass=True,
# action_mask=[[1,0],[0,1],[0,0]],
),
learn=dict(
# (bool) Whether to use multi gpu
......@@ -39,7 +38,7 @@ gym_hybrid_pdqn_config = dict(
# collect data -> update policy-> collect data -> ...
update_per_collect=500, # 100, 10,
batch_size=320, # 32,
learning_rate_dis=3e-4, # 1e-5,#3e-4, alpha
learning_rate_dis=3e-4, # 1e-5, 3e-4, alpha
learning_rate_cont=3e-4, # beta
target_theta=0.001, # 0.005,
# cont_update_freq=10,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册