From e69cc215ad0bafaeb47b577a709d793ca264f355 Mon Sep 17 00:00:00 2001 From: Yuang Liu <liuyuang@baidu.com> Date: Tue, 31 Aug 2021 13:42:17 +0800 Subject: [PATCH] [cherry-pick][hybrid performance] optim the grad fuse for pipeline mode by sorting the grad by dtype (#35070) (#35300) --- python/paddle/fluid/optimizer.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 6e7e7e0399f..eb3d559ddcd 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -5223,6 +5223,9 @@ class PipelineOptimizer(object): if len(grad_param_pairs) == 0: return + grad_param_pairs = self._sort_grad_param_by_dtype(main_block, + grad_param_pairs) + grad_param_segments = [] merged_suffix = '@MERGED@FP16' if fp16 else '@MERGED' dtype = paddle.float16 if fp16 else paddle.float32 @@ -5416,6 +5419,24 @@ class PipelineOptimizer(object): return fused_merged_gradients + def _sort_grad_param_by_dtype(self, main_block, grad_param_pairs): + # sort the grad param paris by the dtype + fp16_pairs = [] + fp32_pairs = [] + other_pairs = [] + for pairs in grad_param_pairs: + dtype = main_block.var(pairs[0]).dtype + if dtype == paddle.float32: + fp32_pairs.append(pairs) + elif dtype == paddle.float16: + fp16_pairs.append(pairs) + else: + other_pairs.append(pairs) + sorted_pairs = fp16_pairs + sorted_pairs.extend(fp32_pairs) + sorted_pairs.extend(other_pairs) + return sorted_pairs + def _get_var_size(self, var): dtype_to_size = { core.VarDesc.VarType.FP16: 2, -- GitLab