diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 93013ef8bf8442311621202e0a86dd65e7c38b30..328dafe6219adb3c6355de0bafc430c52725024f 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -74,7 +74,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): continue for in_var_name in op.input(in_name): in_var = block.var(in_var_name) - if in_var.type not in valid_types: + if in_var.type not in valid_types or in_var.dtype == dest_dtype: continue if in_var.dtype == src_dtype: cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype) @@ -84,7 +84,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): name=cast_name, dtype=dest_dtype, persistable=False, - stop_gradient=False) + stop_gradient=in_var.stop_gradient) block._insert_op( idx, @@ -100,7 +100,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): else: if op.has_attr('in_dtype'): op._set_attr('in_dtype', dest_dtype) - if src_dtype == core.VarDesc.VarType.FP32: + if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16: for out_name in op.output_names: if op.type == 'batch_norm' and out_name != 'Y': continue