import os
import copy
from collections import OrderedDict

import cv2
import paddle
import numpy as np
import paddle.nn as nn
import paddlehub as hub
from skimage.measure import label
from scipy.ndimage.filters import gaussian_filter
from paddlehub.module.module import moduleinfo
from paddlehub.process.functional import npmax
from paddlehub.process.transforms import HandDetect, ResizeScaling, PadDownRight, RemovePadding, DrawPose, DrawHandPose, Normalize


@moduleinfo(name="openpose_hands_estimation",
            type="CV/image_editing",
            author="paddlepaddle",
            author_email="",
            summary="Openpose_hands_estimation is a hand pose estimation model based on Hand Keypoint Detection in \
            Single Images using Multiview Bootstrapping.",
            version="1.0.0")
class HandposeModel(nn.Layer):
    """HandposeModel
    Args:
        load_checkpoint(str): Checkpoint save path, default is None.
        visualization (bool): Whether to save the estimation result. Default is True.
    """
    def __init__(self, load_checkpoint: str = None, visualization: bool = True):
        super(HandposeModel, self).__init__()
        self.visualization = visualization
        self.hand_detect = HandDetect()
        self.resize_func = ResizeScaling()
        self.pad_func = PadDownRight()
        self.remove_pad = RemovePadding()
        self.draw_pose = DrawPose()
        self.draw_hand = DrawHandPose()
        self.norm_func = Normalize(std=[1, 1, 1])
        no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3', \
                          'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']

        block1_0 = OrderedDict([('conv1_1', [3, 64, 3, 1, 1]), ('conv1_2', [64, 64, 3, 1, 1]),
                                ('pool1_stage1', [2, 2, 0]), ('conv2_1', [64, 128, 3, 1, 1]),
                                ('conv2_2', [128, 128, 3, 1, 1]), ('pool2_stage1', [2, 2, 0]),
                                ('conv3_1', [128, 256, 3, 1, 1]), ('conv3_2', [256, 256, 3, 1, 1]),
                                ('conv3_3', [256, 256, 3, 1, 1]), ('conv3_4', [256, 256, 3, 1, 1]),
                                ('pool3_stage1', [2, 2, 0]), ('conv4_1', [256, 512, 3, 1, 1]),
                                ('conv4_2', [512, 512, 3, 1, 1]), ('conv4_3', [512, 512, 3, 1, 1]),
                                ('conv4_4', [512, 512, 3, 1, 1]), ('conv5_1', [512, 512, 3, 1, 1]),
                                ('conv5_2', [512, 512, 3, 1, 1]), ('conv5_3_CPM', [512, 128, 3, 1, 1])])

        block1_1 = OrderedDict([('conv6_1_CPM', [128, 512, 1, 1, 0]), ('conv6_2_CPM', [512, 22, 1, 1, 0])])

        blocks = {}
        blocks['block1_0'] = block1_0
        blocks['block1_1'] = block1_1

        for i in range(2, 7):
            blocks['block%d' % i] = OrderedDict([('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
                                                 ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
                                                 ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
                                                 ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
                                                 ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
                                                 ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
                                                 ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])])

        for k in blocks.keys():
            blocks[k] = self.make_layers(blocks[k], no_relu_layers)

        self.model1_0 = blocks['block1_0']
        self.model1_1 = blocks['block1_1']
        self.model2 = blocks['block2']
        self.model3 = blocks['block3']
        self.model4 = blocks['block4']
        self.model5 = blocks['block5']
        self.model6 = blocks['block6']

        if load_checkpoint is not None:
            model_dict = paddle.load(load_checkpoint)[0]
            self.set_dict(model_dict)
            print("load custom checkpoint success")

        else:
            checkpoint = os.path.join(self.directory, 'hand_estimation.pdparams')
            if not os.path.exists(checkpoint):
                os.system(
                    'wget https://bj.bcebos.com/paddlehub/model/image/keypoint_detection/hand_estimation.pdparams -O ' +
                    checkpoint)
            model_dict = paddle.load(checkpoint)[0]
            self.set_dict(model_dict)
            print("load pretrained checkpoint success")

    def make_layers(self, block: dict, no_relu_layers: list):
        layers = []
        for layer_name, v in block.items():
            if 'pool' in layer_name:
                layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2])
                layers.append((layer_name, layer))
            else:
                conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], kernel_size=v[2], stride=v[3], padding=v[4])
                layers.append((layer_name, conv2d))
                if layer_name not in no_relu_layers:
                    layers.append(('relu_' + layer_name, nn.ReLU()))
        layers = tuple(layers)
        return nn.Sequential(*layers)

    def forward(self, x: paddle.Tensor):
        out1_0 = self.model1_0(x)
        out1_1 = self.model1_1(out1_0)
        concat_stage2 = paddle.concat([out1_1, out1_0], 1)
        out_stage2 = self.model2(concat_stage2)
        concat_stage3 = paddle.concat([out_stage2, out1_0], 1)
        out_stage3 = self.model3(concat_stage3)
        concat_stage4 = paddle.concat([out_stage3, out1_0], 1)
        out_stage4 = self.model4(concat_stage4)
        concat_stage5 = paddle.concat([out_stage4, out1_0], 1)
        out_stage5 = self.model5(concat_stage5)
        concat_stage6 = paddle.concat([out_stage5, out1_0], 1)
        out_stage6 = self.model6(concat_stage6)
        return out_stage6

    def hand_estimation(self, handimg: np.ndarray, scale_search: list):
        heatmap_avg = np.zeros((handimg.shape[0], handimg.shape[1], 22))

        for scale in scale_search:
            process = self.resize_func(handimg, scale)
            imageToTest_padded, pad = self.pad_func(process)
            process = self.norm_func(imageToTest_padded)
            process = np.ascontiguousarray(np.transpose(process[:, :, :, np.newaxis], (3, 2, 0, 1))).astype("float32")
            data = self.forward(paddle.to_tensor(process))
            data = data.numpy()
            heatmap = self.remove_pad(data, imageToTest_padded, handimg, pad)
            heatmap_avg += heatmap / len(scale_search)

        all_peaks = []
        for part in range(21):
            map_ori = heatmap_avg[:, :, part]
            one_heatmap = gaussian_filter(map_ori, sigma=3)
            binary = np.ascontiguousarray(one_heatmap > 0.05, dtype=np.uint8)
            if np.sum(binary) == 0:
                all_peaks.append([0, 0])
                continue
            label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
            max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
            label_img[label_img != max_index] = 0
            map_ori[label_img == 0] = 0

            y, x = npmax(map_ori)
            all_peaks.append([x, y])

        return np.array(all_peaks)

    def predict(self, img_path: str, save_path: str = 'result', scale: list = [0.5, 1.0, 1.5, 2.0]):
        self.body_model = hub.Module(name='openpose_body_estimation')
        self.body_model.eval()
        org_img = cv2.imread(img_path)

        candidate, subset = self.body_model.predict(img_path)
        hands_list = self.hand_detect(candidate, subset, org_img)

        all_hand_peaks = []

        for x, y, w, is_left in hands_list:
            peaks = self.hand_estimation(org_img[y:y + w, x:x + w, :], scale)
            peaks[:, 0] = np.where(peaks[:, 0] == 0, peaks[:, 0], peaks[:, 0] + x)
            peaks[:, 1] = np.where(peaks[:, 1] == 0, peaks[:, 1], peaks[:, 1] + y)
            all_hand_peaks.append(peaks)

        if self.visualization:
            canvas = copy.deepcopy(org_img)
            canvas = self.draw_pose(canvas, candidate, subset)
            canvas = self.draw_hand(canvas, all_hand_peaks)
            if not os.path.exists(save_path):
                os.mkdir(save_path)
            save_path = os.path.join(save_path, img_path.rsplit("/", 1)[-1])
            cv2.imwrite(save_path, canvas)
        return all_hand_peaks


if __name__ == "__main__":
    import numpy as np

    paddle.disable_static()
    model = HandposeModel()
    model.eval()
    out1 = model.predict("detect_hand4.jpg")
