提交 f4860b93 编写于 作者: M Megvii Engine Team

fix(mge/functional): fix trace topk

GitOrigin-RevId: c88ca8219b0dc25e2f1a4d26c3fea86220c86962
上级 310c805f
......@@ -14,8 +14,9 @@ from typing import Optional, Sequence, Tuple, Union
from ..core.ops import builtin
from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const
from ..core.tensor import utils
from ..core.tensor.core import apply
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..tensor import Tensor
from .elemwise import clamp, exp, log, log1p
from .tensor import add_axis, remove_axis, reshape
......@@ -665,15 +666,18 @@ def topk(
mode = Mode.VALUE_IDX_SORTED
op = builtin.TopK(mode=mode)
if not isinstance(k, (TensorBase, TensorWrapperBase)):
(k,) = Const(k, dtype="int32", device=inp.device)(inp)
if len(inp.shape) == 1:
inp = inp.reshape(1, -1)
res = apply(op, inp, Tensor(k, dtype="int32"))
res = apply(op, inp, k)
if kth_only:
tns = res[0]
else:
tns, ind = res[0][0], res[1][0]
else:
res = apply(op, inp, Tensor(k, dtype="int32"))
res = apply(op, inp, k)
if kth_only:
tns = res
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册