diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 2e708d204192f62a44e8de04635fc43b2d3e29b6..1c82280099f17f6d3bf848669e47439505f10576 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -51,7 +51,7 @@ def init_args(): parser.add_argument("--det_db_box_thresh", type=float, default=0.6) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5) parser.add_argument("--max_batch_size", type=int, default=10) - parser.add_argument("--use_dilation", type=bool, default=False) + parser.add_argument("--use_dilation", type=str2bool, default=False) parser.add_argument("--det_db_score_mode", type=str, default="fast") # EAST parmas parser.add_argument("--det_east_score_thresh", type=float, default=0.8) @@ -61,7 +61,7 @@ def init_args(): # SAST parmas parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) - parser.add_argument("--det_sast_polygon", type=bool, default=False) + parser.add_argument("--det_sast_polygon", type=str2bool, default=False) # params for text recognizer parser.add_argument("--rec_algorithm", type=str, default='CRNN') @@ -90,7 +90,7 @@ def init_args(): parser.add_argument( "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt") parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext') - parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True) + parser.add_argument("--e2e_pgnet_polygon", type=str2bool, default=True) parser.add_argument("--e2e_pgnet_mode", type=str, default='fast') # params for text classifier @@ -111,7 +111,7 @@ def init_args(): parser.add_argument("--total_process_num", type=int, default=1) parser.add_argument("--process_id", type=int, default=0) - parser.add_argument("--benchmark", type=bool, default=False) + parser.add_argument("--benchmark", type=str2bool, default=False) parser.add_argument("--save_log_path", type=str, default="./log_output/") parser.add_argument("--show_log", type=str2bool, default=True) @@ -210,22 +210,22 @@ def create_predictor(args, mode, logger): "nearest_interp_v2_0.tmp_0": [1, 256, 40, 40] } min_pact_shape = { - "nearest_interp_v2_26.tmp_0":[1,256,20,20], - "nearest_interp_v2_27.tmp_0":[1,64,20,20], - "nearest_interp_v2_28.tmp_0":[1,64,20,20], - "nearest_interp_v2_29.tmp_0":[1,64,20,20] + "nearest_interp_v2_26.tmp_0": [1, 256, 20, 20], + "nearest_interp_v2_27.tmp_0": [1, 64, 20, 20], + "nearest_interp_v2_28.tmp_0": [1, 64, 20, 20], + "nearest_interp_v2_29.tmp_0": [1, 64, 20, 20] } max_pact_shape = { - "nearest_interp_v2_26.tmp_0":[1,256,400,400], - "nearest_interp_v2_27.tmp_0":[1,64,400,400], - "nearest_interp_v2_28.tmp_0":[1,64,400,400], - "nearest_interp_v2_29.tmp_0":[1,64,400,400] + "nearest_interp_v2_26.tmp_0": [1, 256, 400, 400], + "nearest_interp_v2_27.tmp_0": [1, 64, 400, 400], + "nearest_interp_v2_28.tmp_0": [1, 64, 400, 400], + "nearest_interp_v2_29.tmp_0": [1, 64, 400, 400] } opt_pact_shape = { - "nearest_interp_v2_26.tmp_0":[1,256,160,160], - "nearest_interp_v2_27.tmp_0":[1,64,160,160], - "nearest_interp_v2_28.tmp_0":[1,64,160,160], - "nearest_interp_v2_29.tmp_0":[1,64,160,160] + "nearest_interp_v2_26.tmp_0": [1, 256, 160, 160], + "nearest_interp_v2_27.tmp_0": [1, 64, 160, 160], + "nearest_interp_v2_28.tmp_0": [1, 64, 160, 160], + "nearest_interp_v2_29.tmp_0": [1, 64, 160, 160] } min_input_shape.update(min_pact_shape) max_input_shape.update(max_pact_shape)