未验证 提交 11cc97e8 编写于 作者: S Swain 提交者: GitHub

feature(nyz): add supervised learning image classification training demo (#27)

* feature(nyz): add resnet for cv sl task

* feature(nyz): add imagenet classification dataset and adapt compile config for sl

* feature(nyz): add naive image training entry demo

* style(nyz): polish image cls train log

* polish(nyz): polish multi gpu training setting

* feature(nyz): add nn training bp and update async execution

* feature(nyz): add distributed sampler for different dist backend

* fix(nyz): fix compile config collector and buffer compatibility problem

* style(nyz): correct yapf format

* fix(nyz): fix env manager compile config compatibility bug

* refactor(nyz): abstarct ISerialEvaluator and rename serial evaluation implementation

* refactor(nyz): refactor collector name

* feature(nyz): add metric evaluator and image cls acc metric eval demo

* fix(nyz): fix cuda and multi gpu bug in image cls demo
上级 5e52c1a0
......@@ -11,10 +11,11 @@ import yaml
from easydict import EasyDict
from ding.utils import deep_merge_dicts
from ding.envs import get_env_cls, get_env_manager_cls
from ding.envs import get_env_cls, get_env_manager_cls, BaseEnvManager
from ding.policy import get_policy_cls
from ding.worker import BaseLearner, BaseSerialEvaluator, BaseSerialCommander, Coordinator, AdvancedReplayBuffer, \
get_parallel_commander_cls, get_parallel_collector_cls, get_buffer_cls, get_serial_collector_cls
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, Coordinator, \
AdvancedReplayBuffer, get_parallel_commander_cls, get_parallel_collector_cls, get_buffer_cls, \
get_serial_collector_cls, MetricSerialEvaluator, BattleInteractionSerialEvaluator
from ding.reward_model import get_reward_model_cls
from .utils import parallel_transform, parallel_transform_slurm, parallel_transform_k8s, save_config_formatted
......@@ -308,7 +309,7 @@ def compile_config(
policy: type = None,
learner: type = BaseLearner,
collector: type = None,
evaluator: type = BaseSerialEvaluator,
evaluator: type = InteractionSerialEvaluator,
buffer: type = AdvancedReplayBuffer,
env: type = None,
reward_model: type = None,
......@@ -328,7 +329,7 @@ def compile_config(
- policy (:obj:`type`): Policy class which is to be used in the following pipeline
- learner (:obj:`type`): Input learner class, defaults to BaseLearner
- collector (:obj:`type`): Input collector class, defaults to BaseSerialCollector
- evaluator (:obj:`type`): Input evaluator class, defaults to BaseSerialEvaluator
- evaluator (:obj:`type`): Input evaluator class, defaults to InteractionSerialEvaluator
- buffer (:obj:`type`): Input buffer class, defaults to BufferManager
- env (:obj:`type`): Environment class which is to be used in the following pipeline
- reward_model (:obj:`type`): Reward model class which aims to offer various and valuable reward
......@@ -361,6 +362,7 @@ def compile_config(
env_config.update(create_cfg.env)
env_config.manager = deep_merge_dicts(env_manager.default_config(), env_config.manager)
env_config.manager.update(create_cfg.env_manager)
print(env_config)
policy_config = policy.default_config()
policy_config = deep_merge_dicts(policy_config_template, policy_config)
policy_config.update(create_cfg.policy)
......@@ -379,6 +381,8 @@ def compile_config(
else:
env_config = EasyDict() # env does not have default_config
env_config = deep_merge_dicts(env_config_template, env_config)
if env_manager is None:
env_manager = BaseEnvManager # for compatibility
env_config.manager = deep_merge_dicts(env_manager.default_config(), env_config.manager)
policy_config = policy.default_config()
policy_config = deep_merge_dicts(policy_config_template, policy_config)
......@@ -390,26 +394,32 @@ def compile_config(
learner.default_config(),
policy_config.learn.learner,
)
policy_config.collect.collector = compile_collector_config(policy_config, cfg, collector)
if create_cfg is not None or collector is not None:
policy_config.collect.collector = compile_collector_config(policy_config, cfg, collector)
policy_config.eval.evaluator = deep_merge_dicts(
evaluator.default_config(),
policy_config.eval.evaluator,
)
policy_config.other.replay_buffer = compile_buffer_config(policy_config, cfg, buffer)
if create_cfg is not None or buffer is not None:
policy_config.other.replay_buffer = compile_buffer_config(policy_config, cfg, buffer)
default_config = EasyDict({'env': env_config, 'policy': policy_config})
if len(reward_model_config) > 0:
default_config['reward_model'] = reward_model_config
cfg = deep_merge_dicts(default_config, cfg)
cfg.seed = seed
# check important key in config
assert all([k in cfg.env for k in ['n_evaluator_episode', 'stop_value']]), cfg.env
cfg.policy.eval.evaluator.stop_value = cfg.env.stop_value
cfg.policy.eval.evaluator.n_episode = cfg.env.n_evaluator_episode
if evaluator in [InteractionSerialEvaluator, BattleInteractionSerialEvaluator]: # env interaction evaluation
assert all([k in cfg.env for k in ['n_evaluator_episode', 'stop_value']]), cfg.env
cfg.policy.eval.evaluator.stop_value = cfg.env.stop_value
cfg.policy.eval.evaluator.n_episode = cfg.env.n_evaluator_episode
if 'exp_name' not in cfg:
cfg.exp_name = 'default_experiment'
if save_cfg:
if not os.path.exists(cfg.exp_name):
os.mkdir(cfg.exp_name)
try:
os.mkdir(cfg.exp_name)
except FileExistsError:
pass
save_path = os.path.join(cfg.exp_name, save_path)
save_config(cfg, save_path, save_formatted=True)
return cfg
......
......@@ -501,7 +501,7 @@ parallel_test_create_config = dict(
),
collector=dict(
type='zergling',
import_names=['ding.worker.collector.zergling_collector'],
import_names=['ding.worker.collector.zergling_parallel_collector'],
),
commander=dict(
type='naive',
......
......@@ -4,7 +4,7 @@ import torch
from functools import partial
from ding.config import compile_config, read_config
from ding.worker import SampleCollector, BaseSerialEvaluator
from ding.worker import SampleSerialCollector, InteractionSerialEvaluator
from ding.envs import create_env_manager, get_vec_env_setting
from ding.policy import create_policy
from ding.torch_utils import to_device
......@@ -63,7 +63,7 @@ def eval(
load_path = cfg.policy.learn.learner.load_path
state_dict = torch.load(load_path, map_location='cpu')
policy.eval_mode.load_state_dict(state_dict)
evaluator = BaseSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode)
evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode)
# Evaluate
_, eval_reward = evaluator.eval()
......@@ -135,7 +135,7 @@ def collect_demo_data(
if state_dict is None:
state_dict = torch.load(cfg.learner.load_path, map_location='cpu')
policy.collect_mode.load_state_dict(state_dict)
collector = SampleCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy)
collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy)
# Let's collect some expert demostrations
exp_data = collector.collect(n_sample=collect_count)
......
......@@ -6,7 +6,7 @@ from functools import partial
from tensorboardX import SummaryWriter
from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, BaseSerialCommander, create_buffer, \
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
......@@ -65,7 +65,7 @@ def serial_pipeline(
tb_logger=tb_logger,
exp_name=cfg.exp_name
)
evaluator = BaseSerialEvaluator(
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
......
......@@ -5,7 +5,7 @@ from functools import partial
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from ding.worker import BaseLearner, BaseSerialEvaluator
from ding.worker import BaseLearner, InteractionSerialEvaluator
from ding.envs import get_vec_env_setting, create_env_manager
from ding.config import read_config, compile_config
from ding.policy import create_policy
......@@ -52,7 +52,7 @@ def serial_pipeline_il(
dataset = NaiveRLDataset(data_path)
dataloader = DataLoader(dataset, cfg.policy.learn.batch_size, collate_fn=lambda x: x)
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
evaluator = BaseSerialEvaluator(
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
# ==========
......
......@@ -6,7 +6,7 @@ from functools import partial
from tensorboardX import SummaryWriter
from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, BaseSerialCommander, create_buffer, \
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
......@@ -58,7 +58,7 @@ def serial_pipeline_offline(
dataset = create_dataset(cfg)
dataloader = DataLoader(dataset, cfg.policy.learn.batch_size, collate_fn=lambda x: x)
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
evaluator = BaseSerialEvaluator(
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
# ==========
......
......@@ -6,7 +6,7 @@ from functools import partial
from tensorboardX import SummaryWriter
from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, BaseSerialCommander, create_buffer, \
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
......@@ -65,7 +65,7 @@ def serial_pipeline_onpolicy(
tb_logger=tb_logger,
exp_name=cfg.exp_name
)
evaluator = BaseSerialEvaluator(
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
# ==========
......
......@@ -6,7 +6,7 @@ from functools import partial
from tensorboardX import SummaryWriter
from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, BaseSerialCommander, create_buffer, \
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
......@@ -66,7 +66,7 @@ def serial_pipeline_reward_model(
tb_logger=tb_logger,
exp_name=cfg.exp_name
)
evaluator = BaseSerialEvaluator(
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
......
......@@ -7,7 +7,7 @@ from functools import partial
from tensorboardX import SummaryWriter
from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, BaseSerialCommander, create_buffer, \
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
......@@ -94,7 +94,7 @@ def serial_pipeline_sqil(
tb_logger=tb_logger,
exp_name=expert_cfg.exp_name
)
evaluator = BaseSerialEvaluator(
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
......
......@@ -7,7 +7,7 @@ import copy
from easydict import EasyDict
from ding.model import create_model
from ding.utils import import_module, allreduce, broadcast, get_rank, POLICY_REGISTRY
from ding.utils import import_module, allreduce, broadcast, get_rank, allreduce_async, synchronize, POLICY_REGISTRY
class Policy(ABC):
......@@ -78,7 +78,9 @@ class Policy(ABC):
torch.cuda.set_device(self._rank % torch.cuda.device_count())
model.cuda()
if self._cfg.learn.multi_gpu:
self._init_multi_gpu_setting(model)
bp_update_sync = self._cfg.learn.get('bp_update_sync', True)
self._bp_update_sync = bp_update_sync
self._init_multi_gpu_setting(model, bp_update_sync)
else:
self._rank = 0
if self._cuda:
......@@ -94,12 +96,26 @@ class Policy(ABC):
for field in self._enable_field:
getattr(self, '_init_' + field)()
def _init_multi_gpu_setting(self, model: torch.nn.Module) -> None:
def _init_multi_gpu_setting(self, model: torch.nn.Module, bp_update_sync: bool) -> None:
for name, param in model.state_dict().items():
assert isinstance(param.data, torch.Tensor), type(param.data)
broadcast(param.data, 0)
for name, param in model.named_parameters():
setattr(param, 'grad', torch.zeros_like(param))
if not bp_update_sync:
def make_hook(name, p):
def hook(*ignore):
allreduce_async(name, p.grad.data)
return hook
for i, (name, p) in enumerate(model.named_parameters()):
if p.requires_grad:
p_tmp = p.expand_as(p)
grad_acc = p_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(make_hook(name, p))
def _create_model(self, cfg: dict, model: Optional[torch.nn.Module] = None) -> torch.nn.Module:
if model is None:
......@@ -183,9 +199,12 @@ class Policy(ABC):
return "DI-engine DRL Policy\n{}".format(repr(self._model))
def sync_gradients(self, model: torch.nn.Module) -> None:
for name, param in model.named_parameters():
if param.requires_grad:
allreduce(param.grad.data)
if self._bp_update_sync:
for name, param in model.named_parameters():
if param.requires_grad:
allreduce(param.grad.data)
else:
synchronize()
# don't need to implement default_model method by force
def default_model(self) -> Tuple[str, List[str]]:
......
......@@ -108,7 +108,7 @@ spec:
learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
collector=dict(
type='zergling',
import_names=['ding.worker.collector.zergling_collector'],
import_names=['ding.worker.collector.zergling_parallel_collector'],
),
commander=dict(
type='solo',
......
......@@ -7,3 +7,4 @@ from .rnn import get_lstm, sequence_mask
from .soft_argmax import SoftArgmax
from .transformer import Transformer
from .scatter_connection import ScatterConnection
from .resnet import resnet18, ResNet
此差异已折叠。
import pytest
import torch
from ding.torch_utils.network import resnet18
@pytest.mark.unittest
def test_resnet18():
model = resnet18()
print(model)
inputs = torch.randn(4, 3, 224, 224)
outputs = model(inputs)
assert outputs.shape == (4, 1000)
......@@ -15,7 +15,7 @@ from .log_helper import build_logger, DistributionTimeImage, pretty_print, Logge
from .registry_factory import registries, POLICY_REGISTRY, ENV_REGISTRY, LEARNER_REGISTRY, COMM_LEARNER_REGISTRY, \
SERIAL_COLLECTOR_REGISTRY, PARALLEL_COLLECTOR_REGISTRY, COMM_COLLECTOR_REGISTRY, \
COMMANDER_REGISTRY, LEAGUE_REGISTRY, PLAYER_REGISTRY, MODEL_REGISTRY, \
ENV_MANAGER_REGISTRY, REWARD_MODEL_REGISTRY, BUFFER_REGISTRY, DATASET_REGISTRY
ENV_MANAGER_REGISTRY, REWARD_MODEL_REGISTRY, BUFFER_REGISTRY, DATASET_REGISTRY, SERIAL_EVALUATOR_REGISTRY
from .segment_tree import SumSegmentTree, MinSegmentTree, SegmentTree
from .slurm_helper import find_free_port_slurm, node_to_host, node_to_partition
from .system_helper import get_ip, get_pid, get_task_uid, PropagatingThread, find_free_port
......@@ -25,7 +25,7 @@ from .data import create_dataset
if ding.enable_linklink:
from .linklink_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \
allreduce, broadcast, DistContext
allreduce, broadcast, DistContext, allreduce_async, synchronize
else:
from .pytorch_ddp_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \
allreduce, broadcast, DistContext
allreduce, broadcast, DistContext, allreduce_async, synchronize
......@@ -78,6 +78,26 @@ def allreduce(data: torch.Tensor, op: str = 'sum') -> None:
data.div_(get_world_size())
def allreduce_async(data: torch.Tensor, op: str = 'sum') -> None:
r"""
Overview:
Call ``linklink.allreduce_async`` on the data
Arguments:
- data (:obj:`obj`): the data to reduce
- op (:obj:`str`): the operation to perform on data, support ``['sum', 'max']``
"""
link_op_map = {'sum': get_link().allreduceOp_t.Sum, 'max': get_link().allreduceOp_t.Max}
if op not in link_op_map.keys():
raise KeyError("not support allreduce op type: {}".format(op))
else:
link_op = link_op_map[op]
if is_fake_link():
return data
if op == 'sum':
data.div_(get_world_size())
get_link().allreduce_async(data, reduce_op=link_op)
def get_group(group_size: int) -> List:
r"""
Overview:
......@@ -166,3 +186,7 @@ def simple_group_split(world_size: int, rank: int, num_groups: int) -> List:
groups.append(get_link().new_group(rank_list[i]))
group_size = world_size // num_groups
return groups[rank // group_size]
def synchronize():
get_link().synchronize()
......@@ -29,10 +29,22 @@ def get_world_size() -> int:
broadcast = dist.broadcast
allreduce = dist.all_reduce
allgather = dist.all_gather
def allreduce(x: torch.Tensor) -> None:
dist.all_reduce(x)
x.div_(get_world_size())
def allreduce_async(name: str, x: torch.Tensor) -> None:
x.div_(get_world_size())
dist.all_reduce(x, async_op=True)
synchronize = torch.cuda.synchronize
def get_group(group_size: int) -> List:
r"""
Overview:
......
......@@ -15,6 +15,7 @@ MODEL_REGISTRY = Registry()
ENV_MANAGER_REGISTRY = Registry()
REWARD_MODEL_REGISTRY = Registry()
DATASET_REGISTRY = Registry()
SERIAL_EVALUATOR_REGISTRY = Registry()
registries = {
'policy': POLICY_REGISTRY,
......@@ -32,4 +33,5 @@ registries = {
'player': PLAYER_REGISTRY,
'buffer': BUFFER_REGISTRY,
'dataset': DATASET_REGISTRY,
'serial_evaluator': SERIAL_EVALUATOR_REGISTRY,
}
# serial
from .base_serial_collector import ISerialCollector, create_serial_collector, get_serial_collector_cls, \
to_tensor_transitions
from .sample_serial_collector import SampleCollector
from .episode_serial_collector import EpisodeCollector
from .episode_one_vs_one_serial_collector import Episode1v1Collector
from .base_serial_evaluator import BaseSerialEvaluator
from .one_vs_one_serial_evaluator import OnevOneEvaluator
from .sample_serial_collector import SampleSerialCollector
from .episode_serial_collector import EpisodeSerialCollector
from .battle_episode_serial_collector import BattleEpisodeSerialCollector
from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor
from .interaction_serial_evaluator import InteractionSerialEvaluator
from .battle_interaction_serial_evaluator import BattleInteractionSerialEvaluator
from .metric_serial_evaluator import MetricSerialEvaluator, IMetric
# parallel
from .base_parallel_collector import BaseCollector, create_parallel_collector, get_parallel_collector_cls
from .zergling_collector import ZerglingCollector
from .base_parallel_collector import BaseParallelCollector, create_parallel_collector, get_parallel_collector_cls
from .zergling_parallel_collector import ZerglingParallelCollector
from .marine_parallel_collector import MarineParallelCollector
from .comm import BaseCommCollector, FlaskFileSystemCollector, create_comm_collector, NaiveCollector
......@@ -15,7 +15,7 @@ from ding.utils import build_logger, EasyTimer, get_task_uid, import_module, pre
from ding.torch_utils import build_log_buffer, to_tensor, to_ndarray
class BaseCollector(ABC):
class BaseParallelCollector(ABC):
"""
Overview:
Abstract baseclass for collector.
......@@ -199,7 +199,7 @@ class BaseCollector(ABC):
self._env_manager = _env_manager
def create_parallel_collector(cfg: EasyDict) -> BaseCollector:
def create_parallel_collector(cfg: EasyDict) -> BaseParallelCollector:
import_module(cfg.get('import_names', []))
return PARALLEL_COLLECTOR_REGISTRY.build(cfg.type, cfg=cfg)
......
from typing import Any, Optional, Callable, Tuple
from abc import ABC, abstractmethod
from collections import namedtuple, deque
from easydict import EasyDict
import copy
import numpy as np
import torch
from ding.utils import build_logger, EasyTimer, lists_to_dicts
from ding.envs import BaseEnvManager
from ding.utils import lists_to_dicts
from ding.torch_utils import to_tensor, to_ndarray, tensor_to_list
class BaseSerialEvaluator(object):
class ISerialEvaluator(ABC):
"""
Overview:
Base class for serial evaluator.
Basic interface class for serial evaluator.
Interfaces:
__init__, reset, reset_policy, reset_env, close, should_eval, eval
reset, reset_policy, reset_env, close, should_eval, eval
Property:
env, policy
"""
......@@ -33,232 +33,35 @@ class BaseSerialEvaluator(object):
cfg.cfg_type = cls.__name__ + 'Dict'
return cfg
config = dict(
# Evaluate every "eval_freq" training iterations.
eval_freq=50,
)
def __init__(
self,
cfg: dict,
env: BaseEnvManager = None,
policy: namedtuple = None,
tb_logger: 'SummaryWriter' = None, # noqa
exp_name: Optional[str] = 'default_experiment',
instance_name: Optional[str] = 'evaluator',
) -> None:
"""
Overview:
Init method. Load config and use ``self._cfg`` setting to build common serial evaluator components,
e.g. logger helper, timer.
Arguments:
- cfg (:obj:`EasyDict`): Configuration EasyDict.
"""
self._cfg = cfg
self._exp_name = exp_name
self._instance_name = instance_name
if tb_logger is not None:
self._logger, _ = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False
)
self._tb_logger = tb_logger
else:
self._logger, self._tb_logger = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name
)
self.reset(policy, env)
self._timer = EasyTimer()
self._default_n_episode = cfg.n_episode
self._stop_value = cfg.stop_value
def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
"""
Overview:
Reset evaluator's environment. In some case, we need evaluator use the same policy in different \
environments. We can use reset_env to reset the environment.
If _env is None, reset the old environment.
If _env is not None, replace the old environment in the evaluator with the \
new passed in environment and launch.
Arguments:
- env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
env_manager(BaseEnvManager)
"""
if _env is not None:
self._env = _env
self._env.launch()
self._env_num = self._env.env_num
else:
self._env.reset()
@abstractmethod
def reset_env(self, _env: Optional[Any] = None) -> None:
raise NotImplementedError
@abstractmethod
def reset_policy(self, _policy: Optional[namedtuple] = None) -> None:
"""
Overview:
Reset evaluator's policy. In some case, we need evaluator work in this same environment but use\
different policy. We can use reset_policy to reset the policy.
If _policy is None, reset the old policy.
If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
Arguments:
- policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy
"""
assert hasattr(self, '_env'), "please set env first"
if _policy is not None:
self._policy = _policy
self._policy.reset()
raise NotImplementedError
def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None:
"""
Overview:
Reset evaluator's policy and environment. Use new policy and environment to collect data.
If _env is None, reset the old environment.
If _env is not None, replace the old environment in the evaluator with the new passed in \
environment and launch.
If _policy is None, reset the old policy.
If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
Arguments:
- policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy
- env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
env_manager(BaseEnvManager)
"""
if _env is not None:
self.reset_env(_env)
if _policy is not None:
self.reset_policy(_policy)
self._max_eval_reward = float("-inf")
self._last_eval_iter = 0
self._end_flag = False
@abstractmethod
def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[Any] = None) -> None:
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""
Overview:
Close the evaluator. If end_flag is False, close the environment, flush the tb_logger\
and close the tb_logger.
"""
if self._end_flag:
return
self._end_flag = True
self._env.close()
self._tb_logger.flush()
self._tb_logger.close()
def __del__(self):
"""
Overview:
Execute the close command and close the evaluator. __del__ is automatically called \
to destroy the evaluator instance when the evaluator finishes its work
"""
self.close()
raise NotImplementedError
@abstractmethod
def should_eval(self, train_iter: int) -> bool:
"""
Overview:
Determine whether you need to start the evaluation mode, if the number of training has reached\
the maximum number of times to start the evaluator, return True
"""
if train_iter == self._last_eval_iter:
return False
if (train_iter - self._last_eval_iter) < self._cfg.eval_freq and train_iter != 0:
return False
self._last_eval_iter = train_iter
return True
raise NotImplementedError
@abstractmethod
def eval(
self,
save_ckpt_fn: Callable = None,
train_iter: int = -1,
envstep: int = -1,
n_episode: Optional[int] = None
) -> Tuple[bool, float]:
'''
Overview:
Evaluate policy and store the best policy based on whether it reaches the highest historical reward.
Arguments:
- save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward.
- train_iter (:obj:`int`): Current training iteration.
- envstep (:obj:`int`): Current env interaction step.
- n_episode (:obj:`int`): Number of evaluation episodes.
Returns:
- stop_flag (:obj:`bool`): Whether this training program can be ended.
- eval_reward (:obj:`float`): Current eval_reward.
'''
if n_episode is None:
n_episode = self._default_n_episode
assert n_episode is not None, "please indicate eval n_episode"
envstep_count = 0
info = {}
eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode)
self._env.reset()
self._policy.reset()
with self._timer:
while not eval_monitor.is_finished():
obs = self._env.ready_obs
obs = to_tensor(obs, dtype=torch.float32)
policy_output = self._policy.forward(obs)
actions = {i: a['action'] for i, a in policy_output.items()}
actions = to_ndarray(actions)
timesteps = self._env.step(actions)
timesteps = to_tensor(timesteps, dtype=torch.float32)
for env_id, t in timesteps.items():
if t.info.get('abnormal', False):
# If there is an abnormal timestep, reset all the related variables(including this env).
self._policy.reset([env_id])
continue
if t.done:
# Env reset is done by env_manager automatically.
self._policy.reset([env_id])
reward = t.info['final_eval_reward']
if 'episode_info' in t.info:
eval_monitor.update_info(env_id, t.info['episode_info'])
eval_monitor.update_reward(env_id, reward)
self._logger.info(
"[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format(
env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode()
)
)
envstep_count += 1
duration = self._timer.value
episode_reward = eval_monitor.get_episode_reward()
info = {
'train_iter': train_iter,
'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter),
'episode_count': n_episode,
'envstep_count': envstep_count,
'avg_envstep_per_episode': envstep_count / n_episode,
'evaluate_time': duration,
'avg_envstep_per_sec': envstep_count / duration,
'avg_time_per_episode': n_episode / duration,
'reward_mean': np.mean(episode_reward),
'reward_std': np.std(episode_reward),
'reward_max': np.max(episode_reward),
'reward_min': np.min(episode_reward),
# 'each_reward': episode_reward,
}
episode_info = eval_monitor.get_episode_info()
if episode_info is not None:
info.update(episode_info)
self._logger.info(self._logger.get_tabulate_vars_hor(info))
# self._logger.info(self._logger.get_tabulate_vars(info))
for k, v in info.items():
if k in ['train_iter', 'ckpt_name', 'each_reward']:
continue
if not np.isscalar(v):
continue
self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep)
eval_reward = np.mean(episode_reward)
if eval_reward > self._max_eval_reward:
if save_ckpt_fn:
save_ckpt_fn('ckpt_best.pth.tar')
self._max_eval_reward = eval_reward
stop_flag = eval_reward >= self._stop_value and train_iter > 0
if stop_flag:
self._logger.info(
"[DI-engine serial pipeline] " +
"Current eval_reward: {} is greater than stop_value: {}".format(eval_reward, self._stop_value) +
", so your RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details."
)
return stop_flag, eval_reward
) -> Any:
raise NotImplementedError
class VectorEvalMonitor(object):
......
......@@ -11,10 +11,10 @@ from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF,
@SERIAL_COLLECTOR_REGISTRY.register('episode_1v1')
class Episode1v1Collector(ISerialCollector):
class BattleEpisodeSerialCollector(ISerialCollector):
"""
Overview:
Episode collector(n_episode)
Episode collector(n_episode) with two policy battle
Interfaces:
__init__, reset, reset_env, reset_policy, collect, close
Property:
......
......@@ -6,13 +6,15 @@ import copy
import numpy as np
import torch
from ding.utils import build_logger, EasyTimer, deep_merge_dicts, lists_to_dicts, dicts_to_lists
from ding.utils import build_logger, EasyTimer, deep_merge_dicts, lists_to_dicts, dicts_to_lists, \
SERIAL_EVALUATOR_REGISTRY
from ding.envs import BaseEnvManager
from ding.torch_utils import to_tensor, to_ndarray, tensor_to_list
from .base_serial_collector import CachePool
from .base_serial_evaluator import ISerialEvaluator
class OnevOneEvaluator(object):
@SERIAL_EVALUATOR_REGISTRY.register('battle_interaction')
class BattleInteractionSerialEvaluator(ISerialEvaluator):
"""
Overview:
1v1 battle evaluator class.
......
......@@ -3,7 +3,7 @@ from typing import Any
from easydict import EasyDict
from ding.utils import get_task_uid, import_module, COMM_COLLECTOR_REGISTRY
from ..base_parallel_collector import create_parallel_collector, BaseCollector
from ..base_parallel_collector import create_parallel_collector, BaseParallelCollector
class BaseCommCollector(ABC):
......@@ -80,14 +80,14 @@ class BaseCommCollector(ABC):
def collector_uid(self) -> str:
return self._collector_uid
def _create_collector(self, task_info: dict) -> BaseCollector:
def _create_collector(self, task_info: dict) -> BaseParallelCollector:
"""
Overview:
Receive ``task_info`` passed from coordinator and create a collector.
Arguments:
- task_info (:obj:`dict`): Task info dict from coordinator. Should be like \
Returns:
- collector (:obj:`BaseCollector`): Created base collector.
- collector (:obj:`BaseParallelCollector`): Created base collector.
Note:
Four methods('send_metadata', 'send_stepdata', 'get_policy_update_info'), and policy are set.
The reason why they are set here rather than base collector is, they highly depend on the specific task.
......
......@@ -11,7 +11,7 @@ from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF,
@SERIAL_COLLECTOR_REGISTRY.register('episode')
class EpisodeCollector(ISerialCollector):
class EpisodeSerialCollector(ISerialCollector):
"""
Overview:
Episode collector(n_episode)
......
from typing import Optional, Callable, Tuple
from collections import namedtuple
import numpy as np
import torch
from ding.envs import BaseEnvManager
from ding.torch_utils import to_tensor, to_ndarray
from ding.utils import build_logger, EasyTimer, SERIAL_EVALUATOR_REGISTRY
from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor
@SERIAL_EVALUATOR_REGISTRY.register('interaction')
class InteractionSerialEvaluator(ISerialEvaluator):
"""
Overview:
Interaction serial evaluator class, policy interacts with env.
Interfaces:
__init__, reset, reset_policy, reset_env, close, should_eval, eval
Property:
env, policy
"""
config = dict(
# Evaluate every "eval_freq" training iterations.
eval_freq=50,
)
def __init__(
self,
cfg: dict,
env: BaseEnvManager = None,
policy: namedtuple = None,
tb_logger: 'SummaryWriter' = None, # noqa
exp_name: Optional[str] = 'default_experiment',
instance_name: Optional[str] = 'evaluator',
) -> None:
"""
Overview:
Init method. Load config and use ``self._cfg`` setting to build common serial evaluator components,
e.g. logger helper, timer.
Arguments:
- cfg (:obj:`EasyDict`): Configuration EasyDict.
"""
self._cfg = cfg
self._exp_name = exp_name
self._instance_name = instance_name
if tb_logger is not None:
self._logger, _ = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False
)
self._tb_logger = tb_logger
else:
self._logger, self._tb_logger = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name
)
self.reset(policy, env)
self._timer = EasyTimer()
self._default_n_episode = cfg.n_episode
self._stop_value = cfg.stop_value
def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
"""
Overview:
Reset evaluator's environment. In some case, we need evaluator use the same policy in different \
environments. We can use reset_env to reset the environment.
If _env is None, reset the old environment.
If _env is not None, replace the old environment in the evaluator with the \
new passed in environment and launch.
Arguments:
- env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
env_manager(BaseEnvManager)
"""
if _env is not None:
self._env = _env
self._env.launch()
self._env_num = self._env.env_num
else:
self._env.reset()
def reset_policy(self, _policy: Optional[namedtuple] = None) -> None:
"""
Overview:
Reset evaluator's policy. In some case, we need evaluator work in this same environment but use\
different policy. We can use reset_policy to reset the policy.
If _policy is None, reset the old policy.
If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
Arguments:
- policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy
"""
assert hasattr(self, '_env'), "please set env first"
if _policy is not None:
self._policy = _policy
self._policy.reset()
def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None:
"""
Overview:
Reset evaluator's policy and environment. Use new policy and environment to collect data.
If _env is None, reset the old environment.
If _env is not None, replace the old environment in the evaluator with the new passed in \
environment and launch.
If _policy is None, reset the old policy.
If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
Arguments:
- policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy
- env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
env_manager(BaseEnvManager)
"""
if _env is not None:
self.reset_env(_env)
if _policy is not None:
self.reset_policy(_policy)
self._max_eval_reward = float("-inf")
self._last_eval_iter = 0
self._end_flag = False
def close(self) -> None:
"""
Overview:
Close the evaluator. If end_flag is False, close the environment, flush the tb_logger\
and close the tb_logger.
"""
if self._end_flag:
return
self._end_flag = True
self._env.close()
self._tb_logger.flush()
self._tb_logger.close()
def __del__(self):
"""
Overview:
Execute the close command and close the evaluator. __del__ is automatically called \
to destroy the evaluator instance when the evaluator finishes its work
"""
self.close()
def should_eval(self, train_iter: int) -> bool:
"""
Overview:
Determine whether you need to start the evaluation mode, if the number of training has reached\
the maximum number of times to start the evaluator, return True
"""
if train_iter == self._last_eval_iter:
return False
if (train_iter - self._last_eval_iter) < self._cfg.eval_freq and train_iter != 0:
return False
self._last_eval_iter = train_iter
return True
def eval(
self,
save_ckpt_fn: Callable = None,
train_iter: int = -1,
envstep: int = -1,
n_episode: Optional[int] = None
) -> Tuple[bool, float]:
'''
Overview:
Evaluate policy and store the best policy based on whether it reaches the highest historical reward.
Arguments:
- save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward.
- train_iter (:obj:`int`): Current training iteration.
- envstep (:obj:`int`): Current env interaction step.
- n_episode (:obj:`int`): Number of evaluation episodes.
Returns:
- stop_flag (:obj:`bool`): Whether this training program can be ended.
- eval_reward (:obj:`float`): Current eval_reward.
'''
if n_episode is None:
n_episode = self._default_n_episode
assert n_episode is not None, "please indicate eval n_episode"
envstep_count = 0
info = {}
eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode)
self._env.reset()
self._policy.reset()
with self._timer:
while not eval_monitor.is_finished():
obs = self._env.ready_obs
obs = to_tensor(obs, dtype=torch.float32)
policy_output = self._policy.forward(obs)
actions = {i: a['action'] for i, a in policy_output.items()}
actions = to_ndarray(actions)
timesteps = self._env.step(actions)
timesteps = to_tensor(timesteps, dtype=torch.float32)
for env_id, t in timesteps.items():
if t.info.get('abnormal', False):
# If there is an abnormal timestep, reset all the related variables(including this env).
self._policy.reset([env_id])
continue
if t.done:
# Env reset is done by env_manager automatically.
self._policy.reset([env_id])
reward = t.info['final_eval_reward']
if 'episode_info' in t.info:
eval_monitor.update_info(env_id, t.info['episode_info'])
eval_monitor.update_reward(env_id, reward)
self._logger.info(
"[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format(
env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode()
)
)
envstep_count += 1
duration = self._timer.value
episode_reward = eval_monitor.get_episode_reward()
info = {
'train_iter': train_iter,
'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter),
'episode_count': n_episode,
'envstep_count': envstep_count,
'avg_envstep_per_episode': envstep_count / n_episode,
'evaluate_time': duration,
'avg_envstep_per_sec': envstep_count / duration,
'avg_time_per_episode': n_episode / duration,
'reward_mean': np.mean(episode_reward),
'reward_std': np.std(episode_reward),
'reward_max': np.max(episode_reward),
'reward_min': np.min(episode_reward),
# 'each_reward': episode_reward,
}
episode_info = eval_monitor.get_episode_info()
if episode_info is not None:
info.update(episode_info)
self._logger.info(self._logger.get_tabulate_vars_hor(info))
# self._logger.info(self._logger.get_tabulate_vars(info))
for k, v in info.items():
if k in ['train_iter', 'ckpt_name', 'each_reward']:
continue
if not np.isscalar(v):
continue
self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep)
eval_reward = np.mean(episode_reward)
if eval_reward > self._max_eval_reward:
if save_ckpt_fn:
save_ckpt_fn('ckpt_best.pth.tar')
self._max_eval_reward = eval_reward
stop_flag = eval_reward >= self._stop_value and train_iter > 0
if stop_flag:
self._logger.info(
"[DI-engine serial pipeline] " +
"Current eval_reward: {} is greater than stop_value: {}".format(eval_reward, self._stop_value) +
", so your RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details."
)
return stop_flag, eval_reward
......@@ -13,14 +13,14 @@ from ding.policy import create_policy, Policy
from ding.envs import get_vec_env_setting, create_env_manager
from ding.utils import get_data_compressor, pretty_print, PARALLEL_COLLECTOR_REGISTRY
from ding.envs import BaseEnvTimestep, BaseEnvManager
from .base_parallel_collector import BaseCollector
from .base_parallel_collector import BaseParallelCollector
from .base_serial_collector import CachePool, TrajBuffer
INF = float("inf")
@PARALLEL_COLLECTOR_REGISTRY.register('one_vs_one')
class OneVsOneCollector(BaseCollector):
@PARALLEL_COLLECTOR_REGISTRY.register('marine')
class MarineParallelCollector(BaseParallelCollector):
"""
Feature:
- one policy or two policies, many envs
......@@ -343,4 +343,4 @@ class OneVsOneCollector(BaseCollector):
}
def __repr__(self) -> str:
return "OneVsOneCollector"
return "MarineParallelCollector"
from typing import Optional, Callable, Tuple, Any, List
from abc import ABC, abstractmethod
from collections import namedtuple
import numpy as np
import torch
from torch.utils.data import DataLoader
from ding.torch_utils import to_tensor, to_ndarray
from ding.utils import build_logger, EasyTimer, SERIAL_EVALUATOR_REGISTRY, allreduce
from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor
class IMetric(ABC):
@abstractmethod
def eval(self, inputs: Any, label: Any) -> dict:
raise NotImplementedError
@abstractmethod
def reduce_mean(self, inputs: List[Any]) -> Any:
raise NotImplementedError
@abstractmethod
def gt(self, metric1: Any, metric2: Any) -> bool:
"""
Overview:
Whether metric1 is greater than metric2 (>=)
.. note::
If metric2 is None, return True
"""
raise NotImplementedError
@SERIAL_EVALUATOR_REGISTRY.register('metric')
class MetricSerialEvaluator(ISerialEvaluator):
"""
Overview:
Metric serial evaluator class, policy is evaluated by objective metric(env).
Interfaces:
__init__, reset, reset_policy, reset_env, close, should_eval, eval
Property:
env, policy
"""
config = dict(
# Evaluate every "eval_freq" training iterations.
eval_freq=50,
)
def __init__(
self,
cfg: dict,
env: Tuple[DataLoader, IMetric] = None,
policy: namedtuple = None,
tb_logger: 'SummaryWriter' = None, # noqa
exp_name: Optional[str] = 'default_experiment',
instance_name: Optional[str] = 'evaluator',
) -> None:
"""
Overview:
Init method. Load config and use ``self._cfg`` setting to build common serial evaluator components,
e.g. logger helper, timer.
Arguments:
- cfg (:obj:`EasyDict`): Configuration EasyDict.
"""
self._cfg = cfg
self._exp_name = exp_name
self._instance_name = instance_name
if tb_logger is not None:
self._logger, _ = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False
)
self._tb_logger = tb_logger
else:
self._logger, self._tb_logger = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name
)
self.reset(policy, env)
self._timer = EasyTimer()
self._stop_value = cfg.stop_value
def reset_env(self, _env: Optional[Tuple[DataLoader, IMetric]] = None) -> None:
"""
Overview:
Reset evaluator's environment. In some case, we need evaluator use the same policy in different \
environments. We can use reset_env to reset the environment.
If _env is not None, replace the old environment in the evaluator with the new one
Arguments:
- env (:obj:`Optional[Tuple[DataLoader, IMetric]]`): Instance of the DataLoader and Metric
"""
if _env is not None:
self._dataloader, self._metric = _env
def reset_policy(self, _policy: Optional[namedtuple] = None) -> None:
"""
Overview:
Reset evaluator's policy. In some case, we need evaluator work in this same environment but use\
different policy. We can use reset_policy to reset the policy.
If _policy is None, reset the old policy.
If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
Arguments:
- policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy
"""
if _policy is not None:
self._policy = _policy
self._policy.reset()
def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[Tuple[DataLoader, IMetric]] = None) -> None:
"""
Overview:
Reset evaluator's policy and environment. Use new policy and environment to collect data.
If _env is not None, replace the old environment in the evaluator with the new one
If _policy is None, reset the old policy.
If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
Arguments:
- policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy
- env (:obj:`Optional[Tuple[DataLoader, IMetric]]`): Instance of the DataLoader and Metric
"""
if _env is not None:
self.reset_env(_env)
if _policy is not None:
self.reset_policy(_policy)
self._max_avg_eval_result = None
self._last_eval_iter = -1
self._end_flag = False
def close(self) -> None:
"""
Overview:
Close the evaluator. If end_flag is False, close the environment, flush the tb_logger\
and close the tb_logger.
"""
if self._end_flag:
return
self._end_flag = True
self._env.close()
self._tb_logger.flush()
self._tb_logger.close()
def __del__(self):
"""
Overview:
Execute the close command and close the evaluator. __del__ is automatically called \
to destroy the evaluator instance when the evaluator finishes its work
"""
self.close()
def should_eval(self, train_iter: int) -> bool:
"""
Overview:
Determine whether you need to start the evaluation mode, if the number of training has reached\
the maximum number of times to start the evaluator, return True
"""
if train_iter == self._last_eval_iter:
return False
if (train_iter - self._last_eval_iter) < self._cfg.eval_freq and train_iter != 0:
return False
self._last_eval_iter = train_iter
return True
def eval(
self,
save_ckpt_fn: Callable = None,
train_iter: int = -1,
envstep: int = -1,
) -> Tuple[bool, Any]:
'''
Overview:
Evaluate policy and store the best policy based on whether it reaches the highest historical reward.
Arguments:
- save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward.
- train_iter (:obj:`int`): Current training iteration.
- envstep (:obj:`int`): Current env interaction step.
Returns:
- stop_flag (:obj:`bool`): Whether this training program can be ended.
- eval_metric (:obj:`float`): Current evaluation metric result.
'''
self._policy.reset()
eval_results = []
with self._timer:
self._logger.info("Evaluation begin...")
for batch_idx, batch_data in enumerate(self._dataloader):
inputs, label = to_tensor(batch_data)
policy_output = self._policy.forward(inputs)
eval_results.append(self._metric.eval(policy_output, label))
avg_eval_result = self._metric.reduce_mean(eval_results)
if self._cfg.multi_gpu:
device = self._policy.get_attribute('device')
for k in avg_eval_result.keys():
value_tensor = torch.FloatTensor([avg_eval_result[k]]).to(device)
allreduce(value_tensor)
avg_eval_result[k] = value_tensor.item()
duration = self._timer.value
info = {
'train_iter': train_iter,
'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter),
'data_length': len(self._dataloader),
'evaluate_time': duration,
'avg_time_per_data': duration / len(self._dataloader),
}
info.update(avg_eval_result)
self._logger.info(self._logger.get_tabulate_vars_hor(info))
# self._logger.info(self._logger.get_tabulate_vars(info))
for k, v in info.items():
if k in ['train_iter', 'ckpt_name']:
continue
if not np.isscalar(v):
continue
self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep)
if self._metric.gt(avg_eval_result, self._max_avg_eval_result):
if save_ckpt_fn:
save_ckpt_fn('ckpt_best.pth.tar')
self._max_avg_eval_result = avg_eval_result
stop_flag = self._metric.gt(avg_eval_result, self._stop_value) and train_iter > 0
if stop_flag:
self._logger.info(
"[DI-engine serial pipeline] " +
"Current eval_reward: {} is greater than stop_value: {}".format(avg_eval_result, self._stop_value) +
", so your RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details."
)
return stop_flag, avg_eval_result
......@@ -11,7 +11,7 @@ from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF,
@SERIAL_COLLECTOR_REGISTRY.register('sample')
class SampleCollector(ISerialCollector):
class SampleSerialCollector(ISerialCollector):
"""
Overview:
Sample collector(n_sample), a sample is one training sample for updating model,
......
......@@ -60,15 +60,15 @@ main_config = fake_cpong_dqn_config
fake_cpong_dqn_create_config = dict(
env=dict(
import_names=['ding.worker.collector.tests.test_one_vs_one_collector'],
import_names=['ding.worker.collector.tests.test_marine_parallel_collector'],
type='fake_competitive_rl',
),
env_manager=dict(type='base'),
policy=dict(type='dqn_command'),
learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
collector=dict(
type='one_vs_one',
import_names=['ding.worker.collector.one_vs_one_collector'],
type='marine',
import_names=['ding.worker.collector.marine_parallel_collector'],
),
commander=dict(
type='one_vs_one',
......
......@@ -5,7 +5,7 @@ import pytest
from functools import partial
import copy
from ding.worker import SampleCollector, NaiveReplayBuffer
from ding.worker import SampleSerialCollector, NaiveReplayBuffer
from ding.envs import get_vec_env_setting, create_env_manager, AsyncSubprocessEnvManager, SyncSubprocessEnvManager,\
BaseEnvManager
from ding.utils import deep_merge_dicts, set_pkg_seed
......@@ -48,10 +48,10 @@ def compare_test(cfg, out_str, seed):
collector_env.seed(seed)
# cfg.policy.collect.collector = deep_merge_dicts(
# SampleCollector.default_config(), cfg.policy.collect.collector)
# SampleSerialCollector.default_config(), cfg.policy.collect.collector)
policy = FakePolicy(cfg.policy)
collector_cfg = deep_merge_dicts(SampleCollector.default_config(), cfg.policy.collect.collector)
collector = SampleCollector(collector_cfg, collector_env, policy.collect_mode)
collector_cfg = deep_merge_dicts(SampleSerialCollector.default_config(), cfg.policy.collect.collector)
collector = SampleSerialCollector(collector_cfg, collector_env, policy.collect_mode)
buffer_cfg = deep_merge_dicts(cfg.policy.other.replay_buffer, NaiveReplayBuffer.default_config())
replay_buffer = NaiveReplayBuffer(buffer_cfg)
......
import pytest
from ding.worker import EpisodeCollector
from ding.worker import EpisodeSerialCollector
from ding.envs import BaseEnvManager, SyncSubprocessEnvManager, AsyncSubprocessEnvManager
from ding.policy import DQNPolicy
from ding.model import DQN
......@@ -13,7 +13,7 @@ def test_collect(env_manager_type):
env.seed(0)
model = DQN(obs_shape=4, action_shape=1)
policy = DQNPolicy(DQNPolicy.default_config(), model=model).collect_mode
collector = EpisodeCollector(EpisodeCollector.default_config(), env, policy)
collector = EpisodeSerialCollector(EpisodeSerialCollector.default_config(), env, policy)
collected_episode = collector.collect(
n_episode=18, train_iter=collector._collect_print_freq, policy_kwargs={'eps': 0.5}
......
......@@ -12,14 +12,14 @@ from easydict import EasyDict
from ding.policy import create_policy, Policy
from ding.envs import get_vec_env_setting, create_env_manager, BaseEnvManager
from ding.utils import get_data_compressor, pretty_print, PARALLEL_COLLECTOR_REGISTRY
from .base_parallel_collector import BaseCollector
from .base_parallel_collector import BaseParallelCollector
from .base_serial_collector import CachePool, TrajBuffer
INF = float("inf")
@PARALLEL_COLLECTOR_REGISTRY.register('zergling')
class ZerglingCollector(BaseCollector):
class ZerglingParallelCollector(BaseParallelCollector):
"""
Feature:
- one policy, many envs
......@@ -292,4 +292,4 @@ class ZerglingCollector(BaseCollector):
}
def __repr__(self) -> str:
return "ZerglingCollector"
return "ZerglingParallelCollector"
......@@ -26,7 +26,7 @@ class BaseSerialCommander(object):
cfg: dict,
learner: 'BaseLearner', # noqa
collector: 'BaseSerialCollector', # noqa
evaluator: 'BaseSerialEvaluator', # noqa
evaluator: 'InteractionSerialEvaluator', # noqa
replay_buffer: 'IBuffer', # noqa
policy: namedtuple = None,
) -> None:
......@@ -37,7 +37,7 @@ class BaseSerialCommander(object):
- cfg (:obj:`dict`): the config of commander
- learner (:obj:`BaseLearner`): the learner
- collector (:obj:`BaseSerialCollector`): the collector
- evaluator (:obj:`BaseSerialEvaluator`): the evaluator
- evaluator (:obj:`InteractionSerialEvaluator`): the evaluator
- replay_buffer (:obj:`IBuffer`): the buffer
"""
self._cfg = cfg
......
......@@ -79,7 +79,7 @@ def setup_1v1commander():
learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
collector=dict(
type='zergling',
import_names=['ding.worker.collector.zergling_collector'],
import_names=['ding.worker.collector.zergling_parallel_collector'],
),
commander=dict(
type='one_vs_one',
......
......@@ -70,7 +70,7 @@ qbert_dqn_create_config = dict(
learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
collector=dict(
type='zergling',
import_names=['ding.worker.collector.zergling_collector'],
import_names=['ding.worker.collector.zergling_parallel_collector'],
),
commander=dict(
type='solo',
......
......@@ -71,7 +71,7 @@ qbert_dqn_create_config = dict(
learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
collector=dict(
type='zergling',
import_names=['ding.worker.collector.zergling_collector'],
import_names=['ding.worker.collector.zergling_parallel_collector'],
),
commander=dict(
type='solo',
......
......@@ -5,7 +5,7 @@ from easydict import EasyDict
from functools import partial
from ding.config import compile_config
from ding.worker import BaseLearner, EpisodeCollector, BaseSerialEvaluator, EpisodeReplayBuffer
from ding.worker import BaseLearner, EpisodeSerialCollector, InteractionSerialEvaluator, EpisodeReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import DQNPolicy
from ding.model import DQN
......@@ -22,8 +22,8 @@ def main(cfg, seed=0, max_iterations=int(1e8)):
BaseEnvManager,
DQNPolicy,
BaseLearner,
EpisodeCollector,
BaseSerialEvaluator,
EpisodeSerialCollector,
InteractionSerialEvaluator,
EpisodeReplayBuffer,
save_cfg=True
)
......@@ -47,10 +47,10 @@ def main(cfg, seed=0, max_iterations=int(1e8)):
# Set up collection, training and evaluation utilities
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = EpisodeCollector(
collector = EpisodeSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = BaseSerialEvaluator(
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = EpisodeReplayBuffer(
......
......@@ -67,7 +67,7 @@ cartpole_dqn_create_config = dict(
learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
collector=dict(
type='zergling',
import_names=['ding.worker.collector.zergling_collector'],
import_names=['ding.worker.collector.zergling_parallel_collector'],
),
commander=dict(
type='solo',
......
......@@ -68,7 +68,7 @@ cartpole_dqn_create_config = dict(
learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
collector=dict(
type='zergling',
import_names=['ding.worker.collector.zergling_collector'],
import_names=['ding.worker.collector.zergling_parallel_collector'],
),
commander=dict(
type='solo',
......
......@@ -4,7 +4,7 @@ import torch
from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, AdvancedReplayBuffer
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import DQNPolicy
from ding.model import DQN
......@@ -24,8 +24,8 @@ def main(cfg, seed=0):
BaseEnvManager,
DQNPolicy,
BaseLearner,
SampleCollector,
BaseSerialEvaluator,
SampleSerialCollector,
InteractionSerialEvaluator,
AdvancedReplayBuffer,
save_cfg=True
)
......@@ -44,7 +44,7 @@ def main(cfg, seed=0):
# evaluate
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
evaluator = BaseSerialEvaluator(
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator.eval()
......
......@@ -3,7 +3,7 @@ import gym
from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, AdvancedReplayBuffer
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import DQNPolicy
from ding.model import DQN
......@@ -23,8 +23,8 @@ def main(cfg, seed=0):
BaseEnvManager,
DQNPolicy,
BaseLearner,
SampleCollector,
BaseSerialEvaluator,
SampleSerialCollector,
InteractionSerialEvaluator,
AdvancedReplayBuffer,
save_cfg=True
)
......@@ -44,10 +44,10 @@ def main(cfg, seed=0):
# Set up collection, training and evaluation utilities
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleCollector(
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = BaseSerialEvaluator(
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
......@@ -77,7 +77,7 @@ def main(cfg, seed=0):
# evaluate
evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
evaluator_env.enable_save_replay(cfg.env.replay_path) # switch save replay interface
evaluator = BaseSerialEvaluator(
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
......
......@@ -5,7 +5,7 @@ from easydict import EasyDict
from copy import deepcopy
from ding.config import compile_config
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, AdvancedReplayBuffer
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import PPGPolicy
from ding.model import PPG
......@@ -23,8 +23,8 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
BaseEnvManager,
PPGPolicy,
BaseLearner,
SampleCollector,
BaseSerialEvaluator, {
SampleSerialCollector,
InteractionSerialEvaluator, {
'policy': AdvancedReplayBuffer,
'value': AdvancedReplayBuffer
},
......@@ -42,10 +42,10 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
policy = PPGPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleCollector(
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = BaseSerialEvaluator(
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
policy_buffer = AdvancedReplayBuffer(
......
......@@ -3,7 +3,7 @@ import gym
from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import PPOPolicy
from ding.model import VAC
......@@ -17,7 +17,7 @@ def wrapped_cartpole_env():
def main(cfg, seed=0, max_iterations=int(1e10)):
cfg = compile_config(
cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleCollector, BaseSerialEvaluator, save_cfg=True
cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, save_cfg=True
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(collector_env_num)], cfg=cfg.env.manager)
......@@ -31,10 +31,10 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
policy = PPOPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleCollector(
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = BaseSerialEvaluator(
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
......
......@@ -3,7 +3,7 @@ import gym
from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, NaiveReplayBuffer
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, NaiveReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import PPOOffPolicy
from ding.model import VAC
......@@ -21,8 +21,8 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
BaseEnvManager,
PPOOffPolicy,
BaseLearner,
SampleCollector,
BaseSerialEvaluator,
SampleSerialCollector,
InteractionSerialEvaluator,
NaiveReplayBuffer,
save_cfg=True
)
......@@ -38,10 +38,10 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
policy = PPOOffPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleCollector(
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = BaseSerialEvaluator(
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = NaiveReplayBuffer(cfg.policy.other.replay_buffer, exp_name=cfg.exp_name)
......
......@@ -3,7 +3,7 @@ import gym
from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, NaiveReplayBuffer
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, NaiveReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import PPOOffPolicy
from ding.model import VAC
......@@ -22,8 +22,8 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
BaseEnvManager,
PPOOffPolicy,
BaseLearner,
SampleCollector,
BaseSerialEvaluator,
SampleSerialCollector,
InteractionSerialEvaluator,
NaiveReplayBuffer,
reward_model=RndRewardModel,
save_cfg=True
......@@ -40,10 +40,10 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
policy = PPOOffPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleCollector(
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = BaseSerialEvaluator(
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = NaiveReplayBuffer(cfg.policy.other.replay_buffer, exp_name=cfg.exp_name)
......
......@@ -4,7 +4,7 @@ from tensorboardX import SummaryWriter
from easydict import EasyDict
from ding.config import compile_config
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, NaiveReplayBuffer
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, NaiveReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import PPOPolicy
from ding.model import VAC
......@@ -19,8 +19,8 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
BaseEnvManager,
PPOPolicy,
BaseLearner,
SampleCollector,
BaseSerialEvaluator,
SampleSerialCollector,
InteractionSerialEvaluator,
NaiveReplayBuffer,
save_cfg=True
)
......@@ -40,8 +40,8 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
policy = PPOPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./log/', 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger)
collector = SampleCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger)
evaluator = BaseSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger)
evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
for _ in range(max_iterations):
if evaluator.should_eval(learner.train_iter):
......
......@@ -3,7 +3,7 @@ import gym
from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, AdvancedReplayBuffer
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import DDPGPolicy
from ding.model import QAC
......@@ -18,8 +18,8 @@ def main(cfg, seed=0):
BaseEnvManager,
DDPGPolicy,
BaseLearner,
SampleCollector,
BaseSerialEvaluator,
SampleSerialCollector,
InteractionSerialEvaluator,
AdvancedReplayBuffer,
save_cfg=True
)
......@@ -45,10 +45,10 @@ def main(cfg, seed=0):
# Set up collection, training and evaluation utilities
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleCollector(
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = BaseSerialEvaluator(
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
......
......@@ -68,8 +68,8 @@ cpong_dqn_create_config = dict(
policy=dict(type='dqn_command'),
learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
collector=dict(
type='one_vs_one',
import_names=['ding.worker.collector.one_vs_one_collector'],
type='marine',
import_names=['ding.worker.collector.marine_parallel_collector'],
),
commander=dict(
type='one_vs_one',
......
......@@ -49,7 +49,7 @@ __base_learner_default_config = dict(
__zergling_collector_default_config = dict(
collector_type='zergling',
import_names=['ding.worker.collector.zergling_collector'],
import_names=['ding.worker.collector.zergling_parallel_collector'],
print_freq=10,
compressor='lz4',
policy_update_freq=3,
......
......@@ -67,8 +67,8 @@ gfootball_ppo_create_config = dict(
policy=dict(type='ppo_lstm_command', import_names=['dizoo.gfootball.policy.ppo_lstm']),
learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
collector=dict(
type='one_vs_one',
import_names=['ding.worker.collector.one_vs_one_collector'],
type='marine',
import_names=['ding.worker.collector.marine_parallel_collector'],
),
commander=dict(
type='one_vs_one',
......
from .dataset import ImageNetDataset
from .sampler import DistributedSampler
from typing import Callable, Union
import os
import re
import math
from PIL import Image
import numpy as np
import torch
import torch.utils.data as data
from torchvision import transforms
class ToNumpy:
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img
def _pil_interp(method):
if method == 'bicubic':
return Image.BICUBIC
elif method == 'lanczos':
return Image.LANCZOS
elif method == 'hamming':
return Image.HAMMING
else:
# default bilinear, do we want to allow nearest?
return Image.BILINEAR
def natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def find_images_and_targets(folder, types=('.png', '.jpg', '.jpeg'), class_to_idx=None, leaf_name_only=True, sort=True):
labels = []
filenames = []
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
for f in files:
base, ext = os.path.splitext(f)
if ext.lower() in types:
filenames.append(os.path.join(root, f))
labels.append(label)
if class_to_idx is None:
# building class index
unique_labels = set(labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
if sort:
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
return images_and_targets, class_to_idx
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
DEFAULT_CROP_PCT = 0.875
def transforms_noaug_train(
img_size=224,
interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
):
if interpolation == 'random':
# random interpolation not supported with no-aug
interpolation = 'bilinear'
tfl = [transforms.Resize(img_size, _pil_interp(interpolation)), transforms.CenterCrop(img_size)]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [transforms.ToTensor(), transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))]
return transforms.Compose(tfl)
def transforms_imagenet_eval(
img_size=224,
crop_pct=None,
interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD
):
crop_pct = crop_pct or DEFAULT_CROP_PCT
if isinstance(img_size, (tuple, list)):
assert len(img_size) == 2
if img_size[-1] == img_size[-2]:
# fall-back to older behaviour so Resize scales to shortest edge if target is square
scale_size = int(math.floor(img_size[0] / crop_pct))
else:
scale_size = tuple([int(x / crop_pct) for x in img_size])
else:
scale_size = int(math.floor(img_size / crop_pct))
tfl = [
transforms.Resize(scale_size, _pil_interp(interpolation)),
transforms.CenterCrop(img_size),
]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [transforms.ToTensor(), transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))]
return transforms.Compose(tfl)
class ImageNetDataset(data.Dataset):
def __init__(self, root: str, is_training: bool, transform: Callable = None) -> None:
self.root = root
if transform is None:
if is_training:
transform = transforms_noaug_train()
else:
transform = transforms_imagenet_eval()
self.transform = transform
self.data, _ = find_images_and_targets(root)
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, index: int) -> Union[torch.Tensor, torch.Tensor]:
img, target = self.data[index]
img = Image.open(img).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
return img, target
import math
import torch
from torch.utils.data import Sampler
from ding.utils import get_rank, get_world_size
class DistributedSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
world_size (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within world_size.
"""
def __init__(self, dataset, world_size=None, rank=None, round_up=True):
if world_size is None:
world_size = get_world_size()
if rank is None:
rank = get_rank()
self.dataset = dataset
self.world_size = world_size
self.rank = rank
self.round_up = round_up
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.world_size))
if self.round_up:
self.total_size = self.num_samples * self.world_size
else:
self.total_size = len(self.dataset)
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = list(torch.randperm(len(self.dataset), generator=g))
# add extra samples to make it evenly divisible
if self.round_up:
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
offset = self.num_samples * self.rank
indices = indices[offset:offset + self.num_samples]
if self.round_up or (not self.round_up and self.rank < self.world_size - 1):
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
from easydict import EasyDict
imagenet_res18_config = dict(
exp_name='imagenet_res18',
policy=dict(
cuda=True,
learn=dict(
multi_gpu=True,
bp_update_sync=True,
train_epoch=200,
batch_size=32,
learning_rate=0.01,
decay_epoch=30,
decay_rate=0.1,
warmup_lr=1e-4,
warmup_epoch=3,
weight_decay=1e-4,
learner=dict(
log_show_freq=10,
hook=dict(
log_show_after_iter=int(1e9), # use user-defined hook, disable it
save_ckpt_after_iter=1000,
)
)
),
collect=dict(
learn_data_path='/mnt/lustre/share/images/train',
eval_data_path='/mnt/lustre/share/images/val',
),
eval=dict(
batch_size=32,
evaluator=dict(
eval_freq=1,
multi_gpu=True,
stop_value=dict(
loss=0.5,
acc1=75.0,
acc5=95.0
)
)
),
),
env=dict(),
)
imagenet_res18_config = EasyDict(imagenet_res18_config)
from typing import Union, Optional, Tuple, List
import time
import os
import torch
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from ding.worker import BaseLearner, LearnerHook, MetricSerialEvaluator, IMetric
from ding.config import read_config, compile_config
from ding.torch_utils import resnet18
from ding.utils import set_pkg_seed, get_rank, dist_init
from dizoo.image_classification.policy import ImageClassificationPolicy
from dizoo.image_classification.data import ImageNetDataset, DistributedSampler
from dizoo.image_classification.entry.imagenet_res18_config import imagenet_res18_config
class ImageClsLogShowHook(LearnerHook):
def __init__(self, *args, freq: int = 1, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._freq = freq
def __call__(self, engine: 'BaseLearner') -> None: # noqa
# Only show log for rank 0 learner
if engine.rank != 0:
for k in engine.log_buffer:
engine.log_buffer[k].clear()
return
# For 'scalar' type variables: log_buffer -> tick_monitor -> monitor_time.step
for k, v in engine.log_buffer['scalar'].items():
setattr(engine.monitor, k, v)
engine.monitor.time.step()
iters = engine.last_iter.val
if iters % self._freq == 0:
# For 'scalar' type variables: tick_monitor -> var_dict -> text_logger & tb_logger
var_dict = {}
log_vars = engine.policy.monitor_vars()
attr = 'avg'
for k in log_vars:
k_attr = k + '_' + attr
var_dict[k_attr] = getattr(engine.monitor, attr)[k]()
# user-defined variable
var_dict['data_time_val'] = engine.data_time
epoch_info = engine.epoch_info
var_dict['epoch_val'] = epoch_info[0]
engine.logger.info(
'Epoch: {} [{:>4d}/{}]\t'
'Loss: {:>6.4f}\t'
'Data Time: {:.3f}\t'
'Forward Time: {:.3f}\t'
'Backward Time: {:.3f}\t'
'GradSync Time: {:.3f}\t'
'LR: {:.3e}'.format(
var_dict['epoch_val'], epoch_info[1], epoch_info[2], var_dict['total_loss_avg'],
var_dict['data_time_val'], var_dict['forward_time_avg'], var_dict['backward_time_avg'],
var_dict['sync_time_avg'], var_dict['cur_lr_avg']
)
)
for k, v in var_dict.items():
engine.tb_logger.add_scalar('{}/'.format(engine.instance_name) + k, v, iters)
# For 'histogram' type variables: log_buffer -> tb_var_dict -> tb_logger
tb_var_dict = {}
for k in engine.log_buffer['histogram']:
new_k = '{}/'.format(engine.instance_name) + k
tb_var_dict[new_k] = engine.log_buffer['histogram'][k]
for k, v in tb_var_dict.items():
engine.tb_logger.add_histogram(k, v, iters)
for k in engine.log_buffer:
engine.log_buffer[k].clear()
class ImageClassificationMetric(IMetric):
def __init__(self) -> None:
self.loss = torch.nn.CrossEntropyLoss()
@staticmethod
def accuracy(inputs: torch.Tensor, label: torch.Tensor, topk: Tuple = (1, 5)) -> dict:
"""Computes the accuracy over the k top predictions for the specified values of k"""
maxk = max(topk)
batch_size = label.size(0)
_, pred = inputs.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(label.reshape(1, -1).expand_as(pred))
return {'acc{}'.format(k): correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk}
def eval(self, inputs: torch.Tensor, label: torch.Tensor) -> dict:
"""
Returns:
- eval_result (:obj:`dict`): {'loss': xxx, 'acc1': xxx, 'acc5': xxx}
"""
loss = self.loss(inputs, label)
output = self.accuracy(inputs, label)
output['loss'] = loss
for k in output:
output[k] = output[k].item()
return output
def reduce_mean(self, inputs: List[dict]) -> dict:
L = len(inputs)
output = {}
for k in inputs[0].keys():
output[k] = sum([t[k] for t in inputs]) / L
return output
def gt(self, metric1: dict, metric2: dict) -> bool:
if metric2 is None:
return True
for k in metric1:
if metric1[k] < metric2[k]:
return False
return True
def main(cfg: dict, seed: int) -> None:
cfg = compile_config(cfg, seed=seed, policy=ImageClassificationPolicy, evaluator=MetricSerialEvaluator)
if cfg.policy.learn.multi_gpu:
rank, world_size = dist_init()
else:
rank, world_size = 0, 1
# Random seed
set_pkg_seed(cfg.seed + rank, use_cuda=cfg.policy.cuda)
model = resnet18()
policy = ImageClassificationPolicy(cfg.policy, model=model, enable_field=['learn', 'eval'])
learn_dataset = ImageNetDataset(cfg.policy.collect.learn_data_path, is_training=True)
eval_dataset = ImageNetDataset(cfg.policy.collect.eval_data_path, is_training=False)
if cfg.policy.learn.multi_gpu:
learn_sampler = DistributedSampler(learn_dataset)
eval_sampler = DistributedSampler(eval_dataset)
else:
learn_sampler, eval_sampler = None, None
learn_dataloader = DataLoader(learn_dataset, cfg.policy.learn.batch_size, sampler=learn_sampler, num_workers=3)
eval_dataloader = DataLoader(eval_dataset, cfg.policy.eval.batch_size, sampler=eval_sampler, num_workers=2)
# Main components
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
log_show_hook = ImageClsLogShowHook(
name='image_cls_log_show_hook', priority=0, position='after_iter', freq=cfg.policy.learn.learner.log_show_freq
)
learner.register_hook(log_show_hook)
eval_metric = ImageClassificationMetric()
evaluator = MetricSerialEvaluator(
cfg.policy.eval.evaluator, [eval_dataloader, eval_metric], policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
# ==========
# Main loop
# ==========
learner.call_hook('before_run')
end = time.time()
for epoch in range(cfg.policy.learn.train_epoch):
# Evaluate policy performance
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, epoch, 0)
if stop:
break
for i, train_data in enumerate(learn_dataloader):
learner.data_time = time.time() - end
learner.epoch_info = (epoch, i, len(learn_dataloader))
learner.train(train_data)
end = time.time()
learner.policy.get_attribute('lr_scheduler').step()
learner.call_hook('after_run')
if __name__ == "__main__":
main(imagenet_res18_config, 0)
from .policy import ImageClassificationPolicy
import math
import torch
import torch.nn as nn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from ding.policy import Policy
from ding.model import model_wrap
from ding.torch_utils import to_device
from ding.utils import EasyTimer
class ImageClassificationPolicy(Policy):
config = dict(
type='image_classification',
on_policy=False,
)
def _init_learn(self):
self._optimizer = SGD(
self._model.parameters(),
lr=self._cfg.learn.learning_rate,
weight_decay=self._cfg.learn.weight_decay,
momentum=0.9
)
self._timer = EasyTimer(cuda=True)
def lr_scheduler_fn(epoch):
if epoch <= self._cfg.learn.warmup_epoch:
return self._cfg.learn.warmup_lr / self._cfg.learn.learning_rate
else:
ratio = epoch // self._cfg.learn.decay_epoch
return math.pow(self._cfg.learn.decay_rate, ratio)
self._lr_scheduler = LambdaLR(self._optimizer, lr_scheduler_fn)
self._lr_scheduler.step()
self._learn_model = model_wrap(self._model, 'base')
self._learn_model.reset()
self._ce_loss = nn.CrossEntropyLoss()
def _forward_learn(self, data):
if self._cuda:
data = to_device(data, self._device)
self._learn_model.train()
with self._timer:
img, target = data
logit = self._learn_model.forward(img)
loss = self._ce_loss(logit, target)
forward_time = self._timer.value
with self._timer:
self._optimizer.zero_grad()
loss.backward()
backward_time = self._timer.value
with self._timer:
if self._cfg.learn.multi_gpu:
self.sync_gradients(self._learn_model)
sync_time = self._timer.value
self._optimizer.step()
cur_lr = [param_group['lr'] for param_group in self._optimizer.param_groups]
cur_lr = sum(cur_lr) / len(cur_lr)
return {
'cur_lr': cur_lr,
'total_loss': loss.item(),
'forward_time': forward_time,
'backward_time': backward_time,
'sync_time': sync_time,
}
def _monitor_vars_learn(self):
return ['cur_lr', 'total_loss', 'forward_time', 'backward_time', 'sync_time']
def _init_eval(self):
self._eval_model = model_wrap(self._model, 'base')
def _forward_eval(self, data):
if self._cuda:
data = to_device(data, self._device)
self._eval_model.eval()
with torch.no_grad():
output = self._eval_model.forward(data)
if self._cuda:
output = to_device(output, 'cpu')
return output
def _init_collect(self):
pass
def _forward_collect(self, data):
pass
def _process_transition(self):
pass
def _get_train_sample(self):
pass
......@@ -6,7 +6,7 @@ import torch
from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.worker import BaseLearner, Episode1v1Collector, OnevOneEvaluator, NaiveReplayBuffer
from ding.worker import BaseLearner, BattleEpisodeSerialCollector, BattleInteractionSerialEvaluator, NaiveReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import PPOPolicy
from ding.model import VAC
......@@ -54,8 +54,8 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
BaseEnvManager,
PPOPolicy,
BaseLearner,
Episode1v1Collector,
OnevOneEvaluator,
BattleEpisodeSerialCollector,
BattleInteractionSerialEvaluator,
NaiveReplayBuffer,
save_cfg=True
)
......@@ -100,7 +100,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
exp_name=cfg.exp_name,
instance_name=player_id + '_learner'
)
collectors[player_id] = Episode1v1Collector(
collectors[player_id] = BattleEpisodeSerialCollector(
cfg.policy.collect.collector,
collector_env,
tb_logger=tb_logger,
......@@ -120,7 +120,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
# collect_mode ppo use multimonial sample for selecting action
evaluator1_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
evaluator1_cfg.stop_value = cfg.env.stop_value[0]
evaluator1 = OnevOneEvaluator(
evaluator1 = BattleInteractionSerialEvaluator(
evaluator1_cfg,
evaluator_env1, [policies[main_key].collect_mode, eval_policy1],
tb_logger,
......@@ -129,7 +129,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
)
evaluator2_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
evaluator2_cfg.stop_value = cfg.env.stop_value[1]
evaluator2 = OnevOneEvaluator(
evaluator2 = BattleInteractionSerialEvaluator(
evaluator2_cfg,
evaluator_env2, [policies[main_key].collect_mode, eval_policy2],
tb_logger,
......@@ -138,7 +138,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
)
evaluator3_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
evaluator3_cfg.stop_value = 99999999 # stop_value of evaluator3 is a placeholder
evaluator3 = OnevOneEvaluator(
evaluator3 = BattleInteractionSerialEvaluator(
evaluator3_cfg,
evaluator_env3, [policies[main_key].collect_mode, eval_policy3],
tb_logger,
......
......@@ -6,7 +6,7 @@ import torch
from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.worker import BaseLearner, Episode1v1Collector, OnevOneEvaluator, NaiveReplayBuffer
from ding.worker import BaseLearner, BattleEpisodeSerialCollector, BattleInteractionSerialEvaluator, NaiveReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import PPOPolicy
from ding.model import VAC
......@@ -45,8 +45,8 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
BaseEnvManager,
PPOPolicy,
BaseLearner,
Episode1v1Collector,
OnevOneEvaluator,
BattleEpisodeSerialCollector,
BattleInteractionSerialEvaluator,
NaiveReplayBuffer,
save_cfg=True
)
......@@ -81,7 +81,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
learner2 = BaseLearner(
cfg.policy.learn.learner, policy2.learn_mode, tb_logger, exp_name=cfg.exp_name, instance_name='learner2'
)
collector = Episode1v1Collector(
collector = BattleEpisodeSerialCollector(
cfg.policy.collect.collector,
collector_env, [policy1.collect_mode, policy2.collect_mode],
tb_logger,
......@@ -90,7 +90,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
# collect_mode ppo use multimonial sample for selecting action
evaluator1_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
evaluator1_cfg.stop_value = cfg.env.stop_value[0]
evaluator1 = OnevOneEvaluator(
evaluator1 = BattleInteractionSerialEvaluator(
evaluator1_cfg,
evaluator_env1, [policy1.collect_mode, eval_policy1],
tb_logger,
......@@ -99,7 +99,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
)
evaluator2_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
evaluator2_cfg.stop_value = cfg.env.stop_value[1]
evaluator2 = OnevOneEvaluator(
evaluator2 = BattleInteractionSerialEvaluator(
evaluator2_cfg,
evaluator_env2, [policy1.collect_mode, eval_policy2],
tb_logger,
......
......@@ -4,7 +4,7 @@ from tensorboardX import SummaryWriter
from easydict import EasyDict
from ding.config import compile_config
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, NaiveReplayBuffer
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, NaiveReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import PPOPolicy
from ding.model import VAC
......@@ -21,8 +21,8 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
BaseEnvManager,
PPOPolicy,
BaseLearner,
SampleCollector,
BaseSerialEvaluator,
SampleSerialCollector,
InteractionSerialEvaluator,
NaiveReplayBuffer,
save_cfg=True
)
......@@ -42,8 +42,8 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
policy = PPOPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./log/', 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger)
collector = SampleCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger)
evaluator = BaseSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger)
evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
for _ in range(max_iterations):
if evaluator.should_eval(learner.train_iter):
......
......@@ -4,7 +4,7 @@ from tensorboardX import SummaryWriter
from easydict import EasyDict
from ding.config import compile_config
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, AdvancedReplayBuffer
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
from ding.envs import SyncSubprocessEnvManager
from ding.policy import QMIXPolicy
from ding.model import QMix
......@@ -24,8 +24,8 @@ def main(cfg, seed=0):
SyncSubprocessEnvManager,
QMIXPolicy,
BaseLearner,
SampleCollector,
BaseSerialEvaluator,
SampleSerialCollector,
InteractionSerialEvaluator,
AdvancedReplayBuffer,
save_cfg=True
)
......@@ -45,8 +45,8 @@ def main(cfg, seed=0):
policy = QMIXPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./log/', 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger)
collector = SampleCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger)
evaluator = BaseSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger)
evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger)
eps_cfg = cfg.policy.other.eps
......
......@@ -4,7 +4,7 @@ from tensorboardX import SummaryWriter
from easydict import EasyDict
from ding.config import compile_config
from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, NaiveReplayBuffer
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, NaiveReplayBuffer
from ding.envs import BaseEnvManager, DingEnvWrapper
from ding.policy import PPOPolicy
from ding.model import VAC
......@@ -21,8 +21,8 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
BaseEnvManager,
PPOPolicy,
BaseLearner,
SampleCollector,
BaseSerialEvaluator,
SampleSerialCollector,
InteractionSerialEvaluator,
NaiveReplayBuffer,
save_cfg=True
)
......@@ -42,8 +42,8 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
policy = PPOPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./log/', 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger)
collector = SampleCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger)
evaluator = BaseSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger)
evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
for _ in range(max_iterations):
if evaluator.should_eval(learner.train_iter):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册