未验证 提交 7f4cdd55 编写于 作者: C cyber-pioneer 提交者: GitHub

fix reduce_mean prim float16 bug (#54836)

上级 a35aa8cf
......@@ -248,7 +248,7 @@ def mean_composite(x, axis, keepdim):
from paddle.fluid.data_feeder import convert_dtype
dtype = convert_dtype(x.dtype)
if dtype == ["float16", "uint16"]:
if dtype in ["float16", "uint16"]:
is_amp = True
x = cast(x, "float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册