diff --git a/ding/worker/buffer/__init__.py b/ding/worker/buffer/__init__.py index 244fab5c160448c16e784506abcce2c009129c64..d3cf7c0b13a5fa792832cb0151f9beb80a2cf4fc 100644 --- a/ding/worker/buffer/__init__.py +++ b/ding/worker/buffer/__init__.py @@ -1,2 +1,3 @@ from .buffer import Buffer, apply_middleware, BufferedData from .deque_buffer import DequeBuffer +from .deque_buffer_wrapper import DequeBufferWrapper diff --git a/ding/worker/buffer/deque_buffer_wrapper.py b/ding/worker/buffer/deque_buffer_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..b12bced33123133eb0257fde4c2174019e757c51 --- /dev/null +++ b/ding/worker/buffer/deque_buffer_wrapper.py @@ -0,0 +1,34 @@ +from typing import Optional +import copy +from easydict import EasyDict +from ding.worker.buffer import DequeBuffer +from ding.utils import BUFFER_REGISTRY + + +@BUFFER_REGISTRY.register('deque') +class DequeBufferWrapper(object): + + @classmethod + def default_config(cls: type) -> EasyDict: + cfg = EasyDict(copy.deepcopy(cls.config)) + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + config = dict(replay_buffer_size=10000, ) + + def __init__( + self, cfg: EasyDict, tb_logger: Optional[object] = None, exp_name: str = 'default_experiement' + ) -> None: + self.buffer = DequeBuffer(size=cfg.replay_buffer_size) + + def sample(self, size: int, train_iter: int): + output = self.buffer.sample(size=size, ignore_insufficient=True) + if len(output) > 0: + return [o.data for o in output] + else: + return None + + def push(self, data, cur_collector_envstep: int = -1) -> None: + # meta = {'train_iter_data_collected': } + for d in data: + self.buffer.push(d) diff --git a/dizoo/classic_control/cartpole/entry/cartpole_dqn_buffer_main.py b/dizoo/classic_control/cartpole/entry/cartpole_dqn_buffer_main.py new file mode 100644 index 0000000000000000000000000000000000000000..561ea89d20c6e76c42ae814e033de88730d492b9 --- /dev/null +++ b/dizoo/classic_control/cartpole/entry/cartpole_dqn_buffer_main.py @@ -0,0 +1,80 @@ +import os +import gym +from tensorboardX import SummaryWriter + +from ding.config import compile_config +from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, DequeBufferWrapper +from ding.envs import BaseEnvManager, DingEnvWrapper +from ding.policy import DQNPolicy +from ding.model import DQN +from ding.utils import set_pkg_seed +from ding.rl_utils import get_epsilon_greedy_fn +from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config + + +# Get DI-engine form env class +def wrapped_cartpole_env(): + return DingEnvWrapper(gym.make('CartPole-v0')) + + +def main(cfg, seed=0): + cfg = compile_config( + cfg, + BaseEnvManager, + DQNPolicy, + BaseLearner, + SampleSerialCollector, + InteractionSerialEvaluator, + DequeBufferWrapper, + save_cfg=True + ) + collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num + collector_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(collector_env_num)], cfg=cfg.env.manager) + evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager) + + # Set random seed for all package and instance + collector_env.seed(seed) + evaluator_env.seed(seed, dynamic_seed=False) + set_pkg_seed(seed, use_cuda=cfg.policy.cuda) + + # Set up RL Policy + model = DQN(**cfg.policy.model) + policy = DQNPolicy(cfg.policy, model=model) + + # Set up collection, training and evaluation utilities + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + collector = SampleSerialCollector( + cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name + ) + evaluator = InteractionSerialEvaluator( + cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name + ) + replay_buffer = DequeBufferWrapper(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name) + + # Set up other modules, etc. epsilon greedy + eps_cfg = cfg.policy.other.eps + epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) + + # Training & Evaluation loop + while True: + # Evaluating at the beginning and with specific frequency + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + # Update other modules + eps = epsilon_greedy(collector.envstep) + # Sampling data from environments + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps}) + replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) + # Training + for i in range(cfg.policy.learn.update_per_collect): + train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) + if train_data is None: + break + learner.train(train_data, collector.envstep) + + +if __name__ == "__main__": + main(cartpole_dqn_config)