未验证 提交 ca14b865 编写于 作者: E Evezerest 提交者: GitHub

Merge branch 'dygraph' into dygraph

...@@ -60,8 +60,10 @@ PostProcess: ...@@ -60,8 +60,10 @@ PostProcess:
name: PGPostProcess name: PGPostProcess
score_thresh: 0.5 score_thresh: 0.5
mode: fast # fast or slow two ways mode: fast # fast or slow two ways
Metric: Metric:
name: E2EMetric name: E2EMetric
mode: A # two ways for eval, A: label from txt, B: label from gt_mat
gt_mat_dir: ./train_data/total_text/gt # the dir of gt_mat gt_mat_dir: ./train_data/total_text/gt # the dir of gt_mat
character_dict_path: ppocr/utils/ic15_dict.txt character_dict_path: ppocr/utils/ic15_dict.txt
main_indicator: f_score_e2e main_indicator: f_score_e2e
...@@ -70,13 +72,13 @@ Train: ...@@ -70,13 +72,13 @@ Train:
dataset: dataset:
name: PGDataSet name: PGDataSet
data_dir: ./train_data/total_text/train data_dir: ./train_data/total_text/train
label_file_list: [./train_data/total_text/train/] label_file_list: [./train_data/total_text/train/train.txt]
ratio_list: [1.0] ratio_list: [1.0]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- E2ELabelEncode: - E2ELabelEncodeTrain:
- PGProcessTrain: - PGProcessTrain:
batch_size: 14 # same as loader: batch_size_per_card batch_size: 14 # same as loader: batch_size_per_card
min_crop_size: 24 min_crop_size: 24
...@@ -94,11 +96,12 @@ Eval: ...@@ -94,11 +96,12 @@ Eval:
dataset: dataset:
name: PGDataSet name: PGDataSet
data_dir: ./train_data/total_text/test data_dir: ./train_data/total_text/test
label_file_list: [./train_data/total_text/test/] label_file_list: [./train_data/total_text/test/test.txt]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
channel_first: False channel_first: False
- E2ELabelEncodeTest:
- E2EResizeForTest: - E2EResizeForTest:
max_side_len: 768 max_side_len: 768
- NormalizeImage: - NormalizeImage:
...@@ -108,7 +111,7 @@ Eval: ...@@ -108,7 +111,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'image', 'shape', 'img_id'] keep_keys: [ 'image', 'shape', 'polys', 'texts', 'ignore_tags', 'img_id']
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: True use_space_char: True
save_res_path: ./output/rec/predicts_chinese_common_v2.0.txt
Optimizer: Optimizer:
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: True use_space_char: True
save_res_path: ./output/rec/predicts_chinese_lite_v2.0.txt
Optimizer: Optimizer:
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_ic15.txt
Optimizer: Optimizer:
name: Adam name: Adam
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_mv3_none_bilstm_ctc.txt
Optimizer: Optimizer:
name: Adam name: Adam
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_mv3_none_none_ctc.txt
Optimizer: Optimizer:
name: Adam name: Adam
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_mv3_tps_bilstm_att.txt
Optimizer: Optimizer:
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_mv3_tps_bilstm_ctc.txt
Optimizer: Optimizer:
name: Adam name: Adam
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_r34_vd_none_bilstm_ctc.txt
Optimizer: Optimizer:
name: Adam name: Adam
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_r34_vd_none_none_ctc.txt
Optimizer: Optimizer:
name: Adam name: Adam
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_b3_rare_r34_none_gru.txt
Optimizer: Optimizer:
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_r34_vd_tps_bilstm_ctc.txt
Optimizer: Optimizer:
name: Adam name: Adam
......
...@@ -20,6 +20,7 @@ Global: ...@@ -20,6 +20,7 @@ Global:
num_heads: 8 num_heads: 8
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_srn.txt
Optimizer: Optimizer:
......
...@@ -51,6 +51,7 @@ public: ...@@ -51,6 +51,7 @@ public:
float &ssid); float &ssid);
float BoxScoreFast(std::vector<std::vector<float>> box_array, cv::Mat pred); float BoxScoreFast(std::vector<std::vector<float>> box_array, cv::Mat pred);
float PolygonScoreAcc(std::vector<cv::Point> contour, cv::Mat pred);
std::vector<std::vector<std::vector<int>>> std::vector<std::vector<std::vector<int>>>
BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap, BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
......
...@@ -159,6 +159,39 @@ std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box, ...@@ -159,6 +159,39 @@ std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box,
return array; return array;
} }
float PostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
cv::Mat pred){
int width = pred.cols;
int height = pred.rows;
std::vector<float> box_x;
std::vector<float> box_y;
for(int i=0; i<contour.size(); ++i){
box_x.push_back(contour[i].x);
box_y.push_back(contour[i].y);
}
int xmin = clamp(int(std::floor(*(std::min_element(box_x.begin(), box_x.end())))), 0, width - 1);
int xmax = clamp(int(std::ceil(*(std::max_element(box_x.begin(), box_x.end())))), 0, width - 1);
int ymin = clamp(int(std::floor(*(std::min_element(box_y.begin(), box_y.end())))), 0, height - 1);
int ymax = clamp(int(std::ceil(*(std::max_element(box_y.begin(), box_y.end())))), 0, height - 1);
cv::Mat mask;
mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);
cv::Point rook_point[contour.size()];
for(int i=0; i<contour.size(); ++i){
rook_point[i] = cv::Point(int(box_x[i]) - xmin, int(box_y[i]) - ymin);
}
const cv::Point *ppt[1] = {rook_point};
int npt[] = {int(contour.size())};
cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));
cv::Mat croppedImg;
pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1)).copyTo(croppedImg);
float score = cv::mean(croppedImg, mask)[0];
return score;
}
float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array, float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
cv::Mat pred) { cv::Mat pred) {
auto array = box_array; auto array = box_array;
...@@ -235,6 +268,8 @@ PostProcessor::BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap, ...@@ -235,6 +268,8 @@ PostProcessor::BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
float score; float score;
score = BoxScoreFast(array, pred); score = BoxScoreFast(array, pred);
/* compute using polygon*/
// score = PolygonScoreAcc(contours[_i], pred);
if (score < box_thresh) if (score < box_thresh)
continue; continue;
......
...@@ -77,19 +77,10 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, ...@@ -77,19 +77,10 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
int resize_h = int(float(h) * ratio); int resize_h = int(float(h) * ratio);
int resize_w = int(float(w) * ratio); int resize_w = int(float(w) * ratio);
if (resize_h % 32 == 0)
resize_h = resize_h;
else if (resize_h / 32 < 1 + 1e-5)
resize_h = 32;
else
resize_h = (resize_h / 32) * 32;
if (resize_w % 32 == 0) resize_h = max(int(round(float(resize_h) / 32) * 32), 32);
resize_w = resize_w; resize_w = max(int(round(float(resize_w) / 32) * 32), 32);
else if (resize_w / 32 < 1 + 1e-5)
resize_w = 32;
else
resize_w = (resize_w / 32) * 32;
if (!use_tensorrt) { if (!use_tensorrt) {
cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
ratio_h = float(resize_h) / float(h); ratio_h = float(resize_h) / float(h);
......
...@@ -6,6 +6,7 @@ from __future__ import print_function ...@@ -6,6 +6,7 @@ from __future__ import print_function
import os import os
import sys import sys
sys.path.insert(0, ".") sys.path.insert(0, ".")
import copy
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
...@@ -14,6 +15,8 @@ import paddlehub as hub ...@@ -14,6 +15,8 @@ import paddlehub as hub
from tools.infer.utility import base64_to_cv2 from tools.infer.utility import base64_to_cv2
from tools.infer.predict_cls import TextClassifier from tools.infer.predict_cls import TextClassifier
from tools.infer.utility import parse_args
from deploy.hubserving.ocr_cls.params import read_params
@moduleinfo( @moduleinfo(
...@@ -28,8 +31,7 @@ class OCRCls(hub.Module): ...@@ -28,8 +31,7 @@ class OCRCls(hub.Module):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
from ocr_cls.params import read_params cfg = self.merge_configs()
cfg = read_params()
cfg.use_gpu = use_gpu cfg.use_gpu = use_gpu
if use_gpu: if use_gpu:
...@@ -48,6 +50,20 @@ class OCRCls(hub.Module): ...@@ -48,6 +50,20 @@ class OCRCls(hub.Module):
self.text_classifier = TextClassifier(cfg) self.text_classifier = TextClassifier(cfg)
def merge_configs(self, ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
cfg = parse_args()
update_cfg_map = vars(read_params())
for key in update_cfg_map:
cfg.__setattr__(key, update_cfg_map[key])
sys.argv = copy.deepcopy(backup_argv)
return cfg
def read_images(self, paths=[]): def read_images(self, paths=[]):
images = [] images = []
for img_path in paths: for img_path in paths:
......
...@@ -7,6 +7,8 @@ import os ...@@ -7,6 +7,8 @@ import os
import sys import sys
sys.path.insert(0, ".") sys.path.insert(0, ".")
import copy
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
import cv2 import cv2
...@@ -15,6 +17,8 @@ import paddlehub as hub ...@@ -15,6 +17,8 @@ import paddlehub as hub
from tools.infer.utility import base64_to_cv2 from tools.infer.utility import base64_to_cv2
from tools.infer.predict_det import TextDetector from tools.infer.predict_det import TextDetector
from tools.infer.utility import parse_args
from deploy.hubserving.ocr_system.params import read_params
@moduleinfo( @moduleinfo(
...@@ -29,8 +33,7 @@ class OCRDet(hub.Module): ...@@ -29,8 +33,7 @@ class OCRDet(hub.Module):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
from ocr_det.params import read_params cfg = self.merge_configs()
cfg = read_params()
cfg.use_gpu = use_gpu cfg.use_gpu = use_gpu
if use_gpu: if use_gpu:
...@@ -49,6 +52,20 @@ class OCRDet(hub.Module): ...@@ -49,6 +52,20 @@ class OCRDet(hub.Module):
self.text_detector = TextDetector(cfg) self.text_detector = TextDetector(cfg)
def merge_configs(self, ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
cfg = parse_args()
update_cfg_map = vars(read_params())
for key in update_cfg_map:
cfg.__setattr__(key, update_cfg_map[key])
sys.argv = copy.deepcopy(backup_argv)
return cfg
def read_images(self, paths=[]): def read_images(self, paths=[]):
images = [] images = []
for img_path in paths: for img_path in paths:
......
...@@ -22,6 +22,7 @@ def read_params(): ...@@ -22,6 +22,7 @@ def read_params():
cfg.det_db_box_thresh = 0.5 cfg.det_db_box_thresh = 0.5
cfg.det_db_unclip_ratio = 1.6 cfg.det_db_unclip_ratio = 1.6
cfg.use_dilation = False cfg.use_dilation = False
cfg.det_db_score_mode = "fast"
# #EAST parmas # #EAST parmas
# cfg.det_east_score_thresh = 0.8 # cfg.det_east_score_thresh = 0.8
......
...@@ -6,6 +6,7 @@ from __future__ import print_function ...@@ -6,6 +6,7 @@ from __future__ import print_function
import os import os
import sys import sys
sys.path.insert(0, ".") sys.path.insert(0, ".")
import copy
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
...@@ -14,6 +15,8 @@ import paddlehub as hub ...@@ -14,6 +15,8 @@ import paddlehub as hub
from tools.infer.utility import base64_to_cv2 from tools.infer.utility import base64_to_cv2
from tools.infer.predict_rec import TextRecognizer from tools.infer.predict_rec import TextRecognizer
from tools.infer.utility import parse_args
from deploy.hubserving.ocr_rec.params import read_params
@moduleinfo( @moduleinfo(
...@@ -28,8 +31,7 @@ class OCRRec(hub.Module): ...@@ -28,8 +31,7 @@ class OCRRec(hub.Module):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
from ocr_rec.params import read_params cfg = self.merge_configs()
cfg = read_params()
cfg.use_gpu = use_gpu cfg.use_gpu = use_gpu
if use_gpu: if use_gpu:
...@@ -48,6 +50,20 @@ class OCRRec(hub.Module): ...@@ -48,6 +50,20 @@ class OCRRec(hub.Module):
self.text_recognizer = TextRecognizer(cfg) self.text_recognizer = TextRecognizer(cfg)
def merge_configs(self, ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
cfg = parse_args()
update_cfg_map = vars(read_params())
for key in update_cfg_map:
cfg.__setattr__(key, update_cfg_map[key])
sys.argv = copy.deepcopy(backup_argv)
return cfg
def read_images(self, paths=[]): def read_images(self, paths=[]):
images = [] images = []
for img_path in paths: for img_path in paths:
......
...@@ -6,6 +6,7 @@ from __future__ import print_function ...@@ -6,6 +6,7 @@ from __future__ import print_function
import os import os
import sys import sys
sys.path.insert(0, ".") sys.path.insert(0, ".")
import copy
import time import time
...@@ -17,6 +18,8 @@ import paddlehub as hub ...@@ -17,6 +18,8 @@ import paddlehub as hub
from tools.infer.utility import base64_to_cv2 from tools.infer.utility import base64_to_cv2
from tools.infer.predict_system import TextSystem from tools.infer.predict_system import TextSystem
from tools.infer.utility import parse_args
from deploy.hubserving.ocr_system.params import read_params
@moduleinfo( @moduleinfo(
...@@ -31,8 +34,7 @@ class OCRSystem(hub.Module): ...@@ -31,8 +34,7 @@ class OCRSystem(hub.Module):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
from ocr_system.params import read_params cfg = self.merge_configs()
cfg = read_params()
cfg.use_gpu = use_gpu cfg.use_gpu = use_gpu
if use_gpu: if use_gpu:
...@@ -51,6 +53,20 @@ class OCRSystem(hub.Module): ...@@ -51,6 +53,20 @@ class OCRSystem(hub.Module):
self.text_sys = TextSystem(cfg) self.text_sys = TextSystem(cfg)
def merge_configs(self, ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
cfg = parse_args()
update_cfg_map = vars(read_params())
for key in update_cfg_map:
cfg.__setattr__(key, update_cfg_map[key])
sys.argv = copy.deepcopy(backup_argv)
return cfg
def read_images(self, paths=[]): def read_images(self, paths=[]):
images = [] images = []
for img_path in paths: for img_path in paths:
......
...@@ -22,6 +22,7 @@ def read_params(): ...@@ -22,6 +22,7 @@ def read_params():
cfg.det_db_box_thresh = 0.5 cfg.det_db_box_thresh = 0.5
cfg.det_db_unclip_ratio = 1.6 cfg.det_db_unclip_ratio = 1.6
cfg.use_dilation = False cfg.use_dilation = False
cfg.det_db_score_mode = "fast"
#EAST parmas #EAST parmas
cfg.det_east_score_thresh = 0.8 cfg.det_east_score_thresh = 0.8
......
...@@ -83,19 +83,19 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im ...@@ -83,19 +83,19 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。 本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。
### 准备数据 ### 准备数据
下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md) 数据集到PaddleOCR/train_data/目录,数据集组织结构: 下载解压[totaltext](https://paddleocr.bj.bcebos.com/dataset/total_text.tar) 数据集到PaddleOCR/train_data/目录,数据集组织结构:
``` ```
/PaddleOCR/train_data/total_text/train/ /PaddleOCR/train_data/total_text/train/
|- rgb/ # total_text数据集的训练数据 |- rgb/ # total_text数据集的训练数据
|- gt_0.png |- img11.jpg
| ... | ...
|- total_text.txt # total_text数据集的训练标注 |- train.txt # total_text数据集的训练标注
``` ```
total_text.txt标注文件格式如下,文件名和标注信息中间用"\t"分隔: total_text.txt标注文件格式如下,文件名和标注信息中间用"\t"分隔:
``` ```
" 图像文件名 json.dumps编码的图像标注信息" " 图像文件名 json.dumps编码的图像标注信息"
rgb/gt_0.png [{"transcription": "EST", "points": [[1004.0,689.0],[1019.0,698.0],[1034.0,708.0],[1049.0,718.0],[1064.0,728.0],[1079.0,738.0],[1095.0,748.0],[1094.0,774.0],[1079.0,765.0],[1065.0,756.0],[1050.0,747.0],[1036.0,738.0],[1021.0,729.0],[1007.0,721.0]]}, {...}] rgb/img11.jpg [{"transcription": "ASRAMA", "points": [[214.0, 325.0], [235.0, 308.0], [259.0, 296.0], [286.0, 291.0], [313.0, 295.0], [338.0, 305.0], [362.0, 320.0], [349.0, 347.0], [330.0, 337.0], [310.0, 329.0], [290.0, 324.0], [269.0, 328.0], [249.0, 336.0], [231.0, 346.0]]}, {...}]
``` ```
json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。 json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。
`transcription` 表示当前文本框的文字,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。** `transcription` 表示当前文本框的文字,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**
......
...@@ -76,19 +76,19 @@ The visualized end-to-end results are saved to the `./inference_results` folder ...@@ -76,19 +76,19 @@ The visualized end-to-end results are saved to the `./inference_results` folder
This section takes the totaltext dataset as an example to introduce the training, evaluation and testing of the end-to-end model in PaddleOCR. This section takes the totaltext dataset as an example to introduce the training, evaluation and testing of the end-to-end model in PaddleOCR.
### Data Preparation ### Data Preparation
Download and unzip [totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md) dataset to PaddleOCR/train_data/, dataset organization structure is as follow: Download and unzip [totaltext](https://paddleocr.bj.bcebos.com/dataset/total_text.tar) dataset to PaddleOCR/train_data/, dataset organization structure is as follow:
``` ```
/PaddleOCR/train_data/total_text/train/ /PaddleOCR/train_data/total_text/train/
|- rgb/ # total_text training data of dataset |- rgb/ # total_text training data of dataset
|- gt_0.png |- img11.png
| ... | ...
|- total_text.txt # total_text training annotation of dataset |- train.txt # total_text training annotation of dataset
``` ```
total_text.txt: the format of dimension file is as follows,the file name and annotation information are separated by "\t": total_text.txt: the format of dimension file is as follows,the file name and annotation information are separated by "\t":
``` ```
" Image file name Image annotation information encoded by json.dumps" " Image file name Image annotation information encoded by json.dumps"
rgb/gt_0.png [{"transcription": "EST", "points": [[1004.0,689.0],[1019.0,698.0],[1034.0,708.0],[1049.0,718.0],[1064.0,728.0],[1079.0,738.0],[1095.0,748.0],[1094.0,774.0],[1079.0,765.0],[1065.0,756.0],[1050.0,747.0],[1036.0,738.0],[1021.0,729.0],[1007.0,721.0]]}, {...}] rgb/img11.jpg [{"transcription": "ASRAMA", "points": [[214.0, 325.0], [235.0, 308.0], [259.0, 296.0], [286.0, 291.0], [313.0, 295.0], [338.0, 305.0], [362.0, 320.0], [349.0, 347.0], [330.0, 337.0], [310.0, 329.0], [290.0, 324.0], [269.0, 328.0], [249.0, 336.0], [231.0, 346.0]]}, {...}]
``` ```
The image annotation after **json.dumps()** encoding is a list containing multiple dictionaries. The image annotation after **json.dumps()** encoding is a list containing multiple dictionaries.
......
...@@ -193,6 +193,7 @@ def parse_args(mMain=True, add_help=True): ...@@ -193,6 +193,7 @@ def parse_args(mMain=True, add_help=True):
parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--use_dilation", type=bool, default=False) parser.add_argument("--use_dilation", type=bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast")
# EAST parmas # EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
...@@ -241,6 +242,7 @@ def parse_args(mMain=True, add_help=True): ...@@ -241,6 +242,7 @@ def parse_args(mMain=True, add_help=True):
det_db_box_thresh=0.5, det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6, det_db_unclip_ratio=1.6,
use_dilation=False, use_dilation=False,
det_db_score_mode="fast",
det_east_score_thresh=0.8, det_east_score_thresh=0.8,
det_east_cover_thresh=0.1, det_east_cover_thresh=0.1,
det_east_nms_thresh=0.2, det_east_nms_thresh=0.2,
......
...@@ -187,7 +187,51 @@ class CTCLabelEncode(BaseRecLabelEncode): ...@@ -187,7 +187,51 @@ class CTCLabelEncode(BaseRecLabelEncode):
return dict_character return dict_character
class E2ELabelEncode(object): class E2ELabelEncodeTest(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='EN',
use_space_char=False,
**kwargs):
super(E2ELabelEncodeTest,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
def __call__(self, data):
import json
padnum = len(self.dict)
label = data['label']
label = json.loads(label)
nBox = len(label)
boxes, txts, txt_tags = [], [], []
for bno in range(0, nBox):
box = label[bno]['points']
txt = label[bno]['transcription']
boxes.append(box)
txts.append(txt)
if txt in ['*', '###']:
txt_tags.append(True)
else:
txt_tags.append(False)
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool)
data['polys'] = boxes
data['ignore_tags'] = txt_tags
temp_texts = []
for text in txts:
text = text.lower()
text = self.encode(text)
if text is None:
return None
text = text + [padnum] * (self.max_text_len - len(text)
) # use 36 to pad
temp_texts.append(text)
data['texts'] = np.array(temp_texts)
return data
class E2ELabelEncodeTrain(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
pass pass
......
...@@ -72,6 +72,7 @@ class PGDataSet(Dataset): ...@@ -72,6 +72,7 @@ class PGDataSet(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx] file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines[file_idx] data_line = self.data_lines[file_idx]
img_id = 0
try: try:
data_line = data_line.decode('utf-8') data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter) substr = data_line.strip("\n").split(self.delimiter)
...@@ -79,8 +80,9 @@ class PGDataSet(Dataset): ...@@ -79,8 +80,9 @@ class PGDataSet(Dataset):
label = substr[1] label = substr[1]
img_path = os.path.join(self.data_dir, file_name) img_path = os.path.join(self.data_dir, file_name)
if self.mode.lower() == 'eval': if self.mode.lower() == 'eval':
try:
img_id = int(data_line.split(".")[0][7:]) img_id = int(data_line.split(".")[0][7:])
else: except:
img_id = 0 img_id = 0
data = {'img_path': img_path, 'label': label, 'img_id': img_id} data = {'img_path': img_path, 'label': label, 'img_id': img_id}
if not os.path.exists(img_path): if not os.path.exists(img_path):
......
...@@ -18,16 +18,18 @@ from __future__ import print_function ...@@ -18,16 +18,18 @@ from __future__ import print_function
__all__ = ['E2EMetric'] __all__ = ['E2EMetric']
from ppocr.utils.e2e_metric.Deteval import get_socre, combine_results from ppocr.utils.e2e_metric.Deteval import get_socre_A, get_socre_B, combine_results
from ppocr.utils.e2e_utils.extract_textpoint_slow import get_dict from ppocr.utils.e2e_utils.extract_textpoint_slow import get_dict
class E2EMetric(object): class E2EMetric(object):
def __init__(self, def __init__(self,
mode,
gt_mat_dir, gt_mat_dir,
character_dict_path, character_dict_path,
main_indicator='f_score_e2e', main_indicator='f_score_e2e',
**kwargs): **kwargs):
self.mode = mode
self.gt_mat_dir = gt_mat_dir self.gt_mat_dir = gt_mat_dir
self.label_list = get_dict(character_dict_path) self.label_list = get_dict(character_dict_path)
self.max_index = len(self.label_list) self.max_index = len(self.label_list)
...@@ -35,12 +37,44 @@ class E2EMetric(object): ...@@ -35,12 +37,44 @@ class E2EMetric(object):
self.reset() self.reset()
def __call__(self, preds, batch, **kwargs): def __call__(self, preds, batch, **kwargs):
img_id = batch[2][0] if self.mode == 'A':
gt_polyons_batch = batch[2]
temp_gt_strs_batch = batch[3][0]
ignore_tags_batch = batch[4]
gt_strs_batch = []
for temp_list in temp_gt_strs_batch:
t = ""
for index in temp_list:
if index < self.max_index:
t += self.label_list[index]
gt_strs_batch.append(t)
for pred, gt_polyons, gt_strs, ignore_tags in zip(
[preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch):
# prepare gt
gt_info_list = [{
'points': gt_polyon,
'text': gt_str,
'ignore': ignore_tag
} for gt_polyon, gt_str, ignore_tag in
zip(gt_polyons, gt_strs, ignore_tags)]
# prepare det
e2e_info_list = [{
'points': det_polyon,
'texts': pred_str
} for det_polyon, pred_str in
zip(pred['points'], pred['texts'])]
result = get_socre_A(gt_info_list, e2e_info_list)
self.results.append(result)
else:
img_id = batch[5][0]
e2e_info_list = [{ e2e_info_list = [{
'points': det_polyon, 'points': det_polyon,
'texts': pred_str 'texts': pred_str
} for det_polyon, pred_str in zip(preds['points'], preds['texts'])] } for det_polyon, pred_str in zip(preds['points'], preds['texts'])]
result = get_socre(self.gt_mat_dir, img_id, e2e_info_list) result = get_socre_B(self.gt_mat_dir, img_id, e2e_info_list)
self.results.append(result) self.results.append(result)
def get_metric(self): def get_metric(self):
......
...@@ -34,12 +34,18 @@ class DBPostProcess(object): ...@@ -34,12 +34,18 @@ class DBPostProcess(object):
max_candidates=1000, max_candidates=1000,
unclip_ratio=2.0, unclip_ratio=2.0,
use_dilation=False, use_dilation=False,
score_mode="fast",
**kwargs): **kwargs):
self.thresh = thresh self.thresh = thresh
self.box_thresh = box_thresh self.box_thresh = box_thresh
self.max_candidates = max_candidates self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio self.unclip_ratio = unclip_ratio
self.min_size = 3 self.min_size = 3
self.score_mode = score_mode
assert score_mode in [
"slow", "fast"
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
self.dilation_kernel = None if not use_dilation else np.array( self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]]) [[1, 1], [1, 1]])
...@@ -69,7 +75,10 @@ class DBPostProcess(object): ...@@ -69,7 +75,10 @@ class DBPostProcess(object):
if sside < self.min_size: if sside < self.min_size:
continue continue
points = np.array(points) points = np.array(points)
if self.score_mode == "fast":
score = self.box_score_fast(pred, points.reshape(-1, 2)) score = self.box_score_fast(pred, points.reshape(-1, 2))
else:
score = self.box_score_slow(pred, contour)
if self.box_thresh > score: if self.box_thresh > score:
continue continue
...@@ -120,6 +129,9 @@ class DBPostProcess(object): ...@@ -120,6 +129,9 @@ class DBPostProcess(object):
return box, min(bounding_box[1]) return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box): def box_score_fast(self, bitmap, _box):
'''
box_score_fast: use bbox mean score as the mean score
'''
h, w = bitmap.shape[:2] h, w = bitmap.shape[:2]
box = _box.copy() box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
...@@ -133,6 +145,27 @@ class DBPostProcess(object): ...@@ -133,6 +145,27 @@ class DBPostProcess(object):
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
'''
box_score_slow: use polyon mean score as the mean score
'''
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
contour[:, 0] = contour[:, 0] - xmin
contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, outs_dict, shape_list): def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps'] pred = outs_dict['maps']
if isinstance(pred, paddle.Tensor): if isinstance(pred, paddle.Tensor):
......
...@@ -17,7 +17,144 @@ import scipy.io as io ...@@ -17,7 +17,144 @@ import scipy.io as io
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
def get_socre(gt_dir, img_id, pred_dict): def get_socre_A(gt_dir, pred_dict):
allInputs = 1
def input_reading_mod(pred_dict):
"""This helper reads input from txt files"""
det = []
n = len(pred_dict)
for i in range(n):
points = pred_dict[i]['points']
text = pred_dict[i]['texts']
point = ",".join(map(str, points.reshape(-1, )))
det.append([point, text])
return det
def gt_reading_mod(gt_dict):
"""This helper reads groundtruths from mat files"""
gt = []
n = len(gt_dict)
for i in range(n):
points = gt_dict[i]['points'].tolist()
h = len(points)
text = gt_dict[i]['text']
xx = [
np.array(
['x:'], dtype='<U2'), 0, np.array(
['y:'], dtype='<U2'), 0, np.array(
['#'], dtype='<U1'), np.array(
['#'], dtype='<U1')
]
t_x, t_y = [], []
for j in range(h):
t_x.append(points[j][0])
t_y.append(points[j][1])
xx[1] = np.array([t_x], dtype='int16')
xx[3] = np.array([t_y], dtype='int16')
if text != "":
xx[4] = np.array([text], dtype='U{}'.format(len(text)))
xx[5] = np.array(['c'], dtype='<U1')
gt.append(xx)
return gt
def detection_filtering(detections, groundtruths, threshold=0.5):
for gt_id, gt in enumerate(groundtruths):
if (gt[5] == '#') and (gt[1].shape[1] > 1):
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
for det_id, detection in enumerate(detections):
detection_orig = detection
detection = [float(x) for x in detection[0].split(',')]
detection = list(map(int, detection))
det_x = detection[0::2]
det_y = detection[1::2]
det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
if det_gt_iou > threshold:
detections[det_id] = []
detections[:] = [item for item in detections if item != []]
return detections
def sigma_calculation(det_x, det_y, gt_x, gt_y):
"""
sigma = inter_area / gt_area
"""
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
area(gt_x, gt_y)), 2)
def tau_calculation(det_x, det_y, gt_x, gt_y):
if area(det_x, det_y) == 0.0:
return 0
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
area(det_x, det_y)), 2)
##############################Initialization###################################
# global_sigma = []
# global_tau = []
# global_pred_str = []
# global_gt_str = []
###############################################################################
for input_id in range(allInputs):
if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
and (input_id != 'Deteval_result_non_curved.txt'):
detections = input_reading_mod(pred_dict)
groundtruths = gt_reading_mod(gt_dir)
detections = detection_filtering(
detections,
groundtruths) # filters detections overlapping with DC area
dc_id = []
for i in range(len(groundtruths)):
if groundtruths[i][5] == '#':
dc_id.append(i)
cnt = 0
for a in dc_id:
num = a - cnt
del groundtruths[num]
cnt += 1
local_sigma_table = np.zeros((len(groundtruths), len(detections)))
local_tau_table = np.zeros((len(groundtruths), len(detections)))
local_pred_str = {}
local_gt_str = {}
for gt_id, gt in enumerate(groundtruths):
if len(detections) > 0:
for det_id, detection in enumerate(detections):
detection_orig = detection
detection = [float(x) for x in detection[0].split(',')]
detection = list(map(int, detection))
pred_seq_str = detection_orig[1].strip()
det_x = detection[0::2]
det_y = detection[1::2]
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
gt_seq_str = str(gt[4].tolist()[0])
local_sigma_table[gt_id, det_id] = sigma_calculation(
det_x, det_y, gt_x, gt_y)
local_tau_table[gt_id, det_id] = tau_calculation(
det_x, det_y, gt_x, gt_y)
local_pred_str[det_id] = pred_seq_str
local_gt_str[gt_id] = gt_seq_str
global_sigma = local_sigma_table
global_tau = local_tau_table
global_pred_str = local_pred_str
global_gt_str = local_gt_str
single_data = {}
single_data['sigma'] = global_sigma
single_data['global_tau'] = global_tau
single_data['global_pred_str'] = global_pred_str
single_data['global_gt_str'] = global_gt_str
return single_data
def get_socre_B(gt_dir, img_id, pred_dict):
allInputs = 1 allInputs = 1
def input_reading_mod(pred_dict): def input_reading_mod(pred_dict):
......
...@@ -39,7 +39,10 @@ class TextDetector(object): ...@@ -39,7 +39,10 @@ class TextDetector(object):
self.args = args self.args = args
self.det_algorithm = args.det_algorithm self.det_algorithm = args.det_algorithm
pre_process_list = [{ pre_process_list = [{
'DetResizeForTest': None 'DetResizeForTest': {
'limit_side_len': args.det_limit_side_len,
'limit_type': args.det_limit_type
}
}, { }, {
'NormalizeImage': { 'NormalizeImage': {
'std': [0.229, 0.224, 0.225], 'std': [0.229, 0.224, 0.225],
...@@ -62,6 +65,7 @@ class TextDetector(object): ...@@ -62,6 +65,7 @@ class TextDetector(object):
postprocess_params["max_candidates"] = 1000 postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
elif self.det_algorithm == "EAST": elif self.det_algorithm == "EAST":
postprocess_params['name'] = 'EASTPostProcess' postprocess_params['name'] = 'EASTPostProcess'
postprocess_params["score_thresh"] = args.det_east_score_thresh postprocess_params["score_thresh"] = args.det_east_score_thresh
......
...@@ -48,6 +48,7 @@ def parse_args(): ...@@ -48,6 +48,7 @@ def parse_args():
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--use_dilation", type=bool, default=False) parser.add_argument("--use_dilation", type=bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast")
# EAST parmas # EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
......
...@@ -73,7 +73,14 @@ def main(): ...@@ -73,7 +73,14 @@ def main():
global_config['infer_mode'] = True global_config['infer_mode'] = True
ops = create_operators(transforms, global_config) ops = create_operators(transforms, global_config)
save_res_path = config['Global'].get('save_res_path',
"./output/rec/predicts_rec.txt")
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
model.eval() model.eval()
with open(save_res_path, "w") as fout:
for file in get_image_file_list(config['Global']['infer_img']): for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file)) logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f: with open(file, 'rb') as f:
...@@ -102,6 +109,9 @@ def main(): ...@@ -102,6 +109,9 @@ def main():
post_result = post_process_class(preds) post_result = post_process_class(preds)
for rec_reuslt in post_result: for rec_reuslt in post_result:
logger.info('\t result: {}'.format(rec_reuslt)) logger.info('\t result: {}'.format(rec_reuslt))
if len(rec_reuslt) >= 2:
fout.write(file + "\t" + rec_reuslt[0] + "\t" + str(
rec_reuslt[1]) + "\n")
logger.info("success!") logger.info("success!")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册