diff --git a/configs/rec/rec_benchmark_reader.yml b/configs/rec/rec_benchmark_reader.yml index 3d1e3e0b22ce04573c73f51cbef26133415b9aa3..524f2f68bac92ff6ffe3ff3b34e461d2adc81e41 100755 --- a/configs/rec/rec_benchmark_reader.yml +++ b/configs/rec/rec_benchmark_reader.yml @@ -10,4 +10,3 @@ EvalReader: TestReader: reader_function: ppocr.data.rec.dataset_traversal,LMDBReader lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/ - infer_img: ./infer_img diff --git a/configs/rec/rec_chinese_lite_train.yml b/configs/rec/rec_chinese_lite_train.yml index ec1b7a697d95d12f884fb8e1080a1493b9a30ad3..cbc43e06ad93ddfcbb92e3098425e2f5359b32a3 100755 --- a/configs/rec/rec_chinese_lite_train.yml +++ b/configs/rec/rec_chinese_lite_train.yml @@ -18,6 +18,8 @@ Global: pretrain_weights: checkpoints: save_inference_dir: + infer_img: + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_chinese_reader.yml b/configs/rec/rec_chinese_reader.yml index f09a1ea72e6d929d0446fbbf51ca218e52ae5b3e..a44efd9911d4595cc519b660e868aa9a1e0f144b 100755 --- a/configs/rec/rec_chinese_reader.yml +++ b/configs/rec/rec_chinese_reader.yml @@ -11,4 +11,3 @@ EvalReader: TestReader: reader_function: ppocr.data.rec.dataset_traversal,SimpleReader - infer_img: ./infer_img diff --git a/configs/rec/rec_icdar15_reader.yml b/configs/rec/rec_icdar15_reader.yml index 12facda1a2fd720765ccf8b39e21cff3a4d31129..322d5f25e0ef0fab167c0c39b38fa488a5546f1b 100755 --- a/configs/rec/rec_icdar15_reader.yml +++ b/configs/rec/rec_icdar15_reader.yml @@ -11,4 +11,3 @@ EvalReader: TestReader: reader_function: ppocr.data.rec.dataset_traversal,SimpleReader - infer_img: ./infer_img diff --git a/configs/rec/rec_icdar15_train.yml b/configs/rec/rec_icdar15_train.yml index 6596fc339398af20a9c9ce74a264e24c0a0bdd35..dacf32439abf4a2ebdf3c7bc5dff62473b28d2d7 100755 --- a/configs/rec/rec_icdar15_train.yml +++ b/configs/rec/rec_icdar15_train.yml @@ -17,6 +17,8 @@ Global: pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy checkpoints: save_inference_dir: + infer_img: + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml index 11a09ee927492154c46f82add1bcfae7c2bb787e..951c83cc496702bd8492dfcba3a9444eff4239f7 100755 --- a/configs/rec/rec_mv3_none_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml @@ -1,6 +1,6 @@ Global: algorithm: CRNN - use_gpu: true + use_gpu: false epoch_num: 72 log_smooth_window: 20 print_batch_step: 10 @@ -14,7 +14,7 @@ Global: character_type: en loss_type: ctc reader_yml: ./configs/rec/rec_benchmark_reader.yml - pretrain_weights: + pretrain_weights: ./output/rec_CRNN/rec_mv3_none_bilstm_ctc/best_accuracy checkpoints: save_inference_dir: diff --git a/configs/rec/rec_mv3_none_none_ctc.yml b/configs/rec/rec_mv3_none_none_ctc.yml index bbbb6d1fabacbebaf1481260f34ef0e2cfed97f6..ceec09ce6f3b6cb2238d6fb2e15f510cb31e0fd8 100755 --- a/configs/rec/rec_mv3_none_none_ctc.yml +++ b/configs/rec/rec_mv3_none_none_ctc.yml @@ -17,6 +17,7 @@ Global: pretrain_weights: checkpoints: save_inference_dir: + infer_img: Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_mv3_tps_bilstm_attn.yml b/configs/rec/rec_mv3_tps_bilstm_attn.yml index 03a2e901b4997a5cec0e01756b69b1fa0d04511b..d2fb512fe5b8cc93cc2019e21d2c035d19ca2246 100755 --- a/configs/rec/rec_mv3_tps_bilstm_attn.yml +++ b/configs/rec/rec_mv3_tps_bilstm_attn.yml @@ -17,7 +17,9 @@ Global: pretrain_weights: checkpoints: save_inference_dir: - + infer_img: + + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_mv3_tps_bilstm_ctc.yml b/configs/rec/rec_mv3_tps_bilstm_ctc.yml index 47247b723a0cb3a145d6e87a3d76b1a8dcf1ea21..bc5780bd94aca3a7a73a36ca80aeba17d163cdb3 100755 --- a/configs/rec/rec_mv3_tps_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_tps_bilstm_ctc.yml @@ -17,6 +17,7 @@ Global: pretrain_weights: checkpoints: save_inference_dir: + infer_img: Architecture: diff --git a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml index 1018193611855dd22ad54fb8fbc70b7f47d89c33..b71e8feae7ac8f235bf471101efd4383c61bfab2 100755 --- a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml @@ -17,7 +17,9 @@ Global: pretrain_weights: checkpoints: save_inference_dir: - + infer_img: + + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_r34_vd_none_none_ctc.yml b/configs/rec/rec_r34_vd_none_none_ctc.yml index ff4c57634aa12e6bbd88905a038260c75489d8f3..d9c9458d6d8fcdb9df590b0093d54b71e3e53fcc 100755 --- a/configs/rec/rec_r34_vd_none_none_ctc.yml +++ b/configs/rec/rec_r34_vd_none_none_ctc.yml @@ -17,6 +17,7 @@ Global: pretrain_weights: checkpoints: save_inference_dir: + infer_img: Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_r34_vd_tps_bilstm_attn.yml b/configs/rec/rec_r34_vd_tps_bilstm_attn.yml index 4d96e9e72927e3822137bf95e89164cc33b41db7..405082bdbec0f4b9ac0c885801963e9261c43e6e 100755 --- a/configs/rec/rec_r34_vd_tps_bilstm_attn.yml +++ b/configs/rec/rec_r34_vd_tps_bilstm_attn.yml @@ -17,6 +17,8 @@ Global: pretrain_weights: checkpoints: save_inference_dir: + infer_img: + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml index 844721a2e44019382682e76d4f3f40954eaebc6b..517322c374a7faf80d6e2b69b7f3e8b2dbb5b5af 100755 --- a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml @@ -17,6 +17,8 @@ Global: pretrain_weights: checkpoints: save_inference_dir: + infer_img: + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/doc/recognition.md b/doc/recognition.md index a5a8119c56a34a221cb0441716b1a875caddb897..ea38c0f3a26f2bc22dc9f66dba3a7e7d825577da 100644 --- a/doc/recognition.md +++ b/doc/recognition.md @@ -184,7 +184,7 @@ python3 tools/eval.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkp ``` # 预测英文结果 -python3 tools/infer_rec.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy TestReader.infer_img=doc/imgs_words/en/word_1.jpg +python3 tools/infer_rec.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png ``` 预测图片: diff --git a/ppocr/data/rec/dataset_traversal.py b/ppocr/data/rec/dataset_traversal.py index f60b9fe36dda95dd6ce58a0c0e6b57c843280b35..9d0d2e9679b6d46c3eb396246d8fb9053b80d5d6 100755 --- a/ppocr/data/rec/dataset_traversal.py +++ b/ppocr/data/rec/dataset_traversal.py @@ -43,11 +43,10 @@ class LMDBReader(object): self.mode = params['mode'] if params['mode'] == 'train': self.batch_size = params['train_batch_size_per_card'] - elif params['mode'] == "eval": + else: self.batch_size = params['test_batch_size_per_card'] - elif params['mode'] == "test": - self.batch_size = 1 - self.infer_img = params["infer_img"] + self.infer_img = params['infer_img'] + def load_hierarchical_lmdb_dataset(self): lmdb_sets = {} dataset_idx = 0 @@ -100,11 +99,11 @@ class LMDBReader(object): process_id = 0 def sample_iter_reader(): - if self.mode == 'test': + if self.infer_img is not None: image_file_list = get_image_file_list(self.infer_img) for single_img in image_file_list: img = cv2.imread(single_img) - if img.shape[-1]==1 or len(list(img.shape))==2: + if img.shape[-1] == 1 or len(list(img.shape)) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) norm_img = process_image(img, self.image_shape) yield norm_img @@ -136,6 +135,7 @@ class LMDBReader(object): if finish_read_num == len(lmdb_sets): break self.close_lmdb_dataset(lmdb_sets) + def batch_iter_reader(): batch_outs = [] for outs in sample_iter_reader(): @@ -146,7 +146,7 @@ class LMDBReader(object): if len(batch_outs) != 0: yield batch_outs - if self.mode != 'test': + if self.infer_img is None: return batch_iter_reader return sample_iter_reader @@ -165,24 +165,22 @@ class SimpleReader(object): self.loss_type = params['loss_type'] self.max_text_length = params['max_text_length'] self.mode = params['mode'] + self.infer_img = params['infer_img'] if params['mode'] == 'train': self.batch_size = params['train_batch_size_per_card'] - elif params['mode'] == 'eval': - self.batch_size = params['test_batch_size_per_card'] else: - self.batch_size = 1 - self.infer_img = params['infer_img'] + self.batch_size = params['test_batch_size_per_card'] def __call__(self, process_id): if self.mode != 'train': process_id = 0 def sample_iter_reader(): - if self.mode == 'test': + if self.infer_img is not None: image_file_list = get_image_file_list(self.infer_img) for single_img in image_file_list: img = cv2.imread(single_img) - if img.shape[-1]==1 or len(list(img.shape))==2: + if img.shape[-1] == 1 or len(list(img.shape)) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) norm_img = process_image(img, self.image_shape) yield norm_img @@ -192,7 +190,7 @@ class SimpleReader(object): img_num = len(label_infor_list) img_id_list = list(range(img_num)) random.shuffle(img_id_list) - if sys.platform=="win32": + if sys.platform == "win32": print("multiprocess is not fully compatible with Windows." "num_workers will be 1.") self.num_workers = 1 @@ -204,7 +202,7 @@ class SimpleReader(object): if img is None: logger.info("{} does not exist!".format(img_path)) continue - if img.shape[-1]==1 or len(list(img.shape))==2: + if img.shape[-1] == 1 or len(list(img.shape)) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) label = substr[1] @@ -225,6 +223,6 @@ class SimpleReader(object): if len(batch_outs) != 0: yield batch_outs - if self.mode != 'test': + if self.infer_img is None: return batch_iter_reader return sample_iter_reader diff --git a/tools/eval_utils/eval_rec_utils.py b/tools/eval_utils/eval_rec_utils.py index 3ceaa159ce1a98940bbdf1127b96e82243e96658..2d7d7e1d4e200e12643f8cfcb812a3cba3229b8f 100644 --- a/tools/eval_utils/eval_rec_utils.py +++ b/tools/eval_utils/eval_rec_utils.py @@ -48,7 +48,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode): total_sample_num = 0 total_acc_num = 0 total_batch_num = 0 - if mode == "test": + if mode == "eval": is_remove_duplicate = False else: is_remove_duplicate = True @@ -91,11 +91,11 @@ def test_rec_benchmark(exe, config, eval_info_dict): total_correct_number = 0 eval_data_acc_info = {} for eval_data in eval_data_list: - config['EvalReader']['lmdb_sets_dir'] = \ + config['TestReader']['lmdb_sets_dir'] = \ eval_data_dir + "/" + eval_data - eval_reader = reader_main(config=config, mode="eval") + eval_reader = reader_main(config=config, mode="test") eval_info_dict['reader'] = eval_reader - metrics = eval_rec_run(exe, config, eval_info_dict, "eval") + metrics = eval_rec_run(exe, config, eval_info_dict, "test") total_evaluation_data_number += metrics['total_sample_num'] total_correct_number += metrics['total_acc_num'] eval_data_acc_info[eval_data] = metrics diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 25bae1ca6c3034833c83975d9f47ab30388ca56e..67e61451fb1f1822ebbd3524254b4af14d0acb1a 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -21,6 +21,7 @@ import time import multiprocessing import numpy as np + def set_paddle_flags(**kwargs): for key, value in kwargs.items(): if os.environ.get(key, None) is None: @@ -78,13 +79,13 @@ def main(): init_model(config, eval_prog, exe) blobs = reader_main(config, 'test')() - infer_img = config['TestReader']['infer_img'] + infer_img = config['Global']['infer_img'] infer_list = get_image_file_list(infer_img) max_img_num = len(infer_list) if len(infer_list) == 0: logger.info("Can not find img in infer_img dir.") for i in range(max_img_num): - print("infer_img:",infer_list[i]) + print("infer_img:", infer_list[i]) img = next(blobs) predict = exe.run(program=eval_prog, feed={"image": img}, @@ -105,8 +106,8 @@ def main(): preds_text = preds_text.reshape(-1) preds_text = char_ops.decode(preds_text) - print("\t index:",preds) - print("\t word :",preds_text) + print("\t index:", preds) + print("\t word :", preds_text) # save for inference model target_var = []