提交 5e17b3e4 编写于 作者: M Megvii Engine Team

Merge pull request #426 from Qsingle:fix-pixel_suffle

GitOrigin-RevId: db9a0f755113f7f3bcd2c278092ac465764835ff
......@@ -1843,7 +1843,7 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
)
shape_1 = (
*high_dim,
shape_ori[-3] / square,
int(shape_ori[-3] / square),
shape_ori[-2] * upscale_factor,
shape_ori[-1] * upscale_factor,
)
......@@ -1852,8 +1852,8 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
layerPixelShuffle = _get_layerPixelShuffle(_device, _dtype, dim_order)
shape_0 = convert_single_value(shape_0, dtype=inp.dtype, device=inp.device)
shape_1 = convert_single_value(shape_1, dtype=inp.dtype, device=inp.device)
shape_0 = convert_single_value(shape_0, device=inp.device)
shape_1 = convert_single_value(shape_1, device=inp.device)
outvar, *_ = apply(layerPixelShuffle(), inp, shape_0, shape_1)
return outvar
......
......@@ -1207,41 +1207,61 @@ def test_pixel_shuffle():
out = F.pixel_shuffle(tensor(inp), upscale_factor=4)
golden = pixel_shuffle(inp, 4)
np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2)
golden = pixel_shuffle(inp_float, 2)
np.testing.assert_equal(out.numpy(), golden)
# ndim = 4
inp = np.arange(3 * 18 * 3 * 3).reshape(3, 18, 3, 3)
out = F.pixel_shuffle(tensor(inp), upscale_factor=3)
golden = pixel_shuffle(inp, 3)
np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(tensor(inp_float), upscale_factor=3)
golden = pixel_shuffle(inp_float, 3)
np.testing.assert_equal(out.numpy(), golden)
# ndim = 5
inp = np.arange(5 * 3 * 20 * 3 * 4).reshape(5, 3, 20, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
golden = pixel_shuffle(inp, 2)
np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2)
golden = pixel_shuffle(inp_float, 2)
np.testing.assert_equal(out.numpy(), golden)
# ndim = 6
inp = np.arange(6 * 5 * 3 * 25 * 3 * 4).reshape(6, 5, 3, 25, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=5)
golden = pixel_shuffle(inp, 5)
np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(tensor(inp_float), upscale_factor=5)
golden = pixel_shuffle(inp_float, 5)
np.testing.assert_equal(out.numpy(), golden)
# ndim = 7
inp = np.arange(2 * 3 * 5 * 3 * 20 * 3 * 4).reshape(2, 3, 5, 3, 20, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
golden = pixel_shuffle(inp, 2)
np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2)
golden = pixel_shuffle(inp_float, 2)
np.testing.assert_equal(out.numpy(), golden)
@pytest.mark.parametrize("type", ["int32", "float32"])
@pytest.mark.parametrize("is_symbolic", [False, True])
def test_pixel_shuffle_symbolic(is_symbolic):
def test_pixel_shuffle_symbolic(is_symbolic, type):
def fn(inp, upscale_factor):
return F.pixel_shuffle(inp, upscale_factor=upscale_factor)
if is_symbolic is not None:
fn = jit.trace(symbolic=is_symbolic)(fn)
inp = tensor(np.arange(3 * 4 * 5 * 5).reshape(3, 4, 5, 5))
inp = tensor(np.arange(3 * 4 * 5 * 5).reshape(3, 4, 5, 5).astype(type))
golden = pixel_shuffle(inp, 2)
for _ in range(3):
out = fn(inp, 2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册