未验证 提交 da02c71d 编写于 作者: T topduke 提交者: GitHub

Merge branch 'PaddlePaddle:dygraph' into dygraph

......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle_api.h" // NOLINT
#include <chrono>
#include "paddle_api.h" // NOLINT
#include "paddle_place.h"
#include "cls_process.h"
#include "crnn_process.h"
#include "db_post_process.h"
#include "AutoLog/auto_log/lite_autolog.h"
using namespace paddle::lite_api; // NOLINT
using namespace std;
......@@ -27,7 +29,7 @@ void NeonMeanScale(const float *din, float *dout, int size,
const std::vector<float> mean,
const std::vector<float> scale) {
if (mean.size() != 3 || scale.size() != 3) {
std::cerr << "[ERROR] mean or scale size must equal to 3\n";
std::cerr << "[ERROR] mean or scale size must equal to 3" << std::endl;
exit(1);
}
float32x4_t vmean0 = vdupq_n_f32(mean[0]);
......@@ -159,7 +161,8 @@ void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img,
std::vector<float> &rec_text_score,
std::vector<std::string> charactor_dict,
std::shared_ptr<PaddlePredictor> predictor_cls,
int use_direction_classify) {
int use_direction_classify,
std::vector<double> *times) {
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
......@@ -226,7 +229,7 @@ void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img,
std::vector<std::vector<std::vector<int>>>
RunDetModel(std::shared_ptr<PaddlePredictor> predictor, cv::Mat img,
std::map<std::string, double> Config) {
std::map<std::string, double> Config, std::vector<double> *times) {
// Read img
int max_side_len = int(Config["max_side_len"]);
int det_db_use_dilate = int(Config["det_db_use_dilate"]);
......@@ -234,6 +237,7 @@ RunDetModel(std::shared_ptr<PaddlePredictor> predictor, cv::Mat img,
cv::Mat srcimg;
img.copyTo(srcimg);
auto preprocess_start = std::chrono::steady_clock::now();
std::vector<float> ratio_hw;
img = DetResizeImg(img, max_side_len, ratio_hw);
cv::Mat img_fp;
......@@ -248,8 +252,10 @@ RunDetModel(std::shared_ptr<PaddlePredictor> predictor, cv::Mat img,
std::vector<float> scale = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
const float *dimg = reinterpret_cast<const float *>(img_fp.data);
NeonMeanScale(dimg, data0, img_fp.rows * img_fp.cols, mean, scale);
auto preprocess_end = std::chrono::steady_clock::now();
// Run predictor
auto inference_start = std::chrono::steady_clock::now();
predictor->Run();
// Get output and post process
......@@ -257,8 +263,10 @@ RunDetModel(std::shared_ptr<PaddlePredictor> predictor, cv::Mat img,
std::move(predictor->GetOutput(0)));
auto *outptr = output_tensor->data<float>();
auto shape_out = output_tensor->shape();
auto inference_end = std::chrono::steady_clock::now();
// Save output
auto postprocess_start = std::chrono::steady_clock::now();
float pred[shape_out[2] * shape_out[3]];
unsigned char cbuf[shape_out[2] * shape_out[3]];
......@@ -287,14 +295,35 @@ RunDetModel(std::shared_ptr<PaddlePredictor> predictor, cv::Mat img,
std::vector<std::vector<std::vector<int>>> filter_boxes =
FilterTagDetRes(boxes, ratio_hw[0], ratio_hw[1], srcimg);
auto postprocess_end = std::chrono::steady_clock::now();
std::chrono::duration<float> preprocess_diff = preprocess_end - preprocess_start;
times->push_back(double(preprocess_diff.count() * 1000));
std::chrono::duration<float> inference_diff = inference_end - inference_start;
times->push_back(double(inference_diff.count() * 1000));
std::chrono::duration<float> postprocess_diff = postprocess_end - postprocess_start;
times->push_back(double(postprocess_diff.count() * 1000));
return filter_boxes;
}
std::shared_ptr<PaddlePredictor> loadModel(std::string model_file) {
std::shared_ptr<PaddlePredictor> loadModel(std::string model_file, std::string power_mode, int num_threads) {
MobileConfig config;
config.set_model_from_file(model_file);
if (power_mode == "LITE_POWER_HIGH"){
config.set_power_mode(LITE_POWER_HIGH);
} else {
if (power_mode == "LITE_POWER_LOW") {
config.set_power_mode(LITE_POWER_HIGH);
} else {
std::cerr << "Only support LITE_POWER_HIGH or LITE_POWER_HIGH." << std::endl;
exit(1);
}
}
config.set_threads(num_threads);
std::shared_ptr<PaddlePredictor> predictor =
CreatePaddlePredictor<MobileConfig>(config);
return predictor;
......@@ -354,60 +383,255 @@ std::map<std::string, double> LoadConfigTxt(std::string config_path) {
return dict;
}
int main(int argc, char **argv) {
if (argc < 5) {
std::cerr << "[ERROR] usage: " << argv[0]
<< " det_model_file cls_model_file rec_model_file image_path "
"charactor_dict\n";
void check_params(int argc, char **argv) {
if (argc<=1 || (strcmp(argv[1], "det")!=0 && strcmp(argv[1], "rec")!=0 && strcmp(argv[1], "system")!=0)) {
std::cerr << "Please choose one mode of [det, rec, system] !" << std::endl;
exit(1);
}
std::string det_model_file = argv[1];
std::string rec_model_file = argv[2];
std::string cls_model_file = argv[3];
std::string img_path = argv[4];
std::string dict_path = argv[5];
if (strcmp(argv[1], "det") == 0) {
if (argc < 9){
std::cerr << "[ERROR] usage:" << argv[0]
<< " det det_model num_threads batchsize power_mode img_dir det_config lite_benchmark_value" << std::endl;
exit(1);
}
}
//// load config from txt file
auto Config = LoadConfigTxt("./config.txt");
int use_direction_classify = int(Config["use_direction_classify"]);
if (strcmp(argv[1], "rec") == 0) {
if (argc < 9){
std::cerr << "[ERROR] usage:" << argv[0]
<< " rec rec_model num_threads batchsize power_mode img_dir key_txt lite_benchmark_value" << std::endl;
exit(1);
}
}
if (strcmp(argv[1], "system") == 0) {
if (argc < 12){
std::cerr << "[ERROR] usage:" << argv[0]
<< " system det_model rec_model clas_model num_threads batchsize power_mode img_dir det_config key_txt lite_benchmark_value" << std::endl;
exit(1);
}
}
}
void system(char **argv){
std::string det_model_file = argv[2];
std::string rec_model_file = argv[3];
std::string cls_model_file = argv[4];
std::string precision = argv[5];
std::string num_threads = argv[6];
std::string batchsize = argv[7];
std::string power_mode = argv[8];
std::string img_dir = argv[9];
std::string det_config_path = argv[10];
std::string dict_path = argv[11];
if (strcmp(argv[5], "FP32") != 0 && strcmp(argv[5], "INT8") != 0) {
std::cerr << "Only support FP32 or INT8." << std::endl;
exit(1);
}
auto start = std::chrono::system_clock::now();
std::vector<cv::String> cv_all_img_names;
cv::glob(img_dir, cv_all_img_names);
auto det_predictor = loadModel(det_model_file);
auto rec_predictor = loadModel(rec_model_file);
auto cls_predictor = loadModel(cls_model_file);
//// load config from txt file
auto Config = LoadConfigTxt(det_config_path);
int use_direction_classify = int(Config["use_direction_classify"]);
auto charactor_dict = ReadDict(dict_path);
charactor_dict.insert(charactor_dict.begin(), "#"); // blank char for ctc
charactor_dict.push_back(" ");
cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
auto boxes = RunDetModel(det_predictor, srcimg, Config);
auto det_predictor = loadModel(det_model_file, power_mode, std::stoi(num_threads));
auto rec_predictor = loadModel(rec_model_file, power_mode, std::stoi(num_threads));
auto cls_predictor = loadModel(cls_model_file, power_mode, std::stoi(num_threads));
for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::cout << "The predict img: " << cv_all_img_names[i] << std::endl;
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: " << cv_all_img_names[i] << std::endl;
exit(1);
}
std::vector<double> det_times;
auto boxes = RunDetModel(det_predictor, srcimg, Config, &det_times);
std::vector<std::string> rec_text;
std::vector<float> rec_text_score;
std::vector<double> rec_times;
RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score,
charactor_dict, cls_predictor, use_direction_classify);
charactor_dict, cls_predictor, use_direction_classify, &rec_times);
auto end = std::chrono::system_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
//// visualization
auto img_vis = Visualization(srcimg, boxes);
//// print recognized text
for (int i = 0; i < rec_text.size(); i++) {
std::cout << i << "\t" << rec_text[i] << "\t" << rec_text_score[i]
<< std::endl;
}
}
}
void det(int argc, char **argv) {
std::string det_model_file = argv[2];
std::string precision = argv[3];
std::string num_threads = argv[4];
std::string batchsize = argv[5];
std::string power_mode = argv[6];
std::string img_dir = argv[7];
std::string det_config_path = argv[8];
if (strcmp(argv[3], "FP32") != 0 && strcmp(argv[3], "INT8") != 0) {
std::cerr << "Only support FP32 or INT8." << std::endl;
exit(1);
}
std::vector<cv::String> cv_all_img_names;
cv::glob(img_dir, cv_all_img_names);
//// load config from txt file
auto Config = LoadConfigTxt(det_config_path);
auto det_predictor = loadModel(det_model_file, power_mode, std::stoi(num_threads));
std::vector<double> time_info = {0, 0, 0};
for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::cout << "The predict img: " << cv_all_img_names[i] << std::endl;
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: " << cv_all_img_names[i] << std::endl;
exit(1);
}
std::vector<double> times;
auto boxes = RunDetModel(det_predictor, srcimg, Config, &times);
//// visualization
auto img_vis = Visualization(srcimg, boxes);
std::cout << boxes.size() << " bboxes have detected:" << std::endl;
// for (int i=0; i<boxes.size(); i++){
// std::cout << "The " << i << " box:" << std::endl;
// for (int j=0; j<4; j++){
// for (int k=0; k<2; k++){
// std::cout << boxes[i][j][k] << "\t";
// }
// }
// std::cout << std::endl;
// }
time_info[0] += times[0];
time_info[1] += times[1];
time_info[2] += times[2];
}
if (strcmp(argv[9], "True") == 0) {
AutoLogger autolog(det_model_file,
0,
0,
0,
std::stoi(num_threads),
std::stoi(batchsize),
"dynamic",
precision,
power_mode,
time_info,
cv_all_img_names.size());
autolog.report();
}
}
void rec(int argc, char **argv) {
std::string rec_model_file = argv[2];
std::string precision = argv[3];
std::string num_threads = argv[4];
std::string batchsize = argv[5];
std::string power_mode = argv[6];
std::string img_dir = argv[7];
std::string dict_path = argv[8];
if (strcmp(argv[3], "FP32") != 0 && strcmp(argv[3], "INT8") != 0) {
std::cerr << "Only support FP32 or INT8." << std::endl;
exit(1);
}
std::vector<cv::String> cv_all_img_names;
cv::glob(img_dir, cv_all_img_names);
auto charactor_dict = ReadDict(dict_path);
charactor_dict.insert(charactor_dict.begin(), "#"); // blank char for ctc
charactor_dict.push_back(" ");
auto rec_predictor = loadModel(rec_model_file, power_mode, std::stoi(num_threads));
std::shared_ptr<PaddlePredictor> cls_predictor;
std::vector<double> time_info = {0, 0, 0};
for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::cout << "The predict img: " << cv_all_img_names[i] << std::endl;
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: " << cv_all_img_names[i] << std::endl;
exit(1);
}
int width = srcimg.cols;
int height = srcimg.rows;
std::vector<int> upper_left = {0, 0};
std::vector<int> upper_right = {width, 0};
std::vector<int> lower_right = {width, height};
std::vector<int> lower_left = {0, height};
std::vector<std::vector<int>> box = {upper_left, upper_right, lower_right, lower_left};
std::vector<std::vector<std::vector<int>>> boxes = {box};
std::vector<std::string> rec_text;
std::vector<float> rec_text_score;
std::vector<double> times;
RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score,
charactor_dict, cls_predictor, 0, &times);
//// print recognized text
for (int i = 0; i < rec_text.size(); i++) {
std::cout << i << "\t" << rec_text[i] << "\t" << rec_text_score[i]
<< std::endl;
}
}
// TODO: support autolog
if (strcmp(argv[9], "True") == 0) {
AutoLogger autolog(rec_model_file,
0,
0,
0,
std::stoi(num_threads),
std::stoi(batchsize),
"dynamic",
precision,
power_mode,
time_info,
cv_all_img_names.size());
autolog.report();
}
}
int main(int argc, char **argv) {
check_params(argc, argv);
std::cout << "mode: " << argv[1] << endl;
std::cout << "花费了"
<< double(duration.count()) *
std::chrono::microseconds::period::num /
std::chrono::microseconds::period::den
<< "秒" << std::endl;
if (strcmp(argv[1], "system") == 0) {
system(argv);
}
if (strcmp(argv[1], "det") == 0) {
det(argc, argv);
}
if (strcmp(argv[1], "rec") == 0) {
rec(argc, argv);
}
return 0;
}
......@@ -64,7 +64,7 @@ C-CTC Loss是CTC Loss + Center Loss的简称。 其中Center Loss出自论文 <
以配置文件`configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml`为例, center提取命令如下所示:
```
python tools/export_center.py -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml -o Global.pretrained_model: "./output/rec_mobile_pp-OCRv2/best_accuracy"
python tools/export_center.py -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml -o Global.pretrained_model="./output/rec_mobile_pp-OCRv2/best_accuracy"
```
运行完后,会在PaddleOCR主目录下生成`train_center.pkl`.
......
doc/joinus.PNG

