未验证 提交 96ead92e 编写于 作者: X xiaoting 提交者: GitHub

Merge pull request #83 from tink2123/fix_infer

fix infer_rec for benchmark
......@@ -10,4 +10,4 @@ EvalReader:
TestReader:
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
infer_img: ./infer_img
\ No newline at end of file
infer_img: ./infer_img
......@@ -42,9 +42,11 @@ class LMDBReader(object):
self.mode = params['mode']
if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card']
else:
elif params['mode'] == "eval":
self.batch_size = params['test_batch_size_per_card']
elif params['mode'] == "test":
self.batch_size = 1
self.infer_img = params["infer_img"]
def load_hierarchical_lmdb_dataset(self):
lmdb_sets = {}
dataset_idx = 0
......@@ -97,34 +99,42 @@ class LMDBReader(object):
process_id = 0
def sample_iter_reader():
lmdb_sets = self.load_hierarchical_lmdb_dataset()
if process_id == 0:
self.print_lmdb_sets_info(lmdb_sets)
cur_index_sets = [1 + process_id] * len(lmdb_sets)
while True:
finish_read_num = 0
for dataset_idx in range(len(lmdb_sets)):
cur_index = cur_index_sets[dataset_idx]
if cur_index > lmdb_sets[dataset_idx]['num_samples']:
finish_read_num += 1
else:
sample_info = self.get_lmdb_sample_info(
lmdb_sets[dataset_idx]['txn'], cur_index)
cur_index_sets[dataset_idx] += self.num_workers
if sample_info is None:
continue
img, label = sample_info
outs = process_image(img, self.image_shape, label,
self.char_ops, self.loss_type,
self.max_text_length)
if outs is None:
continue
yield outs
if finish_read_num == len(lmdb_sets):
break
self.close_lmdb_dataset(lmdb_sets)
if self.mode == 'test':
image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list:
img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape)
yield norm_img
else:
lmdb_sets = self.load_hierarchical_lmdb_dataset()
if process_id == 0:
self.print_lmdb_sets_info(lmdb_sets)
cur_index_sets = [1 + process_id] * len(lmdb_sets)
while True:
finish_read_num = 0
for dataset_idx in range(len(lmdb_sets)):
cur_index = cur_index_sets[dataset_idx]
if cur_index > lmdb_sets[dataset_idx]['num_samples']:
finish_read_num += 1
else:
sample_info = self.get_lmdb_sample_info(
lmdb_sets[dataset_idx]['txn'], cur_index)
cur_index_sets[dataset_idx] += self.num_workers
if sample_info is None:
continue
img, label = sample_info
outs = process_image(img, self.image_shape, label,
self.char_ops, self.loss_type,
self.max_text_length)
if outs is None:
continue
yield outs
if finish_read_num == len(lmdb_sets):
break
self.close_lmdb_dataset(lmdb_sets)
def batch_iter_reader():
batch_outs = []
for outs in sample_iter_reader():
......@@ -135,7 +145,9 @@ class LMDBReader(object):
if len(batch_outs) != 0:
yield batch_outs
return batch_iter_reader
if self.mode != 'test':
return batch_iter_reader
return sample_iter_reader
class SimpleReader(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册