diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 93f431fa056c0a799c378787ecf4ba682809895d..6af74785a0033208ec72538ca8e9d8aa19090b18 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -118,7 +118,8 @@ XPUOpMap& get_kl2_ops() { {"concat", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, - phi::DataType::INT64})}, + phi::DataType::INT64, + phi::DataType::INT32})}, {"conv2d_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"conv2d", @@ -159,7 +160,10 @@ XPUOpMap& get_kl2_ops() { {"elementwise_add_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"elementwise_add", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT64, + phi::DataType::INT32})}, {"elementwise_div_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"elementwise_div", @@ -300,7 +304,11 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT64, phi::DataType::FLOAT32})}, {"gather", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::BOOL})}, {"gaussian_random", XPUKernelSet({phi::DataType::FLOAT32})}, {"gelu_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, @@ -491,7 +499,8 @@ XPUOpMap& get_kl2_ops() { {"scale", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, - phi::DataType::INT64})}, + phi::DataType::INT64, + phi::DataType::INT32})}, {"scatter", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, diff --git a/paddle/phi/kernels/selected_rows/full_kernel.cc b/paddle/phi/kernels/selected_rows/full_kernel.cc index a492c1c304bd2f631f9c107c659eebdaefa82e6b..e04139448dddc2f942886e2abd98b9c8c4431fd9 100644 --- a/paddle/phi/kernels/selected_rows/full_kernel.cc +++ b/paddle/phi/kernels/selected_rows/full_kernel.cc @@ -71,7 +71,7 @@ PD_REGISTER_KERNEL(full_sr, phi::dtype::complex) {} #endif -#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) +#if defined(PADDLE_WITH_XPU) PD_REGISTER_KERNEL(full_sr, XPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/xpu/concat_kernel.cc b/paddle/phi/kernels/xpu/concat_kernel.cc index 4e09f6ef85281d0ce8cbcf7f49473b197c83187c..0bd180b692b106e987bbfd53c70d37fe3c61bd84 100644 --- a/paddle/phi/kernels/xpu/concat_kernel.cc +++ b/paddle/phi/kernels/xpu/concat_kernel.cc @@ -117,4 +117,5 @@ PD_REGISTER_KERNEL(concat, phi::ConcatKernel, float, phi::dtype::float16, - int64_t) {} + int64_t, + int) {} diff --git a/paddle/phi/kernels/xpu/elementwise_add_kernel.cc b/paddle/phi/kernels/xpu/elementwise_add_kernel.cc index 0e19c59d26c91bb69daa08842751f3907f349dbb..d32553ac837638960bb15845bcc5acb2b6a88dc2 100644 --- a/paddle/phi/kernels/xpu/elementwise_add_kernel.cc +++ b/paddle/phi/kernels/xpu/elementwise_add_kernel.cc @@ -75,5 +75,11 @@ PD_REGISTER_KERNEL(grad_add, phi::GradAddXPUKernel, phi::dtype::float16, float) {} -PD_REGISTER_KERNEL( - add_raw, XPU, ALL_LAYOUT, phi::AddRawKernel, phi::dtype::float16, float) {} +PD_REGISTER_KERNEL(add_raw, + XPU, + ALL_LAYOUT, + phi::AddRawKernel, + phi::dtype::float16, + float, + int, + int64_t) {} diff --git a/paddle/phi/kernels/xpu/full_kernel.cc b/paddle/phi/kernels/xpu/full_kernel.cc index ae080d0dad07253f37adc2021c3c9606020dda3d..44f5f8b08b7abaf979ad83895a3a098f824dee33 100644 --- a/paddle/phi/kernels/xpu/full_kernel.cc +++ b/paddle/phi/kernels/xpu/full_kernel.cc @@ -28,32 +28,6 @@ namespace phi { -template -void TensorSetConstantXPU(phi::DenseTensor* tensor, - InType value, - phi::Place place) { - auto* begin = tensor->mutable_data(place); - int64_t numel = tensor->numel(); - std::unique_ptr data_cpu(new OutType[numel]); - std::fill( - data_cpu.get(), data_cpu.get() + numel, static_cast(value)); - paddle::memory::Copy(place, - begin, - phi::CPUPlace(), - static_cast(data_cpu.get()), - numel * sizeof(OutType)); -} - -template -void FullValueXPU(const Context& dev_ctx, DenseTensor* tensor, VType val) { - dev_ctx.template Alloc(tensor); - - PD_VISIT_ALL_TYPES(tensor->dtype(), "FullValueXPU", ([&] { - TensorSetConstantXPU( - tensor, val, dev_ctx.GetPlace()); - })); -} - template void FullKernel(const Context& dev_ctx, const IntArray& shape, @@ -64,13 +38,12 @@ void FullKernel(const Context& dev_ctx, out->Resize(phi::make_ddim(shape.GetData())); int numel = out->numel(); dev_ctx.template Alloc(out); - auto value = val.to(); auto out_data = reinterpret_cast(out->data()); if (numel > 0) { int r = xpu::constant(dev_ctx.x_context(), out_data, out->numel(), - static_cast(value)); + static_cast(val.to())); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); } } diff --git a/paddle/phi/kernels/xpu/gather_kernel.cc b/paddle/phi/kernels/xpu/gather_kernel.cc index 76b2f04ee52bab66470486e0b8b14f1d5cb31aa8..658999c5289d3e871044120b775f2fb67866eb75 100644 --- a/paddle/phi/kernels/xpu/gather_kernel.cc +++ b/paddle/phi/kernels/xpu/gather_kernel.cc @@ -83,5 +83,12 @@ void GatherKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - gather, XPU, ALL_LAYOUT, phi::GatherKernel, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL(gather, + XPU, + ALL_LAYOUT, + phi::GatherKernel, + float, + phi::dtype::float16, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/xpu/instance_norm_grad_kernel.cc b/paddle/phi/kernels/xpu/instance_norm_grad_kernel.cc index adf2cd78787ee9c2bf987886cd3f16543707fb73..641794dab0a4387cc8d2a2bf4787b751a338db46 100644 --- a/paddle/phi/kernels/xpu/instance_norm_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/instance_norm_grad_kernel.cc @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/instance_norm_grad_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/norm_utils.h" namespace phi { @@ -24,46 +25,89 @@ void InstanceNormGradKernel(const Context& dev_ctx, const paddle::optional& scale, const DenseTensor& saved_mean, const DenseTensor& saved_variance, - const DenseTensor& y_grad, + const DenseTensor& d_y, float epsilon, - DenseTensor* x_grad, - DenseTensor* scale_grad, - DenseTensor* bias_grad) { + DenseTensor* d_x, + DenseTensor* d_scale, + DenseTensor* d_bias) { using XPUType = typename XPUTypeTrait::Type; - const auto& x_dims = x.dims(); - int n = x_dims[0]; - int c = x_dims[1]; - int h = x_dims[2]; - int w = x_dims[3]; + int N, C, H, W, D; + funcs::ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D); + PADDLE_ENFORCE_EQ( + x_dims.size() <= 5 && D == 1, + true, + phi::errors::InvalidArgument( + "The size of input's dimensions should be less equal than 5", + "and the dimension of D should be eaual to 1", + "But received: the size of input's dimensions is [%d]", + x_dims.size())); - dev_ctx.template Alloc(x_grad); - if (bias_grad != nullptr) { - dev_ctx.template Alloc(bias_grad); - } - if (scale_grad != nullptr) { - dev_ctx.template Alloc(scale_grad); + dev_ctx.template Alloc(d_x); + T* d_scale_data = nullptr; + T* d_bias_data = nullptr; + if (d_scale && d_bias) { + dev_ctx.template Alloc(d_scale); + dev_ctx.template Alloc(d_bias); + d_scale_data = d_scale->data(); + d_bias_data = d_bias->data(); } const auto scale_ptr = scale.get_ptr(); + if (scale_ptr) { + PADDLE_ENFORCE_EQ( + scale_ptr->dims().size(), + 1UL, + phi::errors::InvalidArgument( + "The `shape` in InstanceNormOp is invalid: " + "the size of scale's dimensions must be equal to 1. But " + "received: the size of scale's dimensions" + "is [%d]", + scale_ptr->dims().size())); + PADDLE_ENFORCE_EQ(scale_ptr->dims()[0], + C, + phi::errors::InvalidArgument( + "The `shape` in InstanceNormOp is invalid: " + "the first dimension of scale must be equal to " + "Channels([%d]). But received: " + "the first dimension of scale is [%d]," + "the dimensions of scale is [%s], ", + C, + scale_ptr->dims()[0], + scale_ptr->dims())); + } - int r = xpu::instance_norm_grad( - dev_ctx.x_context(), - reinterpret_cast(x.data()), - reinterpret_cast(y_grad.data()), - reinterpret_cast(x_grad->data()), - scale_ptr->data(), - saved_mean.data(), - saved_variance.data(), - scale_grad->data(), - bias_grad->data(), - n, - c, - h, - w, - epsilon, - true); + DenseTensor scale_tmp; + int r; + if (!scale_ptr) { + scale_tmp.Resize({C}); + dev_ctx.template Alloc(&scale_tmp); + r = xpu::constant(dev_ctx.x_context(), + reinterpret_cast(scale_tmp.data()), + scale_tmp.numel(), + static_cast(1)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + } + auto scale_ptr_tmp = scale_ptr ? scale_ptr : &scale_tmp; + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + auto d_x_data = + d_x ? d_x->data() : RAII_GUARD.alloc_l3_or_gm(x.numel()); + r = xpu::instance_norm_grad(dev_ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(d_y.data()), + reinterpret_cast(d_x_data), + scale_ptr_tmp->data(), + saved_mean.data(), + saved_variance.data(), + d_scale_data, + d_bias_data, + N, + C, + H, + W, + epsilon, + true); PADDLE_ENFORCE_XDNN_SUCCESS(r, "instance_norm_grad"); } diff --git a/paddle/phi/kernels/xpu/scale_kernel.cc b/paddle/phi/kernels/xpu/scale_kernel.cc index a478dfddf1edb73b47091db831f996ace2751b05..cc8bdf6165e295238600353b2c25dbe654e18181 100644 --- a/paddle/phi/kernels/xpu/scale_kernel.cc +++ b/paddle/phi/kernels/xpu/scale_kernel.cc @@ -58,4 +58,5 @@ PD_REGISTER_KERNEL(scale, phi::ScaleKernel, float, phi::dtype::float16, + int, int64_t) {} diff --git a/paddle/phi/kernels/xpu/scatter_kernel.cc b/paddle/phi/kernels/xpu/scatter_kernel.cc index 18e4e03dd27870c99b32b0fad5ab33ec138b156a..de48682c1ecd552badca494b6cf1ae3c4f1d8e7f 100644 --- a/paddle/phi/kernels/xpu/scatter_kernel.cc +++ b/paddle/phi/kernels/xpu/scatter_kernel.cc @@ -114,4 +114,4 @@ void ScatterKernel(const Context &ctx, } // namespace phi PD_REGISTER_KERNEL( - scatter, XPU, ALL_LAYOUT, phi::ScatterKernel, float, int64_t) {} + scatter, XPU, ALL_LAYOUT, phi::ScatterKernel, float, int, int64_t) {} diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_xpu_setup.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_xpu_setup.py index ef0f52d5c3f2dafbc6a480aa2c1497c87b793666..41e253a239a0f438a936633f0e4ac45f8f8d9afe 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_xpu_setup.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_xpu_setup.py @@ -195,7 +195,7 @@ class TestNewCustomOpXpuSetUpInstall(unittest.TestCase): self.custom_op = custom_relu_xpu_module_setup.custom_relu - self.dtypes = ['float32', 'float64'] + self.dtypes = ['float32'] self.device = 'xpu' # config seed diff --git a/python/paddle/fluid/tests/unittests/xpu/test_instance_norm_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_instance_norm_op_xpu.py index 52f4d2b3992c2a4e06319f26ae487809f40c02c8..8968d69c7b7cf1bd38da4cf670f813ce37ac416b 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_instance_norm_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_instance_norm_op_xpu.py @@ -18,6 +18,8 @@ import unittest import numpy as np import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard sys.path.append("..") from op_test_xpu import XPUOpTest @@ -69,6 +71,7 @@ class XPUTestInstanceNormOp(XPUOpTestWrapper): self.dtype = self.in_type self.shape = [2, 3, 4, 5] self.epsilon = 1e-05 + self.no_grad_set = None self.set_attrs() np.random.seed(12345) @@ -112,7 +115,12 @@ class XPUTestInstanceNormOp(XPUOpTestWrapper): self.check_output_with_place(paddle.XPUPlace(0)) def test_check_grad(self): - self.check_grad_with_place(paddle.XPUPlace(0), ['X'], 'Y') + self.check_grad_with_place( + paddle.XPUPlace(0), + ['X', 'Scale', 'Bias'], + ['Y'], + self.no_grad_set, + ) class TestXPUInstanceNormOp1(XPUTestInstanceNormOp): def set_attrs(self): @@ -134,6 +142,57 @@ class XPUTestInstanceNormOp(XPUOpTestWrapper): def set_attrs(self): self.shape = [10, 3, 512, 1] + class TestXPUInstanceNormOp6(XPUTestInstanceNormOp): + def set_attrs(self): + self.shape = [10, 12, 32, 32] + self.no_grad_set = set(['Scale', 'Bias']) + + class TestXPUInstanceNormOp7(XPUTestInstanceNormOp): + def set_attrs(self): + self.shape = [4, 5, 6, 7] + self.no_grad_set = set(['Scale', 'Bias']) + + class TestXPUInstanceNormOp8(XPUTestInstanceNormOp): + def set_attrs(self): + self.shape = [1, 8, 16, 16] + self.no_grad_set = set(['Scale', 'Bias']) + + class TestXPUInstanceNormOp9(XPUTestInstanceNormOp): + def set_attrs(self): + self.shape = [4, 16, 256, 128] + self.no_grad_set = set(['Scale', 'Bias']) + + class TestXPUInstanceNormOp10(XPUTestInstanceNormOp): + def set_attrs(self): + self.shape = [10, 3, 512, 1] + self.no_grad_set = set(['Scale', 'Bias']) + + class TestInstanceNormOpError(XPUOpTest): + def setUp(self): + self.__class__.op_type = "instance_norm" + self.__class__.no_need_check_grad = True + self.dtype = self.in_type + + def test_errors(self): + with program_guard(Program(), Program()): + # the input of instance_norm must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.XPUPlace(0) + ) + self.assertRaises(TypeError, paddle.static.nn.instance_norm, x1) + + # the input dtype of instance_norm must be float32 + x2 = paddle.static.data( + name='x2', shape=[-1, 3, 4, 5, 6], dtype="int32" + ) + self.assertRaises(TypeError, paddle.static.nn.instance_norm, x2) + + # the first dimension of input for instance_norm must between [2d, 5d] + x3 = paddle.static.data(name='x', shape=[3], dtype="float32") + self.assertRaises( + ValueError, paddle.static.nn.instance_norm, x3 + ) + support_types = get_xpu_op_support_types('instance_norm') for stype in support_types: