data_utils.py 2.8 KB
Newer Older
J
jerrywgz 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Based on:
# --------------------------------------------------------
# Detectron
# Copyright (c) 2017-present, Facebook, Inc.
# Licensed under the Apache License, Version 2.0;
# Written by Ross Girshick
# --------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import cv2
import numpy as np
J
jerrywgz 已提交
30
from config import cfg
J
jerrywgz 已提交
31 32


J
jerrywgz 已提交
33
def get_image_blob(roidb, mode):
J
jerrywgz 已提交
34 35 36
    """Builds an input blob from the images in the roidb at the specified
    scales.
    """
J
jerrywgz 已提交
37 38 39 40 41 42 43 44
    if mode == 'train':
        scales = cfg.TRAIN.scales
        scale_ind = np.random.randint(0, high=len(scales))
        target_size = scales[scale_ind]
        max_size = cfg.TRAIN.max_size
    else:
        target_size = cfg.TEST.scales[0]
        max_size = cfg.TEST.max_size
J
jerrywgz 已提交
45 46 47 48 49
    im = cv2.imread(roidb['image'])
    assert im is not None, \
        'Failed to read image \'{}\''.format(roidb['image'])
    if roidb['flipped']:
        im = im[:, ::-1, :]
J
jerrywgz 已提交
50
    im, im_scale = prep_im_for_blob(im, cfg.pixel_means, target_size, max_size)
J
jerrywgz 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81

    return im, im_scale


def prep_im_for_blob(im, pixel_means, target_size, max_size):
    """Prepare an image for use as a network input blob. Specially:
      - Subtract per-channel pixel mean
      - Convert to float32
      - Rescale to each of the specified target size (capped at max_size)
    Returns a list of transformed images, one for each target size. Also returns
    the scale factors that were used to compute each returned image.
    """
    im = im.astype(np.float32, copy=False)
    im -= pixel_means

    im_shape = im.shape
    im_size_min = np.min(im_shape[0:2])
    im_size_max = np.max(im_shape[0:2])
    im_scale = float(target_size) / float(im_size_min)
    # Prevent the biggest axis from being more than max_size
    if np.round(im_scale * im_size_max) > max_size:
        im_scale = float(max_size) / float(im_size_max)
    im = cv2.resize(
        im,
        None,
        None,
        fx=im_scale,
        fy=im_scale,
        interpolation=cv2.INTER_LINEAR)
    im_height, im_width, channel = im.shape
    channel_swap = (2, 0, 1)  #(batch, channel, height, width)
82 83
    im = im.transpose(channel_swap)
    return im, im_scale