提交 d0aa9b41 编写于 作者: M Megvii Engine Team

refactor(mge/functional): move nvof to vision, compatible with old usage

GitOrigin-RevId: cd9d9b4f5a4c6a55daa6d477bce5ca308805417a
上级 ff755451
...@@ -1409,36 +1409,6 @@ def conv1d( ...@@ -1409,36 +1409,6 @@ def conv1d(
return output return output
def nvof(src: Tensor, precision: int = 1) -> Tensor:
r"""
Implements NVIDIA Optical Flow SDK.
:src shape: input tensor with shape (n, t, h, w, c4).
:src dtype: uint8.
:param precision: 0:NV_OF_PERF_LEVEL_SLOW 1:NV_OF_PERF_LEVEL_MEDIUM 2:NV_OF_PERF_LEVEL_FAST.
:output shape: (n, t-1, h//4, w//4, c2).
:output dtype: int16.
.. code-block:: python
import numpy as np
from megengine import tensor
import megengine.functional as F
x = np.random.random_integers(0, 255, (1,2,224,244,4)).astype("uint8")
src = tensor(x)
result = F.nn.nvof(src, precision=1)
print(result.numpy())
"""
assert src.ndim == 5 and src.shape[4] == 4
src = src.detach()
op = builtin.NvOf(precision=precision)
return apply(op, src)[0]
def hswish(x): def hswish(x):
""" """
Element-wise `x * relu6(x + 3) / 6`. Element-wise `x * relu6(x + 3) / 6`.
...@@ -1492,6 +1462,7 @@ roi_align = deprecated_func("1.3", "megengine.functional.vision", "roi_align", T ...@@ -1492,6 +1462,7 @@ roi_align = deprecated_func("1.3", "megengine.functional.vision", "roi_align", T
nms = deprecated_func("1.3", "megengine.functional.vision", "nms", True) nms = deprecated_func("1.3", "megengine.functional.vision", "nms", True)
resize = deprecated_func("1.3", "megengine.functional.vision", "resize", True) resize = deprecated_func("1.3", "megengine.functional.vision", "resize", True)
remap = deprecated_func("1.3", "megengine.functional.vision", "remap", True) remap = deprecated_func("1.3", "megengine.functional.vision", "remap", True)
nvof = deprecated_func("1.3", "megengine.functional.vision", "nvof", True)
warp_affine = deprecated_func("1.3", "megengine.functional.vision", "warp_affine", True) warp_affine = deprecated_func("1.3", "megengine.functional.vision", "warp_affine", True)
warp_perspective = deprecated_func( warp_perspective = deprecated_func(
"1.3", "megengine.functional.vision", "warp_perspective", True "1.3", "megengine.functional.vision", "warp_perspective", True
......
...@@ -574,3 +574,33 @@ def interpolate( ...@@ -574,3 +574,33 @@ def interpolate(
if mode == "LINEAR": if mode == "LINEAR":
ret = reshape(ret, ret.shape[0:3]) ret = reshape(ret, ret.shape[0:3])
return ret return ret
def nvof(src: Tensor, precision: int = 1) -> Tensor:
r"""
Implements NVIDIA Optical Flow SDK.
:src shape: input tensor with shape (n, t, h, w, c4).
:src dtype: uint8.
:param precision: 0:NV_OF_PERF_LEVEL_SLOW 1:NV_OF_PERF_LEVEL_MEDIUM 2:NV_OF_PERF_LEVEL_FAST.
:output shape: (n, t-1, h//4, w//4, c2).
:output dtype: int16.
.. code-block:: python
import numpy as np
from megengine import tensor
import megengine.functional as F
x = np.random.random_integers(0, 255, (1,2,224,244,4)).astype("uint8")
src = tensor(x)
result = F.nn.nvof(src, precision=1)
print(result.numpy())
"""
assert src.ndim == 5 and src.shape[4] == 4
src = src.detach()
op = builtin.NvOf(precision=precision)
return apply(op, src)[0]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册