From cad2e68de2c8f3e405d753884bb1c47d74983e3b Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Wed, 2 Nov 2022 13:57:29 +0800 Subject: [PATCH] [Zero-Dim] support input 0D Tensor for some binary api (#46909) --- .../operators/common_infer_shape_functions.cc | 2 +- .../operators/elementwise/elementwise_npu.h | 2 +- paddle/phi/kernels/funcs/common_shape.h | 2 +- paddle/phi/kernels/funcs/elementwise_base.h | 4 +- .../phi/kernels/funcs/elementwise_grad_base.h | 4 +- paddle/phi/kernels/xpu/elementwise.h | 4 +- .../fluid/tests/unittests/test_bitwise_op.py | 60 +++++++- .../fluid/tests/unittests/test_compare_op.py | 48 ++++++ .../unittests/test_elementwise_add_op.py | 21 +++ .../unittests/test_elementwise_div_op.py | 36 +++++ .../unittests/test_elementwise_floordiv_op.py | 21 +++ .../unittests/test_elementwise_max_op.py | 30 ++++ .../unittests/test_elementwise_min_op.py | 30 ++++ .../unittests/test_elementwise_mod_op.py | 21 +++ .../unittests/test_elementwise_mul_op.py | 21 +++ .../unittests/test_elementwise_pow_op.py | 33 +++++ .../unittests/test_elementwise_sub_op.py | 30 ++++ .../fluid/tests/unittests/test_logical_op.py | 5 +- .../tests/unittests/test_zero_dim_shape.py | 140 ++++++++++++++++++ 19 files changed, 503 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/common_infer_shape_functions.cc b/paddle/fluid/operators/common_infer_shape_functions.cc index 446a24a08b8..9dce94d16b4 100644 --- a/paddle/fluid/operators/common_infer_shape_functions.cc +++ b/paddle/fluid/operators/common_infer_shape_functions.cc @@ -40,7 +40,7 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims, platform::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LT(axis, + PADDLE_ENFORCE_LE(axis, max_dim, platform::errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", diff --git a/paddle/fluid/operators/elementwise/elementwise_npu.h b/paddle/fluid/operators/elementwise/elementwise_npu.h index 5266491d6f5..45e5a548f91 100644 --- a/paddle/fluid/operators/elementwise/elementwise_npu.h +++ b/paddle/fluid/operators/elementwise/elementwise_npu.h @@ -123,7 +123,7 @@ void NpuElementWiseOpBroadcast(const platform::NPUDeviceContext& dev_ctx, platform::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LT(axis, + PADDLE_ENFORCE_LE(axis, max_dim, platform::errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", diff --git a/paddle/phi/kernels/funcs/common_shape.h b/paddle/phi/kernels/funcs/common_shape.h index 2daf8ab0bd9..01b06120965 100644 --- a/paddle/phi/kernels/funcs/common_shape.h +++ b/paddle/phi/kernels/funcs/common_shape.h @@ -45,7 +45,7 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, phi::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LT(axis, + PADDLE_ENFORCE_LE(axis, max_dim, phi::errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", diff --git a/paddle/phi/kernels/funcs/elementwise_base.h b/paddle/phi/kernels/funcs/elementwise_base.h index 100d2dcd612..29da6174138 100644 --- a/paddle/phi/kernels/funcs/elementwise_base.h +++ b/paddle/phi/kernels/funcs/elementwise_base.h @@ -326,7 +326,7 @@ void CommonElementwiseBroadcastForward(const CPUContext &dev_ctx, phi::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LT(axis, + PADDLE_ENFORCE_LE(axis, max_dim, phi::errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", @@ -394,7 +394,7 @@ void ElementwiseCompute(const CPUContext &dev_ctx, errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LT(axis, + PADDLE_ENFORCE_LE(axis, max_dim, errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", diff --git a/paddle/phi/kernels/funcs/elementwise_grad_base.h b/paddle/phi/kernels/funcs/elementwise_grad_base.h index 62889b530af..e52c669c48d 100644 --- a/paddle/phi/kernels/funcs/elementwise_grad_base.h +++ b/paddle/phi/kernels/funcs/elementwise_grad_base.h @@ -287,7 +287,7 @@ void ElemwiseGradComputeWithBroadcast(const CPUContext &ctx, errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LT(axis, + PADDLE_ENFORCE_LE(axis, max_dim, errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", @@ -1725,7 +1725,7 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx, errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LT(axis, + PADDLE_ENFORCE_LE(axis, max_dim, errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", diff --git a/paddle/phi/kernels/xpu/elementwise.h b/paddle/phi/kernels/xpu/elementwise.h index 7c0c2a3497b..dfaaae59bb3 100644 --- a/paddle/phi/kernels/xpu/elementwise.h +++ b/paddle/phi/kernels/xpu/elementwise.h @@ -51,7 +51,7 @@ void XPUElementwise(const XPUContext& dev_ctx, errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LT(axis, + PADDLE_ENFORCE_LE(axis, max_dim, errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", @@ -121,7 +121,7 @@ void XPUElementwiseGrad(const XPUContext& dev_ctx, errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LT(axis, + PADDLE_ENFORCE_LE(axis, max_dim, errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", diff --git a/python/paddle/fluid/tests/unittests/test_bitwise_op.py b/python/paddle/fluid/tests/unittests/test_bitwise_op.py index 6a7b039380b..bcd94392446 100644 --- a/python/paddle/fluid/tests/unittests/test_bitwise_op.py +++ b/python/paddle/fluid/tests/unittests/test_bitwise_op.py @@ -57,6 +57,24 @@ class TestBitwiseAnd(OpTest): self.high = 100 +class TestBitwiseAnd_ZeroDim1(TestBitwiseAnd): + def init_shape(self): + self.x_shape = [] + self.y_shape = [] + + +class TestBitwiseAnd_ZeroDim2(TestBitwiseAnd): + def init_shape(self): + self.x_shape = [2, 3, 4, 5] + self.y_shape = [] + + +class TestBitwiseAnd_ZeroDim3(TestBitwiseAnd): + def init_shape(self): + self.x_shape = [] + self.y_shape = [2, 3, 4, 5] + + class TestBitwiseAndUInt8(TestBitwiseAnd): def init_dtype(self): self.dtype = np.uint8 @@ -143,6 +161,24 @@ class TestBitwiseOr(OpTest): self.high = 100 +class TestBitwiseOr_ZeroDim1(TestBitwiseOr): + def init_shape(self): + self.x_shape = [] + self.y_shape = [] + + +class TestBitwiseOr_ZeroDim2(TestBitwiseOr): + def init_shape(self): + self.x_shape = [2, 3, 4, 5] + self.y_shape = [] + + +class TestBitwiseOr_ZeroDim3(TestBitwiseOr): + def init_shape(self): + self.x_shape = [] + self.y_shape = [2, 3, 4, 5] + + class TestBitwiseOrUInt8(TestBitwiseOr): def init_dtype(self): self.dtype = np.uint8 @@ -229,6 +265,24 @@ class TestBitwiseXor(OpTest): self.high = 100 +class TestBitwiseXor_ZeroDim1(TestBitwiseXor): + def init_shape(self): + self.x_shape = [] + self.y_shape = [] + + +class TestBitwiseXor_ZeroDim2(TestBitwiseXor): + def init_shape(self): + self.x_shape = [2, 3, 4, 5] + self.y_shape = [] + + +class TestBitwiseXor_ZeroDim3(TestBitwiseXor): + def init_shape(self): + self.x_shape = [] + self.y_shape = [2, 3, 4, 5] + + class TestBitwiseXorUInt8(TestBitwiseXor): def init_dtype(self): self.dtype = np.uint8 @@ -311,6 +365,11 @@ class TestBitwiseNot(OpTest): self.high = 100 +class TestBitwiseNot_ZeroDim(TestBitwiseNot): + def init_shape(self): + self.x_shape = [] + + class TestBitwiseNotUInt8(TestBitwiseNot): def init_dtype(self): self.dtype = np.uint8 @@ -334,7 +393,6 @@ class TestBitwiseNotInt16(TestBitwiseNot): def init_shape(self): self.x_shape = [2, 3, 4, 5] - self.y_shape = [4, 1] class TestBitwiseNotInt64(TestBitwiseNot): diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py index d9636972a13..c5b69f8c59a 100755 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -283,6 +283,54 @@ def create_paddle_case(op_type, callback): self.assertEqual((out.numpy() == self.real_result).all(), True) paddle.enable_static() + def test_zero_dim_api_1(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.randint(-3, 3, shape=[], dtype='int32') + y = paddle.randint(-3, 3, shape=[], dtype='int32') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + ( + x_np, + y_np, + res, + ) = exe.run(fetch_list=[x, y, out]) + real_result = callback(x_np, y_np) + self.assertEqual((res == real_result).all(), True) + + def test_zero_dim_api_2(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.randint(-3, 3, shape=[2, 3, 4], dtype='int32') + y = paddle.randint(-3, 3, shape=[], dtype='int32') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + ( + x_np, + y_np, + res, + ) = exe.run(fetch_list=[x, y, out]) + real_result = callback(x_np, y_np) + self.assertEqual((res == real_result).all(), True) + + def test_zero_dim_api_3(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.randint(-3, 3, shape=[], dtype='int32') + y = paddle.randint(-3, 3, shape=[2, 3, 4], dtype='int32') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + ( + x_np, + y_np, + res, + ) = exe.run(fetch_list=[x, y, out]) + real_result = callback(x_np, y_np) + self.assertEqual((res == real_result).all(), True) + def test_broadcast_api_1(self): paddle.enable_static() with program_guard(Program(), Program()): diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index 6bfd14dc841..d9057ee4ca6 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -102,6 +102,27 @@ class TestElementwiseAddOp(OpTest): self.axis = -1 +class TestElementwiseAddOp_ZeroDim1(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, []).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, []).astype(self.dtype) + self.out = np.add(self.x, self.y) + + +class TestElementwiseAddOp_ZeroDim2(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, []).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.out = np.add(self.x, self.y) + + +class TestElementwiseAddOp_ZeroDim3(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, []).astype(self.dtype) + self.out = np.add(self.x, self.y) + + @unittest.skipIf( not core.is_compiled_with_cuda(), "core is not compiled with CUDA" ) diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py index 16bf0df5af3..7a0c5d09fbf 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py @@ -112,6 +112,42 @@ class ElementwiseDivOp(OpTest): self.check_grad_with_place(*check_args, **check_kwargs) +class TestElementwiseDivOp_ZeroDim1(ElementwiseDivOp): + def init_shape(self): + self.x_shape = [] + self.y_shape = [] + + +class TestElementwiseDivOp_ZeroDim2(ElementwiseDivOp): + def init_shape(self): + self.x_shape = [13, 17] + self.y_shape = [] + + def compute_output(self, x, y): + return x / y.reshape([1, 1]) + + def compute_gradient_x(self, grad_out, y): + return grad_out / y.reshape([1, 1]) + + def compute_gradient_y(self, grad_out, out, y): + return np.sum(-1 * grad_out * out / y.reshape([1, 1])) + + +class TestElementwiseDivOp_ZeroDim3(ElementwiseDivOp): + def init_shape(self): + self.x_shape = [] + self.y_shape = [13, 17] + + def compute_output(self, x, y): + return x.reshape([1, 1]) / y + + def compute_gradient_x(self, grad_out, y): + return np.sum(grad_out / y) + + def compute_gradient_y(self, grad_out, out, y): + return -1 * grad_out * out / y + + @unittest.skipIf( not core.is_compiled_with_cuda() or not core.is_bfloat16_supported(core.CUDAPlace(0)), diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py index de058ed2b3b..022d5929f1b 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py @@ -57,6 +57,27 @@ class TestElementwiseModOp(OpTest): pass +class TestElementwiseFloorDivOp_ZeroDim1(TestElementwiseModOp): + def init_input_output(self): + self.x = np.random.uniform(0, 10000, []).astype(self.dtype) + self.y = np.random.uniform(0, 1000, []).astype(self.dtype) + self.out = np.floor_divide(self.x, self.y) + + +class TestElementwiseFloorDivOp_ZeroDim2(TestElementwiseModOp): + def init_input_output(self): + self.x = np.random.uniform(0, 10000, [10, 10]).astype(self.dtype) + self.y = np.random.uniform(0, 1000, []).astype(self.dtype) + self.out = np.floor_divide(self.x, self.y) + + +class TestElementwiseFloorDivOp_ZeroDim3(TestElementwiseModOp): + def init_input_output(self): + self.x = np.random.uniform(0, 10000, []).astype(self.dtype) + self.y = np.random.uniform(0, 1000, [10, 10]).astype(self.dtype) + self.out = np.floor_divide(self.x, self.y) + + class TestElementwiseModOp_scalar(TestElementwiseModOp): def init_input_output(self): scale_x = random.randint(0, 100000000) diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_max_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_max_op.py index 018a44c2be9..671b5a942b8 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_max_op.py @@ -55,6 +55,36 @@ class TestElementwiseOp(OpTest): ) +class TestElementwiseMaxOp_ZeroDim1(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_max" + self.python_api = paddle.maximum + x = np.random.uniform(0.1, 1, []).astype("float64") + y = np.random.uniform(0.1, 1, []).astype("float64") + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])} + + +class TestElementwiseMaxOp_ZeroDim2(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_max" + self.python_api = paddle.maximum + x = np.random.uniform(0.1, 1, [13, 17]).astype("float64") + y = np.random.uniform(0.1, 1, []).astype("float64") + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])} + + +class TestElementwiseMaxOp_ZeroDim3(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_max" + self.python_api = paddle.maximum + x = np.random.uniform(0.1, 1, []).astype("float64") + y = np.random.uniform(0.1, 1, [13, 17]).astype("float64") + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])} + + @unittest.skipIf( core.is_compiled_with_cuda() and ( diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_min_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_min_op.py index 5a2cdc691fa..1fe78b79fb0 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_min_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_min_op.py @@ -58,6 +58,36 @@ class TestElementwiseOp(OpTest): ) +class TestElementwiseMinOp_ZeroDim1(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_min" + self.python_api = paddle.minimum + x = np.random.uniform(0.1, 1, []).astype("float64") + y = np.random.uniform(0.1, 1, []).astype("float64") + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} + + +class TestElementwiseMinOp_ZeroDim2(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_min" + self.python_api = paddle.minimum + x = np.random.uniform(0.1, 1, [13, 17]).astype("float64") + y = np.random.uniform(0.1, 1, []).astype("float64") + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} + + +class TestElementwiseMinOp_ZeroDim3(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_min" + self.python_api = paddle.minimum + x = np.random.uniform(0.1, 1, []).astype("float64") + y = np.random.uniform(0.1, 1, [13, 17]).astype("float64") + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} + + @skip_check_grad_ci( reason="[skip shape check] Use y_shape(1) to test broadcast." ) diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py index 8969d76ce51..9c9d2d91209 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py @@ -59,6 +59,27 @@ class TestElementwiseModOp(OpTest): pass +class TestElementwiseModOp_ZeroDim1(TestElementwiseModOp): + def init_input_output(self): + self.x = np.random.uniform(0, 10000, []).astype(self.dtype) + self.y = np.random.uniform(0, 1000, []).astype(self.dtype) + self.out = np.mod(self.x, self.y) + + +class TestElementwiseModOp_ZeroDim2(TestElementwiseModOp): + def init_input_output(self): + self.x = np.random.uniform(0, 10000, [10, 10]).astype(self.dtype) + self.y = np.random.uniform(0, 1000, []).astype(self.dtype) + self.out = np.mod(self.x, self.y) + + +class TestElementwiseModOp_ZeroDim3(TestElementwiseModOp): + def init_input_output(self): + self.x = np.random.uniform(0, 10000, []).astype(self.dtype) + self.y = np.random.uniform(0, 1000, [10, 10]).astype(self.dtype) + self.out = np.mod(self.x, self.y) + + class TestElementwiseModOp_scalar(TestElementwiseModOp): def init_input_output(self): scale_x = random.randint(0, 100000000) diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py index 987a17ff1f5..263fb8a9981 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py @@ -85,6 +85,27 @@ class ElementwiseMulOp(OpTest): pass +class TestElementwiseMulOp_ZeroDim1(ElementwiseMulOp): + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, []).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, []).astype(self.dtype) + self.out = np.multiply(self.x, self.y) + + +class TestElementwiseMulOp_ZeroDim2(ElementwiseMulOp): + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, []).astype(self.dtype) + self.out = np.multiply(self.x, self.y) + + +class TestElementwiseMulOp_ZeroDim3(ElementwiseMulOp): + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, []).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.out = np.multiply(self.x, self.y) + + class TestBF16ElementwiseMulOp(OpTest): def setUp(self): self.op_type = "elementwise_mul" diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py index 53cb18f8aa3..1d53dbdb2fa 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py @@ -48,6 +48,39 @@ class TestElementwisePowOp(OpTest): self.check_grad(['X', 'Y'], 'Out', check_eager=True) +class TestElementwisePowOp_ZeroDim1(TestElementwisePowOp): + def setUp(self): + self.op_type = "elementwise_pow" + self.python_api = paddle.pow + self.inputs = { + 'X': np.random.uniform(1, 2, []).astype("float64"), + 'Y': np.random.uniform(1, 2, []).astype("float64"), + } + self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} + + +class TestElementwisePowOp_ZeroDim2(TestElementwisePowOp): + def setUp(self): + self.op_type = "elementwise_pow" + self.python_api = paddle.pow + self.inputs = { + 'X': np.random.uniform(1, 2, [20, 5]).astype("float64"), + 'Y': np.random.uniform(1, 2, []).astype("float64"), + } + self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} + + +class TestElementwisePowOp_ZeroDim3(TestElementwisePowOp): + def setUp(self): + self.op_type = "elementwise_pow" + self.python_api = paddle.pow + self.inputs = { + 'X': np.random.uniform(1, 2, []).astype("float64"), + 'Y': np.random.uniform(1, 2, [20, 5]).astype("float64"), + } + self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} + + class TestElementwisePowOp_big_shape_1(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py index f8f050d6f6b..d89b3b22aa3 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py @@ -46,6 +46,36 @@ class TestElementwiseOp(OpTest): ) +class TestElementwiseSubOp_ZeroDim1(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_sub" + self.inputs = { + 'X': np.random.uniform(0.1, 1, []).astype("float64"), + 'Y': np.random.uniform(0.1, 1, []).astype("float64"), + } + self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + + +class TestElementwiseSubOp_ZeroDim2(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_sub" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype("float64"), + 'Y': np.random.uniform(0.1, 1, []).astype("float64"), + } + self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + + +class TestElementwiseSubOp_ZeroDim3(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_sub" + self.inputs = { + 'X': np.random.uniform(0.1, 1, []).astype("float64"), + 'Y': np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype("float64"), + } + self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + + class TestBF16ElementwiseOp(OpTest): def setUp(self): self.op_type = "elementwise_sub" diff --git a/python/paddle/fluid/tests/unittests/test_logical_op.py b/python/paddle/fluid/tests/unittests/test_logical_op.py index 31490961c84..c05d99a4d94 100755 --- a/python/paddle/fluid/tests/unittests/test_logical_op.py +++ b/python/paddle/fluid/tests/unittests/test_logical_op.py @@ -50,6 +50,9 @@ TEST_META_SHAPE_DATA = { 'Axis1InLargerDim': {'x_shape': [1, 4, 5], 'y_shape': [2, 3, 1, 5]}, 'EqualDim1': {'x_shape': [10, 7], 'y_shape': [10, 7]}, 'EqualDim2': {'x_shape': [1, 1, 4, 5], 'y_shape': [2, 3, 1, 5]}, + 'ZeroDim1': {'x_shape': [], 'y_shape': []}, + 'ZeroDim2': {'x_shape': [2, 3, 4, 5], 'y_shape': []}, + 'ZeroDim3': {'x_shape': [], 'y_shape': [2, 3, 4, 5]}, } TEST_META_WRONG_SHAPE_DATA = { @@ -116,7 +119,7 @@ def np_data_generator(np_shape, dtype, *args, **kwargs): if dtype == bool: return np.random.choice(a=[True, False], size=np_shape).astype(bool) else: - return np.random.randn(*np_shape).astype(dtype) + return np.random.normal(0, 1, np_shape).astype(dtype) def test(unit_test, use_gpu=False, test_error=False): diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_shape.py b/python/paddle/fluid/tests/unittests/test_zero_dim_shape.py index 0cab423aa7b..90173712d42 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_shape.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_shape.py @@ -210,5 +210,145 @@ class TestReduceAPI(unittest.TestCase): paddle.disable_static() +binary_api_list = [ + {'func': paddle.add, 'cls_method': '__add__'}, + {'func': paddle.subtract, 'cls_method': '__sub__'}, + {'func': paddle.multiply, 'cls_method': '__mul__'}, + {'func': paddle.divide, 'cls_method': '__div__'}, + {'func': paddle.subtract, 'cls_method': '__sub__'}, + paddle.pow, +] + +binary_api_list_without_grad = [ + {'func': paddle.add, 'cls_method': '__add__'}, + {'func': paddle.subtract, 'cls_method': '__sub__'}, + {'func': paddle.multiply, 'cls_method': '__mul__'}, + {'func': paddle.divide, 'cls_method': '__div__'}, + {'func': paddle.subtract, 'cls_method': '__sub__'}, + paddle.pow, + {'func': paddle.mod, 'cls_method': '__mod__'}, + paddle.floor_mod, + paddle.remainder, + {'func': paddle.equal, 'cls_method': '__eq__'}, + {'func': paddle.not_equal, 'cls_method': '__ne__'}, + {'func': paddle.greater_equal, 'cls_method': '__ge__'}, + {'func': paddle.greater_than, 'cls_method': '__gt__'}, + {'func': paddle.less_equal, 'cls_method': '__le__'}, + {'func': paddle.less_than, 'cls_method': '__lt__'}, + paddle.logical_and, + paddle.logical_or, + paddle.logical_xor, +] + + +class TestBinaryAPI(unittest.TestCase): + def test_dygraph_binary(self): + paddle.disable_static() + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + for api in binary_api_list + binary_api_list_without_grad: + # 1) x/y is 0D + x = paddle.rand([]) + y = paddle.rand([]) + x.stop_gradient = False + y.stop_gradient = False + if isinstance(api, dict): + out = api['func'](x, y) + out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y) + np.testing.assert_array_equal(out_cls.numpy(), out.numpy()) + else: + out = api(x, y) + + self.assertEqual(x.shape, []) + self.assertEqual(y.shape, []) + self.assertEqual(out.shape, []) + + if api not in binary_api_list_without_grad: + out.backward() + self.assertEqual(x.grad.shape, []) + self.assertEqual(y.grad.shape, []) + self.assertEqual(out.grad.shape, []) + + # 2) x is not 0D , y is 0D + x = paddle.rand([2, 3, 4]) + y = paddle.rand([]) + x.stop_gradient = False + y.stop_gradient = False + if isinstance(api, dict): + out = api['func'](x, y) + out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y) + np.testing.assert_array_equal(out_cls.numpy(), out.numpy()) + else: + out = api(x, y) + + self.assertEqual(x.shape, [2, 3, 4]) + self.assertEqual(y.shape, []) + self.assertEqual(out.shape, [2, 3, 4]) + + if api not in binary_api_list_without_grad: + out.backward() + self.assertEqual(x.grad.shape, [2, 3, 4]) + self.assertEqual(y.grad.shape, []) + self.assertEqual(out.grad.shape, [2, 3, 4]) + + # 3) x is 0D , y is not 0D + x = paddle.rand([]) + y = paddle.rand([2, 3, 4]) + x.stop_gradient = False + y.stop_gradient = False + if isinstance(api, dict): + out = api['func'](x, y) + out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y) + np.testing.assert_array_equal(out_cls.numpy(), out.numpy()) + else: + out = api(x, y) + out.backward() + + self.assertEqual(x.shape, []) + self.assertEqual(y.shape, [2, 3, 4]) + self.assertEqual(out.shape, [2, 3, 4]) + + if api not in binary_api_list_without_grad: + out.backward() + self.assertEqual(x.grad.shape, []) + self.assertEqual(y.grad.shape, [2, 3, 4]) + self.assertEqual(out.grad.shape, [2, 3, 4]) + + paddle.enable_static() + + def test_static_unary(self): + paddle.enable_static() + for api in binary_api_list: + main_prog = fluid.Program() + with fluid.program_guard(main_prog, fluid.Program()): + x = paddle.rand([]) + y = paddle.rand([]) + x.stop_gradient = False + y.stop_gradient = False + if isinstance(api, dict): + out = api['func'](x, y) + else: + out = api(x, y) + fluid.backward.append_backward(out) + + # append_backward always set grad shape to [1] + prog = paddle.static.default_main_program() + block = prog.global_block() + + # Test compile shape + self.assertEqual(x.shape, ()) + self.assertEqual(y.shape, ()) + self.assertEqual(out.shape, ()) + + exe = fluid.Executor() + result = exe.run(main_prog, fetch_list=[x, y, out]) + + # Test runtime shape + self.assertEqual(result[0].shape, ()) + self.assertEqual(result[1].shape, ()) + self.assertEqual(result[2].shape, ()) + + paddle.disable_static() + + if __name__ == "__main__": unittest.main() -- GitLab