提交 dd0112f5 编写于 作者: T tink2123

fix infer_rec for benchmark

上级 6de43fbb
...@@ -10,3 +10,4 @@ EvalReader: ...@@ -10,3 +10,4 @@ EvalReader:
TestReader: TestReader:
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/ lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
infer_img:
...@@ -42,9 +42,11 @@ class LMDBReader(object): ...@@ -42,9 +42,11 @@ class LMDBReader(object):
self.mode = params['mode'] self.mode = params['mode']
if params['mode'] == 'train': if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card'] self.batch_size = params['train_batch_size_per_card']
else: elif params['mode'] == "eval":
self.batch_size = params['test_batch_size_per_card'] self.batch_size = params['test_batch_size_per_card']
elif params['mode'] == "test":
self.batch_size = 1
self.infer_img = params["infer_img"]
def load_hierarchical_lmdb_dataset(self): def load_hierarchical_lmdb_dataset(self):
lmdb_sets = {} lmdb_sets = {}
dataset_idx = 0 dataset_idx = 0
...@@ -97,6 +99,15 @@ class LMDBReader(object): ...@@ -97,6 +99,15 @@ class LMDBReader(object):
process_id = 0 process_id = 0
def sample_iter_reader(): def sample_iter_reader():
if self.mode == 'test':
image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list:
img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape)
yield norm_img
else:
lmdb_sets = self.load_hierarchical_lmdb_dataset() lmdb_sets = self.load_hierarchical_lmdb_dataset()
if process_id == 0: if process_id == 0:
self.print_lmdb_sets_info(lmdb_sets) self.print_lmdb_sets_info(lmdb_sets)
...@@ -124,7 +135,6 @@ class LMDBReader(object): ...@@ -124,7 +135,6 @@ class LMDBReader(object):
if finish_read_num == len(lmdb_sets): if finish_read_num == len(lmdb_sets):
break break
self.close_lmdb_dataset(lmdb_sets) self.close_lmdb_dataset(lmdb_sets)
def batch_iter_reader(): def batch_iter_reader():
batch_outs = [] batch_outs = []
for outs in sample_iter_reader(): for outs in sample_iter_reader():
...@@ -135,7 +145,9 @@ class LMDBReader(object): ...@@ -135,7 +145,9 @@ class LMDBReader(object):
if len(batch_outs) != 0: if len(batch_outs) != 0:
yield batch_outs yield batch_outs
if self.mode != 'test':
return batch_iter_reader return batch_iter_reader
return sample_iter_reader
class SimpleReader(object): class SimpleReader(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册