Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
d3f1a516
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 接近 3 年
通知
65
Star
322
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d3f1a516
编写于
10月 26, 2021
作者:
X
Xu Jingxin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add buffer copy middleware
上级
e9246b74
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
118 addition
and
2 deletion
+118
-2
ding/worker/buffer/memory_storage.py
ding/worker/buffer/memory_storage.py
+5
-2
ding/worker/buffer/middlewares/__init__.py
ding/worker/buffer/middlewares/__init__.py
+1
-0
ding/worker/buffer/middlewares/clone_object.py
ding/worker/buffer/middlewares/clone_object.py
+76
-0
ding/worker/buffer/tests/test_middleware.py
ding/worker/buffer/tests/test_middleware.py
+36
-0
未找到文件。
ding/worker/buffer/memory_storage.py
浏览文件 @
d3f1a516
...
...
@@ -2,8 +2,8 @@ from typing import Any, List
from
collections
import
deque
from
operator
import
itemgetter
from
ding.worker.buffer
import
Storage
import
numpy
as
np
import
itertools
import
random
class
MemoryStorage
(
Storage
):
...
...
@@ -21,7 +21,10 @@ class MemoryStorage(Storage):
storage
=
self
.
storage
if
range
:
storage
=
list
(
itertools
.
islice
(
self
.
storage
,
range
.
start
,
range
.
stop
,
range
.
step
))
return
np
.
random
.
choice
(
storage
,
size
,
replace
=
replace
)
if
replace
:
return
random
.
choices
(
storage
,
k
=
size
)
else
:
return
random
.
sample
(
storage
,
k
=
size
)
def
count
(
self
)
->
int
:
return
len
(
self
.
storage
)
...
...
ding/worker/buffer/middlewares/__init__.py
0 → 100644
浏览文件 @
d3f1a516
from
.clone_object
import
clone_object
ding/worker/buffer/middlewares/clone_object.py
0 → 100644
浏览文件 @
d3f1a516
from
typing
import
Callable
,
Any
,
List
import
torch
import
numpy
as
np
class
FastCopy
:
"""
The idea of this class comes from this article
https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list.
We use recursive calls to copy each object that needs to be copied, which will be 5x faster
than copy.deepcopy.
"""
def
__init__
(
self
):
dispatch
=
{}
dispatch
[
list
]
=
self
.
_copy_list
dispatch
[
dict
]
=
self
.
_copy_dict
dispatch
[
torch
.
Tensor
]
=
self
.
_copy_tensor
dispatch
[
np
.
ndarray
]
=
self
.
_copy_ndarray
self
.
dispatch
=
dispatch
def
_copy_list
(
self
,
l
:
List
)
->
dict
:
ret
=
l
.
copy
()
for
idx
,
item
in
enumerate
(
ret
):
cp
=
self
.
dispatch
.
get
(
type
(
item
))
if
cp
is
not
None
:
ret
[
idx
]
=
cp
(
item
)
return
ret
def
_copy_dict
(
self
,
d
:
dict
)
->
dict
:
ret
=
d
.
copy
()
for
key
,
value
in
ret
.
items
():
cp
=
self
.
dispatch
.
get
(
type
(
value
))
if
cp
is
not
None
:
ret
[
key
]
=
cp
(
value
)
return
ret
def
_copy_tensor
(
self
,
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
t
.
clone
()
def
_copy_ndarray
(
self
,
a
:
np
.
ndarray
)
->
np
.
ndarray
:
return
np
.
copy
(
a
)
def
copy
(
self
,
sth
:
Any
)
->
Any
:
cp
=
self
.
dispatch
.
get
(
type
(
sth
))
if
cp
is
None
:
return
sth
else
:
return
cp
(
sth
)
def
clone_object
():
"""
This middleware freezes the objects saved in memory buffer as a copy,
try this middleware when you need to keep the object unchanged in buffer, and modify
the object after sampling it (usuallly in multiple threads)
"""
fastcopy
=
FastCopy
()
def
push
(
next
:
Callable
,
data
:
Any
,
*
args
,
**
kwargs
)
->
None
:
data
=
fastcopy
.
copy
(
data
)
return
next
(
data
,
*
args
,
**
kwargs
)
def
sample
(
next
:
Callable
,
*
args
,
**
kwargs
)
->
List
[
Any
]:
data
=
next
(
*
args
,
**
kwargs
)
return
fastcopy
.
copy
(
data
)
def
_immutable_object
(
action
:
str
,
next
:
Callable
,
*
args
,
**
kwargs
):
if
action
==
"push"
:
return
push
(
next
,
*
args
,
**
kwargs
)
elif
action
==
"sample"
:
return
sample
(
next
,
*
args
,
**
kwargs
)
return
next
(
*
args
,
**
kwargs
)
return
_immutable_object
ding/worker/buffer/tests/test_middleware.py
0 → 100644
浏览文件 @
d3f1a516
import
pytest
import
torch
from
ding.worker.buffer
import
Buffer
,
MemoryStorage
from
ding.worker.buffer.middlewares
import
clone_object
@
pytest
.
mark
.
unittest
def
test_clone_object
():
buffer
=
Buffer
(
MemoryStorage
(
maxlen
=
10
)).
use
(
clone_object
())
# Store a dict, a list, a tensor
arr
=
[{
"key"
:
"v1"
},
[
"a"
],
torch
.
Tensor
([
1
,
2
,
3
])]
for
o
in
arr
:
buffer
.
push
(
o
)
# Modify it
for
item
in
buffer
.
sample
(
len
(
arr
)):
if
isinstance
(
item
,
dict
):
item
[
"key"
]
=
"v2"
elif
isinstance
(
item
,
list
):
item
.
append
(
"b"
)
elif
isinstance
(
item
,
torch
.
Tensor
):
item
[
0
]
=
3
else
:
raise
Exception
(
"Unexpected type"
)
# Resample it, and check their values
for
item
in
buffer
.
sample
(
len
(
arr
)):
if
isinstance
(
item
,
dict
):
assert
item
[
"key"
]
==
"v1"
elif
isinstance
(
item
,
list
):
assert
len
(
item
)
==
1
elif
isinstance
(
item
,
torch
.
Tensor
):
assert
item
[
0
]
==
1
else
:
raise
Exception
(
"Unexpected type"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录