188.2 KB | W: | H:

doc/joinus.PNG

209.7 KB | W: | H:

doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
  • 2-up
  • Swipe
  • Onion skin
......@@ -29,10 +29,7 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
TableLabelDecode, NRTRLabelDecode, SARLabelDecode , SEEDLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
if platform.system() != "Windows":
# pse is not support in Windows
from .pse_postprocess import PSEPostProcess
from .pse_postprocess import PSEPostProcess
def build_post_process(config, global_config=None):
......
......@@ -17,7 +17,12 @@ import subprocess
python_path = sys.executable
if subprocess.call('cd ppocr/postprocess/pse_postprocess/pse;{} setup.py build_ext --inplace;cd -'.format(python_path), shell=True) != 0:
raise RuntimeError('Cannot compile pse: {}'.format(os.path.dirname(os.path.realpath(__file__))))
ori_path = os.getcwd()
os.chdir('ppocr/postprocess/pse_postprocess/pse')
if subprocess.call(
'{} setup.py build_ext --inplace'.format(python_path), shell=True) != 0:
raise RuntimeError('Cannot compile pse: {}'.format(
os.path.dirname(os.path.realpath(__file__))))
os.chdir(ori_path)
from .pse import pse
===========================train_params===========================
model_name:ocr_det
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
gpu_list:0|0,1|10.21.226.181,10.21.226.133;0,1
Global.use_gpu:True|True|True
Global.auto_cast:fp32|amp
Global.epoch_num:lite_train_infer=1|whole_train_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_infer=2|whole_train_infer=4
......@@ -98,3 +98,13 @@ null:null
--benchmark:True
null:null
null:null
===========================lite_params===========================
inference:./ocr_db_crnn det
infer_model:./models/ch_ppocr_mobile_v2.0_det_opt.nb|./models/ch_ppocr_mobile_v2.0_det_slim_opt.nb
--cpu_threads:1|4
--batch_size:1
--power_mode:LITE_POWER_HIGH|LITE_POWER_LOW
--image_dir:./test_data/icdar2015_lite/text_localization/ch4_test_images/|./test_data/icdar2015_lite/text_localization/ch4_test_images/img_233.jpg
--config_dir:./config.txt
--rec_dict_dir:./ppocr_keys_v1.txt
--benchmark:True
......@@ -15,15 +15,15 @@ C++预测功能测试的主程序为`test_inference_cpp.sh`,可以测试基于
## 2. 测试流程
### 2.1 功能测试
先运行`prepare.sh`准备数据和模型,然后运行`test_inference_cpp.sh`进行测试,最终在```tests/output```目录下生成`cpp_infer_*.log`后缀的日志文件。
先运行`prepare.sh`准备数据和模型,然后运行`test_inference_cpp.sh`进行测试,最终在```test_tipc/output```目录下生成`cpp_infer_*.log`后缀的日志文件。
```shell
bash tests/prepare.sh ./tests/configs/ppocr_det_mobile_params.txt "cpp_infer"
bash test_tipc/prepare.sh ./test_tipc/configs/ppocr_det_mobile_params.txt "cpp_infer"
# 用法1:
bash tests/test_inference_cpp.sh ./tests/configs/ppocr_det_mobile_params.txt
bash test_tipc/test_inference_cpp.sh ./test_tipc/configs/ppocr_det_mobile_params.txt
# 用法2: 指定GPU卡预测,第三个传入参数为GPU卡号
bash tests/test_inference_cpp.sh ./tests/configs/ppocr_det_mobile_params.txt '1'
bash test_tipc/test_inference_cpp.sh ./test_tipc/configs/ppocr_det_mobile_params.txt '1'
```
......@@ -37,12 +37,12 @@ bash tests/test_inference_cpp.sh ./tests/configs/ppocr_det_mobile_params.txt '1'
#### 使用方式
运行命令:
```shell
python3.7 tests/compare_results.py --gt_file=./tests/results/cpp_*.txt --log_file=./tests/output/cpp_*.log --atol=1e-3 --rtol=1e-3
python3.7 test_tipc/compare_results.py --gt_file=./test_tipc/results/cpp_*.txt --log_file=./test_tipc/output/cpp_*.log --atol=1e-3 --rtol=1e-3
```
参数介绍:
- gt_file: 指向事先保存好的预测结果路径,支持*.txt 结尾,会自动索引*.txt格式的文件,文件默认保存在tests/result/ 文件夹下
- log_file: 指向运行tests/test.sh 脚本的infer模式保存的预测日志,预测日志中打印的有预测结果,比如:文本框,预测文本,类别等等,同样支持infer_*.log格式传入
- gt_file: 指向事先保存好的预测结果路径,支持*.txt 结尾,会自动索引*.txt格式的文件,文件默认保存在test_tipc/result/ 文件夹下
- log_file: 指向运行test_tipc/test_inference_cpp.sh 脚本的infer模式保存的预测日志,预测日志中打印的有预测结果,比如:文本框,预测文本,类别等等,同样支持cpp_infer_*.log格式传入
- atol: 设置的绝对误差
- rtol: 设置的相对误差
......
# PaddleServing预测功能测试
PaddleServing预测功能测试的主程序为`test_serving.sh`,可以测试基于PaddleServing的部署功能。
## 1. 测试结论汇总
基于训练是否使用量化,进行本测试的模型可以分为`正常模型``量化模型`,这两类模型对应的C++预测功能汇总如下:
| 模型类型 |device | batchsize | tensorrt | mkldnn | cpu多线程 |
| ---- | ---- | ---- | :----: | :----: | :----: |
| 正常模型 | GPU | 1/6 | fp32/fp16 | - | - |
| 正常模型 | CPU | 1/6 | - | fp32 | 支持 |
| 量化模型 | GPU | 1/6 | int8 | - | - |
| 量化模型 | CPU | 1/6 | - | int8 | 支持 |
## 2. 测试流程
### 2.1 功能测试
先运行`prepare.sh`准备数据和模型,然后运行`test_serving.sh`进行测试,最终在```test_tipc/output```目录下生成`serving_infer_*.log`后缀的日志文件。
```shell
bash test_tipc/prepare.sh ./test_tipc/configs/ppocr_det_mobile_params.txt "serving_infer"
# 用法:
bash test_tipc/test_serving.sh ./test_tipc/configs/ppocr_det_mobile_params.txt
```
#### 运行结果
各测试的运行情况会打印在 `test_tipc/output/results_serving.log` 中:
运行成功时会输出:
```
Run successfully with command - python3.7 pipeline_http_client.py --image_dir=../../doc/imgs > ../../tests/output/server_infer_cpu_usemkldnn_True_threads_1_batchsize_1.log 2>&1 !
Run successfully with command - xxxxx
...
```
运行失败时会输出:
```
Run failed with command - python3.7 pipeline_http_client.py --image_dir=../../doc/imgs > ../../tests/output/server_infer_cpu_usemkldnn_True_threads_1_batchsize_1.log 2>&1 !
Run failed with command - python3.7 pipeline_http_client.py --image_dir=../../doc/imgs > ../../tests/output/server_infer_cpu_usemkldnn_True_threads_6_batchsize_1.log 2>&1 !
Run failed with command - xxxxx
...
```
详细的预测结果会存在 test_tipc/output/ 文件夹下,例如`server_infer_gpu_usetrt_True_precision_fp16_batchsize_1.log`中会返回检测框的坐标:
```
{'err_no': 0, 'err_msg': '', 'key': ['dt_boxes'], 'value': ['[[[ 78. 642.]\n [409. 640.]\n [409. 657.]\n
[ 78. 659.]]\n\n [[ 75. 614.]\n [211. 614.]\n [211. 635.]\n [ 75. 635.]]\n\n
[[103. 554.]\n [135. 554.]\n [135. 575.]\n [103. 575.]]\n\n [[ 75. 531.]\n
[347. 531.]\n [347. 549.]\n [ 75. 549.] ]\n\n [[ 76. 503.]\n [309. 498.]\n
[309. 521.]\n [ 76. 526.]]\n\n [[163. 462.]\n [317. 462.]\n [317. 493.]\n
[163. 493.]]\n\n [[324. 431.]\n [414. 431.]\n [414. 452.]\n [324. 452.]]\n\n
[[ 76. 412.]\n [208. 408.]\n [209. 424.]\n [ 76. 428.]]\n\n [[307. 409.]\n
[428. 409.]\n [428. 426.]\n [307 . 426.]]\n\n [[ 74. 385.]\n [217. 382.]\n
[217. 400.]\n [ 74. 403.]]\n\n [[308. 381.]\n [427. 380.]\n [427. 400.]\n
[308. 401.]]\n\n [[ 74. 363.]\n [195. 362.]\n [195. 378.]\n [ 74. 379.]]\n\n
[[303. 359.]\n [423. 357.]\n [423. 375.]\n [303. 377.]]\n\n [[ 70. 336.]\n
[239. 334.]\n [239. 354.]\ n [ 70. 356.]]\n\n [[ 70. 312.]\n [204. 310.]\n
[204. 327.]\n [ 70. 330.]]\n\n [[303. 308.]\n [419. 306.]\n [419. 326.]\n
[303. 328.]]\n\n [[113. 2 72.]\n [246. 270.]\n [247. 299.]\n [113. 301.]]\n\n
[[361. 269.]\n [384. 269.]\n [384. 296.]\n [361. 296.]]\n\n [[ 70. 250.]\n
[243. 246.]\n [243. 265.]\n [ 70. 269.]]\n\n [[ 65. 221.]\n [187. 220.]\n
[187. 240.]\n [ 65. 241.]]\n\n [[337. 216.]\n [382. 216.]\n [382. 240.]\n
[337. 240.]]\n\n [ [ 65. 196.]\n [247. 193.]\n [247. 213.]\n [ 65. 216.]]\n\n
[[296. 197.]\n [423. 191.]\n [424. 209.]\n [296. 215.]]\n\n [[ 65. 167.]\n [244. 167.]\n
[244. 186.]\n [ 65. 186.]]\n\n [[ 67. 139.]\n [290. 139.]\n [290. 159.]\n [ 67. 159.]]\n\n
[[ 68. 113.]\n [410. 113.]\n [410. 128.]\n [ 68. 129.] ]\n\n [[277. 87.]\n [416. 87.]\n
[416. 108.]\n [277. 108.]]\n\n [[ 79. 28.]\n [132. 28.]\n [132. 62.]\n [ 79. 62.]]\n\n
[[163. 17.]\n [410. 14.]\n [410. 50.]\n [163. 53.]]]']}
```
## 3. 更多教程
本文档为功能测试用,更详细的Serving预测使用教程请参考:[PPOCR 服务化部署](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/deploy/pdserving/README_CN.md)
......@@ -46,42 +46,42 @@
### 2.2 功能测试
先运行`prepare.sh`准备数据和模型,然后运行`test_train_inference_python.sh`进行测试,最终在```tests/output```目录下生成`python_infer_*.log`格式的日志文件。
先运行`prepare.sh`准备数据和模型,然后运行`test_train_inference_python.sh`进行测试,最终在```test_tipc/output```目录下生成`python_infer_*.log`格式的日志文件。
`test_train_inference_python.sh`包含5种运行模式,每种模式的运行数据不同,分别用于测试速度和精度,分别是:
- 模式1:lite_train_infer,使用少量数据训练,用于快速验证训练到预测的走通流程,不验证精度和速度;
```shell
bash tests/prepare.sh ./tests/configs/ppocr_det_mobile_params.txt 'lite_train_infer'
bash tests/test_train_inference_python.sh ./tests/configs/ppocr_det_mobile_params.txt 'lite_train_infer'
bash test_tipc/prepare.sh ./test_tipc/configs/ppocr_det_mobile_params.txt 'lite_train_infer'
bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ppocr_det_mobile_params.txt 'lite_train_infer'
```
- 模式2:whole_infer,使用少量数据训练,一定量数据预测,用于验证训练后的模型执行预测,预测速度是否合理;
```shell
bash tests/prepare.sh ./tests/configs/ppocr_det_mobile_params.txt 'whole_infer'
bash tests/test_train_inference_python.sh ./tests/configs/ppocr_det_mobile_params.txt 'whole_infer'
bash test_tipc/prepare.sh ./test_tipc/configs/ppocr_det_mobile_params.txt 'whole_infer'
bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ppocr_det_mobile_params.txt 'whole_infer'
```
- 模式3:infer,不训练,全量数据预测,走通开源模型评估、动转静,检查inference model预测时间和精度;
```shell
bash tests/prepare.sh ./tests/configs/ppocr_det_mobile_params.txt 'infer'
bash test_tipc/prepare.sh ./test_tipc/configs/ppocr_det_mobile_params.txt 'infer'
# 用法1:
bash tests/test_train_inference_python.sh ./tests/configs/ppocr_det_mobile_params.txt 'infer'
bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ppocr_det_mobile_params.txt 'infer'
# 用法2: 指定GPU卡预测,第三个传入参数为GPU卡号
bash tests/test_train_inference_python.sh ./tests/configs/ppocr_det_mobile_params.txt 'infer' '1'
bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ppocr_det_mobile_params.txt 'infer' '1'
```
- 模式4:whole_train_infer,CE: 全量数据训练,全量数据预测,验证模型训练精度,预测精度,预测速度;
```shell
bash tests/prepare.sh ./tests/configs/ppocr_det_mobile_params.txt 'whole_train_infer'
bash tests/test_train_inference_python.sh ./tests/configs/ppocr_det_mobile_params.txt 'whole_train_infer'
bash test_tipc/prepare.sh ./test_tipc/configs/ppocr_det_mobile_params.txt 'whole_train_infer'
bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ppocr_det_mobile_params.txt 'whole_train_infer'
```
- 模式5:klquant_infer,测试离线量化;
```shell
bash tests/prepare.sh ./tests/configs/ppocr_det_mobile_params.txt 'klquant_infer'
bash tests/test_train_inference_python.sh tests/configs/ppocr_det_mobile_params.txt 'klquant_infer'
bash test_tipc/prepare.sh ./test_tipc/configs/ppocr_det_mobile_params.txt 'klquant_infer'
bash test_tipc/test_train_inference_python.sh test_tipc/configs/ppocr_det_mobile_params.txt 'klquant_infer'
```
......@@ -95,12 +95,12 @@ bash tests/test_train_inference_python.sh tests/configs/ppocr_det_mobile_params.
#### 使用方式
运行命令:
```shell
python3.7 tests/compare_results.py --gt_file=./tests/results/python_*.txt --log_file=./tests/output/python_*.log --atol=1e-3 --rtol=1e-3
python3.7 test_tipc/compare_results.py --gt_file=./test_tipc/results/python_*.txt --log_file=./test_tipc/output/python_*.log --atol=1e-3 --rtol=1e-3
```
参数介绍:
- gt_file: 指向事先保存好的预测结果路径,支持*.txt 结尾,会自动索引*.txt格式的文件,文件默认保存在tests/result/ 文件夹下
- log_file: 指向运行tests/test.sh 脚本的infer模式保存的预测日志,预测日志中打印的有预测结果,比如:文本框,预测文本,类别等等,同样支持infer_*.log格式传入
- gt_file: 指向事先保存好的预测结果路径,支持*.txt 结尾,会自动索引*.txt格式的文件,文件默认保存在test_tipc/result/ 文件夹下
- log_file: 指向运行test_tipc/test_train_inference_python.sh 脚本的infer模式保存的预测日志,预测日志中打印的有预测结果,比如:文本框,预测文本,类别等等,同样支持python_infer_*.log格式传入
- atol: 设置的绝对误差
- rtol: 设置的相对误差
......
......@@ -2,7 +2,7 @@
FILENAME=$1
# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer',
# 'cpp_infer', 'serving_infer', 'klquant_infer']
# 'cpp_infer', 'serving_infer', 'klquant_infer', 'lite_infer']
MODE=$2
......@@ -136,3 +136,37 @@ if [ ${MODE} = "serving_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_ppocr_server_v2.0_det_infer.tar && cd ../
fi
if [ ${MODE} = "lite_infer" ];then
# prepare lite nb model and test data
current_dir=${PWD}
wget -nc -P ./models https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_det_opt.nb
wget -nc -P ./models https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_det_slim_opt.nb
wget -nc -P ./test_data https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar
cd ./test_data && tar -xf icdar2015_lite.tar && rm icdar2015_lite.tar && cd ../
# prepare lite env
export http_proxy=http://172.19.57.45:3128
export https_proxy=http://172.19.57.45:3128
paddlelite_url=https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.9/inference_lite_lib.android.armv8.gcc.c++_shared.with_extra.with_cv.tar.gz
paddlelite_zipfile=$(echo $paddlelite_url | awk -F "/" '{print $NF}')
paddlelite_file=inference_lite_lib.android.armv8.gcc.c++_shared.with_extra.with_cv
wget ${paddlelite_url}
tar -xf ${paddlelite_zipfile}
mkdir -p ${paddlelite_file}/demo/cxx/ocr/test_lite
mv models test_data ${paddlelite_file}/demo/cxx/ocr/test_lite
cp ppocr/utils/ppocr_keys_v1.txt deploy/lite/config.txt ${paddlelite_file}/demo/cxx/ocr/test_lite
cp ./deploy/lite/* ${paddlelite_file}/demo/cxx/ocr/
cp ${paddlelite_file}/cxx/lib/libpaddle_light_api_shared.so ${paddlelite_file}/demo/cxx/ocr/test_lite
cp PTDN/configs/ppocr_det_mobile_params.txt PTDN/test_lite.sh PTDN/common_func.sh ${paddlelite_file}/demo/cxx/ocr/test_lite
cd ${paddlelite_file}/demo/cxx/ocr/
git clone https://github.com/LDOUBLEV/AutoLog.git
unset http_proxy
unset https_proxy
make -j
sleep 1
make -j
cp ocr_db_crnn test_lite && cp test_lite/libpaddle_light_api_shared.so test_lite/libc++_shared.so
tar -cf test_lite.tar ./test_lite && cp test_lite.tar ${current_dir} && cd ${current_dir}
fi
# 推理部署导航
# 飞桨训推一体认证
## 1. 简介
飞桨除了基本的模型训练和预测,还提供了支持多端多平台的高性能推理部署工具。本文档提供了PaddleOCR中所有模型的推理部署导航PTDN(Paddle Train Deploy Navigation),方便用户查阅每种模型的推理部署打通情况,并可以进行一键测试。
飞桨除了基本的模型训练和预测,还提供了支持多端多平台的高性能推理部署工具。本文档提供了PaddleOCR中所有模型的飞桨训推一体认证 (Training and Inference Pipeline Certification(TIPC)) 信息和测试工具,方便用户查阅每种模型的训练推理部署打通情况,并可以进行一键测试。
<div align="center">
<img src="docs/guide.png" width="1000">
......@@ -15,20 +15,23 @@
**字段说明:**
- 基础训练预测:包括模型训练、Paddle Inference Python预测。
- 其他:包括Paddle Inference C++预测、Paddle Serving部署、Paddle-Lite部署等。
- 更多训练方式:包括多机多卡、混合精度。
- 模型压缩:包括裁剪、离线/在线量化、蒸馏。
- 其他预测部署:包括Paddle Inference C++预测、Paddle Serving部署、Paddle-Lite部署等。
更详细的mkldnn、Tensorrt等预测加速相关功能的支持情况可以查看各测试工具的[更多教程](#more)
| 算法论文 | 模型名称 | 模型类型 | 基础训练预测 | 其他 |
| :--- | :--- | :----: | :--------: | :---- |
| DB |ch_ppocr_mobile_v2.0_det | 检测 | 支持 | Paddle Inference: C++ <br> Paddle Serving: Python, C++ <br> Paddle-Lite: <br> (1) ARM CPU(C++) |
| DB |ch_ppocr_server_v2.0_det | 检测 | 支持 | Paddle Inference: C++ <br> Paddle Serving: Python, C++ <br> Paddle-Lite: <br> (1) ARM CPU(C++) |
| 算法论文 | 模型名称 | 模型类型 | 基础<br>训练预测 | 更多<br>训练方式 | 模型压缩 | 其他预测部署 |
| :--- | :--- | :----: | :--------: | :---- | :---- | :---- |
| DB |ch_ppocr_mobile_v2.0_det | 检测 | 支持 | 多机多卡 <br> 混合精度 | FPGM裁剪 <br> 离线量化| Paddle Inference: C++ <br> Paddle Serving: Python, C++ <br> Paddle-Lite: <br> (1) ARM CPU(C++) |
| DB |ch_ppocr_server_v2.0_det | 检测 | 支持 | 多机多卡 <br> 混合精度 | FPGM裁剪 <br> 离线量化| Paddle Inference: C++ <br> Paddle Serving: Python, C++ <br> Paddle-Lite: <br> (1) ARM CPU(C++) |
| DB |ch_PP-OCRv2_det | 检测 |
| CRNN |ch_ppocr_mobile_v2.0_rec | 识别 | 支持 | Paddle Inference: C++ <br> Paddle Serving: Python, C++ <br> Paddle-Lite: <br> (1) ARM CPU(C++) |
| CRNN |ch_ppocr_server_v2.0_rec | 识别 | 支持 | Paddle Inference: C++ <br> Paddle Serving: Python, C++ <br> Paddle-Lite: <br> (1) ARM CPU(C++) |
| CRNN |ch_ppocr_mobile_v2.0_rec | 识别 | 支持 | 多机多卡 <br> 混合精度 | PACT量化 <br> 离线量化| Paddle Inference: C++ <br> Paddle Serving: Python, C++ <br> Paddle-Lite: <br> (1) ARM CPU(C++) |
| CRNN |ch_ppocr_server_v2.0_rec | 识别 | 支持 | 多机多卡 <br> 混合精度 | PACT量化 <br> 离线量化| Paddle Inference: C++ <br> Paddle Serving: Python, C++ <br> Paddle-Lite: <br> (1) ARM CPU(C++) |
| CRNN |ch_PP-OCRv2_rec | 识别 |
| PP-OCR |ch_ppocr_mobile_v2.0 | 检测+识别 | 支持 | Paddle Inference: C++ <br> Paddle Serving: Python, C++ <br> Paddle-Lite: <br> (1) ARM CPU(C++) |
| PP-OCR |ch_ppocr_server_v2.0 | 检测+识别 | 支持 | Paddle Inference: C++ <br> Paddle Serving: Python, C++ <br> Paddle-Lite: <br> (1) ARM CPU(C++) |
|PP-OCRv2|ch_PP-OCRv2 | 检测+识别 | 支持 | Paddle Inference: C++ <br> Paddle Serving: Python, C++ <br> Paddle-Lite: <br> (1) ARM CPU(C++) |
| PP-OCR |ch_ppocr_mobile_v2.0 | 检测+识别 | 支持 | 多机多卡 <br> 混合精度 | - | Paddle Inference: C++ <br> Paddle Serving: Python, C++ <br> Paddle-Lite: <br> (1) ARM CPU(C++) |
| PP-OCR |ch_ppocr_server_v2.0 | 检测+识别 | 支持 | 多机多卡 <br> 混合精度 | - | Paddle Inference: C++ <br> Paddle Serving: Python, C++ <br> Paddle-Lite: <br> (1) ARM CPU(C++) |
|PP-OCRv2|ch_PP-OCRv2 | 检测+识别 |
| DB |det_mv3_db_v2.0 | 检测 |
| DB |det_r50_vd_db_v2.0 | 检测 |
| EAST |det_mv3_east_v2.0 | 检测 |
......@@ -55,7 +58,7 @@
### 目录介绍
```shell
PTDN/
test_tipc/
├── configs/ # 配置文件目录
├── det_mv3_db.yml # 测试mobile版ppocr检测模型训练的yml文件
├── det_r50_vd_db.yml # 测试server版ppocr检测模型训练的yml文件
......@@ -98,6 +101,8 @@ PTDN/
- `test_serving.sh`:测试基于Paddle Serving的服务化部署功能。
- `test_lite.sh`:测试基于Paddle-Lite的端侧预测部署功能。
<a name="more"></a>
#### 更多教程
各功能测试中涉及混合精度、裁剪、量化等训练相关,及mkldnn、Tensorrt等多种预测相关参数配置,请点击下方相应链接了解更多细节和使用教程:
[test_train_inference_python 使用](docs/test_train_inference_python.md)
[test_inference_cpp 使用](docs/test_inference_cpp.md)
......
#!/bin/bash
source ./common_func.sh
export LD_LIBRARY_PATH=${PWD}:$LD_LIBRARY_PATH
FILENAME=$1
dataline=$(awk 'NR==101, NR==110{print}' $FILENAME)
echo $dataline
# parser params
IFS=$'\n'
lines=(${dataline})
# parser lite inference
lite_inference_cmd=$(func_parser_value "${lines[1]}")
lite_model_dir_list=$(func_parser_value "${lines[2]}")
lite_cpu_threads_list=$(func_parser_value "${lines[3]}")
lite_batch_size_list=$(func_parser_value "${lines[4]}")
lite_power_mode_list=$(func_parser_value "${lines[5]}")
lite_infer_img_dir_list=$(func_parser_value "${lines[6]}")
lite_config_dir=$(func_parser_value "${lines[7]}")
lite_rec_dict_dir=$(func_parser_value "${lines[8]}")
lite_benchmark_value=$(func_parser_value "${lines[9]}")
LOG_PATH="./output"
mkdir -p ${LOG_PATH}
status_log="${LOG_PATH}/results.log"
function func_lite(){
IFS='|'
_script=$1
_lite_model=$2
_log_path=$3
_img_dir=$4
_config=$5
if [[ $lite_model =~ "slim" ]]; then
precision="INT8"
else
precision="FP32"
fi
is_single_img=$(echo $_img_dir | grep -E ".jpg|.jpeg|.png|.JPEG|.JPG")
if [[ "$is_single_img" != "" ]]; then
single_img="True"
else
single_img="False"
fi
# lite inference
for num_threads in ${lite_cpu_threads_list[*]}; do
for power_mode in ${lite_power_mode_list[*]}; do
for batchsize in ${lite_batch_size_list[*]}; do
model_name=$(echo $lite_model | awk -F "/" '{print $NF}')
_save_log_path="${_log_path}/lite_${model_name}_precision_${precision}_batchsize_${batchsize}_threads_${num_threads}_powermode_${power_mode}_singleimg_${single_img}.log"
command="${_script} ${lite_model} ${precision} ${num_threads} ${batchsize} ${power_mode} ${_img_dir} ${_config} ${lite_benchmark_value} > ${_save_log_path} 2>&1"
eval ${command}
status_check $? "${command}" "${status_log}"
done
done
done
}
echo "################### run test ###################"
IFS="|"
for lite_model in ${lite_model_dir_list[*]}; do
#run lite inference
for img_dir in ${lite_infer_img_dir_list[*]}; do
func_lite "${lite_inference_cmd}" "${lite_model}" "${LOG_PATH}" "${img_dir}" "${lite_config_dir}"
done
done
#!/bin/bash
source tests/common_func.sh
source PTDN/common_func.sh
FILENAME=$1
dataline=$(awk 'NR==67, NR==83{print}' $FILENAME)
......@@ -36,8 +36,8 @@ web_precision_key=$(func_parser_key "${lines[15]}")
web_precision_list=$(func_parser_value "${lines[15]}")
pipeline_py=$(func_parser_value "${lines[16]}")
LOG_PATH="../../tests/output"
mkdir -p ./tests/output
LOG_PATH="../../PTDN/output"
mkdir -p ./PTDN/output
status_log="${LOG_PATH}/results_serving.log"
function func_serving(){
......
......@@ -245,6 +245,7 @@ else
for gpu in ${gpu_list[*]}; do
use_gpu=${USE_GPU_KEY[Count]}
Count=$(($Count + 1))
ips=""
if [ ${gpu} = "-1" ];then
env=""
elif [ ${#gpu} -le 1 ];then
......@@ -264,6 +265,11 @@ else
env=" "
fi
for autocast in ${autocast_list[*]}; do
if [ ${autocast} = "amp" ]; then
set_amp_config="Global.use_amp=True Global.scale_loss=1024.0 Global.use_dynamic_loss_scaling=True"
else
set_amp_config=" "
fi
for trainer in ${trainer_list[*]}; do
flag_quant=False
if [ ${trainer} = ${pact_key} ]; then
......@@ -290,7 +296,6 @@ else
if [ ${run_train} = "null" ]; then
continue
fi
set_autocast=$(func_set_params "${autocast_key}" "${autocast}")
set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}")
set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}")
......@@ -306,11 +311,11 @@ else
set_save_model=$(func_set_params "${save_model_key}" "${save_log}")
if [ ${#gpu} -le 2 ];then # train with cpu or single gpu
cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} "
elif [ ${#gpu} -le 15 ];then # train with multi-gpu
cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1}"
cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config} "
elif [ ${#ips} -le 26 ];then # train with multi-gpu
cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
else # train with multi-machine
cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_train_params1}"
cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${set_use_gpu} ${run_train} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
fi
# run train
eval "unset CUDA_VISIBLE_DEVICES"
......
......@@ -159,7 +159,8 @@ def train(config,
eval_class,
pre_best_model_dict,
logger,
vdl_writer=None):
vdl_writer=None,
scaler=None):
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False)
log_smooth_window = config['Global']['log_smooth_window']
......@@ -226,12 +227,27 @@ def train(config,
images = batch[0]
if use_srn:
model_average = True
# use amp
if scaler:
with paddle.amp.auto_cast():
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
else:
preds = model(images)
else:
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
else:
preds = model(images)
loss = loss_class(preds, batch)
avg_loss = loss['loss']
if scaler:
scaled_avg_loss = scaler.scale(avg_loss)
scaled_avg_loss.backward()
scaler.minimize(optimizer, scaled_avg_loss)
else:
avg_loss.backward()
optimizer.step()
optimizer.clear_grad()
......@@ -480,11 +496,6 @@ def preprocess(is_train=False):
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED'
]
windows_not_support_list = ['PSE']
if platform.system() == "Windows" and alg in windows_not_support_list:
logger.warning('{} is not support in Windows now'.format(
windows_not_support_list))
sys.exit()
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)
......
......@@ -102,10 +102,27 @@ def main(config, device, logger, vdl_writer):
if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format(
len(valid_dataloader)))
use_amp = config["Global"].get("use_amp", False)
if use_amp:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
'FLAGS_max_inplace_grad_add': 8,
}
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
scale_loss = config["Global"].get("scale_loss", 1.0)
use_dynamic_loss_scaling = config["Global"].get(
"use_dynamic_loss_scaling", False)
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
else:
scaler = None
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer)
eval_class, pre_best_model_dict, logger, vdl_writer, scaler)
def test_reader(config, device, logger):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册