diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index f8c55bb3b096921954c58ba4cc063a9ae9b60338..69d2395bcac054e098787c672dad23cc32e1fd01 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -688,28 +688,26 @@ class OpSet9(): axes = node.get_attr('axes') if axes is None: axes = self.graph.get_input_node(node, idx=1, copy=True) - if node.name in ["x2paddle_vis_local_cost_volume_3d_0_ExpandDims_5_0"]: - print("output_shape:", val_x.out_shapes[0]) - # if len(val_x.out_shapes[0]) == 0: - # if node.name: - # self.paddle_graph.add_layer( - # 'paddle.reshape', - # inputs={"x": val_x.name}, - # outputs=[node.name], - # shape=[1]) - # else: - if isinstance(axes, list) or isinstance(axes, tuple): - self.paddle_graph.add_layer( - 'paddle.unsqueeze', - inputs={"x": val_x.name}, - axis=axes, - outputs=[node.name]) + if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0: + if node.name: + self.paddle_graph.add_layer( + 'paddle.reshape', + inputs={"x": val_x.name}, + outputs=[node.name], + shape=[1]) else: - self.paddle_graph.add_layer( - 'paddle.unsqueeze', - inputs={"x": val_x.name, - "axis": axes.name}, - outputs=[node.name]) + if isinstance(axes, list) or isinstance(axes, tuple): + self.paddle_graph.add_layer( + 'paddle.unsqueeze', + inputs={"x": val_x.name}, + axis=axes, + outputs=[node.name]) + else: + self.paddle_graph.add_layer( + 'paddle.unsqueeze', + inputs={"x": val_x.name, + "axis": axes.name}, + outputs=[node.name]) @print_mapping_info def Shrink(self, node):