From 06041f8a7e7c8a98db78e06dd21798e5ee769e4e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 23 Sep 2020 11:42:51 +0800 Subject: [PATCH] fix(mge/functional): fix trace warp_perspective GitOrigin-RevId: 2071bb63a879f6d027995324e46003b1c789e15f --- imperative/python/megengine/functional/nn.py | 4 ++- imperative/python/test/unit/test_tracing.py | 34 ++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 2282956b1..a8f9397a0 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -15,6 +15,7 @@ from ..core.ops._internal import param_defs as P from ..core.ops.special import Const from ..core.tensor import utils from ..core.tensor.core import TensorBase, TensorWrapperBase, apply +from ..core.tensor.utils import astensor1d from ..distributed import WORLD, is_distributed from ..random import uniform from ..tensor import Tensor @@ -868,7 +869,8 @@ def warp_perspective( imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val ) inp, M = utils.convert_inputs(inp, M) - (result,) = apply(op, inp, M, Tensor(dsize)) + dsize = astensor1d(dsize, inp, dtype="int32", device=inp.device) + (result,) = apply(op, inp, M, dsize) return result diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index 3a6a28d5f..ecc811bd2 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -13,6 +13,7 @@ import numpy as np import pytest import megengine.core.tensor.megbrain_graph as G +import megengine.functional as F from megengine import cgtools, tensor from megengine.core._trace_option import set_tensor_shape from megengine.core.ops import builtin as ops @@ -261,3 +262,36 @@ def test_trace_reshape(): f(x1) f(x2) f(x3) + + +def test_trace_topk(): + x = tensor([5, 2, 7, 1, 0, 3, 2]) + + @trace(symbolic=True) + def f(x): + y = F.topk(x, 3) + np.testing.assert_equal(y[0].shape.numpy(), np.array([3,])) + return y + + for i in range(3): + f(x) + + +def test_trace_warp_perspective(): + inp_shape = (1, 1, 4, 4) + x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) + M_shape = (1, 3, 3) + M = tensor( + np.array( + [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32 + ).reshape(M_shape) + ) + + @trace(symbolic=True) + def f(x, M): + out = F.warp_perspective(x, M, (2, 2)) + np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2])) + return out + + for i in range(1): + f(x, M) -- GitLab