提交 35c48c75 编写于 作者: D DesmonDay

support 0D for paddle.sort/argsort

上级 72973d5a
...@@ -220,18 +220,26 @@ void ArgsortInferMeta(const MetaTensor& input, ...@@ -220,18 +220,26 @@ void ArgsortInferMeta(const MetaTensor& input,
MetaTensor* indices) { MetaTensor* indices) {
auto in_dims = input.dims(); auto in_dims = input.dims();
auto num_dims = in_dims.size(); auto num_dims = in_dims.size();
PADDLE_ENFORCE_GE( if (num_dims > 0) {
axis, PADDLE_ENFORCE_GE(axis,
-num_dims, -num_dims,
phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to" phi::errors::InvalidArgument(
" -num_dims(%d).", "'axis'(%d) must be greater than or equal to"
axis, " -num_dims(%d).",
-num_dims)); axis,
PADDLE_ENFORCE_LT( -num_dims));
axis, PADDLE_ENFORCE_LT(
num_dims, axis,
phi::errors::InvalidArgument( num_dims,
"'axis'(%d) must be less than num_dims(%d).", axis, num_dims)); phi::errors::InvalidArgument(
"'axis'(%d) must be less than num_dims(%d).", axis, num_dims));
} else { // 0-dim tensor
PADDLE_ENFORCE_EQ(
axis == 0 || axis == -1,
1,
phi::errors::InvalidArgument(
"'axis'(%d) must be 0 or -1 if input tensor is 0-dim.", axis));
}
output->share_dims(input); output->share_dims(input);
output->set_dtype(input.dtype()); output->set_dtype(input.dtype());
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h"
namespace phi { namespace phi {
...@@ -58,6 +59,7 @@ void ArgsortGradKernel(const Context& dev_ctx, ...@@ -58,6 +59,7 @@ void ArgsortGradKernel(const Context& dev_ctx,
bool descending, bool descending,
DenseTensor* in_grad) { DenseTensor* in_grad) {
auto in_dims = indices.dims(); auto in_dims = indices.dims();
auto rank = input.dims().size();
axis = (axis < 0) ? (in_dims.size() + axis) : axis; axis = (axis < 0) ? (in_dims.size() + axis) : axis;
dev_ctx.template Alloc<T>(in_grad); dev_ctx.template Alloc<T>(in_grad);
auto dxt = EigenVector<T>::Flatten(*in_grad); auto dxt = EigenVector<T>::Flatten(*in_grad);
...@@ -65,6 +67,11 @@ void ArgsortGradKernel(const Context& dev_ctx, ...@@ -65,6 +67,11 @@ void ArgsortGradKernel(const Context& dev_ctx,
dxt.device(place) = dxt.constant(static_cast<T>(0)); dxt.device(place) = dxt.constant(static_cast<T>(0));
if (out_grad.numel() == 0) return; if (out_grad.numel() == 0) return;
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, in_grad, 1.0);
return;
}
// Do full assign // Do full assign
if (axis == -1 || axis + 1 == in_dims.size()) { if (axis == -1 || axis + 1 == in_dims.size()) {
const int64_t input_height = const int64_t input_height =
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h"
namespace phi { namespace phi {
...@@ -75,9 +76,17 @@ void ArgsortKernel(const Context& dev_ctx, ...@@ -75,9 +76,17 @@ void ArgsortKernel(const Context& dev_ctx,
DenseTensor* output, DenseTensor* output,
DenseTensor* indices) { DenseTensor* indices) {
auto in_dims = input.dims(); auto in_dims = input.dims();
auto rank = in_dims.size();
axis = (axis < 0) ? (in_dims.size() + axis) : axis; axis = (axis < 0) ? (in_dims.size() + axis) : axis;
T* out_data = dev_ctx.template Alloc<T>(output); T* out_data = dev_ctx.template Alloc<T>(output);
// For 0D Tensor
if (rank == 0) {
phi::Copy<Context>(dev_ctx, input, dev_ctx.GetPlace(), false, output);
phi::funcs::set_constant(dev_ctx, indices, 0);
return;
}
// Do full sort // Do full sort
if (axis == -1 || axis + 1 == in_dims.size()) { if (axis == -1 || axis + 1 == in_dims.size()) {
const int64_t input_height = const int64_t input_height =
......
...@@ -28,6 +28,7 @@ namespace cub = hipcub; ...@@ -28,6 +28,7 @@ namespace cub = hipcub;
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h" #include "paddle/phi/kernels/primitive/functor_primitives.h"
#include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h"
...@@ -141,11 +142,18 @@ void ArgsortGradKernel(const Context& dev_ctx, ...@@ -141,11 +142,18 @@ void ArgsortGradKernel(const Context& dev_ctx,
bool descending, bool descending,
DenseTensor* in_grad) { DenseTensor* in_grad) {
dev_ctx.template Alloc<T>(in_grad); dev_ctx.template Alloc<T>(in_grad);
phi::funcs::set_constant(dev_ctx, in_grad, 0.0);
if (out_grad.numel() == 0) return; if (out_grad.numel() == 0) return;
auto in_dims = in_grad->dims(); auto in_dims = in_grad->dims();
auto rank = in_dims.size();
axis = (axis < 0) ? (in_dims.size() + axis) : axis; axis = (axis < 0) ? (in_dims.size() + axis) : axis;
int64_t size = in_grad->numel(); int64_t size = in_grad->numel();
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, in_grad, 1.0);
return;
}
// Parallel acceleration when the input size is equal to the length of the // Parallel acceleration when the input size is equal to the length of the
// ‘axis’ dimension. // ‘axis’ dimension.
// Compared to 'special case for full sort' below, the gradient calculation // Compared to 'special case for full sort' below, the gradient calculation
......
...@@ -30,6 +30,7 @@ namespace cub = hipcub; ...@@ -30,6 +30,7 @@ namespace cub = hipcub;
#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h" #include "paddle/phi/kernels/primitive/functor_primitives.h"
#include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h"
...@@ -396,6 +397,7 @@ void ArgsortKernel(const Context &dev_ctx, ...@@ -396,6 +397,7 @@ void ArgsortKernel(const Context &dev_ctx,
DenseTensor *output, DenseTensor *output,
DenseTensor *indices) { DenseTensor *indices) {
auto in_dims = input.dims(); auto in_dims = input.dims();
auto rank = in_dims.size();
axis = (axis < 0) ? (in_dims.size() + axis) : axis; axis = (axis < 0) ? (in_dims.size() + axis) : axis;
const T *in_data = input.data<T>(); const T *in_data = input.data<T>();
...@@ -403,6 +405,12 @@ void ArgsortKernel(const Context &dev_ctx, ...@@ -403,6 +405,12 @@ void ArgsortKernel(const Context &dev_ctx,
T *out_data = dev_ctx.template Alloc<T>(output); T *out_data = dev_ctx.template Alloc<T>(output);
int64_t *ids_data = dev_ctx.template Alloc<int64_t>(indices); int64_t *ids_data = dev_ctx.template Alloc<int64_t>(indices);
if (rank == 0) {
phi::Copy<Context>(dev_ctx, input, dev_ctx.GetPlace(), false, output);
phi::funcs::set_constant(dev_ctx, indices, 0);
return;
}
// Use thrust for parallel acceleration when the input size is equal to the // Use thrust for parallel acceleration when the input size is equal to the
// length of the ‘axis’ dimension. // length of the ‘axis’ dimension.
// Compared to the following 'Special case for full sort', ascending sort is // Compared to the following 'Special case for full sort', ascending sort is
......
...@@ -747,6 +747,42 @@ class TestSundryAPI(unittest.TestCase): ...@@ -747,6 +747,42 @@ class TestSundryAPI(unittest.TestCase):
np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy()) np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy())
np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1)) np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1))
def test_sort(self):
x1 = paddle.rand([])
x2 = paddle.rand([])
x1.stop_gradient = False
x2.stop_gradient = False
out1 = paddle.sort(x1, axis=-1)
out2 = paddle.sort(x2, axis=0)
out1.backward()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
self.assertEqual(out1.grad.shape, [])
self.assertEqual(out2.grad.shape, [])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
def test_argsort(self):
x1 = paddle.rand([])
x2 = paddle.rand([])
x1.stop_gradient = False
x2.stop_gradient = False
out1 = paddle.argsort(x1, axis=-1)
out2 = paddle.argsort(x2, axis=0)
out1.backward()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
self.assertEqual(out1.grad.shape, [])
self.assertEqual(out2.grad.shape, [])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
class TestSundryAPIStatic(unittest.TestCase): class TestSundryAPIStatic(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -990,6 +1026,42 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -990,6 +1026,42 @@ class TestSundryAPIStatic(unittest.TestCase):
np.testing.assert_array_equal(out3_1, out3_2) np.testing.assert_array_equal(out3_1, out3_2)
np.testing.assert_array_equal(out3_2, np.asarray(1)) np.testing.assert_array_equal(out3_2, np.asarray(1))
@prog_scope()
def test_sort(self):
x1 = paddle.rand([])
x1.stop_gradient = False
out1 = paddle.sort(x1, axis=-1)
paddle.static.append_backward(out1)
x2 = paddle.rand([])
x2.stop_gradient = False
out2 = paddle.sort(x2, axis=0)
paddle.static.append_backward(out2)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out1, out2])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
@prog_scope()
def test_argsort(self):
x1 = paddle.rand([])
x1.stop_gradient = False
out1 = paddle.argsort(x1, axis=-1)
paddle.static.append_backward(out1)
x2 = paddle.rand([])
x2.stop_gradient = False
out2 = paddle.argsort(x2, axis=0)
paddle.static.append_backward(out2)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out1, out2])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase): class TestNoBackwardAPI(unittest.TestCase):
......
...@@ -556,6 +556,42 @@ class TestSundryAPI(unittest.TestCase): ...@@ -556,6 +556,42 @@ class TestSundryAPI(unittest.TestCase):
np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy()) np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy())
np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1)) np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1))
def test_sort(self):
x1 = paddle.rand([])
x2 = paddle.rand([])
x1.stop_gradient = False
x2.stop_gradient = False
out1 = paddle.sort(x1, axis=-1)
out2 = paddle.sort(x2, axis=0)
out1.backward()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
self.assertEqual(out1.grad.shape, [])
self.assertEqual(out2.grad.shape, [])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
def test_argsort(self):
x1 = paddle.rand([])
x2 = paddle.rand([])
x1.stop_gradient = False
x2.stop_gradient = False
out1 = paddle.argsort(x1, axis=-1)
out2 = paddle.argsort(x2, axis=0)
out1.backward()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
self.assertEqual(out1.grad.shape, [])
self.assertEqual(out2.grad.shape, [])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase): class TestNoBackwardAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册