diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index e2d0e10429a8e974f96e893a4b23bbe555a41840..0f9e0a4465fc6ec745740e3f486821abf54f145c 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -740,30 +740,27 @@ class OpSet9(): def Unsqueeze(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) axes = node.get_attr('axes') - if axes is None: - axes = self.graph.get_input_node(node, idx=1, copy=True) - axes = _const_weight_or_none(axes) - if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0: - if node.name: + if axes is not None: + if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[ + 0] == 0: self.paddle_graph.add_layer( 'paddle.reshape', inputs={"x": val_x.name}, outputs=[node.name], shape=[1]) - else: - if isinstance(axes, list) or isinstance(axes, tuple) or isinstance( - axes, np.ndarray): + else: self.paddle_graph.add_layer( 'paddle.unsqueeze', inputs={"x": val_x.name}, axis=axes, outputs=[node.name]) - else: - self.paddle_graph.add_layer( - 'paddle.unsqueeze', - inputs={"x": val_x.name, - "axis": axes.name}, - outputs=[node.name]) + else: + axes = self.graph.get_input_node(node, idx=1, copy=True) + self.paddle_graph.add_layer( + 'paddle.unsqueeze', + inputs={"x": val_x.name, + "axis": axes.name}, + outputs=[node.name]) @print_mapping_info def Shrink(self, node): @@ -896,7 +893,7 @@ class OpSet9(): def Gather(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) indices = self.graph.get_input_node(node, idx=1, copy=True) - indices_values = _const_weight_or_none(indices) + indices_values = _const_weight_or_none(indices, necessary=True) if isinstance(indices_values, np.ndarray): indices_values = indices_values.tolist() indices_shape = indices.out_shapes[0]