未验证 提交 b38f3534 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #2921 from LDOUBLEV/bm_trt_dyg

suport TensorRT inference
...@@ -30,6 +30,42 @@ void DBDetector::LoadModel(const std::string &model_dir) { ...@@ -30,6 +30,42 @@ void DBDetector::LoadModel(const std::string &model_dir) {
this->use_fp16_ ? paddle_infer::Config::Precision::kHalf this->use_fp16_ ? paddle_infer::Config::Precision::kHalf
: paddle_infer::Config::Precision::kFloat32, : paddle_infer::Config::Precision::kFloat32,
false, false); false, false);
std::map<std::string, std::vector<int>> min_input_shape = {
{"x", {1, 3, 50, 50}},
{"conv2d_92.tmp_0", {1, 96, 20, 20}},
{"conv2d_91.tmp_0", {1, 96, 10, 10}},
{"nearest_interp_v2_1.tmp_0", {1, 96, 10, 10}},
{"nearest_interp_v2_2.tmp_0", {1, 96, 20, 20}},
{"nearest_interp_v2_3.tmp_0", {1, 24, 20, 20}},
{"nearest_interp_v2_4.tmp_0", {1, 24, 20, 20}},
{"nearest_interp_v2_5.tmp_0", {1, 24, 20, 20}},
{"elementwise_add_7", {1, 56, 2, 2}},
{"nearest_interp_v2_0.tmp_0", {1, 96, 2, 2}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"x", {1, 3, this->max_side_len_, this->max_side_len_}},
{"conv2d_92.tmp_0", {1, 96, 400, 400}},
{"conv2d_91.tmp_0", {1, 96, 200, 200}},
{"nearest_interp_v2_1.tmp_0", {1, 96, 200, 200}},
{"nearest_interp_v2_2.tmp_0", {1, 96, 400, 400}},
{"nearest_interp_v2_3.tmp_0", {1, 24, 400, 400}},
{"nearest_interp_v2_4.tmp_0", {1, 24, 400, 400}},
{"nearest_interp_v2_5.tmp_0", {1, 24, 400, 400}},
{"elementwise_add_7", {1, 56, 400, 400}},
{"nearest_interp_v2_0.tmp_0", {1, 96, 400, 400}}};
std::map<std::string, std::vector<int>> opt_input_shape = {
{"x", {1, 3, 640, 640}},
{"conv2d_92.tmp_0", {1, 96, 160, 160}},
{"conv2d_91.tmp_0", {1, 96, 80, 80}},
{"nearest_interp_v2_1.tmp_0", {1, 96, 80, 80}},
{"nearest_interp_v2_2.tmp_0", {1, 96, 160, 160}},
{"nearest_interp_v2_3.tmp_0", {1, 24, 160, 160}},
{"nearest_interp_v2_4.tmp_0", {1, 24, 160, 160}},
{"nearest_interp_v2_5.tmp_0", {1, 24, 160, 160}},
{"elementwise_add_7", {1, 56, 40, 40}},
{"nearest_interp_v2_0.tmp_0", {1, 96, 40, 40}}};
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape);
} }
} else { } else {
config.DisableGpu(); config.DisableGpu();
...@@ -48,7 +84,7 @@ void DBDetector::LoadModel(const std::string &model_dir) { ...@@ -48,7 +84,7 @@ void DBDetector::LoadModel(const std::string &model_dir) {
config.SwitchIrOptim(true); config.SwitchIrOptim(true);
config.EnableMemoryOptim(); config.EnableMemoryOptim();
config.DisableGlogInfo(); // config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config); this->predictor_ = CreatePredictor(config);
} }
......
...@@ -106,6 +106,15 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { ...@@ -106,6 +106,15 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
this->use_fp16_ ? paddle_infer::Config::Precision::kHalf this->use_fp16_ ? paddle_infer::Config::Precision::kHalf
: paddle_infer::Config::Precision::kFloat32, : paddle_infer::Config::Precision::kFloat32,
false, false); false, false);
std::map<std::string, std::vector<int>> min_input_shape = {
{"x", {1, 3, 32, 10}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"x", {1, 3, 32, 2000}}};
std::map<std::string, std::vector<int>> opt_input_shape = {
{"x", {1, 3, 32, 320}}};
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape);
} }
} else { } else {
config.DisableGpu(); config.DisableGpu();
......
...@@ -77,19 +77,13 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, ...@@ -77,19 +77,13 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
int resize_h = int(float(h) * ratio); int resize_h = int(float(h) * ratio);
int resize_w = int(float(w) * ratio); int resize_w = int(float(w) * ratio);
resize_h = max(int(round(float(resize_h) / 32) * 32), 32); resize_h = max(int(round(float(resize_h) / 32) * 32), 32);
resize_w = max(int(round(float(resize_w) / 32) * 32), 32); resize_w = max(int(round(float(resize_w) / 32) * 32), 32);
if (!use_tensorrt) { cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); ratio_h = float(resize_h) / float(h);
ratio_h = float(resize_h) / float(h); ratio_w = float(resize_w) / float(w);
ratio_w = float(resize_w) / float(w);
} else {
cv::resize(img, resize_img, cv::Size(640, 640));
ratio_h = float(640) / float(h);
ratio_w = float(640) / float(w);
}
} }
void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio, void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
...@@ -108,23 +102,12 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio, ...@@ -108,23 +102,12 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
resize_w = imgW; resize_w = imgW;
else else
resize_w = int(ceilf(imgH * ratio)); resize_w = int(ceilf(imgH * ratio));
if (!use_tensorrt) {
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::INTER_LINEAR); cv::INTER_LINEAR);
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0,
int(imgW - resize_img.cols), cv::BORDER_CONSTANT, int(imgW - resize_img.cols), cv::BORDER_CONSTANT,
{127, 127, 127}); {127, 127, 127});
} else {
int k = int(img.cols * 32 / img.rows);
if (k >= 100) {
cv::resize(img, resize_img, cv::Size(100, 32), 0.f, 0.f,
cv::INTER_LINEAR);
} else {
cv::resize(img, resize_img, cv::Size(k, 32), 0.f, 0.f, cv::INTER_LINEAR);
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, int(100 - k),
cv::BORDER_CONSTANT, {127, 127, 127});
}
}
} }
void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
...@@ -142,15 +125,11 @@ void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, ...@@ -142,15 +125,11 @@ void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
else else
resize_w = int(ceilf(imgH * ratio)); resize_w = int(ceilf(imgH * ratio));
if (!use_tensorrt) { cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, cv::INTER_LINEAR);
cv::INTER_LINEAR); if (resize_w < imgW) {
if (resize_w < imgW) { cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w, cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
}
} else {
cv::resize(img, resize_img, cv::Size(100, 32), 0.f, 0.f, cv::INTER_LINEAR);
} }
} }
......
...@@ -12,7 +12,7 @@ cmake .. \ ...@@ -12,7 +12,7 @@ cmake .. \
-DWITH_MKL=ON \ -DWITH_MKL=ON \
-DWITH_GPU=OFF \ -DWITH_GPU=OFF \
-DWITH_STATIC_LIB=OFF \ -DWITH_STATIC_LIB=OFF \
-DUSE_TENSORRT=OFF \ -DWITH_TENSORRT=OFF \
-DOPENCV_DIR=${OPENCV_DIR} \ -DOPENCV_DIR=${OPENCV_DIR} \
-DCUDNN_LIB=${CUDNN_LIB_DIR} \ -DCUDNN_LIB=${CUDNN_LIB_DIR} \
-DCUDA_LIB=${CUDA_LIB_DIR} \ -DCUDA_LIB=${CUDA_LIB_DIR} \
......
...@@ -21,6 +21,9 @@ import json ...@@ -21,6 +21,9 @@ import json
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import math import math
from paddle import inference from paddle import inference
import time
from ppocr.utils.logging import get_logger
logger = get_logger()
def parse_args(): def parse_args():
...@@ -98,6 +101,7 @@ def parse_args(): ...@@ -98,6 +101,7 @@ def parse_args():
parser.add_argument("--cls_thresh", type=float, default=0.9) parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--cpu_threads", type=int, default=10)
parser.add_argument("--use_pdserving", type=str2bool, default=False) parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--use_mp", type=str2bool, default=False) parser.add_argument("--use_mp", type=str2bool, default=False)
...@@ -135,19 +139,97 @@ def create_predictor(args, mode, logger): ...@@ -135,19 +139,97 @@ def create_predictor(args, mode, logger):
config.enable_use_gpu(args.gpu_mem, 0) config.enable_use_gpu(args.gpu_mem, 0)
if args.use_tensorrt: if args.use_tensorrt:
config.enable_tensorrt_engine( config.enable_tensorrt_engine(
precision_mode=inference.PrecisionType.Half precision_mode=inference.PrecisionType.Float32,
if args.use_fp16 else inference.PrecisionType.Float32, max_batch_size=args.max_batch_size,
max_batch_size=args.max_batch_size) min_subgraph_size=3) # skip the minmum trt subgraph
if mode == "det" and "mobile" in model_file_path:
min_input_shape = {
"x": [1, 3, 50, 50],
"conv2d_92.tmp_0": [1, 96, 20, 20],
"conv2d_91.tmp_0": [1, 96, 10, 10],
"nearest_interp_v2_1.tmp_0": [1, 96, 10, 10],
"nearest_interp_v2_2.tmp_0": [1, 96, 20, 20],
"nearest_interp_v2_3.tmp_0": [1, 24, 20, 20],
"nearest_interp_v2_4.tmp_0": [1, 24, 20, 20],
"nearest_interp_v2_5.tmp_0": [1, 24, 20, 20],
"elementwise_add_7": [1, 56, 2, 2],
"nearest_interp_v2_0.tmp_0": [1, 96, 2, 2]
}
max_input_shape = {
"x": [1, 3, 2000, 2000],
"conv2d_92.tmp_0": [1, 96, 400, 400],
"conv2d_91.tmp_0": [1, 96, 200, 200],
"nearest_interp_v2_1.tmp_0": [1, 96, 200, 200],
"nearest_interp_v2_2.tmp_0": [1, 96, 400, 400],
"nearest_interp_v2_3.tmp_0": [1, 24, 400, 400],
"nearest_interp_v2_4.tmp_0": [1, 24, 400, 400],
"nearest_interp_v2_5.tmp_0": [1, 24, 400, 400],
"elementwise_add_7": [1, 56, 400, 400],
"nearest_interp_v2_0.tmp_0": [1, 96, 400, 400]
}
opt_input_shape = {
"x": [1, 3, 640, 640],
"conv2d_92.tmp_0": [1, 96, 160, 160],
"conv2d_91.tmp_0": [1, 96, 80, 80],
"nearest_interp_v2_1.tmp_0": [1, 96, 80, 80],
"nearest_interp_v2_2.tmp_0": [1, 96, 160, 160],
"nearest_interp_v2_3.tmp_0": [1, 24, 160, 160],
"nearest_interp_v2_4.tmp_0": [1, 24, 160, 160],
"nearest_interp_v2_5.tmp_0": [1, 24, 160, 160],
"elementwise_add_7": [1, 56, 40, 40],
"nearest_interp_v2_0.tmp_0": [1, 96, 40, 40]
}
if mode == "det" and "server" in model_file_path:
min_input_shape = {
"x": [1, 3, 50, 50],
"conv2d_59.tmp_0": [1, 96, 20, 20],
"nearest_interp_v2_2.tmp_0": [1, 96, 20, 20],
"nearest_interp_v2_3.tmp_0": [1, 24, 20, 20],
"nearest_interp_v2_4.tmp_0": [1, 24, 20, 20],
"nearest_interp_v2_5.tmp_0": [1, 24, 20, 20]
}
max_input_shape = {
"x": [1, 3, 2000, 2000],
"conv2d_59.tmp_0": [1, 96, 400, 400],
"nearest_interp_v2_2.tmp_0": [1, 96, 400, 400],
"nearest_interp_v2_3.tmp_0": [1, 24, 400, 400],
"nearest_interp_v2_4.tmp_0": [1, 24, 400, 400],
"nearest_interp_v2_5.tmp_0": [1, 24, 400, 400]
}
opt_input_shape = {
"x": [1, 3, 640, 640],
"conv2d_59.tmp_0": [1, 96, 160, 160],
"nearest_interp_v2_2.tmp_0": [1, 96, 160, 160],
"nearest_interp_v2_3.tmp_0": [1, 24, 160, 160],
"nearest_interp_v2_4.tmp_0": [1, 24, 160, 160],
"nearest_interp_v2_5.tmp_0": [1, 24, 160, 160]
}
elif mode == "rec":
min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]}
elif mode == "cls":
min_input_shape = {"x": [args.rec_batch_num, 3, 48, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 48, 2000]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
else:
min_input_shape = {"x": [1, 3, 10, 10]}
max_input_shape = {"x": [1, 3, 1000, 1000]}
opt_input_shape = {"x": [1, 3, 500, 500]}
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
opt_input_shape)
else: else:
config.disable_gpu() config.disable_gpu()
config.set_cpu_math_library_num_threads(6) if hasattr(args, "cpu_threads"):
config.set_cpu_math_library_num_threads(args.cpu_threads)
else:
# default cpu threads as 10
config.set_cpu_math_library_num_threads(10)
if args.enable_mkldnn: if args.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak # cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10) config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn() config.enable_mkldnn()
# TODO LDOUBLEV: fix mkldnn bug when bach_size > 1
#config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'})
args.rec_batch_num = 1
# enable memory optim # enable memory optim
config.enable_memory_optim() config.enable_memory_optim()
...@@ -210,7 +292,7 @@ def draw_ocr(image, ...@@ -210,7 +292,7 @@ def draw_ocr(image,
txts=None, txts=None,
scores=None, scores=None,
drop_score=0.5, drop_score=0.5,
font_path="./doc/simfang.ttf"): font_path="./doc/fonts/simfang.ttf"):
""" """
Visualize the results of OCR detection and recognition Visualize the results of OCR detection and recognition
args: args:
...@@ -418,22 +500,4 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5): ...@@ -418,22 +500,4 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5):
if __name__ == '__main__': if __name__ == '__main__':
test_img = "./doc/test_v2" pass
predict_txt = "./doc/predict.txt"
f = open(predict_txt, 'r')
data = f.readlines()
img_path, anno = data[0].strip().split('\t')
img_name = os.path.basename(img_path)
img_path = os.path.join(test_img, img_name)
image = Image.open(img_path)
data = json.loads(anno)
boxes, txts, scores = [], [], []
for dic in data:
boxes.append(dic['points'])
txts.append(dic['transcription'])
scores.append(round(dic['scores'], 3))
new_img = draw_ocr(image, boxes, txts, scores)
cv2.imwrite(img_name, new_img)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册