diff --git a/x2paddle/op_mapper/paddle2onnx/opset9/opset.py b/x2paddle/op_mapper/paddle2onnx/opset9/opset.py index 14cba9a4f422cd042cb06e33a4f678291d74a564..89c6525953208d247307ebcfa17781bbd438f164 100644 --- a/x2paddle/op_mapper/paddle2onnx/opset9/opset.py +++ b/x2paddle/op_mapper/paddle2onnx/opset9/opset.py @@ -174,14 +174,15 @@ class OpSet9(object): inputs=[op.input('X')[0], temp_value], outputs=op.output('Out')) return [shape_node, y_node, node] - elif len(x_shape) == len(y_shape): + elif axis == -1 or axis == (len(x_shape) - 1 + ) or len(x_shape) == len(y_shape): node = helper.make_node( 'Add', inputs=[op.input('X')[0], op.input('Y')[0]], outputs=op.output('Out')) return node else: - raise Excpetion("Unexpected situation happend in elementwise_add") + raise Exception("Unexpected situation happend in elementwise_add") def elementwise_sub(self, op, block): axis = op.attr('axis') @@ -203,14 +204,15 @@ class OpSet9(object): inputs=[op.input('X')[0], temp_value], outputs=op.output('Out')) return [shape_node, y_node, node] - elif len(x_shape) == len(y_shape): + elif axis == -1 or axis == (len(x_shape) - 1 + ) or len(x_shape) == len(y_shape): node = helper.make_node( 'Sub', inputs=[op.input('X')[0], op.input('Y')[0]], outputs=op.output('Out')) return node else: - raise Excpetion("Unexpected situation happend in elementwise_sub") + raise Exception("Unexpected situation happend in elementwise_sub") def pool2d(self, op, block): pool_type = { @@ -565,7 +567,7 @@ class OpSet9(object): input_shape = block.vars[op.input('X')[0]].shape if op.attr('align_corners') or op.attr('align_mode') == 0: raise Exception( - "Resize in onnx(opset<=10) only support coordinate_transformation_mode: 'asymmetric', Try converting with --onnx_opest 11" + "Resize in onnx(opset<=10) only support coordinate_transformation_mode: 'asymmetric', Try converting with --onnx_opset 11" ) if ('OutSize' in input_names and len(op.input('OutSize')) > 0) or ( 'SizeTensor' in input_names and @@ -763,14 +765,15 @@ class OpSet9(object): inputs=[op.input('X')[0], temp_value], outputs=op.output('Out')) return [shape_node, y_node, node] - elif len(x_shape) == len(y_shape): + elif axis == -1 or axis == (len(x_shape) - 1 + ) or len(x_shape) == len(y_shape): node = helper.make_node( 'Mul', inputs=[op.input('X')[0], op.input('Y')[0]], outputs=op.output('Out')) return node else: - raise Excpetion("Unexpected situation happend in elementwise_add") + raise Exception("Unexpected situation happend in elementwise_mul") return node def feed(self, op, block):