cv_reader.py 4.2 KB
Newer Older
S
Steffy-zxf 已提交
1
#coding:utf-8
W
wuzewu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

W
wuzewu 已提交
20
import paddle
W
wuzewu 已提交
21 22 23
import numpy as np
from PIL import Image

W
wuzewu 已提交
24
import paddlehub.io.augmentation as image_augmentation
W
wuzewu 已提交
25

26
channel_order_dict = {
W
wuzewu 已提交
27 28 29 30 31 32 33 34 35
    "RGB": [0, 1, 2],
    "RBG": [0, 2, 1],
    "GBR": [1, 2, 0],
    "GRB": [1, 0, 2],
    "BGR": [2, 1, 0],
    "BRG": [2, 0, 1]
}


W
wuzewu 已提交
36
class ImageClassificationReader(object):
W
wuzewu 已提交
37 38 39 40
    def __init__(self,
                 image_width,
                 image_height,
                 dataset,
41 42 43
                 channel_order="RGB",
                 images_mean=None,
                 images_std=None,
W
wuzewu 已提交
44 45 46
                 data_augmentation=False):
        self.image_width = image_width
        self.image_height = image_height
47
        self.channel_order = channel_order
W
wuzewu 已提交
48 49
        self.dataset = dataset
        self.data_augmentation = data_augmentation
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
        self.images_std = images_std
        self.images_mean = images_mean

        if self.images_mean is None:
            try:
                self.images_mean = self.dataset.images_mean
            except:
                self.images_mean = [0, 0, 0]
        self.images_mean = np.array(self.images_mean).reshape(3, 1, 1)

        if self.images_std is None:
            try:
                self.images_std = self.dataset.images_std
            except:
                self.images_std = [1, 1, 1]
        self.images_std = np.array(self.images_std).reshape(3, 1, 1)

        if self.channel_order not in channel_order_dict:
W
wuzewu 已提交
68
            raise ValueError(
69
                "The channel_order should in %s." % channel_order_dict.keys())
W
wuzewu 已提交
70 71 72 73

        if self.image_width <= 0 or self.image_height <= 0:
            raise ValueError("Image width and height should not be negative.")

W
wuzewu 已提交
74 75 76 77 78
    def data_generator(self,
                       batch_size,
                       phase="train",
                       shuffle=False,
                       data=None):
W
wuzewu 已提交
79 80 81 82 83
        if phase == "train":
            data = self.dataset.train_data(shuffle)
        elif phase == "test":
            shuffle = False
            data = self.dataset.test_data(shuffle)
W
wuzewu 已提交
84
        elif phase == "val" or phase == "dev":
W
wuzewu 已提交
85 86
            shuffle = False
            data = self.dataset.validate_data(shuffle)
W
wuzewu 已提交
87 88 89 90 91 92 93 94 95
        elif phase == "predict":
            data = data

        def preprocess(image_path):
            image = Image.open(image_path)
            image = image_augmentation.image_resize(image, self.image_width,
                                                    self.image_height)
            if self.data_augmentation:
                image = image_augmentation.image_random_process(
W
wuzewu 已提交
96
                    image, enable_resize=False, enable_crop=False)
W
wuzewu 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112

            # only support RGB
            image = image.convert('RGB')

            # HWC to CHW
            image = np.array(image).astype('float32')
            if len(image.shape) == 3:
                image = np.swapaxes(image, 1, 2)
                image = np.swapaxes(image, 1, 0)

            # standardization
            image /= 255
            image -= self.images_mean
            image /= self.images_std
            image = image[channel_order_dict[self.channel_order], :, :]
            return image
W
wuzewu 已提交
113 114

        def _data_reader():
W
wuzewu 已提交
115 116 117 118 119 120 121 122
            if phase == "predict":
                for image_path in data:
                    image = preprocess(image_path)
                    yield (image, )
            else:
                for image_path, label in data:
                    image = preprocess(image_path)
                    yield (image, label)
W
wuzewu 已提交
123

W
wuzewu 已提交
124
        return paddle.batch(_data_reader, batch_size=batch_size)