未验证 提交 ecf586a5 编写于 作者: J Jiabin Yang 提交者: GitHub

support amp logic (#52397)

上级 63efdaee
...@@ -2864,7 +2864,6 @@ class Operator: ...@@ -2864,7 +2864,6 @@ class Operator:
self._type = type self._type = type
self.attrs = attrs if attrs else {} self.attrs = attrs if attrs else {}
else: else:
self.legacy_attrs = attrs if attrs else {}
self.block = block self.block = block
self.desc = desc self.desc = desc
...@@ -3097,10 +3096,6 @@ class Operator: ...@@ -3097,10 +3096,6 @@ class Operator:
self.desc.check_attrs() self.desc.check_attrs()
# record all attrs needed by creating op
for item in self.desc.attr_names():
self.legacy_attrs[item] = self.desc.attr(item)
if self._has_kernel(type): if self._has_kernel(type):
self.desc.infer_var_type(self.block.desc) self.desc.infer_var_type(self.block.desc)
self.desc.infer_shape(self.block.desc) self.desc.infer_shape(self.block.desc)
...@@ -3108,10 +3103,6 @@ class Operator: ...@@ -3108,10 +3103,6 @@ class Operator:
def _has_kernel(self, op_type): def _has_kernel(self, op_type):
return op_type not in self.OP_WITHOUT_KERNEL_SET return op_type not in self.OP_WITHOUT_KERNEL_SET
def _get_runtime_attrs(self):
"""Record all attrs needed by creating op. This api is only for to_prim process."""
return self.legacy_attrs
def to_string(self, throw_on_error): def to_string(self, throw_on_error):
""" """
Get debug string. Get debug string.
......
...@@ -206,6 +206,13 @@ def gelu_composite(x, approximate): ...@@ -206,6 +206,13 @@ def gelu_composite(x, approximate):
@REGISTER_COMPOSITE('reduce_mean') @REGISTER_COMPOSITE('reduce_mean')
def mean_composite(x, axis, keepdim): def mean_composite(x, axis, keepdim):
"""define composite rule of op mean""" """define composite rule of op mean"""
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
is_amp = True
x = cast(x, "float32")
axes = axis or list(range(0, len(x.shape))) axes = axis or list(range(0, len(x.shape)))
axes = [axes] if isinstance(axes, int) else axes axes = [axes] if isinstance(axes, int) else axes
sum_x = sum(x, axis=axes, keepdim=keepdim) sum_x = sum(x, axis=axes, keepdim=keepdim)
...@@ -217,7 +224,10 @@ def mean_composite(x, axis, keepdim): ...@@ -217,7 +224,10 @@ def mean_composite(x, axis, keepdim):
value=value_to_fill, value=value_to_fill,
dtype=sum_x.dtype, dtype=sum_x.dtype,
) )
return divide(sum_x, norm) res = divide(sum_x, norm)
if is_amp:
res = cast(res, "float16")
return res
@REGISTER_COMPOSITE('expand_v2') @REGISTER_COMPOSITE('expand_v2')
...@@ -424,9 +434,16 @@ def sigmoid_composite(x): ...@@ -424,9 +434,16 @@ def sigmoid_composite(x):
define composite rule of op sigmoid define composite rule of op sigmoid
res = 1 / (1 + exp(-x)) res = 1 / (1 + exp(-x))
""" """
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
is_amp = True
x = cast(x, "float32")
sum_temp = 1 + exp(-x) sum_temp = 1 + exp(-x)
res = 1 / sum_temp res = 1 / sum_temp
return res return res if not is_amp else cast(res, "float16")
@REGISTER_COMPOSITE('silu') @REGISTER_COMPOSITE('silu')
...@@ -435,9 +452,16 @@ def silu_composite(x): ...@@ -435,9 +452,16 @@ def silu_composite(x):
define composite rule of op silu define composite rule of op silu
res = x / (1 + exp(-x)) res = x / (1 + exp(-x))
""" """
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
is_amp = True
x = cast(x, "float32")
sum_temp = 1 + exp(-x) sum_temp = 1 + exp(-x)
res = x / sum_temp res = x / sum_temp
return res return res if not is_amp else cast(res, "float16")
@REGISTER_COMPOSITE('meshgrid') @REGISTER_COMPOSITE('meshgrid')
...@@ -505,9 +529,16 @@ def sqrt_composite(x): ...@@ -505,9 +529,16 @@ def sqrt_composite(x):
define composite rule of op sqrt define composite rule of op sqrt
res = pow(x, 0.5) res = pow(x, 0.5)
""" """
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
is_amp = True
x = cast(x, "float32")
y = full(x.shape if len(x.shape) == 0 else [1], 0.5, x.dtype) y = full(x.shape if len(x.shape) == 0 else [1], 0.5, x.dtype)
res = pow(x, y) res = pow(x, y)
return res return res if not is_amp else cast(res, "float16")
@REGISTER_COMPOSITE('pow') @REGISTER_COMPOSITE('pow')
...@@ -516,9 +547,18 @@ def pow_composite(x, y): ...@@ -516,9 +547,18 @@ def pow_composite(x, y):
define composite rule of op pow define composite rule of op pow
res = x^y res = x^y
""" """
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
is_amp = True
x = cast(x, "float32")
if isinstance(y, (int, float)): if isinstance(y, (int, float)):
y = full(x.shape if len(x.shape) == 0 else [1], y, x.dtype) y = full(x.shape if len(x.shape) == 0 else [1], y, x.dtype)
res = pow(x, y) res = pow(x, y)
if is_amp:
res = cast(res, "float16")
return res return res
...@@ -556,8 +596,15 @@ def unsqueeze_composite(x, axis): ...@@ -556,8 +596,15 @@ def unsqueeze_composite(x, axis):
def rsqrt_composite(x): def rsqrt_composite(x):
"""define composite rule of op rsqrt.""" """define composite rule of op rsqrt."""
# rsqrt(x) = x^(-0.5) # rsqrt(x) = x^(-0.5)
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
is_amp = True
x = cast(x, "float32")
y = full(x.shape if len(x.shape) == 0 else [1], -0.5, x.dtype) y = full(x.shape if len(x.shape) == 0 else [1], -0.5, x.dtype)
return pow(x, y) res = pow(x, y)
return res if not is_amp else cast(res, "float16")
@REGISTER_COMPOSITE('group_norm') @REGISTER_COMPOSITE('group_norm')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册