From 08770385acefa7755f88ec1159542b21c820a990 Mon Sep 17 00:00:00 2001 From: Channingss Date: Tue, 11 Aug 2020 12:01:04 +0000 Subject: [PATCH] matmul support [m,n,k] * [1,k,z] --- .../op_mapper/onnx2paddle/opset9/opset.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 283e566..4ca4403 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -643,7 +643,7 @@ class OpSet9(): node.fluid_code.add_layer( 'squeeze', inputs={'input': node, - 'axes': [0]}, + 'axes': [axis]}, output=node, param_attr=None) elif axis == 0 and len(indices_shape) > 1: @@ -1132,10 +1132,22 @@ class OpSet9(): def MatMul(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True) + x_shape = val_x.out_shapes[0] + y_shape = val_y.out_shapes[0] inputs = {"x": val_x, "y": val_y} - attr = {"name": string(node.layer_name)} - node.fluid_code.add_layer( - "matmul", inputs=inputs, output=node, param_attr=attr) + if y_shape[0] == 1 and x_shape[-1] != 1: + y_squeeze = val_y.layer_name + '_squeeze' + node.fluid_code.add_layer( + "squeeze", + inputs=val_y, + output=y_squeeze, + param_attr={'axes': [0]}) + inputs['y'] = y_squeeze + node.fluid_code.add_layer( + "matmul", inputs=inputs, output=node, param_attr=None) + else: + node.fluid_code.add_layer( + "matmul", inputs=inputs, output=node, param_attr=None) @print_mapping_info def BatchNormalization(self, node): -- GitLab