diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index a5c6e1aca2018449ced2dc7f29ee046e18b90c55..b970d2167646f5983faf74d3d009cab7e0fe57d3 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -1651,47 +1651,18 @@ class OpSet9(): outputs=[node.name], shape=shape_list) else: + # flatten + reshape self.paddle_graph.add_layer( - 'paddle.shape', + "paddle.flatten", inputs={"input": val_x.name}, - outputs=[val_x.name + "_shape"]) - self.paddle_graph.add_layer( - "paddle.slice", - inputs={"input": val_x.name + "_shape"}, - outputs=[val_x.name + "_shape_first"], - axes=[0], - starts=[0], - ends=[axis]) - self.paddle_graph.add_layer( - 'paddle.prod', - inputs={"x": val_x.name + "_shape_first"}, - outputs=[val_x.name + "_shape_first"]) - self.paddle_graph.add_layer( - "paddle.slice", - inputs={"input": val_x.name + "_shape"}, - outputs=[val_x.name + "_shape_second"], - axes=[0], - starts=[axis], - ends=[2147483647]) - self.paddle_graph.add_layer( - 'paddle.prod', - inputs={"x": val_x.name + "_shape_second"}, - outputs=[val_x.name + "_shape_second"]) - self.paddle_graph.add_layer( - 'paddle.concat', - inputs={ - "x": [ - val_x.name + "_shape_first", - val_x.name + "_shape_second" - ] - }, - outputs=[val_x.name + "_all_shape"], - axis=0) + outputs=[val_x.name + "_flatten"], + start_axis=[0], + stop_axis=[axis]) self.paddle_graph.add_layer( 'paddle.reshape', - inputs={'x': val_x.name}, + inputs={'x': val_x.name + "_flatten"}, outputs=[node.name], - shape=val_x.name + "_all_shape") + shape=[0, -1]) @print_mapping_info def Gemm(self, node):