reader.py 6.1 KB
Newer Older
J
jerrywgz 已提交
1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
J
jerrywgz 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.utils.image_util import *
import random
from PIL import Image
from PIL import ImageDraw
import numpy as np
import xml.etree.ElementTree
import os
import time
import copy
import six
25
from collections import deque
J
jerrywgz 已提交
26 27 28

from roidbs import JsonDataset
import data_utils
J
jerrywgz 已提交
29
from config import cfg
J
jerrywgz 已提交
30 31


J
jerrywgz 已提交
32
def coco(mode,
33 34 35 36
         batch_size=None,
         total_batch_size=None,
         padding_total=False,
         shuffle=False):
J
jerrywgz 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
    if 'coco2014' in cfg.dataset:
        cfg.train_file_list = 'annotations/instances_train2014.json'
        cfg.train_data_dir = 'train2014'
        cfg.val_file_list = 'annotations/instances_val2014.json'
        cfg.val_data_dir = 'val2014'
    elif 'coco2017' in cfg.dataset:
        cfg.train_file_list = 'annotations/instances_train2017.json'
        cfg.train_data_dir = 'train2017'
        cfg.val_file_list = 'annotations/instances_val2017.json'
        cfg.val_data_dir = 'val2017'
    else:
        raise NotImplementedError('Dataset {} not supported'.format(
            cfg.dataset))
    cfg.mean_value = np.array(cfg.pixel_means)[np.newaxis,
                                               np.newaxis, :].astype('float32')
52
    total_batch_size = total_batch_size if total_batch_size else batch_size
J
jerrywgz 已提交
53 54
    if mode != 'infer':
        assert total_batch_size % batch_size == 0
J
jerrywgz 已提交
55
    if mode == 'train':
J
jerrywgz 已提交
56 57
        cfg.train_file_list = os.path.join(cfg.data_dir, cfg.train_file_list)
        cfg.train_data_dir = os.path.join(cfg.data_dir, cfg.train_data_dir)
J
jerrywgz 已提交
58
    elif mode == 'test' or mode == 'infer':
J
jerrywgz 已提交
59 60 61
        cfg.val_file_list = os.path.join(cfg.data_dir, cfg.val_file_list)
        cfg.val_data_dir = os.path.join(cfg.data_dir, cfg.val_data_dir)
    json_dataset = JsonDataset(train=(mode == 'train'))
J
jerrywgz 已提交
62 63
    roidbs = json_dataset.get_roidb()

J
jerrywgz 已提交
64
    print("{} on {} with {} roidbs".format(mode, cfg.dataset, len(roidbs)))
J
jerrywgz 已提交
65

J
jerrywgz 已提交
66
    def roidb_reader(roidb, mode):
J
jerrywgz 已提交
67
        im, im_scales = data_utils.get_image_blob(roidb, mode)
68 69 70 71
        im_id = roidb['id']
        im_height = np.round(roidb['height'] * im_scales)
        im_width = np.round(roidb['width'] * im_scales)
        im_info = np.array([im_height, im_width, im_scales], dtype=np.float32)
J
jerrywgz 已提交
72 73
        if mode == 'test' or mode == 'infer':
            return im, im_info, im_id
74 75 76 77
        gt_boxes = roidb['gt_boxes'].astype('float32')
        gt_classes = roidb['gt_classes'].astype('int32')
        is_crowd = roidb['is_crowd'].astype('int32')
        return im, gt_boxes, gt_classes, is_crowd, im_info, im_id
J
jerrywgz 已提交
78

79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
    def padding_minibatch(batch_data):
        if len(batch_data) == 1:
            return batch_data

        max_shape = np.array([data[0].shape for data in batch_data]).max(axis=0)

        padding_batch = []
        for data in batch_data:
            im_c, im_h, im_w = data[0].shape[:]
            padding_im = np.zeros(
                (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
            padding_im[:, :im_h, :im_w] = data[0]
            padding_batch.append((padding_im, ) + data[1:])
        return padding_batch

94 95 96 97 98 99 100 101 102 103 104
    def reader():
        if mode == "train":
            roidb_perm = deque(np.random.permutation(roidbs))
            roidb_cur = 0
            batch_out = []
            while True:
                roidb = roidb_perm[0]
                roidb_cur += 1
                roidb_perm.rotate(-1)
                if roidb_cur >= len(roidbs):
                    roidb_perm = deque(np.random.permutation(roidbs))
J
jerrywgz 已提交
105
                    roidb_cur = 0
106
                im, gt_boxes, gt_classes, is_crowd, im_info, im_id = roidb_reader(
J
jerrywgz 已提交
107
                    roidb, mode)
108 109 110 111
                if gt_boxes.shape[0] == 0:
                    continue
                batch_out.append(
                    (im, gt_boxes, gt_classes, is_crowd, im_info, im_id))
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
                if not padding_total:
                    if len(batch_out) == batch_size:
                        yield padding_minibatch(batch_out)
                        batch_out = []
                else:
                    if len(batch_out) == total_batch_size:
                        batch_out = padding_minibatch(batch_out)
                        for i in range(total_batch_size / batch_size):
                            sub_batch_out = []
                            for j in range(batch_size):
                                sub_batch_out.append(batch_out[i * batch_size +
                                                               j])
                            yield sub_batch_out
                            sub_batch_out = []
                        batch_out = []
J
jerrywgz 已提交
127 128

        elif mode == "test":
129 130
            batch_out = []
            for roidb in roidbs:
J
jerrywgz 已提交
131 132
                im, im_info, im_id = roidb_reader(roidb, mode)
                batch_out.append((im, im_info, im_id))
133 134 135
                if len(batch_out) == batch_size:
                    yield batch_out
                    batch_out = []
J
jerrywgz 已提交
136 137 138 139 140
            if len(batch_out) != 0:
                yield batch_out

        else:
            for roidb in roidbs:
J
jerrywgz 已提交
141
                if cfg.image_name not in roidb['image']:
J
jerrywgz 已提交
142
                    continue
J
jerrywgz 已提交
143 144 145
                im, im_info, im_id = roidb_reader(roidb, mode)
                batch_out = [(im, im_info, im_id)]
                yield batch_out
J
jerrywgz 已提交
146 147 148 149

    return reader


J
jerrywgz 已提交
150
def train(batch_size, total_batch_size=None, padding_total=False, shuffle=True):
151
    return coco(
J
jerrywgz 已提交
152
        'train', batch_size, total_batch_size, padding_total, shuffle=shuffle)
J
jerrywgz 已提交
153 154


J
jerrywgz 已提交
155 156
def test(batch_size, total_batch_size=None, padding_total=False):
    return coco('test', batch_size, total_batch_size, shuffle=False)
J
jerrywgz 已提交
157 158


J
jerrywgz 已提交
159 160
def infer():
    return coco('infer')