未验证 提交 0414eda5 编写于 作者: J jayyoung0802 提交者: GitHub

test(yzj): add unittest for dataset, metric_serial_evaluator and learner (#107)

* add 4 pytest dataset.py learner_aggregator.py learner_hook.py metric_serial_evaluator.py

* fix yapf and flake8 And remove invalid self._env

* fix fake_cls_config.py flake8
上级 c5af1cf2
import pytest
import threading
import time
import torch
import torch.nn as nn
from easydict import EasyDict
from ding.utils.data import offline_data_save_type, create_dataset
from ding.utils.data import offline_data_save_type, create_dataset, NaiveRLDataset, D4RLDataset, HDF5Dataset
cfg1 = dict(policy=dict(collect=dict(
data_type='naive',
......@@ -46,3 +43,24 @@ def test_offline_data_save_type(data_type):
def test_dataset(cfg):
cfg = EasyDict(cfg)
create_dataset(cfg)
@pytest.mark.parametrize('cfg', [cfg1])
@pytest.mark.unittest
def test_NaiveRLDataset(cfg):
cfg = EasyDict(cfg)
dataset = NaiveRLDataset(cfg)
# @pytest.mark.parametrize('cfg', [cfg3])
# @pytest.mark.unittest
# def test_D4RLDataset(cfg):
# cfg = EasyDict(cfg)
# dataset = D4RLDataset(cfg)
@pytest.mark.parametrize('cfg', [cfg2])
@pytest.mark.unittest
def test_HDF5Dataset(cfg):
cfg = EasyDict(cfg)
dataset = HDF5Dataset(cfg)
......@@ -93,6 +93,13 @@ def test_learner_aggregator():
assert task.result['a_list'] == [1, 2] * 4
assert task.status == TaskStatus.COMPLETED
task = conn.new_task({'name': 'fake_task', 'task_info': {}})
task.start().join()
assert task.status == TaskStatus.FAILED
assert task.result == {'message': 'task name error'}
assert learner_aggregator.deal_with_get_resource()['gpu'] == len(learner_slaves)
learner_aggregator.close()
for learner_slave in learner_slaves:
learner_slave.close()
......@@ -135,7 +135,6 @@ class MetricSerialEvaluator(ISerialEvaluator):
if self._end_flag:
return
self._end_flag = True
self._env.close()
self._tb_logger.flush()
self._tb_logger.close()
......
from ding.policy import Policy
from ding.model import model_wrap
class fake_policy(Policy):
def _init_learn(self):
pass
def _forward_learn(self, data):
pass
def _init_eval(self):
self._eval_model = model_wrap(self._model, 'base')
def _forward_eval(self, data):
self._eval_model.eval()
output = self._eval_model.forward(data)
return output
def _monitor_vars_learn(self):
return ['forward_time', 'backward_time', 'sync_time']
def _init_collect(self):
pass
def _forward_collect(self, data):
pass
def _process_transition(self):
pass
def _get_train_sample(self):
pass
from ding.worker import MetricSerialEvaluator, IMetric
from torch.utils.data import DataLoader
import pytest
import torch.utils.data as data
import torch.nn as nn
from ding.torch_utils import to_tensor
import torch
from easydict import EasyDict
from ding.worker.collector.tests.fake_cls_policy import fake_policy
fake_cls_config = dict(
exp_name='fake_config_for_test_metric_serial_evaluator',
policy=dict(
on_policy=False,
cuda=False,
eval=dict(batch_size=1, evaluator=dict(eval_freq=1, multi_gpu=False, stop_value=dict(acc=75.0))),
),
env=dict(),
)
cfg = EasyDict(fake_cls_config)
class fake_eval_dataset(data.Dataset):
def __init__(self) -> None:
self.data = [i for i in range(5)] # [0, 1, 2, 3, 4, 5]
self.target = [2 * i + 1 for i in range(5)] # [0, 3, 5, 7, 9, 11]
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, index: int):
data = self.data[index]
target = self.target[index]
return data, target
class fake_model(nn.Module): # y = 2*x+1
def __init__(self) -> None:
super(fake_model, self).__init__()
self.linear = nn.Linear(1, 1)
nn.init.constant_(self.linear.bias, 1)
nn.init.constant_(self.linear.weight, 2)
def forward(self, x):
x = to_tensor(x).float()
return self.linear(x)
class fake_ClassificationMetric(IMetric):
@staticmethod
def accuracy(inputs: torch.Tensor, label: torch.Tensor) -> dict:
batch_size = label.size(0)
correct = inputs.eq(label)
return {'acc': correct.reshape(-1).float().sum(0) * 100. / batch_size}
def eval(self, inputs: torch.Tensor, label: torch.Tensor) -> dict:
output = self.accuracy(inputs, label)
for k in output:
output[k] = output[k].item()
return output
def reduce_mean(self, inputs) -> dict:
L = len(inputs)
output = {}
for k in inputs[0].keys():
output[k] = sum([t[k] for t in inputs]) / L
return output
def gt(self, metric1: dict, metric2: dict) -> bool:
if metric2 is None:
return True
for k in metric1:
if metric1[k] < metric2[k]:
return False
return True
@pytest.mark.unittest
@pytest.mark.parametrize('cfg', [cfg])
def test_evaluator(cfg):
model = fake_model()
eval_dataset = fake_eval_dataset()
eval_dataloader = DataLoader(eval_dataset, cfg.policy.eval.batch_size, num_workers=2)
policy = fake_policy(cfg.policy, model=model, enable_field=['eval'])
eval_metric = fake_ClassificationMetric()
evaluator = MetricSerialEvaluator(
cfg.policy.eval.evaluator, [eval_dataloader, eval_metric], policy.eval_mode, exp_name=cfg.exp_name
)
cur_iter = 0
assert evaluator.should_eval(cur_iter)
evaluator._last_eval_iter = 0
cur_iter = 1
stop, reward = evaluator.eval(None, cur_iter, 0)
assert stop
assert reward['acc'] == 100
import easydict
import pytest
from ding.worker.learner import register_learner_hook, build_learner_hook_by_cfg, LearnerHook
from ding.worker.learner.learner_hook import SaveCkptHook, show_hooks
from ding.worker.learner.learner_hook import SaveCkptHook, LoadCkptHook, LogShowHook, LogReduceHook
from ding.worker.learner.learner_hook import show_hooks, add_learner_hook, merge_hooks
from easydict import EasyDict
@pytest.fixture(scope='function')
......@@ -11,6 +14,14 @@ def setup_simplified_hook_cfg():
)
@pytest.fixture(scope='function')
def fake_setup_simplified_hook_cfg():
return dict(
log_show_after_iter=20,
log_reduce_after_iter=True,
)
@pytest.mark.unittest
class TestLearnerHook:
......@@ -32,3 +43,33 @@ class TestLearnerHook:
assert isinstance(hooks['after_iter'][0], SaveCkptHook)
assert len(hooks['after_run']) == 1
assert isinstance(hooks['after_run'][0], SaveCkptHook)
def test_add_learner_hook(self, setup_simplified_hook_cfg):
hooks = build_learner_hook_by_cfg(setup_simplified_hook_cfg)
hook_1 = LogShowHook('log_show', 20, position='after_iter', ext_args=EasyDict({'freq': 100}))
add_learner_hook(hooks, hook_1)
hook_2 = LoadCkptHook('load_ckpt', 20, position='before_run', ext_args=EasyDict({'load_path': './model.pth'}))
add_learner_hook(hooks, hook_2)
hook_3 = LogReduceHook('log_reduce', 10, position='after_iter')
add_learner_hook(hooks, hook_3)
show_hooks(hooks)
assert len(hooks['after_iter']) == 3
assert len(hooks['after_run']) == 1
assert len(hooks['before_run']) == 1
assert len(hooks['before_iter']) == 0
assert isinstance(hooks['after_run'][0], SaveCkptHook)
assert isinstance(hooks['before_run'][0], LoadCkptHook)
def test_merge_hooks(self, setup_simplified_hook_cfg, fake_setup_simplified_hook_cfg):
hooks = build_learner_hook_by_cfg(setup_simplified_hook_cfg)
show_hooks(hooks)
fake_hooks = build_learner_hook_by_cfg(fake_setup_simplified_hook_cfg)
show_hooks(fake_hooks)
hooks_ = merge_hooks(hooks, fake_hooks)
show_hooks(hooks_)
assert len(hooks_['after_iter']) == 3
assert len(hooks_['after_run']) == 1
assert len(hooks_['before_run']) == 0
assert len(hooks_['before_iter']) == 0
assert isinstance(hooks['after_run'][0], SaveCkptHook)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册