diff --git a/ding/policy/ppg.py b/ding/policy/ppg.py index c2c310310e876832853b312a9832292db993359b..1231ca105895b4908a9854ff768d85d4b2847166 100644 --- a/ding/policy/ppg.py +++ b/ding/policy/ppg.py @@ -398,7 +398,7 @@ class PPGPolicy(Policy): data[-1]['done'], gamma=self._gamma, gae_lambda=self._gae_lambda, - cuda=self._cuda, + cuda=False, ) data = get_train_sample(data, self._unroll_len) for d in data: diff --git a/ding/worker/collector/interaction_serial_evaluator.py b/ding/worker/collector/interaction_serial_evaluator.py index 36d904779274cab209a9f75d623c0a276be8296b..34ef1b1db9894b5f3bdeb72d5c75e75a38b40616 100644 --- a/ding/worker/collector/interaction_serial_evaluator.py +++ b/ding/worker/collector/interaction_serial_evaluator.py @@ -22,7 +22,7 @@ class InteractionSerialEvaluator(ISerialEvaluator): config = dict( # Evaluate every "eval_freq" training iterations. - eval_freq=50, + eval_freq=1000, ) def __init__( diff --git a/dizoo/atari/config/serial/pong/pong_ppg_config.py b/dizoo/atari/config/serial/pong/pong_ppg_config.py index c20796973778400c790ac41a9519f916b1821e8e..fee84f79891dbb927fe17294981c0815c4195976 100644 --- a/dizoo/atari/config/serial/pong/pong_ppg_config.py +++ b/dizoo/atari/config/serial/pong/pong_ppg_config.py @@ -20,6 +20,8 @@ pong_ppg_config = dict( obs_shape=[4, 84, 84], action_shape=6, encoder_hidden_size_list=[64, 64, 128], + critic_head_hidden_size=128, + actor_head_hidden_size=128, ), learn=dict( update_per_collect=24, @@ -46,14 +48,14 @@ pong_ppg_config = dict( eval=dict(evaluator=dict(eval_freq=1000, )), other=dict( replay_buffer=dict( - buffer_name=['policy', 'value'], + multi_buffer=True, policy=dict( replay_buffer_size=100000, max_use=3, ), value=dict( replay_buffer_size=100000, - max_use=3, + max_use=5, ), ), ), diff --git a/dizoo/atari/config/serial/qbert/qbert_ppg_config.py b/dizoo/atari/config/serial/qbert/qbert_ppg_config.py index 1f223bb989211a8ef4fccd733b362ea0f7ee891a..a58945b91b37901abe290e105b56cb55f3ced264 100644 --- a/dizoo/atari/config/serial/qbert/qbert_ppg_config.py +++ b/dizoo/atari/config/serial/qbert/qbert_ppg_config.py @@ -49,14 +49,14 @@ qbert_ppg_config = dict( eval=dict(evaluator=dict(eval_freq=1000, )), other=dict( replay_buffer=dict( - buffer_name=['policy', 'value'], + multi_buffer=True, policy=dict( replay_buffer_size=100000, max_use=3, ), value=dict( replay_buffer_size=100000, - max_use=3, + max_use=10, ), ), ), diff --git a/dizoo/atari/config/serial/space_invaders/c51_f1/config.py b/dizoo/atari/config/serial/spaceinvaders/c51_f1/config.py similarity index 100% rename from dizoo/atari/config/serial/space_invaders/c51_f1/config.py rename to dizoo/atari/config/serial/spaceinvaders/c51_f1/config.py diff --git a/dizoo/atari/config/serial/space_invaders/dqn_f1/config.py b/dizoo/atari/config/serial/spaceinvaders/dqn_f1/config.py similarity index 100% rename from dizoo/atari/config/serial/space_invaders/dqn_f1/config.py rename to dizoo/atari/config/serial/spaceinvaders/dqn_f1/config.py diff --git a/dizoo/atari/config/serial/space_invaders/iqn_f1/config.py b/dizoo/atari/config/serial/spaceinvaders/iqn_f1/config.py similarity index 100% rename from dizoo/atari/config/serial/space_invaders/iqn_f1/config.py rename to dizoo/atari/config/serial/spaceinvaders/iqn_f1/config.py diff --git a/dizoo/atari/config/serial/space_invaders/qrdqn_f1/config.py b/dizoo/atari/config/serial/spaceinvaders/qrdqn_f1/config.py similarity index 100% rename from dizoo/atari/config/serial/space_invaders/qrdqn_f1/config.py rename to dizoo/atari/config/serial/spaceinvaders/qrdqn_f1/config.py diff --git a/dizoo/atari/config/serial/space_invaders/spaceinvaders_a2c_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_a2c_config.py similarity index 100% rename from dizoo/atari/config/serial/space_invaders/spaceinvaders_a2c_config.py rename to dizoo/atari/config/serial/spaceinvaders/spaceinvaders_a2c_config.py diff --git a/dizoo/atari/config/serial/space_invaders/spaceinvaders_acer_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_acer_config.py similarity index 100% rename from dizoo/atari/config/serial/space_invaders/spaceinvaders_acer_config.py rename to dizoo/atari/config/serial/spaceinvaders/spaceinvaders_acer_config.py diff --git a/dizoo/atari/config/serial/space_invaders/spaceinvaders_c51_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_c51_config.py similarity index 100% rename from dizoo/atari/config/serial/space_invaders/spaceinvaders_c51_config.py rename to dizoo/atari/config/serial/spaceinvaders/spaceinvaders_c51_config.py diff --git a/dizoo/atari/config/serial/space_invaders/spaceinvaders_dqn_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_config.py similarity index 100% rename from dizoo/atari/config/serial/space_invaders/spaceinvaders_dqn_config.py rename to dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_config.py diff --git a/dizoo/atari/config/serial/space_invaders/spaceinvaders_impala_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_impala_config.py similarity index 100% rename from dizoo/atari/config/serial/space_invaders/spaceinvaders_impala_config.py rename to dizoo/atari/config/serial/spaceinvaders/spaceinvaders_impala_config.py diff --git a/dizoo/atari/config/serial/space_invaders/spaceinvaders_iqn_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_iqn_config.py similarity index 100% rename from dizoo/atari/config/serial/space_invaders/spaceinvaders_iqn_config.py rename to dizoo/atari/config/serial/spaceinvaders/spaceinvaders_iqn_config.py diff --git a/dizoo/atari/config/serial/space_invaders/spaceinvaders_ppg_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_ppg_config.py similarity index 86% rename from dizoo/atari/config/serial/space_invaders/spaceinvaders_ppg_config.py rename to dizoo/atari/config/serial/spaceinvaders/spaceinvaders_ppg_config.py index 9c7fdf5aac121afc1bcabf14ac7b621632ab8b25..dfe6cbb2ec7b9d1c2f13c8fe08387a873206fbc3 100644 --- a/dizoo/atari/config/serial/space_invaders/spaceinvaders_ppg_config.py +++ b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_ppg_config.py @@ -2,7 +2,7 @@ from copy import deepcopy from ding.entry import serial_pipeline from easydict import EasyDict -space_invaders_ppg_config = dict( +spaceinvaders_ppg_config = dict( env=dict( collector_env_num=16, evaluator_env_num=8, @@ -49,22 +49,22 @@ space_invaders_ppg_config = dict( eval=dict(evaluator=dict(eval_freq=1000, )), other=dict( replay_buffer=dict( - buffer_name=['policy', 'value'], + multi_buffer=True, policy=dict( replay_buffer_size=100000, max_use=3, ), value=dict( replay_buffer_size=100000, - max_use=3, + max_use=10, ), ), ), ), ) -main_config = EasyDict(space_invaders_ppg_config) +main_config = EasyDict(spaceinvaders_ppg_config) -space_invaders_ppg_create_config = dict( +spaceinvaders_ppg_create_config = dict( env=dict( type='atari', import_names=['dizoo.atari.envs.atari_env'], @@ -72,7 +72,4 @@ space_invaders_ppg_create_config = dict( env_manager=dict(type='subprocess'), policy=dict(type='ppg'), ) -create_config = EasyDict(space_invaders_ppg_create_config) - -if __name__ == '__main__': - serial_pipeline((main_config, create_config), seed=0) +create_config = EasyDict(spaceinvaders_ppg_create_config) diff --git a/dizoo/atari/config/serial/space_invaders/spaceinvaders_ppo_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_ppo_config.py similarity index 100% rename from dizoo/atari/config/serial/space_invaders/spaceinvaders_ppo_config.py rename to dizoo/atari/config/serial/spaceinvaders/spaceinvaders_ppo_config.py diff --git a/dizoo/atari/config/serial/space_invaders/spaceinvaders_qrdqn_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_qrdqn_config.py similarity index 100% rename from dizoo/atari/config/serial/space_invaders/spaceinvaders_qrdqn_config.py rename to dizoo/atari/config/serial/spaceinvaders/spaceinvaders_qrdqn_config.py diff --git a/dizoo/atari/config/serial/space_invaders/spaceinvaders_rainbow_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_rainbow_config.py similarity index 100% rename from dizoo/atari/config/serial/space_invaders/spaceinvaders_rainbow_config.py rename to dizoo/atari/config/serial/spaceinvaders/spaceinvaders_rainbow_config.py diff --git a/dizoo/atari/config/serial/space_invaders/spaceinvaders_sqil_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_sqil_config.py similarity index 100% rename from dizoo/atari/config/serial/space_invaders/spaceinvaders_sqil_config.py rename to dizoo/atari/config/serial/spaceinvaders/spaceinvaders_sqil_config.py diff --git a/dizoo/atari/config/serial/space_invaders/spaceinvaders_sql_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_sql_config.py similarity index 100% rename from dizoo/atari/config/serial/space_invaders/spaceinvaders_sql_config.py rename to dizoo/atari/config/serial/spaceinvaders/spaceinvaders_sql_config.py diff --git a/dizoo/atari/entry/__init__.py b/dizoo/atari/entry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dizoo/atari/entry/atari_ppg_main.py b/dizoo/atari/entry/atari_ppg_main.py new file mode 100644 index 0000000000000000000000000000000000000000..c51d2a1ede488332b9fd2983fb76bb4f695ee155 --- /dev/null +++ b/dizoo/atari/entry/atari_ppg_main.py @@ -0,0 +1,78 @@ +import os +import gym +from tensorboardX import SummaryWriter +from easydict import EasyDict +from copy import deepcopy +from functools import partial + +from ding.config import compile_config +from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer +from ding.envs import SyncSubprocessEnvManager +from ding.policy import PPGPolicy +from ding.model import PPG +from ding.utils import set_pkg_seed, deep_merge_dicts +from dizoo.atari.envs import AtariEnv +from dizoo.atari.config.serial.spaceinvaders.spaceinvaders_ppg_config import spaceinvaders_ppg_config + + +def main(cfg, seed=0, max_iterations=int(1e10)): + cfg.exp_name = 'spaceinvaders_ppg_seed0' + cfg = compile_config( + cfg, + SyncSubprocessEnvManager, + PPGPolicy, + BaseLearner, + SampleSerialCollector, + InteractionSerialEvaluator, { + 'policy': AdvancedReplayBuffer, + 'value': AdvancedReplayBuffer + }, + save_cfg=True + ) + collector_env_cfg = AtariEnv.create_collector_env_cfg(cfg.env) + evaluator_env_cfg = AtariEnv.create_evaluator_env_cfg(cfg.env) + collector_env = SyncSubprocessEnvManager(env_fn=[partial(AtariEnv, cfg=c) for c in collector_env_cfg], cfg=cfg.env.manager) + evaluator_env = SyncSubprocessEnvManager(env_fn=[partial(AtariEnv, cfg=c) for c in evaluator_env_cfg], cfg=cfg.env.manager) + + collector_env.seed(seed) + evaluator_env.seed(seed, dynamic_seed=False) + set_pkg_seed(seed, use_cuda=cfg.policy.cuda) + + model = PPG(**cfg.policy.model) + policy = PPGPolicy(cfg.policy, model=model) + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + collector = SampleSerialCollector( + cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name + ) + evaluator = InteractionSerialEvaluator( + cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name + ) + policy_buffer = AdvancedReplayBuffer( + cfg.policy.other.replay_buffer.policy, tb_logger, exp_name=cfg.exp_name, instance_name='policy_buffer' + ) + value_buffer = AdvancedReplayBuffer( + cfg.policy.other.replay_buffer.value, tb_logger, exp_name=cfg.exp_name, instance_name='value_buffer' + ) + + while True: + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + new_data = collector.collect(train_iter=learner.train_iter) + policy_buffer.push(new_data, cur_collector_envstep=collector.envstep) + value_buffer.push(deepcopy(new_data), cur_collector_envstep=collector.envstep) + for i in range(cfg.policy.learn.update_per_collect): + batch_size = learner.policy.get_attribute('batch_size') + policy_data = policy_buffer.sample(batch_size['policy'], learner.train_iter) + value_data = value_buffer.sample(batch_size['value'], learner.train_iter) + if policy_data is not None and value_data is not None: + train_data = {'policy': policy_data, 'value': value_data} + learner.train(train_data, collector.envstep) + policy_buffer.clear() + value_buffer.clear() + + +if __name__ == "__main__": + main(EasyDict(spaceinvaders_ppg_config))