提交 7a4de160 编写于 作者: N niuyazhe

feature(nyz): add naive offpolicy demo(ci skip)

上级 ad394fc5
from easydict import EasyDict
import time
import torch
import treetensor.torch as ttorch
import ding
from ding import compile_config, set_pkg_seed, Policy, Pool, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, MemoryBuffer, VariableManager, Pipeline
from .model import CustomizedModel
from .env import CustomizedEnv
class DQNPolicy(Policy):
def _init_learn(self):
pass
def _init_collect(self):
pass
def _init_eval(self):
pass
def forward_learn(self, learner):
data = learner.train_data
log_info = {} # TODO
return log_info
def forward_collect(self, collector):
# TODO pass train_iter
obs = collector.env.obs
obs = ttorch.tensor(obs)
if self._cuda:
obs = obs.cuda()
self._collect_model.eval()
with torch.no_grad():
eps = self._eps_fn(collector.env_step)
output = self._collect_model.forward(obs, eps=eps)
if self._cuda:
output = output.cpu()
action = output.action
timestep = collector.env.step(action)
# TODO s_t+1 problem
train_timestep = self._process_train_timestep(obs, output, timestep) # TODO async case
self._reset_collect(timestep.env_id, timestep.done)
log_info = {} # TODO
return train_timestep, timestep, log_info
def forward_eval(self, evaluator):
obs = evaluator.env.obs
obs = ttorch.tensor(obs)
if self._cuda:
obs = obs.cuda()
self._eval_model.eval()
with torch.no_grad():
output = self._eval_model.forward(obs)
if self._cuda:
output = output.cpu()
action = output.action
timestep = evaluator.env.step(action)
self._reset_eval(timestep.env_id, timestep.done)
log_info = {} # TODO
return timestep, log_info
class OffPolicyTrainPipeline(Pipeline):
def __init__(self, cfg: EasyDict, env, policy):
super(OffPolicyTrainPipeline, self).__init__()
self.env = env
self.policy = policy
self.collector = SampleSerialCollector(cfg.collector, self.env, forward_fn=self.policy.forward_collect)
self.learner = BaseLearner(cfg.learner, self.env, forward_fn=self.policy.forward_learn)
self.buffer = MemoryBuffer(cfg.buffer, strategy_fn=self.policy.buffer_strategy)
def run(self):
while not (self.learner.stop() and self.collector.stop()):
self.collector.collect(self.buffer)
data = self.buffer.sample()
self.learner.learn(data)
class EvalPipeline(Pipeline):
def __init__(self, cfg: EasyDict, env, policy):
super(EvalPipeline, self).__init__()
self.env = env
self.policy = policy
self.evaluator = InteractionSerialEvaluator(cfg.evaluator, self.env, forward_fn=self.policy.forward_eval)
def run(self):
while True:
if self.evaluator.should_eval():
stop_flag, reward = self.evaluator.eval()
if stop_flag:
break
# TODO trigger save ckpt
# TODO shutdown the whole program
time.sleep(1)
def main(config_path: str, seed: int):
cfg = compile_config(config_path, seed)
set_pkg_seed(seed)
model = CustomizedModel(cfg.policy.model)
policy = DQNPolicy(cfg.policy)
train_env = CustomizedEnv(cfg.env, seed)
eval_env = CustomizedEnv(cfg.env, seed)
# train_env = CustomizedEnv(cfg.env, seed).clone(8)
train_pipeline = OffPolicyTrainPipeline(cfg, train_env, policy)
eval_pipeline = EvalPipeline(cfg, eval_env, policy)
ding.run([train_pipeline, eval_pipeline])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册