# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import cv2
import json
import copy
import numpy as np

try:
    from collections.abc import Sequence
except Exception:
    from collections import Sequence

from ppdet.core.workspace import register, serializable
from ppdet.data.crop_utils.annotation_cropper import AnnoCropper
from .coco import COCODataSet
from .dataset import _make_dataset, _is_valid_file
from ppdet.utils.logger import setup_logger

logger = setup_logger('sniper_coco_dataset')


@register
@serializable
class SniperCOCODataSet(COCODataSet):
    """SniperCOCODataSet"""

    def __init__(self,
                 dataset_dir=None,
                 image_dir=None,
                 anno_path=None,
                 proposals_file=None,
                 data_fields=['image'],
                 sample_num=-1,
                 load_crowd=False,
                 allow_empty=True,
                 empty_ratio=1.,
                 is_trainset=True,
                 image_target_sizes=[2000, 1000],
                 valid_box_ratio_ranges=[[-1, 0.1],[0.08, -1]],
                 chip_target_size=500,
                 chip_target_stride=200,
                 use_neg_chip=False,
                 max_neg_num_per_im=8,
                 max_per_img=-1,
                 nms_thresh=0.5):
        super(SniperCOCODataSet, self).__init__(
            dataset_dir=dataset_dir,
            image_dir=image_dir,
            anno_path=anno_path,
            data_fields=data_fields,
            sample_num=sample_num,
            load_crowd=load_crowd,
            allow_empty=allow_empty,
            empty_ratio=empty_ratio
        )
        self.proposals_file = proposals_file
        self.proposals = None
        self.anno_cropper = None
        self.is_trainset = is_trainset
        self.image_target_sizes = image_target_sizes
        self.valid_box_ratio_ranges = valid_box_ratio_ranges
        self.chip_target_size = chip_target_size
        self.chip_target_stride = chip_target_stride
        self.use_neg_chip = use_neg_chip
        self.max_neg_num_per_im = max_neg_num_per_im
        self.max_per_img = max_per_img
        self.nms_thresh = nms_thresh


    def parse_dataset(self):
        if not hasattr(self, "roidbs"):
            super(SniperCOCODataSet, self).parse_dataset()
        if self.is_trainset:
            self._parse_proposals()
            self._merge_anno_proposals()
        self.ori_roidbs = copy.deepcopy(self.roidbs)
        self.init_anno_cropper()
        self.roidbs = self.generate_chips_roidbs(self.roidbs, self.is_trainset)

    def set_proposals_file(self, file_path):
        self.proposals_file = file_path

    def init_anno_cropper(self):
        logger.info("Init AnnoCropper...")
        self.anno_cropper = AnnoCropper(
            image_target_sizes=self.image_target_sizes,
            valid_box_ratio_ranges=self.valid_box_ratio_ranges,
            chip_target_size=self.chip_target_size,
            chip_target_stride=self.chip_target_stride,
            use_neg_chip=self.use_neg_chip,
            max_neg_num_per_im=self.max_neg_num_per_im,
            max_per_img=self.max_per_img,
            nms_thresh=self.nms_thresh
        )

    def generate_chips_roidbs(self, roidbs, is_trainset):
        if is_trainset:
            roidbs = self.anno_cropper.crop_anno_records(roidbs)
        else:
            roidbs = self.anno_cropper.crop_infer_anno_records(roidbs)
        return roidbs

    def _parse_proposals(self):
        if self.proposals_file:
            self.proposals = {}
            logger.info("Parse proposals file:{}".format(self.proposals_file))
            with open(self.proposals_file, 'r') as f:
                proposals = json.load(f)
            for prop in proposals:
                image_id = prop["image_id"]
                if image_id not in self.proposals:
                    self.proposals[image_id] = []
                x, y, w, h = prop["bbox"]
                self.proposals[image_id].append([x, y, x + w, y + h])

    def _merge_anno_proposals(self):
        assert self.roidbs
        if self.proposals and len(self.proposals.keys()) > 0:
            logger.info("merge proposals to annos")
            for id, record in enumerate(self.roidbs):
                image_id = int(record["im_id"])
                if image_id not in self.proposals.keys():
                    logger.info("image id :{} no proposals".format(image_id))
                record["proposals"] = np.array(self.proposals.get(image_id, []), dtype=np.float32)
                self.roidbs[id] = record

    def get_ori_roidbs(self):
        if not hasattr(self, "ori_roidbs"):
            return None
        return self.ori_roidbs

    def get_roidbs(self):
        if not hasattr(self, "roidbs"):
            self.parse_dataset()
        return self.roidbs

    def set_roidbs(self, roidbs):
        self.roidbs = roidbs

    def check_or_download_dataset(self):
        return

    def _parse(self):
        image_dir = self.image_dir
        if not isinstance(image_dir, Sequence):
            image_dir = [image_dir]
        images = []
        for im_dir in image_dir:
            if os.path.isdir(im_dir):
                im_dir = os.path.join(self.dataset_dir, im_dir)
                images.extend(_make_dataset(im_dir))
            elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
                images.append(im_dir)
        return images

    def _load_images(self):
        images = self._parse()
        ct = 0
        records = []
        for image in images:
            assert image != '' and os.path.isfile(image), \
                "Image {} not found".format(image)
            if self.sample_num > 0 and ct >= self.sample_num:
                break
            im = cv2.imread(image)
            h, w, c = im.shape
            rec = {'im_id': np.array([ct]), 'im_file': image, "h": h, "w": w}
            self._imid2path[ct] = image
            ct += 1
            records.append(rec)
        assert len(records) > 0, "No image file found"
        return records

    def get_imid2path(self):
        return self._imid2path

    def set_images(self, images):
        self._imid2path = {}
        self.image_dir = images
        self.roidbs = self._load_images()