Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
3d698d0f
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 接近 3 年
通知
67
Star
322
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
3d698d0f
编写于
11月 08, 2021
作者:
X
Xu Jingxin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix sample with indices, ensure return size is equal to input size or indices size
上级
1572fd3e
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
84 addition
and
83 deletion
+84
-83
ding/worker/buffer/buffer.py
ding/worker/buffer/buffer.py
+2
-2
ding/worker/buffer/deque_buffer.py
ding/worker/buffer/deque_buffer.py
+12
-20
ding/worker/buffer/middleware/clone_object.py
ding/worker/buffer/middleware/clone_object.py
+3
-56
ding/worker/buffer/middleware/priority.py
ding/worker/buffer/middleware/priority.py
+8
-5
ding/worker/buffer/utils/__init__.py
ding/worker/buffer/utils/__init__.py
+1
-0
ding/worker/buffer/utils/fast_copy.py
ding/worker/buffer/utils/fast_copy.py
+58
-0
未找到文件。
ding/worker/buffer/buffer.py
浏览文件 @
3d698d0f
...
...
@@ -52,7 +52,7 @@ class Buffer:
self
.
middleware
=
[]
@
abstractmethod
def
push
(
self
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
)
->
Any
:
def
push
(
self
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
)
->
BufferedData
:
"""
Overview:
Push data and it's meta information in buffer.
...
...
@@ -60,7 +60,7 @@ class Buffer:
- data (:obj:`Any`): The data which will be pushed into buffer.
- meta (:obj:`dict`): Meta information, e.g. priority, count, staleness.
Returns:
-
index (:obj:`Any`): The index of
pushed data.
-
buffered_data (:obj:`BufferedData`): The
pushed data.
"""
raise
NotImplementedError
...
...
ding/worker/buffer/deque_buffer.py
浏览文件 @
3d698d0f
from
typing
import
Any
,
Iterable
,
List
,
Optional
,
Union
from
collections
import
deque
from
ding.worker.buffer
import
Buffer
,
apply_middleware
,
BufferedData
import
itertools
import
random
import
uuid
import
logging
from
ding.worker.buffer
import
Buffer
,
apply_middleware
,
BufferedData
class
DequeBuffer
(
Buffer
):
...
...
@@ -14,10 +14,11 @@ class DequeBuffer(Buffer):
self
.
storage
=
deque
(
maxlen
=
size
)
@
apply_middleware
(
"push"
)
def
push
(
self
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
)
->
str
:
def
push
(
self
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
)
->
BufferedData
:
index
=
uuid
.
uuid1
().
hex
self
.
storage
.
append
(
BufferedData
(
data
=
data
,
index
=
index
,
meta
=
meta
))
return
index
buffered
=
BufferedData
(
data
=
data
,
index
=
index
,
meta
=
meta
)
self
.
storage
.
append
(
buffered
)
return
buffered
@
apply_middleware
(
"sample"
)
def
sample
(
...
...
@@ -40,18 +41,12 @@ class DequeBuffer(Buffer):
value_error
=
None
sampled_data
=
[]
if
indices
:
sampled_data
=
list
(
filter
(
lambda
item
:
item
.
index
in
indices
,
self
.
storage
))
# for the same indices
if
len
(
indices
)
!=
len
(
set
(
indices
)):
sampled_data_no_same
=
sampled_data
sampled_data
=
[
sampled_data_no_same
[
0
]]
j
=
0
for
i
in
range
(
1
,
len
(
indices
)):
if
indices
[
i
-
1
]
==
indices
[
i
]:
sampled_data
.
append
(
copy
.
deepcopy
(
sampled_data_no_same
[
j
]))
else
:
sampled_data
.
append
(
sampled_data_no_same
[
j
])
j
+=
1
indices_set
=
set
(
indices
)
hashed_data
=
filter
(
lambda
item
:
item
.
index
in
indices_set
,
self
.
storage
)
hashed_data
=
map
(
lambda
item
:
(
item
.
index
,
item
),
hashed_data
)
hashed_data
=
dict
(
hashed_data
)
# Re-sample and return in indices order
sampled_data
=
[
hashed_data
[
index
]
for
index
in
indices
]
else
:
if
replace
:
sampled_data
=
random
.
choices
(
storage
,
k
=
size
)
...
...
@@ -67,9 +62,6 @@ class DequeBuffer(Buffer):
"Sample operation is ignored due to data insufficient, current buffer count is {} while sample size is {}"
.
format
(
self
.
count
(),
size
)
)
else
:
if
value_error
:
raise
ValueError
(
"Some errors in sample operation"
)
from
value_error
else
:
raise
ValueError
(
"There are less than {} data in buffer({})"
.
format
(
size
,
self
.
count
()))
...
...
ding/worker/buffer/middleware/clone_object.py
浏览文件 @
3d698d0f
from
typing
import
Callable
,
Any
,
List
from
ding.worker.buffer
import
BufferedData
import
torch
import
numpy
as
np
class
FastCopy
:
"""
The idea of this class comes from this article
https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list.
We use recursive calls to copy each object that needs to be copied, which will be 5x faster
than copy.deepcopy.
"""
def
__init__
(
self
):
dispatch
=
{}
dispatch
[
list
]
=
self
.
_copy_list
dispatch
[
dict
]
=
self
.
_copy_dict
dispatch
[
torch
.
Tensor
]
=
self
.
_copy_tensor
dispatch
[
np
.
ndarray
]
=
self
.
_copy_ndarray
dispatch
[
BufferedData
]
=
self
.
_copy_buffereddata
self
.
dispatch
=
dispatch
def
_copy_list
(
self
,
l
:
List
)
->
dict
:
ret
=
l
.
copy
()
for
idx
,
item
in
enumerate
(
ret
):
cp
=
self
.
dispatch
.
get
(
type
(
item
))
if
cp
is
not
None
:
ret
[
idx
]
=
cp
(
item
)
return
ret
def
_copy_dict
(
self
,
d
:
dict
)
->
dict
:
ret
=
d
.
copy
()
for
key
,
value
in
ret
.
items
():
cp
=
self
.
dispatch
.
get
(
type
(
value
))
if
cp
is
not
None
:
ret
[
key
]
=
cp
(
value
)
return
ret
def
_copy_tensor
(
self
,
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
t
.
clone
()
def
_copy_ndarray
(
self
,
a
:
np
.
ndarray
)
->
np
.
ndarray
:
return
np
.
copy
(
a
)
def
_copy_buffereddata
(
self
,
d
:
BufferedData
)
->
BufferedData
:
return
BufferedData
(
data
=
self
.
copy
(
d
.
data
),
index
=
d
.
index
,
meta
=
self
.
copy
(
d
.
meta
))
def
copy
(
self
,
sth
:
Any
)
->
Any
:
cp
=
self
.
dispatch
.
get
(
type
(
sth
))
if
cp
is
None
:
return
sth
else
:
return
cp
(
sth
)
from
ding.worker.buffer.utils
import
fastcopy
def
clone_object
():
...
...
@@ -61,13 +9,12 @@ def clone_object():
try this middleware when you need to keep the object unchanged in buffer, and modify
the object after sampling it (usuallly in multiple threads)
"""
fastcopy
=
FastCopy
()
def
push
(
chain
:
Callable
,
data
:
Any
,
*
args
,
**
kwargs
)
->
Any
:
def
push
(
chain
:
Callable
,
data
:
Any
,
*
args
,
**
kwargs
)
->
BufferedData
:
data
=
fastcopy
.
copy
(
data
)
return
chain
(
data
,
*
args
,
**
kwargs
)
def
sample
(
chain
:
Callable
,
*
args
,
**
kwargs
)
->
List
[
Any
]:
def
sample
(
chain
:
Callable
,
*
args
,
**
kwargs
)
->
List
[
BufferedData
]:
data
=
chain
(
*
args
,
**
kwargs
)
return
fastcopy
.
copy
(
data
)
...
...
ding/worker/buffer/middleware/priority.py
浏览文件 @
3d698d0f
from
collections
import
defaultdict
from
typing
import
Callable
,
Any
,
List
,
Dict
,
Optional
import
copy
import
numpy
as
np
from
ding.utils
import
SumSegmentTree
,
MinSegmentTree
from
ding.worker.buffer.buffer
import
BufferedData
class
PriorityExperienceReplay
:
...
...
@@ -33,7 +35,7 @@ class PriorityExperienceReplay:
self
.
delta_anneal
=
(
1
-
self
.
IS_weight_power_factor
)
/
self
.
IS_weight_anneal_train_iter
self
.
pivot
=
0
def
push
(
self
,
chain
:
Callable
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
,
*
args
,
**
kwargs
)
->
Any
:
def
push
(
self
,
chain
:
Callable
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
,
*
args
,
**
kwargs
)
->
BufferedData
:
if
meta
is
None
:
meta
=
{
'priority'
:
self
.
max_priority
}
else
:
...
...
@@ -41,12 +43,13 @@ class PriorityExperienceReplay:
meta
[
'priority'
]
=
self
.
max_priority
meta
[
'priority_idx'
]
=
self
.
pivot
self
.
_update_tree
(
meta
[
'priority'
],
self
.
pivot
)
index
=
chain
(
data
,
meta
=
meta
,
*
args
,
**
kwargs
)
buffered
=
chain
(
data
,
meta
=
meta
,
*
args
,
**
kwargs
)
index
=
buffered
.
index
self
.
buffer_idx
[
self
.
pivot
]
=
index
self
.
pivot
=
(
self
.
pivot
+
1
)
%
self
.
buffer_size
return
index
return
buffered
def
sample
(
self
,
chain
:
Callable
,
size
:
int
,
*
args
,
**
kwargs
)
->
List
[
Any
]:
def
sample
(
self
,
chain
:
Callable
,
size
:
int
,
*
args
,
**
kwargs
)
->
List
[
BufferedData
]:
# Divide [0, 1) into size intervals on average
intervals
=
np
.
array
([
i
*
1.0
/
size
for
i
in
range
(
size
)])
# Uniformly sample within each interval
...
...
@@ -55,7 +58,7 @@ class PriorityExperienceReplay:
mass
*=
self
.
sum_tree
.
reduce
()
indices
=
[
self
.
sum_tree
.
find_prefixsum_idx
(
m
)
for
m
in
mass
]
indices
=
[
self
.
buffer_idx
[
i
]
for
i
in
indices
]
#
s
ample with indices
#
S
ample with indices
data
=
chain
(
indices
=
indices
,
*
args
,
**
kwargs
)
if
self
.
IS_weight
:
# Calculate max weight for normalizing IS
...
...
ding/worker/buffer/utils/__init__.py
0 → 100644
浏览文件 @
3d698d0f
from
.fast_copy
import
FastCopy
,
fastcopy
ding/worker/buffer/utils/fast_copy.py
0 → 100644
浏览文件 @
3d698d0f
import
torch
import
numpy
as
np
from
typing
import
Any
,
List
from
ding.worker.buffer.buffer
import
BufferedData
class
FastCopy
:
"""
The idea of this class comes from this article
https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list.
We use recursive calls to copy each object that needs to be copied, which will be 5x faster
than copy.deepcopy.
"""
def
__init__
(
self
):
dispatch
=
{}
dispatch
[
list
]
=
self
.
_copy_list
dispatch
[
dict
]
=
self
.
_copy_dict
dispatch
[
torch
.
Tensor
]
=
self
.
_copy_tensor
dispatch
[
np
.
ndarray
]
=
self
.
_copy_ndarray
dispatch
[
BufferedData
]
=
self
.
_copy_buffereddata
self
.
dispatch
=
dispatch
def
_copy_list
(
self
,
l
:
List
)
->
dict
:
ret
=
l
.
copy
()
for
idx
,
item
in
enumerate
(
ret
):
cp
=
self
.
dispatch
.
get
(
type
(
item
))
if
cp
is
not
None
:
ret
[
idx
]
=
cp
(
item
)
return
ret
def
_copy_dict
(
self
,
d
:
dict
)
->
dict
:
ret
=
d
.
copy
()
for
key
,
value
in
ret
.
items
():
cp
=
self
.
dispatch
.
get
(
type
(
value
))
if
cp
is
not
None
:
ret
[
key
]
=
cp
(
value
)
return
ret
def
_copy_tensor
(
self
,
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
t
.
clone
()
def
_copy_ndarray
(
self
,
a
:
np
.
ndarray
)
->
np
.
ndarray
:
return
np
.
copy
(
a
)
def
_copy_buffereddata
(
self
,
d
:
BufferedData
)
->
BufferedData
:
return
BufferedData
(
data
=
self
.
copy
(
d
.
data
),
index
=
d
.
index
,
meta
=
self
.
copy
(
d
.
meta
))
def
copy
(
self
,
sth
:
Any
)
->
Any
:
cp
=
self
.
dispatch
.
get
(
type
(
sth
))
if
cp
is
None
:
return
sth
else
:
return
cp
(
sth
)
fastcopy
=
FastCopy
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录