未验证 提交 94cd1ba2 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Fix dropout CINN amp error (#51688)

* support amp logic for layer_norm and softmax

* fix layer_norm amp

* fix layernorm api and dropout fp16

* fix layernorm api and dropout fp16

* fix bn, ln dtype in float16

* fix dropout fp16

* fix comment

* fix cinn dropout amp error
上级 c665400b
......@@ -355,10 +355,14 @@ def dropout_composite(x, seed_tensor, p, is_test, mode, seed, fix_seed):
def bernoulli(shape, dtype, p, seed=0):
from paddle.fluid.data_feeder import convert_dtype
# TODO(jiabin) Fix uniform doesn't support float16 error in CINN
new_dtype = "float32" if convert_dtype(dtype) == "float16" else dtype
return cast(
greater_equal(
uniform(shape, dtype, min=0.0, max=1.0, seed=seed),
fill_constant(shape, dtype, p),
uniform(shape, new_dtype, min=0.0, max=1.0, seed=seed),
fill_constant(shape, new_dtype, p),
),
dtype,
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册