# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import ReLU
from paddle.nn.initializer import Normal, XavierUniform
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register
from ppdet.modeling import ops

from ..backbone.name_adapter import NameAdapter
from ..backbone.resnet import Blocks


@register
class TwoFCHead(nn.Layer):

    __shared__ = ['roi_stages']

    def __init__(self, in_dim=256, mlp_dim=1024, resolution=7, roi_stages=1):
        super(TwoFCHead, self).__init__()
        self.in_dim = in_dim
        self.mlp_dim = mlp_dim
        self.roi_stages = roi_stages
        fan = in_dim * resolution * resolution
        self.fc6_list = []
        self.fc6_relu_list = []
        self.fc7_list = []
        self.fc7_relu_list = []
        for stage in range(roi_stages):
            fc6_name = 'fc6_{}'.format(stage)
            fc7_name = 'fc7_{}'.format(stage)
            lr_factor = 2**stage
            fc6 = self.add_sublayer(
                fc6_name,
                nn.Linear(
                    in_dim * resolution * resolution,
                    mlp_dim,
                    weight_attr=ParamAttr(
                        learning_rate=lr_factor,
                        initializer=XavierUniform(fan_out=fan)),
                    bias_attr=ParamAttr(
                        learning_rate=2. * lr_factor, regularizer=L2Decay(0.))))
            fc6_relu = self.add_sublayer(fc6_name + 'act', ReLU())
            fc7 = self.add_sublayer(
                fc7_name,
                nn.Linear(
                    mlp_dim,
                    mlp_dim,
                    weight_attr=ParamAttr(
                        learning_rate=lr_factor, initializer=XavierUniform()),
                    bias_attr=ParamAttr(
                        learning_rate=2. * lr_factor, regularizer=L2Decay(0.))))
            fc7_relu = self.add_sublayer(fc7_name + 'act', ReLU())
            self.fc6_list.append(fc6)
            self.fc6_relu_list.append(fc6_relu)
            self.fc7_list.append(fc7)
            self.fc7_relu_list.append(fc7_relu)

    def forward(self, rois_feat, stage=0):
        rois_feat = paddle.flatten(rois_feat, start_axis=1, stop_axis=-1)
        fc6 = self.fc6_list[stage](rois_feat)
        fc6_relu = self.fc6_relu_list[stage](fc6)
        fc7 = self.fc7_list[stage](fc6_relu)
        fc7_relu = self.fc7_relu_list[stage](fc7)
        return fc7_relu


@register
class Res5Head(nn.Layer):
    def __init__(self, feat_in=1024, feat_out=512):
        super(Res5Head, self).__init__()
        na = NameAdapter(self)
        self.res5_conv = []
        self.res5 = self.add_sublayer(
            'res5_roi_feat',
            Blocks(
                feat_in, feat_out, count=3, name_adapter=na, stage_num=5))
        self.feat_out = feat_out * 4

    def forward(self, roi_feat, stage=0):
        y = self.res5(roi_feat)
        return y


@register
class BBoxFeat(nn.Layer):
    __inject__ = ['roi_extractor', 'head_feat']

    def __init__(self, roi_extractor, head_feat):
        super(BBoxFeat, self).__init__()
        self.roi_extractor = roi_extractor
        self.head_feat = head_feat
        self.rois_feat_list = []

    def forward(self, body_feats, rois, spatial_scale, stage=0):
        rois_feat = self.roi_extractor(body_feats, rois, spatial_scale)
        bbox_feat = self.head_feat(rois_feat, stage)
        return rois_feat, bbox_feat


