提交 0453f9cc 编写于 作者: N niuyazhe

hotfix(nyz): fix c51 head dimension mismatch bug and ppo config mismatch bug

上级 9929dc37
......@@ -105,7 +105,7 @@ class C51DQN(nn.Module):
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
encoder_hidden_size_list: SequenceType = [128, 128, 64],
head_hidden_size: int = 64,
head_hidden_size: int = None,
head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
......@@ -132,6 +132,8 @@ class C51DQN(nn.Module):
super(C51DQN, self).__init__()
# For compatibility: 1, (1, ), [4, 32, 32]
obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
if head_hidden_size is None:
head_hidden_size = encoder_hidden_size_list[-1]
# FC Encoder
if isinstance(obs_shape, int) or len(obs_shape) == 1:
self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
......
......@@ -20,6 +20,8 @@ pong_a2c_config = dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
),
learn=dict(
update_per_collect=1,
......
......@@ -70,6 +70,4 @@ pong_ppg_create_config = dict(
policy=dict(type='ppg'),
)
create_config = EasyDict(pong_ppg_create_config)
if __name__ == '__main__':
serial_pipeline((main_config, create_config), seed=0)
# PPG needs to use specific entry, like `cartpole_ppg_main.py`
......@@ -20,6 +20,8 @@ pong_ppo_config = dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
),
learn=dict(
update_per_collect=24,
......@@ -44,7 +46,6 @@ pong_ppo_config = dict(
other=dict(replay_buffer=dict(
replay_buffer_size=100000,
max_use=3,
min_sample_ratio=1,
), ),
),
)
......@@ -56,7 +57,7 @@ pong_ppo_create_config = dict(
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo'),
policy=dict(type='ppo_offpolicy'),
)
create_config = EasyDict(pong_ppo_create_config)
......
......@@ -59,7 +59,7 @@ qbert_ppo_create_config = dict(
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo'),
policy=dict(type='ppo_offpolicy'),
)
create_config = EasyDict(qbert_ppo_create_config)
......
......@@ -58,7 +58,7 @@ space_invaders_ppo_create_config = dict(
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo'),
policy=dict(type='ppo_offpolicy'),
)
create_config = EasyDict(space_invaders_ppo_create_config)
......
......@@ -41,7 +41,7 @@ lunarlander_ppo_create_config = dict(
import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo'),
policy=dict(type='ppo_offpolicy'),
)
lunarlander_ppo_create_config = EasyDict(lunarlander_ppo_create_config)
create_config = lunarlander_ppo_create_config
......
......@@ -48,7 +48,7 @@ lunarlander_ppo_rnd_create_config = dict(
import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo'),
policy=dict(type='ppo_offpolicy'),
reward_model=dict(type='rnd')
)
lunarlander_ppo_rnd_create_config = EasyDict(lunarlander_ppo_rnd_create_config)
......
......@@ -42,7 +42,7 @@ minigrid_ppo_create_config = dict(
import_names=['dizoo.minigrid.envs.minigrid_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo'),
policy=dict(type='ppo_offpolicy'),
)
minigrid_ppo_create_config = EasyDict(minigrid_ppo_create_config)
create_config = minigrid_ppo_create_config
......
......@@ -49,7 +49,7 @@ minigrid_ppo_rnd_create_config = dict(
import_names=['dizoo.minigrid.envs.minigrid_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo'),
policy=dict(type='ppo_offpolicy'),
reward_model=dict(type='rnd'),
)
minigrid_ppo_rnd_create_config = EasyDict(minigrid_ppo_rnd_create_config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册