From f856b1708592c5f80320659654f8159c9a38c088 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 3 Jul 2020 16:09:44 +0800 Subject: [PATCH] fix(mge/functional): reshape bias to (1, out_features) in linear GitOrigin-RevId: a15880b7fc8917e36c5882ab09e8ce3f2c7b2727 --- python_module/megengine/functional/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python_module/megengine/functional/nn.py b/python_module/megengine/functional/nn.py index 5dc27f0d..e579bea2 100644 --- a/python_module/megengine/functional/nn.py +++ b/python_module/megengine/functional/nn.py @@ -44,7 +44,7 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor ret = mgb.opr.matrix_mul(inp, weight, transposeB=True) ret = ret.reshape(orig_shape[:-1], weight.shape[0]) if bias is not None: - ret += bias + ret += bias.reshape(1, bias.shape[0]) return ret -- GitLab