diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 3dec5b5e809dd9509132ef2250e0427d603f745d..db5530e0fe2cdd5cd7d41e66a06cb00ee9b738f4 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -296,14 +296,40 @@ class OpSet9(): val_x = self.graph.get_input_node(node, idx=0, copy=True) inputs = {'x': val_x.name} attrs = dict() + val_x_shape = val_x.out_shapes[0] if node.layer_type == 'Resize': if len(node.layer.input) == 2: # opset 10 val_scales = self.graph.get_input_node(node, idx=1, copy=True) # TODO(syf): paddle.nn.functional.interpolate will support the length # which is the same as the rank of input. - attrs['scale_factor'] = self.weights[val_scales.name].tolist()[ - 2:] + scale_values = _const_weight_or_none(val_scales) + if scale_values is not None: + attrs['scale_factor'] = self.weights[ + val_scales.name].tolist()[2:] + else: + var_nc, var_hw = val_scales.name + '_nc', val_scales.name + '_hw' + self.paddle_graph.add_layer( + 'paddle.split', + inputs={"x": val_scales.name}, + outputs=[var_nc, var_hw], + num_or_sections=[2, 2], + axis=0) + inputs['scale_factor'] = var_hw + mode = node.get_attr('mode', 'nearest') + attrs.update({ + "align_corners": False, + "mode": string(mode), + "align_mode": 1 + }) + if mode == "linear" and len(val_x_shape) == 4: + attrs["mode"] = string("bilinear") + self.paddle_graph.add_layer( + kernel="paddle.nn.functional.interpolate", + inputs=inputs, + outputs=[node.name], + **attrs) + return elif len(node.layer.input) == 3: # opset 11 val_scales = self.graph.get_input_node(node, idx=2, copy=True) @@ -315,7 +341,6 @@ class OpSet9(): # opset 11 val_sizes = self.graph.get_input_node(node, idx=3, copy=True) size_values = _const_weight_or_none(val_sizes) - val_x_shape = val_x.out_shapes[0] if len(val_x_shape) == 3: var_n, var_hw = val_sizes.name + '_n', val_sizes.name + '_hw' self.paddle_graph.add_layer( @@ -418,7 +443,6 @@ class OpSet9(): }) if len(node.layer.input) == 1: attrs["scale_factor"] = val_scales - val_x_shape = val_x.out_shapes[0] if mode == "linear" and len(val_x_shape) == 4: attrs["mode"] = string("bilinear") if node.get_attr('coordinate_transformation_mode',