Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
70770d0d
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
70770d0d
编写于
8月 03, 2022
作者:
J
JZ-LIANG
提交者:
GitHub
8月 03, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Unify gradient synchronization procedure of data parallel (#44815)
上级
0e6bf744
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
171 addition
and
181 deletion
+171
-181
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+7
-7
python/paddle/distributed/auto_parallel/operators/common.py
python/paddle/distributed/auto_parallel/operators/common.py
+132
-0
python/paddle/distributed/auto_parallel/operators/dist_default.py
...addle/distributed/auto_parallel/operators/dist_default.py
+16
-76
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
...dle/distributed/auto_parallel/operators/dist_embedding.py
+6
-49
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+10
-49
未找到文件。
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
70770d0d
...
@@ -223,8 +223,8 @@ class Engine:
...
@@ -223,8 +223,8 @@ class Engine:
assert
"dataset"
in
self
.
_user_tuning_config
,
"Optimization Tuning should provide with dataset."
assert
"dataset"
in
self
.
_user_tuning_config
,
"Optimization Tuning should provide with dataset."
batch_size
=
self
.
_user_tuning_config
[
"batch_size"
]
batch_size
=
self
.
_user_tuning_config
[
"batch_size"
]
dataset
=
self
.
_user_tuning_config
[
"dataset"
]
dataset
=
self
.
_user_tuning_config
[
"dataset"
]
dataset
.
dp_world_size
=
self
.
_
dp_world
_size
dataset
.
dp_world_size
=
self
.
_
input_split
_size
dataset
.
dp_rank
=
self
.
_
dp
_rank
dataset
.
dp_rank
=
self
.
_
input_split
_rank
from
.tuner.optimization_tuner
import
OptimizationTuner
from
.tuner.optimization_tuner
import
OptimizationTuner
self
.
_optimization_tuner
=
OptimizationTuner
(
self
.
_user_tuning_config
,
self
.
_optimization_tuner
=
OptimizationTuner
(
self
.
_user_tuning_config
,
...
@@ -262,7 +262,7 @@ class Engine:
...
@@ -262,7 +262,7 @@ class Engine:
if
var
.
name
in
block
.
vars
:
if
var
.
name
in
block
.
vars
:
feed_list
.
append
(
block
.
vars
[
var
.
name
])
feed_list
.
append
(
block
.
vars
[
var
.
name
])
self
.
_
dp_world_size
,
self
.
_dp_rank
=
self
.
_get_data_parallel
_info
(
self
.
_
input_split_size
,
self
.
_input_split_rank
=
self
.
_get_input_split
_info
(
feed_list
[
0
],
self
.
_dist_contexts
[
mode
])
feed_list
[
0
],
self
.
_dist_contexts
[
mode
])
def
_parallel
(
self
,
mode
,
all_ranks
):
def
_parallel
(
self
,
mode
,
all_ranks
):
...
@@ -554,8 +554,8 @@ class Engine:
...
@@ -554,8 +554,8 @@ class Engine:
batch_size
,
batch_size
,
epochs
,
epochs
,
steps_per_epoch
,
steps_per_epoch
,
data_parallel_world_size
=
self
.
_
dp_world
_size
,
data_parallel_world_size
=
self
.
_
input_split
_size
,
data_parallel_rank
=
self
.
_
dp
_rank
)
data_parallel_rank
=
self
.
_
input_split
_rank
)
# move read op from the end of program to the start of program
# move read op from the end of program to the start of program
new_op_size
=
len
(
dist_main_block
.
ops
)
new_op_size
=
len
(
dist_main_block
.
ops
)
...
@@ -615,8 +615,8 @@ class Engine:
...
@@ -615,8 +615,8 @@ class Engine:
fetches
=
dict
(
inner_fetch
,
**
usr_fetch
)
fetches
=
dict
(
inner_fetch
,
**
usr_fetch
)
return
list
(
fetches
.
keys
()),
fetches
return
list
(
fetches
.
keys
()),
fetches
def
_get_
data_parallel
_info
(
self
,
var
,
dist_context
):
def
_get_
input_split
_info
(
self
,
var
,
dist_context
):
#
get data parallel world size and current data parallel rank
#
deduce how the input data is split among the cluster
from
.utils
import
_get_comm_group
,
_get_corresponding_rank
from
.utils
import
_get_comm_group
,
_get_corresponding_rank
tensor_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
var
)
tensor_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
var
)
...
...
python/paddle/distributed/auto_parallel/operators/common.py
浏览文件 @
70770d0d
...
@@ -13,7 +13,11 @@
...
@@ -13,7 +13,11 @@
# limitations under the License
# limitations under the License
import
abc
import
abc
import
paddle
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
from
..dist_attribute
import
OperatorDistributedAttribute
from
..dist_attribute
import
OperatorDistributedAttribute
from
..utils
import
_get_comm_group
,
_get_corresponding_rank
from
..process_group
import
new_process_group
_g_distributed_operator_impl_containers
=
{}
_g_distributed_operator_impl_containers
=
{}
...
@@ -24,6 +28,16 @@ _g_elementwise_ops = [
...
@@ -24,6 +28,16 @@ _g_elementwise_ops = [
BACKWARD_ONLY_DIST_OPS
=
{
'check_finite_and_unscale'
,
'update_loss_scaling'
}
BACKWARD_ONLY_DIST_OPS
=
{
'check_finite_and_unscale'
,
'update_loss_scaling'
}
class
ParallelMode
():
"""
the parallel mode for communication or auxiliary operator
"""
DataParallel
=
"auto_parallel/data_parallel"
ModelParallel
=
"auto_parallel/model_parallel"
PipelineParalel
=
"auto_parallel/pipeline_paralel"
MoEParallel
=
"auto_parallel/moe_parallel"
def
is_elementwise_op
(
op_type
):
def
is_elementwise_op
(
op_type
):
if
op_type
in
_g_elementwise_ops
:
if
op_type
in
_g_elementwise_ops
:
return
True
return
True
...
@@ -303,3 +317,121 @@ def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx):
...
@@ -303,3 +317,121 @@ def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx):
new_op
.
output
(
output_name
)[
0
],
ref_tensor_dist_attr
)
new_op
.
output
(
output_name
)[
0
],
ref_tensor_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
new_op
,
new_op_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
new_op
,
new_op_dist_attr
)
def
get_data_parallel_group
(
dist_ctx
,
op
,
act_grad_names
,
rank
):
"""
deduce the data parallel communication group for current operator.
Args:
dist_ctx (DistributedContext): dist context.
op (Operator): the current (backward) operator which might need.
act_grad_names (list): list of input activation grads variable name to the current operator.
out_grad_names (list): list of the output parameter's grads variable name of the current operator.
rank (int): global ranks index for current process.
"""
dp_group
=
None
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
process_mesh
=
op_dist_attr
.
process_mesh
mesh_shape
=
process_mesh
.
topology
# FIXME Hack for Pipeline Parallelism where the current operator
# not belong to the mesh the current rank belong to.
if
rank
not
in
process_mesh
.
processes
:
rank
=
_get_corresponding_rank
(
dist_ctx
,
process_mesh
,
rank
)
for
var_name
in
act_grad_names
:
var_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
var_name
)
# consider that the variable's shape is None
# TODO utilize the batch_dim attr instead of "0" in future
batch_size_axis
=
var_dim_mapping
[
0
]
if
len
(
var_dim_mapping
)
>
0
else
-
1
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
:
group_ranks
=
_get_comm_group
(
process_mesh
.
processes
,
process_mesh
.
topology
,
batch_size_axis
,
rank
)
dp_group
=
new_process_group
(
group_ranks
)
break
return
dp_group
def
sync_and_scale_gradients
(
dist_ctx
,
op
,
dp_group
,
allreduce_var_names
):
"""
insert the allreudce and scale ops for gradients of model
parameters for operator in data parallelism.
Args:
dist_ctx (DistributedContext): dist context.
op (Operator): the current (backward) operator which might need.
allreduce_var_names (list): list of the parameter's grads variable name in the current operator output.
"""
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
process_mesh
=
op_dist_attr
.
process_mesh
dist_op_context
=
dist_ctx
.
dist_op_context
main_block
=
dist_op_context
.
work_block
dp_degree
=
len
(
dp_group
.
ranks
)
for
var_name
in
allreduce_var_names
:
added_ops
=
[]
grad_var
=
main_block
.
var
(
var_name
)
allreduce_op
=
main_block
.
append_op
(
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
[
grad_var
]},
outputs
=
{
'Out'
:
[
grad_var
]},
attrs
=
{
'ring_id'
:
dp_group
.
id
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
allreduce_op
.
_set_attr
(
'op_namescope'
,
str
(
'/'
)
+
ParallelMode
.
DataParallel
)
added_ops
.
append
(
allreduce_op
)
if
dist_ctx
.
gradient_scale
:
scale_op
=
main_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
grad_var
},
outputs
=
{
'Out'
:
grad_var
},
attrs
=
{
'scale'
:
1.0
/
dp_degree
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
scale_op
.
_set_attr
(
'op_namescope'
,
str
(
'/'
)
+
ParallelMode
.
DataParallel
)
added_ops
.
append
(
scale_op
)
dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
grad_var
.
name
)
assert
dims_mapping
is
not
None
,
"Unexception: dims_mapping of output [{}] of op [{}] is None"
.
format
(
grad_var
.
name
,
op_dist_attr
.
op_type
)
# NOTE auxiliary op's dist attr should follow dist_op not dist_tensor
for
new_op
in
added_ops
:
new_op_attr
=
OperatorDistributedAttribute
()
new_op_attr
.
process_mesh
=
process_mesh
new_op_attr
.
set_output_dims_mapping
(
grad_var
.
name
,
dims_mapping
)
new_op_attr
.
set_input_dims_mapping
(
grad_var
.
name
,
dims_mapping
)
dist_ctx
.
set_op_dist_attr_for_program
(
new_op
,
new_op_attr
)
def
gradient_synchronization
(
dist_ctx
,
op
,
act_grad_names
,
out_grad_names
,
rank
):
"""
conduct the allreudce and scaling(dp size)for gradients of model
parameters for operator in data parallelism.
Args:
dist_ctx (DistributedContext): dist context.
op (Operator): the current (backward) operator which might need.
act_grad_names (list): list of input activation grads variable name to the current operator.
out_grad_names (list): list of the output parameter's grads variable name of the current operator.
rank (int): global ranks index for current process.
"""
if
len
(
act_grad_names
)
==
0
or
len
(
out_grad_names
)
==
0
:
return
dp_group
=
get_data_parallel_group
(
dist_ctx
,
op
,
act_grad_names
,
rank
)
if
not
dp_group
:
return
sync_and_scale_gradients
(
dist_ctx
,
op
,
dp_group
,
out_grad_names
)
python/paddle/distributed/auto_parallel/operators/dist_default.py
浏览文件 @
70770d0d
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
from
.common
import
DistributedOperatorImplContainer
from
.common
import
DistributedOperatorImplContainer
from
.common
import
DistributedOperatorImpl
from
.common
import
DistributedOperatorImpl
from
.common
import
register_distributed_operator_impl_container
from
.common
import
register_distributed_operator_impl_container
from
.common
import
gradient_synchronization
from
.common
import
register_distributed_operator_impl
,
is_parameter_related
from
.common
import
register_distributed_operator_impl
,
is_parameter_related
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_replicate
from
..utils
import
is_dim_replicate
...
@@ -537,39 +538,15 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
...
@@ -537,39 +538,15 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for
output_name
in
backward_op
.
desc
.
output_names
():
for
output_name
in
backward_op
.
desc
.
output_names
():
dist_op_desc
.
set_output
(
output_name
,
kwargs
[
output_name
])
dist_op_desc
.
set_output
(
output_name
,
kwargs
[
output_name
])
# check if need gradient allreduce
# data parallel gradient synchronization
# if there is a non-gradient & non-parameter input and its batch dimension is splited,
act_grad_names
=
[]
# we need insert gradient allreduce for the gradient of parameter in its output
need_gradient_allreduce
=
False
for
input_name
in
backward_op
.
desc
.
input_names
():
for
input_name
in
backward_op
.
desc
.
input_names
():
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
if
"@GRAD"
not
in
varname
and
not
is_parameter_related
(
if
"@GRAD"
not
in
varname
and
not
is_parameter_related
(
varname
,
main_block
):
varname
,
main_block
):
act_grad_names
.
append
(
varname
)
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
out_grad_names
=
[]
process_mesh
=
dist_attr
.
process_mesh
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
varname
)
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
process_mesh
,
rank_id
)
# NOTE: consider that the variable's shape is None
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
len
(
var_dim_mapping
)
>
0
else
-
1
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
:
need_gradient_allreduce
=
True
group_ranks
=
_get_comm_group
(
process_mesh
.
processes
,
process_mesh
.
topology
,
batch_size_axis
,
rank_id
)
dp_degree
=
len
(
group_ranks
)
dp_group
=
new_process_group
(
group_ranks
)
break
if
need_gradient_allreduce
:
allreduce_vars
=
[]
for
output_name
in
backward_op
.
desc
.
output_names
():
for
output_name
in
backward_op
.
desc
.
output_names
():
for
varname
in
backward_op
.
desc
.
output
(
output_name
):
for
varname
in
backward_op
.
desc
.
output
(
output_name
):
if
varname
in
kwargs
[
"grad_var_to_var"
]:
if
varname
in
kwargs
[
"grad_var_to_var"
]:
...
@@ -577,47 +554,10 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
...
@@ -577,47 +554,10 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
if
fwd_name
not
in
main_block
.
vars
:
if
fwd_name
not
in
main_block
.
vars
:
continue
continue
if
is_parameter_related
(
fwd_name
,
main_block
):
if
is_parameter_related
(
fwd_name
,
main_block
):
allreduce_vars
.
append
(
varname
)
out_grad_names
.
append
(
varname
)
if
len
(
allreduce_vars
)
>
0
:
for
varname
in
allreduce_vars
:
added_ops
=
[]
grad_var
=
main_block
.
var
(
varname
)
gradient_synchronization
(
ctx
,
backward_op
,
act_grad_names
,
allreduce_op
=
main_block
.
append_op
(
out_grad_names
,
rank_id
)
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
[
grad_var
]},
outputs
=
{
'Out'
:
[
grad_var
]},
attrs
=
{
'ring_id'
:
dp_group
.
id
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
allreduce_op
)
if
ctx
.
gradient_scale
:
scale_op
=
main_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
grad_var
},
outputs
=
{
'Out'
:
grad_var
},
attrs
=
{
'scale'
:
1.0
/
dp_degree
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
scale_op
)
dims_mapping
=
ctx
.
get_tensor_dist_attr_for_program
(
grad_var
).
dims_mapping
process_mesh
=
dist_attr
.
process_mesh
for
op
in
added_ops
:
op_attr
=
OperatorDistributedAttribute
()
op_attr
.
process_mesh
=
process_mesh
op_attr
.
set_output_dims_mapping
(
grad_var
.
name
,
dims_mapping
)
op_attr
.
set_input_dims_mapping
(
grad_var
.
name
,
dims_mapping
)
ctx
.
set_op_dist_attr_for_program
(
op
,
op_attr
)
register_distributed_operator_impl
(
register_distributed_operator_impl
(
...
...
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
浏览文件 @
70770d0d
...
@@ -16,6 +16,7 @@ from .common import infer_shape
...
@@ -16,6 +16,7 @@ from .common import infer_shape
from
.common
import
DistributedOperatorImplContainer
from
.common
import
DistributedOperatorImplContainer
from
.common
import
DistributedOperatorImpl
from
.common
import
DistributedOperatorImpl
from
.common
import
register_distributed_operator_impl_container
from
.common
import
register_distributed_operator_impl_container
from
.common
import
gradient_synchronization
from
.common
import
register_distributed_operator_impl
,
set_comm_op_dist_attr_for_program
,
naive_copy_op_dist_attr_for_program
,
is_parameter_related
from
.common
import
register_distributed_operator_impl
,
set_comm_op_dist_attr_for_program
,
naive_copy_op_dist_attr_for_program
,
is_parameter_related
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_replicate
from
..utils
import
is_dim_replicate
...
@@ -518,56 +519,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
...
@@ -518,56 +519,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
naive_copy_op_dist_attr_for_program
(
c_embedding_grad_op
,
backward_op
,
naive_copy_op_dist_attr_for_program
(
c_embedding_grad_op
,
backward_op
,
ctx
)
ctx
)
# check if need gradient allreduce
# data parallel gradient synchronization
need_gradient_allreduce
=
False
act_grad_names
=
[
Ids_var
.
name
]
out_grad_names
=
[
kwargs
[
'W@GRAD'
][
0
]]
process_mesh
=
dist_attr
.
process_mesh
gradient_synchronization
(
ctx
,
backward_op
,
act_grad_names
,
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Ids_var
.
name
)
out_grad_names
,
rank_id
)
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
:
need_gradient_allreduce
=
True
group_ranks
=
_get_comm_group
(
process_mesh
.
processes
,
process_mesh
.
topology
,
batch_size_axis
,
rank_id
)
dp_degree
=
len
(
group_ranks
)
dp_group
=
new_process_group
(
group_ranks
)
if
need_gradient_allreduce
:
added_ops
=
[]
W_Grad_var
=
main_block
.
var
(
kwargs
[
'W@GRAD'
][
0
])
allreduce_op
=
main_block
.
append_op
(
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
[
W_Grad_var
]},
outputs
=
{
'Out'
:
[
W_Grad_var
]},
attrs
=
{
'ring_id'
:
dp_group
.
id
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
allreduce_op
)
if
ctx
.
gradient_scale
:
scale_op
=
main_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
W_Grad_var
},
outputs
=
{
'Out'
:
W_Grad_var
},
attrs
=
{
'scale'
:
1.0
/
dp_degree
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
scale_op
)
main_block
.
_sync_with_cpp
()
dims_mapping
=
ctx
.
get_tensor_dist_attr_for_program
(
W_Grad_var
).
dims_mapping
process_mesh
=
dist_attr
.
process_mesh
for
op
in
added_ops
:
op_attr
=
OperatorDistributedAttribute
()
op_attr
.
process_mesh
=
process_mesh
op_attr
.
set_output_dims_mapping
(
W_Grad_var
.
name
,
dims_mapping
)
op_attr
.
set_input_dims_mapping
(
W_Grad_var
.
name
,
dims_mapping
)
ctx
.
set_op_dist_attr_for_program
(
op
,
op_attr
)
register_distributed_operator_impl
(
"lookup_table_v2"
,
register_distributed_operator_impl
(
"lookup_table_v2"
,
...
...
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
70770d0d
...
@@ -19,6 +19,7 @@ from .common import DistributedOperatorImplContainer
...
@@ -19,6 +19,7 @@ from .common import DistributedOperatorImplContainer
from
.common
import
DistributedOperatorImpl
from
.common
import
DistributedOperatorImpl
from
.common
import
register_distributed_operator_impl_container
from
.common
import
register_distributed_operator_impl_container
from
.common
import
register_distributed_operator_impl
from
.common
import
register_distributed_operator_impl
from
.common
import
gradient_synchronization
from
.common
import
set_comm_op_dist_attr_for_program
,
naive_copy_op_dist_attr_for_program
,
is_parameter_related
from
.common
import
set_comm_op_dist_attr_for_program
,
naive_copy_op_dist_attr_for_program
,
is_parameter_related
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_replicate
from
..utils
import
is_dim_replicate
...
@@ -422,55 +423,15 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
...
@@ -422,55 +423,15 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
matmul_op_desc
=
copy_op_with_new_input_output
(
ctx
,
main_block
,
matmul_op_desc
=
copy_op_with_new_input_output
(
ctx
,
main_block
,
backward_op
,
**
kwargs
)
backward_op
,
**
kwargs
)
#
check if need gradient allreduce
#
data parallel gradient synchronization
need_gradient_allreduce
=
False
act_grad_names
=
[
X_var
.
name
]
process_mesh
=
dist_attr
.
process_mesh
out_grad_names
=
[]
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
X_var
.
name
)
if
is_parameter_related
(
Y_var
.
name
,
main_block
):
mesh_shape
=
process_mesh
.
topology
out_grad_names
=
[
kwargs
[
'Y@GRAD'
][
0
]]
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
:
need_gradient_allreduce
=
True
group_ranks
=
_get_comm_group
(
process_mesh
.
processes
,
process_mesh
.
topology
,
batch_size_axis
,
rank_id
)
dp_degree
=
len
(
group_ranks
)
dp_group
=
new_process_group
(
group_ranks
)
if
need_gradient_allreduce
and
is_parameter_related
(
Y_var
.
name
,
main_block
):
added_ops
=
[]
Y_Grad_var
=
main_block
.
var
(
kwargs
[
'Y@GRAD'
][
0
])
allreduce_op
=
main_block
.
append_op
(
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
[
Y_Grad_var
]},
outputs
=
{
'Out'
:
[
Y_Grad_var
]},
attrs
=
{
'ring_id'
:
dp_group
.
id
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
allreduce_op
)
if
ctx
.
gradient_scale
:
scale_op
=
main_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
Y_Grad_var
},
outputs
=
{
'Out'
:
Y_Grad_var
},
attrs
=
{
'scale'
:
1.0
/
dp_degree
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
scale_op
)
main_block
.
_sync_with_cpp
()
dims_mapping
=
ctx
.
get_tensor_dist_attr_for_program
(
gradient_synchronization
(
ctx
,
backward_op
,
act_grad_names
,
out_grad_names
,
Y_Grad_var
).
dims_mapping
rank_id
)
process_mesh
=
dist_attr
.
process_mesh
for
op
in
added_ops
:
op_attr
=
OperatorDistributedAttribute
()
op_attr
.
process_mesh
=
process_mesh
op_attr
.
set_output_dims_mapping
(
Y_Grad_var
.
name
,
dims_mapping
)
op_attr
.
set_input_dims_mapping
(
Y_Grad_var
.
name
,
dims_mapping
)
ctx
.
set_op_dist_attr_for_program
(
op
,
op_attr
)
def
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
):
def
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录