# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from scipy.optimize import linear_sum_assignment
from collections import abc, defaultdict
import numpy as np
import paddle

from ppdet.core.workspace import register, create, serializable
from .meta_arch import BaseArch
from .. import layers as L
from ..keypoint_utils import transpred

__all__ = ['HigherHRNet']
class HigherHRNet(BaseArch):
    __category__ = 'architecture'

    def __init__(self,
        HigherHRNet network, see;
        HigherHRNet+swahr, see
            backbone (nn.Layer): backbone instance
            hrhrnet_head (nn.Layer): keypoint_head instance
            bbox_post_process (object): `BBoxPostProcess` instance
        super(HigherHRNet, self).__init__()
        self.backbone = backbone
        self.hrhrnet_head = hrhrnet_head
        self.post_process = post_process
        self.flip = eval_flip
        self.flip_perm = paddle.to_tensor(flip_perm)
        self.deploy = False
        self.interpolate = L.Upsample(2, mode='bilinear')
        self.pool = L.MaxPool(5, 1, 2)
        self.max_num_people = max_num_people
    def from_config(cls, cfg, *args, **kwargs):
        # backbone
        backbone = create(cfg['backbone'])
        # head
        kwargs = {'input_shape': backbone.out_shape}
        hrhrnet_head = create(cfg['hrhrnet_head'], **kwargs)
        post_process = create(cfg['post_process'])

        return {
            'backbone': backbone,
            "hrhrnet_head": hrhrnet_head,
            "post_process": post_process,

    def _forward(self):
        if self.flip and not and not self.deploy:
            self.inputs['image'] = paddle.concat(
                (self.inputs['image'], paddle.flip(self.inputs['image'], [3])))
        body_feats = self.backbone(self.inputs)

            return self.hrhrnet_head(body_feats, self.inputs)
            outputs = self.hrhrnet_head(body_feats)

            if self.flip and not self.deploy:
                outputs = [paddle.split(o, 2) for o in outputs]
                output_rflip = [
                    paddle.flip(paddle.gather(o[1], self.flip_perm, 1), [3])
                    for o in outputs
                output1 = [o[0] for o in outputs]
                heatmap = (output1[0] + output_rflip[0]) / 2.
                tagmaps = [output1[1], output_rflip[1]]
                outputs = [heatmap] + tagmaps
            outputs = self.get_topk(outputs)

            if self.deploy:
                return outputs

            res_lst = []
            h = self.inputs['im_shape'][0, 0].numpy().item()
            w = self.inputs['im_shape'][0, 1].numpy().item()
            kpts, scores = self.post_process(*outputs, h, w)
            res_lst.append([kpts, scores])
            return res_lst
    def get_loss(self):
        return self._forward()

    def get_pred(self):
        outputs = {}
        res_lst = self._forward()
        outputs['keypoint'] = res_lst
        return outputs

    def get_topk(self, outputs):
        # resize to image size
        outputs = [self.interpolate(x) for x in outputs]
        if len(outputs) == 3:
            tagmap = paddle.concat(
                (outputs[1].unsqueeze(4), outputs[2].unsqueeze(4)), axis=4)
            tagmap = outputs[1].unsqueeze(4)

        heatmap = outputs[0]
        N, J = 1, self.hrhrnet_head.num_joints
        heatmap_maxpool = self.pool(heatmap)
        # topk
        maxmap = heatmap * (heatmap == heatmap_maxpool)
        maxmap = maxmap.reshape([N, J, -1])
        heat_k, inds_k = maxmap.topk(self.max_num_people, axis=2)

        outputs = [heatmap, tagmap, heat_k, inds_k]
        return outputs

class HrHRNetPostProcess(object):
    HrHRNet postprocess contain:
        1) get topk keypoints in the output heatmap
        2) sample the tagmap's value corresponding to each of the topk coordinate
        3) match different joints to combine to some people with Hungary algorithm
        4) adjust the coordinate by +-0.25 to decrease error std
        5) salvage missing joints by check positivity of heatmap - tagdiff_norm
        max_num_people (int): max number of people support in postprocess
        heat_thresh (float): value of topk below this threshhold will be ignored
        tag_thresh (float): coord's value sampled in tagmap below this threshold belong to same people for init

        inputs(list[heatmap]): the output list of modle, [heatmap, heatmap_maxpool, tagmap], heatmap_maxpool used to get topk
        original_height, original_width (float): the original image size

    def __init__(self, max_num_people=30, heat_thresh=0.1, tag_thresh=1.):
        self.max_num_people = max_num_people
        self.heat_thresh = heat_thresh
        self.tag_thresh = tag_thresh

    def lerp(self, j, y, x, heatmap):
        H, W = heatmap.shape[-2:]
        left = np.clip(x - 1, 0, W - 1)
        right = np.clip(x + 1, 0, W - 1)
        up = np.clip(y - 1, 0, H - 1)
        down = np.clip(y + 1, 0, H - 1)
        offset_y = np.where(heatmap[j, down, x] > heatmap[j, up, x], 0.25,
        offset_x = np.where(heatmap[j, y, right] > heatmap[j, y, left], 0.25,
        return offset_y + 0.5, offset_x + 0.5

    def __call__(self, heatmap, tagmap, heat_k, inds_k, original_height,
        N, J, H, W = heatmap.shape
        assert N == 1, "only support batch size 1"
        heatmap = heatmap[0].cpu().detach().numpy()
        tagmap = tagmap[0].cpu().detach().numpy()
        heats = heat_k[0].cpu().detach().numpy()
        inds_np = inds_k[0].cpu().detach().numpy()
        y = inds_np // W
        x = inds_np % W
        tags = tagmap[np.arange(J)[None, :].repeat(self.max_num_people),
                      y.flatten(), x.flatten()].reshape(J, -1, tagmap.shape[-1])
        coords = np.stack((y, x), axis=2)
        # threshold
        mask = heats > self.heat_thresh
        # cluster
        cluster = defaultdict(lambda: {
            'coords': np.zeros((J, 2), dtype=np.float32),
            'scores': np.zeros(J, dtype=np.float32),
            'tags': []
        for jid, m in enumerate(mask):
            num_valid = m.sum()
            if num_valid == 0:
            valid_inds = np.where(m)[0]
            valid_tags = tags[jid, m, :]
            if len(cluster) == 0:  # initialize
                for i in valid_inds:
                    tag = tags[jid, i]
                    key = tag[0]
                    cluster[key]['scores'][jid] = heats[jid, i]
                    cluster[key]['coords'][jid] = coords[jid, i]
            candidates = list(cluster.keys())[:self.max_num_people]
            centroids = [
                    cluster[k]['tags'], axis=0) for k in candidates
            num_clusters = len(centroids)
            # shape is (num_valid, num_clusters, tag_dim)
            dist = valid_tags[:, None, :] - np.array(centroids)[None, ...]
            l2_dist = np.linalg.norm(dist, ord=2, axis=2)
            # modulate dist with heat value, see `use_detection_val`
            cost = np.round(l2_dist) * 100 - heats[jid, m, None]
            # pad the cost matrix, otherwise new pose are ignored
            if num_valid > num_clusters:
                cost = np.pad(cost, ((0, 0), (0, num_valid - num_clusters)),
                              constant_values=((0, 0), (0, 1e-10)))
            rows, cols = linear_sum_assignment(cost)
            for y, x in zip(rows, cols):
                tag = tags[jid, y]
                if y < num_valid and x < num_clusters and \
                   l2_dist[y, x] < self.tag_thresh:
                    key = candidates[x]  # merge to cluster
                    key = tag[0]  # initialize new cluster
                cluster[key]['scores'][jid] = heats[jid, y]
                cluster[key]['coords'][jid] = coords[jid, y]

        # shape is [k, J, 2] and [k, J]
        pose_tags = np.array([cluster[k]['tags'] for k in cluster])
        pose_coords = np.array([cluster[k]['coords'] for k in cluster])
        pose_scores = np.array([cluster[k]['scores'] for k in cluster])
        valid = pose_scores > 0

        pose_kpts = np.zeros((pose_scores.shape[0], J, 3), dtype=np.float32)
        if valid.sum() == 0:
            return pose_kpts, pose_kpts

        # refine coords
        valid_coords = pose_coords[valid].astype(np.int32)
        y = valid_coords[..., 0].flatten()
        x = valid_coords[..., 1].flatten()
        _, j = np.nonzero(valid)
        offsets = self.lerp(j, y, x, heatmap)
        pose_coords[valid, 0] += offsets[0]
        pose_coords[valid, 1] += offsets[1]

        # mean score before salvage
        mean_score = pose_scores.mean(axis=1)
        pose_kpts[valid, 2] = pose_scores[valid]

        # salvage missing joints
        if True:
            for pid, coords in enumerate(pose_coords):
                tag_mean = np.array(pose_tags[pid]).mean(axis=0)
                norm = np.sum((tagmap - tag_mean)**2, axis=3)**0.5
                score = heatmap - np.round(norm)  # (J, H, W)
                flat_score = score.reshape(J, -1)
                max_inds = np.argmax(flat_score, axis=1)
                max_scores = np.max(flat_score, axis=1)
                salvage_joints = (pose_scores[pid] == 0) & (max_scores > 0)
                if salvage_joints.sum() == 0:
                y = max_inds[salvage_joints] // W
                x = max_inds[salvage_joints] % W
                offsets = self.lerp(salvage_joints.nonzero()[0], y, x, heatmap)
                y = y.astype(np.float32) + offsets[0]
                x = x.astype(np.float32) + offsets[1]
                pose_coords[pid][salvage_joints, 0] = y
                pose_coords[pid][salvage_joints, 1] = x
                pose_kpts[pid][salvage_joints, 2] = max_scores[salvage_joints]
        pose_kpts[..., :2] = transpred(pose_coords[..., :2][..., ::-1],
                                       original_height, original_width,
                                       min(H, W))
        return pose_kpts, mean_score