未验证 提交 2a420846 编写于 作者: W WJJ1995 提交者: GitHub

Interpolate for dynamic (#720)

* fixed split bug

* fixed concat bug

* fixed node_name bugs

* fixed node_name bugs

* fixed CI bugs

* Interpolate for dynamic

* fixed for CI

* fixed for codestyle
上级 87f906ce
...@@ -302,6 +302,7 @@ class OpSet9(): ...@@ -302,6 +302,7 @@ class OpSet9():
elif len(node.layer.input) == 4: elif len(node.layer.input) == 4:
# opset 11 # opset 11
val_sizes = self.graph.get_input_node(node, idx=3, copy=True) val_sizes = self.graph.get_input_node(node, idx=3, copy=True)
size_values = _const_weight_or_none(val_sizes)
val_x_shape = val_x.out_shapes[0] val_x_shape = val_x.out_shapes[0]
if len(val_x_shape) == 3: if len(val_x_shape) == 3:
var_n, var_hw = val_sizes.name + '_n', val_sizes.name + '_hw' var_n, var_hw = val_sizes.name + '_n', val_sizes.name + '_hw'
...@@ -347,23 +348,26 @@ class OpSet9(): ...@@ -347,23 +348,26 @@ class OpSet9():
outputs=[node.name], outputs=[node.name],
axis=0) axis=0)
else: else:
var_nc, var_hw = val_sizes.name + '_nc', val_sizes.name + '_hw' if size_values is not None:
self.paddle_graph.add_layer( attrs["size"] = [size_values[2], size_values[3]]
'paddle.split', else:
inputs={"x": val_sizes.name}, var_nc, var_hw = val_sizes.name + '_nc', val_sizes.name + '_hw'
outputs=[var_nc, var_hw], self.paddle_graph.add_layer(
num_or_sections=[2, 2], 'paddle.split',
axis=0) inputs={"x": val_sizes.name},
self.paddle_graph.add_layer( outputs=[var_nc, var_hw],
"paddle.cast", num_or_sections=[2, 2],
inputs={"x": var_hw}, axis=0)
outputs=[var_hw], self.paddle_graph.add_layer(
dtype=string('int32')) "paddle.cast",
inputs['size'] = var_hw inputs={"x": var_hw},
attrs = { outputs=[var_hw],
dtype=string('int32'))
inputs['size'] = var_hw
attrs.update({
"align_corners": False, "align_corners": False,
"mode": string(node.get_attr('mode', 'nearest')) "mode": string(node.get_attr('mode', 'nearest'))
} })
mode = node.get_attr('mode', 'nearest') mode = node.get_attr('mode', 'nearest')
if mode == "linear": if mode == "linear":
attrs["mode"] = string("bilinear") attrs["mode"] = string("bilinear")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册