coco_reader.py 10.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
# Copyright (c) 2018-present, Baidu, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Data reader for COCO dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import functools
import numpy as np
import cv2
import random

from utils.transforms import fliplr_joints
from utils.transforms import get_affine_transform
from utils.transforms import affine_transform
from lib.base_reader import visualize, generate_target
from pycocotools.coco import COCO

# NOTE
# -- COCO Datatset --
# "keypoints": 
# {
#   0: "nose",
#   1: "left_eye",
#   2: "right_eye",
#   3: "left_ear",
#   4: "right_ear",
#   5: "left_shoulder",
#   6: "right_shoulder",
#   7: "left_elbow",
#   8: "right_elbow",
#   9: "left_wrist",
#   10: "right_wrist",
#   11: "left_hip",
#   12: "right_hip",
#   13: "left_knee",
#   14: "right_knee",
#   15: "left_ankle",
#   16: "right_ankle"
# },
#
# "skeleton": 
# [
#   [16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13], [6,7],[6,8],
#   [7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]
# ]

u010070587's avatar
u010070587 已提交
62

63 64 65 66 67 68 69 70
class Config:
    """Configurations for COCO dataset.
    """
    DEBUG = False
    TMPDIR = 'tmp_fold_for_debug'

    # For reader
    BUF_SIZE = 102400
u010070587's avatar
u010070587 已提交
71
    THREAD = 1 if DEBUG else 8  # have to be larger than 0
72 73 74 75 76

    # Fixed infos of dataset
    DATAROOT = 'data/coco'
    IMAGEDIR = 'images'
    NUM_JOINTS = 17
u010070587's avatar
u010070587 已提交
77 78
    FLIP_PAIRS = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14],
                  [15, 16]]
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
    PARENT_IDS = None

    # CFGS
    SCALE_FACTOR = 0.3
    ROT_FACTOR = 40
    FLIP = True
    TARGET_TYPE = 'gaussian'
    SIGMA = 3
    IMAGE_SIZE = [288, 384]
    HEATMAP_SIZE = [72, 96]
    ASPECT_RATIO = IMAGE_SIZE[0] * 1.0 / IMAGE_SIZE[1]
    MEAN = [0.485, 0.456, 0.406]
    STD = [0.229, 0.224, 0.225]
    PIXEL_STD = 200

u010070587's avatar
u010070587 已提交
94

95 96
cfg = Config()

u010070587's avatar
u010070587 已提交
97

98 99 100 101
def _box2cs(box):
    x, y, w, h = box[:4]
    return _xywh2cs(x, y, w, h)

u010070587's avatar
u010070587 已提交
102

103 104 105 106 107 108 109 110 111 112
def _xywh2cs(x, y, w, h):
    center = np.zeros((2), dtype=np.float32)
    center[0] = x + w * 0.5
    center[1] = y + h * 0.5

    if w > cfg.ASPECT_RATIO * h:
        h = w * 1.0 / cfg.ASPECT_RATIO
    elif w < cfg.ASPECT_RATIO * h:
        w = h * cfg.ASPECT_RATIO
    scale = np.array(
u010070587's avatar
u010070587 已提交
113
        [w * 1.0 / cfg.PIXEL_STD, h * 1.0 / cfg.PIXEL_STD], dtype=np.float32)
114 115 116 117 118
    if center[0] != -1:
        scale = scale * 1.25

    return center, scale

u010070587's avatar
u010070587 已提交
119

120 121 122 123 124 125
def _select_data(db):
    db_selected = []
    for rec in db:
        num_vis = 0
        joints_x = 0.0
        joints_y = 0.0
u010070587's avatar
u010070587 已提交
126
        for joint, joint_vis in zip(rec['joints_3d'], rec['joints_3d_vis']):
127 128 129 130 131 132 133 134 135 136 137 138 139 140
            if joint_vis[0] <= 0:
                continue
            num_vis += 1

            joints_x += joint[0]
            joints_y += joint[1]
        if num_vis == 0:
            continue

        joints_x, joints_y = joints_x / num_vis, joints_y / num_vis

        area = rec['scale'][0] * rec['scale'][1] * (cfg.PIXEL_STD**2)
        joints_center = np.array([joints_x, joints_y])
        bbox_center = np.array(rec['center'])
u010070587's avatar
u010070587 已提交
141 142
        diff_norm2 = np.linalg.norm((joints_center - bbox_center), 2)
        ks = np.exp(-1.0 * (diff_norm2**2) / ((0.2)**2 * 2.0 * area))
