From d0aa9b41ee44480937aa3375f22cdd53af4e49d1 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 31 Mar 2021 14:57:24 +0800 Subject: [PATCH] refactor(mge/functional): move nvof to vision, compatible with old usage GitOrigin-RevId: cd9d9b4f5a4c6a55daa6d477bce5ca308805417a --- imperative/python/megengine/functional/nn.py | 31 +------------------ .../python/megengine/functional/vision.py | 30 ++++++++++++++++++ 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 16bfaeb7d..ea3b0ae3e 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1409,36 +1409,6 @@ def conv1d( return output -def nvof(src: Tensor, precision: int = 1) -> Tensor: - r""" - Implements NVIDIA Optical Flow SDK. - - :src shape: input tensor with shape (n, t, h, w, c4). - :src dtype: uint8. - :param precision: 0:NV_OF_PERF_LEVEL_SLOW 1:NV_OF_PERF_LEVEL_MEDIUM 2:NV_OF_PERF_LEVEL_FAST. - :output shape: (n, t-1, h//4, w//4, c2). - :output dtype: int16. - - .. code-block:: python - - import numpy as np - from megengine import tensor - import megengine.functional as F - - x = np.random.random_integers(0, 255, (1,2,224,244,4)).astype("uint8") - src = tensor(x) - result = F.nn.nvof(src, precision=1) - print(result.numpy()) - - """ - assert src.ndim == 5 and src.shape[4] == 4 - - src = src.detach() - - op = builtin.NvOf(precision=precision) - return apply(op, src)[0] - - def hswish(x): """ Element-wise `x * relu6(x + 3) / 6`. @@ -1492,6 +1462,7 @@ roi_align = deprecated_func("1.3", "megengine.functional.vision", "roi_align", T nms = deprecated_func("1.3", "megengine.functional.vision", "nms", True) resize = deprecated_func("1.3", "megengine.functional.vision", "resize", True) remap = deprecated_func("1.3", "megengine.functional.vision", "remap", True) +nvof = deprecated_func("1.3", "megengine.functional.vision", "nvof", True) warp_affine = deprecated_func("1.3", "megengine.functional.vision", "warp_affine", True) warp_perspective = deprecated_func( "1.3", "megengine.functional.vision", "warp_perspective", True diff --git a/imperative/python/megengine/functional/vision.py b/imperative/python/megengine/functional/vision.py index 4a367248e..b2a975869 100644 --- a/imperative/python/megengine/functional/vision.py +++ b/imperative/python/megengine/functional/vision.py @@ -574,3 +574,33 @@ def interpolate( if mode == "LINEAR": ret = reshape(ret, ret.shape[0:3]) return ret + + +def nvof(src: Tensor, precision: int = 1) -> Tensor: + r""" + Implements NVIDIA Optical Flow SDK. + + :src shape: input tensor with shape (n, t, h, w, c4). + :src dtype: uint8. + :param precision: 0:NV_OF_PERF_LEVEL_SLOW 1:NV_OF_PERF_LEVEL_MEDIUM 2:NV_OF_PERF_LEVEL_FAST. + :output shape: (n, t-1, h//4, w//4, c2). + :output dtype: int16. + + .. code-block:: python + + import numpy as np + from megengine import tensor + import megengine.functional as F + + x = np.random.random_integers(0, 255, (1,2,224,244,4)).astype("uint8") + src = tensor(x) + result = F.nn.nvof(src, precision=1) + print(result.numpy()) + + """ + assert src.ndim == 5 and src.shape[4] == 4 + + src = src.detach() + + op = builtin.NvOf(precision=precision) + return apply(op, src)[0] -- GitLab