Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e48cb42b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e48cb42b
编写于
8月 01, 2022
作者:
L
LiYuRio
提交者:
GitHub
8月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix all_gather_object with various length, test=allcases (#44718)
上级
3e8708bc
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
22 addition
and
7 deletion
+22
-7
python/paddle/distributed/collective.py
python/paddle/distributed/collective.py
+20
-7
python/paddle/fluid/tests/unittests/test_collective_api_base.py
.../paddle/fluid/tests/unittests/test_collective_api_base.py
+2
-0
未找到文件。
python/paddle/distributed/collective.py
浏览文件 @
e48cb42b
...
...
@@ -1032,12 +1032,12 @@ def _convert_object_to_tensor(obj):
_pickler
(
f
).
dump
(
obj
)
data
=
np
.
frombuffer
(
f
.
getvalue
(),
dtype
=
np
.
uint8
)
tensor
=
paddle
.
to_tensor
(
data
)
return
tensor
return
tensor
,
tensor
.
numel
()
def
_convert_tensor_to_object
(
tensor
):
def
_convert_tensor_to_object
(
tensor
,
len_of_tensor
):
_unpickler
=
pickle
.
Unpickler
return
_unpickler
(
io
.
BytesIO
(
tensor
.
numpy
())).
load
()
return
_unpickler
(
io
.
BytesIO
(
tensor
.
numpy
()
[:
len_of_tensor
]
)).
load
()
def
all_gather_object
(
object_list
,
obj
,
group
=
None
):
...
...
@@ -1076,12 +1076,25 @@ def all_gather_object(object_list, obj, group=None):
assert
in_dygraph_mode
(
),
"all_gather_object doesn't support static graph mode."
tensor
=
_convert_object_to_tensor
(
obj
)
tensor
,
len_of_tensor
=
_convert_object_to_tensor
(
obj
)
# gather len_of_tensor from all ranks
list_len_of_tensor
=
[]
all_gather
(
list_len_of_tensor
,
len_of_tensor
,
group
)
# get the max length from list
max_len_of_tensor
=
int
(
max
(
list_len_of_tensor
).
item
())
# resize the input tensor to max length avoid hang in all gather
# Note(liyurui): Maybe we should support various length all_gather?
# Now this operation is efficient for we don't support resize in python.
numpy_data
=
tensor
.
numpy
()
numpy_data
=
np
.
resize
(
numpy_data
,
[
max_len_of_tensor
])
input_tensor
=
paddle
.
to_tensor
(
numpy_data
)
tensor_list
=
[]
all_gather
(
tensor_list
,
tensor
,
group
)
for
tensor
in
tensor_list
:
object_list
.
append
(
_convert_tensor_to_object
(
tensor
))
all_gather
(
tensor_list
,
input_tensor
,
group
)
for
i
,
tensor
in
enumerate
(
tensor_list
):
object_list
.
append
(
_convert_tensor_to_object
(
tensor
,
list_len_of_tensor
[
i
]))
def
scatter
(
tensor
,
tensor_list
=
None
,
src
=
0
,
group
=
None
,
use_calc_stream
=
True
):
...
...
python/paddle/fluid/tests/unittests/test_collective_api_base.py
浏览文件 @
e48cb42b
...
...
@@ -63,6 +63,8 @@ def create_complex_test_data(shape=None, dtype=None, seed=None):
def
create_pylist_test_data
(
shape
=
None
,
seed
=
None
):
if
seed
:
np
.
random
.
seed
(
seed
)
# Generate random shape test case for xxx_object api
shape
=
np
.
random
.
randint
(
0
,
high
=
100
,
size
=
(
2
)).
tolist
()
data
=
np
.
random
.
random
(
shape
).
tolist
()
return
data
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录