From dd0112f52ba5b46aeb329a168db19f64c6138547 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Thu, 21 May 2020 11:11:36 +0800 Subject: [PATCH] fix infer_rec for benchmark --- configs/rec/rec_benchmark_reader.yml | 3 +- ppocr/data/rec/dataset_traversal.py | 74 ++++++++++++++++------------ 2 files changed, 45 insertions(+), 32 deletions(-) diff --git a/configs/rec/rec_benchmark_reader.yml b/configs/rec/rec_benchmark_reader.yml index d119c7aa..43cd514c 100755 --- a/configs/rec/rec_benchmark_reader.yml +++ b/configs/rec/rec_benchmark_reader.yml @@ -9,4 +9,5 @@ EvalReader: TestReader: reader_function: ppocr.data.rec.dataset_traversal,LMDBReader - lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/ \ No newline at end of file + lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/ + infer_img: diff --git a/ppocr/data/rec/dataset_traversal.py b/ppocr/data/rec/dataset_traversal.py index c45273e0..357a89fb 100755 --- a/ppocr/data/rec/dataset_traversal.py +++ b/ppocr/data/rec/dataset_traversal.py @@ -42,9 +42,11 @@ class LMDBReader(object): self.mode = params['mode'] if params['mode'] == 'train': self.batch_size = params['train_batch_size_per_card'] - else: + elif params['mode'] == "eval": self.batch_size = params['test_batch_size_per_card'] - + elif params['mode'] == "test": + self.batch_size = 1 + self.infer_img = params["infer_img"] def load_hierarchical_lmdb_dataset(self): lmdb_sets = {} dataset_idx = 0 @@ -97,34 +99,42 @@ class LMDBReader(object): process_id = 0 def sample_iter_reader(): - lmdb_sets = self.load_hierarchical_lmdb_dataset() - if process_id == 0: - self.print_lmdb_sets_info(lmdb_sets) - cur_index_sets = [1 + process_id] * len(lmdb_sets) - while True: - finish_read_num = 0 - for dataset_idx in range(len(lmdb_sets)): - cur_index = cur_index_sets[dataset_idx] - if cur_index > lmdb_sets[dataset_idx]['num_samples']: - finish_read_num += 1 - else: - sample_info = self.get_lmdb_sample_info( - lmdb_sets[dataset_idx]['txn'], cur_index) - cur_index_sets[dataset_idx] += self.num_workers - if sample_info is None: - continue - img, label = sample_info - outs = process_image(img, self.image_shape, label, - self.char_ops, self.loss_type, - self.max_text_length) - if outs is None: - continue - yield outs - - if finish_read_num == len(lmdb_sets): - break - self.close_lmdb_dataset(lmdb_sets) - + if self.mode == 'test': + image_file_list = get_image_file_list(self.infer_img) + for single_img in image_file_list: + img = cv2.imread(single_img) + if img.shape[-1]==1 or len(list(img.shape))==2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + norm_img = process_image(img, self.image_shape) + yield norm_img + else: + lmdb_sets = self.load_hierarchical_lmdb_dataset() + if process_id == 0: + self.print_lmdb_sets_info(lmdb_sets) + cur_index_sets = [1 + process_id] * len(lmdb_sets) + while True: + finish_read_num = 0 + for dataset_idx in range(len(lmdb_sets)): + cur_index = cur_index_sets[dataset_idx] + if cur_index > lmdb_sets[dataset_idx]['num_samples']: + finish_read_num += 1 + else: + sample_info = self.get_lmdb_sample_info( + lmdb_sets[dataset_idx]['txn'], cur_index) + cur_index_sets[dataset_idx] += self.num_workers + if sample_info is None: + continue + img, label = sample_info + outs = process_image(img, self.image_shape, label, + self.char_ops, self.loss_type, + self.max_text_length) + if outs is None: + continue + yield outs + + if finish_read_num == len(lmdb_sets): + break + self.close_lmdb_dataset(lmdb_sets) def batch_iter_reader(): batch_outs = [] for outs in sample_iter_reader(): @@ -135,7 +145,9 @@ class LMDBReader(object): if len(batch_outs) != 0: yield batch_outs - return batch_iter_reader + if self.mode != 'test': + return batch_iter_reader + return sample_iter_reader class SimpleReader(object): -- GitLab