diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index e17eb800aad165a6a236a010d04689eb7ecf2d0f..d1cc7286efbf59a4923a8fb21cab547c59ff0524 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1843,7 +1843,7 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: ) shape_1 = ( *high_dim, - shape_ori[-3] / square, + int(shape_ori[-3] / square), shape_ori[-2] * upscale_factor, shape_ori[-1] * upscale_factor, ) @@ -1852,8 +1852,8 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: layerPixelShuffle = _get_layerPixelShuffle(_device, _dtype, dim_order) - shape_0 = convert_single_value(shape_0, dtype=inp.dtype, device=inp.device) - shape_1 = convert_single_value(shape_1, dtype=inp.dtype, device=inp.device) + shape_0 = convert_single_value(shape_0, device=inp.device) + shape_1 = convert_single_value(shape_1, device=inp.device) outvar, *_ = apply(layerPixelShuffle(), inp, shape_0, shape_1) return outvar diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 2c9968aceb0c0ee943ac16b0749216748bd34be8..e38c10ebe4ae8e354fe9063de740ae17fc569f62 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -1207,41 +1207,61 @@ def test_pixel_shuffle(): out = F.pixel_shuffle(tensor(inp), upscale_factor=4) golden = pixel_shuffle(inp, 4) np.testing.assert_equal(out.numpy(), golden) + inp_float = np.float32(inp) + out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2) + golden = pixel_shuffle(inp_float, 2) + np.testing.assert_equal(out.numpy(), golden) # ndim = 4 inp = np.arange(3 * 18 * 3 * 3).reshape(3, 18, 3, 3) out = F.pixel_shuffle(tensor(inp), upscale_factor=3) golden = pixel_shuffle(inp, 3) np.testing.assert_equal(out.numpy(), golden) + inp_float = np.float32(inp) + out = F.pixel_shuffle(tensor(inp_float), upscale_factor=3) + golden = pixel_shuffle(inp_float, 3) + np.testing.assert_equal(out.numpy(), golden) # ndim = 5 inp = np.arange(5 * 3 * 20 * 3 * 4).reshape(5, 3, 20, 3, 4) out = F.pixel_shuffle(tensor(inp), upscale_factor=2) golden = pixel_shuffle(inp, 2) np.testing.assert_equal(out.numpy(), golden) - + inp_float = np.float32(inp) + out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2) + golden = pixel_shuffle(inp_float, 2) + np.testing.assert_equal(out.numpy(), golden) # ndim = 6 inp = np.arange(6 * 5 * 3 * 25 * 3 * 4).reshape(6, 5, 3, 25, 3, 4) out = F.pixel_shuffle(tensor(inp), upscale_factor=5) golden = pixel_shuffle(inp, 5) np.testing.assert_equal(out.numpy(), golden) + inp_float = np.float32(inp) + out = F.pixel_shuffle(tensor(inp_float), upscale_factor=5) + golden = pixel_shuffle(inp_float, 5) + np.testing.assert_equal(out.numpy(), golden) # ndim = 7 inp = np.arange(2 * 3 * 5 * 3 * 20 * 3 * 4).reshape(2, 3, 5, 3, 20, 3, 4) out = F.pixel_shuffle(tensor(inp), upscale_factor=2) golden = pixel_shuffle(inp, 2) np.testing.assert_equal(out.numpy(), golden) + inp_float = np.float32(inp) + out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2) + golden = pixel_shuffle(inp_float, 2) + np.testing.assert_equal(out.numpy(), golden) +@pytest.mark.parametrize("type", ["int32", "float32"]) @pytest.mark.parametrize("is_symbolic", [False, True]) -def test_pixel_shuffle_symbolic(is_symbolic): +def test_pixel_shuffle_symbolic(is_symbolic, type): def fn(inp, upscale_factor): return F.pixel_shuffle(inp, upscale_factor=upscale_factor) if is_symbolic is not None: fn = jit.trace(symbolic=is_symbolic)(fn) - inp = tensor(np.arange(3 * 4 * 5 * 5).reshape(3, 4, 5, 5)) + inp = tensor(np.arange(3 * 4 * 5 * 5).reshape(3, 4, 5, 5).astype(type)) golden = pixel_shuffle(inp, 2) for _ in range(3): out = fn(inp, 2)