From 6ae67e350b81eaa5b9632ea8b2baa13fd9dd8551 Mon Sep 17 00:00:00 2001 From: yeliang2258 <1047690002@qq.com> Date: Wed, 29 Dec 2021 07:12:03 +0000 Subject: [PATCH] fix --- .../op_mapper/onnx2paddle/opset9/opset.py | 40 +++++++++---------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index f8c55bb..69d2395 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -688,28 +688,26 @@ class OpSet9(): axes = node.get_attr('axes') if axes is None: axes = self.graph.get_input_node(node, idx=1, copy=True) - if node.name in ["x2paddle_vis_local_cost_volume_3d_0_ExpandDims_5_0"]: - print("output_shape:", val_x.out_shapes[0]) - # if len(val_x.out_shapes[0]) == 0: - # if node.name: - # self.paddle_graph.add_layer( - # 'paddle.reshape', - # inputs={"x": val_x.name}, - # outputs=[node.name], - # shape=[1]) - # else: - if isinstance(axes, list) or isinstance(axes, tuple): - self.paddle_graph.add_layer( - 'paddle.unsqueeze', - inputs={"x": val_x.name}, - axis=axes, - outputs=[node.name]) + if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0: + if node.name: + self.paddle_graph.add_layer( + 'paddle.reshape', + inputs={"x": val_x.name}, + outputs=[node.name], + shape=[1]) else: - self.paddle_graph.add_layer( - 'paddle.unsqueeze', - inputs={"x": val_x.name, - "axis": axes.name}, - outputs=[node.name]) + if isinstance(axes, list) or isinstance(axes, tuple): + self.paddle_graph.add_layer( + 'paddle.unsqueeze', + inputs={"x": val_x.name}, + axis=axes, + outputs=[node.name]) + else: + self.paddle_graph.add_layer( + 'paddle.unsqueeze', + inputs={"x": val_x.name, + "axis": axes.name}, + outputs=[node.name]) @print_mapping_info def Shrink(self, node): -- GitLab