From fb16bdc7ed1cc3da48a487a8284f9dee2a366a3a Mon Sep 17 00:00:00 2001 From: ykkk2333 <77383312+ykkk2333@users.noreply.github.com> Date: Thu, 30 Mar 2023 14:13:58 +0800 Subject: [PATCH] add xpu cumprod, group norm grad (#52089) --- cmake/external/xpu.cmake | 2 +- paddle/phi/backends/xpu/xpu2_op_list.cc | 14 ++ paddle/phi/backends/xpu/xpu_context.cc | 7 +- paddle/phi/kernels/xpu/cumprod_kernel.cc | 57 ++++++ .../phi/kernels/xpu/group_norm_grad_kernel.cc | 114 +++++++++++ paddle/phi/kernels/xpu/pool_grad_kernel.cc | 6 +- .../unittests/xpu/test_cumprod_op_xpu.py | 181 ++++++++++++++++++ .../unittests/xpu/test_group_norm_op_xpu.py | 4 +- 8 files changed, 378 insertions(+), 7 deletions(-) create mode 100644 paddle/phi/kernels/xpu/cumprod_kernel.cc create mode 100644 paddle/phi/kernels/xpu/group_norm_grad_kernel.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_cumprod_op_xpu.py diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 5e9c939c4df..138f06c4ae8 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -8,7 +8,7 @@ set(XPU_API_LIB_NAME "libxpuapi.so") set(XPU_RT_LIB_NAME "libxpurt.so") set(XPU_XFT_LIB_NAME "libxft.so") -set(XPU_BASE_DATE "20230310") +set(XPU_BASE_DATE "20230323") set(XPU_XCCL_BASE_VERSION "1.0.13") set(XPU_XFT_BASE_VERSION "latest") diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 27a4e054a7b..eba37fdf4ff 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -160,6 +160,10 @@ XPUOpMap& get_kl2_ops() { phi::DataType::FLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, + {"cumprod", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, {"deformable_conv_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"deformable_conv", XPUKernelSet({phi::DataType::FLOAT32})}, {"deformable_conv_v1_grad", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -714,6 +718,15 @@ XPUOpMap& get_kl2_ops() { {"tanh", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"temporal_shift", XPUKernelSet({phi::DataType::FLOAT32})}, {"temporal_shift_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"transfer_dtype", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT64, + phi::DataType::BOOL, + phi::DataType::UINT8, + phi::DataType::INT8, + phi::DataType::INT64, + phi::DataType::INT32})}, {"tril_triu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32, @@ -844,6 +857,7 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})}, {"randint", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})}, {"group_norm", XPUKernelSet({phi::DataType::FLOAT32})}, + {"group_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"meshgrid", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32, diff --git a/paddle/phi/backends/xpu/xpu_context.cc b/paddle/phi/backends/xpu/xpu_context.cc index acb8ae8db3b..c1fc20761dc 100644 --- a/paddle/phi/backends/xpu/xpu_context.cc +++ b/paddle/phi/backends/xpu/xpu_context.cc @@ -62,7 +62,7 @@ struct XPUContext::Impl { std::string cur_thread_name = phi::GetCurrentThreadName(); VLOG(3) << "XPU Dataloader: current thread at Get Context = " << phi::GetCurrentThreadName(); - bool is_dataloader_thread = (cur_thread_name.substr(0, 10) == "Dataloader"); + bool is_dataloader_thread = (cur_thread_name != "MainThread"); return is_dataloader_thread; } @@ -146,6 +146,11 @@ struct XPUContext::Impl { backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId()); PD_CHECK(context_ != nullptr, "the xpu context is nullptr."); xpu_wait(context_->xpu_stream); + xpu::Context* ctx_t = GetXdlCtx(); + if (ctx_t) { + PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr."); + xpu_wait(ctx_t->xpu_stream); + } } void Init() { diff --git a/paddle/phi/kernels/xpu/cumprod_kernel.cc b/paddle/phi/kernels/xpu/cumprod_kernel.cc new file mode 100644 index 00000000000..c9b771c7bd3 --- /dev/null +++ b/paddle/phi/kernels/xpu/cumprod_kernel.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cumprod_kernel.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/cumprod.h" + +namespace phi { +template +void CumprodKernel(const Context& dev_ctx, + const DenseTensor& input, + int dim, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + const DenseTensor* x = &input; + auto* x_data = x->data(); + auto* out_data = dev_ctx.template Alloc(out); + DDim shape = x->dims(); + std::vector xshape = phi::vectorize(shape); + + if (dim < 0) dim += xshape.size(); + if (shape.size() == 0) { + int r = + xpu::copy(dev_ctx.x_context(), + reinterpret_cast(input.data()), + reinterpret_cast(out->data()), + input.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); + + return; + } + + int r = xpu::cumprod(dev_ctx.x_context(), + reinterpret_cast(x_data), + reinterpret_cast(out_data), + xshape, + dim); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cumprod"); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + cumprod, XPU, ALL_LAYOUT, phi::CumprodKernel, float, int, int64_t) {} diff --git a/paddle/phi/kernels/xpu/group_norm_grad_kernel.cc b/paddle/phi/kernels/xpu/group_norm_grad_kernel.cc new file mode 100644 index 00000000000..08532e22d86 --- /dev/null +++ b/paddle/phi/kernels/xpu/group_norm_grad_kernel.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/group_norm_grad_kernel.h" + +#include +#include +#include +#include + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void GroupNormGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + const DenseTensor& y, + const DenseTensor& mean, + const DenseTensor& var, + const DenseTensor& d_y, + float epsilon, + int groups, + const std::string& data_layout_str, + DenseTensor* d_x, + DenseTensor* d_scale, + DenseTensor* d_bias) { + using XPUType = typename XPUTypeTrait::Type; + const DataLayout data_layout = phi::StringToDataLayout(data_layout_str); + const auto scale_ptr = scale.get_ptr(); + const auto bias_ptr = bias.get_ptr(); + const auto x_dims = phi::vectorize(x.dims()); + const int N = x_dims[0]; + const bool channel_first = + data_layout == DataLayout::kNCHW || data_layout == DataLayout::kNCDHW; + const int C = (channel_first ? x_dims[1] : x_dims[x_dims.size() - 1]); + const int L = + (channel_first + ? std::accumulate( + x_dims.begin() + 2, x_dims.end(), 1, std::multiplies()) + : std::accumulate(x_dims.begin() + 1, + x_dims.end() - 1, + 1, + std::multiplies())); + + dev_ctx.template Alloc(d_x); + phi::funcs::SetConstant set_zero; + + auto* x_data = x.data(); + auto* y_data = y.data(); + auto* d_x_data = d_x->data(); + auto* d_y_data = d_y.data(); + auto* mean_data = mean.data(); + auto* var_data = var.data(); + T* d_scale_data = nullptr; + if (d_scale) { + dev_ctx.template Alloc(d_scale); + set_zero(dev_ctx, d_scale, static_cast(0)); + d_scale_data = d_scale->data(); + } + T* d_bias_data = nullptr; + if (d_bias) { + dev_ctx.template Alloc(d_bias); + set_zero(dev_ctx, d_bias, static_cast(0)); + d_bias_data = d_bias->data(); + } + + const T* scale_data = nullptr; + if (scale_ptr) scale_data = scale_ptr->data(); + const T* bias_data = nullptr; + if (bias_ptr) bias_data = bias_ptr->data(); + + int r = xpu::group_norm_grad( + dev_ctx.x_context(), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + reinterpret_cast(d_y_data), + reinterpret_cast(d_x_data), + N, + C, + L, + 1, + groups, + static_cast(epsilon), + reinterpret_cast(scale_data), + reinterpret_cast(bias_data), + reinterpret_cast(mean_data), + reinterpret_cast(var_data), + reinterpret_cast(d_scale_data), + reinterpret_cast(d_bias_data), + channel_first); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "group_norm_grad"); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + group_norm_grad, XPU, ALL_LAYOUT, phi::GroupNormGradKernel, float) {} diff --git a/paddle/phi/kernels/xpu/pool_grad_kernel.cc b/paddle/phi/kernels/xpu/pool_grad_kernel.cc index a94a757dc8b..afc0bb8fbe1 100644 --- a/paddle/phi/kernels/xpu/pool_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/pool_grad_kernel.cc @@ -116,8 +116,7 @@ void Pool2dGradKernel(const Context& ctx, // and broadcast kernels to get same output, but better performance. // Since the dim is special in particular models, // use 'export XPU_POOLING_GRAD_SPECIAL=1' to open this path - if (out_h == 1 && out_w == 1 && std::is_same::value && - std::getenv("XPU_POOLING_GRAD_SPECIAL") != nullptr) { + if (out_h == 1 && out_w == 1 && std::is_same::value) { xpu::ctx_guard RAII_GUARD(ctx.x_context()); float scale = 1.0 / (in_h * in_w); float* scaled_dy = RAII_GUARD.alloc_l3_or_gm(n * c); @@ -301,8 +300,7 @@ void Pool3dGradKernel(const Context& ctx, } else if (pooling_type == "avg") { if (out_d == 1 && out_h == 1 && out_w == 1 && - std::is_same::value && - std::getenv("XPU_POOLING_GRAD_SPECIAL") != nullptr) { + std::is_same::value) { xpu::ctx_guard RAII_GUARD(ctx.x_context()); float scale = 1.0 / (in_d * in_h * in_w); float* scaled_dy = RAII_GUARD.alloc_l3_or_gm(n * c); diff --git a/python/paddle/fluid/tests/unittests/xpu/test_cumprod_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_cumprod_op_xpu.py new file mode 100644 index 00000000000..3ea12d2bf9f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_cumprod_op_xpu.py @@ -0,0 +1,181 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import sys +import unittest + +import numpy as np + +sys.path.append("..") + +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import ( + XPUOpTestWrapper, + create_test_class, + get_xpu_op_support_types, +) + +import paddle + +np.random.seed(0) + + +# define cumprod grad function. +def cumprod_grad(x, y, dy, dx, shape, dim): + if dim < 0: + dim += len(shape) + mid_dim = shape[dim] + outer_dim = 1 + inner_dim = 1 + for i in range(0, dim): + outer_dim *= shape[i] + for i in range(dim + 1, len(shape)): + inner_dim *= shape[i] + for i in range(outer_dim): + for k in range(inner_dim): + for j in range(mid_dim): + index = i * mid_dim * inner_dim + j * inner_dim + k + for n in range(mid_dim): + pos = i * mid_dim * inner_dim + n * inner_dim + k + elem = 0 + if j == 0: + elem = dy[pos] + else: + elem = dy[pos] * y[index - inner_dim] + if pos > index: + for m in range( + index + inner_dim, pos + inner_dim, inner_dim + ): + elem *= x[m] + elif pos < index: + elem = 0 + dx[index] += elem + + +# test function. +class XPUTestCumprodOP(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'cumprod' + self.use_dynamic_create_class = False + + class TestCumprod(XPUOpTest): + def init_params(self): + self.shape = (2, 3, 4, 5) + self.zero_nums = [0, 10, 20, 30, int(np.prod(self.shape))] + + def init_dtype(self): + self.dtype = self.in_type + + def setUp(self): + paddle.enable_static() + self.place = paddle.XPUPlace(0) + self.init_params() + self.init_dtype() + self.op_type = "cumprod" + self.python_api = paddle.cumprod + self.inputs = {'X': None} + self.outputs = {'Out': None} + self.attrs = {'dim': None} + + def prepare_inputs_outputs_attrs(self, dim, zero_num): + self.x = np.random.random(self.shape).astype(self.dtype) + 0.5 + if zero_num > 0: + zero_num = min(zero_num, self.x.size) + shape = self.x.shape + self.x = self.x.flatten() + indices = random.sample(range(self.x.size), zero_num) + for i in indices: + self.x[i] = 0 + self.x = np.reshape(self.x, self.shape) + self.out = np.cumprod(self.x, axis=dim) + self.inputs = {'X': self.x} + self.outputs = {'Out': self.out} + self.attrs = {'dim': dim} + + def init_grad_input_output(self, dim): + reshape_x = self.x.reshape(self.x.size) + self.grad_out = np.ones(self.x.size, self.dtype) + self.grad_x = np.zeros(self.x.size, self.dtype) + out_data = self.out.reshape(self.x.size) + if self.dtype == np.complex128 or self.dtype == np.complex64: + reshape_x = np.conj(reshape_x) + out_data = np.conj(out_data) + cumprod_grad( + reshape_x, out_data, self.grad_out, self.grad_x, self.shape, dim + ) + self.grad_x = self.grad_x.reshape(self.shape) + self.grad_out = self.grad_out.reshape(self.shape) + + # test forward. + def test_check_output(self): + for dim in range(-len(self.shape), len(self.shape)): + for zero_num in self.zero_nums: + self.prepare_inputs_outputs_attrs(dim, zero_num) + self.check_output_with_place(self.place) + + # test backward. + def test_check_grad(self): + pass + + # test api. + class TestCumprodAPI(unittest.TestCase): + def init_dtype(self): + self.dtype = 'float32' + self.shape = [2, 3, 10, 10] + + def setUp(self): + paddle.enable_static() + self.init_dtype() + self.x = (np.random.rand(2, 3, 10, 10) + 0.5).astype(self.dtype) + self.place = [paddle.XPUPlace(0)] + + # test static graph api. + def test_static_api(self): + paddle.enable_static() + + def run(place): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('X', self.shape, dtype=self.dtype) + out = paddle.cumprod(x, -2) + exe = paddle.static.Executor(place) + res = exe.run(feed={'X': self.x}, fetch_list=[out]) + out_ref = np.cumprod(self.x, -2) + + for r in res: + np.testing.assert_allclose(out_ref, r, rtol=1e-05) + + for place in self.place: + run(place) + + # test dynamic graph api. + def test_dygraph_api(self): + def run(place): + paddle.disable_static(place) + x = paddle.to_tensor(self.x) + out = paddle.cumprod(x, 1) + out_ref = np.cumprod(self.x, 1) + np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05) + paddle.enable_static() + + for place in self.place: + run(place) + + +support_types = get_xpu_op_support_types('cumprod') +for stype in support_types: + create_test_class(globals(), XPUTestCumprodOP, stype) + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_group_norm_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_group_norm_op_xpu.py index c583a952856..67161776f81 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_group_norm_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_group_norm_op_xpu.py @@ -92,7 +92,9 @@ class XPUTestGroupNormOp(XPUOpTestWrapper): self.check_output_with_place(paddle.XPUPlace(0)) def test_check_grad(self): - pass + self.check_grad_with_place( + paddle.XPUPlace(0), ['X', 'Scale', 'Bias'], 'Y' + ) class TestGroupNormOp2(TestGroupNormOp): def init_test_case(self): -- GitLab