未验证 提交 7c61e521 编写于 作者: W WJJ1995 提交者: GitHub

[Bug] Fixed ONNX Gemm bug (#917)

* fixed Gemm bug

* re-lint
上级 49d90ae4
...@@ -1637,29 +1637,70 @@ class OpSet(): ...@@ -1637,29 +1637,70 @@ class OpSet():
"transpose_x": trans_a, "transpose_x": trans_a,
"transpose_y": trans_b, "transpose_y": trans_b,
} }
self.paddle_graph.add_layer( if abs(alpha - 1.0) < 1e-5:
'paddle.matmul', if abs(beta - 0.0) < 1e-5:
inputs=matmul_inputs,
outputs=[val_mm],
**attr_matmul)
self.paddle_graph.add_layer(
"paddle.scale", inputs={"x": val_mm}, outputs=[val_mm], scale=alpha)
if beta != 0:
if beta == 1.:
add_inputs = {"x": val_mm, "y": val_c.name}
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.add", inputs=add_inputs, outputs=[node.name]) 'paddle.matmul',
inputs=matmul_inputs,
outputs=[node.name],
**attr_matmul)
else: else:
var_beta = node.name + '_beta' self.paddle_graph.add_layer(
'paddle.matmul',
inputs=matmul_inputs,
outputs=[val_mm],
**attr_matmul)
if abs(beta - 1.0) < 1e-5:
add_inputs = {"x": val_mm, "y": val_c.name}
self.paddle_graph.add_layer(
"paddle.add", inputs=add_inputs, outputs=[node.name])
else:
var_beta = node.name + '_beta'
self.paddle_graph.add_layer(
"paddle.scale",
inputs={"x": val_c.name},
outputs=[var_beta],
scale=beta)
add_inputs = {"x": val_mm, "y": var_beta}
self.paddle_graph.add_layer(
"paddle.add", inputs=add_inputs, outputs=[node.name])
else:
if abs(beta - 0.0) < 1e-5:
self.paddle_graph.add_layer(
'paddle.matmul',
inputs=matmul_inputs,
outputs=[val_mm],
**attr_matmul)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.scale", "paddle.scale",
inputs={"x": val_c.name}, inputs={"x": val_mm},
outputs=[var_beta], outputs=[node.name],
scale=beta) scale=alpha)
add_inputs = {"x": val_mm, "y": var_beta} else:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.add", inputs=add_inputs, outputs=[node.name]) 'paddle.matmul',
inputs=[matmul_inputs],
outputs=[val_mm],
**attr_matmul)
self.paddle_graph.add_layer(
"paddle.scale",
inputs={"x": val_mm},
outputs=[val_mm],
scale=alpha)
if abs(beta - 1.0) < 1e-5:
add_inputs = {"x": val_mm, "y": val_c.name}
self.paddle_graph.add_layer(
"paddle.add", inputs=add_inputs, outputs=[node.name])
else:
var_beta = node.name + '_beta'
self.paddle_graph.add_layer(
"paddle.scale",
inputs={"x": val_c.name},
outputs=[var_beta],
scale=beta)
add_inputs = {"x": val_mm, "y": var_beta}
self.paddle_graph.add_layer(
"paddle.add", inputs=add_inputs, outputs=[node.name])
@print_mapping_info @print_mapping_info
def Sum(self, node): def Sum(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册