提交 af94699c 编写于 作者: L LDOUBLEV

Merge branch 'release/2.1' of https://github.com/PaddlePaddle/PaddleOCR into fix_hub_21

......@@ -44,7 +44,7 @@ The above pictures are the visualizations of the English recognition model. For
- Scan the QR code below with your Wechat, you can access to official technical exchange group. Look forward to your participation.
<div align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.0/doc/joinus.PNG" width = "200" height = "200" />
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/dygraph/doc/joinus.PNG" width = "200" height = "200" />
</div>
......
......@@ -8,7 +8,7 @@ PaddleOCR同时支持动态图与静态图两种编程范式
- 静态图版本:develop分支
**近期更新**
- 2021.4.20 [FAQ](./doc/doc_ch/FAQ.md)新增5个高频问题,总数208个,每周一都会更新,欢迎大家持续关注。
- 2021.4.26 [FAQ](./doc/doc_ch/FAQ.md)新增5个高频问题,总数213个,每周一都会更新,欢迎大家持续关注。
- PaddleOCR研发团队对最新发版内容技术深入解读,4月13日晚上19:00,[直播地址](https://live.bilibili.com/21689802)
- 2021.4.8 release 2.1版本,新增AAAI 2021论文[端到端识别算法PGNet](./doc/doc_ch/pgnet.md)开源,[多语言模型](./doc/doc_ch/multi_languages.md)支持种类增加到80+。
- 2021.2.8 正式发布PaddleOCRv2.0(branch release/2.0)并设置为推荐用户使用的默认分支. 发布的详细内容,请参考: https://github.com/PaddlePaddle/PaddleOCR/releases/tag/v2.0.0
......@@ -45,7 +45,7 @@ PaddleOCR同时支持动态图与静态图两种编程范式
- 微信扫描二维码加入官方交流群,获得更高效的问题答疑,与各行各业开发者充分交流,期待您的加入。
<div align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.0/doc/joinus.PNG" width = "200" height = "200" />
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/dygraph/doc/joinus.PNG" width = "200" height = "200" />
</div>
## 快速体验
......@@ -78,7 +78,7 @@ PaddleOCR同时支持动态图与静态图两种编程范式
- 算法介绍
- [文本检测](./doc/doc_ch/algorithm_overview.md)
- [文本识别](./doc/doc_ch/algorithm_overview.md)
- [PP-OCR Pipline](#PP-OCR)
- [PP-OCR Pipeline](#PP-OCR)
- [端到端PGNet算法](./doc/doc_ch/pgnet.md)
- 模型训练/评估
- [文本检测](./doc/doc_ch/detection.md)
......@@ -113,7 +113,7 @@ PaddleOCR同时支持动态图与静态图两种编程范式
<a name="PP-OCR"></a>
## PP-OCR Pipline
## PP-OCR Pipeline
<div align="center">
<img src="./doc/ppocr_framework.png" width="800">
</div>
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: True
save_res_path: ./output/rec/predicts_chinese_common_v2.0.txt
Optimizer:
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: True
save_res_path: ./output/rec/predicts_chinese_lite_v2.0.txt
Optimizer:
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_ic15.txt
Optimizer:
name: Adam
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_mv3_none_bilstm_ctc.txt
Optimizer:
name: Adam
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_mv3_none_none_ctc.txt
Optimizer:
name: Adam
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_mv3_tps_bilstm_att.txt
Optimizer:
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_mv3_tps_bilstm_ctc.txt
Optimizer:
name: Adam
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_r34_vd_none_bilstm_ctc.txt
Optimizer:
name: Adam
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_r34_vd_none_none_ctc.txt
Optimizer:
name: Adam
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_b3_rare_r34_none_gru.txt
Optimizer:
......
......@@ -19,6 +19,7 @@ Global:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_r34_vd_tps_bilstm_ctc.txt
Optimizer:
name: Adam
......
......@@ -20,6 +20,7 @@ Global:
num_heads: 8
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_srn.txt
Optimizer:
......
......@@ -49,6 +49,8 @@ public:
this->det_db_unclip_ratio = stod(config_map_["det_db_unclip_ratio"]);
this->use_polygon_score = bool(stoi(config_map_["use_polygon_score"]));
this->det_model_dir.assign(config_map_["det_model_dir"]);
this->rec_model_dir.assign(config_map_["rec_model_dir"]);
......@@ -86,6 +88,8 @@ public:
double det_db_unclip_ratio = 2.0;
bool use_polygon_score = false;
std::string det_model_dir;
std::string rec_model_dir;
......
......@@ -44,7 +44,8 @@ public:
const bool &use_mkldnn, const int &max_side_len,
const double &det_db_thresh,
const double &det_db_box_thresh,
const double &det_db_unclip_ratio, const bool &visualize,
const double &det_db_unclip_ratio,
const bool &use_polygon_score, const bool &visualize,
const bool &use_tensorrt, const bool &use_fp16) {
this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id;
......@@ -57,6 +58,7 @@ public:
this->det_db_thresh_ = det_db_thresh;
this->det_db_box_thresh_ = det_db_box_thresh;
this->det_db_unclip_ratio_ = det_db_unclip_ratio;
this->use_polygon_score_ = use_polygon_score;
this->visualize_ = visualize;
this->use_tensorrt_ = use_tensorrt;
......@@ -85,6 +87,7 @@ private:
double det_db_thresh_ = 0.3;
double det_db_box_thresh_ = 0.5;
double det_db_unclip_ratio_ = 2.0;
bool use_polygon_score_ = false;
bool visualize_ = true;
bool use_tensorrt_ = false;
......
......@@ -55,7 +55,8 @@ public:
std::vector<std::vector<std::vector<int>>>
BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
const float &box_thresh, const float &det_db_unclip_ratio);
const float &box_thresh, const float &det_db_unclip_ratio,
const bool &use_polygon_score);
std::vector<std::vector<std::vector<int>>>
FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
......
......@@ -183,7 +183,7 @@ cmake .. \
make -j
```
`OPENCV_DIR`为opencv编译安装的地址;`LIB_DIR`为下载(`paddle_inference`文件夹)或者编译生成的Paddle预测库地址(`build/paddle_inference_install_dir`文件夹);`CUDA_LIB_DIR`为cuda库文件地址,在docker中`/usr/local/cuda/lib64``CUDNN_LIB_DIR`为cudnn库文件地址,在docker中为`/usr/lib/x86_64-linux-gnu/`
`OPENCV_DIR`为opencv编译安装的地址;`LIB_DIR`为下载(`paddle_inference`文件夹)或者编译生成的Paddle预测库地址(`build/paddle_inference_install_dir`文件夹);`CUDA_LIB_DIR`为cuda库文件地址,在docker中为`/usr/local/cuda/lib64``CUDNN_LIB_DIR`为cudnn库文件地址,在docker中为`/usr/lib/x86_64-linux-gnu/`
* 编译完成之后,会在`build`文件夹下生成一个名为`ocr_system`的可执行文件。
......@@ -211,6 +211,7 @@ max_side_len 960 # 输入图像长宽大于960时,等比例缩放图像,使
det_db_thresh 0.3 # 用于过滤DB预测的二值化图像,设置为0.-0.3对结果影响不明显
det_db_box_thresh 0.5 # DB后处理过滤box的阈值,如果检测存在漏框情况,可酌情减小
det_db_unclip_ratio 1.6 # 表示文本框的紧致程度,越小则文本框更靠近文本
use_polygon_score 1 # 是否使用多边形框计算bbox score,0表示使用矩形框计算。矩形框计算速度更快,多边形框对弯曲文本区域计算更准确。
det_model_dir ./inference/det_db # 检测模型inference model地址
# cls config
......
......@@ -219,6 +219,7 @@ max_side_len 960 # Limit the maximum image height and width to 960
det_db_thresh 0.3 # Used to filter the binarized image of DB prediction, setting 0.-0.3 has no obvious effect on the result
det_db_box_thresh 0.5 # DDB post-processing filter box threshold, if there is a missing box detected, it can be reduced as appropriate
det_db_unclip_ratio 1.6 # Indicates the compactness of the text box, the smaller the value, the closer the text box to the text
use_polygon_score 1 # Whether to use polygon box to calculate bbox score, 0 means to use rectangle box to calculate. Use rectangular box to calculate faster, and polygonal box more accurate for curved text area.
det_model_dir ./inference/det_db # Address of detection inference model
# cls config
......
......@@ -59,7 +59,8 @@ int main(int argc, char **argv) {
config.gpu_mem, config.cpu_math_library_num_threads,
config.use_mkldnn, config.max_side_len, config.det_db_thresh,
config.det_db_box_thresh, config.det_db_unclip_ratio,
config.visualize, config.use_tensorrt, config.use_fp16);
config.use_polygon_score, config.visualize,
config.use_tensorrt, config.use_fp16);
Classifier *cls = nullptr;
if (config.use_angle_cls == true) {
......
......@@ -109,9 +109,9 @@ void DBDetector::Run(cv::Mat &img,
cv::Mat dilation_map;
cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
cv::dilate(bit_map, dilation_map, dila_ele);
boxes = post_processor_.BoxesFromBitmap(pred_map, dilation_map,
this->det_db_box_thresh_,
this->det_db_unclip_ratio_);
boxes = post_processor_.BoxesFromBitmap(
pred_map, dilation_map, this->det_db_box_thresh_,
this->det_db_unclip_ratio_, this->use_polygon_score_);
boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg);
......
......@@ -160,35 +160,49 @@ std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box,
}
float PostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
cv::Mat pred){
cv::Mat pred) {
int width = pred.cols;
int height = pred.rows;
std::vector<float> box_x;
std::vector<float> box_y;
for(int i=0; i<contour.size(); ++i){
for (int i = 0; i < contour.size(); ++i) {
box_x.push_back(contour[i].x);
box_y.push_back(contour[i].y);
}
int xmin = clamp(int(std::floor(*(std::min_element(box_x.begin(), box_x.end())))), 0, width - 1);
int xmax = clamp(int(std::ceil(*(std::max_element(box_x.begin(), box_x.end())))), 0, width - 1);
int ymin = clamp(int(std::floor(*(std::min_element(box_y.begin(), box_y.end())))), 0, height - 1);
int ymax = clamp(int(std::ceil(*(std::max_element(box_y.begin(), box_y.end())))), 0, height - 1);
int xmin =
clamp(int(std::floor(*(std::min_element(box_x.begin(), box_x.end())))), 0,
width - 1);
int xmax =
clamp(int(std::ceil(*(std::max_element(box_x.begin(), box_x.end())))), 0,
width - 1);
int ymin =
clamp(int(std::floor(*(std::min_element(box_y.begin(), box_y.end())))), 0,
height - 1);
int ymax =
clamp(int(std::ceil(*(std::max_element(box_y.begin(), box_y.end())))), 0,
height - 1);
cv::Mat mask;
mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);
cv::Point rook_point[contour.size()];
for(int i=0; i<contour.size(); ++i){
cv::Point* rook_point = new cv::Point[contour.size()];
for (int i = 0; i < contour.size(); ++i) {
rook_point[i] = cv::Point(int(box_x[i]) - xmin, int(box_y[i]) - ymin);
}
const cv::Point *ppt[1] = {rook_point};
int npt[] = {int(contour.size())};
cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));
cv::Mat croppedImg;
pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1)).copyTo(croppedImg);
float score = cv::mean(croppedImg, mask)[0];
delete []rook_point;
return score;
}
......@@ -230,10 +244,9 @@ float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
return score;
}
std::vector<std::vector<std::vector<int>>>
PostProcessor::BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
const float &box_thresh,
const float &det_db_unclip_ratio) {
std::vector<std::vector<std::vector<int>>> PostProcessor::BoxesFromBitmap(
const cv::Mat pred, const cv::Mat bitmap, const float &box_thresh,
const float &det_db_unclip_ratio, const bool &use_polygon_score) {
const int min_size = 3;
const int max_candidates = 1000;
......@@ -267,9 +280,12 @@ PostProcessor::BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
}
float score;
score = BoxScoreFast(array, pred);
/* compute using polygon*/
// score = PolygonScoreAcc(contours[_i], pred);
if (use_polygon_score)
/* compute using polygon*/
score = PolygonScoreAcc(contours[_i], pred);
else
score = BoxScoreFast(array, pred);
if (score < box_thresh)
continue;
......
......@@ -77,19 +77,10 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
int resize_h = int(float(h) * ratio);
int resize_w = int(float(w) * ratio);
if (resize_h % 32 == 0)
resize_h = resize_h;
else if (resize_h / 32 < 1 + 1e-5)
resize_h = 32;
else
resize_h = (resize_h / 32) * 32;
if (resize_w % 32 == 0)
resize_w = resize_w;
else if (resize_w / 32 < 1 + 1e-5)
resize_w = 32;
else
resize_w = (resize_w / 32) * 32;
resize_h = max(int(round(float(resize_h) / 32) * 32), 32);
resize_w = max(int(round(float(resize_w) / 32) * 32), 32);
if (!use_tensorrt) {
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
ratio_h = float(resize_h) / float(h);
......
......@@ -10,6 +10,7 @@ max_side_len 960
det_db_thresh 0.3
det_db_box_thresh 0.5
det_db_unclip_ratio 1.6
use_polygon_score 1
det_model_dir ./inference/ch_ppocr_mobile_v2.0_det_infer/
# cls config
......
......@@ -16,6 +16,7 @@ import paddlehub as hub
from tools.infer.utility import base64_to_cv2
from tools.infer.predict_cls import TextClassifier
from tools.infer.utility import parse_args
from deploy.hubserving.ocr_cls.params import read_params
@moduleinfo(
......@@ -55,7 +56,6 @@ class OCRCls(hub.Module):
sys.argv = sys.argv[:1]
cfg = parse_args()
from ocr_det.params import read_params
update_cfg_map = vars(read_params())
for key in update_cfg_map:
......
......@@ -9,41 +9,35 @@
## PaddleOCR常见问题汇总(持续更新)
* [近期更新(2021.4.20](#近期更新)
* [近期更新(2021.4.26](#近期更新)
* [【精选】OCR精选10个问题](#OCR精选10个问题)
* [【理论篇】OCR通用43个问题](#OCR通用问题)
* [基础知识13题](#基础知识)
* [数据集9题](#数据集2)
* [模型训练调优21题](#模型训练调优2)
* [【实战篇】PaddleOCR实战150个问题](#PaddleOCR实战问题)
* [使用咨询61](#使用咨询)
* [【实战篇】PaddleOCR实战160个问题](#PaddleOCR实战问题)
* [使用咨询63](#使用咨询)
* [数据集18题](#数据集3)
* [模型训练调优34](#模型训练调优3)
* [预测部署42](#预测部署3)
* [模型训练调优35](#模型训练调优3)
* [预测部署44](#预测部署3)
<a name="近期更新"></a>
## 近期更新(2021.4.20
## 近期更新(2021.4.26
#### Q3.1.58: 使用PGNet进行eval报错?
**A**: 需要注意,我们目前在release/2.1更新了评测代码,目前支持A,B两种评测模式:
* A模式:该模式主要为了方便用户使用,与训练集一样的标注文件就可以正常进行eval操作, 代码中默认是A模式。
* B模式:该模式主要为了保证我们的评测代码可以和Total Text官方的评测方式对齐,该模式下直接加载官方提供的mat文件进行eval。
#### Q3.1.62: 弯曲文本(如略微形变的文档图像)漏检问题
**A**: db后处理中计算文本框平均得分时,是求rectangle区域的平均分数,容易造成弯曲文本漏检,已新增求polygon区域的平均分数,会更准确,但速度有所降低,可按需选择,在相关pr中可查看[可视化对比效果](https://github.com/PaddlePaddle/PaddleOCR/pull/2604)。该功能通过参数 [det_db_score_mode](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/tools/infer/utility.py#L51)进行选择,参数值可选[`fast`(默认)、`slow`],`fast`对应原始的rectangle方式,`slow`对应polygon方式。感谢用户[buptlihang](https://github.com/buptlihang)[pr](https://github.com/PaddlePaddle/PaddleOCR/pull/2574)帮助解决该问题🌹。
#### Q3.1.59: 使用预训练模型进行预测,对于特定字符识别识别效果较差,怎么解决
**A**: 由于我们所提供的识别模型是基于通用大规模数据集进行训练的,部分字符可能在训练集中包含较少,因此您可以构建特定场景的数据集,基于我们提供的预训练模型进行微调。建议用于微调的数据集中,每个字符出现的样本数量不低于300,但同时需要注意不同字符的数量均衡。具体可以参考:[微调](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_ch/recognition.md#2-%E5%90%AF%E5%8A%A8%E8%AE%AD%E7%BB%83)
#### Q3.1.63: 请问端到端的pgnet相比于DB+CRNN在准确率上有优势吗?或者是pgnet最擅长的场景是什么场景呢
**A**: pgnet是端到端算法,检测识别一步到位,不用分开训练2个模型,也支持弯曲文本的识别,但是在中文上的效果还没有充分验证;db+crnn的验证更充分,应用相对成熟,常规非弯曲的文本都能解的不错。
#### Q3.1.60: PGNet有中文预训练模型吗?
**A**: 目前我们尚未提供针对中文的预训练模型,如有需要,可以尝试自己训练。具体需要修改的地方有:
1. [config文件中](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/configs/e2e/e2e_r50_vd_pg.yml#L23-L24),字典文件路径及语种设置;
1. [网络结构中](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/ppocr/modeling/heads/e2e_pg_head.py#L181)`out_channels`修改为字典中的字符数目+1(考虑到空格);
1. [loss中](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/ppocr/losses/e2e_pg_loss.py#L93),修改`37`为字典中的字符数目+1(考虑到空格);
#### Q3.3.35: SRN训练不收敛(loss不降)或SRN训练acc一直为0。
**A**: 如果loss下降不正常,需要确认没有修改yml文件中的image_shape,默认[1, 64, 256],代码中针对这个配置写死了,修改可能会造成无法收敛。如果确认参数无误,loss正常下降,可以多迭代一段时间观察下,开始acc为0是正常的。
#### Q3.1.61: 用于PGNet的训练集,文本框的标注有要求吗?
**A**: PGNet支持多点标注,比如4点、8点、14点等。但需要注意的是,标注点尽可能分布均匀(相邻标注点间隔距离均匀一致),且label文件中的标注点需要从标注框的左上角开始,按标注点顺时针顺序依次编写,以上问题都可能对训练精度造成影响。
我们提供的,基于Total Text数据集的PGNet预训练模型使用了14点标注方式。
#### Q3.4.43: 预测时显存爆炸、内存泄漏问题?
**A**: 打开显存/内存优化开关`enable_memory_optim`可以解决该问题,相关代码已合入,[查看详情](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/tools/infer/utility.py#L153)
#### Q3.4.42: 在使用PaddleLite进行预测部署时,启动预测后卡死/手机死机
**A**: 请检查模型转换时所用PaddleLite的版本,和预测库的版本是否对齐。即PaddleLite版本为2.8,则预测库版本也要为2.8
#### Q3.4.44: 如何多进程预测
**A**: 近期PaddleOCR新增了[多进程预测控制参数](https://github.com/PaddlePaddle/PaddleOCR/blob/a312647be716776c1aac33ff939ae358a39e8188/tools/infer/utility.py#L103)`use_mp`表示是否使用多进程,`total_process_num`表示在使用多进程时的进程数。具体使用方式请参考[文档](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_ch/inference.md#1-%E8%B6%85%E8%BD%BB%E9%87%8F%E4%B8%AD%E6%96%87ocr%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86)
<a name="OCR精选10个问题"></a>
## 【精选】OCR精选10个问题
......@@ -638,6 +632,11 @@ repo中config.yml文件的前后处理参数和inference预测默认的超参数
**A**: PGNet支持多点标注,比如4点、8点、14点等。但需要注意的是,标注点尽可能分布均匀(相邻标注点间隔距离均匀一致),且label文件中的标注点需要从标注框的左上角开始,按标注点顺时针顺序依次编写,以上问题都可能对训练精度造成影响。
我们提供的,基于Total Text数据集的PGNet预训练模型使用了14点标注方式。
#### Q3.1.62: 弯曲文本(如略微形变的文档图像)漏检问题
**A**: db后处理中计算文本框平均得分时,是求rectangle区域的平均分数,容易造成弯曲文本漏检,已新增求polygon区域的平均分数,会更准确,但速度有所降低,可按需选择,在相关pr中可查看[可视化对比效果](https://github.com/PaddlePaddle/PaddleOCR/pull/2604)。该功能通过参数 [det_db_score_mode](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/tools/infer/utility.py#L51)进行选择,参数值可选[`fast`(默认)、`slow`],`fast`对应原始的rectangle方式,`slow`对应polygon方式。感谢用户[buptlihang](https://github.com/buptlihang)[pr](https://github.com/PaddlePaddle/PaddleOCR/pull/2574)帮助解决该问题🌹。
#### Q3.1.63: 请问端到端的pgnet相比于DB+CRNN在准确率上有优势吗?或者是pgnet最擅长的场景是什么场景呢?
**A**: pgnet是端到端算法,检测识别一步到位,不用分开训练2个模型,也支持弯曲文本的识别,但是在中文上的效果还没有充分验证;db+crnn的验证更充分,应用相对成熟,常规非弯曲的文本都能解的不错。
<a name="数据集3"></a>
......@@ -911,8 +910,10 @@ lr:
#### Q3.3.34: 表格识别中,如何提高单字的识别结果?
**A**: 首先需要确认一下检测模型有没有有效的检测出单个字符,如果没有的话,需要在训练集当中添加相应的单字数据集。
<a name="预测部署3"></a>
#### Q3.3.35: SRN训练不收敛(loss不降)或SRN训练acc一直为0。
**A**: 如果loss下降不正常,需要确认没有修改yml文件中的image_shape,默认[1, 64, 256],代码中针对这个配置写死了,修改可能会造成无法收敛。如果确认参数无误,loss正常下降,可以多迭代一段时间观察下,开始acc为0是正常的。
<a name="预测部署3"></a>
### 预测部署
......@@ -956,10 +957,6 @@ lr:
**A**:在安卓APK上无法设置,没有暴露这个接口,如果使用的是PaddledOCR/deploy/lite/的demo,可以修改config.txt中的对应参数来设置
#### Q3.4.9:PaddleOCR模型是否可以转换成ONNX模型?
**A**:目前暂不支持转ONNX,相关工作在研发中。
#### Q3.4.10:使用opt工具对检测模型转换时报错 can not found op arguments for node conv2_b_attr
**A**:这个问题大概率是编译opt工具的Paddle-Lite不是develop分支,建议使用Paddle-Lite 的develop分支编译opt工具。
......@@ -1114,3 +1111,9 @@ nvidia-smi --lock-gpu-clocks=1590 -i 0
#### Q3.4.42: 在使用PaddleLite进行预测部署时,启动预测后卡死/手机死机?
**A**: 请检查模型转换时所用PaddleLite的版本,和预测库的版本是否对齐。即PaddleLite版本为2.8,则预测库版本也要为2.8。
#### Q3.4.43: 预测时显存爆炸、内存泄漏问题?
**A**: 打开显存/内存优化开关`enable_memory_optim`可以解决该问题,相关代码已合入,[查看详情](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/tools/infer/utility.py#L153)
#### Q3.4.44: 如何多进程预测?
**A**: 近期PaddleOCR新增了[多进程预测控制参数](https://github.com/PaddlePaddle/PaddleOCR/blob/a312647be716776c1aac33ff939ae358a39e8188/tools/infer/utility.py#L103)`use_mp`表示是否使用多进程,`total_process_num`表示在使用多进程时的进程数。具体使用方式请参考[文档](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_ch/inference.md#1-%E8%B6%85%E8%BD%BB%E9%87%8F%E4%B8%AD%E6%96%87ocr%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86)
......@@ -47,7 +47,7 @@ PaddleOCR 旨在打造一套丰富、领先、且实用的OCR工具库,不仅
pip install paddlepaddle
# gpu
pip instll paddlepaddle-gpu
pip install paddlepaddle-gpu
```
<a name="paddleocr_package_安装"></a>
......@@ -179,11 +179,11 @@ ppocr 支持使用自己的数据进行自定义训练或finetune, 其中识别
## 4 预测部署
除了安装whl包进行快速预测,ppocr 也提供了多种预测部署方式,如有需求可阅读相关文档:
- [基于Python脚本预测引擎推理](./doc/doc_ch/inference.md)
- [基于C++预测引擎推理](./deploy/cpp_infer/readme.md)
- [服务化部署](./deploy/hubserving/readme.md)
- [基于Python脚本预测引擎推理](./inference.md)
- [基于C++预测引擎推理](../../deploy/cpp_infer/readme.md)
- [服务化部署](../../deploy/hubserving/readme.md)
- [端侧部署](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/deploy/lite/readme.md)
- [Benchmark](./doc/doc_ch/benchmark.md)
- [Benchmark](./benchmark.md)
......
......@@ -48,7 +48,7 @@ This document will briefly introduce how to use the multilingual model.
pip install paddlepaddle
# gpu
pip instll paddlepaddle-gpu
pip install paddlepaddle-gpu
```
<a name="paddleocr_package_install"></a>
......@@ -181,11 +181,11 @@ In addition to installing the whl package for quick forecasting,
ppocr also provides a variety of forecasting deployment methods.
If necessary, you can read related documents:
- [Python Inference](./doc/doc_en/inference_en.md)
- [C++ Inference](./deploy/cpp_infer/readme_en.md)
- [Serving](./deploy/hubserving/readme_en.md)
- [Python Inference](./inference_en.md)
- [C++ Inference](../../deploy/cpp_infer/readme_en.md)
- [Serving](../../deploy/hubserving/readme_en.md)
- [Mobile](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/deploy/lite/readme_en.md)
- [Benchmark](./doc/doc_en/benchmark_en.md)
- [Benchmark](./benchmark_en.md)
<a name="language_abbreviations"></a>
......
......@@ -249,7 +249,7 @@ class ResNet(nn.Layer):
name=conv_name))
shortcut = True
self.block_list.append(bottleneck_block)
self.out_channels = num_filters[block]
self.out_channels = num_filters[block] * 4
else:
for block in range(len(depth)):
shortcut = False
......
......@@ -285,8 +285,7 @@ class PrePostProcessLayer(nn.Layer):
elif cmd == "n": # add layer normalization
self.functors.append(
self.add_sublayer(
"layer_norm_%d" % len(
self.sublayers(include_sublayers=False)),
"layer_norm_%d" % len(self.sublayers()),
paddle.nn.LayerNorm(
normalized_shape=d_model,
weight_attr=fluid.ParamAttr(
......@@ -320,9 +319,7 @@ class PrepareEncoder(nn.Layer):
self.src_emb_dim = src_emb_dim
self.src_max_len = src_max_len
self.emb = paddle.nn.Embedding(
num_embeddings=self.src_max_len,
embedding_dim=self.src_emb_dim,
sparse=True)
num_embeddings=self.src_max_len, embedding_dim=self.src_emb_dim)
self.dropout_rate = dropout_rate
def forward(self, src_word, src_pos):
......
......@@ -39,7 +39,10 @@ class TextDetector(object):
self.args = args
self.det_algorithm = args.det_algorithm
pre_process_list = [{
'DetResizeForTest': None
'DetResizeForTest': {
'limit_side_len': args.det_limit_side_len,
'limit_type': args.det_limit_type
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
......@@ -160,7 +163,6 @@ class TextDetector(object):
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
starttime = time.time()
self.input_tensor.copy_from_cpu(img)
self.predictor.run()
outputs = []
......
......@@ -73,35 +73,45 @@ def main():
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
save_res_path = config['Global'].get('save_res_path',
"./output/rec/predicts_rec.txt")
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
model.eval()
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f:
img = f.read()
data = {'image': img}
batch = transform(data, ops)
if config['Architecture']['algorithm'] == "SRN":
encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
others = [
paddle.to_tensor(encoder_word_pos_list),
paddle.to_tensor(gsrm_word_pos_list),
paddle.to_tensor(gsrm_slf_attn_bias1_list),
paddle.to_tensor(gsrm_slf_attn_bias2_list)
]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
preds = model(images, others)
else:
preds = model(images)
post_result = post_process_class(preds)
for rec_reuslt in post_result:
logger.info('\t result: {}'.format(rec_reuslt))
with open(save_res_path, "w") as fout:
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f:
img = f.read()
data = {'image': img}
batch = transform(data, ops)
if config['Architecture']['algorithm'] == "SRN":
encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
others = [
paddle.to_tensor(encoder_word_pos_list),
paddle.to_tensor(gsrm_word_pos_list),
paddle.to_tensor(gsrm_slf_attn_bias1_list),
paddle.to_tensor(gsrm_slf_attn_bias2_list)
]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
preds = model(images, others)
else:
preds = model(images)
post_result = post_process_class(preds)
for rec_reuslt in post_result:
logger.info('\t result: {}'.format(rec_reuslt))
if len(rec_reuslt) >= 2:
fout.write(file + "\t" + rec_reuslt[0] + "\t" + str(
rec_reuslt[1]) + "\n")
logger.info("success!")
......
......@@ -18,6 +18,7 @@ from __future__ import print_function
import os
import sys
import platform
import yaml
import time
import shutil
......@@ -196,9 +197,11 @@ def train(config,
train_reader_cost = 0.0
batch_sum = 0
batch_start = time.time()
max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader):
train_reader_cost += time.time() - batch_start
if idx >= len(train_dataloader):
if idx >= max_iter:
break
lr = optimizer.get_lr()
images = batch[0]
......@@ -335,8 +338,10 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
total_frame = 0.0
total_time = 0.0
pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
max_iter = len(valid_dataloader) - 1 if platform.system(
) == "Windows" else len(valid_dataloader)
for idx, batch in enumerate(valid_dataloader):
if idx >= len(valid_dataloader):
if idx >= max_iter:
break
images = batch[0]
start = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册