From 7c61e52187d7e859b88aacf23707797a82cf49f9 Mon Sep 17 00:00:00 2001 From: WJJ1995 Date: Thu, 8 Dec 2022 11:16:01 +0800 Subject: [PATCH] [Bug] Fixed ONNX Gemm bug (#917) * fixed Gemm bug * re-lint --- .../op_mapper/onnx2paddle/opset_legacy.py | 77 ++++++++++++++----- 1 file changed, 59 insertions(+), 18 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset_legacy.py b/x2paddle/op_mapper/onnx2paddle/opset_legacy.py index 808dd39..5f39b35 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset_legacy.py +++ b/x2paddle/op_mapper/onnx2paddle/opset_legacy.py @@ -1637,29 +1637,70 @@ class OpSet(): "transpose_x": trans_a, "transpose_y": trans_b, } - self.paddle_graph.add_layer( - 'paddle.matmul', - inputs=matmul_inputs, - outputs=[val_mm], - **attr_matmul) - self.paddle_graph.add_layer( - "paddle.scale", inputs={"x": val_mm}, outputs=[val_mm], scale=alpha) - - if beta != 0: - if beta == 1.: - add_inputs = {"x": val_mm, "y": val_c.name} + if abs(alpha - 1.0) < 1e-5: + if abs(beta - 0.0) < 1e-5: self.paddle_graph.add_layer( - "paddle.add", inputs=add_inputs, outputs=[node.name]) + 'paddle.matmul', + inputs=matmul_inputs, + outputs=[node.name], + **attr_matmul) else: - var_beta = node.name + '_beta' + self.paddle_graph.add_layer( + 'paddle.matmul', + inputs=matmul_inputs, + outputs=[val_mm], + **attr_matmul) + if abs(beta - 1.0) < 1e-5: + add_inputs = {"x": val_mm, "y": val_c.name} + self.paddle_graph.add_layer( + "paddle.add", inputs=add_inputs, outputs=[node.name]) + else: + var_beta = node.name + '_beta' + self.paddle_graph.add_layer( + "paddle.scale", + inputs={"x": val_c.name}, + outputs=[var_beta], + scale=beta) + add_inputs = {"x": val_mm, "y": var_beta} + self.paddle_graph.add_layer( + "paddle.add", inputs=add_inputs, outputs=[node.name]) + else: + if abs(beta - 0.0) < 1e-5: + self.paddle_graph.add_layer( + 'paddle.matmul', + inputs=matmul_inputs, + outputs=[val_mm], + **attr_matmul) self.paddle_graph.add_layer( "paddle.scale", - inputs={"x": val_c.name}, - outputs=[var_beta], - scale=beta) - add_inputs = {"x": val_mm, "y": var_beta} + inputs={"x": val_mm}, + outputs=[node.name], + scale=alpha) + else: self.paddle_graph.add_layer( - "paddle.add", inputs=add_inputs, outputs=[node.name]) + 'paddle.matmul', + inputs=[matmul_inputs], + outputs=[val_mm], + **attr_matmul) + self.paddle_graph.add_layer( + "paddle.scale", + inputs={"x": val_mm}, + outputs=[val_mm], + scale=alpha) + if abs(beta - 1.0) < 1e-5: + add_inputs = {"x": val_mm, "y": val_c.name} + self.paddle_graph.add_layer( + "paddle.add", inputs=add_inputs, outputs=[node.name]) + else: + var_beta = node.name + '_beta' + self.paddle_graph.add_layer( + "paddle.scale", + inputs={"x": val_c.name}, + outputs=[var_beta], + scale=beta) + add_inputs = {"x": val_mm, "y": var_beta} + self.paddle_graph.add_layer( + "paddle.add", inputs=add_inputs, outputs=[node.name]) @print_mapping_info def Sum(self, node): -- GitLab