Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a6a4895a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a6a4895a
编写于
12月 20, 2022
作者:
Y
Yulong Ao
提交者:
GitHub
12月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[0D Tensor] Add tests of 0D Tensor for allgather and allreduce (#49175)
上级
495c1fc0
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
102 addition
and
0 deletion
+102
-0
python/paddle/fluid/tests/unittests/collective/process_group_nccl.py
...le/fluid/tests/unittests/collective/process_group_nccl.py
+102
-0
未找到文件。
python/paddle/fluid/tests/unittests/collective/process_group_nccl.py
浏览文件 @
a6a4895a
...
...
@@ -68,6 +68,24 @@ class TestProcessGroupFp32(unittest.TestCase):
print
(
"test allreduce sum api ok"
)
# test allreduce sum with shape = []
# rank 0
x
=
np
.
random
.
random
([]).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
# rank 1
y
=
np
.
random
.
random
([]).
astype
(
self
.
dtype
)
tensor_y
=
paddle
.
to_tensor
(
y
)
sum_result
=
tensor_x
+
tensor_y
if
pg
.
rank
()
==
0
:
task
=
dist
.
all_reduce
(
tensor_x
)
assert
np
.
array_equal
(
tensor_x
,
sum_result
)
else
:
task
=
dist
.
all_reduce
(
tensor_y
)
assert
np
.
array_equal
(
tensor_y
,
sum_result
)
print
(
"test allreduce sum api with = [] ok"
)
# test allreduce max
# rank 0
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
...
...
@@ -89,6 +107,27 @@ class TestProcessGroupFp32(unittest.TestCase):
print
(
"test allreduce max api ok"
)
# test allreduce max with shape = []
# rank 0
x
=
np
.
random
.
random
([]).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
# rank 1
y
=
np
.
random
.
random
([]).
astype
(
self
.
dtype
)
tensor_y
=
paddle
.
to_tensor
(
y
)
max_result
=
paddle
.
maximum
(
tensor_x
,
tensor_y
)
if
pg
.
rank
()
==
0
:
task
=
dist
.
all_reduce
(
tensor_x
,
dist
.
ReduceOp
.
MAX
,
sync_op
=
False
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_x
,
max_result
)
else
:
task
=
dist
.
all_reduce
(
tensor_y
,
dist
.
ReduceOp
.
MAX
,
sync_op
=
False
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_y
,
max_result
)
print
(
"test allreduce max api with shape = [] ok"
)
# test allreduce min
# rank 0
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
...
...
@@ -110,6 +149,27 @@ class TestProcessGroupFp32(unittest.TestCase):
print
(
"test allreduce min api ok"
)
# test allreduce min with shape = []
# rank 0
x
=
np
.
random
.
random
([]).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
# rank 1
y
=
np
.
random
.
random
([]).
astype
(
self
.
dtype
)
tensor_y
=
paddle
.
to_tensor
(
y
)
min_result
=
paddle
.
minimum
(
tensor_x
,
tensor_y
)
if
pg
.
rank
()
==
0
:
task
=
dist
.
all_reduce
(
tensor_x
,
dist
.
ReduceOp
.
MIN
,
sync_op
=
False
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_x
,
min_result
)
else
:
task
=
dist
.
all_reduce
(
tensor_y
,
dist
.
ReduceOp
.
MIN
,
sync_op
=
False
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_y
,
min_result
)
print
(
"test allreduce min api with shape [] ok"
)
# test allreduce prod
# rank 0
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
...
...
@@ -131,6 +191,27 @@ class TestProcessGroupFp32(unittest.TestCase):
print
(
"test allreduce prod api ok"
)
# test allreduce prod with shape = []
# rank 0
x
=
np
.
random
.
random
([]).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
# rank 1
y
=
np
.
random
.
random
([]).
astype
(
self
.
dtype
)
tensor_y
=
paddle
.
to_tensor
(
y
)
prod_result
=
np
.
multiply
(
x
,
y
)
if
pg
.
rank
()
==
0
:
task
=
dist
.
all_reduce
(
tensor_x
,
dist
.
ReduceOp
.
PROD
,
sync_op
=
False
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_x
,
prod_result
)
else
:
task
=
dist
.
all_reduce
(
tensor_y
,
dist
.
ReduceOp
.
PROD
,
sync_op
=
False
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_y
,
prod_result
)
print
(
"test allreduce prod api with shape = [] ok"
)
# test broadcast
# rank 0
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
...
...
@@ -236,6 +317,27 @@ class TestProcessGroupFp32(unittest.TestCase):
assert
np
.
array_equal
(
tensor_y
,
out_2
)
print
(
"test allgather api2 ok
\n
"
)
# test allgather with shape = []
# rank 0
x
=
np
.
random
.
random
([]).
astype
(
self
.
dtype
)
y
=
np
.
random
.
random
([]).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
tensor_y
=
paddle
.
to_tensor
(
y
)
tensor_out_list
=
[]
if
pg
.
rank
()
==
0
:
task
=
dist
.
all_gather
(
tensor_out_list
,
tensor_x
)
task
.
wait
()
paddle
.
device
.
cuda
.
synchronize
()
# rank 1
else
:
task
=
dist
.
all_gather
(
tensor_out_list
,
tensor_y
,
sync_op
=
False
)
paddle
.
device
.
cuda
.
synchronize
()
out_1
=
tensor_out_list
[
0
]
out_2
=
tensor_out_list
[
1
]
assert
np
.
array_equal
(
tensor_x
,
out_1
)
assert
np
.
array_equal
(
tensor_y
,
out_2
)
print
(
"test allgather api with shape [] ok
\n
"
)
# test alltoall
# rank 0
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录