提交 f856b170 编写于 作者: M Megvii Engine Team

fix(mge/functional): reshape bias to (1, out_features) in linear

GitOrigin-RevId: a15880b7fc8917e36c5882ab09e8ce3f2c7b2727
上级 486cbdea
......@@ -44,7 +44,7 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor
ret = mgb.opr.matrix_mul(inp, weight, transposeB=True)
ret = ret.reshape(orig_shape[:-1], weight.shape[0])
if bias is not None:
ret += bias
ret += bias.reshape(1, bias.shape[0])
return ret
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册