未验证 提交 38ec37cd 编写于 作者: K kangguangli 提交者: GitHub

[Perf] fix static graph performance issue in amp mode with multicard (#52724)

* fix

* fix

* fix

* fix

* fix

* fix fuse group order
上级 f6f18835
...@@ -1534,7 +1534,7 @@ class Fleet: ...@@ -1534,7 +1534,7 @@ class Fleet:
# i.e. users can not modify current computation graph anymore # i.e. users can not modify current computation graph anymore
context["graph_optimize_ops"] = optimize_ops context["graph_optimize_ops"] = optimize_ops
context["graph_optimize_grads"] = params_grads context["graph_optimize_grads"] = params_grads
else: elif loss.block.program._pass_applied is None:
apply_ir_passes(loss.block.program, startup_program, self) apply_ir_passes(loss.block.program, startup_program, self)
if not self._role_maker._is_heter_parameter_server_mode: if not self._role_maker._is_heter_parameter_server_mode:
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
from paddle import static from paddle import static
from paddle.fluid import core from paddle.fluid import core
from paddle.framework import _global_flags
from paddle.framework.ir import apply_build_strategy
from paddle.utils import unique_name from paddle.utils import unique_name
from .common import ( from .common import (
...@@ -146,6 +148,18 @@ class RawProgramOptimizer(MetaOptimizerBase): ...@@ -146,6 +148,18 @@ class RawProgramOptimizer(MetaOptimizerBase):
optimize_ops, params_grads = self.inner_opt.minimize( optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set loss, startup_program, parameter_list, no_grad_set
) )
if _global_flags()['FLAGS_apply_pass_to_program']:
pass_attrs = {"use_cuda": True}
build_strategy = self.user_defined_strategy.build_strategy._copy()
build_strategy.fuse_all_optimizer_ops = False
build_strategy.fuse_all_reduce_ops = False
apply_build_strategy(
self.main_program,
self.startup_program,
build_strategy,
pass_attrs,
)
self.main_program._pass_applied = True
if self.nranks == 1: if self.nranks == 1:
return optimize_ops, params_grads return optimize_ops, params_grads
self._init_process_group() self._init_process_group()
...@@ -357,24 +371,39 @@ class RawProgramOptimizer(MetaOptimizerBase): ...@@ -357,24 +371,39 @@ class RawProgramOptimizer(MetaOptimizerBase):
# [([grad0, grad1], [param0, param1]), ([grad2, grad3], [param2, param3])] # [([grad0, grad1], [param0, param1]), ([grad2, grad3], [param2, param3])]
# each entry of the list is a tuple stores the grads segment list and # each entry of the list is a tuple stores the grads segment list and
# the corresponding params segment list # the corresponding params segment list
grad_param_segments = []
last_dtype = None # its type is: dict[dtype, list[tuple[list[grad], list[param]]]]
grad_param_segments_by_dtype = {}
# split the grad based on dtype and fused size # split the grad based on dtype and fused size
for param, grad in param_grads: for param, grad in param_grads:
if ( if grad.dtype not in grad_param_segments_by_dtype:
len(grad_param_segments) == 0 grad_param_segments_by_dtype[grad.dtype] = [([], [])]
or len(grad_param_segments[-1][0]) == self.fuse_grad_size_in_num grad_segment, param_segment = grad_param_segments_by_dtype[
or grad.dtype != last_dtype grad.dtype
): ][-1]
grad_param_segments.append(([grad], [param])) if len(param_segment) == self.fuse_grad_size_in_num:
last_dtype = grad.dtype grad_param_segments_by_dtype[grad.dtype].append(([], []))
else: grad_segment, param_segment = grad_param_segments_by_dtype[
grad_param_segments[-1][0].append(grad) grad.dtype
grad_param_segments[-1][1].append(param) ][-1]
param_segment.append(param)
grad_segment.append(grad)
grad_param_segments = []
for _, group in grad_param_segments_by_dtype.items():
grad_param_segments.extend(group)
if len(grad_param_segments) == 0: if len(grad_param_segments) == 0:
return return
# because the regroup operation make the relative order invalid,
# we need to reorder these fuse group by after_idx
def get_after_idx_of_fuse_group(grad_param_segments):
grad_segment, param_segment = grad_param_segments
return max([outputs_name_to_idx[grad][1] for grad in grad_segment])
grad_param_segments.sort(key=get_after_idx_of_fuse_group)
fused_vars = [None] * len(grad_param_segments) fused_vars = [None] * len(grad_param_segments)
for i in range(len(grad_param_segments) - 1, -1, -1): for i in range(len(grad_param_segments) - 1, -1, -1):
# travers the grad_param_segments in backward # travers the grad_param_segments in backward
...@@ -390,7 +419,9 @@ class RawProgramOptimizer(MetaOptimizerBase): ...@@ -390,7 +419,9 @@ class RawProgramOptimizer(MetaOptimizerBase):
stop_gradient=True, stop_gradient=True,
) )
fused_vars[i] = fused_var fused_vars[i] = fused_var
after_idx = outputs_name_to_idx[grad_segment[-1]][1] after_idx = max(
[outputs_name_to_idx[grad][1] for grad in grad_segment]
)
block._insert_op_without_sync( block._insert_op_without_sync(
after_idx + 1, after_idx + 1,
type='c_allreduce_sum', type='c_allreduce_sum',
......
...@@ -5335,6 +5335,8 @@ class Program: ...@@ -5335,6 +5335,8 @@ class Program:
self._fleet_opt = None self._fleet_opt = None
self._program_config = None self._program_config = None
self._pass_applied = None
# assigned if this program has been parsed by a pipeline optimizer # assigned if this program has been parsed by a pipeline optimizer
self._pipeline_opt = None self._pipeline_opt = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册