Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b518fa2a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b518fa2a
编写于
3月 23, 2022
作者:
Y
Yulong Ao
提交者:
GitHub
3月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Add distributed mul op for paddle.fluid.layers.fc (#40207)
* [Auto Parallel] Add distributed mul for the old version
上级
9121115b
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
509 addition
and
0 deletion
+509
-0
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+509
-0
未找到文件。
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
b518fa2a
...
@@ -1482,3 +1482,512 @@ register_distributed_operator_impl("matmul_v2",
...
@@ -1482,3 +1482,512 @@ register_distributed_operator_impl("matmul_v2",
DistributedMatmulV2Impl1
(
"row_parallel"
))
DistributedMatmulV2Impl1
(
"row_parallel"
))
register_distributed_operator_impl
(
register_distributed_operator_impl
(
"matmul_v2"
,
DistributedMatmulV2Impl2
(
"replicate_parallel"
))
"matmul_v2"
,
DistributedMatmulV2Impl2
(
"replicate_parallel"
))
class
DistributedMul
(
DistributedOperatorImplContainer
):
def
__init__
(
self
,
op_type
):
super
(
DistributedMul
,
self
).
__init__
(
op_type
)
register_distributed_operator_impl_container
(
DistributedMul
(
"mul"
))
# ColumnParallel
class
DistributedMulImpl0
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
super
(
DistributedMulImpl0
,
self
).
__init__
(
name
)
self
.
_forward_implemented
=
True
self
.
_backward_implemented
=
True
def
is_input_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
y_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
2
])
or
is_dim_replicate
(
y_dims_mapping
[
-
1
]):
return
False
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
if
is_dim_shard
(
mapping
):
return
False
return
True
def
is_output_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
out_name
=
op_desc
.
output
(
'Out'
)[
0
]
out_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
out_name
)
if
is_dim_replicate
(
out_dims_mapping
[
-
1
]):
return
False
for
mapping
in
out_dims_mapping
[
1
:
-
1
]:
if
is_dim_shard
(
mapping
):
return
False
return
True
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
return
True
def
update_dims_mapping
(
self
,
dist_op
):
changed
=
False
dim_changed
=
_update_dims_mapping_for_matmul
(
dist_op
)
if
dim_changed
:
changed
=
True
return
changed
@
staticmethod
def
forward
(
ctx
,
*
args
,
**
kwargs
):
"""
kwargs: inputname_mapping & outputname_mapping
"""
dist_op_context
=
ctx
.
dist_op_context
main_block
=
dist_op_context
.
work_block
startup_block
=
dist_op_context
.
startup_block
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
assert
op_dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
# check validation of inputs / outputs
for
input_name
in
src_op
.
desc
.
input_names
():
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
input_name
)
assert
len
(
kwargs
[
input_name
])
==
len
(
src_op
.
desc
.
input
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
for
output_name
in
src_op
.
desc
.
output_names
():
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
output_name
)
assert
len
(
kwargs
[
output_name
])
==
len
(
src_op
.
desc
.
output
(
output_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
output_name
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
# TODO infer logic comm presentation
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
assert
matmul_col_dim_mapping
>=
0
,
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_col_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
parallel_axis
=
matmul_col_dim_mapping
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group
=
new_process_group
(
group_ranks
)
# infer new var shape with op dist attr
x_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
X_var
)
assert
x_tensor_dist_attr
is
not
None
identity_var_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
X_var
.
name
)
assert
identity_var_dist_attr
is
not
None
ref_shape_x
=
infer_shape
(
main_block
,
X_var
,
x_tensor_dist_attr
,
identity_var_dist_attr
)
# infer out var shape with op dist attr
out_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
Out_var
)
assert
out_tensor_dist_attr
is
not
None
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
out_var_dist_attr
is
not
None
ref_shape_out
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
[
"c_identity"
,
'tmp'
])),
dtype
=
X_var
.
dtype
,
shape
=
X_var
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
stop_gradient
=
X_var
.
stop_gradient
)
# set intermediate_var_0's dist_attr with X_var's dist_attr
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
identity_var_dist_attr
)
check_variable_and_dtype
(
X_var
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'_c_identity'
)
c_identity_op
=
main_block
.
append_op
(
type
=
'c_identity'
,
inputs
=
{
'X'
:
[
X_var
]},
outputs
=
{
'Out'
:
intermediate_var_0
},
attrs
=
{
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
})
if
intermediate_var_0
.
shape
!=
ref_shape_x
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
# attrs = {'trans_x': False, 'trans_y': False}
attrs
=
{
"x_num_col_dims"
:
src_op
.
desc
.
attr
(
"x_num_col_dims"
),
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
)
}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
mul_op
=
main_block
.
append_op
(
type
=
'mul'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
Out_var
},
attrs
=
attrs
)
if
Out_var
.
shape
!=
ref_shape_out
:
Out_var
.
desc
.
set_shape
(
ref_shape_out
)
# set dist op's dist_attr with serial op's dist_attr
# c_identity
identity_op_dist_attr
=
OperatorDistributedAttribute
()
identity_op_dist_attr
.
process_mesh
=
op_dist_attr
.
process_mesh
identity_op_dist_attr
.
impl_type
=
op_dist_attr
.
impl_type
identity_op_dist_attr
.
impl_idx
=
op_dist_attr
.
impl_idx
# input
input_varname
=
c_identity_op
.
desc
.
input_arg_names
()[
0
]
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
identity_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
# output
output_varname
=
c_identity_op
.
desc
.
output_arg_names
()[
0
]
identity_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
input_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
c_identity_op
,
identity_op_dist_attr
)
# matmulv2
matmulv2_op_dist_attr
=
OperatorDistributedAttribute
()
matmulv2_op_dist_attr
.
process_mesh
=
op_dist_attr
.
process_mesh
matmulv2_op_dist_attr
.
impl_type
=
op_dist_attr
.
impl_type
matmulv2_op_dist_attr
.
impl_idx
=
op_dist_attr
.
impl_idx
for
input_varname
in
mul_op
.
desc
.
input_arg_names
():
if
input_varname
in
src_op
.
desc
.
input_arg_names
():
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
else
:
input_var
=
main_block
.
var
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
tensor_dist_attr
)
for
output_varname
in
mul_op
.
desc
.
output_arg_names
():
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
mul_op
,
matmulv2_op_dist_attr
)
# init param sync
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
_right_operand_parameter_matmul_backward
(
ctx
,
*
args
,
**
kwargs
)
# RowParallel
class
DistributedMulImpl1
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
super
(
DistributedMulImpl1
,
self
).
__init__
(
name
)
self
.
_forward_implemented
=
True
self
.
_backward_implemented
=
True
def
is_input_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
y_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
if
is_dim_replicate
(
x_dims_mapping
[
-
1
]):
return
False
if
is_dim_replicate
(
y_dims_mapping
[
-
2
])
or
is_dim_shard
(
y_dims_mapping
[
-
1
]):
return
False
# Other dimensions must be replicate except the batch dimension
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
if
is_dim_shard
(
mapping
):
return
False
return
True
def
is_output_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
out_name
=
op_desc
.
output
(
'Out'
)[
0
]
out_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
out_name
)
if
is_dim_shard
(
out_dims_mapping
[
-
1
]):
return
False
# Other dimensions must be replicate except the batch dimension
for
mapping
in
out_dims_mapping
[
1
:
-
1
]:
if
is_dim_shard
(
mapping
):
return
False
return
True
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
return
True
def
update_dims_mapping
(
self
,
dist_op
):
changed
=
False
dim_changed
=
_update_dims_mapping_for_matmul
(
dist_op
)
if
dim_changed
:
changed
=
True
return
changed
@
staticmethod
def
forward
(
ctx
,
*
args
,
**
kwargs
):
"""
kwargs: inputname_mapping & outputname_mapping
"""
dist_op_context
=
ctx
.
dist_op_context
main_block
=
dist_op_context
.
work_block
startup_block
=
dist_op_context
.
startup_block
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
assert
op_dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
# check validation of inputs / outputs
for
input_name
in
src_op
.
desc
.
input_names
():
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
input_name
)
assert
len
(
kwargs
[
input_name
])
==
len
(
src_op
.
desc
.
input
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
for
output_name
in
src_op
.
desc
.
output_names
():
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
output_name
)
assert
len
(
kwargs
[
output_name
])
==
len
(
src_op
.
desc
.
output
(
output_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
output_name
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
# TODO infer logic comm presentation
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
assert
matmul_row_dim_mapping
>=
0
,
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_row_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
parallel_axis
=
matmul_row_dim_mapping
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group
=
new_process_group
(
group_ranks
)
check_variable_and_dtype
(
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
# attrs = {'trans_x': False, 'trans_y': False}
attrs
=
{
"x_num_col_dims"
:
src_op
.
desc
.
attr
(
"x_num_col_dims"
),
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
)
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
# infer out var shape with op dist attr
out_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
Out_var
)
assert
out_tensor_dist_attr
is
not
None
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
out_var_dist_attr
is
not
None
ref_shape
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
intermediate_var_0
=
main_block
.
create_var
(
shape
=
Out_var
.
shape
,
dtype
=
Out_var
.
dtype
,
type
=
Out_var
.
type
,
lod_level
=
Out_var
.
lod_level
,
persistable
=
False
,
is_data
=
False
,
need_check_feed
=
Out_var
.
desc
.
need_check_feed
())
# set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
out_var_dist_attr
)
mul_op
=
main_block
.
append_op
(
type
=
'mul'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
intermediate_var_0
},
attrs
=
attrs
)
if
intermediate_var_0
.
shape
!=
ref_shape
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape
)
c_allreduce_sum_op
=
main_block
.
append_op
(
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
intermediate_var_0
},
outputs
=
{
'Out'
:
Out_var
},
attrs
=
{
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
})
if
Out_var
.
shape
!=
ref_shape
:
Out_var
.
desc
.
set_shape
(
ref_shape
)
# set dist op's dist_attr with serial op's dist_attr
# matmulv2
matmulv2_op_dist_attr
=
OperatorDistributedAttribute
()
matmulv2_op_dist_attr
.
process_mesh
=
op_dist_attr
.
process_mesh
matmulv2_op_dist_attr
.
impl_type
=
op_dist_attr
.
impl_type
matmulv2_op_dist_attr
.
impl_idx
=
op_dist_attr
.
impl_idx
for
input_varname
in
mul_op
.
desc
.
input_arg_names
():
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
output_varname
=
mul_op
.
desc
.
output_arg_names
()[
0
]
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
mul_op
,
matmulv2_op_dist_attr
)
# allreduce
allreduce_op_dist_attr
=
OperatorDistributedAttribute
()
allreduce_op_dist_attr
.
process_mesh
=
op_dist_attr
.
process_mesh
allreduce_op_dist_attr
.
impl_type
=
op_dist_attr
.
impl_type
allreduce_op_dist_attr
.
impl_idx
=
op_dist_attr
.
impl_idx
for
input_varname
in
c_allreduce_sum_op
.
desc
.
input_arg_names
():
input_var
=
main_block
.
var
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
assert
tensor_dist_attr
is
not
None
allreduce_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
tensor_dist_attr
)
for
output_varname
in
c_allreduce_sum_op
.
desc
.
output_arg_names
():
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
allreduce_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
c_allreduce_sum_op
,
allreduce_op_dist_attr
)
# init param sync
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
_right_operand_parameter_matmul_backward
(
ctx
,
*
args
,
**
kwargs
)
# ReplicateParallel
class
DistributedMulImpl2
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
super
(
DistributedMulImpl2
,
self
).
__init__
(
name
)
def
is_input_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
y_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
if
is_valid_list_index
(
x_dims_mapping
,
-
2
)
and
is_dim_shard
(
x_dims_mapping
[
-
2
]):
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
1
]):
return
False
if
is_valid_list_index
(
y_dims_mapping
,
-
2
)
and
is_dim_shard
(
y_dims_mapping
[
-
2
]):
return
False
return
True
def
is_output_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
out_name
=
op_desc
.
output
(
'Out'
)[
0
]
out_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
out_name
)
if
is_dim_shard
(
out_dims_mapping
[
-
1
]):
return
False
if
is_valid_list_index
(
out_dims_mapping
,
-
2
)
and
is_dim_shard
(
out_dims_mapping
[
-
2
]):
return
False
return
True
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
return
True
def
update_dims_mapping
(
self
,
dist_op
):
changed
=
False
dim_changed
=
_update_dims_mapping_for_matmul
(
dist_op
)
if
dim_changed
:
changed
=
True
return
changed
@
staticmethod
def
forward
(
ctx
,
*
args
,
**
kwargs
):
DistributedDefaultImpl0
.
forward
(
ctx
,
*
args
,
**
kwargs
)
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
_right_operand_parameter_matmul_backward
(
ctx
,
*
args
,
**
kwargs
)
register_distributed_operator_impl
(
"mul"
,
DistributedMulImpl0
(
"column_parallel"
))
register_distributed_operator_impl
(
"mul"
,
DistributedMulImpl1
(
"row_parallel"
))
register_distributed_operator_impl
(
"mul"
,
DistributedMulImpl2
(
"replicate_parallel"
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录