提交 dd0112f5 编写于 作者: T tink2123

fix infer_rec for benchmark

上级 6de43fbb
...@@ -9,4 +9,5 @@ EvalReader: ...@@ -9,4 +9,5 @@ EvalReader:
TestReader: TestReader:
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/ lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
\ No newline at end of file infer_img:
...@@ -42,9 +42,11 @@ class LMDBReader(object): ...@@ -42,9 +42,11 @@ class LMDBReader(object):
self.mode = params['mode'] self.mode = params['mode']
if params['mode'] == 'train': if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card'] self.batch_size = params['train_batch_size_per_card']
else: elif params['mode'] == "eval":
self.batch_size = params['test_batch_size_per_card'] self.batch_size = params['test_batch_size_per_card']
elif params['mode'] == "test":
self.batch_size = 1
self.infer_img = params["infer_img"]
def load_hierarchical_lmdb_dataset(self): def load_hierarchical_lmdb_dataset(self):
lmdb_sets = {} lmdb_sets = {}
dataset_idx = 0 dataset_idx = 0
...@@ -97,34 +99,42 @@ class LMDBReader(object): ...@@ -97,34 +99,42 @@ class LMDBReader(object):
process_id = 0 process_id = 0
def sample_iter_reader(): def sample_iter_reader():
lmdb_sets = self.load_hierarchical_lmdb_dataset() if self.mode == 'test':
if process_id == 0: image_file_list = get_image_file_list(self.infer_img)
self.print_lmdb_sets_info(lmdb_sets) for single_img in image_file_list:
cur_index_sets = [1 + process_id] * len(lmdb_sets) img = cv2.imread(single_img)
while True: if img.shape[-1]==1 or len(list(img.shape))==2:
finish_read_num = 0 img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
for dataset_idx in range(len(lmdb_sets)): norm_img = process_image(img, self.image_shape)
cur_index = cur_index_sets[dataset_idx] yield norm_img
if cur_index > lmdb_sets[dataset_idx]['num_samples']: else:
finish_read_num += 1 lmdb_sets = self.load_hierarchical_lmdb_dataset()
else: if process_id == 0:
sample_info = self.get_lmdb_sample_info( self.print_lmdb_sets_info(lmdb_sets)
lmdb_sets[dataset_idx]['txn'], cur_index) cur_index_sets = [1 + process_id] * len(lmdb_sets)
cur_index_sets[dataset_idx] += self.num_workers while True:
if sample_info is None: finish_read_num = 0
continue for dataset_idx in range(len(lmdb_sets)):
img, label = sample_info cur_index = cur_index_sets[dataset_idx]
outs = process_image(img, self.image_shape, label, if cur_index > lmdb_sets[dataset_idx]['num_samples']:
self.char_ops, self.loss_type, finish_read_num += 1
self.max_text_length) else:
if outs is None: sample_info = self.get_lmdb_sample_info(
continue lmdb_sets[dataset_idx]['txn'], cur_index)
yield outs cur_index_sets[dataset_idx] += self.num_workers
if sample_info is None:
if finish_read_num == len(lmdb_sets): continue
break img, label = sample_info
self.close_lmdb_dataset(lmdb_sets) outs = process_image(img, self.image_shape, label,
self.char_ops, self.loss_type,
self.max_text_length)
if outs is None:
continue
yield outs
if finish_read_num == len(lmdb_sets):
break
self.close_lmdb_dataset(lmdb_sets)
def batch_iter_reader(): def batch_iter_reader():
batch_outs = [] batch_outs = []
for outs in sample_iter_reader(): for outs in sample_iter_reader():
...@@ -135,7 +145,9 @@ class LMDBReader(object): ...@@ -135,7 +145,9 @@ class LMDBReader(object):
if len(batch_outs) != 0: if len(batch_outs) != 0:
yield batch_outs yield batch_outs
return batch_iter_reader if self.mode != 'test':
return batch_iter_reader
return sample_iter_reader
class SimpleReader(object): class SimpleReader(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册