From dcfb6a537ef1941856d3c31d74944a6b62bfe8a4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 2 Mar 2021 12:00:12 +0800 Subject: [PATCH] refactor(mge/functional): move functional api GitOrigin-RevId: 9cd3e09996f77a00e114f0870eae55ad80b4ba8b --- .../python/megengine/distributed/helper.py | 2 +- .../python/megengine/functional/__init__.py | 3 +- .../python/megengine/functional/elemwise.py | 56 +- .../python/megengine/functional/img_proc.py | 50 -- .../python/megengine/functional/loss.py | 5 +- .../python/megengine/functional/math.py | 2 - .../functional/{utils.py => metric.py} | 41 +- imperative/python/megengine/functional/nn.py | 709 ++++++------------ .../python/megengine/functional/tensor.py | 36 +- .../python/megengine/functional/vision.py | 576 ++++++++++++++ .../python/megengine/module/identity.py | 2 +- .../python/test/unit/core/test_autodiff.py | 2 +- .../test/unit/functional/test_functional.py | 36 +- .../python/test/unit/jit/test_tracing.py | 6 +- .../test/unit/utils/test_network_node.py | 14 +- .../src/impl/ops/{img_proc.cpp => vision.cpp} | 4 +- 16 files changed, 887 insertions(+), 657 deletions(-) delete mode 100644 imperative/python/megengine/functional/img_proc.py rename imperative/python/megengine/functional/{utils.py => metric.py} (69%) create mode 100644 imperative/python/megengine/functional/vision.py rename imperative/src/impl/ops/{img_proc.cpp => vision.cpp} (95%) diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index bc365c2e..0a67f2dd 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -19,7 +19,7 @@ from megengine.device import get_default_device, get_device_count from ..core._imperative_rt.core2 import apply from ..core.ops.builtin import ParamPackConcat, ParamPackSplit -from ..functional.utils import copy +from ..functional.tensor import copy from ..tensor import Tensor from ..utils.future import Future from .functional import all_reduce_sum, broadcast diff --git a/imperative/python/megengine/functional/__init__.py b/imperative/python/megengine/functional/__init__.py index 976e96c1..fcd76b10 100644 --- a/imperative/python/megengine/functional/__init__.py +++ b/imperative/python/megengine/functional/__init__.py @@ -7,12 +7,11 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # pylint: disable=redefined-builtin +from . import metric, vision from .elemwise import * -from .img_proc import * from .math import * from .nn import * from .tensor import * -from .utils import * from . import distributed # isort:skip diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 9e943d5b..f6e876db 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -7,8 +7,6 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order -import functools - import numpy as np from ..core._imperative_rt.core2 import apply @@ -17,7 +15,7 @@ from ..core.ops import builtin from ..core.ops.builtin import Elemwise from ..core.tensor import utils from ..core.tensor.array_method import _elwise_apply -from ..core.tensor.utils import astype, isscalar, setscalar +from ..core.tensor.utils import astype from ..device import get_default_device from ..jit.tracing import is_tracing from ..tensor import Tensor @@ -44,8 +42,6 @@ __all__ = [ "floor_div", "greater", "greater_equal", - "hswish", - "hsigmoid", "left_shift", "less", "less_equal", @@ -62,11 +58,8 @@ __all__ = [ "neg", "not_equal", "pow", - "relu", - "relu6", "right_shift", "round", - "sigmoid", "sin", "sinh", "sqrt", @@ -523,53 +516,6 @@ def greater_equal(x, y): # other functions -def hswish(x): - """ - Element-wise `x * relu6(x + 3) / 6`. - - :param x: input tensor. - :return: computed tensor. - - Example: - - .. testcode:: - - import numpy as np - from megengine import tensor - import megengine.functional as F - - x = tensor(np.arange(5).astype(np.float32)) - out = F.hswish(x) - print(out.numpy().round(decimals=4)) - - .. testoutput:: - - [0. 0.6667 1.6667 3. 4. ] - - """ - return _elwise(x, mode=Elemwise.Mode.H_SWISH) - - -def hsigmoid(x): - """Element-wise `relu6(x + 3) / 6`.""" - return relu6(x + 3) / 6 - - -def relu(x): - """Element-wise `max(x, 0)`.""" - return _elwise(x, mode=Elemwise.Mode.RELU) - - -def relu6(x): - """Element-wise `min(max(x, 0), 6)`.""" - return minimum(maximum(x, 0), 6) - - -def sigmoid(x): - """Element-wise `1 / ( 1 + exp( -x ) )`.""" - return _elwise(x, mode=Elemwise.Mode.SIGMOID) - - def clip(x: Tensor, lower=None, upper=None) -> Tensor: r""" Clamps all elements in input tensor into the range `[` :attr:`lower`, :attr:`upper` `]` and returns diff --git a/imperative/python/megengine/functional/img_proc.py b/imperative/python/megengine/functional/img_proc.py deleted file mode 100644 index 5222b47d..00000000 --- a/imperative/python/megengine/functional/img_proc.py +++ /dev/null @@ -1,50 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from ..core._imperative_rt.core2 import apply -from ..core.ops import builtin -from ..tensor import Tensor - -__all__ = [ - "cvt_color", -] - - -def cvt_color(inp: Tensor, mode: str = ""): - r""" - Convert images from one format to another - - :param inp: input images. - :param mode: format mode. - :return: convert result. - - Examples: - - .. testcode:: - - import numpy as np - import megengine as mge - import megengine.functional as F - - x = mge.tensor(np.array([[[[-0.58675045, 1.7526233, 0.10702174]]]]).astype(np.float32)) - y = F.img_proc.cvt_color(x, mode="RGB2GRAY") - print(y.numpy()) - - Outputs: - - .. testoutput:: - - [[[[0.86555195]]]] - - """ - assert mode in builtin.CvtColor.Mode.__dict__, "unspport mode for cvt_color" - mode = getattr(builtin.CvtColor.Mode, mode) - assert isinstance(mode, builtin.CvtColor.Mode) - op = builtin.CvtColor(mode=mode) - (out,) = apply(op, inp) - return out diff --git a/imperative/python/megengine/functional/loss.py b/imperative/python/megengine/functional/loss.py index 76814fdd..7711bf20 100644 --- a/imperative/python/megengine/functional/loss.py +++ b/imperative/python/megengine/functional/loss.py @@ -8,10 +8,9 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np -from ..core.tensor.utils import make_shape_tuple from ..tensor import Tensor -from .elemwise import abs, equal, exp, log, maximum, pow, relu -from .nn import indexing_one_hot, logsigmoid, logsumexp +from .elemwise import abs, log +from .nn import indexing_one_hot, logsigmoid, logsumexp, relu from .tensor import where __all__ = [ diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index f93e5cee..5941db2e 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -7,9 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections -import functools import math -import numbers from typing import Optional, Sequence, Tuple, Union from ..core._imperative_rt.core2 import apply diff --git a/imperative/python/megengine/functional/utils.py b/imperative/python/megengine/functional/metric.py similarity index 69% rename from imperative/python/megengine/functional/utils.py rename to imperative/python/megengine/functional/metric.py index 465ff20d..91a03e1d 100644 --- a/imperative/python/megengine/functional/utils.py +++ b/imperative/python/megengine/functional/metric.py @@ -6,23 +6,14 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import collections from typing import Iterable, Union import numpy as np -from ..core._imperative_rt.core2 import apply -from ..core._wrap import device as as_device -from ..core.ops.builtin import Copy, Identity from ..tensor import Tensor from .math import topk as _topk from .tensor import broadcast_to, transpose -__all__ = [ - "topk_accuracy", - "copy", -] - def topk_accuracy( logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1 @@ -46,7 +37,7 @@ def topk_accuracy( logits = tensor(np.arange(80, dtype=np.int32).reshape(8,10)) target = tensor(np.arange(8, dtype=np.int32)) - top1, top5 = F.topk_accuracy(logits, target, (1, 5)) + top1, top5 = F.metric.topk_accuracy(logits, target, (1, 5)) print(top1.numpy(), top5.numpy()) Outputs: @@ -67,33 +58,3 @@ def topk_accuracy( if len(topk) == 1: # type: ignore[arg-type] accs = accs[0] return accs - - -def copy(inp, device=None): - r""" - Copies tensor to another device. - - :param inp: input tensor. - :param device: destination device. - - Examples: - - .. testcode:: - - import numpy as np - from megengine import tensor - import megengine.functional as F - - x = tensor([1, 2, 3], np.int32) - y = F.copy(x, "xpu1") - print(y.numpy()) - - Outputs: - - .. testoutput:: - - [1 2 3] - """ - if device is None: - return apply(Identity(), inp)[0] - return apply(Copy(comp_node=as_device(device).to_c()), inp)[0] diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index dba6ddfb..bc1aac49 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -7,24 +7,25 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # pylint: disable=too-many-lines -from typing import Iterable, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Tuple, Union -from ..core._imperative_rt import CompNode from ..core._imperative_rt.core2 import apply +from ..core._imperative_rt.graph import VarNode from ..core._trace_option import use_symbolic_shape from ..core.ops import builtin -from ..core.ops.builtin import BatchNorm +from ..core.ops.builtin import BatchNorm, Elemwise from ..core.ops.special import Const -from ..core.tensor import utils -from ..core.tensor.utils import astensor1d, setscalar +from ..core.tensor import megbrain_graph, utils +from ..core.tensor.array_method import _elwise_apply +from ..core.tensor.utils import astensor1d, astype, setscalar +from ..device import get_default_device from ..distributed import WORLD, is_distributed -from ..jit.tracing import is_tracing from ..random import uniform from ..tensor import Tensor from ..utils.tuple_function import _pair, _pair_nonzero -from .debug_param import get_execution_strategy +from .debug_param import get_conv_execution_strategy, get_execution_strategy from .distributed import all_reduce_sum -from .elemwise import exp, floor, log, log1p, maximum, minimum, relu +from .elemwise import exp, floor, log, log1p, maximum, minimum from .math import argsort, matmul, max, prod, sum from .tensor import ( broadcast_to, @@ -47,8 +48,10 @@ __all__ = [ "deformable_conv2d", "deformable_psroi_pooling", "dropout", + "embedding", "indexing_one_hot", "leaky_relu", + "linear", "local_conv2d", "logsigmoid", "logsumexp", @@ -56,12 +59,16 @@ __all__ = [ "max_pool2d", "one_hot", "prelu", - "remap", "softmax", "softplus", - "warp_affine", - "warp_perspective", + "svd", + "sync_batch_norm", "conv1d", + "sigmoid", + "hsigmoid", + "relu", + "relu6", + "hswish", ] @@ -983,79 +990,32 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor: return result -def warp_affine( - inp: Tensor, - weight: Tensor, - out_shape, - border_mode="REPLICATE", - border_val=0, - format="NHWC", - imode="LINEAR", -): - """ - Batched affine transform on 2D images. - - :param inp: input image. - :param weight: weight tensor. - :param out_shape: output tensor shape. - :param border_mode: pixel extrapolation method. - Default: "WRAP". Currently "CONSTANT", "REFLECT", - "REFLECT_101", "ISOLATED", "WRAP", "REPLICATE", "TRANSPARENT" are supported. - :param border_val: value used in case of a constant border. Default: 0 - :param format: "NHWC" as default based on historical concerns, - "NCHW" is also supported. Default: "NCHW". - :param imode: interpolation methods. Could be "LINEAR", "NEAREST", "CUBIC", "AREA". - Default: "LINEAR". - :return: output tensor. - - .. note:: - - Here all available options for params are listed, - however it does not mean that you can use all the combinations. - On different platforms, different combinations are supported. +def matmul( + inp1: Tensor, + inp2: Tensor, + transpose_a=False, + transpose_b=False, + compute_mode="DEFAULT", + format="DEFAULT", +) -> Tensor: """ - op = builtin.WarpAffine( - border_mode=border_mode, border_val=border_val, format=format, imode=imode - ) - out_shape = utils.astensor1d(out_shape, inp, dtype="int32", device=inp.device) - (result,) = apply(op, inp, weight, out_shape) - return result + Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``. + With different inputs dim, this function behaves differently: -def warp_perspective( - inp: Tensor, - M: Tensor, - dsize: Union[Tuple[int, int], int, Tensor], - border_mode: str = "REPLICATE", - border_val: float = 0.0, - interp_mode: str = "LINEAR", -) -> Tensor: - r""" - Applies perspective transformation to batched 2D images. + - Both 1-D tensor, simply forward to ``dot``. + - Both 2-D tensor, normal matrix multiplication. + - If one input tensor is 1-D, matrix vector multiplication. + - If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, the batched matrix-matrix is returned, and the tensor with smaller dimension will + be broadcasted. For example: + - inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)` + - inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)` + - inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)` - The input images are transformed to the output images by the transformation matrix: - - .. math:: - \text{output}(n, c, h, w) = \text{input} \left( n, c, - \frac{M_{00}h + M_{01}w + M_{02}}{M_{20}h + M_{21}w + M_{22}}, - \frac{M_{10}h + M_{11}w + M_{12}}{M_{20}h + M_{21}w + M_{22}} - \right) - - :param inp: input image. - :param M: `(batch, 3, 3)` transformation matrix. - :param dsize: `(h, w)` size of the output image. - :param border_mode: pixel extrapolation method. - Default: "REPLICATE". Currently also support "CONSTANT", "REFLECT", - "REFLECT_101", "WRAP". - :param border_val: value used in case of a constant border. Default: 0 - :param interp_mode: interpolation methods. - Default: "LINEAR". Currently only support "LINEAR" mode. + :param inp1: first matrix to be multiplied. + :param inp2: second matrix to be multiplied. :return: output tensor. - .. note:: - - The transformation matrix is the inverse of that used by `cv2.warpPerspective`. - Examples: .. testcode:: @@ -1064,55 +1024,111 @@ def warp_perspective( from megengine import tensor import megengine.functional as F - inp_shape = (1, 1, 4, 4) - x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) - M_shape = (1, 3, 3) - # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1) - M = tensor(np.array([[1., 0., 1.], - [0., 1., 1.], - [0., 0., 1.]], dtype=np.float32).reshape(M_shape)) - out = F.warp_perspective(x, M, (2, 2)) + data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) + data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2)) + out = F.matmul(data1, data2) print(out.numpy()) Outputs: .. testoutput:: - [[[[ 5. 6.] - [ 9. 10.]]]] + [[10. 13.] + [28. 40.]] """ - op = builtin.WarpPerspective( - imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val - ) - inp, M = utils.convert_inputs(inp, M) - dsize = astensor1d(dsize, inp, dtype="int32", device=inp.device) - (result,) = apply(op, inp, M, dsize) + remove_row, remove_col = False, False + inp1, inp2 = utils.convert_inputs(inp1, inp2) + + dim1, dim2 = inp1.ndim, inp2.ndim + # handle dim=1 cases, dot and matrix-vector multiplication + if dim1 == 1 and dim2 == 1: + return dot(inp1, inp2) + # the underlying matmul op requires input dims to be at least 2 + if dim1 == 1: + inp1 = expand_dims(inp1, 0) + dim1 = 2 + remove_row = True + if dim2 == 1: + inp2 = expand_dims(inp2, 1) + dim2 = 2 + remove_col = True + + batch_shape = None + shape1 = inp1.shape + shape2 = inp2.shape + + maxdim = dim1 if dim1 > dim2 else dim2 + if dim1 >= 3 or dim2 >= 3: + if use_symbolic_shape(): + if dim1 > dim2: + shape2 = concat([shape1[:-2], shape2[-2:]]) + inp2 = broadcast_to(inp2, shape2) + if dim1 < dim2: + shape1 = concat([shape2[:-2], shape1[-2:]]) + inp1 = broadcast_to(inp1, shape1) + if maxdim > 3: + batch_shape = shape1[:-2] + # compress inputs to 3d + (inp1,) = apply( + builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]]) + ) + (inp2,) = apply( + builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]]) + ) + else: + if dim1 > dim2: + shape2 = shape1[:-2] + shape2[-2:] + inp2 = broadcast_to(inp2, shape2) + if dim1 < dim2: + shape1 = shape2[:-2] + shape1[-2:] + inp1 = broadcast_to(inp1, shape1) + if maxdim > 3: + batch_shape = shape1[:-2] + # compress inputs to 3d + inp1 = inp1.reshape((-1, shape1[-2], shape1[-1])) + inp2 = inp2.reshape((-1, shape2[-2], shape2[-1])) + + op = builtin.BatchedMatrixMul( + transposeA=transpose_a, + transposeB=transpose_b, + compute_mode=compute_mode, + format=format, + strategy=get_conv_execution_strategy(), + ) + else: + op = builtin.MatrixMul( + transposeA=transpose_a, + transposeB=transpose_b, + compute_mode=compute_mode, + format=format, + strategy=get_conv_execution_strategy(), + ) + + (result,) = apply(op, inp1, inp2) + if maxdim > 3: + if use_symbolic_shape(): + (result,) = apply( + builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]]) + ) + else: + result = result.reshape(batch_shape + result.shape[-2:]) + if remove_row: + result = squeeze(result, axis=-2) + if remove_col: + result = squeeze(result, axis=-1) return result -def remap( - inp: Tensor, - map_xy: Tensor, - border_mode: str = "REPLICATE", - scalar: float = 0.0, - interp_mode: str = "LINEAR", -) -> Tensor: - r""" - Applies remap transformation to batched 2D images. - - The input images are transformed to the output images by the tensor map_xy. - The output's H and W are same as map_xy's H and W. - - :param inp: input image - :param map_xy: (batch, oh, ow, 2) transformation matrix - :param border_mode: pixel extrapolation method. - Default: "REPLICATE". Currently also support "CONSTANT", "REFLECT", - "REFLECT_101", "WRAP". - :param scalar: value used in case of a constant border. Default: 0 - :param interp_mode: interpolation methods. - Default: "LINEAR". Currently only support "LINEAR" mode. - :return: output tensor. +def dot(inp1: Tensor, inp2: Tensor) -> Tensor: + """ + Computes dot-product of two vectors ``inp1`` and ``inp2``. + inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted. + Refer to :func:`~.matmul` for more general usage. + + :param inp1: first vector. + :param inp2: second vector. + :return: output value. Examples: @@ -1121,56 +1137,35 @@ def remap( import numpy as np from megengine import tensor import megengine.functional as F - inp_shape = (1, 1, 4, 4) - inp = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) - map_xy_shape = (1, 2, 2, 2) - map_xy = tensor(np.array([[[1., 0.],[0., 1.]], - [[0., 1.],[0., 1.]]], - dtype=np.float32).reshape(map_xy_shape)) - out = F.remap(inp, map_xy) + + data1 = tensor(np.arange(0, 6, dtype=np.float32)) + data2 = tensor(np.arange(0, 6, dtype=np.float32)) + out = F.dot(data1, data2) print(out.numpy()) Outputs: .. testoutput:: - [[[[1. 4.] - [4. 4.]]]] + 55. """ - - op = builtin.Remap( - imode=interp_mode, border_type=border_mode, format="NCHW", scalar=scalar - ) - (result,) = apply(op, inp, map_xy) + op = builtin.Dot() + inp1, inp2 = utils.convert_inputs(inp1, inp2) + assert ( + inp1.ndim <= 1 and inp2.ndim <= 1 + ), "Input tensors for dot must be 1-dimensional or scalar" + (result,) = apply(op, inp1, inp2) + setscalar(result) return result -def interpolate( - inp: Tensor, - size: Optional[Union[int, Tuple[int, int]]] = None, - scale_factor: Optional[Union[float, Tuple[float, float]]] = None, - mode: str = "BILINEAR", - align_corners: Optional[bool] = None, -) -> Tensor: - r""" - Down/up samples the input tensor to either the given size or with the given scale_factor. ``size`` can not coexist with ``scale_factor``. +def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: + """ + Computes the singular value decompositions of input matrix. - :param inp: input tensor. - :param size: size of the output tensor. Default: None - :param scale_factor: scaling factor of the output tensor. Default: None - :param mode: interpolation methods, acceptable values are: - "BILINEAR", "LINEAR". Default: "BILINEAR" - :param align_corners: This only has an effect when `mode` - is "BILINEAR" or "LINEAR". Geometrically, we consider the pixels of the input - and output as squares rather than points. If set to ``True``, the input - and output tensors are aligned by the center points of their corner - pixels, preserving the values at the corner pixels. If set to ``False``, - the input and output tensors are aligned by the corner points of their - corner pixels, and the interpolation uses edge value padding for - out-of-boundary values, making this operation *independent* of input size - when `scale_factor` is kept the same. Default: None - :return: output tensor. + :param inp: input matrix, must has shape `[..., M, N]`. + :return: output matrices, `(U, sigma, V)`. Examples: @@ -1180,141 +1175,20 @@ def interpolate( from megengine import tensor import megengine.functional as F - x = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) - out = F.nn.interpolate(x, [4, 4], align_corners=False) - print(out.numpy()) - out2 = F.nn.interpolate(x, scale_factor=2.) - np.testing.assert_allclose(out.numpy(), out2.numpy()) + x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3)) + _, y, _ = F.svd(x) + print(y.numpy().round(decimals=3)) Outputs: .. testoutput:: - [[[[1. 1.25 1.75 2. ] - [1.5 1.75 2.25 2.5 ] - [2.5 2.75 3.25 3.5 ] - [3. 3.25 3.75 4. ]]]] + [7.348 1. ] """ - mode = mode.upper() - if mode not in ["BILINEAR", "LINEAR"]: - raise ValueError("interpolate only support linear or bilinear mode") - if mode not in ["BILINEAR", "LINEAR"]: - if align_corners is not None: - raise ValueError( - "align_corners option can only be set in the bilinear/linear interpolating mode" - ) - else: - if align_corners is None: - align_corners = False - - if ( - size is not None - and scale_factor is None - and not align_corners - and mode == "BILINEAR" - and inp.ndim in [4, 5] - ): - # fastpath for interpolate - op = builtin.Resize(imode="LINEAR", format="NCHW") - shape = astensor1d(size, inp, dtype="int32", device=inp.device) - (result,) = apply(op, inp, shape) - return result - - if mode == "LINEAR": - inp = expand_dims(inp, 3) - - if inp.ndim != 4: - raise ValueError("shape of input tensor must correspond to the operartion mode") - - if size is None: - if scale_factor is None: - raise ValueError("scale_factor must not be None when size is None") - - if isinstance(scale_factor, (float, int)): - scale_factor = float(scale_factor) - if mode == "LINEAR": - scale_factor = (scale_factor, float(1)) - else: - scale_factor = (scale_factor, scale_factor) - else: - if mode == "LINEAR": - raise ValueError( - "under LINEAR mode, scale_factor can only be single value" - ) - - assert len(scale_factor) == 2, "shape of scale_factor must be equal to (2, )" - assert isinstance(scale_factor[0], float) and isinstance( - scale_factor[1], float - ), "scale_factor must be float type" - dsize = tuple( - floor( - Tensor( - inp.shape[i + 2] * scale_factor[i], - dtype="float32", - device=inp.device, - ) - ) - for i in range(2) - ) - dsize = concat([dsize[0], dsize[1]], axis=0) - else: - if scale_factor is not None: - raise ValueError("scale_factor must be None when size is provided") - - if isinstance(size, int): - size = (size, 1) - else: - if mode == "LINEAR": - raise ValueError("under LINEAR mode, size can only be single value") - dsize = size - - oh, ow = dsize[0], dsize[1] - ih, iw = inp.shape[2], inp.shape[3] - - if align_corners: - hscale = (ih - 1.0) / (oh - 1.0) - wscale = 1.0 * iw / ow - if mode != "LINEAR": - wscale = (iw - 1.0) / (ow - 1.0) - row0 = concat( - [wscale, Tensor([0, 0], dtype="float32", device=inp.device)], axis=0 - ).reshape(1, 3) - row1 = concat( - [ - Tensor(0, dtype="float32", device=inp.device), - hscale, - Tensor(0, dtype="float32", device=inp.device), - ], - axis=0, - ).reshape(1, 3) - weight = concat( - [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], - axis=0, - ).reshape(1, 3, 3) - weight = broadcast_to(weight, (inp.shape[0], 3, 3)) - else: - hscale = 1.0 * ih / oh - wscale = 1.0 * iw / ow - row0 = concat( - [wscale, Tensor(0, dtype="float32", device=inp.device), 0.5 * wscale - 0.5], - axis=0, - ).reshape(1, 3) - row1 = concat( - [Tensor(0, dtype="float32", device=inp.device), hscale, 0.5 * hscale - 0.5], - axis=0, - ).reshape(1, 3) - weight = concat( - [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], - axis=0, - ).reshape(1, 3, 3) - weight = broadcast_to(weight, (inp.shape[0], 3, 3)) - - weight = weight.astype("float32") - ret = warp_perspective(inp, weight, dsize, interp_mode="LINEAR") - if mode == "LINEAR": - ret = reshape(ret, ret.shape[0:3]) - return ret + op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv) + U, sigma, V = apply(op, inp) + return U, sigma, V def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor: @@ -1385,127 +1259,6 @@ def embedding( return weight[inp.reshape(-1)].reshape(dest_shp) -def roi_pooling( - inp: Tensor, - rois: Tensor, - output_shape: Union[int, tuple, list], - mode: str = "max", - scale: float = 1.0, -) -> Tensor: - """ - Applies roi pooling on input feature. - - :param inp: tensor that represents the input feature, `(N, C, H, W)` images. - :param rois: `(K, 5)` boxes. First column is the index into N. The other 4 columns are xyxy. - :param output_shape: `(height, width)` of output rois feature. - :param mode: "max" or "average", use max/average align just like max/average pooling. Default: "max" - :param scale: scale the input boxes by this number. Default: 1.0 - :return: `(K, C, output_shape[0], output_shape[1])` feature of rois. - - Examples: - - .. testcode:: - - import numpy as np - from megengine import tensor - import megengine.functional as F - - np.random.seed(42) - inp = tensor(np.random.randn(1, 1, 128, 128)) - rois = tensor(np.random.random((4, 5))) - y = F.nn.roi_pooling(inp, rois, (2, 2)) - print(y.numpy()[0].round(decimals=4)) - - Outputs: - - .. testoutput:: - - [[[-0.1383 -0.1383] - [-0.5035 -0.5035]]] - - - """ - assert mode in ["max", "average"], "only max/average mode is supported" - if isinstance(output_shape, int): - output_shape = (output_shape, output_shape) - - op = builtin.ROIPooling(mode=mode, scale=scale) - inp, rois = utils.convert_inputs(inp, rois) - result, _ = apply( - op, inp, rois, Tensor(output_shape, dtype="int32", device=inp.device) - ) - return result - - -def roi_align( - inp: Tensor, - rois: Tensor, - output_shape: Union[int, tuple, list], - mode: str = "average", - spatial_scale: float = 1.0, - sample_points: Union[int, tuple, list] = 2, - aligned: bool = True, -) -> Tensor: - """ - Applies roi align on input feature. - - :param inp: tensor that represents the input feature, shape is `(N, C, H, W)`. - :param rois: `(N, 5)` boxes. First column is the box index. The other 4 columns are ``xyxy``. - :param output_shape: `(height, width)` shape of output rois feature. - :param mode: "max" or "average", use max/average align just like max/average pooling. Default: "average" - :param spatial_scale: scale the input boxes by this number. Default: 1.0 - :param sample_points: number of inputs samples to take for each output sample. - 0 to take samples densely. Default: 2 - :param aligned: wheather to align the input feature, with `aligned=True`, - we first appropriately scale the ROI and then shift it by -0.5. Default: True - :return: output tensor. - - Examples: - - .. testcode:: - - import numpy as np - from megengine import tensor - import megengine.functional as F - - np.random.seed(42) - inp = tensor(np.random.randn(1, 1, 128, 128)) - rois = tensor(np.random.random((4, 5))) - y = F.nn.roi_align(inp, rois, (2, 2)) - print(y.numpy()[0].round(decimals=4)) - - Outputs: - - .. testoutput:: - - [[[0.175 0.175 ] - [0.1359 0.1359]]] - - """ - assert mode in ["max", "average"], "only max/average mode is supported" - if isinstance(output_shape, int): - output_shape = (output_shape, output_shape) - pooled_height, pooled_width = output_shape - if isinstance(sample_points, int): - sample_points = (sample_points, sample_points) - sample_height, sample_width = sample_points - offset = 0.5 if aligned else 0.0 - - op = builtin.ROIAlign( - mode=mode, - format="NCHW", - spatial_scale=spatial_scale, - offset=offset, - pooled_height=pooled_height, - pooled_width=pooled_width, - sample_height=sample_height, - sample_width=sample_width, - ) - inp, rois = utils.convert_inputs(inp, rois) - result, *_ = apply(op, inp, rois) - return result - - def indexing_one_hot( src: Tensor, index: Tensor, axis: int = 1, keepdims=False ) -> Tensor: @@ -1621,72 +1374,6 @@ def conv1d( return output -def nms( - boxes: Tensor, scores: Tensor, iou_thresh: float, max_output: Optional[int] = None -) -> Tensor: - r""" - Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union(IoU). - - :param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format. - :param iou_thresh: IoU threshold for overlapping. - :param scores: tensor of shape `(N,)`, the score of boxes. - :param max_output: the maximum number of boxes to keep; it is optional if this operator is not traced - otherwise it required to be specified; if it is not specified, all boxes are kept. - :return: indices of the elements that have been kept by NMS. - - Examples: - - .. testcode:: - - import numpy as np - from megengine import tensor - import megengine.functional as F - - x = np.zeros((100,4)) - np.random.seed(42) - x[:,:2] = np.random.rand(100,2)*20 - x[:,2:] = np.random.rand(100,2)*20 + 100 - scores = tensor(np.random.rand(100)) - inp = tensor(x) - result = F.nn.nms(inp, scores, iou_thresh=0.7) - print(result.numpy()) - - Outputs: - - .. testoutput:: - - [75 69] - - """ - assert ( - boxes.ndim == 2 and boxes.shape[1] == 4 - ), "the expected shape of boxes is (N, 4)" - assert scores.ndim == 1, "the expected shape of scores is (N,)" - assert ( - boxes.shape[0] == scores.shape[0] - ), "number of boxes and scores are not matched" - - boxes = boxes.detach() - scores = scores.detach() - sorted_idx = argsort(scores, descending=True) - boxes = boxes[sorted_idx] - - if is_tracing(): - assert ( - max_output is not None and max_output > 0 - ), "max_output should be specified under tracing" - - if max_output is None: - max_output = boxes.shape[0] - - op = builtin.NMSKeep(iou_thresh, max_output) - inp = utils.convert_inputs(boxes.reshape(1, -1, 4)) - indices, count = apply(op, *inp) - indices = indices[0][: count[0]] - keep_inds = sorted_idx[indices] - return keep_inds - - def nvof(src: Tensor, precision: int = 1) -> Tensor: r""" Implements NVIDIA Optical Flow SDK. @@ -1717,5 +1404,89 @@ def nvof(src: Tensor, precision: int = 1) -> Tensor: return apply(op, src)[0] +def _elwise(*args, mode): + tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args)) + if len(tensor_args) == 0: + dtype = utils.dtype_promotion(args) + first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) + args = utils.convert_inputs(first_arg, *args[1:]) + else: + args = utils.convert_inputs(*args) + if mode in ( + Elemwise.Mode.TRUE_DIV, + Elemwise.Mode.EXP, + Elemwise.Mode.POW, + Elemwise.Mode.LOG, + Elemwise.Mode.EXPM1, + Elemwise.Mode.LOG1P, + Elemwise.Mode.TANH, + Elemwise.Mode.ACOS, + Elemwise.Mode.ASIN, + Elemwise.Mode.ATAN2, + Elemwise.Mode.CEIL, + Elemwise.Mode.COS, + Elemwise.Mode.FLOOR, + Elemwise.Mode.H_SWISH, + Elemwise.Mode.ROUND, + Elemwise.Mode.SIGMOID, + Elemwise.Mode.SIN, + ): + if mode in ( + Elemwise.Mode.CEIL, + Elemwise.Mode.FLOOR, + Elemwise.Mode.ROUND, + ) and np.issubdtype(args[0].dtype, np.integer): + return args[0] + args = tuple(map(lambda x: astype(x, "float32"), args)) + return _elwise_apply(args, mode) + + +def hswish(x): + """ + Element-wise `x * relu6(x + 3) / 6`. + + :param x: input tensor. + :return: computed tensor. + + Example: + + .. testcode:: + + import numpy as np + from megengine import tensor + import megengine.functional as F + + x = tensor(np.arange(5).astype(np.float32)) + out = F.hswish(x) + print(out.numpy().round(decimals=4)) + + .. testoutput:: + + [0. 0.6667 1.6667 3. 4. ] + + """ + return _elwise(x, mode=Elemwise.Mode.H_SWISH) + + +def sigmoid(x): + """Element-wise `1 / ( 1 + exp( -x ) )`.""" + return _elwise(x, mode=Elemwise.Mode.SIGMOID) + + +def hsigmoid(x): + """Element-wise `relu6(x + 3) / 6`.""" + return relu6(x + 3) / 6 + + +def relu(x): + """Element-wise `max(x, 0)`.""" + return _elwise(x, mode=Elemwise.Mode.RELU) + + +def relu6(x): + """Element-wise `min(max(x, 0), 6)`.""" + return minimum(maximum(x, 0), 6) + + from .loss import * # isort:skip from .quantized import conv_bias_activation # isort:skip diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 6f7c63e2..c67b4428 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -6,10 +6,8 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import functools import math -from itertools import accumulate -from typing import Iterable, List, Optional, Sequence, Tuple, Union +from typing import Iterable, Optional, Sequence, Union import numpy as np @@ -17,6 +15,7 @@ from ..core._imperative_rt import CompNode from ..core._imperative_rt.core2 import apply from ..core._wrap import device as as_device from ..core.ops import builtin +from ..core.ops.builtin import Copy, Identity from ..core.ops.special import Const from ..core.tensor.array_method import _broadcast, _remove_axis from ..core.tensor.utils import ( @@ -51,6 +50,7 @@ __all__ = [ "stack", "scatter", "tile", + "copy", "transpose", "where", "zeros", @@ -1130,3 +1130,33 @@ def tile(inp: Tensor, reps: Iterable[int]): inp = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape) return inp + + +def copy(inp, device=None): + r""" + Copies tensor to another device. + + :param inp: input tensor. + :param device: destination device. + + Examples: + + .. testcode:: + + import numpy as np + from megengine import tensor + import megengine.functional as F + + x = tensor([1, 2, 3], np.int32) + y = F.copy(x, "xpu1") + print(y.numpy()) + + Outputs: + + .. testoutput:: + + [1 2 3] + """ + if device is None: + return apply(Identity(), inp)[0] + return apply(Copy(comp_node=as_device(device).to_c()), inp)[0] diff --git a/imperative/python/megengine/functional/vision.py b/imperative/python/megengine/functional/vision.py new file mode 100644 index 00000000..4a367248 --- /dev/null +++ b/imperative/python/megengine/functional/vision.py @@ -0,0 +1,576 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from typing import Iterable, Optional, Tuple, Union + +from ..core._imperative_rt.core2 import apply +from ..core.ops import builtin +from ..core.tensor import megbrain_graph, utils +from ..core.tensor.utils import astensor1d +from ..jit.tracing import is_tracing +from ..tensor import Tensor +from .elemwise import floor +from .math import argsort +from .tensor import broadcast_to, concat, expand_dims, reshape + + +def cvt_color(inp: Tensor, mode: str = ""): + r""" + Convert images from one format to another + + :param inp: input images. + :param mode: format mode. + :return: convert result. + + Examples: + + .. testcode:: + + import numpy as np + import megengine as mge + import megengine.functional as F + + x = mge.tensor(np.array([[[[-0.58675045, 1.7526233, 0.10702174]]]]).astype(np.float32)) + y = F.vision.cvt_color(x, mode="RGB2GRAY") + print(y.numpy()) + + Outputs: + + .. testoutput:: + + [[[[0.86555195]]]] + + """ + assert mode in builtin.CvtColor.Mode.__dict__, "unspport mode for cvt_color" + mode = getattr(builtin.CvtColor.Mode, mode) + assert isinstance(mode, builtin.CvtColor.Mode) + op = builtin.CvtColor(mode=mode) + (out,) = apply(op, inp) + return out + + +def roi_pooling( + inp: Tensor, + rois: Tensor, + output_shape: Union[int, tuple, list], + mode: str = "max", + scale: float = 1.0, +) -> Tensor: + """ + Applies roi pooling on input feature. + + :param inp: tensor that represents the input feature, `(N, C, H, W)` images. + :param rois: `(K, 5)` boxes. First column is the index into N. The other 4 columns are xyxy. + :param output_shape: `(height, width)` of output rois feature. + :param mode: "max" or "average", use max/average align just like max/average pooling. Default: "max" + :param scale: scale the input boxes by this number. Default: 1.0 + :return: `(K, C, output_shape[0], output_shape[1])` feature of rois. + + Examples: + + .. testcode:: + + import numpy as np + from megengine import tensor + import megengine.functional as F + + np.random.seed(42) + inp = tensor(np.random.randn(1, 1, 128, 128)) + rois = tensor(np.random.random((4, 5))) + y = F.vision.roi_pooling(inp, rois, (2, 2)) + print(y.numpy()[0].round(decimals=4)) + + Outputs: + + .. testoutput:: + + [[[-0.1383 -0.1383] + [-0.5035 -0.5035]]] + + + """ + assert mode in ["max", "average"], "only max/average mode is supported" + if isinstance(output_shape, int): + output_shape = (output_shape, output_shape) + + op = builtin.ROIPooling(mode=mode, scale=scale) + inp, rois = utils.convert_inputs(inp, rois) + result, _ = apply( + op, inp, rois, Tensor(output_shape, dtype="int32", device=inp.device) + ) + return result + + +def roi_align( + inp: Tensor, + rois: Tensor, + output_shape: Union[int, tuple, list], + mode: str = "average", + spatial_scale: float = 1.0, + sample_points: Union[int, tuple, list] = 2, + aligned: bool = True, +) -> Tensor: + """ + Applies roi align on input feature. + + :param inp: tensor that represents the input feature, shape is `(N, C, H, W)`. + :param rois: `(N, 5)` boxes. First column is the box index. The other 4 columns are ``xyxy``. + :param output_shape: `(height, width)` shape of output rois feature. + :param mode: "max" or "average", use max/average align just like max/average pooling. Default: "average" + :param spatial_scale: scale the input boxes by this number. Default: 1.0 + :param sample_points: number of inputs samples to take for each output sample. + 0 to take samples densely. Default: 2 + :param aligned: wheather to align the input feature, with `aligned=True`, + we first appropriately scale the ROI and then shift it by -0.5. Default: True + :return: output tensor. + + Examples: + + .. testcode:: + + import numpy as np + from megengine import tensor + import megengine.functional as F + + np.random.seed(42) + inp = tensor(np.random.randn(1, 1, 128, 128)) + rois = tensor(np.random.random((4, 5))) + y = F.vision.roi_align(inp, rois, (2, 2)) + print(y.numpy()[0].round(decimals=4)) + + Outputs: + + .. testoutput:: + + [[[0.175 0.175 ] + [0.1359 0.1359]]] + + """ + assert mode in ["max", "average"], "only max/average mode is supported" + if isinstance(output_shape, int): + output_shape = (output_shape, output_shape) + pooled_height, pooled_width = output_shape + if isinstance(sample_points, int): + sample_points = (sample_points, sample_points) + sample_height, sample_width = sample_points + offset = 0.5 if aligned else 0.0 + + op = builtin.ROIAlign( + mode=mode, + format="NCHW", + spatial_scale=spatial_scale, + offset=offset, + pooled_height=pooled_height, + pooled_width=pooled_width, + sample_height=sample_height, + sample_width=sample_width, + ) + inp, rois = utils.convert_inputs(inp, rois) + result, *_ = apply(op, inp, rois) + return result + + +def nms( + boxes: Tensor, scores: Tensor, iou_thresh: float, max_output: Optional[int] = None +) -> Tensor: + r""" + Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union(IoU). + + :param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format. + :param iou_thresh: IoU threshold for overlapping. + :param scores: tensor of shape `(N,)`, the score of boxes. + :param max_output: the maximum number of boxes to keep; it is optional if this operator is not traced + otherwise it required to be specified; if it is not specified, all boxes are kept. + :return: indices of the elements that have been kept by NMS. + + Examples: + + .. testcode:: + + import numpy as np + from megengine import tensor + import megengine.functional as F + + x = np.zeros((100,4)) + np.random.seed(42) + x[:,:2] = np.random.rand(100,2)*20 + x[:,2:] = np.random.rand(100,2)*20 + 100 + scores = tensor(np.random.rand(100)) + inp = tensor(x) + result = F.vision.nms(inp, scores, iou_thresh=0.7) + print(result.numpy()) + + Outputs: + + .. testoutput:: + + [75 69] + + """ + assert ( + boxes.ndim == 2 and boxes.shape[1] == 4 + ), "the expected shape of boxes is (N, 4)" + assert scores.ndim == 1, "the expected shape of scores is (N,)" + assert ( + boxes.shape[0] == scores.shape[0] + ), "number of boxes and scores are not matched" + + boxes = boxes.detach() + scores = scores.detach() + sorted_idx = argsort(scores, descending=True) + boxes = boxes[sorted_idx] + + if is_tracing(): + assert ( + max_output is not None and max_output > 0 + ), "max_output should be specified under tracing" + + if max_output is None: + max_output = boxes.shape[0] + + op = builtin.NMSKeep(iou_thresh, max_output) + inp = utils.convert_inputs(boxes.reshape(1, -1, 4)) + indices, count = apply(op, *inp) + indices = indices[0][: count[0]] + keep_inds = sorted_idx[indices] + return keep_inds + + +def remap( + inp: Tensor, + map_xy: Tensor, + border_mode: str = "REPLICATE", + scalar: float = 0.0, + interp_mode: str = "LINEAR", +) -> Tensor: + r""" + Applies remap transformation to batched 2D images. + + The input images are transformed to the output images by the tensor map_xy. + The output's H and W are same as map_xy's H and W. + + :param inp: input image + :param map_xy: (batch, oh, ow, 2) transformation matrix + :param border_mode: pixel extrapolation method. + Default: "REPLICATE". Currently also support "CONSTANT", "REFLECT", + "REFLECT_101", "WRAP". + :param scalar: value used in case of a constant border. Default: 0 + :param interp_mode: interpolation methods. + Default: "LINEAR". Currently only support "LINEAR" mode. + :return: output tensor. + + Examples: + + .. testcode:: + + import numpy as np + from megengine import tensor + import megengine.functional as F + inp_shape = (1, 1, 4, 4) + inp = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) + map_xy_shape = (1, 2, 2, 2) + map_xy = tensor(np.array([[[1., 0.],[0., 1.]], + [[0., 1.],[0., 1.]]], + dtype=np.float32).reshape(map_xy_shape)) + out = F.vision.remap(inp, map_xy) + print(out.numpy()) + + Outputs: + + .. testoutput:: + + [[[[1. 4.] + [4. 4.]]]] + + """ + + op = builtin.Remap( + imode=interp_mode, border_type=border_mode, format="NCHW", scalar=scalar + ) + assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type" + (result,) = apply(op, inp, map_xy) + return result + + +def warp_affine( + inp: Tensor, + weight: Tensor, + out_shape, + border_mode="REPLICATE", + border_val=0, + format="NHWC", + imode="LINEAR", +): + """ + Batched affine transform on 2D images. + + :param inp: input image. + :param weight: weight tensor. + :param out_shape: output tensor shape. + :param border_mode: pixel extrapolation method. + Default: "WRAP". Currently "CONSTANT", "REFLECT", + "REFLECT_101", "ISOLATED", "WRAP", "REPLICATE", "TRANSPARENT" are supported. + :param border_val: value used in case of a constant border. Default: 0 + :param format: "NHWC" as default based on historical concerns, + "NCHW" is also supported. Default: "NCHW". + :param imode: interpolation methods. Could be "LINEAR", "NEAREST", "CUBIC", "AREA". + Default: "LINEAR". + :return: output tensor. + + .. note:: + + Here all available options for params are listed, + however it does not mean that you can use all the combinations. + On different platforms, different combinations are supported. + """ + op = builtin.WarpAffine( + border_mode=border_mode, border_val=border_val, format=format, imode=imode + ) + out_shape = utils.astensor1d(out_shape, inp, dtype="int32", device=inp.device) + (result,) = apply(op, inp, weight, out_shape) + return result + + +def warp_perspective( + inp: Tensor, + M: Tensor, + dsize: Union[Tuple[int, int], int, Tensor], + border_mode: str = "REPLICATE", + border_val: float = 0.0, + interp_mode: str = "LINEAR", +) -> Tensor: + r""" + Applies perspective transformation to batched 2D images. + + The input images are transformed to the output images by the transformation matrix: + + .. math:: + \text{output}(n, c, h, w) = \text{input} \left( n, c, + \frac{M_{00}h + M_{01}w + M_{02}}{M_{20}h + M_{21}w + M_{22}}, + \frac{M_{10}h + M_{11}w + M_{12}}{M_{20}h + M_{21}w + M_{22}} + \right) + + :param inp: input image. + :param M: `(batch, 3, 3)` transformation matrix. + :param dsize: `(h, w)` size of the output image. + :param border_mode: pixel extrapolation method. + Default: "REPLICATE". Currently also support "CONSTANT", "REFLECT", + "REFLECT_101", "WRAP". + :param border_val: value used in case of a constant border. Default: 0 + :param interp_mode: interpolation methods. + Default: "LINEAR". Currently only support "LINEAR" mode. + :return: output tensor. + + Note: + + The transformation matrix is the inverse of that used by `cv2.warpPerspective`. + + Examples: + + .. testcode:: + + import numpy as np + from megengine import tensor + import megengine.functional as F + + inp_shape = (1, 1, 4, 4) + x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) + M_shape = (1, 3, 3) + # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1) + M = tensor(np.array([[1., 0., 1.], + [0., 1., 1.], + [0., 0., 1.]], dtype=np.float32).reshape(M_shape)) + out = F.vision.warp_perspective(x, M, (2, 2)) + print(out.numpy()) + + Outputs: + + .. testoutput:: + + [[[[ 5. 6.] + [ 9. 10.]]]] + + """ + op = builtin.WarpPerspective( + imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val + ) + inp, M = utils.convert_inputs(inp, M) + dsize = astensor1d(dsize, inp, dtype="int32", device=inp.device) + (result,) = apply(op, inp, M, dsize) + return result + + +def interpolate( + inp: Tensor, + size: Optional[Union[int, Tuple[int, int]]] = None, + scale_factor: Optional[Union[float, Tuple[float, float]]] = None, + mode: str = "BILINEAR", + align_corners: Optional[bool] = None, +) -> Tensor: + r""" + Down/up samples the input tensor to either the given size or with the given scale_factor. ``size`` can not coexist with ``scale_factor``. + + :param inp: input tensor. + :param size: size of the output tensor. Default: None + :param scale_factor: scaling factor of the output tensor. Default: None + :param mode: interpolation methods, acceptable values are: + "BILINEAR", "LINEAR". Default: "BILINEAR" + :param align_corners: This only has an effect when `mode` + is "BILINEAR" or "LINEAR". Geometrically, we consider the pixels of the input + and output as squares rather than points. If set to ``True``, the input + and output tensors are aligned by the center points of their corner + pixels, preserving the values at the corner pixels. If set to ``False``, + the input and output tensors are aligned by the corner points of their + corner pixels, and the interpolation uses edge value padding for + out-of-boundary values, making this operation *independent* of input size + + :return: output tensor. + + Examples: + + .. testcode:: + + import numpy as np + from megengine import tensor + import megengine.functional as F + + x = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) + out = F.vision.interpolate(x, [4, 4], align_corners=False) + print(out.numpy()) + out2 = F.vision.interpolate(x, scale_factor=2.) + np.testing.assert_allclose(out.numpy(), out2.numpy()) + + Outputs: + + .. testoutput:: + + [[[[1. 1.25 1.75 2. ] + [1.5 1.75 2.25 2.5 ] + [2.5 2.75 3.25 3.5 ] + [3. 3.25 3.75 4. ]]]] + + """ + mode = mode.upper() + if mode not in ["BILINEAR", "LINEAR"]: + raise ValueError("interpolate only support linear or bilinear mode") + if mode not in ["BILINEAR", "LINEAR"]: + if align_corners is not None: + raise ValueError( + "align_corners option can only be set in the bilinear/linear interpolating mode" + ) + else: + if align_corners is None: + align_corners = False + + if ( + size is not None + and scale_factor is None + and not align_corners + and mode == "BILINEAR" + and inp.ndim in [4, 5] + ): + # fastpath for interpolate + op = builtin.Resize(imode="LINEAR", format="NCHW") + shape = astensor1d(size, inp, dtype="int32", device=inp.device) + (result,) = apply(op, inp, shape) + return result + + if mode == "LINEAR": + inp = expand_dims(inp, 3) + + if inp.ndim != 4: + raise ValueError("shape of input tensor must correspond to the operartion mode") + + if size is None: + if scale_factor is None: + raise ValueError("scale_factor must not be None when size is None") + + if isinstance(scale_factor, (float, int)): + scale_factor = float(scale_factor) + if mode == "LINEAR": + scale_factor = (scale_factor, float(1)) + else: + scale_factor = (scale_factor, scale_factor) + else: + if mode == "LINEAR": + raise ValueError( + "under LINEAR mode, scale_factor can only be single value" + ) + + assert len(scale_factor) == 2, "shape of scale_factor must be equal to (2, )" + assert isinstance(scale_factor[0], float) and isinstance( + scale_factor[1], float + ), "scale_factor must be float type" + dsize = tuple( + floor( + Tensor( + inp.shape[i + 2] * scale_factor[i], + dtype="float32", + device=inp.device, + ) + ) + for i in range(2) + ) + dsize = concat([dsize[0], dsize[1]], axis=0) + else: + if scale_factor is not None: + raise ValueError("scale_factor must be None when size is provided") + + if isinstance(size, int): + size = (size, 1) + else: + if mode == "LINEAR": + raise ValueError("under LINEAR mode, size can only be single value") + dsize = size + + oh, ow = dsize[0], dsize[1] + ih, iw = inp.shape[2], inp.shape[3] + + if align_corners: + hscale = (ih - 1.0) / (oh - 1.0) + wscale = 1.0 * iw / ow + if mode != "LINEAR": + wscale = (iw - 1.0) / (ow - 1.0) + row0 = concat( + [wscale, Tensor([0, 0], dtype="float32", device=inp.device)], axis=0 + ).reshape(1, 3) + row1 = concat( + [ + Tensor(0, dtype="float32", device=inp.device), + hscale, + Tensor(0, dtype="float32", device=inp.device), + ], + axis=0, + ).reshape(1, 3) + weight = concat( + [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], + axis=0, + ).reshape(1, 3, 3) + weight = broadcast_to(weight, (inp.shape[0], 3, 3)) + else: + hscale = 1.0 * ih / oh + wscale = 1.0 * iw / ow + row0 = concat( + [wscale, Tensor(0, dtype="float32", device=inp.device), 0.5 * wscale - 0.5], + axis=0, + ).reshape(1, 3) + row1 = concat( + [Tensor(0, dtype="float32", device=inp.device), hscale, 0.5 * hscale - 0.5], + axis=0, + ).reshape(1, 3) + weight = concat( + [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], + axis=0, + ).reshape(1, 3, 3) + weight = broadcast_to(weight, (inp.shape[0], 3, 3)) + + weight = weight.astype("float32") + ret = warp_perspective(inp, weight, dsize, interp_mode="LINEAR") + if mode == "LINEAR": + ret = reshape(ret, ret.shape[0:3]) + return ret diff --git a/imperative/python/megengine/module/identity.py b/imperative/python/megengine/module/identity.py index cdeb0e5b..a62d1ed9 100644 --- a/imperative/python/megengine/module/identity.py +++ b/imperative/python/megengine/module/identity.py @@ -6,7 +6,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from ..functional import copy +from ..functional.tensor import copy from .module import Module diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index be1f543e..b95359b5 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -372,7 +372,7 @@ def test_interpolate_fastpath(): x = mge.Tensor(x_np) grad = Grad().wrt(x, callback=save_to(x)) - y = F.nn.interpolate(x, size=(16, 16), mode="BILINEAR") + y = F.vision.interpolate(x, size=(16, 16), mode="BILINEAR") grad(y, F.ones_like(y)) np.testing.assert_equal(np.ones(x_np.shape, dtype=np.float32) / 4, x.grad.numpy()) diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 13a80d24..84ce29bc 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -136,8 +136,8 @@ def test_interpolate(): def linear_interpolate(): inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) - out = F.nn.interpolate(inp, scale_factor=2.0, mode="LINEAR") - out2 = F.nn.interpolate(inp, 4, mode="LINEAR") + out = F.vision.interpolate(inp, scale_factor=2.0, mode="LINEAR") + out2 = F.vision.interpolate(inp, 4, mode="LINEAR") np.testing.assert_allclose( out.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32) @@ -149,16 +149,16 @@ def test_interpolate(): def many_batch_interpolate(): inp = tensor(np.arange(1, 9, dtype=np.float32).reshape(2, 1, 2, 2)) - out = F.nn.interpolate(inp, [4, 4]) - out2 = F.nn.interpolate(inp, scale_factor=2.0) + out = F.vision.interpolate(inp, [4, 4]) + out2 = F.vision.interpolate(inp, scale_factor=2.0) np.testing.assert_allclose(out.numpy(), out2.numpy()) def assign_corner_interpolate(): inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) - out = F.nn.interpolate(inp, [4, 4], align_corners=True) - out2 = F.nn.interpolate(inp, scale_factor=2.0, align_corners=True) + out = F.vision.interpolate(inp, [4, 4], align_corners=True) + out2 = F.vision.interpolate(inp, scale_factor=2.0, align_corners=True) np.testing.assert_allclose(out.numpy(), out2.numpy()) @@ -166,13 +166,13 @@ def test_interpolate(): inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) with pytest.raises(ValueError): - F.nn.interpolate(inp, scale_factor=2.0, mode="LINEAR") + F.vision.interpolate(inp, scale_factor=2.0, mode="LINEAR") def inappropriate_scale_linear_interpolate(): inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) with pytest.raises(ValueError): - F.nn.interpolate(inp, scale_factor=[2.0, 3.0], mode="LINEAR") + F.vision.interpolate(inp, scale_factor=[2.0, 3.0], mode="LINEAR") linear_interpolate() many_batch_interpolate() @@ -205,7 +205,7 @@ def test_roi_align(): grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat)) output_shape = (7, 7) - out_feat = F.nn.roi_align( + out_feat = F.vision.roi_align( inp_feat, rois, output_shape=output_shape, @@ -228,7 +228,7 @@ def test_roi_pooling(): inp_feat, rois = _gen_roi_inp() grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat)) output_shape = (7, 7) - out_feat = F.nn.roi_pooling( + out_feat = F.vision.roi_pooling( inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4, ) assert make_shape_tuple(out_feat.shape) == ( @@ -335,18 +335,18 @@ def test_interpolate_fastpath(): ] for inp_shape, target_shape in test_cases: x = tensor(np.random.randn(*inp_shape), dtype=np.float32) - out = F.nn.interpolate(x, target_shape, mode="BILINEAR") + out = F.vision.interpolate(x, target_shape, mode="BILINEAR") assert out.shape[0] == x.shape[0] and out.shape[1] == x.shape[1] assert out.shape[2] == target_shape[0] and out.shape[3] == target_shape[1] # check value x = tensor(np.ones((3, 3, 10, 10)), dtype=np.float32) - out = F.nn.interpolate(x, (15, 5), mode="BILINEAR") + out = F.vision.interpolate(x, (15, 5), mode="BILINEAR") np.testing.assert_equal(out.numpy(), np.ones((3, 3, 15, 5)).astype(np.float32)) np_x = np.arange(32) x = tensor(np_x).astype(np.float32).reshape(1, 1, 32, 1) - out = F.nn.interpolate(x, (1, 1), mode="BILINEAR") + out = F.vision.interpolate(x, (1, 1), mode="BILINEAR") np.testing.assert_equal(out.item(), np_x.mean()) @@ -360,7 +360,7 @@ def test_warp_perspective(): [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32 ).reshape(M_shape) ) - outp = F.warp_perspective(x, M, (2, 2)) + outp = F.vision.warp_perspective(x, M, (2, 2)) np.testing.assert_equal( outp.numpy(), np.array([[[[5.0, 6.0], [9.0, 10.0]]]], dtype=np.float32) ) @@ -370,7 +370,7 @@ def test_warp_affine(): inp_shape = (1, 3, 3, 3) x = tensor(np.arange(27, dtype=np.float32).reshape(inp_shape)) weightv = [[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]] - outp = F.warp_affine(x, tensor(weightv), (2, 2), border_mode="WRAP") + outp = F.vision.warp_affine(x, tensor(weightv), (2, 2), border_mode="WRAP") res = np.array( [ [ @@ -393,7 +393,7 @@ def test_remap(): [[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]], dtype=np.float32 ).reshape(map_xy_shape) ) - outp = F.remap(inp, map_xy) + outp = F.vision.remap(inp, map_xy) np.testing.assert_equal( outp.numpy(), np.array([[[[1.0, 4.0], [4.0, 4.0]]]], dtype=np.float32) ) @@ -476,7 +476,7 @@ def test_nms(): ) inp = tensor(x) scores = tensor([0.5, 0.8, 0.9, 0.6], dtype=np.float32) - result = F.nn.nms(inp, scores=scores, iou_thresh=0.5) + result = F.vision.nms(inp, scores=scores, iou_thresh=0.5) np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32)) @@ -737,7 +737,7 @@ def test_cvt_color(): inp = np.random.randn(3, 3, 3, 3).astype(np.float32) out = np.expand_dims(rgb2gray(inp), 3).astype(np.float32) x = tensor(inp) - y = F.img_proc.cvt_color(x, mode="RGB2GRAY") + y = F.vision.cvt_color(x, mode="RGB2GRAY") np.testing.assert_allclose(y.numpy(), out, atol=1e-5) diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index 1669899f..3a699441 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -360,7 +360,7 @@ def test_trace_warp_perspective(): @trace(symbolic=True) def f(x, M): - out = F.warp_perspective(x, M, (2, 2)) + out = F.vision.warp_perspective(x, M, (2, 2)) np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2])) return out @@ -429,10 +429,10 @@ def test_trace_nms(): @trace(symbolic=False) def f(boxes, scores): # with tracing, max_output must be specified - results = F.nn.nms(boxes, scores=scores, iou_thresh=0.5, max_output=20) + results = F.vision.nms(boxes, scores=scores, iou_thresh=0.5, max_output=20) # without tracing, max output can be inferred inside nms with exclude_from_trace(): - _ = F.nn.nms(boxes, scores=scores, iou_thresh=0.5) + _ = F.vision.nms(boxes, scores=scores, iou_thresh=0.5) return results f(*make_inputs(10)) diff --git a/imperative/python/test/unit/utils/test_network_node.py b/imperative/python/test/unit/utils/test_network_node.py index 008be7bd..62896bad 100644 --- a/imperative/python/test/unit/utils/test_network_node.py +++ b/imperative/python/test/unit/utils/test_network_node.py @@ -226,7 +226,7 @@ def test_roipooling(): @trace(symbolic=True, capture_as_const=True) def fwd(inp, rois): - return F.nn.roi_pooling(inp, rois, (2, 2), scale=2.0) + return F.vision.roi_pooling(inp, rois, (2, 2), scale=2.0) output = fwd(inp, rois) check_pygraph_dump(fwd, [inp, rois], [output]) @@ -315,7 +315,7 @@ def test_roialign(): @trace(symbolic=True, capture_as_const=True) def fwd(inp, rois): - return F.nn.roi_align(inp, rois, (2, 2)) + return F.vision.roi_align(inp, rois, (2, 2)) output = fwd(inp, rois) check_pygraph_dump(fwd, [inp, rois], [output]) @@ -334,7 +334,7 @@ def test_warpperspective(): @trace(symbolic=True, capture_as_const=True) def fwd(x, M): - return F.warp_perspective(x, M, (2, 2)) + return F.vision.warp_perspective(x, M, (2, 2)) result = fwd(x, M) check_pygraph_dump(fwd, [x, M], [result]) @@ -347,7 +347,7 @@ def test_warpaffine(): @trace(symbolic=True, capture_as_const=True) def fwd(x, weightv): - return F.warp_affine(x, weightv, (2, 2), border_mode="WRAP") + return F.vision.warp_affine(x, weightv, (2, 2), border_mode="WRAP") outp = fwd(x, weightv) check_pygraph_dump(fwd, [x, weightv], [outp]) @@ -365,7 +365,7 @@ def test_remap(): @trace(symbolic=True, capture_as_const=True) def fwd(inp, map_xy): - return F.remap(inp, map_xy) + return F.vision.remap(inp, map_xy) out = fwd(inp, map_xy) check_pygraph_dump(fwd, [inp, map_xy], [out]) @@ -376,7 +376,7 @@ def test_resize(): @trace(symbolic=True, capture_as_const=True) def fwd(x): - return F.nn.interpolate(x, size=(16, 16), mode="BILINEAR") + return F.vision.interpolate(x, size=(16, 16), mode="BILINEAR") out = fwd(x) check_pygraph_dump(fwd, [x], [out]) @@ -706,7 +706,7 @@ def test_cvtcolor(): @trace(symbolic=True, capture_as_const=True) def fwd(inp): - return F.img_proc.cvt_color(inp, mode="RGB2GRAY") + return F.vision.cvt_color(inp, mode="RGB2GRAY") result = fwd(x) check_pygraph_dump(fwd, [x], [result]) diff --git a/imperative/src/impl/ops/img_proc.cpp b/imperative/src/impl/ops/vision.cpp similarity index 95% rename from imperative/src/impl/ops/img_proc.cpp rename to imperative/src/impl/ops/vision.cpp index 2e1aaceb..6cc0922b 100644 --- a/imperative/src/impl/ops/img_proc.cpp +++ b/imperative/src/impl/ops/vision.cpp @@ -1,5 +1,5 @@ /** - * \file imperative/src/impl/ops/img_proc.cpp + * \file imperative/src/impl/ops/vision.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -31,4 +31,4 @@ OP_TRAIT_REG(CvtColor, CvtColor) .fallback(); } } -} \ No newline at end of file +} -- GitLab