From ecf586a50ac56fe763682d4aece939211cbfa615 Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Tue, 4 Apr 2023 10:20:23 +0800 Subject: [PATCH] support amp logic (#52397) --- python/paddle/fluid/framework.py | 9 --- .../incubate/autograd/composite_rules.py | 57 +++++++++++++++++-- 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 38eaa774041..6a81bc71724 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2864,7 +2864,6 @@ class Operator: self._type = type self.attrs = attrs if attrs else {} else: - self.legacy_attrs = attrs if attrs else {} self.block = block self.desc = desc @@ -3097,10 +3096,6 @@ class Operator: self.desc.check_attrs() - # record all attrs needed by creating op - for item in self.desc.attr_names(): - self.legacy_attrs[item] = self.desc.attr(item) - if self._has_kernel(type): self.desc.infer_var_type(self.block.desc) self.desc.infer_shape(self.block.desc) @@ -3108,10 +3103,6 @@ class Operator: def _has_kernel(self, op_type): return op_type not in self.OP_WITHOUT_KERNEL_SET - def _get_runtime_attrs(self): - """Record all attrs needed by creating op. This api is only for to_prim process.""" - return self.legacy_attrs - def to_string(self, throw_on_error): """ Get debug string. diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index ba5d4af6b4b..84b7d415638 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -206,6 +206,13 @@ def gelu_composite(x, approximate): @REGISTER_COMPOSITE('reduce_mean') def mean_composite(x, axis, keepdim): """define composite rule of op mean""" + is_amp = False + from paddle.fluid.data_feeder import convert_dtype + + if convert_dtype(x.dtype) == "float16": + is_amp = True + x = cast(x, "float32") + axes = axis or list(range(0, len(x.shape))) axes = [axes] if isinstance(axes, int) else axes sum_x = sum(x, axis=axes, keepdim=keepdim) @@ -217,7 +224,10 @@ def mean_composite(x, axis, keepdim): value=value_to_fill, dtype=sum_x.dtype, ) - return divide(sum_x, norm) + res = divide(sum_x, norm) + if is_amp: + res = cast(res, "float16") + return res @REGISTER_COMPOSITE('expand_v2') @@ -424,9 +434,16 @@ def sigmoid_composite(x): define composite rule of op sigmoid res = 1 / (1 + exp(-x)) """ + is_amp = False + from paddle.fluid.data_feeder import convert_dtype + + if convert_dtype(x.dtype) == "float16": + is_amp = True + x = cast(x, "float32") + sum_temp = 1 + exp(-x) res = 1 / sum_temp - return res + return res if not is_amp else cast(res, "float16") @REGISTER_COMPOSITE('silu') @@ -435,9 +452,16 @@ def silu_composite(x): define composite rule of op silu res = x / (1 + exp(-x)) """ + is_amp = False + from paddle.fluid.data_feeder import convert_dtype + + if convert_dtype(x.dtype) == "float16": + is_amp = True + x = cast(x, "float32") + sum_temp = 1 + exp(-x) res = x / sum_temp - return res + return res if not is_amp else cast(res, "float16") @REGISTER_COMPOSITE('meshgrid') @@ -505,9 +529,16 @@ def sqrt_composite(x): define composite rule of op sqrt res = pow(x, 0.5) """ + is_amp = False + from paddle.fluid.data_feeder import convert_dtype + + if convert_dtype(x.dtype) == "float16": + is_amp = True + x = cast(x, "float32") + y = full(x.shape if len(x.shape) == 0 else [1], 0.5, x.dtype) res = pow(x, y) - return res + return res if not is_amp else cast(res, "float16") @REGISTER_COMPOSITE('pow') @@ -516,9 +547,18 @@ def pow_composite(x, y): define composite rule of op pow res = x^y """ + is_amp = False + from paddle.fluid.data_feeder import convert_dtype + + if convert_dtype(x.dtype) == "float16": + is_amp = True + x = cast(x, "float32") + if isinstance(y, (int, float)): y = full(x.shape if len(x.shape) == 0 else [1], y, x.dtype) res = pow(x, y) + if is_amp: + res = cast(res, "float16") return res @@ -556,8 +596,15 @@ def unsqueeze_composite(x, axis): def rsqrt_composite(x): """define composite rule of op rsqrt.""" # rsqrt(x) = x^(-0.5) + is_amp = False + from paddle.fluid.data_feeder import convert_dtype + + if convert_dtype(x.dtype) == "float16": + is_amp = True + x = cast(x, "float32") y = full(x.shape if len(x.shape) == 0 else [1], -0.5, x.dtype) - return pow(x, y) + res = pow(x, y) + return res if not is_amp else cast(res, "float16") @REGISTER_COMPOSITE('group_norm') -- GitLab