提交 44742e32 编写于 作者: M Megvii Engine Team

fix(mge/api): check input dim of dot and mark output as scalar

GitOrigin-RevId: a3ba7e099ce44a1269d9363406748bc49e2b3cf6
上级 697f70c0
...@@ -16,7 +16,7 @@ from ..core.ops import builtin ...@@ -16,7 +16,7 @@ from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm from ..core.ops.builtin import BatchNorm
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import megbrain_graph, utils from ..core.tensor import megbrain_graph, utils
from ..core.tensor.utils import astensor1d from ..core.tensor.utils import astensor1d, setscalar
from ..distributed import WORLD, is_distributed from ..distributed import WORLD, is_distributed
from ..jit.tracing import is_tracing from ..jit.tracing import is_tracing
from ..random import uniform from ..random import uniform
...@@ -1133,7 +1133,8 @@ def matmul( ...@@ -1133,7 +1133,8 @@ def matmul(
def dot(inp1: Tensor, inp2: Tensor) -> Tensor: def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
""" """
Computes dot-product of two vectors ``inp1`` and ``inp2``. Computes dot-product of two vectors ``inp1`` and ``inp2``.
inputs must be 1-dimensional, scalar input can be automatically broadcasted. inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted.
Refer to :func:`~.matmul` for more general usage.
:param inp1: first vector. :param inp1: first vector.
:param inp2: second vector. :param inp2: second vector.
...@@ -1156,12 +1157,16 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor: ...@@ -1156,12 +1157,16 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
.. testoutput:: .. testoutput::
[55.] 55.
""" """
op = builtin.Dot() op = builtin.Dot()
inp1, inp2 = utils.convert_inputs(inp1, inp2) inp1, inp2 = utils.convert_inputs(inp1, inp2)
assert (
inp1.ndim <= 1 and inp2.ndim <= 1
), "Input tensors for dot must be 1-dimensional or scalar"
(result,) = apply(op, inp1, inp2) (result,) = apply(op, inp1, inp2)
setscalar(result)
return result return result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册