From e69cc215ad0bafaeb47b577a709d793ca264f355 Mon Sep 17 00:00:00 2001
From: Yuang Liu <liuyuang@baidu.com>
Date: Tue, 31 Aug 2021 13:42:17 +0800
Subject: [PATCH] [cherry-pick][hybrid performance] optim the grad fuse for
 pipeline mode by sorting the grad by dtype (#35070) (#35300)

---
 python/paddle/fluid/optimizer.py | 21 +++++++++++++++++++++
 1 file changed, 21 insertions(+)

diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py
index 6e7e7e0399f..eb3d559ddcd 100755
--- a/python/paddle/fluid/optimizer.py
+++ b/python/paddle/fluid/optimizer.py
@@ -5223,6 +5223,9 @@ class PipelineOptimizer(object):
         if len(grad_param_pairs) == 0:
             return
 
+        grad_param_pairs = self._sort_grad_param_by_dtype(main_block,
+                                                          grad_param_pairs)
+
         grad_param_segments = []
         merged_suffix = '@MERGED@FP16' if fp16 else '@MERGED'
         dtype = paddle.float16 if fp16 else paddle.float32
@@ -5416,6 +5419,24 @@ class PipelineOptimizer(object):
 
         return fused_merged_gradients
 
+    def _sort_grad_param_by_dtype(self, main_block, grad_param_pairs):
+        # sort the grad param paris by the dtype
+        fp16_pairs = []
+        fp32_pairs = []
+        other_pairs = []
+        for pairs in grad_param_pairs:
+            dtype = main_block.var(pairs[0]).dtype
+            if dtype == paddle.float32:
+                fp32_pairs.append(pairs)
+            elif dtype == paddle.float16:
+                fp16_pairs.append(pairs)
+            else:
+                other_pairs.append(pairs)
+        sorted_pairs = fp16_pairs
+        sorted_pairs.extend(fp32_pairs)
+        sorted_pairs.extend(other_pairs)
+        return sorted_pairs
+
     def _get_var_size(self, var):
         dtype_to_size = {
             core.VarDesc.VarType.FP16: 2,
-- 
GitLab