143 144 145 146 147 148 149 150 151

        metric = (0.2 / 16) * num_vis + 0.45 - 0.2 / 16
        if ks > metric:
            db_selected.append(rec)

    print('=> num db: {}'.format(len(db)))
    print('=> num selected db: {}'.format(len(db_selected)))
    return db_selected

u010070587's avatar
u010070587 已提交
152 153 154

def _load_coco_keypoint_annotation(image_set_index, coco,
                                   _coco_ind_to_class_ind, image_set):
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
    """Ground truth bbox and keypoints.
    """
    print('generating coco gt_db...')
    gt_db = []
    for index in image_set_index:
        im_ann = coco.loadImgs(index)[0]
        width = im_ann['width']
        height = im_ann['height']

        annIds = coco.getAnnIds(imgIds=index, iscrowd=False)
        objs = coco.loadAnns(annIds)

        # Sanitize bboxes
        valid_objs = []
        for obj in objs:
            x, y, w, h = obj['bbox']
            x1 = np.max((0, x))
            y1 = np.max((0, y))
            x2 = np.min((width - 1, x1 + np.max((0, w - 1))))
            y2 = np.min((height - 1, y1 + np.max((0, h - 1))))
            if obj['area'] > 0 and x2 >= x1 and y2 >= y1:
u010070587's avatar
u010070587 已提交
176
                obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
                valid_objs.append(obj)
        objs = valid_objs

        rec = []
        for obj in objs:
            cls = _coco_ind_to_class_ind[obj['category_id']]
            if cls != 1:
                continue

            # Ignore objs without keypoints annotation
            if max(obj['keypoints']) == 0:
                continue

            joints_3d = np.zeros((cfg.NUM_JOINTS, 3), dtype=np.float)
            joints_3d_vis = np.zeros((cfg.NUM_JOINTS, 3), dtype=np.float)
            for ipt in range(cfg.NUM_JOINTS):
                joints_3d[ipt, 0] = obj['keypoints'][ipt * 3 + 0]
                joints_3d[ipt, 1] = obj['keypoints'][ipt * 3 + 1]
                joints_3d[ipt, 2] = 0
                t_vis = obj['keypoints'][ipt * 3 + 2]
                if t_vis > 1:
                    t_vis = 1
                joints_3d_vis[ipt, 0] = t_vis
                joints_3d_vis[ipt, 1] = t_vis
                joints_3d_vis[ipt, 2] = 0

            center, scale = _box2cs(obj['clean_bbox'][:4])
            rec.append({
u010070587's avatar
u010070587 已提交
205 206
                'image': os.path.join(cfg.DATAROOT, cfg.IMAGEDIR,
                                      image_set + '2017', '%012d.jpg' % index),
207 208 209 210 211 212 213 214 215 216 217
                'center': center,
                'scale': scale,
                'joints_3d': joints_3d,
                'joints_3d_vis': joints_3d_vis,
                'filename': '%012d.jpg' % index,
                'imgnum': 0,
            })

        gt_db.extend(rec)
    return gt_db

u010070587's avatar
u010070587 已提交
218

219 220 221 222 223 224 225
def data_augmentation(sample, is_train):
    image_file = sample['image']
    filename = sample['filename'] if 'filename' in sample else ''
    joints = sample['joints_3d']
    joints_vis = sample['joints_3d_vis']
    c = sample['center']
    s = sample['scale']
226
    score = sample['score'] if 'score' in sample else 1
227 228 229
    # imgnum = sample['imgnum'] if 'imgnum' in sample else ''
    r = 0

u010070587's avatar
u010070587 已提交
230 231 232 233 234 235 236
    # used for ce
    if 'ce_mode' in os.environ:
        random.seed(0)
        np.random.seed(0)

    data_numpy = cv2.imread(image_file, cv2.IMREAD_COLOR |
                            cv2.IMREAD_IGNORE_ORIENTATION)
237 238 239 240

    if is_train:
        sf = cfg.SCALE_FACTOR
        rf = cfg.ROT_FACTOR
u010070587's avatar
u010070587 已提交
241
        s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
242 243 244 245 246 247
        r = np.clip(np.random.randn()*rf, -rf*2, rf*2) \
                if random.random() <= 0.6 else 0

        if cfg.FLIP and random.random() <= 0.5:
            data_numpy = data_numpy[:, ::-1, :]
            joints, joints_vis = fliplr_joints(
u010070587's avatar
u010070587 已提交
248
                joints, joints_vis, data_numpy.shape[1], cfg.FLIP_PAIRS)
249 250 251 252
            c[0] = data_numpy.shape[1] - c[0] - 1

    trans = get_affine_transform(c, s, r, cfg.IMAGE_SIZE)
    input = cv2.warpAffine(
u010070587's avatar
u010070587 已提交
253 254 255
        data_numpy,
        trans, (int(cfg.IMAGE_SIZE[0]), int(cfg.IMAGE_SIZE[1])),
        flags=cv2.INTER_LINEAR)
