提交 536506c3 编写于 作者: M Megvii Engine Team

feat(functional): let interpolate support more modes

GitOrigin-RevId: 9693a1ac638658ca1ee0b5d0eff507d74fcc996d
上级 d811dc54
...@@ -17,7 +17,7 @@ from ..core.tensor.utils import astensor1d ...@@ -17,7 +17,7 @@ from ..core.tensor.utils import astensor1d
from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import floor from .elemwise import floor
from .math import argsort from .math import argsort
from .tensor import broadcast_to, concat, expand_dims, reshape from .tensor import broadcast_to, concat, expand_dims, reshape, transpose
def cvt_color(inp: Tensor, mode: str = ""): def cvt_color(inp: Tensor, mode: str = ""):
...@@ -474,7 +474,7 @@ def interpolate( ...@@ -474,7 +474,7 @@ def interpolate(
:param size: size of the output tensor. Default: None :param size: size of the output tensor. Default: None
:param scale_factor: scaling factor of the output tensor. Default: None :param scale_factor: scaling factor of the output tensor. Default: None
:param mode: interpolation methods, acceptable values are: :param mode: interpolation methods, acceptable values are:
"bilinear", "linear". Default: "bilinear" "bilinear", "linear", "bicubic" and "nearest". Default: "bilinear"
:param align_corners: This only has an effect when `mode` :param align_corners: This only has an effect when `mode`
is "bilinear" or "linear". Geometrically, we consider the pixels of the input is "bilinear" or "linear". Geometrically, we consider the pixels of the input
and output as squares rather than points. If set to ``True``, the input and output as squares rather than points. If set to ``True``, the input
...@@ -511,8 +511,8 @@ def interpolate( ...@@ -511,8 +511,8 @@ def interpolate(
""" """
mode = mode.lower() mode = mode.lower()
if mode not in ["bilinear", "linear"]: if mode not in ["bilinear", "linear", "bicubic", "nearest"]:
raise ValueError("interpolate only support linear or bilinear mode") raise ValueError("unsupported interpolate mode: {}".format(mode))
if mode not in ["bilinear", "linear"]: if mode not in ["bilinear", "linear"]:
if align_corners is not None: if align_corners is not None:
raise ValueError( raise ValueError(
...@@ -625,9 +625,21 @@ def interpolate( ...@@ -625,9 +625,21 @@ def interpolate(
weight = broadcast_to(weight, (inp.shape[0], 3, 3)) weight = broadcast_to(weight, (inp.shape[0], 3, 3))
weight = weight.astype("float32") weight = weight.astype("float32")
ret = warp_perspective(inp, weight, dsize, interp_mode="linear") if mode in ["linear", "bilinear"]:
if mode == "linear": ret = warp_perspective(inp, weight, dsize, interp_mode="linear")
ret = reshape(ret, ret.shape[0:3]) if mode == "linear":
ret = reshape(ret, ret.shape[0:3])
else:
# only NHWC format support "cubic" and "nearest" mode
inp = transpose(inp, (0, 2, 3, 1))
ret = warp_perspective(
inp,
weight,
dsize,
format="NHWC",
interp_mode="cubic" if mode == "bicubic" else mode,
)
ret = transpose(ret, (0, 3, 1, 2))
return ret return ret
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册