未验证 提交 96ead92e 编写于 作者: X xiaoting 提交者: GitHub

Merge pull request #83 from tink2123/fix_infer

fix infer_rec for benchmark
......@@ -42,9 +42,11 @@ class LMDBReader(object):
self.mode = params['mode']
if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card']
else:
elif params['mode'] == "eval":
self.batch_size = params['test_batch_size_per_card']
elif params['mode'] == "test":
self.batch_size = 1
self.infer_img = params["infer_img"]
def load_hierarchical_lmdb_dataset(self):
lmdb_sets = {}
dataset_idx = 0
......@@ -97,6 +99,15 @@ class LMDBReader(object):
process_id = 0
def sample_iter_reader():
if self.mode == 'test':
image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list:
img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape)
yield norm_img
else:
lmdb_sets = self.load_hierarchical_lmdb_dataset()
if process_id == 0:
self.print_lmdb_sets_info(lmdb_sets)
......@@ -124,7 +135,6 @@ class LMDBReader(object):
if finish_read_num == len(lmdb_sets):
break
self.close_lmdb_dataset(lmdb_sets)
def batch_iter_reader():
batch_outs = []
for outs in sample_iter_reader():
......@@ -135,7 +145,9 @@ class LMDBReader(object):
if len(batch_outs) != 0:
yield batch_outs
if self.mode != 'test':
return batch_iter_reader
return sample_iter_reader
class SimpleReader(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册