diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 283e566b6bb2a17943506e6b53b048be41865f11..4ca4403030fa9536a6c9beffa249f5005527252b 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -643,7 +643,7 @@ class OpSet9(): node.fluid_code.add_layer( 'squeeze', inputs={'input': node, - 'axes': [0]}, + 'axes': [axis]}, output=node, param_attr=None) elif axis == 0 and len(indices_shape) > 1: @@ -1132,10 +1132,22 @@ class OpSet9(): def MatMul(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True) + x_shape = val_x.out_shapes[0] + y_shape = val_y.out_shapes[0] inputs = {"x": val_x, "y": val_y} - attr = {"name": string(node.layer_name)} - node.fluid_code.add_layer( - "matmul", inputs=inputs, output=node, param_attr=attr) + if y_shape[0] == 1 and x_shape[-1] != 1: + y_squeeze = val_y.layer_name + '_squeeze' + node.fluid_code.add_layer( + "squeeze", + inputs=val_y, + output=y_squeeze, + param_attr={'axes': [0]}) + inputs['y'] = y_squeeze + node.fluid_code.add_layer( + "matmul", inputs=inputs, output=node, param_attr=None) + else: + node.fluid_code.add_layer( + "matmul", inputs=inputs, output=node, param_attr=None) @print_mapping_info def BatchNormalization(self, node):