diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 534d264a1f0e47e20de976e4bdf23dfc440b7606..71b51eb0f1037c6a51cc431d884414ba44a7fc25 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -646,7 +646,7 @@ class OpSet9(): node.fluid_code.add_layer( 'squeeze', inputs={'input': node, - 'axes': [0]}, + 'axes': [axis]}, output=node, param_attr=None) elif axis == 0 and len(indices_shape) > 1: @@ -894,8 +894,6 @@ class OpSet9(): 'this is not supported') if len(value) == 1: value = value[0] - if dtype.name == 'int64': - dtype = 'int32' attr = { 'shape': val_shape.layer_name, 'dtype': string(dtype), @@ -1040,12 +1038,16 @@ class OpSet9(): @print_mapping_info def Concat(self, node): inputs = [] + dtypes = set() for i in range(len(node.layer.input)): ipt = self.graph.get_input_node(node, idx=i, copy=True) if isinstance(ipt, str): inputs.append(ipt) else: inputs.append(ipt.layer_name) + dtypes.add(ipt.dtype) + if len(dtypes) > 1: + assert 'Unspported situation happened, please create issue on https://github.com/PaddlePaddle/X2Paddle/issues.' axis = node.get_attr('axis') attr = {'axis': axis} node.fluid_code.add_layer( @@ -1133,10 +1135,22 @@ class OpSet9(): def MatMul(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True) + x_shape = val_x.out_shapes[0] + y_shape = val_y.out_shapes[0] inputs = {"x": val_x, "y": val_y} - attr = {"name": string(node.layer_name)} - node.fluid_code.add_layer( - "matmul", inputs=inputs, output=node, param_attr=attr) + if y_shape[0] == 1 and x_shape[-1] != 1: + y_squeeze = val_y.layer_name + '_squeeze' + node.fluid_code.add_layer( + "squeeze", + inputs=val_y, + output=y_squeeze, + param_attr={'axes': [0]}) + inputs['y'] = y_squeeze + node.fluid_code.add_layer( + "matmul", inputs=inputs, output=node, param_attr=None) + else: + node.fluid_code.add_layer( + "matmul", inputs=inputs, output=node, param_attr=None) @print_mapping_info def BatchNormalization(self, node):