roi_extractor.py 4.0 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from ppdet.core.workspace import register
from ppdet.modeling import ops


20 21 22 23 24 25
def _to_list(v):
    if not isinstance(v, (list, tuple)):
        return [v]
    return v


Q
qingqing01 已提交
26 27
@register
class RoIAlign(object):
W
wangguanzhong 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
    """
    RoI Align module

    For more details, please refer to the document of roi_align in
    in ppdet/modeing/ops.py

    Args:
        resolution (int): The output size, default 14
        spatial_scale (float): Multiplicative spatial scale factor to translate
            ROI coords from their input scale to the scale used when pooling.
            default 0.0625
        sampling_ratio (int): The number of sampling points in the interpolation
            grid, default 0
        canconical_level (int): The referring level of FPN layer with 
            specified level. default 4
        canonical_size (int): The referring scale of FPN layer with 
            specified scale. default 224
        start_level (int): The start level of FPN layer to extract RoI feature,
            default 0
        end_level (int): The end level of FPN layer to extract RoI feature,
            default 3
        aligned (bool): Whether to add offset to rois' coord in roi_align.
            default false
    """

Q
qingqing01 已提交
53 54
    def __init__(self,
                 resolution=14,
55
                 spatial_scale=0.0625,
Q
qingqing01 已提交
56 57 58 59
                 sampling_ratio=0,
                 canconical_level=4,
                 canonical_size=224,
                 start_level=0,
60 61
                 end_level=3,
                 aligned=False):
Q
qingqing01 已提交
62 63
        super(RoIAlign, self).__init__()
        self.resolution = resolution
64
        self.spatial_scale = _to_list(spatial_scale)
Q
qingqing01 已提交
65 66 67 68 69
        self.sampling_ratio = sampling_ratio
        self.canconical_level = canconical_level
        self.canonical_size = canonical_size
        self.start_level = start_level
        self.end_level = end_level
70
        self.aligned = aligned
Q
qingqing01 已提交
71

72 73 74 75 76 77 78
    @classmethod
    def from_config(cls, cfg, input_shape):
        return {'spatial_scale': [1. / i.stride for i in input_shape]}

    def __call__(self, feats, roi, rois_num):
        roi = paddle.concat(roi) if len(roi) > 1 else roi[0]
        if len(feats) == 1:
Q
qingqing01 已提交
79 80 81 82
            rois_feat = ops.roi_align(
                feats[self.start_level],
                roi,
                self.resolution,
83 84 85
                self.spatial_scale[0],
                rois_num=rois_num,
                aligned=self.aligned)
Q
qingqing01 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
        else:
            offset = 2
            k_min = self.start_level + offset
            k_max = self.end_level + offset
            rois_dist, restore_index, rois_num_dist = ops.distribute_fpn_proposals(
                roi,
                k_min,
                k_max,
                self.canconical_level,
                self.canonical_size,
                rois_num=rois_num)
            rois_feat_list = []
            for lvl in range(self.start_level, self.end_level + 1):
                roi_feat = ops.roi_align(
                    feats[lvl],
                    rois_dist[lvl],
                    self.resolution,
103
                    self.spatial_scale[lvl],
Q
qingqing01 已提交
104
                    sampling_ratio=self.sampling_ratio,
105 106
                    rois_num=rois_num_dist[lvl],
                    aligned=self.aligned)
107
                rois_feat_list.append(roi_feat)
Q
qingqing01 已提交
108 109 110 111
            rois_feat_shuffle = paddle.concat(rois_feat_list)
            rois_feat = paddle.gather(rois_feat_shuffle, restore_index)

        return rois_feat