From f4860b93457017b2cceb428a291c974a459569c8 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 23 Sep 2020 11:20:30 +0800 Subject: [PATCH] fix(mge/functional): fix trace topk GitOrigin-RevId: c88ca8219b0dc25e2f1a4d26c3fea86220c86962 --- imperative/python/megengine/functional/math.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 408fd8482..a8a2e28f8 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -14,8 +14,9 @@ from typing import Optional, Sequence, Tuple, Union from ..core.ops import builtin from ..core.ops._internal import param_defs as P +from ..core.ops.special import Const from ..core.tensor import utils -from ..core.tensor.core import apply +from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..tensor import Tensor from .elemwise import clamp, exp, log, log1p from .tensor import add_axis, remove_axis, reshape @@ -665,15 +666,18 @@ def topk( mode = Mode.VALUE_IDX_SORTED op = builtin.TopK(mode=mode) + if not isinstance(k, (TensorBase, TensorWrapperBase)): + (k,) = Const(k, dtype="int32", device=inp.device)(inp) + if len(inp.shape) == 1: inp = inp.reshape(1, -1) - res = apply(op, inp, Tensor(k, dtype="int32")) + res = apply(op, inp, k) if kth_only: tns = res[0] else: tns, ind = res[0][0], res[1][0] else: - res = apply(op, inp, Tensor(k, dtype="int32")) + res = apply(op, inp, k) if kth_only: tns = res else: -- GitLab