# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import numpy as np import os import random import time import copy import cv2 import box_utils import image_utils from pycocotools.coco import COCO from data_utils import GeneratorEnqueuer from config.config import cfg class DataSetReader(object): """A class for parsing and read COCO dataset""" def __init__(self): self.has_parsed_categpry = False def _parse_dataset_dir(self, mode): # cfg.data_dir = "dataset/coco" # cfg.train_file_list = 'annotations/instances_val2017.json' # cfg.train_data_dir = 'val2017' # cfg.dataset = "coco2017" if 'coco2014' in cfg.dataset: cfg.train_file_list = 'annotations/instances_train2014.json' cfg.train_data_dir = 'train2014' cfg.val_file_list = 'annotations/instances_val2014.json' cfg.val_data_dir = 'val2014' elif 'coco2017' in cfg.dataset: cfg.train_file_list = 'annotations/instances_train2017.json' cfg.train_data_dir = 'train2017' cfg.val_file_list = 'annotations/instances_val2017.json' cfg.val_data_dir = 'val2017' else: raise NotImplementedError('Dataset {} not supported'.format( cfg.dataset)) if mode == 'train': cfg.train_file_list = os.path.join(cfg.data_dir, cfg.train_file_list) cfg.train_data_dir = os.path.join(cfg.data_dir, cfg.train_data_dir) self.COCO = COCO(cfg.train_file_list) self.img_dir = cfg.train_data_dir elif mode == 'test' or mode == 'infer': cfg.val_file_list = os.path.join(cfg.data_dir, cfg.val_file_list) cfg.val_data_dir = os.path.join(cfg.data_dir, cfg.val_data_dir) self.COCO = COCO(cfg.val_file_list) self.img_dir = cfg.val_data_dir def _parse_dataset_catagory(self): self.categories = self.COCO.loadCats(self.COCO.getCatIds()) self.num_category = len(self.categories) self.label_names = [] self.label_ids = [] for category in self.categories: self.label_names.append(category['name']) self.label_ids.append(int(category['id'])) self.category_to_id_map = { v: i for i, v in enumerate(self.label_ids) } print("Load in {} categories.".format(self.num_category)) self.has_parsed_categpry = True def get_label_infos(self): if not self.has_parsed_categpry: self._parse_dataset_dir("test") self._parse_dataset_catagory() return (self.label_names, self.label_ids) def _parse_gt_annotations(self, img): img_height = img['height'] img_width = img['width'] anno = self.COCO.loadAnns(self.COCO.getAnnIds(imgIds=img['id'], iscrowd=None)) gt_index = 0 for target in anno: if target['area'] < cfg.gt_min_area: continue if target.has_key('ignore') and target['ignore']: continue box = box_utils.coco_anno_box_to_center_relative(target['bbox'], img_height, img_width) if box[2] <= 0 and box[3] <= 0: continue img['gt_id'][gt_index] = np.int32(target['id']) img['gt_boxes'][gt_index] = box img['gt_labels'][gt_index] = self.category_to_id_map[target['category_id']] gt_index += 1 if gt_index >= cfg.max_box_num: break def _parse_images(self, is_train): image_ids = self.COCO.getImgIds() image_ids.sort() imgs = copy.deepcopy(self.COCO.loadImgs(image_ids)) imgs = imgs[-8:] for img in imgs: img['image'] = os.path.join(self.img_dir, img['file_name']) assert os.path.exists(img['image']), \ "image {} not found.".format(img['image']) box_num = cfg.max_box_num img['gt_id'] = np.zeros((cfg.max_box_num), dtype=np.int32) img['gt_boxes'] = np.zeros((cfg.max_box_num, 4), dtype=np.float32) img['gt_labels'] = np.zeros((cfg.max_box_num), dtype=np.int32) for k in ['date_captured', 'url', 'license', 'file_name']: if img.has_key(k): del img[k] if is_train: self._parse_gt_annotations(img) print("Loaded {0} images from {1}.".format(len(imgs), cfg.dataset)) return imgs def _parse_images_by_mode(self, mode): if mode == 'infer': return [] else: return self._parse_images(is_train=(mode=='train')) def get_reader(self, mode, size=416, batch_size=None, shuffle=False, random_shape_iter=0, random_sizes=[], image=None): assert mode in ['train', 'test', 'infer'], "Unknow mode type!" if mode != 'infer': assert batch_size is not None, "batch size connot be None in mode {}".format(mode) self._parse_dataset_dir(mode) self._parse_dataset_catagory() def img_reader(img, size, mean, std): im_path = img['image'] im = cv2.imread(im_path).astype('float32') im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) h, w, _ = im.shape im_scale_x = size / float(w) im_scale_y = size / float(h) out_img = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=cv2.INTER_LINEAR) mean = np.array(mean).reshape((1, 1, -1)) std = np.array(std).reshape((1, 1, -1)) out_img = (out_img / 255.0 - mean) / std out_img = out_img.transpose((2, 0, 1)) return (out_img, int(img['id']), (h, w)) def img_reader_with_augment(img, size, mean, std, mixup_img): im_path = img['image'] im = cv2.imread(im_path) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) gt_boxes = img['gt_boxes'].copy() gt_labels = img['gt_labels'].copy() gt_scores = np.ones_like(gt_labels) im, gt_boxes, gt_labels, gt_scores = image_utils.image_augment(im, gt_boxes, gt_labels, gt_scores, size, mean) mean = np.array(mean).reshape((1, 1, -1)) std = np.array(std).reshape((1, 1, -1)) out_img = (im / 255.0 - mean) / std out_img = out_img.astype('float32').transpose((2, 0, 1)) return (out_img, gt_boxes, gt_labels, gt_scores) def get_img_size(size, random_sizes=[]): if len(random_sizes): return np.random.choice(random_sizes) return size def reader(): if mode == 'train': imgs = self._parse_images_by_mode(mode) if shuffle: np.random.shuffle(imgs) read_cnt = 0 total_iter = 0 batch_out = [] img_size = get_img_size(size, random_sizes) # img_ids = [] while True: img = imgs[read_cnt % len(imgs)] mixup_img = None read_cnt += 1 if read_cnt % len(imgs) == 0 and shuffle: np.random.shuffle(imgs) im, gt_boxes, gt_labels, gt_scores = img_reader_with_augment(img, img_size, cfg.pixel_means, cfg.pixel_stds, mixup_img) batch_out.append([im, gt_boxes, gt_labels, gt_scores]) # img_ids.append((img['id'], mixup_img['id'] if mixup_img else -1)) if len(batch_out) == batch_size: # print("img_ids: ", img_ids) yield batch_out batch_out = [] total_iter += 1 if total_iter % 10 == 0: img_size = get_img_size(size, random_sizes) # img_ids = [] elif mode == 'test': imgs = self._parse_images_by_mode(mode) batch_out = [] for img in imgs: im, im_id, im_shape = img_reader(img, size, cfg.pixel_means, cfg.pixel_stds) batch_out.append((im, im_id, im_shape)) if len(batch_out) == batch_size: yield batch_out batch_out = [] if len(batch_out) != 0: yield batch_out else: img = {} img['image'] = image img['id'] = 0 im, im_id, im_shape = img_reader(img, size, cfg.pixel_means, cfg.pixel_stds) batch_out = [(im, im_id, im_shape)] yield batch_out return reader dsr = DataSetReader() def train(size=416, batch_size=64, shuffle=True, random_shape_iter=0, random_sizes=[], interval=10, pyreader_num=1, num_workers=16, max_queue=32, use_multiprocessing=True): generator = dsr.get_reader('train', size, batch_size, shuffle, random_shape_iter, random_sizes) if not use_multiprocessing: return generator def infinite_reader(): while True: for data in generator(): yield data def reader(): try: enqueuer = GeneratorEnqueuer( infinite_reader(), use_multiprocessing=True) enqueuer.start(max_queue_size=max_queue, workers=num_workers, random_sizes=random_sizes) generator_out = None np.random.seed(1000) intervals = pyreader_num * interval total_random_iter = pyreader_num * random_shape_iter cnt = 0 idx = len(random_sizes) - 1 while True: while enqueuer.is_running(): if not enqueuer.queues[idx].empty(): generator_out = enqueuer.queues[idx].get() break else: time.sleep(0.02) yield generator_out generator_out = None cnt += 1 if cnt % intervals == 0: idx = np.random.randint(len(random_sizes)) if cnt >= total_random_iter: idx = -1 print("Resizing: ", random_sizes[idx]) finally: if enqueuer is not None: enqueuer.stop() return reader def test(size=416, batch_size=1): return dsr.get_reader('test', size, batch_size) def infer(size=416, image=None): return dsr.get_reader('infer', size, image=image) def get_label_infos(): return dsr.get_label_infos()