# Copyright 2020 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YOLOV3 dataset."""
import os

from PIL import Image
from pycocotools.coco import COCO
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as CV
from src.distributed_sampler import DistributedSampler
from src.transforms import reshape_fn, MultiScaleTrans

min_keypoints_per_image = 10

def _has_only_empty_bbox(anno):
    return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)

def _count_visible_keypoints(anno):
    return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)

def has_valid_annotation(anno):
    """Check annotation file."""
    # if it's empty, there is no annotation
    if not anno:
        return False
    # if all boxes have close to zero area, there is no annotation
    if _has_only_empty_bbox(anno):
        return False
    # keypoints task have a slight different critera for considering
    # if an annotation is valid
    if "keypoints" not in anno[0]:
        return True
    # for keypoint detection tasks, only consider valid images those
    # containing at least min_keypoints_per_image
    if _count_visible_keypoints(anno) >= min_keypoints_per_image:
        return True
    return False

class COCOYoloDataset:
    """YOLOV3 Dataset for COCO."""
    def __init__(self, root, ann_file, remove_images_without_annotations=True,
                 filter_crowd_anno=True, is_training=True):
        self.coco = COCO(ann_file)
        self.root = root
        self.img_ids = list(sorted(self.coco.imgs.keys()))
        self.filter_crowd_anno = filter_crowd_anno
        self.is_training = is_training

        # filter images without any annotations
        if remove_images_without_annotations:
            img_ids = []
            for img_id in self.img_ids:
                ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
                anno = self.coco.loadAnns(ann_ids)
                if has_valid_annotation(anno):
            self.img_ids = img_ids

        self.categories = {cat["id"]: cat["name"] for cat in self.coco.cats.values()}

        self.cat_ids_to_continuous_ids = {
            v: i for i, v in enumerate(self.coco.getCatIds())
        self.continuous_ids_cat_ids = {
            v: k for k, v in self.cat_ids_to_continuous_ids.items()

    def __getitem__(self, index):
            index (int): Index

            (img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints",
                generated by the image's annotation. img is a PIL image.
        coco = self.coco
        img_id = self.img_ids[index]
        img_path = coco.loadImgs(img_id)[0]["file_name"]
        img = Image.open(os.path.join(self.root, img_path)).convert("RGB")
        if not self.is_training:
            return img, img_id

        ann_ids = coco.getAnnIds(imgIds=img_id)
        target = coco.loadAnns(ann_ids)
        # filter crowd annotations
        if self.filter_crowd_anno:
            annos = [anno for anno in target if anno["iscrowd"] == 0]
            annos = [anno for anno in target]

        target = {}
        boxes = [anno["bbox"] for anno in annos]
        target["bboxes"] = boxes

        classes = [anno["category_id"] for anno in annos]
        classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes]
        target["labels"] = classes

        bboxes = target['bboxes']
        labels = target['labels']
        out_target = []
        for bbox, label in zip(bboxes, labels):
            tmp = []
            # convert to [x_min y_min x_max y_max]
            bbox = self._convetTopDown(bbox)
            # tmp [x_min y_min x_max y_max, label]
        return img, out_target

    def __len__(self):
        return len(self.img_ids)

    def _convetTopDown(self, bbox):
        x_min = bbox[0]
        y_min = bbox[1]
        w = bbox[2]
        h = bbox[3]
        return [x_min, y_min, x_min+w, y_min+h]

def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank,
                        config=None, is_training=True, shuffle=True):
    """Create dataset for YOLOV3."""
    if is_training:
        filter_crowd = True
        remove_empty_anno = True
        filter_crowd = False
        remove_empty_anno = False

    yolo_dataset = COCOYoloDataset(root=image_dir, ann_file=anno_path, filter_crowd_anno=filter_crowd,
                                   remove_images_without_annotations=remove_empty_anno, is_training=is_training)
    distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle)
    hwc_to_chw = CV.HWC2CHW()

    config.dataset_size = len(yolo_dataset)
    num_parallel_workers1 = int(64 / device_num)
    num_parallel_workers2 = int(16 / device_num)
    if is_training:
        multi_scale_trans = MultiScaleTrans(config, device_num)
        if device_num != 8:
            ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"],
            ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'],
                          num_parallel_workers=num_parallel_workers2, drop_remainder=True)
            ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler)
            ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'],
                          num_parallel_workers=8, drop_remainder=True)
        ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"],
        compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config))
        ds = ds.map(input_columns=["image", "img_id"],
                    output_columns=["image", "image_shape", "img_id"],
                    column_order=["image", "image_shape", "img_id"],
                    operations=compose_map_func, num_parallel_workers=8)
        ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=8)
        ds = ds.batch(batch_size, drop_remainder=True)
    ds = ds.repeat(max_epoch)

    return ds, len(yolo_dataset)