未验证 提交 dbcb7398 编写于 作者: Q qqj1130247885 提交者: GitHub

fix_resize_op (#830)

* fix_resiz

* check

* fix-ci

* fix split and resize

* fix resize and split

* re-lint
Co-authored-by: NSunAhong1993 <sunyanfang01@baidu.com>
上级 f2e933cf
...@@ -335,7 +335,13 @@ class OpSet9(): ...@@ -335,7 +335,13 @@ class OpSet9():
return return
elif len(node.layer.input) == 3: elif len(node.layer.input) == 3:
# opset 11 # opset 11
val_scales = self.graph.get_input_node(node, idx=2, copy=True) try:
#to avoid the error causeed by NULL value of resize inputs.
val_scales = self.graph.get_input_node(
node, idx=2, copy=True)
except:
val_scales = self.graph.get_input_node(
node, idx=1, copy=True)
# TODO(syf): paddle.nn.functional.interpolate will support the length # TODO(syf): paddle.nn.functional.interpolate will support the length
# which is the same as the rank of input. # which is the same as the rank of input.
attrs['scale_factor'] = self.weights[val_scales.name].tolist()[ attrs['scale_factor'] = self.weights[val_scales.name].tolist()[
...@@ -1391,10 +1397,19 @@ class OpSet9(): ...@@ -1391,10 +1397,19 @@ class OpSet9():
axis = node.get_attr('axis', 0) axis = node.get_attr('axis', 0)
if split is None: if split is None:
split_num = len(node.layer.output) split_num = len(node.layer.output)
layer_attrs = { try:
'num_or_sections': split_num, #split is an input of this node
'axis': axis, split_node = self.graph.get_input_node(node, idx=1, copy=True)
} split_value = _const_weight_or_none(split_node)
layer_attrs = {
'num_or_sections': split_value.tolist(),
'axis': axis,
}
except:
layer_attrs = {
'num_or_sections': split_num,
'axis': axis,
}
outputs_list = list() outputs_list = list()
for i in range(len(node.layer.output)): for i in range(len(node.layer.output)):
if hasattr(node, 'index'): if hasattr(node, 'index'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册