From 422a6b56375b4ed9585a4ecd1e3f0d477ae184ef Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Thu, 12 May 2022 15:52:23 +0800 Subject: [PATCH] deal with comments --- .../op_mapper/onnx2paddle/opset9/opset.py | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index e2d0e10..0f9e0a4 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -740,30 +740,27 @@ class OpSet9(): def Unsqueeze(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) axes = node.get_attr('axes') - if axes is None: - axes = self.graph.get_input_node(node, idx=1, copy=True) - axes = _const_weight_or_none(axes) - if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0: - if node.name: + if axes is not None: + if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[ + 0] == 0: self.paddle_graph.add_layer( 'paddle.reshape', inputs={"x": val_x.name}, outputs=[node.name], shape=[1]) - else: - if isinstance(axes, list) or isinstance(axes, tuple) or isinstance( - axes, np.ndarray): + else: self.paddle_graph.add_layer( 'paddle.unsqueeze', inputs={"x": val_x.name}, axis=axes, outputs=[node.name]) - else: - self.paddle_graph.add_layer( - 'paddle.unsqueeze', - inputs={"x": val_x.name, - "axis": axes.name}, - outputs=[node.name]) + else: + axes = self.graph.get_input_node(node, idx=1, copy=True) + self.paddle_graph.add_layer( + 'paddle.unsqueeze', + inputs={"x": val_x.name, + "axis": axes.name}, + outputs=[node.name]) @print_mapping_info def Shrink(self, node): @@ -896,7 +893,7 @@ class OpSet9(): def Gather(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) indices = self.graph.get_input_node(node, idx=1, copy=True) - indices_values = _const_weight_or_none(indices) + indices_values = _const_weight_or_none(indices, necessary=True) if isinstance(indices_values, np.ndarray): indices_values = indices_values.tolist() indices_shape = indices.out_shapes[0] -- GitLab