From 616ad6a179b0cc7aede699fe8cfe1dae92d3ef92 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 9 Jun 2021 19:15:57 +0800 Subject: [PATCH] fix tps and fix trt --- ppocr/modeling/transforms/tps.py | 13 ++++--------- tools/infer/utility.py | 32 +++++--------------------------- 2 files changed, 9 insertions(+), 36 deletions(-) diff --git a/ppocr/modeling/transforms/tps.py b/ppocr/modeling/transforms/tps.py index 78338edf..13220991 100644 --- a/ppocr/modeling/transforms/tps.py +++ b/ppocr/modeling/transforms/tps.py @@ -230,15 +230,10 @@ class GridGenerator(nn.Layer): def build_inv_delta_C_paddle(self, C): """ Return inv_delta_C which is needed to calculate T """ F = self.F - hat_C = paddle.zeros((F, F), dtype='float64') # F x F - for i in range(0, F): - for j in range(i, F): - if i == j: - hat_C[i, j] = 1 - else: - r = paddle.norm(C[i] - C[j]) - hat_C[i, j] = r - hat_C[j, i] = r + hat_eye = paddle.eye(F, dtype='float64') # F x F + tmp1 = C.reshape([1, F, 2]) + tmp2 = C.reshape([F, 1, 2]) + hat_C = paddle.norm(tmp1 - tmp2, axis=2) + hat_eye hat_C = (hat_C**2) * paddle.log(hat_C) delta_C = paddle.concat( # F+3 x F+3 [ diff --git a/tools/infer/utility.py b/tools/infer/utility.py index d16a1923..2b70f36e 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -235,12 +235,13 @@ def create_predictor(args, mode, logger): config.enable_tensorrt_engine( precision_mode=inference.PrecisionType.Float32, max_batch_size=args.max_batch_size, - min_subgraph_size=10) # skip the minmum trt subgraph - if mode == "det" and "mobile" in model_file_path: + min_subgraph_size=3) # skip the minmum trt subgraph + if mode == "det": min_input_shape = { "x": [1, 3, 50, 50], "conv2d_92.tmp_0": [1, 96, 20, 20], "conv2d_91.tmp_0": [1, 96, 10, 10], + "conv2d_59.tmp_0": [1, 96, 20, 20], "nearest_interp_v2_1.tmp_0": [1, 96, 10, 10], "nearest_interp_v2_2.tmp_0": [1, 96, 20, 20], "nearest_interp_v2_3.tmp_0": [1, 24, 20, 20], @@ -253,6 +254,7 @@ def create_predictor(args, mode, logger): "x": [1, 3, 2000, 2000], "conv2d_92.tmp_0": [1, 96, 400, 400], "conv2d_91.tmp_0": [1, 96, 200, 200], + "conv2d_59.tmp_0": [1, 96, 400, 400], "nearest_interp_v2_1.tmp_0": [1, 96, 200, 200], "nearest_interp_v2_2.tmp_0": [1, 96, 400, 400], "nearest_interp_v2_3.tmp_0": [1, 24, 400, 400], @@ -265,6 +267,7 @@ def create_predictor(args, mode, logger): "x": [1, 3, 640, 640], "conv2d_92.tmp_0": [1, 96, 160, 160], "conv2d_91.tmp_0": [1, 96, 80, 80], + "conv2d_59.tmp_0": [1, 96, 160, 160], "nearest_interp_v2_1.tmp_0": [1, 96, 80, 80], "nearest_interp_v2_2.tmp_0": [1, 96, 160, 160], "nearest_interp_v2_3.tmp_0": [1, 24, 160, 160], @@ -273,31 +276,6 @@ def create_predictor(args, mode, logger): "elementwise_add_7": [1, 56, 40, 40], "nearest_interp_v2_0.tmp_0": [1, 96, 40, 40] } - if mode == "det" and "server" in model_file_path: - min_input_shape = { - "x": [1, 3, 50, 50], - "conv2d_59.tmp_0": [1, 96, 20, 20], - "nearest_interp_v2_2.tmp_0": [1, 96, 20, 20], - "nearest_interp_v2_3.tmp_0": [1, 24, 20, 20], - "nearest_interp_v2_4.tmp_0": [1, 24, 20, 20], - "nearest_interp_v2_5.tmp_0": [1, 24, 20, 20] - } - max_input_shape = { - "x": [1, 3, 2000, 2000], - "conv2d_59.tmp_0": [1, 96, 400, 400], - "nearest_interp_v2_2.tmp_0": [1, 96, 400, 400], - "nearest_interp_v2_3.tmp_0": [1, 24, 400, 400], - "nearest_interp_v2_4.tmp_0": [1, 24, 400, 400], - "nearest_interp_v2_5.tmp_0": [1, 24, 400, 400] - } - opt_input_shape = { - "x": [1, 3, 640, 640], - "conv2d_59.tmp_0": [1, 96, 160, 160], - "nearest_interp_v2_2.tmp_0": [1, 96, 160, 160], - "nearest_interp_v2_3.tmp_0": [1, 24, 160, 160], - "nearest_interp_v2_4.tmp_0": [1, 24, 160, 160], - "nearest_interp_v2_5.tmp_0": [1, 24, 160, 160] - } elif mode == "rec": min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]} max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]} -- GitLab