diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_gelu.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..bd8587b906778f3b876f518e65233793f2368b2b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_gelu.py @@ -0,0 +1,129 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform +import unittest + +import numpy as np + +import paddle +import paddle.nn.functional as F +from paddle.fluid import core + +TOLERANCE = { + "float16": {"rtol": 1e-3, "atol": 1e-3}, + "float32": {"rtol": 1e-6, "atol": 1e-6}, + "float64": {"rtol": 1e-15, "atol": 1e-15}, +} + +approximate_conds = [True, False] + + +def apply_to_static(net, use_cinn): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static(net, build_strategy=build_strategy) + + +def generate_data(shape, dtype="float32"): + np_data = np.random.random(shape).astype(dtype) + return np_data + + +class PrimeNet(paddle.nn.Layer): + def __init__(self, approximate): + super(PrimeNet, self).__init__() + self.fc = paddle.nn.Linear(4, 4) + self.approximate = approximate + + def forward(self, x): + # y = self.fc(x) + out = F.gelu(x, approximate=self.approximate) + return out + + +class TestPrimForwardAndBackward(unittest.TestCase): + """ + Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph + """ + + def setUp(self): + paddle.seed(2022) + self.shapes = [[2, 4], [64, 16, 4]] + self.dtypes = ["float16", "float32"] + + def train(self, use_prim, data): + for approximate in approximate_conds: + return self._train(use_prim, approximate, data) + + def _train(self, use_prim, approximate, data): + paddle.seed(2022) + net = PrimeNet(approximate) + sgd = paddle.optimizer.SGD( + learning_rate=0.1, parameters=net.parameters() + ) + core._set_prim_all_enabled(use_prim) + if use_prim: + net = apply_to_static(net, use_prim) + + res = [] + for _ in range(10): + out = net(data) + loss = paddle.mean(out) + loss.backward() + sgd.step() + sgd.clear_grad() + + res.append(out.numpy()) + self.check_prim(net, use_prim) + + return res + + def check_prim(self, net, use_prim): + if not use_prim: + return + fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + # Ensure that gelu is splitted into small ops + self.assertTrue('gelu' not in fwd_ops) + + def test_cinn_prim(self): + plat = platform.system() + if plat == "Linux": + for shape in self.shapes: + for dtype in self.dtypes: + if ( + paddle.device.get_device() == "cpu" + and dtype == "float16" + ): + print("need pass this case") + continue + data = generate_data(shape, dtype) + data_t = paddle.to_tensor(data) + data_t.stop_gradient = False + dy_res = self.train(use_prim=False, data=data_t) + cinn_res = self.train(use_prim=True, data=data_t) + for i in range(len(dy_res)): + np.testing.assert_allclose( + cinn_res[i], + dy_res[i], + rtol=TOLERANCE[dtype]['rtol'], + atol=TOLERANCE[dtype]['atol'], + ) + + else: + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_gelu.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..3e5c10f803ae34bd542102a3b1c1f1be69940ba0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_gelu.py @@ -0,0 +1,133 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from utils import TOLERANCE + +np.random.seed(2013) + +import paddle +import paddle.nn.functional as F +from paddle.fluid import core + + +def generate_data(shape, dtype="float32"): + np_data = np.random.random(shape).astype(dtype) + return np_data + + +class Attr: + def __init__(self) -> None: + self.dtype = "float32" + self.shape = None + self.approximate = False + + def set_dtype(self, dtype) -> None: + self.dtype = dtype + return + + def set_shape(self, shape) -> None: + self.shape = shape + return + + def set_approximate(self, approximate) -> None: + self.approximate = approximate + return + + def get_rtol(self, flag): + rtol = TOLERANCE[self.dtype][flag].get("rtol") + return rtol + + def get_atol(self, flag): + atol = TOLERANCE[self.dtype][flag].get("atol") + return atol + + +attrs = Attr() + + +def fn(x): + return F.gelu(x, approximate=attrs.approximate) + + +def expect_forward(inputs): + return fn(inputs) + + +class TestCompositeGelu(unittest.TestCase): + def setUp(self): + self.dtypes = ["float16", "float32", "float64"] + self.shapes = [[16, 16, 64, 64], [2, 3, 4], [2, 3]] + self.approximate = [True, False] + + def cal_composite(self, inputs): + paddle.enable_static() + core._set_prim_forward_enabled(True) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + y = fn(x) + blocks = main_program.blocks + + fwd_ops = [op.type for op in blocks[0].ops] + # Ensure that gelu in original block + self.assertTrue('gelu' in fwd_ops) + + paddle.incubate.autograd.to_prim(blocks) + + fwd_ops_new = [op.type for op in blocks[0].ops] + # Ensure that gelu is splitted into small ops + self.assertTrue('gelu' not in fwd_ops_new) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run(main_program, feed={'x': inputs}, fetch_list=[y]) + paddle.disable_static() + return res + + def compare_forward(self): + np_data = generate_data(attrs.shape, attrs.dtype) + tensor_data = paddle.to_tensor(np_data) + + expect = expect_forward(tensor_data).numpy() + actual = self.cal_composite(np_data)[0] + assert expect.dtype == actual.dtype + np.testing.assert_allclose( + expect, + actual, + rtol=attrs.get_rtol("forward"), + atol=attrs.get_atol("forward"), + ) + + def test_forward(self): + for i in self.approximate: + for j in self.dtypes: + for t in self.shapes: + # gelu-kernel on cpu not support float16 + if paddle.device.get_device() == "cpu" and j == "float16": + print("need pass this case") + continue + attrs.set_approximate(i) + attrs.set_dtype(j) + attrs.set_shape(t) + self.compare_forward() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_gelu_grad.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_gelu_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..dda900e2472723e5c52d48f79342734902c2eac5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_gelu_grad.py @@ -0,0 +1,205 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from utils import TOLERANCE + +np.random.seed(2013) + +import paddle +import paddle.nn.functional as F +from paddle.fluid import core + + +def generate_data(shape, dtype="float32"): + np_data = np.random.random(shape).astype(dtype) + return np_data + + +class Attr: + def __init__(self) -> None: + self.dtype = "float32" + self.shape = None + self.approximate = False + + def set_dtype(self, dtype) -> None: + self.dtype = dtype + return + + def set_shape(self, shape) -> None: + self.shape = shape + return + + def set_approximate(self, approximate) -> None: + self.approximate = approximate + return + + def get_rtol(self, flag): + rtol = TOLERANCE[self.dtype][flag].get("rtol") + return rtol + + def get_atol(self, flag): + atol = TOLERANCE[self.dtype][flag].get("atol") + return atol + + +attrs = Attr() + + +def fn(x): + return F.gelu(x, approximate=attrs.approximate) + + +def expect_grad(inputs): + paddle.disable_static() + inputs.stop_gradient = False + res = fn(inputs) + + gradients = paddle.grad(res, inputs) + return gradients + + +class TestCompositeGelu(unittest.TestCase): + "test composite gelu: prim forward" + + def setUp(self): + self.dtypes = ["float16", "float32", "float64"] + self.shapes = [[16, 16, 64, 64], [2, 3, 4], [2, 3]] + self.approximates = [True, False] + + def cal_composite_grad(self, inputs): + paddle.enable_static() + core._set_prim_forward_enabled(True) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x.stop_gradient = False + y = fn(x) + blocks = main_program.blocks + + fwd_ops = [op.type for op in blocks[0].ops] + # Ensure that gelu in original block + self.assertTrue('gelu' in fwd_ops) + + paddle.incubate.autograd.to_prim(blocks) + + fwd_ops_new = [op.type for op in blocks[0].ops] + # Ensure that gelu is splitted into small ops + self.assertTrue('gelu' not in fwd_ops_new) + + z = paddle.static.gradients([y], x) + fwd_ops_grad = [op.type for op in blocks[0].ops] + # Ensure that gelu_grad not in grad block + + self.assertTrue('gelu_grad' not in fwd_ops_grad) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) + paddle.disable_static() + core._set_prim_forward_enabled(False) + return res + + def compare_backward(self): + np_data = generate_data(attrs.shape, attrs.dtype) + tensor_data = paddle.to_tensor(np_data) + + expect = expect_grad(tensor_data)[0].numpy() + actual = self.cal_composite_grad(np_data)[0] + + assert expect.dtype == actual.dtype + np.testing.assert_allclose( + expect, + actual, + rtol=attrs.get_rtol("backward"), + atol=attrs.get_atol("backward"), + ) + + def test_backward(self): + for i in self.approximates: + for j in self.dtypes: + for t in self.shapes: + if paddle.device.get_device() == "cpu" and j == "float16": + print("need pass this case") + continue + attrs.set_approximate(i) + attrs.set_dtype(j) + attrs.set_shape(t) + self.compare_backward() + + +class TestCompositeGeluPrimBackward(unittest.TestCase): + "test composite gelu: prim forward and backward" + + def setUp(self): + self.dtypes = ["float16", "float32", "float64"] + self.shapes = [[16, 16, 64, 64], [2, 3, 4], [2, 3]] + self.approximates = [True, False] + + def cal_composite_grad(self, inputs): + paddle.enable_static() + core._set_prim_all_enabled(True) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x.stop_gradient = False + y = fn(x) + blocks = main_program.blocks + paddle.incubate.autograd.to_prim(blocks) + z = paddle.static.gradients([y], x) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) + paddle.disable_static() + core._set_prim_all_enabled(False) + return res + + def compare_backward(self): + np_data = generate_data(attrs.shape, attrs.dtype) + tensor_data = paddle.to_tensor(np_data) + expect = expect_grad(tensor_data)[0].numpy() + actual = self.cal_composite_grad(np_data)[0] + + assert expect.dtype == actual.dtype + np.testing.assert_allclose( + expect, + actual, + rtol=attrs.get_rtol("prim_backward"), + atol=attrs.get_rtol("prim_backward"), + ) + + def test_prim_backward(self): + for i in self.approximates: + for j in self.dtypes: + for t in self.shapes: + if paddle.device.get_device() == "cpu" and j == "float16": + print("need pass this case") + continue + attrs.set_approximate(i) + attrs.set_dtype(j) + attrs.set_shape(t) + self.compare_backward() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/utils.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/utils.py index ed2365adf7c4f48d10afd3fc3e0c8e1db480afca..32358375f05d219f11b86d8fd5a8d02687930bab 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/utils.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/utils.py @@ -14,6 +14,11 @@ # default tolerance TOLERANCE = { + "float16": { + "forward": {"rtol": 1e-3, "atol": 1e-3}, + "backward": {"rtol": 1e-3, "atol": 1e-3}, + "prim_backward": {"rtol": 1e-3, "atol": 1e-3}, + }, "float32": { "forward": {"rtol": 1e-6, "atol": 1e-6}, "backward": {"rtol": 1e-6, "atol": 1e-6}, diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index d7eb54ba571974b7a127a674546bd347264ef2b0..bcf7ab2a4a4cd8b098f9bcd76b6b7c32c5e3c6ca 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -101,3 +101,27 @@ def composite_batchnorm( return run_mean_, None, batch_mean_, batch_var_, run_var_, y else: return run_mean_, batch_mean_, batch_var_, run_var_, y + + +@REGISTER_COMPOSITE('gelu') +def gelu_composite(x, approximate): + """define composite rule of op gelu""" + M_SQRT1_2 = ( + 0.70710678118654752440 # /* 1/sqrt(2) */ copy from gelu-kernel.cc + ) + M_2_SQRTPI = 1.12837916709551257390 # /* 2/sqrt(pi) */ + one = ones(x.shape, x.dtype) + half = full(x.shape, 0.5, x.dtype) + if approximate: + # gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3}))) + kAlpha = full(x.shape, M_2_SQRTPI * M_SQRT1_2, x.dtype) + GELU_CONSTANT = full(x.shape, 0.044715, x.dtype) + tanh_out = tanh(kAlpha * (x + GELU_CONSTANT * x * x * x)) + out = x * half * (one + tanh_out) + return out + + else: + # gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + cdf = half * (one + erf(x * full(x.shape, M_SQRT1_2, x.dtype))) + out = x * cdf + return out diff --git a/python/paddle/incubate/autograd/primitives.py b/python/paddle/incubate/autograd/primitives.py index 3eb3ce5d0838965b7d2c9cd88d3e9ea20e0ddad5..35687279affda7037dddbecd336ae2ebd70c5b08 100644 --- a/python/paddle/incubate/autograd/primitives.py +++ b/python/paddle/incubate/autograd/primitives.py @@ -34,6 +34,7 @@ from paddle.tensor import erf # noqa: F401 from paddle.tensor import erfinv # noqa: F401 from paddle.tensor import exp # noqa: F401 from paddle.tensor import expm1 # noqa: F401 +from paddle.tensor import full # noqa: F401 from paddle.tensor import lgamma # noqa: F401 from paddle.tensor import log # noqa: F401 from paddle.tensor import log1p # noqa: F401 @@ -113,6 +114,7 @@ others = [ 'assign', 'fill_constant', 'reshape', + 'full', ] __all__ = []