diff --git a/doc/doc_ch/enhanced_ctc_loss.md b/doc/doc_ch/enhanced_ctc_loss.md
new file mode 100644
index 0000000000000000000000000000000000000000..309dc712dc0242b859f934338be96e6648f81031
--- /dev/null
+++ b/doc/doc_ch/enhanced_ctc_loss.md
@@ -0,0 +1,78 @@
+# Enhanced CTC Loss
+
+在OCR识别中, CRNN是一种在工业界广泛使用的文字识别算法。 在训练阶段,其采用CTCLoss来计算网络损失; 在推理阶段,其采用CTCDecode来获得解码结果。虽然CRNN算法在实际业务中被证明能够获得很好的识别效果, 然而用户对识别准确率的要求却是无止境的,如何进一步提升文字识别的准确率呢? 本文以CTCLoss为切人点,分别从难例挖掘、 多任务学习、 Metric Learning 3个不同的角度探索了CTCLoss的改进融合方案,提出了EnhancedCTCLoss,其包括如下3个组成部分: Focal-CTC Loss,A-CTC Loss, C-CTC Loss。
+
+## 1. Focal-CTC Loss
+Focal Loss 出自论文《Focal Loss for Dense Object Detection》, 该loss最先提出的时候主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题。该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘。
+其损失函数形式如下:
+
+
+
+
+其中, y' 是经过激活函数的输出,取值在0-1之间。其在原始的交叉熵损失的基础上加了一个调制系数(1 – y’)^ γ和平衡因子α。 当α = 1,y=1时,其损失函数与交叉熵损失的对比如下图所示:
+
+
+
+
+从上图可以看到, 当γ> 0时,调整系数(1-y’)^γ 赋予易分类样本损失一个更小的权重,使得网络更关注于困难的、错分的样本。 调整因子γ用于调节简单样本权重降低的速率,当γ为0时即为交叉熵损失函数,当γ增加时,调整因子的影响也会随之增大。实验发现γ为2是最优。平衡因子α用来平衡正负样本本身的比例不均,文中α取0.25。
+
+对于经典的CTC算法,假设某个特征序列(f1, f2, ......ft), 经过CTC解码之后结果等于label的概率为y’, 则CTC解码结果不为label的概率即为(1-y’);不难发现 CTCLoss值和y’有如下关系:
+
+
+
+
+结合Focal Loss的思想,赋予困难样本较大的权重,简单样本较小的权重,可以使网络更加聚焦于对困难样本的挖掘,进一步提升识别的准确率,由此我们提出了Focal-CTC Loss; 其定义如下所示:
+
+
+
+
+实验中,γ取值为2, α= 1, 具体实现见: [rec_ctc_loss.py](../../ppocr/losses/rec_ctc_loss.py)
+
+## 2. A-CTC Loss
+A-CTC Loss是CTC Loss + ACE Loss的简称。 其中ACE Loss出自论文< Aggregation Cross-Entropy for Sequence Recognition>. ACE Loss相比于CTCLoss,主要有如下两点优势:
++ ACE Loss能够解决2-D文本的识别问题; CTCLoss只能够处理1-D文本
++ ACE Loss 在时间复杂度和空间复杂度上优于CTC loss
+
+前人总结的OCR识别算法的优劣如下图所示:
+
+
+
+
+虽然ACELoss确实如上图所说,可以处理2D预测,在内存占用及推理速度方面具备优势,但在实践过程中,我们发现单独使用ACE Loss, 识别效果并不如CTCLoss. 因此,我们尝试将CTCLoss和ACELoss进行组合,同时以CTCLoss为主,将ACELoss 定位为一个辅助监督loss。 这一尝试收到了效果,在我们内部的实验数据集上,相比单独使用CTCLoss,识别准确率可以提升1%左右。
+A_CTC Loss定义如下:
+
+
+
+
+实验中,λ = 0.1. ACE loss实现代码见: [ace_loss.py](../../ppocr/losses/ace_loss.py)
+
+## 3. C-CTC Loss
+C-CTC Loss是CTC Loss + Center Loss的简称。 其中Center Loss出自论文 < A Discriminative Feature Learning Approach for Deep Face Recognition>. 最早用于人脸识别任务,用于增大累间距离,减小类内距离, 是Metric Learning领域一种较早的、也比较常用的一种算法。
+在中文OCR识别任务中,通过对badcase分析, 我们发现中文识别的一大难点是相似字符多,容易误识。 由此我们想到是否可以借鉴Metric Learing的想法, 增大相似字符的类间距,从而提高识别准确率。然而,MetricLearning主要用于图像识别领域,训练数据的标签为一个固定的值;而对于OCR识别来说,其本质上是一个序列识别任务,特征和label之间并不具有显式的对齐关系,因此两者如何结合依然是一个值得探索的方向。
+通过尝试Arcmargin, Cosmargin等方法, 我们最终发现Centerloss 有助于进一步提升识别的准确率。C_CTC Loss定义如下:
+
+
+
+
+实验中,我们设置λ=0.25. center_loss实现代码见: [center_loss.py](../../ppocr/losses/center_loss.py)
+
+值得一提的是, 在C-CTC Loss中,选择随机初始化Center并不能够带来明显的提升. 我们的Center初始化方法如下:
++ 基于原始的CTCLoss, 训练得到一个网络N
++ 挑选出训练集中,识别完全正确的部分, 组成集合G
++ 将G中的每个样本送入网络,进行前向计算, 提取最后一个FC层的输入(即feature)及其经过argmax计算的结果(即index)之间的对应关系
++ 将相同index的feature进行聚合,计算平均值,得到各自字符的初始center.
+
+以配置文件`configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml`为例, center提取命令如下所示:
+```
+python tools/export_center.py -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml -o Global.pretrained_model: "./output/rec_mobile_pp-OCRv2/best_accuracy"
+```
+运行完后,会在PaddleOCR主目录下生成`train_center.pkl`.
+
+## 4. 实验
+对于上述的三种方案,我们基于百度内部数据集进行了训练、评测,实验情况如下表所示:
+|algorithm| Focal_CTC | A_CTC | C-CTC |
+|:------| :------| ------: | :------: |
+|gain| +0.3% | +0.7% | +1.7% |
+
+基于上述实验结论,我们在PP-OCRv2中,采用了C-CTC的策略。 值得一提的是,由于PP-OCRv2 处理的是6625个中文字符的识别任务,字符集比较大,形似字较多,所以在该任务上C-CTC 方案带来的提升较大。 但如果换做其他OCR识别任务,结论可能会有所不同。大家可以尝试Focal-CTC,A-CTC, C-CTC以及组合方案EnhancedCTC,相信会带来不同程度的提升效果。
+统一的融合方案见如下文件: [rec_enhanced_ctc_loss.py](../../ppocr/losses/rec_enhanced_ctc_loss.py)
diff --git a/doc/doc_ch/equation_a_ctc.png b/doc/doc_ch/equation_a_ctc.png
new file mode 100644
index 0000000000000000000000000000000000000000..ae097610d37a88e76edefdbeb81df8403e94215f
Binary files /dev/null and b/doc/doc_ch/equation_a_ctc.png differ
diff --git a/doc/doc_ch/equation_c_ctc.png b/doc/doc_ch/equation_c_ctc.png
new file mode 100644
index 0000000000000000000000000000000000000000..67207a9937481f4920af3cbafbe1bfe8d27ee5dc
Binary files /dev/null and b/doc/doc_ch/equation_c_ctc.png differ
diff --git a/doc/doc_ch/equation_ctcloss.png b/doc/doc_ch/equation_ctcloss.png
new file mode 100644
index 0000000000000000000000000000000000000000..33ad92c9e4567d2a4a0c8fc3b2a0bf3fba5ea8f2
Binary files /dev/null and b/doc/doc_ch/equation_ctcloss.png differ
diff --git a/doc/doc_ch/equation_focal_ctc.png b/doc/doc_ch/equation_focal_ctc.png
new file mode 100644
index 0000000000000000000000000000000000000000..6ba1e8715d5876705ef429e48b5c94388fd41398
Binary files /dev/null and b/doc/doc_ch/equation_focal_ctc.png differ
diff --git a/doc/doc_ch/focal_loss_formula.png b/doc/doc_ch/focal_loss_formula.png
new file mode 100644
index 0000000000000000000000000000000000000000..971cebcd082cf5e19f9246f02216c0c14896bdc9
Binary files /dev/null and b/doc/doc_ch/focal_loss_formula.png differ
diff --git a/doc/doc_ch/focal_loss_image.png b/doc/doc_ch/focal_loss_image.png
new file mode 100644
index 0000000000000000000000000000000000000000..430550a732d4e2769151771bc85ae889dfc78fda
Binary files /dev/null and b/doc/doc_ch/focal_loss_image.png differ
diff --git a/doc/doc_ch/rec_algo_compare.png b/doc/doc_ch/rec_algo_compare.png
new file mode 100644
index 0000000000000000000000000000000000000000..2dde496c75f327ca1c0c9ccb0dbe6949215a4a1b
Binary files /dev/null and b/doc/doc_ch/rec_algo_compare.png differ
diff --git a/ppocr/losses/rec_ctc_loss.py b/ppocr/losses/rec_ctc_loss.py
index 5d09802b46d7ddfa802461760b917267155b3923..063d68e30861e092e10fa3068e4b7f4755b6197f 100755
--- a/ppocr/losses/rec_ctc_loss.py
+++ b/ppocr/losses/rec_ctc_loss.py
@@ -38,7 +38,7 @@ class CTCLoss(nn.Layer):
if self.use_focal_loss:
weight = paddle.exp(-loss)
weight = paddle.subtract(paddle.to_tensor([1.0]), weight)
- weight = paddle.square(weight) * self.focal_loss_alpha
+ weight = paddle.square(weight)
loss = paddle.multiply(loss, weight)
- loss = loss.mean() # sum
+ loss = loss.mean()
return {'loss': loss}
diff --git a/ppocr/losses/rec_enhanced_ctc_loss.py b/ppocr/losses/rec_enhanced_ctc_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..b57be6468e2ec75811442e7449525267e7d9e82e
--- /dev/null
+++ b/ppocr/losses/rec_enhanced_ctc_loss.py
@@ -0,0 +1,70 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+from .ace_loss import ACELoss
+from .center_loss import CenterLoss
+from .rec_ctc_loss import CTCLoss
+
+
+class EnhancedCTCLoss(nn.Layer):
+ def __init__(self,
+ use_focal_loss=False,
+ use_ace_loss=False,
+ ace_loss_weight=0.1,
+ use_center_loss=False,
+ center_loss_weight=0.05,
+ num_classes=6625,
+ feat_dim=96,
+ init_center=False,
+ center_file_path=None,
+ **kwargs):
+ super(EnhancedCTCLoss, self).__init__()
+ self.ctc_loss_func = CTCLoss(use_focal_loss=use_focal_loss)
+
+ self.use_ace_loss = False
+ if use_ace_loss:
+ self.use_ace_loss = use_ace_loss
+ self.ace_loss_func = ACELoss()
+ self.ace_loss_weight = ace_loss_weight
+
+ self.use_center_loss = False
+ if use_center_loss:
+ self.use_center_loss = use_center_loss
+ self.center_loss_func = CenterLoss(
+ num_classes=num_classes,
+ feat_dim=feat_dim,
+ init_center=init_center,
+ center_file_path=center_file_path)
+ self.center_loss_weight = center_loss_weight
+
+ def __call__(self, predicts, batch):
+ loss = self.ctc_loss_func(predicts, batch)["loss"]
+
+ if self.use_center_loss:
+ center_loss = self.center_loss_func(
+ predicts, batch)["loss_center"] * self.center_loss_weight
+ loss = loss + center_loss
+
+ if self.use_ace_loss:
+ ace_loss = self.ace_loss_func(
+ predicts, batch)["loss_ace"] * self.ace_loss_weight
+ loss = loss + ace_loss
+
+ return {'enhanced_ctc_loss': loss}
diff --git a/tests/ocr_det_params.txt b/tests/configs/ppocr_det_mobile_params.txt
similarity index 84%
rename from tests/ocr_det_params.txt
rename to tests/configs/ppocr_det_mobile_params.txt
index 6fd22e409a5219574b2f29285ff5ee5d2e1cf7ca..5edb14cdbf8eef87b5b5558cbd8d1a2ff54ae919 100644
--- a/tests/ocr_det_params.txt
+++ b/tests/configs/ppocr_det_mobile_params.txt
@@ -40,13 +40,13 @@ infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
--enable_mkldnn:True|False
---cpu_threads:6
+--cpu_threads:1|6
--rec_batch_num:1
--use_tensorrt:False|True
--precision:fp32|fp16|int8
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
---save_log_path:null
+null:null
--benchmark:True
null:null
===========================cpp_infer_params===========================
@@ -79,4 +79,20 @@ op.det.local_service_conf.thread_num:1|6
op.det.local_service_conf.use_trt:False|True
op.det.local_service_conf.precision:fp32|fp16|int8
pipline:pipeline_http_client.py --image_dir=../../doc/imgs
-
+===========================kl_quant_params===========================
+infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
+infer_export:tools/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o
+infer_quant:False
+inference:tools/infer/predict_det.py
+--use_gpu:True|False
+--enable_mkldnn:True|False
+--cpu_threads:1|6
+--rec_batch_num:1
+--use_tensorrt:False|True
+--precision:fp32|fp16|int8
+--det_model_dir:
+--image_dir:./inference/ch_det_data_50/all-sum-510/
+null:null
+--benchmark:True
+null:null
+null:null
\ No newline at end of file
diff --git a/tests/ocr_det_server_params.txt b/tests/configs/ppocr_det_server_params.txt
similarity index 77%
rename from tests/ocr_det_server_params.txt
rename to tests/configs/ppocr_det_server_params.txt
index 4a17fa683439fdc4716b4ed6b067a572fa3a5057..b3df1735e50d941b34eeb274c28eb4ce50d79292 100644
--- a/tests/ocr_det_server_params.txt
+++ b/tests/configs/ppocr_det_server_params.txt
@@ -12,10 +12,10 @@ train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
-trainer:norm_train|pact_train
-norm_train:tools/train.py -c tests/configs/det_r50_vd_db.yml -o Global.pretrained_model=""
-pact_train:null
-fpgm_train:null
+trainer:norm_train|pact_train|fpgm_export
+norm_train:tools/train.py -c tests/configs/det_r50_vd_db.yml -o
+quant_export:deploy/slim/quantization/export_model.py -c tests/configs/det_r50_vd_db.yml -o
+fpgm_export:deploy/slim/prune/export_prune_model.py -c tests/configs/det_r50_vd_db.yml -o
distill_train:null
null:null
null:null
@@ -34,8 +34,8 @@ distill_export:null
export1:null
export2:null
##
-infer_model:./inference/ch_ppocr_server_v2.0_det_infer/
-infer_export:null
+train_model:./inference/ch_ppocr_server_v2.0_det_train/best_accuracy
+infer_export:tools/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
diff --git a/tests/ocr_rec_params.txt b/tests/configs/ppocr_rec_mobile_params.txt
similarity index 100%
rename from tests/ocr_rec_params.txt
rename to tests/configs/ppocr_rec_mobile_params.txt
diff --git a/tests/ocr_rec_server_params.txt b/tests/configs/ppocr_rec_server_params.txt
similarity index 100%
rename from tests/ocr_rec_server_params.txt
rename to tests/configs/ppocr_rec_server_params.txt
diff --git a/tests/ocr_ppocr_mobile_params.txt b/tests/configs/ppocr_sys_mobile_params.txt
similarity index 100%
rename from tests/ocr_ppocr_mobile_params.txt
rename to tests/configs/ppocr_sys_mobile_params.txt
diff --git a/tests/ocr_ppocr_server_params.txt b/tests/configs/ppocr_sys_server_params.txt
similarity index 100%
rename from tests/ocr_ppocr_server_params.txt
rename to tests/configs/ppocr_sys_server_params.txt
diff --git a/tests/ocr_kl_quant_params.txt b/tests/ocr_kl_quant_params.txt
deleted file mode 100644
index c6ee97dca49bb7d942a339783af44053e6c79b00..0000000000000000000000000000000000000000
--- a/tests/ocr_kl_quant_params.txt
+++ /dev/null
@@ -1,51 +0,0 @@
-===========================train_params===========================
-model_name:ocr_system
-python:python3.7
-gpu_list:null
-Global.use_gpu:null
-Global.auto_cast:null
-Global.epoch_num:null
-Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:null
-Global.pretrained_model:null
-train_model_name:null
-train_infer_img_dir:null
-null:null
-##
-trainer:
-norm_train:null
-pact_train:null
-fpgm_train:null
-distill_train:null
-null:null
-null:null
-##
-===========================eval_params===========================
-eval:null
-null:null
-##
-===========================infer_params===========================
-Global.save_inference_dir:./output/
-Global.pretrained_model:
-norm_export:null
-quant_export:null
-fpgm_export:null
-distill_export:null
-export1:null
-export2:null
-##
-infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
-kl_quant:deploy/slim/quantization/quant_kl.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o
-infer_quant:True
-inference:tools/infer/predict_det.py
---use_gpu:TrueFalse
---enable_mkldnn:True|False
---cpu_threads:1|6
---rec_batch_num:1
---use_tensorrt:False|True
---precision:fp32|fp16|int8
---det_model_dir:
---image_dir:./inference/ch_det_data_50/all-sum-510/
---save_log_path:null
---benchmark:True
-null:null
diff --git a/tests/prepare.sh b/tests/prepare.sh
index ef021fa385f16ae5c9c996bfcb607f73b4129f49..f43ddb56fcd615050f110fb0d05bb178b1621da0 100644
--- a/tests/prepare.sh
+++ b/tests/prepare.sh
@@ -1,7 +1,9 @@
#!/bin/bash
FILENAME=$1
-# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer', 'cpp_infer', 'serving_infer']
+# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer',
+# 'cpp_infer', 'serving_infer', 'klquant_infer']
+
MODE=$2
dataline=$(cat ${FILENAME})
@@ -72,9 +74,9 @@ elif [ ${MODE} = "infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar
cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ocr_server_det" ]; then
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar
+ wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
- cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
+ cd ./inference && tar xf ch_ppocr_server_v2.0_det_train.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ocr_system_mobile" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
@@ -98,6 +100,12 @@ elif [ ${MODE} = "infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar
cd ./inference && tar xf ${eval_model_name}.tar && tar xf rec_inference.tar && cd ../
fi
+elif [ ${MODE} = "klquant_infer" ];then
+ if [ ${model_name} = "ocr_det" ]; then
+ wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar
+ wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
+ cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
+ fi
elif [ ${MODE} = "cpp_infer" ];then
if [ ${model_name} = "ocr_det" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
diff --git a/tests/results/det_results_gpu_trt_fp16.txt b/tests/results/ppocr_det_mobile_results_fp16.txt
similarity index 100%
rename from tests/results/det_results_gpu_trt_fp16.txt
rename to tests/results/ppocr_det_mobile_results_fp16.txt
diff --git a/tests/results/det_results_gpu_trt_fp16_cpp.txt b/tests/results/ppocr_det_mobile_results_fp16_cpp.txt
similarity index 100%
rename from tests/results/det_results_gpu_trt_fp16_cpp.txt
rename to tests/results/ppocr_det_mobile_results_fp16_cpp.txt
diff --git a/tests/results/det_results_gpu_fp32.txt b/tests/results/ppocr_det_mobile_results_fp32.txt
similarity index 100%
rename from tests/results/det_results_gpu_fp32.txt
rename to tests/results/ppocr_det_mobile_results_fp32.txt
diff --git a/tests/results/det_results_gpu_trt_fp32_cpp.txt b/tests/results/ppocr_det_mobile_results_fp32_cpp.txt
similarity index 100%
rename from tests/results/det_results_gpu_trt_fp32_cpp.txt
rename to tests/results/ppocr_det_mobile_results_fp32_cpp.txt
diff --git a/tests/test.sh b/tests/test.sh
index 5649e344b76cf4485db533eee4035e1cbdd5adae..3df0d52cc5cfa6fd8d7259d47178d8c26d2952fb 100644
--- a/tests/test.sh
+++ b/tests/test.sh
@@ -1,9 +1,16 @@
#!/bin/bash
FILENAME=$1
-# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer', 'cpp_infer']
+# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer', 'cpp_infer', 'serving_infer', 'klquant_infer']
MODE=$2
-
-dataline=$(cat ${FILENAME})
+if [ ${MODE} = "cpp_infer" ]; then
+ dataline=$(awk 'NR==67, NR==81{print}' $FILENAME)
+elif [ ${MODE} = "serving_infer" ]; then
+ dataline=$(awk 'NR==52, NR==66{print}' $FILENAME)
+elif [ ${MODE} = "klquant_infer" ]; then
+ dataline=$(awk 'NR==82, NR==98{print}' $FILENAME)
+else
+ dataline=$(awk 'NR==1, NR==51{print}' $FILENAME)
+fi
# parser params
IFS=$'\n'
@@ -144,61 +151,93 @@ benchmark_key=$(func_parser_key "${lines[49]}")
benchmark_value=$(func_parser_value "${lines[49]}")
infer_key1=$(func_parser_key "${lines[50]}")
infer_value1=$(func_parser_value "${lines[50]}")
-# parser serving
-trans_model_py=$(func_parser_value "${lines[67]}")
-infer_model_dir_key=$(func_parser_key "${lines[68]}")
-infer_model_dir_value=$(func_parser_value "${lines[68]}")
-model_filename_key=$(func_parser_key "${lines[69]}")
-model_filename_value=$(func_parser_value "${lines[69]}")
-params_filename_key=$(func_parser_key "${lines[70]}")
-params_filename_value=$(func_parser_value "${lines[70]}")
-serving_server_key=$(func_parser_key "${lines[71]}")
-serving_server_value=$(func_parser_value "${lines[71]}")
-serving_client_key=$(func_parser_key "${lines[72]}")
-serving_client_value=$(func_parser_value "${lines[72]}")
-serving_dir_value=$(func_parser_value "${lines[73]}")
-web_service_py=$(func_parser_value "${lines[74]}")
-web_use_gpu_key=$(func_parser_key "${lines[75]}")
-web_use_gpu_list=$(func_parser_value "${lines[75]}")
-web_use_mkldnn_key=$(func_parser_key "${lines[76]}")
-web_use_mkldnn_list=$(func_parser_value "${lines[76]}")
-web_cpu_threads_key=$(func_parser_key "${lines[77]}")
-web_cpu_threads_list=$(func_parser_value "${lines[77]}")
-web_use_trt_key=$(func_parser_key "${lines[78]}")
-web_use_trt_list=$(func_parser_value "${lines[78]}")
-web_precision_key=$(func_parser_key "${lines[79]}")
-web_precision_list=$(func_parser_value "${lines[79]}")
-pipeline_py=$(func_parser_value "${lines[80]}")
+# parser serving
+if [ ${MODE} = "klquant_infer" ]; then
+ # parser inference model
+ infer_model_dir_list=$(func_parser_value "${lines[1]}")
+ infer_export_list=$(func_parser_value "${lines[2]}")
+ infer_is_quant=$(func_parser_value "${lines[3]}")
+ # parser inference
+ inference_py=$(func_parser_value "${lines[4]}")
+ use_gpu_key=$(func_parser_key "${lines[5]}")
+ use_gpu_list=$(func_parser_value "${lines[5]}")
+ use_mkldnn_key=$(func_parser_key "${lines[6]}")
+ use_mkldnn_list=$(func_parser_value "${lines[6]}")
+ cpu_threads_key=$(func_parser_key "${lines[7]}")
+ cpu_threads_list=$(func_parser_value "${lines[7]}")
+ batch_size_key=$(func_parser_key "${lines[8]}")
+ batch_size_list=$(func_parser_value "${lines[8]}")
+ use_trt_key=$(func_parser_key "${lines[9]}")
+ use_trt_list=$(func_parser_value "${lines[9]}")
+ precision_key=$(func_parser_key "${lines[10]}")
+ precision_list=$(func_parser_value "${lines[10]}")
+ infer_model_key=$(func_parser_key "${lines[11]}")
+ image_dir_key=$(func_parser_key "${lines[12]}")
+ infer_img_dir=$(func_parser_value "${lines[12]}")
+ save_log_key=$(func_parser_key "${lines[13]}")
+ benchmark_key=$(func_parser_key "${lines[14]}")
+ benchmark_value=$(func_parser_value "${lines[14]}")
+ infer_key1=$(func_parser_key "${lines[15]}")
+ infer_value1=$(func_parser_value "${lines[15]}")
+fi
+# parser serving
+if [ ${MODE} = "server_infer" ]; then
+ trans_model_py=$(func_parser_value "${lines[1]}")
+ infer_model_dir_key=$(func_parser_key "${lines[2]}")
+ infer_model_dir_value=$(func_parser_value "${lines[2]}")
+ model_filename_key=$(func_parser_key "${lines[3]}")
+ model_filename_value=$(func_parser_value "${lines[3]}")
+ params_filename_key=$(func_parser_key "${lines[4]}")
+ params_filename_value=$(func_parser_value "${lines[4]}")
+ serving_server_key=$(func_parser_key "${lines[5]}")
+ serving_server_value=$(func_parser_value "${lines[5]}")
+ serving_client_key=$(func_parser_key "${lines[6]}")
+ serving_client_value=$(func_parser_value "${lines[6]}")
+ serving_dir_value=$(func_parser_value "${lines[7]}")
+ web_service_py=$(func_parser_value "${lines[8]}")
+ web_use_gpu_key=$(func_parser_key "${lines[9]}")
+ web_use_gpu_list=$(func_parser_value "${lines[9]}")
+ web_use_mkldnn_key=$(func_parser_key "${lines[10]}")
+ web_use_mkldnn_list=$(func_parser_value "${lines[10]}")
+ web_cpu_threads_key=$(func_parser_key "${lines[11]}")
+ web_cpu_threads_list=$(func_parser_value "${lines[11]}")
+ web_use_trt_key=$(func_parser_key "${lines[12]}")
+ web_use_trt_list=$(func_parser_value "${lines[12]}")
+ web_precision_key=$(func_parser_key "${lines[13]}")
+ web_precision_list=$(func_parser_value "${lines[13]}")
+ pipeline_py=$(func_parser_value "${lines[14]}")
+fi
if [ ${MODE} = "cpp_infer" ]; then
# parser cpp inference model
- cpp_infer_model_dir_list=$(func_parser_value "${lines[53]}")
- cpp_infer_is_quant=$(func_parser_value "${lines[54]}")
+ cpp_infer_model_dir_list=$(func_parser_value "${lines[1]}")
+ cpp_infer_is_quant=$(func_parser_value "${lines[2]}")
# parser cpp inference
- inference_cmd=$(func_parser_value "${lines[55]}")
- cpp_use_gpu_key=$(func_parser_key "${lines[56]}")
- cpp_use_gpu_list=$(func_parser_value "${lines[56]}")
- cpp_use_mkldnn_key=$(func_parser_key "${lines[57]}")
- cpp_use_mkldnn_list=$(func_parser_value "${lines[57]}")
- cpp_cpu_threads_key=$(func_parser_key "${lines[58]}")
- cpp_cpu_threads_list=$(func_parser_value "${lines[58]}")
- cpp_batch_size_key=$(func_parser_key "${lines[59]}")
- cpp_batch_size_list=$(func_parser_value "${lines[59]}")
- cpp_use_trt_key=$(func_parser_key "${lines[60]}")
- cpp_use_trt_list=$(func_parser_value "${lines[60]}")
- cpp_precision_key=$(func_parser_key "${lines[61]}")
- cpp_precision_list=$(func_parser_value "${lines[61]}")
- cpp_infer_model_key=$(func_parser_key "${lines[62]}")
- cpp_image_dir_key=$(func_parser_key "${lines[63]}")
- cpp_infer_img_dir=$(func_parser_value "${lines[63]}")
- cpp_infer_key1=$(func_parser_key "${lines[64]}")
- cpp_infer_value1=$(func_parser_value "${lines[64]}")
- cpp_benchmark_key=$(func_parser_key "${lines[65]}")
- cpp_benchmark_value=$(func_parser_value "${lines[65]}")
+ inference_cmd=$(func_parser_value "${lines[3]}")
+ cpp_use_gpu_key=$(func_parser_key "${lines[4]}")
+ cpp_use_gpu_list=$(func_parser_value "${lines[4]}")
+ cpp_use_mkldnn_key=$(func_parser_key "${lines[5]}")
+ cpp_use_mkldnn_list=$(func_parser_value "${lines[5]}")
+ cpp_cpu_threads_key=$(func_parser_key "${lines[6]}")
+ cpp_cpu_threads_list=$(func_parser_value "${lines[6]}")
+ cpp_batch_size_key=$(func_parser_key "${lines[7]}")
+ cpp_batch_size_list=$(func_parser_value "${lines[7]}")
+ cpp_use_trt_key=$(func_parser_key "${lines[8]}")
+ cpp_use_trt_list=$(func_parser_value "${lines[8]}")
+ cpp_precision_key=$(func_parser_key "${lines[9]}")
+ cpp_precision_list=$(func_parser_value "${lines[9]}")
+ cpp_infer_model_key=$(func_parser_key "${lines[10]}")
+ cpp_image_dir_key=$(func_parser_key "${lines[11]}")
+ cpp_infer_img_dir=$(func_parser_value "${lines[12]}")
+ cpp_infer_key1=$(func_parser_key "${lines[13]}")
+ cpp_infer_value1=$(func_parser_value "${lines[13]}")
+ cpp_benchmark_key=$(func_parser_key "${lines[14]}")
+ cpp_benchmark_value=$(func_parser_value "${lines[14]}")
fi
+
LOG_PATH="./tests/output"
mkdir -p ${LOG_PATH}
status_log="${LOG_PATH}/results.log"
@@ -414,7 +453,7 @@ function func_cpp_inference(){
done
}
-if [ ${MODE} = "infer" ]; then
+if [ ${MODE} = "infer" ] || [ ${MODE} = "klquant_infer" ]; then
GPUID=$3
if [ ${#GPUID} -le 0 ];then
env=" "
@@ -447,7 +486,6 @@ if [ ${MODE} = "infer" ]; then
func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${LOG_PATH}" "${infer_img_dir}" ${is_quant}
Count=$(($Count + 1))
done
-
elif [ ${MODE} = "cpp_infer" ]; then
GPUID=$3
if [ ${#GPUID} -le 0 ];then
@@ -481,6 +519,8 @@ elif [ ${MODE} = "serving_infer" ]; then
#run serving
func_serving "${web_service_cmd}"
+
+
else
IFS="|"
export Count=0
diff --git a/tools/export_center.py b/tools/export_center.py
new file mode 100644
index 0000000000000000000000000000000000000000..c46e8b9d58997b9b66c6ce81b2558ecd4cad0e81
--- /dev/null
+++ b/tools/export_center.py
@@ -0,0 +1,77 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+import pickle
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+
+from ppocr.data import build_dataloader
+from ppocr.modeling.architectures import build_model
+from ppocr.postprocess import build_post_process
+from ppocr.utils.save_load import init_model, load_dygraph_params
+from ppocr.utils.utility import print_dict
+import tools.program as program
+
+
+def main():
+ global_config = config['Global']
+ # build dataloader
+ config['Eval']['dataset']['name'] = config['Train']['dataset']['name']
+ config['Eval']['dataset']['data_dir'] = config['Train']['dataset'][
+ 'data_dir']
+ config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][
+ 'label_file_list']
+ eval_dataloader = build_dataloader(config, 'Eval', device, logger)
+
+ # build post process
+ post_process_class = build_post_process(config['PostProcess'],
+ global_config)
+
+ # build model
+ # for rec algorithm
+ if hasattr(post_process_class, 'character'):
+ char_num = len(getattr(post_process_class, 'character'))
+ config['Architecture']["Head"]['out_channels'] = char_num
+
+ #set return_features = True
+ config['Architecture']["Head"]["return_feats"] = True
+
+ model = build_model(config['Architecture'])
+
+ best_model_dict = load_dygraph_params(config, model, logger, None)
+ if len(best_model_dict):
+ logger.info('metric in ckpt ***************')
+ for k, v in best_model_dict.items():
+ logger.info('{}:{}'.format(k, v))
+
+ # get features from train data
+ char_center = program.get_center(model, eval_dataloader, post_process_class)
+
+ #serialize to disk
+ with open("train_center.pkl", 'wb') as f:
+ pickle.dump(char_center, f)
+ return
+
+
+if __name__ == '__main__':
+ config, device, logger, vdl_writer = program.preprocess()
+ main()
diff --git a/tools/program.py b/tools/program.py
index ddf39e65c34012ae36efd2752946f737f365b1c1..72c3da495e86d32c91c508b35afd8d95cf1d0941 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -404,6 +404,57 @@ def eval(model,
return metric
+def update_center(char_center, post_result, preds):
+ result, label = post_result
+ feats, logits = preds
+ logits = paddle.argmax(logits, axis=-1)
+ feats = feats.numpy()
+ logits = logits.numpy()
+
+ for idx_sample in range(len(label)):
+ if result[idx_sample][0] == label[idx_sample][0]:
+ feat = feats[idx_sample]
+ logit = logits[idx_sample]
+ for idx_time in range(len(logit)):
+ index = logit[idx_time]
+ if index in char_center.keys():
+ char_center[index][0] = (
+ char_center[index][0] * char_center[index][1] +
+ feat[idx_time]) / (char_center[index][1] + 1)
+ char_center[index][1] += 1
+ else:
+ char_center[index] = [feat[idx_time], 1]
+ return char_center
+
+
+def get_center(model, eval_dataloader, post_process_class):
+ pbar = tqdm(total=len(eval_dataloader), desc='get center:')
+ max_iter = len(eval_dataloader) - 1 if platform.system(
+ ) == "Windows" else len(eval_dataloader)
+ char_center = dict()
+ for idx, batch in enumerate(eval_dataloader):
+ if idx >= max_iter:
+ break
+ images = batch[0]
+ start = time.time()
+ preds = model(images)
+
+ batch = [item.numpy() for item in batch]
+ # Obtain usable results from post-processing methods
+ total_time += time.time() - start
+ # Evaluate the results of the current batch
+ post_result = post_process_class(preds, batch[1])
+
+ #update char_center
+ char_center = update_center(char_center, post_result, preds)
+ pbar.update(1)
+
+ pbar.close()
+ for key in char_center.keys():
+ char_center[key] = char_center[key][0]
+ return char_center
+
+
def preprocess(is_train=False):
FLAGS = ArgsParser().parse_args()
profiler_options = FLAGS.profiler_options