diff --git a/ding/utils/data/tests/test_dataset.py b/ding/utils/data/tests/test_dataset.py index facc5ce3f3b6767089e42729acc4a02b3c1774ff..747fe687184b56dc711d570532180a8cdf4476bd 100644 --- a/ding/utils/data/tests/test_dataset.py +++ b/ding/utils/data/tests/test_dataset.py @@ -1,11 +1,8 @@ import pytest -import threading -import time import torch -import torch.nn as nn from easydict import EasyDict -from ding.utils.data import offline_data_save_type, create_dataset +from ding.utils.data import offline_data_save_type, create_dataset, NaiveRLDataset, D4RLDataset, HDF5Dataset cfg1 = dict(policy=dict(collect=dict( data_type='naive', @@ -46,3 +43,24 @@ def test_offline_data_save_type(data_type): def test_dataset(cfg): cfg = EasyDict(cfg) create_dataset(cfg) + + +@pytest.mark.parametrize('cfg', [cfg1]) +@pytest.mark.unittest +def test_NaiveRLDataset(cfg): + cfg = EasyDict(cfg) + dataset = NaiveRLDataset(cfg) + + +# @pytest.mark.parametrize('cfg', [cfg3]) +# @pytest.mark.unittest +# def test_D4RLDataset(cfg): +# cfg = EasyDict(cfg) +# dataset = D4RLDataset(cfg) + + +@pytest.mark.parametrize('cfg', [cfg2]) +@pytest.mark.unittest +def test_HDF5Dataset(cfg): + cfg = EasyDict(cfg) + dataset = HDF5Dataset(cfg) diff --git a/ding/worker/adapter/tests/test_learner_aggregator.py b/ding/worker/adapter/tests/test_learner_aggregator.py index 048b922ad413897c14b1bf18441e41adda36060b..511d8de7ef876b3473fb06cf030085766f1d371f 100644 --- a/ding/worker/adapter/tests/test_learner_aggregator.py +++ b/ding/worker/adapter/tests/test_learner_aggregator.py @@ -93,6 +93,13 @@ def test_learner_aggregator(): assert task.result['a_list'] == [1, 2] * 4 assert task.status == TaskStatus.COMPLETED + task = conn.new_task({'name': 'fake_task', 'task_info': {}}) + task.start().join() + assert task.status == TaskStatus.FAILED + assert task.result == {'message': 'task name error'} + + assert learner_aggregator.deal_with_get_resource()['gpu'] == len(learner_slaves) + learner_aggregator.close() for learner_slave in learner_slaves: learner_slave.close() diff --git a/ding/worker/collector/metric_serial_evaluator.py b/ding/worker/collector/metric_serial_evaluator.py index 6fb190ce2fb59c9fc1248033068144b4038b8925..4313e745e2180fbc1e0f4b993fc8f188d03eb0a9 100644 --- a/ding/worker/collector/metric_serial_evaluator.py +++ b/ding/worker/collector/metric_serial_evaluator.py @@ -135,7 +135,6 @@ class MetricSerialEvaluator(ISerialEvaluator): if self._end_flag: return self._end_flag = True - self._env.close() self._tb_logger.flush() self._tb_logger.close() diff --git a/ding/worker/collector/tests/fake_cls_policy.py b/ding/worker/collector/tests/fake_cls_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..6bbebc0fd6496b4a0c59a26c2973c18500ff113a --- /dev/null +++ b/ding/worker/collector/tests/fake_cls_policy.py @@ -0,0 +1,34 @@ +from ding.policy import Policy +from ding.model import model_wrap + + +class fake_policy(Policy): + + def _init_learn(self): + pass + + def _forward_learn(self, data): + pass + + def _init_eval(self): + self._eval_model = model_wrap(self._model, 'base') + + def _forward_eval(self, data): + self._eval_model.eval() + output = self._eval_model.forward(data) + return output + + def _monitor_vars_learn(self): + return ['forward_time', 'backward_time', 'sync_time'] + + def _init_collect(self): + pass + + def _forward_collect(self, data): + pass + + def _process_transition(self): + pass + + def _get_train_sample(self): + pass diff --git a/ding/worker/collector/tests/test_metric_serial_evaluator.py b/ding/worker/collector/tests/test_metric_serial_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..2652f0ace4bda3c1f57d938af6e75a5dc9ed8787 --- /dev/null +++ b/ding/worker/collector/tests/test_metric_serial_evaluator.py @@ -0,0 +1,102 @@ +from ding.worker import MetricSerialEvaluator, IMetric +from torch.utils.data import DataLoader +import pytest +import torch.utils.data as data + +import torch.nn as nn +from ding.torch_utils import to_tensor +import torch +from easydict import EasyDict +from ding.worker.collector.tests.fake_cls_policy import fake_policy + +fake_cls_config = dict( + exp_name='fake_config_for_test_metric_serial_evaluator', + policy=dict( + on_policy=False, + cuda=False, + eval=dict(batch_size=1, evaluator=dict(eval_freq=1, multi_gpu=False, stop_value=dict(acc=75.0))), + ), + env=dict(), +) + +cfg = EasyDict(fake_cls_config) + + +class fake_eval_dataset(data.Dataset): + + def __init__(self) -> None: + self.data = [i for i in range(5)] # [0, 1, 2, 3, 4, 5] + self.target = [2 * i + 1 for i in range(5)] # [0, 3, 5, 7, 9, 11] + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, index: int): + data = self.data[index] + target = self.target[index] + return data, target + + +class fake_model(nn.Module): # y = 2*x+1 + + def __init__(self) -> None: + super(fake_model, self).__init__() + self.linear = nn.Linear(1, 1) + nn.init.constant_(self.linear.bias, 1) + nn.init.constant_(self.linear.weight, 2) + + def forward(self, x): + x = to_tensor(x).float() + return self.linear(x) + + +class fake_ClassificationMetric(IMetric): + + @staticmethod + def accuracy(inputs: torch.Tensor, label: torch.Tensor) -> dict: + batch_size = label.size(0) + correct = inputs.eq(label) + return {'acc': correct.reshape(-1).float().sum(0) * 100. / batch_size} + + def eval(self, inputs: torch.Tensor, label: torch.Tensor) -> dict: + output = self.accuracy(inputs, label) + for k in output: + output[k] = output[k].item() + return output + + def reduce_mean(self, inputs) -> dict: + L = len(inputs) + output = {} + for k in inputs[0].keys(): + output[k] = sum([t[k] for t in inputs]) / L + return output + + def gt(self, metric1: dict, metric2: dict) -> bool: + if metric2 is None: + return True + for k in metric1: + if metric1[k] < metric2[k]: + return False + return True + + +@pytest.mark.unittest +@pytest.mark.parametrize('cfg', [cfg]) +def test_evaluator(cfg): + model = fake_model() + eval_dataset = fake_eval_dataset() + eval_dataloader = DataLoader(eval_dataset, cfg.policy.eval.batch_size, num_workers=2) + policy = fake_policy(cfg.policy, model=model, enable_field=['eval']) + eval_metric = fake_ClassificationMetric() + evaluator = MetricSerialEvaluator( + cfg.policy.eval.evaluator, [eval_dataloader, eval_metric], policy.eval_mode, exp_name=cfg.exp_name + ) + + cur_iter = 0 + assert evaluator.should_eval(cur_iter) + + evaluator._last_eval_iter = 0 + cur_iter = 1 + stop, reward = evaluator.eval(None, cur_iter, 0) + assert stop + assert reward['acc'] == 100 diff --git a/ding/worker/learner/tests/test_learner_hook.py b/ding/worker/learner/tests/test_learner_hook.py index ff49f8bfbebdf613251f7ce191a58dac30cf0b80..bf29f91a382544d5cb95092f088dd6907e9d3599 100644 --- a/ding/worker/learner/tests/test_learner_hook.py +++ b/ding/worker/learner/tests/test_learner_hook.py @@ -1,6 +1,9 @@ +import easydict import pytest from ding.worker.learner import register_learner_hook, build_learner_hook_by_cfg, LearnerHook -from ding.worker.learner.learner_hook import SaveCkptHook, show_hooks +from ding.worker.learner.learner_hook import SaveCkptHook, LoadCkptHook, LogShowHook, LogReduceHook +from ding.worker.learner.learner_hook import show_hooks, add_learner_hook, merge_hooks +from easydict import EasyDict @pytest.fixture(scope='function') @@ -11,6 +14,14 @@ def setup_simplified_hook_cfg(): ) +@pytest.fixture(scope='function') +def fake_setup_simplified_hook_cfg(): + return dict( + log_show_after_iter=20, + log_reduce_after_iter=True, + ) + + @pytest.mark.unittest class TestLearnerHook: @@ -32,3 +43,33 @@ class TestLearnerHook: assert isinstance(hooks['after_iter'][0], SaveCkptHook) assert len(hooks['after_run']) == 1 assert isinstance(hooks['after_run'][0], SaveCkptHook) + + def test_add_learner_hook(self, setup_simplified_hook_cfg): + hooks = build_learner_hook_by_cfg(setup_simplified_hook_cfg) + hook_1 = LogShowHook('log_show', 20, position='after_iter', ext_args=EasyDict({'freq': 100})) + add_learner_hook(hooks, hook_1) + hook_2 = LoadCkptHook('load_ckpt', 20, position='before_run', ext_args=EasyDict({'load_path': './model.pth'})) + add_learner_hook(hooks, hook_2) + hook_3 = LogReduceHook('log_reduce', 10, position='after_iter') + add_learner_hook(hooks, hook_3) + + show_hooks(hooks) + assert len(hooks['after_iter']) == 3 + assert len(hooks['after_run']) == 1 + assert len(hooks['before_run']) == 1 + assert len(hooks['before_iter']) == 0 + assert isinstance(hooks['after_run'][0], SaveCkptHook) + assert isinstance(hooks['before_run'][0], LoadCkptHook) + + def test_merge_hooks(self, setup_simplified_hook_cfg, fake_setup_simplified_hook_cfg): + hooks = build_learner_hook_by_cfg(setup_simplified_hook_cfg) + show_hooks(hooks) + fake_hooks = build_learner_hook_by_cfg(fake_setup_simplified_hook_cfg) + show_hooks(fake_hooks) + hooks_ = merge_hooks(hooks, fake_hooks) + show_hooks(hooks_) + assert len(hooks_['after_iter']) == 3 + assert len(hooks_['after_run']) == 1 + assert len(hooks_['before_run']) == 0 + assert len(hooks_['before_iter']) == 0 + assert isinstance(hooks['after_run'][0], SaveCkptHook)