提交 44742e32 编写于 作者: M Megvii Engine Team

fix(mge/api): check input dim of dot and mark output as scalar

GitOrigin-RevId: a3ba7e099ce44a1269d9363406748bc49e2b3cf6
上级 697f70c0
......@@ -16,7 +16,7 @@ from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph, utils
from ..core.tensor.utils import astensor1d
from ..core.tensor.utils import astensor1d, setscalar
from ..distributed import WORLD, is_distributed
from ..jit.tracing import is_tracing
from ..random import uniform
......@@ -1133,7 +1133,8 @@ def matmul(
def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
"""
Computes dot-product of two vectors ``inp1`` and ``inp2``.
inputs must be 1-dimensional, scalar input can be automatically broadcasted.
inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted.
Refer to :func:`~.matmul` for more general usage.
:param inp1: first vector.
:param inp2: second vector.
......@@ -1156,12 +1157,16 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
.. testoutput::
[55.]
55.
"""
op = builtin.Dot()
inp1, inp2 = utils.convert_inputs(inp1, inp2)
assert (
inp1.ndim <= 1 and inp2.ndim <= 1
), "Input tensors for dot must be 1-dimensional or scalar"
(result,) = apply(op, inp1, inp2)
setscalar(result)
return result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册