提交 0d8b9136 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mge/functional): reshape bias to (1, out_features) in linear

GitOrigin-RevId: a15880b7fc8917e36c5882ab09e8ce3f2c7b2727
上级 f60ab501
...@@ -44,7 +44,7 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor ...@@ -44,7 +44,7 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor
ret = mgb.opr.matrix_mul(inp, weight, transposeB=True) ret = mgb.opr.matrix_mul(inp, weight, transposeB=True)
ret = ret.reshape(orig_shape[:-1], weight.shape[0]) ret = ret.reshape(orig_shape[:-1], weight.shape[0])
if bias is not None: if bias is not None:
ret += bias ret += bias.reshape(1, bias.shape[0])
return ret return ret
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册