diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index bfa3d9b867f69076574485ad67a8bb1b2f905c9f..7caefe90e76c55f9bb7c9f8ffd360b6a161aab26 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -16,7 +16,7 @@ from ..core.ops import builtin from ..core.ops.builtin import BatchNorm from ..core.ops.special import Const from ..core.tensor import megbrain_graph, utils -from ..core.tensor.utils import astensor1d +from ..core.tensor.utils import astensor1d, setscalar from ..distributed import WORLD, is_distributed from ..jit.tracing import is_tracing from ..random import uniform @@ -1133,7 +1133,8 @@ def matmul( def dot(inp1: Tensor, inp2: Tensor) -> Tensor: """ Computes dot-product of two vectors ``inp1`` and ``inp2``. - inputs must be 1-dimensional, scalar input can be automatically broadcasted. + inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted. + Refer to :func:`~.matmul` for more general usage. :param inp1: first vector. :param inp2: second vector. @@ -1156,12 +1157,16 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor: .. testoutput:: - [55.] + 55. """ op = builtin.Dot() inp1, inp2 = utils.convert_inputs(inp1, inp2) + assert ( + inp1.ndim <= 1 and inp2.ndim <= 1 + ), "Input tensors for dot must be 1-dimensional or scalar" (result,) = apply(op, inp1, inp2) + setscalar(result) return result