From bcdbac17534624056443bfd9f91c02c948f01be9 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Thu, 3 Sep 2020 14:32:10 +0800 Subject: [PATCH] fix some cast error. (#26884) --- python/paddle/fluid/contrib/mixed_precision/fp16_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 93013ef8bf..328dafe621 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -74,7 +74,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): continue for in_var_name in op.input(in_name): in_var = block.var(in_var_name) - if in_var.type not in valid_types: + if in_var.type not in valid_types or in_var.dtype == dest_dtype: continue if in_var.dtype == src_dtype: cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype) @@ -84,7 +84,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): name=cast_name, dtype=dest_dtype, persistable=False, - stop_gradient=False) + stop_gradient=in_var.stop_gradient) block._insert_op( idx, @@ -100,7 +100,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): else: if op.has_attr('in_dtype'): op._set_attr('in_dtype', dest_dtype) - if src_dtype == core.VarDesc.VarType.FP32: + if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16: for out_name in op.output_names: if op.type == 'batch_norm' and out_name != 'Y': continue -- GitLab