anchor.py 5.2 KB
Newer Older
M
MegEngine Team 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
# -*- coding: utf-8 -*-
# Copyright 2018-2019 Open-MMLab.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ---------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ---------------------------------------------------------------------
from abc import ABCMeta, abstractmethod

import megengine.functional as F
import numpy as np

from megengine.core import tensor, Tensor


class BaseAnchorGenerator(metaclass=ABCMeta):
    """base class for anchor generator.
    """

    def __init__(self):
        pass

    @abstractmethod
    def get_anchors_by_feature(self) -> Tensor:
        pass


class DefaultAnchorGenerator(BaseAnchorGenerator):
    """default retinanet anchor generator.
    This class generate anchors by feature map in level.

    Args:
        base_size (int): anchor base size.
        anchor_scales (np.ndarray): anchor scales based on stride.
            The practical anchor scale is anchor_scale * stride
        anchor_ratios(np.ndarray): anchor aspect ratios.
        offset (float): center point offset.default is 0.

    """

    def __init__(
        self,
        base_size=8,
        anchor_scales: np.ndarray = np.array([2, 3, 4]),
        anchor_ratios: np.ndarray = np.array([0.5, 1, 2]),
        offset: float = 0,
    ):
        super().__init__()
        self.base_size = base_size
        self.anchor_scales = anchor_scales
        self.anchor_ratios = anchor_ratios
        self.offset = offset

    def _whctrs(self, anchor):
        """convert anchor box into (w, h, ctr_x, ctr_y)
        """
        w = anchor[:, 2] - anchor[:, 0] + 1
        h = anchor[:, 3] - anchor[:, 1] + 1
        x_ctr = anchor[:, 0] + 0.5 * (w - 1)
        y_ctr = anchor[:, 1] + 0.5 * (h - 1)

        return w, h, x_ctr, y_ctr

    def get_plane_anchors(self, anchor_scales: np.ndarray):
        """get anchors per location on feature map.
        The anchor number is anchor_scales x anchor_ratios
        """
        base_anchor = tensor([0, 0, self.base_size - 1, self.base_size - 1])
        base_anchor = F.add_axis(base_anchor, 0)
        w, h, x_ctr, y_ctr = self._whctrs(base_anchor)
        # ratio enumerate
        size = w * h
        size_ratios = size / self.anchor_ratios

        ws = size_ratios.sqrt().round()
        hs = (ws * self.anchor_ratios).round()

        # scale enumerate
        anchor_scales = anchor_scales[None, ...]
        ws = F.add_axis(ws, 1)
        hs = F.add_axis(hs, 1)
        ws = (ws * anchor_scales).reshape(-1, 1)
        hs = (hs * anchor_scales).reshape(-1, 1)

        anchors = F.concat(
            [
                x_ctr - 0.5 * (ws - 1),
                y_ctr - 0.5 * (hs - 1),
                x_ctr + 0.5 * (ws - 1),
                y_ctr + 0.5 * (hs - 1),
            ],
            axis=1,
        )

        return anchors.astype(np.float32)

    def get_center_offsets(self, featmap, stride):
        f_shp = featmap.shape
        fm_height, fm_width = f_shp[-2], f_shp[-1]

        shift_x = F.linspace(0, fm_width - 1, fm_width) * stride
        shift_y = F.linspace(0, fm_height - 1, fm_height) * stride

        # make the mesh grid of shift_x and shift_y
        mesh_shape = (fm_height, fm_width)
        broad_shift_x = shift_x.reshape(-1, shift_x.shape[0]).broadcast(*mesh_shape)
        broad_shift_y = shift_y.reshape(shift_y.shape[0], -1).broadcast(*mesh_shape)

        flatten_shift_x = F.add_axis(broad_shift_x.reshape(-1), 1)
        flatten_shift_y = F.add_axis(broad_shift_y.reshape(-1), 1)

        centers = F.concat(
            [flatten_shift_x, flatten_shift_y, flatten_shift_x, flatten_shift_y,],
            axis=1,
        )
        if self.offset > 0:
            centers = centers + self.offset * stride
        return centers

    def get_anchors_by_feature(self, featmap, stride):
        # shifts shape: [A, 4]
        shifts = self.get_center_offsets(featmap, stride)
        # plane_anchors shape: [B, 4], e.g. B=9
        plane_anchors = self.get_plane_anchors(self.anchor_scales * stride)

        all_anchors = F.add_axis(plane_anchors, 0) + F.add_axis(shifts, 1)
        all_anchors = all_anchors.reshape(-1, 4)

        return all_anchors

    def __call__(self, featmap, stride):
        return self.get_anchors_by_feature(featmap, stride)