Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
3d698d0f
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 接近 3 年
通知
67
Star
322
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
3d698d0f
编写于
11月 08, 2021
作者:
X
Xu Jingxin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix sample with indices, ensure return size is equal to input size or indices size
上级
1572fd3e
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
84 addition
and
83 deletion
+84
-83
ding/worker/buffer/buffer.py
ding/worker/buffer/buffer.py
+2
-2
ding/worker/buffer/deque_buffer.py
ding/worker/buffer/deque_buffer.py
+12
-20
ding/worker/buffer/middleware/clone_object.py
ding/worker/buffer/middleware/clone_object.py
+3
-56
ding/worker/buffer/middleware/priority.py
ding/worker/buffer/middleware/priority.py
+8
-5
ding/worker/buffer/utils/__init__.py
ding/worker/buffer/utils/__init__.py
+1
-0
ding/worker/buffer/utils/fast_copy.py
ding/worker/buffer/utils/fast_copy.py
+58
-0
未找到文件。
ding/worker/buffer/buffer.py
浏览文件 @
3d698d0f
...
@@ -52,7 +52,7 @@ class Buffer:
...
@@ -52,7 +52,7 @@ class Buffer:
self
.
middleware
=
[]
self
.
middleware
=
[]
@
abstractmethod
@
abstractmethod
def
push
(
self
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
)
->
Any
:
def
push
(
self
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
)
->
BufferedData
:
"""
"""
Overview:
Overview:
Push data and it's meta information in buffer.
Push data and it's meta information in buffer.
...
@@ -60,7 +60,7 @@ class Buffer:
...
@@ -60,7 +60,7 @@ class Buffer:
- data (:obj:`Any`): The data which will be pushed into buffer.
- data (:obj:`Any`): The data which will be pushed into buffer.
- meta (:obj:`dict`): Meta information, e.g. priority, count, staleness.
- meta (:obj:`dict`): Meta information, e.g. priority, count, staleness.
Returns:
Returns:
-
index (:obj:`Any`): The index of
pushed data.
-
buffered_data (:obj:`BufferedData`): The
pushed data.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
...
...
ding/worker/buffer/deque_buffer.py
浏览文件 @
3d698d0f
from
typing
import
Any
,
Iterable
,
List
,
Optional
,
Union
from
typing
import
Any
,
Iterable
,
List
,
Optional
,
Union
from
collections
import
deque
from
collections
import
deque
from
ding.worker.buffer
import
Buffer
,
apply_middleware
,
BufferedData
import
itertools
import
itertools
import
random
import
random
import
uuid
import
uuid
import
logging
import
logging
from
ding.worker.buffer
import
Buffer
,
apply_middleware
,
BufferedData
class
DequeBuffer
(
Buffer
):
class
DequeBuffer
(
Buffer
):
...
@@ -14,10 +14,11 @@ class DequeBuffer(Buffer):
...
@@ -14,10 +14,11 @@ class DequeBuffer(Buffer):
self
.
storage
=
deque
(
maxlen
=
size
)
self
.
storage
=
deque
(
maxlen
=
size
)
@
apply_middleware
(
"push"
)
@
apply_middleware
(
"push"
)
def
push
(
self
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
)
->
str
:
def
push
(
self
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
)
->
BufferedData
:
index
=
uuid
.
uuid1
().
hex
index
=
uuid
.
uuid1
().
hex
self
.
storage
.
append
(
BufferedData
(
data
=
data
,
index
=
index
,
meta
=
meta
))
buffered
=
BufferedData
(
data
=
data
,
index
=
index
,
meta
=
meta
)
return
index
self
.
storage
.
append
(
buffered
)
return
buffered
@
apply_middleware
(
"sample"
)
@
apply_middleware
(
"sample"
)
def
sample
(
def
sample
(
...
@@ -40,18 +41,12 @@ class DequeBuffer(Buffer):
...
@@ -40,18 +41,12 @@ class DequeBuffer(Buffer):
value_error
=
None
value_error
=
None
sampled_data
=
[]
sampled_data
=
[]
if
indices
:
if
indices
:
sampled_data
=
list
(
filter
(
lambda
item
:
item
.
index
in
indices
,
self
.
storage
))
indices_set
=
set
(
indices
)
# for the same indices
hashed_data
=
filter
(
lambda
item
:
item
.
index
in
indices_set
,
self
.
storage
)
if
len
(
indices
)
!=
len
(
set
(
indices
)):
hashed_data
=
map
(
lambda
item
:
(
item
.
index
,
item
),
hashed_data
)
sampled_data_no_same
=
sampled_data
hashed_data
=
dict
(
hashed_data
)
sampled_data
=
[
sampled_data_no_same
[
0
]]
# Re-sample and return in indices order
j
=
0
sampled_data
=
[
hashed_data
[
index
]
for
index
in
indices
]
for
i
in
range
(
1
,
len
(
indices
)):
if
indices
[
i
-
1
]
==
indices
[
i
]:
sampled_data
.
append
(
copy
.
deepcopy
(
sampled_data_no_same
[
j
]))
else
:
sampled_data
.
append
(
sampled_data_no_same
[
j
])
j
+=
1
else
:
else
:
if
replace
:
if
replace
:
sampled_data
=
random
.
choices
(
storage
,
k
=
size
)
sampled_data
=
random
.
choices
(
storage
,
k
=
size
)
...
@@ -67,9 +62,6 @@ class DequeBuffer(Buffer):
...
@@ -67,9 +62,6 @@ class DequeBuffer(Buffer):
"Sample operation is ignored due to data insufficient, current buffer count is {} while sample size is {}"
"Sample operation is ignored due to data insufficient, current buffer count is {} while sample size is {}"
.
format
(
self
.
count
(),
size
)
.
format
(
self
.
count
(),
size
)
)
)
else
:
if
value_error
:
raise
ValueError
(
"Some errors in sample operation"
)
from
value_error
else
:
else
:
raise
ValueError
(
"There are less than {} data in buffer({})"
.
format
(
size
,
self
.
count
()))
raise
ValueError
(
"There are less than {} data in buffer({})"
.
format
(
size
,
self
.
count
()))
...
...
ding/worker/buffer/middleware/clone_object.py
浏览文件 @
3d698d0f
from
typing
import
Callable
,
Any
,
List
from
typing
import
Callable
,
Any
,
List
from
ding.worker.buffer
import
BufferedData
from
ding.worker.buffer
import
BufferedData
import
torch
from
ding.worker.buffer.utils
import
fastcopy
import
numpy
as
np
class
FastCopy
:
"""
The idea of this class comes from this article
https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list.
We use recursive calls to copy each object that needs to be copied, which will be 5x faster
than copy.deepcopy.
"""
def
__init__
(
self
):
dispatch
=
{}
dispatch
[
list
]
=
self
.
_copy_list
dispatch
[
dict
]
=
self
.
_copy_dict
dispatch
[
torch
.
Tensor
]
=
self
.
_copy_tensor
dispatch
[
np
.
ndarray
]
=
self
.
_copy_ndarray
dispatch
[
BufferedData
]
=
self
.
_copy_buffereddata
self
.
dispatch
=
dispatch
def
_copy_list
(
self
,
l
:
List
)
->
dict
:
ret
=
l
.
copy
()
for
idx
,
item
in
enumerate
(
ret
):
cp
=
self
.
dispatch
.
get
(
type
(
item
))
if
cp
is
not
None
:
ret
[
idx
]
=
cp
(
item
)
return
ret
def
_copy_dict
(
self
,
d
:
dict
)
->
dict
:
ret
=
d
.
copy
()
for
key
,
value
in
ret
.
items
():
cp
=
self
.
dispatch
.
get
(
type
(
value
))
if
cp
is
not
None
:
ret
[
key
]
=
cp
(
value
)
return
ret
def
_copy_tensor
(
self
,
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
t
.
clone
()
def
_copy_ndarray
(
self
,
a
:
np
.
ndarray
)
->
np
.
ndarray
:
return
np
.
copy
(
a
)
def
_copy_buffereddata
(
self
,
d
:
BufferedData
)
->
BufferedData
:
return
BufferedData
(
data
=
self
.
copy
(
d
.
data
),
index
=
d
.
index
,
meta
=
self
.
copy
(
d
.
meta
))
def
copy
(
self
,
sth
:
Any
)
->
Any
:
cp
=
self
.
dispatch
.
get
(
type
(
sth
))
if
cp
is
None
:
return
sth
else
:
return
cp
(
sth
)
def
clone_object
():
def
clone_object
():
...
@@ -61,13 +9,12 @@ def clone_object():
...
@@ -61,13 +9,12 @@ def clone_object():
try this middleware when you need to keep the object unchanged in buffer, and modify
try this middleware when you need to keep the object unchanged in buffer, and modify
the object after sampling it (usuallly in multiple threads)
the object after sampling it (usuallly in multiple threads)
"""
"""
fastcopy
=
FastCopy
()
def
push
(
chain
:
Callable
,
data
:
Any
,
*
args
,
**
kwargs
)
->
Any
:
def
push
(
chain
:
Callable
,
data
:
Any
,
*
args
,
**
kwargs
)
->
BufferedData
:
data
=
fastcopy
.
copy
(
data
)
data
=
fastcopy
.
copy
(
data
)
return
chain
(
data
,
*
args
,
**
kwargs
)
return
chain
(
data
,
*
args
,
**
kwargs
)
def
sample
(
chain
:
Callable
,
*
args
,
**
kwargs
)
->
List
[
Any
]:
def
sample
(
chain
:
Callable
,
*
args
,
**
kwargs
)
->
List
[
BufferedData
]:
data
=
chain
(
*
args
,
**
kwargs
)
data
=
chain
(
*
args
,
**
kwargs
)
return
fastcopy
.
copy
(
data
)
return
fastcopy
.
copy
(
data
)
...
...
ding/worker/buffer/middleware/priority.py
浏览文件 @
3d698d0f
from
collections
import
defaultdict
from
typing
import
Callable
,
Any
,
List
,
Dict
,
Optional
from
typing
import
Callable
,
Any
,
List
,
Dict
,
Optional
import
copy
import
copy
import
numpy
as
np
import
numpy
as
np
from
ding.utils
import
SumSegmentTree
,
MinSegmentTree
from
ding.utils
import
SumSegmentTree
,
MinSegmentTree
from
ding.worker.buffer.buffer
import
BufferedData
class
PriorityExperienceReplay
:
class
PriorityExperienceReplay
:
...
@@ -33,7 +35,7 @@ class PriorityExperienceReplay:
...
@@ -33,7 +35,7 @@ class PriorityExperienceReplay:
self
.
delta_anneal
=
(
1
-
self
.
IS_weight_power_factor
)
/
self
.
IS_weight_anneal_train_iter
self
.
delta_anneal
=
(
1
-
self
.
IS_weight_power_factor
)
/
self
.
IS_weight_anneal_train_iter
self
.
pivot
=
0
self
.
pivot
=
0
def
push
(
self
,
chain
:
Callable
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
,
*
args
,
**
kwargs
)
->
Any
:
def
push
(
self
,
chain
:
Callable
,
data
:
Any
,
meta
:
Optional
[
dict
]
=
None
,
*
args
,
**
kwargs
)
->
BufferedData
:
if
meta
is
None
:
if
meta
is
None
:
meta
=
{
'priority'
:
self
.
max_priority
}
meta
=
{
'priority'
:
self
.
max_priority
}
else
:
else
:
...
@@ -41,12 +43,13 @@ class PriorityExperienceReplay:
...
@@ -41,12 +43,13 @@ class PriorityExperienceReplay:
meta
[
'priority'
]
=
self
.
max_priority
meta
[
'priority'
]
=
self
.
max_priority
meta
[
'priority_idx'
]
=
self
.
pivot
meta
[
'priority_idx'
]
=
self
.
pivot
self
.
_update_tree
(
meta
[
'priority'
],
self
.
pivot
)
self
.
_update_tree
(
meta
[
'priority'
],
self
.
pivot
)
index
=
chain
(
data
,
meta
=
meta
,
*
args
,
**
kwargs
)
buffered
=
chain
(
data
,
meta
=
meta
,
*
args
,
**
kwargs
)
index
=
buffered
.
index
self
.
buffer_idx
[
self
.
pivot
]
=
index
self
.
buffer_idx
[
self
.
pivot
]
=
index
self
.
pivot
=
(
self
.
pivot
+
1
)
%
self
.
buffer_size
self
.
pivot
=
(
self
.
pivot
+
1
)
%
self
.
buffer_size
return
index
return
buffered
def
sample
(
self
,
chain
:
Callable
,
size
:
int
,
*
args
,
**
kwargs
)
->
List
[
Any
]:
def
sample
(
self
,
chain
:
Callable
,
size
:
int
,
*
args
,
**
kwargs
)
->
List
[
BufferedData
]:
# Divide [0, 1) into size intervals on average
# Divide [0, 1) into size intervals on average
intervals
=
np
.
array
([
i
*
1.0
/
size
for
i
in
range
(
size
)])
intervals
=
np
.
array
([
i
*
1.0
/
size
for
i
in
range
(
size
)])
# Uniformly sample within each interval
# Uniformly sample within each interval
...
@@ -55,7 +58,7 @@ class PriorityExperienceReplay:
...
@@ -55,7 +58,7 @@ class PriorityExperienceReplay:
mass
*=
self
.
sum_tree
.
reduce
()
mass
*=
self
.
sum_tree
.
reduce
()
indices
=
[
self
.
sum_tree
.
find_prefixsum_idx
(
m
)
for
m
in
mass
]
indices
=
[
self
.
sum_tree
.
find_prefixsum_idx
(
m
)
for
m
in
mass
]
indices
=
[
self
.
buffer_idx
[
i
]
for
i
in
indices
]
indices
=
[
self
.
buffer_idx
[
i
]
for
i
in
indices
]
#
s
ample with indices
#
S
ample with indices
data
=
chain
(
indices
=
indices
,
*
args
,
**
kwargs
)
data
=
chain
(
indices
=
indices
,
*
args
,
**
kwargs
)
if
self
.
IS_weight
:
if
self
.
IS_weight
:
# Calculate max weight for normalizing IS
# Calculate max weight for normalizing IS
...
...
ding/worker/buffer/utils/__init__.py
0 → 100644
浏览文件 @
3d698d0f
from
.fast_copy
import
FastCopy
,
fastcopy
ding/worker/buffer/utils/fast_copy.py
0 → 100644
浏览文件 @
3d698d0f
import
torch
import
numpy
as
np
from
typing
import
Any
,
List
from
ding.worker.buffer.buffer
import
BufferedData
class
FastCopy
:
"""
The idea of this class comes from this article
https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list.
We use recursive calls to copy each object that needs to be copied, which will be 5x faster
than copy.deepcopy.
"""
def
__init__
(
self
):
dispatch
=
{}
dispatch
[
list
]
=
self
.
_copy_list
dispatch
[
dict
]
=
self
.
_copy_dict
dispatch
[
torch
.
Tensor
]
=
self
.
_copy_tensor
dispatch
[
np
.
ndarray
]
=
self
.
_copy_ndarray
dispatch
[
BufferedData
]
=
self
.
_copy_buffereddata
self
.
dispatch
=
dispatch
def
_copy_list
(
self
,
l
:
List
)
->
dict
:
ret
=
l
.
copy
()
for
idx
,
item
in
enumerate
(
ret
):
cp
=
self
.
dispatch
.
get
(
type
(
item
))
if
cp
is
not
None
:
ret
[
idx
]
=
cp
(
item
)
return
ret
def
_copy_dict
(
self
,
d
:
dict
)
->
dict
:
ret
=
d
.
copy
()
for
key
,
value
in
ret
.
items
():
cp
=
self
.
dispatch
.
get
(
type
(
value
))
if
cp
is
not
None
:
ret
[
key
]
=
cp
(
value
)
return
ret
def
_copy_tensor
(
self
,
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
t
.
clone
()
def
_copy_ndarray
(
self
,
a
:
np
.
ndarray
)
->
np
.
ndarray
:
return
np
.
copy
(
a
)
def
_copy_buffereddata
(
self
,
d
:
BufferedData
)
->
BufferedData
:
return
BufferedData
(
data
=
self
.
copy
(
d
.
data
),
index
=
d
.
index
,
meta
=
self
.
copy
(
d
.
meta
))
def
copy
(
self
,
sth
:
Any
)
->
Any
:
cp
=
self
.
dispatch
.
get
(
type
(
sth
))
if
cp
is
None
:
return
sth
else
:
return
cp
(
sth
)
fastcopy
=
FastCopy
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录