From feb0f15741271e383eb8457c36316986003ac29a Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Thu, 12 May 2022 20:11:12 +0800 Subject: [PATCH] deal with scalar tensor --- .../op_mapper/onnx2paddle/opset9/opset.py | 29 ++++++++----------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 0f9e0a4..3e7911c 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -740,26 +740,21 @@ class OpSet9(): def Unsqueeze(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) axes = node.get_attr('axes') - if axes is not None: - if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[ - 0] == 0: - self.paddle_graph.add_layer( - 'paddle.reshape', - inputs={"x": val_x.name}, - outputs=[node.name], - shape=[1]) - else: - self.paddle_graph.add_layer( - 'paddle.unsqueeze', - inputs={"x": val_x.name}, - axis=axes, - outputs=[node.name]) + if axes is None: + axes_node = self.graph.get_input_node(node, idx=1, copy=True) + axes = _const_weight_or_none(axes_node, necessary=True) + # deal with scalar(0D) tensor + if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0: + self.paddle_graph.add_layer( + 'paddle.reshape', + inputs={"x": val_x.name}, + outputs=[node.name], + shape=[1]) else: - axes = self.graph.get_input_node(node, idx=1, copy=True) self.paddle_graph.add_layer( 'paddle.unsqueeze', - inputs={"x": val_x.name, - "axis": axes.name}, + inputs={"x": val_x.name}, + axis=axes, outputs=[node.name]) @print_mapping_info -- GitLab