Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
e9246b74
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 接近 3 年
通知
65
Star
322
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
e9246b74
编写于
10月 21, 2021
作者:
X
Xu Jingxin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Test slicing
上级
d40bdf16
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
23 addition
and
6 deletion
+23
-6
ding/worker/buffer/buffer.py
ding/worker/buffer/buffer.py
+9
-2
ding/worker/buffer/memory_storage.py
ding/worker/buffer/memory_storage.py
+6
-3
ding/worker/buffer/storage.py
ding/worker/buffer/storage.py
+1
-1
ding/worker/buffer/tests/test_buffer.py
ding/worker/buffer/tests/test_buffer.py
+7
-0
未找到文件。
ding/worker/buffer/buffer.py
浏览文件 @
e9246b74
...
...
@@ -8,6 +8,13 @@ def apply_middleware(func_name: str):
def
wrap_func
(
base_func
:
Callable
):
def
handler
(
buffer
,
*
args
,
**
kwargs
):
"""
The real processing starts here, we apply the middlewares one by one,
each middleware will receive a `next` function, which is an executor of next
middleware. You can change the input arguments to the `next` middleware, and you
also can get the return value from the next middleware, so you have the
maximum freedom to choose at what stage to implement your method.
"""
def
wrap_handler
(
middlewares
,
*
args
,
**
kwargs
):
if
len
(
middlewares
)
==
0
:
...
...
@@ -49,7 +56,7 @@ class Buffer:
self
.
storage
.
append
(
data
)
@
apply_middleware
(
"sample"
)
def
sample
(
self
,
size
:
int
,
replace
:
bool
=
False
)
->
List
[
Any
]:
def
sample
(
self
,
size
:
int
,
replace
:
bool
=
False
,
range
:
slice
=
None
)
->
List
[
Any
]:
"""
Overview:
Sample data with length ``size``, this function may be wrapped by middlewares.
...
...
@@ -58,7 +65,7 @@ class Buffer:
Returns:
- sample_data (:obj:`list`): A list of data with length ``size``.
"""
return
self
.
storage
.
sample
(
size
,
replace
)
return
self
.
storage
.
sample
(
size
,
replace
=
replace
,
range
=
range
)
@
apply_middleware
(
"clear"
)
def
clear
(
self
)
->
None
:
...
...
ding/worker/buffer/memory_storage.py
浏览文件 @
e9246b74
...
...
@@ -2,8 +2,8 @@ from typing import Any, List
from
collections
import
deque
from
operator
import
itemgetter
from
ding.worker.buffer
import
Storage
import
random
import
numpy
as
np
import
itertools
class
MemoryStorage
(
Storage
):
...
...
@@ -17,8 +17,11 @@ class MemoryStorage(Storage):
def
get
(
self
,
indices
:
List
[
int
])
->
List
[
Any
]:
return
itemgetter
(
*
indices
)(
self
.
storage
)
def
sample
(
self
,
size
:
int
,
replace
:
bool
=
False
)
->
List
[
Any
]:
return
np
.
random
.
choice
(
self
.
storage
,
size
,
replace
=
replace
)
def
sample
(
self
,
size
:
int
,
replace
:
bool
=
False
,
range
:
slice
=
None
)
->
List
[
Any
]:
storage
=
self
.
storage
if
range
:
storage
=
list
(
itertools
.
islice
(
self
.
storage
,
range
.
start
,
range
.
stop
,
range
.
step
))
return
np
.
random
.
choice
(
storage
,
size
,
replace
=
replace
)
def
count
(
self
)
->
int
:
return
len
(
self
.
storage
)
...
...
ding/worker/buffer/storage.py
浏览文件 @
e9246b74
...
...
@@ -13,7 +13,7 @@ class Storage:
raise
NotImplementedError
@
abstractmethod
def
sample
(
self
,
size
:
int
,
replace
:
bool
=
False
)
->
List
[
Any
]:
def
sample
(
self
,
size
:
int
,
replace
:
bool
=
False
,
range
:
slice
=
None
)
->
List
[
Any
]:
raise
NotImplementedError
@
abstractmethod
...
...
ding/worker/buffer/tests/test_buffer.py
浏览文件 @
e9246b74
...
...
@@ -75,6 +75,13 @@ def test_naive_push_sample():
assert
storage
.
count
()
==
5
assert
len
(
buffer
.
sample
(
10
,
replace
=
True
))
==
10
# Test slicing
buffer
.
clear
()
for
i
in
range
(
10
):
buffer
.
push
(
i
)
assert
len
(
buffer
.
sample
(
5
,
range
=
slice
(
5
,
10
)))
==
5
assert
0
not
in
buffer
.
sample
(
5
,
range
=
slice
(
5
,
10
))
@
pytest
.
mark
.
unittest
def
test_rate_limit_push_sample
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录