From 8dae34cb0499d089c00a97a55c321873834d5e54 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Thu, 25 Nov 2021 22:51:07 +0800 Subject: [PATCH] refactor(nyz): add basic new sil entry(ci skip) --- ding/entry/serial_entry_sil.py | 78 ++++++++++--------- ding/model/wrapper/model_wrappers.py | 2 +- ding/policy/a2c.py | 4 +- ding/policy/sil.py | 12 +-- .../config/lunarlander_a2c_config.py | 54 +++++++++++++ .../config/lunarlander_a2c_sil_config.py | 39 +++++----- 6 files changed, 125 insertions(+), 64 deletions(-) create mode 100644 dizoo/box2d/lunarlander/config/lunarlander_a2c_config.py diff --git a/ding/entry/serial_entry_sil.py b/ding/entry/serial_entry_sil.py index 6796bea..f876fac 100644 --- a/ding/entry/serial_entry_sil.py +++ b/ding/entry/serial_entry_sil.py @@ -6,7 +6,7 @@ from functools import partial from tensorboardX import SummaryWriter from ding.envs import get_vec_env_setting, create_env_manager -from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \ +from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \ create_serial_collector from ding.config import read_config, compile_config from ding.policy import create_policy, PolicyFactory, create_sil @@ -58,7 +58,12 @@ def serial_pipeline_sil( # Create worker components: learner, collector, evaluator, replay buffer, commander. tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) - learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + base_learner = BaseLearner( + cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name, instance_name='base_learner' + ) + sil_learner = BaseLearner( + cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name, instance_name='sil_learner' + ) collector = create_serial_collector( cfg.policy.collect.collector, env=collector_env, @@ -71,64 +76,67 @@ def serial_pipeline_sil( ) replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name) commander = BaseSerialCommander( - cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode + cfg.policy.other.commander, base_learner, collector, evaluator, replay_buffer, policy.command_mode ) # ========== # Main loop # ========== # Learner's before_run hook. - learner.call_hook('before_run') - new_ptr = old_ptr = 0 + base_learner.call_hook('before_run') + # Accumulate plenty of data at the beginning of training. if cfg.policy.get('random_collect_size', 0) > 0: - action_space = collector_env.env_info().act_space - random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space) - collector.reset_policy(policy.collect_mode) + if cfg.policy.get('transition_with_policy_data', False): + collector.reset_policy(policy.collect_mode) + else: + action_space = collector_env.env_info().act_space + random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space) + collector.reset_policy(random_policy) collect_kwargs = commander.step() - new_data = collector.collect(n_episode=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs) + new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs) replay_buffer.push(new_data, cur_collector_envstep=0) collector.reset_policy(policy.collect_mode) - new_ptr += len(new_data) for _ in range(max_iterations): collect_kwargs = commander.step() # Evaluate policy performance - if evaluator.should_eval(learner.train_iter): - stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if evaluator.should_eval(base_learner.train_iter): + stop, reward = evaluator.eval(base_learner.save_checkpoint, base_learner.train_iter, collector.envstep) if stop: break # Collect data by default config n_sample/n_episode - new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + new_data = collector.collect(train_iter=base_learner.train_iter, policy_kwargs=collect_kwargs) replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) - new_ptr += len(new_data) # Learn policy from collected data - for i in range(cfg.policy.learn.update_per_collect): - # Learner will train ``update_per_collect`` times in one iteration. - if cfg.policy.on_policy: - train_data_base_policy = replay_buffer.sample( - learner.policy.get_attribute('batch_size'), - learner.train_iter, - sample_range=slice(old_ptr - new_ptr, None) - ) - else: + if cfg.policy.on_policy: + train_data_base_policy = replay_buffer.sample( + base_learner.policy.get_attribute('batch_size'), + base_learner.train_iter, + sample_range=slice(-len(new_data), None) + ) + base_learner.train(train_data_base_policy) + else: + for i in range(cfg.policy.learn.update_per_collect): + # Learner will train ``update_per_collect`` times in one iteration. train_data_base_policy = replay_buffer.sample( - learner.policy.get_attribute('batch_size'), learner.train_iter + base_learner.policy.get_attribute('batch_size'), base_learner.train_iter ) - train_data_sil = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) - if train_data_base_policy is None or train_data_sil is None: + base_learner.train(train_data_base_policy) + if base_learner.policy.get_attribute('priority'): + replay_buffer.update(base_learner.priority_info) + for i in range(cfg.policy.other.sil.update_per_collect): + train_data_sil = replay_buffer.sample( + cfg.policy.other.sil.n_episode_per_train, sil_learner.train_iter, groupby='episode' + ) + train_data_sil = policy.process_sil_data(train_data_sil) + if train_data_sil is None: # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times logging.warning( - "Replay buffer's data can only train for {} steps. ".format(i) + + "Replay buffer's data can only train for sil {} steps. ".format(i) + "You can modify data collect config, e.g. increasing n_sample, n_episode." ) break - learner.train({'base_policy': train_data_base_policy, 'sil': train_data_sil}, collector.envstep) - if learner.policy.get_attribute('priority'): - replay_buffer.update(learner.priority_info) - if cfg.policy.on_policy: - # On-policy algorithm must clear the replay buffer. - # replay_buffer.clear() - old_ptr = new_ptr + sil_learner.train(train_data_sil) # Learner's after_run hook. - learner.call_hook('after_run') + base_learner.call_hook('after_run') return policy diff --git a/ding/model/wrapper/model_wrappers.py b/ding/model/wrapper/model_wrappers.py index b53d527..66353df 100644 --- a/ding/model/wrapper/model_wrappers.py +++ b/ding/model/wrapper/model_wrappers.py @@ -320,7 +320,7 @@ class EpsGreedyMultinomialSampleWrapper(IModelWrapper): def forward(self, *args, **kwargs): eps = kwargs.pop('eps') - alpha = kwargs.pop('alpha') + alpha = kwargs.pop('alpha', 1) output = self._model.forward(*args, **kwargs) assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) logit = output['logit'] diff --git a/ding/policy/a2c.py b/ding/policy/a2c.py index ff54410..fd6c28f 100644 --- a/ding/policy/a2c.py +++ b/ding/policy/a2c.py @@ -170,7 +170,7 @@ class A2CPolicy(Policy): self._gamma = self._cfg.collect.discount_factor self._gae_lambda = self._cfg.collect.gae_lambda - def _forward_collect(self, data: dict) -> dict: + def _forward_collect(self, data: dict, **kwargs) -> dict: r""" Overview: Forward function of collect mode. @@ -188,7 +188,7 @@ class A2CPolicy(Policy): data = to_device(data, self._device) self._collect_model.eval() with torch.no_grad(): - output = self._collect_model.forward(data, mode='compute_actor_critic') + output = self._collect_model.forward(data, mode='compute_actor_critic', **kwargs) if self._cuda: output = to_device(output, 'cpu') output = default_decollate(output) diff --git a/ding/policy/sil.py b/ding/policy/sil.py index 883052a..448d428 100644 --- a/ding/policy/sil.py +++ b/ding/policy/sil.py @@ -17,12 +17,14 @@ def create_sil(policy: Policy, cfg): return sil_policy -class SIL(Policy): +class SILPolicy(Policy): r""" Overview: Policy class of SIL algorithm. """ - sil_config = dict( + config = dict( + update_per_collect=10, + n_episode_per_train=4, value_weight=0.5, learning_rate=0.001, betas=(0.9, 0.999), @@ -32,8 +34,8 @@ class SIL(Policy): def __init__(self, policy: Policy, cfg): self.base_policy = policy self._model = policy._model - cfg.policy.other.sil = deep_merge_dicts(self.sil_config, cfg.policy.other.sil) - super(SIL, self).__init__(cfg.policy, model=policy._model, enable_field=policy._enable_field) + cfg.policy.other.sil = deep_merge_dicts(self.config, cfg.policy.other.sil) + super(SILPolicy, self).__init__(cfg.policy, model=policy._model, enable_field=policy._enable_field) def _init_learn(self) -> None: r""" @@ -184,5 +186,5 @@ class SIL(Policy): ] -class SILCommand(SIL, DummyCommandModePolicy): +class SILCommand(SILPolicy, DummyCommandModePolicy): pass diff --git a/dizoo/box2d/lunarlander/config/lunarlander_a2c_config.py b/dizoo/box2d/lunarlander/config/lunarlander_a2c_config.py new file mode 100644 index 0000000..0a41503 --- /dev/null +++ b/dizoo/box2d/lunarlander/config/lunarlander_a2c_config.py @@ -0,0 +1,54 @@ +from ding.entry.serial_entry_onpolicy import serial_pipeline_onpolicy +from easydict import EasyDict + +lunarlander_a2c_config = dict( + exp_name='lunarlander_a2c_seed0', + env=dict( + collector_env_num=4, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=200, + ), + policy=dict( + cuda=False, + model=dict( + obs_shape=8, + action_shape=4, + encoder_hidden_size_list=[128, 64], + share_encoder=False, + ), + learn=dict( + batch_size=64, + # (bool) Whether to normalize advantage. Default to False. + adv_norm=False, + learning_rate=0.001, + # (float) loss weight of the value network, the weight of policy network is set to 1 + value_weight=0.1, + # (float) loss weight of the entropy regularization, the weight of policy network is set to 1 + entropy_weight=0.00001, + ), + collect=dict( + # (int) collect n_sample data, train model n_iteration times + n_sample=64, + # (float) the trade-off factor lambda to balance 1step td and mc + gae_lambda=0.95, + discount_factor=0.995, + ), + ), +) +lunarlander_a2c_config = EasyDict(lunarlander_a2c_config) +main_config = lunarlander_a2c_config + +lunarlander_a2c_create_config = dict( + env=dict( + type='lunarlander', + import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='a2c'), +) +lunarlander_a2c_create_config = EasyDict(lunarlander_a2c_create_config) +create_config = lunarlander_a2c_create_config + +if __name__ == '__main__': + serial_pipeline_onpolicy((main_config, create_config), seed=0) diff --git a/dizoo/box2d/lunarlander/config/lunarlander_a2c_sil_config.py b/dizoo/box2d/lunarlander/config/lunarlander_a2c_sil_config.py index 0903c6f..01d197b 100644 --- a/dizoo/box2d/lunarlander/config/lunarlander_a2c_sil_config.py +++ b/dizoo/box2d/lunarlander/config/lunarlander_a2c_sil_config.py @@ -2,47 +2,45 @@ from ding.entry.serial_entry_sil import serial_pipeline_sil from easydict import EasyDict lunarlander_a2c_config = dict( - exp_name='lunarlander_a2c', + exp_name='lunarlander_a2c_sil_seed0', env=dict( - collector_env_num=8, + collector_env_num=4, evaluator_env_num=5, n_evaluator_episode=5, - stop_value=195, + stop_value=200, ), policy=dict( - on_policy=True, cuda=False, - # (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same) model=dict( obs_shape=8, action_shape=4, - encoder_hidden_size_list=[512, 64], + encoder_hidden_size_list=[128, 64], + share_encoder=False, ), learn=dict( batch_size=64, # (bool) Whether to normalize advantage. Default to False. - unroll_len=1, - normalize_advantage=False, + adv_norm=False, learning_rate=0.001, # (float) loss weight of the value network, the weight of policy network is set to 1 - value_weight=0.5, + value_weight=0.1, # (float) loss weight of the entropy regularization, the weight of policy network is set to 1 - entropy_weight=1.0, + entropy_weight=0.00001, ), collect=dict( - collector=dict( - type='episode', - get_train_sample=True, - ), # (int) collect n_sample data, train model n_iteration times - n_episode=8, + n_sample=64, # (float) the trade-off factor lambda to balance 1step td and mc gae_lambda=0.95, + discount_factor=0.995, ), - other=dict(sil=dict( - value_weight=0.5, - learning_rate=0.0001, - ), replay_buffer=dict(replay_buffer_size=200000, )), + other=dict( + replay_buffer=dict(replay_buffer_size=100000, ), + sil=dict( + update_per_collect=10, + n_episode_per_train=8, + ), + ) ), ) lunarlander_a2c_config = EasyDict(lunarlander_a2c_config) @@ -55,11 +53,10 @@ lunarlander_a2c_create_config = dict( ), env_manager=dict(type='subprocess'), policy=dict(type='a2c'), + replay_buffer=dict(type='deque'), ) lunarlander_a2c_create_config = EasyDict(lunarlander_a2c_create_config) create_config = lunarlander_a2c_create_config if __name__ == '__main__': - from ding.entry import serial_entry_sil - serial_pipeline_sil((main_config, create_config), seed=0) -- GitLab