Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
39540b0e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
39540b0e
编写于
9月 08, 2021
作者:
L
lilong12
提交者:
GitHub
9月 08, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add checkers for auto parallel apis (#35486)
* update, test=develop
上级
c4a3e8b4
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
29 addition
and
6 deletion
+29
-6
python/paddle/distributed/auto_parallel/interface.py
python/paddle/distributed/auto_parallel/interface.py
+27
-4
python/paddle/fluid/tests/unittests/test_auto_parallel_api.py
...on/paddle/fluid/tests/unittests/test_auto_parallel_api.py
+2
-2
未找到文件。
python/paddle/distributed/auto_parallel/interface.py
浏览文件 @
39540b0e
...
...
@@ -271,12 +271,22 @@ class ProcessMesh(object):
def
_dim_mapping_checker
(
tensor
,
mesh
,
dim_mapping
):
assert
len
(
tensor
.
shape
)
==
len
(
dim_mapping
)
assert
isinstance
(
mesh
,
ProcessMesh
),
'The type of mesh must be ProcessMesh.'
assert
isinstance
(
dim_mapping
,
list
),
'The type of dim_mapping must be list.'
assert
len
(
tensor
.
shape
)
==
len
(
dim_mapping
),
(
'The number of dimensions '
'of tensor must be the same as the length of its corresponding '
'dim_mapping.'
)
mesh_dim
=
len
(
mesh
.
topology
)
dim_set
=
set
()
for
i
in
range
(
len
(
dim_mapping
)):
assert
dim_mapping
[
i
]
==
-
1
or
(
dim_mapping
[
i
]
<
mesh_dim
and
dim_mapping
[
i
]
>=
0
)
assert
dim_mapping
[
i
]
==
-
1
or
(
dim_mapping
[
i
]
<
mesh_dim
and
dim_mapping
[
i
]
>=
0
),
(
'Each element '
'in dim_mapping must be greater than zero and less than the '
'length of its corresponding topology, or it must be -1.'
)
if
dim_mapping
[
i
]
>=
0
:
assert
dim_mapping
[
i
]
not
in
dim_set
dim_set
.
add
(
dim_mapping
[
i
])
...
...
@@ -347,6 +357,7 @@ def set_shard_mask(x, mask):
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
mask = [[1, 0, 1], [0, 1, 0]]
x = paddle.ones([4, 6])
dist.shard_tensor(x, mesh, [-1, 1])
dist.set_shard_mask(x, mask)
"""
...
...
@@ -355,6 +366,9 @@ def set_shard_mask(x, mask):
np_mask
=
numpy
.
array
(
mask
)
min_ele
=
numpy
.
min
(
np_mask
)
max_ele
=
numpy
.
max
(
np_mask
)
mesh_attr_name
=
_append_attr_suffix
(
'mesh_id'
)
assert
x
.
_has_attr
(
mesh_attr_name
),
\
"Please set process mesh for the variable firstly."
assert
min_ele
>=
0
and
max_ele
<=
1
,
"Elements in mask must be 0 or 1."
x_mesh
=
x
.
process_mesh
assert
x_mesh
,
"Please set process mesh for the variable firstly."
...
...
@@ -403,7 +417,15 @@ def shard_op(op_fn, mesh, dim_mapping_dict, **kwargs):
op_size
=
len
(
main_block
.
ops
)
output
=
op_fn
(
**
kwargs
)
new_op_size
=
len
(
main_block
.
ops
)
if
dim_mapping_dict
is
None
:
dim_mapping_dict
=
dict
()
if
dim_mapping_dict
is
None
:
dim_mapping_dict
=
dict
()
else
:
assert
isinstance
(
dim_mapping_dict
,
dict
),
'The type of dim_mapping_dict must be dict.'
for
var_name
in
dim_mapping_dict
.
keys
():
dim_mapping
=
dim_mapping_dict
[
var_name
]
tensor
=
main_block
.
var
(
var_name
)
_dim_mapping_checker
(
tensor
,
mesh
,
dim_mapping
)
for
idx
in
range
(
op_size
,
new_op_size
):
op
=
main_block
.
ops
[
idx
]
attr_name
=
_append_attr_suffix
(
'mesh_id'
)
...
...
@@ -477,4 +499,5 @@ def set_pipeline_stage(stage):
"""
from
paddle.fluid.framework
import
_set_pipeline_stage
_static_mode_check
()
assert
isinstance
(
stage
,
int
),
'The type of stage must be int.'
_set_pipeline_stage
(
stage
)
python/paddle/fluid/tests/unittests/test_auto_parallel_api.py
浏览文件 @
39540b0e
...
...
@@ -97,8 +97,8 @@ class TestAutoParallelAPI(unittest.TestCase):
self
.
assertEqual
(
last_op
.
pipeline_stage
,
LAST_PP_STAGE
)
DIMS_MAPPING1
=
[
0
,
1
,
-
1
]
DIMS_MAPPING2
=
[
-
1
,
2
,
0
]
DIMS_MAPPING1
=
[
0
,
1
]
DIMS_MAPPING2
=
[
-
1
,
0
]
kwargs
=
{
'x'
:
data2
,
'y'
:
data3
}
dist
.
shard_op
(
paddle
.
add
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录