From d0ec9c38f0576e8b85e7fa7fba2e0c75d81e5d78 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Tue, 7 Jun 2022 20:18:25 +0800 Subject: [PATCH] fixed for ci --- .../onnx2paddle/onnx_custom_layer/roi_align.py | 17 ++++++----------- x2paddle/op_mapper/onnx2paddle/opset9/opset.py | 8 +++++--- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/onnx_custom_layer/roi_align.py b/x2paddle/op_mapper/onnx2paddle/onnx_custom_layer/roi_align.py index c29e744..cef9f13 100644 --- a/x2paddle/op_mapper/onnx2paddle/onnx_custom_layer/roi_align.py +++ b/x2paddle/op_mapper/onnx2paddle/onnx_custom_layer/roi_align.py @@ -21,18 +21,13 @@ from paddle.common_ops_import import Variable, LayerHelper, check_variable_and_d @paddle.jit.not_to_static def roi_align(input, rois, - output_size, + pooled_height, + pooled_width, spatial_scale=1.0, sampling_ratio=-1, rois_num=None, aligned=True, name=None): - check_type(output_size, 'output_size', (int, tuple), 'roi_align') - if isinstance(output_size, int): - output_size = (output_size, output_size) - - pooled_height, pooled_width = output_size - if in_dynamic_mode(): assert rois_num is not None, "rois_num should not be None in dygraph mode." align_out = _C_ops.roi_align( @@ -71,15 +66,15 @@ def roi_align(input, class ROIAlign(object): def __init__(self, pooled_height, pooled_width, spatial_scale, - sampling_ratio, rois_num): + sampling_ratio): self.roialign_layer_attrs = { "pooled_height": pooled_height, "pooled_width": pooled_width, "spatial_scale": spatial_scale, 'sampling_ratio': sampling_ratio, - 'rois_num': rois_num, } - def __call__(self, x0, x1): - out = roi_align(input=x0, rois=x1, **self.roialign_layer_attrs) + def __call__(self, x0, x1, x2): + out = roi_align( + input=x0, rois=x1, rois_num=x2, **self.roialign_layer_attrs) return out diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 4e2d73c..f10fff7 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -538,12 +538,14 @@ class OpSet9(): 'pooled_width': pooled_width, 'spatial_scale': spatial_scale, 'sampling_ratio': sampling_ratio, - 'rois_num': val_rois_num, } self.paddle_graph.add_layer( 'custom_layer:ROIAlign', - inputs={'input': val_x.name, - 'rois': val_rois.name}, + inputs={ + 'input': val_x.name, + 'rois': val_rois.name, + 'rois_num': val_rois_num + }, outputs=[node.name], **layer_attrs) -- GitLab