You need to sign in or sign up before continuing.
未验证 提交 4be3b057 编写于 作者: H huangxu96 提交者: GitHub

fix bug in amp O2 (#32343)

上级 7bae5e9a
...@@ -103,7 +103,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -103,7 +103,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
if in_name not in {'X', 'Z'}: if in_name not in {'X', 'Z'}:
continue continue
for in_var_name in op.input(in_name): for in_var_name in op.input(in_name):
in_var = block.var(in_var_name) in_var = block._find_var_recursive(in_var_name)
if in_var.type not in _valid_types or in_var.dtype == dest_dtype: if in_var.type not in _valid_types or in_var.dtype == dest_dtype:
continue continue
if in_var.dtype == src_dtype: if in_var.dtype == src_dtype:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册