未验证 提交 1026052c 编写于 作者: Y Yuang Liu 提交者: GitHub

fix_dp_grad_merge_with_grad_clip_by_global_norm (#36334)

上级 00245cfd
...@@ -28,6 +28,7 @@ from .dygraph import base as imperative_base ...@@ -28,6 +28,7 @@ from .dygraph import base as imperative_base
from .data_feeder import check_variable_and_dtype from .data_feeder import check_variable_and_dtype
from .framework import in_dygraph_mode from .framework import in_dygraph_mode
from .layer_helper import LayerHelper from .layer_helper import LayerHelper
from .framework import default_main_program
__all__ = [ __all__ = [
'set_gradient_clip', 'ErrorClipByValue', 'ClipGradByValue', 'set_gradient_clip', 'ErrorClipByValue', 'ClipGradByValue',
...@@ -547,7 +548,12 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -547,7 +548,12 @@ class ClipGradByGlobalNorm(ClipGradBase):
scale_input = (scale_var.astype('float16') scale_input = (scale_var.astype('float16')
if g.dtype == core.VarDesc.VarType.FP16 else if g.dtype == core.VarDesc.VarType.FP16 else
scale_var) scale_var)
p.block.append_op( # NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
# will be in different blocks with the gradient clip related ops.
# We need to handle the correct block, otherwise will encounter
# a 'NotFoundError' during compile time.
block = default_main_program().current_block()
block.append_op(
type='elementwise_mul', type='elementwise_mul',
inputs={'X': g, inputs={'X': g,
'Y': scale_input}, 'Y': scale_input},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册