From 7a4de160db4a2423e762ad693c1364b71e7dc5a9 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Tue, 19 Oct 2021 20:25:53 +0800 Subject: [PATCH] feature(nyz): add naive offpolicy demo(ci skip) --- ding/entry/offpolicy_demo.py | 112 +++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 ding/entry/offpolicy_demo.py diff --git a/ding/entry/offpolicy_demo.py b/ding/entry/offpolicy_demo.py new file mode 100644 index 0000000..36bcf5b --- /dev/null +++ b/ding/entry/offpolicy_demo.py @@ -0,0 +1,112 @@ +from easydict import EasyDict +import time +import torch +import treetensor.torch as ttorch + +import ding +from ding import compile_config, set_pkg_seed, Policy, Pool, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, MemoryBuffer, VariableManager, Pipeline +from .model import CustomizedModel +from .env import CustomizedEnv + + +class DQNPolicy(Policy): + + def _init_learn(self): + pass + + def _init_collect(self): + pass + + def _init_eval(self): + pass + + def forward_learn(self, learner): + data = learner.train_data + log_info = {} # TODO + return log_info + + def forward_collect(self, collector): + # TODO pass train_iter + obs = collector.env.obs + obs = ttorch.tensor(obs) + if self._cuda: + obs = obs.cuda() + self._collect_model.eval() + with torch.no_grad(): + eps = self._eps_fn(collector.env_step) + output = self._collect_model.forward(obs, eps=eps) + if self._cuda: + output = output.cpu() + action = output.action + timestep = collector.env.step(action) + # TODO s_t+1 problem + train_timestep = self._process_train_timestep(obs, output, timestep) # TODO async case + self._reset_collect(timestep.env_id, timestep.done) + log_info = {} # TODO + return train_timestep, timestep, log_info + + def forward_eval(self, evaluator): + obs = evaluator.env.obs + obs = ttorch.tensor(obs) + if self._cuda: + obs = obs.cuda() + self._eval_model.eval() + with torch.no_grad(): + output = self._eval_model.forward(obs) + if self._cuda: + output = output.cpu() + action = output.action + timestep = evaluator.env.step(action) + self._reset_eval(timestep.env_id, timestep.done) + log_info = {} # TODO + return timestep, log_info + + +class OffPolicyTrainPipeline(Pipeline): + + def __init__(self, cfg: EasyDict, env, policy): + super(OffPolicyTrainPipeline, self).__init__() + self.env = env + self.policy = policy + self.collector = SampleSerialCollector(cfg.collector, self.env, forward_fn=self.policy.forward_collect) + self.learner = BaseLearner(cfg.learner, self.env, forward_fn=self.policy.forward_learn) + self.buffer = MemoryBuffer(cfg.buffer, strategy_fn=self.policy.buffer_strategy) + + def run(self): + while not (self.learner.stop() and self.collector.stop()): + self.collector.collect(self.buffer) + data = self.buffer.sample() + self.learner.learn(data) + + +class EvalPipeline(Pipeline): + + def __init__(self, cfg: EasyDict, env, policy): + super(EvalPipeline, self).__init__() + self.env = env + self.policy = policy + self.evaluator = InteractionSerialEvaluator(cfg.evaluator, self.env, forward_fn=self.policy.forward_eval) + + def run(self): + while True: + if self.evaluator.should_eval(): + stop_flag, reward = self.evaluator.eval() + if stop_flag: + break + # TODO trigger save ckpt + # TODO shutdown the whole program + time.sleep(1) + + +def main(config_path: str, seed: int): + cfg = compile_config(config_path, seed) + set_pkg_seed(seed) + model = CustomizedModel(cfg.policy.model) + policy = DQNPolicy(cfg.policy) + train_env = CustomizedEnv(cfg.env, seed) + eval_env = CustomizedEnv(cfg.env, seed) + # train_env = CustomizedEnv(cfg.env, seed).clone(8) + + train_pipeline = OffPolicyTrainPipeline(cfg, train_env, policy) + eval_pipeline = EvalPipeline(cfg, eval_env, policy) + ding.run([train_pipeline, eval_pipeline]) -- GitLab