diff --git a/paddle/fluid/operators/randint_op.cc b/paddle/fluid/operators/randint_op.cc index bafb66b7ad0ceaa84c1a62e4f84afbc8a2185e8a..9752f21b7eca7d68df1599b062d199978b6aec26 100644 --- a/paddle/fluid/operators/randint_op.cc +++ b/paddle/fluid/operators/randint_op.cc @@ -77,14 +77,6 @@ class RandintOp : public framework::OperatorWithKernel { return; } - PADDLE_ENFORCE_EQ(shape.empty(), - false, - platform::errors::InvalidArgument( - "if there is no Input(ShapeTensorList) and no " - "Input(ShapeTensor),the " - "attr(shape) information must " - "be set by Attr(shape).")); - std::vector tensor_shape; tensor_shape.reserve(shape.size()); for (auto dim : shape) { diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index 419590cbe67527bc2a11b2fe1dc9846c65ea980c..f9c432c2e0a790937718b12499d8d49563d40b05 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -112,11 +112,6 @@ void RandintInferMeta( high)); auto& shape_vector = shape.GetData(); - PADDLE_ENFORCE_EQ( - shape_vector.empty(), - false, - errors::InvalidArgument("The shape information should not be empty, it " - "must be set by Attr(shape).")); std::vector tensor_shape; tensor_shape.reserve(shape_vector.size()); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 8df3eeead45c941307d8043e1358134b60eae1b2..3c66523aefffea990de1e9bdc30cdeca2f1f0f22 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3108,16 +3108,29 @@ void SliceRawInferMeta(const MetaTensor& input, void SoftmaxInferMeta(const MetaTensor& x, int axis, MetaTensor* out) { auto dim_x = x.dims(); auto rank_x = dim_x.size(); - PADDLE_ENFORCE_GE(axis, - -rank_x, - phi::errors::InvalidArgument( - "Attr(axis) value should be in range [-R, R-1], " - "R is the rank of Input(X).")); - PADDLE_ENFORCE_LT(axis, - rank_x, - phi::errors::InvalidArgument( - "Attr(axis) value should be in range [-R, R-1], " - "R is the rank of Input(X).")); + if (rank_x > 0) { + PADDLE_ENFORCE_GE(axis, + -rank_x, + phi::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(X).")); + PADDLE_ENFORCE_LT(axis, + rank_x, + phi::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(X).")); + } else { + PADDLE_ENFORCE_GE( + axis, + -1, + phi::errors::InvalidArgument("Attr(axis) value should be in range [-1, " + "0] when input is 0D Tensor ")); + PADDLE_ENFORCE_LE( + axis, + 0, + phi::errors::InvalidArgument("Attr(axis) value should be in range [-1, " + "0] when input is 0D Tensor ")); + } out->set_dims(x.dims()); out->set_dtype(x.dtype()); @@ -3963,22 +3976,29 @@ void UnchangedInferMetaCheckAxis(const MetaTensor& x, int axis, MetaTensor* out) { auto rank = x.dims().size(); - PADDLE_ENFORCE_GE( - axis, - -rank, - phi::errors::InvalidArgument( - "Attr(axis) value should be in range [-R, R-1], " - "R is the rank of Input(X). But received axis: %d, R: %d.", - axis, - rank)); - PADDLE_ENFORCE_LT( - axis, - rank, - phi::errors::InvalidArgument( - "Attr(axis) value should be in range [-R, R-1], " - "R is the rank of Input(X). But received axis: %d, R: %d.", - axis, - rank)); + if (rank > 0) { + PADDLE_ENFORCE_GE(axis, + -rank, + phi::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(X).")); + PADDLE_ENFORCE_LT(axis, + rank, + phi::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(X).")); + } else if (rank == 0) { + PADDLE_ENFORCE_GE( + axis, + -1, + phi::errors::InvalidArgument("Attr(axis) value should be in range [-1, " + "0] when input is 0D Tensor ")); + PADDLE_ENFORCE_LE( + axis, + 0, + phi::errors::InvalidArgument("Attr(axis) value should be in range [-1, " + "0] when input is 0D Tensor ")); + } out->share_meta(x); } diff --git a/paddle/phi/kernels/cpu/log_softmax_grad_kernel.cc b/paddle/phi/kernels/cpu/log_softmax_grad_kernel.cc index d3e5e90fd17a37c61f07cc59255962fa094ec09e..07e6584b49f8dcf3481e6b26b7e9c0e9bbca1668 100644 --- a/paddle/phi/kernels/cpu/log_softmax_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/log_softmax_grad_kernel.cc @@ -19,6 +19,7 @@ #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -71,6 +72,11 @@ void LogSoftmaxGradKernel(const Context& dev_ctx, const int canonical_axis = funcs::CanonicalAxis(axis, rank); dev_ctx.template Alloc(x_grad); + // For 0D Tensor + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, x_grad, 0.0); + return; + } if (out.numel() != 0) { LogSoftmaxGradFunctor()( dev_ctx, &out, &out_grad, x_grad, canonical_axis); diff --git a/paddle/phi/kernels/cpu/log_softmax_kernel.cc b/paddle/phi/kernels/cpu/log_softmax_kernel.cc index 0ba4aea78c3ca14b300bcf59025cffaa0cce9d77..1f4e5d9be462ee7ef56716aa0a3ca47ffcaf0ce2 100644 --- a/paddle/phi/kernels/cpu/log_softmax_kernel.cc +++ b/paddle/phi/kernels/cpu/log_softmax_kernel.cc @@ -19,6 +19,7 @@ #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -109,6 +110,11 @@ void LogSoftmaxKernel(const Context& dev_ctx, const int canonical_axis = funcs::CanonicalAxis(axis, rank); dev_ctx.template Alloc(out); + // For 0D Tensor + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, out, 0.0); + return; + } if (x.numel() != 0) { LogSoftmaxFunctor()(dev_ctx, &x, out, canonical_axis); } diff --git a/paddle/phi/kernels/gpu/log_softmax_grad_kernel.cu b/paddle/phi/kernels/gpu/log_softmax_grad_kernel.cu index f7b282536558db524c082de11c7ca92b2bd98edc..f6a5b26960a62a93a61a963de9d44c55bd14448b 100644 --- a/paddle/phi/kernels/gpu/log_softmax_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/log_softmax_grad_kernel.cu @@ -16,6 +16,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" namespace phi { @@ -27,6 +28,12 @@ void LogSoftmaxGradKernel(const Context &dev_ctx, int axis, DenseTensor *x_grad) { dev_ctx.template Alloc(x_grad); + const int rank = out.dims().size(); + // For 0D Tensor + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, x_grad, 0.0); + return; + } phi::SoftmaxBackwardCUDAKernelDriver( dev_ctx, out, out_grad, axis, x_grad); } diff --git a/paddle/phi/kernels/gpu/log_softmax_kernel.cu b/paddle/phi/kernels/gpu/log_softmax_kernel.cu index d7e34c6c14e7a49f50c016d888f6fb875dca0776..6dfe3d2b6173d52824e818fffbdf61270516ca0c 100644 --- a/paddle/phi/kernels/gpu/log_softmax_kernel.cu +++ b/paddle/phi/kernels/gpu/log_softmax_kernel.cu @@ -16,6 +16,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" namespace phi { @@ -25,7 +26,14 @@ void LogSoftmaxKernel(const Context &dev_ctx, const DenseTensor &x, int axis, DenseTensor *out) { + const int rank = x.dims().size(); + dev_ctx.template Alloc(out); + // For 0D Tensor + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, out, 0.0); + return; + } phi::SoftmaxForwardCUDAKernelDriver(dev_ctx, x, axis, out); } diff --git a/paddle/phi/kernels/gpudnn/softmax_grad_kernel.cu b/paddle/phi/kernels/gpudnn/softmax_grad_kernel.cu index 45ab645d3736734fb9ec4c6a7b949274c1f0a91e..72a5f37d140059d255d669409d322c5981240350 100644 --- a/paddle/phi/kernels/gpudnn/softmax_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/softmax_grad_kernel.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" namespace phi { @@ -27,6 +28,14 @@ void SoftmaxGradGPUDNNKernel(const Context& dev_ctx, int axis, DenseTensor* x_grad) { dev_ctx.template Alloc(x_grad); + + const int rank = out.dims().size(); + // For 0D Tensor + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, x_grad, 0.0); + return; + } + SoftmaxBackwardCUDAKernelDriver(dev_ctx, out, out_grad, axis, x_grad); } diff --git a/paddle/phi/kernels/gpudnn/softmax_kernel.cu b/paddle/phi/kernels/gpudnn/softmax_kernel.cu index 37175c427ffe142c31b41c8356d160d203fd6d73..77ff99334d2a2159179ba460f7f07a9702597314 100644 --- a/paddle/phi/kernels/gpudnn/softmax_kernel.cu +++ b/paddle/phi/kernels/gpudnn/softmax_kernel.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" namespace phi { @@ -26,6 +27,14 @@ void SoftmaxGPUDNNKernel(const Context& dev_ctx, int axis, DenseTensor* out) { dev_ctx.template Alloc(out); + + const int rank = x.dims().size(); + // For 0D Tensor + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, out, 1.0); + return; + } + SoftmaxForwardCUDAKernelDriver(dev_ctx, x, axis, out); } diff --git a/paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h b/paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h index 3d57dd1002ac853093f089f8eaa7f78ac96de078..96ae00366e9134f32c089405f29766b10fbce11b 100644 --- a/paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/gumbel_softmax_grad_kernel_impl.h @@ -18,6 +18,7 @@ #include "paddle/fluid/operators/math/softmax_impl.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -37,6 +38,12 @@ void GumbelSoftmaxGradKernel(const Context& ctx, return; } + // For 0D Tensor + if (rank == 0) { + phi::funcs::set_constant(ctx, dx, 0.0); + return; + } + const int size_to_axis = funcs::SizeToAxis(axis, dx->dims()); const int size_from_axis = funcs::SizeFromAxis(axis, dx->dims()); DenseTensor dx_2d(*dx), out_2d(out), dout_2d(dout); diff --git a/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h b/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h index e310d4a1674bcb1c8c7136f70b88fc7c226e82f4..26dd121be2db6eb549268acc022e0b04b11096e9 100644 --- a/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h +++ b/paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h @@ -21,6 +21,7 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -66,6 +67,12 @@ void GumbelSoftmaxKernelHelper(const Context& ctx, return; } + // For 0D Tensor + if (rank == 0) { + phi::funcs::set_constant(ctx, out, 1.0); + return; + } + const int size_to_axis = funcs::SizeToAxis(axis, x.dims()); const int size_from_axis = funcs::SizeFromAxis(axis, x.dims()); DenseTensor x_noise_2d, out_2d(*out); diff --git a/paddle/phi/kernels/impl/softmax_grad_kernel_impl.h b/paddle/phi/kernels/impl/softmax_grad_kernel_impl.h index 19df20c0d7cb6af17e3c85b24cf943a6c19c96f6..ef869195caf289c9b74271fca88653fa5e588edb 100644 --- a/paddle/phi/kernels/impl/softmax_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/softmax_grad_kernel_impl.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/softmax.h" #include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/softmax_grad_kernel.h" namespace phi { @@ -26,16 +27,22 @@ void SoftmaxGradKernel(const Context& dev_ctx, const DenseTensor& out_grad, int axis, DenseTensor* x_grad) { - const int rank = x_grad->dims().size(); - const int calc_axis = phi::funcs::CanonicalAxis(axis, rank); - int axis_dim = x_grad->dims()[calc_axis]; - - // allocate memory on device. dev_ctx.template Alloc(x_grad); + + const int rank = x_grad->dims().size(); + // For 0D Tensor + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, x_grad, 0.0); + return; + } + // For zero-sized Tensor if (x_grad->numel() == 0) { return; } + const int calc_axis = phi::funcs::CanonicalAxis(axis, rank); + int axis_dim = x_grad->dims()[calc_axis]; + const int n = phi::funcs::SizeToAxis(calc_axis, x_grad->dims()); const int d = phi::funcs::SizeFromAxis(calc_axis, x_grad->dims()); DenseTensor dX_2d, Out_2d, dOut_2d; diff --git a/paddle/phi/kernels/impl/softmax_kernel_impl.h b/paddle/phi/kernels/impl/softmax_kernel_impl.h index 5f7d097242003b67224a2f4d1f6d12e5799b056f..4114e1105191aab08617dc5d1e556bf9b091b975 100644 --- a/paddle/phi/kernels/impl/softmax_kernel_impl.h +++ b/paddle/phi/kernels/impl/softmax_kernel_impl.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/softmax.h" #include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/softmax_kernel.h" namespace phi { @@ -31,9 +32,15 @@ void SoftmaxKernel(const Context& dev_ctx, // allocate memory on device. dev_ctx.template Alloc(out); + // For 0-Sized Tensor if (out->numel() == 0) { return; } + // For 0D Tensor + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, out, 1.0); + return; + } const int n = phi::funcs::SizeToAxis(calc_axis, x.dims()); const int d = phi::funcs::SizeFromAxis(calc_axis, x.dims()); diff --git a/paddle/phi/kernels/xpu/log_softmax_grad_kernel.cc b/paddle/phi/kernels/xpu/log_softmax_grad_kernel.cc index 26f532f17b9dc2b3fb13c14f924c30a176a7e9a8..1f5d95c50f8e2a234f774014eee304471dcc763c 100644 --- a/paddle/phi/kernels/xpu/log_softmax_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/log_softmax_grad_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -30,6 +31,12 @@ void LogSoftmaxGradKernel(const Context& dev_ctx, const int rank = out.dims().size(); axis = funcs::CanonicalAxis(axis, rank); + // For 0D Tensor + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, x_grad, 0.0); + return; + } + if (out.numel() != 0) { auto out_shape = phi::vectorize(out.dims()); dev_ctx.template Alloc(x_grad); diff --git a/paddle/phi/kernels/xpu/log_softmax_kernel.cc b/paddle/phi/kernels/xpu/log_softmax_kernel.cc index 0250b08e50476b1362a7e886e9bc3dd39d690f34..a4feac7b2330716770ec97ec7147e27e2b12a22b 100644 --- a/paddle/phi/kernels/xpu/log_softmax_kernel.cc +++ b/paddle/phi/kernels/xpu/log_softmax_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -29,6 +30,11 @@ void LogSoftmaxKernel(const Context& dev_ctx, const int rank = x.dims().size(); axis = funcs::CanonicalAxis(axis, rank); + // For 0D Tensor + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, out, 0.0); + return; + } if (x.numel() != 0) { auto x_shape = phi::vectorize(x.dims()); dev_ctx.template Alloc(out); diff --git a/paddle/phi/kernels/xpu/softmax_kernel.cc b/paddle/phi/kernels/xpu/softmax_kernel.cc index 60b1c52ca5047f7d8e6cfd0266d14fe6902a7374..e8d8cd3cc77a3f9bc6f4e144a95313cc1ea3c212 100644 --- a/paddle/phi/kernels/xpu/softmax_kernel.cc +++ b/paddle/phi/kernels/xpu/softmax_kernel.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -30,9 +31,15 @@ void SoftmaxKernel(const Context& dev_ctx, // allocate memory on device. dev_ctx.template Alloc(out); + // For 0-Sized Tensor if (out->numel() == 0) { return; } + // For 0D Tensor + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, out, 1.0); + return; + } std::vector x_dims; for (int i = 0; i < rank; i++) { diff --git a/python/paddle/fluid/tests/unittests/test_gumbel_softmax_op.py b/python/paddle/fluid/tests/unittests/test_gumbel_softmax_op.py index de7f5d46352368dcc94d8756d0d0039428656193..3494ccb5d16c3112dadd86130182a99ef04cfa70 100644 --- a/python/paddle/fluid/tests/unittests/test_gumbel_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_gumbel_softmax_op.py @@ -49,6 +49,24 @@ class TestGumbelSoftmaxOp(OpTest): self.check_grad(["X"], "Out") +class TestGumbelSoftmax_ZeroDim(OpTest): + def setUp(self): + self.op_type = "gumbel_softmax" + self.dtype = "float64" + x = np.random.uniform(0.1, 1, []).astype(self.dtype) + out = np.array(1.0).astype(self.dtype) + + self.inputs = {'X': x} + self.outputs = {'Out': out} + self.attrs = {"hard": True, "axis": -1} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + class TestGumbelSoftmaxOp2(TestGumbelSoftmaxOp): def init_attrs(self): self.shape = [20, 10] diff --git a/python/paddle/fluid/tests/unittests/test_log_softmax.py b/python/paddle/fluid/tests/unittests/test_log_softmax.py index 4be4319d805f3fc534bb5f0aeb7f7126e04f0c06..ab371869c85c97e96fc272aba88cc85e1bcc3785 100644 --- a/python/paddle/fluid/tests/unittests/test_log_softmax.py +++ b/python/paddle/fluid/tests/unittests/test_log_softmax.py @@ -69,6 +69,26 @@ class TestLogSoftmaxOp(OpTest): ) +class TestLogSoftmaxOp_ZeroDim(TestLogSoftmaxOp): + def setUp(self): + self.op_type = 'log_softmax' + self.python_api = F.log_softmax + self.dtype = 'float64' + + x = np.random.uniform(0.1, 1.0, []).astype(self.dtype) + out = np.array(0.0).astype(self.dtype) + + self.inputs = {'X': x} + self.outputs = {'Out': out} + self.attrs = {'axis': -1} + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad(['X'], ['Out'], check_eager=True) + + class TestLogSoftmaxShape(TestLogSoftmaxOp): def set_attrs(self): self.shape = [12, 10] diff --git a/python/paddle/fluid/tests/unittests/test_randint_op.py b/python/paddle/fluid/tests/unittests/test_randint_op.py index 505cd43923a83f0690b13ffc78a3977bfb1a356a..dcb59a97d0f85901df7da9f0c47f128409208aa0 100644 --- a/python/paddle/fluid/tests/unittests/test_randint_op.py +++ b/python/paddle/fluid/tests/unittests/test_randint_op.py @@ -19,6 +19,7 @@ from op_test import OpTest from paddle.fluid import core from paddle.fluid.framework import _test_eager_guard from paddle.static import program_guard, Program +import paddle.fluid as fluid paddle.enable_static() @@ -239,5 +240,28 @@ class TestRandomValue(unittest.TestCase): np.testing.assert_array_equal(x[30, 2, 1000, 1000:1005], expect) +# Test API shape +class TestRandintAPI_ZeroDim(unittest.TestCase): + def test_dygraph(self): + paddle.disable_static() + x = paddle.randint(0, 2, []) + self.assertEqual(x.shape, []) + paddle.enable_static() + + def test_static(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = paddle.randint(-10, 10, []) + + # Test compile shape + self.assertEqual(x.shape, ()) + + # Test runtime shape + exe = fluid.Executor() + result = exe.run(fetch_list=[x]) + self.assertEqual(result[0].shape, ()) + + paddle.enable_static() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index 75324475f46de1b8308287fcc5263bb043ee72e9..c83f569cb11a1f2423fb806ac682e3b07507669c 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -17,6 +17,7 @@ import numpy as np from op_test import OpTest, convert_float_to_uint16 import paddle.fluid.core as core import paddle +import paddle.fluid as fluid import paddle.nn.functional as F np.random.seed(10) @@ -103,6 +104,51 @@ class TestSoftmaxOp(OpTest): ) +class TestSoftmaxOp_ZeroDim1(TestSoftmaxOp): + def setUp(self): + self.op_type = "softmax" + self.use_cudnn = False + self.use_mkldnn = False + # explicilty use float32 for ROCm, as MIOpen does not yet support float64 + self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 + + np.random.seed(0) + x = np.random.uniform(0.1, 1, []).astype(self.dtype) + out = np.array(1.0).astype(self.dtype) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + self.attrs = { + 'axis': -1, + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn, + } + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestSoftmaxOp_ZeroDim2(TestSoftmaxOp): + def setUp(self): + self.op_type = "softmax" + self.use_cudnn = True + self.use_mkldnn = False + # explicilty use float32 for ROCm, as MIOpen does not yet support float64 + self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 + + np.random.seed(0) + x = np.random.uniform(0.1, 1, []).astype(self.dtype) + out = np.array(1.0).astype(self.dtype) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + self.attrs = { + 'axis': -1, + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn, + } + + class TestSoftmaxOp2(TestSoftmaxOp): def get_x_shape(self): return [2, 3, 4, 5] @@ -442,6 +488,42 @@ class TestSoftmaxAPI(unittest.TestCase): self.softmax(x_fp16) +class TestSoftmaxAPI_ZeroDim(unittest.TestCase): + def test_dygraph(self): + paddle.disable_static() + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + x = paddle.rand([]) + x.stop_gradient = False + + out = paddle.nn.functional.softmax(x) + out.backward() + self.assertEqual(x.shape, []) + self.assertEqual(x.grad.shape, []) + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + + paddle.enable_static() + + def test_static(self): + main_prog = fluid.Program() + with fluid.program_guard(main_prog, fluid.Program()): + x = paddle.rand([]) + x.stop_gradient = False + out = paddle.nn.functional.softmax(x) + fluid.backward.append_backward(out) + + # Test compile shape + self.assertEqual(x.shape, ()) + self.assertEqual(out.shape, ()) + + exe = fluid.Executor() + result = exe.run(main_prog, fetch_list=[x, out]) + + # Test runtime shape + self.assertEqual(result[0].shape, ()) + self.assertEqual(result[1].shape, ()) + + class TestSoftmaxInplaceAPI(TestSoftmaxAPI): def executed_api(self): self.softmax = F.softmax_