diff --git a/x2paddle/op_mapper/paddle2onnx/opset9/opset.py b/x2paddle/op_mapper/paddle2onnx/opset9/opset.py index 14cba9a4f422cd042cb06e33a4f678291d74a564..ee6ab5b7bb3eaadd040d13218c5cc6e9c8152784 100644 --- a/x2paddle/op_mapper/paddle2onnx/opset9/opset.py +++ b/x2paddle/op_mapper/paddle2onnx/opset9/opset.py @@ -155,6 +155,7 @@ class OpSet9(object): return node def elementwise_add(self, op, block): + print(op.input('Y')) axis = op.attr('axis') x_shape = block.var(op.input('X')[0]).shape y_shape = block.var(op.input('Y')[0]).shape @@ -174,14 +175,14 @@ class OpSet9(object): inputs=[op.input('X')[0], temp_value], outputs=op.output('Out')) return [shape_node, y_node, node] - elif len(x_shape) == len(y_shape): + elif axis == -1 or axis == 0 or axis == (len(x_shape) - 1): node = helper.make_node( 'Add', inputs=[op.input('X')[0], op.input('Y')[0]], outputs=op.output('Out')) return node else: - raise Excpetion("Unexpected situation happend in elementwise_add") + raise Exception("Unexpected situation happend in elementwise_add") def elementwise_sub(self, op, block): axis = op.attr('axis') @@ -203,14 +204,14 @@ class OpSet9(object): inputs=[op.input('X')[0], temp_value], outputs=op.output('Out')) return [shape_node, y_node, node] - elif len(x_shape) == len(y_shape): + elif axis == -1 or axis == 0 or axis == (len(x_shape) - 1): node = helper.make_node( 'Sub', inputs=[op.input('X')[0], op.input('Y')[0]], outputs=op.output('Out')) return node else: - raise Excpetion("Unexpected situation happend in elementwise_sub") + raise Exception("Unexpected situation happend in elementwise_sub") def pool2d(self, op, block): pool_type = { @@ -763,14 +764,14 @@ class OpSet9(object): inputs=[op.input('X')[0], temp_value], outputs=op.output('Out')) return [shape_node, y_node, node] - elif len(x_shape) == len(y_shape): + elif axis == -1 or axis == 0 or axis == (len(x_shape) - 1): node = helper.make_node( 'Mul', inputs=[op.input('X')[0], op.input('Y')[0]], outputs=op.output('Out')) return node else: - raise Excpetion("Unexpected situation happend in elementwise_add") + raise Exception("Unexpected situation happend in elementwise_mul") return node def feed(self, op, block):