未验证 提交 00228775 编写于 作者: L littletomatodonkey 提交者: GitHub

fix hub post (#2613)

* fix hub serving for db post process

* update hub serving config method
上级 c9841fe2
...@@ -6,6 +6,7 @@ from __future__ import print_function ...@@ -6,6 +6,7 @@ from __future__ import print_function
import os import os
import sys import sys
sys.path.insert(0, ".") sys.path.insert(0, ".")
import copy
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
...@@ -14,6 +15,7 @@ import paddlehub as hub ...@@ -14,6 +15,7 @@ import paddlehub as hub
from tools.infer.utility import base64_to_cv2 from tools.infer.utility import base64_to_cv2
from tools.infer.predict_cls import TextClassifier from tools.infer.predict_cls import TextClassifier
from tools.infer.utility import parse_args
@moduleinfo( @moduleinfo(
...@@ -28,8 +30,7 @@ class OCRCls(hub.Module): ...@@ -28,8 +30,7 @@ class OCRCls(hub.Module):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
from ocr_cls.params import read_params cfg = self.merge_configs()
cfg = read_params()
cfg.use_gpu = use_gpu cfg.use_gpu = use_gpu
if use_gpu: if use_gpu:
...@@ -48,6 +49,21 @@ class OCRCls(hub.Module): ...@@ -48,6 +49,21 @@ class OCRCls(hub.Module):
self.text_classifier = TextClassifier(cfg) self.text_classifier = TextClassifier(cfg)
def merge_configs(self, ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
cfg = parse_args()
from ocr_det.params import read_params
update_cfg_map = vars(read_params())
for key in update_cfg_map:
cfg.__setattr__(key, update_cfg_map[key])
sys.argv = copy.deepcopy(backup_argv)
return cfg
def read_images(self, paths=[]): def read_images(self, paths=[]):
images = [] images = []
for img_path in paths: for img_path in paths:
......
...@@ -7,6 +7,8 @@ import os ...@@ -7,6 +7,8 @@ import os
import sys import sys
sys.path.insert(0, ".") sys.path.insert(0, ".")
import copy
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
import cv2 import cv2
...@@ -15,6 +17,7 @@ import paddlehub as hub ...@@ -15,6 +17,7 @@ import paddlehub as hub
from tools.infer.utility import base64_to_cv2 from tools.infer.utility import base64_to_cv2
from tools.infer.predict_det import TextDetector from tools.infer.predict_det import TextDetector
from tools.infer.utility import parse_args
@moduleinfo( @moduleinfo(
...@@ -29,8 +32,7 @@ class OCRDet(hub.Module): ...@@ -29,8 +32,7 @@ class OCRDet(hub.Module):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
from ocr_det.params import read_params cfg = self.merge_configs()
cfg = read_params()
cfg.use_gpu = use_gpu cfg.use_gpu = use_gpu
if use_gpu: if use_gpu:
...@@ -49,6 +51,21 @@ class OCRDet(hub.Module): ...@@ -49,6 +51,21 @@ class OCRDet(hub.Module):
self.text_detector = TextDetector(cfg) self.text_detector = TextDetector(cfg)
def merge_configs(self, ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
cfg = parse_args()
from ocr_det.params import read_params
update_cfg_map = vars(read_params())
for key in update_cfg_map:
cfg.__setattr__(key, update_cfg_map[key])
sys.argv = copy.deepcopy(backup_argv)
return cfg
def read_images(self, paths=[]): def read_images(self, paths=[]):
images = [] images = []
for img_path in paths: for img_path in paths:
......
...@@ -22,6 +22,7 @@ def read_params(): ...@@ -22,6 +22,7 @@ def read_params():
cfg.det_db_box_thresh = 0.5 cfg.det_db_box_thresh = 0.5
cfg.det_db_unclip_ratio = 1.6 cfg.det_db_unclip_ratio = 1.6
cfg.use_dilation = False cfg.use_dilation = False
cfg.det_db_score_mode = "fast"
# #EAST parmas # #EAST parmas
# cfg.det_east_score_thresh = 0.8 # cfg.det_east_score_thresh = 0.8
......
...@@ -6,6 +6,7 @@ from __future__ import print_function ...@@ -6,6 +6,7 @@ from __future__ import print_function
import os import os
import sys import sys
sys.path.insert(0, ".") sys.path.insert(0, ".")
import copy
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
...@@ -14,6 +15,7 @@ import paddlehub as hub ...@@ -14,6 +15,7 @@ import paddlehub as hub
from tools.infer.utility import base64_to_cv2 from tools.infer.utility import base64_to_cv2
from tools.infer.predict_rec import TextRecognizer from tools.infer.predict_rec import TextRecognizer
from tools.infer.utility import parse_args
@moduleinfo( @moduleinfo(
...@@ -28,8 +30,7 @@ class OCRRec(hub.Module): ...@@ -28,8 +30,7 @@ class OCRRec(hub.Module):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
from ocr_rec.params import read_params cfg = self.merge_configs()
cfg = read_params()
cfg.use_gpu = use_gpu cfg.use_gpu = use_gpu
if use_gpu: if use_gpu:
...@@ -48,6 +49,21 @@ class OCRRec(hub.Module): ...@@ -48,6 +49,21 @@ class OCRRec(hub.Module):
self.text_recognizer = TextRecognizer(cfg) self.text_recognizer = TextRecognizer(cfg)
def merge_configs(self, ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
cfg = parse_args()
from ocr_det.params import read_params
update_cfg_map = vars(read_params())
for key in update_cfg_map:
cfg.__setattr__(key, update_cfg_map[key])
sys.argv = copy.deepcopy(backup_argv)
return cfg
def read_images(self, paths=[]): def read_images(self, paths=[]):
images = [] images = []
for img_path in paths: for img_path in paths:
......
...@@ -6,6 +6,7 @@ from __future__ import print_function ...@@ -6,6 +6,7 @@ from __future__ import print_function
import os import os
import sys import sys
sys.path.insert(0, ".") sys.path.insert(0, ".")
import copy
import time import time
...@@ -17,6 +18,7 @@ import paddlehub as hub ...@@ -17,6 +18,7 @@ import paddlehub as hub
from tools.infer.utility import base64_to_cv2 from tools.infer.utility import base64_to_cv2
from tools.infer.predict_system import TextSystem from tools.infer.predict_system import TextSystem
from tools.infer.utility import parse_args
@moduleinfo( @moduleinfo(
...@@ -31,8 +33,7 @@ class OCRSystem(hub.Module): ...@@ -31,8 +33,7 @@ class OCRSystem(hub.Module):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
from ocr_system.params import read_params cfg = self.merge_configs()
cfg = read_params()
cfg.use_gpu = use_gpu cfg.use_gpu = use_gpu
if use_gpu: if use_gpu:
...@@ -51,6 +52,21 @@ class OCRSystem(hub.Module): ...@@ -51,6 +52,21 @@ class OCRSystem(hub.Module):
self.text_sys = TextSystem(cfg) self.text_sys = TextSystem(cfg)
def merge_configs(self, ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
cfg = parse_args()
from ocr_det.params import read_params
update_cfg_map = vars(read_params())
for key in update_cfg_map:
cfg.__setattr__(key, update_cfg_map[key])
sys.argv = copy.deepcopy(backup_argv)
return cfg
def read_images(self, paths=[]): def read_images(self, paths=[]):
images = [] images = []
for img_path in paths: for img_path in paths:
......
...@@ -22,6 +22,7 @@ def read_params(): ...@@ -22,6 +22,7 @@ def read_params():
cfg.det_db_box_thresh = 0.5 cfg.det_db_box_thresh = 0.5
cfg.det_db_unclip_ratio = 1.6 cfg.det_db_unclip_ratio = 1.6
cfg.use_dilation = False cfg.use_dilation = False
cfg.det_db_score_mode = "fast"
#EAST parmas #EAST parmas
cfg.det_east_score_thresh = 0.8 cfg.det_east_score_thresh = 0.8
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册