From 6cd7fcafc0a513351799ec7254af2c20221ead65 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Wed, 18 Jan 2023 12:23:40 +0800 Subject: [PATCH] Zero-dim support of histogram kernel, test=develop (#49884) --- .../fluid/tests/unittests/test_histogram_op.py | 8 ++++++++ .../fluid/tests/unittests/test_zero_dim_tensor.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_histogram_op.py b/python/paddle/fluid/tests/unittests/test_histogram_op.py index 71283aceaa4..dc52df4226a 100644 --- a/python/paddle/fluid/tests/unittests/test_histogram_op.py +++ b/python/paddle/fluid/tests/unittests/test_histogram_op.py @@ -153,6 +153,14 @@ class TestHistogramOp(OpTest): self.check_output(check_eager=True) +class TestHistogramOp_ZeroDim(TestHistogramOp): + def init_test_case(self): + self.in_shape = [] + self.bins = 5 + self.min = 1 + self.max = 5 + + if __name__ == "__main__": paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index cc7a257e4c4..a97d2841202 100755 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -917,6 +917,11 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out.grad.shape, [1]) self.assertEqual(x.grad.shape, []) + def test_histogram(self): + x = paddle.rand([]) + out = paddle.histogram(x, bins=5, min=1, max=5) + self.assertEqual(out.shape, [5]) + def test_scale(self): x = paddle.rand([]) x.stop_gradient = False @@ -1658,6 +1663,16 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[1].shape, ()) self.assertEqual(res[2].shape, (1,)) + @prog_scope() + def test_histogram(self): + x = paddle.full([], 1, 'float32') + out = paddle.histogram(x, bins=5, min=1, max=5) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, feed={}, fetch_list=[out]) + + self.assertEqual(res[0].shape, (5,)) + @prog_scope() def test_scale(self): x = paddle.rand([]) -- GitLab