未验证 提交 fad4b3b4 编写于 作者: Y Yuang Liu 提交者: GitHub

[hybrid performance] optim the grad fuse for pipeline mode by sorting the grad by dtype (#35070)

上级 b6dc16cb
...@@ -5216,6 +5216,9 @@ class PipelineOptimizer(object): ...@@ -5216,6 +5216,9 @@ class PipelineOptimizer(object):
if len(grad_param_pairs) == 0: if len(grad_param_pairs) == 0:
return return
grad_param_pairs = self._sort_grad_param_by_dtype(main_block,
grad_param_pairs)
grad_param_segments = [] grad_param_segments = []
merged_suffix = '@MERGED@FP16' if fp16 else '@MERGED' merged_suffix = '@MERGED@FP16' if fp16 else '@MERGED'
dtype = paddle.float16 if fp16 else paddle.float32 dtype = paddle.float16 if fp16 else paddle.float32
...@@ -5409,6 +5412,24 @@ class PipelineOptimizer(object): ...@@ -5409,6 +5412,24 @@ class PipelineOptimizer(object):
return fused_merged_gradients return fused_merged_gradients
def _sort_grad_param_by_dtype(self, main_block, grad_param_pairs):
# sort the grad param paris by the dtype
fp16_pairs = []
fp32_pairs = []
other_pairs = []
for pairs in grad_param_pairs:
dtype = main_block.var(pairs[0]).dtype
if dtype == paddle.float32:
fp32_pairs.append(pairs)
elif dtype == paddle.float16:
fp16_pairs.append(pairs)
else:
other_pairs.append(pairs)
sorted_pairs = fp16_pairs
sorted_pairs.extend(fp32_pairs)
sorted_pairs.extend(other_pairs)
return sorted_pairs
def _get_var_size(self, var): def _get_var_size(self, var):
dtype_to_size = { dtype_to_size = {
core.VarDesc.VarType.FP16: 2, core.VarDesc.VarType.FP16: 2,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册