From 00228775558b06660d73f444e7d3ef3c2a25446c Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Sat, 24 Apr 2021 14:30:49 +0800 Subject: [PATCH] fix hub post (#2613) * fix hub serving for db post process * update hub serving config method --- deploy/hubserving/ocr_cls/module.py | 20 ++++++++++++++++++-- deploy/hubserving/ocr_det/module.py | 21 +++++++++++++++++++-- deploy/hubserving/ocr_det/params.py | 1 + deploy/hubserving/ocr_rec/module.py | 20 ++++++++++++++++++-- deploy/hubserving/ocr_system/module.py | 20 ++++++++++++++++++-- deploy/hubserving/ocr_system/params.py | 1 + 6 files changed, 75 insertions(+), 8 deletions(-) diff --git a/deploy/hubserving/ocr_cls/module.py b/deploy/hubserving/ocr_cls/module.py index 1b91580c..803d5ac2 100644 --- a/deploy/hubserving/ocr_cls/module.py +++ b/deploy/hubserving/ocr_cls/module.py @@ -6,6 +6,7 @@ from __future__ import print_function import os import sys sys.path.insert(0, ".") +import copy from paddlehub.common.logger import logger from paddlehub.module.module import moduleinfo, runnable, serving @@ -14,6 +15,7 @@ import paddlehub as hub from tools.infer.utility import base64_to_cv2 from tools.infer.predict_cls import TextClassifier +from tools.infer.utility import parse_args @moduleinfo( @@ -28,8 +30,7 @@ class OCRCls(hub.Module): """ initialize with the necessary elements """ - from ocr_cls.params import read_params - cfg = read_params() + cfg = self.merge_configs() cfg.use_gpu = use_gpu if use_gpu: @@ -48,6 +49,21 @@ class OCRCls(hub.Module): self.text_classifier = TextClassifier(cfg) + def merge_configs(self, ): + # deafult cfg + backup_argv = copy.deepcopy(sys.argv) + sys.argv = sys.argv[:1] + cfg = parse_args() + + from ocr_det.params import read_params + update_cfg_map = vars(read_params()) + + for key in update_cfg_map: + cfg.__setattr__(key, update_cfg_map[key]) + + sys.argv = copy.deepcopy(backup_argv) + return cfg + def read_images(self, paths=[]): images = [] for img_path in paths: diff --git a/deploy/hubserving/ocr_det/module.py b/deploy/hubserving/ocr_det/module.py index 5f7bd6c4..595f4cea 100644 --- a/deploy/hubserving/ocr_det/module.py +++ b/deploy/hubserving/ocr_det/module.py @@ -7,6 +7,8 @@ import os import sys sys.path.insert(0, ".") +import copy + from paddlehub.common.logger import logger from paddlehub.module.module import moduleinfo, runnable, serving import cv2 @@ -15,6 +17,7 @@ import paddlehub as hub from tools.infer.utility import base64_to_cv2 from tools.infer.predict_det import TextDetector +from tools.infer.utility import parse_args @moduleinfo( @@ -29,8 +32,7 @@ class OCRDet(hub.Module): """ initialize with the necessary elements """ - from ocr_det.params import read_params - cfg = read_params() + cfg = self.merge_configs() cfg.use_gpu = use_gpu if use_gpu: @@ -49,6 +51,21 @@ class OCRDet(hub.Module): self.text_detector = TextDetector(cfg) + def merge_configs(self, ): + # deafult cfg + backup_argv = copy.deepcopy(sys.argv) + sys.argv = sys.argv[:1] + cfg = parse_args() + + from ocr_det.params import read_params + update_cfg_map = vars(read_params()) + + for key in update_cfg_map: + cfg.__setattr__(key, update_cfg_map[key]) + + sys.argv = copy.deepcopy(backup_argv) + return cfg + def read_images(self, paths=[]): images = [] for img_path in paths: diff --git a/deploy/hubserving/ocr_det/params.py b/deploy/hubserving/ocr_det/params.py index 7be88e9b..bc75cc40 100755 --- a/deploy/hubserving/ocr_det/params.py +++ b/deploy/hubserving/ocr_det/params.py @@ -22,6 +22,7 @@ def read_params(): cfg.det_db_box_thresh = 0.5 cfg.det_db_unclip_ratio = 1.6 cfg.use_dilation = False + cfg.det_db_score_mode = "fast" # #EAST parmas # cfg.det_east_score_thresh = 0.8 diff --git a/deploy/hubserving/ocr_rec/module.py b/deploy/hubserving/ocr_rec/module.py index 41a42104..70998241 100644 --- a/deploy/hubserving/ocr_rec/module.py +++ b/deploy/hubserving/ocr_rec/module.py @@ -6,6 +6,7 @@ from __future__ import print_function import os import sys sys.path.insert(0, ".") +import copy from paddlehub.common.logger import logger from paddlehub.module.module import moduleinfo, runnable, serving @@ -14,6 +15,7 @@ import paddlehub as hub from tools.infer.utility import base64_to_cv2 from tools.infer.predict_rec import TextRecognizer +from tools.infer.utility import parse_args @moduleinfo( @@ -28,8 +30,7 @@ class OCRRec(hub.Module): """ initialize with the necessary elements """ - from ocr_rec.params import read_params - cfg = read_params() + cfg = self.merge_configs() cfg.use_gpu = use_gpu if use_gpu: @@ -48,6 +49,21 @@ class OCRRec(hub.Module): self.text_recognizer = TextRecognizer(cfg) + def merge_configs(self, ): + # deafult cfg + backup_argv = copy.deepcopy(sys.argv) + sys.argv = sys.argv[:1] + cfg = parse_args() + + from ocr_det.params import read_params + update_cfg_map = vars(read_params()) + + for key in update_cfg_map: + cfg.__setattr__(key, update_cfg_map[key]) + + sys.argv = copy.deepcopy(backup_argv) + return cfg + def read_images(self, paths=[]): images = [] for img_path in paths: diff --git a/deploy/hubserving/ocr_system/module.py b/deploy/hubserving/ocr_system/module.py index 7f361733..7a65db09 100644 --- a/deploy/hubserving/ocr_system/module.py +++ b/deploy/hubserving/ocr_system/module.py @@ -6,6 +6,7 @@ from __future__ import print_function import os import sys sys.path.insert(0, ".") +import copy import time @@ -17,6 +18,7 @@ import paddlehub as hub from tools.infer.utility import base64_to_cv2 from tools.infer.predict_system import TextSystem +from tools.infer.utility import parse_args @moduleinfo( @@ -31,8 +33,7 @@ class OCRSystem(hub.Module): """ initialize with the necessary elements """ - from ocr_system.params import read_params - cfg = read_params() + cfg = self.merge_configs() cfg.use_gpu = use_gpu if use_gpu: @@ -51,6 +52,21 @@ class OCRSystem(hub.Module): self.text_sys = TextSystem(cfg) + def merge_configs(self, ): + # deafult cfg + backup_argv = copy.deepcopy(sys.argv) + sys.argv = sys.argv[:1] + cfg = parse_args() + + from ocr_det.params import read_params + update_cfg_map = vars(read_params()) + + for key in update_cfg_map: + cfg.__setattr__(key, update_cfg_map[key]) + + sys.argv = copy.deepcopy(backup_argv) + return cfg + def read_images(self, paths=[]): images = [] for img_path in paths: diff --git a/deploy/hubserving/ocr_system/params.py b/deploy/hubserving/ocr_system/params.py index bd56dc2e..bee53bfd 100755 --- a/deploy/hubserving/ocr_system/params.py +++ b/deploy/hubserving/ocr_system/params.py @@ -22,6 +22,7 @@ def read_params(): cfg.det_db_box_thresh = 0.5 cfg.det_db_unclip_ratio = 1.6 cfg.use_dilation = False + cfg.det_db_score_mode = "fast" #EAST parmas cfg.det_east_score_thresh = 0.8 -- GitLab