未验证 提交 9a6c85a5 编写于 作者: W WJJ1995 提交者: GitHub

fixed resize op (#749)

上级 2958c54e
...@@ -296,14 +296,40 @@ class OpSet9(): ...@@ -296,14 +296,40 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
inputs = {'x': val_x.name} inputs = {'x': val_x.name}
attrs = dict() attrs = dict()
val_x_shape = val_x.out_shapes[0]
if node.layer_type == 'Resize': if node.layer_type == 'Resize':
if len(node.layer.input) == 2: if len(node.layer.input) == 2:
# opset 10 # opset 10
val_scales = self.graph.get_input_node(node, idx=1, copy=True) val_scales = self.graph.get_input_node(node, idx=1, copy=True)
# TODO(syf): paddle.nn.functional.interpolate will support the length # TODO(syf): paddle.nn.functional.interpolate will support the length
# which is the same as the rank of input. # which is the same as the rank of input.
attrs['scale_factor'] = self.weights[val_scales.name].tolist()[ scale_values = _const_weight_or_none(val_scales)
2:] if scale_values is not None:
attrs['scale_factor'] = self.weights[
val_scales.name].tolist()[2:]
else:
var_nc, var_hw = val_scales.name + '_nc', val_scales.name + '_hw'
self.paddle_graph.add_layer(
'paddle.split',
inputs={"x": val_scales.name},
outputs=[var_nc, var_hw],
num_or_sections=[2, 2],
axis=0)
inputs['scale_factor'] = var_hw
mode = node.get_attr('mode', 'nearest')
attrs.update({
"align_corners": False,
"mode": string(mode),
"align_mode": 1
})
if mode == "linear" and len(val_x_shape) == 4:
attrs["mode"] = string("bilinear")
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.interpolate",
inputs=inputs,
outputs=[node.name],
**attrs)
return
elif len(node.layer.input) == 3: elif len(node.layer.input) == 3:
# opset 11 # opset 11
val_scales = self.graph.get_input_node(node, idx=2, copy=True) val_scales = self.graph.get_input_node(node, idx=2, copy=True)
...@@ -315,7 +341,6 @@ class OpSet9(): ...@@ -315,7 +341,6 @@ class OpSet9():
# opset 11 # opset 11
val_sizes = self.graph.get_input_node(node, idx=3, copy=True) val_sizes = self.graph.get_input_node(node, idx=3, copy=True)
size_values = _const_weight_or_none(val_sizes) size_values = _const_weight_or_none(val_sizes)
val_x_shape = val_x.out_shapes[0]
if len(val_x_shape) == 3: if len(val_x_shape) == 3:
var_n, var_hw = val_sizes.name + '_n', val_sizes.name + '_hw' var_n, var_hw = val_sizes.name + '_n', val_sizes.name + '_hw'
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
...@@ -418,7 +443,6 @@ class OpSet9(): ...@@ -418,7 +443,6 @@ class OpSet9():
}) })
if len(node.layer.input) == 1: if len(node.layer.input) == 1:
attrs["scale_factor"] = val_scales attrs["scale_factor"] = val_scales
val_x_shape = val_x.out_shapes[0]
if mode == "linear" and len(val_x_shape) == 4: if mode == "linear" and len(val_x_shape) == 4:
attrs["mode"] = string("bilinear") attrs["mode"] = string("bilinear")
if node.get_attr('coordinate_transformation_mode', if node.get_attr('coordinate_transformation_mode',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册