dataset_traversal.py 6.4 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import os
T
tink2123 已提交
16
import sys
L
LDOUBLEV 已提交
17 18 19 20 21 22 23 24 25
import math
import random
import functools
import numpy as np
import cv2
import string
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from ppocr.utils.utility import create_module
L
LDOUBLEV 已提交
26
from ppocr.utils.utility import get_image_file_list
L
LDOUBLEV 已提交
27 28 29 30 31 32 33
import time


class TrainReader(object):
    def __init__(self, params):
        self.num_workers = params['num_workers']
        self.label_file_path = params['label_file_path']
L
licx 已提交
34 35 36 37 38
        print(self.label_file_path)
        self.use_mul_data = False
        if isinstance(self.label_file_path, list):
            self.use_mul_data = True
            self.data_ratio_list = params['data_ratio_list']
L
LDOUBLEV 已提交
39 40 41 42 43
        self.batch_size = params['train_batch_size_per_card']
        assert 'process_function' in params,\
            "absence process_function in Reader"
        self.process = create_module(params['process_function'])(params)

44
    def __call__(self, process_id):     
L
LDOUBLEV 已提交
45 46 47 48 49 50
        def sample_iter_reader():
            with open(self.label_file_path, "rb") as fin:
                label_infor_list = fin.readlines()
            img_num = len(label_infor_list)
            img_id_list = list(range(img_num))
            random.shuffle(img_id_list)
L
licx 已提交
51
            if sys.platform == "win32" and self.num_workers != 1:
T
tink2123 已提交
52 53 54
                print("multiprocess is not fully compatible with Windows."
                      "num_workers will be 1.")
                self.num_workers = 1
L
LDOUBLEV 已提交
55 56 57 58 59 60 61
            for img_id in range(process_id, img_num, self.num_workers):
                label_infor = label_infor_list[img_id_list[img_id]]
                outs = self.process(label_infor)
                if outs is None:
                    continue
                yield outs

L
licx 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75
        def sample_iter_reader_mul():
            batch_size = 1000
            data_source_list = self.label_file_path
            batch_size_list = list(map(int, [max(1.0, batch_size * x) for x in self.data_ratio_list]))
            print(self.data_ratio_list, batch_size_list)

            data_filename_list, data_size_list, fetch_record_list = [], [], []
            for data_source in data_source_list:
                image_files = open(data_source, "rb").readlines()
                random.shuffle(image_files)
                data_filename_list.append(image_files)
                data_size_list.append(len(image_files))
                fetch_record_list.append(0)

L
licx 已提交
76
            image_batch = []
L
licx 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
            # get a batch of img_fns and poly_fns
            for i in range(0, len(batch_size_list)):
                bs = batch_size_list[i]
                ds = data_size_list[i]
                image_names = data_filename_list[i]
                fetch_record = fetch_record_list[i]
                data_path = data_source_list[i]
                for j in range(fetch_record, fetch_record + bs):
                    index = j % ds
                    image_batch.append(image_names[index])

                if (fetch_record + bs) > ds:
                    fetch_record_list[i] = 0
                    random.shuffle(data_filename_list[i])
                else:
                    fetch_record_list[i] = fetch_record + bs

            if sys.platform == "win32":
                print("multiprocess is not fully compatible with Windows."
                      "num_workers will be 1.")
                self.num_workers = 1

            for label_infor in image_batch:
                outs = self.process(label_infor)
                if outs is None:
                    continue
                yield outs

L
LDOUBLEV 已提交
105 106
        def batch_iter_reader():
            batch_outs = []
L
licx 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119
            if self.use_mul_data:
                print("Sample date from multiple datasets!")
                for outs in sample_iter_reader_mul():
                    batch_outs.append(outs)
                    if len(batch_outs) == self.batch_size:
                        yield batch_outs
                        batch_outs = []                
            else:
                for outs in sample_iter_reader():
                    batch_outs.append(outs)
                    if len(batch_outs) == self.batch_size:
                        yield batch_outs
                        batch_outs = []
L
LDOUBLEV 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135

        return batch_iter_reader


class EvalTestReader(object):
    def __init__(self, params):
        self.params = params
        assert 'process_function' in params,\
            "absence process_function in EvalTestReader"

    def __call__(self, mode):
        process_function = create_module(self.params['process_function'])(
            self.params)
        batch_size = self.params['test_batch_size_per_card']

        img_list = []
L
LDOUBLEV 已提交
136
        if mode != "test":
L
LDOUBLEV 已提交
137 138 139 140 141 142
            img_set_dir = self.params['img_set_dir']
            img_name_list_path = self.params['label_file_path']
            with open(img_name_list_path, "rb") as fin:
                lines = fin.readlines()
                for line in lines:
                    img_name = line.decode().strip("\n").split("\t")[0]
L
LDOUBLEV 已提交
143
                    img_path = os.path.join(img_set_dir, img_name)
L
LDOUBLEV 已提交
144
                    img_list.append(img_path)
L
LDOUBLEV 已提交
145
        else:
146
            img_path = self.params['infer_img']
L
LDOUBLEV 已提交
147
            img_list = get_image_file_list(img_path)
L
LDOUBLEV 已提交
148 149 150

        def batch_iter_reader():
            batch_outs = []
L
LDOUBLEV 已提交
151
            for img_path in img_list:
L
LDOUBLEV 已提交
152
                img = cv2.imread(img_path)
L
LDOUBLEV 已提交
153 154
                if img is None:
                    logger.info("{} does not exist!".format(img_path))
155
                    continue
X
xxxpsyduck 已提交
156
                elif len(list(img.shape)) == 2 or img.shape[2] == 1:
L
LDOUBLEV 已提交
157
                    img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
L
LDOUBLEV 已提交
158
                outs = process_function(img)
L
LDOUBLEV 已提交
159
                outs.append(img_path)
L
LDOUBLEV 已提交
160 161 162 163 164 165 166 167
                batch_outs.append(outs)
                if len(batch_outs) == batch_size:
                    yield batch_outs
                    batch_outs = []
            if len(batch_outs) != 0:
                yield batch_outs

        return batch_iter_reader