# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
from typing import List

import megengine.module as M

from megengine.core import Tensor

from official.vision.detection.layers import basic


class RetinaNetHead(M.Module):
    """
    The head used in RetinaNet for object classification and box regression.
    """

    def __init__(self, cfg, input_shape: List[basic.ShapeSpec]):
        super().__init__()

        in_channels = input_shape[0].channels
        num_classes = cfg.num_classes
        num_convs = 4
        prior_prob = cfg.cls_prior_prob
        num_anchors = [len(cfg.anchor_ratios) * len(cfg.anchor_scales)] * 5

        assert (
            len(set(num_anchors)) == 1
        ), "not support different number of anchors between levels"
        num_anchors = num_anchors[0]

        cls_subnet = []
        bbox_subnet = []
        for _ in range(num_convs):
            cls_subnet.append(
                M.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1,)
            )
            cls_subnet.append(M.ReLU())
            bbox_subnet.append(
                M.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1,)
            )
            bbox_subnet.append(M.ReLU())

        self.cls_subnet = M.Sequential(*cls_subnet)
        self.bbox_subnet = M.Sequential(*bbox_subnet)
        self.cls_score = M.Conv2d(
            in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1,
        )
        self.bbox_pred = M.Conv2d(
            in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1
        )

        # Initialization
        for modules in [
            self.cls_subnet,
            self.bbox_subnet,
            self.cls_score,
            self.bbox_pred,
        ]:
            for layer in modules.modules():
                if isinstance(layer, M.Conv2d):
                    M.init.normal_(layer.weight, mean=0, std=0.01)
                    M.init.fill_(layer.bias, 0)

        # Use prior in model initialization to improve stability
        bias_value = -math.log((1 - prior_prob) / prior_prob)
        M.init.fill_(self.cls_score.bias, bias_value)

    def forward(self, features: List[Tensor]):
        logits, bbox_reg = [], []
        for feature in features:
            logits.append(self.cls_score(self.cls_subnet(feature)))
            bbox_reg.append(self.bbox_pred(self.bbox_subnet(feature)))
        return logits, bbox_reg