提交 06041f8a 编写于 作者: M Megvii Engine Team

fix(mge/functional): fix trace warp_perspective

GitOrigin-RevId: 2071bb63a879f6d027995324e46003b1c789e15f
上级 f4860b93
...@@ -15,6 +15,7 @@ from ..core.ops._internal import param_defs as P ...@@ -15,6 +15,7 @@ from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import utils from ..core.tensor import utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.utils import astensor1d
from ..distributed import WORLD, is_distributed from ..distributed import WORLD, is_distributed
from ..random import uniform from ..random import uniform
from ..tensor import Tensor from ..tensor import Tensor
...@@ -868,7 +869,8 @@ def warp_perspective( ...@@ -868,7 +869,8 @@ def warp_perspective(
imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val
) )
inp, M = utils.convert_inputs(inp, M) inp, M = utils.convert_inputs(inp, M)
(result,) = apply(op, inp, M, Tensor(dsize)) dsize = astensor1d(dsize, inp, dtype="int32", device=inp.device)
(result,) = apply(op, inp, M, dsize)
return result return result
......
...@@ -13,6 +13,7 @@ import numpy as np ...@@ -13,6 +13,7 @@ import numpy as np
import pytest import pytest
import megengine.core.tensor.megbrain_graph as G import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F
from megengine import cgtools, tensor from megengine import cgtools, tensor
from megengine.core._trace_option import set_tensor_shape from megengine.core._trace_option import set_tensor_shape
from megengine.core.ops import builtin as ops from megengine.core.ops import builtin as ops
...@@ -261,3 +262,36 @@ def test_trace_reshape(): ...@@ -261,3 +262,36 @@ def test_trace_reshape():
f(x1) f(x1)
f(x2) f(x2)
f(x3) f(x3)
def test_trace_topk():
x = tensor([5, 2, 7, 1, 0, 3, 2])
@trace(symbolic=True)
def f(x):
y = F.topk(x, 3)
np.testing.assert_equal(y[0].shape.numpy(), np.array([3,]))
return y
for i in range(3):
f(x)
def test_trace_warp_perspective():
inp_shape = (1, 1, 4, 4)
x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
M_shape = (1, 3, 3)
M = tensor(
np.array(
[[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
).reshape(M_shape)
)
@trace(symbolic=True)
def f(x, M):
out = F.warp_perspective(x, M, (2, 2))
np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2]))
return out
for i in range(1):
f(x, M)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册