From dbcb7398c5e4f9900fa6d7ad6cb9f6252affa6db Mon Sep 17 00:00:00 2001 From: qqj1130247885 <51647379+qqj1130247885@users.noreply.github.com> Date: Mon, 18 Jul 2022 14:24:55 +0800 Subject: [PATCH] fix_resize_op (#830) * fix_resiz * check * fix-ci * fix split and resize * fix resize and split * re-lint Co-authored-by: SunAhong1993 --- .../op_mapper/onnx2paddle/opset9/opset.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 9c205a4..4db1033 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -335,7 +335,13 @@ class OpSet9(): return elif len(node.layer.input) == 3: # opset 11 - val_scales = self.graph.get_input_node(node, idx=2, copy=True) + try: + #to avoid the error causeed by NULL value of resize inputs. + val_scales = self.graph.get_input_node( + node, idx=2, copy=True) + except: + val_scales = self.graph.get_input_node( + node, idx=1, copy=True) # TODO(syf): paddle.nn.functional.interpolate will support the length # which is the same as the rank of input. attrs['scale_factor'] = self.weights[val_scales.name].tolist()[ @@ -1391,10 +1397,19 @@ class OpSet9(): axis = node.get_attr('axis', 0) if split is None: split_num = len(node.layer.output) - layer_attrs = { - 'num_or_sections': split_num, - 'axis': axis, - } + try: + #split is an input of this node + split_node = self.graph.get_input_node(node, idx=1, copy=True) + split_value = _const_weight_or_none(split_node) + layer_attrs = { + 'num_or_sections': split_value.tolist(), + 'axis': axis, + } + except: + layer_attrs = { + 'num_or_sections': split_num, + 'axis': axis, + } outputs_list = list() for i in range(len(node.layer.output)): if hasattr(node, 'index'): -- GitLab