提交 fe9c6e26 编写于 作者: M Megvii Engine Team

feat(imperative/opr): deprecate resize op and make it as a special case of interpolate

GitOrigin-RevId: a5668c5779000e6f0a1fce1694ab347624c3c20f
上级 798ae5e5
......@@ -57,7 +57,6 @@ __all__ = [
"one_hot",
"prelu",
"remap",
"resize",
"softmax",
"softplus",
"warp_affine",
......@@ -984,41 +983,6 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor:
return result
def resize(
inp: Tensor, target_shape: Iterable[int], interp_mode: str = "LINEAR"
) -> Tensor:
r"""
Applies resize transformation to batched 2D images.
:param inp: `(N, C, H, W)` input tensor. Currently only support "NCHW" format.
:param target_shape: `(H, W)` target images shape.
:param interp_mode: interpolation methods. Defaule mode is "LINEAR", Currently only support "LINEAR".
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.random.randn(10, 3, 32, 32))
out = F.resize(x, (16, 16))
print(out.numpy().shape)
Outputs:
.. testoutput::
(10, 3, 16, 16)
"""
op = builtin.Resize(imode=interp_mode, format="NCHW")
shape = astensor1d(target_shape, inp, dtype="int32", device=inp.device)
(result,) = apply(op, inp, shape)
return result
def warp_affine(
inp: Tensor,
weight: Tensor,
......@@ -1187,7 +1151,7 @@ def interpolate(
size: Optional[Union[int, Tuple[int, int]]] = None,
scale_factor: Optional[Union[float, Tuple[float, float]]] = None,
mode: str = "BILINEAR",
align_corners: bool = None,
align_corners: Optional[bool] = None,
) -> Tensor:
r"""
Down/up samples the input tensor to either the given size or with the given scale_factor. ``size`` can not coexist with ``scale_factor``.
......@@ -1197,6 +1161,15 @@ def interpolate(
:param scale_factor: scaling factor of the output tensor. Default: None
:param mode: interpolation methods, acceptable values are:
"BILINEAR", "LINEAR". Default: "BILINEAR"
:param align_corners: This only has an effect when `mode`
is "BILINEAR" or "LINEAR". Geometrically, we consider the pixels of the input
and output as squares rather than points. If set to ``True``, the input
and output tensors are aligned by the center points of their corner
pixels, preserving the values at the corner pixels. If set to ``False``,
the input and output tensors are aligned by the corner points of their
corner pixels, and the interpolation uses edge value padding for
out-of-boundary values, making this operation *independent* of input size
when `scale_factor` is kept the same. Default: None
:return: output tensor.
Examples:
......@@ -1235,6 +1208,19 @@ def interpolate(
if align_corners is None:
align_corners = False
if (
size is not None
and scale_factor is None
and not align_corners
and mode == "BILINEAR"
and inp.ndim in [4, 5]
):
# fastpath for interpolate
op = builtin.Resize(imode="LINEAR", format="NCHW")
shape = astensor1d(size, inp, dtype="int32", device=inp.device)
(result,) = apply(op, inp, shape)
return result
if mode == "LINEAR":
inp = expand_dims(inp, 3)
......
......@@ -367,12 +367,12 @@ def test_Broadcast():
np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy())
def test_resize():
def test_interpolate_fastpath():
x_np = np.random.rand(3, 3, 32, 32).astype("float32")
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = F.resize(x, (16, 16))
y = F.nn.interpolate(x, size=(16, 16), mode="BILINEAR")
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones(x_np.shape, dtype=np.float32) / 4, x.grad.numpy())
......
......@@ -325,7 +325,7 @@ def test_one_hot():
onehot_high_dimension()
def test_resize():
def test_interpolate_fastpath():
# check shape
test_cases = [
[(1, 1, 10, 10), (5, 5)],
......@@ -335,18 +335,18 @@ def test_resize():
]
for inp_shape, target_shape in test_cases:
x = tensor(np.random.randn(*inp_shape), dtype=np.float32)
out = F.resize(x, target_shape, interp_mode="LINEAR")
out = F.nn.interpolate(x, target_shape, mode="BILINEAR")
assert out.shape[0] == x.shape[0] and out.shape[1] == x.shape[1]
assert out.shape[2] == target_shape[0] and out.shape[3] == target_shape[1]
# check value
x = tensor(np.ones((3, 3, 10, 10)), dtype=np.float32)
out = F.resize(x, (15, 5), interp_mode="LINEAR")
out = F.nn.interpolate(x, (15, 5), mode="BILINEAR")
np.testing.assert_equal(out.numpy(), np.ones((3, 3, 15, 5)).astype(np.float32))
np_x = np.arange(32)
x = tensor(np_x).astype(np.float32).reshape(1, 1, 32, 1)
out = F.resize(x, (1, 1), interp_mode="LINEAR")
out = F.nn.interpolate(x, (1, 1), mode="BILINEAR")
np.testing.assert_equal(out.item(), np_x.mean())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册