未验证 提交 bcdbac17 编写于 作者: Z Zhen Wang 提交者: GitHub

fix some cast error. (#26884)

上级 6a09b8f1
......@@ -74,7 +74,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
continue
for in_var_name in op.input(in_name):
in_var = block.var(in_var_name)
if in_var.type not in valid_types:
if in_var.type not in valid_types or in_var.dtype == dest_dtype:
continue
if in_var.dtype == src_dtype:
cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype)
......@@ -84,7 +84,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
name=cast_name,
dtype=dest_dtype,
persistable=False,
stop_gradient=False)
stop_gradient=in_var.stop_gradient)
block._insert_op(
idx,
......@@ -100,7 +100,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP32:
if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16:
for out_name in op.output_names:
if op.type == 'batch_norm' and out_name != 'Y':
continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册