Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6992170e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6992170e
编写于
11月 22, 2022
作者:
Z
zhaoyingli
提交者:
GitHub
11月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix_var_recursive (#48206)
上级
3c0bd3af
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
57 addition
and
85 deletion
+57
-85
python/paddle/distributed/auto_parallel/operators/common.py
python/paddle/distributed/auto_parallel/operators/common.py
+3
-3
python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py
.../auto_parallel/operators/dist_check_finite_and_unscale.py
+4
-4
python/paddle/distributed/auto_parallel/operators/dist_default.py
...addle/distributed/auto_parallel/operators/dist_default.py
+2
-3
python/paddle/distributed/auto_parallel/operators/dist_eltwise.py
...addle/distributed/auto_parallel/operators/dist_eltwise.py
+0
-1
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
...dle/distributed/auto_parallel/operators/dist_embedding.py
+7
-7
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+25
-43
python/paddle/distributed/auto_parallel/operators/dist_pnorm.py
.../paddle/distributed/auto_parallel/operators/dist_pnorm.py
+4
-4
python/paddle/distributed/auto_parallel/operators/dist_reduce_sum_p.py
.../distributed/auto_parallel/operators/dist_reduce_sum_p.py
+2
-2
python/paddle/distributed/auto_parallel/operators/dist_reshape.py
...addle/distributed/auto_parallel/operators/dist_reshape.py
+9
-15
python/paddle/distributed/auto_parallel/operators/dist_softmax.py
...addle/distributed/auto_parallel/operators/dist_softmax.py
+0
-1
python/paddle/distributed/auto_parallel/operators/dist_transpose.py
...dle/distributed/auto_parallel/operators/dist_transpose.py
+0
-1
python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py
...buted/auto_parallel/operators/dist_update_loss_scaling.py
+1
-1
未找到文件。
python/paddle/distributed/auto_parallel/operators/common.py
浏览文件 @
6992170e
...
@@ -266,13 +266,13 @@ def is_parameter_related(varname, block):
...
@@ -266,13 +266,13 @@ def is_parameter_related(varname, block):
varname
=
varname
[:
varname
.
index
(
".cast_fp"
)]
varname
=
varname
[:
varname
.
index
(
".cast_fp"
)]
if
".quantized"
in
varname
:
if
".quantized"
in
varname
:
varname
=
varname
[:
varname
.
index
(
".quantized"
)]
varname
=
varname
[:
varname
.
index
(
".quantized"
)]
assert
block
.
has_var
(
varname
)
assert
block
.
_find_var_recursive
(
varname
)
var
=
block
.
var
(
varname
)
var
=
block
.
_var_recursive
(
varname
)
return
var
.
is_parameter
return
var
.
is_parameter
def
infer_shape
(
block
,
src_var
,
src_var_dist_attr
,
op_input_dist_attr
):
def
infer_shape
(
block
,
src_var
,
src_var_dist_attr
,
op_input_dist_attr
):
var_shape
=
block
.
var
(
src_var
.
name
).
shape
var_shape
=
block
.
_var_recursive
(
src_var
.
name
).
shape
var_topoloy
=
src_var_dist_attr
.
process_mesh
.
topology
var_topoloy
=
src_var_dist_attr
.
process_mesh
.
topology
var_dims_mapping
=
src_var_dist_attr
.
dims_mapping
var_dims_mapping
=
src_var_dist_attr
.
dims_mapping
...
...
python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py
浏览文件 @
6992170e
...
@@ -117,7 +117,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
...
@@ -117,7 +117,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
if
(
if
(
rank_id
rank_id
in
ctx
.
get_tensor_dist_attr_for_program
(
in
ctx
.
get_tensor_dist_attr_for_program
(
main_block
.
var
(
varname
)
main_block
.
_var_recursive
(
varname
)
).
process_mesh
.
processes
).
process_mesh
.
processes
):
):
filter_vars
.
append
(
varname
)
filter_vars
.
append
(
varname
)
...
@@ -132,7 +132,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
...
@@ -132,7 +132,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
# sync result
# sync result
group
=
new_process_group
(
world_process_group
.
ranks
)
group
=
new_process_group
(
world_process_group
.
ranks
)
inf_var
=
main_block
.
var
(
kwargs
[
'FoundInfinite'
][
0
])
inf_var
=
main_block
.
_var_recursive
(
kwargs
[
'FoundInfinite'
][
0
])
inf_var_int32
=
main_block
.
create_var
(
inf_var_int32
=
main_block
.
create_var
(
name
=
inf_var
.
name
+
"@cast_int32"
,
name
=
inf_var
.
name
+
"@cast_int32"
,
shape
=
inf_var
.
shape
,
shape
=
inf_var
.
shape
,
...
@@ -179,7 +179,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
...
@@ -179,7 +179,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
new_op_dist_attr
=
OperatorDistributedAttribute
()
new_op_dist_attr
=
OperatorDistributedAttribute
()
for
varname
in
op
.
input_arg_names
:
for
varname
in
op
.
input_arg_names
:
var_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
var_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
main_block
.
var
(
varname
)
main_block
.
_var_recursive
(
varname
)
)
)
assert
var_dist_attr
is
not
None
assert
var_dist_attr
is
not
None
new_op_dist_attr
.
set_input_dims_mapping
(
new_op_dist_attr
.
set_input_dims_mapping
(
...
@@ -187,7 +187,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
...
@@ -187,7 +187,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
)
)
for
varname
in
op
.
output_arg_names
:
for
varname
in
op
.
output_arg_names
:
var_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
var_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
main_block
.
var
(
varname
)
main_block
.
_var_recursive
(
varname
)
)
)
new_op_dist_attr
.
set_output_dims_mapping
(
new_op_dist_attr
.
set_output_dims_mapping
(
varname
,
var_dist_attr
.
dims_mapping
varname
,
var_dist_attr
.
dims_mapping
...
...
python/paddle/distributed/auto_parallel/operators/dist_default.py
浏览文件 @
6992170e
...
@@ -69,7 +69,7 @@ def prim_operator_data_parallel_functor(ctx, src_op):
...
@@ -69,7 +69,7 @@ def prim_operator_data_parallel_functor(ctx, src_op):
},
},
)
)
grad_var
=
main_block
.
var
(
var_name
)
grad_var
=
main_block
.
_var_recursive
(
var_name
)
dims_mapping
=
ctx
.
get_tensor_dist_attr_for_program
(
dims_mapping
=
ctx
.
get_tensor_dist_attr_for_program
(
grad_var
grad_var
).
dims_mapping
).
dims_mapping
...
@@ -140,7 +140,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
...
@@ -140,7 +140,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
res
.
append
(
cost_mapping
)
res
.
append
(
cost_mapping
)
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
need_gradient_allreduce
=
False
need_gradient_allreduce
=
False
for
input_name
in
backward_op
.
desc
.
input_names
():
for
input_name
in
backward_op
.
desc
.
input_names
():
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
...
@@ -588,7 +587,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
...
@@ -588,7 +587,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for
varname
in
backward_op
.
desc
.
output
(
output_name
):
for
varname
in
backward_op
.
desc
.
output
(
output_name
):
if
varname
in
kwargs
[
"grad_var_to_var"
]:
if
varname
in
kwargs
[
"grad_var_to_var"
]:
fwd_name
=
kwargs
[
"grad_var_to_var"
][
varname
]
fwd_name
=
kwargs
[
"grad_var_to_var"
][
varname
]
if
fwd_name
not
in
main_block
.
vars
:
if
not
main_block
.
_find_var_recursive
(
fwd_name
)
:
continue
continue
if
is_parameter_related
(
fwd_name
,
main_block
):
if
is_parameter_related
(
fwd_name
,
main_block
):
out_grad_names
.
append
(
varname
)
out_grad_names
.
append
(
varname
)
...
...
python/paddle/distributed/auto_parallel/operators/dist_eltwise.py
浏览文件 @
6992170e
...
@@ -84,7 +84,6 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
...
@@ -84,7 +84,6 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
res
.
append
(
cost_mapping
)
res
.
append
(
cost_mapping
)
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
need_gradient_allreduce
=
False
need_gradient_allreduce
=
False
for
input_name
in
backward_op
.
desc
.
input_names
():
for
input_name
in
backward_op
.
desc
.
input_names
():
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
...
...
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
浏览文件 @
6992170e
...
@@ -370,9 +370,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
...
@@ -370,9 +370,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
kwargs
[
'Out'
]
kwargs
[
'Out'
]
)
)
Ids_var
=
main_block
.
var
(
kwargs
[
'Ids'
][
0
])
Ids_var
=
main_block
.
_var_recursive
(
kwargs
[
'Ids'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'W'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'W'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
Out_var
=
main_block
.
_var_recursive
(
kwargs
[
'Out'
][
0
])
# support lookup_table_v1
# support lookup_table_v1
if
src_op
.
type
==
'lookup_table'
:
if
src_op
.
type
==
'lookup_table'
:
...
@@ -507,7 +507,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
...
@@ -507,7 +507,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
allreduce_op_dist_attr
.
impl_type
=
op_dist_attr
.
impl_type
allreduce_op_dist_attr
.
impl_type
=
op_dist_attr
.
impl_type
allreduce_op_dist_attr
.
impl_idx
=
op_dist_attr
.
impl_idx
allreduce_op_dist_attr
.
impl_idx
=
op_dist_attr
.
impl_idx
for
input_varname
in
c_allreduce_sum_op
.
desc
.
input_arg_names
():
for
input_varname
in
c_allreduce_sum_op
.
desc
.
input_arg_names
():
input_var
=
main_block
.
var
(
input_varname
)
input_var
=
main_block
.
_var_recursive
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
assert
tensor_dist_attr
is
not
None
assert
tensor_dist_attr
is
not
None
allreduce_op_dist_attr
.
set_input_dist_attr
(
allreduce_op_dist_attr
.
set_input_dist_attr
(
...
@@ -607,10 +607,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
...
@@ -607,10 +607,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
kwargs
[
'W@GRAD'
]
kwargs
[
'W@GRAD'
]
)
)
Ids_var
=
main_block
.
var
(
kwargs
[
'Ids'
][
0
])
Ids_var
=
main_block
.
_var_recursive
(
kwargs
[
'Ids'
][
0
])
Weight_var
=
main_block
.
var
(
kwargs
[
'W'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'W'
][
0
])
Out_grad
=
main_block
.
var
(
kwargs
[
'Out@GRAD'
][
0
])
Out_grad
=
main_block
.
_var_recursive
(
kwargs
[
'Out@GRAD'
][
0
])
Weight_grad
=
main_block
.
var
(
kwargs
[
'W@GRAD'
][
0
])
Weight_grad
=
main_block
.
_var_recursive
(
kwargs
[
'W@GRAD'
][
0
])
embedding_row_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
embedding_row_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
Weight_var
.
name
...
...
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
6992170e
...
@@ -316,10 +316,10 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
...
@@ -316,10 +316,10 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
kwargs
[
'Y@GRAD'
]
kwargs
[
'Y@GRAD'
]
)
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
_var_recursive
(
kwargs
[
'X'
][
0
])
Y_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Y_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Out_grad
=
main_block
.
var
(
kwargs
[
'Out@GRAD'
][
0
])
Out_grad
=
main_block
.
_var_recursive
(
kwargs
[
'Out@GRAD'
][
0
])
Y_grad
=
main_block
.
var
(
kwargs
[
'Y@GRAD'
][
0
])
Y_grad
=
main_block
.
_var_recursive
(
kwargs
[
'Y@GRAD'
][
0
])
assert
not
is_parameter_related
(
assert
not
is_parameter_related
(
X_var
.
name
,
main_block
X_var
.
name
,
main_block
...
@@ -433,7 +433,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
...
@@ -433,7 +433,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
has_x_grad
=
len
(
kwargs
[
'X@GRAD'
])
>
0
has_x_grad
=
len
(
kwargs
[
'X@GRAD'
])
>
0
if
has_x_grad
:
if
has_x_grad
:
assert
len
(
kwargs
[
'X@GRAD'
])
==
1
assert
len
(
kwargs
[
'X@GRAD'
])
==
1
X_grad
=
main_block
.
var
(
kwargs
[
'X@GRAD'
][
0
])
X_grad
=
main_block
.
_var_recursive
(
kwargs
[
'X@GRAD'
][
0
])
intermediate_var_0
=
main_block
.
create_var
(
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
([
"c_identity"
,
'tmp'
])
"."
.
join
([
"c_identity"
,
'tmp'
])
...
@@ -572,7 +572,6 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -572,7 +572,6 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
backward_op
=
dist_op
.
serial_op
backward_op
=
dist_op
.
serial_op
dist_attr
=
dist_op
.
dist_attr
dist_attr
=
dist_op
.
dist_attr
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
]
backward_op
.
input
(
"Y"
)[
0
]
)
)
...
@@ -647,7 +646,6 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -647,7 +646,6 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
# calc comm op cost
# calc comm op cost
serial_op
=
dist_op
.
serial_op
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
1
]
)[
-
1
]
...
@@ -762,9 +760,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -762,9 +760,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
output_name
output_name
)
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
_var_recursive
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
var
(
kwargs
[
'Y'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
Out_var
=
main_block
.
_var_recursive
(
kwargs
[
'Out'
][
0
])
trans_x
=
src_op
.
attr
(
"transpose_X"
)
trans_x
=
src_op
.
attr
(
"transpose_X"
)
trans_y
=
src_op
.
attr
(
"transpose_Y"
)
trans_y
=
src_op
.
attr
(
"transpose_Y"
)
...
@@ -906,7 +904,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -906,7 +904,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
input_varname
,
input_dist_attr
input_varname
,
input_dist_attr
)
)
else
:
else
:
input_var
=
main_block
.
var
(
input_varname
)
input_var
=
main_block
.
_var_recursive
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
input_var
)
)
...
@@ -958,7 +956,6 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -958,7 +956,6 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
backward_op
=
dist_op
.
serial_op
backward_op
=
dist_op
.
serial_op
dist_attr
=
dist_op
.
dist_attr
dist_attr
=
dist_op
.
dist_attr
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
]
backward_op
.
input
(
"Y"
)[
0
]
)
)
...
@@ -1023,8 +1020,6 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -1023,8 +1020,6 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# calc comm op cost
# calc comm op cost
serial_op
=
dist_op
.
serial_op
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
2
]
)[
-
2
]
...
@@ -1147,9 +1142,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -1147,9 +1142,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
output_name
output_name
)
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
_var_recursive
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
var
(
kwargs
[
'Y'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
Out_var
=
main_block
.
_var_recursive
(
kwargs
[
'Out'
][
0
])
trans_x
=
src_op
.
attr
(
'transpose_X'
)
trans_x
=
src_op
.
attr
(
'transpose_X'
)
trans_y
=
src_op
.
attr
(
'transpose_Y'
)
trans_y
=
src_op
.
attr
(
'transpose_Y'
)
...
@@ -1268,7 +1263,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -1268,7 +1263,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
allreduce_op_dist_attr
.
impl_type
=
op_dist_attr
.
impl_type
allreduce_op_dist_attr
.
impl_type
=
op_dist_attr
.
impl_type
allreduce_op_dist_attr
.
impl_idx
=
op_dist_attr
.
impl_idx
allreduce_op_dist_attr
.
impl_idx
=
op_dist_attr
.
impl_idx
for
input_varname
in
c_allreduce_sum_op
.
desc
.
input_arg_names
():
for
input_varname
in
c_allreduce_sum_op
.
desc
.
input_arg_names
():
input_var
=
main_block
.
var
(
input_varname
)
input_var
=
main_block
.
_var_recursive
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
assert
tensor_dist_attr
is
not
None
assert
tensor_dist_attr
is
not
None
allreduce_op_dist_attr
.
set_input_dist_attr
(
allreduce_op_dist_attr
.
set_input_dist_attr
(
...
@@ -1316,7 +1311,6 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
...
@@ -1316,7 +1311,6 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
backward_op
=
dist_op
.
serial_op
backward_op
=
dist_op
.
serial_op
dist_attr
=
dist_op
.
dist_attr
dist_attr
=
dist_op
.
dist_attr
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
desc_mapping
=
build_comp_desc_from_dist_op
(
...
@@ -1469,7 +1463,6 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1469,7 +1463,6 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
backward_op
=
dist_op
.
serial_op
backward_op
=
dist_op
.
serial_op
dist_attr
=
dist_op
.
dist_attr
dist_attr
=
dist_op
.
dist_attr
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
]
backward_op
.
input
(
"Y"
)[
0
]
)
)
...
@@ -1549,8 +1542,6 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1549,8 +1542,6 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
# calc comm op cost
# calc comm op cost
serial_op
=
dist_op
.
serial_op
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
1
]
)[
-
1
]
...
@@ -1665,9 +1656,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1665,9 +1656,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
output_name
output_name
)
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
_var_recursive
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
Out_var
=
main_block
.
_var_recursive
(
kwargs
[
'Out'
][
0
])
trans_x
=
src_op
.
attr
(
'trans_x'
)
trans_x
=
src_op
.
attr
(
'trans_x'
)
trans_y
=
src_op
.
attr
(
'trans_y'
)
trans_y
=
src_op
.
attr
(
'trans_y'
)
...
@@ -1808,7 +1799,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1808,7 +1799,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
input_varname
,
input_dist_attr
input_varname
,
input_dist_attr
)
)
else
:
else
:
input_var
=
main_block
.
var
(
input_varname
)
input_var
=
main_block
.
_var_recursive
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
input_var
)
)
...
@@ -1858,7 +1849,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1858,7 +1849,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
backward_op
=
dist_op
.
serial_op
backward_op
=
dist_op
.
serial_op
dist_attr
=
dist_op
.
dist_attr
dist_attr
=
dist_op
.
dist_attr
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
]
backward_op
.
input
(
"Y"
)[
0
]
)
)
...
@@ -1924,8 +1915,6 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1924,8 +1915,6 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
# calc comm op cost
# calc comm op cost
serial_op
=
dist_op
.
serial_op
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
2
]
)[
-
2
]
...
@@ -2047,9 +2036,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -2047,9 +2036,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
output_name
output_name
)
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
_var_recursive
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
Out_var
=
main_block
.
_var_recursive
(
kwargs
[
'Out'
][
0
])
trans_x
=
src_op
.
attr
(
'trans_x'
)
trans_x
=
src_op
.
attr
(
'trans_x'
)
trans_y
=
src_op
.
attr
(
'trans_y'
)
trans_y
=
src_op
.
attr
(
'trans_y'
)
...
@@ -2167,7 +2156,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -2167,7 +2156,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
allreduce_op_dist_attr
.
impl_type
=
op_dist_attr
.
impl_type
allreduce_op_dist_attr
.
impl_type
=
op_dist_attr
.
impl_type
allreduce_op_dist_attr
.
impl_idx
=
op_dist_attr
.
impl_idx
allreduce_op_dist_attr
.
impl_idx
=
op_dist_attr
.
impl_idx
for
input_varname
in
c_allreduce_sum_op
.
desc
.
input_arg_names
():
for
input_varname
in
c_allreduce_sum_op
.
desc
.
input_arg_names
():
input_var
=
main_block
.
var
(
input_varname
)
input_var
=
main_block
.
_var_recursive
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
assert
tensor_dist_attr
is
not
None
assert
tensor_dist_attr
is
not
None
allreduce_op_dist_attr
.
set_input_dist_attr
(
allreduce_op_dist_attr
.
set_input_dist_attr
(
...
@@ -2215,7 +2204,6 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
...
@@ -2215,7 +2204,6 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
backward_op
=
dist_op
.
serial_op
backward_op
=
dist_op
.
serial_op
dist_attr
=
dist_op
.
dist_attr
dist_attr
=
dist_op
.
dist_attr
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
process_mesh
=
dist_attr
.
process_mesh
process_mesh
=
dist_attr
.
process_mesh
# calc comp op cost
# calc comp op cost
...
@@ -2370,7 +2358,6 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2370,7 +2358,6 @@ class DistributedMulImpl0(DistributedOperatorImpl):
backward_op
=
dist_op
.
serial_op
backward_op
=
dist_op
.
serial_op
dist_attr
=
dist_op
.
dist_attr
dist_attr
=
dist_op
.
dist_attr
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
]
backward_op
.
input
(
"Y"
)[
0
]
)
)
...
@@ -2445,7 +2432,6 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2445,7 +2432,6 @@ class DistributedMulImpl0(DistributedOperatorImpl):
# calc comm op cost
# calc comm op cost
serial_op
=
dist_op
.
serial_op
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
1
]
)[
-
1
]
...
@@ -2555,9 +2541,9 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2555,9 +2541,9 @@ class DistributedMulImpl0(DistributedOperatorImpl):
output_name
output_name
)
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
_var_recursive
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
Out_var
=
main_block
.
_var_recursive
(
kwargs
[
'Out'
][
0
])
# TODO infer logic comm presentation
# TODO infer logic comm presentation
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
...
@@ -2712,7 +2698,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2712,7 +2698,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
input_varname
,
input_dist_attr
input_varname
,
input_dist_attr
)
)
else
:
else
:
input_var
=
main_block
.
var
(
input_varname
)
input_var
=
main_block
.
_var_recursive
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
input_var
)
)
...
@@ -2763,7 +2749,6 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -2763,7 +2749,6 @@ class DistributedMulImpl1(DistributedOperatorImpl):
dist_attr
=
dist_op
.
dist_attr
dist_attr
=
dist_op
.
dist_attr
process_mesh
=
dist_attr
.
process_mesh
process_mesh
=
dist_attr
.
process_mesh
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
]
backward_op
.
input
(
"Y"
)[
0
]
)
)
...
@@ -2827,8 +2812,6 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -2827,8 +2812,6 @@ class DistributedMulImpl1(DistributedOperatorImpl):
# calc comm op cost
# calc comm op cost
serial_op
=
dist_op
.
serial_op
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
2
]
)[
-
2
]
...
@@ -2947,9 +2930,9 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -2947,9 +2930,9 @@ class DistributedMulImpl1(DistributedOperatorImpl):
output_name
output_name
)
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
_var_recursive
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
Out_var
=
main_block
.
_var_recursive
(
kwargs
[
'Out'
][
0
])
# TODO infer logic comm presentation
# TODO infer logic comm presentation
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
...
@@ -3082,7 +3065,7 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -3082,7 +3065,7 @@ class DistributedMulImpl1(DistributedOperatorImpl):
allreduce_op_dist_attr
.
impl_type
=
op_dist_attr
.
impl_type
allreduce_op_dist_attr
.
impl_type
=
op_dist_attr
.
impl_type
allreduce_op_dist_attr
.
impl_idx
=
op_dist_attr
.
impl_idx
allreduce_op_dist_attr
.
impl_idx
=
op_dist_attr
.
impl_idx
for
input_varname
in
c_allreduce_sum_op
.
desc
.
input_arg_names
():
for
input_varname
in
c_allreduce_sum_op
.
desc
.
input_arg_names
():
input_var
=
main_block
.
var
(
input_varname
)
input_var
=
main_block
.
_var_recursive
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
assert
tensor_dist_attr
is
not
None
assert
tensor_dist_attr
is
not
None
allreduce_op_dist_attr
.
set_input_dist_attr
(
allreduce_op_dist_attr
.
set_input_dist_attr
(
...
@@ -3130,7 +3113,6 @@ class DistributedMulImpl2(DistributedOperatorImpl):
...
@@ -3130,7 +3113,6 @@ class DistributedMulImpl2(DistributedOperatorImpl):
backward_op
=
dist_op
.
serial_op
backward_op
=
dist_op
.
serial_op
dist_attr
=
dist_op
.
dist_attr
dist_attr
=
dist_op
.
dist_attr
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
desc_mapping
=
build_comp_desc_from_dist_op
(
...
...
python/paddle/distributed/auto_parallel/operators/dist_pnorm.py
浏览文件 @
6992170e
...
@@ -155,7 +155,7 @@ class DistributedPNormImpl(DistributedOperatorImpl):
...
@@ -155,7 +155,7 @@ class DistributedPNormImpl(DistributedOperatorImpl):
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
_var_recursive
(
kwargs
[
'X'
][
0
])
in_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
X_var
.
name
)
in_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
X_var
.
name
)
for
axis
in
range
(
len
(
in_dims_mapping
)):
for
axis
in
range
(
len
(
in_dims_mapping
)):
if
in_dims_mapping
[
axis
]
!=
-
1
:
if
in_dims_mapping
[
axis
]
!=
-
1
:
...
@@ -260,13 +260,13 @@ class DistributedPNormImpl(DistributedOperatorImpl):
...
@@ -260,13 +260,13 @@ class DistributedPNormImpl(DistributedOperatorImpl):
output_name
output_name
)
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
_var_recursive
(
kwargs
[
'X'
][
0
])
X_grad_var
=
main_block
.
var
(
kwargs
[
'X@GRAD'
][
0
])
X_grad_var
=
main_block
.
_var_recursive
(
kwargs
[
'X@GRAD'
][
0
])
# 1. copy p_norm_grad op and reset input name and output name
# 1. copy p_norm_grad op and reset input name and output name
new_kwargs
=
copy
.
deepcopy
(
kwargs
)
new_kwargs
=
copy
.
deepcopy
(
kwargs
)
new_kwargs
[
'X'
]
=
[
"."
.
join
([
"c_allgather"
,
X_var
.
name
])]
new_kwargs
[
'X'
]
=
[
"."
.
join
([
"c_allgather"
,
X_var
.
name
])]
new_X_var
=
main_block
.
var
(
new_kwargs
[
'X'
][
0
])
new_X_var
=
main_block
.
_var_recursive
(
new_kwargs
[
'X'
][
0
])
new_X_grad
=
main_block
.
create_var
(
new_X_grad
=
main_block
.
create_var
(
name
=
"."
.
join
([
"c_allgather"
,
X_grad_var
.
name
]),
name
=
"."
.
join
([
"c_allgather"
,
X_grad_var
.
name
]),
dtype
=
X_grad_var
.
dtype
,
dtype
=
X_grad_var
.
dtype
,
...
...
python/paddle/distributed/auto_parallel/operators/dist_reduce_sum_p.py
浏览文件 @
6992170e
...
@@ -54,7 +54,7 @@ class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl):
...
@@ -54,7 +54,7 @@ class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl):
return
False
return
False
output_name
=
outputs
[
0
]
output_name
=
outputs
[
0
]
output_var
=
dist_op
.
serial_op
.
block
.
var
(
output_name
)
output_var
=
dist_op
.
serial_op
.
block
.
_var_recursive
(
output_name
)
if
output_var
.
shape
!=
(
1
,):
if
output_var
.
shape
!=
(
1
,):
return
False
return
False
...
@@ -124,7 +124,7 @@ class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl):
...
@@ -124,7 +124,7 @@ class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl):
)
)
# dist attr
# dist attr
var
=
main_block
.
var
(
var_name
)
var
=
main_block
.
_var_recursive
(
var_name
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
var
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
var
)
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
new_op_attr
=
OperatorDistributedAttribute
()
new_op_attr
=
OperatorDistributedAttribute
()
...
...
python/paddle/distributed/auto_parallel/operators/dist_reshape.py
浏览文件 @
6992170e
...
@@ -53,7 +53,6 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
...
@@ -53,7 +53,6 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
res
=
[]
res
=
[]
op
=
dist_op
.
serial_op
op
=
dist_op
.
serial_op
vars
=
op
.
block
.
vars
dist_attr
=
dist_op
.
dist_attr
dist_attr
=
dist_op
.
dist_attr
shape_list
=
op
.
desc
.
attr
(
"shape"
)
shape_list
=
op
.
desc
.
attr
(
"shape"
)
...
@@ -103,7 +102,6 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
...
@@ -103,7 +102,6 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
backward_op
=
dist_op
.
serial_op
backward_op
=
dist_op
.
serial_op
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
need_gradient_allreduce
=
False
need_gradient_allreduce
=
False
vars
=
main_block
.
vars
for
input_name
in
backward_op
.
desc
.
input_names
():
for
input_name
in
backward_op
.
desc
.
input_names
():
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
if
"@GRAD"
not
in
varname
and
is_parameter_related
(
if
"@GRAD"
not
in
varname
and
is_parameter_related
(
...
@@ -246,9 +244,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
...
@@ -246,9 +244,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
output_name
output_name
)
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
_var_recursive
(
kwargs
[
'X'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
Out_var
=
main_block
.
_var_recursive
(
kwargs
[
'Out'
][
0
])
XShape_var
=
main_block
.
var
(
kwargs
[
'XShape'
][
0
])
XShape_var
=
main_block
.
_var_recursive
(
kwargs
[
'XShape'
][
0
])
shape_list
=
src_op
.
desc
.
attr
(
"shape"
)
shape_list
=
src_op
.
desc
.
attr
(
"shape"
)
ShapeTensor_var_list
=
[]
ShapeTensor_var_list
=
[]
for
name
in
kwargs
[
'ShapeTensor'
]:
for
name
in
kwargs
[
'ShapeTensor'
]:
...
@@ -303,7 +301,6 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
...
@@ -303,7 +301,6 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
res
=
[]
res
=
[]
op
=
dist_op
.
serial_op
op
=
dist_op
.
serial_op
vars
=
op
.
block
.
vars
dist_attr
=
dist_op
.
dist_attr
dist_attr
=
dist_op
.
dist_attr
shape_list
=
op
.
desc
.
attr
(
"shape"
)
shape_list
=
op
.
desc
.
attr
(
"shape"
)
...
@@ -353,7 +350,6 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
...
@@ -353,7 +350,6 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
backward_op
=
dist_op
.
serial_op
backward_op
=
dist_op
.
serial_op
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
need_gradient_allreduce
=
False
need_gradient_allreduce
=
False
vars
=
main_block
.
vars
for
input_name
in
backward_op
.
desc
.
input_names
():
for
input_name
in
backward_op
.
desc
.
input_names
():
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
if
"@GRAD"
not
in
varname
and
not
is_parameter_related
(
if
"@GRAD"
not
in
varname
and
not
is_parameter_related
(
...
@@ -499,9 +495,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
...
@@ -499,9 +495,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
output_name
output_name
)
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
_var_recursive
(
kwargs
[
'X'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
Out_var
=
main_block
.
_var_recursive
(
kwargs
[
'Out'
][
0
])
XShape_var
=
main_block
.
var
(
kwargs
[
'XShape'
][
0
])
XShape_var
=
main_block
.
_var_recursive
(
kwargs
[
'XShape'
][
0
])
shape_list
=
src_op
.
desc
.
attr
(
"shape"
)
shape_list
=
src_op
.
desc
.
attr
(
"shape"
)
ShapeTensor_var_list
=
[]
ShapeTensor_var_list
=
[]
for
name
in
kwargs
[
'ShapeTensor'
]:
for
name
in
kwargs
[
'ShapeTensor'
]:
...
@@ -556,7 +552,6 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
...
@@ -556,7 +552,6 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
res
=
[]
res
=
[]
op
=
dist_op
.
serial_op
op
=
dist_op
.
serial_op
vars
=
op
.
block
.
vars
dist_attr
=
dist_op
.
dist_attr
dist_attr
=
dist_op
.
dist_attr
shape_list
=
op
.
desc
.
attr
(
"shape"
)
shape_list
=
op
.
desc
.
attr
(
"shape"
)
...
@@ -606,7 +601,6 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
...
@@ -606,7 +601,6 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
backward_op
=
dist_op
.
serial_op
backward_op
=
dist_op
.
serial_op
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
need_gradient_allreduce
=
False
need_gradient_allreduce
=
False
vars
=
main_block
.
vars
for
input_name
in
backward_op
.
desc
.
input_names
():
for
input_name
in
backward_op
.
desc
.
input_names
():
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
if
"@GRAD"
not
in
varname
and
not
is_parameter_related
(
if
"@GRAD"
not
in
varname
and
not
is_parameter_related
(
...
@@ -745,9 +739,9 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
...
@@ -745,9 +739,9 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
output_name
output_name
)
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
_var_recursive
(
kwargs
[
'X'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
Out_var
=
main_block
.
_var_recursive
(
kwargs
[
'Out'
][
0
])
XShape_var
=
main_block
.
var
(
kwargs
[
'XShape'
][
0
])
XShape_var
=
main_block
.
_var_recursive
(
kwargs
[
'XShape'
][
0
])
shape_list
=
src_op
.
desc
.
attr
(
"shape"
)
shape_list
=
src_op
.
desc
.
attr
(
"shape"
)
ShapeTensor_var_list
=
[]
ShapeTensor_var_list
=
[]
for
name
in
kwargs
[
'ShapeTensor'
]:
for
name
in
kwargs
[
'ShapeTensor'
]:
...
...
python/paddle/distributed/auto_parallel/operators/dist_softmax.py
浏览文件 @
6992170e
...
@@ -79,7 +79,6 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
...
@@ -79,7 +79,6 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
backward_op
=
dist_op
.
serial_op
backward_op
=
dist_op
.
serial_op
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
need_gradient_allreduce
=
False
need_gradient_allreduce
=
False
vars
=
main_block
.
vars
for
input_name
in
backward_op
.
desc
.
input_names
():
for
input_name
in
backward_op
.
desc
.
input_names
():
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
if
"@GRAD"
not
in
varname
and
is_parameter_related
(
if
"@GRAD"
not
in
varname
and
is_parameter_related
(
...
...
python/paddle/distributed/auto_parallel/operators/dist_transpose.py
浏览文件 @
6992170e
...
@@ -160,7 +160,6 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
...
@@ -160,7 +160,6 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
backward_op
=
dist_op
.
serial_op
backward_op
=
dist_op
.
serial_op
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
need_gradient_allreduce
=
False
need_gradient_allreduce
=
False
vars
=
main_block
.
vars
for
input_name
in
backward_op
.
desc
.
input_names
():
for
input_name
in
backward_op
.
desc
.
input_names
():
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
if
"@GRAD"
not
in
varname
and
is_parameter_related
(
if
"@GRAD"
not
in
varname
and
is_parameter_related
(
...
...
python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py
浏览文件 @
6992170e
...
@@ -151,7 +151,7 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
...
@@ -151,7 +151,7 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
if
(
if
(
rank_id
rank_id
in
ctx
.
get_tensor_dist_attr_for_program
(
in
ctx
.
get_tensor_dist_attr_for_program
(
main_block
.
var
(
varname
)
main_block
.
_var_recursive
(
varname
)
).
process_mesh
.
processes
).
process_mesh
.
processes
):
):
filter_vars
.
append
(
varname
)
filter_vars
.
append
(
varname
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录