diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index a4b3a1565f746e098ccbc77fc6506d345f9fb07c..104bd67025755298cd55d82285fb035aa9643533 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -1627,20 +1627,66 @@ class OpSet9(): val_x = self.graph.get_input_node(node, idx=0, copy=True) output_shape = val_x.out_shapes[0] axis = node.get_attr('axis', 1) - shape_list = [1, 1] if axis == 0: - for s in output_shape: - shape_list[1] *= s + self.paddle_graph.add_layer( + 'paddle.reshape', + inputs={"x": val_x.name}, + outputs=[node.name], + shape=[1, -1]) else: - for s in output_shape[:axis]: - shape_list[0] *= s - for s in output_shape[axis:]: - shape_list[1] *= s - self.paddle_graph.add_layer( - 'paddle.reshape', - inputs={"x": val_x.name}, - outputs=[node.name], - shape=shape_list) + if len(output_shape) != 0: + shape_list = [1, 1] + for s in output_shape[:axis]: + shape_list[0] *= s + for s in output_shape[axis:]: + shape_list[1] *= s + self.paddle_graph.add_layer( + 'paddle.reshape', + inputs={"x": val_x.name}, + outputs=[node.name], + shape=shape_list) + else: + self.paddle_graph.add_layer( + 'paddle.shape', + inputs={"input": val_x.name}, + outputs=[val_x.name + "_shape"]) + self.paddle_graph.add_layer( + "paddle.slice", + inputs={"input": val_x.name + "_shape"}, + outputs=[val_x.name + "_shape_first"], + axes=[0], + starts=[0], + ends=[axis]) + self.paddle_graph.add_layer( + 'paddle.prod', + inputs={"x": val_x.name + "_shape_first"}, + outputs=[val_x.name + "_shape_first"]) + self.paddle_graph.add_layer( + "paddle.slice", + inputs={"input": val_x.name + "_shape"}, + outputs=[val_x.name + "_shape_second"], + axes=[0], + starts=[axis], + ends=[2147483647]) + self.paddle_graph.add_layer( + 'paddle.prod', + inputs={"x": val_x.name + "_shape_second"}, + outputs=[val_x.name + "_shape_second"]) + self.paddle_graph.add_layer( + 'paddle.concat', + inputs={ + "x": [ + val_x.name + "_shape_first", + val_x.name + "_shape_second" + ] + }, + outputs=[val_x.name + "_all_shape"], + axis=0) + self.paddle_graph.add_layer( + 'paddle.reshape', + inputs={'x': val_x.name}, + outputs=[node.name], + shape=val_x.name + "_all_shape") @print_mapping_info def Gemm(self, node):