diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 9ababcebee4d4fc420072887084b9480396d2df4..643d2a0412d2a5e51008d2c34f3c0c48d182533b 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1889,7 +1889,7 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: ) shape_1 = ( *high_dim, - shape_ori[-3] / square, + int(shape_ori[-3] / square), shape_ori[-2] * upscale_factor, shape_ori[-1] * upscale_factor, ) @@ -1898,8 +1898,8 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: layerPixelShuffle = _get_layerPixelShuffle(_device, _dtype, dim_order) - shape_0 = convert_single_value(shape_0, dtype=inp.dtype, device=inp.device) - shape_1 = convert_single_value(shape_1, dtype=inp.dtype, device=inp.device) + shape_0 = convert_single_value(shape_0, device=inp.device) + shape_1 = convert_single_value(shape_1, device=inp.device) outvar, *_ = apply(layerPixelShuffle(), inp, shape_0, shape_1) return outvar diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 1a82b30dbd1358f75577246d271690be1ec5c4f9..6208cba09affbda168618b362a3adc0a3c13578d 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -1218,30 +1218,49 @@ def test_pixel_shuffle(): out = F.pixel_shuffle(tensor(inp), upscale_factor=4) golden = pixel_shuffle(inp, 4) np.testing.assert_equal(out.numpy(), golden) + inp_float = np.float32(inp) + out = F.pixel_shuffle(inp_float, 2) + golden = pixel_shuffle(inp_float, 2) + np.testing.assert_equal(out.numpy(),golden) # ndim = 4 inp = np.arange(3 * 18 * 3 * 3).reshape(3, 18, 3, 3) out = F.pixel_shuffle(tensor(inp), upscale_factor=3) golden = pixel_shuffle(inp, 3) np.testing.assert_equal(out.numpy(), golden) + inp_float = np.float32(inp) + out = F.pixel_shuffle(inp_float, 2) + golden = pixel_shuffle(inp_float, 2) + np.testing.assert_equal(out.numpy(),golden) # ndim = 5 inp = np.arange(5 * 3 * 20 * 3 * 4).reshape(5, 3, 20, 3, 4) out = F.pixel_shuffle(tensor(inp), upscale_factor=2) golden = pixel_shuffle(inp, 2) np.testing.assert_equal(out.numpy(), golden) - + inp_float = np.float32(inp) + out = F.pixel_shuffle(inp_float, 2) + golden = pixel_shuffle(inp_float, 2) + np.testing.assert_equal(out.numpy(),golden) # ndim = 6 inp = np.arange(6 * 5 * 3 * 25 * 3 * 4).reshape(6, 5, 3, 25, 3, 4) out = F.pixel_shuffle(tensor(inp), upscale_factor=5) golden = pixel_shuffle(inp, 5) np.testing.assert_equal(out.numpy(), golden) + inp_float = np.float32(inp) + out = F.pixel_shuffle(inp_float, 2) + golden = pixel_shuffle(inp_float, 2) + np.testing.assert_equal(out.numpy(),golden) # ndim = 7 inp = np.arange(2 * 3 * 5 * 3 * 20 * 3 * 4).reshape(2, 3, 5, 3, 20, 3, 4) out = F.pixel_shuffle(tensor(inp), upscale_factor=2) golden = pixel_shuffle(inp, 2) np.testing.assert_equal(out.numpy(), golden) + inp_float = np.float32(inp) + out = F.pixel_shuffle(inp_float, 2) + golden = pixel_shuffle(inp_float, 2) + np.testing.assert_equal(out.numpy(),golden) @pytest.mark.parametrize("is_symbolic", [False, True]) @@ -1260,6 +1279,13 @@ def test_pixel_shuffle_symbolic(is_symbolic): if is_symbolic is None: break + inp = np.float32(inp) + golden = pixel_shuffle(inp, 2) + for _ in range(3): + out = fn(inp, 2) + np.testing.assert_equal(out.numpy(), golden) + if is_symbolic is None: + break def test_set_conv2d_config(): """check setting config by contextmanager is equal to manually converted result"""