diff --git a/configs/rec/rec_chinese_lite_train.yml b/configs/rec/rec_chinese_lite_train.yml index 90e9dab18ec134d35bd3bdf82a65122f35d23358..bbd2590c4843377ee4dc954cfd487f067ca6faf0 100755 --- a/configs/rec/rec_chinese_lite_train.yml +++ b/configs/rec/rec_chinese_lite_train.yml @@ -1,6 +1,6 @@ Global: algorithm: CRNN - use_gpu: false + use_gpu: true epoch_num: 3000 log_smooth_window: 20 print_batch_step: 10 @@ -16,7 +16,7 @@ Global: character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt loss_type: ctc reader_yml: ./configs/rec/rec_chinese_reader.yml - pretrain_weights: output/rec_CRNN/rec_mv3_crnn/best_accuracy + pretrain_weights: checkpoints: save_inference_dir: infer_img: diff --git a/configs/rec/rec_icdar15_train.yml b/configs/rec/rec_icdar15_train.yml index 35a1b17df72eb969c3941107fe1d89785332ff51..1cca2defb936a212eb262a402753534f387cbaa7 100755 --- a/configs/rec/rec_icdar15_train.yml +++ b/configs/rec/rec_icdar15_train.yml @@ -15,7 +15,7 @@ Global: character_type: en loss_type: ctc reader_yml: ./configs/rec/rec_icdar15_reader.yml - pretrain_weights: + pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy checkpoints: save_inference_dir: infer_img: diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml index 6cbca16f867282847e6dc24396a83bc81049108d..7bae1bd1f136e27a900e1425c6d9231d331e1a68 100755 --- a/configs/rec/rec_mv3_none_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml @@ -1,6 +1,6 @@ Global: algorithm: CRNN - use_gpu: false + use_gpu: true epoch_num: 72 log_smooth_window: 20 print_batch_step: 10 diff --git a/configs/rec/rec_mv3_tps_bilstm_attn.yml b/configs/rec/rec_mv3_tps_bilstm_attn.yml index 1787abc17029eeed2371a5a694ca6d0c94325b6e..2dbb9d0780aee8e1d9b5eb93a46ea66b1fa54486 100755 --- a/configs/rec/rec_mv3_tps_bilstm_attn.yml +++ b/configs/rec/rec_mv3_tps_bilstm_attn.yml @@ -1,6 +1,6 @@ Global: algorithm: RARE - use_gpu: false + use_gpu: true epoch_num: 72 log_smooth_window: 20 print_batch_step: 10 @@ -12,8 +12,7 @@ Global: test_batch_size_per_card: 256 image_shape: [3, 32, 100] max_text_length: 25 - character_type: ch - character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt + character_type: en loss_type: attention tps: true reader_yml: ./configs/rec/rec_benchmark_reader.yml diff --git a/doc/inference.md b/doc/inference.md index a15043b50bb29686e2fe4162cb1e0eefd8fabe3a..b16b89a963be1cd5a6b87d13b5ae696681036cf2 100644 --- a/doc/inference.md +++ b/doc/inference.md @@ -165,6 +165,12 @@ STAR-Net文本识别模型推理,可以执行如下命令: ``` python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en" ``` + +RARE 文本识别模型推理,可以执行如下命令: +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/sare/" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_algorithm="RARE" +``` + ![](imgs_words_en/word_336.png) 执行命令后,上面图像的识别结果如下: diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 914fcde31e5ea3efa544a0a2fa6a29266f835e5d..e8b485fb000a4aa7386d4351d3ad2fa4e706bb5a 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -32,10 +32,14 @@ class TextRecognizer(object): self.rec_image_shape = image_shape self.character_type = args.rec_char_type self.rec_batch_num = args.rec_batch_num + self.rec_algorithm = args.rec_algorithm char_ops_params = {} char_ops_params["character_type"] = args.rec_char_type char_ops_params["character_dict_path"] = args.rec_char_dict_path - char_ops_params['loss_type'] = 'ctc' + if self.rec_algorithm != "RARE": + char_ops_params['loss_type'] = 'ctc' + else: + char_ops_params['loss_type'] = 'attention' self.char_ops = CharacterOps(char_ops_params) def resize_norm_img(self, img, max_wh_ratio): @@ -81,7 +85,7 @@ class TextRecognizer(object): self.input_tensor.copy_from_cpu(norm_img_batch) self.predictor.zero_copy_run() - if args.rec_algorithm != "RARE": + if self.rec_algorithm != "RARE": rec_idx_batch = self.output_tensors[0].copy_to_cpu() rec_idx_lod = self.output_tensors[0].lod()[0] predict_batch = self.output_tensors[1].copy_to_cpu() @@ -104,6 +108,8 @@ class TextRecognizer(object): else: rec_idx_batch = self.output_tensors[0].copy_to_cpu() predict_batch = self.output_tensors[1].copy_to_cpu() + elapse = time.time() - starttime + predict_time += elapse for rno in range(len(rec_idx_batch)): end_pos = np.where(rec_idx_batch[rno, :] == 1)[0] if len(end_pos) <= 1: @@ -112,8 +118,6 @@ class TextRecognizer(object): else: preds = rec_idx_batch[rno, 1:end_pos[1]] score = np.mean(predict_batch[rno, 1:end_pos[1]]) - #attenton index has 2 offset: beg and end - preds = preds - 2 preds_text = self.char_ops.decode(preds) rec_res.append([preds_text, score])