From a91d37bdb0971b46d6b587665f9032b79afee283 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Fri, 13 May 2022 16:27:40 +0800 Subject: [PATCH] fixed Flatten --- .../op_mapper/onnx2paddle/opset9/opset.py | 70 +++++++++++++++---- 1 file changed, 58 insertions(+), 12 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index a4b3a15..104bd67 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -1627,20 +1627,66 @@ class OpSet9(): val_x = self.graph.get_input_node(node, idx=0, copy=True) output_shape = val_x.out_shapes[0] axis = node.get_attr('axis', 1) - shape_list = [1, 1] if axis == 0: - for s in output_shape: - shape_list[1] *= s + self.paddle_graph.add_layer( + 'paddle.reshape', + inputs={"x": val_x.name}, + outputs=[node.name], + shape=[1, -1]) else: - for s in output_shape[:axis]: - shape_list[0] *= s - for s in output_shape[axis:]: - shape_list[1] *= s - self.paddle_graph.add_layer( - 'paddle.reshape', - inputs={"x": val_x.name}, - outputs=[node.name], - shape=shape_list) + if len(output_shape) != 0: + shape_list = [1, 1] + for s in output_shape[:axis]: + shape_list[0] *= s + for s in output_shape[axis:]: + shape_list[1] *= s + self.paddle_graph.add_layer( + 'paddle.reshape', + inputs={"x": val_x.name}, + outputs=[node.name], + shape=shape_list) + else: + self.paddle_graph.add_layer( + 'paddle.shape', + inputs={"input": val_x.name}, + outputs=[val_x.name + "_shape"]) + self.paddle_graph.add_layer( + "paddle.slice", + inputs={"input": val_x.name + "_shape"}, + outputs=[val_x.name + "_shape_first"], + axes=[0], + starts=[0], + ends=[axis]) + self.paddle_graph.add_layer( + 'paddle.prod', + inputs={"x": val_x.name + "_shape_first"}, + outputs=[val_x.name + "_shape_first"]) + self.paddle_graph.add_layer( + "paddle.slice", + inputs={"input": val_x.name + "_shape"}, + outputs=[val_x.name + "_shape_second"], + axes=[0], + starts=[axis], + ends=[2147483647]) + self.paddle_graph.add_layer( + 'paddle.prod', + inputs={"x": val_x.name + "_shape_second"}, + outputs=[val_x.name + "_shape_second"]) + self.paddle_graph.add_layer( + 'paddle.concat', + inputs={ + "x": [ + val_x.name + "_shape_first", + val_x.name + "_shape_second" + ] + }, + outputs=[val_x.name + "_all_shape"], + axis=0) + self.paddle_graph.add_layer( + 'paddle.reshape', + inputs={'x': val_x.name}, + outputs=[node.name], + shape=val_x.name + "_all_shape") @print_mapping_info def Gemm(self, node): -- GitLab