From 0c37a588cafce1059e3c2d64d0fba469706b040e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 24 Feb 2021 11:39:16 +0800 Subject: [PATCH] fix(mge/functional): fix F.ones when input is a tensor of scalar type GitOrigin-RevId: 6d01d6b58d0445b42cc3f3e5f137ebd590af31a4 --- imperative/python/megengine/functional/tensor.py | 2 +- imperative/python/test/unit/functional/test_functional.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 3072837a..4218e52d 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -108,7 +108,7 @@ def full(shape, value, dtype="float32", device=None): if device is None: device = get_default_device() (x,) = Const(value, dtype=dtype, device=device)() - if len(shape) == 0: # scalar + if shape is (): # scalar.shape return x return broadcast_to(x, shape) diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index dc43cd45..bd8b016b 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -739,3 +739,10 @@ def test_cvt_color(): x = tensor(inp) y = F.img_proc.cvt_color(x, mode="RGB2GRAY") np.testing.assert_allclose(y.numpy(), out, atol=1e-5) + + +@pytest.mark.parametrize("val", [2, [2,], [2, 3]]) +def test_ones(val): + shp = tensor(val) + np_shp = np.array(val) + np.testing.assert_equal(F.ones(shp), np.ones(np_shp)) -- GitLab