提交 965aa978 编写于 作者: C Channingss

support Reisze opset 11

上级 562ffba4
...@@ -332,10 +332,37 @@ class OpSet9(): ...@@ -332,10 +332,37 @@ class OpSet9():
def _interpolate(self, node): def _interpolate(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
inputs = {'input': val_x}
if node.layer_type == 'Resize': if node.layer_type == 'Resize':
val_scales = self.graph.get_input_node(node, idx=2, copy=True) if len(node.layer.input) == 2:
# opset 10
val_scales = self.graph.get_input_node(node, idx=1, copy=True)
inputs['scale'] = val_scales
elif len(node.layer.input) == 3:
# opset 11
val_scales = self.graph.get_input_node(node, idx=2, copy=True)
inputs['scale'] = val_scales
elif len(node.layer.input) == 4:
# opset 11
val_sizes = self.graph.get_input_node(node, idx=3, copy=True)
var_nc, var_hw = val_sizes.layer_name + '_nc', val_sizes.layer_name + '_hw'
node.fluid_code.add_layer(
'split',
inputs=val_sizes,
output=var_nc + ',' + var_hw,
param_attr={
'dim': 0,
'num_or_sections': [2, 2],
})
node.fluid_code.add_layer(
"cast",
inputs=var_hw,
output=var_hw,
param_attr={'dtype': string('int32')})
inputs['out_shape'] = var_hw
elif node.layer_type == 'Upsample': elif node.layer_type == 'Upsample':
val_scales = self.graph.get_input_node(node, idx=1, copy=True) val_scales = self.graph.get_input_node(node, idx=1, copy=True)
inputs['scale'] = val_scales
attr = {'name': string(node.layer_name)} attr = {'name': string(node.layer_name)}
mode = node.get_attr('mode', 'nearest') mode = node.get_attr('mode', 'nearest')
...@@ -345,13 +372,8 @@ class OpSet9(): ...@@ -345,13 +372,8 @@ class OpSet9():
'Warnning: paddle not support op:resize wiht mode: linear, we use bilinear replace linear' 'Warnning: paddle not support op:resize wiht mode: linear, we use bilinear replace linear'
) )
fluid_op = 'resize_bilinear' fluid_op = 'resize_bilinear'
node.fluid_code.add_layer( node.fluid_code.add_layer(
fluid_op, fluid_op, inputs=inputs, output=node, param_attr=attr)
inputs={'input': val_x,
'scale': val_scales},
output=node,
param_attr=attr)
@print_mapping_info @print_mapping_info
def RoiAlign(self, node): def RoiAlign(self, node):
...@@ -497,7 +519,6 @@ class OpSet9(): ...@@ -497,7 +519,6 @@ class OpSet9():
'attribute "shape" of %s not inferred, ' 'attribute "shape" of %s not inferred, '
'using value as 1-D tensor may lead to fails', 'using value as 1-D tensor may lead to fails',
val_output.layer_name, val_output.layer_name) val_output.layer_name, val_output.layer_name)
if len(value) == 1: if len(value) == 1:
value = value.tolist() value = value.tolist()
shape = [1] shape = [1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册