提交 14499e83 编写于 作者: Q Qsingle

fix the type error of pixel shuffle

add the test of dtype float
上级 d9a46ea4
......@@ -1889,7 +1889,7 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
)
shape_1 = (
*high_dim,
shape_ori[-3] / square,
int(shape_ori[-3] / square),
shape_ori[-2] * upscale_factor,
shape_ori[-1] * upscale_factor,
)
......@@ -1898,8 +1898,8 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
layerPixelShuffle = _get_layerPixelShuffle(_device, _dtype, dim_order)
shape_0 = convert_single_value(shape_0, dtype=inp.dtype, device=inp.device)
shape_1 = convert_single_value(shape_1, dtype=inp.dtype, device=inp.device)
shape_0 = convert_single_value(shape_0, device=inp.device)
shape_1 = convert_single_value(shape_1, device=inp.device)
outvar, *_ = apply(layerPixelShuffle(), inp, shape_0, shape_1)
return outvar
......
......@@ -1218,30 +1218,49 @@ def test_pixel_shuffle():
out = F.pixel_shuffle(tensor(inp), upscale_factor=4)
golden = pixel_shuffle(inp, 4)
np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(inp_float, 2)
golden = pixel_shuffle(inp_float, 2)
np.testing.assert_equal(out.numpy(),golden)
# ndim = 4
inp = np.arange(3 * 18 * 3 * 3).reshape(3, 18, 3, 3)
out = F.pixel_shuffle(tensor(inp), upscale_factor=3)
golden = pixel_shuffle(inp, 3)
np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(inp_float, 2)
golden = pixel_shuffle(inp_float, 2)
np.testing.assert_equal(out.numpy(),golden)
# ndim = 5
inp = np.arange(5 * 3 * 20 * 3 * 4).reshape(5, 3, 20, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
golden = pixel_shuffle(inp, 2)
np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(inp_float, 2)
golden = pixel_shuffle(inp_float, 2)
np.testing.assert_equal(out.numpy(),golden)
# ndim = 6
inp = np.arange(6 * 5 * 3 * 25 * 3 * 4).reshape(6, 5, 3, 25, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=5)
golden = pixel_shuffle(inp, 5)
np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(inp_float, 2)
golden = pixel_shuffle(inp_float, 2)
np.testing.assert_equal(out.numpy(),golden)
# ndim = 7
inp = np.arange(2 * 3 * 5 * 3 * 20 * 3 * 4).reshape(2, 3, 5, 3, 20, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
golden = pixel_shuffle(inp, 2)
np.testing.assert_equal(out.numpy(), golden)
inp_float = np.float32(inp)
out = F.pixel_shuffle(inp_float, 2)
golden = pixel_shuffle(inp_float, 2)
np.testing.assert_equal(out.numpy(),golden)
@pytest.mark.parametrize("is_symbolic", [False, True])
......@@ -1260,6 +1279,13 @@ def test_pixel_shuffle_symbolic(is_symbolic):
if is_symbolic is None:
break
inp = np.float32(inp)
golden = pixel_shuffle(inp, 2)
for _ in range(3):
out = fn(inp, 2)
np.testing.assert_equal(out.numpy(), golden)
if is_symbolic is None:
break
def test_set_conv2d_config():
"""check setting config by contextmanager is equal to manually converted result"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册