From 8d512b8feb7e026ce903b8125f8e836b7a35189b Mon Sep 17 00:00:00 2001 From: wangshengxiang <121413869+shengxiangwang@users.noreply.github.com> Date: Fri, 13 Jan 2023 11:47:47 +0800 Subject: [PATCH] add prelu & prelu_grad op for xpu (#49672) --- paddle/phi/backends/xpu/xpu2_op_list.cc | 3 + paddle/phi/kernels/xpu/prelu_grad_kernel.cc | 97 ++++++++ paddle/phi/kernels/xpu/prelu_kernel.cc | 64 ++++++ .../tests/unittests/xpu/test_prelu_op_xpu.py | 217 ++++++++++++++++++ 4 files changed, 381 insertions(+) create mode 100644 paddle/phi/kernels/xpu/prelu_grad_kernel.cc create mode 100644 paddle/phi/kernels/xpu/prelu_kernel.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_prelu_op_xpu.py diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index fe1b5989c01..77b113c61b0 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -418,6 +418,9 @@ XPUOpMap& get_kl2_ops() { {"pow_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"pow2_decay_with_linear_warmup", XPUKernelSet({phi::DataType::FLOAT32})}, {"prior_box", XPUKernelSet({phi::DataType::FLOAT32})}, + {"prelu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"prelu_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"range", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT64})}, {"reciprocal", XPUKernelSet({phi::DataType::FLOAT32})}, {"reciprocal_grad", diff --git a/paddle/phi/kernels/xpu/prelu_grad_kernel.cc b/paddle/phi/kernels/xpu/prelu_grad_kernel.cc new file mode 100644 index 00000000000..b786b13b3db --- /dev/null +++ b/paddle/phi/kernels/xpu/prelu_grad_kernel.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/prelu_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void PReluGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& alpha, + const DenseTensor& out_grad, + const std::string& data_format, + const std::string& mode, + DenseTensor* x_grad, + DenseTensor* alpha_grad) { + using XPUType = typename XPUTypeTrait::Type; + + const T* x_ptr = x.data(); + const T* alpha_ptr = alpha.data(); + const T* out_grad_ptr = out_grad.data(); + + T* x_grad_ptr = dev_ctx.template Alloc(x_grad); + T* alpha_grad_ptr = dev_ctx.template Alloc(alpha_grad); + + auto x_dim = x.dims(); + auto x_rank = x_dim.size(); + + std::vector x_shape(x_rank); + for (int i = 0; i < x_rank; i++) { + x_shape[i] = x_dim[i]; + } + + auto alpha_dim = alpha.dims(); + auto alpha_rank = alpha_dim.size(); + + std::vector alpha_shape(alpha_rank); + for (int i = 0; i < x_rank; i++) { + alpha_shape[i] = alpha_dim[i]; + } + + // mode = 0: channel_nchw, slope_shape = {c}, default. meanwhile, xhsape = {n, + // c, h, w} + // mode = 1, channel_nhwc, slope_shape = {c}, meanwhile, xhsape = {n, h, w, c} + // mode = 2, elementwise, slope_shape = {c*h*w} + // mode = 3, single slope, slope_shape = {1} + + int xpu_mode = 0; + + if (mode == "channel") { + if (data_format == "NCHW") { + xpu_mode = 0; + } else { + // NHWC + xpu_mode = 1; + } + } else if (mode == "element") { + xpu_mode = 2; + } else { + xpu_mode = 3; + } + + int r = xpu::prelu_grad( + dev_ctx.x_context(), + reinterpret_cast(x_ptr), + reinterpret_cast( + out_grad_ptr), /* const T* y, not used in xpu kernel */ + reinterpret_cast(alpha_ptr), + reinterpret_cast(out_grad_ptr), + reinterpret_cast(x_grad_ptr), + reinterpret_cast(alpha_grad_ptr), + x_shape, + xpu_mode); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "prelu_grad"); +} +} // namespace phi + +PD_REGISTER_KERNEL(prelu_grad, + XPU, + ALL_LAYOUT, + phi::PReluGradKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/prelu_kernel.cc b/paddle/phi/kernels/xpu/prelu_kernel.cc new file mode 100644 index 00000000000..9a72cd79da3 --- /dev/null +++ b/paddle/phi/kernels/xpu/prelu_kernel.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/prelu_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void PReluKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& alpha, + const std::string& data_format, + const std::string& mode, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + + const T* x_ptr = x.data(); + const T* alpha_ptr = alpha.data(); + + T* y_ptr = dev_ctx.template Alloc(out); + + auto x_dim = x.dims(); + auto x_rank = x_dim.size(); + + std::vector x_shape(x_rank); + + for (int i = 0; i < x_rank; i++) { + x_shape[i] = x_dim[i]; + } + + auto alpha_dim = alpha.dims(); + auto alpha_rank = alpha_dim.size(); + + std::vector alpha_shape(x_rank, 1); // same size with x_shape + + for (int i = 0; i < alpha_rank; i++) { + alpha_shape[i] = alpha_dim[i]; + } + + int r = xpu::prelu(dev_ctx.x_context(), + reinterpret_cast(x_ptr), + reinterpret_cast(alpha_ptr), + reinterpret_cast(y_ptr), + x_shape, + alpha_shape); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "prelu"); +} +} // namespace phi + +PD_REGISTER_KERNEL(prelu, XPU, ALL_LAYOUT, phi::PReluKernel, float) {} diff --git a/python/paddle/fluid/tests/unittests/xpu/test_prelu_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_prelu_op_xpu.py new file mode 100644 index 00000000000..2f699ca3c02 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_prelu_op_xpu.py @@ -0,0 +1,217 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np + +sys.path.append("..") + +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import ( + XPUOpTestWrapper, + create_test_class, + get_xpu_op_support_types, +) + +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program + +paddle.enable_static() + + +class XPUTestPReluOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = "prelu" + self.use_dynamic_create_class = False + + class TestPReluOp(XPUOpTest): + def setUp(self): + self.set_xpu() + self.op_type = "prelu" + self.init_dtype() + self.eager_mode = True + + # override + self.init_input_shape() + self.init_attr() + + self.x = np.random.uniform(-10.0, 10.0, self.x_shape).astype( + self.dtype + ) + # Since zero point in prelu is not differentiable, avoid randomize zero. + self.x[np.abs(self.x) < 0.005] = 0.02 + + if self.attrs == { + 'mode': "all", + "data_format": "NCHW", + } or self.attrs == {'mode': "all", "data_format": "NHWC"}: + self.alpha = np.random.uniform(-1, -0.5, (1)) + elif self.attrs == {'mode': "channel", "data_format": "NCHW"}: + self.alpha = np.random.uniform( + -1, -0.5, [1, self.x_shape[1], 1, 1] + ) + elif self.attrs == {'mode': "channel", "data_format": "NHWC"}: + self.alpha = np.random.uniform( + -1, -0.5, [1, 1, 1, self.x_shape[-1]] + ) + else: + self.alpha = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:]) + # eager check don't support mode = 'all' + self.eager_mode = False + self.alpha = self.alpha.astype(self.dtype) + + self.inputs = {'X': self.x, 'Alpha': self.alpha} + + reshaped_alpha = self.inputs['Alpha'] + if self.attrs == {'mode': "channel", "data_format": "NCHW"}: + reshaped_alpha = np.reshape( + self.inputs['Alpha'], + [1, self.x_shape[1]] + [1] * len(self.x_shape[2:]), + ) + elif self.attrs == {'mode': "channel", "data_format": "NHWC"}: + reshaped_alpha = np.reshape( + self.inputs['Alpha'], + [1] + [1] * len(self.x_shape[1:-1]) + [self.x_shape[-1]], + ) + + self.alpha = np.random.uniform( + -10.0, 10.0, [1, self.x_shape[1], 1, 1] + ).astype(self.dtype) + + out_np = np.maximum(self.inputs['X'], 0.0) + out_np = out_np + np.minimum(self.inputs['X'], 0.0) * reshaped_alpha + assert out_np is not self.inputs['X'] + self.outputs = {'Out': out_np} + + def init_input_shape(self): + self.x_shape = [2, 3, 5, 6] + + def init_attr(self): + self.attrs = {'mode': "channel", 'data_format': "NCHW"} + + def set_xpu(self): + self.__class__.no_need_check_grad = False + self.place = paddle.XPUPlace(0) + + def init_dtype(self): + self.dtype = self.in_type + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place( + self.place, ['X', 'Alpha'], 'Out', check_eager=self.eager_mode + ) + + class TestModeChannelNHWC(TestPReluOp): + def init_input_shape(self): + self.x_shape = [2, 3, 4, 5] + + def init_attr(self): + self.attrs = {'mode': "channel", "data_format": "NHWC"} + + class TestModeAll(TestPReluOp): + def init_input_shape(self): + self.x_shape = [2, 3, 4, 5] + + def init_attr(self): + self.attrs = {'mode': "all", "data_format": "NCHW"} + + class TestModeAllNHWC(TestPReluOp): + def init_input_shape(self): + self.x_shape = [2, 3, 4, 50] + + def init_attr(self): + self.attrs = {'mode': "all", "data_format": "NHWC"} + + class TestModeElt(TestPReluOp): + def init_input_shape(self): + self.x_shape = [3, 2, 5, 10] + + def init_attr(self): + self.attrs = {'mode': "element", "data_format": "NCHW"} + + class TestModeEltNHWC(TestPReluOp): + def init_input_shape(self): + self.x_shape = [3, 2, 5, 10] + + def init_attr(self): + self.attrs = {'mode': "element", "data_format": "NHWC"} + + +def prelu_t(x, mode, param_attr=None, name=None, data_format='NCHW'): + helper = fluid.layer_helper.LayerHelper('prelu', **locals()) + alpha_shape = [1, x.shape[1], 1, 1] + dtype = helper.input_dtype(input_param_name='x') + alpha = helper.create_parameter( + attr=helper.param_attr, + shape=alpha_shape, + dtype='float32', + is_bias=False, + default_initializer=fluid.initializer.ConstantInitializer(0.25), + ) + out = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="prelu", + inputs={"X": x, 'Alpha': alpha}, + attrs={"mode": mode, 'data_format': data_format}, + outputs={"Out": out}, + ) + return out + + +# error message test if mode is not one of 'all', 'channel', 'element' +class TestModeError(unittest.TestCase): + def setUp(self): + self.place = paddle.XPUPlace(0) + self.x_np = np.ones([1, 2, 3, 4]).astype('float32') + + def test_mode_error(self): + main_program = Program() + with fluid.program_guard(main_program, Program()): + x = fluid.data(name='x', shape=[2, 3, 4, 5]) + try: + y = prelu_t(x, 'any') + except Exception as e: + assert e.args[0].find('InvalidArgument') != -1 + + def test_data_format_error1(self): + main_program = Program() + with fluid.program_guard(main_program, Program()): + x = fluid.data(name='x', shape=[2, 3, 4, 5]) + try: + y = prelu_t(x, 'channel', data_format='N') + except Exception as e: + assert e.args[0].find('InvalidArgument') != -1 + + def test_data_format_error2(self): + main_program = Program() + with fluid.program_guard(main_program, Program()): + x = fluid.data(name='x', shape=[2, 3, 4, 5]) + try: + y = paddle.static.nn.prelu(x, 'channel', data_format='N') + except ValueError as e: + pass + + +support_types = get_xpu_op_support_types("prelu") +for stype in support_types: + create_test_class(globals(), XPUTestPReluOp, stype) + +if __name__ == "__main__": + unittest.main() -- GitLab