# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """YOLOV3 dataset.""" import os import multiprocessing from PIL import Image from pycocotools.coco import COCO import mindspore.dataset as de import mindspore.dataset.vision.c_transforms as CV from src.distributed_sampler import DistributedSampler from src.transforms import reshape_fn, MultiScaleTrans min_keypoints_per_image = 10 def _has_only_empty_bbox(anno): return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) def _count_visible_keypoints(anno): return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) def has_valid_annotation(anno): """Check annotation file.""" # if it's empty, there is no annotation if not anno: return False # if all boxes have close to zero area, there is no annotation if _has_only_empty_bbox(anno): return False # keypoints task have a slight different critera for considering # if an annotation is valid if "keypoints" not in anno[0]: return True # for keypoint detection tasks, only consider valid images those # containing at least min_keypoints_per_image if _count_visible_keypoints(anno) >= min_keypoints_per_image: return True return False class COCOYoloDataset: """YOLOV3 Dataset for COCO.""" def __init__(self, root, ann_file, remove_images_without_annotations=True, filter_crowd_anno=True, is_training=True): self.coco = COCO(ann_file) self.root = root self.img_ids = list(sorted(self.coco.imgs.keys())) self.filter_crowd_anno = filter_crowd_anno self.is_training = is_training # filter images without any annotations if remove_images_without_annotations: img_ids = [] for img_id in self.img_ids: ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) anno = self.coco.loadAnns(ann_ids) if has_valid_annotation(anno): img_ids.append(img_id) self.img_ids = img_ids self.categories = {cat["id"]: cat["name"] for cat in self.coco.cats.values()} self.cat_ids_to_continuous_ids = { v: i for i, v in enumerate(self.coco.getCatIds()) } self.continuous_ids_cat_ids = { v: k for k, v in self.cat_ids_to_continuous_ids.items() } def __getitem__(self, index): """ Args: index (int): Index Returns: (img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints", generated by the image's annotation. img is a PIL image. """ coco = self.coco img_id = self.img_ids[index] img_path = coco.loadImgs(img_id)[0]["file_name"] img = Image.open(os.path.join(self.root, img_path)).convert("RGB") if not self.is_training: return img, img_id ann_ids = coco.getAnnIds(imgIds=img_id) target = coco.loadAnns(ann_ids) # filter crowd annotations if self.filter_crowd_anno: annos = [anno for anno in target if anno["iscrowd"] == 0] else: annos = [anno for anno in target] target = {} boxes = [anno["bbox"] for anno in annos] target["bboxes"] = boxes classes = [anno["category_id"] for anno in annos] classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes] target["labels"] = classes bboxes = target['bboxes'] labels = target['labels'] out_target = [] for bbox, label in zip(bboxes, labels): tmp = [] # convert to [x_min y_min x_max y_max] bbox = self._convetTopDown(bbox) tmp.extend(bbox) tmp.append(int(label)) # tmp [x_min y_min x_max y_max, label] out_target.append(tmp) return img, out_target, [], [], [], [], [], [] def __len__(self): return len(self.img_ids) def _convetTopDown(self, bbox): x_min = bbox[0] y_min = bbox[1] w = bbox[2] h = bbox[3] return [x_min, y_min, x_min+w, y_min+h] def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank, config=None, is_training=True, shuffle=True): """Create dataset for YOLOV3.""" if is_training: filter_crowd = True remove_empty_anno = True else: filter_crowd = False remove_empty_anno = False yolo_dataset = COCOYoloDataset(root=image_dir, ann_file=anno_path, filter_crowd_anno=filter_crowd, remove_images_without_annotations=remove_empty_anno, is_training=is_training) distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle) hwc_to_chw = CV.HWC2CHW() config.dataset_size = len(yolo_dataset) cores = multiprocessing.cpu_count() num_parallel_workers = int(cores / device_num) if is_training: multi_scale_trans = MultiScaleTrans(config, device_num) dataset_column_names = ["image", "annotation", "bbox1", "bbox2", "bbox3", "gt_box1", "gt_box2", "gt_box3"] if device_num != 8: ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, num_parallel_workers=min(32, num_parallel_workers), sampler=distributed_sampler) ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names, num_parallel_workers=min(32, num_parallel_workers), drop_remainder=True) else: ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, sampler=distributed_sampler) ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names, num_parallel_workers=min(8, num_parallel_workers), drop_remainder=True) else: ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"], sampler=distributed_sampler) compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config)) ds = ds.map(input_columns=["image", "img_id"], output_columns=["image", "image_shape", "img_id"], column_order=["image", "image_shape", "img_id"], operations=compose_map_func, num_parallel_workers=8) ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=8) ds = ds.batch(batch_size, drop_remainder=True) ds = ds.repeat(max_epoch) return ds, len(yolo_dataset)