data_feed.py 2.1 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6
# -*- coding:utf-8 -*-
import os
import time
from collections import OrderedDict

import cv2
jm_12138's avatar
jm_12138 已提交
7
import numpy as np
W
wuzewu 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26

__all__ = ['reader']


def reader(images=None, paths=None, org_labels=None, target_labels=None):
    """
    Preprocess to yield image.

    Args:
        images (list(numpy.ndarray)): images data, shape of each is [H, W, C]
        paths (list[str]): paths to images.

    Yield:
        each (collections.OrderedDict): info of original image, preprocessed image.
    """
    component = list()
    if paths:
        for i, im_path in enumerate(paths):
            each = OrderedDict()
jm_12138's avatar
jm_12138 已提交
27 28
            assert os.path.isfile(
                im_path), "The {} isn't a valid file path.".format(im_path)
W
wuzewu 已提交
29 30 31 32 33
            im = cv2.imread(im_path)
            each['org_im'] = im
            each['org_im_path'] = im_path
            each['org_label'] = np.array(org_labels[i]).astype('float32')
            if not target_labels:
jm_12138's avatar
jm_12138 已提交
34 35
                each['target_label'] = np.array(
                    org_labels[i]).astype('float32')
W
wuzewu 已提交
36
            else:
jm_12138's avatar
jm_12138 已提交
37 38
                each['target_label'] = np.array(
                    target_labels[i]).astype('float32')
W
wuzewu 已提交
39 40 41 42 43 44
            component.append(each)
    if images is not None:
        assert type(images) is list, "images should be a list."
        for i, im in enumerate(images):
            each = OrderedDict()
            each['org_im'] = im
jm_12138's avatar
jm_12138 已提交
45 46
            each['org_im_path'] = 'ndarray_time={}'.format(
                round(time.time(), 6) * 1e6)
W
wuzewu 已提交
47 48
            each['org_label'] = np.array(org_labels[i]).astype('float32')
            if not target_labels:
jm_12138's avatar
jm_12138 已提交
49 50
                each['target_label'] = np.array(
                    org_labels[i]).astype('float32')
W
wuzewu 已提交
51
            else:
jm_12138's avatar
jm_12138 已提交
52 53
                each['target_label'] = np.array(
                    target_labels[i]).astype('float32')
W
wuzewu 已提交
54 55 56 57 58 59 60 61 62 63
            component.append(each)

    for element in component:
        img = cv2.cvtColor(element['org_im'], cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (128, 128), interpolation=cv2.INTER_LINEAR)
        img = (img.astype('float32') / 255.0 - 0.5) / 0.5
        img = img.transpose([2, 0, 1])
        element['img'] = img[np.newaxis, :, :, :]

        yield element