Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
ec068cb5
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 接近 3 年
通知
65
Star
322
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ec068cb5
编写于
11月 08, 2021
作者:
X
Xu Jingxin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support sample by grouped meta key
上级
12db1bad
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
82 addition
and
3 deletion
+82
-3
ding/worker/buffer/buffer.py
ding/worker/buffer/buffer.py
+5
-1
ding/worker/buffer/deque_buffer.py
ding/worker/buffer/deque_buffer.py
+50
-2
ding/worker/buffer/tests/test_buffer.py
ding/worker/buffer/tests/test_buffer.py
+27
-0
未找到文件。
ding/worker/buffer/buffer.py
浏览文件 @
ec068cb5
...
...
@@ -71,7 +71,9 @@ class Buffer:
indices
:
Optional
[
List
[
str
]]
=
None
,
replace
:
bool
=
False
,
sample_range
:
Optional
[
slice
]
=
None
,
ignore_insufficient
:
bool
=
False
ignore_insufficient
:
bool
=
False
,
groupby
:
str
=
None
,
rolling_window
:
int
=
None
)
->
List
[
BufferedData
]:
"""
Overview:
...
...
@@ -83,6 +85,8 @@ class Buffer:
- sample_range (:obj:`slice`): Sample range slice.
- ignore_insufficient (:obj:`bool`): If ignore_insufficient is true, sampling more than buffer size
with no repetition will not cause an exception.
- groupby (:obj:`str`): Groupby key in meta.
- rolling_window (:obj:`int`): Return batches of window size.
Returns:
- sample_data (:obj:`List[BufferedData]`):
A list of data with length ``size``.
...
...
ding/worker/buffer/deque_buffer.py
浏览文件 @
ec068cb5
from
typing
import
Any
,
Iterable
,
List
,
Optional
,
Union
from
collections
import
defaultdict
,
deque
from
ding.worker.buffer
import
Buffer
,
apply_middleware
,
BufferedData
from
ding.worker.buffer.utils
import
fastcopy
import
itertools
...
...
@@ -13,12 +14,20 @@ class DequeBuffer(Buffer):
def
__init__
(
self
,
size
:
int
)
->
None
:
super
().
__init__
()
self
.
storage
=
deque
(
maxlen
=
size
)
# Meta index is a dict which use deque as values
self
.
meta_index
=
{}
@
apply_middleware
(
"push"
)
def
push
(
self
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
)
->
BufferedData
:
index
=
uuid
.
uuid1
().
hex
if
meta
is
None
:
meta
=
{}
buffered
=
BufferedData
(
data
=
data
,
index
=
index
,
meta
=
meta
)
self
.
storage
.
append
(
buffered
)
# Add meta index
for
key
in
self
.
meta_index
:
self
.
meta_index
[
key
].
append
(
meta
[
key
]
if
key
in
meta
else
None
)
return
buffered
@
apply_middleware
(
"sample"
)
...
...
@@ -29,25 +38,33 @@ class DequeBuffer(Buffer):
replace
:
bool
=
False
,
sample_range
:
Optional
[
slice
]
=
None
,
ignore_insufficient
:
bool
=
False
,
groupby
:
str
=
None
,
rolling_window
:
int
=
None
)
->
List
[
BufferedData
]:
storage
=
self
.
storage
if
sample_range
:
storage
=
list
(
itertools
.
islice
(
self
.
storage
,
sample_range
.
start
,
sample_range
.
stop
,
sample_range
.
step
))
# Size and indices
assert
size
or
indices
,
"One of size and indices must not be empty."
if
(
size
and
indices
)
and
(
size
!=
len
(
indices
)):
raise
AssertionError
(
"Size and indices length must be equal."
)
if
not
size
:
size
=
len
(
indices
)
# Indices and groupby
assert
not
(
indices
and
groupby
),
"Cannot use groupby and indicex at the same time."
value_error
=
None
sampled_data
=
[]
if
indices
:
indices_set
=
set
(
indices
)
hashed_data
=
filter
(
lambda
item
:
item
.
index
in
indices_set
,
s
elf
.
s
torage
)
hashed_data
=
filter
(
lambda
item
:
item
.
index
in
indices_set
,
storage
)
hashed_data
=
map
(
lambda
item
:
(
item
.
index
,
item
),
hashed_data
)
hashed_data
=
dict
(
hashed_data
)
# Re-sample and return in indices order
sampled_data
=
[
hashed_data
[
index
]
for
index
in
indices
]
elif
groupby
:
sampled_data
=
self
.
_sample_by_group
(
size
=
size
,
groupby
=
groupby
,
storage
=
storage
,
replace
=
replace
)
else
:
if
replace
:
sampled_data
=
random
.
choices
(
storage
,
k
=
size
)
...
...
@@ -66,7 +83,10 @@ class DequeBuffer(Buffer):
else
:
raise
ValueError
(
"There are less than {} data in buffer({})"
.
format
(
size
,
self
.
count
()))
sampled_data
=
self
.
_independence
(
sampled_data
)
if
groupby
:
sampled_data
=
[
self
.
_independence
(
data
)
for
data
in
sampled_data
]
else
:
sampled_data
=
self
.
_independence
(
sampled_data
)
return
sampled_data
...
...
@@ -109,6 +129,34 @@ class DequeBuffer(Buffer):
buffered_samples
[
i
]
=
fastcopy
.
copy
(
buffered
)
return
buffered_samples
def
_sample_by_group
(
self
,
size
:
int
,
groupby
:
str
,
storage
:
deque
,
replace
:
bool
=
False
)
->
List
[
List
[
BufferedData
]]:
if
groupby
not
in
self
.
meta_index
:
self
.
_create_index
(
groupby
)
meta_indices
=
list
(
set
(
self
.
meta_index
[
groupby
]))
sampled_groups
=
[]
if
replace
:
sampled_groups
=
random
.
choices
(
meta_indices
,
k
=
size
)
else
:
try
:
sampled_groups
=
random
.
sample
(
meta_indices
,
k
=
size
)
except
ValueError
as
e
:
pass
sampled_data
=
defaultdict
(
list
)
for
buffered
in
storage
:
meta_value
=
buffered
.
meta
[
groupby
]
if
groupby
in
buffered
.
meta
else
None
if
meta_value
in
sampled_groups
:
sampled_data
[
buffered
.
meta
[
groupby
]].
append
(
buffered
)
return
sampled_data
.
values
()
def
_create_index
(
self
,
meta_key
:
str
):
self
.
meta_index
[
meta_key
]
=
deque
(
maxlen
=
self
.
storage
.
maxlen
)
for
data
in
self
.
storage
:
self
.
meta_index
[
meta_key
].
append
(
data
.
meta
[
meta_key
]
if
meta_key
in
data
.
meta
else
None
)
def
__iter__
(
self
)
->
deque
:
return
iter
(
self
.
storage
)
...
...
ding/worker/buffer/tests/test_buffer.py
浏览文件 @
ec068cb5
...
...
@@ -187,3 +187,30 @@ def test_independence():
assert
len
(
sampled_data
)
==
2
sampled_data
[
0
].
data
[
"key"
]
=
"new"
assert
sampled_data
[
1
].
data
[
"key"
]
==
"origin"
@
pytest
.
mark
.
unittest
def
test_groupby
():
buffer
=
DequeBuffer
(
size
=
3
)
buffer
.
push
(
"a"
,
{
"group"
:
1
})
buffer
.
push
(
"b"
,
{
"group"
:
2
})
buffer
.
push
(
"c"
,
{
"group"
:
2
})
sampled_data
=
buffer
.
sample
(
2
,
groupby
=
"group"
)
assert
len
(
sampled_data
)
==
2
group1
=
sampled_data
[
0
]
if
len
(
sampled_data
[
0
])
==
1
else
sampled_data
[
1
]
group2
=
sampled_data
[
0
]
if
len
(
sampled_data
[
0
])
==
2
else
sampled_data
[
1
]
# Group1 should contain a
assert
"a"
==
group1
[
0
].
data
# Group2 should contain b and c
data
=
[
buffered
.
data
for
buffered
in
group2
]
# ["b", "c"]
assert
"b"
in
data
assert
"c"
in
data
# Push new data and swap out a
buffer
.
push
(
"d"
,
{
"group"
:
2
})
sampled_data
=
buffer
.
sample
(
1
,
groupby
=
"group"
)
assert
len
(
sampled_data
)
==
1
assert
len
(
sampled_data
[
0
])
==
3
data
=
[
buffered
.
data
for
buffered
in
sampled_data
[
0
]]
assert
"d"
in
data
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录