From 536506c3f49463f2a1ec9b5dc6497060ab8a5801 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 23 Jul 2021 19:33:52 +0800 Subject: [PATCH] feat(functional): let interpolate support more modes GitOrigin-RevId: 9693a1ac638658ca1ee0b5d0eff507d74fcc996d --- .../python/megengine/functional/vision.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/imperative/python/megengine/functional/vision.py b/imperative/python/megengine/functional/vision.py index 14044df1f..60c1823f8 100644 --- a/imperative/python/megengine/functional/vision.py +++ b/imperative/python/megengine/functional/vision.py @@ -17,7 +17,7 @@ from ..core.tensor.utils import astensor1d from ..tensor import Tensor from .elemwise import floor from .math import argsort -from .tensor import broadcast_to, concat, expand_dims, reshape +from .tensor import broadcast_to, concat, expand_dims, reshape, transpose def cvt_color(inp: Tensor, mode: str = ""): @@ -474,7 +474,7 @@ def interpolate( :param size: size of the output tensor. Default: None :param scale_factor: scaling factor of the output tensor. Default: None :param mode: interpolation methods, acceptable values are: - "bilinear", "linear". Default: "bilinear" + "bilinear", "linear", "bicubic" and "nearest". Default: "bilinear" :param align_corners: This only has an effect when `mode` is "bilinear" or "linear". Geometrically, we consider the pixels of the input and output as squares rather than points. If set to ``True``, the input @@ -511,8 +511,8 @@ def interpolate( """ mode = mode.lower() - if mode not in ["bilinear", "linear"]: - raise ValueError("interpolate only support linear or bilinear mode") + if mode not in ["bilinear", "linear", "bicubic", "nearest"]: + raise ValueError("unsupported interpolate mode: {}".format(mode)) if mode not in ["bilinear", "linear"]: if align_corners is not None: raise ValueError( @@ -625,9 +625,21 @@ def interpolate( weight = broadcast_to(weight, (inp.shape[0], 3, 3)) weight = weight.astype("float32") - ret = warp_perspective(inp, weight, dsize, interp_mode="linear") - if mode == "linear": - ret = reshape(ret, ret.shape[0:3]) + if mode in ["linear", "bilinear"]: + ret = warp_perspective(inp, weight, dsize, interp_mode="linear") + if mode == "linear": + ret = reshape(ret, ret.shape[0:3]) + else: + # only NHWC format support "cubic" and "nearest" mode + inp = transpose(inp, (0, 2, 3, 1)) + ret = warp_perspective( + inp, + weight, + dsize, + format="NHWC", + interp_mode="cubic" if mode == "bicubic" else mode, + ) + ret = transpose(ret, (0, 3, 1, 2)) return ret -- GitLab