diff --git a/benchmark/analysis.py b/benchmark/analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..c4189b99d8ee082082a254718617a7e58bebe961 --- /dev/null +++ b/benchmark/analysis.py @@ -0,0 +1,273 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import json +import os +import re +import traceback + + +def parse_args(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--filename", type=str, help="The name of log which need to analysis.") + parser.add_argument( + "--log_with_profiler", type=str, help="The path of train log with profiler") + parser.add_argument( + "--profiler_path", type=str, help="The path of profiler timeline log.") + parser.add_argument( + "--keyword", type=str, help="Keyword to specify analysis data") + parser.add_argument( + "--separator", type=str, default=None, help="Separator of different field in log") + parser.add_argument( + '--position', type=int, default=None, help='The position of data field') + parser.add_argument( + '--range', type=str, default="", help='The range of data field to intercept') + parser.add_argument( + '--base_batch_size', type=int, help='base_batch size on gpu') + parser.add_argument( + '--skip_steps', type=int, default=0, help='The number of steps to be skipped') + parser.add_argument( + '--model_mode', type=int, default=-1, help='Analysis mode, default value is -1') + parser.add_argument( + '--ips_unit', type=str, default=None, help='IPS unit') + parser.add_argument( + '--model_name', type=str, default=0, help='training model_name, transformer_base') + parser.add_argument( + '--mission_name', type=str, default=0, help='training mission name') + parser.add_argument( + '--direction_id', type=int, default=0, help='training direction_id') + parser.add_argument( + '--run_mode', type=str, default="sp", help='multi process or single process') + parser.add_argument( + '--index', type=int, default=1, help='{1: speed, 2:mem, 3:profiler, 6:max_batch_size}') + parser.add_argument( + '--gpu_num', type=int, default=1, help='nums of training gpus') + args = parser.parse_args() + args.separator = None if args.separator == "None" else args.separator + return args + + +def _is_number(num): + pattern = re.compile(r'^[-+]?[-0-9]\d*\.\d*|[-+]?\.?[0-9]\d*$') + result = pattern.match(num) + if result: + return True + else: + return False + + +class TimeAnalyzer(object): + def __init__(self, filename, keyword=None, separator=None, position=None, range="-1"): + if filename is None: + raise Exception("Please specify the filename!") + + if keyword is None: + raise Exception("Please specify the keyword!") + + self.filename = filename + self.keyword = keyword + self.separator = separator + self.position = position + self.range = range + self.records = None + self._distil() + + def _distil(self): + self.records = [] + with open(self.filename, "r") as f_object: + lines = f_object.readlines() + for line in lines: + if self.keyword not in line: + continue + try: + result = None + + # Distil the string from a line. + line = line.strip() + line_words = line.split(self.separator) if self.separator else line.split() + if args.position: + result = line_words[self.position] + else: + # Distil the string following the keyword. + for i in range(len(line_words) - 1): + if line_words[i] == self.keyword: + result = line_words[i + 1] + break + + # Distil the result from the picked string. + if not self.range: + result = result[0:] + elif _is_number(self.range): + result = result[0: int(self.range)] + else: + result = result[int(self.range.split(":")[0]): int(self.range.split(":")[1])] + self.records.append(float(result)) + except Exception as exc: + print("line is: {}; separator={}; position={}".format(line, self.separator, self.position)) + + print("Extract {} records: separator={}; position={}".format(len(self.records), self.separator, self.position)) + + def _get_fps(self, mode, batch_size, gpu_num, avg_of_records, run_mode, unit=None): + if mode == -1 and run_mode == 'sp': + assert unit, "Please set the unit when mode is -1." + fps = gpu_num * avg_of_records + elif mode == -1 and run_mode == 'mp': + assert unit, "Please set the unit when mode is -1." + fps = gpu_num * avg_of_records #temporarily, not used now + print("------------this is mp") + elif mode == 0: + # s/step -> samples/s + fps = (batch_size * gpu_num) / avg_of_records + unit = "samples/s" + elif mode == 1: + # steps/s -> steps/s + fps = avg_of_records + unit = "steps/s" + elif mode == 2: + # s/step -> steps/s + fps = 1 / avg_of_records + unit = "steps/s" + elif mode == 3: + # steps/s -> samples/s + fps = batch_size * gpu_num * avg_of_records + unit = "samples/s" + elif mode == 4: + # s/epoch -> s/epoch + fps = avg_of_records + unit = "s/epoch" + else: + ValueError("Unsupported analysis mode.") + + return fps, unit + + def analysis(self, batch_size, gpu_num=1, skip_steps=0, mode=-1, run_mode='sp', unit=None): + if batch_size <= 0: + print("base_batch_size should larger than 0.") + return 0, '' + + if len(self.records) <= skip_steps: # to address the condition which item of log equals to skip_steps + print("no records") + return 0, '' + + sum_of_records = 0 + sum_of_records_skipped = 0 + skip_min = self.records[skip_steps] + skip_max = self.records[skip_steps] + + count = len(self.records) + for i in range(count): + sum_of_records += self.records[i] + if i >= skip_steps: + sum_of_records_skipped += self.records[i] + if self.records[i] < skip_min: + skip_min = self.records[i] + if self.records[i] > skip_max: + skip_max = self.records[i] + + avg_of_records = sum_of_records / float(count) + avg_of_records_skipped = sum_of_records_skipped / float(count - skip_steps) + + fps, fps_unit = self._get_fps(mode, batch_size, gpu_num, avg_of_records, run_mode, unit) + fps_skipped, _ = self._get_fps(mode, batch_size, gpu_num, avg_of_records_skipped, run_mode, unit) + if mode == -1: + print("average ips of %d steps, skip 0 step:" % count) + print("\tAvg: %.3f %s" % (avg_of_records, fps_unit)) + print("\tFPS: %.3f %s" % (fps, fps_unit)) + if skip_steps > 0: + print("average ips of %d steps, skip %d steps:" % (count, skip_steps)) + print("\tAvg: %.3f %s" % (avg_of_records_skipped, fps_unit)) + print("\tMin: %.3f %s" % (skip_min, fps_unit)) + print("\tMax: %.3f %s" % (skip_max, fps_unit)) + print("\tFPS: %.3f %s" % (fps_skipped, fps_unit)) + elif mode == 1 or mode == 3: + print("average latency of %d steps, skip 0 step:" % count) + print("\tAvg: %.3f steps/s" % avg_of_records) + print("\tFPS: %.3f %s" % (fps, fps_unit)) + if skip_steps > 0: + print("average latency of %d steps, skip %d steps:" % (count, skip_steps)) + print("\tAvg: %.3f steps/s" % avg_of_records_skipped) + print("\tMin: %.3f steps/s" % skip_min) + print("\tMax: %.3f steps/s" % skip_max) + print("\tFPS: %.3f %s" % (fps_skipped, fps_unit)) + elif mode == 0 or mode == 2: + print("average latency of %d steps, skip 0 step:" % count) + print("\tAvg: %.3f s/step" % avg_of_records) + print("\tFPS: %.3f %s" % (fps, fps_unit)) + if skip_steps > 0: + print("average latency of %d steps, skip %d steps:" % (count, skip_steps)) + print("\tAvg: %.3f s/step" % avg_of_records_skipped) + print("\tMin: %.3f s/step" % skip_min) + print("\tMax: %.3f s/step" % skip_max) + print("\tFPS: %.3f %s" % (fps_skipped, fps_unit)) + + return round(fps_skipped, 3), fps_unit + + +if __name__ == "__main__": + args = parse_args() + run_info = dict() + run_info["log_file"] = args.filename + run_info["model_name"] = args.model_name + run_info["mission_name"] = args.mission_name + run_info["direction_id"] = args.direction_id + run_info["run_mode"] = args.run_mode + run_info["index"] = args.index + run_info["gpu_num"] = args.gpu_num + run_info["FINAL_RESULT"] = 0 + run_info["JOB_FAIL_FLAG"] = 0 + + try: + if args.index == 1: + if args.gpu_num == 1: + run_info["log_with_profiler"] = args.log_with_profiler + run_info["profiler_path"] = args.profiler_path + analyzer = TimeAnalyzer(args.filename, args.keyword, args.separator, args.position, args.range) + run_info["FINAL_RESULT"], run_info["UNIT"] = analyzer.analysis( + batch_size=args.base_batch_size, + gpu_num=args.gpu_num, + skip_steps=args.skip_steps, + mode=args.model_mode, + run_mode=args.run_mode, + unit=args.ips_unit) + try: + if int(os.getenv('job_fail_flag')) == 1 or int(run_info["FINAL_RESULT"]) == 0: + run_info["JOB_FAIL_FLAG"] = 1 + except: + pass + elif args.index == 3: + run_info["FINAL_RESULT"] = {} + records_fo_total = TimeAnalyzer(args.filename, 'Framework overhead', None, 3, '').records + records_fo_ratio = TimeAnalyzer(args.filename, 'Framework overhead', None, 5).records + records_ct_total = TimeAnalyzer(args.filename, 'Computation time', None, 3, '').records + records_gm_total = TimeAnalyzer(args.filename, 'GpuMemcpy Calls', None, 4, '').records + records_gm_ratio = TimeAnalyzer(args.filename, 'GpuMemcpy Calls', None, 6).records + records_gmas_total = TimeAnalyzer(args.filename, 'GpuMemcpyAsync Calls', None, 4, '').records + records_gms_total = TimeAnalyzer(args.filename, 'GpuMemcpySync Calls', None, 4, '').records + run_info["FINAL_RESULT"]["Framework_Total"] = records_fo_total[0] if records_fo_total else 0 + run_info["FINAL_RESULT"]["Framework_Ratio"] = records_fo_ratio[0] if records_fo_ratio else 0 + run_info["FINAL_RESULT"]["ComputationTime_Total"] = records_ct_total[0] if records_ct_total else 0 + run_info["FINAL_RESULT"]["GpuMemcpy_Total"] = records_gm_total[0] if records_gm_total else 0 + run_info["FINAL_RESULT"]["GpuMemcpy_Ratio"] = records_gm_ratio[0] if records_gm_ratio else 0 + run_info["FINAL_RESULT"]["GpuMemcpyAsync_Total"] = records_gmas_total[0] if records_gmas_total else 0 + run_info["FINAL_RESULT"]["GpuMemcpySync_Total"] = records_gms_total[0] if records_gms_total else 0 + else: + print("Not support!") + except Exception: + traceback.print_exc() + print("{}".format(json.dumps(run_info))) # it's required, for the log file path insert to the database + diff --git a/benchmark/readme.md b/benchmark/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..7f7704cca5341d495dfbcdc66ddfd29fbea1e1df --- /dev/null +++ b/benchmark/readme.md @@ -0,0 +1,34 @@ + +# PaddleOCR DB/EAST 算法训练benchmark测试 + +PaddleOCR/benchmark目录下的文件用于获取并分析训练日志。 +训练采用icdar2015数据集,包括1000张训练图像和500张测试图像。模型配置采用resnet18_vd作为backbone,分别训练batch_size=8和batch_size=16的情况。 + +## 运行训练benchmark + +benchmark/run_det.sh 中包含了三个过程: +- 安装依赖 +- 下载数据 +- 执行训练 +- 日志分析获取IPS + +在执行训练部分,会执行单机单卡(默认0号卡)单机多卡训练,并分别执行batch_size=8和batch_size=16的情况。所以执行完后,每种模型会得到4个日志文件。 + +run_det.sh 执行方式如下: + +``` +# cd PaddleOCR/ +bash benchmark/run_det.sh +``` + +以DB为例,将得到四个日志文件,如下: +``` +det_res18_db_v2.0_sp_bs16_fp32_1 +det_res18_db_v2.0_sp_bs8_fp32_1 +det_res18_db_v2.0_mp_bs16_fp32_1 +det_res18_db_v2.0_mp_bs8_fp32_1 +``` + + + + diff --git a/benchmark/run_benchmark_det.sh b/benchmark/run_benchmark_det.sh new file mode 100644 index 0000000000000000000000000000000000000000..26bcda5d20ba4e4d0498da28aafb93f29468169d --- /dev/null +++ b/benchmark/run_benchmark_det.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash +set -xe +# 运行示例:CUDA_VISIBLE_DEVICES=0 bash run_benchmark.sh ${run_mode} ${bs_item} ${fp_item} 500 ${model_mode} +# 参数说明 +function _set_params(){ + run_mode=${1:-"sp"} # 单卡sp|多卡mp + batch_size=${2:-"64"} + fp_item=${3:-"fp32"} # fp32|fp16 + max_iter=${4:-"500"} # 可选,如果需要修改代码提前中断 + model_name=${5:-"model_name"} + run_log_path=${TRAIN_LOG_DIR:-$(pwd)} # TRAIN_LOG_DIR 后续QA设置该参数 + +# 以下不用修改 + device=${CUDA_VISIBLE_DEVICES//,/ } + arr=(${device}) + num_gpu_devices=${#arr[*]} + log_file=${run_log_path}/${model_name}_${run_mode}_bs${batch_size}_${fp_item}_${num_gpu_devices} +} +function _train(){ + echo "Train on ${num_gpu_devices} GPUs" + echo "current CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES, gpus=$num_gpu_devices, batch_size=$batch_size" + + train_cmd="-c configs/det/${model_name}.yml -o Train.loader.batch_size_per_card=${batch_size} Global.epoch_num=${max_iter} " + case ${run_mode} in + sp) + train_cmd="python3.7 tools/train.py "${train_cmd}"" + ;; + mp) + train_cmd="python3.7 -m paddle.distributed.launch --log_dir=./mylog --gpus=$CUDA_VISIBLE_DEVICES tools/train.py ${train_cmd}" + ;; + *) echo "choose run_mode(sp or mp)"; exit 1; + esac +# 以下不用修改 + timeout 15m ${train_cmd} > ${log_file} 2>&1 + if [ $? -ne 0 ];then + echo -e "${model_name}, FAIL" + export job_fail_flag=1 + else + echo -e "${model_name}, SUCCESS" + export job_fail_flag=0 + fi + kill -9 `ps -ef|grep 'python3.7'|awk '{print $2}'` + + if [ $run_mode = "mp" -a -d mylog ]; then + rm ${log_file} + cp mylog/workerlog.0 ${log_file} + fi + + # run log analysis + analysis_cmd="python3.7 benchmark/analysis.py --filename ${log_file} --mission_name ${model_name} --run_mode ${mode} --direction_id 0 --keyword 'ips:' --base_batch_size ${batch_szie} --skip_steps 1 --gpu_num ${num_gpu_devices} --index 1 --model_mode=-1 --ips_unit=samples/sec" + eval $analysis_cmd +} + +_set_params $@ +_train + diff --git a/benchmark/run_det.sh b/benchmark/run_det.sh new file mode 100644 index 0000000000000000000000000000000000000000..c507510c615a60177e07300976947b010dbae990 --- /dev/null +++ b/benchmark/run_det.sh @@ -0,0 +1,28 @@ +# 提供可稳定复现性能的脚本,默认在标准docker环境内py37执行: paddlepaddle/paddle:latest-gpu-cuda10.1-cudnn7 paddle=2.1.2 py=37 +# 执行目录: ./PaddleOCR +# 1 安装该模型需要的依赖 (如需开启优化策略请注明) +python3.7 -m pip install -r requirements.txt +# 2 拷贝该模型需要数据、预训练模型 +wget -c -p ./tain_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar && cd train_data && tar xf icdar2015.tar && cd ../ +wget -c -p ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_pretrained.pdparams +# 3 批量运行(如不方便批量,1,2需放到单个模型中) + +model_mode_list=(det_res18_db_v2.0 det_r50_vd_east) +fp_item_list=(fp32) +bs_list=(8 16) +for model_mode in ${model_mode_list[@]}; do + for fp_item in ${fp_item_list[@]}; do + for bs_item in ${bs_list[@]}; do + echo "index is speed, 1gpus, begin, ${model_name}" + run_mode=sp + CUDA_VISIBLE_DEVICES=0 bash benchmark/run_benchmark_det.sh ${run_mode} ${bs_item} ${fp_item} 10 ${model_mode} # (5min) + sleep 60 + echo "index is speed, 8gpus, run_mode is multi_process, begin, ${model_name}" + run_mode=mp + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash benchmark/run_benchmark_det.sh ${run_mode} ${bs_item} ${fp_item} 10 ${model_mode} + sleep 60 + done + done +done + + diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml index 0f08909add17d8c73ad6e1b00e17d4c351def7e5..ab484a44833a405513d7f2b4079a4da4c2e403c8 100644 --- a/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml +++ b/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml @@ -141,6 +141,7 @@ Train: img_mode: BGR channel_first: False - DetLabelEncode: # Class handling label + - CopyPaste: - IaaAugment: augmenter_args: - { 'type': Fliplr, 'args': { 'p': 0.5 } } diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_distill.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_distill.yml index 1159d71bf94c330e26c3009b38c5c2b4a9c96f52..46daeeb86d004772a6fb964d602369dcd53b3a01 100644 --- a/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_distill.yml +++ b/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_distill.yml @@ -68,8 +68,7 @@ Loss: ohem_ratio: 3 - DistillationDBLoss: weight: 1.0 - model_name_list: ["Student", "Teacher"] - # key: maps + model_name_list: ["Student"] name: DBLoss balance_loss: true main_loss_type: DiceLoss @@ -116,6 +115,7 @@ Train: img_mode: BGR channel_first: False - DetLabelEncode: # Class handling label + - CopyPaste: - IaaAugment: augmenter_args: - { 'type': Fliplr, 'args': { 'p': 0.5 } } diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_dml.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_dml.yml index 7fe2d2e1a065b54d0e2479475f5f67ac5e38a166..bfbc3b6268cf521acb035be33ced9141046fc430 100644 --- a/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_dml.yml +++ b/configs/det/ch_PP-OCRv2/ch_PP-OCR_det_dml.yml @@ -118,6 +118,7 @@ Train: img_mode: BGR channel_first: False - DetLabelEncode: # Class handling label + - CopyPaste: - IaaAugment: augmenter_args: - { 'type': Fliplr, 'args': { 'p': 0.5 } } diff --git a/configs/det/det_r50_vd_east.yml b/configs/det/det_r50_vd_east.yml index 0253c5bd9940fa6c0ec7da2c6639c1bc060842ca..e84a5fa7a7af34bde5e0abc6fed2e01f6ce42e6b 100644 --- a/configs/det/det_r50_vd_east.yml +++ b/configs/det/det_r50_vd_east.yml @@ -8,7 +8,7 @@ Global: # evaluation is run every 5000 iterations after the 4000th iteration eval_batch_step: [4000, 5000] cal_metric_during_train: False - pretrained_model: ./pretrain_models/ResNet50_vd_pretrained/ + pretrained_model: ./pretrain_models/ResNet50_vd_pretrained checkpoints: save_inference_dir: use_visualdl: False diff --git a/configs/det/det_res18_db_v2.0.yml b/configs/det/det_res18_db_v2.0.yml new file mode 100644 index 0000000000000000000000000000000000000000..7b07ef99648956a70b5a71f1e61f09b592226f90 --- /dev/null +++ b/configs/det/det_res18_db_v2.0.yml @@ -0,0 +1,131 @@ +Global: + use_gpu: true + epoch_num: 1200 + log_smooth_window: 20 + print_batch_step: 2 + save_model_dir: ./output/ch_db_res18/ + save_epoch_step: 1200 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [3000, 2000] + cal_metric_during_train: False + pretrained_model: ./pretrain_models/ResNet18_vd_pretrained + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./output/det_db/predicts_db.txt + +Architecture: + model_type: det + algorithm: DB + Transform: + Backbone: + name: ResNet + layers: 18 + disable_se: True + Neck: + name: DBFPN + out_channels: 256 + Head: + name: DBHead + k: 50 + +Loss: + name: DBLoss + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.001 + warmup_epoch: 2 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: DBPostProcess + thresh: 0.3 + box_thresh: 0.6 + max_candidates: 1000 + unclip_ratio: 1.5 + +Metric: + name: DetMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [1.0] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - IaaAugment: + augmenter_args: + - { 'type': Fliplr, 'args': { 'p': 0.5 } } + - { 'type': Affine, 'args': { 'rotate': [-10, 10] } } + - { 'type': Resize, 'args': { 'size': [0.5, 3] } } + - EastRandomCropData: + size: [960, 960] + max_tries: 50 + keep_ratio: true + - MakeBorderMap: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + - MakeShrinkMap: + shrink_ratio: 0.4 + min_text_size: 8 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: +# image_shape: [736, 1280] + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml index 4a6e19f4461c7236f3a9a5253437eff97fa72f67..c4c5226e796a42db723ce78ef65473e357c25dc6 100644 --- a/configs/e2e/e2e_r50_vd_pg.yml +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -94,7 +94,7 @@ Eval: label_file_list: [./train_data/total_text/test/test.txt] transforms: - DecodeImage: # load image - img_mode: RGB + img_mode: BGR channel_first: False - E2ELabelEncodeTest: - E2EResizeForTest: @@ -111,4 +111,4 @@ Eval: shuffle: False drop_last: False batch_size_per_card: 1 # must be 1 - num_workers: 2 \ No newline at end of file + num_workers: 2 diff --git a/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml new file mode 100644 index 0000000000000000000000000000000000000000..8b568637a189ac47438b84e89fc55ddc643ab297 --- /dev/null +++ b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml @@ -0,0 +1,126 @@ +Global: + debug: false + use_gpu: true + epoch_num: 800 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec_mobile_pp-OCRv2_enhanced_ctc_loss + save_epoch_step: 3 + eval_batch_step: [0, 2000] + cal_metric_during_train: true + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: false + infer_img: doc/imgs_words/ch/word_1.jpg + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + character_type: ch + max_text_length: 25 + infer_mode: false + use_space_char: true + distributed: true + save_res_path: ./output/rec/predicts_mobile_pp-OCRv2_enhanced_ctc_loss.txt + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Piecewise + decay_epochs : [700, 800] + values : [0.001, 0.0001] + warmup_epoch: 5 + regularizer: + name: L2 + factor: 2.0e-05 + + +Architecture: + model_type: rec + algorithm: CRNN + Transform: + Backbone: + name: MobileNetV1Enhance + scale: 0.5 + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 64 + Head: + name: CTCHead + mid_channels: 96 + fc_decay: 0.00002 + return_feats: true + +Loss: + name: CombinedLoss + loss_config_list: + - CTCLoss: + use_focal_loss: false + weight: 1.0 + - CenterLoss: + weight: 0.05 + num_classes: 6625 + feat_dim: 96 + init_center: false + center_file_path: "./train_center.pkl" + # you can also try to add ace loss on your own dataset + # - ACELoss: + # weight: 0.1 + +PostProcess: + name: CTCLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ + label_file_list: + - ./train_data/train_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - RecAug: + - CTCLabelEncode: + - RecResizeImg: + image_shape: [3, 32, 320] + - KeepKeys: + keep_keys: + - image + - label + - length + - label_ace + loader: + shuffle: true + batch_size_per_card: 128 + drop_last: true + num_workers: 8 +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data + label_file_list: + - ./train_data/val_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - CTCLabelEncode: + - RecResizeImg: + image_shape: [3, 32, 320] + - KeepKeys: + keep_keys: + - image + - label + - length + loader: + shuffle: false + drop_last: false + batch_size_per_card: 128 + num_workers: 8 diff --git a/configs/rec/rec_resnet_stn_bilstm_att.yml b/configs/rec/rec_resnet_stn_bilstm_att.yml new file mode 100644 index 0000000000000000000000000000000000000000..1f6e534a6878a7ae84fc7fa7e1d975077f164d80 --- /dev/null +++ b/configs/rec/rec_resnet_stn_bilstm_att.yml @@ -0,0 +1,109 @@ +Global: + use_gpu: True + epoch_num: 400 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/seed + save_epoch_step: 3 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words_en/word_10.png + # for data or label process + character_dict_path: + character_type: EN_symbol + max_text_length: 100 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/predicts_seed.txt + + +Optimizer: + name: Adadelta + weight_deacy: 0.0 + momentum: 0.9 + lr: + name: Piecewise + decay_epochs: [4,5,8] + values: [1.0, 0.1, 0.01] + regularizer: + name: 'L2' + factor: 2.0e-05 + + +Architecture: + model_type: rec + algorithm: SEED + Transform: + name: STN_ON + tps_inputsize: [32, 64] + tps_outputsize: [32, 100] + num_control_points: 20 + tps_margins: [0.05,0.05] + stn_activation: none + Backbone: + name: ResNet_ASTER + Head: + name: AsterHead # AttentionHead + sDim: 512 + attDim: 512 + max_len_labels: 100 + +Loss: + name: AsterLoss + +PostProcess: + name: SEEDLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + is_filter: True + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/training/ + transforms: + - Fasttext: + path: "./cc.en.300.bin" + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SEEDLabelEncode: # Class handling label + - RecResizeImg: + character_type: en + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length', 'fast_label'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 256 + drop_last: True + num_workers: 6 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/evaluation/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SEEDLabelEncode: # Class handling label + - RecResizeImg: + character_type: en + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: True + batch_size_per_card: 256 + num_workers: 4 diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp index b64dcea5ae2a68485296c02cdb7689c60ea504f8..3739a66ad802fd108df16bbbbe8c8695963b7693 100644 --- a/deploy/cpp_infer/src/ocr_rec.cpp +++ b/deploy/cpp_infer/src/ocr_rec.cpp @@ -112,12 +112,16 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { 1 << 20, 10, 3, precision, false, false); + std::map> min_input_shape = { - {"x", {1, 3, 32, 10}}}; + {"x", {1, 3, 32, 10}}, + {"lstm_0.tmp_0", {10, 1, 96}}}; std::map> max_input_shape = { - {"x", {1, 3, 32, 2000}}}; + {"x", {1, 3, 32, 2000}}, + {"lstm_0.tmp_0", {1000, 1, 96}}}; std::map> opt_input_shape = { - {"x", {1, 3, 32, 320}}}; + {"x", {1, 3, 32, 320}}, + {"lstm_0.tmp_0", {25, 1, 96}}}; config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, opt_input_shape); @@ -139,7 +143,7 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { config.SwitchIrOptim(true); config.EnableMemoryOptim(); - config.DisableGlogInfo(); +// config.DisableGlogInfo(); this->predictor_ = CreatePredictor(config); } diff --git a/deploy/slim/prune/sensitivity_anal.py b/deploy/slim/prune/sensitivity_anal.py index f80ddd9fb424a4b0c8f30304530d7d20f982598a..0f0492af2f57eea9b9c1d13ec5ee1dad9fc2f1bc 100644 --- a/deploy/slim/prune/sensitivity_anal.py +++ b/deploy/slim/prune/sensitivity_anal.py @@ -110,25 +110,42 @@ def main(config, device, logger, vdl_writer): logger.info("metric['hmean']: {}".format(metric['hmean'])) return metric['hmean'] - params_sensitive = pruner.sensitive( - eval_func=eval_fn, - sen_file="./sen.pickle", - skip_vars=[ - "conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0" - ]) - - logger.info( - "The sensitivity analysis results of model parameters saved in sen.pickle" - ) - # calculate pruned params's ratio - params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02) - for key in params_sensitive.keys(): - logger.info("{}, {}".format(key, params_sensitive[key])) - - #params_sensitive = {} - #for param in model.parameters(): - # if 'transpose' not in param.name and 'linear' not in param.name: - # params_sensitive[param.name] = 0.1 + run_sensitive_analysis = False + """ + run_sensitive_analysis=True: + Automatically compute the sensitivities of convolutions in a model. + The sensitivity of a convolution is the losses of accuracy on test dataset in + differenct pruned ratios. The sensitivities can be used to get a group of best + ratios with some condition. + + run_sensitive_analysis=False: + Set prune trim ratio to a fixed value, such as 10%. The larger the value, + the more convolution weights will be cropped. + + """ + + if run_sensitive_analysis: + params_sensitive = pruner.sensitive( + eval_func=eval_fn, + sen_file="./deploy/slim/prune/sen.pickle", + skip_vars=[ + "conv2d_57.w_0", "conv2d_transpose_2.w_0", + "conv2d_transpose_3.w_0" + ]) + logger.info( + "The sensitivity analysis results of model parameters saved in sen.pickle" + ) + # calculate pruned params's ratio + params_sensitive = pruner._get_ratios_by_loss( + params_sensitive, loss=0.02) + for key in params_sensitive.keys(): + logger.info("{}, {}".format(key, params_sensitive[key])) + else: + params_sensitive = {} + for param in model.parameters(): + if 'transpose' not in param.name and 'linear' not in param.name: + # set prune ratio as 10%. The larger the value, the more convolution weights will be cropped + params_sensitive[param.name] = 0.1 plan = pruner.prune_vars(params_sensitive, [0]) diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index 6daacaf7f08e5bd0c5e5aba8b435c5bd26eaca7b..af883de86c798babe6ca1616710c0e13546e1045 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -50,6 +50,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: - [x] SRN([paper](https://arxiv.org/abs/2003.12294)) - [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2)) - [x] SAR([paper](https://arxiv.org/abs/1811.00751v2)) +- [x] SEED([paper](https://arxiv.org/pdf/2005.10977.pdf)) 参考[DTRB](https://arxiv.org/abs/1904.01906) 文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -66,5 +67,5 @@ PaddleOCR基于动态图开源的文本识别算法列表: |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) | |NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | |SAR|Resnet31| 87.2% | rec_r31_sar | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | - +|SEED| Aster_Resnet | 85.2% | rec_resnet_stn_bilstm_att | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar)| PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。 diff --git a/doc/doc_ch/enhanced_ctc_loss.md b/doc/doc_ch/enhanced_ctc_loss.md new file mode 100644 index 0000000000000000000000000000000000000000..309dc712dc0242b859f934338be96e6648f81031 --- /dev/null +++ b/doc/doc_ch/enhanced_ctc_loss.md @@ -0,0 +1,78 @@ +# Enhanced CTC Loss + +在OCR识别中, CRNN是一种在工业界广泛使用的文字识别算法。 在训练阶段,其采用CTCLoss来计算网络损失; 在推理阶段,其采用CTCDecode来获得解码结果。虽然CRNN算法在实际业务中被证明能够获得很好的识别效果, 然而用户对识别准确率的要求却是无止境的,如何进一步提升文字识别的准确率呢? 本文以CTCLoss为切人点,分别从难例挖掘、 多任务学习、 Metric Learning 3个不同的角度探索了CTCLoss的改进融合方案,提出了EnhancedCTCLoss,其包括如下3个组成部分: Focal-CTC Loss,A-CTC Loss, C-CTC Loss。 + +## 1. Focal-CTC Loss +Focal Loss 出自论文《Focal Loss for Dense Object Detection》, 该loss最先提出的时候主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题。该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘。 +其损失函数形式如下: +
+ +
+ +其中, y' 是经过激活函数的输出,取值在0-1之间。其在原始的交叉熵损失的基础上加了一个调制系数(1 – y’)^ γ和平衡因子α。 当α = 1,y=1时,其损失函数与交叉熵损失的对比如下图所示: +
+ +
+ +从上图可以看到, 当γ> 0时,调整系数(1-y’)^γ 赋予易分类样本损失一个更小的权重,使得网络更关注于困难的、错分的样本。 调整因子γ用于调节简单样本权重降低的速率,当γ为0时即为交叉熵损失函数,当γ增加时,调整因子的影响也会随之增大。实验发现γ为2是最优。平衡因子α用来平衡正负样本本身的比例不均,文中α取0.25。 + +对于经典的CTC算法,假设某个特征序列(f1, f2, ......ft), 经过CTC解码之后结果等于label的概率为y’, 则CTC解码结果不为label的概率即为(1-y’);不难发现 CTCLoss值和y’有如下关系: +
+ +
+ +结合Focal Loss的思想,赋予困难样本较大的权重,简单样本较小的权重,可以使网络更加聚焦于对困难样本的挖掘,进一步提升识别的准确率,由此我们提出了Focal-CTC Loss; 其定义如下所示: +
+ +
+ +实验中,γ取值为2, α= 1, 具体实现见: [rec_ctc_loss.py](../../ppocr/losses/rec_ctc_loss.py) + +## 2. A-CTC Loss +A-CTC Loss是CTC Loss + ACE Loss的简称。 其中ACE Loss出自论文< Aggregation Cross-Entropy for Sequence Recognition>. ACE Loss相比于CTCLoss,主要有如下两点优势: ++ ACE Loss能够解决2-D文本的识别问题; CTCLoss只能够处理1-D文本 ++ ACE Loss 在时间复杂度和空间复杂度上优于CTC loss + +前人总结的OCR识别算法的优劣如下图所示: +
+ +
+ +虽然ACELoss确实如上图所说,可以处理2D预测,在内存占用及推理速度方面具备优势,但在实践过程中,我们发现单独使用ACE Loss, 识别效果并不如CTCLoss. 因此,我们尝试将CTCLoss和ACELoss进行组合,同时以CTCLoss为主,将ACELoss 定位为一个辅助监督loss。 这一尝试收到了效果,在我们内部的实验数据集上,相比单独使用CTCLoss,识别准确率可以提升1%左右。 +A_CTC Loss定义如下: +
+ +
+ +实验中,λ = 0.1. ACE loss实现代码见: [ace_loss.py](../../ppocr/losses/ace_loss.py) + +## 3. C-CTC Loss +C-CTC Loss是CTC Loss + Center Loss的简称。 其中Center Loss出自论文 < A Discriminative Feature Learning Approach for Deep Face Recognition>. 最早用于人脸识别任务,用于增大累间距离,减小类内距离, 是Metric Learning领域一种较早的、也比较常用的一种算法。 +在中文OCR识别任务中,通过对badcase分析, 我们发现中文识别的一大难点是相似字符多,容易误识。 由此我们想到是否可以借鉴Metric Learing的想法, 增大相似字符的类间距,从而提高识别准确率。然而,MetricLearning主要用于图像识别领域,训练数据的标签为一个固定的值;而对于OCR识别来说,其本质上是一个序列识别任务,特征和label之间并不具有显式的对齐关系,因此两者如何结合依然是一个值得探索的方向。 +通过尝试Arcmargin, Cosmargin等方法, 我们最终发现Centerloss 有助于进一步提升识别的准确率。C_CTC Loss定义如下: +
+ +
+ +实验中,我们设置λ=0.25. center_loss实现代码见: [center_loss.py](../../ppocr/losses/center_loss.py) + +值得一提的是, 在C-CTC Loss中,选择随机初始化Center并不能够带来明显的提升. 我们的Center初始化方法如下: ++ 基于原始的CTCLoss, 训练得到一个网络N ++ 挑选出训练集中,识别完全正确的部分, 组成集合G ++ 将G中的每个样本送入网络,进行前向计算, 提取最后一个FC层的输入(即feature)及其经过argmax计算的结果(即index)之间的对应关系 ++ 将相同index的feature进行聚合,计算平均值,得到各自字符的初始center. + +以配置文件`configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml`为例, center提取命令如下所示: +``` +python tools/export_center.py -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml -o Global.pretrained_model: "./output/rec_mobile_pp-OCRv2/best_accuracy" +``` +运行完后,会在PaddleOCR主目录下生成`train_center.pkl`. + +## 4. 实验 +对于上述的三种方案,我们基于百度内部数据集进行了训练、评测,实验情况如下表所示: +|algorithm| Focal_CTC | A_CTC | C-CTC | +|:------| :------| ------: | :------: | +|gain| +0.3% | +0.7% | +1.7% | + +基于上述实验结论,我们在PP-OCRv2中,采用了C-CTC的策略。 值得一提的是,由于PP-OCRv2 处理的是6625个中文字符的识别任务,字符集比较大,形似字较多,所以在该任务上C-CTC 方案带来的提升较大。 但如果换做其他OCR识别任务,结论可能会有所不同。大家可以尝试Focal-CTC,A-CTC, C-CTC以及组合方案EnhancedCTC,相信会带来不同程度的提升效果。 +统一的融合方案见如下文件: [rec_enhanced_ctc_loss.py](../../ppocr/losses/rec_enhanced_ctc_loss.py) diff --git a/doc/doc_ch/equation_a_ctc.png b/doc/doc_ch/equation_a_ctc.png new file mode 100644 index 0000000000000000000000000000000000000000..ae097610d37a88e76edefdbeb81df8403e94215f Binary files /dev/null and b/doc/doc_ch/equation_a_ctc.png differ diff --git a/doc/doc_ch/equation_c_ctc.png b/doc/doc_ch/equation_c_ctc.png new file mode 100644 index 0000000000000000000000000000000000000000..67207a9937481f4920af3cbafbe1bfe8d27ee5dc Binary files /dev/null and b/doc/doc_ch/equation_c_ctc.png differ diff --git a/doc/doc_ch/equation_ctcloss.png b/doc/doc_ch/equation_ctcloss.png new file mode 100644 index 0000000000000000000000000000000000000000..33ad92c9e4567d2a4a0c8fc3b2a0bf3fba5ea8f2 Binary files /dev/null and b/doc/doc_ch/equation_ctcloss.png differ diff --git a/doc/doc_ch/equation_focal_ctc.png b/doc/doc_ch/equation_focal_ctc.png new file mode 100644 index 0000000000000000000000000000000000000000..6ba1e8715d5876705ef429e48b5c94388fd41398 Binary files /dev/null and b/doc/doc_ch/equation_focal_ctc.png differ diff --git a/doc/doc_ch/focal_loss_formula.png b/doc/doc_ch/focal_loss_formula.png new file mode 100644 index 0000000000000000000000000000000000000000..971cebcd082cf5e19f9246f02216c0c14896bdc9 Binary files /dev/null and b/doc/doc_ch/focal_loss_formula.png differ diff --git a/doc/doc_ch/focal_loss_image.png b/doc/doc_ch/focal_loss_image.png new file mode 100644 index 0000000000000000000000000000000000000000..430550a732d4e2769151771bc85ae889dfc78fda Binary files /dev/null and b/doc/doc_ch/focal_loss_image.png differ diff --git a/doc/doc_ch/rec_algo_compare.png b/doc/doc_ch/rec_algo_compare.png new file mode 100644 index 0000000000000000000000000000000000000000..2dde496c75f327ca1c0c9ccb0dbe6949215a4a1b Binary files /dev/null and b/doc/doc_ch/rec_algo_compare.png differ diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md index c9fd63f71f143a2aa08a04c86f8e738163852c9e..52f978a734cbc750f4e2f36bb3ff28b2e67ab612 100644 --- a/doc/doc_ch/recognition.md +++ b/doc/doc_ch/recognition.md @@ -234,6 +234,9 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t | rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | | rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder | | rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder | +| rec_resnet_stn_bilstm_att.yml | SEED | Aster_Resnet | STN | BiLSTM | att | + +*其中SEED模型需要额外加载FastText训练好的[语言模型](https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz) 训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件: @@ -460,5 +463,3 @@ python3 tools/export_model.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_trai ``` python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path" ``` - - diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index d7b47a8ac8beac684192cd8245e519fd1f600e6b..ebf52ec4e1d8713fd4da407318b14e682952606d 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -215,6 +215,11 @@ class CTCLabelEncode(BaseRecLabelEncode): data['length'] = np.array(len(text)) text = text + [0] * (self.max_text_len - len(text)) data['label'] = np.array(text) + + label = [0] * len(self.character) + for x in text: + label[x] += 1 + data['label_ace'] = np.array(label) return data def add_special_char(self, dict_character): @@ -342,6 +347,38 @@ class AttnLabelEncode(BaseRecLabelEncode): return idx +class SEEDLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + character_type='ch', + use_space_char=False, + **kwargs): + super(SEEDLabelEncode, + self).__init__(max_text_length, character_dict_path, + character_type, use_space_char) + + def add_special_char(self, dict_character): + self.end_str = "eos" + dict_character = dict_character + [self.end_str] + return dict_character + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + if len(text) >= self.max_text_len: + return None + data['length'] = np.array(len(text)) + 1 # conclude eos + text = text + [len(self.character) - 1] * (self.max_text_len - len(text) + ) + data['label'] = np.array(text) + return data + + class SRNLabelEncode(BaseRecLabelEncode): """ Convert between text-label and text-index """ @@ -421,7 +458,6 @@ class TableLabelEncode(object): substr = lines[0].decode('utf-8').strip("\r\n").split("\t") character_num = int(substr[0]) elem_num = int(substr[1]) - for cno in range(1, 1 + character_num): character = lines[cno].decode('utf-8').strip("\r\n") list_character.append(character) diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index bbdf8f9647999304ea586b4089cd99589ac1fb94..87e3088d07a8c5a2eea5d4deff87c69a753e215b 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -23,6 +23,7 @@ import sys import six import cv2 import numpy as np +import fasttext class DecodeImage(object): @@ -83,12 +84,13 @@ class NRTRDecodeImage(object): elif self.img_mode == 'RGB': assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape) img = img[:, :, ::-1] - img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if self.channel_first: img = img.transpose((2, 0, 1)) data['image'] = img return data + class NormalizeImage(object): """ normalize image such as substract mean, divide std """ @@ -133,6 +135,17 @@ class ToCHWImage(object): return data +class Fasttext(object): + def __init__(self, path="None", **kwargs): + self.fast_model = fasttext.load_model(path) + + def __call__(self, data): + label = data['label'] + fast_label = self.fast_model[label] + data['fast_label'] = fast_label + return data + + class KeepKeys(object): def __init__(self, keep_keys, **kwargs): self.keep_keys = keep_keys diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 2c6302386579df7024f6f8570870967b2483f283..71ed8976db7de24a489d1f75612a9a9a67995ba2 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -88,17 +88,19 @@ class RecResizeImg(object): image_shape, infer_mode=False, character_type='ch', + padding=True, **kwargs): self.image_shape = image_shape self.infer_mode = infer_mode self.character_type = character_type + self.padding = padding def __call__(self, data): img = data['image'] if self.infer_mode and self.character_type == "ch": norm_img = resize_norm_img_chinese(img, self.image_shape) else: - norm_img = resize_norm_img(img, self.image_shape) + norm_img = resize_norm_img(img, self.image_shape, self.padding) data['image'] = norm_img return data @@ -174,16 +176,21 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): return padding_im, resize_shape, pad_shape, valid_ratio -def resize_norm_img(img, image_shape): +def resize_norm_img(img, image_shape, padding=True): imgC, imgH, imgW = image_shape h = img.shape[0] w = img.shape[1] - ratio = w / float(h) - if math.ceil(imgH * ratio) > imgW: + if not padding: + resized_image = cv2.resize( + img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) resized_w = imgW else: - resized_w = int(math.ceil(imgH * ratio)) - resized_image = cv2.resize(img, (resized_w, imgH)) + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) resized_image = resized_image.astype('float32') if image_shape[0] == 1: resized_image = resized_image / 255 diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 467d7b6579cef5eca2532cd30df43169184967fc..f3f4cd49332b605ec3a0e65e688d965fd91a5cdf 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -28,6 +28,8 @@ from .rec_att_loss import AttentionLoss from .rec_srn_loss import SRNLoss from .rec_nrtr_loss import NRTRLoss from .rec_sar_loss import SARLoss +from .rec_aster_loss import AsterLoss + # cls loss from .cls_loss import ClsLoss @@ -48,9 +50,8 @@ def build_loss(config): support_dict = [ 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', - 'TableAttentionLoss', 'SARLoss' + 'TableAttentionLoss', 'SARLoss', 'AsterLoss' ] - config = copy.deepcopy(config) module_name = config.pop('name') assert module_name in support_dict, Exception('loss only support {}'.format( diff --git a/ppocr/losses/ace_loss.py b/ppocr/losses/ace_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9c868520e5bd7b398c7f248c416b70427baee0a6 --- /dev/null +++ b/ppocr/losses/ace_loss.py @@ -0,0 +1,50 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn + + +class ACELoss(nn.Layer): + def __init__(self, **kwargs): + super().__init__() + self.loss_func = nn.CrossEntropyLoss( + weight=None, + ignore_index=0, + reduction='none', + soft_label=True, + axis=-1) + + def __call__(self, predicts, batch): + if isinstance(predicts, (list, tuple)): + predicts = predicts[-1] + B, N = predicts.shape[:2] + div = paddle.to_tensor([N]).astype('float32') + + predicts = nn.functional.softmax(predicts, axis=-1) + aggregation_preds = paddle.sum(predicts, axis=1) + aggregation_preds = paddle.divide(aggregation_preds, div) + + length = batch[2].astype("float32") + batch = batch[3].astype("float32") + batch[:, 0] = paddle.subtract(div, length) + + batch = paddle.divide(batch, div) + + loss = self.loss_func(aggregation_preds, batch) + + return {"loss_ace": loss} diff --git a/ppocr/losses/center_loss.py b/ppocr/losses/center_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..72149df19f9e9a864ec31239177dc574648da3d5 --- /dev/null +++ b/ppocr/losses/center_loss.py @@ -0,0 +1,89 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import pickle + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class CenterLoss(nn.Layer): + """ + Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. + """ + + def __init__(self, + num_classes=6625, + feat_dim=96, + init_center=False, + center_file_path=None): + super().__init__() + self.num_classes = num_classes + self.feat_dim = feat_dim + self.centers = paddle.randn( + shape=[self.num_classes, self.feat_dim]).astype( + "float64") #random center + + if init_center: + assert os.path.exists( + center_file_path + ), f"center path({center_file_path}) must exist when init_center is set as True." + with open(center_file_path, 'rb') as f: + char_dict = pickle.load(f) + for key in char_dict.keys(): + self.centers[key] = paddle.to_tensor(char_dict[key]) + + def __call__(self, predicts, batch): + assert isinstance(predicts, (list, tuple)) + features, predicts = predicts + + feats_reshape = paddle.reshape( + features, [-1, features.shape[-1]]).astype("float64") + label = paddle.argmax(predicts, axis=2) + label = paddle.reshape(label, [label.shape[0] * label.shape[1]]) + + batch_size = feats_reshape.shape[0] + + #calc feat * feat + dist1 = paddle.sum(paddle.square(feats_reshape), axis=1, keepdim=True) + dist1 = paddle.expand(dist1, [batch_size, self.num_classes]) + + #dist2 of centers + dist2 = paddle.sum(paddle.square(self.centers), axis=1, + keepdim=True) #num_classes + dist2 = paddle.expand(dist2, + [self.num_classes, batch_size]).astype("float64") + dist2 = paddle.transpose(dist2, [1, 0]) + + #first x * x + y * y + distmat = paddle.add(dist1, dist2) + tmp = paddle.matmul(feats_reshape, + paddle.transpose(self.centers, [1, 0])) + distmat = distmat - 2.0 * tmp + + #generate the mask + classes = paddle.arange(self.num_classes).astype("int64") + label = paddle.expand( + paddle.unsqueeze(label, 1), (batch_size, self.num_classes)) + mask = paddle.equal( + paddle.expand(classes, [batch_size, self.num_classes]), + label).astype("float64") #get mask + dist = paddle.multiply(distmat, mask) + loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size + return {'loss_center': loss} diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py index f3bb36cf5ac751e6c27e4aa29a46fc5f913f7d05..72f706e37d6eb0c640cc30de80afe00bce82fd13 100644 --- a/ppocr/losses/combined_loss.py +++ b/ppocr/losses/combined_loss.py @@ -15,6 +15,10 @@ import paddle import paddle.nn as nn +from .rec_ctc_loss import CTCLoss +from .center_loss import CenterLoss +from .ace_loss import ACELoss + from .distillation_loss import DistillationCTCLoss from .distillation_loss import DistillationDMLLoss from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index 73d3ae2ad2499607f897a102f6ea25e4cb7f297f..06aa7fa8458a5deece75f1393fe7300e8227d3ca 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -112,7 +112,7 @@ class DistillationDMLLoss(DMLLoss): if isinstance(loss, dict): for key in loss: loss_dict["{}_{}_{}_{}_{}".format(key, pair[ - 0], pair[1], map_name, idx)] = loss[key] + 0], pair[1], self.maps_name, idx)] = loss[key] else: loss_dict["{}_{}_{}".format(self.name, self.maps_name[ _c], idx)] = loss diff --git a/ppocr/losses/rec_aster_loss.py b/ppocr/losses/rec_aster_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..fbb99d29a638540b02649a8912051339c08b22dd --- /dev/null +++ b/ppocr/losses/rec_aster_loss.py @@ -0,0 +1,99 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn + + +class CosineEmbeddingLoss(nn.Layer): + def __init__(self, margin=0.): + super(CosineEmbeddingLoss, self).__init__() + self.margin = margin + self.epsilon = 1e-12 + + def forward(self, x1, x2, target): + similarity = paddle.fluid.layers.reduce_sum( + x1 * x2, dim=-1) / (paddle.norm( + x1, axis=-1) * paddle.norm( + x2, axis=-1) + self.epsilon) + one_list = paddle.full_like(target, fill_value=1) + out = paddle.fluid.layers.reduce_mean( + paddle.where( + paddle.equal(target, one_list), 1. - similarity, + paddle.maximum( + paddle.zeros_like(similarity), similarity - self.margin))) + + return out + + +class AsterLoss(nn.Layer): + def __init__(self, + weight=None, + size_average=True, + ignore_index=-100, + sequence_normalize=False, + sample_normalize=True, + **kwargs): + super(AsterLoss, self).__init__() + self.weight = weight + self.size_average = size_average + self.ignore_index = ignore_index + self.sequence_normalize = sequence_normalize + self.sample_normalize = sample_normalize + self.loss_sem = CosineEmbeddingLoss() + self.is_cosin_loss = True + self.loss_func_rec = nn.CrossEntropyLoss(weight=None, reduction='none') + + def forward(self, predicts, batch): + targets = batch[1].astype("int64") + label_lengths = batch[2].astype('int64') + sem_target = batch[3].astype('float32') + embedding_vectors = predicts['embedding_vectors'] + rec_pred = predicts['rec_pred'] + + if not self.is_cosin_loss: + sem_loss = paddle.sum(self.loss_sem(embedding_vectors, sem_target)) + else: + label_target = paddle.ones([embedding_vectors.shape[0]]) + sem_loss = paddle.sum( + self.loss_sem(embedding_vectors, sem_target, label_target)) + + # rec loss + batch_size, def_max_length = targets.shape[0], targets.shape[1] + + mask = paddle.zeros([batch_size, def_max_length]) + for i in range(batch_size): + mask[i, :label_lengths[i]] = 1 + mask = paddle.cast(mask, "float32") + max_length = max(label_lengths) + assert max_length == rec_pred.shape[1] + targets = targets[:, :max_length] + mask = mask[:, :max_length] + rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[2]]) + input = nn.functional.log_softmax(rec_pred, axis=1) + targets = paddle.reshape(targets, [-1, 1]) + mask = paddle.reshape(mask, [-1, 1]) + output = -paddle.index_sample(input, index=targets) * mask + output = paddle.sum(output) + if self.sequence_normalize: + output = output / paddle.sum(mask) + if self.sample_normalize: + output = output / batch_size + + loss = output + sem_loss * 0.1 + return {'loss': loss} diff --git a/ppocr/losses/rec_ctc_loss.py b/ppocr/losses/rec_ctc_loss.py index 6c0b56ff84db4ff23786fb781d461bf9fbc86ef2..063d68e30861e092e10fa3068e4b7f4755b6197f 100755 --- a/ppocr/losses/rec_ctc_loss.py +++ b/ppocr/losses/rec_ctc_loss.py @@ -21,16 +21,24 @@ from paddle import nn class CTCLoss(nn.Layer): - def __init__(self, **kwargs): + def __init__(self, use_focal_loss=False, **kwargs): super(CTCLoss, self).__init__() self.loss_func = nn.CTCLoss(blank=0, reduction='none') + self.use_focal_loss = use_focal_loss def forward(self, predicts, batch): + if isinstance(predicts, (list, tuple)): + predicts = predicts[-1] predicts = predicts.transpose((1, 0, 2)) N, B, _ = predicts.shape preds_lengths = paddle.to_tensor([N] * B, dtype='int64') labels = batch[1].astype("int32") label_lengths = batch[2].astype('int64') loss = self.loss_func(predicts, labels, preds_lengths, label_lengths) - loss = loss.mean() # sum + if self.use_focal_loss: + weight = paddle.exp(-loss) + weight = paddle.subtract(paddle.to_tensor([1.0]), weight) + weight = paddle.square(weight) + loss = paddle.multiply(loss, weight) + loss = loss.mean() return {'loss': loss} diff --git a/ppocr/losses/rec_enhanced_ctc_loss.py b/ppocr/losses/rec_enhanced_ctc_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b57be6468e2ec75811442e7449525267e7d9e82e --- /dev/null +++ b/ppocr/losses/rec_enhanced_ctc_loss.py @@ -0,0 +1,70 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +from .ace_loss import ACELoss +from .center_loss import CenterLoss +from .rec_ctc_loss import CTCLoss + + +class EnhancedCTCLoss(nn.Layer): + def __init__(self, + use_focal_loss=False, + use_ace_loss=False, + ace_loss_weight=0.1, + use_center_loss=False, + center_loss_weight=0.05, + num_classes=6625, + feat_dim=96, + init_center=False, + center_file_path=None, + **kwargs): + super(EnhancedCTCLoss, self).__init__() + self.ctc_loss_func = CTCLoss(use_focal_loss=use_focal_loss) + + self.use_ace_loss = False + if use_ace_loss: + self.use_ace_loss = use_ace_loss + self.ace_loss_func = ACELoss() + self.ace_loss_weight = ace_loss_weight + + self.use_center_loss = False + if use_center_loss: + self.use_center_loss = use_center_loss + self.center_loss_func = CenterLoss( + num_classes=num_classes, + feat_dim=feat_dim, + init_center=init_center, + center_file_path=center_file_path) + self.center_loss_weight = center_loss_weight + + def __call__(self, predicts, batch): + loss = self.ctc_loss_func(predicts, batch)["loss"] + + if self.use_center_loss: + center_loss = self.center_loss_func( + predicts, batch)["loss_center"] * self.center_loss_weight + loss = loss + center_loss + + if self.use_ace_loss: + ace_loss = self.ace_loss_func( + predicts, batch)["loss_ace"] * self.ace_loss_weight + loss = loss + ace_loss + + return {'enhanced_ctc_loss': loss} diff --git a/ppocr/losses/rec_sar_loss.py b/ppocr/losses/rec_sar_loss.py index 9e1c6495fb5ffa274662a3d1031b174d7297ee30..c8bd8bb0ca395fa4658e57b8dcac52a3e94aadce 100644 --- a/ppocr/losses/rec_sar_loss.py +++ b/ppocr/losses/rec_sar_loss.py @@ -9,11 +9,14 @@ from paddle import nn class SARLoss(nn.Layer): def __init__(self, **kwargs): super(SARLoss, self).__init__() - self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=96) + self.loss_func = paddle.nn.loss.CrossEntropyLoss( + reduction="mean", ignore_index=92) def forward(self, predicts, batch): - predict = predicts[:, :-1, :] # ignore last index of outputs to be in same seq_len with targets - label = batch[1].astype("int64")[:, 1:] # ignore first index of target in loss calculation + predict = predicts[:, : + -1, :] # ignore last index of outputs to be in same seq_len with targets + label = batch[1].astype( + "int64")[:, 1:] # ignore first index of target in loss calculation batch_size, num_steps, num_classes = predict.shape[0], predict.shape[ 1], predict.shape[2] assert len(label.shape) == len(list(predict.shape)) - 1, \ diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index 3e82fe756b1e42d6a77192af41e27e0fe30a86ab..db2f41c3a140ecebc42b71ee03f0ecb5cf50ca80 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -13,13 +13,20 @@ # limitations under the License. import Levenshtein +import string class RecMetric(object): - def __init__(self, main_indicator='acc', **kwargs): + def __init__(self, main_indicator='acc', is_filter=False, **kwargs): self.main_indicator = main_indicator + self.is_filter = is_filter self.reset() + def _normalize_text(self, text): + text = ''.join( + filter(lambda x: x in (string.digits + string.ascii_letters), text)) + return text.lower() + def __call__(self, pred_label, *args, **kwargs): preds, labels = pred_label correct_num = 0 @@ -28,6 +35,9 @@ class RecMetric(object): for (pred, pred_conf), (target, _) in zip(preds, labels): pred = pred.replace(" ", "") target = target.replace(" ", "") + if self.is_filter: + pred = self._normalize_text(pred) + target = self._normalize_text(target) norm_edit_dis += Levenshtein.distance(pred, target) / max( len(pred), len(target), 1) if pred == target: @@ -57,4 +67,3 @@ class RecMetric(object): self.correct_num = 0 self.all_num = 0 self.norm_edit_dis = 0 - diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 50438893c7b8511b699f0ebeaf6f4ff46e7d3106..169eb821f110d4a212068ebab4d46d636e241307 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -28,8 +28,10 @@ def build_backbone(config, model_type): from .rec_mv1_enhance import MobileNetV1Enhance from .rec_nrtr_mtb import MTB from .rec_resnet_31 import ResNet31 + from .rec_resnet_aster import ResNet_ASTER support_dict = [ - 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', "ResNet31" + 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', + "ResNet31", "ResNet_ASTER" ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/rec_resnet_aster.py b/ppocr/modeling/backbones/rec_resnet_aster.py new file mode 100644 index 0000000000000000000000000000000000000000..bdecaf46af98f9b967d9a339f82d4e938abdc6d9 --- /dev/null +++ b/ppocr/modeling/backbones/rec_resnet_aster.py @@ -0,0 +1,140 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +import sys +import math + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2D( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias_attr=False) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2D( + in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False) + + +def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000): + # [n_position] + positions = paddle.arange(0, n_position) + # [feat_dim] + dim_range = paddle.arange(0, feat_dim) + dim_range = paddle.pow(wave_length, 2 * (dim_range // 2) / feat_dim) + # [n_position, feat_dim] + angles = paddle.unsqueeze( + positions, axis=1) / paddle.unsqueeze( + dim_range, axis=0) + angles = paddle.cast(angles, "float32") + angles[:, 0::2] = paddle.sin(angles[:, 0::2]) + angles[:, 1::2] = paddle.cos(angles[:, 1::2]) + return angles + + +class AsterBlock(nn.Layer): + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(AsterBlock, self).__init__() + self.conv1 = conv1x1(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2D(planes) + self.relu = nn.ReLU() + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2D(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class ResNet_ASTER(nn.Layer): + """For aster or crnn""" + + def __init__(self, with_lstm=True, n_group=1, in_channels=3): + super(ResNet_ASTER, self).__init__() + self.with_lstm = with_lstm + self.n_group = n_group + + self.layer0 = nn.Sequential( + nn.Conv2D( + in_channels, + 32, + kernel_size=(3, 3), + stride=1, + padding=1, + bias_attr=False), + nn.BatchNorm2D(32), + nn.ReLU()) + + self.inplanes = 32 + self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50] + self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25] + self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25] + self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25] + self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25] + + if with_lstm: + self.rnn = nn.LSTM(512, 256, direction="bidirect", num_layers=2) + self.out_channels = 2 * 256 + else: + self.out_channels = 512 + + def _make_layer(self, planes, blocks, stride): + downsample = None + if stride != [1, 1] or self.inplanes != planes: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes)) + + layers = [] + layers.append(AsterBlock(self.inplanes, planes, stride, downsample)) + self.inplanes = planes + for _ in range(1, blocks): + layers.append(AsterBlock(self.inplanes, planes)) + return nn.Sequential(*layers) + + def forward(self, x): + x0 = self.layer0(x) + x1 = self.layer1(x0) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + x5 = self.layer5(x4) + + cnn_feat = x5.squeeze(2) # [N, c, w] + cnn_feat = paddle.transpose(cnn_feat, perm=[0, 2, 1]) + if self.with_lstm: + rnn_feat, _ = self.rnn(cnn_feat) + return rnn_feat + else: + return cnn_feat diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 20eb150fc3f91c3abd9a3fd8ce655c24ccab8179..fdadfed5e3fe30b6bd311a07d6ba36869f175488 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -29,13 +29,14 @@ def build_head(config): from .rec_srn_head import SRNHead from .rec_nrtr_head import Transformer from .rec_sar_head import SARHead + from .rec_aster_head import AsterHead # cls head from .cls_head import ClsHead support_dict = [ 'DBHead', 'PSEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', - 'TableAttentionHead', 'SARHead' + 'TableAttentionHead', 'SARHead', 'AsterHead' ] #table head diff --git a/ppocr/modeling/heads/rec_aster_head.py b/ppocr/modeling/heads/rec_aster_head.py new file mode 100644 index 0000000000000000000000000000000000000000..4961897b409020fe6cff72eb96f3257156fa33ac --- /dev/null +++ b/ppocr/modeling/heads/rec_aster_head.py @@ -0,0 +1,389 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import paddle +from paddle import nn +from paddle.nn import functional as F + + +class AsterHead(nn.Layer): + def __init__(self, + in_channels, + out_channels, + sDim, + attDim, + max_len_labels, + time_step=25, + beam_width=5, + **kwargs): + super(AsterHead, self).__init__() + self.num_classes = out_channels + self.in_planes = in_channels + self.sDim = sDim + self.attDim = attDim + self.max_len_labels = max_len_labels + self.decoder = AttentionRecognitionHead(in_channels, out_channels, sDim, + attDim, max_len_labels) + self.time_step = time_step + self.embeder = Embedding(self.time_step, in_channels) + self.beam_width = beam_width + self.eos = self.num_classes - 1 + + def forward(self, x, targets=None, embed=None): + return_dict = {} + embedding_vectors = self.embeder(x) + + if self.training: + rec_targets, rec_lengths, _ = targets + rec_pred = self.decoder([x, rec_targets, rec_lengths], + embedding_vectors) + return_dict['rec_pred'] = rec_pred + return_dict['embedding_vectors'] = embedding_vectors + else: + rec_pred, rec_pred_scores = self.decoder.beam_search( + x, self.beam_width, self.eos, embedding_vectors) + return_dict['rec_pred'] = rec_pred + return_dict['rec_pred_scores'] = rec_pred_scores + return_dict['embedding_vectors'] = embedding_vectors + + return return_dict + + +class Embedding(nn.Layer): + def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300): + super(Embedding, self).__init__() + self.in_timestep = in_timestep + self.in_planes = in_planes + self.embed_dim = embed_dim + self.mid_dim = mid_dim + self.eEmbed = nn.Linear( + in_timestep * in_planes, + self.embed_dim) # Embed encoder output to a word-embedding like + + def forward(self, x): + x = paddle.reshape(x, [paddle.shape(x)[0], -1]) + x = self.eEmbed(x) + return x + + +class AttentionRecognitionHead(nn.Layer): + """ + input: [b x 16 x 64 x in_planes] + output: probability sequence: [b x T x num_classes] + """ + + def __init__(self, in_channels, out_channels, sDim, attDim, max_len_labels): + super(AttentionRecognitionHead, self).__init__() + self.num_classes = out_channels # this is the output classes. So it includes the . + self.in_planes = in_channels + self.sDim = sDim + self.attDim = attDim + self.max_len_labels = max_len_labels + + self.decoder = DecoderUnit( + sDim=sDim, xDim=in_channels, yDim=self.num_classes, attDim=attDim) + + def forward(self, x, embed): + x, targets, lengths = x + batch_size = paddle.shape(x)[0] + # Decoder + state = self.decoder.get_initial_state(embed) + outputs = [] + for i in range(max(lengths)): + if i == 0: + y_prev = paddle.full( + shape=[batch_size], fill_value=self.num_classes) + else: + y_prev = targets[:, i - 1] + output, state = self.decoder(x, state, y_prev) + outputs.append(output) + outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1) + return outputs + + # inference stage. + def sample(self, x): + x, _, _ = x + batch_size = x.size(0) + # Decoder + state = paddle.zeros([1, batch_size, self.sDim]) + + predicted_ids, predicted_scores = [], [] + for i in range(self.max_len_labels): + if i == 0: + y_prev = paddle.full( + shape=[batch_size], fill_value=self.num_classes) + else: + y_prev = predicted + + output, state = self.decoder(x, state, y_prev) + output = F.softmax(output, axis=1) + score, predicted = output.max(1) + predicted_ids.append(predicted.unsqueeze(1)) + predicted_scores.append(score.unsqueeze(1)) + predicted_ids = paddle.concat([predicted_ids, 1]) + predicted_scores = paddle.concat([predicted_scores, 1]) + # return predicted_ids.squeeze(), predicted_scores.squeeze() + return predicted_ids, predicted_scores + + def beam_search(self, x, beam_width, eos, embed): + def _inflate(tensor, times, dim): + repeat_dims = [1] * tensor.dim() + repeat_dims[dim] = times + output = paddle.tile(tensor, repeat_dims) + return output + + # https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py + batch_size, l, d = x.shape + x = paddle.tile( + paddle.transpose( + x.unsqueeze(1), perm=[1, 0, 2, 3]), [beam_width, 1, 1, 1]) + inflated_encoder_feats = paddle.reshape( + paddle.transpose( + x, perm=[1, 0, 2, 3]), [-1, l, d]) + + # Initialize the decoder + state = self.decoder.get_initial_state(embed, tile_times=beam_width) + + pos_index = paddle.reshape( + paddle.arange(batch_size) * beam_width, shape=[-1, 1]) + + # Initialize the scores + sequence_scores = paddle.full( + shape=[batch_size * beam_width, 1], fill_value=-float('Inf')) + index = [i * beam_width for i in range(0, batch_size)] + sequence_scores[index] = 0.0 + + # Initialize the input vector + y_prev = paddle.full( + shape=[batch_size * beam_width], fill_value=self.num_classes) + + # Store decisions for backtracking + stored_scores = list() + stored_predecessors = list() + stored_emitted_symbols = list() + + for i in range(self.max_len_labels): + output, state = self.decoder(inflated_encoder_feats, state, y_prev) + state = paddle.unsqueeze(state, axis=0) + log_softmax_output = paddle.nn.functional.log_softmax( + output, axis=1) + + sequence_scores = _inflate(sequence_scores, self.num_classes, 1) + sequence_scores += log_softmax_output + scores, candidates = paddle.topk( + paddle.reshape(sequence_scores, [batch_size, -1]), + beam_width, + axis=1) + + # Reshape input = (bk, 1) and sequence_scores = (bk, 1) + y_prev = paddle.reshape( + candidates % self.num_classes, shape=[batch_size * beam_width]) + sequence_scores = paddle.reshape( + scores, shape=[batch_size * beam_width, 1]) + + # Update fields for next timestep + pos_index = paddle.expand_as(pos_index, candidates) + predecessors = paddle.cast( + candidates / self.num_classes + pos_index, dtype='int64') + predecessors = paddle.reshape( + predecessors, shape=[batch_size * beam_width, 1]) + state = paddle.index_select( + state, index=predecessors.squeeze(), axis=1) + + # Update sequence socres and erase scores for symbol so that they aren't expanded + stored_scores.append(sequence_scores.clone()) + y_prev = paddle.reshape(y_prev, shape=[-1, 1]) + eos_prev = paddle.full_like(y_prev, fill_value=eos) + mask = eos_prev == y_prev + mask = paddle.nonzero(mask) + if mask.dim() > 0: + sequence_scores = sequence_scores.numpy() + mask = mask.numpy() + sequence_scores[mask] = -float('inf') + sequence_scores = paddle.to_tensor(sequence_scores) + + # Cache results for backtracking + stored_predecessors.append(predecessors) + y_prev = paddle.squeeze(y_prev) + stored_emitted_symbols.append(y_prev) + + # Do backtracking to return the optimal values + #====== backtrak ======# + # Initialize return variables given different types + p = list() + l = [[self.max_len_labels] * beam_width for _ in range(batch_size) + ] # Placeholder for lengths of top-k sequences + + # the last step output of the beams are not sorted + # thus they are sorted here + sorted_score, sorted_idx = paddle.topk( + paddle.reshape( + stored_scores[-1], shape=[batch_size, beam_width]), + beam_width) + + # initialize the sequence scores with the sorted last step beam scores + s = sorted_score.clone() + + batch_eos_found = [0] * batch_size # the number of EOS found + # in the backward loop below for each batch + t = self.max_len_labels - 1 + # initialize the back pointer with the sorted order of the last step beams. + # add pos_index for indexing variable with b*k as the first dimension. + t_predecessors = paddle.reshape( + sorted_idx + pos_index.expand_as(sorted_idx), + shape=[batch_size * beam_width]) + while t >= 0: + # Re-order the variables with the back pointer + current_symbol = paddle.index_select( + stored_emitted_symbols[t], index=t_predecessors, axis=0) + t_predecessors = paddle.index_select( + stored_predecessors[t].squeeze(), index=t_predecessors, axis=0) + eos_indices = stored_emitted_symbols[t] == eos + eos_indices = paddle.nonzero(eos_indices) + + if eos_indices.dim() > 0: + for i in range(eos_indices.shape[0] - 1, -1, -1): + # Indices of the EOS symbol for both variables + # with b*k as the first dimension, and b, k for + # the first two dimensions + idx = eos_indices[i] + b_idx = int(idx[0] / beam_width) + # The indices of the replacing position + # according to the replacement strategy noted above + res_k_idx = beam_width - (batch_eos_found[b_idx] % + beam_width) - 1 + batch_eos_found[b_idx] += 1 + res_idx = b_idx * beam_width + res_k_idx + + # Replace the old information in return variables + # with the new ended sequence information + t_predecessors[res_idx] = stored_predecessors[t][idx[0]] + current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]] + s[b_idx, res_k_idx] = stored_scores[t][idx[0], 0] + l[b_idx][res_k_idx] = t + 1 + + # record the back tracked results + p.append(current_symbol) + t -= 1 + + # Sort and re-order again as the added ended sequences may change + # the order (very unlikely) + s, re_sorted_idx = s.topk(beam_width) + for b_idx in range(batch_size): + l[b_idx] = [ + l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :] + ] + + re_sorted_idx = paddle.reshape( + re_sorted_idx + pos_index.expand_as(re_sorted_idx), + [batch_size * beam_width]) + + # Reverse the sequences and re-order at the same time + # It is reversed because the backtracking happens in reverse time order + p = [ + paddle.reshape( + paddle.index_select(step, re_sorted_idx, 0), + shape=[batch_size, beam_width, -1]) for step in reversed(p) + ] + p = paddle.concat(p, -1)[:, 0, :] + return p, paddle.ones_like(p) + + +class AttentionUnit(nn.Layer): + def __init__(self, sDim, xDim, attDim): + super(AttentionUnit, self).__init__() + + self.sDim = sDim + self.xDim = xDim + self.attDim = attDim + + self.sEmbed = nn.Linear(sDim, attDim) + self.xEmbed = nn.Linear(xDim, attDim) + self.wEmbed = nn.Linear(attDim, 1) + + def forward(self, x, sPrev): + batch_size, T, _ = x.shape # [b x T x xDim] + x = paddle.reshape(x, [-1, self.xDim]) # [(b x T) x xDim] + xProj = self.xEmbed(x) # [(b x T) x attDim] + xProj = paddle.reshape(xProj, [batch_size, T, -1]) # [b x T x attDim] + + sPrev = sPrev.squeeze(0) + sProj = self.sEmbed(sPrev) # [b x attDim] + sProj = paddle.unsqueeze(sProj, 1) # [b x 1 x attDim] + sProj = paddle.expand(sProj, + [batch_size, T, self.attDim]) # [b x T x attDim] + + sumTanh = paddle.tanh(sProj + xProj) + sumTanh = paddle.reshape(sumTanh, [-1, self.attDim]) + + vProj = self.wEmbed(sumTanh) # [(b x T) x 1] + vProj = paddle.reshape(vProj, [batch_size, T]) + alpha = F.softmax( + vProj, axis=1) # attention weights for each sample in the minibatch + return alpha + + +class DecoderUnit(nn.Layer): + def __init__(self, sDim, xDim, yDim, attDim): + super(DecoderUnit, self).__init__() + self.sDim = sDim + self.xDim = xDim + self.yDim = yDim + self.attDim = attDim + self.emdDim = attDim + + self.attention_unit = AttentionUnit(sDim, xDim, attDim) + self.tgt_embedding = nn.Embedding( + yDim + 1, self.emdDim, weight_attr=nn.initializer.Normal( + std=0.01)) # the last is used for + self.gru = nn.GRUCell(input_size=xDim + self.emdDim, hidden_size=sDim) + self.fc = nn.Linear( + sDim, + yDim, + weight_attr=nn.initializer.Normal(std=0.01), + bias_attr=nn.initializer.Constant(value=0)) + self.embed_fc = nn.Linear(300, self.sDim) + + def get_initial_state(self, embed, tile_times=1): + assert embed.shape[1] == 300 + state = self.embed_fc(embed) # N * sDim + if tile_times != 1: + state = state.unsqueeze(1) + trans_state = paddle.transpose(state, perm=[1, 0, 2]) + state = paddle.tile(trans_state, repeat_times=[tile_times, 1, 1]) + trans_state = paddle.transpose(state, perm=[1, 0, 2]) + state = paddle.reshape(trans_state, shape=[-1, self.sDim]) + state = state.unsqueeze(0) # 1 * N * sDim + return state + + def forward(self, x, sPrev, yPrev): + # x: feature sequence from the image decoder. + batch_size, T, _ = x.shape + alpha = self.attention_unit(x, sPrev) + context = paddle.squeeze(paddle.matmul(alpha.unsqueeze(1), x), axis=1) + yPrev = paddle.cast(yPrev, dtype="int64") + yProj = self.tgt_embedding(yPrev) + + concat_context = paddle.concat([yProj, context], 1) + concat_context = paddle.squeeze(concat_context, 1) + sPrev = paddle.squeeze(sPrev, 0) + output, state = self.gru(concat_context, sPrev) + output = paddle.squeeze(output, axis=1) + output = self.fc(output) + return output, state \ No newline at end of file diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py index 9c38d31fa0abcf39a583e5edcebfc8f336f41c46..35d33d5f56b3b378286565cbfa9755f43343b278 100755 --- a/ppocr/modeling/heads/rec_ctc_head.py +++ b/ppocr/modeling/heads/rec_ctc_head.py @@ -38,6 +38,7 @@ class CTCHead(nn.Layer): out_channels, fc_decay=0.0004, mid_channels=None, + return_feats=False, **kwargs): super(CTCHead, self).__init__() if mid_channels is None: @@ -66,14 +67,22 @@ class CTCHead(nn.Layer): bias_attr=bias_attr2) self.out_channels = out_channels self.mid_channels = mid_channels + self.return_feats = return_feats def forward(self, x, targets=None): if self.mid_channels is None: predicts = self.fc(x) else: - predicts = self.fc1(x) - predicts = self.fc2(predicts) - + x = self.fc1(x) + predicts = self.fc2(x) + + if self.return_feats: + result = (x, predicts) + else: + result = predicts + if not self.training: predicts = F.softmax(predicts, axis=2) - return predicts + result = predicts + + return result diff --git a/ppocr/modeling/necks/rnn.py b/ppocr/modeling/necks/rnn.py index de87b3d9895168657f8c9722177c026b992c2966..86e649028f8fbb76cb5a1fd85381bd361277c6ee 100644 --- a/ppocr/modeling/necks/rnn.py +++ b/ppocr/modeling/necks/rnn.py @@ -51,7 +51,7 @@ class EncoderWithFC(nn.Layer): super(EncoderWithFC, self).__init__() self.out_channels = hidden_size weight_attr, bias_attr = get_para_bias_attr( - l2_decay=0.00001, k=in_channels, name='reduce_encoder_fea') + l2_decay=0.00001, k=in_channels) self.fc = nn.Linear( in_channels, hidden_size, diff --git a/ppocr/modeling/transforms/__init__.py b/ppocr/modeling/transforms/__init__.py index 78eaecccc55f77d6624aa0c5bdb839acc3462129..405ab3cc6c0380654f61e42e523ddc85839139b3 100755 --- a/ppocr/modeling/transforms/__init__.py +++ b/ppocr/modeling/transforms/__init__.py @@ -17,8 +17,9 @@ __all__ = ['build_transform'] def build_transform(config): from .tps import TPS + from .stn import STN_ON - support_dict = ['TPS'] + support_dict = ['TPS', 'STN_ON'] module_name = config.pop('name') assert module_name in support_dict, Exception( diff --git a/ppocr/modeling/transforms/stn.py b/ppocr/modeling/transforms/stn.py new file mode 100644 index 0000000000000000000000000000000000000000..215895f4c4c719f407f4998f7429d965e0529ddc --- /dev/null +++ b/ppocr/modeling/transforms/stn.py @@ -0,0 +1,132 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import numpy as np + +from .tps_spatial_transformer import TPSSpatialTransformer + + +def conv3x3_block(in_channels, out_channels, stride=1): + n = 3 * 3 * out_channels + w = math.sqrt(2. / n) + conv_layer = nn.Conv2D( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + weight_attr=nn.initializer.Normal( + mean=0.0, std=w), + bias_attr=nn.initializer.Constant(0)) + block = nn.Sequential(conv_layer, nn.BatchNorm2D(out_channels), nn.ReLU()) + return block + + +class STN(nn.Layer): + def __init__(self, in_channels, num_ctrlpoints, activation='none'): + super(STN, self).__init__() + self.in_channels = in_channels + self.num_ctrlpoints = num_ctrlpoints + self.activation = activation + self.stn_convnet = nn.Sequential( + conv3x3_block(in_channels, 32), #32x64 + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(32, 64), #16x32 + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(64, 128), # 8*16 + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(128, 256), # 4*8 + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(256, 256), # 2*4, + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(256, 256)) # 1*2 + self.stn_fc1 = nn.Sequential( + nn.Linear( + 2 * 256, + 512, + weight_attr=nn.initializer.Normal(0, 0.001), + bias_attr=nn.initializer.Constant(0)), + nn.BatchNorm1D(512), + nn.ReLU()) + fc2_bias = self.init_stn() + self.stn_fc2 = nn.Linear( + 512, + num_ctrlpoints * 2, + weight_attr=nn.initializer.Constant(0.0), + bias_attr=nn.initializer.Assign(fc2_bias)) + + def init_stn(self): + margin = 0.01 + sampling_num_per_side = int(self.num_ctrlpoints / 2) + ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side) + ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin + ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + ctrl_points = np.concatenate( + [ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32) + if self.activation == 'none': + pass + elif self.activation == 'sigmoid': + ctrl_points = -np.log(1. / ctrl_points - 1.) + ctrl_points = paddle.to_tensor(ctrl_points) + fc2_bias = paddle.reshape( + ctrl_points, shape=[ctrl_points.shape[0] * ctrl_points.shape[1]]) + return fc2_bias + + def forward(self, x): + x = self.stn_convnet(x) + batch_size, _, h, w = x.shape + x = paddle.reshape(x, shape=(batch_size, -1)) + img_feat = self.stn_fc1(x) + x = self.stn_fc2(0.1 * img_feat) + if self.activation == 'sigmoid': + x = F.sigmoid(x) + x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2]) + return img_feat, x + + +class STN_ON(nn.Layer): + def __init__(self, in_channels, tps_inputsize, tps_outputsize, + num_control_points, tps_margins, stn_activation): + super(STN_ON, self).__init__() + self.tps = TPSSpatialTransformer( + output_image_size=tuple(tps_outputsize), + num_control_points=num_control_points, + margins=tuple(tps_margins)) + self.stn_head = STN(in_channels=in_channels, + num_ctrlpoints=num_control_points, + activation=stn_activation) + self.tps_inputsize = tps_inputsize + self.out_channels = in_channels + + def forward(self, image): + stn_input = paddle.nn.functional.interpolate( + image, self.tps_inputsize, mode="bilinear", align_corners=True) + stn_img_feat, ctrl_points = self.stn_head(stn_input) + x, _ = self.tps(image, ctrl_points) + return x diff --git a/ppocr/modeling/transforms/tps.py b/ppocr/modeling/transforms/tps.py index dcce6246ac64b4b84229cbd69a4dc53c658b4c7b..6cd68555369dd1ddbd6ccf5236688a4b957b8525 100644 --- a/ppocr/modeling/transforms/tps.py +++ b/ppocr/modeling/transforms/tps.py @@ -231,7 +231,8 @@ class GridGenerator(nn.Layer): """ Return inv_delta_C which is needed to calculate T """ F = self.F hat_eye = paddle.eye(F, dtype='float64') # F x F - hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye + hat_C = paddle.norm( + C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye hat_C = (hat_C**2) * paddle.log(hat_C) delta_C = paddle.concat( # F+3 x F+3 [ diff --git a/ppocr/modeling/transforms/tps_spatial_transformer.py b/ppocr/modeling/transforms/tps_spatial_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b510acb0d4012c9a4d90c7ca07cac895f0bf242e --- /dev/null +++ b/ppocr/modeling/transforms/tps_spatial_transformer.py @@ -0,0 +1,152 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import numpy as np +import itertools + + +def grid_sample(input, grid, canvas=None): + input.stop_gradient = False + output = F.grid_sample(input, grid) + if canvas is None: + return output + else: + input_mask = paddle.ones(shape=input.shape) + output_mask = F.grid_sample(input_mask, grid) + padded_output = output * output_mask + canvas * (1 - output_mask) + return padded_output + + +# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2 +def compute_partial_repr(input_points, control_points): + N = input_points.shape[0] + M = control_points.shape[0] + pairwise_diff = paddle.reshape( + input_points, shape=[N, 1, 2]) - paddle.reshape( + control_points, shape=[1, M, 2]) + # original implementation, very slow + # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance + pairwise_diff_square = pairwise_diff * pairwise_diff + pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, + 1] + repr_matrix = 0.5 * pairwise_dist * paddle.log(pairwise_dist) + # fix numerical error for 0 * log(0), substitute all nan with 0 + mask = repr_matrix != repr_matrix + repr_matrix[mask] = 0 + return repr_matrix + + +# output_ctrl_pts are specified, according to our task. +def build_output_control_points(num_control_points, margins): + margin_x, margin_y = margins + num_ctrl_pts_per_side = num_control_points // 2 + ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side) + ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y + ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + output_ctrl_pts_arr = np.concatenate( + [ctrl_pts_top, ctrl_pts_bottom], axis=0) + output_ctrl_pts = paddle.to_tensor(output_ctrl_pts_arr) + return output_ctrl_pts + + +class TPSSpatialTransformer(nn.Layer): + def __init__(self, + output_image_size=None, + num_control_points=None, + margins=None): + super(TPSSpatialTransformer, self).__init__() + self.output_image_size = output_image_size + self.num_control_points = num_control_points + self.margins = margins + + self.target_height, self.target_width = output_image_size + target_control_points = build_output_control_points(num_control_points, + margins) + N = num_control_points + + # create padded kernel matrix + forward_kernel = paddle.zeros(shape=[N + 3, N + 3]) + target_control_partial_repr = compute_partial_repr( + target_control_points, target_control_points) + target_control_partial_repr = paddle.cast(target_control_partial_repr, + forward_kernel.dtype) + forward_kernel[:N, :N] = target_control_partial_repr + forward_kernel[:N, -3] = 1 + forward_kernel[-3, :N] = 1 + target_control_points = paddle.cast(target_control_points, + forward_kernel.dtype) + forward_kernel[:N, -2:] = target_control_points + forward_kernel[-2:, :N] = paddle.transpose( + target_control_points, perm=[1, 0]) + # compute inverse matrix + inverse_kernel = paddle.inverse(forward_kernel) + + # create target cordinate matrix + HW = self.target_height * self.target_width + target_coordinate = list( + itertools.product( + range(self.target_height), range(self.target_width))) + target_coordinate = paddle.to_tensor(target_coordinate) # HW x 2 + Y, X = paddle.split( + target_coordinate, target_coordinate.shape[1], axis=1) + Y = Y / (self.target_height - 1) + X = X / (self.target_width - 1) + target_coordinate = paddle.concat( + [X, Y], axis=1) # convert from (y, x) to (x, y) + target_coordinate_partial_repr = compute_partial_repr( + target_coordinate, target_control_points) + target_coordinate_repr = paddle.concat( + [ + target_coordinate_partial_repr, paddle.ones(shape=[HW, 1]), + target_coordinate + ], + axis=1) + + # register precomputed matrices + self.inverse_kernel = inverse_kernel + self.padding_matrix = paddle.zeros(shape=[3, 2]) + self.target_coordinate_repr = target_coordinate_repr + self.target_control_points = target_control_points + + def forward(self, input, source_control_points): + assert source_control_points.ndimension() == 3 + assert source_control_points.shape[1] == self.num_control_points + assert source_control_points.shape[2] == 2 + batch_size = paddle.shape(source_control_points)[0] + + self.padding_matrix = paddle.expand( + self.padding_matrix, shape=[batch_size, 3, 2]) + Y = paddle.concat([source_control_points, self.padding_matrix], 1) + mapping_matrix = paddle.matmul(self.inverse_kernel, Y) + source_coordinate = paddle.matmul(self.target_coordinate_repr, + mapping_matrix) + + grid = paddle.reshape( + source_coordinate, + shape=[-1, self.target_height, self.target_width, 2]) + grid = paddle.clip(grid, 0, + 1) # the source_control_points may be out of [0, 1]. + # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] + grid = 2.0 * grid - 1.0 + output_maps = grid_sample(input, grid, canvas=None) + return output_maps, source_coordinate diff --git a/ppocr/optimizer/optimizer.py b/ppocr/optimizer/optimizer.py index 8215b92d8c8d05c2b3c2e95ac989bf4ea011310b..34098c0fad553f7d39f6b5341e4da70a263eeaea 100644 --- a/ppocr/optimizer/optimizer.py +++ b/ppocr/optimizer/optimizer.py @@ -127,3 +127,34 @@ class RMSProp(object): grad_clip=self.grad_clip, parameters=parameters) return opt + + +class Adadelta(object): + def __init__(self, + learning_rate=0.001, + epsilon=1e-08, + rho=0.95, + parameter_list=None, + weight_decay=None, + grad_clip=None, + name=None, + **kwargs): + self.learning_rate = learning_rate + self.epsilon = epsilon + self.rho = rho + self.parameter_list = parameter_list + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.grad_clip = grad_clip + self.name = name + + def __call__(self, parameters): + opt = optim.Adadelta( + learning_rate=self.learning_rate, + epsilon=self.epsilon, + rho=self.rho, + weight_decay=self.weight_decay, + grad_clip=self.grad_clip, + name=self.name, + parameters=parameters) + return opt diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 3eb5e28da34bf4d9ed478c52ab38b1de21c0d1e3..3a4ebf52a3bd91ffd509b113103dab900588b0bd 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -18,17 +18,21 @@ from __future__ import print_function from __future__ import unicode_literals import copy +import platform __all__ = ['build_post_process'] from .db_postprocess import DBPostProcess, DistillationDBPostProcess from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess -from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode, \ - TableLabelDecode, SARLabelDecode +from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \ + TableLabelDecode, NRTRLabelDecode, SARLabelDecode , SEEDLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess -from .pse_postprocess import PSEPostProcess + +if platform.system() != "Windows": + # pse is not support in Windows + from .pse_postprocess import PSEPostProcess def build_post_process(config, global_config=None): @@ -36,7 +40,8 @@ def build_post_process(config, global_config=None): 'DBPostProcess', 'PSEPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode', - 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode' + 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', + 'SEEDLabelDecode' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 96b2169d28004d408bd567db2b3130b681fcb582..c06159ca55600e7afe01a68ab43acd1919cf742c 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -111,6 +111,8 @@ class CTCLabelDecode(BaseRecLabelDecode): character_type, use_space_char) def __call__(self, preds, label=None, *args, **kwargs): + if isinstance(preds, tuple): + preds = preds[-1] if isinstance(preds, paddle.Tensor): preds = preds.numpy() preds_idx = preds.argmax(axis=2) @@ -308,6 +310,87 @@ class AttnLabelDecode(BaseRecLabelDecode): return idx +class SEEDLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='ch', + use_space_char=False, + **kwargs): + super(SEEDLabelDecode, self).__init__(character_dict_path, + character_type, use_space_char) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = dict_character + [self.end_str] + return dict_character + + def get_ignored_tokens(self): + end_idx = self.get_beg_end_flag_idx("eos") + return [end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "sos": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "eos": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end + return idx + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + [end_idx] = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if int(text_index[batch_idx][idx]) == int(end_idx): + break + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list))) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + """ + text = self.decode(text) + if label is None: + return text + else: + label = self.decode(label, is_remove_duplicate=False) + return text, label + """ + preds_idx = preds["rec_pred"] + if isinstance(preds_idx, paddle.Tensor): + preds_idx = preds_idx.numpy() + if "rec_pred_scores" in preds: + preds_idx = preds["rec_pred"] + preds_prob = preds["rec_pred_scores"] + else: + preds_idx = preds["rec_pred"].argmax(axis=2) + preds_prob = preds["rec_pred"].max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + class SRNLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ diff --git a/ppocr/utils/dict90.txt b/ppocr/utils/dict90.txt new file mode 100644 index 0000000000000000000000000000000000000000..a945ae9c526e4faa68852eb3fb47d078a2f3f6ce --- /dev/null +++ b/ppocr/utils/dict90.txt @@ -0,0 +1,90 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ +: +; +< += +> +? +@ +[ +\ +] +_ +` +~ \ No newline at end of file diff --git a/ppocr/utils/profiler.py b/ppocr/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..c4e28bc6bea9ca912a0786d879a48ec0349e7698 --- /dev/null +++ b/ppocr/utils/profiler.py @@ -0,0 +1,110 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import paddle + +# A global variable to record the number of calling times for profiler +# functions. It is used to specify the tracing range of training steps. +_profiler_step_id = 0 + +# A global variable to avoid parsing from string every time. +_profiler_options = None + + +class ProfilerOptions(object): + ''' + Use a string to initialize a ProfilerOptions. + The string should be in the format: "key1=value1;key2=value;key3=value3". + For example: + "profile_path=model.profile" + "batch_range=[50, 60]; profile_path=model.profile" + "batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile" + ProfilerOptions supports following key-value pair: + batch_range - a integer list, e.g. [100, 110]. + state - a string, the optional values are 'CPU', 'GPU' or 'All'. + sorted_key - a string, the optional values are 'calls', 'total', + 'max', 'min' or 'ave. + tracer_option - a string, the optional values are 'Default', 'OpDetail', + 'AllOpDetail'. + profile_path - a string, the path to save the serialized profile data, + which can be used to generate a timeline. + exit_on_finished - a boolean. + ''' + + def __init__(self, options_str): + assert isinstance(options_str, str) + + self._options = { + 'batch_range': [10, 20], + 'state': 'All', + 'sorted_key': 'total', + 'tracer_option': 'Default', + 'profile_path': '/tmp/profile', + 'exit_on_finished': True + } + self._parse_from_string(options_str) + + def _parse_from_string(self, options_str): + for kv in options_str.replace(' ', '').split(';'): + key, value = kv.split('=') + if key == 'batch_range': + value_list = value.replace('[', '').replace(']', '').split(',') + value_list = list(map(int, value_list)) + if len(value_list) >= 2 and value_list[0] >= 0 and value_list[ + 1] > value_list[0]: + self._options[key] = value_list + elif key == 'exit_on_finished': + self._options[key] = value.lower() in ("yes", "true", "t", "1") + elif key in [ + 'state', 'sorted_key', 'tracer_option', 'profile_path' + ]: + self._options[key] = value + + def __getitem__(self, name): + if self._options.get(name, None) is None: + raise ValueError( + "ProfilerOptions does not have an option named %s." % name) + return self._options[name] + + +def add_profiler_step(options_str=None): + ''' + Enable the operator-level timing using PaddlePaddle's profiler. + The profiler uses a independent variable to count the profiler steps. + One call of this function is treated as a profiler step. + + Args: + profiler_options - a string to initialize the ProfilerOptions. + Default is None, and the profiler is disabled. + ''' + if options_str is None: + return + + global _profiler_step_id + global _profiler_options + + if _profiler_options is None: + _profiler_options = ProfilerOptions(options_str) + + if _profiler_step_id == _profiler_options['batch_range'][0]: + paddle.utils.profiler.start_profiler( + _profiler_options['state'], _profiler_options['tracer_option']) + elif _profiler_step_id == _profiler_options['batch_range'][1]: + paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'], + _profiler_options['profile_path']) + if _profiler_options['exit_on_finished']: + sys.exit(0) + + _profiler_step_id += 1 diff --git a/ppstructure/layout/train_layoutparser_model.md b/ppstructure/layout/train_layoutparser_model.md index 08f5ebbf1aa276e4a3ecf27af46442161afcda1f..58975d71606e45b2f68a7f68565459042ef32775 100644 --- a/ppstructure/layout/train_layoutparser_model.md +++ b/ppstructure/layout/train_layoutparser_model.md @@ -4,9 +4,9 @@ ​ [1.1 Requirements](#Requirements) -​ [1.2 Install PaddleDetection](#Install PaddleDetection) +​ [1.2 Install PaddleDetection](#Install_PaddleDetection) -[2. Data preparation](#Data preparation) +[2. Data preparation](#Data_reparation) [3. Configuration](#Configuration) @@ -16,7 +16,7 @@ [6. Deployment](#Deployment) -​ [6.1 Export model](#Export model) +​ [6.1 Export model](#Export_model) ​ [6.2 Inference](#Inference) @@ -35,7 +35,7 @@ - CUDA >= 10.1 - cuDNN >= 7.6 - + ### 1.2 Install PaddleDetection @@ -51,7 +51,7 @@ pip install -r requirements.txt For more installation tutorials, please refer to: [Install doc](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/INSTALL_cn.md) - + ## 2. Data preparation @@ -165,7 +165,7 @@ python tools/infer.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --infer Use your trained model in Layout Parser - + ### 6.1 Export model diff --git a/requirements.txt b/requirements.txt index 0b2366c5cd344260d7afab811b27e19499a89b26..311030f65f2dc2dad4a51821e64f2777e7621a0b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ opencv-contrib-python==4.4.0.46 cython lxml premailer -openpyxl \ No newline at end of file +openpyxl +fasttext==0.9.1 \ No newline at end of file diff --git a/tests/common_func.sh b/tests/common_func.sh new file mode 100644 index 0000000000000000000000000000000000000000..3f0fa66b77ff50b23b1e83dea506580f549f8ecf --- /dev/null +++ b/tests/common_func.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +function func_parser_key(){ + strs=$1 + IFS=":" + array=(${strs}) + tmp=${array[0]} + echo ${tmp} +} + +function func_parser_value(){ + strs=$1 + IFS=":" + array=(${strs}) + tmp=${array[1]} + echo ${tmp} +} + +function func_set_params(){ + key=$1 + value=$2 + if [ ${key}x = "null"x ];then + echo " " + elif [[ ${value} = "null" ]] || [[ ${value} = " " ]] || [ ${#value} -le 0 ];then + echo " " + else + echo "${key}=${value}" + fi +} + +function func_parser_params(){ + strs=$1 + IFS=":" + array=(${strs}) + key=${array[0]} + tmp=${array[1]} + IFS="|" + res="" + for _params in ${tmp[*]}; do + IFS="=" + array=(${_params}) + mode=${array[0]} + value=${array[1]} + if [[ ${mode} = ${MODE} ]]; then + IFS="|" + #echo $(func_set_params "${mode}" "${value}") + echo $value + break + fi + IFS="|" + done + echo ${res} +} + +function status_check(){ + last_status=$1 # the exit code + run_command=$2 + run_log=$3 + if [ $last_status -eq 0 ]; then + echo -e "\033[33m Run successfully with command - ${run_command}! \033[0m" | tee -a ${run_log} + else + echo -e "\033[33m Run failed with command - ${run_command}! \033[0m" | tee -a ${run_log} + fi +} + diff --git a/tests/ocr_det_params.txt b/tests/configs/ppocr_det_mobile_params.txt similarity index 84% rename from tests/ocr_det_params.txt rename to tests/configs/ppocr_det_mobile_params.txt index 6fd22e409a5219574b2f29285ff5ee5d2e1cf7ca..5edb14cdbf8eef87b5b5558cbd8d1a2ff54ae919 100644 --- a/tests/ocr_det_params.txt +++ b/tests/configs/ppocr_det_mobile_params.txt @@ -40,13 +40,13 @@ infer_quant:False inference:tools/infer/predict_det.py --use_gpu:True|False --enable_mkldnn:True|False ---cpu_threads:6 +--cpu_threads:1|6 --rec_batch_num:1 --use_tensorrt:False|True --precision:fp32|fp16|int8 --det_model_dir: --image_dir:./inference/ch_det_data_50/all-sum-510/ ---save_log_path:null +null:null --benchmark:True null:null ===========================cpp_infer_params=========================== @@ -79,4 +79,20 @@ op.det.local_service_conf.thread_num:1|6 op.det.local_service_conf.use_trt:False|True op.det.local_service_conf.precision:fp32|fp16|int8 pipline:pipeline_http_client.py --image_dir=../../doc/imgs - +===========================kl_quant_params=========================== +infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/ +infer_export:tools/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o +infer_quant:False +inference:tools/infer/predict_det.py +--use_gpu:True|False +--enable_mkldnn:True|False +--cpu_threads:1|6 +--rec_batch_num:1 +--use_tensorrt:False|True +--precision:fp32|fp16|int8 +--det_model_dir: +--image_dir:./inference/ch_det_data_50/all-sum-510/ +null:null +--benchmark:True +null:null +null:null \ No newline at end of file diff --git a/tests/ocr_det_server_params.txt b/tests/configs/ppocr_det_server_params.txt similarity index 77% rename from tests/ocr_det_server_params.txt rename to tests/configs/ppocr_det_server_params.txt index 4a17fa683439fdc4716b4ed6b067a572fa3a5057..b3df1735e50d941b34eeb274c28eb4ce50d79292 100644 --- a/tests/ocr_det_server_params.txt +++ b/tests/configs/ppocr_det_server_params.txt @@ -12,10 +12,10 @@ train_model_name:latest train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ null:null ## -trainer:norm_train|pact_train -norm_train:tools/train.py -c tests/configs/det_r50_vd_db.yml -o Global.pretrained_model="" -pact_train:null -fpgm_train:null +trainer:norm_train|pact_train|fpgm_export +norm_train:tools/train.py -c tests/configs/det_r50_vd_db.yml -o +quant_export:deploy/slim/quantization/export_model.py -c tests/configs/det_r50_vd_db.yml -o +fpgm_export:deploy/slim/prune/export_prune_model.py -c tests/configs/det_r50_vd_db.yml -o distill_train:null null:null null:null @@ -34,8 +34,8 @@ distill_export:null export1:null export2:null ## -infer_model:./inference/ch_ppocr_server_v2.0_det_infer/ -infer_export:null +train_model:./inference/ch_ppocr_server_v2.0_det_train/best_accuracy +infer_export:tools/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml -o infer_quant:False inference:tools/infer/predict_det.py --use_gpu:True|False diff --git a/tests/ocr_rec_params.txt b/tests/configs/ppocr_rec_mobile_params.txt similarity index 100% rename from tests/ocr_rec_params.txt rename to tests/configs/ppocr_rec_mobile_params.txt diff --git a/tests/ocr_rec_server_params.txt b/tests/configs/ppocr_rec_server_params.txt similarity index 100% rename from tests/ocr_rec_server_params.txt rename to tests/configs/ppocr_rec_server_params.txt diff --git a/tests/ocr_ppocr_mobile_params.txt b/tests/configs/ppocr_sys_mobile_params.txt similarity index 100% rename from tests/ocr_ppocr_mobile_params.txt rename to tests/configs/ppocr_sys_mobile_params.txt diff --git a/tests/ocr_ppocr_server_params.txt b/tests/configs/ppocr_sys_server_params.txt similarity index 100% rename from tests/ocr_ppocr_server_params.txt rename to tests/configs/ppocr_sys_server_params.txt diff --git a/tests/ocr_kl_quant_params.txt b/tests/ocr_kl_quant_params.txt deleted file mode 100644 index c6ee97dca49bb7d942a339783af44053e6c79b00..0000000000000000000000000000000000000000 --- a/tests/ocr_kl_quant_params.txt +++ /dev/null @@ -1,51 +0,0 @@ -===========================train_params=========================== -model_name:ocr_system -python:python3.7 -gpu_list:null -Global.use_gpu:null -Global.auto_cast:null -Global.epoch_num:null -Global.save_model_dir:./output/ -Train.loader.batch_size_per_card:null -Global.pretrained_model:null -train_model_name:null -train_infer_img_dir:null -null:null -## -trainer: -norm_train:null -pact_train:null -fpgm_train:null -distill_train:null -null:null -null:null -## -===========================eval_params=========================== -eval:null -null:null -## -===========================infer_params=========================== -Global.save_inference_dir:./output/ -Global.pretrained_model: -norm_export:null -quant_export:null -fpgm_export:null -distill_export:null -export1:null -export2:null -## -infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/ -kl_quant:deploy/slim/quantization/quant_kl.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o -infer_quant:True -inference:tools/infer/predict_det.py ---use_gpu:TrueFalse ---enable_mkldnn:True|False ---cpu_threads:1|6 ---rec_batch_num:1 ---use_tensorrt:False|True ---precision:fp32|fp16|int8 ---det_model_dir: ---image_dir:./inference/ch_det_data_50/all-sum-510/ ---save_log_path:null ---benchmark:True -null:null diff --git a/tests/prepare.sh b/tests/prepare.sh index ef021fa385f16ae5c9c996bfcb607f73b4129f49..abb84c881e52ca8076f218e926d41679b6578d09 100644 --- a/tests/prepare.sh +++ b/tests/prepare.sh @@ -1,7 +1,9 @@ #!/bin/bash FILENAME=$1 -# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer', 'cpp_infer', 'serving_infer'] +# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer', +# 'cpp_infer', 'serving_infer', 'klquant_infer'] + MODE=$2 dataline=$(cat ${FILENAME}) @@ -72,9 +74,9 @@ elif [ ${MODE} = "infer" ];then wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && cd ../ elif [ ${model_name} = "ocr_server_det" ]; then - wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar - cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../ + cd ./inference && tar xf ch_ppocr_server_v2.0_det_train.tar && tar xf ch_det_data_50.tar && cd ../ elif [ ${model_name} = "ocr_system_mobile" ]; then wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar @@ -98,6 +100,12 @@ elif [ ${MODE} = "infer" ];then wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar cd ./inference && tar xf ${eval_model_name}.tar && tar xf rec_inference.tar && cd ../ fi +elif [ ${MODE} = "klquant_infer" ];then + if [ ${model_name} = "ocr_det" ]; then + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar + cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../ + fi elif [ ${MODE} = "cpp_infer" ];then if [ ${model_name} = "ocr_det" ]; then wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar @@ -128,77 +136,3 @@ if [ ${MODE} = "serving_infer" ];then wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_ppocr_server_v2.0_det_infer.tar cd ../ fi - -if [ ${MODE} = "cpp_infer" ];then - cd deploy/cpp_infer - use_opencv=$(func_parser_value "${lines[52]}") - if [ ${use_opencv} = "True" ]; then - if [ -d "opencv-3.4.7/opencv3/" ] && [ $(md5sum opencv-3.4.7.tar.gz | awk -F ' ' '{print $1}') = "faa2b5950f8bee3f03118e600c74746a" ];then - echo "################### build opencv skipped ###################" - else - echo "################### build opencv ###################" - rm -rf opencv-3.4.7.tar.gz opencv-3.4.7/ - wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/opencv-3.4.7.tar.gz - tar -xf opencv-3.4.7.tar.gz - - cd opencv-3.4.7/ - install_path=$(pwd)/opencv3 - - rm -rf build - mkdir build - cd build - - cmake .. \ - -DCMAKE_INSTALL_PREFIX=${install_path} \ - -DCMAKE_BUILD_TYPE=Release \ - -DBUILD_SHARED_LIBS=OFF \ - -DWITH_IPP=OFF \ - -DBUILD_IPP_IW=OFF \ - -DWITH_LAPACK=OFF \ - -DWITH_EIGEN=OFF \ - -DCMAKE_INSTALL_LIBDIR=lib64 \ - -DWITH_ZLIB=ON \ - -DBUILD_ZLIB=ON \ - -DWITH_JPEG=ON \ - -DBUILD_JPEG=ON \ - -DWITH_PNG=ON \ - -DBUILD_PNG=ON \ - -DWITH_TIFF=ON \ - -DBUILD_TIFF=ON - - make -j - make install - cd ../ - echo "################### build opencv finished ###################" - fi - fi - - - echo "################### build PaddleOCR demo ####################" - if [ ${use_opencv} = "True" ]; then - OPENCV_DIR=$(pwd)/opencv-3.4.7/opencv3/ - else - OPENCV_DIR='' - fi - LIB_DIR=$(pwd)/Paddle/build/paddle_inference_install_dir/ - CUDA_LIB_DIR=$(dirname `find /usr -name libcudart.so`) - CUDNN_LIB_DIR=$(dirname `find /usr -name libcudnn.so`) - - BUILD_DIR=build - rm -rf ${BUILD_DIR} - mkdir ${BUILD_DIR} - cd ${BUILD_DIR} - cmake .. \ - -DPADDLE_LIB=${LIB_DIR} \ - -DWITH_MKL=ON \ - -DWITH_GPU=OFF \ - -DWITH_STATIC_LIB=OFF \ - -DWITH_TENSORRT=OFF \ - -DOPENCV_DIR=${OPENCV_DIR} \ - -DCUDNN_LIB=${CUDNN_LIB_DIR} \ - -DCUDA_LIB=${CUDA_LIB_DIR} \ - -DTENSORRT_DIR=${TENSORRT_DIR} \ - - make -j - echo "################### build PaddleOCR demo finished ###################" -fi diff --git a/tests/results/det_results_gpu_trt_fp16.txt b/tests/results/ppocr_det_mobile_results_fp16.txt similarity index 100% rename from tests/results/det_results_gpu_trt_fp16.txt rename to tests/results/ppocr_det_mobile_results_fp16.txt diff --git a/tests/results/det_results_gpu_trt_fp16_cpp.txt b/tests/results/ppocr_det_mobile_results_fp16_cpp.txt similarity index 100% rename from tests/results/det_results_gpu_trt_fp16_cpp.txt rename to tests/results/ppocr_det_mobile_results_fp16_cpp.txt diff --git a/tests/results/det_results_gpu_fp32.txt b/tests/results/ppocr_det_mobile_results_fp32.txt similarity index 100% rename from tests/results/det_results_gpu_fp32.txt rename to tests/results/ppocr_det_mobile_results_fp32.txt diff --git a/tests/results/det_results_gpu_trt_fp32_cpp.txt b/tests/results/ppocr_det_mobile_results_fp32_cpp.txt similarity index 100% rename from tests/results/det_results_gpu_trt_fp32_cpp.txt rename to tests/results/ppocr_det_mobile_results_fp32_cpp.txt diff --git a/tests/test.sh b/tests/test.sh index 5649e344b76cf4485db533eee4035e1cbdd5adae..3df0d52cc5cfa6fd8d7259d47178d8c26d2952fb 100644 --- a/tests/test.sh +++ b/tests/test.sh @@ -1,9 +1,16 @@ #!/bin/bash FILENAME=$1 -# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer', 'cpp_infer'] +# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer', 'cpp_infer', 'serving_infer', 'klquant_infer'] MODE=$2 - -dataline=$(cat ${FILENAME}) +if [ ${MODE} = "cpp_infer" ]; then + dataline=$(awk 'NR==67, NR==81{print}' $FILENAME) +elif [ ${MODE} = "serving_infer" ]; then + dataline=$(awk 'NR==52, NR==66{print}' $FILENAME) +elif [ ${MODE} = "klquant_infer" ]; then + dataline=$(awk 'NR==82, NR==98{print}' $FILENAME) +else + dataline=$(awk 'NR==1, NR==51{print}' $FILENAME) +fi # parser params IFS=$'\n' @@ -144,61 +151,93 @@ benchmark_key=$(func_parser_key "${lines[49]}") benchmark_value=$(func_parser_value "${lines[49]}") infer_key1=$(func_parser_key "${lines[50]}") infer_value1=$(func_parser_value "${lines[50]}") -# parser serving -trans_model_py=$(func_parser_value "${lines[67]}") -infer_model_dir_key=$(func_parser_key "${lines[68]}") -infer_model_dir_value=$(func_parser_value "${lines[68]}") -model_filename_key=$(func_parser_key "${lines[69]}") -model_filename_value=$(func_parser_value "${lines[69]}") -params_filename_key=$(func_parser_key "${lines[70]}") -params_filename_value=$(func_parser_value "${lines[70]}") -serving_server_key=$(func_parser_key "${lines[71]}") -serving_server_value=$(func_parser_value "${lines[71]}") -serving_client_key=$(func_parser_key "${lines[72]}") -serving_client_value=$(func_parser_value "${lines[72]}") -serving_dir_value=$(func_parser_value "${lines[73]}") -web_service_py=$(func_parser_value "${lines[74]}") -web_use_gpu_key=$(func_parser_key "${lines[75]}") -web_use_gpu_list=$(func_parser_value "${lines[75]}") -web_use_mkldnn_key=$(func_parser_key "${lines[76]}") -web_use_mkldnn_list=$(func_parser_value "${lines[76]}") -web_cpu_threads_key=$(func_parser_key "${lines[77]}") -web_cpu_threads_list=$(func_parser_value "${lines[77]}") -web_use_trt_key=$(func_parser_key "${lines[78]}") -web_use_trt_list=$(func_parser_value "${lines[78]}") -web_precision_key=$(func_parser_key "${lines[79]}") -web_precision_list=$(func_parser_value "${lines[79]}") -pipeline_py=$(func_parser_value "${lines[80]}") +# parser serving +if [ ${MODE} = "klquant_infer" ]; then + # parser inference model + infer_model_dir_list=$(func_parser_value "${lines[1]}") + infer_export_list=$(func_parser_value "${lines[2]}") + infer_is_quant=$(func_parser_value "${lines[3]}") + # parser inference + inference_py=$(func_parser_value "${lines[4]}") + use_gpu_key=$(func_parser_key "${lines[5]}") + use_gpu_list=$(func_parser_value "${lines[5]}") + use_mkldnn_key=$(func_parser_key "${lines[6]}") + use_mkldnn_list=$(func_parser_value "${lines[6]}") + cpu_threads_key=$(func_parser_key "${lines[7]}") + cpu_threads_list=$(func_parser_value "${lines[7]}") + batch_size_key=$(func_parser_key "${lines[8]}") + batch_size_list=$(func_parser_value "${lines[8]}") + use_trt_key=$(func_parser_key "${lines[9]}") + use_trt_list=$(func_parser_value "${lines[9]}") + precision_key=$(func_parser_key "${lines[10]}") + precision_list=$(func_parser_value "${lines[10]}") + infer_model_key=$(func_parser_key "${lines[11]}") + image_dir_key=$(func_parser_key "${lines[12]}") + infer_img_dir=$(func_parser_value "${lines[12]}") + save_log_key=$(func_parser_key "${lines[13]}") + benchmark_key=$(func_parser_key "${lines[14]}") + benchmark_value=$(func_parser_value "${lines[14]}") + infer_key1=$(func_parser_key "${lines[15]}") + infer_value1=$(func_parser_value "${lines[15]}") +fi +# parser serving +if [ ${MODE} = "server_infer" ]; then + trans_model_py=$(func_parser_value "${lines[1]}") + infer_model_dir_key=$(func_parser_key "${lines[2]}") + infer_model_dir_value=$(func_parser_value "${lines[2]}") + model_filename_key=$(func_parser_key "${lines[3]}") + model_filename_value=$(func_parser_value "${lines[3]}") + params_filename_key=$(func_parser_key "${lines[4]}") + params_filename_value=$(func_parser_value "${lines[4]}") + serving_server_key=$(func_parser_key "${lines[5]}") + serving_server_value=$(func_parser_value "${lines[5]}") + serving_client_key=$(func_parser_key "${lines[6]}") + serving_client_value=$(func_parser_value "${lines[6]}") + serving_dir_value=$(func_parser_value "${lines[7]}") + web_service_py=$(func_parser_value "${lines[8]}") + web_use_gpu_key=$(func_parser_key "${lines[9]}") + web_use_gpu_list=$(func_parser_value "${lines[9]}") + web_use_mkldnn_key=$(func_parser_key "${lines[10]}") + web_use_mkldnn_list=$(func_parser_value "${lines[10]}") + web_cpu_threads_key=$(func_parser_key "${lines[11]}") + web_cpu_threads_list=$(func_parser_value "${lines[11]}") + web_use_trt_key=$(func_parser_key "${lines[12]}") + web_use_trt_list=$(func_parser_value "${lines[12]}") + web_precision_key=$(func_parser_key "${lines[13]}") + web_precision_list=$(func_parser_value "${lines[13]}") + pipeline_py=$(func_parser_value "${lines[14]}") +fi if [ ${MODE} = "cpp_infer" ]; then # parser cpp inference model - cpp_infer_model_dir_list=$(func_parser_value "${lines[53]}") - cpp_infer_is_quant=$(func_parser_value "${lines[54]}") + cpp_infer_model_dir_list=$(func_parser_value "${lines[1]}") + cpp_infer_is_quant=$(func_parser_value "${lines[2]}") # parser cpp inference - inference_cmd=$(func_parser_value "${lines[55]}") - cpp_use_gpu_key=$(func_parser_key "${lines[56]}") - cpp_use_gpu_list=$(func_parser_value "${lines[56]}") - cpp_use_mkldnn_key=$(func_parser_key "${lines[57]}") - cpp_use_mkldnn_list=$(func_parser_value "${lines[57]}") - cpp_cpu_threads_key=$(func_parser_key "${lines[58]}") - cpp_cpu_threads_list=$(func_parser_value "${lines[58]}") - cpp_batch_size_key=$(func_parser_key "${lines[59]}") - cpp_batch_size_list=$(func_parser_value "${lines[59]}") - cpp_use_trt_key=$(func_parser_key "${lines[60]}") - cpp_use_trt_list=$(func_parser_value "${lines[60]}") - cpp_precision_key=$(func_parser_key "${lines[61]}") - cpp_precision_list=$(func_parser_value "${lines[61]}") - cpp_infer_model_key=$(func_parser_key "${lines[62]}") - cpp_image_dir_key=$(func_parser_key "${lines[63]}") - cpp_infer_img_dir=$(func_parser_value "${lines[63]}") - cpp_infer_key1=$(func_parser_key "${lines[64]}") - cpp_infer_value1=$(func_parser_value "${lines[64]}") - cpp_benchmark_key=$(func_parser_key "${lines[65]}") - cpp_benchmark_value=$(func_parser_value "${lines[65]}") + inference_cmd=$(func_parser_value "${lines[3]}") + cpp_use_gpu_key=$(func_parser_key "${lines[4]}") + cpp_use_gpu_list=$(func_parser_value "${lines[4]}") + cpp_use_mkldnn_key=$(func_parser_key "${lines[5]}") + cpp_use_mkldnn_list=$(func_parser_value "${lines[5]}") + cpp_cpu_threads_key=$(func_parser_key "${lines[6]}") + cpp_cpu_threads_list=$(func_parser_value "${lines[6]}") + cpp_batch_size_key=$(func_parser_key "${lines[7]}") + cpp_batch_size_list=$(func_parser_value "${lines[7]}") + cpp_use_trt_key=$(func_parser_key "${lines[8]}") + cpp_use_trt_list=$(func_parser_value "${lines[8]}") + cpp_precision_key=$(func_parser_key "${lines[9]}") + cpp_precision_list=$(func_parser_value "${lines[9]}") + cpp_infer_model_key=$(func_parser_key "${lines[10]}") + cpp_image_dir_key=$(func_parser_key "${lines[11]}") + cpp_infer_img_dir=$(func_parser_value "${lines[12]}") + cpp_infer_key1=$(func_parser_key "${lines[13]}") + cpp_infer_value1=$(func_parser_value "${lines[13]}") + cpp_benchmark_key=$(func_parser_key "${lines[14]}") + cpp_benchmark_value=$(func_parser_value "${lines[14]}") fi + LOG_PATH="./tests/output" mkdir -p ${LOG_PATH} status_log="${LOG_PATH}/results.log" @@ -414,7 +453,7 @@ function func_cpp_inference(){ done } -if [ ${MODE} = "infer" ]; then +if [ ${MODE} = "infer" ] || [ ${MODE} = "klquant_infer" ]; then GPUID=$3 if [ ${#GPUID} -le 0 ];then env=" " @@ -447,7 +486,6 @@ if [ ${MODE} = "infer" ]; then func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${LOG_PATH}" "${infer_img_dir}" ${is_quant} Count=$(($Count + 1)) done - elif [ ${MODE} = "cpp_infer" ]; then GPUID=$3 if [ ${#GPUID} -le 0 ];then @@ -481,6 +519,8 @@ elif [ ${MODE} = "serving_infer" ]; then #run serving func_serving "${web_service_cmd}" + + else IFS="|" export Count=0 diff --git a/tests/test_cpp.sh b/tests/test_cpp.sh new file mode 100644 index 0000000000000000000000000000000000000000..f755858cd0f781d7ef0de8089aefbee86ef83f82 --- /dev/null +++ b/tests/test_cpp.sh @@ -0,0 +1,204 @@ +#!/bin/bash +source tests/common_func.sh + +FILENAME=$1 +dataline=$(awk 'NR==52, NR==66{print}' $FILENAME) + +# parser params +IFS=$'\n' +lines=(${dataline}) + +# parser cpp inference model +use_opencv=$(func_parser_value "${lines[1]}") +cpp_infer_model_dir_list=$(func_parser_value "${lines[2]}") +cpp_infer_is_quant=$(func_parser_value "${lines[3]}") +# parser cpp inference +inference_cmd=$(func_parser_value "${lines[4]}") +cpp_use_gpu_key=$(func_parser_key "${lines[5]}") +cpp_use_gpu_list=$(func_parser_value "${lines[5]}") +cpp_use_mkldnn_key=$(func_parser_key "${lines[6]}") +cpp_use_mkldnn_list=$(func_parser_value "${lines[6]}") +cpp_cpu_threads_key=$(func_parser_key "${lines[7]}") +cpp_cpu_threads_list=$(func_parser_value "${lines[7]}") +cpp_batch_size_key=$(func_parser_key "${lines[8]}") +cpp_batch_size_list=$(func_parser_value "${lines[8]}") +cpp_use_trt_key=$(func_parser_key "${lines[9]}") +cpp_use_trt_list=$(func_parser_value "${lines[9]}") +cpp_precision_key=$(func_parser_key "${lines[10]}") +cpp_precision_list=$(func_parser_value "${lines[10]}") +cpp_infer_model_key=$(func_parser_key "${lines[11]}") +cpp_image_dir_key=$(func_parser_key "${lines[12]}") +cpp_infer_img_dir=$(func_parser_value "${lines[12]}") +cpp_infer_key1=$(func_parser_key "${lines[13]}") +cpp_infer_value1=$(func_parser_value "${lines[13]}") +cpp_benchmark_key=$(func_parser_key "${lines[14]}") +cpp_benchmark_value=$(func_parser_value "${lines[14]}") + + +LOG_PATH="./tests/output" +mkdir -p ${LOG_PATH} +status_log="${LOG_PATH}/results_cpp.log" + + +function func_cpp_inference(){ + IFS='|' + _script=$1 + _model_dir=$2 + _log_path=$3 + _img_dir=$4 + _flag_quant=$5 + # inference + for use_gpu in ${cpp_use_gpu_list[*]}; do + if [ ${use_gpu} = "False" ] || [ ${use_gpu} = "cpu" ]; then + for use_mkldnn in ${cpp_use_mkldnn_list[*]}; do + if [ ${use_mkldnn} = "False" ] && [ ${_flag_quant} = "True" ]; then + continue + fi + for threads in ${cpp_cpu_threads_list[*]}; do + for batch_size in ${cpp_batch_size_list[*]}; do + _save_log_path="${_log_path}/cpp_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_${batch_size}.log" + set_infer_data=$(func_set_params "${cpp_image_dir_key}" "${_img_dir}") + set_benchmark=$(func_set_params "${cpp_benchmark_key}" "${cpp_benchmark_value}") + set_batchsize=$(func_set_params "${cpp_batch_size_key}" "${batch_size}") + set_cpu_threads=$(func_set_params "${cpp_cpu_threads_key}" "${threads}") + set_model_dir=$(func_set_params "${cpp_infer_model_key}" "${_model_dir}") + set_infer_params1=$(func_set_params "${cpp_infer_key1}" "${cpp_infer_value1}") + command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${cpp_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 " + eval $command + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check $last_status "${command}" "${status_log}" + done + done + done + elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then + for use_trt in ${cpp_use_trt_list[*]}; do + for precision in ${cpp_precision_list[*]}; do + if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then + continue + fi + if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then + continue + fi + if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [ ${_flag_quant} = "True" ]; then + continue + fi + for batch_size in ${cpp_batch_size_list[*]}; do + _save_log_path="${_log_path}/cpp_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log" + set_infer_data=$(func_set_params "${cpp_image_dir_key}" "${_img_dir}") + set_benchmark=$(func_set_params "${cpp_benchmark_key}" "${cpp_benchmark_value}") + set_batchsize=$(func_set_params "${cpp_batch_size_key}" "${batch_size}") + set_tensorrt=$(func_set_params "${cpp_use_trt_key}" "${use_trt}") + set_precision=$(func_set_params "${cpp_precision_key}" "${precision}") + set_model_dir=$(func_set_params "${cpp_infer_model_key}" "${_model_dir}") + set_infer_params1=$(func_set_params "${cpp_infer_key1}" "${cpp_infer_value1}") + command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 " + eval $command + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check $last_status "${command}" "${status_log}" + + done + done + done + else + echo "Does not support hardware other than CPU and GPU Currently!" + fi + done +} + + +cd deploy/cpp_infer +if [ ${use_opencv} = "True" ]; then + if [ -d "opencv-3.4.7/opencv3/" ] && [ $(md5sum opencv-3.4.7.tar.gz | awk -F ' ' '{print $1}') = "faa2b5950f8bee3f03118e600c74746a" ];then + echo "################### build opencv skipped ###################" + else + echo "################### build opencv ###################" + rm -rf opencv-3.4.7.tar.gz opencv-3.4.7/ + wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/opencv-3.4.7.tar.gz + tar -xf opencv-3.4.7.tar.gz + + cd opencv-3.4.7/ + install_path=$(pwd)/opencv3 + + rm -rf build + mkdir build + cd build + + cmake .. \ + -DCMAKE_INSTALL_PREFIX=${install_path} \ + -DCMAKE_BUILD_TYPE=Release \ + -DBUILD_SHARED_LIBS=OFF \ + -DWITH_IPP=OFF \ + -DBUILD_IPP_IW=OFF \ + -DWITH_LAPACK=OFF \ + -DWITH_EIGEN=OFF \ + -DCMAKE_INSTALL_LIBDIR=lib64 \ + -DWITH_ZLIB=ON \ + -DBUILD_ZLIB=ON \ + -DWITH_JPEG=ON \ + -DBUILD_JPEG=ON \ + -DWITH_PNG=ON \ + -DBUILD_PNG=ON \ + -DWITH_TIFF=ON \ + -DBUILD_TIFF=ON + + make -j + make install + cd ../ + echo "################### build opencv finished ###################" + fi +fi + + +echo "################### build PaddleOCR demo ####################" +if [ ${use_opencv} = "True" ]; then + OPENCV_DIR=$(pwd)/opencv-3.4.7/opencv3/ +else + OPENCV_DIR='' +fi +LIB_DIR=$(pwd)/Paddle/build/paddle_inference_install_dir/ +CUDA_LIB_DIR=$(dirname `find /usr -name libcudart.so`) +CUDNN_LIB_DIR=$(dirname `find /usr -name libcudnn.so`) + +BUILD_DIR=build +rm -rf ${BUILD_DIR} +mkdir ${BUILD_DIR} +cd ${BUILD_DIR} +cmake .. \ + -DPADDLE_LIB=${LIB_DIR} \ + -DWITH_MKL=ON \ + -DWITH_GPU=OFF \ + -DWITH_STATIC_LIB=OFF \ + -DWITH_TENSORRT=OFF \ + -DOPENCV_DIR=${OPENCV_DIR} \ + -DCUDNN_LIB=${CUDNN_LIB_DIR} \ + -DCUDA_LIB=${CUDA_LIB_DIR} \ + -DTENSORRT_DIR=${TENSORRT_DIR} \ + +make -j +cd ../../../ +echo "################### build PaddleOCR demo finished ###################" + + +# set cuda device +GPUID=$2 +if [ ${#GPUID} -le 0 ];then + env=" " +else + env="export CUDA_VISIBLE_DEVICES=${GPUID}" +fi +set CUDA_VISIBLE_DEVICES +eval $env + + +echo "################### run test ###################" +export Count=0 +IFS="|" +infer_quant_flag=(${cpp_infer_is_quant}) +for infer_model in ${cpp_infer_model_dir_list[*]}; do + #run inference + is_quant=${infer_quant_flag[Count]} + func_cpp_inference "${inference_cmd}" "${infer_model}" "${LOG_PATH}" "${cpp_infer_img_dir}" ${is_quant} + Count=$(($Count + 1)) +done diff --git a/tests/test_python.sh b/tests/test_python.sh new file mode 100644 index 0000000000000000000000000000000000000000..2f5fdf4e8dc63a0ebb33c7880bc08c24be759740 --- /dev/null +++ b/tests/test_python.sh @@ -0,0 +1,173 @@ +#!/bin/bash +source tests/common_func.sh + +FILENAME=$1 +dataline=$(awk 'NR==1, NR==51{print}' $FILENAME) + +# parser params +IFS=$'\n' +lines=(${dataline}) + +# The training params +model_name=$(func_parser_value "${lines[1]}") +python=$(func_parser_value "${lines[2]}") +gpu_list=$(func_parser_value "${lines[3]}") +train_use_gpu_key=$(func_parser_key "${lines[4]}") +train_use_gpu_value=$(func_parser_value "${lines[4]}") +autocast_list=$(func_parser_value "${lines[5]}") +autocast_key=$(func_parser_key "${lines[5]}") +epoch_key=$(func_parser_key "${lines[6]}") +epoch_num=$(func_parser_params "${lines[6]}") +save_model_key=$(func_parser_key "${lines[7]}") +train_batch_key=$(func_parser_key "${lines[8]}") +train_batch_value=$(func_parser_params "${lines[8]}") +pretrain_model_key=$(func_parser_key "${lines[9]}") +pretrain_model_value=$(func_parser_value "${lines[9]}") +train_model_name=$(func_parser_value "${lines[10]}") +train_infer_img_dir=$(func_parser_value "${lines[11]}") +train_param_key1=$(func_parser_key "${lines[12]}") +train_param_value1=$(func_parser_value "${lines[12]}") + +trainer_list=$(func_parser_value "${lines[14]}") +trainer_norm=$(func_parser_key "${lines[15]}") +norm_trainer=$(func_parser_value "${lines[15]}") +pact_key=$(func_parser_key "${lines[16]}") +pact_trainer=$(func_parser_value "${lines[16]}") +fpgm_key=$(func_parser_key "${lines[17]}") +fpgm_trainer=$(func_parser_value "${lines[17]}") +distill_key=$(func_parser_key "${lines[18]}") +distill_trainer=$(func_parser_value "${lines[18]}") +trainer_key1=$(func_parser_key "${lines[19]}") +trainer_value1=$(func_parser_value "${lines[19]}") +trainer_key2=$(func_parser_key "${lines[20]}") +trainer_value2=$(func_parser_value "${lines[20]}") + +eval_py=$(func_parser_value "${lines[23]}") +eval_key1=$(func_parser_key "${lines[24]}") +eval_value1=$(func_parser_value "${lines[24]}") + +save_infer_key=$(func_parser_key "${lines[27]}") +export_weight=$(func_parser_key "${lines[28]}") +norm_export=$(func_parser_value "${lines[29]}") +pact_export=$(func_parser_value "${lines[30]}") +fpgm_export=$(func_parser_value "${lines[31]}") +distill_export=$(func_parser_value "${lines[32]}") +export_key1=$(func_parser_key "${lines[33]}") +export_value1=$(func_parser_value "${lines[33]}") +export_key2=$(func_parser_key "${lines[34]}") +export_value2=$(func_parser_value "${lines[34]}") + +# parser inference model +infer_model_dir_list=$(func_parser_value "${lines[36]}") +infer_export_list=$(func_parser_value "${lines[37]}") +infer_is_quant=$(func_parser_value "${lines[38]}") +# parser inference +inference_py=$(func_parser_value "${lines[39]}") +use_gpu_key=$(func_parser_key "${lines[40]}") +use_gpu_list=$(func_parser_value "${lines[40]}") +use_mkldnn_key=$(func_parser_key "${lines[41]}") +use_mkldnn_list=$(func_parser_value "${lines[41]}") +cpu_threads_key=$(func_parser_key "${lines[42]}") +cpu_threads_list=$(func_parser_value "${lines[42]}") +batch_size_key=$(func_parser_key "${lines[43]}") +batch_size_list=$(func_parser_value "${lines[43]}") +use_trt_key=$(func_parser_key "${lines[44]}") +use_trt_list=$(func_parser_value "${lines[44]}") +precision_key=$(func_parser_key "${lines[45]}") +precision_list=$(func_parser_value "${lines[45]}") +infer_model_key=$(func_parser_key "${lines[46]}") +image_dir_key=$(func_parser_key "${lines[47]}") +infer_img_dir=$(func_parser_value "${lines[47]}") +save_log_key=$(func_parser_key "${lines[48]}") +benchmark_key=$(func_parser_key "${lines[49]}") +benchmark_value=$(func_parser_value "${lines[49]}") +infer_key1=$(func_parser_key "${lines[50]}") +infer_value1=$(func_parser_value "${lines[50]}") + + +LOG_PATH="./tests/output" +mkdir -p ${LOG_PATH} +status_log="${LOG_PATH}/results_python.log" + + +function func_inference(){ + IFS='|' + _python=$1 + _script=$2 + _model_dir=$3 + _log_path=$4 + _img_dir=$5 + _flag_quant=$6 + # inference + for use_gpu in ${use_gpu_list[*]}; do + if [ ${use_gpu} = "False" ] || [ ${use_gpu} = "cpu" ]; then + for use_mkldnn in ${use_mkldnn_list[*]}; do + if [ ${use_mkldnn} = "False" ] && [ ${_flag_quant} = "True" ]; then + continue + fi + for threads in ${cpu_threads_list[*]}; do + for batch_size in ${batch_size_list[*]}; do + _save_log_path="${_log_path}/infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_${batch_size}.log" + set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}") + set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}") + set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}") + set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}") + set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}") + set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}") + command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 " + eval $command + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check $last_status "${command}" "${status_log}" + done + done + done + elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then + for use_trt in ${use_trt_list[*]}; do + for precision in ${precision_list[*]}; do + if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then + continue + fi + if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then + continue + fi + if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [ ${_flag_quant} = "True" ]; then + continue + fi + for batch_size in ${batch_size_list[*]}; do + _save_log_path="${_log_path}/infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log" + set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}") + set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}") + set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}") + set_tensorrt=$(func_set_params "${use_trt_key}" "${use_trt}") + set_precision=$(func_set_params "${precision_key}" "${precision}") + set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}") + set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}") + command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 " + eval $command + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check $last_status "${command}" "${status_log}" + + done + done + done + else + echo "Does not support hardware other than CPU and GPU Currently!" + fi + done +} + + +# set cuda device +GPUID=$2 +if [ ${#GPUID} -le 0 ];then + env=" " +else + env="export CUDA_VISIBLE_DEVICES=${GPUID}" +fi +set CUDA_VISIBLE_DEVICES +eval $env + + +echo "################### run test ###################" diff --git a/tests/test_serving.sh b/tests/test_serving.sh new file mode 100644 index 0000000000000000000000000000000000000000..8998ee7ee4c7f14c1a7df86611b709647c0d1a05 --- /dev/null +++ b/tests/test_serving.sh @@ -0,0 +1,131 @@ +#!/bin/bash +source tests/common_func.sh + +FILENAME=$1 +dataline=$(awk 'NR==67, NR==81{print}' $FILENAME) + +# parser params +IFS=$'\n' +lines=(${dataline}) + +# parser serving +trans_model_py=$(func_parser_value "${lines[1]}") +infer_model_dir_key=$(func_parser_key "${lines[2]}") +infer_model_dir_value=$(func_parser_value "${lines[2]}") +model_filename_key=$(func_parser_key "${lines[3]}") +model_filename_value=$(func_parser_value "${lines[3]}") +params_filename_key=$(func_parser_key "${lines[4]}") +params_filename_value=$(func_parser_value "${lines[4]}") +serving_server_key=$(func_parser_key "${lines[5]}") +serving_server_value=$(func_parser_value "${lines[5]}") +serving_client_key=$(func_parser_key "${lines[6]}") +serving_client_value=$(func_parser_value "${lines[6]}") +serving_dir_value=$(func_parser_value "${lines[7]}") +web_service_py=$(func_parser_value "${lines[8]}") +web_use_gpu_key=$(func_parser_key "${lines[9]}") +web_use_gpu_list=$(func_parser_value "${lines[9]}") +web_use_mkldnn_key=$(func_parser_key "${lines[10]}") +web_use_mkldnn_list=$(func_parser_value "${lines[10]}") +web_cpu_threads_key=$(func_parser_key "${lines[11]}") +web_cpu_threads_list=$(func_parser_value "${lines[11]}") +web_use_trt_key=$(func_parser_key "${lines[12]}") +web_use_trt_list=$(func_parser_value "${lines[12]}") +web_precision_key=$(func_parser_key "${lines[13]}") +web_precision_list=$(func_parser_value "${lines[13]}") +pipeline_py=$(func_parser_value "${lines[14]}") + + +LOG_PATH="./tests/output" +mkdir -p ${LOG_PATH} +status_log="${LOG_PATH}/results_serving.log" + + +function func_serving(){ + IFS='|' + _python=$1 + _script=$2 + _model_dir=$3 + # pdserving + set_dirname=$(func_set_params "${infer_model_dir_key}" "${infer_model_dir_value}") + set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}") + set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}") + set_serving_server=$(func_set_params "${serving_server_key}" "${serving_server_value}") + set_serving_client=$(func_set_params "${serving_client_key}" "${serving_client_value}") + trans_model_cmd="${python} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}" + eval $trans_model_cmd + cd ${serving_dir_value} + echo $PWD + unset https_proxy + unset http_proxy + for use_gpu in ${web_use_gpu_list[*]}; do + echo ${ues_gpu} + if [ ${use_gpu} = "null" ]; then + for use_mkldnn in ${web_use_mkldnn_list[*]}; do + if [ ${use_mkldnn} = "False" ]; then + continue + fi + for threads in ${web_cpu_threads_list[*]}; do + _save_log_path="${_log_path}/server_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_1.log" + set_cpu_threads=$(func_set_params "${web_cpu_threads_key}" "${threads}") + web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}=${use_gpu} ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} &>${_save_log_path} &" + eval $web_service_cmd + sleep 2s + pipeline_cmd="${python} ${pipeline_py}" + eval $pipeline_cmd + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check $last_status "${pipeline_cmd}" "${status_log}" + PID=$! + kill $PID + sleep 2s + ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9 + done + done + elif [ ${use_gpu} = "0" ]; then + for use_trt in ${web_use_trt_list[*]}; do + for precision in ${web_precision_list[*]}; do + if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then + continue + fi + if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then + continue + fi + if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then + continue + fi + _save_log_path="${_log_path}/infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_1.log" + set_tensorrt=$(func_set_params "${web_use_trt_key}" "${use_trt}") + set_precision=$(func_set_params "${web_precision_key}" "${precision}") + web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} &>${_save_log_path} & " + eval $web_service_cmd + sleep 2s + pipeline_cmd="${python} ${pipeline_py}" + eval $pipeline_cmd + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check $last_status "${pipeline_cmd}" "${status_log}" + PID=$! + kill $PID + sleep 2s + ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9 + done + done + else + echo "Does not support hardware other than CPU and GPU Currently!" + fi + done +} + + +# set cuda device +GPUID=$2 +if [ ${#GPUID} -le 0 ];then + env=" " +else + env="export CUDA_VISIBLE_DEVICES=${GPUID}" +fi +set CUDA_VISIBLE_DEVICES +eval $env + + +echo "################### run test ###################" diff --git a/tools/eval.py b/tools/eval.py index 39a26ffefff46a9a3fe1465e874d501a334921c7..28247bc57450aaf067fcb405674098eacb990166 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -54,8 +54,7 @@ def main(): config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - use_srn = config['Architecture']['algorithm'] == "SRN" - use_sar = config['Architecture']['algorithm'] == "SAR" + extra_input = config['Architecture']['algorithm'] in ["SRN", "SAR"] if "model_type" in config['Architecture'].keys(): model_type = config['Architecture']['model_type'] else: @@ -72,7 +71,7 @@ def main(): # start eval metric = program.eval(model, valid_dataloader, post_process_class, - eval_class, model_type, use_srn, use_sar) + eval_class, model_type, extra_input) logger.info('metric eval ***************') for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) diff --git a/tools/export_center.py b/tools/export_center.py new file mode 100644 index 0000000000000000000000000000000000000000..c46e8b9d58997b9b66c6ce81b2558ecd4cad0e81 --- /dev/null +++ b/tools/export_center.py @@ -0,0 +1,77 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import pickle + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +from ppocr.data import build_dataloader +from ppocr.modeling.architectures import build_model +from ppocr.postprocess import build_post_process +from ppocr.utils.save_load import init_model, load_dygraph_params +from ppocr.utils.utility import print_dict +import tools.program as program + + +def main(): + global_config = config['Global'] + # build dataloader + config['Eval']['dataset']['name'] = config['Train']['dataset']['name'] + config['Eval']['dataset']['data_dir'] = config['Train']['dataset'][ + 'data_dir'] + config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][ + 'label_file_list'] + eval_dataloader = build_dataloader(config, 'Eval', device, logger) + + # build post process + post_process_class = build_post_process(config['PostProcess'], + global_config) + + # build model + # for rec algorithm + if hasattr(post_process_class, 'character'): + char_num = len(getattr(post_process_class, 'character')) + config['Architecture']["Head"]['out_channels'] = char_num + + #set return_features = True + config['Architecture']["Head"]["return_feats"] = True + + model = build_model(config['Architecture']) + + best_model_dict = load_dygraph_params(config, model, logger, None) + if len(best_model_dict): + logger.info('metric in ckpt ***************') + for k, v in best_model_dict.items(): + logger.info('{}:{}'.format(k, v)) + + # get features from train data + char_center = program.get_center(model, eval_dataloader, post_process_class) + + #serialize to disk + with open("train_center.pkl", 'wb') as f: + pickle.dump(char_center, f) + return + + +if __name__ == '__main__': + config, device, logger, vdl_writer = program.preprocess() + main() diff --git a/tools/export_model.py b/tools/export_model.py index d8fe297235b2f5de6861d387cff64e8737cd30c0..64a0d4036303716a632eb93c53f2478f32b42848 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -49,6 +49,12 @@ def export_single_model(model, arch_config, save_path, logger): ] ] model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "SAR": + other_shape = [ + paddle.static.InputSpec( + shape=[None, 3, 48, 160], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) else: infer_shape = [3, -1, -1] if arch_config["model_type"] == "rec": diff --git a/tools/infer/predict_e2e.py b/tools/infer/predict_e2e.py index 8ff279d7437965f725082a9eb1c83e05a7ffc8a8..5029d6059346a00062418d8d1b6cb029b0110643 100755 --- a/tools/infer/predict_e2e.py +++ b/tools/infer/predict_e2e.py @@ -141,7 +141,6 @@ if __name__ == "__main__": img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) - img = img[:, :, ::-1] if img is None: logger.info("error in loading image:{}".format(image_file)) continue diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 332cffd5395f8f511089b0bfde762820af7bbe8c..dad70281ef7604f110d29963103068bba1c8fd9d 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -68,6 +68,13 @@ class TextRecognizer(object): "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } + elif self.rec_algorithm == "SAR": + postprocess_params = { + 'name': 'SARLabelDecode', + "character_type": args.rec_char_type, + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors, self.config = \ utility.create_predictor(args, 'rec', logger) @@ -194,6 +201,41 @@ class TextRecognizer(object): return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2) + def resize_norm_img_sar(self, img, image_shape, + width_downsample_ratio=0.25): + imgC, imgH, imgW_min, imgW_max = image_shape + h = img.shape[0] + w = img.shape[1] + valid_ratio = 1.0 + # make sure new_width is an integral multiple of width_divisor. + width_divisor = int(1 / width_downsample_ratio) + # resize + ratio = w / float(h) + resize_w = math.ceil(imgH * ratio) + if resize_w % width_divisor != 0: + resize_w = round(resize_w / width_divisor) * width_divisor + if imgW_min is not None: + resize_w = max(imgW_min, resize_w) + if imgW_max is not None: + valid_ratio = min(1.0, 1.0 * resize_w / imgW_max) + resize_w = min(imgW_max, resize_w) + resized_image = cv2.resize(img, (resize_w, imgH)) + resized_image = resized_image.astype('float32') + # norm + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + resize_shape = resized_image.shape + padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32) + padding_im[:, :, 0:resize_w] = resized_image + pad_shape = padding_im.shape + + return padding_im, resize_shape, pad_shape, valid_ratio + def __call__(self, img_list): img_num = len(img_list) # Calculate the aspect ratio of all text bars @@ -216,11 +258,19 @@ class TextRecognizer(object): wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): - if self.rec_algorithm != "SRN": + if self.rec_algorithm != "SRN" and self.rec_algorithm != "SAR": norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) norm_img = norm_img[np.newaxis, :] norm_img_batch.append(norm_img) + elif self.rec_algorithm == "SAR": + norm_img, _, _, valid_ratio = self.resize_norm_img_sar( + img_list[indices[ino]], self.rec_image_shape) + norm_img = norm_img[np.newaxis, :] + valid_ratio = np.expand_dims(valid_ratio, axis=0) + valid_ratios = [] + valid_ratios.append(valid_ratio) + norm_img_batch.append(norm_img) else: norm_img = self.process_image_srn( img_list[indices[ino]], self.rec_image_shape, 8, 25) @@ -266,6 +316,25 @@ class TextRecognizer(object): if self.benchmark: self.autolog.times.stamp() preds = {"predict": outputs[2]} + elif self.rec_algorithm == "SAR": + valid_ratios = np.concatenate(valid_ratios) + inputs = [ + norm_img_batch, + valid_ratios, + ] + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle(input_names[ + i]) + input_tensor.copy_from_cpu(inputs[i]) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.benchmark: + self.autolog.times.stamp() + preds = outputs[0] else: self.input_tensor.copy_from_cpu(norm_img_batch) self.predictor.run() diff --git a/tools/program.py b/tools/program.py index f484cf4a1f512107bd755a6b446935751563bfcb..798e6dff297ad1149942488cca1d5540f1924867 100755 --- a/tools/program.py +++ b/tools/program.py @@ -31,6 +31,7 @@ from ppocr.utils.stats import TrainingStats from ppocr.utils.save_load import save_model from ppocr.utils.utility import print_dict from ppocr.utils.logging import get_logger +from ppocr.utils import profiler from ppocr.data import build_dataloader import numpy as np @@ -42,6 +43,13 @@ class ArgsParser(ArgumentParser): self.add_argument("-c", "--config", help="configuration file to use") self.add_argument( "-o", "--opt", nargs='+', help="set configuration options") + self.add_argument( + '-p', + '--profiler_options', + type=str, + default=None, + help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".' + ) def parse_args(self, argv=None): args = super(ArgsParser, self).parse_args(argv) @@ -158,6 +166,7 @@ def train(config, epoch_num = config['Global']['epoch_num'] print_batch_step = config['Global']['print_batch_step'] eval_batch_step = config['Global']['eval_batch_step'] + profiler_options = config['profiler_options'] global_step = 0 if 'global_step' in pre_best_model_dict: @@ -186,12 +195,13 @@ def train(config, model.train() use_srn = config['Architecture']['algorithm'] == "SRN" - use_nrtr = config['Architecture']['algorithm'] == "NRTR" - use_sar = config['Architecture']['algorithm'] == 'SAR' + extra_input = config['Architecture'][ + 'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"] try: model_type = config['Architecture']['model_type'] except: model_type = None + algorithm = config['Architecture']['algorithm'] if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] @@ -208,6 +218,7 @@ def train(config, max_iter = len(train_dataloader) - 1 if platform.system( ) == "Windows" else len(train_dataloader) for idx, batch in enumerate(train_dataloader): + profiler.add_profiler_step(profiler_options) train_reader_cost += time.time() - batch_start if idx >= max_iter: break @@ -215,7 +226,7 @@ def train(config, images = batch[0] if use_srn: model_average = True - if use_srn or model_type == 'table' or use_nrtr or use_sar: + if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) else: preds = model(images) @@ -279,8 +290,7 @@ def train(config, post_process_class, eval_class, model_type, - use_srn=use_srn, - use_sar=use_sar) + extra_input=extra_input) cur_metric_str = 'cur metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in cur_metric.items()])) logger.info(cur_metric_str) @@ -351,9 +361,8 @@ def eval(model, valid_dataloader, post_process_class, eval_class, - model_type, - use_srn=False, - use_sar=False): + model_type=None, + extra_input=False): model.eval() with paddle.no_grad(): total_frame = 0.0 @@ -366,7 +375,7 @@ def eval(model, break images = batch[0] start = time.time() - if use_srn or model_type == 'table' or use_sar: + if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) else: preds = model(images) @@ -390,10 +399,76 @@ def eval(model, return metric +def update_center(char_center, post_result, preds): + result, label = post_result + feats, logits = preds + logits = paddle.argmax(logits, axis=-1) + feats = feats.numpy() + logits = logits.numpy() + + for idx_sample in range(len(label)): + if result[idx_sample][0] == label[idx_sample][0]: + feat = feats[idx_sample] + logit = logits[idx_sample] + for idx_time in range(len(logit)): + index = logit[idx_time] + if index in char_center.keys(): + char_center[index][0] = ( + char_center[index][0] * char_center[index][1] + + feat[idx_time]) / (char_center[index][1] + 1) + char_center[index][1] += 1 + else: + char_center[index] = [feat[idx_time], 1] + return char_center + + +def get_center(model, eval_dataloader, post_process_class): + pbar = tqdm(total=len(eval_dataloader), desc='get center:') + max_iter = len(eval_dataloader) - 1 if platform.system( + ) == "Windows" else len(eval_dataloader) + char_center = dict() + for idx, batch in enumerate(eval_dataloader): + if idx >= max_iter: + break + images = batch[0] + start = time.time() + preds = model(images) + + batch = [item.numpy() for item in batch] + # Obtain usable results from post-processing methods + total_time += time.time() - start + # Evaluate the results of the current batch + post_result = post_process_class(preds, batch[1]) + + #update char_center + char_center = update_center(char_center, post_result, preds) + pbar.update(1) + + pbar.close() + for key in char_center.keys(): + char_center[key] = char_center[key][0] + return char_center + + def preprocess(is_train=False): FLAGS = ArgsParser().parse_args() + profiler_options = FLAGS.profiler_options config = load_config(FLAGS.config) merge_config(FLAGS.opt) + profile_dic = {"profiler_options": FLAGS.profiler_options} + merge_config(profile_dic) + + if is_train: + # save_config + save_model_dir = config['Global']['save_model_dir'] + os.makedirs(save_model_dir, exist_ok=True) + with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: + yaml.dump( + dict(config), f, default_flow_style=False, sort_keys=False) + log_file = '{}/train.log'.format(save_model_dir) + else: + log_file = None + logger = get_logger(name='root', log_file=log_file) # check if set use_gpu=True in paddlepaddle cpu version use_gpu = config['Global']['use_gpu'] @@ -402,24 +477,20 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', - 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE' + 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', + 'SEED' ] + windows_not_support_list = ['PSE'] + if platform.system() == "Windows" and alg in windows_not_support_list: + logger.warning('{} is not support in Windows now'.format( + windows_not_support_list)) + sys.exit() device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = paddle.set_device(device) config['Global']['distributed'] = dist.get_world_size() != 1 - if is_train: - # save_config - save_model_dir = config['Global']['save_model_dir'] - os.makedirs(save_model_dir, exist_ok=True) - with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: - yaml.dump( - dict(config), f, default_flow_style=False, sort_keys=False) - log_file = '{}/train.log'.format(save_model_dir) - else: - log_file = None - logger = get_logger(name='root', log_file=log_file) + if config['Global']['use_visualdl']: from visualdl import LogWriter save_model_dir = config['Global']['save_model_dir']