未验证 提交 182e0904 编写于 作者: C Charles-hit 提交者: GitHub

[AMP Prim OP]support amp logic for some prim ops (#54608)

* fix api rename

* support amp logic for some prim ops

---------
Co-authored-by: Nkangguangli <kangguangli@hotmail.com>
上级 cc91fa66
......@@ -38,7 +38,8 @@ def softmax_composite(x, axis):
from paddle.fluid.data_feeder import convert_dtype
# Softmax need fp32 compute since it has sum op in
if convert_dtype(x.dtype) == "float16":
dtype = convert_dtype(x.dtype)
if dtype in ["float16", "uint16"]:
is_amp = True
x = cast(x, "float32")
if not x.shape:
......@@ -53,7 +54,7 @@ def softmax_composite(x, axis):
denominator = sum(molecular, axis=axis, keepdim=True)
res = divide(molecular, denominator)
if is_amp:
res = cast(res, "float16")
res = cast(res, dtype)
return res
......@@ -246,7 +247,8 @@ def mean_composite(x, axis, keepdim):
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
dtype = convert_dtype(x.dtype)
if dtype == ["float16", "uint16"]:
is_amp = True
x = cast(x, "float32")
......@@ -263,7 +265,7 @@ def mean_composite(x, axis, keepdim):
)
res = divide(sum_x, norm)
if is_amp:
res = cast(res, "float16")
res = cast(res, dtype)
return res
......@@ -420,7 +422,9 @@ def bernoulli(shape, dtype, p, seed=0):
from paddle.fluid.data_feeder import convert_dtype
# TODO(jiabin) Fix uniform doesn't support float16 error in CINN
new_dtype = "float32" if convert_dtype(dtype) == "float16" else dtype
new_dtype = (
"float32" if convert_dtype(dtype) in ["float16", "uint16"] else dtype
)
return cast(
greater_equal(
uniform(shape, new_dtype, min=0.0, max=1.0, seed=seed),
......@@ -474,13 +478,14 @@ def sigmoid_composite(x):
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
dtype = convert_dtype(x.dtype)
if dtype in ["float16", "uint16"]:
is_amp = True
x = cast(x, "float32")
sum_temp = 1 + exp(-x)
res = 1 / sum_temp
return res if not is_amp else cast(res, "float16")
return res if not is_amp else cast(res, dtype)
@REGISTER_COMPOSITE('silu')
......@@ -492,13 +497,14 @@ def silu_composite(x):
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
dtype = convert_dtype(x.dtype)
if dtype in ["float16", "uint16"]:
is_amp = True
x = cast(x, "float32")
sum_temp = 1 + exp(-x)
res = x / sum_temp
return res if not is_amp else cast(res, "float16")
return res if not is_amp else cast(res, dtype)
@REGISTER_COMPOSITE('meshgrid')
......@@ -569,13 +575,14 @@ def sqrt_composite(x):
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
dtype = convert_dtype(x.dtype)
if dtype in ["float16", "uint16"]:
is_amp = True
x = cast(x, "float32")
y = full(x.shape if len(x.shape) == 0 else [1], 0.5, x.dtype)
res = pow(x, y)
return res if not is_amp else cast(res, "float16")
return res if not is_amp else cast(res, dtype)
@REGISTER_COMPOSITE('pow')
......@@ -587,7 +594,8 @@ def pow_composite(x, y):
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
dtype = convert_dtype(x.dtype)
if dtype in ["float16", "uint16"]:
is_amp = True
x = cast(x, "float32")
......@@ -595,7 +603,7 @@ def pow_composite(x, y):
y = full(x.shape if len(x.shape) == 0 else [1], y, x.dtype)
res = pow(x, y)
if is_amp:
res = cast(res, "float16")
res = cast(res, dtype)
return res
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册