From 292738f344eb6107b97f2ca2b9e8c419ebb12f02 Mon Sep 17 00:00:00 2001 From: JYChen Date: Fri, 6 Jan 2023 16:26:04 +0800 Subject: [PATCH] [zero-dim] Support 0-d for kthvalue and mode (#49340) * add 0-d support for paddle.kthvalue * add 0-d support for paddle.mode * fix coverage test for device * fix check-bug in windows * change axis check from LT to LE * add shape & value check for grad when input is 0d tensor --- paddle/phi/infermeta/unary.cc | 59 +++++++++-------- .../phi/kernels/cpu/kthvalue_grad_kernel.cc | 10 ++- paddle/phi/kernels/cpu/kthvalue_kernel.cc | 14 ++++ paddle/phi/kernels/cpu/mode_grad_kernel.cc | 10 ++- paddle/phi/kernels/cpu/mode_kernel.cc | 8 +++ .../phi/kernels/gpu/kthvalue_grad_kernel.cu | 10 ++- paddle/phi/kernels/gpu/kthvalue_kernel.cu | 13 ++++ paddle/phi/kernels/gpu/mode_grad_kernel.cu | 7 ++ paddle/phi/kernels/gpu/mode_kernel.cu | 8 +++ .../fluid/tests/unittests/test_kthvalue_op.py | 6 ++ .../tests/unittests/test_zero_dim_tensor.py | 64 +++++++++++++++++++ 11 files changed, 180 insertions(+), 29 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index b92b26f643..da24fed946 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1785,20 +1785,22 @@ void KthvalueInferMeta(const MetaTensor& x, MetaConfig config) { auto input_dims = x.dims(); const int& dim_size = input_dims.size(); - PADDLE_ENFORCE_LT(axis, + PADDLE_ENFORCE_LE(axis, dim_size, phi::errors::InvalidArgument( "the axis must be [-%d, %d), but received %d .", dim_size, dim_size, axis)); - PADDLE_ENFORCE_GE(axis, - -dim_size, - phi::errors::InvalidArgument( - "the axis must be [-%d, %d), but received %d .", - dim_size, - dim_size, - axis)); + if (dim_size > 0) { + PADDLE_ENFORCE_GE(axis, + -dim_size, + phi::errors::InvalidArgument( + "the axis must be [-%d, %d), but received %d .", + dim_size, + dim_size, + axis)); + } if (axis < 0) axis += dim_size; PADDLE_ENFORCE_GE( k, @@ -1807,9 +1809,9 @@ void KthvalueInferMeta(const MetaTensor& x, "the k in the kthvalue must >= 1, but received %d .", k)); PADDLE_ENFORCE_GE( input_dims.size(), - 1, - phi::errors::InvalidArgument("input of kthvalue must have >= 1d shape")); - if (config.is_runtime) { + 0, + phi::errors::InvalidArgument("input of kthvalue must have >= 0d shape")); + if (dim_size > 0 && config.is_runtime) { PADDLE_ENFORCE_GE( input_dims[axis], k, @@ -1822,7 +1824,7 @@ void KthvalueInferMeta(const MetaTensor& x, for (int64_t i = 0; i < axis; i++) { dimvec.emplace_back(input_dims[i]); } - if (keepdim) { + if (keepdim && dim_size > 0) { dimvec.emplace_back(static_cast(1)); } for (int64_t i = axis + 1; i < dim_size; i++) { @@ -2071,33 +2073,38 @@ void ModeInferMeta(const MetaTensor& x, MetaTensor* indices) { auto input_dims = x.dims(); const int& dim_size = input_dims.size(); - PADDLE_ENFORCE_EQ( - (axis < dim_size) && (axis >= (-1 * dim_size)), - true, - errors::InvalidArgument( - "the axis of ModeOp must be [-%d, %d), but you set axis is %d", - dim_size, - dim_size, - axis)); + PADDLE_ENFORCE_LE(axis, + dim_size, + phi::errors::InvalidArgument( + "the axis must be [-%d, %d), but received %d .", + dim_size, + dim_size, + axis)); + if (dim_size > 0) { + PADDLE_ENFORCE_GE(axis, + -dim_size, + phi::errors::InvalidArgument( + "the axis must be [-%d, %d), but received %d .", + dim_size, + dim_size, + axis)); + } PADDLE_ENFORCE_GE( input_dims.size(), - 1, - errors::InvalidArgument("input of ModeOp must have >= 1d shape")); + 0, + errors::InvalidArgument("input of ModeOp must have >= 0d shape")); if (axis < 0) axis += dim_size; std::vector dimvec; for (int64_t i = 0; i < axis; i++) { dimvec.emplace_back(input_dims[i]); } - if (keepdim) { + if (keepdim && dim_size > 0) { dimvec.emplace_back(static_cast(1)); } for (int64_t i = axis + 1; i < dim_size; i++) { dimvec.emplace_back(input_dims[i]); } DDim dims = phi::make_ddim(dimvec); - PADDLE_ENFORCE_GE(input_dims.size(), - 1, - errors::InvalidArgument("input shape should >= 1d")); out->set_dims(dims); out->share_lod(x); out->set_dtype(x.dtype()); diff --git a/paddle/phi/kernels/cpu/kthvalue_grad_kernel.cc b/paddle/phi/kernels/cpu/kthvalue_grad_kernel.cc index 386d41984b..459c66fae0 100644 --- a/paddle/phi/kernels/cpu/kthvalue_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/kthvalue_grad_kernel.cc @@ -55,6 +55,14 @@ void KthvalueGradKernel(const Context& dev_ctx, DenseTensor* d_x) { auto in_dims = x.dims(); auto out_dims = indices.dims(); + T* x_grad_data = dev_ctx.template Alloc(d_x); + + // For 0D Tensor + if (in_dims.size() == 0) { + phi::funcs::set_constant(dev_ctx, d_x, 1.0); + return; + } + axis = (axis < 0) ? (in_dims.size() + axis) : axis; if (!keepdim) { std::vector tmp_out_shape; @@ -67,7 +75,7 @@ void KthvalueGradKernel(const Context& dev_ctx, } out_dims = phi::make_ddim(tmp_out_shape); } - T* x_grad_data = dev_ctx.template Alloc(d_x); + if (axis == in_dims.size() - 1) { const int64_t input_height = phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); diff --git a/paddle/phi/kernels/cpu/kthvalue_kernel.cc b/paddle/phi/kernels/cpu/kthvalue_kernel.cc index 5e436623ca..a39fc59dcf 100644 --- a/paddle/phi/kernels/cpu/kthvalue_kernel.cc +++ b/paddle/phi/kernels/cpu/kthvalue_kernel.cc @@ -82,8 +82,22 @@ void KthvalueKernel(const Context& dev_ctx, DenseTensor* indices) { const auto& in_dims = x.dims(); if (axis < 0) axis += in_dims.size(); + T* output_data = dev_ctx.template Alloc(output); int64_t* indices_data = dev_ctx.template Alloc(indices); + // For 0D Tensor + if (in_dims.size() == 0) { + PADDLE_ENFORCE_EQ(k, + 1, + phi::errors::InvalidArgument( + "the k in the kthvalue must less equal than the " + "elemenents number of the input X, but received %d .", + k)); + + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, output); + phi::funcs::set_constant(dev_ctx, indices, 0); + return; + } auto out_dims = output->dims(); if (axis == in_dims.size() - 1) { const int64_t& input_height = diff --git a/paddle/phi/kernels/cpu/mode_grad_kernel.cc b/paddle/phi/kernels/cpu/mode_grad_kernel.cc index 05675cf1ab..2878e4f047 100644 --- a/paddle/phi/kernels/cpu/mode_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/mode_grad_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/mode.h" namespace phi { @@ -32,9 +33,17 @@ void ModeGradKernel(const Context& dev_ctx, auto in_dims = x.dims(); auto out_dims = indices.dims(); + T* x_grad_data = dev_ctx.template Alloc(x_grad); + // axis < 0, get the real axis axis = (axis < 0) ? (in_dims.size() + axis) : axis; + // For 0D Tensor + if (in_dims.size() == 0) { + phi::funcs::set_constant(dev_ctx, x_grad, 1.0); + return; + } + if (!keepdim) { std::vector tmp_out_shape; for (int i = 0; i < axis; i++) { @@ -46,7 +55,6 @@ void ModeGradKernel(const Context& dev_ctx, } out_dims = phi::make_ddim(tmp_out_shape); } - T* x_grad_data = dev_ctx.template Alloc(x_grad); if (axis == in_dims.size() - 1) { // allocate the memory for the input_grad diff --git a/paddle/phi/kernels/cpu/mode_kernel.cc b/paddle/phi/kernels/cpu/mode_kernel.cc index 6535d1b89a..00958ccd86 100644 --- a/paddle/phi/kernels/cpu/mode_kernel.cc +++ b/paddle/phi/kernels/cpu/mode_kernel.cc @@ -16,6 +16,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/mode.h" namespace phi { @@ -34,6 +35,13 @@ void ModeKernel(const Context& dev_ctx, T* output_data = dev_ctx.template Alloc(out); int64_t* indices_data = dev_ctx.template Alloc(indices); + + if (in_dims.size() == 0) { + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + phi::funcs::set_constant(dev_ctx, indices, 0); + return; + } + // if axis is not the last dim, transpose it to the last dim, do the // calculation, then tranpose it back to original axis. if (axis == in_dims.size() - 1) { diff --git a/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu b/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu index 044fe7c621..69c65aa839 100644 --- a/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu @@ -16,6 +16,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/top_k_function_cuda.h" namespace phi { @@ -43,8 +44,15 @@ void KthvalueGradKernel(const Context& dev_ctx, DenseTensor* d_x) { const auto& in_dims = x.dims(); auto out_dims = indices.dims(); - if (axis < 0) axis += in_dims.size(); T* x_grad_data = dev_ctx.template Alloc(d_x); + // For 0D Tensor + if (in_dims.size() == 0) { + phi::funcs::set_constant(dev_ctx, d_x, 1.0); + return; + } + + if (axis < 0) axis += in_dims.size(); + const T* out_grad_data = d_out.data(); const int64_t* indices_data = indices.data(); int pre, n, post; diff --git a/paddle/phi/kernels/gpu/kthvalue_kernel.cu b/paddle/phi/kernels/gpu/kthvalue_kernel.cu index b04cea2ceb..1340eb0d55 100644 --- a/paddle/phi/kernels/gpu/kthvalue_kernel.cu +++ b/paddle/phi/kernels/gpu/kthvalue_kernel.cu @@ -167,6 +167,19 @@ void KthvalueKernel(const Context& dev_ctx, T* output_data = dev_ctx.template Alloc(output); int64_t* indices_data = dev_ctx.template Alloc(indices); + // For 0D Tensor + if (in_dims.size() == 0) { + PADDLE_ENFORCE_EQ(k, + 1, + phi::errors::InvalidArgument( + "the k in the kthvalue must less equal than the " + "elemenents number of the input X, but received %d .", + k)); + + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, output); + phi::funcs::set_constant(dev_ctx, indices, 0); + return; + } if (axis == in_dims.size() - 1) { const int64_t& input_height = phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); diff --git a/paddle/phi/kernels/gpu/mode_grad_kernel.cu b/paddle/phi/kernels/gpu/mode_grad_kernel.cu index 43502621c2..e297eb88d0 100644 --- a/paddle/phi/kernels/gpu/mode_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/mode_grad_kernel.cu @@ -16,6 +16,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/mode.h" namespace phi { @@ -61,6 +62,12 @@ void ModeGradKernel(const Context& dev_ctx, const T* out_grad_data = out_grad.data(); const int64_t* indices_data = indices.data(); + // For 0D Tensor + if (in_dims.size() == 0) { + phi::funcs::set_constant(dev_ctx, x_grad, 1.0); + return; + } + int pre, n, post; funcs::GetDims(in_dims, axis, &pre, &n, &post); diff --git a/paddle/phi/kernels/gpu/mode_kernel.cu b/paddle/phi/kernels/gpu/mode_kernel.cu index 629b9722cd..dfef96688a 100644 --- a/paddle/phi/kernels/gpu/mode_kernel.cu +++ b/paddle/phi/kernels/gpu/mode_kernel.cu @@ -16,6 +16,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/mode.h" namespace phi { @@ -38,6 +39,13 @@ void ModeKernel(const Context& dev_ctx, T* output_data = dev_ctx.template Alloc(out); int64_t* indices_data = dev_ctx.template Alloc(indices); + // For 0D Tensor + if (in_dims.size() == 0) { + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + phi::funcs::set_constant(dev_ctx, indices, 0); + return; + } + if (axis == in_dims.size() - 1) { const int64_t& input_height = phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); diff --git a/python/paddle/fluid/tests/unittests/test_kthvalue_op.py b/python/paddle/fluid/tests/unittests/test_kthvalue_op.py index 276e3f4b8a..93008fd773 100644 --- a/python/paddle/fluid/tests/unittests/test_kthvalue_op.py +++ b/python/paddle/fluid/tests/unittests/test_kthvalue_op.py @@ -177,6 +177,12 @@ class TestKthvalueOpErrors(unittest.TestCase): self.assertRaises(ValueError, test_dim_range_error) + def test_k_error_0_dim_input(): + x_0d = paddle.full([], 1) + x_0d.kthvalue(k=8) + + self.assertRaises(ValueError, test_k_error_0_dim_input) + class TestModeOpInStatic(unittest.TestCase): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index d4ee3e5019..10c9e1c6cb 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -721,6 +721,48 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out.numpy()[3], 2) self.assertEqual(out.grad.shape, [5]) + def test_kthvalue(self): + places = ['cpu'] + if paddle.is_compiled_with_cuda(): + places.append('gpu') + for place in places: + paddle.set_device(place) + + x = paddle.randn(()) + x.stop_gradient = False + + out = paddle.kthvalue(x, 1) + out[0].backward() + + # check shape of output value and indice + self.assertEqual(out[0].shape, []) + self.assertEqual(out[1].shape, []) + + # check grad shape and value + self.assertEqual(x.grad.shape, []) + self.assertTrue(x.grad.numpy() == 1) + + def test_mode(self): + places = ['cpu'] + if paddle.is_compiled_with_cuda(): + places.append('gpu') + for place in places: + paddle.set_device(place) + + x = paddle.randn(()) + x.stop_gradient = False + + out = paddle.mode(x) + out[0].backward() + + # check shape of output value and indice + self.assertEqual(out[0].shape, []) + self.assertEqual(out[1].shape, []) + + # check grad shape and value + self.assertEqual(x.grad.shape, []) + self.assertTrue(x.grad.numpy() == 1) + def test_flatten(self): x = paddle.rand([]) x.stop_gradient = False @@ -1126,6 +1168,28 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[0].shape, (5,)) self.assertEqual(res[0][3], 2) + @prog_scope() + def test_kthvalue(self): + x = paddle.full([], 1, 'float32') + out = paddle.kthvalue(x, 1) + paddle.static.append_backward(out[0]) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + self.assertEqual(len(res[0].shape), 0) + self.assertEqual(len(res[0].shape), 0) + + @prog_scope() + def test_mode(self): + x = paddle.full([], 1, 'float32') + out = paddle.mode(x) + paddle.static.append_backward(out[0]) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + self.assertEqual(len(res[0].shape), 0) + self.assertEqual(len(res[0].shape), 0) + @prog_scope() def test_flatten(self): x = paddle.full([], 1, 'float32') -- GitLab