未验证 提交 772b4906 编写于 作者: Z Zhang Ting 提交者: GitHub

fix dtype missmatch error (#53712)

上级 1019b264
......@@ -629,12 +629,15 @@ def cast_model_to_fp16(
def need_process(op):
need_process = True
if op.type in ["cast", "create_py_reader", "read"]:
if op.type in ["create_py_reader", "read"]:
need_process = False
else:
for attr_name in ['out_dtype', 'dtype']:
if op.has_attr(attr_name) and is_float_dtype(
op.attr(attr_name)
# output type of some operators such as fill_constant will be determined by the attribute value.
#
if not op.has_attr('in_dtype') and (
op.has_attr(attr_name)
and is_float_dtype(op.attr(attr_name))
):
need_process = False
......@@ -667,6 +670,24 @@ def cast_model_to_fp16(
"---- Add into keep_fp16_ops because the op in white_list ----"
)
else:
# if cast in orgin program, we only modifiy attr and output's dtype to avoid dtype mismatch errors.
if op.type == 'cast':
in_var = block._find_var_recursive(op.input('X')[0])
out_var = block._find_var_recursive(op.output('Out')[0])
op._set_attr('in_dtype', in_var.dtype)
out_var.desc.set_dtype(paddle.dtype(op.attr('out_dtype')))
_logger.debug(
"---- op type: {}, in var [name: {} dtype: {}], out var [name: {} dtype: {}], attr [in_dtype {} out_dtype {}] ----".format(
op.type,
op.input('X')[0],
in_var.dtype,
op.output('Out')[0],
out_var.dtype,
op.attr('in_dtype'),
op.attr('out_dtype'),
)
)
continue
# divide others ops into fp16/fp32 sets according to promoting principle.
dst_dtype = dest_type
if not use_promote:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册