未验证 提交 86cc694f 编写于 作者: M mhy-666 提交者: GitHub

[Zero-Dim] support input 0D Tensor for std/var (#49735)

* add test_std

* add test_var

* fix std/var assertequal

* fix std/var assertequal

* fix std/var assertequal

* -madd api name to reduce_api

* fix

* fix var

* fix

* fix

* fix stat

* fix unitest

* fix stat/var

* fix stat/var, unittest

* fix stat/std, unittest

* add unittest of var,std, fix stat/var,std

* fix stat/var, unittest

* fix

* fix unittest

* fix

* fix

* fix

* fix unittest
上级 efef3035
......@@ -622,6 +622,45 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, 3.0)
def test_std(self):
x = paddle.rand([])
x.stop_gradient = False
out1 = paddle.std(x)
out2 = paddle.std(x, [])
out1.backward()
out2.backward()
# checkout shape of out
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
# checkout value of out
self.assertEqual(out1, 0)
self.assertEqual(out2, 0)
# checkout backward
self.assertEqual(x.grad.shape, [])
def test_var(self):
x = paddle.rand([])
x.stop_gradient = False
out1 = paddle.var(x)
out2 = paddle.var(x, [])
out1.backward()
out2.backward()
# checkout shape of out
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
# checkout value of out
self.assertEqual(out1, 0)
self.assertEqual(out2, 0)
# checkout backward
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, 0)
def test_quantile(self):
# 1) x is 0D
x = paddle.rand([])
......@@ -1708,6 +1747,48 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[2].shape, ())
np.testing.assert_allclose(res[2], 1.0)
@prog_scope()
def test_std(self):
x = paddle.rand([])
out1 = paddle.std(x)
out2 = paddle.std(x, [])
paddle.static.append_backward(out1)
paddle.static.append_backward(out2)
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[
x,
out1,
out2,
],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
@prog_scope()
def test_var(self):
x = paddle.rand([])
out1 = paddle.var(x)
out2 = paddle.var(x, [])
paddle.static.append_backward(out1)
paddle.static.append_backward(out2)
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[
x,
out1,
out2,
],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
@prog_scope()
def test_quantile(self):
x1 = paddle.rand([])
......
......@@ -148,7 +148,7 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None):
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'var')
u = mean(x, axis, True, name)
out = paddle.sum((x - u) ** 2, axis, keepdim=keepdim, name=name)
out = paddle.sum(paddle.pow((x - u), 2), axis, keepdim=keepdim, name=name)
dtype = x.dtype
n = paddle.cast(paddle.numel(x), paddle.int64) / paddle.cast(
......@@ -212,7 +212,6 @@ def std(x, axis=None, unbiased=True, keepdim=False, name=None):
"""
if not in_dygraph_mode():
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'std')
out = var(**locals())
return paddle.sqrt(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册