提交 06041f8a 编写于 作者: M Megvii Engine Team

fix(mge/functional): fix trace warp_perspective

GitOrigin-RevId: 2071bb63a879f6d027995324e46003b1c789e15f
上级 f4860b93
......@@ -15,6 +15,7 @@ from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const
from ..core.tensor import utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.utils import astensor1d
from ..distributed import WORLD, is_distributed
from ..random import uniform
from ..tensor import Tensor
......@@ -868,7 +869,8 @@ def warp_perspective(
imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val
)
inp, M = utils.convert_inputs(inp, M)
(result,) = apply(op, inp, M, Tensor(dsize))
dsize = astensor1d(dsize, inp, dtype="int32", device=inp.device)
(result,) = apply(op, inp, M, dsize)
return result
......
......@@ -13,6 +13,7 @@ import numpy as np
import pytest
import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F
from megengine import cgtools, tensor
from megengine.core._trace_option import set_tensor_shape
from megengine.core.ops import builtin as ops
......@@ -261,3 +262,36 @@ def test_trace_reshape():
f(x1)
f(x2)
f(x3)
def test_trace_topk():
x = tensor([5, 2, 7, 1, 0, 3, 2])
@trace(symbolic=True)
def f(x):
y = F.topk(x, 3)
np.testing.assert_equal(y[0].shape.numpy(), np.array([3,]))
return y
for i in range(3):
f(x)
def test_trace_warp_perspective():
inp_shape = (1, 1, 4, 4)
x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
M_shape = (1, 3, 3)
M = tensor(
np.array(
[[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
).reshape(M_shape)
)
@trace(symbolic=True)
def f(x, M):
out = F.warp_perspective(x, M, (2, 2))
np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2]))
return out
for i in range(1):
f(x, M)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册