未验证 提交 7c61e521 编写于 作者: W WJJ1995 提交者: GitHub

[Bug] Fixed ONNX Gemm bug (#917)

* fixed Gemm bug

* re-lint
上级 49d90ae4
......@@ -1637,16 +1637,57 @@ class OpSet():
"transpose_x": trans_a,
"transpose_y": trans_b,
}
if abs(alpha - 1.0) < 1e-5:
if abs(beta - 0.0) < 1e-5:
self.paddle_graph.add_layer(
'paddle.matmul',
inputs=matmul_inputs,
outputs=[node.name],
**attr_matmul)
else:
self.paddle_graph.add_layer(
'paddle.matmul',
inputs=matmul_inputs,
outputs=[val_mm],
**attr_matmul)
if abs(beta - 1.0) < 1e-5:
add_inputs = {"x": val_mm, "y": val_c.name}
self.paddle_graph.add_layer(
"paddle.scale", inputs={"x": val_mm}, outputs=[val_mm], scale=alpha)
if beta != 0:
if beta == 1.:
"paddle.add", inputs=add_inputs, outputs=[node.name])
else:
var_beta = node.name + '_beta'
self.paddle_graph.add_layer(
"paddle.scale",
inputs={"x": val_c.name},
outputs=[var_beta],
scale=beta)
add_inputs = {"x": val_mm, "y": var_beta}
self.paddle_graph.add_layer(
"paddle.add", inputs=add_inputs, outputs=[node.name])
else:
if abs(beta - 0.0) < 1e-5:
self.paddle_graph.add_layer(
'paddle.matmul',
inputs=matmul_inputs,
outputs=[val_mm],
**attr_matmul)
self.paddle_graph.add_layer(
"paddle.scale",
inputs={"x": val_mm},
outputs=[node.name],
scale=alpha)
else:
self.paddle_graph.add_layer(
'paddle.matmul',
inputs=[matmul_inputs],
outputs=[val_mm],
**attr_matmul)
self.paddle_graph.add_layer(
"paddle.scale",
inputs={"x": val_mm},
outputs=[val_mm],
scale=alpha)
if abs(beta - 1.0) < 1e-5:
add_inputs = {"x": val_mm, "y": val_c.name}
self.paddle_graph.add_layer(
"paddle.add", inputs=add_inputs, outputs=[node.name])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册