未验证 提交 ca59c72b 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] add 0D test case (#54581)

上级 49a45f71
......@@ -2162,32 +2162,32 @@ void LogspaceInferMeta(const MetaTensor& start,
MetaTensor* out) {
auto s_dims = start.dims();
PADDLE_ENFORCE_EQ(
(s_dims.size() == 1) && (s_dims[0] == 1),
true,
phi::errors::InvalidArgument("The shape of Input(Start) must be [1],"
"but received input shape is [%s].",
s_dims));
phi::product(s_dims),
1,
phi::errors::InvalidArgument("The size of Input(Start) must be 1,"
"but received input size is %s.",
phi::product(s_dims)));
auto e_dims = stop.dims();
PADDLE_ENFORCE_EQ(
(e_dims.size() == 1) && (e_dims[0] == 1),
phi::product(e_dims),
true,
phi::errors::InvalidArgument("The shape of Input(Stop) must be [1],"
"but received input shape is [%s].",
e_dims));
phi::errors::InvalidArgument("The size of Input(Stop) must be 1,"
"but received input size is %s.",
phi::product(e_dims)));
auto num_dims = number.dims();
PADDLE_ENFORCE_EQ(
(num_dims.size() == 1) && (num_dims[0] == 1),
phi::product(num_dims),
true,
phi::errors::InvalidArgument("The shape of Input(Num) must be [1],"
"but received input shape is [%s].",
num_dims));
phi::errors::InvalidArgument("The size of Input(Num) must be 1,"
"but received input size is %s.",
phi::product(num_dims)));
auto b_dims = base.dims();
PADDLE_ENFORCE_EQ(
(b_dims.size() == 1) && (b_dims[0] == 1),
PADDLE_ENFORCE_EQ(phi::product(b_dims),
true,
phi::errors::InvalidArgument("The shape of Input(Base) must be [1],"
"but received input shape is [%s].",
b_dims));
phi::errors::InvalidArgument(
"The size of Input(Base) must be 1,"
"but received input size is phi::product(b_dims).",
phi::product(b_dims)));
out->set_dims(phi::make_ddim({-1}));
out->set_dtype(dtype);
}
......
......@@ -394,15 +394,15 @@ def logspace(start, stop, num, base=10.0, dtype=None, name=None):
Args:
start(int|float|Tensor): The input :attr:`start` is exponent of first entry in \
the sequence. It is a scalar, or a Tensor of shape [1] with input data \
the sequence. It is a scalar, or a 0-D Tensor of shape [] with input data \
type int32, int64, float32 or float64.
stop(int|float|Tensor): The input :attr:`stop` is exponent of last entry in the \
sequence. It is a scalar, or a Tensor of shape [1] with input data \
sequence. It is a scalar, or a 0-D Tensor of shape [] with input data \
type int32, int64, float32 or float64.
num(int|Tensor): The input :attr:`num` is given number of items in the sequence. \
It is an int scalar, or a Tensor of shape [1] with data type int32.
It is an int scalar, or a 0-D Tensor of shape [] with data type int32.
base(int|float|Tensor): The input :attr:`base` is base of the logarithm function. \
It is a scalar, or a Tensor of shape [1] with input data type int32, int64, \
It is a scalar, or a 0-D Tensor of shape [] with input data type int32, int64, \
float32 or float64.
dtype(np.dtype|str, optional): The data type of output tensor, it could be \
int32, int64, float32 or float64. Default: if None, the data type is float32. \
......
......@@ -1617,7 +1617,7 @@ def count_nonzero(x, axis=None, keepdim=False, name=None):
# x is a 2-D Tensor:
x = paddle.to_tensor([[0., 1.1, 1.2], [0., 0., 1.3], [0., 0., 0.]])
out1 = paddle.count_nonzero(x)
# [3]
# 3
out2 = paddle.count_nonzero(x, axis=0)
# [0, 1, 2]
out3 = paddle.count_nonzero(x, axis=0, keepdim=True)
......@@ -1638,17 +1638,8 @@ def count_nonzero(x, axis=None, keepdim=False, name=None):
# [1, 3, 5]
"""
if axis is not None:
if isinstance(axis, int):
axis = [axis]
dims = len(x.shape)
for i in range(len(axis)):
if not isinstance(axis[i], int) or not (
axis[i] < dims and axis[i] >= -dims
):
raise ValueError(
"Axis should be None, int, or a list, element should in range [-rank(x), rank(x))."
)
bool_tensor = paddle.cast(x, 'bool')
int_tensor = paddle.cast(bool_tensor, 'int64')
......
......@@ -28,6 +28,14 @@ import paddle.nn.functional as F
unary_api_list = [
paddle.nn.functional.elu,
paddle.nn.functional.rrelu,
paddle.frac,
paddle.sgn,
paddle.nan_to_num,
paddle.i0,
paddle.i0e,
paddle.i1,
paddle.i1e,
paddle.nn.functional.gelu,
paddle.nn.functional.hardsigmoid,
paddle.nn.functional.hardswish,
......@@ -95,9 +103,15 @@ unary_api_list = [
paddle.nn.functional.alpha_dropout,
]
inplace_api_list = [
inplace_unary_api_list = [
paddle.nn.functional.relu_,
paddle.nn.functional.tanh_,
paddle.tensor.sigmoid_,
paddle.tensor.ceil_,
paddle.tensor.floor_,
paddle.tensor.reciprocal_,
paddle.tensor.exp_,
paddle.tensor.sqrt_,
]
......@@ -119,7 +133,7 @@ class TestUnaryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [])
for api in inplace_api_list:
for api in inplace_unary_api_list:
x = paddle.rand([])
out = api(x)
self.assertEqual(x.shape, [])
......@@ -173,6 +187,7 @@ reduce_api_list = [
paddle.logsumexp,
paddle.all,
paddle.any,
paddle.count_nonzero,
]
......@@ -194,6 +209,7 @@ class TestReduceAPI(unittest.TestCase):
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
if api not in [paddle.count_nonzero]:
np.testing.assert_allclose(out.numpy(), x.numpy())
out_empty_list = api(x, [])
......@@ -286,6 +302,7 @@ class TestReduceAPI(unittest.TestCase):
res = exe.run(main_prog, fetch_list=fetch_list)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
if api not in [paddle.count_nonzero]:
np.testing.assert_allclose(res[0], res[1])
if len(res) > 2:
......@@ -359,6 +376,11 @@ binary_api_list = [
paddle.fmin,
paddle.complex,
paddle.kron,
paddle.logaddexp,
paddle.nextafter,
paddle.ldexp,
paddle.polar,
paddle.heaviside,
]
binary_int_api_list = [
......@@ -370,6 +392,15 @@ binary_int_api_list = [
]
inplace_binary_api_list = [
paddle.tensor.add_,
paddle.tensor.subtract_,
paddle.tensor.multiply_,
paddle.tensor.remainder_,
paddle.tensor.remainder_,
]
# Use to test zero-dim of binary API
class TestBinaryAPI(unittest.TestCase):
def test_dygraph_binary(self):
......@@ -497,6 +528,20 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(out.shape, [3, 5])
np.testing.assert_array_equal(out.numpy(), out_np)
for api in inplace_binary_api_list:
with paddle.no_grad():
x = paddle.rand([])
y = paddle.rand([])
out = api(x, y)
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
x = paddle.rand([3, 5])
y = paddle.rand([])
out = api(x, y)
self.assertEqual(x.shape, [3, 5])
self.assertEqual(out.shape, [3, 5])
paddle.enable_static()
def test_static_binary(self):
......@@ -640,6 +685,65 @@ class TestSundryAPI(unittest.TestCase):
paddle.disable_static()
self.x = paddle.rand([])
def test_polygamma(self):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.polygamma(x, 2)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [])
def test_frexp(self):
x = paddle.rand([])
x.stop_gradient = False
out1, out2 = paddle.frexp(x)
out1.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
self.assertEqual(x.grad.shape, [])
def test_pairwise_distance(self):
x = paddle.rand([5])
x.stop_gradient = False
y = paddle.rand([5])
y.stop_gradient = False
out = paddle.nn.functional.pairwise_distance(x, y)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [5])
def test_take(self):
x = paddle.rand([4, 5])
x.stop_gradient = False
out = paddle.take(x, paddle.to_tensor(2))
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [4, 5])
np.testing.assert_allclose(x.grad[0, 2], 1.0)
x = paddle.rand([])
x.stop_gradient = False
out = paddle.take(x, paddle.to_tensor(0))
out.backward()
self.assertEqual(out.shape, [])
np.testing.assert_allclose(out, x)
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad.numpy(), 1.0)
def test_trapezoid(self):
y = paddle.rand([5])
y.stop_gradient = False
out = paddle.trapezoid(y, dx=2.0)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(y.grad.shape, [5])
def test_create_parameter_var(self):
zero_dim_param = paddle.create_parameter(shape=[], dtype='float32')
self.assertEqual(zero_dim_param.shape, [])
......@@ -1964,6 +2068,25 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out3.grad.shape, [])
self.assertTrue(out3.grad.numpy() == 1)
def test_logcumsumexp(self):
x = paddle.rand([])
x.stop_gradient = False
out1 = paddle.logcumsumexp(x)
out2 = paddle.logcumsumexp(x, axis=0)
out3 = paddle.logcumsumexp(x, axis=-1)
out1.backward()
out2.backward()
out3.backward()
self.assertEqual(out1.shape, [1])
self.assertEqual(out2.shape, [])
self.assertEqual(out3.shape, [])
self.assertEqual(x.grad.shape, [])
self.assertTrue(x.grad.numpy() == 3)
def test_add_n(self):
x1 = paddle.rand([])
x1.stop_gradient = False
......@@ -2653,6 +2776,15 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(xt_1_out.shape, [])
self.assertEqual(xt_1.grad.shape, [12])
def test_corrcoef(self):
x = paddle.randn((12,))
x.stop_gradient = False
out = paddle.linalg.corrcoef(x)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [12])
def test_det(self):
xt = paddle.randn([3, 3, 3])
xt.stop_gradient = False
......@@ -2851,6 +2983,81 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.enable_static()
self.exe = paddle.static.Executor()
@prog_scope()
def test_polygamma(self):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.polygamma(x, 2)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, x.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
@prog_scope()
def test_frexp(self):
x = paddle.rand([])
x.stop_gradient = False
out1, out2 = paddle.frexp(x)
paddle.static.append_backward(out1)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out1, out2, x.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
@prog_scope()
def test_pairwise_distance(self):
x = paddle.rand([5])
x.stop_gradient = False
y = paddle.rand([5])
y.stop_gradient = False
out = paddle.nn.functional.pairwise_distance(x, y)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, x.grad_name, y.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (5,))
self.assertEqual(res[2].shape, (5,))
@prog_scope()
def test_take(self):
x1 = paddle.rand([4, 5])
x1.stop_gradient = False
out1 = paddle.take(x1, paddle.to_tensor(2))
paddle.static.append_backward(out1)
x2 = paddle.rand([])
x2.stop_gradient = False
out2 = paddle.take(x2, paddle.to_tensor(0))
paddle.static.append_backward(out2)
prog = paddle.static.default_main_program()
res = self.exe.run(
prog, fetch_list=[out1, x1.grad_name, out2, x2.grad_name]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (4, 5))
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
np.testing.assert_allclose(res[3], 1.0)
@prog_scope()
def test_trapezoid(self):
y = paddle.rand([5])
y.stop_gradient = False
out = paddle.trapezoid(y, dx=2.0)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, y.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (5,))
@prog_scope()
def test_create_parameter_var(self):
zero_dim_param = paddle.create_parameter(shape=[], dtype='float32')
......@@ -4107,16 +4314,45 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
self.assertEqual(res[3], 1)
self.assertEqual(res[3], 1.0)
self.assertEqual(res[4].shape, (1,))
self.assertEqual(res[4], 1)
self.assertEqual(res[4], 1.0)
self.assertEqual(res[5].shape, ())
self.assertEqual(res[5], 1)
self.assertEqual(res[5], 1.0)
self.assertEqual(res[6].shape, ())
self.assertEqual(res[6], 1)
self.assertEqual(res[6], 1.0)
self.assertEqual(out2.shape, ())
self.assertEqual(out3.shape, ())
@prog_scope()
def test_logcumsumexp(self):
x = paddle.rand([])
x.stop_gradient = False
out1 = paddle.logcumsumexp(x)
out2 = paddle.logcumsumexp(x, axis=0)
out3 = paddle.logcumsumexp(x, axis=-1)
paddle.static.append_backward(out1)
paddle.static.append_backward(out2)
paddle.static.append_backward(out3)
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[
out1,
out2,
out3,
x.grad_name,
],
)
self.assertEqual(res[0].shape, (1,))
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
self.assertEqual(res[3], 1.0)
@prog_scope()
def test_add_n(self):
x1 = paddle.rand([])
......@@ -4985,11 +5221,22 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, xt_1.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (12,))
@prog_scope()
def test_corrcoef(self):
x = paddle.randn((12,))
x.stop_gradient = False
out = paddle.linalg.corrcoef(x)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, x.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (12,))
@prog_scope()
def test_det(self):
xt_1 = paddle.randn((3, 3))
......@@ -5260,6 +5507,14 @@ class TestNoBackwardAPI(unittest.TestCase):
out = paddle.linspace(start, stop, num)
np.testing.assert_array_equal(out.numpy(), [1.0, 2.0, 3.0, 4.0, 5.0])
def test_logspace(self):
start = paddle.full([], 1.0)
stop = paddle.full([], 3.0)
num = paddle.full([], 5, 'int32')
base = paddle.full([], 2.0)
out = paddle.logspace(start, stop, num, base)
self.assertEqual(out.shape, [5])
def test_arange(self):
start = paddle.full([], 1.0)
stop = paddle.full([], 6.0)
......@@ -5882,6 +6137,49 @@ class TestDistribution(unittest.TestCase):
def setUp(self):
self.x = paddle.full([], 2.0)
def test_Bernoulli(self):
d = paddle.distribution.Bernoulli(probs=0.3)
self.assertEqual(d.mean.shape, [])
self.assertEqual(d.variance.shape, [])
self.assertEqual(d.entropy().shape, [])
self.assertEqual(d.sample([]).shape, [])
self.assertEqual(d.rsample([]).shape, [])
self.assertEqual(d.cdf(self.x).shape, [])
self.assertEqual(d.prob(self.x).shape, [])
self.assertEqual(d.log_prob(self.x).shape, [])
d_other = paddle.distribution.Bernoulli(probs=0.7)
self.assertEqual(d.kl_divergence(d_other).shape, [])
def test_Geometric(self):
d = paddle.distribution.Geometric(0.5)
self.assertEqual(d.mean.shape, [])
self.assertEqual(d.variance.shape, [])
self.assertEqual(d.entropy().shape, [])
self.assertEqual(d.stddev.shape, [])
self.assertEqual(d.pmf(self.x).shape, [])
self.assertEqual(d.log_pmf(self.x).shape, [])
self.assertEqual(d.sample([]).shape, [])
self.assertEqual(d.rsample([]).shape, [])
self.assertEqual(d.cdf(self.x).shape, [])
d_other = paddle.distribution.Geometric(probs=0.7)
self.assertEqual(d.kl_divergence(d_other).shape, [])
def test_Cauchy(self):
d = paddle.distribution.Cauchy(loc=0.1, scale=1.2)
self.assertEqual(d.sample([]).shape, [])
self.assertEqual(d.rsample([]).shape, [])
self.assertEqual(d.prob(self.x).shape, [])
self.assertEqual(d.log_prob(self.x).shape, [])
self.assertEqual(d.cdf(self.x).shape, [])
self.assertEqual(d.entropy().shape, [])
d_other = paddle.distribution.Cauchy(
loc=paddle.to_tensor(1.2), scale=paddle.to_tensor(2.3)
)
self.assertEqual(d.kl_divergence(d_other).shape, [])
def test_Categorical(self):
logits = paddle.rand([6])
d = paddle.distribution.Categorical(logits)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册