From 021c1132a94de11db92b6bbc9739d16224271520 Mon Sep 17 00:00:00 2001 From: MissPenguin Date: Wed, 9 Dec 2020 06:45:25 +0000 Subject: [PATCH] add east & sast --- MANIFEST.in | 7 +- configs/det/det_mv3_east.yml | 111 +++ configs/det/det_r50_vd_east.yml | 110 +++ configs/det/det_r50_vd_sast_icdar15.yml | 110 +++ configs/det/det_r50_vd_sast_totaltext.yml | 109 +++ .../rec_en_number_lite_train.yml | 2 +- .../multi_language/rec_french_lite_train.yml | 2 +- .../multi_language/rec_german_lite_train.yml | 2 +- .../multi_language/rec_japan_lite_train.yml | 2 +- .../multi_language/rec_korean_lite_train.yml | 2 +- doc/doc_ch/whl.md | 58 +- doc/doc_en/whl_en.md | 56 +- paddleocr.py | 248 +++++-- ppocr/data/imaug/__init__.py | 3 + ppocr/data/imaug/east_process.py | 439 +++++++++++ ppocr/data/imaug/label_ops.py | 18 +- ppocr/data/imaug/operators.py | 48 +- ppocr/data/imaug/sast_process.py | 689 ++++++++++++++++++ ppocr/losses/__init__.py | 4 +- ppocr/losses/det_east_loss.py | 63 ++ ppocr/losses/det_sast_loss.py | 121 +++ ppocr/modeling/backbones/__init__.py | 1 + .../modeling/backbones/det_resnet_vd_sast.py | 285 ++++++++ ppocr/modeling/heads/__init__.py | 4 +- ppocr/modeling/heads/det_east_head.py | 121 +++ ppocr/modeling/heads/det_sast_head.py | 128 ++++ ppocr/modeling/necks/__init__.py | 4 +- ppocr/modeling/necks/east_fpn.py | 188 +++++ ppocr/modeling/necks/sast_fpn.py | 284 ++++++++ ppocr/postprocess/__init__.py | 4 +- ppocr/postprocess/east_postprocess.py | 141 ++++ ppocr/postprocess/locality_aware_nms.py | 199 +++++ ppocr/postprocess/rec_postprocess.py | 8 +- ppocr/postprocess/sast_postprocess.py | 295 ++++++++ setup.py | 2 +- tools/infer/predict_system.py | 33 +- 36 files changed, 3798 insertions(+), 103 deletions(-) create mode 100644 configs/det/det_mv3_east.yml create mode 100644 configs/det/det_r50_vd_east.yml create mode 100644 configs/det/det_r50_vd_sast_icdar15.yml create mode 100644 configs/det/det_r50_vd_sast_totaltext.yml create mode 100644 ppocr/data/imaug/east_process.py create mode 100644 ppocr/data/imaug/sast_process.py create mode 100644 ppocr/losses/det_east_loss.py create mode 100644 ppocr/losses/det_sast_loss.py create mode 100644 ppocr/modeling/backbones/det_resnet_vd_sast.py create mode 100644 ppocr/modeling/heads/det_east_head.py create mode 100644 ppocr/modeling/heads/det_sast_head.py create mode 100644 ppocr/modeling/necks/east_fpn.py create mode 100644 ppocr/modeling/necks/sast_fpn.py create mode 100644 ppocr/postprocess/east_postprocess.py create mode 100644 ppocr/postprocess/locality_aware_nms.py create mode 100644 ppocr/postprocess/sast_postprocess.py diff --git a/MANIFEST.in b/MANIFEST.in index 388882df..4c16c09d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,8 +1,7 @@ include LICENSE.txt include README.md -recursive-include ppocr/utils *.txt utility.py character.py check.py -recursive-include ppocr/data/det *.py +recursive-include ppocr/utils *.txt utility.py logging.py +recursive-include ppocr/data/ *.py recursive-include ppocr/postprocess *.py -recursive-include ppocr/postprocess/lanms *.* -recursive-include tools/infer *.py +recursive-include tools/infer *.py \ No newline at end of file diff --git a/configs/det/det_mv3_east.yml b/configs/det/det_mv3_east.yml new file mode 100644 index 00000000..05581a76 --- /dev/null +++ b/configs/det/det_mv3_east.yml @@ -0,0 +1,111 @@ +Global: + use_gpu: true + epoch_num: 10000 + log_smooth_window: 20 + print_batch_step: 2 + save_model_dir: ./output/east_mv3/ + save_epoch_step: 1000 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [4000, 5000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + load_static_weights: True + cal_metric_during_train: False + pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: + save_res_path: ./output/det_east/predicts_east.txt + +Architecture: + model_type: det + algorithm: EAST + Transform: + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: large + Neck: + name: EASTFPN + model_name: small + Head: + name: EASTHead + model_name: small + +Loss: + name: EASTLoss + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + # name: Cosine + learning_rate: 0.001 + # warmup_epoch: 0 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: EASTPostProcess + score_thresh: 0.8 + cover_thresh: 0.1 + nms_thresh: 0.2 + +Metric: + name: DetMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [1.0] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - EASTProcessTrain: + image_shape: [512, 512] + background_ratio: 0.125 + min_crop_side_ratio: 0.1 + min_text_size: 10 + - KeepKeys: + keep_keys: ['image', 'score_map', 'geo_map', 'training_mask'] # dataloader will return list in this order + loader: + shuffle: True + drop_last: False + batch_size_per_card: 16 + num_workers: 8 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: + limit_side_len: 2400 + limit_type: max + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 \ No newline at end of file diff --git a/configs/det/det_r50_vd_east.yml b/configs/det/det_r50_vd_east.yml new file mode 100644 index 00000000..b8fe55d4 --- /dev/null +++ b/configs/det/det_r50_vd_east.yml @@ -0,0 +1,110 @@ +Global: + use_gpu: true + epoch_num: 10000 + log_smooth_window: 20 + print_batch_step: 2 + save_model_dir: ./output/east_r50_vd/ + save_epoch_step: 1000 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [4000, 5000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + load_static_weights: True + cal_metric_during_train: False + pretrained_model: ./pretrain_models/ResNet50_vd_pretrained/ + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: + save_res_path: ./output/det_east/predicts_east.txt + +Architecture: + model_type: det + algorithm: EAST + Transform: + Backbone: + name: ResNet + layers: 50 + Neck: + name: EASTFPN + model_name: large + Head: + name: EASTHead + model_name: large + +Loss: + name: EASTLoss + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + # name: Cosine + learning_rate: 0.001 + # warmup_epoch: 0 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: EASTPostProcess + score_thresh: 0.8 + cover_thresh: 0.1 + nms_thresh: 0.2 + +Metric: + name: DetMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [1.0] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - EASTProcessTrain: + image_shape: [512, 512] + background_ratio: 0.125 + min_crop_side_ratio: 0.1 + min_text_size: 10 + - KeepKeys: + keep_keys: ['image', 'score_map', 'geo_map', 'training_mask'] # dataloader will return list in this order + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 8 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: + limit_side_len: 2400 + limit_type: max + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 \ No newline at end of file diff --git a/configs/det/det_r50_vd_sast_icdar15.yml b/configs/det/det_r50_vd_sast_icdar15.yml new file mode 100644 index 00000000..7ca93cec --- /dev/null +++ b/configs/det/det_r50_vd_sast_icdar15.yml @@ -0,0 +1,110 @@ +Global: + use_gpu: true + epoch_num: 5000 + log_smooth_window: 20 + print_batch_step: 2 + save_model_dir: ./output/sast_r50_vd_ic15/ + save_epoch_step: 1000 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [4000, 5000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + load_static_weights: True + cal_metric_during_train: False + pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/ + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: + save_res_path: ./output/sast_r50_vd_ic15/predicts_sast.txt + +Architecture: + model_type: det + algorithm: SAST + Transform: + Backbone: + name: ResNet_SAST + layers: 50 + Neck: + name: SASTFPN + with_cab: True + Head: + name: SASTHead + +Loss: + name: SASTLoss + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + # name: Cosine + learning_rate: 0.001 + # warmup_epoch: 0 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: SASTPostProcess + score_thresh: 0.5 + sample_pts_num: 2 + nms_thresh: 0.2 + expand_scale: 1.0 + shrink_ratio_of_width: 0.3 + +Metric: + name: DetMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ + label_file_path: [./train_data/art_latin_icdar_14pt/train_no_tt_test/train_label_json.txt, ./train_data/total_text_icdar_14pt/train_label_json.txt] + data_ratio_list: [0.5, 0.5] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - SASTProcessTrain: + image_shape: [512, 512] + min_crop_side_ratio: 0.3 + min_crop_size: 24 + min_text_size: 4 + max_text_size: 512 + - KeepKeys: + keep_keys: ['image', 'score_map', 'border_map', 'training_mask', 'tvo_map', 'tco_map'] # dataloader will return list in this order + loader: + shuffle: True + drop_last: False + batch_size_per_card: 4 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: + resize_long: 1536 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 \ No newline at end of file diff --git a/configs/det/det_r50_vd_sast_totaltext.yml b/configs/det/det_r50_vd_sast_totaltext.yml new file mode 100644 index 00000000..a9a037c8 --- /dev/null +++ b/configs/det/det_r50_vd_sast_totaltext.yml @@ -0,0 +1,109 @@ +Global: + use_gpu: true + epoch_num: 5000 + log_smooth_window: 20 + print_batch_step: 2 + save_model_dir: ./output/sast_r50_vd_tt/ + save_epoch_step: 1000 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [4000, 5000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + load_static_weights: True + cal_metric_during_train: False + pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/ + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: + save_res_path: ./output/sast_r50_vd_tt/predicts_sast.txt + +Architecture: + model_type: det + algorithm: SAST + Transform: + Backbone: + name: ResNet_SAST + layers: 50 + Neck: + name: SASTFPN + with_cab: True + Head: + name: SASTHead + +Loss: + name: SASTLoss + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + # name: Cosine + learning_rate: 0.001 + # warmup_epoch: 0 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: SASTPostProcess + score_thresh: 0.5 + sample_pts_num: 6 + nms_thresh: 0.2 + expand_scale: 1.2 + shrink_ratio_of_width: 0.2 + +Metric: + name: DetMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + label_file_list: [./train_data/icdar2013/train_label_json.txt, ./train_data/icdar2015/train_label_json.txt, ./train_data/icdar17_mlt_latin/train_label_json.txt, ./train_data/coco_text_icdar_4pts/train_label_json.txt] + ratio_list: [0.1, 0.45, 0.3, 0.15] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - SASTProcessTrain: + image_shape: [512, 512] + min_crop_side_ratio: 0.3 + min_crop_size: 24 + min_text_size: 4 + max_text_size: 512 + - KeepKeys: + keep_keys: ['image', 'score_map', 'border_map', 'training_mask', 'tvo_map', 'tco_map'] # dataloader will return list in this order + loader: + shuffle: True + drop_last: False + batch_size_per_card: 4 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ + label_file_list: + - ./train_data/total_text_icdar_14pt/test_label_json.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: + resize_long: 768 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 \ No newline at end of file diff --git a/configs/rec/multi_language/rec_en_number_lite_train.yml b/configs/rec/multi_language/rec_en_number_lite_train.yml index 9d0f1f00..70d825e6 100644 --- a/configs/rec/multi_language/rec_en_number_lite_train.yml +++ b/configs/rec/multi_language/rec_en_number_lite_train.yml @@ -15,7 +15,7 @@ Global: use_visualdl: False infer_img: # for data or label process - character_dict_path: ppocr/utils/ic15_dict.txt + character_dict_path: ppocr/utils/dict/ic15_dict.txt character_type: ch max_text_length: 25 infer_mode: False diff --git a/configs/rec/multi_language/rec_french_lite_train.yml b/configs/rec/multi_language/rec_french_lite_train.yml index da3aad58..0e8f4eb3 100644 --- a/configs/rec/multi_language/rec_french_lite_train.yml +++ b/configs/rec/multi_language/rec_french_lite_train.yml @@ -15,7 +15,7 @@ Global: use_visualdl: False infer_img: # for data or label process - character_dict_path: ppocr/utils/french_dict.txt + character_dict_path: ppocr/utils/dict/french_dict.txt character_type: french max_text_length: 25 infer_mode: False diff --git a/configs/rec/multi_language/rec_german_lite_train.yml b/configs/rec/multi_language/rec_german_lite_train.yml index 403be669..9978a21e 100644 --- a/configs/rec/multi_language/rec_german_lite_train.yml +++ b/configs/rec/multi_language/rec_german_lite_train.yml @@ -15,7 +15,7 @@ Global: use_visualdl: False infer_img: # for data or label process - character_dict_path: ppocr/utils/german_dict.txt + character_dict_path: ppocr/utils/dict/german_dict.txt character_type: german max_text_length: 25 infer_mode: False diff --git a/configs/rec/multi_language/rec_japan_lite_train.yml b/configs/rec/multi_language/rec_japan_lite_train.yml index 5ff61c01..938d377e 100644 --- a/configs/rec/multi_language/rec_japan_lite_train.yml +++ b/configs/rec/multi_language/rec_japan_lite_train.yml @@ -15,7 +15,7 @@ Global: use_visualdl: False infer_img: # for data or label process - character_dict_path: ppocr/utils/japan_dict.txt + character_dict_path: ppocr/utils/dict/japan_dict.txt character_type: japan max_text_length: 25 infer_mode: False diff --git a/configs/rec/multi_language/rec_korean_lite_train.yml b/configs/rec/multi_language/rec_korean_lite_train.yml index 2b2211ef..7b070c44 100644 --- a/configs/rec/multi_language/rec_korean_lite_train.yml +++ b/configs/rec/multi_language/rec_korean_lite_train.yml @@ -15,7 +15,7 @@ Global: use_visualdl: False infer_img: # for data or label process - character_dict_path: ppocr/utils/korean_dict.txt + character_dict_path: ppocr/utils/dict/korean_dict.txt character_type: korean max_text_length: 25 infer_mode: False diff --git a/doc/doc_ch/whl.md b/doc/doc_ch/whl.md index 1b04a9a8..c51f3277 100644 --- a/doc/doc_ch/whl.md +++ b/doc/doc_ch/whl.md @@ -261,6 +261,61 @@ im_show.save('result.jpg') paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true --cls true ``` +### 使用网络图片或者numpy数组作为输入 + +1. 网络图片 + +代码使用 +```python +from paddleocr import PaddleOCR, draw_ocr +# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换 +# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。 +ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory +img_path = 'http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg' +result = ocr.ocr(img_path, cls=True) +for line in result: + print(line) + +# 显示结果 +from PIL import Image +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = Image.fromarray(im_show) +im_show.save('result.jpg') +``` +命令行模式 +```bash +paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg --use_angle_cls=true +``` + +2. numpy数组 +仅通过代码使用时支持numpy数组作为输入 +```python +from paddleocr import PaddleOCR, draw_ocr +# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换 +# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。 +ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory +img_path = 'PaddleOCR/doc/imgs/11.jpg' +img = cv2.imread(img_path) +# img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY), 如果你自己训练的模型支持灰度图,可以将这句话的注释取消 +result = ocr.ocr(img_path, cls=True) +for line in result: + print(line) + +# 显示结果 +from PIL import Image +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = Image.fromarray(im_show) +im_show.save('result.jpg') +``` + ## 参数说明 | 字段 | 说明 | 默认值 | @@ -285,6 +340,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_ | max_text_length | 识别算法能识别的最大文字长度 | 25 | | rec_char_dict_path | 识别模型字典路径,当rec_model_dir使用方式2传参时需要修改为自己的字典路径 | ./ppocr/utils/ppocr_keys_v1.txt | | use_space_char | 是否识别空格 | TRUE | +| drop_score | 对输出按照分数(来自于识别模型)进行过滤,低于此分数的不返回 | 0.5 | | use_angle_cls | 是否加载分类模型 | FALSE | | cls_model_dir | 分类模型所在文件夹。传参方式有两种,1. None: 自动下载内置模型到 `~/.paddleocr/cls`;2.自己转换好的inference模型路径,模型路径下必须包含model和params文件 | None | | cls_image_shape | 分类算法的输入图片尺寸 | "3, 48, 192" | @@ -295,4 +351,4 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_ | lang | 模型语言类型,目前支持 中文(ch)和英文(en) | ch | | det | 前向时使用启动检测 | TRUE | | rec | 前向时是否启动识别 | TRUE | -| cls | 前向时是否启动分类 | FALSE | +| cls | 前向时是否启动分类 (命令行模式下使用use_angle_cls控制前向是否启动分类) | FALSE | diff --git a/doc/doc_en/whl_en.md b/doc/doc_en/whl_en.md index ffbced34..c25999d4 100644 --- a/doc/doc_en/whl_en.md +++ b/doc/doc_en/whl_en.md @@ -271,6 +271,59 @@ im_show.save('result.jpg') paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true --cls true ``` +### Use web images or numpy array as input + +1. Web image + +Use by code +```python +from paddleocr import PaddleOCR, draw_ocr +ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory +img_path = 'http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg' +result = ocr.ocr(img_path, cls=True) +for line in result: + print(line) + +# show result +from PIL import Image +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = Image.fromarray(im_show) +im_show.save('result.jpg') +``` +Use by command line +```bash +paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg --use_angle_cls=true +``` + +2. Numpy array +Support numpy array as input only when used by code + +```python +from paddleocr import PaddleOCR, draw_ocr +ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory +img_path = 'PaddleOCR/doc/imgs/11.jpg' +img = cv2.imread(img_path) +# img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY), If your own training model supports grayscale images, you can uncomment this line +result = ocr.ocr(img_path, cls=True) +for line in result: + print(line) + +# show result +from PIL import Image +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = Image.fromarray(im_show) +im_show.save('result.jpg') +``` + + ## Parameter Description | Parameter | Description | Default value | @@ -295,6 +348,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_ | max_text_length | The maximum text length that the recognition algorithm can recognize | 25 | | rec_char_dict_path | the alphabet path which needs to be modified to your own path when `rec_model_Name` use mode 2 | ./ppocr/utils/ppocr_keys_v1.txt | | use_space_char | Whether to recognize spaces | TRUE | +| drop_score | Filter the output by score (from the recognition model), and those below this score will not be returned | 0.5 | | use_angle_cls | Whether to load classification model | FALSE | | cls_model_dir | the classification inference model folder. There are two ways to transfer parameters, 1. None: Automatically download the built-in model to `~/.paddleocr/cls`; 2. The path of the inference model converted by yourself, the model and params files must be included in the model path | None | | cls_image_shape | image shape of classification algorithm | "3,48,192" | @@ -305,4 +359,4 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_ | lang | The support language, now only Chinese(ch)、English(en)、French(french)、German(german)、Korean(korean)、Japanese(japan) are supported | ch | | det | Enable detction when `ppocr.ocr` func exec | TRUE | | rec | Enable recognition when `ppocr.ocr` func exec | TRUE | -| cls | Enable classification when `ppocr.ocr` func exec | FALSE | +| cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE | diff --git a/paddleocr.py b/paddleocr.py index d3d73cb1..17306e79 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -26,17 +26,50 @@ import requests from tqdm import tqdm from tools.infer import predict_system -from ppocr.utils.utility import initial_logger +from ppocr.utils.logging import get_logger -logger = initial_logger() +logger = get_logger() from ppocr.utils.utility import check_and_read_gif, get_image_file_list __all__ = ['PaddleOCR'] -model_params = { - 'det': 'https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar', - 'rec': - 'https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar', +model_urls = { + 'det': + 'https://paddleocr.bj.bcebos.com/20-09-22/mobile/det/ch_ppocr_mobile_v1.1_det_infer.tar', + 'rec': { + 'ch': { + 'url': + 'https://paddleocr.bj.bcebos.com/20-09-22/mobile/rec/ch_ppocr_mobile_v1.1_rec_infer.tar', + 'dict_path': './ppocr/utils/ppocr_keys_v1.txt' + }, + 'en': { + 'url': + 'https://paddleocr.bj.bcebos.com/20-09-22/mobile/en/en_ppocr_mobile_v1.1_rec_infer.tar', + 'dict_path': './ppocr/utils/ic15_dict.txt' + }, + 'french': { + 'url': + 'https://paddleocr.bj.bcebos.com/20-09-22/mobile/fr/french_ppocr_mobile_v1.1_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/french_dict.txt' + }, + 'german': { + 'url': + 'https://paddleocr.bj.bcebos.com/20-09-22/mobile/ge/german_ppocr_mobile_v1.1_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/german_dict.txt' + }, + 'korean': { + 'url': + 'https://paddleocr.bj.bcebos.com/20-09-22/mobile/kr/korean_ppocr_mobile_v1.1_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/korean_dict.txt' + }, + 'japan': { + 'url': + 'https://paddleocr.bj.bcebos.com/20-09-22/mobile/jp/japan_ppocr_mobile_v1.1_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/japan_dict.txt' + } + }, + 'cls': + 'https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile_v1.1_cls_infer.tar' } SUPPORT_DET_MODEL = ['DB'] @@ -54,8 +87,8 @@ def download_with_progressbar(url, save_path): progress_bar.update(len(data)) file.write(data) progress_bar.close() - if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: - logger.error("ERROR, something went wrong") + if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes: + logger.error("Something went wrong while downloading models") sys.exit(0) @@ -63,7 +96,7 @@ def maybe_download(model_storage_directory, url): # using custom model if not os.path.exists(os.path.join( model_storage_directory, 'model')) or not os.path.exists( - os.path.join(model_storage_directory, 'params')): + os.path.join(model_storage_directory, 'params')): tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) print('download {} to {}'.format(url, tmp_path)) os.makedirs(model_storage_directory, exist_ok=True) @@ -84,53 +117,102 @@ def maybe_download(model_storage_directory, url): os.remove(tmp_path) -def parse_args(): +def parse_args(mMain=True, add_help=True): import argparse def str2bool(v): return v.lower() in ("true", "t", "1") - parser = argparse.ArgumentParser() - # params for prediction engine - parser.add_argument("--use_gpu", type=str2bool, default=True) - parser.add_argument("--ir_optim", type=str2bool, default=True) - parser.add_argument("--use_tensorrt", type=str2bool, default=False) - parser.add_argument("--gpu_mem", type=int, default=8000) - - # params for text detector - parser.add_argument("--image_dir", type=str) - parser.add_argument("--det_algorithm", type=str, default='DB') - parser.add_argument("--det_model_dir", type=str, default=None) - parser.add_argument("--det_max_side_len", type=float, default=960) - - # DB parmas - parser.add_argument("--det_db_thresh", type=float, default=0.3) - parser.add_argument("--det_db_box_thresh", type=float, default=0.5) - parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0) - - # EAST parmas - parser.add_argument("--det_east_score_thresh", type=float, default=0.8) - parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) - parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) - - # params for text recognizer - parser.add_argument("--rec_algorithm", type=str, default='CRNN') - parser.add_argument("--rec_model_dir", type=str, default=None) - parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") - parser.add_argument("--rec_char_type", type=str, default='ch') - parser.add_argument("--rec_batch_num", type=int, default=30) - parser.add_argument("--max_text_length", type=int, default=25) - parser.add_argument( - "--rec_char_dict_path", - type=str, - default="./ppocr/utils/ppocr_keys_v1.txt") - parser.add_argument("--use_space_char", type=bool, default=True) - parser.add_argument("--enable_mkldnn", type=bool, default=False) - - parser.add_argument("--det", type=str2bool, default=True) - parser.add_argument("--rec", type=str2bool, default=True) - parser.add_argument("--use_zero_copy_run", type=bool, default=False) - return parser.parse_args() + if mMain: + parser = argparse.ArgumentParser(add_help=add_help) + # params for prediction engine + parser.add_argument("--use_gpu", type=str2bool, default=True) + parser.add_argument("--ir_optim", type=str2bool, default=True) + parser.add_argument("--use_tensorrt", type=str2bool, default=False) + parser.add_argument("--gpu_mem", type=int, default=8000) + + # params for text detector + parser.add_argument("--image_dir", type=str) + parser.add_argument("--det_algorithm", type=str, default='DB') + parser.add_argument("--det_model_dir", type=str, default=None) + parser.add_argument("--det_limit_side_len", type=float, default=960) + parser.add_argument("--det_limit_type", type=str, default='max') + + # DB parmas + parser.add_argument("--det_db_thresh", type=float, default=0.3) + parser.add_argument("--det_db_box_thresh", type=float, default=0.5) + parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0) + + # EAST parmas + parser.add_argument("--det_east_score_thresh", type=float, default=0.8) + parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) + parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) + + # params for text recognizer + parser.add_argument("--rec_algorithm", type=str, default='CRNN') + parser.add_argument("--rec_model_dir", type=str, default=None) + parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") + parser.add_argument("--rec_char_type", type=str, default='ch') + parser.add_argument("--rec_batch_num", type=int, default=30) + parser.add_argument("--max_text_length", type=int, default=25) + parser.add_argument("--rec_char_dict_path", type=str, default=None) + parser.add_argument("--use_space_char", type=bool, default=True) + parser.add_argument("--drop_score", type=float, default=0.5) + + # params for text classifier + parser.add_argument("--cls_model_dir", type=str, default=None) + parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") + parser.add_argument("--label_list", type=list, default=['0', '180']) + parser.add_argument("--cls_batch_num", type=int, default=30) + parser.add_argument("--cls_thresh", type=float, default=0.9) + + parser.add_argument("--enable_mkldnn", type=bool, default=False) + parser.add_argument("--use_zero_copy_run", type=bool, default=False) + parser.add_argument("--use_pdserving", type=str2bool, default=False) + + parser.add_argument("--lang", type=str, default='ch') + parser.add_argument("--det", type=str2bool, default=True) + parser.add_argument("--rec", type=str2bool, default=True) + parser.add_argument("--use_angle_cls", type=str2bool, default=False) + return parser.parse_args() + else: + return argparse.Namespace(use_gpu=True, + ir_optim=True, + use_tensorrt=False, + gpu_mem=8000, + image_dir='', + det_algorithm='DB', + det_model_dir=None, + det_limit_side_len=960, + det_limit_type='max', + det_db_thresh=0.3, + det_db_box_thresh=0.5, + det_db_unclip_ratio=2.0, + det_east_score_thresh=0.8, + det_east_cover_thresh=0.1, + det_east_nms_thresh=0.2, + rec_algorithm='CRNN', + rec_model_dir=None, + rec_image_shape="3, 32, 320", + rec_char_type='ch', + rec_batch_num=30, + max_text_length=25, + rec_char_dict_path=None, + use_space_char=True, + drop_score=0.5, + cls_model_dir=None, + cls_image_shape="3, 48, 192", + label_list=['0', '180'], + cls_batch_num=30, + cls_thresh=0.9, + enable_mkldnn=False, + use_zero_copy_run=False, + use_pdserving=False, + lang='ch', + det=True, + rec=True, + use_angle_cls=False + ) class PaddleOCR(predict_system.TextSystem): @@ -140,18 +222,31 @@ class PaddleOCR(predict_system.TextSystem): args: **kwargs: other params show in paddleocr --help """ - postprocess_params = parse_args() + postprocess_params = parse_args(mMain=False, add_help=False) postprocess_params.__dict__.update(**kwargs) + self.use_angle_cls = postprocess_params.use_angle_cls + lang = postprocess_params.lang + assert lang in model_urls[ + 'rec'], 'param lang must in {}, but got {}'.format( + model_urls['rec'].keys(), lang) + if postprocess_params.rec_char_dict_path is None: + postprocess_params.rec_char_dict_path = model_urls['rec'][lang][ + 'dict_path'] # init model dir if postprocess_params.det_model_dir is None: postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det') if postprocess_params.rec_model_dir is None: - postprocess_params.rec_model_dir = os.path.join(BASE_DIR, 'rec') + postprocess_params.rec_model_dir = os.path.join( + BASE_DIR, 'rec/{}'.format(lang)) + if postprocess_params.cls_model_dir is None: + postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls') print(postprocess_params) # download model - maybe_download(postprocess_params.det_model_dir, model_params['det']) - maybe_download(postprocess_params.rec_model_dir, model_params['rec']) + maybe_download(postprocess_params.det_model_dir, model_urls['det']) + maybe_download(postprocess_params.rec_model_dir, + model_urls['rec'][lang]['url']) + maybe_download(postprocess_params.cls_model_dir, model_urls['cls']) if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL: logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) @@ -166,7 +261,7 @@ class PaddleOCR(predict_system.TextSystem): # init det_model and rec_model super().__init__(postprocess_params) - def ocr(self, img, det=True, rec=True): + def ocr(self, img, det=True, rec=True, cls=False): """ ocr with paddleocr args: @@ -175,7 +270,16 @@ class PaddleOCR(predict_system.TextSystem): rec: use text recognition or not, if false, only det will be exec. default is True """ assert isinstance(img, (np.ndarray, list, str)) + if isinstance(img, list) and det == True: + logger.error('When input a list of images, det must be false') + exit(0) + + self.use_angle_cls = cls if isinstance(img, str): + # download net image + if img.startswith('http'): + download_with_progressbar(img, 'tmp.jpg') + img = 'tmp.jpg' image_file = img img, flag = check_and_read_gif(image_file) if not flag: @@ -183,6 +287,8 @@ class PaddleOCR(predict_system.TextSystem): if img is None: logger.error("error in loading image:{}".format(image_file)) return None + if isinstance(img, np.ndarray) and len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if det and rec: dt_boxes, rec_res = self.__call__(img) return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] @@ -194,20 +300,34 @@ class PaddleOCR(predict_system.TextSystem): else: if not isinstance(img, list): img = [img] + if self.use_angle_cls: + img, cls_res, elapse = self.text_classifier(img) + if not rec: + return cls_res rec_res, elapse = self.text_recognizer(img) return rec_res def main(): - # for com - args = parse_args() - image_file_list = get_image_file_list(args.image_dir) + # for cmd + args = parse_args(mMain=True) + image_dir = args.image_dir + if image_dir.startswith('http'): + download_with_progressbar(image_dir, 'tmp.jpg') + image_file_list = ['tmp.jpg'] + else: + image_file_list = get_image_file_list(args.image_dir) if len(image_file_list) == 0: logger.error('no images find in {}'.format(args.image_dir)) return - ocr_engine = PaddleOCR() + + ocr_engine = PaddleOCR(**(args.__dict__)) for img_path in image_file_list: - print(img_path) - result = ocr_engine.ocr(img_path, det=args.det, rec=args.rec) - for line in result: - print(line) \ No newline at end of file + logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10)) + result = ocr_engine.ocr(img_path, + det=args.det, + rec=args.rec, + cls=args.use_angle_cls) + if result is not None: + for line in result: + logger.info(line) diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index fd143b53..6ea4dd8e 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -26,6 +26,9 @@ from .randaugment import RandAugment from .operators import * from .label_ops import * +from .east_process import * +from .sast_process import * + def transform(data, ops=None): """ transform """ diff --git a/ppocr/data/imaug/east_process.py b/ppocr/data/imaug/east_process.py new file mode 100644 index 00000000..b1d7a5e5 --- /dev/null +++ b/ppocr/data/imaug/east_process.py @@ -0,0 +1,439 @@ +#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import math +import cv2 +import numpy as np +import json +import sys +import os + +__all__ = ['EASTProcessTrain'] + + +class EASTProcessTrain(object): + def __init__(self, + image_shape = [512, 512], + background_ratio = 0.125, + min_crop_side_ratio = 0.1, + min_text_size = 10, + **kwargs): + self.input_size = image_shape[1] + self.random_scale = np.array([0.5, 1, 2.0, 3.0]) + self.background_ratio = background_ratio + self.min_crop_side_ratio = min_crop_side_ratio + self.min_text_size = min_text_size + + def preprocess(self, im): + input_size = self.input_size + im_shape = im.shape + im_size_min = np.min(im_shape[0:2]) + im_size_max = np.max(im_shape[0:2]) + im_scale = float(input_size) / float(im_size_max) + im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale) + img_mean = [0.485, 0.456, 0.406] + img_std = [0.229, 0.224, 0.225] + # im = im[:, :, ::-1].astype(np.float32) + im = im / 255 + im -= img_mean + im /= img_std + new_h, new_w, _ = im.shape + im_padded = np.zeros((input_size, input_size, 3), dtype=np.float32) + im_padded[:new_h, :new_w, :] = im + im_padded = im_padded.transpose((2, 0, 1)) + im_padded = im_padded[np.newaxis, :] + return im_padded, im_scale + + def rotate_im_poly(self, im, text_polys): + """ + rotate image with 90 / 180 / 270 degre + """ + im_w, im_h = im.shape[1], im.shape[0] + dst_im = im.copy() + dst_polys = [] + rand_degree_ratio = np.random.rand() + rand_degree_cnt = 1 + if 0.333 < rand_degree_ratio < 0.666: + rand_degree_cnt = 2 + elif rand_degree_ratio > 0.666: + rand_degree_cnt = 3 + for i in range(rand_degree_cnt): + dst_im = np.rot90(dst_im) + rot_degree = -90 * rand_degree_cnt + rot_angle = rot_degree * math.pi / 180.0 + n_poly = text_polys.shape[0] + cx, cy = 0.5 * im_w, 0.5 * im_h + ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0] + for i in range(n_poly): + wordBB = text_polys[i] + poly = [] + for j in range(4): + sx, sy = wordBB[j][0], wordBB[j][1] + dx = math.cos(rot_angle) * (sx - cx)\ + - math.sin(rot_angle) * (sy - cy) + ncx + dy = math.sin(rot_angle) * (sx - cx)\ + + math.cos(rot_angle) * (sy - cy) + ncy + poly.append([dx, dy]) + dst_polys.append(poly) + dst_polys = np.array(dst_polys, dtype=np.float32) + return dst_im, dst_polys + + def polygon_area(self, poly): + """ + compute area of a polygon + :param poly: + :return: + """ + edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]), + (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]), + (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]), + (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])] + return np.sum(edge) / 2. + + def check_and_validate_polys(self, polys, tags, img_height, img_width): + """ + check so that the text poly is in the same direction, + and also filter some invalid polygons + :param polys: + :param tags: + :return: + """ + h, w = img_height, img_width + if polys.shape[0] == 0: + return polys + polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) + polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1) + + validated_polys = [] + validated_tags = [] + for poly, tag in zip(polys, tags): + p_area = self.polygon_area(poly) + #invalid poly + if abs(p_area) < 1: + continue + if p_area > 0: + #'poly in wrong direction' + if not tag: + tag = True #reversed cases should be ignore + poly = poly[(0, 3, 2, 1), :] + validated_polys.append(poly) + validated_tags.append(tag) + return np.array(validated_polys), np.array(validated_tags) + + def draw_img_polys(self, img, polys): + if len(img.shape) == 4: + img = np.squeeze(img, axis=0) + if img.shape[0] == 3: + img = img.transpose((1, 2, 0)) + img[:, :, 2] += 123.68 + img[:, :, 1] += 116.78 + img[:, :, 0] += 103.94 + cv2.imwrite("tmp.jpg", img) + img = cv2.imread("tmp.jpg") + for box in polys: + box = box.astype(np.int32).reshape((-1, 1, 2)) + cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2) + import random + ino = random.randint(0, 100) + cv2.imwrite("tmp_%d.jpg" % ino, img) + return + + def shrink_poly(self, poly, r): + """ + fit a poly inside the origin poly, maybe bugs here... + used for generate the score map + :param poly: the text poly + :param r: r in the paper + :return: the shrinked poly + """ + # shrink ratio + R = 0.3 + # find the longer pair + dist0 = np.linalg.norm(poly[0] - poly[1]) + dist1 = np.linalg.norm(poly[2] - poly[3]) + dist2 = np.linalg.norm(poly[0] - poly[3]) + dist3 = np.linalg.norm(poly[1] - poly[2]) + if dist0 + dist1 > dist2 + dist3: + # first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2) + ## p0, p1 + theta = np.arctan2((poly[1][1] - poly[0][1]), + (poly[1][0] - poly[0][0])) + poly[0][0] += R * r[0] * np.cos(theta) + poly[0][1] += R * r[0] * np.sin(theta) + poly[1][0] -= R * r[1] * np.cos(theta) + poly[1][1] -= R * r[1] * np.sin(theta) + ## p2, p3 + theta = np.arctan2((poly[2][1] - poly[3][1]), + (poly[2][0] - poly[3][0])) + poly[3][0] += R * r[3] * np.cos(theta) + poly[3][1] += R * r[3] * np.sin(theta) + poly[2][0] -= R * r[2] * np.cos(theta) + poly[2][1] -= R * r[2] * np.sin(theta) + ## p0, p3 + theta = np.arctan2((poly[3][0] - poly[0][0]), + (poly[3][1] - poly[0][1])) + poly[0][0] += R * r[0] * np.sin(theta) + poly[0][1] += R * r[0] * np.cos(theta) + poly[3][0] -= R * r[3] * np.sin(theta) + poly[3][1] -= R * r[3] * np.cos(theta) + ## p1, p2 + theta = np.arctan2((poly[2][0] - poly[1][0]), + (poly[2][1] - poly[1][1])) + poly[1][0] += R * r[1] * np.sin(theta) + poly[1][1] += R * r[1] * np.cos(theta) + poly[2][0] -= R * r[2] * np.sin(theta) + poly[2][1] -= R * r[2] * np.cos(theta) + else: + ## p0, p3 + # print poly + theta = np.arctan2((poly[3][0] - poly[0][0]), + (poly[3][1] - poly[0][1])) + poly[0][0] += R * r[0] * np.sin(theta) + poly[0][1] += R * r[0] * np.cos(theta) + poly[3][0] -= R * r[3] * np.sin(theta) + poly[3][1] -= R * r[3] * np.cos(theta) + ## p1, p2 + theta = np.arctan2((poly[2][0] - poly[1][0]), + (poly[2][1] - poly[1][1])) + poly[1][0] += R * r[1] * np.sin(theta) + poly[1][1] += R * r[1] * np.cos(theta) + poly[2][0] -= R * r[2] * np.sin(theta) + poly[2][1] -= R * r[2] * np.cos(theta) + ## p0, p1 + theta = np.arctan2((poly[1][1] - poly[0][1]), + (poly[1][0] - poly[0][0])) + poly[0][0] += R * r[0] * np.cos(theta) + poly[0][1] += R * r[0] * np.sin(theta) + poly[1][0] -= R * r[1] * np.cos(theta) + poly[1][1] -= R * r[1] * np.sin(theta) + ## p2, p3 + theta = np.arctan2((poly[2][1] - poly[3][1]), + (poly[2][0] - poly[3][0])) + poly[3][0] += R * r[3] * np.cos(theta) + poly[3][1] += R * r[3] * np.sin(theta) + poly[2][0] -= R * r[2] * np.cos(theta) + poly[2][1] -= R * r[2] * np.sin(theta) + return poly + + def generate_quad(self, im_size, polys, tags): + """ + Generate quadrangle. + """ + h, w = im_size + poly_mask = np.zeros((h, w), dtype=np.uint8) + score_map = np.zeros((h, w), dtype=np.uint8) + # (x1, y1, ..., x4, y4, short_edge_norm) + geo_map = np.zeros((h, w, 9), dtype=np.float32) + # mask used during traning, to ignore some hard areas + training_mask = np.ones((h, w), dtype=np.uint8) + for poly_idx, poly_tag in enumerate(zip(polys, tags)): + poly = poly_tag[0] + tag = poly_tag[1] + + r = [None, None, None, None] + for i in range(4): + dist1 = np.linalg.norm(poly[i] - poly[(i + 1) % 4]) + dist2 = np.linalg.norm(poly[i] - poly[(i - 1) % 4]) + r[i] = min(dist1, dist2) + # score map + shrinked_poly = self.shrink_poly( + poly.copy(), r).astype(np.int32)[np.newaxis, :, :] + cv2.fillPoly(score_map, shrinked_poly, 1) + cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1) + # if the poly is too small, then ignore it during training + poly_h = min( + np.linalg.norm(poly[0] - poly[3]), + np.linalg.norm(poly[1] - poly[2])) + poly_w = min( + np.linalg.norm(poly[0] - poly[1]), + np.linalg.norm(poly[2] - poly[3])) + if min(poly_h, poly_w) < self.min_text_size: + cv2.fillPoly(training_mask, + poly.astype(np.int32)[np.newaxis, :, :], 0) + + if tag: + cv2.fillPoly(training_mask, + poly.astype(np.int32)[np.newaxis, :, :], 0) + + xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1)) + # geo map. + y_in_poly = xy_in_poly[:, 0] + x_in_poly = xy_in_poly[:, 1] + poly[:, 0] = np.minimum(np.maximum(poly[:, 0], 0), w) + poly[:, 1] = np.minimum(np.maximum(poly[:, 1], 0), h) + for pno in range(4): + geo_channel_beg = pno * 2 + geo_map[y_in_poly, x_in_poly, geo_channel_beg] =\ + x_in_poly - poly[pno, 0] + geo_map[y_in_poly, x_in_poly, geo_channel_beg+1] =\ + y_in_poly - poly[pno, 1] + geo_map[y_in_poly, x_in_poly, 8] = \ + 1.0 / max(min(poly_h, poly_w), 1.0) + return score_map, geo_map, training_mask + + def crop_area(self, + im, + polys, + tags, + crop_background=False, + max_tries=50): + """ + make random crop from the input image + :param im: + :param polys: + :param tags: + :param crop_background: + :param max_tries: + :return: + """ + h, w, _ = im.shape + pad_h = h // 10 + pad_w = w // 10 + h_array = np.zeros((h + pad_h * 2), dtype=np.int32) + w_array = np.zeros((w + pad_w * 2), dtype=np.int32) + for poly in polys: + poly = np.round(poly, decimals=0).astype(np.int32) + minx = np.min(poly[:, 0]) + maxx = np.max(poly[:, 0]) + w_array[minx + pad_w:maxx + pad_w] = 1 + miny = np.min(poly[:, 1]) + maxy = np.max(poly[:, 1]) + h_array[miny + pad_h:maxy + pad_h] = 1 + # ensure the cropped area not across a text + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + if len(h_axis) == 0 or len(w_axis) == 0: + return im, polys, tags + + for i in range(max_tries): + xx = np.random.choice(w_axis, size=2) + xmin = np.min(xx) - pad_w + xmax = np.max(xx) - pad_w + xmin = np.clip(xmin, 0, w - 1) + xmax = np.clip(xmax, 0, w - 1) + yy = np.random.choice(h_axis, size=2) + ymin = np.min(yy) - pad_h + ymax = np.max(yy) - pad_h + ymin = np.clip(ymin, 0, h - 1) + ymax = np.clip(ymax, 0, h - 1) + if xmax - xmin < self.min_crop_side_ratio * w or \ + ymax - ymin < self.min_crop_side_ratio * h: + # area too small + continue + if polys.shape[0] != 0: + poly_axis_in_area = (polys[:, :, 0] >= xmin)\ + & (polys[:, :, 0] <= xmax)\ + & (polys[:, :, 1] >= ymin)\ + & (polys[:, :, 1] <= ymax) + selected_polys = np.where( + np.sum(poly_axis_in_area, axis=1) == 4)[0] + else: + selected_polys = [] + + if len(selected_polys) == 0: + # no text in this area + if crop_background: + im = im[ymin:ymax + 1, xmin:xmax + 1, :] + polys = [] + tags = [] + return im, polys, tags + else: + continue + + im = im[ymin:ymax + 1, xmin:xmax + 1, :] + polys = polys[selected_polys] + tags = tags[selected_polys] + polys[:, :, 0] -= xmin + polys[:, :, 1] -= ymin + return im, polys, tags + return im, polys, tags + + def crop_background_infor(self, im, text_polys, text_tags): + im, text_polys, text_tags = self.crop_area( + im, text_polys, text_tags, crop_background=True) + + if len(text_polys) > 0: + return None + # pad and resize image + input_size = self.input_size + im, ratio = self.preprocess(im) + score_map = np.zeros((input_size, input_size), dtype=np.float32) + geo_map = np.zeros((input_size, input_size, 9), dtype=np.float32) + training_mask = np.ones((input_size, input_size), dtype=np.float32) + return im, score_map, geo_map, training_mask + + def crop_foreground_infor(self, im, text_polys, text_tags): + im, text_polys, text_tags = self.crop_area( + im, text_polys, text_tags, crop_background=False) + + if text_polys.shape[0] == 0: + return None + #continue for all ignore case + if np.sum((text_tags * 1.0)) >= text_tags.size: + return None + # pad and resize image + input_size = self.input_size + im, ratio = self.preprocess(im) + text_polys[:, :, 0] *= ratio + text_polys[:, :, 1] *= ratio + _, _, new_h, new_w = im.shape + # print(im.shape) + # self.draw_img_polys(im, text_polys) + score_map, geo_map, training_mask = self.generate_quad( + (new_h, new_w), text_polys, text_tags) + return im, score_map, geo_map, training_mask + + def __call__(self, data): + im = data['image'] + text_polys = data['polys'] + text_tags = data['ignore_tags'] + if im is None: + return None + if text_polys.shape[0] == 0: + return None + + #add rotate cases + if np.random.rand() < 0.5: + im, text_polys = self.rotate_im_poly(im, text_polys) + h, w, _ = im.shape + text_polys, text_tags = self.check_and_validate_polys(text_polys, + text_tags, h, w) + if text_polys.shape[0] == 0: + return None + + # random scale this image + rd_scale = np.random.choice(self.random_scale) + im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) + text_polys *= rd_scale + if np.random.rand() < self.background_ratio: + outs = self.crop_background_infor(im, text_polys, text_tags) + else: + outs = self.crop_foreground_infor(im, text_polys, text_tags) + + if outs is None: + return None + im, score_map, geo_map, training_mask = outs + score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32) + geo_map = np.swapaxes(geo_map, 1, 2) + geo_map = np.swapaxes(geo_map, 1, 0) + geo_map = geo_map[:, ::4, ::4].astype(np.float32) + training_mask = training_mask[np.newaxis, ::4, ::4] + training_mask = training_mask.astype(np.float32) + + data['image'] = im[0] + data['score_map'] = score_map + data['geo_map'] = geo_map + data['training_mask'] = training_mask + # print(im.shape, score_map.shape, geo_map.shape, training_mask.shape) + return data \ No newline at end of file diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index f3c90050..9b99d2ec 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -52,6 +52,7 @@ class DetLabelEncode(object): txt_tags.append(True) else: txt_tags.append(False) + boxes = self.expand_points_num(boxes) boxes = np.array(boxes, dtype=np.float32) txt_tags = np.array(txt_tags, dtype=np.bool) @@ -70,6 +71,17 @@ class DetLabelEncode(object): rect[3] = pts[np.argmax(diff)] return rect + def expand_points_num(self, boxes): + max_points_num = 0 + for box in boxes: + if len(box) > max_points_num: + max_points_num = len(box) + ex_boxes = [] + for box in boxes: + ex_box = box + [box[-1]] * (max_points_num - len(box)) + ex_boxes.append(ex_box) + return ex_boxes + class BaseRecLabelEncode(object): """ Convert between text-label and text-index """ @@ -79,7 +91,9 @@ class BaseRecLabelEncode(object): character_dict_path=None, character_type='ch', use_space_char=False): - support_character_type = ['ch', 'en', 'en_sensitive'] + support_character_type = [ + 'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean' + ] assert character_type in support_character_type, "Only {} are supported now but get {}".format( support_character_type, self.character_str) @@ -87,7 +101,7 @@ class BaseRecLabelEncode(object): if character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) - elif character_type == "ch": + elif character_type in ["ch", "french", "german", "japan", "korean"]: self.character_str = "" assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch" with open(character_dict_path, "rb") as fin: diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index 74b60de4..c1352927 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -120,26 +120,37 @@ class DetResizeForTest(object): if 'limit_side_len' in kwargs: self.limit_side_len = kwargs['limit_side_len'] self.limit_type = kwargs.get('limit_type', 'min') + if 'resize_long' in kwargs: + self.resize_type = 2 + self.resize_long = kwargs.get('resize_long', 960) else: self.limit_side_len = 736 self.limit_type = 'min' def __call__(self, data): img = data['image'] + src_h, src_w, _ = img.shape if self.resize_type == 0: - img, shape = self.resize_image_type0(img) + # img, shape = self.resize_image_type0(img) + img, [ratio_h, ratio_w] = self.resize_image_type0(img) + elif self.resize_type == 2: + img, [ratio_h, ratio_w] = self.resize_image_type2(img) else: - img, shape = self.resize_image_type1(img) + # img, shape = self.resize_image_type1(img) + img, [ratio_h, ratio_w] = self.resize_image_type1(img) data['image'] = img - data['shape'] = shape + data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) return data def resize_image_type1(self, img): resize_h, resize_w = self.image_shape ori_h, ori_w = img.shape[:2] # (h, w, c) + ratio_h = float(resize_h) / ori_h + ratio_w = float(resize_w) / ori_w img = cv2.resize(img, (int(resize_w), int(resize_h))) - return img, np.array([ori_h, ori_w]) + # return img, np.array([ori_h, ori_w]) + return img, [ratio_h, ratio_w] def resize_image_type0(self, img): """ @@ -182,4 +193,31 @@ class DetResizeForTest(object): except: print(img.shape, resize_w, resize_h) sys.exit(0) - return img, np.array([h, w]) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + # return img, np.array([h, w]) + return img, [ratio_h, ratio_w] + + def resize_image_type2(self, img): + h, w, _ = img.shape + + resize_w = w + resize_h = h + + # Fix the longer side + if resize_h > resize_w: + ratio = float(self.resize_long) / resize_h + else: + ratio = float(self.resize_long) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + img = cv2.resize(img, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return img, [ratio_h, ratio_w] diff --git a/ppocr/data/imaug/sast_process.py b/ppocr/data/imaug/sast_process.py new file mode 100644 index 00000000..b8d6ff89 --- /dev/null +++ b/ppocr/data/imaug/sast_process.py @@ -0,0 +1,689 @@ +#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import math +import cv2 +import numpy as np +import json +import sys +import os + +__all__ = ['SASTProcessTrain'] + + +class SASTProcessTrain(object): + def __init__(self, + image_shape = [512, 512], + min_crop_size = 24, + min_crop_side_ratio = 0.3, + min_text_size = 10, + max_text_size = 512, + **kwargs): + self.input_size = image_shape[1] + self.min_crop_size = min_crop_size + self.min_crop_side_ratio = min_crop_side_ratio + self.min_text_size = min_text_size + self.max_text_size = max_text_size + + def quad_area(self, poly): + """ + compute area of a polygon + :param poly: + :return: + """ + edge = [ + (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]), + (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]), + (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]), + (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]) + ] + return np.sum(edge) / 2. + + def gen_quad_from_poly(self, poly): + """ + Generate min area quad from poly. + """ + point_num = poly.shape[0] + min_area_quad = np.zeros((4, 2), dtype=np.float32) + if True: + rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation) + center_point = rect[0] + box = np.array(cv2.boxPoints(rect)) + + first_point_idx = 0 + min_dist = 1e4 + for i in range(4): + dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \ + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \ + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \ + np.linalg.norm(box[(i + 3) % 4] - poly[-1]) + if dist < min_dist: + min_dist = dist + first_point_idx = i + for i in range(4): + min_area_quad[i] = box[(first_point_idx + i) % 4] + + return min_area_quad + + def check_and_validate_polys(self, polys, tags, xxx_todo_changeme): + """ + check so that the text poly is in the same direction, + and also filter some invalid polygons + :param polys: + :param tags: + :return: + """ + (h, w) = xxx_todo_changeme + if polys.shape[0] == 0: + return polys, np.array([]), np.array([]) + polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) + polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1) + + validated_polys = [] + validated_tags = [] + hv_tags = [] + for poly, tag in zip(polys, tags): + quad = self.gen_quad_from_poly(poly) + p_area = self.quad_area(quad) + if abs(p_area) < 1: + print('invalid poly') + continue + if p_area > 0: + if tag == False: + print('poly in wrong direction') + tag = True # reversed cases should be ignore + poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1), :] + quad = quad[(0, 3, 2, 1), :] + + len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] - quad[2]) + len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2]) + hv_tag = 1 + + if len_w * 2.0 < len_h: + hv_tag = 0 + + validated_polys.append(poly) + validated_tags.append(tag) + hv_tags.append(hv_tag) + return np.array(validated_polys), np.array(validated_tags), np.array(hv_tags) + + def crop_area(self, im, polys, tags, hv_tags, crop_background=False, max_tries=25): + """ + make random crop from the input image + :param im: + :param polys: + :param tags: + :param crop_background: + :param max_tries: 50 -> 25 + :return: + """ + h, w, _ = im.shape + pad_h = h // 10 + pad_w = w // 10 + h_array = np.zeros((h + pad_h * 2), dtype=np.int32) + w_array = np.zeros((w + pad_w * 2), dtype=np.int32) + for poly in polys: + poly = np.round(poly, decimals=0).astype(np.int32) + minx = np.min(poly[:, 0]) + maxx = np.max(poly[:, 0]) + w_array[minx + pad_w: maxx + pad_w] = 1 + miny = np.min(poly[:, 1]) + maxy = np.max(poly[:, 1]) + h_array[miny + pad_h: maxy + pad_h] = 1 + # ensure the cropped area not across a text + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + if len(h_axis) == 0 or len(w_axis) == 0: + return im, polys, tags, hv_tags + for i in range(max_tries): + xx = np.random.choice(w_axis, size=2) + xmin = np.min(xx) - pad_w + xmax = np.max(xx) - pad_w + xmin = np.clip(xmin, 0, w - 1) + xmax = np.clip(xmax, 0, w - 1) + yy = np.random.choice(h_axis, size=2) + ymin = np.min(yy) - pad_h + ymax = np.max(yy) - pad_h + ymin = np.clip(ymin, 0, h - 1) + ymax = np.clip(ymax, 0, h - 1) + # if xmax - xmin < ARGS.min_crop_side_ratio * w or \ + # ymax - ymin < ARGS.min_crop_side_ratio * h: + if xmax - xmin < self.min_crop_size or \ + ymax - ymin < self.min_crop_size: + # area too small + continue + if polys.shape[0] != 0: + poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \ + & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax) + selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0] + else: + selected_polys = [] + if len(selected_polys) == 0: + # no text in this area + if crop_background: + return im[ymin : ymax + 1, xmin : xmax + 1, :], \ + polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts + else: + continue + im = im[ymin: ymax + 1, xmin: xmax + 1, :] + polys = polys[selected_polys] + tags = tags[selected_polys] + hv_tags = hv_tags[selected_polys] + polys[:, :, 0] -= xmin + polys[:, :, 1] -= ymin + return im, polys, tags, hv_tags + + return im, polys, tags, hv_tags + + def generate_direction_map(self, poly_quads, direction_map): + """ + """ + width_list = [] + height_list = [] + for quad in poly_quads: + quad_w = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0 + quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0 + width_list.append(quad_w) + height_list.append(quad_h) + norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0) + average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0) + + for quad in poly_quads: + direct_vector_full = ((quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0 + direct_vector = direct_vector_full / (np.linalg.norm(direct_vector_full) + 1e-6) * norm_width + direction_label = tuple(map(float, [direct_vector[0], direct_vector[1], 1.0 / (average_height + 1e-6)])) + cv2.fillPoly(direction_map, quad.round().astype(np.int32)[np.newaxis, :, :], direction_label) + return direction_map + + def calculate_average_height(self, poly_quads): + """ + """ + height_list = [] + for quad in poly_quads: + quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0 + height_list.append(quad_h) + average_height = max(sum(height_list) / len(height_list), 1.0) + return average_height + + def generate_tcl_label(self, hw, polys, tags, ds_ratio, + tcl_ratio=0.3, shrink_ratio_of_width=0.15): + """ + Generate polygon. + """ + h, w = hw + h, w = int(h * ds_ratio), int(w * ds_ratio) + polys = polys * ds_ratio + + score_map = np.zeros((h, w,), dtype=np.float32) + tbo_map = np.zeros((h, w, 5), dtype=np.float32) + training_mask = np.ones((h, w,), dtype=np.float32) + direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape([1, 1, 3]).astype(np.float32) + + for poly_idx, poly_tag in enumerate(zip(polys, tags)): + poly = poly_tag[0] + tag = poly_tag[1] + + # generate min_area_quad + min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly) + min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) + + np.linalg.norm(min_area_quad[1] - min_area_quad[2])) + min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) + + np.linalg.norm(min_area_quad[2] - min_area_quad[3])) + + if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \ + or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio: + continue + + if tag: + # continue + cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0.15) + else: + tcl_poly = self.poly2tcl(poly, tcl_ratio) + tcl_quads = self.poly2quads(tcl_poly) + poly_quads = self.poly2quads(poly) + # stcl map + stcl_quads, quad_index = self.shrink_poly_along_width(tcl_quads, shrink_ratio_of_width=shrink_ratio_of_width, + expand_height_ratio=1.0 / tcl_ratio) + # generate tcl map + cv2.fillPoly(score_map, np.round(stcl_quads).astype(np.int32), 1.0) + + # generate tbo map + for idx, quad in enumerate(stcl_quads): + quad_mask = np.zeros((h, w), dtype=np.float32) + quad_mask = cv2.fillPoly(quad_mask, np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0) + tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]], quad_mask, tbo_map) + return score_map, tbo_map, training_mask + + def generate_tvo_and_tco(self, hw, polys, tags, tcl_ratio=0.3, ds_ratio=0.25): + """ + Generate tcl map, tvo map and tbo map. + """ + h, w = hw + h, w = int(h * ds_ratio), int(w * ds_ratio) + polys = polys * ds_ratio + poly_mask = np.zeros((h, w), dtype=np.float32) + + tvo_map = np.ones((9, h, w), dtype=np.float32) + tvo_map[0:-1:2] = np.tile(np.arange(0, w), (h, 1)) + tvo_map[1:-1:2] = np.tile(np.arange(0, w), (h, 1)).T + poly_tv_xy_map = np.zeros((8, h, w), dtype=np.float32) + + # tco map + tco_map = np.ones((3, h, w), dtype=np.float32) + tco_map[0] = np.tile(np.arange(0, w), (h, 1)) + tco_map[1] = np.tile(np.arange(0, w), (h, 1)).T + poly_tc_xy_map = np.zeros((2, h, w), dtype=np.float32) + + poly_short_edge_map = np.ones((h, w), dtype=np.float32) + + for poly, poly_tag in zip(polys, tags): + + if poly_tag == True: + continue + + # adjust point order for vertical poly + poly = self.adjust_point(poly) + + # generate min_area_quad + min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly) + min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) + + np.linalg.norm(min_area_quad[1] - min_area_quad[2])) + min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) + + np.linalg.norm(min_area_quad[2] - min_area_quad[3])) + + # generate tcl map and text, 128 * 128 + tcl_poly = self.poly2tcl(poly, tcl_ratio) + + # generate poly_tv_xy_map + for idx in range(4): + cv2.fillPoly(poly_tv_xy_map[2 * idx], + np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), + float(min(max(min_area_quad[idx, 0], 0), w))) + cv2.fillPoly(poly_tv_xy_map[2 * idx + 1], + np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), + float(min(max(min_area_quad[idx, 1], 0), h))) + + # generate poly_tc_xy_map + for idx in range(2): + cv2.fillPoly(poly_tc_xy_map[idx], + np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), float(center_point[idx])) + + # generate poly_short_edge_map + cv2.fillPoly(poly_short_edge_map, + np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), + float(max(min(min_area_quad_h, min_area_quad_w), 1.0))) + + # generate poly_mask and training_mask + cv2.fillPoly(poly_mask, np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), 1) + + tvo_map *= poly_mask + tvo_map[:8] -= poly_tv_xy_map + tvo_map[-1] /= poly_short_edge_map + tvo_map = tvo_map.transpose((1, 2, 0)) + + tco_map *= poly_mask + tco_map[:2] -= poly_tc_xy_map + tco_map[-1] /= poly_short_edge_map + tco_map = tco_map.transpose((1, 2, 0)) + + return tvo_map, tco_map + + def adjust_point(self, poly): + """ + adjust point order. + """ + point_num = poly.shape[0] + if point_num == 4: + len_1 = np.linalg.norm(poly[0] - poly[1]) + len_2 = np.linalg.norm(poly[1] - poly[2]) + len_3 = np.linalg.norm(poly[2] - poly[3]) + len_4 = np.linalg.norm(poly[3] - poly[0]) + + if (len_1 + len_3) * 1.5 < (len_2 + len_4): + poly = poly[[1, 2, 3, 0], :] + + elif point_num > 4: + vector_1 = poly[0] - poly[1] + vector_2 = poly[1] - poly[2] + cos_theta = np.dot(vector_1, vector_2) / (np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6) + theta = np.arccos(np.round(cos_theta, decimals=4)) + + if abs(theta) > (70 / 180 * math.pi): + index = list(range(1, point_num)) + [0] + poly = poly[np.array(index), :] + return poly + + def gen_min_area_quad_from_poly(self, poly): + """ + Generate min area quad from poly. + """ + point_num = poly.shape[0] + min_area_quad = np.zeros((4, 2), dtype=np.float32) + if point_num == 4: + min_area_quad = poly + center_point = np.sum(poly, axis=0) / 4 + else: + rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation) + center_point = rect[0] + box = np.array(cv2.boxPoints(rect)) + + first_point_idx = 0 + min_dist = 1e4 + for i in range(4): + dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \ + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \ + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \ + np.linalg.norm(box[(i + 3) % 4] - poly[-1]) + if dist < min_dist: + min_dist = dist + first_point_idx = i + + for i in range(4): + min_area_quad[i] = box[(first_point_idx + i) % 4] + + return min_area_quad, center_point + + def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.): + """ + Generate shrink_quad_along_width. + """ + ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + def shrink_poly_along_width(self, quads, shrink_ratio_of_width, expand_height_ratio=1.0): + """ + shrink poly with given length. + """ + upper_edge_list = [] + + def get_cut_info(edge_len_list, cut_len): + for idx, edge_len in enumerate(edge_len_list): + cut_len -= edge_len + if cut_len <= 0.000001: + ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx] + return idx, ratio + + for quad in quads: + upper_edge_len = np.linalg.norm(quad[0] - quad[1]) + upper_edge_list.append(upper_edge_len) + + # length of left edge and right edge. + left_length = np.linalg.norm(quads[0][0] - quads[0][3]) * expand_height_ratio + right_length = np.linalg.norm(quads[-1][1] - quads[-1][2]) * expand_height_ratio + + shrink_length = min(left_length, right_length, sum(upper_edge_list)) * shrink_ratio_of_width + # shrinking length + upper_len_left = shrink_length + upper_len_right = sum(upper_edge_list) - shrink_length + + left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left) + left_quad = self.shrink_quad_along_width(quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1) + right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right) + right_quad = self.shrink_quad_along_width(quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio) + + out_quad_list = [] + if left_idx == right_idx: + out_quad_list.append([left_quad[0], right_quad[1], right_quad[2], left_quad[3]]) + else: + out_quad_list.append(left_quad) + for idx in range(left_idx + 1, right_idx): + out_quad_list.append(quads[idx]) + out_quad_list.append(right_quad) + + return np.array(out_quad_list), list(range(left_idx, right_idx + 1)) + + def vector_angle(self, A, B): + """ + Calculate the angle between vector AB and x-axis positive direction. + """ + AB = np.array([B[1] - A[1], B[0] - A[0]]) + return np.arctan2(*AB) + + def theta_line_cross_point(self, theta, point): + """ + Calculate the line through given point and angle in ax + by + c =0 form. + """ + x, y = point + cos = np.cos(theta) + sin = np.sin(theta) + return [sin, -cos, cos * y - sin * x] + + def line_cross_two_point(self, A, B): + """ + Calculate the line through given point A and B in ax + by + c =0 form. + """ + angle = self.vector_angle(A, B) + return self.theta_line_cross_point(angle, A) + + def average_angle(self, poly): + """ + Calculate the average angle between left and right edge in given poly. + """ + p0, p1, p2, p3 = poly + angle30 = self.vector_angle(p3, p0) + angle21 = self.vector_angle(p2, p1) + return (angle30 + angle21) / 2 + + def line_cross_point(self, line1, line2): + """ + line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2 + """ + a1, b1, c1 = line1 + a2, b2, c2 = line2 + d = a1 * b2 - a2 * b1 + + if d == 0: + #print("line1", line1) + #print("line2", line2) + print('Cross point does not exist') + return np.array([0, 0], dtype=np.float32) + else: + x = (b1 * c2 - b2 * c1) / d + y = (a2 * c1 - a1 * c2) / d + + return np.array([x, y], dtype=np.float32) + + def quad2tcl(self, poly, ratio): + """ + Generate center line by poly clock-wise point. (4, 2) + """ + ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) + p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair + p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair + return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]]) + + def poly2tcl(self, poly, ratio): + """ + Generate center line by poly clock-wise point. + """ + ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) + tcl_poly = np.zeros_like(poly) + point_num = poly.shape[0] + + for idx in range(point_num // 2): + point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]) * ratio_pair + tcl_poly[idx] = point_pair[0] + tcl_poly[point_num - 1 - idx] = point_pair[1] + return tcl_poly + + def gen_quad_tbo(self, quad, tcl_mask, tbo_map): + """ + Generate tbo_map for give quad. + """ + # upper and lower line function: ax + by + c = 0; + up_line = self.line_cross_two_point(quad[0], quad[1]) + lower_line = self.line_cross_two_point(quad[3], quad[2]) + + quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) + quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) + + # average angle of left and right line. + angle = self.average_angle(quad) + + xy_in_poly = np.argwhere(tcl_mask == 1) + for y, x in xy_in_poly: + point = (x, y) + line = self.theta_line_cross_point(angle, point) + cross_point_upper = self.line_cross_point(up_line, line) + cross_point_lower = self.line_cross_point(lower_line, line) + ##FIX, offset reverse + upper_offset_x, upper_offset_y = cross_point_upper - point + lower_offset_x, lower_offset_y = cross_point_lower - point + tbo_map[y, x, 0] = upper_offset_y + tbo_map[y, x, 1] = upper_offset_x + tbo_map[y, x, 2] = lower_offset_y + tbo_map[y, x, 3] = lower_offset_x + tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2 + return tbo_map + + def poly2quads(self, poly): + """ + Split poly into quads. + """ + quad_list = [] + point_num = poly.shape[0] + + # point pair + point_pair_list = [] + for idx in range(point_num // 2): + point_pair = [poly[idx], poly[point_num - 1 - idx]] + point_pair_list.append(point_pair) + + quad_num = point_num // 2 - 1 + for idx in range(quad_num): + # reshape and adjust to clock-wise + quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]).reshape(4, 2)[[0, 2, 3, 1]]) + + return np.array(quad_list) + + def __call__(self, data): + im = data['image'] + text_polys = data['polys'] + text_tags = data['ignore_tags'] + if im is None: + return None + if text_polys.shape[0] == 0: + return None + + h, w, _ = im.shape + text_polys, text_tags, hv_tags = self.check_and_validate_polys(text_polys, text_tags, (h, w)) + + if text_polys.shape[0] == 0: + return None + + #set aspect ratio and keep area fix + asp_scales = np.arange(1.0, 1.55, 0.1) + asp_scale = np.random.choice(asp_scales) + + if np.random.rand() < 0.5: + asp_scale = 1.0 / asp_scale + asp_scale = math.sqrt(asp_scale) + + asp_wx = asp_scale + asp_hy = 1.0 / asp_scale + im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy) + text_polys[:, :, 0] *= asp_wx + text_polys[:, :, 1] *= asp_hy + + h, w, _ = im.shape + if max(h, w) > 2048: + rd_scale = 2048.0 / max(h, w) + im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) + text_polys *= rd_scale + h, w, _ = im.shape + if min(h, w) < 16: + return None + + #no background + im, text_polys, text_tags, hv_tags = self.crop_area(im, \ + text_polys, text_tags, hv_tags, crop_background=False) + + if text_polys.shape[0] == 0: + return None + #continue for all ignore case + if np.sum((text_tags * 1.0)) >= text_tags.size: + return None + new_h, new_w, _ = im.shape + if (new_h is None) or (new_w is None): + return None + #resize image + std_ratio = float(self.input_size) / max(new_w, new_h) + rand_scales = np.array([0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0]) + rz_scale = std_ratio * np.random.choice(rand_scales) + im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale) + text_polys[:, :, 0] *= rz_scale + text_polys[:, :, 1] *= rz_scale + + #add gaussian blur + if np.random.rand() < 0.1 * 0.5: + ks = np.random.permutation(5)[0] + 1 + ks = int(ks/2)*2 + 1 + im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0) + #add brighter + if np.random.rand() < 0.1 * 0.5: + im = im * (1.0 + np.random.rand() * 0.5) + im = np.clip(im, 0.0, 255.0) + #add darker + if np.random.rand() < 0.1 * 0.5: + im = im * (1.0 - np.random.rand() * 0.5) + im = np.clip(im, 0.0, 255.0) + + # Padding the im to [input_size, input_size] + new_h, new_w, _ = im.shape + if min(new_w, new_h) < self.input_size * 0.5: + return None + + im_padded = np.ones((self.input_size, self.input_size, 3), dtype=np.float32) + im_padded[:, :, 2] = 0.485 * 255 + im_padded[:, :, 1] = 0.456 * 255 + im_padded[:, :, 0] = 0.406 * 255 + + # Random the start position + del_h = self.input_size - new_h + del_w = self.input_size - new_w + sh, sw = 0, 0 + if del_h > 1: + sh = int(np.random.rand() * del_h) + if del_w > 1: + sw = int(np.random.rand() * del_w) + + # Padding + im_padded[sh: sh + new_h, sw: sw + new_w, :] = im.copy() + text_polys[:, :, 0] += sw + text_polys[:, :, 1] += sh + + score_map, border_map, training_mask = self.generate_tcl_label((self.input_size, self.input_size), + text_polys, text_tags, 0.25) + + # SAST head + tvo_map, tco_map = self.generate_tvo_and_tco((self.input_size, self.input_size), text_polys, text_tags, tcl_ratio=0.3, ds_ratio=0.25) + # print("test--------tvo_map shape:", tvo_map.shape) + + im_padded[:, :, 2] -= 0.485 * 255 + im_padded[:, :, 1] -= 0.456 * 255 + im_padded[:, :, 0] -= 0.406 * 255 + im_padded[:, :, 2] /= (255.0 * 0.229) + im_padded[:, :, 1] /= (255.0 * 0.224) + im_padded[:, :, 0] /= (255.0 * 0.225) + im_padded = im_padded.transpose((2, 0, 1)) + + data['image'] = im_padded[::-1, :, :] + data['score_map'] = score_map[np.newaxis, :, :] + data['border_map'] = border_map.transpose((2, 0, 1)) + data['training_mask'] = training_mask[np.newaxis, :, :] + data['tvo_map'] = tvo_map.transpose((2, 0, 1)) + data['tco_map'] = tco_map.transpose((2, 0, 1)) + return data \ No newline at end of file diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 564956e0..4673d35c 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -18,6 +18,8 @@ import copy def build_loss(config): # det loss from .det_db_loss import DBLoss + from .det_east_loss import EASTLoss + from .det_sast_loss import SASTLoss # rec loss from .rec_ctc_loss import CTCLoss @@ -25,7 +27,7 @@ def build_loss(config): # cls loss from .cls_loss import ClsLoss - support_dict = ['DBLoss', 'CTCLoss', 'ClsLoss'] + support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss'] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/det_east_loss.py b/ppocr/losses/det_east_loss.py new file mode 100644 index 00000000..bcf5372b --- /dev/null +++ b/ppocr/losses/det_east_loss.py @@ -0,0 +1,63 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +from .det_basic_loss import DiceLoss + + +class EASTLoss(nn.Layer): + """ + """ + + def __init__(self, + eps=1e-6, + **kwargs): + super(EASTLoss, self).__init__() + self.dice_loss = DiceLoss(eps=eps) + + def forward(self, predicts, labels): + l_score, l_geo, l_mask = labels[1:] + f_score = predicts['f_score'] + f_geo = predicts['f_geo'] + + dice_loss = self.dice_loss(f_score, l_score, l_mask) + + #smoooth_l1_loss + channels = 8 + l_geo_split = paddle.split( + l_geo, num_or_sections=channels + 1, axis=1) + f_geo_split = paddle.split(f_geo, num_or_sections=channels, axis=1) + smooth_l1 = 0 + for i in range(0, channels): + geo_diff = l_geo_split[i] - f_geo_split[i] + abs_geo_diff = paddle.abs(geo_diff) + smooth_l1_sign = paddle.less_than(abs_geo_diff, l_score) + smooth_l1_sign = paddle.cast(smooth_l1_sign, dtype='float32') + in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + \ + (abs_geo_diff - 0.5) * (1.0 - smooth_l1_sign) + out_loss = l_geo_split[-1] / channels * in_loss * l_score + smooth_l1 += out_loss + smooth_l1_loss = paddle.mean(smooth_l1 * l_score) + + dice_loss = dice_loss * 0.01 + total_loss = dice_loss + smooth_l1_loss + losses = {"loss":total_loss, \ + "dice_loss":dice_loss,\ + "smooth_l1_loss":smooth_l1_loss} + return losses diff --git a/ppocr/losses/det_sast_loss.py b/ppocr/losses/det_sast_loss.py new file mode 100644 index 00000000..a07af6a4 --- /dev/null +++ b/ppocr/losses/det_sast_loss.py @@ -0,0 +1,121 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +from .det_basic_loss import DiceLoss +import paddle.fluid as fluid +import numpy as np + + +class SASTLoss(nn.Layer): + """ + """ + + def __init__(self, + eps=1e-6, + **kwargs): + super(SASTLoss, self).__init__() + self.dice_loss = DiceLoss(eps=eps) + + def forward(self, predicts, labels): + """ + tcl_pos: N x 128 x 3 + tcl_mask: N x 128 x 1 + tcl_label: N x X list or LoDTensor + """ + + f_score = predicts['f_score'] + f_border = predicts['f_border'] + f_tvo = predicts['f_tvo'] + f_tco = predicts['f_tco'] + + l_score, l_border, l_mask, l_tvo, l_tco = labels[1:] + + #score_loss + intersection = paddle.sum(f_score * l_score * l_mask) + union = paddle.sum(f_score * l_mask) + paddle.sum(l_score * l_mask) + score_loss = 1.0 - 2 * intersection / (union + 1e-5) + + #border loss + l_border_split, l_border_norm = paddle.split(l_border, num_or_sections=[4, 1], axis=1) + f_border_split = f_border + border_ex_shape = l_border_norm.shape * np.array([1, 4, 1, 1]) + l_border_norm_split = paddle.expand(x=l_border_norm, shape=border_ex_shape) + l_border_score = paddle.expand(x=l_score, shape=border_ex_shape) + l_border_mask = paddle.expand(x=l_mask, shape=border_ex_shape) + + border_diff = l_border_split - f_border_split + abs_border_diff = paddle.abs(border_diff) + border_sign = abs_border_diff < 1.0 + border_sign = paddle.cast(border_sign, dtype='float32') + border_sign.stop_gradient = True + border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \ + (abs_border_diff - 0.5) * (1.0 - border_sign) + border_out_loss = l_border_norm_split * border_in_loss + border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \ + (paddle.sum(l_border_score * l_border_mask) + 1e-5) + + #tvo_loss + l_tvo_split, l_tvo_norm = paddle.split(l_tvo, num_or_sections=[8, 1], axis=1) + f_tvo_split = f_tvo + tvo_ex_shape = l_tvo_norm.shape * np.array([1, 8, 1, 1]) + l_tvo_norm_split = paddle.expand(x=l_tvo_norm, shape=tvo_ex_shape) + l_tvo_score = paddle.expand(x=l_score, shape=tvo_ex_shape) + l_tvo_mask = paddle.expand(x=l_mask, shape=tvo_ex_shape) + # + tvo_geo_diff = l_tvo_split - f_tvo_split + abs_tvo_geo_diff = paddle.abs(tvo_geo_diff) + tvo_sign = abs_tvo_geo_diff < 1.0 + tvo_sign = paddle.cast(tvo_sign, dtype='float32') + tvo_sign.stop_gradient = True + tvo_in_loss = 0.5 * abs_tvo_geo_diff * abs_tvo_geo_diff * tvo_sign + \ + (abs_tvo_geo_diff - 0.5) * (1.0 - tvo_sign) + tvo_out_loss = l_tvo_norm_split * tvo_in_loss + tvo_loss = paddle.sum(tvo_out_loss * l_tvo_score * l_tvo_mask) / \ + (paddle.sum(l_tvo_score * l_tvo_mask) + 1e-5) + + #tco_loss + l_tco_split, l_tco_norm = paddle.split(l_tco, num_or_sections=[2, 1], axis=1) + f_tco_split = f_tco + tco_ex_shape = l_tco_norm.shape * np.array([1, 2, 1, 1]) + l_tco_norm_split = paddle.expand(x=l_tco_norm, shape=tco_ex_shape) + l_tco_score = paddle.expand(x=l_score, shape=tco_ex_shape) + l_tco_mask = paddle.expand(x=l_mask, shape=tco_ex_shape) + + tco_geo_diff = l_tco_split - f_tco_split + abs_tco_geo_diff = paddle.abs(tco_geo_diff) + tco_sign = abs_tco_geo_diff < 1.0 + tco_sign = paddle.cast(tco_sign, dtype='float32') + tco_sign.stop_gradient = True + tco_in_loss = 0.5 * abs_tco_geo_diff * abs_tco_geo_diff * tco_sign + \ + (abs_tco_geo_diff - 0.5) * (1.0 - tco_sign) + tco_out_loss = l_tco_norm_split * tco_in_loss + tco_loss = paddle.sum(tco_out_loss * l_tco_score * l_tco_mask) / \ + (paddle.sum(l_tco_score * l_tco_mask) + 1e-5) + + + # total loss + tvo_lw, tco_lw = 1.5, 1.5 + score_lw, border_lw = 1.0, 1.0 + total_loss = score_loss * score_lw + border_loss * border_lw + \ + tvo_loss * tvo_lw + tco_loss * tco_lw + + losses = {'loss':total_loss, "score_loss":score_loss,\ + "border_loss":border_loss, 'tvo_loss':tvo_loss, 'tco_loss':tco_loss} + return losses \ No newline at end of file diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 7085d3af..43103e53 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -19,6 +19,7 @@ def build_backbone(config, model_type): if model_type == 'det': from .det_mobilenet_v3 import MobileNetV3 from .det_resnet_vd import ResNet + from .det_resnet_vd_sast import ResNet_SAST support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST'] elif model_type == 'rec' or model_type == 'cls': from .rec_mobilenet_v3 import MobileNetV3 diff --git a/ppocr/modeling/backbones/det_resnet_vd_sast.py b/ppocr/modeling/backbones/det_resnet_vd_sast.py new file mode 100644 index 00000000..c9376a8d --- /dev/null +++ b/ppocr/modeling/backbones/det_resnet_vd_sast.py @@ -0,0 +1,285 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F + +__all__ = ["ResNet_SAST"] + + +class ConvBNLayer(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + is_vd_mode=False, + act=None, + name=None, ): + super(ConvBNLayer, self).__init__() + + self.is_vd_mode = is_vd_mode + self._pool2d_avg = nn.AvgPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self._conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + self._batch_norm = nn.BatchNorm( + out_channels, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def forward(self, inputs): + if self.is_vd_mode: + inputs = self._pool2d_avg(inputs) + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class BottleneckBlock(nn.Layer): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False, + name=None): + super(BottleneckBlock, self).__init__() + + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + act='relu', + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu', + name=name + "_branch2b") + self.conv2 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels * 4, + kernel_size=1, + act=None, + name=name + "_branch2c") + + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels * 4, + kernel_size=1, + stride=1, + is_vd_mode=False if if_first else True, + name=name + "_branch1") + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + conv2 = self.conv2(conv1) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.add(x=short, y=conv2) + y = F.relu(y) + return y + + +class BasicBlock(nn.Layer): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False, + name=None): + super(BasicBlock, self).__init__() + self.stride = stride + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu', + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + act=None, + name=name + "_branch2b") + + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + is_vd_mode=False if if_first else True, + name=name + "_branch1") + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.add(x=short, y=conv1) + y = F.relu(y) + return y + + +class ResNet_SAST(nn.Layer): + def __init__(self, in_channels=3, layers=50, **kwargs): + super(ResNet_SAST, self).__init__() + + self.layers = layers + supported_layers = [18, 34, 50, 101, 152, 200] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format( + supported_layers, layers) + + if layers == 18: + depth = [2, 2, 2, 2] + elif layers == 34 or layers == 50: + # depth = [3, 4, 6, 3] + depth = [3, 4, 6, 3, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + elif layers == 200: + depth = [3, 12, 48, 3] + # num_channels = [64, 256, 512, + # 1024] if layers >= 50 else [64, 64, 128, 256] + # num_filters = [64, 128, 256, 512] + num_channels = [64, 256, 512, + 1024, 2048] if layers >= 50 else [64, 64, 128, 256] + num_filters = [64, 128, 256, 512, 512] + + self.conv1_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=32, + kernel_size=3, + stride=2, + act='relu', + name="conv1_1") + self.conv1_2 = ConvBNLayer( + in_channels=32, + out_channels=32, + kernel_size=3, + stride=1, + act='relu', + name="conv1_2") + self.conv1_3 = ConvBNLayer( + in_channels=32, + out_channels=64, + kernel_size=3, + stride=1, + act='relu', + name="conv1_3") + self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + + self.stages = [] + self.out_channels = [3, 64] + if layers >= 50: + for block in range(len(depth)): + block_list = [] + shortcut = False + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + bottleneck_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BottleneckBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block] * 4, + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + if_first=block == i == 0, + name=conv_name)) + shortcut = True + block_list.append(bottleneck_block) + self.out_channels.append(num_filters[block] * 4) + self.stages.append(nn.Sequential(*block_list)) + else: + for block in range(len(depth)): + block_list = [] + shortcut = False + for i in range(depth[block]): + conv_name = "res" + str(block + 2) + chr(97 + i) + basic_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BasicBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block], + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + if_first=block == i == 0, + name=conv_name)) + shortcut = True + block_list.append(basic_block) + self.out_channels.append(num_filters[block]) + self.stages.append(nn.Sequential(*block_list)) + + def forward(self, inputs): + out = [inputs] + y = self.conv1_1(inputs) + y = self.conv1_2(y) + y = self.conv1_3(y) + out.append(y) + y = self.pool2d_max(y) + for block in self.stages: + y = block(y) + out.append(y) + return out \ No newline at end of file diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 673f6fb4..78074709 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -18,13 +18,15 @@ __all__ = ['build_head'] def build_head(config): # det head from .det_db_head import DBHead + from .det_east_head import EASTHead + from .det_sast_head import SASTHead # rec head from .rec_ctc_head import CTCHead # cls head from .cls_head import ClsHead - support_dict = ['DBHead', 'CTCHead', 'ClsHead'] + support_dict = ['DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead'] module_name = config.pop('name') assert module_name in support_dict, Exception('head only support {}'.format( diff --git a/ppocr/modeling/heads/det_east_head.py b/ppocr/modeling/heads/det_east_head.py new file mode 100644 index 00000000..9d0c3c4c --- /dev/null +++ b/ppocr/modeling/heads/det_east_head.py @@ -0,0 +1,121 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups=1, + if_act=True, + act=None, + name=None): + super(ConvBNLayer, self).__init__() + self.if_act = if_act + self.act = act + self.conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + weight_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + + self.bn = nn.BatchNorm( + num_channels=out_channels, + act=act, + param_attr=ParamAttr(name="bn_" + name + "_scale"), + bias_attr=ParamAttr(name="bn_" + name + "_offset"), + moving_mean_name="bn_" + name + "_mean", + moving_variance_name="bn_" + name + "_variance") + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class EASTHead(nn.Layer): + """ + """ + def __init__(self, in_channels, model_name, **kwargs): + super(EASTHead, self).__init__() + self.model_name = model_name + if self.model_name == "large": + num_outputs = [128, 64, 1, 8] + else: + num_outputs = [64, 32, 1, 8] + + self.det_conv1 = ConvBNLayer( + in_channels=in_channels, + out_channels=num_outputs[0], + kernel_size=3, + stride=1, + padding=1, + if_act=True, + act='relu', + name="det_head1") + self.det_conv2 = ConvBNLayer( + in_channels=num_outputs[0], + out_channels=num_outputs[1], + kernel_size=3, + stride=1, + padding=1, + if_act=True, + act='relu', + name="det_head2") + self.score_conv = ConvBNLayer( + in_channels=num_outputs[1], + out_channels=num_outputs[2], + kernel_size=1, + stride=1, + padding=0, + if_act=False, + act=None, + name="f_score") + self.geo_conv = ConvBNLayer( + in_channels=num_outputs[1], + out_channels=num_outputs[3], + kernel_size=1, + stride=1, + padding=0, + if_act=False, + act=None, + name="f_geo") + + def forward(self, x): + f_det = self.det_conv1(x) + f_det = self.det_conv2(f_det) + f_score = self.score_conv(f_det) + f_score = F.sigmoid(f_score) + f_geo = self.geo_conv(f_det) + f_geo = (F.sigmoid(f_geo) - 0.5) * 2 * 800 + + pred = {'f_score': f_score, 'f_geo': f_geo} + return pred diff --git a/ppocr/modeling/heads/det_sast_head.py b/ppocr/modeling/heads/det_sast_head.py new file mode 100644 index 00000000..263b2867 --- /dev/null +++ b/ppocr/modeling/heads/det_sast_head.py @@ -0,0 +1,128 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + groups=1, + if_act=True, + act=None, + name=None): + super(ConvBNLayer, self).__init__() + self.if_act = if_act + self.act = act + self.conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + + self.bn = nn.BatchNorm( + num_channels=out_channels, + act=act, + param_attr=ParamAttr(name="bn_" + name + "_scale"), + bias_attr=ParamAttr(name="bn_" + name + "_offset"), + moving_mean_name="bn_" + name + "_mean", + moving_variance_name="bn_" + name + "_variance") + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class SAST_Header1(nn.Layer): + def __init__(self, in_channels, **kwargs): + super(SAST_Header1, self).__init__() + out_channels = [64, 64, 128] + self.score_conv = nn.Sequential( + ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_score1'), + ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_score2'), + ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_score3'), + ConvBNLayer(out_channels[2], 1, 3, 1, act=None, name='f_score4') + ) + self.border_conv = nn.Sequential( + ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_border1'), + ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_border2'), + ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_border3'), + ConvBNLayer(out_channels[2], 4, 3, 1, act=None, name='f_border4') + ) + + def forward(self, x): + f_score = self.score_conv(x) + f_score = F.sigmoid(f_score) + f_border = self.border_conv(x) + return f_score, f_border + + +class SAST_Header2(nn.Layer): + def __init__(self, in_channels, **kwargs): + super(SAST_Header2, self).__init__() + out_channels = [64, 64, 128] + self.tvo_conv = nn.Sequential( + ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_tvo1'), + ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_tvo2'), + ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_tvo3'), + ConvBNLayer(out_channels[2], 8, 3, 1, act=None, name='f_tvo4') + ) + self.tco_conv = nn.Sequential( + ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_tco1'), + ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_tco2'), + ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_tco3'), + ConvBNLayer(out_channels[2], 2, 3, 1, act=None, name='f_tco4') + ) + + def forward(self, x): + f_tvo = self.tvo_conv(x) + f_tco = self.tco_conv(x) + return f_tvo, f_tco + + +class SASTHead(nn.Layer): + """ + """ + def __init__(self, in_channels, **kwargs): + super(SASTHead, self).__init__() + + self.head1 = SAST_Header1(in_channels) + self.head2 = SAST_Header2(in_channels) + + def forward(self, x): + f_score, f_border = self.head1(x) + f_tvo, f_tco = self.head2(x) + + predicts = {} + predicts['f_score'] = f_score + predicts['f_border'] = f_border + predicts['f_tvo'] = f_tvo + predicts['f_tco'] = f_tco + return predicts \ No newline at end of file diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py index a9bf414b..405e062b 100644 --- a/ppocr/modeling/necks/__init__.py +++ b/ppocr/modeling/necks/__init__.py @@ -16,8 +16,10 @@ __all__ = ['build_neck'] def build_neck(config): from .db_fpn import DBFPN + from .east_fpn import EASTFPN + from .sast_fpn import SASTFPN from .rnn import SequenceEncoder - support_dict = ['DBFPN', 'SequenceEncoder'] + support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder'] module_name = config.pop('name') assert module_name in support_dict, Exception('neck only support {}'.format( diff --git a/ppocr/modeling/necks/east_fpn.py b/ppocr/modeling/necks/east_fpn.py new file mode 100644 index 00000000..120ff156 --- /dev/null +++ b/ppocr/modeling/necks/east_fpn.py @@ -0,0 +1,188 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups=1, + if_act=True, + act=None, + name=None): + super(ConvBNLayer, self).__init__() + self.if_act = if_act + self.act = act + self.conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + weight_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + + self.bn = nn.BatchNorm( + num_channels=out_channels, + act=act, + param_attr=ParamAttr(name="bn_" + name + "_scale"), + bias_attr=ParamAttr(name="bn_" + name + "_offset"), + moving_mean_name="bn_" + name + "_mean", + moving_variance_name="bn_" + name + "_variance") + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class DeConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups=1, + if_act=True, + act=None, + name=None): + super(DeConvBNLayer, self).__init__() + self.if_act = if_act + self.act = act + self.deconv = nn.Conv2DTranspose( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + weight_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + self.bn = nn.BatchNorm( + num_channels=out_channels, + act=act, + param_attr=ParamAttr(name="bn_" + name + "_scale"), + bias_attr=ParamAttr(name="bn_" + name + "_offset"), + moving_mean_name="bn_" + name + "_mean", + moving_variance_name="bn_" + name + "_variance") + + def forward(self, x): + x = self.deconv(x) + x = self.bn(x) + return x + + +class EASTFPN(nn.Layer): + def __init__(self, in_channels, model_name, **kwargs): + super(EASTFPN, self).__init__() + self.model_name = model_name + if self.model_name == "large": + self.out_channels = 128 + else: + self.out_channels = 64 + self.in_channels = in_channels[::-1] + self.h1_conv = ConvBNLayer( + in_channels=self.out_channels+self.in_channels[1], + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + if_act=True, + act='relu', + name="unet_h_1") + self.h2_conv = ConvBNLayer( + in_channels=self.out_channels+self.in_channels[2], + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + if_act=True, + act='relu', + name="unet_h_2") + self.h3_conv = ConvBNLayer( + in_channels=self.out_channels+self.in_channels[3], + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + if_act=True, + act='relu', + name="unet_h_3") + self.g0_deconv = DeConvBNLayer( + in_channels=self.in_channels[0], + out_channels=self.out_channels, + kernel_size=4, + stride=2, + padding=1, + if_act=True, + act='relu', + name="unet_g_0") + self.g1_deconv = DeConvBNLayer( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=4, + stride=2, + padding=1, + if_act=True, + act='relu', + name="unet_g_1") + self.g2_deconv = DeConvBNLayer( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=4, + stride=2, + padding=1, + if_act=True, + act='relu', + name="unet_g_2") + self.g3_conv = ConvBNLayer( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + if_act=True, + act='relu', + name="unet_g_3") + + def forward(self, x): + f = x[::-1] + + h = f[0] + g = self.g0_deconv(h) + h = paddle.concat([g, f[1]], axis=1) + h = self.h1_conv(h) + g = self.g1_deconv(h) + h = paddle.concat([g, f[2]], axis=1) + h = self.h2_conv(h) + g = self.g2_deconv(h) + h = paddle.concat([g, f[3]], axis=1) + h = self.h3_conv(h) + g = self.g3_conv(h) + + return g \ No newline at end of file diff --git a/ppocr/modeling/necks/sast_fpn.py b/ppocr/modeling/necks/sast_fpn.py new file mode 100644 index 00000000..9b602459 --- /dev/null +++ b/ppocr/modeling/necks/sast_fpn.py @@ -0,0 +1,284 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + groups=1, + if_act=True, + act=None, + name=None): + super(ConvBNLayer, self).__init__() + self.if_act = if_act + self.act = act + self.conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + + self.bn = nn.BatchNorm( + num_channels=out_channels, + act=act, + param_attr=ParamAttr(name="bn_" + name + "_scale"), + bias_attr=ParamAttr(name="bn_" + name + "_offset"), + moving_mean_name="bn_" + name + "_mean", + moving_variance_name="bn_" + name + "_variance") + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class DeConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + groups=1, + if_act=True, + act=None, + name=None): + super(DeConvBNLayer, self).__init__() + self.if_act = if_act + self.act = act + self.deconv = nn.Conv2DTranspose( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + self.bn = nn.BatchNorm( + num_channels=out_channels, + act=act, + param_attr=ParamAttr(name="bn_" + name + "_scale"), + bias_attr=ParamAttr(name="bn_" + name + "_offset"), + moving_mean_name="bn_" + name + "_mean", + moving_variance_name="bn_" + name + "_variance") + + def forward(self, x): + x = self.deconv(x) + x = self.bn(x) + return x + + +class FPN_Up_Fusion(nn.Layer): + def __init__(self, in_channels): + super(FPN_Up_Fusion, self).__init__() + in_channels = in_channels[::-1] + out_channels = [256, 256, 192, 192, 128] + + self.h0_conv = ConvBNLayer(in_channels[0], out_channels[0], 1, 1, act=None, name='fpn_up_h0') + self.h1_conv = ConvBNLayer(in_channels[1], out_channels[1], 1, 1, act=None, name='fpn_up_h1') + self.h2_conv = ConvBNLayer(in_channels[2], out_channels[2], 1, 1, act=None, name='fpn_up_h2') + self.h3_conv = ConvBNLayer(in_channels[3], out_channels[3], 1, 1, act=None, name='fpn_up_h3') + self.h4_conv = ConvBNLayer(in_channels[4], out_channels[4], 1, 1, act=None, name='fpn_up_h4') + + self.g0_conv = DeConvBNLayer(out_channels[0], out_channels[1], 4, 2, act=None, name='fpn_up_g0') + + self.g1_conv = nn.Sequential( + ConvBNLayer(out_channels[1], out_channels[1], 3, 1, act='relu', name='fpn_up_g1_1'), + DeConvBNLayer(out_channels[1], out_channels[2], 4, 2, act=None, name='fpn_up_g1_2') + ) + self.g2_conv = nn.Sequential( + ConvBNLayer(out_channels[2], out_channels[2], 3, 1, act='relu', name='fpn_up_g2_1'), + DeConvBNLayer(out_channels[2], out_channels[3], 4, 2, act=None, name='fpn_up_g2_2') + ) + self.g3_conv = nn.Sequential( + ConvBNLayer(out_channels[3], out_channels[3], 3, 1, act='relu', name='fpn_up_g3_1'), + DeConvBNLayer(out_channels[3], out_channels[4], 4, 2, act=None, name='fpn_up_g3_2') + ) + + self.g4_conv = nn.Sequential( + ConvBNLayer(out_channels[4], out_channels[4], 3, 1, act='relu', name='fpn_up_fusion_1'), + ConvBNLayer(out_channels[4], out_channels[4], 1, 1, act=None, name='fpn_up_fusion_2') + ) + + def _add_relu(self, x1, x2): + x = paddle.add(x=x1, y=x2) + x = F.relu(x) + return x + + def forward(self, x): + f = x[2:][::-1] + h0 = self.h0_conv(f[0]) + h1 = self.h1_conv(f[1]) + h2 = self.h2_conv(f[2]) + h3 = self.h3_conv(f[3]) + h4 = self.h4_conv(f[4]) + + g0 = self.g0_conv(h0) + g1 = self._add_relu(g0, h1) + g1 = self.g1_conv(g1) + g2 = self.g2_conv(self._add_relu(g1, h2)) + g3 = self.g3_conv(self._add_relu(g2, h3)) + g4 = self.g4_conv(self._add_relu(g3, h4)) + + return g4 + + +class FPN_Down_Fusion(nn.Layer): + def __init__(self, in_channels): + super(FPN_Down_Fusion, self).__init__() + out_channels = [32, 64, 128] + + self.h0_conv = ConvBNLayer(in_channels[0], out_channels[0], 3, 1, act=None, name='fpn_down_h0') + self.h1_conv = ConvBNLayer(in_channels[1], out_channels[1], 3, 1, act=None, name='fpn_down_h1') + self.h2_conv = ConvBNLayer(in_channels[2], out_channels[2], 3, 1, act=None, name='fpn_down_h2') + + self.g0_conv = ConvBNLayer(out_channels[0], out_channels[1], 3, 2, act=None, name='fpn_down_g0') + + self.g1_conv = nn.Sequential( + ConvBNLayer(out_channels[1], out_channels[1], 3, 1, act='relu', name='fpn_down_g1_1'), + ConvBNLayer(out_channels[1], out_channels[2], 3, 2, act=None, name='fpn_down_g1_2') + ) + + self.g2_conv = nn.Sequential( + ConvBNLayer(out_channels[2], out_channels[2], 3, 1, act='relu', name='fpn_down_fusion_1'), + ConvBNLayer(out_channels[2], out_channels[2], 1, 1, act=None, name='fpn_down_fusion_2') + ) + + def forward(self, x): + f = x[:3] + h0 = self.h0_conv(f[0]) + h1 = self.h1_conv(f[1]) + h2 = self.h2_conv(f[2]) + g0 = self.g0_conv(h0) + g1 = paddle.add(x=g0, y=h1) + g1 = F.relu(g1) + g1 = self.g1_conv(g1) + g2 = paddle.add(x=g1, y=h2) + g2 = F.relu(g2) + g2 = self.g2_conv(g2) + return g2 + + +class Cross_Attention(nn.Layer): + def __init__(self, in_channels): + super(Cross_Attention, self).__init__() + self.theta_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act='relu', name='f_theta') + self.phi_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act='relu', name='f_phi') + self.g_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act='relu', name='f_g') + + self.fh_weight_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fh_weight') + self.fh_sc_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fh_sc') + + self.fv_weight_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fv_weight') + self.fv_sc_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fv_sc') + + self.f_attn_conv = ConvBNLayer(in_channels * 2, in_channels, 1, 1, act='relu', name='f_attn') + + def _cal_fweight(self, f, shape): + f_theta, f_phi, f_g = f + #flatten + f_theta = paddle.transpose(f_theta, [0, 2, 3, 1]) + f_theta = paddle.reshape(f_theta, [shape[0] * shape[1], shape[2], 128]) + f_phi = paddle.transpose(f_phi, [0, 2, 3, 1]) + f_phi = paddle.reshape(f_phi, [shape[0] * shape[1], shape[2], 128]) + f_g = paddle.transpose(f_g, [0, 2, 3, 1]) + f_g = paddle.reshape(f_g, [shape[0] * shape[1], shape[2], 128]) + #correlation + f_attn = paddle.matmul(f_theta, paddle.transpose(f_phi, [0, 2, 1])) + #scale + f_attn = f_attn / (128**0.5) + f_attn = F.softmax(f_attn) + #weighted sum + f_weight = paddle.matmul(f_attn, f_g) + f_weight = paddle.reshape( + f_weight, [shape[0], shape[1], shape[2], 128]) + return f_weight + + def forward(self, f_common): + f_shape = paddle.shape(f_common) + # print('f_shape: ', f_shape) + + f_theta = self.theta_conv(f_common) + f_phi = self.phi_conv(f_common) + f_g = self.g_conv(f_common) + + ######## horizon ######## + fh_weight = self._cal_fweight([f_theta, f_phi, f_g], + [f_shape[0], f_shape[2], f_shape[3]]) + fh_weight = paddle.transpose(fh_weight, [0, 3, 1, 2]) + fh_weight = self.fh_weight_conv(fh_weight) + #short cut + fh_sc = self.fh_sc_conv(f_common) + f_h = F.relu(fh_weight + fh_sc) + + ######## vertical ######## + fv_theta = paddle.transpose(f_theta, [0, 1, 3, 2]) + fv_phi = paddle.transpose(f_phi, [0, 1, 3, 2]) + fv_g = paddle.transpose(f_g, [0, 1, 3, 2]) + fv_weight = self._cal_fweight([fv_theta, fv_phi, fv_g], + [f_shape[0], f_shape[3], f_shape[2]]) + fv_weight = paddle.transpose(fv_weight, [0, 3, 2, 1]) + fv_weight = self.fv_weight_conv(fv_weight) + #short cut + fv_sc = self.fv_sc_conv(f_common) + f_v = F.relu(fv_weight + fv_sc) + + ######## merge ######## + f_attn = paddle.concat([f_h, f_v], axis=1) + f_attn = self.f_attn_conv(f_attn) + return f_attn + + +class SASTFPN(nn.Layer): + def __init__(self, in_channels, with_cab=False, **kwargs): + super(SASTFPN, self).__init__() + self.in_channels = in_channels + self.with_cab = with_cab + self.FPN_Down_Fusion = FPN_Down_Fusion(self.in_channels) + self.FPN_Up_Fusion = FPN_Up_Fusion(self.in_channels) + self.out_channels = 128 + self.cross_attention = Cross_Attention(self.out_channels) + + def forward(self, x): + #down fpn + f_down = self.FPN_Down_Fusion(x) + + #up fpn + f_up = self.FPN_Up_Fusion(x) + + #fusion + f_common = paddle.add(x=f_down, y=f_up) + f_common = F.relu(f_common) + + if self.with_cab: + # print('enhence f_common with CAB.') + f_common = self.cross_attention(f_common) + + return f_common diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index e08a217d..c9b42e08 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -24,11 +24,13 @@ __all__ = ['build_post_process'] def build_post_process(config, global_config=None): from .db_postprocess import DBPostProcess + from .east_postprocess import EASTPostProcess + from .sast_postprocess import SASTPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode from .cls_postprocess import ClsPostProcess support_dict = [ - 'DBPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess' + 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/east_postprocess.py b/ppocr/postprocess/east_postprocess.py new file mode 100644 index 00000000..0b669405 --- /dev/null +++ b/ppocr/postprocess/east_postprocess.py @@ -0,0 +1,141 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from .locality_aware_nms import nms_locality +import cv2 + +import os +import sys +# __dir__ = os.path.dirname(os.path.abspath(__file__)) +# sys.path.append(__dir__) +# sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + + +class EASTPostProcess(object): + """ + The post process for EAST. + """ + def __init__(self, + score_thresh=0.8, + cover_thresh=0.1, + nms_thresh=0.2, + **kwargs): + + self.score_thresh = score_thresh + self.cover_thresh = cover_thresh + self.nms_thresh = nms_thresh + + # c++ la-nms is faster, but only support python 3.5 + self.is_python35 = False + if sys.version_info.major == 3 and sys.version_info.minor == 5: + self.is_python35 = True + + def restore_rectangle_quad(self, origin, geometry): + """ + Restore rectangle from quadrangle. + """ + # quad + origin_concat = np.concatenate( + (origin, origin, origin, origin), axis=1) # (n, 8) + pred_quads = origin_concat - geometry + pred_quads = pred_quads.reshape((-1, 4, 2)) # (n, 4, 2) + return pred_quads + + def detect(self, + score_map, + geo_map, + score_thresh=0.8, + cover_thresh=0.1, + nms_thresh=0.2): + """ + restore text boxes from score map and geo map + """ + score_map = score_map[0] + geo_map = np.swapaxes(geo_map, 1, 0) + geo_map = np.swapaxes(geo_map, 1, 2) + # filter the score map + xy_text = np.argwhere(score_map > score_thresh) + if len(xy_text) == 0: + return [] + # sort the text boxes via the y axis + xy_text = xy_text[np.argsort(xy_text[:, 0])] + #restore quad proposals + text_box_restored = self.restore_rectangle_quad( + xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :]) + boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32) + boxes[:, :8] = text_box_restored.reshape((-1, 8)) + boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]] + if self.is_python35: + import lanms + boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh) + else: + boxes = nms_locality(boxes.astype(np.float64), nms_thresh) + if boxes.shape[0] == 0: + return [] + # Here we filter some low score boxes by the average score map, + # this is different from the orginal paper. + for i, box in enumerate(boxes): + mask = np.zeros_like(score_map, dtype=np.uint8) + cv2.fillPoly(mask, box[:8].reshape( + (-1, 4, 2)).astype(np.int32) // 4, 1) + boxes[i, 8] = cv2.mean(score_map, mask)[0] + boxes = boxes[boxes[:, 8] > cover_thresh] + return boxes + + def sort_poly(self, p): + """ + Sort polygons. + """ + min_axis = np.argmin(np.sum(p, axis=1)) + p = p[[min_axis, (min_axis + 1) % 4,\ + (min_axis + 2) % 4, (min_axis + 3) % 4]] + if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]): + return p + else: + return p[[0, 3, 2, 1]] + + def __call__(self, outs_dict, shape_list): + score_list = outs_dict['f_score'] + geo_list = outs_dict['f_geo'] + img_num = len(shape_list) + dt_boxes_list = [] + for ino in range(img_num): + score = score_list[ino].numpy() + geo = geo_list[ino].numpy() + boxes = self.detect( + score_map=score, + geo_map=geo, + score_thresh=self.score_thresh, + cover_thresh=self.cover_thresh, + nms_thresh=self.nms_thresh) + boxes_norm = [] + if len(boxes) > 0: + h, w = score.shape[1:] + src_h, src_w, ratio_h, ratio_w = shape_list[ino] + boxes = boxes[:, :8].reshape((-1, 4, 2)) + boxes[:, :, 0] /= ratio_w + boxes[:, :, 1] /= ratio_h + for i_box, box in enumerate(boxes): + box = self.sort_poly(box.astype(np.int32)) + if np.linalg.norm(box[0] - box[1]) < 5 \ + or np.linalg.norm(box[3] - box[0]) < 5: + continue + boxes_norm.append(box) + dt_boxes_list.append({'points': np.array(boxes_norm)}) + return dt_boxes_list \ No newline at end of file diff --git a/ppocr/postprocess/locality_aware_nms.py b/ppocr/postprocess/locality_aware_nms.py new file mode 100644 index 00000000..53280cc1 --- /dev/null +++ b/ppocr/postprocess/locality_aware_nms.py @@ -0,0 +1,199 @@ +""" +Locality aware nms. +""" + +import numpy as np +from shapely.geometry import Polygon + + +def intersection(g, p): + """ + Intersection. + """ + g = Polygon(g[:8].reshape((4, 2))) + p = Polygon(p[:8].reshape((4, 2))) + g = g.buffer(0) + p = p.buffer(0) + if not g.is_valid or not p.is_valid: + return 0 + inter = Polygon(g).intersection(Polygon(p)).area + union = g.area + p.area - inter + if union == 0: + return 0 + else: + return inter / union + + +def intersection_iog(g, p): + """ + Intersection_iog. + """ + g = Polygon(g[:8].reshape((4, 2))) + p = Polygon(p[:8].reshape((4, 2))) + if not g.is_valid or not p.is_valid: + return 0 + inter = Polygon(g).intersection(Polygon(p)).area + #union = g.area + p.area - inter + union = p.area + if union == 0: + print("p_area is very small") + return 0 + else: + return inter / union + + +def weighted_merge(g, p): + """ + Weighted merge. + """ + g[:8] = (g[8] * g[:8] + p[8] * p[:8]) / (g[8] + p[8]) + g[8] = (g[8] + p[8]) + return g + + +def standard_nms(S, thres): + """ + Standard nms. + """ + order = np.argsort(S[:, 8])[::-1] + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + ovr = np.array([intersection(S[i], S[t]) for t in order[1:]]) + + inds = np.where(ovr <= thres)[0] + order = order[inds + 1] + + return S[keep] + + +def standard_nms_inds(S, thres): + """ + Standard nms, retun inds. + """ + order = np.argsort(S[:, 8])[::-1] + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + ovr = np.array([intersection(S[i], S[t]) for t in order[1:]]) + + inds = np.where(ovr <= thres)[0] + order = order[inds + 1] + + return keep + + +def nms(S, thres): + """ + nms. + """ + order = np.argsort(S[:, 8])[::-1] + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + ovr = np.array([intersection(S[i], S[t]) for t in order[1:]]) + + inds = np.where(ovr <= thres)[0] + order = order[inds + 1] + + return keep + + +def soft_nms(boxes_in, Nt_thres=0.3, threshold=0.8, sigma=0.5, method=2): + """ + soft_nms + :para boxes_in, N x 9 (coords + score) + :para threshould, eliminate cases min score(0.001) + :para Nt_thres, iou_threshi + :para sigma, gaussian weght + :method, linear or gaussian + """ + boxes = boxes_in.copy() + N = boxes.shape[0] + if N is None or N < 1: + return np.array([]) + pos, maxpos = 0, 0 + weight = 0.0 + inds = np.arange(N) + tbox, sbox = boxes[0].copy(), boxes[0].copy() + for i in range(N): + maxscore = boxes[i, 8] + maxpos = i + tbox = boxes[i].copy() + ti = inds[i] + pos = i + 1 + #get max box + while pos < N: + if maxscore < boxes[pos, 8]: + maxscore = boxes[pos, 8] + maxpos = pos + pos = pos + 1 + #add max box as a detection + boxes[i, :] = boxes[maxpos, :] + inds[i] = inds[maxpos] + #swap + boxes[maxpos, :] = tbox + inds[maxpos] = ti + tbox = boxes[i].copy() + pos = i + 1 + #NMS iteration + while pos < N: + sbox = boxes[pos].copy() + ts_iou_val = intersection(tbox, sbox) + if ts_iou_val > 0: + if method == 1: + if ts_iou_val > Nt_thres: + weight = 1 - ts_iou_val + else: + weight = 1 + elif method == 2: + weight = np.exp(-1.0 * ts_iou_val**2 / sigma) + else: + if ts_iou_val > Nt_thres: + weight = 0 + else: + weight = 1 + boxes[pos, 8] = weight * boxes[pos, 8] + #if box score falls below thresold, discard the box by + #swaping last box update N + if boxes[pos, 8] < threshold: + boxes[pos, :] = boxes[N - 1, :] + inds[pos] = inds[N - 1] + N = N - 1 + pos = pos - 1 + pos = pos + 1 + + return boxes[:N] + + +def nms_locality(polys, thres=0.3): + """ + locality aware nms of EAST + :param polys: a N*9 numpy array. first 8 coordinates, then prob + :return: boxes after nms + """ + S = [] + p = None + for g in polys: + if p is not None and intersection(g, p) > thres: + p = weighted_merge(g, p) + else: + if p is not None: + S.append(p) + p = g + if p is not None: + S.append(p) + + if len(S) == 0: + return np.array([]) + return standard_nms(np.array(S), thres) + + +if __name__ == '__main__': + # 343,350,448,135,474,143,369,359 + print( + Polygon(np.array([[343, 350], [448, 135], [474, 143], [369, 359]])) + .area) \ No newline at end of file diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index eb9be686..6943f84f 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -23,14 +23,16 @@ class BaseRecLabelDecode(object): character_dict_path=None, character_type='ch', use_space_char=False): - support_character_type = ['ch', 'en', 'en_sensitive'] + support_character_type = [ + 'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean' + ] assert character_type in support_character_type, "Only {} are supported now but get {}".format( support_character_type, self.character_str) if character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) - elif character_type == "ch": + elif character_type in ["ch", "french", "german", "japan", "korean"]: self.character_str = "" assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch" with open(character_dict_path, "rb") as fin: @@ -150,4 +152,4 @@ class AttnLabelDecode(BaseRecLabelDecode): else: assert False, "unsupport type %s in get_beg_end_flag_idx" \ % beg_or_end - return idx \ No newline at end of file + return idx diff --git a/ppocr/postprocess/sast_postprocess.py b/ppocr/postprocess/sast_postprocess.py new file mode 100644 index 00000000..03b0e8f1 --- /dev/null +++ b/ppocr/postprocess/sast_postprocess.py @@ -0,0 +1,295 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +__dir__ = os.path.dirname(__file__) +sys.path.append(__dir__) +sys.path.append(os.path.join(__dir__, '..')) + +import numpy as np +from .locality_aware_nms import nms_locality +# import lanms +import cv2 +import time + + +class SASTPostProcess(object): + """ + The post process for SAST. + """ + + def __init__(self, + score_thresh=0.5, + nms_thresh=0.2, + sample_pts_num=2, + shrink_ratio_of_width=0.3, + expand_scale=1.0, + tcl_map_thresh=0.5, + **kwargs): + + self.score_thresh = score_thresh + self.nms_thresh = nms_thresh + self.sample_pts_num = sample_pts_num + self.shrink_ratio_of_width = shrink_ratio_of_width + self.expand_scale = expand_scale + self.tcl_map_thresh = tcl_map_thresh + + # c++ la-nms is faster, but only support python 3.5 + self.is_python35 = False + if sys.version_info.major == 3 and sys.version_info.minor == 5: + self.is_python35 = True + + def point_pair2poly(self, point_pair_list): + """ + Transfer vertical point_pairs into poly point in clockwise. + """ + # constract poly + point_num = len(point_pair_list) * 2 + point_list = [0] * point_num + for idx, point_pair in enumerate(point_pair_list): + point_list[idx] = point_pair[0] + point_list[point_num - 1 - idx] = point_pair[1] + return np.array(point_list).reshape(-1, 2) + + def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.): + """ + Generate shrink_quad_along_width. + """ + ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3): + """ + expand poly along width. + """ + point_num = poly.shape[0] + left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) + left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ + (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) + left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0) + right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1], + poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32) + right_ratio = 1.0 + \ + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ + (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) + right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio) + poly[0] = left_quad_expand[0] + poly[-1] = left_quad_expand[-1] + poly[point_num // 2 - 1] = right_quad_expand[1] + poly[point_num // 2] = right_quad_expand[2] + return poly + + def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map): + """Restore quad.""" + xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) + xy_text = xy_text[:, ::-1] # (n, 2) + + # Sort the text boxes via the y axis + xy_text = xy_text[np.argsort(xy_text[:, 1])] + + scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0] + scores = scores[:, np.newaxis] + + # Restore + point_num = int(tvo_map.shape[-1] / 2) + assert point_num == 4 + tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :] + xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2) + quads = xy_text_tile - tvo_map + + return scores, quads, xy_text + + def quad_area(self, quad): + """ + compute area of a quad. + """ + edge = [ + (quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]), + (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]), + (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]), + (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1]) + ] + return np.sum(edge) / 2. + + def nms(self, dets): + if self.is_python35: + import lanms + dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh) + else: + dets = nms_locality(dets, self.nms_thresh) + return dets + + def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map): + """ + Cluster pixels in tcl_map based on quads. + """ + instance_count = quads.shape[0] + 1 # contain background + instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32) + if instance_count == 1: + return instance_count, instance_label_map + + # predict text center + xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) + n = xy_text.shape[0] + xy_text = xy_text[:, ::-1] # (n, 2) + tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2) + pred_tc = xy_text - tco + + # get gt text center + m = quads.shape[0] + gt_tc = np.mean(quads, axis=1) # (m, 2) + + pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2) + gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2) + dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m) + xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,) + + instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign + return instance_count, instance_label_map + + def estimate_sample_pts_num(self, quad, xy_text): + """ + Estimate sample points number. + """ + eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0 + ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0 + + dense_sample_pts_num = max(2, int(ew)) + dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num, + endpoint=True, dtype=np.float32).astype(np.int32)] + + dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1] + estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1)) + + sample_pts_num = max(2, int(estimate_arc_len / eh)) + return sample_pts_num + + def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h, + shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0): + """ + first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys + """ + # restore quad + scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map) + dets = np.hstack((quads, scores)).astype(np.float32, copy=False) + dets = self.nms(dets) + if dets.shape[0] == 0: + return [] + quads = dets[:, :-1].reshape(-1, 4, 2) + + # Compute quad area + quad_areas = [] + for quad in quads: + quad_areas.append(-self.quad_area(quad)) + + # instance segmentation + # instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8) + instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map) + + # restore single poly with tcl instance. + poly_list = [] + for instance_idx in range(1, instance_count): + xy_text = np.argwhere(instance_label_map == instance_idx)[:, ::-1] + quad = quads[instance_idx - 1] + q_area = quad_areas[instance_idx - 1] + if q_area < 5: + continue + + # + len1 = float(np.linalg.norm(quad[0] -quad[1])) + len2 = float(np.linalg.norm(quad[1] -quad[2])) + min_len = min(len1, len2) + if min_len < 3: + continue + + # filter small CC + if xy_text.shape[0] <= 0: + continue + + # filter low confidence instance + xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0] + if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1: + # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05: + continue + + # sort xy_text + left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0, + (quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2) + right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0, + (quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2) + proj_unit_vec = (right_center_pt - left_center_pt) / \ + (np.linalg.norm(right_center_pt - left_center_pt) + 1e-6) + proj_value = np.sum(xy_text * proj_unit_vec, axis=1) + xy_text = xy_text[np.argsort(proj_value)] + + # Sample pts in tcl map + if self.sample_pts_num == 0: + sample_pts_num = self.estimate_sample_pts_num(quad, xy_text) + else: + sample_pts_num = self.sample_pts_num + xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num, + endpoint=True, dtype=np.float32).astype(np.int32)] + + point_pair_list = [] + for x, y in xy_center_line: + # get corresponding offset + offset = tbo_map[y, x, :].reshape(2, 2) + if offset_expand != 1.0: + offset_length = np.linalg.norm(offset, axis=1, keepdims=True) + expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0) + offset_detal = offset / offset_length * expand_length + offset = offset + offset_detal + # original point + ori_yx = np.array([y, x], dtype=np.float32) + point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2) + point_pair_list.append(point_pair) + + # ndarry: (x, 2), expand poly along width + detected_poly = self.point_pair2poly(point_pair_list) + detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width) + detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w) + detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h) + poly_list.append(detected_poly) + + return poly_list + + def __call__(self, outs_dict, shape_list): + score_list = outs_dict['f_score'] + border_list = outs_dict['f_border'] + tvo_list = outs_dict['f_tvo'] + tco_list = outs_dict['f_tco'] + + img_num = len(shape_list) + poly_lists = [] + for ino in range(img_num): + p_score = score_list[ino].transpose((1,2,0)).numpy() + p_border = border_list[ino].transpose((1,2,0)).numpy() + p_tvo = tvo_list[ino].transpose((1,2,0)).numpy() + p_tco = tco_list[ino].transpose((1,2,0)).numpy() + src_h, src_w, ratio_h, ratio_w = shape_list[ino] + + poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h, + shrink_ratio_of_width=self.shrink_ratio_of_width, + tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale) + poly_lists.append({'points': np.array(poly_list)}) + + return poly_lists + diff --git a/setup.py b/setup.py index 6b503ce3..bef6dbbf 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ setup( package_dir={'paddleocr': ''}, include_package_data=True, entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]}, - version='0.0.3', + version='2.0', install_requires=requirements, license='Apache License 2.0', description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices', diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index ae660fde..07dfc216 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -13,6 +13,7 @@ # limitations under the License. import os import sys + __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) @@ -30,12 +31,15 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.logging import get_logger from tools.infer.utility import draw_ocr_box_txt +logger = get_logger() + class TextSystem(object): def __init__(self, args): self.text_detector = predict_det.TextDetector(args) self.text_recognizer = predict_rec.TextRecognizer(args) self.use_angle_cls = args.use_angle_cls + self.drop_score = args.drop_score if self.use_angle_cls: self.text_classifier = predict_cls.TextClassifier(args) @@ -81,7 +85,8 @@ class TextSystem(object): def __call__(self, img): ori_im = img.copy() dt_boxes, elapse = self.text_detector(img) - logger.info("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse)) + logger.info("dt_boxes num : {}, elapse : {}".format( + len(dt_boxes), elapse)) if dt_boxes is None: return None, None img_crop_list = [] @@ -99,9 +104,16 @@ class TextSystem(object): len(img_crop_list), elapse)) rec_res, elapse = self.text_recognizer(img_crop_list) - logger.info("rec_res num : {}, elapse : {}".format(len(rec_res), elapse)) + logger.info("rec_res num : {}, elapse : {}".format( + len(rec_res), elapse)) # self.print_draw_crop_rec_res(img_crop_list, rec_res) - return dt_boxes, rec_res + filter_boxes, filter_rec_res = [], [] + for box, rec_reuslt in zip(dt_boxes, rec_res): + text, score = rec_reuslt + if score >= self.drop_score: + filter_boxes.append(box) + filter_rec_res.append(rec_reuslt) + return filter_boxes, filter_rec_res def sorted_boxes(dt_boxes): @@ -117,8 +129,8 @@ def sorted_boxes(dt_boxes): _boxes = list(sorted_boxes) for i in range(num_boxes - 1): - if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ - (_boxes[i + 1][0][0] < _boxes[i][0][0]): + if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): tmp = _boxes[i] _boxes[i] = _boxes[i + 1] _boxes[i + 1] = tmp @@ -143,12 +155,8 @@ def main(args): elapse = time.time() - starttime logger.info("Predict time of %s: %.3fs" % (image_file, elapse)) - dt_num = len(dt_boxes) - for dno in range(dt_num): - text, score = rec_res[dno] - if score >= drop_score: - text_str = "%s, %.3f" % (text, score) - logger.info(text_str) + for text, score in rec_res: + logger.info("{}, {:.3f}".format(text, score)) if is_visualize: image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) @@ -174,5 +182,4 @@ def main(args): if __name__ == "__main__": - logger = get_logger() - main(utility.parse_args()) + main(utility.parse_args()) \ No newline at end of file -- GitLab