Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
ad394fc5
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
56
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
ad394fc5
编写于
10月 17, 2021
作者:
N
niuyazhe
1
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test(nyz): add test for ding/utils and remove DistributionImage
上级
1568e53d
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
196 addition
and
171 deletion
+196
-171
.coveragerc
.coveragerc
+3
-0
ding/utils/__init__.py
ding/utils/__init__.py
+1
-1
ding/utils/data/tests/test_dataset.py
ding/utils/data/tests/test_dataset.py
+3
-1
ding/utils/log_helper.py
ding/utils/log_helper.py
+0
-57
ding/utils/tests/test_default_helper.py
ding/utils/tests/test_default_helper.py
+61
-2
ding/utils/tests/test_import_helper.py
ding/utils/tests/test_import_helper.py
+3
-0
ding/utils/tests/test_time_helper.py
ding/utils/tests/test_time_helper.py
+17
-8
ding/utils/time_helper.py
ding/utils/time_helper.py
+8
-102
ding/utils/time_helper_base.py
ding/utils/time_helper_base.py
+41
-0
ding/utils/time_helper_cuda.py
ding/utils/time_helper_cuda.py
+59
-0
未找到文件。
.coveragerc
浏览文件 @
ad394fc5
...
...
@@ -5,6 +5,9 @@ omit =
ding/utils/linklink_dist_helper.py
ding/utils/pytorch_ddp_dist_helper.py
ding/utils/k8s_helper.py
ding/utils/time_helper_cuda.py
ding/utils/time_helper_base.py
ding/utils/data/tests/test_dataloader.py
ding/config/utils.py
ding/entry/tests/test_serial_entry_algo.py
ding/entry/tests/test_serial_entry.py
...
...
ding/utils/__init__.py
浏览文件 @
ad394fc5
...
...
@@ -13,7 +13,7 @@ from .k8s_helper import get_operator_server_kwargs, exist_operator_server, DEFAU
K8sLauncher
from
.orchestrator_launcher
import
OrchestratorLauncher
from
.lock_helper
import
LockContext
,
LockContextType
,
get_file_lock
,
get_rw_file_lock
from
.log_helper
import
build_logger
,
DistributionTimeImage
,
pretty_print
,
LoggerFactory
from
.log_helper
import
build_logger
,
pretty_print
,
LoggerFactory
from
.registry_factory
import
registries
,
POLICY_REGISTRY
,
ENV_REGISTRY
,
LEARNER_REGISTRY
,
COMM_LEARNER_REGISTRY
,
\
SERIAL_COLLECTOR_REGISTRY
,
PARALLEL_COLLECTOR_REGISTRY
,
COMM_COLLECTOR_REGISTRY
,
\
COMMANDER_REGISTRY
,
LEAGUE_REGISTRY
,
PLAYER_REGISTRY
,
MODEL_REGISTRY
,
\
...
...
ding/utils/data/tests/test_dataset.py
浏览文件 @
ad394fc5
...
...
@@ -19,7 +19,7 @@ cfg2 = dict(policy=dict(collect=dict(
cfg3
=
dict
(
env
=
dict
(
env_id
=
'hopper-expert-v0'
),
policy
=
dict
(
collect
=
dict
(
data_type
=
'd4rl'
,
),
))
cfgs
=
[
cfg1
,
cfg2
,
cfg3
]
cfgs
=
[
cfg1
,
cfg2
]
# cfg3
unittest_args
=
[
'naive'
,
'hdf5'
]
# fake transition & data
...
...
@@ -36,11 +36,13 @@ expert_data_path = './expert.pkl'
@
pytest
.
mark
.
parametrize
(
'data_type'
,
unittest_args
)
@
pytest
.
mark
.
unittest
def
test_offline_data_save_type
(
data_type
):
offline_data_save_type
(
exp_data
=
fake_data
,
expert_data_path
=
expert_data_path
,
data_type
=
data_type
)
@
pytest
.
mark
.
parametrize
(
'cfg'
,
cfgs
)
@
pytest
.
mark
.
unittest
def
test_dataset
(
cfg
):
cfg
=
EasyDict
(
cfg
)
create_dataset
(
cfg
)
ding/utils/log_helper.py
浏览文件 @
ad394fc5
...
...
@@ -102,63 +102,6 @@ class LoggerFactory(object):
return
s
class
DistributionTimeImage
:
r
"""
Overview:
``DistributionTimeImage`` can be used to store images accorrding to ``time_steps``,
for data with 3 dims``(time, category, value)``
Interface:
``__init__``, ``add_one_time_step``, ``get_image``
"""
def
__init__
(
self
,
maxlen
:
int
=
600
,
val_range
:
Optional
[
dict
]
=
None
):
r
"""
Overview:
Init the ``DistributionTimeImage`` class
Arguments:
- maxlen (:obj:`int`): The max length of data inputs
- val_range (:obj:`dict` or :obj:`None`): Dict with ``val_range['min']`` and ``val_range['max']``.
"""
self
.
maxlen
=
maxlen
self
.
val_range
=
val_range
self
.
img
=
np
.
ones
((
maxlen
,
maxlen
))
self
.
time_step
=
0
self
.
one_img
=
np
.
ones
((
maxlen
,
maxlen
))
def
add_one_time_step
(
self
,
data
:
np
.
ndarray
)
->
None
:
r
"""
Overview:
Step one timestep in ``DistributionTimeImage`` and add the data to distribution image
Arguments:
- data (:obj:`np.ndarray`): The data input
"""
assert
(
isinstance
(
data
,
np
.
ndarray
))
data
=
np
.
expand_dims
(
data
,
1
)
data
=
np
.
resize
(
data
,
(
1
,
self
.
maxlen
))
if
self
.
time_step
>=
self
.
maxlen
:
self
.
img
=
np
.
concatenate
([
self
.
img
[:,
1
:],
data
])
else
:
self
.
img
[:,
self
.
time_step
:
self
.
time_step
+
1
]
=
data
self
.
time_step
+=
1
def
get_image
(
self
)
->
np
.
ndarray
:
r
"""
Overview:
Return the distribution image
Returns:
- img (:obj:`np.ndarray`): The calculated distribution image
"""
norm_img
=
np
.
copy
(
self
.
img
)
valid
=
norm_img
[:,
:
self
.
time_step
]
if
self
.
val_range
is
None
:
valid
=
(
valid
-
valid
.
min
())
/
(
valid
.
max
()
-
valid
.
min
())
else
:
valid
=
np
.
clip
(
valid
,
self
.
val_range
[
'min'
],
self
.
val_range
[
'max'
])
valid
=
(
valid
-
self
.
val_range
[
'min'
])
/
(
self
.
val_range
[
'max'
]
-
self
.
val_range
[
'min'
])
norm_img
[:,
:
self
.
time_step
]
=
valid
return
np
.
stack
([
self
.
one_img
,
norm_img
,
norm_img
],
axis
=
0
)
def
pretty_print
(
result
:
dict
,
direct_print
:
bool
=
True
)
->
str
:
r
"""
Overview:
...
...
ding/utils/tests/test_default_helper.py
浏览文件 @
ad394fc5
...
...
@@ -3,8 +3,9 @@ import numpy as np
import
torch
from
collections
import
namedtuple
from
ding.utils.default_helper
import
lists_to_dicts
,
dicts_to_lists
,
squeeze
,
default_get
,
override
,
error_wrapper
,
\
list_split
,
LimitedSpaceContainer
,
set_pkg_seed
,
deep_merge_dicts
,
deep_update
,
flatten_dict
from
ding.utils.default_helper
import
lists_to_dicts
,
dicts_to_lists
,
squeeze
,
default_get
,
override
,
error_wrapper
,
\
list_split
,
LimitedSpaceContainer
,
set_pkg_seed
,
deep_merge_dicts
,
deep_update
,
flatten_dict
,
RunningMeanStd
,
\
one_time_warning
,
split_data_generator
@
pytest
.
mark
.
unittest
...
...
@@ -84,6 +85,7 @@ class TestDefaultHelper():
wrap_bad_ret
=
error_wrapper
(
bad_ret
,
0
)
assert
wrap_bad_ret
(
1
)
==
0
wrap_bad_ret_with_customized_log
=
error_wrapper
(
bad_ret
,
0
,
'customized_information'
)
def
test_list_split
(
self
):
data
=
[
i
for
i
in
range
(
10
)]
...
...
@@ -213,3 +215,60 @@ class TestDict:
assert
flat
[
'b/d/e'
]
==
6
assert
flat
[
'b/d/f'
]
==
5
assert
flat
[
'b/z'
]
==
4
def
test_one_time_warning
(
self
):
one_time_warning
(
'test_one_time_warning'
)
def
test_running_mean_std
(
self
):
running
=
RunningMeanStd
()
running
.
reset
()
running
.
update
(
np
.
arange
(
1
,
10
))
assert
running
.
mean
==
pytest
.
approx
(
5
,
abs
=
1e-4
)
assert
running
.
std
==
pytest
.
approx
(
2.582030
,
abs
=
1e-6
)
running
.
update
(
np
.
arange
(
2
,
11
))
assert
running
.
mean
==
pytest
.
approx
(
5.5
,
abs
=
1e-4
)
assert
running
.
std
==
pytest
.
approx
(
2.629981
,
abs
=
1e-6
)
running
.
reset
()
running
.
update
(
np
.
arange
(
1
,
10
))
assert
pytest
.
approx
(
running
.
mean
,
5
)
assert
running
.
mean
==
pytest
.
approx
(
5
,
abs
=
1e-4
)
assert
running
.
std
==
pytest
.
approx
(
2.582030
,
abs
=
1e-6
)
new_shape
=
running
.
new_shape
((
2
,
4
),
(
3
,
),
(
1
,
))
assert
isinstance
(
new_shape
,
tuple
)
and
len
(
new_shape
)
==
3
running
=
RunningMeanStd
(
shape
=
(
4
,
))
running
.
reset
()
running
.
update
(
np
.
random
.
random
((
10
,
4
)))
assert
isinstance
(
running
.
mean
,
torch
.
Tensor
)
and
running
.
mean
.
shape
==
(
4
,
)
assert
isinstance
(
running
.
std
,
torch
.
Tensor
)
and
running
.
std
.
shape
==
(
4
,
)
def
test_split_data_generator
(
self
):
def
get_data
():
return
{
'obs'
:
torch
.
randn
(
5
),
'action'
:
torch
.
randint
(
0
,
10
,
size
=
(
1
,
)),
'prev_state'
:
[
None
,
None
],
'info'
:
{
'other_obs'
:
torch
.
randn
(
5
)
},
}
data
=
[
get_data
()
for
_
in
range
(
4
)]
data
=
lists_to_dicts
(
data
)
data
[
'obs'
]
=
torch
.
stack
(
data
[
'obs'
])
data
[
'action'
]
=
torch
.
stack
(
data
[
'action'
])
data
[
'info'
]
=
{
'other_obs'
:
torch
.
stack
([
t
[
'other_obs'
]
for
t
in
data
[
'info'
]])}
assert
len
(
data
[
'obs'
])
==
4
data
[
'NoneKey'
]
=
None
generator
=
split_data_generator
(
data
,
3
)
generator_result
=
list
(
generator
)
assert
len
(
generator_result
)
==
2
assert
generator_result
[
0
][
'NoneKey'
]
is
None
assert
len
(
generator_result
[
0
][
'obs'
])
==
3
assert
generator_result
[
0
][
'info'
][
'other_obs'
].
shape
==
(
3
,
5
)
assert
generator_result
[
1
][
'NoneKey'
]
is
None
assert
len
(
generator_result
[
1
][
'obs'
])
==
3
assert
generator_result
[
1
][
'info'
][
'other_obs'
].
shape
==
(
3
,
5
)
generator
=
split_data_generator
(
data
,
3
,
shuffle
=
False
)
ding/utils/tests/test_import_helper.py
浏览文件 @
ad394fc5
import
pytest
import
ding
from
ding.utils.import_helper
import
try_import_ceph
,
try_import_mc
,
try_import_redis
,
try_import_rediscluster
,
\
try_import_link
,
import_module
...
...
@@ -12,3 +13,5 @@ def test_try_import():
try_import_rediscluster
()
try_import_link
()
import_module
([
'ding.utils'
])
ding
.
enable_linklink
=
True
try_import_link
()
ding/utils/tests/test_time_helper.py
浏览文件 @
ad394fc5
import
pytest
import
numpy
as
np
import
time
from
ding.utils.time_helper
import
build_time_helper
,
WatchDog
from
ding.utils.time_helper
import
build_time_helper
,
WatchDog
,
TimeWrapperTime
,
EasyTimer
@
pytest
.
mark
.
unittest
...
...
@@ -17,10 +17,12 @@ class TestTimeHelper:
setattr
(
cfg
.
common
,
'time_wrapper_type'
,
'time'
)
with
pytest
.
raises
(
RuntimeError
):
time_handle
=
build_time_helper
()
build_time_helper
(
cfg
=
None
,
wrapper_type
=
"??"
)
# with pytest.raises(KeyError):
# build_time_helper(cfg=None,wrapper_type="not_implement")
with
pytest
.
raises
(
KeyError
):
build_time_helper
(
cfg
=
None
,
wrapper_type
=
"not_implement"
)
time_handle
=
build_time_helper
(
cfg
)
time_handle
=
build_time_helper
(
wrapper_type
=
'cuda'
)
# wrapper_type='cuda' but cuda is not available
assert
issubclass
(
time_handle
,
TimeWrapperTime
)
time_handle
=
build_time_helper
(
wrapper_type
=
'time'
)
@
time_handle
.
wrapper
...
...
@@ -52,14 +54,21 @@ class TestTimeHelper:
# assert abs(t-1) < 1e-3
assert
abs
(
t
-
1
)
<
1e-2
timer
=
EasyTimer
()
with
timer
:
tmp
=
np
.
random
.
random
(
size
=
(
4
,
100
))
tmp
=
tmp
**
2
value
=
timer
.
value
assert
isinstance
(
value
,
float
)
@
pytest
.
mark
.
unittest
class
TestWatchDog
:
def
test_naive
(
self
):
watchdog
=
WatchDog
(
5
)
watchdog
=
WatchDog
(
3
)
watchdog
.
start
()
time
.
sleep
(
4
)
time
.
sleep
(
2
)
with
pytest
.
raises
(
TimeoutError
):
time
.
sleep
(
4
)
time
.
sleep
(
4
)
time
.
sleep
(
2
)
watchdog
.
stop
(
)
ding/utils/time_helper.py
浏览文件 @
ad394fc5
...
...
@@ -4,6 +4,8 @@ from typing import Any, Callable
import
torch
from
easydict
import
EasyDict
from
.time_helper_base
import
TimeWrapper
from
.time_helper_cuda
import
get_cuda_time_wrapper
def
build_time_helper
(
cfg
:
EasyDict
=
None
,
wrapper_type
:
str
=
None
)
->
Callable
[[],
'TimeWrapper'
]:
...
...
@@ -31,11 +33,14 @@ def build_time_helper(cfg: EasyDict = None, wrapper_type: str = None) -> Callabl
else
:
raise
RuntimeError
(
'Either wrapper_type or cfg should be provided.'
)
if
time_wrapper_type
==
'time'
or
(
not
torch
.
cuda
.
is_available
())
:
if
time_wrapper_type
==
'time'
:
return
TimeWrapperTime
elif
time_wrapper_type
==
'cuda'
:
# lazy initialize to make code runnable locally
return
get_cuda_time_wrapper
()
if
torch
.
cuda
.
is_available
():
# lazy initialize to make code runnable locally
return
get_cuda_time_wrapper
()
else
:
return
TimeWrapperTime
else
:
raise
KeyError
(
'invalid time_wrapper_type: {}'
.
format
(
time_wrapper_type
))
...
...
@@ -86,49 +91,6 @@ class EasyTimer:
self
.
value
=
self
.
_timer
.
end_time
()
class
TimeWrapper
(
object
):
r
"""
Overview:
Abstract class method that defines ``TimeWrapper`` class
Interface:
``wrapper``, ``start_time``, ``end_time``
"""
@
classmethod
def
wrapper
(
cls
,
fn
):
r
"""
Overview:
Classmethod wrapper, wrap a function and automatically return its running time
- fn (:obj:`function`): The function to be wrap and timed
"""
def
time_func
(
*
args
,
**
kwargs
):
cls
.
start_time
()
ret
=
fn
(
*
args
,
**
kwargs
)
t
=
cls
.
end_time
()
return
ret
,
t
return
time_func
@
classmethod
def
start_time
(
cls
):
r
"""
Overview:
Abstract classmethod, start timing
"""
raise
NotImplementedError
@
classmethod
def
end_time
(
cls
):
r
"""
Overview:
Abstract classmethod, stop timing
"""
raise
NotImplementedError
class
TimeWrapperTime
(
TimeWrapper
):
r
"""
Overview:
...
...
@@ -161,62 +123,6 @@ class TimeWrapperTime(TimeWrapper):
return
cls
.
end
-
cls
.
start
def
get_cuda_time_wrapper
()
->
Callable
[[],
'TimeWrapper'
]:
r
"""
Overview:
Return the ``TimeWrapperCuda`` class
Returns:
- TimeWrapperCuda(:obj:`class`): See ``TimeWrapperCuda`` class
.. note::
Must use ``torch.cuda.synchronize()``, reference: <https://blog.csdn.net/u013548568/article/details/81368019>
"""
# TODO find a way to autodoc the class within method
class
TimeWrapperCuda
(
TimeWrapper
):
r
"""
Overview:
A class method that inherit from ``TimeWrapper`` class
Notes:
Must use torch.cuda.synchronize(), reference: \
<https://blog.csdn.net/u013548568/article/details/81368019>
Interface:
``start_time``, ``end_time``
"""
# cls variable is initialized on loading this class
start_record
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_record
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
# overwrite
@
classmethod
def
start_time
(
cls
):
r
"""
Overview:
Implement and overide the ``start_time`` method in ``TimeWrapper`` class
"""
torch
.
cuda
.
synchronize
()
cls
.
start
=
cls
.
start_record
.
record
()
# overwrite
@
classmethod
def
end_time
(
cls
):
r
"""
Overview:
Implement and overide the end_time method in ``TimeWrapper`` class
Returns:
- time(:obj:`float`): The time between ``start_time`` and ``end_time``
"""
cls
.
end
=
cls
.
end_record
.
record
()
torch
.
cuda
.
synchronize
()
return
cls
.
start_record
.
elapsed_time
(
cls
.
end_record
)
/
1000
return
TimeWrapperCuda
class
WatchDog
(
object
):
"""
Overview:
...
...
ding/utils/time_helper_base.py
0 → 100644
浏览文件 @
ad394fc5
class
TimeWrapper
(
object
):
r
"""
Overview:
Abstract class method that defines ``TimeWrapper`` class
Interface:
``wrapper``, ``start_time``, ``end_time``
"""
@
classmethod
def
wrapper
(
cls
,
fn
):
r
"""
Overview:
Classmethod wrapper, wrap a function and automatically return its running time
- fn (:obj:`function`): The function to be wrap and timed
"""
def
time_func
(
*
args
,
**
kwargs
):
cls
.
start_time
()
ret
=
fn
(
*
args
,
**
kwargs
)
t
=
cls
.
end_time
()
return
ret
,
t
return
time_func
@
classmethod
def
start_time
(
cls
):
r
"""
Overview:
Abstract classmethod, start timing
"""
raise
NotImplementedError
@
classmethod
def
end_time
(
cls
):
r
"""
Overview:
Abstract classmethod, stop timing
"""
raise
NotImplementedError
ding/utils/time_helper_cuda.py
0 → 100644
浏览文件 @
ad394fc5
from
typing
import
Callable
import
torch
from
.time_helper_base
import
TimeWrapper
def
get_cuda_time_wrapper
()
->
Callable
[[],
'TimeWrapper'
]:
r
"""
Overview:
Return the ``TimeWrapperCuda`` class, this wrapper aims to ensure compatibility in no cuda device
Returns:
- TimeWrapperCuda(:obj:`class`): See ``TimeWrapperCuda`` class
.. note::
Must use ``torch.cuda.synchronize()``, reference: <https://blog.csdn.net/u013548568/article/details/81368019>
"""
# TODO find a way to autodoc the class within method
class
TimeWrapperCuda
(
TimeWrapper
):
r
"""
Overview:
A class method that inherit from ``TimeWrapper`` class
Notes:
Must use torch.cuda.synchronize(), reference: \
<https://blog.csdn.net/u013548568/article/details/81368019>
Interface:
``start_time``, ``end_time``
"""
# cls variable is initialized on loading this class
start_record
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_record
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
# overwrite
@
classmethod
def
start_time
(
cls
):
r
"""
Overview:
Implement and overide the ``start_time`` method in ``TimeWrapper`` class
"""
torch
.
cuda
.
synchronize
()
cls
.
start
=
cls
.
start_record
.
record
()
# overwrite
@
classmethod
def
end_time
(
cls
):
r
"""
Overview:
Implement and overide the end_time method in ``TimeWrapper`` class
Returns:
- time(:obj:`float`): The time between ``start_time`` and ``end_time``
"""
cls
.
end
=
cls
.
end_record
.
record
()
torch
.
cuda
.
synchronize
()
return
cls
.
start_record
.
elapsed_time
(
cls
.
end_record
)
/
1000
return
TimeWrapperCuda
OpenDILab开源决策智能平台
@m0_55289267
mentioned in commit
3a67012b
·
10月 18, 2021
mentioned in commit
3a67012b
mentioned in commit 3a67012bbd2d06d9d003dc4ab7ecfc58d9a7035b
开关提交列表
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录