From 9023a5c57ae57404c218bab4668760ca209a5712 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Wed, 3 Jun 2020 17:38:44 +0800 Subject: [PATCH] update infer doc and fix yml --- configs/rec/rec_chinese_lite_train.yml | 4 ++-- configs/rec/rec_icdar15_train.yml | 2 +- configs/rec/rec_mv3_none_bilstm_ctc.yml | 2 +- configs/rec/rec_mv3_tps_bilstm_attn.yml | 5 ++--- doc/inference.md | 6 ++++++ tools/infer/predict_rec.py | 12 ++++++++---- 6 files changed, 20 insertions(+), 11 deletions(-) diff --git a/configs/rec/rec_chinese_lite_train.yml b/configs/rec/rec_chinese_lite_train.yml index 90e9dab1..bbd2590c 100755 --- a/configs/rec/rec_chinese_lite_train.yml +++ b/configs/rec/rec_chinese_lite_train.yml @@ -1,6 +1,6 @@ Global: algorithm: CRNN - use_gpu: false + use_gpu: true epoch_num: 3000 log_smooth_window: 20 print_batch_step: 10 @@ -16,7 +16,7 @@ Global: character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt loss_type: ctc reader_yml: ./configs/rec/rec_chinese_reader.yml - pretrain_weights: output/rec_CRNN/rec_mv3_crnn/best_accuracy + pretrain_weights: checkpoints: save_inference_dir: infer_img: diff --git a/configs/rec/rec_icdar15_train.yml b/configs/rec/rec_icdar15_train.yml index 35a1b17d..1cca2def 100755 --- a/configs/rec/rec_icdar15_train.yml +++ b/configs/rec/rec_icdar15_train.yml @@ -15,7 +15,7 @@ Global: character_type: en loss_type: ctc reader_yml: ./configs/rec/rec_icdar15_reader.yml - pretrain_weights: + pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy checkpoints: save_inference_dir: infer_img: diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml index 6cbca16f..7bae1bd1 100755 --- a/configs/rec/rec_mv3_none_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml @@ -1,6 +1,6 @@ Global: algorithm: CRNN - use_gpu: false + use_gpu: true epoch_num: 72 log_smooth_window: 20 print_batch_step: 10 diff --git a/configs/rec/rec_mv3_tps_bilstm_attn.yml b/configs/rec/rec_mv3_tps_bilstm_attn.yml index 1787abc1..2dbb9d07 100755 --- a/configs/rec/rec_mv3_tps_bilstm_attn.yml +++ b/configs/rec/rec_mv3_tps_bilstm_attn.yml @@ -1,6 +1,6 @@ Global: algorithm: RARE - use_gpu: false + use_gpu: true epoch_num: 72 log_smooth_window: 20 print_batch_step: 10 @@ -12,8 +12,7 @@ Global: test_batch_size_per_card: 256 image_shape: [3, 32, 100] max_text_length: 25 - character_type: ch - character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt + character_type: en loss_type: attention tps: true reader_yml: ./configs/rec/rec_benchmark_reader.yml diff --git a/doc/inference.md b/doc/inference.md index a15043b5..b16b89a9 100644 --- a/doc/inference.md +++ b/doc/inference.md @@ -165,6 +165,12 @@ STAR-Net文本识别模型推理,可以执行如下命令: ``` python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en" ``` + +RARE 文本识别模型推理,可以执行如下命令: +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/sare/" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_algorithm="RARE" +``` + ![](imgs_words_en/word_336.png) 执行命令后,上面图像的识别结果如下: diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 914fcde3..e8b485fb 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -32,10 +32,14 @@ class TextRecognizer(object): self.rec_image_shape = image_shape self.character_type = args.rec_char_type self.rec_batch_num = args.rec_batch_num + self.rec_algorithm = args.rec_algorithm char_ops_params = {} char_ops_params["character_type"] = args.rec_char_type char_ops_params["character_dict_path"] = args.rec_char_dict_path - char_ops_params['loss_type'] = 'ctc' + if self.rec_algorithm != "RARE": + char_ops_params['loss_type'] = 'ctc' + else: + char_ops_params['loss_type'] = 'attention' self.char_ops = CharacterOps(char_ops_params) def resize_norm_img(self, img, max_wh_ratio): @@ -81,7 +85,7 @@ class TextRecognizer(object): self.input_tensor.copy_from_cpu(norm_img_batch) self.predictor.zero_copy_run() - if args.rec_algorithm != "RARE": + if self.rec_algorithm != "RARE": rec_idx_batch = self.output_tensors[0].copy_to_cpu() rec_idx_lod = self.output_tensors[0].lod()[0] predict_batch = self.output_tensors[1].copy_to_cpu() @@ -104,6 +108,8 @@ class TextRecognizer(object): else: rec_idx_batch = self.output_tensors[0].copy_to_cpu() predict_batch = self.output_tensors[1].copy_to_cpu() + elapse = time.time() - starttime + predict_time += elapse for rno in range(len(rec_idx_batch)): end_pos = np.where(rec_idx_batch[rno, :] == 1)[0] if len(end_pos) <= 1: @@ -112,8 +118,6 @@ class TextRecognizer(object): else: preds = rec_idx_batch[rno, 1:end_pos[1]] score = np.mean(predict_batch[rno, 1:end_pos[1]]) - #attenton index has 2 offset: beg and end - preds = preds - 2 preds_text = self.char_ops.decode(preds) rec_res.append([preds_text, score]) -- GitLab