Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
70770d0d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
70770d0d
编写于
8月 03, 2022
作者:
J
JZ-LIANG
提交者:
GitHub
8月 03, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Unify gradient synchronization procedure of data parallel (#44815)
上级
0e6bf744
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
171 addition
and
181 deletion
+171
-181
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+7
-7
python/paddle/distributed/auto_parallel/operators/common.py
python/paddle/distributed/auto_parallel/operators/common.py
+132
-0
python/paddle/distributed/auto_parallel/operators/dist_default.py
...addle/distributed/auto_parallel/operators/dist_default.py
+16
-76
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
...dle/distributed/auto_parallel/operators/dist_embedding.py
+6
-49
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+10
-49
未找到文件。
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
70770d0d
...
@@ -223,8 +223,8 @@ class Engine:
...
@@ -223,8 +223,8 @@ class Engine:
assert
"dataset"
in
self
.
_user_tuning_config
,
"Optimization Tuning should provide with dataset."
assert
"dataset"
in
self
.
_user_tuning_config
,
"Optimization Tuning should provide with dataset."
batch_size
=
self
.
_user_tuning_config
[
"batch_size"
]
batch_size
=
self
.
_user_tuning_config
[
"batch_size"
]
dataset
=
self
.
_user_tuning_config
[
"dataset"
]
dataset
=
self
.
_user_tuning_config
[
"dataset"
]
dataset
.
dp_world_size
=
self
.
_
dp_world
_size
dataset
.
dp_world_size
=
self
.
_
input_split
_size
dataset
.
dp_rank
=
self
.
_
dp
_rank
dataset
.
dp_rank
=
self
.
_
input_split
_rank
from
.tuner.optimization_tuner
import
OptimizationTuner
from
.tuner.optimization_tuner
import
OptimizationTuner
self
.
_optimization_tuner
=
OptimizationTuner
(
self
.
_user_tuning_config
,
self
.
_optimization_tuner
=
OptimizationTuner
(
self
.
_user_tuning_config
,
...
@@ -262,7 +262,7 @@ class Engine:
...
@@ -262,7 +262,7 @@ class Engine:
if
var
.
name
in
block
.
vars
:
if
var
.
name
in
block
.
vars
:
feed_list
.
append
(
block
.
vars
[
var
.
name
])
feed_list
.
append
(
block
.
vars
[
var
.
name
])
self
.
_
dp_world_size
,
self
.
_dp_rank
=
self
.
_get_data_parallel
_info
(
self
.
_
input_split_size
,
self
.
_input_split_rank
=
self
.
_get_input_split
_info
(
feed_list
[
0
],
self
.
_dist_contexts
[
mode
])
feed_list
[
0
],
self
.
_dist_contexts
[
mode
])
def
_parallel
(
self
,
mode
,
all_ranks
):
def
_parallel
(
self
,
mode
,
all_ranks
):
...
@@ -554,8 +554,8 @@ class Engine:
...
@@ -554,8 +554,8 @@ class Engine:
batch_size
,
batch_size
,
epochs
,
epochs
,
steps_per_epoch
,
steps_per_epoch
,
data_parallel_world_size
=
self
.
_
dp_world
_size
,
data_parallel_world_size
=
self
.
_
input_split
_size
,
data_parallel_rank
=
self
.
_
dp
_rank
)
data_parallel_rank
=
self
.
_
input_split
_rank
)
# move read op from the end of program to the start of program
# move read op from the end of program to the start of program
new_op_size
=
len
(
dist_main_block
.
ops
)
new_op_size
=
len
(
dist_main_block
.
ops
)
...
@@ -615,8 +615,8 @@ class Engine:
...
@@ -615,8 +615,8 @@ class Engine:
fetches
=
dict
(
inner_fetch
,
**
usr_fetch
)
fetches
=
dict
(
inner_fetch
,
**
usr_fetch
)
return
list
(
fetches
.
keys
()),
fetches
return
list
(
fetches
.
keys
()),
fetches
def
_get_
data_parallel
_info
(
self
,
var
,
dist_context
):
def
_get_
input_split
_info
(
self
,
var
,
dist_context
):
#
get data parallel world size and current data parallel rank
#
deduce how the input data is split among the cluster
from
.utils
import
_get_comm_group
,
_get_corresponding_rank
from
.utils
import
_get_comm_group
,
_get_corresponding_rank
tensor_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
var
)
tensor_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
var
)
...
...
python/paddle/distributed/auto_parallel/operators/common.py
浏览文件 @
70770d0d
...
@@ -13,7 +13,11 @@
...
@@ -13,7 +13,11 @@
# limitations under the License
# limitations under the License
import
abc
import
abc
import
paddle
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
from
..dist_attribute
import
OperatorDistributedAttribute
from
..dist_attribute
import
OperatorDistributedAttribute
from
..utils
import
_get_comm_group
,
_get_corresponding_rank
from
..process_group
import
new_process_group
_g_distributed_operator_impl_containers
=
{}
_g_distributed_operator_impl_containers
=
{}
...
@@ -24,6 +28,16 @@ _g_elementwise_ops = [
...
@@ -24,6 +28,16 @@ _g_elementwise_ops = [
BACKWARD_ONLY_DIST_OPS
=
{
'check_finite_and_unscale'
,
'update_loss_scaling'
}
BACKWARD_ONLY_DIST_OPS
=
{
'check_finite_and_unscale'
,
'update_loss_scaling'
}
class
ParallelMode
():
"""
the parallel mode for communication or auxiliary operator
"""
DataParallel
=
"auto_parallel/data_parallel"
ModelParallel
=
"auto_parallel/model_parallel"
PipelineParalel
=
"auto_parallel/pipeline_paralel"
MoEParallel
=
"auto_parallel/moe_parallel"
def
is_elementwise_op
(
op_type
):
def
is_elementwise_op
(
op_type
):
if
op_type
in
_g_elementwise_ops
:
if
op_type
in
_g_elementwise_ops
:
return
True
return
True
...
@@ -303,3 +317,121 @@ def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx):
...
@@ -303,3 +317,121 @@ def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx):
new_op
.
output
(
output_name
)[
0
],
ref_tensor_dist_attr
)
new_op
.
output
(
output_name
)[
0
],
ref_tensor_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
new_op
,
new_op_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
new_op
,
new_op_dist_attr
)
def
get_data_parallel_group
(
dist_ctx
,
op
,
act_grad_names
,
rank
):
"""
deduce the data parallel communication group for current operator.
Args:
dist_ctx (DistributedContext): dist context.
op (Operator): the current (backward) operator which might need.
act_grad_names (list): list of input activation grads variable name to the current operator.
out_grad_names (list): list of the output parameter's grads variable name of the current operator.
rank (int): global ranks index for current process.
"""
dp_group
=
None
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
process_mesh
=
op_dist_attr
.
process_mesh
mesh_shape
=
process_mesh
.
topology
# FIXME Hack for Pipeline Parallelism where the current operator
# not belong to the mesh the current rank belong to.
if
rank
not
in
process_mesh
.
processes
:
rank
=
_get_corresponding_rank
(
dist_ctx
,
process_mesh
,
rank
)
for
var_name
in
act_grad_names
:
var_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
var_name
)
# consider that the variable's shape is None
# TODO utilize the batch_dim attr instead of "0" in future
batch_size_axis
=
var_dim_mapping
[
0
]
if
len
(
var_dim_mapping
)
>
0
else
-
1
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
:
group_ranks
=
_get_comm_group
(
process_mesh
.
processes
,
process_mesh
.
topology
,
batch_size_axis
,
rank
)
dp_group
=
new_process_group
(
group_ranks
)
break
return
dp_group
def
sync_and_scale_gradients
(
dist_ctx
,
op
,
dp_group
,
allreduce_var_names
):
"""
insert the allreudce and scale ops for gradients of model
parameters for operator in data parallelism.
Args:
dist_ctx (DistributedContext): dist context.
op (Operator): the current (backward) operator which might need.
allreduce_var_names (list): list of the parameter's grads variable name in the current operator output.
"""
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
process_mesh
=
op_dist_attr
.
process_mesh
dist_op_context
=
dist_ctx
.
dist_op_context
main_block
=
dist_op_context
.
work_block
dp_degree
=
len
(
dp_group
.
ranks
)
for
var_name
in
allreduce_var_names
:
added_ops
=
[]
grad_var
=
main_block
.
var
(
var_name
)
allreduce_op
=
main_block
.
append_op
(
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
[
grad_var
]},
outputs
=
{
'Out'
:
[
grad_var
]},
attrs
=
{
'ring_id'
:
dp_group
.
id
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
allreduce_op
.
_set_attr
(
'op_namescope'
,
str
(
'/'
)
+
ParallelMode
.
DataParallel
)
added_ops
.
append
(
allreduce_op
)
if
dist_ctx
.
gradient_scale
:
scale_op
=
main_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
grad_var
},
outputs
=
{
'Out'
:
grad_var
},
attrs
=
{
'scale'
:
1.0
/
dp_degree
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
scale_op
.
_set_attr
(
'op_namescope'
,
str
(
'/'
)
+
ParallelMode
.
DataParallel
)
added_ops
.
append
(
scale_op
)
dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
grad_var
.
name
)
assert
dims_mapping
is
not
None
,
"Unexception: dims_mapping of output [{}] of op [{}] is None"
.
format
(
grad_var
.
name
,
op_dist_attr
.
op_type
)
# NOTE auxiliary op's dist attr should follow dist_op not dist_tensor
for
new_op
in
added_ops
:
new_op_attr
=
OperatorDistributedAttribute
()
new_op_attr
.
process_mesh
=
process_mesh
new_op_attr
.
set_output_dims_mapping
(
grad_var
.
name
,
dims_mapping
)
new_op_attr
.
set_input_dims_mapping
(
grad_var
.
name
,
dims_mapping
)
dist_ctx
.
set_op_dist_attr_for_program
(
new_op
,
new_op_attr
)
def
gradient_synchronization
(
dist_ctx
,
op
,
act_grad_names
,
out_grad_names
,
rank
):
"""
conduct the allreudce and scaling(dp size)for gradients of model
parameters for operator in data parallelism.
Args:
dist_ctx (DistributedContext): dist context.
op (Operator): the current (backward) operator which might need.
act_grad_names (list): list of input activation grads variable name to the current operator.
out_grad_names (list): list of the output parameter's grads variable name of the current operator.
rank (int): global ranks index for current process.
"""
if
len
(
act_grad_names
)
==
0
or
len
(
out_grad_names
)
==
0
:
return
dp_group
=
get_data_parallel_group
(
dist_ctx
,
op
,
act_grad_names
,
rank
)
if
not
dp_group
:
return
sync_and_scale_gradients
(
dist_ctx
,
op
,
dp_group
,
out_grad_names
)
python/paddle/distributed/auto_parallel/operators/dist_default.py
浏览文件 @
70770d0d
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
from
.common
import
DistributedOperatorImplContainer
from
.common
import
DistributedOperatorImplContainer
from
.common
import
DistributedOperatorImpl
from
.common
import
DistributedOperatorImpl
from
.common
import
register_distributed_operator_impl_container
from
.common
import
register_distributed_operator_impl_container
from
.common
import
gradient_synchronization
from
.common
import
register_distributed_operator_impl
,
is_parameter_related
from
.common
import
register_distributed_operator_impl
,
is_parameter_related
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_replicate
from
..utils
import
is_dim_replicate
...
@@ -537,39 +538,15 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
...
@@ -537,39 +538,15 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for
output_name
in
backward_op
.
desc
.
output_names
():
for
output_name
in
backward_op
.
desc
.
output_names
():
dist_op_desc
.
set_output
(
output_name
,
kwargs
[
output_name
])
dist_op_desc
.
set_output
(
output_name
,
kwargs
[
output_name
])
# check if need gradient allreduce
# data parallel gradient synchronization
# if there is a non-gradient & non-parameter input and its batch dimension is splited,
act_grad_names
=
[]
# we need insert gradient allreduce for the gradient of parameter in its output
need_gradient_allreduce
=
False
for
input_name
in
backward_op
.
desc
.
input_names
():
for
input_name
in
backward_op
.
desc
.
input_names
():
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
if
"@GRAD"
not
in
varname
and
not
is_parameter_related
(
if
"@GRAD"
not
in
varname
and
not
is_parameter_related
(
varname
,
main_block
):
varname
,
main_block
):
act_grad_names
.
append
(
varname
)
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
out_grad_names
=
[]
process_mesh
=
dist_attr
.
process_mesh
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
varname
)
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
process_mesh
,
rank_id
)
# NOTE: consider that the variable's shape is None
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
len
(
var_dim_mapping
)
>
0
else
-
1
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
:
need_gradient_allreduce
=
True
group_ranks
=
_get_comm_group
(
process_mesh
.
processes
,
process_mesh
.
topology
,
batch_size_axis
,
rank_id
)
dp_degree
=
len
(
group_ranks
)
dp_group
=
new_process_group
(
group_ranks
)
break
if
need_gradient_allreduce
:
allreduce_vars
=
[]
for
output_name
in
backward_op
.
desc
.
output_names
():
for
output_name
in
backward_op
.
desc
.
output_names
():
for
varname
in
backward_op
.
desc
.
output
(
output_name
):
for
varname
in
backward_op
.
desc
.
output
(
output_name
):
if
varname
in
kwargs
[
"grad_var_to_var"
]:
if
varname
in
kwargs
[
"grad_var_to_var"
]:
...
@@ -577,47 +554,10 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
...
@@ -577,47 +554,10 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
if
fwd_name
not
in
main_block
.
vars
:
if
fwd_name
not
in
main_block
.
vars
:
continue
continue
if
is_parameter_related
(
fwd_name
,
main_block
):
if
is_parameter_related
(
fwd_name
,
main_block
):
allreduce_vars
.
append
(
varname
)
out_grad_names
.
append
(
varname
)
if
len
(
allreduce_vars
)
>
0
:
for
varname
in
allreduce_vars
:
added_ops
=
[]
grad_var
=
main_block
.
var
(
varname
)
gradient_synchronization
(
ctx
,
backward_op
,
act_grad_names
,
allreduce_op
=
main_block
.
append_op
(
out_grad_names
,
rank_id
)
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
[
grad_var
]},
outputs
=
{
'Out'
:
[
grad_var
]},
attrs
=
{
'ring_id'
:
dp_group
.
id
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
allreduce_op
)
if
ctx
.
gradient_scale
:
scale_op
=
main_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
grad_var
},
outputs
=
{
'Out'
:
grad_var
},
attrs
=
{
'scale'
:
1.0
/
dp_degree
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
scale_op
)
dims_mapping
=
ctx
.
get_tensor_dist_attr_for_program
(
grad_var
).
dims_mapping
process_mesh
=
dist_attr
.
process_mesh
for
op
in
added_ops
:
op_attr
=
OperatorDistributedAttribute
()
op_attr
.
process_mesh
=
process_mesh
op_attr
.
set_output_dims_mapping
(
grad_var
.
name
,
dims_mapping
)
op_attr
.
set_input_dims_mapping
(
grad_var
.
name
,
dims_mapping
)
ctx
.
set_op_dist_attr_for_program
(
op
,
op_attr
)
register_distributed_operator_impl
(
register_distributed_operator_impl
(
...
...
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
浏览文件 @
70770d0d
...
@@ -16,6 +16,7 @@ from .common import infer_shape
...
@@ -16,6 +16,7 @@ from .common import infer_shape
from
.common
import
DistributedOperatorImplContainer
from
.common
import
DistributedOperatorImplContainer
from
.common
import
DistributedOperatorImpl
from
.common
import
DistributedOperatorImpl
from
.common
import
register_distributed_operator_impl_container
from
.common
import
register_distributed_operator_impl_container
from
.common
import
gradient_synchronization
from
.common
import
register_distributed_operator_impl
,
set_comm_op_dist_attr_for_program
,
naive_copy_op_dist_attr_for_program
,
is_parameter_related
from
.common
import
register_distributed_operator_impl
,
set_comm_op_dist_attr_for_program
,
naive_copy_op_dist_attr_for_program
,
is_parameter_related
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_replicate
from
..utils
import
is_dim_replicate
...
@@ -518,56 +519,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
...
@@ -518,56 +519,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
naive_copy_op_dist_attr_for_program
(
c_embedding_grad_op
,
backward_op
,
naive_copy_op_dist_attr_for_program
(
c_embedding_grad_op
,
backward_op
,
ctx
)
ctx
)
# check if need gradient allreduce
# data parallel gradient synchronization
need_gradient_allreduce
=
False
act_grad_names
=
[
Ids_var
.
name
]
out_grad_names
=
[
kwargs
[
'W@GRAD'
][
0
]]
process_mesh
=
dist_attr
.
process_mesh
gradient_synchronization
(
ctx
,
backward_op
,
act_grad_names
,
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Ids_var
.
name
)
out_grad_names
,
rank_id
)
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
:
need_gradient_allreduce
=
True
group_ranks
=
_get_comm_group
(
process_mesh
.
processes
,
process_mesh
.
topology
,
batch_size_axis
,
rank_id
)
dp_degree
=
len
(
group_ranks
)
dp_group
=
new_process_group
(
group_ranks
)
if
need_gradient_allreduce
:
added_ops
=
[]
W_Grad_var
=
main_block
.
var
(
kwargs
[
'W@GRAD'
][
0
])
allreduce_op
=
main_block
.
append_op
(
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
[
W_Grad_var
]},
outputs
=
{
'Out'
:
[
W_Grad_var
]},
attrs
=
{
'ring_id'
:
dp_group
.
id
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
allreduce_op
)
if
ctx
.
gradient_scale
:
scale_op
=
main_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
W_Grad_var
},
outputs
=
{
'Out'
:
W_Grad_var
},
attrs
=
{
'scale'
:
1.0
/
dp_degree
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
scale_op
)
main_block
.
_sync_with_cpp
()
dims_mapping
=
ctx
.
get_tensor_dist_attr_for_program
(
W_Grad_var
).
dims_mapping
process_mesh
=
dist_attr
.
process_mesh
for
op
in
added_ops
:
op_attr
=
OperatorDistributedAttribute
()
op_attr
.
process_mesh
=
process_mesh
op_attr
.
set_output_dims_mapping
(
W_Grad_var
.
name
,
dims_mapping
)
op_attr
.
set_input_dims_mapping
(
W_Grad_var
.
name
,
dims_mapping
)
ctx
.
set_op_dist_attr_for_program
(
op
,
op_attr
)
register_distributed_operator_impl
(
"lookup_table_v2"
,
register_distributed_operator_impl
(
"lookup_table_v2"
,
...
...
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
70770d0d
...
@@ -19,6 +19,7 @@ from .common import DistributedOperatorImplContainer
...
@@ -19,6 +19,7 @@ from .common import DistributedOperatorImplContainer
from
.common
import
DistributedOperatorImpl
from
.common
import
DistributedOperatorImpl
from
.common
import
register_distributed_operator_impl_container
from
.common
import
register_distributed_operator_impl_container
from
.common
import
register_distributed_operator_impl
from
.common
import
register_distributed_operator_impl
from
.common
import
gradient_synchronization
from
.common
import
set_comm_op_dist_attr_for_program
,
naive_copy_op_dist_attr_for_program
,
is_parameter_related
from
.common
import
set_comm_op_dist_attr_for_program
,
naive_copy_op_dist_attr_for_program
,
is_parameter_related
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_replicate
from
..utils
import
is_dim_replicate
...
@@ -422,55 +423,15 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
...
@@ -422,55 +423,15 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
matmul_op_desc
=
copy_op_with_new_input_output
(
ctx
,
main_block
,
matmul_op_desc
=
copy_op_with_new_input_output
(
ctx
,
main_block
,
backward_op
,
**
kwargs
)
backward_op
,
**
kwargs
)
#
check if need gradient allreduce
#
data parallel gradient synchronization
need_gradient_allreduce
=
False
act_grad_names
=
[
X_var
.
name
]
process_mesh
=
dist_attr
.
process_mesh
out_grad_names
=
[]
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
X_var
.
name
)
if
is_parameter_related
(
Y_var
.
name
,
main_block
):
mesh_shape
=
process_mesh
.
topology
out_grad_names
=
[
kwargs
[
'Y@GRAD'
][
0
]]
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
:
need_gradient_allreduce
=
True
group_ranks
=
_get_comm_group
(
process_mesh
.
processes
,
process_mesh
.
topology
,
batch_size_axis
,
rank_id
)
dp_degree
=
len
(
group_ranks
)
dp_group
=
new_process_group
(
group_ranks
)
if
need_gradient_allreduce
and
is_parameter_related
(
Y_var
.
name
,
main_block
):
added_ops
=
[]
Y_Grad_var
=
main_block
.
var
(
kwargs
[
'Y@GRAD'
][
0
])
allreduce_op
=
main_block
.
append_op
(
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
[
Y_Grad_var
]},
outputs
=
{
'Out'
:
[
Y_Grad_var
]},
attrs
=
{
'ring_id'
:
dp_group
.
id
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
allreduce_op
)
if
ctx
.
gradient_scale
:
scale_op
=
main_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
Y_Grad_var
},
outputs
=
{
'Out'
:
Y_Grad_var
},
attrs
=
{
'scale'
:
1.0
/
dp_degree
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
scale_op
)
main_block
.
_sync_with_cpp
()
dims_mapping
=
ctx
.
get_tensor_dist_attr_for_program
(
gradient_synchronization
(
ctx
,
backward_op
,
act_grad_names
,
out_grad_names
,
Y_Grad_var
).
dims_mapping
rank_id
)
process_mesh
=
dist_attr
.
process_mesh
for
op
in
added_ops
:
op_attr
=
OperatorDistributedAttribute
()
op_attr
.
process_mesh
=
process_mesh
op_attr
.
set_output_dims_mapping
(
Y_Grad_var
.
name
,
dims_mapping
)
op_attr
.
set_input_dims_mapping
(
Y_Grad_var
.
name
,
dims_mapping
)
ctx
.
set_op_dist_attr_for_program
(
op
,
op_attr
)
def
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
):
def
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录