Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
7e9e4e88
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
61
Star
322
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
7e9e4e88
编写于
8月 31, 2021
作者:
N
niuyazhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feature(nyz): add sample_range arg in replay buffer
上级
608fee41
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
52 addition
and
12 deletion
+52
-12
ding/worker/replay_buffer/advanced_buffer.py
ding/worker/replay_buffer/advanced_buffer.py
+23
-6
ding/worker/replay_buffer/naive_buffer.py
ding/worker/replay_buffer/naive_buffer.py
+13
-5
ding/worker/replay_buffer/tests/test_advanced_buffer.py
ding/worker/replay_buffer/tests/test_advanced_buffer.py
+7
-0
ding/worker/replay_buffer/tests/test_naive_buffer.py
ding/worker/replay_buffer/tests/test_naive_buffer.py
+9
-1
未找到文件。
ding/worker/replay_buffer/advanced_buffer.py
浏览文件 @
7e9e4e88
...
...
@@ -10,6 +10,13 @@ from ding.utils.autolog import TickTime
from
.utils
import
UsedDataRemover
,
generate_id
,
SampledDataAttrMonitor
,
PeriodicThruputMonitor
,
ThruputController
def
to_positive_index
(
idx
:
Union
[
int
,
None
],
size
:
int
)
->
int
:
if
idx
is
None
or
idx
>=
0
:
return
idx
else
:
return
size
+
idx
@
BUFFER_REGISTRY
.
register
(
'advanced'
)
class
AdvancedReplayBuffer
(
IBuffer
):
r
"""
...
...
@@ -206,13 +213,15 @@ class AdvancedReplayBuffer(IBuffer):
if
self
.
_enable_track_used_data
:
self
.
_used_data_remover
.
close
()
def
sample
(
self
,
size
:
int
,
cur_learner_iter
:
int
)
->
Optional
[
list
]:
r
"""
def
sample
(
self
,
size
:
int
,
cur_learner_iter
:
int
,
sample_range
:
slice
=
None
)
->
Optional
[
list
]:
"""
Overview:
Sample data with length ``size``.
Arguments:
- size (:obj:`int`): The number of the data that will be sampled.
- cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness.
- sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which
\
means only sample among the last 10 data
Returns:
- sample_data (:obj:`list`): A list of data with length ``size``
ReturnsKeys:
...
...
@@ -235,7 +244,7 @@ class AdvancedReplayBuffer(IBuffer):
)
return
None
with
self
.
_lock
:
indices
=
self
.
_get_indices
(
size
)
indices
=
self
.
_get_indices
(
size
,
sample_range
)
result
=
self
.
_sample_with_indices
(
indices
,
cur_learner_iter
)
# Deepcopy ``result``'s same indice datas in case ``self._get_indices`` may get datas with
# the same indices, i.e. the same datas would be sampled afterwards.
...
...
@@ -498,7 +507,7 @@ class AdvancedReplayBuffer(IBuffer):
# only the data passes all the check functions, would the check return True
return
all
([
fn
(
d
)
for
fn
in
self
.
check_list
])
def
_get_indices
(
self
,
size
:
int
)
->
list
:
def
_get_indices
(
self
,
size
:
int
,
sample_range
:
slice
=
None
)
->
list
:
r
"""
Overview:
Get the sample index list according to the priority probability.
...
...
@@ -511,8 +520,16 @@ class AdvancedReplayBuffer(IBuffer):
intervals
=
np
.
array
([
i
*
1.0
/
size
for
i
in
range
(
size
)])
# Uniformly sample within each interval
mass
=
intervals
+
np
.
random
.
uniform
(
size
=
(
size
,
))
*
1.
/
size
# Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree)
mass
*=
self
.
_sum_tree
.
reduce
()
if
sample_range
is
None
:
# Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree)
mass
*=
self
.
_sum_tree
.
reduce
()
else
:
# Rescale to [a, b)
start
=
to_positive_index
(
sample_range
.
start
,
self
.
_replay_buffer_size
)
end
=
to_positive_index
(
sample_range
.
stop
,
self
.
_replay_buffer_size
)
a
=
self
.
_sum_tree
.
reduce
(
0
,
start
)
b
=
self
.
_sum_tree
.
reduce
(
0
,
end
)
mass
=
mass
*
(
b
-
a
)
+
a
# Find prefix sum index to sample with probability
return
[
self
.
_sum_tree
.
find_prefixsum_idx
(
m
)
for
m
in
mass
]
...
...
ding/worker/replay_buffer/naive_buffer.py
浏览文件 @
7e9e4e88
...
...
@@ -95,14 +95,16 @@ class NaiveReplayBuffer(IBuffer):
else
:
self
.
_append
(
data
,
cur_collector_envstep
)
def
sample
(
self
,
size
:
int
,
cur_learner_iter
:
int
)
->
Optional
[
list
]:
r
"""
def
sample
(
self
,
size
:
int
,
cur_learner_iter
:
int
,
sample_range
:
slice
=
None
)
->
Optional
[
list
]:
"""
Overview:
Sample data with length ``size``.
Arguments:
- size (:obj:`int`): The number of the data that will be sampled.
- cur_learner_iter (:obj:`int`): Learner's current iteration.
\
Not used in naive buffer, but preserved for compatibility.
- sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which
\
means only sample among the last 10 data
Returns:
- sample_data (:obj:`list`): A list of data with length ``size``.
"""
...
...
@@ -112,7 +114,7 @@ class NaiveReplayBuffer(IBuffer):
if
not
can_sample
:
return
None
with
self
.
_lock
:
indices
=
self
.
_get_indices
(
size
)
indices
=
self
.
_get_indices
(
size
,
sample_range
)
result
=
self
.
_sample_with_indices
(
indices
,
cur_learner_iter
)
return
result
...
...
@@ -234,12 +236,14 @@ class NaiveReplayBuffer(IBuffer):
"""
self
.
close
()
def
_get_indices
(
self
,
size
:
int
)
->
list
:
def
_get_indices
(
self
,
size
:
int
,
sample_range
:
slice
=
None
)
->
list
:
r
"""
Overview:
Get the sample index list.
Arguments:
- size (:obj:`int`): The number of the data that will be sampled
- sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \
means only sample among the last 10 data
Returns:
- index_list (:obj:`list`): A list including all the sample indices, whose length should equal to ``size``.
"""
...
...
@@ -248,7 +252,11 @@ class NaiveReplayBuffer(IBuffer):
tail
=
self
.
_replay_buffer_size
else
:
tail
=
self
.
_tail
indices
=
list
(
np
.
random
.
choice
(
a
=
tail
,
size
=
size
,
replace
=
False
))
if
sample_range
is
None
:
indices
=
list
(
np
.
random
.
choice
(
a
=
tail
,
size
=
size
,
replace
=
False
))
else
:
indices
=
list
(
range
(
tail
))[
sample_range
]
indices
=
list
(
np
.
random
.
choice
(
indices
,
size
=
size
,
replace
=
False
))
return
indices
def
_sample_with_indices
(
self
,
indices
:
List
[
int
],
cur_learner_iter
:
int
)
->
list
:
...
...
ding/worker/replay_buffer/tests/test_advanced_buffer.py
浏览文件 @
7e9e4e88
...
...
@@ -129,6 +129,13 @@ class TestAdvancedBuffer:
if
v
>
advanced_buffer
.
_max_use
:
assert
advanced_buffer
.
_data
[
k
]
is
None
for
_
in
range
(
64
):
data
=
generate_data
()
data
[
'priority'
]
=
None
advanced_buffer
.
push
(
data
,
0
)
batch
=
advanced_buffer
.
sample
(
10
,
0
,
sample_range
=
slice
(
-
20
,
-
2
))
assert
len
(
batch
)
==
10
def
test_head_tail
(
self
):
buffer_cfg
=
deep_merge_dicts
(
AdvancedReplayBuffer
.
default_config
(),
EasyDict
(
dict
(
replay_buffer_size
=
64
,
max_use
=
4
))
...
...
ding/worker/replay_buffer/tests/test_naive_buffer.py
浏览文件 @
7e9e4e88
...
...
@@ -43,7 +43,15 @@ class TestNaiveBuffer:
for
_
in
range
(
64
):
naive_buffer
.
push
(
generate_data
(),
0
)
batch
=
naive_buffer
.
sample
(
32
,
0
)
assert
(
len
(
batch
)
==
32
)
assert
len
(
batch
)
==
32
last_one_batch
=
naive_buffer
.
sample
(
1
,
0
,
sample_range
=
slice
(
-
1
,
None
))
assert
len
(
last_one_batch
)
==
1
assert
last_one_batch
[
0
]
==
naive_buffer
.
_data
[
-
1
]
batch
=
naive_buffer
.
sample
(
5
,
0
,
sample_range
=
slice
(
-
10
,
-
2
))
sample_range_data
=
naive_buffer
.
_data
[
-
10
:
-
2
]
assert
len
(
batch
)
==
5
for
b
in
batch
:
assert
any
([
b
[
'data_id'
]
==
d
[
'data_id'
]
for
d
in
sample_range_data
])
# test clear
naive_buffer
.
clear
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录