@register
class BBoxHead(nn.Layer):
    __shared__ = ['num_classes', 'roi_stages']
    __inject__ = ['bbox_feat']

    def __init__(self,
                 bbox_feat,
                 in_feat=1024,
                 num_classes=81,
                 cls_agnostic=False,
                 roi_stages=1,
                 with_pool=False,
                 score_stage=[0, 1, 2],
                 delta_stage=[2]):
        super(BBoxHead, self).__init__()
        self.num_classes = num_classes
        self.cls_agnostic = cls_agnostic
        self.delta_dim = 2 if cls_agnostic else num_classes
        self.bbox_feat = bbox_feat
        self.roi_stages = roi_stages
        self.bbox_score_list = []
        self.bbox_delta_list = []
        self.roi_feat_list = [[] for i in range(roi_stages)]
        self.with_pool = with_pool
        self.score_stage = score_stage
        self.delta_stage = delta_stage
        for stage in range(roi_stages):
            score_name = 'bbox_score_{}'.format(stage)
            delta_name = 'bbox_delta_{}'.format(stage)
            lr_factor = 2**stage
            bbox_score = self.add_sublayer(
                score_name,
                nn.Linear(
                    in_feat,
                    1 * self.num_classes,
                    weight_attr=ParamAttr(
                        learning_rate=lr_factor,
                        initializer=Normal(
                            mean=0.0, std=0.01)),
                    bias_attr=ParamAttr(
                        learning_rate=2. * lr_factor, regularizer=L2Decay(0.))))

            bbox_delta = self.add_sublayer(
                delta_name,
                nn.Linear(
                    in_feat,
                    4 * self.delta_dim,
                    weight_attr=ParamAttr(
                        learning_rate=lr_factor,
                        initializer=Normal(
                            mean=0.0, std=0.001)),
                    bias_attr=ParamAttr(
                        learning_rate=2. * lr_factor, regularizer=L2Decay(0.))))
            self.bbox_score_list.append(bbox_score)
            self.bbox_delta_list.append(bbox_delta)

    def forward(self,
                body_feats=None,
                rois=None,
                spatial_scale=None,
                stage=0,
                roi_stage=-1):
        if rois is not None:
            rois_feat, bbox_feat = self.bbox_feat(body_feats, rois,
                                                  spatial_scale, stage)
            self.roi_feat_list[stage] = rois_feat
        else:
            rois_feat = self.roi_feat_list[roi_stage]
            bbox_feat = self.bbox_feat.head_feat(rois_feat, stage)
        if self.with_pool:
            bbox_feat_ = F.adaptive_avg_pool2d(bbox_feat, output_size=1)
            bbox_feat_ = paddle.squeeze(bbox_feat_, axis=[2, 3])
            scores = self.bbox_score_list[stage](bbox_feat_)
            deltas = self.bbox_delta_list[stage](bbox_feat_)
        else:
            scores = self.bbox_score_list[stage](bbox_feat)
            deltas = self.bbox_delta_list[stage](bbox_feat)
        bbox_head_out = (scores, deltas)
        return bbox_feat, bbox_head_out, self.bbox_feat.head_feat

    def _get_head_loss(self, score, delta, target):
        # bbox cls  
        labels_int64 = paddle.cast(x=target['labels_int32'], dtype='int64')
        labels_int64.stop_gradient = True
        loss_bbox_cls = F.softmax_with_cross_entropy(
            logits=score, label=labels_int64)
        loss_bbox_cls = paddle.mean(loss_bbox_cls)
        # bbox reg
        loss_bbox_reg = ops.smooth_l1(
            input=delta,
            label=target['bbox_targets'],
            inside_weight=target['bbox_inside_weights'],
            outside_weight=target['bbox_outside_weights'],
            sigma=1.0)
        loss_bbox_reg = paddle.mean(loss_bbox_reg)
        return loss_bbox_cls, loss_bbox_reg

    def get_loss(self, bbox_head_out, targets):
        loss_bbox = {}
        cls_name = 'loss_bbox_cls'
        reg_name = 'loss_bbox_reg'
        for lvl, (bboxhead, target) in enumerate(zip(bbox_head_out, targets)):
            score, delta = bboxhead
            if len(targets) > 1:
                cls_name = 'loss_bbox_cls_{}'.format(lvl)
                reg_name = 'loss_bbox_reg_{}'.format(lvl)
            loss_bbox_cls, loss_bbox_reg = self._get_head_loss(score, delta,
                                                               target)
            loss_weight = 1. / 2**lvl
            loss_bbox[cls_name] = loss_bbox_cls * loss_weight
            loss_bbox[reg_name] = loss_bbox_reg * loss_weight
        return loss_bbox

    def get_prediction(self, bbox_head_out, rois):
        proposal, proposal_num = rois
        score, delta = bbox_head_out
        bbox_prob = F.softmax(score)
        delta = paddle.reshape(delta, (-1, self.delta_dim, 4))
        bbox_pred = (delta, bbox_prob)
        return bbox_pred, rois

    def get_cascade_prediction(self, bbox_head_out, rois):
        proposal_list = []
        prob_list = []
        delta_list = []
        for stage in range(len(rois)):
            proposals = rois[stage]
            bboxhead = bbox_head_out[stage]
            score, delta = bboxhead
            proposal, proposal_num = proposals
            if stage in self.score_stage:
                if stage < 2:
                    _, head_out, _ = self(stage=stage, roi_stage=-1)
                    score = head_out[0]

                bbox_prob = F.softmax(score)
                prob_list.append(bbox_prob)
            if stage in self.delta_stage:
                proposal_list.append(proposal)
                delta_list.append(delta)
        bbox_prob = paddle.mean(paddle.stack(prob_list), axis=0)
        delta = paddle.mean(paddle.stack(delta_list), axis=0)
        proposal = paddle.mean(paddle.stack(proposal_list), axis=0)
        delta = paddle.reshape(delta, (-1, self.delta_dim, 4))
        if self.cls_agnostic:
            N, C, M = delta.shape
            delta = delta[:, 1:2, :]
            delta = paddle.expand(delta, [N, self.num_classes, M])
        bboxes = (proposal, proposal_num)
        bbox_pred = (delta, bbox_prob)
        return bbox_pred, bboxes