diff --git a/README.md b/README.md index 4924195e306f5b28e95a5c7ed7cdef9d145f7706..a2bfecb8d9698a24f70e8e008b35a3bbb5363683 100644 --- a/README.md +++ b/README.md @@ -120,12 +120,13 @@ ding -m serial -e cartpole -p dqn -s 0 | 20 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 | | 21 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_reward_model -c cartpole_dqn_config.py -s 0 | | 22 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 | -| 23 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py | -| 24 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py | -| 25 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py | -| 26 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` | -| 27 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` | -| 28 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py | +| 23 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![IL](https://img.shields.io/badge/-discrete-brightgreen) | [policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 | +| 24 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py | +| 25 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py | +| 26 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py | +| 27 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` | +| 28 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` | +| 29 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space, which is only label in normal DRL algorithms(1-15) diff --git a/ding/entry/__init__.py b/ding/entry/__init__.py index 53fb9909adfd93036266f6c85cde16baad4796b7..33065a3436365c23f20405bb12ab6fcd8270082d 100644 --- a/ding/entry/__init__.py +++ b/ding/entry/__init__.py @@ -4,5 +4,7 @@ from .serial_entry_onpolicy import serial_pipeline_onpolicy from .serial_entry_offline import serial_pipeline_offline from .serial_entry_il import serial_pipeline_il from .serial_entry_reward_model import serial_pipeline_reward_model +from .serial_entry_dqfd import serial_pipeline_dqfd +from .serial_entry_sqil import serial_pipeline_sqil from .parallel_entry import parallel_pipeline from .application_entry import eval, collect_demo_data diff --git a/ding/entry/cli.py b/ding/entry/cli.py index cc9df3e48add3b3979c59d79789ccb5e3fd17121..69ee3b3c06e0301da7ff92f864ade3c4ce8b7c71 100644 --- a/ding/entry/cli.py +++ b/ding/entry/cli.py @@ -52,7 +52,7 @@ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) @click.option( '-m', '--mode', - type=click.Choice(['serial', 'serial_onpolicy', 'serial_sqil', 'parallel', 'dist', 'eval']), + type=click.Choice(['serial', 'serial_onpolicy', 'serial_sqil', 'serial_dqfd', 'parallel', 'dist', 'eval']), help='serial-train or parallel-train or dist-train or eval' ) @click.option('-c', '--config', type=str, help='Path to DRL experiment config') @@ -157,6 +157,12 @@ def cli( config = get_predefined_config(env, policy) expert_config = input("Enter the name of the config you used to generate your expert model: ") serial_pipeline_sqil(config, expert_config, seed, max_iterations=train_iter) + elif mode == 'serial_dqfd': + from .serial_entry_dqfd import serial_pipeline_dqfd + if config is None: + config = get_predefined_config(env, policy) + expert_config = input("Enter the name of the config you used to generate your expert model: ") + serial_pipeline_dqfd(config, expert_config, seed, max_iterations=train_iter) elif mode == 'parallel': from .parallel_entry import parallel_pipeline parallel_pipeline(config, seed, enable_total_log, disable_flask_log) diff --git a/ding/entry/serial_entry_dqfd.py b/ding/entry/serial_entry_dqfd.py new file mode 100644 index 0000000000000000000000000000000000000000..5ce033889750990b20647cb3f1b13061c53f50d7 --- /dev/null +++ b/ding/entry/serial_entry_dqfd.py @@ -0,0 +1,455 @@ +#from ding.policy.base_policy import Policy +from typing import Union, Optional, List, Any, Tuple +import os +import torch +import numpy as np +import logging +from functools import partial +from tensorboardX import SummaryWriter + +from ding.envs import get_vec_env_setting, create_env_manager +from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \ + create_serial_collector +from ding.config import read_config, compile_config +from ding.policy import create_policy, PolicyFactory +from ding.utils import set_pkg_seed +from ding.model import DQN +from copy import deepcopy +from dizoo.classic_control.cartpole.config.cartpole_dqfd_config import main_config, create_config # for testing +#from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config_1, create_config_1 # for testing + + +def serial_pipeline_dqfd( + input_cfg: Union[str, Tuple[dict, dict]], + expert_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + env_setting: Optional[List[Any]] = None, + model: Optional[torch.nn.Module] = None, + expert_model: Optional[torch.nn.Module] = None, + max_iterations: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + Serial pipeline dqfd entry: we create this serial pipeline in order to\ + implement dqfd in DI-engine. For now, we support the following envs\ + Cartpole, Lunarlander, Pong, Spaceinvader. The demonstration\ + data come from the expert model. We use a well-trained model to \ + generate demonstration data online + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ + ``BaseEnv`` subclass, collector env config, and evaluator env config. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - expert_model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.\ + The default model is DQN(**cfg.policy.model) + - max_iterations (:obj:`Optional[torch.nn.Module]`): Learner's max iteration. Pipeline will stop \ + when reaching this iteration. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + expert_cfg, expert_create_cfg = read_config(expert_cfg) + else: + cfg, create_cfg = input_cfg + expert_cfg, expert_create_cfg = expert_cfg + create_cfg.policy.type = create_cfg.policy.type + '_command' + expert_create_cfg.policy.type = expert_create_cfg.policy.type + '_command' + env_fn = None if env_setting is None else env_setting[0] + cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) + expert_cfg = compile_config( + expert_cfg, seed=seed, env=env_fn, auto=True, create_cfg=expert_create_cfg, save_cfg=True + ) + # Create main components: env, policy + if env_setting is None: + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + else: + env_fn, collector_env_cfg, evaluator_env_cfg = env_setting + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + expert_collector_env = create_env_manager( + expert_cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg] + ) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + expert_collector_env.seed(cfg.seed) + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + #expert_model = DQN(**cfg.policy.model) + expert_policy = create_policy(expert_cfg.policy, model=expert_model, enable_field=['collect', 'command']) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + #model = DQN(**cfg.policy.model) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) + expert_policy.collect_mode.load_state_dict( + torch.load(cfg.policy.collect.demonstration_info_path, map_location='cpu') + ) + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + collector = create_serial_collector( + cfg.policy.collect.collector, + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name + ) + expert_collector = create_serial_collector( + expert_cfg.policy.collect.collector, + env=expert_collector_env, + policy=expert_policy.collect_mode, + tb_logger=tb_logger, + exp_name=expert_cfg.exp_name + ) + evaluator = InteractionSerialEvaluator( + cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name + ) + replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name) + commander = BaseSerialCommander( + cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode + ) + expert_commander = BaseSerialCommander( + expert_cfg.policy.other.commander, learner, expert_collector, evaluator, replay_buffer, + expert_policy.command_mode + ) # we create this to avoid the issue of eps, this is an issue due to the sample collector part. + expert_collect_kwargs = expert_commander.step() + if 'eps' in expert_collect_kwargs: + expert_collect_kwargs['eps'] = -1 + # ========== + # Main loop + # ========== + # Learner's before_run hook. + learner.call_hook('before_run') + if cfg.policy.learn.expert_replay_buffer_size != 0: # for ablation study + dummy_variable = deepcopy(cfg.policy.other.replay_buffer) + dummy_variable['replay_buffer_size'] = cfg.policy.learn.expert_replay_buffer_size + expert_buffer = create_buffer(dummy_variable, tb_logger=tb_logger, exp_name=cfg.exp_name) + expert_data = expert_collector.collect( + n_sample=cfg.policy.learn.expert_replay_buffer_size, policy_kwargs=expert_collect_kwargs + ) + for i in range(len(expert_data)): + expert_data[i]['is_expert'] = 1 # set is_expert flag(expert 1, agent 0) + expert_buffer.push(expert_data, cur_collector_envstep=0) + for _ in range(cfg.policy.learn.per_train_iter_k): + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + # Learn policy from collected data + # Expert_learner will train ``update_per_collect == 1`` times in one iteration. + train_data = expert_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) + learner.train(train_data, collector.envstep) + if learner.policy.get_attribute('priority'): + expert_buffer.update(learner.priority_info) + learner.priority_info = {} + # Accumulate plenty of data at the beginning of training. + if cfg.policy.get('random_collect_size', 0) > 0: + action_space = collector_env.env_info().act_space + random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space) + collector.reset_policy(random_policy) + collect_kwargs = commander.step() + new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs) + for i in range(len(new_data)): + new_data[i]['is_expert'] = 0 # set is_expert flag(expert 1, agent 0) + replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) + collector.reset_policy(policy.collect_mode) + for _ in range(max_iterations): + collect_kwargs = commander.step() + # Evaluate policy performance + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + # Collect data by default config n_sample/n_episode + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + for i in range(len(new_data)): + new_data[i]['is_expert'] = 0 # set is_expert flag(expert 1, agent 0) + replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) + # Learn policy from collected data + for i in range(cfg.policy.learn.update_per_collect): + if cfg.policy.learn.expert_replay_buffer_size != 0: + # Learner will train ``update_per_collect`` times in one iteration. + # The hyperparameter pho, the demo ratio, control the propotion of data coming\ + # from expert demonstrations versus from the agent's own experience. + stats = np.random.choice( + (learner.policy.get_attribute('batch_size')), size=(learner.policy.get_attribute('batch_size')) + ) < ( + learner.policy.get_attribute('batch_size') + ) * cfg.policy.collect.pho # torch.rand((learner.policy.get_attribute('batch_size')))\ + # < cfg.policy.collect.pho + expert_batch_size = stats[stats].shape[0] + demo_batch_size = (learner.policy.get_attribute('batch_size')) - expert_batch_size + train_data = replay_buffer.sample(demo_batch_size, learner.train_iter) + train_data_demonstration = expert_buffer.sample(expert_batch_size, learner.train_iter) + if train_data is None: + # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times + logging.warning( + "Replay buffer's data can only train for {} steps. ".format(i) + + "You can modify data collect config, e.g. increasing n_sample, n_episode." + ) + break + train_data = train_data + train_data_demonstration + learner.train(train_data, collector.envstep) + if learner.policy.get_attribute('priority'): + # When collector, set replay_buffer_idx and replay_unique_id for each data item, priority = 1.\ + # When learner, assign priority for each data item according their loss + learner.priority_info_agent = deepcopy(learner.priority_info) + learner.priority_info_expert = deepcopy(learner.priority_info) + learner.priority_info_agent['priority'] = learner.priority_info['priority'][0:demo_batch_size] + learner.priority_info_agent['replay_buffer_idx'] = learner.priority_info['replay_buffer_idx'][ + 0:demo_batch_size] + learner.priority_info_agent['replay_unique_id'] = learner.priority_info['replay_unique_id'][ + 0:demo_batch_size] + learner.priority_info_expert['priority'] = learner.priority_info['priority'][demo_batch_size:] + learner.priority_info_expert['replay_buffer_idx'] = learner.priority_info['replay_buffer_idx'][ + demo_batch_size:] + learner.priority_info_expert['replay_unique_id'] = learner.priority_info['replay_unique_id'][ + demo_batch_size:] + # Expert data and demo data update their priority separately. + replay_buffer.update(learner.priority_info_agent) + expert_buffer.update(learner.priority_info_expert) + else: + # Learner will train ``update_per_collect`` times in one iteration. + train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) + if train_data is None: + # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times + logging.warning( + "Replay buffer's data can only train for {} steps. ".format(i) + + "You can modify data collect config, e.g. increasing n_sample, n_episode." + ) + break + learner.train(train_data, collector.envstep) + if learner.policy.get_attribute('priority'): + replay_buffer.update(learner.priority_info) + if cfg.policy.on_policy: + # On-policy algorithm must clear the replay buffer. + replay_buffer.clear() + + # Learner's after_run hook. + learner.call_hook('after_run') + return policy + + +if __name__ == '__main__': + main_config_1 = deepcopy(main_config) + main_config_2 = deepcopy(create_config) + serial_pipeline_dqfd([main_config, create_config], [main_config_1, main_config_2], seed=0) +''' +#from ding.policy.base_policy import Policy +from typing import Union, Optional, List, Any, Tuple +import os +import torch +import numpy as np +import logging +from functools import partial +from tensorboardX import SummaryWriter + +from ding.envs import get_vec_env_setting, create_env_manager +from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \ + create_serial_collector +from ding.config import read_config, compile_config +from ding.policy import create_policy, PolicyFactory +from ding.utils import set_pkg_seed +from ding.model import DQN +from copy import deepcopy +from dizoo.classic_control.cartpole.config.cartpole_dqfd_config import main_config, create_config # for testing + + +def serial_pipeline_dqfd( + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + env_setting: Optional[List[Any]] = None, + model: Optional[torch.nn.Module] = None, + expert_model: Optional[torch.nn.Module] = None, + max_iterations: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + Serial pipeline dqfd entry: we create this serial pipeline in order to\ + implement dqfd in DI-engine. For now, we support the following envs\ + Cartpole, Lunarlander, Pong, Spaceinvader. The demonstration\ + data come from the expert model. We use a well-trained model to \ + generate demonstration data online + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ + ``BaseEnv`` subclass, collector env config, and evaluator env config. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - expert_model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.\ + The default model is DQN(**cfg.policy.model) + - max_iterations (:obj:`Optional[torch.nn.Module]`): Learner's max iteration. Pipeline will stop \ + when reaching this iteration. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + else: + cfg, create_cfg = input_cfg + create_cfg.policy.type = create_cfg.policy.type + '_command' + env_fn = None if env_setting is None else env_setting[0] + cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + if env_setting is None: + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + else: + env_fn, collector_env_cfg, evaluator_env_cfg = env_setting + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + expert_collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + expert_collector_env.seed(cfg.seed) + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + #expert_model = DQN(**cfg.policy.model) + expert_policy = create_policy(cfg.policy, model=expert_model, enable_field=['collect']) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + #model = DQN(**cfg.policy.model) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) + expert_policy.collect_mode.load_state_dict( + torch.load(cfg.policy.collect.demonstration_info_path, map_location='cpu') + ) + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + collector = create_serial_collector( + cfg.policy.collect.collector, + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name + ) + expert_collector = create_serial_collector( + cfg.policy.collect.collector, + env=expert_collector_env, + policy=expert_policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name + ) + evaluator = InteractionSerialEvaluator( + cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name + ) + replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name) + commander = BaseSerialCommander( + cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode + ) + # ========== + # Main loop + # ========== + # Learner's before_run hook. + learner.call_hook('before_run') + if cfg.policy.learn.expert_replay_buffer_size != 0: # for ablation study + dummy_variable = deepcopy(cfg.policy.other.replay_buffer) + dummy_variable['replay_buffer_size'] = cfg.policy.learn.expert_replay_buffer_size + expert_buffer = create_buffer(dummy_variable, tb_logger=tb_logger, exp_name=cfg.exp_name) + expert_data = expert_collector.collect( + n_sample=cfg.policy.learn.expert_replay_buffer_size, policy_kwargs={'eps': -1} + ) + for i in range(len(expert_data)): + expert_data[i]['is_expert'] = 1 # set is_expert flag(expert 1, agent 0) + expert_buffer.push(expert_data, cur_collector_envstep=0) + for _ in range(cfg.policy.learn.per_train_iter_k): + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + # Learn policy from collected data + # Expert_learner will train ``update_per_collect == 1`` times in one iteration. + train_data = expert_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) + learner.train(train_data, collector.envstep) + if learner.policy.get_attribute('priority'): + expert_buffer.update(learner.priority_info) + learner.priority_info = {} + # Accumulate plenty of data at the beginning of training. + if cfg.policy.get('random_collect_size', 0) > 0: + action_space = collector_env.env_info().act_space + random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space) + collector.reset_policy(random_policy) + collect_kwargs = commander.step() + new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs) + for i in range(len(new_data)): + new_data[i]['is_expert'] = 0 # set is_expert flag(expert 1, agent 0) + replay_buffer.push(new_data, cur_collector_envstep=0) + collector.reset_policy(policy.collect_mode) + for _ in range(max_iterations): + collect_kwargs = commander.step() + # Evaluate policy performance + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + # Collect data by default config n_sample/n_episode + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + for i in range(len(new_data)): + new_data[i]['is_expert'] = 0 # set is_expert flag(expert 1, agent 0) + replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) + # Learn policy from collected data + for i in range(cfg.policy.learn.update_per_collect): + if cfg.policy.learn.expert_replay_buffer_size != 0: + # Learner will train ``update_per_collect`` times in one iteration. + # The hyperparameter pho, the demo ratio, control the propotion of data coming\ + # from expert demonstrations versus from the agent's own experience. + stats = np.random.choice( + (learner.policy.get_attribute('batch_size')), size=(learner.policy.get_attribute('batch_size')) + ) < ( + learner.policy.get_attribute('batch_size') + ) * cfg.policy.collect.pho # torch.rand((learner.policy.get_attribute('batch_size')))\ + # < cfg.policy.collect.pho + expert_batch_size = stats[stats].shape[0] + demo_batch_size = (learner.policy.get_attribute('batch_size')) - expert_batch_size + train_data = replay_buffer.sample(demo_batch_size, learner.train_iter) + train_data_demonstration = expert_buffer.sample(expert_batch_size, learner.train_iter) + if train_data is None: + # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times + logging.warning( + "Replay buffer's data can only train for {} steps. ".format(i) + + "You can modify data collect config, e.g. increasing n_sample, n_episode." + ) + break + train_data = train_data + train_data_demonstration + learner.train(train_data, collector.envstep) + if learner.policy.get_attribute('priority'): + # When collector, set replay_buffer_idx and replay_unique_id for each data item, priority = 1.\ + # When learner, assign priority for each data item according their loss + learner.priority_info_agent = deepcopy(learner.priority_info) + learner.priority_info_expert = deepcopy(learner.priority_info) + learner.priority_info_agent['priority'] = learner.priority_info['priority'][0:demo_batch_size] + learner.priority_info_agent['replay_buffer_idx'] = learner.priority_info['replay_buffer_idx'][ + 0:demo_batch_size] + learner.priority_info_agent['replay_unique_id'] = learner.priority_info['replay_unique_id'][ + 0:demo_batch_size] + learner.priority_info_expert['priority'] = learner.priority_info['priority'][demo_batch_size:] + learner.priority_info_expert['replay_buffer_idx'] = learner.priority_info['replay_buffer_idx'][ + demo_batch_size:] + learner.priority_info_expert['replay_unique_id'] = learner.priority_info['replay_unique_id'][ + demo_batch_size:] + # Expert data and demo data update their priority separately. + replay_buffer.update(learner.priority_info_agent) + expert_buffer.update(learner.priority_info_expert) + else: + # Learner will train ``update_per_collect`` times in one iteration. + train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) + if train_data is None: + # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times + logging.warning( + "Replay buffer's data can only train for {} steps. ".format(i) + + "You can modify data collect config, e.g. increasing n_sample, n_episode." + ) + break + learner.train(train_data, collector.envstep) + if learner.policy.get_attribute('priority'): + replay_buffer.update(learner.priority_info) + if cfg.policy.on_policy: + # On-policy algorithm must clear the replay buffer. + replay_buffer.clear() + + # Learner's after_run hook. + learner.call_hook('after_run') + return policy + + +if __name__ == '__main__': + serial_pipeline_dqfd([main_config, create_config], seed=0) +''' diff --git a/ding/entry/tests/test_serial_entry_dqfd.py b/ding/entry/tests/test_serial_entry_dqfd.py new file mode 100644 index 0000000000000000000000000000000000000000..9748f6d23bddd26f017f26d81284247c85a78f07 --- /dev/null +++ b/ding/entry/tests/test_serial_entry_dqfd.py @@ -0,0 +1,23 @@ +import pytest +import torch +from copy import deepcopy +from ding.entry import serial_pipeline +from ding.entry.serial_entry_dqfd import serial_pipeline_dqfd +from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config +from dizoo.classic_control.cartpole.config.cartpole_dqfd_config import cartpole_dqfd_config, cartpole_dqfd_create_config + + +@pytest.mark.unittest +def test_dqfd(): + expert_policy_state_dict_path = './expert_policy.pth' + config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)] + expert_policy = serial_pipeline(config, seed=0) + torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path) + + config = [deepcopy(cartpole_dqfd_config), deepcopy(cartpole_dqfd_create_config)] + config[0].policy.collect.demonstration_info_path = expert_policy_state_dict_path + config[0].policy.learn.update_per_collect = 1 + try: + serial_pipeline_dqfd(config, [cartpole_dqfd_config, cartpole_dqfd_create_config], seed=0, max_iterations=1) + except Exception: + assert False, "pipeline fail" diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index f1fd77ac2b46284f4590cd5ec672c8bef8429b1c..f5e7d00c127c1d8ed7f011cbf1700f3962f29bc0 100644 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -24,6 +24,7 @@ from .atoc import ATOCPolicy from .acer import ACERPolicy from .qtran import QTRANPolicy from .sql import SQLPolicy +from .dqfd import DQFDPolicy from .d4pg import D4PGPolicy from .cql import CQLPolicy, CQLDiscretePolicy @@ -81,6 +82,11 @@ class DQNCommandModePolicy(DQNPolicy, EpsCommandModePolicy): pass +@POLICY_REGISTRY.register('dqfd_command') +class DQFDCommandModePolicy(DQFDPolicy, EpsCommandModePolicy): + pass + + @POLICY_REGISTRY.register('c51_command') class C51CommandModePolicy(C51Policy, EpsCommandModePolicy): pass diff --git a/ding/policy/dqfd.py b/ding/policy/dqfd.py new file mode 100644 index 0000000000000000000000000000000000000000..7dc41df29be3d46179beac851fc973c59bbeae32 --- /dev/null +++ b/ding/policy/dqfd.py @@ -0,0 +1,262 @@ +from typing import List, Dict, Any, Tuple +from collections import namedtuple +import copy +import torch + +from ding.torch_utils import Adam, to_device +from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, get_nstep_return_data, get_train_sample, \ + dqfd_nstep_td_error, dqfd_nstep_td_data +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY +from ding.utils.data import default_collate, default_decollate +from .dqn import DQNPolicy +from .common_utils import default_preprocess_learn +from copy import deepcopy + + +@POLICY_REGISTRY.register('dqfd') +class DQFDPolicy(DQNPolicy): + r""" + Overview: + Policy class of DQFD algorithm, extended by Double DQN/Dueling DQN/PER/multi-step TD. + + Config: + == ==================== ======== ============== ======================================== ======================= + ID Symbol Type Default Value Description Other(Shape) + == ==================== ======== ============== ======================================== ======================= + 1 ``type`` str dqn | RL policy register name, refer to | This arg is optional, + | registry ``POLICY_REGISTRY`` | a placeholder + 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff- + | erent from modes + 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy + | or off-policy + 4 ``priority`` bool True | Whether use priority(PER) | Priority sample, + | update priority + 5 | ``priority_IS`` bool True | Whether use Importance Sampling Weight + | ``_weight`` | to correct biased update. If True, + | priority must be True. + 6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse + | ``factor`` [0.95, 0.999] | gamma | reward env + 7 ``nstep`` int 10, | N-step reward discount sum for target + [3, 5] | q_value estimation + 8 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary + | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val + | valid in serial training | means more off-policy + 9 | ``learn.batch_`` int 64 | The number of samples of an iteration + | ``size`` + 10 | ``learn.learning`` float 0.001 | Gradient step length of an iteration. + | ``_rate`` + 11 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update + | ``update_freq`` + 12 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some + | ``done`` | calculation. | fake termination env + 13 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from + | call of collector. | different envs + 14 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1 + | ``_len`` + == ==================== ======== ============== ======================================== ======================= + """ + + config = dict( + type='dqfd', + cuda=False, + on_policy=False, + priority=True, + # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. + priority_IS_weight=True, + discount_factor=0.99, + nstep=10, + learn=dict( + # multiplicative factor for each loss + lambda1=1.0, + lambda2=1.0, + lambda3=1e-5, + # margin function in JE, here we implement this as a constant + margin_function=0.8, + # number of pertraining iterations + per_train_iter_k=10, + # (bool) Whether to use multi gpu + multi_gpu=False, + # How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + update_per_collect=3, + batch_size=64, + learning_rate=0.001, + # ============================================================== + # The following configs are algorithm-specific + # ============================================================== + # (int) Frequence of target network update. + target_update_freq=100, + # (bool) Whether ignore done(usually for max step termination env) + ignore_done=False, + ), + # collect_mode config + collect=dict( + # (int) Only one of [n_sample, n_episode] should be set + # n_sample=8, + # (int) Cut trajectories into pieces with length "unroll_len". + unroll_len=1, + # The hyperparameter pho, the demo ratio, control the propotion of data\ + # coming from expert demonstrations versus from the agent's own experience. + pho=0.5, + ), + eval=dict(), + # other config + other=dict( + # Epsilon greedy with decay. + eps=dict( + # (str) Decay type. Support ['exp', 'linear']. + type='exp', + start=0.95, + end=0.1, + # (int) Decay length(env step) + decay=10000, + ), + replay_buffer=dict(replay_buffer_size=10000, ), + ), + ) + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``, initialize the optimizer, algorithm arguments, main \ + and target model. + """ + self.lambda1 = self._cfg.learn.lambda1, # n-step return + self.lambda2 = self._cfg.learn.lambda2, # supervised loss + self.lambda3 = self._cfg.learn.lambda3, # L2 + # margin function in JE, here we implement this as a constant + self.margin_function = self._cfg.learn.margin_function + self._priority = self._cfg.priority + self._priority_IS_weight = self._cfg.priority_IS_weight + # Optimizer + self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate, weight_decay=self.lambda3[0]) + + self._gamma = self._cfg.discount_factor + self._nstep = self._cfg.nstep + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.learn.target_update_freq} + ) + self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') + self._learn_model.reset() + self._target_model.reset() + + def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ + Overview: + Forward computation graph of learn mode(updating policy). + Arguments: + - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \ + np.ndarray or dict/list combinations. + Returns: + - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ + recorded in text log and tensorboard, values are python scalar or a list of scalars. + ArgumentsKeys: + - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done`` + - optional: ``value_gamma``, ``IS`` + ReturnsKeys: + - necessary: ``cur_lr``, ``total_loss``, ``priority`` + - optional: ``action_distribution`` + """ + data = default_preprocess_learn( + data, + use_priority=self._priority, + use_priority_IS_weight=self._cfg.priority_IS_weight, + ignore_done=self._cfg.learn.ignore_done, + use_nstep=True + ) + data['done_1'] = data['done_1'].float() + if self._cuda: + data = to_device(data, self._device) + # ==================== + # Q-learning forward + # ==================== + self._learn_model.train() + self._target_model.train() + # Current q value (main model) + q_value = self._learn_model.forward(data['obs'])['logit'] + # Target q value + with torch.no_grad(): + target_q_value = self._target_model.forward(data['next_obs'])['logit'] + target_q_value_one_step = self._target_model.forward(data['next_obs_1'])['logit'] + # Max q value action (main model) + target_q_action = self._learn_model.forward(data['next_obs'])['action'] + target_q_action_one_step = self._learn_model.forward(data['next_obs_1'])['action'] + + data_n = dqfd_nstep_td_data( + q_value, + target_q_value, + data['action'], + target_q_action, + data['reward'], + data['done'], + data['done_1'], + data['weight'], + target_q_value_one_step, + target_q_action_one_step, + data['is_expert'] # set is_expert flag(expert 1, agent 0) + ) + value_gamma = data.get('value_gamma') + loss, td_error_per_sample = dqfd_nstep_td_error( + data_n, + self._gamma, + self.lambda1, + self.lambda2, + self.margin_function, + nstep=self._nstep, + value_gamma=value_gamma + ) + + # ==================== + # Q-learning update + # ==================== + self._optimizer.zero_grad() + loss.backward() + if self._cfg.learn.multi_gpu: + self.sync_gradients(self._learn_model) + self._optimizer.step() + + # ============= + # after update + # ============= + self._target_model.update(self._learn_model.state_dict()) + return { + 'cur_lr': self._optimizer.defaults['lr'], + 'total_loss': loss.item(), + 'priority': td_error_per_sample.abs().tolist(), + # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard. + # '[histogram]action_distribution': data['action'], + } + + def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Overview: + For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \ + or some continuous transitions(DRQN). + Arguments: + - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \ + format as the return value of ``self._process_transition`` method. + Returns: + - samples (:obj:`dict`): The list of training samples. + + .. note:: + We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ + And the user can customize the this data processing procecure by overriding this two methods and collector \ + itself. + """ + data_1 = deepcopy(get_nstep_return_data(data, 1, gamma=self._gamma)) + data = get_nstep_return_data( + data, self._nstep, gamma=self._gamma + ) # here we want to include one-step next observation + for i in range(len(data)): + data[i]['next_obs_1'] = data_1[i]['next_obs'] # concat the one-step next observation + data[i]['done_1'] = data_1[i]['done'] + return get_train_sample(data, self._unroll_len) diff --git a/ding/rl_utils/__init__.py b/ding/rl_utils/__init__.py index d115f0b814ebeb1e49a8381c5a995e22015597c7..61e8dc5b80952ed7f24b2a7c9b69e94c046bb4ce 100644 --- a/ding/rl_utils/__init__.py +++ b/ding/rl_utils/__init__.py @@ -9,7 +9,7 @@ from .td import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, q_1step_td_e q_nstep_td_error_with_rescale, v_1step_td_data, v_1step_td_error, v_nstep_td_data, v_nstep_td_error, \ generalized_lambda_returns, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_error, dist_nstep_td_data, \ nstep_return_data, nstep_return, iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error,\ - q_nstep_sql_td_error + q_nstep_sql_td_error, dqfd_nstep_td_error, dqfd_nstep_td_data from .vtrace import vtrace_loss, compute_importance_weights from .upgo import upgo_loss from .adder import get_gae, get_gae_with_default_last_value, get_nstep_return_data, get_train_sample diff --git a/ding/rl_utils/td.py b/ding/rl_utils/td.py index 094d9bf7706370f4e234d35fed9c549e7c0620af..ad9041bc5b4aa50d0586785ad24992a8fc3885b4 100644 --- a/ding/rl_utils/td.py +++ b/ding/rl_utils/td.py @@ -259,6 +259,13 @@ q_nstep_td_data = namedtuple( 'q_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'] ) +dqfd_nstep_td_data = namedtuple( + 'dqfd_nstep_td_data', [ + 'q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'done_1', 'weight', 'new_n_q_one_step', + 'next_n_action_one_step', 'is_expert' + ] +) + def shape_fn_qntd(args, kwargs): r""" @@ -329,6 +336,107 @@ def q_nstep_td_error( return (td_error_per_sample * weight).mean(), td_error_per_sample +def dqfd_nstep_td_error( + data: namedtuple, + gamma: float, + lambda1: tuple, + lambda2: tuple, + margin_function: float, + nstep: int = 1, + cum_reward: bool = False, + value_gamma: Optional[torch.Tensor] = None, + criterion: torch.nn.modules = nn.MSELoss(reduction='none'), +) -> torch.Tensor: + """ + Overview: + Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd + Arguments: + - data (:obj:`dqfd_nstep_td_data`): the input data, dqfd_nstep_td_data to calculate loss + - gamma (:obj:`float`): discount factor + - cum_reward (:obj:`bool`): whether to use cumulative nstep reward, which is figured out when collecting data + - value_gamma (:obj:`torch.Tensor`): gamma discount value for target q_value + - criterion (:obj:`torch.nn.modules`): loss function criterion + - nstep (:obj:`int`): nstep num, default set to 10 + Returns: + - loss (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor + - td_error_per_sample (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error\ + + supervised margin loss, 1-dim tensor + Shapes: + - data (:obj:`q_nstep_td_data`): the q_nstep_td_data containing\ + ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'\ + , 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert'] + - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] + - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` + - action (:obj:`torch.LongTensor`): :math:`(B, )` + - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` + - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) + - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep + - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` + - new_n_q_one_step (:obj:`torch.FloatTensor`): :math:`(B, N)` + - next_n_action_one_step (:obj:`torch.LongTensor`): :math:`(B, )` + - is_expert (:obj:`int`) : 0 or 1 + """ + q, next_n_q, action, next_n_action, reward, done, done_1, weight, new_n_q_one_step, next_n_action_one_step,\ + is_expert = data # set is_expert flag(expert 1, agent 0) + assert len(action.shape) == 1, action.shape + if weight is None: + weight = torch.ones_like(action) + + batch_range = torch.arange(action.shape[0]) + q_s_a = q[batch_range, action] + target_q_s_a = next_n_q[batch_range, next_n_action] + target_q_s_a_one_step = new_n_q_one_step[batch_range, next_n_action_one_step] + + # calculate n-step TD-loss + if cum_reward: + if value_gamma is None: + target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done) + else: + target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done) + else: + target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) + td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) + + # calculate 1-step TD-loss + nstep = 1 + reward = reward[0].unsqueeze(0) + value_gamma = None + if cum_reward: + if value_gamma is None: + target_q_s_a_one_step = reward + (gamma ** nstep) * target_q_s_a_one_step * (1 - done_1) + else: + target_q_s_a_one_step = reward + value_gamma * target_q_s_a_one_step * (1 - done_1) + else: + target_q_s_a_one_step = nstep_return( + nstep_return_data(reward, target_q_s_a_one_step, done_1), gamma, nstep, value_gamma + ) + td_error_one_step_per_sample = criterion(q_s_a, target_q_s_a_one_step.detach()) + + # calculate the supervised loss + device = q_s_a.device + device_cpu = torch.device('cpu') + ''' + max_action = torch.argmax(q, dim=-1) + JE = is_expert * ( + q[batch_range, max_action] + margin_function * + torch.where(action == max_action, torch.ones_like(action), torch.zeros_like(action)).float().to(device) - q_s_a + ) + ''' + l = margin_function * torch.ones_like(q).to(device_cpu) + l.scatter_( + 1, torch.LongTensor(action.unsqueeze(1).to(device_cpu)), torch.zeros_like(q, device=device_cpu) + ) # along the first dimension. for the index of the action, fill the corresponding position in l with 0 + JE = is_expert * (torch.max(q + l.to(device), dim=1)[0] - q_s_a) + ''' + Js = is_expert * ( + q[batch_range, max_action.type(torch.int64)] + + 0.8 * torch.from_numpy((action == max_action).numpy().astype(int)).float().to(device) - q_s_a + ) + ''' + return ((lambda1[0] * td_error_per_sample + td_error_one_step_per_sample + lambda2[0] * JE) * + weight).mean(), td_error_per_sample + td_error_one_step_per_sample + JE + + def shape_fn_qntd_rescale(args, kwargs): r""" Overview: diff --git a/ding/rl_utils/tests/test_td.py b/ding/rl_utils/tests/test_td.py index 5e5950955b4c35123036f948ed3aa883f16d6573..a937df79e84ec9549915bea1a1d0f07557919a7d 100644 --- a/ding/rl_utils/tests/test_td.py +++ b/ding/rl_utils/tests/test_td.py @@ -1,9 +1,10 @@ import pytest import torch from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, q_1step_td_error, td_lambda_data,\ - td_lambda_error, q_nstep_td_error_with_rescale, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_data, \ - dist_nstep_td_error, v_1step_td_data, v_1step_td_error, v_nstep_td_data, v_nstep_td_error, q_nstep_sql_td_error, \ - iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error + td_lambda_error, q_nstep_td_error_with_rescale, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_data,\ + dqfd_nstep_td_data, dqfd_nstep_td_error, dist_nstep_td_error, v_1step_td_data, v_1step_td_error, v_nstep_td_data,\ + v_nstep_td_error, q_nstep_sql_td_error, iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data,\ + qrdqn_nstep_td_error from ding.rl_utils.td import shape_fn_dntd, shape_fn_qntd, shape_fn_td_lambda, shape_fn_qntd_rescale @@ -214,6 +215,35 @@ def test_v_nstep_td(): assert isinstance(v.grad, torch.Tensor) +@pytest.mark.unittest +def test_dqfd_nstep_td(): + batch_size = 4 + action_dim = 3 + next_q = torch.randn(batch_size, action_dim) + done = torch.randn(batch_size) + done_1 = torch.randn(batch_size) + next_q_one_step = torch.randn(batch_size, action_dim) + action = torch.randint(0, action_dim, size=(batch_size, )) + next_action = torch.randint(0, action_dim, size=(batch_size, )) + next_action_one_step = torch.randint(0, action_dim, size=(batch_size, )) + is_expert = torch.ones((batch_size)) + for nstep in range(1, 10): + q = torch.randn(batch_size, action_dim).requires_grad_(True) + reward = torch.rand(nstep, batch_size) + data = dqfd_nstep_td_data( + q, next_q, action, next_action, reward, done, done_1, None, next_q_one_step, next_action_one_step, is_expert + ) + loss, td_error_per_sample = dqfd_nstep_td_error( + data, 0.95, lambda1=(1, ), lambda2=(1, ), margin_function=0.8, nstep=nstep + ) + assert td_error_per_sample.shape == (batch_size, ) + assert loss.shape == () + assert q.grad is None + loss.backward() + assert isinstance(q.grad, torch.Tensor) + print(loss) + + @pytest.mark.unittest def test_q_nstep_sql_td(): batch_size = 4 diff --git a/dizoo/atari/config/serial/pong/pong_dqfd_config.py b/dizoo/atari/config/serial/pong/pong_dqfd_config.py new file mode 100644 index 0000000000000000000000000000000000000000..6807746c48a008a50d2cdd0398213fcf6ac28f00 --- /dev/null +++ b/dizoo/atari/config/serial/pong/pong_dqfd_config.py @@ -0,0 +1,63 @@ +from copy import deepcopy +from ding.entry import serial_pipeline +from easydict import EasyDict + +pong_dqfd_config = dict( + exp_name='pong_dqfd', + env=dict( + collector_env_num=8, + evaluator_env_num=8, + n_evaluator_episode=8, + stop_value=20, + env_id='PongNoFrameskip-v4', + frame_stack=4, + manager=dict(shared_memory=True, force_reproducibility=True) + ), + policy=dict( + cuda=True, + priority=True, + model=dict( + obs_shape=[4, 84, 84], + action_shape=6, + encoder_hidden_size_list=[128, 128, 512], + ), + nstep=3, + discount_factor=0.99, + learn=dict( + update_per_collect=10, + batch_size=32, + learning_rate=0.0001, + target_update_freq=500, + lambda1 = 1.0, + lambda2 = 1.0, + lambda3 = 1e-5, + per_train_iter_k = 10, + expert_replay_buffer_size = 10000, # justify the buffer size of the expert buffer + ), + collect=dict(n_sample=96, demonstration_info_path = 'path'), #Users should add their own path here (path should lead to a well-trained model) + other=dict( + eps=dict( + type='exp', + start=1., + end=0.05, + decay=250000, + ), + replay_buffer=dict(replay_buffer_size=100000, ), + ), + ), +) +pong_dqfd_config = EasyDict(pong_dqfd_config) +main_config = pong_dqfd_config +pong_dqfd_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='dqfd'), +) +pong_dqfd_create_config = EasyDict(pong_dqfd_create_config) +create_config = pong_dqfd_create_config + +if __name__ == '__main__': + serial_pipeline((main_config, create_config), seed=0) diff --git a/dizoo/atari/config/serial/qbert/qbert_dqfd_config.py b/dizoo/atari/config/serial/qbert/qbert_dqfd_config.py new file mode 100644 index 0000000000000000000000000000000000000000..b839fea2c72d5925e3b9615c72ea1fe9385ae35d --- /dev/null +++ b/dizoo/atari/config/serial/qbert/qbert_dqfd_config.py @@ -0,0 +1,63 @@ +from ding.entry import serial_pipeline +from easydict import EasyDict + +qbert_dqn_config = dict( + exp_name='qbert_dqfd', + env=dict( + collector_env_num=8, + evaluator_env_num=8, + n_evaluator_episode=8, + stop_value=30000, + env_id='QbertNoFrameskip-v4', + frame_stack=4, + manager=dict(shared_memory=True, force_reproducibility=True) + ), + policy=dict( + cuda=True, + priority=True, + model=dict( + obs_shape=[4, 84, 84], + action_shape=6, + encoder_hidden_size_list=[128, 128, 512], + ), + nstep=3, + discount_factor=0.99, + learn=dict( + update_per_collect=10, + batch_size=32, + learning_rate=0.0001, + target_update_freq=500, + lambda1 = 1.0, + lambda2 = 1.0, + lambda3 = 1e-5, + per_train_iter_k = 10, + expert_replay_buffer_size = 10000, # justify the buffer size of the expert buffer + ), + collect=dict(n_sample=100, demonstration_info_path = 'path'), #Users should add their own path here (path should lead to a well-trained model) + eval=dict(evaluator=dict(eval_freq=4000, )), + other=dict( + eps=dict( + type='exp', + start=1., + end=0.05, + decay=1000000, + ), + replay_buffer=dict(replay_buffer_size=400000, ), + ), + ), +) +qbert_dqn_config = EasyDict(qbert_dqn_config) +main_config = qbert_dqn_config +qbert_dqn_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='dqfd'), +) +qbert_dqn_create_config = EasyDict(qbert_dqn_create_config) +create_config = qbert_dqn_create_config + +if __name__ == '__main__': + serial_pipeline((main_config, create_config), seed=0) diff --git a/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqfd_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqfd_config.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2f7ff8ee24f4c90cd9bb69e81865c8801be513 --- /dev/null +++ b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqfd_config.py @@ -0,0 +1,64 @@ +from copy import deepcopy +from ding.entry import serial_pipeline +from easydict import EasyDict + +space_invaders_dqfd_config = dict( + exp_name='space_invaders_dqfd', + env=dict( + collector_env_num=8, + evaluator_env_num=8, + n_evaluator_episode=8, + stop_value=10000000000, + env_id='SpaceInvadersNoFrameskip-v4', + frame_stack=4, + manager=dict(shared_memory=True, force_reproducibility=True) + ), + policy=dict( + cuda=True, + priority=True, + model=dict( + obs_shape=[4, 84, 84], + action_shape=6, + encoder_hidden_size_list=[128, 128, 512], + ), + nstep=3, + discount_factor=0.99, + learn=dict( + update_per_collect=10, + batch_size=32, + learning_rate=0.0001, + target_update_freq=500, + lambda1 = 1.0, + lambda2 = 1.0, + lambda3 = 1e-5, + per_train_iter_k = 10, + expert_replay_buffer_size = 10000, # justify the buffer size of the expert buffer + ), + collect=dict(n_sample=100, demonstration_info_path = 'path'), #Users should add their own path here (path should lead to a well-trained model) + eval=dict(evaluator=dict(eval_freq=4000, )), + other=dict( + eps=dict( + type='exp', + start=1., + end=0.05, + decay=1000000, + ), + replay_buffer=dict(replay_buffer_size=400000, ), + ), + ), +) +space_invaders_dqfd_config = EasyDict(space_invaders_dqfd_config) +main_config = space_invaders_dqfd_config +space_invaders_dqfd_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='dqfd'), +) +space_invaders_dqfd_create_config = EasyDict(space_invaders_dqfd_create_config) +create_config = space_invaders_dqfd_create_config + +if __name__ == '__main__': + serial_pipeline((main_config, create_config), seed=0) diff --git a/dizoo/box2d/lunarlander/config/lunarlander_dqfd_config.py b/dizoo/box2d/lunarlander/config/lunarlander_dqfd_config.py new file mode 100644 index 0000000000000000000000000000000000000000..6f00c13921dbbb57fc80d70a2f5ac18d0e71e983 --- /dev/null +++ b/dizoo/box2d/lunarlander/config/lunarlander_dqfd_config.py @@ -0,0 +1,64 @@ +from easydict import EasyDict +from ding.entry import serial_pipeline + +lunarlander_dqfd_config = dict( + exp_name='lunarlander_dqfd', + env=dict( + # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess' + manager=dict(shared_memory=True, force_reproducibility=True), + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=200, + ), + policy=dict( + cuda=True, + model=dict( + obs_shape=8, + action_shape=4, + encoder_hidden_size_list=[512, 64], + dueling=True, + ), + nstep=3, + discount_factor=0.97, + learn=dict(batch_size=64, learning_rate=0.001, + lambda1 = 1.0, + lambda2 = 1.0, + lambda3 = 1e-5, + per_train_iter_k = 10, + expert_replay_buffer_size = 10000, # justify the buffer size of the expert buffer + ), + collect=dict( + n_sample=64, + # Users should add their own path here (path should lead to a well-trained model) + demonstration_info_path='path', + # Cut trajectories into pieces with length "unroll_len". + unroll_len=1, + ), + eval=dict(evaluator=dict(eval_freq=50, )), # note: this is the times after which you learns to evaluate + other=dict( + eps=dict( + type='exp', + start=0.95, + end=0.1, + decay=10000, + ), + replay_buffer=dict(replay_buffer_size=20000, ), + ), + ), +) +lunarlander_dqfd_config = EasyDict(lunarlander_dqfd_config) +main_config = lunarlander_dqfd_config +lunarlander_dqfd_create_config = dict( + env=dict( + type='lunarlander', + import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='dqfd'), +) +lunarlander_dqfd_create_config = EasyDict(lunarlander_dqfd_create_config) +create_config = lunarlander_dqfd_create_config + +if __name__ == "__main__": + serial_pipeline([main_config, create_config], seed=0) diff --git a/dizoo/box2d/lunarlander/config/lunarlander_dqn_config.py b/dizoo/box2d/lunarlander/config/lunarlander_dqn_config.py index f7f1da069cc97e7b5c87df2d0d67d048201cb75d..03941312850fac27afaf6edc01a3e6b758b92334 100644 --- a/dizoo/box2d/lunarlander/config/lunarlander_dqn_config.py +++ b/dizoo/box2d/lunarlander/config/lunarlander_dqn_config.py @@ -24,7 +24,7 @@ lunarlander_dqn_default_config = dict( # Whether to use dueling head. dueling=True, ), - # Reward's future discount facotr, aka. gamma. + # Reward's future discount factor, aka. gamma. discount_factor=0.99, # How many steps in td error. nstep=nstep, @@ -33,7 +33,7 @@ lunarlander_dqn_default_config = dict( update_per_collect=10, batch_size=64, learning_rate=0.001, - # Frequence of target network update. + # Frequency of target network update. target_update_freq=100, ), # collect_mode config @@ -41,7 +41,7 @@ lunarlander_dqn_default_config = dict( # You can use either "n_sample" or "n_episode" in collector.collect. # Get "n_sample" samples per collect. n_sample=64, - # Cut trajectories into pieces with length "unrol_len". + # Cut trajectories into pieces with length "unroll_len". unroll_len=1, ), # command_mode config diff --git a/dizoo/classic_control/cartpole/config/__init__.py b/dizoo/classic_control/cartpole/config/__init__.py index 3a6ee98b74455b318eaafe9a13232bbc75f90bd4..4296aa4d0bd3157ab6240784e21164fded0cdb1b 100644 --- a/dizoo/classic_control/cartpole/config/__init__.py +++ b/dizoo/classic_control/cartpole/config/__init__.py @@ -11,4 +11,7 @@ from .cartpole_sqn_config import cartpole_sqn_config, cartpole_sqn_create_config from .cartpole_ppg_config import cartpole_ppg_config, cartpole_ppg_create_config from .cartpole_r2d2_config import cartpole_r2d2_config, cartpole_r2d2_create_config from .cartpole_acer_config import cartpole_acer_config, cartpole_acer_create_config +from .cartpole_dqfd_config import cartpole_dqfd_config, cartpole_dqfd_create_config +from .cartpole_sqil_config import cartpole_sqil_config, cartpole_sqil_create_config +from .cartpole_sql_config import cartpole_sql_config, cartpole_sql_create_config # from .cartpole_ppo_default_loader import cartpole_ppo_default_loader diff --git a/dizoo/classic_control/cartpole/config/cartpole_dqfd_config.py b/dizoo/classic_control/cartpole/config/cartpole_dqfd_config.py new file mode 100644 index 0000000000000000000000000000000000000000..4007c558ef3d95f5004e1a838af11f01cb87cbe2 --- /dev/null +++ b/dizoo/classic_control/cartpole/config/cartpole_dqfd_config.py @@ -0,0 +1,58 @@ +from easydict import EasyDict + +cartpole_dqfd_config = dict( + exp_name='cartpole_dqfd', + env=dict( + manager=dict(shared_memory=True, force_reproducibility=True), + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=195, + ), + policy=dict( + cuda=True, + priority=True, + model=dict( + obs_shape=4, + action_shape=2, + encoder_hidden_size_list=[128, 128, 64], + dueling=True, + ), + nstep=3, + discount_factor=0.97, + learn=dict( + batch_size=64, + learning_rate=0.001, + lambda1 = 1, + lambda2 = 3.0, + lambda3 = 0, # set this to be 0 (L2 loss = 0) with expert_replay_buffer_size = 0 and lambda1 = 0 recover the one step pdd dqn + per_train_iter_k = 10, + expert_replay_buffer_size = 10000, # justify the buffer size of the expert buffer + ), + # Users should add their own path here (path should lead to a well-trained model) + collect=dict(n_sample=8, demonstration_info_path = 'path'), + # note: this is the times after which you learns to evaluate + eval=dict(evaluator=dict(eval_freq=50, )), + other=dict( + eps=dict( + type='exp', + start=0.95, + end=0.1, + decay=10000, + ), + replay_buffer=dict(replay_buffer_size=20000, ), + ), + ), +) +cartpole_dqfd_config = EasyDict(cartpole_dqfd_config) +main_config = cartpole_dqfd_config +cartpole_dqfd_create_config = dict( + env=dict( + type='cartpole', + import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='dqfd'), +) +cartpole_dqfd_create_config = EasyDict(cartpole_dqfd_create_config) +create_config = cartpole_dqfd_create_config