Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b518fa2a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b518fa2a
编写于
3月 23, 2022
作者:
Y
Yulong Ao
提交者:
GitHub
3月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Add distributed mul op for paddle.fluid.layers.fc (#40207)
* [Auto Parallel] Add distributed mul for the old version
上级
9121115b
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
509 addition
and
0 deletion
+509
-0
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+509
-0
未找到文件。
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
b518fa2a
...
...
@@ -1482,3 +1482,512 @@ register_distributed_operator_impl("matmul_v2",
DistributedMatmulV2Impl1
(
"row_parallel"
))
register_distributed_operator_impl
(
"matmul_v2"
,
DistributedMatmulV2Impl2
(
"replicate_parallel"
))
class
DistributedMul
(
DistributedOperatorImplContainer
):
def
__init__
(
self
,
op_type
):
super
(
DistributedMul
,
self
).
__init__
(
op_type
)
register_distributed_operator_impl_container
(
DistributedMul
(
"mul"
))
# ColumnParallel
class
DistributedMulImpl0
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
super
(
DistributedMulImpl0
,
self
).
__init__
(
name
)
self
.
_forward_implemented
=
True
self
.
_backward_implemented
=
True
def
is_input_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
y_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
2
])
or
is_dim_replicate
(
y_dims_mapping
[
-
1
]):
return
False
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
if
is_dim_shard
(
mapping
):
return
False
return
True
def
is_output_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
out_name
=
op_desc
.
output
(
'Out'
)[
0
]
out_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
out_name
)
if
is_dim_replicate
(
out_dims_mapping
[
-
1
]):
return
False
for
mapping
in
out_dims_mapping
[
1
:
-
1
]:
if
is_dim_shard
(
mapping
):
return
False
return
True
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
return
True
def
update_dims_mapping
(
self
,
dist_op
):
changed
=
False
dim_changed
=
_update_dims_mapping_for_matmul
(
dist_op
)
if
dim_changed
:
changed
=
True
return
changed
@
staticmethod
def
forward
(
ctx
,
*
args
,
**
kwargs
):
"""
kwargs: inputname_mapping & outputname_mapping
"""
dist_op_context
=
ctx
.
dist_op_context
main_block
=
dist_op_context
.
work_block
startup_block
=
dist_op_context
.
startup_block
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
assert
op_dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
# check validation of inputs / outputs
for
input_name
in
src_op
.
desc
.
input_names
():
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
input_name
)
assert
len
(
kwargs
[
input_name
])
==
len
(
src_op
.
desc
.
input
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
for
output_name
in
src_op
.
desc
.
output_names
():
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
output_name
)
assert
len
(
kwargs
[
output_name
])
==
len
(
src_op
.
desc
.
output
(
output_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
output_name
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
# TODO infer logic comm presentation
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
assert
matmul_col_dim_mapping
>=
0
,
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_col_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
parallel_axis
=
matmul_col_dim_mapping
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group
=
new_process_group
(
group_ranks
)
# infer new var shape with op dist attr
x_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
X_var
)
assert
x_tensor_dist_attr
is
not
None
identity_var_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
X_var
.
name
)
assert
identity_var_dist_attr
is
not
None
ref_shape_x
=
infer_shape
(
main_block
,
X_var
,
x_tensor_dist_attr
,
identity_var_dist_attr
)
# infer out var shape with op dist attr
out_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
Out_var
)
assert
out_tensor_dist_attr
is
not
None
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
out_var_dist_attr
is
not
None
ref_shape_out
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
[
"c_identity"
,
'tmp'
])),
dtype
=
X_var
.
dtype
,
shape
=
X_var
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
stop_gradient
=
X_var
.
stop_gradient
)
# set intermediate_var_0's dist_attr with X_var's dist_attr
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
identity_var_dist_attr
)
check_variable_and_dtype
(
X_var
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'_c_identity'
)
c_identity_op
=
main_block
.
append_op
(
type
=
'c_identity'
,
inputs
=
{
'X'
:
[
X_var
]},
outputs
=
{
'Out'
:
intermediate_var_0
},
attrs
=
{
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
})
if
intermediate_var_0
.
shape
!=
ref_shape_x
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
# attrs = {'trans_x': False, 'trans_y': False}
attrs
=
{
"x_num_col_dims"
:
src_op
.
desc
.
attr
(
"x_num_col_dims"
),
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
)
}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
mul_op
=
main_block
.
append_op
(
type
=
'mul'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
Out_var
},
attrs
=
attrs
)
if
Out_var
.
shape
!=
ref_shape_out
:
Out_var
.
desc
.
set_shape
(
ref_shape_out
)
# set dist op's dist_attr with serial op's dist_attr
# c_identity
identity_op_dist_attr
=
OperatorDistributedAttribute
()
identity_op_dist_attr
.
process_mesh
=
op_dist_attr
.
process_mesh
identity_op_dist_attr
.
impl_type
=
op_dist_attr
.
impl_type
identity_op_dist_attr
.
impl_idx
=
op_dist_attr
.
impl_idx
# input
input_varname
=
c_identity_op
.
desc
.
input_arg_names
()[
0
]
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
identity_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
# output
output_varname
=
c_identity_op
.
desc
.
output_arg_names
()[
0
]
identity_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
input_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
c_identity_op
,
identity_op_dist_attr
)
# matmulv2
matmulv2_op_dist_attr
=
OperatorDistributedAttribute
()
matmulv2_op_dist_attr
.
process_mesh
=
op_dist_attr
.
process_mesh
matmulv2_op_dist_attr
.
impl_type
=
op_dist_attr
.
impl_type
matmulv2_op_dist_attr
.
impl_idx
=
op_dist_attr
.
impl_idx
for
input_varname
in
mul_op
.
desc
.
input_arg_names
():
if
input_varname
in
src_op
.
desc
.
input_arg_names
():
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
else
:
input_var
=
main_block
.
var
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
tensor_dist_attr
)
for
output_varname
in
mul_op
.
desc
.
output_arg_names
():
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
mul_op
,
matmulv2_op_dist_attr
)
# init param sync
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
_right_operand_parameter_matmul_backward
(
ctx
,
*
args
,
**
kwargs
)
# RowParallel
class
DistributedMulImpl1
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
super
(
DistributedMulImpl1
,
self
).
__init__
(
name
)
self
.
_forward_implemented
=
True
self
.
_backward_implemented
=
True
def
is_input_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
y_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
if
is_dim_replicate
(
x_dims_mapping
[
-
1
]):
return
False
if
is_dim_replicate
(
y_dims_mapping
[
-
2
])
or
is_dim_shard
(
y_dims_mapping
[
-
1
]):
return
False
# Other dimensions must be replicate except the batch dimension
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
if
is_dim_shard
(
mapping
):
return
False
return
True
def
is_output_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
out_name
=
op_desc
.
output
(
'Out'
)[
0
]
out_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
out_name
)
if
is_dim_shard
(
out_dims_mapping
[
-
1
]):
return
False
# Other dimensions must be replicate except the batch dimension
for
mapping
in
out_dims_mapping
[
1
:
-
1
]:
if
is_dim_shard
(
mapping
):
return
False
return
True
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
return
True
def
update_dims_mapping
(
self
,
dist_op
):
changed
=
False
dim_changed
=
_update_dims_mapping_for_matmul
(
dist_op
)
if
dim_changed
:
changed
=
True
return
changed
@
staticmethod
def
forward
(
ctx
,
*
args
,
**
kwargs
):
"""
kwargs: inputname_mapping & outputname_mapping
"""
dist_op_context
=
ctx
.
dist_op_context
main_block
=
dist_op_context
.
work_block
startup_block
=
dist_op_context
.
startup_block
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
assert
op_dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
# check validation of inputs / outputs
for
input_name
in
src_op
.
desc
.
input_names
():
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
input_name
)
assert
len
(
kwargs
[
input_name
])
==
len
(
src_op
.
desc
.
input
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
for
output_name
in
src_op
.
desc
.
output_names
():
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
output_name
)
assert
len
(
kwargs
[
output_name
])
==
len
(
src_op
.
desc
.
output
(
output_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
output_name
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
# TODO infer logic comm presentation
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
assert
matmul_row_dim_mapping
>=
0
,
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_row_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
parallel_axis
=
matmul_row_dim_mapping
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group
=
new_process_group
(
group_ranks
)
check_variable_and_dtype
(
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
# attrs = {'trans_x': False, 'trans_y': False}
attrs
=
{
"x_num_col_dims"
:
src_op
.
desc
.
attr
(
"x_num_col_dims"
),
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
)
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
# infer out var shape with op dist attr
out_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
Out_var
)
assert
out_tensor_dist_attr
is
not
None
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
out_var_dist_attr
is
not
None
ref_shape
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
intermediate_var_0
=
main_block
.
create_var
(
shape
=
Out_var
.
shape
,
dtype
=
Out_var
.
dtype
,
type
=
Out_var
.
type
,
lod_level
=
Out_var
.
lod_level
,
persistable
=
False
,
is_data
=
False
,
need_check_feed
=
Out_var
.
desc
.
need_check_feed
())
# set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
out_var_dist_attr
)
mul_op
=
main_block
.
append_op
(
type
=
'mul'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
intermediate_var_0
},
attrs
=
attrs
)
if
intermediate_var_0
.
shape
!=
ref_shape
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape
)
c_allreduce_sum_op
=
main_block
.
append_op
(
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
intermediate_var_0
},
outputs
=
{
'Out'
:
Out_var
},
attrs
=
{
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
})
if
Out_var
.
shape
!=
ref_shape
:
Out_var
.
desc
.
set_shape
(
ref_shape
)
# set dist op's dist_attr with serial op's dist_attr
# matmulv2
matmulv2_op_dist_attr
=
OperatorDistributedAttribute
()
matmulv2_op_dist_attr
.
process_mesh
=
op_dist_attr
.
process_mesh
matmulv2_op_dist_attr
.
impl_type
=
op_dist_attr
.
impl_type
matmulv2_op_dist_attr
.
impl_idx
=
op_dist_attr
.
impl_idx
for
input_varname
in
mul_op
.
desc
.
input_arg_names
():
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
output_varname
=
mul_op
.
desc
.
output_arg_names
()[
0
]
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
mul_op
,
matmulv2_op_dist_attr
)
# allreduce
allreduce_op_dist_attr
=
OperatorDistributedAttribute
()
allreduce_op_dist_attr
.
process_mesh
=
op_dist_attr
.
process_mesh
allreduce_op_dist_attr
.
impl_type
=
op_dist_attr
.
impl_type
allreduce_op_dist_attr
.
impl_idx
=
op_dist_attr
.
impl_idx
for
input_varname
in
c_allreduce_sum_op
.
desc
.
input_arg_names
():
input_var
=
main_block
.
var
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
assert
tensor_dist_attr
is
not
None
allreduce_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
tensor_dist_attr
)
for
output_varname
in
c_allreduce_sum_op
.
desc
.
output_arg_names
():
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
allreduce_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
c_allreduce_sum_op
,
allreduce_op_dist_attr
)
# init param sync
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
_right_operand_parameter_matmul_backward
(
ctx
,
*
args
,
**
kwargs
)
# ReplicateParallel
class
DistributedMulImpl2
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
super
(
DistributedMulImpl2
,
self
).
__init__
(
name
)
def
is_input_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
y_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
if
is_valid_list_index
(
x_dims_mapping
,
-
2
)
and
is_dim_shard
(
x_dims_mapping
[
-
2
]):
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
1
]):
return
False
if
is_valid_list_index
(
y_dims_mapping
,
-
2
)
and
is_dim_shard
(
y_dims_mapping
[
-
2
]):
return
False
return
True
def
is_output_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
out_name
=
op_desc
.
output
(
'Out'
)[
0
]
out_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
out_name
)
if
is_dim_shard
(
out_dims_mapping
[
-
1
]):
return
False
if
is_valid_list_index
(
out_dims_mapping
,
-
2
)
and
is_dim_shard
(
out_dims_mapping
[
-
2
]):
return
False
return
True
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
return
True
def
update_dims_mapping
(
self
,
dist_op
):
changed
=
False
dim_changed
=
_update_dims_mapping_for_matmul
(
dist_op
)
if
dim_changed
:
changed
=
True
return
changed
@
staticmethod
def
forward
(
ctx
,
*
args
,
**
kwargs
):
DistributedDefaultImpl0
.
forward
(
ctx
,
*
args
,
**
kwargs
)
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
_right_operand_parameter_matmul_backward
(
ctx
,
*
args
,
**
kwargs
)
register_distributed_operator_impl
(
"mul"
,
DistributedMulImpl0
(
"column_parallel"
))
register_distributed_operator_impl
(
"mul"
,
DistributedMulImpl1
(
"row_parallel"
))
register_distributed_operator_impl
(
"mul"
,
DistributedMulImpl2
(
"replicate_parallel"
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录