diff --git a/python_module/megengine/functional/nn.py b/python_module/megengine/functional/nn.py index 5dc27f0dbb0aa801987cb3e38b3ae5a73e151aa4..e579bea2fe770271dd680cba1a52f0da04a3d071 100644 --- a/python_module/megengine/functional/nn.py +++ b/python_module/megengine/functional/nn.py @@ -44,7 +44,7 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor ret = mgb.opr.matrix_mul(inp, weight, transposeB=True) ret = ret.reshape(orig_shape[:-1], weight.shape[0]) if bias is not None: - ret += bias + ret += bias.reshape(1, bias.shape[0]) return ret