256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274

    for i in range(cfg.NUM_JOINTS):
        if joints_vis[i, 0] > 0.0:
            joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)

    # Numpy target
    target, target_weight = generate_target(cfg, joints, joints_vis)

    if cfg.DEBUG:
        visualize(cfg, filename, data_numpy, input.copy(), joints, target)

    # Normalization
    input = input.astype('float32').transpose((2, 0, 1)) / 255
    input -= np.array(cfg.MEAN).reshape((3, 1, 1))
    input /= np.array(cfg.STD).reshape((3, 1, 1))

    if is_train:
        return input, target, target_weight
    else:
275
        return input, target, target_weight, c, s, score, image_file
276 277


u010070587's avatar
u010070587 已提交
278 279 280 281 282 283
# Create a reader
def _reader_creator(root,
                    image_set,
                    shuffle=False,
                    is_train=False,
                    use_gt_bbox=False):
284 285
    def reader():
        if image_set in ['train', 'val']:
u010070587's avatar
u010070587 已提交
286 287 288
            file_name = os.path.join(
                root, 'annotations',
                'person_keypoints_' + image_set + '2017.json')
289
        elif image_set in ['test', 'test-dev']:
u010070587's avatar
u010070587 已提交
290 291
            file_name = os.path.join(root, 'annotations',
                                     'image_info_' + image_set + '2017.json')
292
        else:
u010070587's avatar
u010070587 已提交
293 294
            raise ValueError("The dataset '{}' is not supported".format(
                image_set))
295 296 297 298 299

        # Load annotations
        coco = COCO(file_name)

        # Deal with class names
u010070587's avatar
u010070587 已提交
300
        cats = [cat['name'] for cat in coco.loadCats(coco.getCatIds())]
301 302 303 304 305 306 307
        classes = ['__background__'] + cats
        print('=> classes: {}'.format(classes))
        num_classes = len(classes)
        _class_to_ind = dict(zip(classes, range(num_classes)))
        _class_to_coco_ind = dict(zip(cats, coco.getCatIds()))
        _coco_ind_to_class_ind = dict([(_class_to_coco_ind[cls],
                                        _class_to_ind[cls])
u010070587's avatar
u010070587 已提交
308
                                       for cls in classes[1:]])
309 310 311 312 313 314 315 316

        # Load image file names
        image_set_index = coco.getImgIds()
        num_images = len(image_set_index)
        print('=> num_images: {}'.format(num_images))

        if is_train or use_gt_bbox:
            gt_db = _load_coco_keypoint_annotation(
u010070587's avatar
u010070587 已提交
317
                image_set_index, coco, _coco_ind_to_class_ind, image_set)
318 319 320 321 322 323 324 325 326 327 328
            gt_db = _select_data(gt_db)

        if shuffle:
            random.shuffle(gt_db)

        for db in gt_db:
            yield db

    mapper = functools.partial(data_augmentation, is_train=is_train)
    return reader, mapper

u010070587's avatar
u010070587 已提交
329

330
def train():
u010070587's avatar
u010070587 已提交
331 332 333 334 335 336 337 338
    reader, mapper = _reader_creator(
        cfg.DATAROOT, 'train', shuffle=True, is_train=True)

    # used for ce
    if 'ce_mode' in os.environ:
        reader, mapper = _reader_creator(
            cfg.DATAROOT, 'train', shuffle=False, is_train=True)

339 340 341
    def pop():
        for i, x in enumerate(reader()):
            yield mapper(x)
u010070587's avatar
u010070587 已提交
342

343 344
    return pop

u010070587's avatar
u010070587 已提交
345

346
def valid():
u010070587's avatar
u010070587 已提交
347 348 349
    reader, mapper = _reader_creator(
        cfg.DATAROOT, 'val', shuffle=False, is_train=False, use_gt_bbox=True)

350 351 352
    def pop():
        for i, x in enumerate(reader()):
            yield mapper(x)
u010070587's avatar
u010070587 已提交
353

354 355
    return pop

u010070587's avatar
u010070587 已提交
356

357
def test():
u010070587's avatar
u010070587 已提交
358 359 360
    reader, mapper = _reader_creator(
        cfg.DATAROOT, 'test', shuffle=False, is_train=False, use_gt_bbox=True)

361 362 363
    def pop():
        for i, x in enumerate(reader()):
            yield mapper(x)
u010070587's avatar
u010070587 已提交
364

365
    return pop