提交 602b6c18 编写于 作者: L LDOUBLEV

Merge branch 'fix_infer_mode' of https://github.com/LDOUBLEV/PaddleOCR into fix_infer_mode

...@@ -230,15 +230,8 @@ class GridGenerator(nn.Layer): ...@@ -230,15 +230,8 @@ class GridGenerator(nn.Layer):
def build_inv_delta_C_paddle(self, C): def build_inv_delta_C_paddle(self, C):
""" Return inv_delta_C which is needed to calculate T """ """ Return inv_delta_C which is needed to calculate T """
F = self.F F = self.F
hat_C = paddle.zeros((F, F), dtype='float64') # F x F hat_eye = paddle.eye(F, dtype='float64') # F x F
for i in range(0, F): hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
for j in range(i, F):
if i == j:
hat_C[i, j] = 1
else:
r = paddle.norm(C[i] - C[j])
hat_C[i, j] = r
hat_C[j, i] = r
hat_C = (hat_C**2) * paddle.log(hat_C) hat_C = (hat_C**2) * paddle.log(hat_C)
delta_C = paddle.concat( # F+3 x F+3 delta_C = paddle.concat( # F+3 x F+3
[ [
......
...@@ -235,3 +235,4 @@ if __name__ == "__main__": ...@@ -235,3 +235,4 @@ if __name__ == "__main__":
"det_res_{}".format(img_name_pure)) "det_res_{}".format(img_name_pure))
cv2.imwrite(img_path, src_im) cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path)) logger.info("The visualized image saved in {}".format(img_path))
...@@ -37,6 +37,7 @@ def init_args(): ...@@ -37,6 +37,7 @@ def init_args():
parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--min_subgraph_size", type=int, default=3)
parser.add_argument("--precision", type=str, default="fp32") parser.add_argument("--precision", type=str, default="fp32")
parser.add_argument("--gpu_mem", type=int, default=500) parser.add_argument("--gpu_mem", type=int, default=500)
...@@ -165,12 +166,14 @@ def create_predictor(args, mode, logger): ...@@ -165,12 +166,14 @@ def create_predictor(args, mode, logger):
config.enable_tensorrt_engine( config.enable_tensorrt_engine(
precision_mode=inference.PrecisionType.Float32, precision_mode=inference.PrecisionType.Float32,
max_batch_size=args.max_batch_size, max_batch_size=args.max_batch_size,
min_subgraph_size=3) # skip the minmum trt subgraph min_subgraph_size=args.min_subgraph_size)
if mode == "det" and "mobile" in model_file_path: # skip the minmum trt subgraph
if mode == "det":
min_input_shape = { min_input_shape = {
"x": [1, 3, 50, 50], "x": [1, 3, 50, 50],
"conv2d_92.tmp_0": [1, 96, 20, 20], "conv2d_92.tmp_0": [1, 96, 20, 20],
"conv2d_91.tmp_0": [1, 96, 10, 10], "conv2d_91.tmp_0": [1, 96, 10, 10],
"conv2d_59.tmp_0": [1, 96, 20, 20],
"nearest_interp_v2_1.tmp_0": [1, 96, 10, 10], "nearest_interp_v2_1.tmp_0": [1, 96, 10, 10],
"nearest_interp_v2_2.tmp_0": [1, 96, 20, 20], "nearest_interp_v2_2.tmp_0": [1, 96, 20, 20],
"nearest_interp_v2_3.tmp_0": [1, 24, 20, 20], "nearest_interp_v2_3.tmp_0": [1, 24, 20, 20],
...@@ -183,6 +186,7 @@ def create_predictor(args, mode, logger): ...@@ -183,6 +186,7 @@ def create_predictor(args, mode, logger):
"x": [1, 3, 2000, 2000], "x": [1, 3, 2000, 2000],
"conv2d_92.tmp_0": [1, 96, 400, 400], "conv2d_92.tmp_0": [1, 96, 400, 400],
"conv2d_91.tmp_0": [1, 96, 200, 200], "conv2d_91.tmp_0": [1, 96, 200, 200],
"conv2d_59.tmp_0": [1, 96, 400, 400],
"nearest_interp_v2_1.tmp_0": [1, 96, 200, 200], "nearest_interp_v2_1.tmp_0": [1, 96, 200, 200],
"nearest_interp_v2_2.tmp_0": [1, 96, 400, 400], "nearest_interp_v2_2.tmp_0": [1, 96, 400, 400],
"nearest_interp_v2_3.tmp_0": [1, 24, 400, 400], "nearest_interp_v2_3.tmp_0": [1, 24, 400, 400],
...@@ -195,6 +199,7 @@ def create_predictor(args, mode, logger): ...@@ -195,6 +199,7 @@ def create_predictor(args, mode, logger):
"x": [1, 3, 640, 640], "x": [1, 3, 640, 640],
"conv2d_92.tmp_0": [1, 96, 160, 160], "conv2d_92.tmp_0": [1, 96, 160, 160],
"conv2d_91.tmp_0": [1, 96, 80, 80], "conv2d_91.tmp_0": [1, 96, 80, 80],
"conv2d_59.tmp_0": [1, 96, 160, 160],
"nearest_interp_v2_1.tmp_0": [1, 96, 80, 80], "nearest_interp_v2_1.tmp_0": [1, 96, 80, 80],
"nearest_interp_v2_2.tmp_0": [1, 96, 160, 160], "nearest_interp_v2_2.tmp_0": [1, 96, 160, 160],
"nearest_interp_v2_3.tmp_0": [1, 24, 160, 160], "nearest_interp_v2_3.tmp_0": [1, 24, 160, 160],
...@@ -203,31 +208,6 @@ def create_predictor(args, mode, logger): ...@@ -203,31 +208,6 @@ def create_predictor(args, mode, logger):
"elementwise_add_7": [1, 56, 40, 40], "elementwise_add_7": [1, 56, 40, 40],
"nearest_interp_v2_0.tmp_0": [1, 96, 40, 40] "nearest_interp_v2_0.tmp_0": [1, 96, 40, 40]
} }
if mode == "det" and "server" in model_file_path:
min_input_shape = {
"x": [1, 3, 50, 50],
"conv2d_59.tmp_0": [1, 96, 20, 20],
"nearest_interp_v2_2.tmp_0": [1, 96, 20, 20],
"nearest_interp_v2_3.tmp_0": [1, 24, 20, 20],
"nearest_interp_v2_4.tmp_0": [1, 24, 20, 20],
"nearest_interp_v2_5.tmp_0": [1, 24, 20, 20]
}
max_input_shape = {
"x": [1, 3, 2000, 2000],
"conv2d_59.tmp_0": [1, 96, 400, 400],
"nearest_interp_v2_2.tmp_0": [1, 96, 400, 400],
"nearest_interp_v2_3.tmp_0": [1, 24, 400, 400],
"nearest_interp_v2_4.tmp_0": [1, 24, 400, 400],
"nearest_interp_v2_5.tmp_0": [1, 24, 400, 400]
}
opt_input_shape = {
"x": [1, 3, 640, 640],
"conv2d_59.tmp_0": [1, 96, 160, 160],
"nearest_interp_v2_2.tmp_0": [1, 96, 160, 160],
"nearest_interp_v2_3.tmp_0": [1, 24, 160, 160],
"nearest_interp_v2_4.tmp_0": [1, 24, 160, 160],
"nearest_interp_v2_5.tmp_0": [1, 24, 160, 160]
}
elif mode == "rec": elif mode == "rec":
min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]} min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]} max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册