From 2f4763eec87b5eb30753857164092f09d1464c52 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Tue, 21 Feb 2023 13:04:31 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90prim=E3=80=91Layer=20norm=20(#50422)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix composite mean op map * fix composite check output * init layer_norm * init layer_norm * map output from composite rule to origin op * add dropout op map * add input map check * polish log * modify rules * success test_forward * modify test without cinn * modify cinn test * modify cinn test * except fp64 * except fp64 * delete flatten * delete unused change * review * pass cpu test * code style * delete flatten fp16 error * modify flatten test --------- Co-authored-by: cyber-pioneer --- .../test_cinn_prim_layer_norm.py | 183 ++++++ .../test_flatten_contigous_range_op_mlu.py | 19 - .../test_flatten_contiguous_range_op_npu.py | 18 - .../test_composite_layer_norm.py | 206 ++++++ .../test_composite_layer_norm_grad.py | 613 ++++++++++++++++++ .../unittests/prim/composite_ops/utils.py | 5 + .../test_flatten_contiguous_range_op.py | 19 - .../test_flatten_contiguous_range_op_xpu.py | 19 - .../incubate/autograd/composite_rules.py | 29 + python/paddle/incubate/autograd/utils.py | 2 + python/paddle/tensor/manipulation.py | 13 +- 11 files changed, 1049 insertions(+), 77 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py create mode 100644 python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm.py create mode 100644 python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm_grad.py diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py new file mode 100644 index 00000000000..6460515c0a8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py @@ -0,0 +1,183 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform +import unittest + +import numpy as np + +import paddle +import paddle.nn.functional as F +from paddle.fluid import core + +TOLERANCE = { + "float16": {"rtol": 1e-2, "atol": 1e-2}, + "float32": {"rtol": 1e-5, "atol": 1e-5}, + "float64": {"rtol": 1e-13, "atol": 1e-13}, +} + + +def generate_data(dtype="float32"): + np_data1 = np.random.random([2, 64]).astype(dtype) + np_data2 = np.random.random([64]).astype(dtype) + np_data3 = np.random.random([64]).astype(dtype) + return np_data1, np_data2, np_data3 + + +def apply_to_static(net, use_cinn): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static(net, build_strategy=build_strategy) + + +class PrimeNet(paddle.nn.Layer): + def __init__(self): + super(PrimeNet, self).__init__() + self.fc = paddle.nn.Linear(64, 64) + + def forward(self, x, w, b): + n_shape = x.shape[1:] + out = F.layer_norm(x, n_shape, w, b) + return out[0] + + +class TestPrimForward(unittest.TestCase): + """ + This case only tests prim_forward + to_static + cinn. Thus we need to + set this flag as False to avoid prim_backward. + core.set_prim_backward(False) + """ + + def setUp(self): + self.x = None + self.w = None + self.b = None + self.dtypes = ["float16", "float32"] + + def train(self, use_prim): + net = PrimeNet() + sgd = paddle.optimizer.SGD( + learning_rate=0.1, parameters=net.parameters() + ) + core._set_prim_forward_enabled(use_prim) + core._add_skip_comp_ops("sqrt") + # TODO(Ruting) delete this after modify sqrt + if use_prim: + net = apply_to_static(net, use_prim) + + out = net(self.x, self.w, self.b) + loss = paddle.mean(out) + loss.backward() + sgd.step() + sgd.clear_grad() + + self.check_prim(net, use_prim) + + return out.numpy() + + def check_prim(self, net, use_prim): + if not use_prim: + return + fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + # Ensure that layer_norm is splitted into small ops + self.assertTrue('layer_norm' not in fwd_ops) + + def test_cinn_prim_forward(self): + + for dtype in self.dtypes: + if paddle.device.get_device() == "cpu": + print("need pass this case") + continue + x_n, w_n, b_n = generate_data(dtype) + self.x = paddle.to_tensor(x_n) + self.w = paddle.to_tensor(w_n) + self.b = paddle.to_tensor(b_n) + self.x.stop_gradient = False + dy_res = self.train(use_prim=False) + cinn_res = self.train(use_prim=True) + + np.testing.assert_allclose( + cinn_res, + dy_res, + rtol=TOLERANCE[dtype]['rtol'], + atol=TOLERANCE[dtype]['atol'], + ) + + +class TestPrimForwardAndBackward(unittest.TestCase): + """ + Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph + """ + + def setUp(self): + self.x = None + self.w = None + self.b = None + self.dtypes = ["float16", "float32"] + + def train(self, use_prim): + net = PrimeNet() + sgd = paddle.optimizer.SGD( + learning_rate=0.1, parameters=net.parameters() + ) + core._set_prim_all_enabled(use_prim) + core._add_skip_comp_ops("sqrt") + # TODO(Ruting) delete this after modify sqrt + if use_prim: + net = apply_to_static(net, use_prim) + + out = net(self.x, self.w, self.b) + loss = paddle.mean(out) + loss.backward() + sgd.step() + sgd.clear_grad() + + self.check_prim(net, use_prim) + + return out.numpy() + + def check_prim(self, net, use_prim): + if not use_prim: + return + fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + # Ensure that layer_norm is splitted into small ops + self.assertTrue('layer_norm' not in fwd_ops) + + def test_cinn_prim(self): + plat = platform.system() + if plat == "Linux": + for dtype in self.dtypes: + if paddle.device.get_device() == "cpu": + print("need pass this case") + continue + x_n, w_n, b_n = generate_data(dtype) + self.x = paddle.to_tensor(x_n) + self.w = paddle.to_tensor(w_n) + self.b = paddle.to_tensor(b_n) + self.x.stop_gradient = False + dy_res = self.train(use_prim=False) + cinn_res = self.train(use_prim=True) + + np.testing.assert_allclose( + cinn_res, + dy_res, + rtol=TOLERANCE[dtype]['rtol'], + atol=TOLERANCE[dtype]['atol'], + ) + else: + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_flatten_contigous_range_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_flatten_contigous_range_op_mlu.py index 942d5f9c897..8567c2d1979 100755 --- a/python/paddle/fluid/tests/unittests/mlu/test_flatten_contigous_range_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_flatten_contigous_range_op_mlu.py @@ -273,25 +273,6 @@ class TestFlatten2OpError(unittest.TestCase): self.assertRaises(ValueError, test_ValueError3) - def test_type(): - # dtype must be float32, float64, int8, int32, int64, uint8. - x2 = ( - np.arange( - image_shape[0] - * image_shape[1] - * image_shape[2] - * image_shape[3] - ).reshape(image_shape) - / 100.0 - ) - x2 = x2.astype('float16') - x2_var = paddle.fluid.data( - name='x2', shape=[3, 2, 4, 5], dtype='float16' - ) - paddle.flatten(x2_var) - - self.assertRaises(TypeError, test_type) - def test_InputError(): out = paddle.flatten(x) diff --git a/python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py index 1ee72cbac8c..83a664f619c 100755 --- a/python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py @@ -272,24 +272,6 @@ class TestFlatten2OpError(unittest.TestCase): self.assertRaises(ValueError, test_ValueError3) - def test_type(): - # dtype must be float32, float64, int8, int32, int64, uint8. - x2 = ( - np.arange( - image_shape[0] - * image_shape[1] - * image_shape[2] - * image_shape[3] - ).reshape(image_shape) - / 100.0 - ) - x2 = x2.astype('float16') - x2_var = paddle.fluid.data( - name='x2', shape=[3, 2, 4, 5], dtype='float16' - ) - paddle.flatten(x2_var) - - self.assertRaises(TypeError, test_type) def test_InputError(): out = paddle.flatten(x) diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm.py new file mode 100644 index 00000000000..d34003c5ae9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm.py @@ -0,0 +1,206 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from utils import SUB_TOLERANCE + +import paddle +import paddle.nn.functional as F +from paddle.fluid import core + + +def generate_data(shape1, shape2, shape3, dtype="float32"): + np.random.seed(200) + np_data1 = np.random.random(shape1).astype(dtype) + np_data2 = np.random.random(shape2).astype(dtype) + np_data3 = np.random.random(shape3).astype(dtype) + return np_data1, np_data2, np_data3 + + +class Attr: + def __init__(self) -> None: + self.dtype = None + self.n_shape = None + self.shape1 = None + self.shape2 = None + self.shape3 = None + + def set_dtype(self, dtype) -> None: + self.dtype = dtype + return + + def set_shape(self, n_shape, shape1, shape2, shape3) -> None: + self.n_shape = n_shape + self.shape1 = shape1 + self.shape2 = shape2 + self.shape3 = shape3 + return + + def get_rtol(self, flag): + rtol = SUB_TOLERANCE[self.dtype][flag].get("rtol") + return rtol + + def get_atol(self, flag): + atol = SUB_TOLERANCE[self.dtype][flag].get("atol") + return atol + + +attrs = Attr() + + +def fn(x, norm_shape, w, b): + return F.layer_norm(x, norm_shape, w, b) + + +def expect_forward(x, norm_shape, w, b): + return fn(x, norm_shape, w, b) + + +class TestCompositelayer_norm(unittest.TestCase): + def setUp(self): + self.dtypes = ["float16", "float32", "float64"] + self.n_shape = [[4], [64, 128], [64]] + self.shape1s = [[3, 4], [64, 64, 128], [128, 64, 64]] + self.shape2s = [[4], [64 * 128], [64]] + self.shape3s = [[4], [64 * 128], [64]] + + def cal_composite(self, inputs, norm_shape, weight, bias): + paddle.enable_static() + core._set_prim_forward_enabled(True) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + w = paddle.static.data( + 'w', shape=weight.shape, dtype=str(weight.dtype) + ) + b = paddle.static.data('b', shape=bias.shape, dtype=str(bias.dtype)) + y = fn(x, norm_shape, w, b) + + blocks = main_program.blocks + + fwd_ops = [op.type for op in blocks[0].ops] + # Ensure that layer_norm in original block + self.assertTrue('layer_norm' in fwd_ops) + + paddle.incubate.autograd.to_prim(blocks) + + fwd_ops_new = [op.type for op in blocks[0].ops] + # Ensure that layer_norm is splitted into small ops + self.assertTrue('layer_norm' not in fwd_ops_new) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run( + main_program, + feed={ + 'x': inputs, + 'w': weight, + 'b': bias, + }, + fetch_list=[y], + ) + paddle.disable_static() + core._set_prim_forward_enabled(False) + return res + + def cal2_composite(self, inputs, norm_shape, weight, bias): + paddle.enable_static() + core._set_prim_forward_enabled(True) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + + y = fn(x, norm_shape, weight, bias) + + blocks = main_program.blocks + + fwd_ops = [op.type for op in blocks[0].ops] + # Ensure that layer_norm in original block + self.assertTrue('layer_norm' in fwd_ops) + + paddle.incubate.autograd.to_prim(blocks) + + fwd_ops_new = [op.type for op in blocks[0].ops] + # Ensure that layer_norm is splitted into small ops + self.assertTrue('layer_norm' not in fwd_ops_new) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run( + main_program, + feed={ + 'x': inputs, + }, + fetch_list=[y], + ) + paddle.disable_static() + core._set_prim_forward_enabled(False) + return res + + def compare_forward(self): + x, w, b = generate_data( + attrs.shape1, attrs.shape2, attrs.shape3, attrs.dtype + ) + n_shape = attrs.n_shape + x_p = paddle.to_tensor(x) + w_p = paddle.to_tensor(w) + b_p = paddle.to_tensor(b) + + expect = expect_forward(x_p, n_shape, w_p, b_p).numpy() + actual = self.cal_composite(x, n_shape, w, b)[0] + + assert expect.dtype == actual.dtype + np.testing.assert_allclose( + expect, + actual, + rtol=attrs.get_rtol("forward"), + atol=attrs.get_atol("forward"), + ) + + expect_2 = expect_forward(x_p, n_shape, None, None).numpy() + actual_2 = self.cal2_composite(x, n_shape, None, None)[0] + assert expect_2.dtype == actual_2.dtype + np.testing.assert_allclose( + expect_2, + actual_2, + rtol=attrs.get_rtol("forward"), + atol=attrs.get_atol("forward"), + ) + + def test_forward(self): + for j in self.dtypes: + if paddle.device.get_device() == "cpu" and j == "float16": + print("need pass this case") + continue + for t in range(0, len(self.shape1s)): + attrs.set_dtype(j) + attrs.set_shape( + self.n_shape[t], + self.shape1s[t], + self.shape2s[t], + self.shape3s[t], + ) + self.compare_forward() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm_grad.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm_grad.py new file mode 100644 index 00000000000..a4551732033 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm_grad.py @@ -0,0 +1,613 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import reduce +from operator import mul + +import numpy as np +from utils import SUB_TOLERANCE + +import paddle +import paddle.nn.functional as F +from paddle.fluid import core + +TOLERANCE_NUMPY = { + "float32": {"rtol": 2e-5, "atol": 2e-5}, + "float64": {"rtol": 1e-11, "atol": 1e-11}, +} + + +def generate_data(shape1, shape2, shape3, dtype="float32"): + np.random.seed(200) + np_data1 = np.random.random(shape1).astype(dtype) + np_data2 = np.random.random(shape2).astype(dtype) + np_data3 = np.random.random(shape3).astype(dtype) + return np_data1, np_data2, np_data3 + + +def _reference_layer_norm_naive( + x, scale, beta, epsilon=1e-5, begin_norm_axis=1 +): + x_shape = x.shape + N = reduce(mul, x_shape[0:begin_norm_axis], 1) + D = reduce(mul, x_shape[begin_norm_axis : len(x_shape)], 1) + x.shape = [N, D] + + mean = np.mean(x, axis=1) + difference = x - mean.reshape([N, 1]) + var_tmp1 = np.power(difference, 2.0) + variance = np.mean(var_tmp1, axis=1) + var = variance + epsilon + # var = np.var(x, axis=1) + epsilon + output = np.divide( + (x - mean.reshape([N, 1])), (np.sqrt(var)).reshape([N, 1]) + ) + if scale is not None: + output = scale.reshape([1, D]) * output + if beta is not None: + output = output + beta.reshape([1, D]) + + x.shape, output.shape = x_shape, x_shape + return output, mean, var + + +def _reference_layer_norm_grad( + x, grad_y, scale, bias, mean, var, begin_norm_axis=1 +): + x_shape = x.shape + N = reduce(mul, x_shape[0:begin_norm_axis], 1) + D = reduce(mul, x_shape[begin_norm_axis : len(x_shape)], 1) + + if scale is not None: + scale_shape = scale.shape + scale.shape = [1, D] + x.shape, grad_y.shape = [N, D], [N, D] + var.shape, mean.shape = [N, 1], [N, 1] + + # d_bias + if bias is not None: + d_bias = np.sum(grad_y, axis=0).reshape([1, D]) + else: + d_bias = None + # d_scale + if scale is not None: + d_scale = np.sum( + ((x - mean) * np.sqrt(1 / var)) * grad_y, axis=0 + ).reshape([1, D]) + else: + d_scale = None + # dx + if scale is not None: + dx_end = scale * np.sqrt(1.0 / var) * grad_y + d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y * scale, axis=1).reshape( + [N, 1] + ) # the second part equals to zero. + d_mean = 1.0 / D * d_mean_0 + d_std = np.sum( + -(1.0 / var) * (x - mean) * grad_y * scale, axis=1 + ).reshape([N, 1]) * ( + 1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean) + ) + else: + dx_end = 1.0 * np.sqrt(1.0 / var) * grad_y + d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y * 1.0, axis=1).reshape( + [N, 1] + ) # the second part equals to zero. + d_mean = 1.0 / D * d_mean_0 + d_std = np.sum( + -(1.0 / var) * (x - mean) * grad_y * 1.0, axis=1 + ).reshape([N, 1]) * ( + 1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean) + ) + + grad_x = dx_end + d_mean + d_std + + grad_x.shape, x.shape, grad_y.shape = x_shape, x_shape, x_shape + var.shape, mean.shape = [N], [N] + + if scale is not None: + scale.shape = scale_shape + + return grad_x, d_scale, d_bias + + +class Attr: + def __init__(self) -> None: + self.dtype = None + self.n_shape = None + self.shape1 = None + self.shape2 = None + self.shape3 = None + + def set_dtype(self, dtype) -> None: + self.dtype = dtype + return + + def set_shape(self, n_shape, shape1, shape2, shape3) -> None: + self.n_shape = n_shape + self.shape1 = shape1 + self.shape2 = shape2 + self.shape3 = shape3 + return + + def get_rtol(self, flag): + rtol = SUB_TOLERANCE[self.dtype][flag].get("rtol") + return rtol + + def get_atol(self, flag): + atol = SUB_TOLERANCE[self.dtype][flag].get("atol") + return atol + + +attrs = Attr() + + +def fn(x, norm_shape, w, b): + return F.layer_norm(x, norm_shape, w, b) + + +def expect_backward(x, norm_shape, w, b): + paddle.disable_static() + x.stop_gradient = False + res = fn(x, norm_shape, w, b) + gradients = paddle.grad(res, x) + return gradients + + +class TestCompositelayer_norm(unittest.TestCase): + def setUp(self): + self.dtypes = ["float16", "float32"] + self.n_shape = [[4], [64, 128], [64]] + self.shape1s = [[3, 4], [64, 64, 128], [128, 64, 64]] + self.shape2s = [[4], [64 * 128], [64]] + self.shape3s = [[4], [64 * 128], [64]] + + def cal_composite_backward(self, inputs, norm_shape, weight, bias): + paddle.enable_static() + core._set_prim_forward_enabled(True) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x.stop_gradient = False + w = paddle.static.data( + 'w', shape=weight.shape, dtype=str(weight.dtype) + ) + b = paddle.static.data('b', shape=bias.shape, dtype=str(bias.dtype)) + y = fn(x, norm_shape, w, b) + + blocks = main_program.blocks + + fwd_ops = [op.type for op in blocks[0].ops] + # Ensure that layer_norm in original block + self.assertTrue('layer_norm' in fwd_ops) + + paddle.incubate.autograd.to_prim(blocks) + + fwd_ops_new = [op.type for op in blocks[0].ops] + # Ensure that layer_norm is splitted into small ops + self.assertTrue('layer_norm' not in fwd_ops_new) + + z = paddle.static.gradients([y], x) + fwd_ops_grad = [op.type for op in blocks[0].ops] + # Ensure that layer_norm_grad not in grad block + + self.assertTrue('layer_norm_grad' not in fwd_ops_grad) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run( + main_program, + feed={ + 'x': inputs, + 'w': weight, + 'b': bias, + }, + fetch_list=[z], + ) + paddle.disable_static() + core._set_prim_forward_enabled(False) + return res + + def cal2_composite_backward(self, inputs, norm_shape, weight, bias): + paddle.enable_static() + core._set_prim_forward_enabled(True) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + + x.stop_gradient = False + y = fn(x, norm_shape, weight, bias) + + blocks = main_program.blocks + + fwd_ops = [op.type for op in blocks[0].ops] + # Ensure that layer_norm in original block + self.assertTrue('layer_norm' in fwd_ops) + + paddle.incubate.autograd.to_prim(blocks) + + fwd_ops_new = [op.type for op in blocks[0].ops] + # Ensure that layer_norm is splitted into small ops + self.assertTrue('layer_norm' not in fwd_ops_new) + + z = paddle.static.gradients([y], x) + fwd_ops_grad = [op.type for op in blocks[0].ops] + # Ensure that layer_norm_grad not in grad block + + self.assertTrue('layer_norm_grad' not in fwd_ops_grad) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run( + main_program, + feed={ + 'x': inputs, + }, + fetch_list=[z], + ) + paddle.disable_static() + core._set_prim_forward_enabled(False) + return res + + def compare_backward(self): + x, w, b = generate_data( + attrs.shape1, attrs.shape2, attrs.shape3, attrs.dtype + ) + n_shape = attrs.n_shape + x_p = paddle.to_tensor(x) + w_p = paddle.to_tensor(w) + b_p = paddle.to_tensor(b) + + expect = expect_backward(x_p, n_shape, w_p, b_p)[0].numpy() + actual = self.cal_composite_backward(x, n_shape, w, b)[0] + + assert expect.dtype == actual.dtype + np.testing.assert_allclose( + expect, + actual, + rtol=attrs.get_rtol("backward"), + atol=attrs.get_atol("backward"), + ) + + expect_2 = expect_backward(x_p, n_shape, None, None)[0].numpy() + actual_2 = self.cal2_composite_backward(x, n_shape, None, None)[0] + assert expect_2.dtype == actual_2.dtype + np.testing.assert_allclose( + expect_2, + actual_2, + rtol=attrs.get_rtol("backward"), + atol=attrs.get_atol("backward"), + ) + + def test_backward(self): + for j in self.dtypes: + if paddle.device.get_device() == "cpu": + print("need pass this case") + continue + for t in range(0, len(self.shape1s)): + attrs.set_dtype(j) + attrs.set_shape( + self.n_shape[t], + self.shape1s[t], + self.shape2s[t], + self.shape3s[t], + ) + self.compare_backward() + + +class TestCompositelayer_normPrimBackward(unittest.TestCase): + def setUp(self): + core._set_prim_backward_enabled(True) + self.dtypes = ["float16", "float32"] + self.n_shape = [[4], [64, 128], [64]] + self.shape1s = [[3, 4], [64, 64, 128], [128, 64, 64]] + self.shape2s = [[4], [64 * 128], [64]] + self.shape3s = [[4], [64 * 128], [64]] + + def cal_composite_backward(self, inputs, norm_shape, weight, bias): + paddle.enable_static() + core._set_prim_all_enabled(True) + core._add_skip_comp_ops("sqrt") + # TODO(Ruting) delete this after modify sqrt + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x.stop_gradient = False + w = paddle.static.data( + 'w', shape=weight.shape, dtype=str(weight.dtype) + ) + b = paddle.static.data('b', shape=bias.shape, dtype=str(bias.dtype)) + y = fn(x, norm_shape, w, b) + + blocks = main_program.blocks + paddle.incubate.autograd.to_prim(blocks) + z = paddle.static.gradients([y], x) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run( + main_program, + feed={ + 'x': inputs, + 'w': weight, + 'b': bias, + }, + fetch_list=[z], + ) + paddle.disable_static() + core._set_prim_all_enabled(False) + return res + + def cal2_composite_backward(self, inputs, norm_shape, weight, bias): + paddle.enable_static() + core._set_prim_all_enabled(True) + core._add_skip_comp_ops("sqrt") + # TODO(Ruting) delete this after modify sqrt + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x.stop_gradient = False + y = fn(x, norm_shape, weight, bias) + + blocks = main_program.blocks + paddle.incubate.autograd.to_prim(blocks) + z = paddle.static.gradients([y], x) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run( + main_program, + feed={ + 'x': inputs, + }, + fetch_list=[z], + ) + paddle.disable_static() + core._set_prim_all_enabled(False) + return res + + def compare_backward(self): + x, w, b = generate_data( + attrs.shape1, attrs.shape2, attrs.shape3, attrs.dtype + ) + n_shape = attrs.n_shape + x_p = paddle.to_tensor(x) + w_p = paddle.to_tensor(w) + b_p = paddle.to_tensor(b) + + expect = expect_backward(x_p, n_shape, w_p, b_p)[0].numpy() + actual = self.cal_composite_backward(x, n_shape, w, b)[0] + + assert expect.dtype == actual.dtype + np.testing.assert_allclose( + expect, + actual, + rtol=attrs.get_rtol("prim_backward"), + atol=attrs.get_rtol("prim_backward"), + ) + + expect_2 = expect_backward(x_p, n_shape, None, None)[0].numpy() + actual_2 = self.cal2_composite_backward(x, n_shape, None, None)[0] + assert expect_2.dtype == actual_2.dtype + np.testing.assert_allclose( + expect_2, + actual_2, + rtol=attrs.get_rtol("prim_backward"), + atol=attrs.get_atol("prim_backward"), + ) + + def test_prim_backward(self): + for j in self.dtypes: + if paddle.device.get_device() == "cpu": + print("need pass this case") + continue + for t in range(0, len(self.shape1s)): + attrs.set_dtype(j) + attrs.set_shape( + self.n_shape[t], + self.shape1s[t], + self.shape2s[t], + self.shape3s[t], + ) + self.compare_backward() + + +class TestCompositeNumpylayer_norm(unittest.TestCase): + def setUp(self): + self.dtypes = ["float32", "float64"] + self.n_shape = [ + [4], + [64, 128], + ] + self.shape1s = [ + [3, 4], + [64, 64, 128], + ] + self.shape2s = [ + [4], + [64 * 128], + ] + self.shape3s = [ + [4], + [64 * 128], + ] + + def cal_composite_backward(self, inputs, norm_shape, weight, bias, y_grad): + paddle.enable_static() + core._set_prim_forward_enabled(True) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x.stop_gradient = False + w = paddle.static.data( + 'w', shape=weight.shape, dtype=str(weight.dtype) + ) + b = paddle.static.data('b', shape=bias.shape, dtype=str(bias.dtype)) + y = fn(x, norm_shape, w, b) + y_g = paddle.static.data( + 'y_g', shape=y_grad.shape, dtype=str(y_grad.dtype) + ) + blocks = main_program.blocks + + fwd_ops = [op.type for op in blocks[0].ops] + # Ensure that layer_norm in original block + self.assertTrue('layer_norm' in fwd_ops) + + paddle.incubate.autograd.to_prim(blocks) + + fwd_ops_new = [op.type for op in blocks[0].ops] + # Ensure that layer_norm is splitted into small ops + self.assertTrue('layer_norm' not in fwd_ops_new) + + z = paddle.static.gradients([y], x, y_g) + fwd_ops_grad = [op.type for op in blocks[0].ops] + # Ensure that layer_norm_grad not in grad block + + self.assertTrue('layer_norm_grad' not in fwd_ops_grad) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run( + main_program, + feed={ + 'x': inputs, + 'w': weight, + 'b': bias, + 'y_g': y_grad, + }, + fetch_list=[y, z[0]], + ) + paddle.disable_static() + core._set_prim_forward_enabled(False) + return res[0], res[1] + + def cal_composite_backward_prim( + self, inputs, norm_shape, weight, bias, y_grad + ): + paddle.enable_static() + core._set_prim_all_enabled(True) + core._add_skip_comp_ops("sqrt") + # TODO(Ruting) delete this after modify sqrt + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x.stop_gradient = False + w = paddle.static.data( + 'w', shape=weight.shape, dtype=str(weight.dtype) + ) + b = paddle.static.data('b', shape=bias.shape, dtype=str(bias.dtype)) + y = fn(x, norm_shape, w, b) + y_g = paddle.static.data( + 'y_g', shape=y_grad.shape, dtype=str(y_grad.dtype) + ) + + blocks = main_program.blocks + paddle.incubate.autograd.to_prim(blocks) + z = paddle.static.gradients([y], x) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run( + main_program, + feed={'x': inputs, 'w': weight, 'b': bias, 'y_g': y_grad}, + fetch_list=[y, z[0]], + ) + paddle.disable_static() + core._set_prim_all_enabled(False) + return res[0], res[1] + + def compare_backward(self): + x, w, b = generate_data( + attrs.shape1, attrs.shape2, attrs.shape3, attrs.dtype + ) + y_grad = np.ones_like(x) + n_shape = attrs.n_shape + + composite1, composite2 = self.cal_composite_backward( + x, n_shape, w, b, y_grad + ) + composite_p1, composite_p2 = self.cal_composite_backward_prim( + x, n_shape, w, b, y_grad + ) + + numpy1, mean, variance = _reference_layer_norm_naive( + x, + w, + b, + ) + numpy2, _, _ = _reference_layer_norm_grad( + x, + y_grad, + w, + b, + mean, + variance, + ) + + # forward_prim + np.testing.assert_allclose( + composite1, + numpy1, + rtol=TOLERANCE_NUMPY[attrs.dtype]['rtol'], + atol=TOLERANCE_NUMPY[attrs.dtype]['atol'], + ) + # forward_prim + backward + np.testing.assert_allclose( + composite2, + numpy2, + rtol=TOLERANCE_NUMPY[attrs.dtype]['rtol'], + atol=TOLERANCE_NUMPY[attrs.dtype]['atol'], + ) + # forward_prim + backward_prim + np.testing.assert_allclose( + composite_p2, + numpy2, + rtol=TOLERANCE_NUMPY[attrs.dtype]['rtol'], + atol=TOLERANCE_NUMPY[attrs.dtype]['atol'], + ) + + def test_backward(self): + for j in self.dtypes: + for t in range(0, len(self.shape1s)): + attrs.set_dtype(j) + attrs.set_shape( + self.n_shape[t], + self.shape1s[t], + self.shape2s[t], + self.shape3s[t], + ) + self.compare_backward() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/utils.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/utils.py index 32358375f05..f00b460b9d0 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/utils.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/utils.py @@ -33,6 +33,11 @@ TOLERANCE = { # this tolerance is for big composite ops like batch_norm. SUB_TOLERANCE = { + "float16": { + "forward": {"rtol": 1e-2, "atol": 1e-2}, + "backward": {"rtol": 1e-2, "atol": 1e-2}, + "prim_backward": {"rtol": 1e-2, "atol": 1e-2}, + }, "float32": { "forward": {"rtol": 1e-5, "atol": 1e-5}, "backward": {"rtol": 1e-5, "atol": 1e-5}, diff --git a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py index df36af0f516..40d4407d81b 100644 --- a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py +++ b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py @@ -206,25 +206,6 @@ class TestFlatten2OpError(unittest.TestCase): self.assertRaises(ValueError, test_ValueError5) - def test_type(): - # dtype must be float32, float64, int8, int32, int64, uint8. - x2 = ( - np.arange( - image_shape[0] - * image_shape[1] - * image_shape[2] - * image_shape[3] - ).reshape(image_shape) - / 100.0 - ) - x2 = x2.astype('float16') - x2_var = paddle.fluid.data( - name='x2', shape=[3, 2, 4, 5], dtype='float16' - ) - paddle.flatten(x2_var) - - self.assertRaises(TypeError, test_type) - def test_InputError(): out = paddle.flatten(x) diff --git a/python/paddle/fluid/tests/unittests/xpu/test_flatten_contiguous_range_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_flatten_contiguous_range_op_xpu.py index 7a177651f1e..af6f2095fc9 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_flatten_contiguous_range_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_flatten_contiguous_range_op_xpu.py @@ -264,25 +264,6 @@ class TestFlatten2OpError(unittest.TestCase): self.assertRaises(ValueError, test_ValueError3) - def test_type(): - # dtype must be float32, float64, int8, int32, int64 - x2 = ( - np.arange( - image_shape[0] - * image_shape[1] - * image_shape[2] - * image_shape[3] - ).reshape(image_shape) - / 100.0 - ) - x2 = x2.astype('float16') - x2_var = paddle.fluid.data( - name='x2', shape=[3, 2, 4, 5], dtype='float16' - ) - paddle.flatten(x2_var) - - self.assertRaises(TypeError, test_type) - def test_InputError(): out = paddle.flatten(x) diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 70bb8f8b804..215e5b1ee2e 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -110,6 +110,35 @@ def composite_batchnorm( return y, run_mean_, run_var_, batch_mean_, batch_var_, reserve_space +@REGISTER_COMPOSITE('layer_norm') +def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis): + """ + define composite rule of op layer_norm + out = (x - mean(x)) / sqrt(var + epsilon)) + var = mean((x-mean(x))^2) + """ + + axis = tuple(range(begin_norm_axis, len(x.shape))) + mean_ = mean(x, axis=axis, keepdim=True) + difference = x - mean_ + var_tmp1 = difference * difference + variance = mean(var_tmp1, axis=axis, keepdim=True) + var_tmp3 = variance + epsilon + sqrt_var = sqrt(var_tmp3) + out = difference / sqrt_var + + if scale is not None: + scale = reshape(scale, x.shape[begin_norm_axis:]) + out = out * scale + if bias is not None: + bias = reshape(bias, x.shape[begin_norm_axis:]) + out = out + bias + + mean_ = reshape(mean_, [-1]) + variance = reshape(variance, [-1]) + return out, mean_, variance + + @REGISTER_COMPOSITE('gelu') def gelu_composite(x, approximate): """define composite rule of op gelu""" diff --git a/python/paddle/incubate/autograd/utils.py b/python/paddle/incubate/autograd/utils.py index b4a78dec869..ff367a8ec9b 100644 --- a/python/paddle/incubate/autograd/utils.py +++ b/python/paddle/incubate/autograd/utils.py @@ -184,7 +184,9 @@ def _get_args_values(op, phi_name): and arg_name in op_content["attrs"].keys() ): arg_name = op_content["attrs"][arg_name] + # Note: in some cases, attrs may be optional , thus assign None. Such case must be recorded. + if arg_name not in op.attr_names: attrs.append(None) else: diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index efdd6b1a221..35293a30a0e 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1582,7 +1582,16 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None): check_variable_and_dtype( x, 'x', - ['float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8'], + [ + 'float16', + 'float32', + 'float64', + 'int8', + 'int16', + 'int32', + 'int64', + 'uint8', + ], 'flatten', ) helper = LayerHelper('flatten', **locals()) @@ -3285,7 +3294,7 @@ def broadcast_to(x, shape, name=None): check_variable_and_dtype( x, 'x', - ['bool', 'float32', 'float64', 'int32', 'int64'], + ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], 'broadcast_to', ) check_type(shape, 'shape', (list, tuple, Variable), 'broadcast_to') -- GitLab