From 8340c4a62a674b28e76dfc2587e778e66c32c33c Mon Sep 17 00:00:00 2001 From: Channingss Date: Fri, 7 Aug 2020 02:14:12 +0000 Subject: [PATCH] fix bug of elementwise_ops --- x2paddle/op_mapper/paddle2onnx/opset9/opset.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/x2paddle/op_mapper/paddle2onnx/opset9/opset.py b/x2paddle/op_mapper/paddle2onnx/opset9/opset.py index 14cba9a..ee6ab5b 100644 --- a/x2paddle/op_mapper/paddle2onnx/opset9/opset.py +++ b/x2paddle/op_mapper/paddle2onnx/opset9/opset.py @@ -155,6 +155,7 @@ class OpSet9(object): return node def elementwise_add(self, op, block): + print(op.input('Y')) axis = op.attr('axis') x_shape = block.var(op.input('X')[0]).shape y_shape = block.var(op.input('Y')[0]).shape @@ -174,14 +175,14 @@ class OpSet9(object): inputs=[op.input('X')[0], temp_value], outputs=op.output('Out')) return [shape_node, y_node, node] - elif len(x_shape) == len(y_shape): + elif axis == -1 or axis == 0 or axis == (len(x_shape) - 1): node = helper.make_node( 'Add', inputs=[op.input('X')[0], op.input('Y')[0]], outputs=op.output('Out')) return node else: - raise Excpetion("Unexpected situation happend in elementwise_add") + raise Exception("Unexpected situation happend in elementwise_add") def elementwise_sub(self, op, block): axis = op.attr('axis') @@ -203,14 +204,14 @@ class OpSet9(object): inputs=[op.input('X')[0], temp_value], outputs=op.output('Out')) return [shape_node, y_node, node] - elif len(x_shape) == len(y_shape): + elif axis == -1 or axis == 0 or axis == (len(x_shape) - 1): node = helper.make_node( 'Sub', inputs=[op.input('X')[0], op.input('Y')[0]], outputs=op.output('Out')) return node else: - raise Excpetion("Unexpected situation happend in elementwise_sub") + raise Exception("Unexpected situation happend in elementwise_sub") def pool2d(self, op, block): pool_type = { @@ -763,14 +764,14 @@ class OpSet9(object): inputs=[op.input('X')[0], temp_value], outputs=op.output('Out')) return [shape_node, y_node, node] - elif len(x_shape) == len(y_shape): + elif axis == -1 or axis == 0 or axis == (len(x_shape) - 1): node = helper.make_node( 'Mul', inputs=[op.input('X')[0], op.input('Y')[0]], outputs=op.output('Out')) return node else: - raise Excpetion("Unexpected situation happend in elementwise_add") + raise Exception("Unexpected situation happend in elementwise_mul") return node def feed(self, op, block): -- GitLab