diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 58f61b77fd1fe01219a38d239551e5b7b65664a8..478ea75472717aa788eaab965c762aa94869b332 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -5216,6 +5216,9 @@ class PipelineOptimizer(object): if len(grad_param_pairs) == 0: return + grad_param_pairs = self._sort_grad_param_by_dtype(main_block, + grad_param_pairs) + grad_param_segments = [] merged_suffix = '@MERGED@FP16' if fp16 else '@MERGED' dtype = paddle.float16 if fp16 else paddle.float32 @@ -5409,6 +5412,24 @@ class PipelineOptimizer(object): return fused_merged_gradients + def _sort_grad_param_by_dtype(self, main_block, grad_param_pairs): + # sort the grad param paris by the dtype + fp16_pairs = [] + fp32_pairs = [] + other_pairs = [] + for pairs in grad_param_pairs: + dtype = main_block.var(pairs[0]).dtype + if dtype == paddle.float32: + fp32_pairs.append(pairs) + elif dtype == paddle.float16: + fp16_pairs.append(pairs) + else: + other_pairs.append(pairs) + sorted_pairs = fp16_pairs + sorted_pairs.extend(fp32_pairs) + sorted_pairs.extend(other_pairs) + return sorted_pairs + def _get_var_size(self, var): dtype_to_size = { core.VarDesc.VarType.FP16: 2,