提交 d6778e25 编写于 作者: T tink2123

del lable, polish data_reader

上级 d27ae42a
...@@ -170,10 +170,14 @@ class SimpleReader(object): ...@@ -170,10 +170,14 @@ class SimpleReader(object):
image_file_list = [self.infer_img] image_file_list = [self.infer_img]
elif os.path.isdir(self.infer_img): elif os.path.isdir(self.infer_img):
for single_file in os.listdir(self.infer_img): for single_file in os.listdir(self.infer_img):
if single_file.endswith('png') or single_file.endswith('jpg'): if single_file.split('.')[
image_file_list.append(os.path.join(self.infer_img, single_file)) -1] not in ['bmp', 'jpg', 'jpeg', 'png', 'JPEG', 'JPG', 'PNG']:
continue
image_file_list.append(os.path.join(self.infer_img, single_file))
for single_img in image_file_list: for single_img in image_file_list:
img = cv2.imread(single_img) img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape) norm_img = process_image(img, self.image_shape)
yield norm_img yield norm_img
with open(self.label_file_path, "rb") as fin: with open(self.label_file_path, "rb") as fin:
......
...@@ -79,7 +79,11 @@ def main(): ...@@ -79,7 +79,11 @@ def main():
init_model(config, eval_prog, exe) init_model(config, eval_prog, exe)
blobs = reader_main(config, 'test')() blobs = reader_main(config, 'test')()
infer_list = os.listdir(config['Global']['infer_img']) infer_img = config['Global']['infer_img']
if os.path.isfile(infer_img):
infer_list = [infer_img]
elif os.path.isdir(infer_img):
infer_list = os.listdir(config['Global']['infer_img'])
max_img_num = len(infer_list) max_img_num = len(infer_list)
if len(infer_list) == 0: if len(infer_list) == 0:
logger.info("Can not find img in infer_img dir.") logger.info("Can not find img in infer_img dir.")
......
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册