未验证 提交 7c61e521 编写于 作者: W WJJ1995 提交者: GitHub

[Bug] Fixed ONNX Gemm bug (#917)

* fixed Gemm bug

* re-lint
上级 49d90ae4
...@@ -1637,16 +1637,57 @@ class OpSet(): ...@@ -1637,16 +1637,57 @@ class OpSet():
"transpose_x": trans_a, "transpose_x": trans_a,
"transpose_y": trans_b, "transpose_y": trans_b,
} }
if abs(alpha - 1.0) < 1e-5:
if abs(beta - 0.0) < 1e-5:
self.paddle_graph.add_layer(
'paddle.matmul',
inputs=matmul_inputs,
outputs=[node.name],
**attr_matmul)
else:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.matmul', 'paddle.matmul',
inputs=matmul_inputs, inputs=matmul_inputs,
outputs=[val_mm], outputs=[val_mm],
**attr_matmul) **attr_matmul)
if abs(beta - 1.0) < 1e-5:
add_inputs = {"x": val_mm, "y": val_c.name}
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.scale", inputs={"x": val_mm}, outputs=[val_mm], scale=alpha) "paddle.add", inputs=add_inputs, outputs=[node.name])
else:
if beta != 0: var_beta = node.name + '_beta'
if beta == 1.: self.paddle_graph.add_layer(
"paddle.scale",
inputs={"x": val_c.name},
outputs=[var_beta],
scale=beta)
add_inputs = {"x": val_mm, "y": var_beta}
self.paddle_graph.add_layer(
"paddle.add", inputs=add_inputs, outputs=[node.name])
else:
if abs(beta - 0.0) < 1e-5:
self.paddle_graph.add_layer(
'paddle.matmul',
inputs=matmul_inputs,
outputs=[val_mm],
**attr_matmul)
self.paddle_graph.add_layer(
"paddle.scale",
inputs={"x": val_mm},
outputs=[node.name],
scale=alpha)
else:
self.paddle_graph.add_layer(
'paddle.matmul',
inputs=[matmul_inputs],
outputs=[val_mm],
**attr_matmul)
self.paddle_graph.add_layer(
"paddle.scale",
inputs={"x": val_mm},
outputs=[val_mm],
scale=alpha)
if abs(beta - 1.0) < 1e-5:
add_inputs = {"x": val_mm, "y": val_c.name} add_inputs = {"x": val_mm, "y": val_c.name}
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.add", inputs=add_inputs, outputs=[node.name]) "paddle.add", inputs=add_inputs, outputs=[node.name])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册