未验证 提交 ac3b882f 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] support input 0D Tensor for softmax/log_softmax/gumbel_softmax (#47251)

上级 afd5a96b
......@@ -77,14 +77,6 @@ class RandintOp : public framework::OperatorWithKernel {
return;
}
PADDLE_ENFORCE_EQ(shape.empty(),
false,
platform::errors::InvalidArgument(
"if there is no Input(ShapeTensorList) and no "
"Input(ShapeTensor),the "
"attr(shape) information must "
"be set by Attr(shape)."));
std::vector<int64_t> tensor_shape;
tensor_shape.reserve(shape.size());
for (auto dim : shape) {
......
......@@ -112,11 +112,6 @@ void RandintInferMeta(
high));
auto& shape_vector = shape.GetData();
PADDLE_ENFORCE_EQ(
shape_vector.empty(),
false,
errors::InvalidArgument("The shape information should not be empty, it "
"must be set by Attr(shape)."));
std::vector<int64_t> tensor_shape;
tensor_shape.reserve(shape_vector.size());
......
......@@ -3108,16 +3108,29 @@ void SliceRawInferMeta(const MetaTensor& input,
void SoftmaxInferMeta(const MetaTensor& x, int axis, MetaTensor* out) {
auto dim_x = x.dims();
auto rank_x = dim_x.size();
PADDLE_ENFORCE_GE(axis,
-rank_x,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
PADDLE_ENFORCE_LT(axis,
rank_x,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
if (rank_x > 0) {
PADDLE_ENFORCE_GE(axis,
-rank_x,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
PADDLE_ENFORCE_LT(axis,
rank_x,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
} else {
PADDLE_ENFORCE_GE(
axis,
-1,
phi::errors::InvalidArgument("Attr(axis) value should be in range [-1, "
"0] when input is 0D Tensor "));
PADDLE_ENFORCE_LE(
axis,
0,
phi::errors::InvalidArgument("Attr(axis) value should be in range [-1, "
"0] when input is 0D Tensor "));
}
out->set_dims(x.dims());
out->set_dtype(x.dtype());
......@@ -3963,22 +3976,29 @@ void UnchangedInferMetaCheckAxis(const MetaTensor& x,
int axis,
MetaTensor* out) {
auto rank = x.dims().size();
PADDLE_ENFORCE_GE(
axis,
-rank,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X). But received axis: %d, R: %d.",
axis,
rank));
PADDLE_ENFORCE_LT(
axis,
rank,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X). But received axis: %d, R: %d.",
axis,
rank));
if (rank > 0) {
PADDLE_ENFORCE_GE(axis,
-rank,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
PADDLE_ENFORCE_LT(axis,
rank,
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
} else if (rank == 0) {
PADDLE_ENFORCE_GE(
axis,
-1,
phi::errors::InvalidArgument("Attr(axis) value should be in range [-1, "
"0] when input is 0D Tensor "));
PADDLE_ENFORCE_LE(
axis,
0,
phi::errors::InvalidArgument("Attr(axis) value should be in range [-1, "
"0] when input is 0D Tensor "));
}
out->share_meta(x);
}
......
......@@ -19,6 +19,7 @@
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......@@ -71,6 +72,11 @@ void LogSoftmaxGradKernel(const Context& dev_ctx,
const int canonical_axis = funcs::CanonicalAxis(axis, rank);
dev_ctx.template Alloc<T>(x_grad);
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, x_grad, 0.0);
return;
}
if (out.numel() != 0) {
LogSoftmaxGradFunctor<Context, T>()(
dev_ctx, &out, &out_grad, x_grad, canonical_axis);
......
......@@ -19,6 +19,7 @@
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......@@ -109,6 +110,11 @@ void LogSoftmaxKernel(const Context& dev_ctx,
const int canonical_axis = funcs::CanonicalAxis(axis, rank);
dev_ctx.template Alloc<T>(out);
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, out, 0.0);
return;
}
if (x.numel() != 0) {
LogSoftmaxFunctor<Context, T>()(dev_ctx, &x, out, canonical_axis);
}
......
......@@ -16,6 +16,7 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace phi {
......@@ -27,6 +28,12 @@ void LogSoftmaxGradKernel(const Context &dev_ctx,
int axis,
DenseTensor *x_grad) {
dev_ctx.template Alloc<T>(x_grad);
const int rank = out.dims().size();
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, x_grad, 0.0);
return;
}
phi::SoftmaxBackwardCUDAKernelDriver<T, true>(
dev_ctx, out, out_grad, axis, x_grad);
}
......
......@@ -16,6 +16,7 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace phi {
......@@ -25,7 +26,14 @@ void LogSoftmaxKernel(const Context &dev_ctx,
const DenseTensor &x,
int axis,
DenseTensor *out) {
const int rank = x.dims().size();
dev_ctx.template Alloc<T>(out);
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, out, 0.0);
return;
}
phi::SoftmaxForwardCUDAKernelDriver<T, true>(dev_ctx, x, axis, out);
}
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace phi {
......@@ -27,6 +28,14 @@ void SoftmaxGradGPUDNNKernel(const Context& dev_ctx,
int axis,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
const int rank = out.dims().size();
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, x_grad, 0.0);
return;
}
SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx, out, out_grad, axis, x_grad);
}
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace phi {
......@@ -26,6 +27,14 @@ void SoftmaxGPUDNNKernel(const Context& dev_ctx,
int axis,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
const int rank = x.dims().size();
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, out, 1.0);
return;
}
SoftmaxForwardCUDAKernelDriver<T>(dev_ctx, x, axis, out);
}
......
......@@ -18,6 +18,7 @@
#include "paddle/fluid/operators/math/softmax_impl.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......@@ -37,6 +38,12 @@ void GumbelSoftmaxGradKernel(const Context& ctx,
return;
}
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(ctx, dx, 0.0);
return;
}
const int size_to_axis = funcs::SizeToAxis(axis, dx->dims());
const int size_from_axis = funcs::SizeFromAxis(axis, dx->dims());
DenseTensor dx_2d(*dx), out_2d(out), dout_2d(dout);
......
......@@ -21,6 +21,7 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......@@ -66,6 +67,12 @@ void GumbelSoftmaxKernelHelper(const Context& ctx,
return;
}
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(ctx, out, 1.0);
return;
}
const int size_to_axis = funcs::SizeToAxis(axis, x.dims());
const int size_from_axis = funcs::SizeFromAxis(axis, x.dims());
DenseTensor x_noise_2d, out_2d(*out);
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/softmax_grad_kernel.h"
namespace phi {
......@@ -26,16 +27,22 @@ void SoftmaxGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad) {
const int rank = x_grad->dims().size();
const int calc_axis = phi::funcs::CanonicalAxis(axis, rank);
int axis_dim = x_grad->dims()[calc_axis];
// allocate memory on device.
dev_ctx.template Alloc<T>(x_grad);
const int rank = x_grad->dims().size();
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, x_grad, 0.0);
return;
}
// For zero-sized Tensor
if (x_grad->numel() == 0) {
return;
}
const int calc_axis = phi::funcs::CanonicalAxis(axis, rank);
int axis_dim = x_grad->dims()[calc_axis];
const int n = phi::funcs::SizeToAxis(calc_axis, x_grad->dims());
const int d = phi::funcs::SizeFromAxis(calc_axis, x_grad->dims());
DenseTensor dX_2d, Out_2d, dOut_2d;
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/softmax_kernel.h"
namespace phi {
......@@ -31,9 +32,15 @@ void SoftmaxKernel(const Context& dev_ctx,
// allocate memory on device.
dev_ctx.template Alloc<T>(out);
// For 0-Sized Tensor
if (out->numel() == 0) {
return;
}
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, out, 1.0);
return;
}
const int n = phi::funcs::SizeToAxis(calc_axis, x.dims());
const int d = phi::funcs::SizeFromAxis(calc_axis, x.dims());
......
......@@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......@@ -30,6 +31,12 @@ void LogSoftmaxGradKernel(const Context& dev_ctx,
const int rank = out.dims().size();
axis = funcs::CanonicalAxis(axis, rank);
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, x_grad, 0.0);
return;
}
if (out.numel() != 0) {
auto out_shape = phi::vectorize<int>(out.dims());
dev_ctx.template Alloc<T>(x_grad);
......
......@@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......@@ -29,6 +30,11 @@ void LogSoftmaxKernel(const Context& dev_ctx,
const int rank = x.dims().size();
axis = funcs::CanonicalAxis(axis, rank);
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, out, 0.0);
return;
}
if (x.numel() != 0) {
auto x_shape = phi::vectorize<int>(x.dims());
dev_ctx.template Alloc<T>(out);
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......@@ -30,9 +31,15 @@ void SoftmaxKernel(const Context& dev_ctx,
// allocate memory on device.
dev_ctx.template Alloc<T>(out);
// For 0-Sized Tensor
if (out->numel() == 0) {
return;
}
// For 0D Tensor
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, out, 1.0);
return;
}
std::vector<int> x_dims;
for (int i = 0; i < rank; i++) {
......
......@@ -49,6 +49,24 @@ class TestGumbelSoftmaxOp(OpTest):
self.check_grad(["X"], "Out")
class TestGumbelSoftmax_ZeroDim(OpTest):
def setUp(self):
self.op_type = "gumbel_softmax"
self.dtype = "float64"
x = np.random.uniform(0.1, 1, []).astype(self.dtype)
out = np.array(1.0).astype(self.dtype)
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {"hard": True, "axis": -1}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestGumbelSoftmaxOp2(TestGumbelSoftmaxOp):
def init_attrs(self):
self.shape = [20, 10]
......
......@@ -69,6 +69,26 @@ class TestLogSoftmaxOp(OpTest):
)
class TestLogSoftmaxOp_ZeroDim(TestLogSoftmaxOp):
def setUp(self):
self.op_type = 'log_softmax'
self.python_api = F.log_softmax
self.dtype = 'float64'
x = np.random.uniform(0.1, 1.0, []).astype(self.dtype)
out = np.array(0.0).astype(self.dtype)
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {'axis': -1}
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], ['Out'], check_eager=True)
class TestLogSoftmaxShape(TestLogSoftmaxOp):
def set_attrs(self):
self.shape = [12, 10]
......
......@@ -19,6 +19,7 @@ from op_test import OpTest
from paddle.fluid import core
from paddle.fluid.framework import _test_eager_guard
from paddle.static import program_guard, Program
import paddle.fluid as fluid
paddle.enable_static()
......@@ -239,5 +240,28 @@ class TestRandomValue(unittest.TestCase):
np.testing.assert_array_equal(x[30, 2, 1000, 1000:1005], expect)
# Test API shape
class TestRandintAPI_ZeroDim(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
x = paddle.randint(0, 2, [])
self.assertEqual(x.shape, [])
paddle.enable_static()
def test_static(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
x = paddle.randint(-10, 10, [])
# Test compile shape
self.assertEqual(x.shape, ())
# Test runtime shape
exe = fluid.Executor()
result = exe.run(fetch_list=[x])
self.assertEqual(result[0].shape, ())
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
......@@ -17,6 +17,7 @@ import numpy as np
from op_test import OpTest, convert_float_to_uint16
import paddle.fluid.core as core
import paddle
import paddle.fluid as fluid
import paddle.nn.functional as F
np.random.seed(10)
......@@ -103,6 +104,51 @@ class TestSoftmaxOp(OpTest):
)
class TestSoftmaxOp_ZeroDim1(TestSoftmaxOp):
def setUp(self):
self.op_type = "softmax"
self.use_cudnn = False
self.use_mkldnn = False
# explicilty use float32 for ROCm, as MIOpen does not yet support float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
np.random.seed(0)
x = np.random.uniform(0.1, 1, []).astype(self.dtype)
out = np.array(1.0).astype(self.dtype)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {
'axis': -1,
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
}
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestSoftmaxOp_ZeroDim2(TestSoftmaxOp):
def setUp(self):
self.op_type = "softmax"
self.use_cudnn = True
self.use_mkldnn = False
# explicilty use float32 for ROCm, as MIOpen does not yet support float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
np.random.seed(0)
x = np.random.uniform(0.1, 1, []).astype(self.dtype)
out = np.array(1.0).astype(self.dtype)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {
'axis': -1,
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
}
class TestSoftmaxOp2(TestSoftmaxOp):
def get_x_shape(self):
return [2, 3, 4, 5]
......@@ -442,6 +488,42 @@ class TestSoftmaxAPI(unittest.TestCase):
self.softmax(x_fp16)
class TestSoftmaxAPI_ZeroDim(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
x = paddle.rand([])
x.stop_gradient = False
out = paddle.nn.functional.softmax(x)
out.backward()
self.assertEqual(x.shape, [])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
paddle.enable_static()
def test_static(self):
main_prog = fluid.Program()
with fluid.program_guard(main_prog, fluid.Program()):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.nn.functional.softmax(x)
fluid.backward.append_backward(out)
# Test compile shape
self.assertEqual(x.shape, ())
self.assertEqual(out.shape, ())
exe = fluid.Executor()
result = exe.run(main_prog, fetch_list=[x, out])
# Test runtime shape
self.assertEqual(result[0].shape, ())
self.assertEqual(result[1].shape, ())
class TestSoftmaxInplaceAPI(TestSoftmaxAPI):
def executed_api(self):
self.softmax = F.softmax_
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册