未验证 提交 7f4cdd55 编写于 作者: C cyber-pioneer 提交者: GitHub

fix reduce_mean prim float16 bug (#54836)

上级 a35aa8cf
...@@ -248,7 +248,7 @@ def mean_composite(x, axis, keepdim): ...@@ -248,7 +248,7 @@ def mean_composite(x, axis, keepdim):
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
dtype = convert_dtype(x.dtype) dtype = convert_dtype(x.dtype)
if dtype == ["float16", "uint16"]: if dtype in ["float16", "uint16"]:
is_amp = True is_amp = True
x = cast(x, "float32") x = cast(x, "float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册