未验证 提交 00228775 编写于 作者: L littletomatodonkey 提交者: GitHub

fix hub post (#2613)

* fix hub serving for db post process

* update hub serving config method
上级 c9841fe2
......@@ -6,6 +6,7 @@ from __future__ import print_function
import os
import sys
sys.path.insert(0, ".")
import copy
from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving
......@@ -14,6 +15,7 @@ import paddlehub as hub
from tools.infer.utility import base64_to_cv2
from tools.infer.predict_cls import TextClassifier
from tools.infer.utility import parse_args
@moduleinfo(
......@@ -28,8 +30,7 @@ class OCRCls(hub.Module):
"""
initialize with the necessary elements
"""
from ocr_cls.params import read_params
cfg = read_params()
cfg = self.merge_configs()
cfg.use_gpu = use_gpu
if use_gpu:
......@@ -48,6 +49,21 @@ class OCRCls(hub.Module):
self.text_classifier = TextClassifier(cfg)
def merge_configs(self, ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
cfg = parse_args()
from ocr_det.params import read_params
update_cfg_map = vars(read_params())
for key in update_cfg_map:
cfg.__setattr__(key, update_cfg_map[key])
sys.argv = copy.deepcopy(backup_argv)
return cfg
def read_images(self, paths=[]):
images = []
for img_path in paths:
......
......@@ -7,6 +7,8 @@ import os
import sys
sys.path.insert(0, ".")
import copy
from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving
import cv2
......@@ -15,6 +17,7 @@ import paddlehub as hub
from tools.infer.utility import base64_to_cv2
from tools.infer.predict_det import TextDetector
from tools.infer.utility import parse_args
@moduleinfo(
......@@ -29,8 +32,7 @@ class OCRDet(hub.Module):
"""
initialize with the necessary elements
"""
from ocr_det.params import read_params
cfg = read_params()
cfg = self.merge_configs()
cfg.use_gpu = use_gpu
if use_gpu:
......@@ -49,6 +51,21 @@ class OCRDet(hub.Module):
self.text_detector = TextDetector(cfg)
def merge_configs(self, ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
cfg = parse_args()
from ocr_det.params import read_params
update_cfg_map = vars(read_params())
for key in update_cfg_map:
cfg.__setattr__(key, update_cfg_map[key])
sys.argv = copy.deepcopy(backup_argv)
return cfg
def read_images(self, paths=[]):
images = []
for img_path in paths:
......
......@@ -22,6 +22,7 @@ def read_params():
cfg.det_db_box_thresh = 0.5
cfg.det_db_unclip_ratio = 1.6
cfg.use_dilation = False
cfg.det_db_score_mode = "fast"
# #EAST parmas
# cfg.det_east_score_thresh = 0.8
......
......@@ -6,6 +6,7 @@ from __future__ import print_function
import os
import sys
sys.path.insert(0, ".")
import copy
from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving
......@@ -14,6 +15,7 @@ import paddlehub as hub
from tools.infer.utility import base64_to_cv2
from tools.infer.predict_rec import TextRecognizer
from tools.infer.utility import parse_args
@moduleinfo(
......@@ -28,8 +30,7 @@ class OCRRec(hub.Module):
"""
initialize with the necessary elements
"""
from ocr_rec.params import read_params
cfg = read_params()
cfg = self.merge_configs()
cfg.use_gpu = use_gpu
if use_gpu:
......@@ -48,6 +49,21 @@ class OCRRec(hub.Module):
self.text_recognizer = TextRecognizer(cfg)
def merge_configs(self, ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
cfg = parse_args()
from ocr_det.params import read_params
update_cfg_map = vars(read_params())
for key in update_cfg_map:
cfg.__setattr__(key, update_cfg_map[key])
sys.argv = copy.deepcopy(backup_argv)
return cfg
def read_images(self, paths=[]):
images = []
for img_path in paths:
......
......@@ -6,6 +6,7 @@ from __future__ import print_function
import os
import sys
sys.path.insert(0, ".")
import copy
import time
......@@ -17,6 +18,7 @@ import paddlehub as hub
from tools.infer.utility import base64_to_cv2
from tools.infer.predict_system import TextSystem
from tools.infer.utility import parse_args
@moduleinfo(
......@@ -31,8 +33,7 @@ class OCRSystem(hub.Module):
"""
initialize with the necessary elements
"""
from ocr_system.params import read_params
cfg = read_params()
cfg = self.merge_configs()
cfg.use_gpu = use_gpu
if use_gpu:
......@@ -51,6 +52,21 @@ class OCRSystem(hub.Module):
self.text_sys = TextSystem(cfg)
def merge_configs(self, ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
cfg = parse_args()
from ocr_det.params import read_params
update_cfg_map = vars(read_params())
for key in update_cfg_map:
cfg.__setattr__(key, update_cfg_map[key])
sys.argv = copy.deepcopy(backup_argv)
return cfg
def read_images(self, paths=[]):
images = []
for img_path in paths:
......
......@@ -22,6 +22,7 @@ def read_params():
cfg.det_db_box_thresh = 0.5
cfg.det_db_unclip_ratio = 1.6
cfg.use_dilation = False
cfg.det_db_score_mode = "fast"
#EAST parmas
cfg.det_east_score_thresh = 0.8
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册