From c4720557e8c99faf523d3a45c427ea6da390513a Mon Sep 17 00:00:00 2001 From: wangjiawei04 Date: Tue, 22 Sep 2020 19:30:26 +0800 Subject: [PATCH] fix all minor bugs --- deploy/pdserving/clas_local_server.py | 2 +- deploy/pdserving/clas_web_client.py | 1 - deploy/pdserving/det_local_server.py | 2 +- deploy/pdserving/det_rpc_server.py | 3 +-- deploy/pdserving/ocr_local_server.py | 7 ++----- deploy/pdserving/ocr_rpc_server.py | 7 ++++--- deploy/pdserving/ocr_web_client.py | 3 --- deploy/pdserving/rec_local_server.py | 2 -- deploy/pdserving/rec_rpc_server.py | 2 -- deploy/pdserving/rec_web_client.py | 1 - tools/infer/predict_cls.py | 2 +- tools/infer/predict_det.py | 2 +- tools/infer/predict_rec.py | 2 +- tools/infer/predict_system.py | 7 ++++++- tools/infer/utility.py | 12 ++++-------- 15 files changed, 22 insertions(+), 33 deletions(-) diff --git a/deploy/pdserving/clas_local_server.py b/deploy/pdserving/clas_local_server.py index abf5e2d8..4bceea77 100644 --- a/deploy/pdserving/clas_local_server.py +++ b/deploy/pdserving/clas_local_server.py @@ -117,7 +117,7 @@ class OCRService(WebService): if __name__ == "__main__": ocr_service = OCRService(name="ocr") - ocr_service.load_model_config("cls_server") + ocr_service.load_model_config(global_args.cls_model_dir) ocr_service.init_rec() if global_args.use_gpu: ocr_service.prepare_server( diff --git a/deploy/pdserving/clas_web_client.py b/deploy/pdserving/clas_web_client.py index 9bcd929e..576e073b 100644 --- a/deploy/pdserving/clas_web_client.py +++ b/deploy/pdserving/clas_web_client.py @@ -37,4 +37,3 @@ for img_file in os.listdir(test_img_dir): data = {"feed": [{"image": image}], "fetch": ["res"]} r = requests.post(url=url, headers=headers, data=json.dumps(data)) print(r.json()) - break diff --git a/deploy/pdserving/det_local_server.py b/deploy/pdserving/det_local_server.py index f79b9994..d0e52cd0 100644 --- a/deploy/pdserving/det_local_server.py +++ b/deploy/pdserving/det_local_server.py @@ -96,7 +96,7 @@ class DetService(WebService): if __name__ == "__main__": ocr_service = DetService(name="ocr") - ocr_service.load_model_config("serving_server_dir") + ocr_service.load_model_config(global_args.det_model_dir) ocr_service.init_det() if global_args.use_gpu: ocr_service.prepare_server( diff --git a/deploy/pdserving/det_rpc_server.py b/deploy/pdserving/det_rpc_server.py index ef6d135b..6588a0d6 100644 --- a/deploy/pdserving/det_rpc_server.py +++ b/deploy/pdserving/det_rpc_server.py @@ -79,7 +79,6 @@ class TextDetectorHelper(TextDetector): class DetService(WebService): def init_det(self): self.text_detector = TextDetectorHelper(global_args) - print("init finish") def preprocess(self, feed=[], fetch=[]): data = base64.b64decode(feed[0]["image"].encode('utf8')) @@ -96,7 +95,7 @@ class DetService(WebService): if __name__ == "__main__": ocr_service = DetService(name="ocr") - ocr_service.load_model_config("serving_server_dir") + ocr_service.load_model_config(global_args.det_model_dir) ocr_service.init_det() if global_args.use_gpu: ocr_service.prepare_server( diff --git a/deploy/pdserving/ocr_local_server.py b/deploy/pdserving/ocr_local_server.py index dae71374..a4fb540b 100644 --- a/deploy/pdserving/ocr_local_server.py +++ b/deploy/pdserving/ocr_local_server.py @@ -44,17 +44,16 @@ class TextSystemHelper(TextSystem): if self.use_angle_cls: self.clas_client = Debugger() self.clas_client.load_model_config( - "ocr_clas_server", gpu=True, profile=False) + global_args.cls_model_dir, gpu=True, profile=False) self.text_classifier = TextClassifierHelper(args) self.det_client = Debugger() self.det_client.load_model_config( - "serving_server_dir", gpu=True, profile=False) + global_args.det_model_dir, gpu=True, profile=False) self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] def preprocess(self, img): feed, fetch, self.tmp_args = self.text_detector.preprocess(img) fetch_map = self.det_client.predict(feed, fetch) - print("det fetch_map", fetch_map) outputs = [fetch_map[x] for x in fetch] dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args) if dt_boxes is None: @@ -90,12 +89,10 @@ class OCRService(WebService): def preprocess(self, feed=[], fetch=[]): # TODO: to handle batch rec images - print("start preprocess") data = base64.b64decode(feed[0]["image"].encode('utf8')) data = np.fromstring(data, np.uint8) im = cv2.imdecode(data, cv2.IMREAD_COLOR) feed, fetch, self.tmp_args = self.text_system.preprocess(im) - print("ocr preprocess done") return feed, fetch def postprocess(self, feed={}, fetch=[], fetch_map=None): diff --git a/deploy/pdserving/ocr_rpc_server.py b/deploy/pdserving/ocr_rpc_server.py index 3ed8810e..873e8792 100644 --- a/deploy/pdserving/ocr_rpc_server.py +++ b/deploy/pdserving/ocr_rpc_server.py @@ -25,7 +25,7 @@ from clas_rpc_server import TextClassifierHelper from det_rpc_server import TextDetectorHelper from rec_rpc_server import TextRecognizerHelper import tools.infer.utility as utility -from tools.infer.predict_system import TextSystem +from tools.infer.predict_system import TextSystem, sorted_boxes import copy global_args = utility.parse_args() @@ -48,7 +48,7 @@ class TextSystemHelper(TextSystem): self.text_classifier = TextClassifierHelper(args) self.det_client = Client() self.det_client.load_client_config( - "ocr_det_server/serving_client_conf.prototxt") + "det_db_client/serving_client_conf.prototxt") self.det_client.connect(["127.0.0.1:9293"]) self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] @@ -57,10 +57,10 @@ class TextSystemHelper(TextSystem): fetch_map = self.det_client.predict(feed, fetch) outputs = [fetch_map[x] for x in fetch] dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args) + print(dt_boxes) if dt_boxes is None: return None, None img_crop_list = [] - sorted_boxes = SortedBoxes() dt_boxes = sorted_boxes(dt_boxes) for bno in range(len(dt_boxes)): tmp_box = copy.deepcopy(dt_boxes[bno]) @@ -70,6 +70,7 @@ class TextSystemHelper(TextSystem): feed, fetch, self.tmp_args = self.text_classifier.preprocess( img_crop_list) fetch_map = self.clas_client.predict(feed, fetch) + print(fetch_map) outputs = [fetch_map[x] for x in self.text_classifier.fetch] for x in fetch_map.keys(): if ".lod" in x: diff --git a/deploy/pdserving/ocr_web_client.py b/deploy/pdserving/ocr_web_client.py index 036f730e..4324406b 100644 --- a/deploy/pdserving/ocr_web_client.py +++ b/deploy/pdserving/ocr_web_client.py @@ -36,8 +36,5 @@ for img_file in os.listdir(test_img_dir): image = cv2_to_base64(image_data1) data = {"feed": [{"image": image}], "fetch": ["res"]} r = requests.post(url=url, headers=headers, data=json.dumps(data)) - print(r) rjson = r.json() print(rjson) - #for x in rjson["result"]["pred_text"]: - # print(x) diff --git a/deploy/pdserving/rec_local_server.py b/deploy/pdserving/rec_local_server.py index 5021cdd9..58df5b32 100644 --- a/deploy/pdserving/rec_local_server.py +++ b/deploy/pdserving/rec_local_server.py @@ -85,7 +85,6 @@ class TextRecognizerHelper(TextRecognizer): rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"] predict_lod = args["softmax_0.tmp_0.lod"] indices = args["indices"] - print("indices", indices, rec_idx_lod) rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1) for rno in range(len(rec_idx_lod) - 1): beg = rec_idx_lod[rno] @@ -155,7 +154,6 @@ class OCRService(WebService): if ".lod" in x: self.tmp_args[x] = fetch_map[x] rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args) - print("rec_res", rec_res) res = { "pred_text": [x[0] for x in rec_res], "score": [str(x[1]) for x in rec_res] diff --git a/deploy/pdserving/rec_rpc_server.py b/deploy/pdserving/rec_rpc_server.py index b1a9df9e..38251b3e 100644 --- a/deploy/pdserving/rec_rpc_server.py +++ b/deploy/pdserving/rec_rpc_server.py @@ -91,7 +91,6 @@ class TextRecognizerHelper(TextRecognizer): rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"] predict_lod = args["softmax_0.tmp_0.lod"] indices = args["indices"] - print("indices", indices, rec_idx_lod) rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1) for rno in range(len(rec_idx_lod) - 1): beg = rec_idx_lod[rno] @@ -161,7 +160,6 @@ class OCRService(WebService): if ".lod" in x: self.tmp_args[x] = fetch_map[x] rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args) - print("rec_res", rec_res) res = { "pred_text": [x[0] for x in rec_res], "score": [str(x[1]) for x in rec_res] diff --git a/deploy/pdserving/rec_web_client.py b/deploy/pdserving/rec_web_client.py index 9bcd929e..576e073b 100644 --- a/deploy/pdserving/rec_web_client.py +++ b/deploy/pdserving/rec_web_client.py @@ -37,4 +37,3 @@ for img_file in os.listdir(test_img_dir): data = {"feed": [{"image": image}], "fetch": ["res"]} r = requests.post(url=url, headers=headers, data=json.dumps(data)) print(r.json()) - break diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index 74029b24..d5c7830b 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -33,7 +33,7 @@ from paddle import fluid class TextClassifier(object): def __init__(self, args): - if args.use_serving is False: + if args.use_pdserving is False: self.predictor, self.input_tensor, self.output_tensors = \ utility.create_predictor(args, mode="cls") self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")] diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 18ea4bff..e658b9f6 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -75,7 +75,7 @@ class TextDetector(object): else: logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) sys.exit(0) - if args.use_gpu is False: + if args.use_pdserving is False: self.predictor, self.input_tensor, self.output_tensors =\ utility.create_predictor(args, mode="det") diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index c09a14f9..e9fa52ac 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -34,7 +34,7 @@ from ppocr.utils.character import CharacterOps class TextRecognizer(object): def __init__(self, args): - if args.use_serving is False: + if args.use_pdserving is False: self.predictor, self.input_tensor, self.output_tensors =\ utility.create_predictor(args, mode="rec") self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")] diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 29c4d7e8..dd6b76a6 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -161,7 +161,12 @@ def main(args): scores = [rec_res[i][1] for i in range(len(rec_res))] draw_img = draw_ocr( - image, boxes, txts, scores, drop_score=drop_score, font_path=font_path) + image, + boxes, + txts, + scores, + drop_score=drop_score, + font_path=font_path) draw_img_save = "./inference_results/" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index d85322d8..bd3919e8 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -37,7 +37,7 @@ def parse_args(): parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--gpu_mem", type=int, default=8000) - parser.add_argument("--use_serving", type=str2bool, default=False) + parser.add_argument("--use_pdserving", type=str2bool, default=False) # params for text detector parser.add_argument("--image_dir", type=str) @@ -73,9 +73,7 @@ def parse_args(): default="./ppocr/utils/ppocr_keys_v1.txt") parser.add_argument("--use_space_char", type=str2bool, default=True) parser.add_argument( - "--vis_font_path", - type=str, - default="./doc/simfang.ttf") + "--vis_font_path", type=str, default="./doc/simfang.ttf") # params for text classifier parser.add_argument("--use_angle_cls", type=str2bool, default=False) @@ -230,8 +228,7 @@ def draw_ocr_box_txt(image, boxes, txts, font_path="./doc/simfang.ttf"): 1])**2) if box_height > 2 * box_width: font_size = max(int(box_width * 0.9), 10) - font = ImageFont.truetype( - font_path, font_size, encoding="utf-8") + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") cur_y = box[0][1] for c in txt: char_size = font.getsize(c) @@ -240,8 +237,7 @@ def draw_ocr_box_txt(image, boxes, txts, font_path="./doc/simfang.ttf"): cur_y += char_size[1] else: font_size = max(int(box_height * 0.8), 10) - font = ImageFont.truetype( - font_path, font_size, encoding="utf-8") + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") draw_right.text( [box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) img_left = Image.blend(image, img_left, 0.5) -- GitLab