提交 24757ba6 编写于 作者: T tink2123

modified infer single

上级 dab609b1
...@@ -10,4 +10,3 @@ EvalReader: ...@@ -10,4 +10,3 @@ EvalReader:
TestReader: TestReader:
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/ lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
infer_img: ./infer_img
...@@ -18,6 +18,8 @@ Global: ...@@ -18,6 +18,8 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -11,4 +11,3 @@ EvalReader: ...@@ -11,4 +11,3 @@ EvalReader:
TestReader: TestReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
infer_img: ./infer_img
...@@ -11,4 +11,3 @@ EvalReader: ...@@ -11,4 +11,3 @@ EvalReader:
TestReader: TestReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
infer_img: ./infer_img
...@@ -17,6 +17,8 @@ Global: ...@@ -17,6 +17,8 @@ Global:
pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
Global: Global:
algorithm: CRNN algorithm: CRNN
use_gpu: true use_gpu: false
epoch_num: 72 epoch_num: 72
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
...@@ -14,7 +14,7 @@ Global: ...@@ -14,7 +14,7 @@ Global:
character_type: en character_type: en
loss_type: ctc loss_type: ctc
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights: ./output/rec_CRNN/rec_mv3_none_bilstm_ctc/best_accuracy
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
......
...@@ -17,6 +17,7 @@ Global: ...@@ -17,6 +17,7 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -17,7 +17,9 @@ Global: ...@@ -17,7 +17,9 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -17,6 +17,7 @@ Global: ...@@ -17,6 +17,7 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
......
...@@ -17,7 +17,9 @@ Global: ...@@ -17,7 +17,9 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -17,6 +17,7 @@ Global: ...@@ -17,6 +17,7 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -17,6 +17,8 @@ Global: ...@@ -17,6 +17,8 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -17,6 +17,8 @@ Global: ...@@ -17,6 +17,8 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -184,7 +184,7 @@ python3 tools/eval.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkp ...@@ -184,7 +184,7 @@ python3 tools/eval.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkp
``` ```
# 预测英文结果 # 预测英文结果
python3 tools/infer_rec.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy TestReader.infer_img=doc/imgs_words/en/word_1.jpg python3 tools/infer_rec.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
``` ```
预测图片: 预测图片:
......
...@@ -43,11 +43,10 @@ class LMDBReader(object): ...@@ -43,11 +43,10 @@ class LMDBReader(object):
self.mode = params['mode'] self.mode = params['mode']
if params['mode'] == 'train': if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card'] self.batch_size = params['train_batch_size_per_card']
elif params['mode'] == "eval": else:
self.batch_size = params['test_batch_size_per_card'] self.batch_size = params['test_batch_size_per_card']
elif params['mode'] == "test": self.infer_img = params['infer_img']
self.batch_size = 1
self.infer_img = params["infer_img"]
def load_hierarchical_lmdb_dataset(self): def load_hierarchical_lmdb_dataset(self):
lmdb_sets = {} lmdb_sets = {}
dataset_idx = 0 dataset_idx = 0
...@@ -100,11 +99,11 @@ class LMDBReader(object): ...@@ -100,11 +99,11 @@ class LMDBReader(object):
process_id = 0 process_id = 0
def sample_iter_reader(): def sample_iter_reader():
if self.mode == 'test': if self.infer_img is not None:
image_file_list = get_image_file_list(self.infer_img) image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list: for single_img in image_file_list:
img = cv2.imread(single_img) img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape) norm_img = process_image(img, self.image_shape)
yield norm_img yield norm_img
...@@ -136,6 +135,7 @@ class LMDBReader(object): ...@@ -136,6 +135,7 @@ class LMDBReader(object):
if finish_read_num == len(lmdb_sets): if finish_read_num == len(lmdb_sets):
break break
self.close_lmdb_dataset(lmdb_sets) self.close_lmdb_dataset(lmdb_sets)
def batch_iter_reader(): def batch_iter_reader():
batch_outs = [] batch_outs = []
for outs in sample_iter_reader(): for outs in sample_iter_reader():
...@@ -146,7 +146,7 @@ class LMDBReader(object): ...@@ -146,7 +146,7 @@ class LMDBReader(object):
if len(batch_outs) != 0: if len(batch_outs) != 0:
yield batch_outs yield batch_outs
if self.mode != 'test': if self.infer_img is None:
return batch_iter_reader return batch_iter_reader
return sample_iter_reader return sample_iter_reader
...@@ -165,24 +165,22 @@ class SimpleReader(object): ...@@ -165,24 +165,22 @@ class SimpleReader(object):
self.loss_type = params['loss_type'] self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length'] self.max_text_length = params['max_text_length']
self.mode = params['mode'] self.mode = params['mode']
self.infer_img = params['infer_img']
if params['mode'] == 'train': if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card'] self.batch_size = params['train_batch_size_per_card']
elif params['mode'] == 'eval':
self.batch_size = params['test_batch_size_per_card']
else: else:
self.batch_size = 1 self.batch_size = params['test_batch_size_per_card']
self.infer_img = params['infer_img']
def __call__(self, process_id): def __call__(self, process_id):
if self.mode != 'train': if self.mode != 'train':
process_id = 0 process_id = 0
def sample_iter_reader(): def sample_iter_reader():
if self.mode == 'test': if self.infer_img is not None:
image_file_list = get_image_file_list(self.infer_img) image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list: for single_img in image_file_list:
img = cv2.imread(single_img) img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape) norm_img = process_image(img, self.image_shape)
yield norm_img yield norm_img
...@@ -192,7 +190,7 @@ class SimpleReader(object): ...@@ -192,7 +190,7 @@ class SimpleReader(object):
img_num = len(label_infor_list) img_num = len(label_infor_list)
img_id_list = list(range(img_num)) img_id_list = list(range(img_num))
random.shuffle(img_id_list) random.shuffle(img_id_list)
if sys.platform=="win32": if sys.platform == "win32":
print("multiprocess is not fully compatible with Windows." print("multiprocess is not fully compatible with Windows."
"num_workers will be 1.") "num_workers will be 1.")
self.num_workers = 1 self.num_workers = 1
...@@ -204,7 +202,7 @@ class SimpleReader(object): ...@@ -204,7 +202,7 @@ class SimpleReader(object):
if img is None: if img is None:
logger.info("{} does not exist!".format(img_path)) logger.info("{} does not exist!".format(img_path))
continue continue
if img.shape[-1]==1 or len(list(img.shape))==2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
label = substr[1] label = substr[1]
...@@ -225,6 +223,6 @@ class SimpleReader(object): ...@@ -225,6 +223,6 @@ class SimpleReader(object):
if len(batch_outs) != 0: if len(batch_outs) != 0:
yield batch_outs yield batch_outs
if self.mode != 'test': if self.infer_img is None:
return batch_iter_reader return batch_iter_reader
return sample_iter_reader return sample_iter_reader
...@@ -48,7 +48,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode): ...@@ -48,7 +48,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
total_sample_num = 0 total_sample_num = 0
total_acc_num = 0 total_acc_num = 0
total_batch_num = 0 total_batch_num = 0
if mode == "test": if mode == "eval":
is_remove_duplicate = False is_remove_duplicate = False
else: else:
is_remove_duplicate = True is_remove_duplicate = True
...@@ -91,11 +91,11 @@ def test_rec_benchmark(exe, config, eval_info_dict): ...@@ -91,11 +91,11 @@ def test_rec_benchmark(exe, config, eval_info_dict):
total_correct_number = 0 total_correct_number = 0
eval_data_acc_info = {} eval_data_acc_info = {}
for eval_data in eval_data_list: for eval_data in eval_data_list:
config['EvalReader']['lmdb_sets_dir'] = \ config['TestReader']['lmdb_sets_dir'] = \
eval_data_dir + "/" + eval_data eval_data_dir + "/" + eval_data
eval_reader = reader_main(config=config, mode="eval") eval_reader = reader_main(config=config, mode="test")
eval_info_dict['reader'] = eval_reader eval_info_dict['reader'] = eval_reader
metrics = eval_rec_run(exe, config, eval_info_dict, "eval") metrics = eval_rec_run(exe, config, eval_info_dict, "test")
total_evaluation_data_number += metrics['total_sample_num'] total_evaluation_data_number += metrics['total_sample_num']
total_correct_number += metrics['total_acc_num'] total_correct_number += metrics['total_acc_num']
eval_data_acc_info[eval_data] = metrics eval_data_acc_info[eval_data] = metrics
......
...@@ -21,6 +21,7 @@ import time ...@@ -21,6 +21,7 @@ import time
import multiprocessing import multiprocessing
import numpy as np import numpy as np
def set_paddle_flags(**kwargs): def set_paddle_flags(**kwargs):
for key, value in kwargs.items(): for key, value in kwargs.items():
if os.environ.get(key, None) is None: if os.environ.get(key, None) is None:
...@@ -78,13 +79,13 @@ def main(): ...@@ -78,13 +79,13 @@ def main():
init_model(config, eval_prog, exe) init_model(config, eval_prog, exe)
blobs = reader_main(config, 'test')() blobs = reader_main(config, 'test')()
infer_img = config['TestReader']['infer_img'] infer_img = config['Global']['infer_img']
infer_list = get_image_file_list(infer_img) infer_list = get_image_file_list(infer_img)
max_img_num = len(infer_list) max_img_num = len(infer_list)
if len(infer_list) == 0: if len(infer_list) == 0:
logger.info("Can not find img in infer_img dir.") logger.info("Can not find img in infer_img dir.")
for i in range(max_img_num): for i in range(max_img_num):
print("infer_img:",infer_list[i]) print("infer_img:", infer_list[i])
img = next(blobs) img = next(blobs)
predict = exe.run(program=eval_prog, predict = exe.run(program=eval_prog,
feed={"image": img}, feed={"image": img},
...@@ -105,8 +106,8 @@ def main(): ...@@ -105,8 +106,8 @@ def main():
preds_text = preds_text.reshape(-1) preds_text = preds_text.reshape(-1)
preds_text = char_ops.decode(preds_text) preds_text = char_ops.decode(preds_text)
print("\t index:",preds) print("\t index:", preds)
print("\t word :",preds_text) print("\t word :", preds_text)
# save for inference model # save for inference model
target_var = [] target_var = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册