未验证 提交 6cd7fcaf 编写于 作者: Q Qi Li 提交者: GitHub

Zero-dim support of histogram kernel, test=develop (#49884)

上级 5fca45ea
...@@ -153,6 +153,14 @@ class TestHistogramOp(OpTest): ...@@ -153,6 +153,14 @@ class TestHistogramOp(OpTest):
self.check_output(check_eager=True) self.check_output(check_eager=True)
class TestHistogramOp_ZeroDim(TestHistogramOp):
def init_test_case(self):
self.in_shape = []
self.bins = 5
self.min = 1
self.max = 5
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -917,6 +917,11 @@ class TestSundryAPI(unittest.TestCase): ...@@ -917,6 +917,11 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.grad.shape, [1]) self.assertEqual(out.grad.shape, [1])
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
def test_histogram(self):
x = paddle.rand([])
out = paddle.histogram(x, bins=5, min=1, max=5)
self.assertEqual(out.shape, [5])
def test_scale(self): def test_scale(self):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
...@@ -1658,6 +1663,16 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1658,6 +1663,16 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (1,)) self.assertEqual(res[2].shape, (1,))
@prog_scope()
def test_histogram(self):
x = paddle.full([], 1, 'float32')
out = paddle.histogram(x, bins=5, min=1, max=5)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={}, fetch_list=[out])
self.assertEqual(res[0].shape, (5,))
@prog_scope() @prog_scope()
def test_scale(self): def test_scale(self):
x = paddle.rand([]) x = paddle.rand([])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册