diff --git a/README.md b/README.md index 9bb055cd423ef526046f27687ef5586730a18ff4..5b6e4bd0b594d71edd3ab4f8da350475c3ac83b8 100644 --- a/README.md +++ b/README.md @@ -173,7 +173,7 @@ This project is released under visualize = bool(stoi(config_map_["visualize"])); + + this->use_tensorrt = bool(stoi(config_map_["use_tensorrt"])); + + this->use_fp16 = bool(stod(config_map_["use_fp16"])); } bool use_gpu = false; @@ -96,6 +100,10 @@ public: bool visualize = true; + bool use_tensorrt = false; + + bool use_fp16 = false; + void PrintConfigInfo(); private: diff --git a/deploy/cpp_infer/include/ocr_cls.h b/deploy/cpp_infer/include/ocr_cls.h index 87772cc109b18beb6a31940311389e2f0596b031..41494085a797c7a4490942741e6e888033c0be00 100644 --- a/deploy/cpp_infer/include/ocr_cls.h +++ b/deploy/cpp_infer/include/ocr_cls.h @@ -39,7 +39,8 @@ public: explicit Classifier(const std::string &model_dir, const bool &use_gpu, const int &gpu_id, const int &gpu_mem, const int &cpu_math_library_num_threads, - const bool &use_mkldnn, const double &cls_thresh) { + const bool &use_mkldnn, const double &cls_thresh, + const bool &use_tensorrt, const bool &use_fp16) { this->use_gpu_ = use_gpu; this->gpu_id_ = gpu_id; this->gpu_mem_ = gpu_mem; @@ -47,6 +48,8 @@ public: this->use_mkldnn_ = use_mkldnn; this->cls_thresh = cls_thresh; + this->use_tensorrt_ = use_tensorrt; + this->use_fp16_ = use_fp16; LoadModel(model_dir); } @@ -69,7 +72,8 @@ private: std::vector mean_ = {0.5f, 0.5f, 0.5f}; std::vector scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; bool is_scale_ = true; - + bool use_tensorrt_ = false; + bool use_fp16_ = false; // pre-process ClsResizeImg resize_op_; Normalize normalize_op_; diff --git a/deploy/cpp_infer/include/ocr_det.h b/deploy/cpp_infer/include/ocr_det.h index d50fd70af5ec04105e993e358e459f1940d36c7f..bab9c95fa4a3f1cb160ccbf9ca4587fa4c2ba16a 100644 --- a/deploy/cpp_infer/include/ocr_det.h +++ b/deploy/cpp_infer/include/ocr_det.h @@ -44,8 +44,8 @@ public: const bool &use_mkldnn, const int &max_side_len, const double &det_db_thresh, const double &det_db_box_thresh, - const double &det_db_unclip_ratio, - const bool &visualize) { + const double &det_db_unclip_ratio, const bool &visualize, + const bool &use_tensorrt, const bool &use_fp16) { this->use_gpu_ = use_gpu; this->gpu_id_ = gpu_id; this->gpu_mem_ = gpu_mem; @@ -59,6 +59,8 @@ public: this->det_db_unclip_ratio_ = det_db_unclip_ratio; this->visualize_ = visualize; + this->use_tensorrt_ = use_tensorrt; + this->use_fp16_ = use_fp16; LoadModel(model_dir); } @@ -85,6 +87,8 @@ private: double det_db_unclip_ratio_ = 2.0; bool visualize_ = true; + bool use_tensorrt_ = false; + bool use_fp16_ = false; std::vector mean_ = {0.485f, 0.456f, 0.406f}; std::vector scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f}; diff --git a/deploy/cpp_infer/include/ocr_rec.h b/deploy/cpp_infer/include/ocr_rec.h index 14b77b084a30ade71efe626430cb854d0bfbc1ce..94d605a96e1f43423b15b0d81c7cd88f618ea4d3 100644 --- a/deploy/cpp_infer/include/ocr_rec.h +++ b/deploy/cpp_infer/include/ocr_rec.h @@ -41,12 +41,15 @@ public: explicit CRNNRecognizer(const std::string &model_dir, const bool &use_gpu, const int &gpu_id, const int &gpu_mem, const int &cpu_math_library_num_threads, - const bool &use_mkldnn, const string &label_path) { + const bool &use_mkldnn, const string &label_path, + const bool &use_tensorrt, const bool &use_fp16) { this->use_gpu_ = use_gpu; this->gpu_id_ = gpu_id; this->gpu_mem_ = gpu_mem; this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; this->use_mkldnn_ = use_mkldnn; + this->use_tensorrt_ = use_tensorrt; + this->use_fp16_ = use_fp16; this->label_list_ = Utility::ReadDict(label_path); this->label_list_.insert(this->label_list_.begin(), @@ -76,7 +79,8 @@ private: std::vector mean_ = {0.5f, 0.5f, 0.5f}; std::vector scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; bool is_scale_ = true; - + bool use_tensorrt_ = false; + bool use_fp16_ = false; // pre-process CrnnResizeImg resize_op_; Normalize normalize_op_; diff --git a/deploy/cpp_infer/src/main.cpp b/deploy/cpp_infer/src/main.cpp index 21890d45ce8c6b13e280c87bdfad8ca8e48f8523..f40e5edfcc2c19e0a61894bed11aef636317e056 100644 --- a/deploy/cpp_infer/src/main.cpp +++ b/deploy/cpp_infer/src/main.cpp @@ -54,18 +54,20 @@ int main(int argc, char **argv) { config.gpu_mem, config.cpu_math_library_num_threads, config.use_mkldnn, config.max_side_len, config.det_db_thresh, config.det_db_box_thresh, config.det_db_unclip_ratio, - config.visualize); + config.visualize, config.use_tensorrt, config.use_fp16); Classifier *cls = nullptr; if (config.use_angle_cls == true) { cls = new Classifier(config.cls_model_dir, config.use_gpu, config.gpu_id, config.gpu_mem, config.cpu_math_library_num_threads, - config.use_mkldnn, config.cls_thresh); + config.use_mkldnn, config.cls_thresh, + config.use_tensorrt, config.use_fp16); } CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id, config.gpu_mem, config.cpu_math_library_num_threads, - config.use_mkldnn, config.char_list_file); + config.use_mkldnn, config.char_list_file, + config.use_tensorrt, config.use_fp16); auto start = std::chrono::system_clock::now(); std::vector>> boxes; @@ -75,11 +77,11 @@ int main(int argc, char **argv) { auto end = std::chrono::system_clock::now(); auto duration = std::chrono::duration_cast(end - start); - std::cout << "花费了" + std::cout << "Cost" << double(duration.count()) * std::chrono::microseconds::period::num / std::chrono::microseconds::period::den - << "秒" << std::endl; + << "s" << std::endl; return 0; } diff --git a/deploy/cpp_infer/src/ocr_cls.cpp b/deploy/cpp_infer/src/ocr_cls.cpp index 9757b482d4f407cefd8db5bd611000062f754645..3aeda2ed0c286d1ec5e816e15ac5500f53c9a3a2 100644 --- a/deploy/cpp_infer/src/ocr_cls.cpp +++ b/deploy/cpp_infer/src/ocr_cls.cpp @@ -76,6 +76,13 @@ void Classifier::LoadModel(const std::string &model_dir) { if (this->use_gpu_) { config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); + if (this->use_tensorrt_) { + config.EnableTensorRtEngine( + 1 << 20, 10, 3, + this->use_fp16_ ? paddle_infer::Config::Precision::kHalf + : paddle_infer::Config::Precision::kFloat32, + false, false); + } } else { config.DisableGpu(); if (this->use_mkldnn_) { diff --git a/deploy/cpp_infer/src/ocr_det.cpp b/deploy/cpp_infer/src/ocr_det.cpp index c6c93991743b28609e880a9534d3228daf2c5bef..3678f37dfb1c0c4aed392dd31830e732e2854899 100644 --- a/deploy/cpp_infer/src/ocr_det.cpp +++ b/deploy/cpp_infer/src/ocr_det.cpp @@ -24,10 +24,13 @@ void DBDetector::LoadModel(const std::string &model_dir) { if (this->use_gpu_) { config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); - // config.EnableTensorRtEngine( - // 1 << 20, 1, 3, - // AnalysisConfig::Precision::kFloat32, - // false, false); + if (this->use_tensorrt_) { + config.EnableTensorRtEngine( + 1 << 20, 10, 3, + this->use_fp16_ ? paddle_infer::Config::Precision::kHalf + : paddle_infer::Config::Precision::kFloat32, + false, false); + } } else { config.DisableGpu(); if (this->use_mkldnn_) { diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp index e33695a74d72020f4397b84fcc07e9d9bf01486c..27cfe4c95009c6454514a43e304a23503fe5fa9a 100644 --- a/deploy/cpp_infer/src/ocr_rec.cpp +++ b/deploy/cpp_infer/src/ocr_rec.cpp @@ -76,7 +76,7 @@ void CRNNRecognizer::Run(std::vector>> boxes, float(*std::max_element(&predict_batch[n * predict_shape[2]], &predict_batch[(n + 1) * predict_shape[2]])); - if (argmax_idx > 0 && (not(i > 0 && argmax_idx == last_index))) { + if (argmax_idx > 0 && (!(i > 0 && argmax_idx == last_index))) { score += max_value; count += 1; str_res.push_back(label_list_[argmax_idx]); @@ -99,6 +99,13 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { if (this->use_gpu_) { config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); + if (this->use_tensorrt_) { + config.EnableTensorRtEngine( + 1 << 20, 10, 3, + this->use_fp16_ ? paddle_infer::Config::Precision::kHalf + : paddle_infer::Config::Precision::kFloat32, + false, false); + } } else { config.DisableGpu(); if (this->use_mkldnn_) { @@ -176,4 +183,4 @@ cv::Mat CRNNRecognizer::GetRotateCropImage(const cv::Mat &srcimage, } } -} // namespace PaddleOCR \ No newline at end of file +} // namespace PaddleOCR diff --git a/deploy/cpp_infer/tools/config.txt b/deploy/cpp_infer/tools/config.txt index 34f47ed82015b5c27a61a34d1de22f3251e0fd75..e185377e2f2c9cbd5c1d8ed09ba43df9c41c05d2 100644 --- a/deploy/cpp_infer/tools/config.txt +++ b/deploy/cpp_infer/tools/config.txt @@ -24,3 +24,7 @@ char_list_file ../../ppocr/utils/ppocr_keys_v1.txt # show the detection results visualize 1 +# use_tensorrt +use_tensorrt 0 +use_fp16 0 + diff --git a/deploy/slim/quantization/README.md b/deploy/slim/quantization/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ccd4d06b4f16165f968402751b63a8fe58773e0b --- /dev/null +++ b/deploy/slim/quantization/README.md @@ -0,0 +1,61 @@ + +## 介绍 +复杂的模型有利于提高模型的性能,但也导致模型中存在一定冗余,模型量化将全精度缩减到定点数减少这种冗余,达到减少模型计算复杂度,提高模型推理性能的目的。 +模型量化可以在基本不损失模型的精度的情况下,将FP32精度的模型参数转换为Int8精度,减小模型参数大小并加速计算,使用量化后的模型在移动端等部署时更具备速度优势。 + +本教程将介绍如何使用飞桨模型压缩库PaddleSlim做PaddleOCR模型的压缩。 +[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim) 集成了模型剪枝、量化(包括量化训练和离线量化)、蒸馏和神经网络搜索等多种业界常用且领先的模型压缩功能,如果您感兴趣,可以关注并了解。 + +在开始本教程之前,建议先了解[PaddleOCR模型的训练方法](../../../doc/doc_ch/quickstart.md)以及[PaddleSlim](https://paddleslim.readthedocs.io/zh_CN/latest/index.html) + + +## 快速开始 +量化多适用于轻量模型在移动端的部署,当训练出一个模型后,如果希望进一步的压缩模型大小并加速预测,可使用量化的方法压缩模型。 + +模型量化主要包括五个步骤: +1. 安装 PaddleSlim +2. 准备训练好的模型 +3. 量化训练 +4. 导出量化推理模型 +5. 量化模型预测部署 + +### 1. 安装PaddleSlim + +```bash +git clone https://github.com/PaddlePaddle/PaddleSlim.git +cd Paddleslim +python setup.py install +``` + +### 2. 准备训练好的模型 + +PaddleOCR提供了一系列训练好的[模型](../../../doc/doc_ch/models_list.md),如果待量化的模型不在列表中,需要按照[常规训练](../../../doc/doc_ch/quickstart.md)方法得到训练好的模型。 + +### 3. 量化训练 +量化训练包括离线量化训练和在线量化训练,在线量化训练效果更好,需加载预训练模型,在定义好量化策略后即可对模型进行量化。 + + +量化训练的代码位于slim/quantization/quant.py 中,比如训练检测模型,训练指令如下: +```bash +python deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights='your trained model' Global.save_model_dir=./output/quant_model + +# 比如下载提供的训练模型 +wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar +tar -xf ch_ppocr_mobile_v2.0_det_train.tar +python deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./ch_ppocr_mobile_v2.0_det_train/best_accuracy Global.save_model_dir=./output/quant_model + +``` +如果要训练识别模型的量化,修改配置文件和加载的模型参数即可。 + +### 4. 导出模型 + +在得到量化训练保存的模型后,我们可以将其导出为inference_model,用于预测部署: + +```bash +python deploy/slim/quantization/export_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=output/quant_model/best_accuracy Global.save_model_dir=./output/quant_inference_model +``` + +### 5. 量化模型部署 + +上述步骤导出的量化模型,参数精度仍然是FP32,但是参数的数值范围是int8,导出的模型可以通过PaddleLite的opt模型转换工具完成模型转换。 +量化模型部署的可参考 [移动端模型部署](../../lite/readme.md) diff --git a/deploy/slim/quantization/README_en.md b/deploy/slim/quantization/README_en.md new file mode 100644 index 0000000000000000000000000000000000000000..7da0b3e7e7d5f72e45dc17864630b9725f6fc8ba --- /dev/null +++ b/deploy/slim/quantization/README_en.md @@ -0,0 +1,68 @@ + +## Introduction + +Generally, a more complex model would achive better performance in the task, but it also leads to some redundancy in the model. +Quantization is a technique that reduces this redundancy by reducing the full precision data to a fixed number, +so as to reduce model calculation complexity and improve model inference performance. + +This example uses PaddleSlim provided [APIs of Quantization](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/) to compress the OCR model. + +It is recommended that you could understand following pages before reading this example: +- [The training strategy of OCR model](../../../doc/doc_en/quickstart_en.md) +- [PaddleSlim Document](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/) + +## Quick Start +Quantization is mostly suitable for the deployment of lightweight models on mobile terminals. +After training, if you want to further compress the model size and accelerate the prediction, you can use quantization methods to compress the model according to the following steps. + +1. Install PaddleSlim +2. Prepare trained model +3. Quantization-Aware Training +4. Export inference model +5. Deploy quantization inference model + + +### 1. Install PaddleSlim + +```bash +git clone https://github.com/PaddlePaddle/PaddleSlim.git +cd Paddleslim +python setup.py install +``` + + +### 2. Download Pretrain Model +PaddleOCR provides a series of trained [models](../../../doc/doc_en/models_list_en.md). +If the model to be quantified is not in the list, you need to follow the [Regular Training](../../../doc/doc_en/quickstart_en.md) method to get the trained model. + + +### 3. Quant-Aware Training +Quantization training includes offline quantization training and online quantization training. +Online quantization training is more effective. It is necessary to load the pre-training model. +After the quantization strategy is defined, the model can be quantified. + +The code for quantization training is located in `slim/quantization/quant.py`. For example, to train a detection model, the training instructions are as follows: +```bash +python deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights='your trained model' Global.save_model_dir=./output/quant_model + +# download provided model +wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar +tar -xf ch_ppocr_mobile_v2.0_det_train.tar +python deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./ch_ppocr_mobile_v2.0_det_train/best_accuracy Global.save_model_dir=./output/quant_model + +``` + + +### 4. Export inference model + +After getting the model after pruning and finetuning we, can export it as inference_model for predictive deployment: + +```bash +python deploy/slim/quantization/export_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=output/quant_model/best_accuracy Global.save_model_dir=./output/quant_inference_model +``` + +### 5. Deploy +The numerical range of the quantized model parameters derived from the above steps is still FP32, but the numerical range of the parameters is int8. +The derived model can be converted through the `opt tool` of PaddleLite. + +For quantitative model deployment, please refer to [Mobile terminal model deployment](../../lite/readme_en.md) diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py new file mode 100755 index 0000000000000000000000000000000000000000..100b107a1deb1ce9932c9cefa50659c060f5803e --- /dev/null +++ b/deploy/slim/quantization/export_model.py @@ -0,0 +1,118 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..', '..', '..'))) +sys.path.append( + os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools'))) + +import argparse + +import paddle +from paddle.jit import to_static + +from ppocr.modeling.architectures import build_model +from ppocr.postprocess import build_post_process +from ppocr.utils.save_load import init_model +from ppocr.utils.logging import get_logger +from tools.program import load_config, merge_config, ArgsParser +from ppocr.metrics import build_metric +import tools.program as program +from paddleslim.dygraph.quant import QAT +from ppocr.data import build_dataloader + + +def main(): + ############################################################################################################ + # 1. quantization configs + ############################################################################################################ + quant_config = { + # weight preprocess type, default is None and no preprocessing is performed. + 'weight_preprocess_type': None, + # activation preprocess type, default is None and no preprocessing is performed. + 'activation_preprocess_type': None, + # weight quantize type, default is 'channel_wise_abs_max' + 'weight_quantize_type': 'channel_wise_abs_max', + # activation quantize type, default is 'moving_average_abs_max' + 'activation_quantize_type': 'moving_average_abs_max', + # weight quantize bit num, default is 8 + 'weight_bits': 8, + # activation quantize bit num, default is 8 + 'activation_bits': 8, + # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' + 'dtype': 'int8', + # window size for 'range_abs_max' quantization. default is 10000 + 'window_size': 10000, + # The decay coefficient of moving average, default is 0.9 + 'moving_rate': 0.9, + # for dygraph quantization, layers of type in quantizable_layer_type will be quantized + 'quantizable_layer_type': ['Conv2D', 'Linear'], + } + FLAGS = ArgsParser().parse_args() + config = load_config(FLAGS.config) + merge_config(FLAGS.opt) + logger = get_logger() + # build post process + + post_process_class = build_post_process(config['PostProcess'], + config['Global']) + + # build model + # for rec algorithm + if hasattr(post_process_class, 'character'): + char_num = len(getattr(post_process_class, 'character')) + config['Architecture']["Head"]['out_channels'] = char_num + model = build_model(config['Architecture']) + + # get QAT model + quanter = QAT(config=quant_config) + quanter.quantize(model) + + init_model(config, model, logger) + model.eval() + + # build metric + eval_class = build_metric(config['Metric']) + + # build dataloader + valid_dataloader = build_dataloader(config, 'Eval', device, logger) + + # start eval + metirc = program.eval(model, valid_dataloader, post_process_class, + eval_class) + logger.info('metric eval ***************') + for k, v in metirc.items(): + logger.info('{}:{}'.format(k, v)) + + save_path = '{}/inference'.format(config['Global']['save_inference_dir']) + infer_shape = [3, 32, 100] if config['Architecture'][ + 'model_type'] != "det" else [3, 640, 640] + + quanter.save_quantized_model( + model, + save_path, + input_spec=[ + paddle.static.InputSpec( + shape=[None] + infer_shape, dtype='float32') + ]) + logger.info('inference QAT model is saved to {}'.format(save_path)) + + +if __name__ == "__main__": + config, device, logger, vdl_writer = program.preprocess() + main() diff --git a/deploy/slim/quantization/quant.py b/deploy/slim/quantization/quant.py new file mode 100755 index 0000000000000000000000000000000000000000..7671e5f871ce6769fc51876d1fa2e5f0af63d904 --- /dev/null +++ b/deploy/slim/quantization/quant.py @@ -0,0 +1,166 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..', '..', '..'))) +sys.path.append( + os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools'))) + +import yaml +import paddle +import paddle.distributed as dist + +paddle.seed(2) + +from ppocr.data import build_dataloader +from ppocr.modeling.architectures import build_model +from ppocr.losses import build_loss +from ppocr.optimizer import build_optimizer +from ppocr.postprocess import build_post_process +from ppocr.metrics import build_metric +from ppocr.utils.save_load import init_model +import tools.program as program +from paddleslim.dygraph.quant import QAT + +dist.get_world_size() + + +class PACT(paddle.nn.Layer): + def __init__(self): + super(PACT, self).__init__() + alpha_attr = paddle.ParamAttr( + name=self.full_name() + ".pact", + initializer=paddle.nn.initializer.Constant(value=20), + learning_rate=1.0, + regularizer=paddle.regularizer.L2Decay(2e-5)) + + self.alpha = self.create_parameter( + shape=[1], attr=alpha_attr, dtype='float32') + + def forward(self, x): + out_left = paddle.nn.functional.relu(x - self.alpha) + out_right = paddle.nn.functional.relu(-self.alpha - x) + x = x - out_left + out_right + return x + + +quant_config = { + # weight preprocess type, default is None and no preprocessing is performed. + 'weight_preprocess_type': None, + # activation preprocess type, default is None and no preprocessing is performed. + 'activation_preprocess_type': None, + # weight quantize type, default is 'channel_wise_abs_max' + 'weight_quantize_type': 'channel_wise_abs_max', + # activation quantize type, default is 'moving_average_abs_max' + 'activation_quantize_type': 'moving_average_abs_max', + # weight quantize bit num, default is 8 + 'weight_bits': 8, + # activation quantize bit num, default is 8 + 'activation_bits': 8, + # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' + 'dtype': 'int8', + # window size for 'range_abs_max' quantization. default is 10000 + 'window_size': 10000, + # The decay coefficient of moving average, default is 0.9 + 'moving_rate': 0.9, + # for dygraph quantization, layers of type in quantizable_layer_type will be quantized + 'quantizable_layer_type': ['Conv2D', 'Linear'], +} + + +def main(config, device, logger, vdl_writer): + # init dist environment + if config['Global']['distributed']: + dist.init_parallel_env() + + global_config = config['Global'] + + # build dataloader + train_dataloader = build_dataloader(config, 'Train', device, logger) + if config['Eval']: + valid_dataloader = build_dataloader(config, 'Eval', device, logger) + else: + valid_dataloader = None + + # build post process + post_process_class = build_post_process(config['PostProcess'], + global_config) + + # build model + # for rec algorithm + if hasattr(post_process_class, 'character'): + char_num = len(getattr(post_process_class, 'character')) + config['Architecture']["Head"]['out_channels'] = char_num + model = build_model(config['Architecture']) + + # prepare to quant + quanter = QAT(config=quant_config, act_preprocess=PACT) + quanter.quantize(model) + + if config['Global']['distributed']: + model = paddle.DataParallel(model) + + # build loss + loss_class = build_loss(config['Loss']) + + # build optim + optimizer, lr_scheduler = build_optimizer( + config['Optimizer'], + epochs=config['Global']['epoch_num'], + step_each_epoch=len(train_dataloader), + parameters=model.parameters()) + + # build metric + eval_class = build_metric(config['Metric']) + # load pretrain model + pre_best_model_dict = init_model(config, model, logger, optimizer) + + logger.info('train dataloader has {} iters, valid dataloader has {} iters'. + format(len(train_dataloader), len(valid_dataloader))) + # start train + program.train(config, train_dataloader, valid_dataloader, device, model, + loss_class, optimizer, lr_scheduler, post_process_class, + eval_class, pre_best_model_dict, logger, vdl_writer) + + +def test_reader(config, device, logger): + loader = build_dataloader(config, 'Train', device, logger) + import time + starttime = time.time() + count = 0 + try: + for data in loader(): + count += 1 + if count % 1 == 0: + batch_time = time.time() - starttime + starttime = time.time() + logger.info("reader: {}, {}, {}".format( + count, len(data[0]), batch_time)) + except Exception as e: + logger.info(e) + logger.info("finish reader: {}, Success!".format(count)) + + +if __name__ == '__main__': + config, device, logger, vdl_writer = program.preprocess(is_train=True) + main(config, device, logger, vdl_writer) + # test_reader(config, device, logger) diff --git a/doc/doc_ch/angle_class.md b/doc/doc_ch/angle_class.md index 846be15f834952587e8d3b2533ff147375db7c31..4d7ff0d7aa839591df6e359d4f7295ab2f0cc445 100644 --- a/doc/doc_ch/angle_class.md +++ b/doc/doc_ch/angle_class.md @@ -21,9 +21,8 @@ ln -sf /train_data/cls/dataset ``` " 图像文件名 图像标注信息 " - -train_data/cls/word_001.jpg 0 -train_data/cls/word_002.jpg 180 +train/word_001.jpg 0 +train/word_002.jpg 180 ``` 最终训练集应有如下文件结构: @@ -55,6 +54,8 @@ train_data/cls/word_002.jpg 180 ### 启动训练 +将准备好的txt文件和图片文件夹路径分别写入配置文件的 `Train/Eval.dataset.label_file_list` 和 `Train/Eval.dataset.data_dir` 字段下,`Train/Eval.dataset.data_dir`字段下的路径和文件里记载的图片名构成了图片的绝对路径。 + PaddleOCR提供了训练脚本、评估脚本和预测脚本。 开始训练: diff --git a/doc/doc_ch/tree.md b/doc/doc_ch/tree.md index 5f048db022dbe422a78f87b0236d04e00ccc4d48..c222bcb447292fb3644c6d6fc6cf013a67b9dff3 100644 --- a/doc/doc_ch/tree.md +++ b/doc/doc_ch/tree.md @@ -211,6 +211,6 @@ PaddleOCR ├── README_ch.md // 中文说明文档 ├── README_en.md // 英文说明文档 ├── README.md // 主页说明文档 -├── requirements.txt // 安装依赖 +├── requirements.txt // 安装依赖 ├── setup.py // whl包打包脚本 ├── train.sh // 启动训练脚本 diff --git a/doc/doc_en/angle_class_en.md b/doc/doc_en/angle_class_en.md index e6157d1635431a45b3bc9392d1115dcdd917aeeb..8d9328700f3e638eb4576d132aa32fb93b3ad0c0 100644 --- a/doc/doc_en/angle_class_en.md +++ b/doc/doc_en/angle_class_en.md @@ -23,8 +23,8 @@ First put the training images in the same folder (train_images), and use a txt f ``` " Image file name Image annotation " -train_data/word_001.jpg 0 -train_data/word_002.jpg 180 +train/word_001.jpg 0 +train/word_002.jpg 180 ``` The final training set should have the following file structure: @@ -57,6 +57,7 @@ containing all images (test) and a cls_gt_test.txt. The structure of the test se ``` ### TRAINING +Write the prepared txt file and image folder path into the configuration file under the `Train/Eval.dataset.label_file_list` and `Train/Eval.dataset.data_dir` fields, the absolute path of the image consists of the `Train/Eval.dataset.data_dir` field and the image name recorded in the txt file. PaddleOCR provides training scripts, evaluation scripts, and prediction scripts. diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index bd0f92e0d759204b33b6cb9b261531d61134605e..a86fc8382f40b5b73edc7ec8e9d4dbe3e5822283 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -26,6 +26,8 @@ class RecMetric(object): all_num = 0 norm_edit_dis = 0.0 for (pred, pred_conf), (target, _) in zip(preds, labels): + pred = pred.replace(" ", "") + target = target.replace(" ", "") norm_edit_dis += Levenshtein.distance(pred, target) / max( len(pred), len(target)) if pred == target: diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py index 28fbc2b1c6e79ade2ae0ba68b001b8a8e65a7f01..29576d971486326aec3c93601656d7b982ef3336 100755 --- a/ppocr/utils/utility.py +++ b/ppocr/utils/utility.py @@ -57,7 +57,7 @@ def get_image_file_list(img_file): elif os.path.isdir(img_file): for single_file in os.listdir(img_file): file_path = os.path.join(img_file, single_file) - if imghdr.what(file_path) in img_end: + if os.path.isfile(file_path) and imghdr.what(file_path) in img_end: imgs_lists.append(file_path) if len(imgs_lists) == 0: raise Exception("not found any img file in {}".format(img_file))