Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
9c67db8b
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 接近 3 年
通知
65
Star
322
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
9c67db8b
编写于
11月 07, 2021
作者:
N
niuyazhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish(nyz): add return index in push and copy same data in sample
上级
f764de31
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
55 addition
and
23 deletion
+55
-23
ding/worker/buffer/buffer.py
ding/worker/buffer/buffer.py
+5
-3
ding/worker/buffer/deque_buffer.py
ding/worker/buffer/deque_buffer.py
+31
-8
ding/worker/buffer/middleware/clone_object.py
ding/worker/buffer/middleware/clone_object.py
+1
-1
ding/worker/buffer/middleware/priority.py
ding/worker/buffer/middleware/priority.py
+13
-6
ding/worker/buffer/middleware/staleness_check.py
ding/worker/buffer/middleware/staleness_check.py
+1
-1
ding/worker/buffer/middleware/use_time_check.py
ding/worker/buffer/middleware/use_time_check.py
+1
-1
ding/worker/buffer/tests/test_buffer.py
ding/worker/buffer/tests/test_buffer.py
+3
-3
未找到文件。
ding/worker/buffer/buffer.py
浏览文件 @
9c67db8b
...
...
@@ -52,13 +52,15 @@ class Buffer:
self
.
middleware
=
[]
@
abstractmethod
def
push
(
self
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
)
->
None
:
def
push
(
self
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
)
->
Any
:
"""
Overview:
Push data and it's meta information in buffer.
Arguments:
- data (:obj:`Any`): The data which will be pushed into buffer.
- meta (:obj:`dict`): Meta information, e.g. priority, count, staleness.
Returns:
- index (:obj:`Any`): The index of pushed data.
"""
raise
NotImplementedError
...
...
@@ -68,7 +70,7 @@ class Buffer:
size
:
Optional
[
int
]
=
None
,
indices
:
Optional
[
List
[
str
]]
=
None
,
replace
:
bool
=
False
,
range
:
Optional
[
slice
]
=
None
,
sample_
range
:
Optional
[
slice
]
=
None
,
ignore_insufficient
:
bool
=
False
)
->
List
[
BufferedData
]:
"""
...
...
@@ -78,7 +80,7 @@ class Buffer:
- size (:obj:`Optional[int]`): The number of the data that will be sampled.
- indices (:obj:`Optional[List[str]]`): Sample with multiple indices.
- replace (:obj:`bool`): If use replace is true, you may receive duplicated data from the buffer.
-
range (:obj:`slice`): R
ange slice.
-
sample_range (:obj:`slice`): Sample r
ange slice.
- ignore_insufficient (:obj:`bool`): If ignore_insufficient is true, sampling more than buffer size
with no repetition will not cause an exception.
Returns:
...
...
ding/worker/buffer/deque_buffer.py
浏览文件 @
9c67db8b
import
enum
import
copy
from
typing
import
Any
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
collections
import
deque
from
ding.worker.buffer
import
Buffer
,
apply_middleware
,
BufferedData
import
itertools
import
random
import
uuid
import
logging
from
ding.worker.buffer
import
Buffer
,
apply_middleware
,
BufferedData
class
DequeBuffer
(
Buffer
):
...
...
@@ -14,9 +16,10 @@ class DequeBuffer(Buffer):
self
.
storage
=
deque
(
maxlen
=
size
)
@
apply_middleware
(
"push"
)
def
push
(
self
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
)
->
None
:
def
push
(
self
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
)
->
str
:
index
=
uuid
.
uuid1
().
hex
self
.
storage
.
append
(
BufferedData
(
data
=
data
,
index
=
index
,
meta
=
meta
))
return
index
@
apply_middleware
(
"sample"
)
def
sample
(
...
...
@@ -24,12 +27,12 @@ class DequeBuffer(Buffer):
size
:
Optional
[
int
]
=
None
,
indices
:
Optional
[
List
[
str
]]
=
None
,
replace
:
bool
=
False
,
range
:
Optional
[
slice
]
=
None
,
ignore_insufficient
:
bool
=
False
sample_
range
:
Optional
[
slice
]
=
None
,
ignore_insufficient
:
bool
=
False
,
)
->
List
[
BufferedData
]:
storage
=
self
.
storage
if
range
:
storage
=
list
(
itertools
.
islice
(
self
.
storage
,
range
.
start
,
range
.
stop
,
range
.
step
))
if
sample_
range
:
storage
=
list
(
itertools
.
islice
(
self
.
storage
,
sample_range
.
start
,
sample_range
.
stop
,
sample_
range
.
step
))
assert
size
or
indices
,
"One of size and indices must not be empty."
if
(
size
and
indices
)
and
(
size
!=
len
(
indices
)):
raise
AssertionError
(
"Size and indices length must be equal."
)
...
...
@@ -40,6 +43,17 @@ class DequeBuffer(Buffer):
sampled_data
=
[]
if
indices
:
sampled_data
=
list
(
filter
(
lambda
item
:
item
.
index
in
indices
,
self
.
storage
))
# for the same indices
if
len
(
indices
)
!=
len
(
set
(
indices
)):
sampled_data_no_same
=
sampled_data
sampled_data
=
[
sampled_data_no_same
[
0
]]
j
=
0
for
i
in
range
(
1
,
len
(
indices
)):
if
indices
[
i
-
1
]
==
indices
[
i
]:
sampled_data
.
append
(
copy
.
deepcopy
(
sampled_data_no_same
[
j
]))
else
:
sampled_data
.
append
(
sampled_data_no_same
[
j
])
j
+=
1
else
:
if
replace
:
sampled_data
=
random
.
choices
(
storage
,
k
=
size
)
...
...
@@ -49,8 +63,17 @@ class DequeBuffer(Buffer):
except
ValueError
as
e
:
value_error
=
e
if
not
ignore_insufficient
and
(
value_error
or
len
(
sampled_data
)
!=
size
):
raise
ValueError
(
"There are less than {} data in buffer"
.
format
(
size
))
if
value_error
or
len
(
sampled_data
)
!=
size
:
if
ignore_insufficient
:
logging
.
warning
(
"Sample operation is ignored due to data insufficient, current buffer count is {} while sample size is {}"
.
format
(
self
.
count
(),
size
)
)
else
:
if
value_error
:
raise
ValueError
(
"Some errors in sample operation"
)
from
value_error
else
:
raise
ValueError
(
"There are less than {} data in buffer({})"
.
format
(
size
,
self
.
count
()))
return
sampled_data
...
...
ding/worker/buffer/middleware/clone_object.py
浏览文件 @
9c67db8b
...
...
@@ -63,7 +63,7 @@ def clone_object():
"""
fastcopy
=
FastCopy
()
def
push
(
chain
:
Callable
,
data
:
Any
,
*
args
,
**
kwargs
)
->
None
:
def
push
(
chain
:
Callable
,
data
:
Any
,
*
args
,
**
kwargs
)
->
Any
:
data
=
fastcopy
.
copy
(
data
)
return
chain
(
data
,
*
args
,
**
kwargs
)
...
...
ding/worker/buffer/middleware/priority.py
浏览文件 @
9c67db8b
...
...
@@ -16,6 +16,7 @@ class PriorityExperienceReplay:
IS_weight_anneal_train_iter
:
int
=
int
(
1e5
)
)
->
None
:
self
.
buffer
=
buffer
self
.
buffer_idx
=
{}
self
.
buffer_size
=
buffer_size
self
.
IS_weight
=
IS_weight
self
.
priority_power_factor
=
priority_power_factor
...
...
@@ -32,7 +33,7 @@ class PriorityExperienceReplay:
self
.
delta_anneal
=
(
1
-
self
.
IS_weight_power_factor
)
/
self
.
IS_weight_anneal_train_iter
self
.
pivot
=
0
def
push
(
self
,
chain
:
Callable
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
,
*
args
,
**
kwargs
)
->
None
:
def
push
(
self
,
chain
:
Callable
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
,
*
args
,
**
kwargs
)
->
Any
:
if
meta
is
None
:
meta
=
{
'priority'
:
self
.
max_priority
}
else
:
...
...
@@ -40,8 +41,10 @@ class PriorityExperienceReplay:
meta
[
'priority'
]
=
self
.
max_priority
meta
[
'priority_idx'
]
=
self
.
pivot
self
.
_update_tree
(
meta
[
'priority'
],
self
.
pivot
)
index
=
chain
(
data
,
meta
=
meta
,
*
args
,
**
kwargs
)
self
.
buffer_idx
[
self
.
pivot
]
=
index
self
.
pivot
=
(
self
.
pivot
+
1
)
%
self
.
buffer_size
return
chain
(
data
,
meta
=
meta
,
*
args
,
**
kwargs
)
return
index
def
sample
(
self
,
chain
:
Callable
,
size
:
int
,
*
args
,
**
kwargs
)
->
List
[
Any
]:
# Divide [0, 1) into size intervals on average
...
...
@@ -51,8 +54,9 @@ class PriorityExperienceReplay:
# Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree)
mass
*=
self
.
sum_tree
.
reduce
()
indices
=
[
self
.
sum_tree
.
find_prefixsum_idx
(
m
)
for
m
in
mass
]
# TODO sample with indices
data
=
chain
(
size
,
*
args
,
**
kwargs
)
indices
=
[
self
.
buffer_idx
[
i
]
for
i
in
indices
]
# sample with indices
data
=
chain
(
indices
=
indices
,
*
args
,
**
kwargs
)
if
self
.
IS_weight
:
# Calculate max weight for normalizing IS
sum_tree_root
=
self
.
sum_tree
.
reduce
()
...
...
@@ -83,6 +87,7 @@ class PriorityExperienceReplay:
priority_idx
=
meta
[
'priority_idx'
]
self
.
sum_tree
[
priority_idx
]
=
self
.
sum_tree
.
neutral_element
self
.
min_tree
[
priority_idx
]
=
self
.
min_tree
.
neutral_element
self
.
buffer_idx
.
pop
(
priority_idx
)
return
chain
(
index
,
*
args
,
**
kwargs
)
def
clear
(
self
,
chain
:
Callable
)
->
None
:
...
...
@@ -91,6 +96,7 @@ class PriorityExperienceReplay:
self
.
sum_tree
=
SumSegmentTree
(
capacity
)
if
self
.
IS_weight
:
self
.
min_tree
=
MinSegmentTree
(
capacity
)
self
.
buffer_idx
=
{}
self
.
pivot
=
0
chain
()
...
...
@@ -105,14 +111,15 @@ class PriorityExperienceReplay:
'IS_weight_power_factor'
:
self
.
IS_weight_power_factor
,
'sumtree'
:
self
.
sumtree
,
'mintree'
:
self
.
mintree
,
'buffer_idx'
:
self
.
buffer_idx
,
}
def
load_state_dict
(
self
,
_state_dict
:
Dict
,
deepcopy
:
bool
=
False
)
->
None
:
for
k
,
v
in
_state_dict
.
items
():
if
deepcopy
:
setattr
(
self
,
'
_
{}'
.
format
(
k
),
copy
.
deepcopy
(
v
))
setattr
(
self
,
'{}'
.
format
(
k
),
copy
.
deepcopy
(
v
))
else
:
setattr
(
self
,
'
_
{}'
.
format
(
k
),
v
)
setattr
(
self
,
'{}'
.
format
(
k
),
v
)
def
priority
(
*
per_args
,
**
per_kwargs
):
...
...
ding/worker/buffer/middleware/staleness_check.py
浏览文件 @
9c67db8b
...
...
@@ -9,7 +9,7 @@ def staleness_check(buffer_: 'Buffer', max_staleness: int = float("inf")) -> Cal
If data's staleness is greater(>) than max_staleness, this data will be removed from buffer as soon as possible.
"""
def
push
(
next
:
Callable
,
data
:
Any
,
*
args
,
**
kwargs
)
->
None
:
def
push
(
next
:
Callable
,
data
:
Any
,
*
args
,
**
kwargs
)
->
Any
:
assert
'meta'
in
kwargs
and
'train_iter_data_collected'
in
kwargs
[
'meta'
],
"staleness_check middleware must push data with meta={'train_iter_data_collected': <iter>}"
return
next
(
data
,
*
args
,
**
kwargs
)
...
...
ding/worker/buffer/middleware/use_time_check.py
浏览文件 @
9c67db8b
...
...
@@ -8,7 +8,7 @@ def use_time_check(buffer_: 'Buffer', max_use: int = float("inf")) -> Callable:
greater than or equal to max_use, this data will be removed from buffer as soon as possible.
"""
def
push
(
chain
:
Callable
,
data
:
Any
,
meta
:
dict
=
None
,
*
args
,
**
kwargs
)
->
None
:
def
push
(
chain
:
Callable
,
data
:
Any
,
meta
:
dict
=
None
,
*
args
,
**
kwargs
)
->
Any
:
if
meta
:
meta
[
"use_count"
]
=
0
else
:
...
...
ding/worker/buffer/tests/test_buffer.py
浏览文件 @
9c67db8b
...
...
@@ -76,8 +76,8 @@ def test_naive_push_sample():
buffer
.
clear
()
for
i
in
range
(
10
):
buffer
.
push
(
i
)
assert
len
(
buffer
.
sample
(
5
,
range
=
slice
(
5
,
10
)))
==
5
assert
0
not
in
[
item
.
data
for
item
in
buffer
.
sample
(
5
,
range
=
slice
(
5
,
10
))]
assert
len
(
buffer
.
sample
(
5
,
sample_
range
=
slice
(
5
,
10
)))
==
5
assert
0
not
in
[
item
.
data
for
item
in
buffer
.
sample
(
5
,
sample_
range
=
slice
(
5
,
10
))]
@
pytest
.
mark
.
unittest
...
...
@@ -162,6 +162,6 @@ def test_ignore_insufficient():
buffer
.
push
(
i
)
with
pytest
.
raises
(
ValueError
):
buffer
.
sample
(
3
)
buffer
.
sample
(
3
,
ignore_insufficient
=
False
)
data
=
buffer
.
sample
(
3
,
ignore_insufficient
=
True
)
assert
len
(
data
)
==
0
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录