提交 6b863cc5 编写于 作者: C chenjiahui 提交者: Megvii Engine Team

feat(imperative): add pixel_shuffle opr

Closes #211
GitOrigin-RevId: 18467147917d9631653a3bd5919c4718e530e69d
上级 5e345043
......@@ -15,6 +15,7 @@ from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core.ops import builtin
from ..core.ops.builtin import (
BatchNorm,
Dimshuffle,
Elemwise,
GetVarShape,
Identity,
......@@ -86,6 +87,7 @@ __all__ = [
"sync_batch_norm",
"warp_affine",
"warp_perspective",
"pixel_shuffle",
]
......@@ -1733,6 +1735,69 @@ def pad(
return output
@lru_cache(maxsize=None)
def _get_layerPixelShuffle(device, dtype, dim_order):
@subgraph("LayerPixelShuffle", dtype, device, 3)
def layerPixelShuffle(inputs, f, c):
inp, shape_0, shape_1 = inputs
inp = f(Reshape(), inp, shape_0)
inp = f(Dimshuffle(dim_order), inp)
oup = f(Reshape(), inp, shape_1)
return (oup,), (True,)
return layerPixelShuffle
def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
"""
Rearranges elements in a tensor of shape (*, C x r^2, H, W) to a tensor of
shape (*, C, H x r, W x r), where r is an upscale factor, where * is zero
or more batch dimensions.
:param inp: input tensor.
:param upscale_factor: upscale factor of pixel_shuffle.
:return: output tensor.
"""
assert upscale_factor > 0, "upscale_factor should larger than 0"
assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3"
assert (
inp.shape[-3] % (upscale_factor ** 2) == 0
), "the -3 dimension should be divided by (upscale_factor ** 2)"
_device = inp.device
_dtype = inp.dtype
shape_ori = inp.shape
high_dim = shape_ori[:-3]
square = upscale_factor ** 2
n = 1
for item in high_dim:
n *= item
shape_0 = (
n,
int(shape_ori[-3] / square),
upscale_factor,
upscale_factor,
shape_ori[-2],
shape_ori[-1],
)
shape_1 = (
*high_dim,
shape_ori[-3] / square,
shape_ori[-2] * upscale_factor,
shape_ori[-1] * upscale_factor,
)
dim_order = (0, 1, 4, 2, 5, 3)
layerPixelShuffle = _get_layerPixelShuffle(_device, _dtype, dim_order)
shape_0 = convert_single_value(shape_0, dtype=inp.dtype, device=inp.device)
shape_1 = convert_single_value(shape_1, dtype=inp.dtype, device=inp.device)
outvar, *_ = apply(layerPixelShuffle(), inp, shape_0, shape_1)
return outvar
from .quantized import conv_bias_activation # isort:skip
from .loss import * # isort:skip
from .metric import * # isort:skip
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..functional.nn import pixel_shuffle
from .module import Module
class PixelShuffle(Module):
r"""
Rearranges elements in a tensor of shape (*, C x r^2, H, W) to a tensor of
shape (*, C, H x r, W x r), where r is an upscale factor, where * is zero
or more batch dimensions.
"""
def __init__(self, upscale_factor: int, **kwargs):
super().__init__(**kwargs)
self.upscale_factor = upscale_factor
def forward(self, x):
return pixel_shuffle(x, self.upscale_factor)
......@@ -1177,3 +1177,74 @@ def test_pad():
dst = np.pad(src, ((2, 2), (2, 2)), "reflect")
res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "REFLECT")
np.testing.assert_allclose(res, dst, atol=1e-5)
def pixel_shuffle(data, r):
high_dim = data.shape[:-3]
data = data.reshape(-1, data.shape[-3], data.shape[-2], data.shape[-1])
inn, ic, ih, iw = data.shape
res = np.zeros((inn, int(ic / (r * r)), ih * r, iw * r))
for n in range(inn):
for c in range(ic):
for h in range(ih):
for w in range(iw):
res[
n,
int(c / r / r),
h * r + int((c % (r * r)) / r),
w * r + c % r,
] = data[n, c, h, w]
if len(high_dim) > 0:
res = res.reshape((*high_dim, int(ic / r / r), ih * r, iw * r))
else:
res = res[0]
return res
def test_pixel_shuffle():
# ndim = 3
inp = np.arange(16 * 3 * 3).reshape(16, 3, 3)
out = F.pixel_shuffle(tensor(inp), upscale_factor=4)
golden = pixel_shuffle(inp, 4)
np.testing.assert_equal(out.numpy(), golden)
# ndim = 4
inp = np.arange(3 * 18 * 3 * 3).reshape(3, 18, 3, 3)
out = F.pixel_shuffle(tensor(inp), upscale_factor=3)
golden = pixel_shuffle(inp, 3)
np.testing.assert_equal(out.numpy(), golden)
# ndim = 5
inp = np.arange(5 * 3 * 20 * 3 * 4).reshape(5, 3, 20, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
golden = pixel_shuffle(inp, 2)
np.testing.assert_equal(out.numpy(), golden)
# ndim = 6
inp = np.arange(6 * 5 * 3 * 25 * 3 * 4).reshape(6, 5, 3, 25, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=5)
golden = pixel_shuffle(inp, 5)
np.testing.assert_equal(out.numpy(), golden)
# ndim = 7
inp = np.arange(2 * 3 * 5 * 3 * 20 * 3 * 4).reshape(2, 3, 5, 3, 20, 3, 4)
out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
golden = pixel_shuffle(inp, 2)
np.testing.assert_equal(out.numpy(), golden)
@pytest.mark.parametrize("is_symbolic", [False, True])
def test_pixel_shuffle_symbolic(is_symbolic):
def fn(inp, upscale_factor):
return F.pixel_shuffle(inp, upscale_factor=upscale_factor)
if is_symbolic is not None:
fn = jit.trace(symbolic=is_symbolic)(fn)
inp = tensor(np.arange(3 * 4 * 5 * 5).reshape(3, 4, 5, 5))
golden = pixel_shuffle(inp, 2)
for _ in range(3):
out = fn(inp, 2)
np.testing.assert_equal(out.numpy(), golden)
if is_symbolic is None:
break
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册