提交 ad394fc5 编写于 作者: N niuyazhe

test(nyz): add test for ding/utils and remove DistributionImage

上级 1568e53d
......@@ -5,6 +5,9 @@ omit =
ding/utils/linklink_dist_helper.py
ding/utils/pytorch_ddp_dist_helper.py
ding/utils/k8s_helper.py
ding/utils/time_helper_cuda.py
ding/utils/time_helper_base.py
ding/utils/data/tests/test_dataloader.py
ding/config/utils.py
ding/entry/tests/test_serial_entry_algo.py
ding/entry/tests/test_serial_entry.py
......
......@@ -13,7 +13,7 @@ from .k8s_helper import get_operator_server_kwargs, exist_operator_server, DEFAU
K8sLauncher
from .orchestrator_launcher import OrchestratorLauncher
from .lock_helper import LockContext, LockContextType, get_file_lock, get_rw_file_lock
from .log_helper import build_logger, DistributionTimeImage, pretty_print, LoggerFactory
from .log_helper import build_logger, pretty_print, LoggerFactory
from .registry_factory import registries, POLICY_REGISTRY, ENV_REGISTRY, LEARNER_REGISTRY, COMM_LEARNER_REGISTRY, \
SERIAL_COLLECTOR_REGISTRY, PARALLEL_COLLECTOR_REGISTRY, COMM_COLLECTOR_REGISTRY, \
COMMANDER_REGISTRY, LEAGUE_REGISTRY, PLAYER_REGISTRY, MODEL_REGISTRY, \
......
......@@ -19,7 +19,7 @@ cfg2 = dict(policy=dict(collect=dict(
cfg3 = dict(env=dict(env_id='hopper-expert-v0'), policy=dict(collect=dict(data_type='d4rl', ), ))
cfgs = [cfg1, cfg2, cfg3]
cfgs = [cfg1, cfg2] # cfg3
unittest_args = ['naive', 'hdf5']
# fake transition & data
......@@ -36,11 +36,13 @@ expert_data_path = './expert.pkl'
@pytest.mark.parametrize('data_type', unittest_args)
@pytest.mark.unittest
def test_offline_data_save_type(data_type):
offline_data_save_type(exp_data=fake_data, expert_data_path=expert_data_path, data_type=data_type)
@pytest.mark.parametrize('cfg', cfgs)
@pytest.mark.unittest
def test_dataset(cfg):
cfg = EasyDict(cfg)
create_dataset(cfg)
......@@ -102,63 +102,6 @@ class LoggerFactory(object):
return s
class DistributionTimeImage:
r"""
Overview:
``DistributionTimeImage`` can be used to store images accorrding to ``time_steps``,
for data with 3 dims``(time, category, value)``
Interface:
``__init__``, ``add_one_time_step``, ``get_image``
"""
def __init__(self, maxlen: int = 600, val_range: Optional[dict] = None):
r"""
Overview:
Init the ``DistributionTimeImage`` class
Arguments:
- maxlen (:obj:`int`): The max length of data inputs
- val_range (:obj:`dict` or :obj:`None`): Dict with ``val_range['min']`` and ``val_range['max']``.
"""
self.maxlen = maxlen
self.val_range = val_range
self.img = np.ones((maxlen, maxlen))
self.time_step = 0
self.one_img = np.ones((maxlen, maxlen))
def add_one_time_step(self, data: np.ndarray) -> None:
r"""
Overview:
Step one timestep in ``DistributionTimeImage`` and add the data to distribution image
Arguments:
- data (:obj:`np.ndarray`): The data input
"""
assert (isinstance(data, np.ndarray))
data = np.expand_dims(data, 1)
data = np.resize(data, (1, self.maxlen))
if self.time_step >= self.maxlen:
self.img = np.concatenate([self.img[:, 1:], data])
else:
self.img[:, self.time_step:self.time_step + 1] = data
self.time_step += 1
def get_image(self) -> np.ndarray:
r"""
Overview:
Return the distribution image
Returns:
- img (:obj:`np.ndarray`): The calculated distribution image
"""
norm_img = np.copy(self.img)
valid = norm_img[:, :self.time_step]
if self.val_range is None:
valid = (valid - valid.min()) / (valid.max() - valid.min())
else:
valid = np.clip(valid, self.val_range['min'], self.val_range['max'])
valid = (valid - self.val_range['min']) / (self.val_range['max'] - self.val_range['min'])
norm_img[:, :self.time_step] = valid
return np.stack([self.one_img, norm_img, norm_img], axis=0)
def pretty_print(result: dict, direct_print: bool = True) -> str:
r"""
Overview:
......
......@@ -3,8 +3,9 @@ import numpy as np
import torch
from collections import namedtuple
from ding.utils.default_helper import lists_to_dicts, dicts_to_lists, squeeze, default_get, override, error_wrapper,\
list_split, LimitedSpaceContainer, set_pkg_seed, deep_merge_dicts, deep_update, flatten_dict
from ding.utils.default_helper import lists_to_dicts, dicts_to_lists, squeeze, default_get, override, error_wrapper, \
list_split, LimitedSpaceContainer, set_pkg_seed, deep_merge_dicts, deep_update, flatten_dict, RunningMeanStd, \
one_time_warning, split_data_generator
@pytest.mark.unittest
......@@ -84,6 +85,7 @@ class TestDefaultHelper():
wrap_bad_ret = error_wrapper(bad_ret, 0)
assert wrap_bad_ret(1) == 0
wrap_bad_ret_with_customized_log = error_wrapper(bad_ret, 0, 'customized_information')
def test_list_split(self):
data = [i for i in range(10)]
......@@ -213,3 +215,60 @@ class TestDict:
assert flat['b/d/e'] == 6
assert flat['b/d/f'] == 5
assert flat['b/z'] == 4
def test_one_time_warning(self):
one_time_warning('test_one_time_warning')
def test_running_mean_std(self):
running = RunningMeanStd()
running.reset()
running.update(np.arange(1, 10))
assert running.mean == pytest.approx(5, abs=1e-4)
assert running.std == pytest.approx(2.582030, abs=1e-6)
running.update(np.arange(2, 11))
assert running.mean == pytest.approx(5.5, abs=1e-4)
assert running.std == pytest.approx(2.629981, abs=1e-6)
running.reset()
running.update(np.arange(1, 10))
assert pytest.approx(running.mean, 5)
assert running.mean == pytest.approx(5, abs=1e-4)
assert running.std == pytest.approx(2.582030, abs=1e-6)
new_shape = running.new_shape((2, 4), (3, ), (1, ))
assert isinstance(new_shape, tuple) and len(new_shape) == 3
running = RunningMeanStd(shape=(4, ))
running.reset()
running.update(np.random.random((10, 4)))
assert isinstance(running.mean, torch.Tensor) and running.mean.shape == (4, )
assert isinstance(running.std, torch.Tensor) and running.std.shape == (4, )
def test_split_data_generator(self):
def get_data():
return {
'obs': torch.randn(5),
'action': torch.randint(0, 10, size=(1, )),
'prev_state': [None, None],
'info': {
'other_obs': torch.randn(5)
},
}
data = [get_data() for _ in range(4)]
data = lists_to_dicts(data)
data['obs'] = torch.stack(data['obs'])
data['action'] = torch.stack(data['action'])
data['info'] = {'other_obs': torch.stack([t['other_obs'] for t in data['info']])}
assert len(data['obs']) == 4
data['NoneKey'] = None
generator = split_data_generator(data, 3)
generator_result = list(generator)
assert len(generator_result) == 2
assert generator_result[0]['NoneKey'] is None
assert len(generator_result[0]['obs']) == 3
assert generator_result[0]['info']['other_obs'].shape == (3, 5)
assert generator_result[1]['NoneKey'] is None
assert len(generator_result[1]['obs']) == 3
assert generator_result[1]['info']['other_obs'].shape == (3, 5)
generator = split_data_generator(data, 3, shuffle=False)
import pytest
import ding
from ding.utils.import_helper import try_import_ceph, try_import_mc, try_import_redis, try_import_rediscluster, \
try_import_link, import_module
......@@ -12,3 +13,5 @@ def test_try_import():
try_import_rediscluster()
try_import_link()
import_module(['ding.utils'])
ding.enable_linklink = True
try_import_link()
import pytest
import numpy as np
import time
from ding.utils.time_helper import build_time_helper, WatchDog
from ding.utils.time_helper import build_time_helper, WatchDog, TimeWrapperTime, EasyTimer
@pytest.mark.unittest
......@@ -17,10 +17,12 @@ class TestTimeHelper:
setattr(cfg.common, 'time_wrapper_type', 'time')
with pytest.raises(RuntimeError):
time_handle = build_time_helper()
build_time_helper(cfg=None, wrapper_type="??")
# with pytest.raises(KeyError):
# build_time_helper(cfg=None,wrapper_type="not_implement")
with pytest.raises(KeyError):
build_time_helper(cfg=None, wrapper_type="not_implement")
time_handle = build_time_helper(cfg)
time_handle = build_time_helper(wrapper_type='cuda')
# wrapper_type='cuda' but cuda is not available
assert issubclass(time_handle, TimeWrapperTime)
time_handle = build_time_helper(wrapper_type='time')
@time_handle.wrapper
......@@ -52,14 +54,21 @@ class TestTimeHelper:
# assert abs(t-1) < 1e-3
assert abs(t - 1) < 1e-2
timer = EasyTimer()
with timer:
tmp = np.random.random(size=(4, 100))
tmp = tmp ** 2
value = timer.value
assert isinstance(value, float)
@pytest.mark.unittest
class TestWatchDog:
def test_naive(self):
watchdog = WatchDog(5)
watchdog = WatchDog(3)
watchdog.start()
time.sleep(4)
time.sleep(2)
with pytest.raises(TimeoutError):
time.sleep(4)
time.sleep(4)
time.sleep(2)
watchdog.stop()
......@@ -4,6 +4,8 @@ from typing import Any, Callable
import torch
from easydict import EasyDict
from .time_helper_base import TimeWrapper
from .time_helper_cuda import get_cuda_time_wrapper
def build_time_helper(cfg: EasyDict = None, wrapper_type: str = None) -> Callable[[], 'TimeWrapper']:
......@@ -31,11 +33,14 @@ def build_time_helper(cfg: EasyDict = None, wrapper_type: str = None) -> Callabl
else:
raise RuntimeError('Either wrapper_type or cfg should be provided.')
if time_wrapper_type == 'time' or (not torch.cuda.is_available()):
if time_wrapper_type == 'time':
return TimeWrapperTime
elif time_wrapper_type == 'cuda':
# lazy initialize to make code runnable locally
return get_cuda_time_wrapper()
if torch.cuda.is_available():
# lazy initialize to make code runnable locally
return get_cuda_time_wrapper()
else:
return TimeWrapperTime
else:
raise KeyError('invalid time_wrapper_type: {}'.format(time_wrapper_type))
......@@ -86,49 +91,6 @@ class EasyTimer:
self.value = self._timer.end_time()
class TimeWrapper(object):
r"""
Overview:
Abstract class method that defines ``TimeWrapper`` class
Interface:
``wrapper``, ``start_time``, ``end_time``
"""
@classmethod
def wrapper(cls, fn):
r"""
Overview:
Classmethod wrapper, wrap a function and automatically return its running time
- fn (:obj:`function`): The function to be wrap and timed
"""
def time_func(*args, **kwargs):
cls.start_time()
ret = fn(*args, **kwargs)
t = cls.end_time()
return ret, t
return time_func
@classmethod
def start_time(cls):
r"""
Overview:
Abstract classmethod, start timing
"""
raise NotImplementedError
@classmethod
def end_time(cls):
r"""
Overview:
Abstract classmethod, stop timing
"""
raise NotImplementedError
class TimeWrapperTime(TimeWrapper):
r"""
Overview:
......@@ -161,62 +123,6 @@ class TimeWrapperTime(TimeWrapper):
return cls.end - cls.start
def get_cuda_time_wrapper() -> Callable[[], 'TimeWrapper']:
r"""
Overview:
Return the ``TimeWrapperCuda`` class
Returns:
- TimeWrapperCuda(:obj:`class`): See ``TimeWrapperCuda`` class
.. note::
Must use ``torch.cuda.synchronize()``, reference: <https://blog.csdn.net/u013548568/article/details/81368019>
"""
# TODO find a way to autodoc the class within method
class TimeWrapperCuda(TimeWrapper):
r"""
Overview:
A class method that inherit from ``TimeWrapper`` class
Notes:
Must use torch.cuda.synchronize(), reference: \
<https://blog.csdn.net/u013548568/article/details/81368019>
Interface:
``start_time``, ``end_time``
"""
# cls variable is initialized on loading this class
start_record = torch.cuda.Event(enable_timing=True)
end_record = torch.cuda.Event(enable_timing=True)
# overwrite
@classmethod
def start_time(cls):
r"""
Overview:
Implement and overide the ``start_time`` method in ``TimeWrapper`` class
"""
torch.cuda.synchronize()
cls.start = cls.start_record.record()
# overwrite
@classmethod
def end_time(cls):
r"""
Overview:
Implement and overide the end_time method in ``TimeWrapper`` class
Returns:
- time(:obj:`float`): The time between ``start_time`` and ``end_time``
"""
cls.end = cls.end_record.record()
torch.cuda.synchronize()
return cls.start_record.elapsed_time(cls.end_record) / 1000
return TimeWrapperCuda
class WatchDog(object):
"""
Overview:
......
class TimeWrapper(object):
r"""
Overview:
Abstract class method that defines ``TimeWrapper`` class
Interface:
``wrapper``, ``start_time``, ``end_time``
"""
@classmethod
def wrapper(cls, fn):
r"""
Overview:
Classmethod wrapper, wrap a function and automatically return its running time
- fn (:obj:`function`): The function to be wrap and timed
"""
def time_func(*args, **kwargs):
cls.start_time()
ret = fn(*args, **kwargs)
t = cls.end_time()
return ret, t
return time_func
@classmethod
def start_time(cls):
r"""
Overview:
Abstract classmethod, start timing
"""
raise NotImplementedError
@classmethod
def end_time(cls):
r"""
Overview:
Abstract classmethod, stop timing
"""
raise NotImplementedError
from typing import Callable
import torch
from .time_helper_base import TimeWrapper
def get_cuda_time_wrapper() -> Callable[[], 'TimeWrapper']:
r"""
Overview:
Return the ``TimeWrapperCuda`` class, this wrapper aims to ensure compatibility in no cuda device
Returns:
- TimeWrapperCuda(:obj:`class`): See ``TimeWrapperCuda`` class
.. note::
Must use ``torch.cuda.synchronize()``, reference: <https://blog.csdn.net/u013548568/article/details/81368019>
"""
# TODO find a way to autodoc the class within method
class TimeWrapperCuda(TimeWrapper):
r"""
Overview:
A class method that inherit from ``TimeWrapper`` class
Notes:
Must use torch.cuda.synchronize(), reference: \
<https://blog.csdn.net/u013548568/article/details/81368019>
Interface:
``start_time``, ``end_time``
"""
# cls variable is initialized on loading this class
start_record = torch.cuda.Event(enable_timing=True)
end_record = torch.cuda.Event(enable_timing=True)
# overwrite
@classmethod
def start_time(cls):
r"""
Overview:
Implement and overide the ``start_time`` method in ``TimeWrapper`` class
"""
torch.cuda.synchronize()
cls.start = cls.start_record.record()
# overwrite
@classmethod
def end_time(cls):
r"""
Overview:
Implement and overide the end_time method in ``TimeWrapper`` class
Returns:
- time(:obj:`float`): The time between ``start_time`` and ``end_time``
"""
cls.end = cls.end_record.record()
torch.cuda.synchronize()
return cls.start_record.elapsed_time(cls.end_record) / 1000
return TimeWrapperCuda
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册