未验证 提交 3603b9b1 编写于 作者: K kangguangli 提交者: GitHub

[Perf] fix static graph performance issue in amp mode with multicard (#52724) (#53115)

* fix

* fix

* fix

* fix

* fix

* fix fuse group order

(cherry picked from commit 38ec37cd)
上级 02f44fcc
......@@ -1534,7 +1534,7 @@ class Fleet:
# i.e. users can not modify current computation graph anymore
context["graph_optimize_ops"] = optimize_ops
context["graph_optimize_grads"] = params_grads
else:
elif loss.block.program._pass_applied is None:
apply_ir_passes(loss.block.program, startup_program, self)
if not self._role_maker._is_heter_parameter_server_mode:
......
......@@ -13,6 +13,8 @@
from paddle import static
from paddle.fluid import core
from paddle.framework import _global_flags
from paddle.framework.ir import apply_build_strategy
from paddle.utils import unique_name
from .common import (
......@@ -146,6 +148,18 @@ class RawProgramOptimizer(MetaOptimizerBase):
optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set
)
if _global_flags()['FLAGS_apply_pass_to_program']:
pass_attrs = {"use_cuda": True}
build_strategy = self.user_defined_strategy.build_strategy._copy()
build_strategy.fuse_all_optimizer_ops = False
build_strategy.fuse_all_reduce_ops = False
apply_build_strategy(
self.main_program,
self.startup_program,
build_strategy,
pass_attrs,
)
self.main_program._pass_applied = True
if self.nranks == 1:
return optimize_ops, params_grads
self._init_process_group()
......@@ -357,24 +371,39 @@ class RawProgramOptimizer(MetaOptimizerBase):
# [([grad0, grad1], [param0, param1]), ([grad2, grad3], [param2, param3])]
# each entry of the list is a tuple stores the grads segment list and
# the corresponding params segment list
grad_param_segments = []
last_dtype = None
# its type is: dict[dtype, list[tuple[list[grad], list[param]]]]
grad_param_segments_by_dtype = {}
# split the grad based on dtype and fused size
for param, grad in param_grads:
if (
len(grad_param_segments) == 0
or len(grad_param_segments[-1][0]) == self.fuse_grad_size_in_num
or grad.dtype != last_dtype
):
grad_param_segments.append(([grad], [param]))
last_dtype = grad.dtype
else:
grad_param_segments[-1][0].append(grad)
grad_param_segments[-1][1].append(param)
if grad.dtype not in grad_param_segments_by_dtype:
grad_param_segments_by_dtype[grad.dtype] = [([], [])]
grad_segment, param_segment = grad_param_segments_by_dtype[
grad.dtype
][-1]
if len(param_segment) == self.fuse_grad_size_in_num:
grad_param_segments_by_dtype[grad.dtype].append(([], []))
grad_segment, param_segment = grad_param_segments_by_dtype[
grad.dtype
][-1]
param_segment.append(param)
grad_segment.append(grad)
grad_param_segments = []
for _, group in grad_param_segments_by_dtype.items():
grad_param_segments.extend(group)
if len(grad_param_segments) == 0:
return
# because the regroup operation make the relative order invalid,
# we need to reorder these fuse group by after_idx
def get_after_idx_of_fuse_group(grad_param_segments):
grad_segment, param_segment = grad_param_segments
return max([outputs_name_to_idx[grad][1] for grad in grad_segment])
grad_param_segments.sort(key=get_after_idx_of_fuse_group)
fused_vars = [None] * len(grad_param_segments)
for i in range(len(grad_param_segments) - 1, -1, -1):
# travers the grad_param_segments in backward
......@@ -390,7 +419,9 @@ class RawProgramOptimizer(MetaOptimizerBase):
stop_gradient=True,
)
fused_vars[i] = fused_var
after_idx = outputs_name_to_idx[grad_segment[-1]][1]
after_idx = max(
[outputs_name_to_idx[grad][1] for grad in grad_segment]
)
block._insert_op_without_sync(
after_idx + 1,
type='c_allreduce_sum',
......
......@@ -5335,6 +5335,8 @@ class Program:
self._fleet_opt = None
self._program_config = None
self._pass_applied = None
# assigned if this program has been parsed by a pipeline optimizer
self._pipeline_opt = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册