未验证 提交 ecf586a5 编写于 作者: J Jiabin Yang 提交者: GitHub

support amp logic (#52397)

上级 63efdaee
......@@ -2864,7 +2864,6 @@ class Operator:
self._type = type
self.attrs = attrs if attrs else {}
else:
self.legacy_attrs = attrs if attrs else {}
self.block = block
self.desc = desc
......@@ -3097,10 +3096,6 @@ class Operator:
self.desc.check_attrs()
# record all attrs needed by creating op
for item in self.desc.attr_names():
self.legacy_attrs[item] = self.desc.attr(item)
if self._has_kernel(type):
self.desc.infer_var_type(self.block.desc)
self.desc.infer_shape(self.block.desc)
......@@ -3108,10 +3103,6 @@ class Operator:
def _has_kernel(self, op_type):
return op_type not in self.OP_WITHOUT_KERNEL_SET
def _get_runtime_attrs(self):
"""Record all attrs needed by creating op. This api is only for to_prim process."""
return self.legacy_attrs
def to_string(self, throw_on_error):
"""
Get debug string.
......
......@@ -206,6 +206,13 @@ def gelu_composite(x, approximate):
@REGISTER_COMPOSITE('reduce_mean')
def mean_composite(x, axis, keepdim):
"""define composite rule of op mean"""
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
is_amp = True
x = cast(x, "float32")
axes = axis or list(range(0, len(x.shape)))
axes = [axes] if isinstance(axes, int) else axes
sum_x = sum(x, axis=axes, keepdim=keepdim)
......@@ -217,7 +224,10 @@ def mean_composite(x, axis, keepdim):
value=value_to_fill,
dtype=sum_x.dtype,
)
return divide(sum_x, norm)
res = divide(sum_x, norm)
if is_amp:
res = cast(res, "float16")
return res
@REGISTER_COMPOSITE('expand_v2')
......@@ -424,9 +434,16 @@ def sigmoid_composite(x):
define composite rule of op sigmoid
res = 1 / (1 + exp(-x))
"""
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
is_amp = True
x = cast(x, "float32")
sum_temp = 1 + exp(-x)
res = 1 / sum_temp
return res
return res if not is_amp else cast(res, "float16")
@REGISTER_COMPOSITE('silu')
......@@ -435,9 +452,16 @@ def silu_composite(x):
define composite rule of op silu
res = x / (1 + exp(-x))
"""
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
is_amp = True
x = cast(x, "float32")
sum_temp = 1 + exp(-x)
res = x / sum_temp
return res
return res if not is_amp else cast(res, "float16")
@REGISTER_COMPOSITE('meshgrid')
......@@ -505,9 +529,16 @@ def sqrt_composite(x):
define composite rule of op sqrt
res = pow(x, 0.5)
"""
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
is_amp = True
x = cast(x, "float32")
y = full(x.shape if len(x.shape) == 0 else [1], 0.5, x.dtype)
res = pow(x, y)
return res
return res if not is_amp else cast(res, "float16")
@REGISTER_COMPOSITE('pow')
......@@ -516,9 +547,18 @@ def pow_composite(x, y):
define composite rule of op pow
res = x^y
"""
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
is_amp = True
x = cast(x, "float32")
if isinstance(y, (int, float)):
y = full(x.shape if len(x.shape) == 0 else [1], y, x.dtype)
res = pow(x, y)
if is_amp:
res = cast(res, "float16")
return res
......@@ -556,8 +596,15 @@ def unsqueeze_composite(x, axis):
def rsqrt_composite(x):
"""define composite rule of op rsqrt."""
# rsqrt(x) = x^(-0.5)
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
is_amp = True
x = cast(x, "float32")
y = full(x.shape if len(x.shape) == 0 else [1], -0.5, x.dtype)
return pow(x, y)
res = pow(x, y)
return res if not is_amp else cast(res, "float16")
@REGISTER_COMPOSITE('group_norm')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册