Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
638f5110
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 接近 3 年
通知
65
Star
322
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
638f5110
编写于
11月 02, 2021
作者:
N
niuyazhe
提交者:
Xu Jingxin
11月 05, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feature(nyz): add naive priority experience replay
上级
d066372a
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
153 addition
and
1 deletion
+153
-1
ding/worker/buffer/middleware/__init__.py
ding/worker/buffer/middleware/__init__.py
+1
-0
ding/worker/buffer/middleware/priority.py
ding/worker/buffer/middleware/priority.py
+125
-0
ding/worker/buffer/tests/test_middleware.py
ding/worker/buffer/tests/test_middleware.py
+27
-1
未找到文件。
ding/worker/buffer/middleware/__init__.py
浏览文件 @
638f5110
from
.clone_object
import
clone_object
from
.use_time_check
import
use_time_check
from
.staleness_check
import
staleness_check
from
.priority
import
priority
ding/worker/buffer/middleware/priority.py
0 → 100644
浏览文件 @
638f5110
from
typing
import
Callable
,
Any
,
List
,
Dict
,
Optional
import
copy
import
numpy
as
np
from
ding.utils
import
SumSegmentTree
,
MinSegmentTree
class
PriorityExperienceReplay
:
def
__init__
(
self
,
buffer
:
'Buffer'
,
# noqa
buffer_size
:
int
,
IS_weight
:
bool
=
True
,
priority_power_factor
:
float
=
0.6
,
IS_weight_power_factor
:
float
=
0.4
,
IS_weight_anneal_train_iter
:
int
=
int
(
1e5
)
)
->
None
:
self
.
buffer
=
buffer
self
.
buffer_size
=
buffer_size
self
.
IS_weight
=
IS_weight
self
.
priority_power_factor
=
priority_power_factor
self
.
IS_weight_power_factor
=
IS_weight_power_factor
self
.
IS_weight_anneal_train_iter
=
IS_weight_anneal_train_iter
# Max priority till now, it's used to initizalize data's priority if "priority" is not passed in with the data.
self
.
max_priority
=
1.0
# Capacity needs to be the power of 2.
capacity
=
int
(
np
.
power
(
2
,
np
.
ceil
(
np
.
log2
(
self
.
buffer_size
))))
self
.
sum_tree
=
SumSegmentTree
(
capacity
)
if
self
.
IS_weight
:
self
.
min_tree
=
MinSegmentTree
(
capacity
)
self
.
delta_anneal
=
(
1
-
self
.
IS_weight_power_factor
)
/
self
.
IS_weight_anneal_train_iter
self
.
pivot
=
0
def
push
(
self
,
chain
:
Callable
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
,
*
args
,
**
kwargs
)
->
None
:
if
meta
is
None
:
meta
=
{
'priority'
:
self
.
max_priority
}
else
:
if
'priority'
not
in
meta
:
meta
[
'priority'
]
=
self
.
max_priority
meta
[
'priority_idx'
]
=
self
.
pivot
self
.
_update_tree
(
meta
[
'priority'
],
self
.
pivot
)
self
.
pivot
=
(
self
.
pivot
+
1
)
%
self
.
buffer_size
return
chain
(
data
,
meta
=
meta
,
*
args
,
**
kwargs
)
def
sample
(
self
,
chain
:
Callable
,
size
:
int
,
*
args
,
**
kwargs
)
->
List
[
Any
]:
# Divide [0, 1) into size intervals on average
intervals
=
np
.
array
([
i
*
1.0
/
size
for
i
in
range
(
size
)])
# Uniformly sample within each interval
mass
=
intervals
+
np
.
random
.
uniform
(
size
=
(
size
,
))
*
1.
/
size
# Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree)
mass
*=
self
.
sum_tree
.
reduce
()
indices
=
[
self
.
sum_tree
.
find_prefixsum_idx
(
m
)
for
m
in
mass
]
# TODO sample with indices
data
=
chain
(
size
,
return_index
=
True
,
return_meta
=
True
,
*
args
,
**
kwargs
)
if
self
.
IS_weight
:
# Calculate max weight for normalizing IS
sum_tree_root
=
self
.
sum_tree
.
reduce
()
p_min
=
self
.
min_tree
.
reduce
()
/
sum_tree_root
buffer_count
=
self
.
buffer
.
count
()
max_weight
=
(
buffer_count
*
p_min
)
**
(
-
self
.
IS_weight_power_factor
)
for
i
in
range
(
len
(
data
)):
meta
=
data
[
i
][
-
1
]
priority_idx
=
meta
[
'priority_idx'
]
p_sample
=
self
.
sum_tree
[
priority_idx
]
/
sum_tree_root
weight
=
(
buffer_count
*
p_sample
)
**
(
-
self
.
IS_weight_power_factor
)
meta
[
'priority_IS'
]
=
weight
/
max_weight
self
.
IS_weight_power_factor
=
min
(
1.0
,
self
.
IS_weight_power_factor
+
self
.
delta_anneal
)
return
data
def
update
(
self
,
chain
:
Callable
,
index
:
str
,
data
:
Any
,
meta
:
dict
,
*
args
,
**
kwargs
)
->
None
:
update_flag
=
chain
(
index
,
data
,
meta
,
*
args
,
**
kwargs
)
if
update_flag
:
# when update succeed
new_priority
,
idx
=
meta
[
'priority'
],
meta
[
'priority_idx'
]
assert
new_priority
>=
0
,
"new_priority should greater than 0, but found {}"
.
format
(
new_priority
)
new_priority
+=
1e-5
# Add epsilon to avoid priority == 0
self
.
_update_tree
(
new_priority
,
idx
)
self
.
max_priority
=
max
(
self
.
max_priority
,
new_priority
)
def
delete
(
self
,
chain
:
Callable
,
index
:
str
,
*
args
,
**
kwargs
)
->
None
:
for
(
_
,
_
,
meta
)
in
self
.
buffer
.
storage
:
priority_idx
=
meta
[
'priority_idx'
]
self
.
sum_tree
[
priority_idx
]
=
self
.
sum_tree
.
neutral_element
self
.
min_tree
[
priority_idx
]
=
self
.
min_tree
.
neutral_element
return
chain
(
index
,
*
args
,
**
kwargs
)
def
clear
(
self
,
chain
:
Callable
)
->
None
:
self
.
max_priority
=
1.0
capacity
=
int
(
np
.
power
(
2
,
np
.
ceil
(
np
.
log2
(
self
.
buffer_size
))))
self
.
sum_tree
=
SumSegmentTree
(
capacity
)
if
self
.
IS_weight
:
self
.
min_tree
=
MinSegmentTree
(
capacity
)
self
.
pivot
=
0
chain
()
def
_update_tree
(
self
,
priority
:
float
,
idx
:
int
)
->
None
:
weight
=
priority
**
self
.
priority_power_factor
self
.
sum_tree
[
idx
]
=
weight
self
.
min_tree
[
idx
]
=
weight
def
state_dict
(
self
)
->
Dict
:
return
{
'max_priority'
:
self
.
max_priority
,
'IS_weight_power_factor'
:
self
.
IS_weight_power_factor
,
'sumtree'
:
self
.
sumtree
,
'mintree'
:
self
.
mintree
,
}
def
load_state_dict
(
self
,
_state_dict
:
Dict
,
deepcopy
:
bool
=
False
)
->
None
:
for
k
,
v
in
_state_dict
.
items
():
if
deepcopy
:
setattr
(
self
,
'_{}'
.
format
(
k
),
copy
.
deepcopy
(
v
))
else
:
setattr
(
self
,
'_{}'
.
format
(
k
),
v
)
def
priority
(
*
per_args
,
**
per_kwargs
):
per
=
PriorityExperienceReplay
(
*
per_args
,
**
per_kwargs
)
def
_priority
(
action
:
str
,
chain
:
Callable
,
*
args
,
**
kwargs
)
->
Any
:
if
action
in
[
"push"
,
"sample"
,
"update"
,
"delete"
,
"clear"
]:
return
getattr
(
per
,
action
)(
chain
,
*
args
,
**
kwargs
)
return
chain
(
chain
,
*
args
,
**
kwargs
)
return
_priority
ding/worker/buffer/tests/test_middleware.py
浏览文件 @
638f5110
import
pytest
import
torch
from
ding.worker.buffer
import
Buffer
,
DequeStorage
from
ding.worker.buffer.middleware
import
clone_object
,
use_time_check
,
staleness_check
from
ding.worker.buffer.middleware
import
clone_object
,
use_time_check
,
staleness_check
,
priority
@
pytest
.
mark
.
unittest
...
...
@@ -76,3 +76,29 @@ def test_staleness_check():
with
pytest
.
raises
(
ValueError
):
data
=
buffer
.
sample
(
size
=
N
,
replace
=
False
,
train_iter_sample_data
=
11
)
assert
buffer
.
count
()
==
2
@
pytest
.
mark
.
unittest
def
test_priority
():
N
=
5
buffer
=
Buffer
(
DequeStorage
(
maxlen
=
10
))
buffer
.
use
(
priority
(
buffer
,
buffer_size
=
10
,
IS_weight
=
True
))
for
_
in
range
(
N
):
buffer
.
push
(
get_data
())
assert
buffer
.
count
()
==
N
for
_
in
range
(
N
):
buffer
.
push
(
get_data
(),
meta
=
{
'priority'
:
2.0
})
assert
buffer
.
count
()
==
N
+
N
data
=
buffer
.
sample
(
size
=
N
+
N
,
replace
=
False
)
assert
len
(
data
)
==
N
+
N
for
(
item
,
_
,
meta
)
in
data
:
assert
set
(
meta
.
keys
()).
issuperset
(
set
([
'priority'
,
'priority_idx'
,
'priority_IS'
]))
meta
[
'priority'
]
=
3.0
for
item
,
index
,
meta
in
data
:
buffer
.
update
(
index
,
item
,
meta
)
data
=
buffer
.
sample
(
size
=
1
)
assert
data
[
0
][
2
][
'priority'
]
==
3.0
buffer
.
delete
(
data
[
0
][
1
])
assert
buffer
.
count
()
==
N
+
N
-
1
buffer
.
clear
()
assert
buffer
.
count
()
==
0
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录