提交 02d96fbe 编写于 作者: W wangjiawei04

fix ocr

上级 eb08b056
......@@ -53,7 +53,9 @@ class OCRService(WebService):
self.ori_h, self.ori_w, _ = im.shape
det_img = self.det_preprocess(im)
_, self.new_h, self.new_w = det_img.shape
return {"image": det_img[np.newaxis, :].copy()}, ["concat_1.tmp_0"]
return {
"image": det_img[np.newaxis, :].copy()
}, ["concat_1.tmp_0"], True
def postprocess(self, feed={}, fetch=[], fetch_map=None):
det_out = fetch_map["concat_1.tmp_0"]
......
......@@ -42,10 +42,9 @@ class OCRService(WebService):
self.det_client = LocalPredictor()
if sys.argv[1] == 'gpu':
self.det_client.load_model_config(
det_model_config, gpu=True, profile=False)
det_model_config, use_gpu=True, gpu_id=1)
elif sys.argv[1] == 'cpu':
self.det_client.load_model_config(
det_model_config, gpu=False, profile=False)
self.det_client.load_model_config(det_model_config)
self.ocr_reader = OCRReader()
def preprocess(self, feed=[], fetch=[]):
......@@ -58,7 +57,7 @@ class OCRService(WebService):
det_img = det_img[np.newaxis, :]
det_img = det_img.copy()
det_out = self.det_client.predict(
feed={"image": det_img}, fetch=["concat_1.tmp_0"])
feed={"image": det_img}, fetch=["concat_1.tmp_0"], batch=True)
filter_func = FilterBoxes(10, 10)
post_func = DBPostProcess({
"thresh": 0.3,
......@@ -91,7 +90,7 @@ class OCRService(WebService):
imgs[id] = norm_img
feed = {"image": imgs.copy()}
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
return feed, fetch
return feed, fetch, True
def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True)
......@@ -107,7 +106,8 @@ ocr_service.load_model_config("ocr_rec_model")
ocr_service.prepare_server(workdir="workdir", port=9292)
ocr_service.init_det_debugger(det_model_config="ocr_det_model")
if sys.argv[1] == 'gpu':
ocr_service.run_debugger_service(gpu=True)
ocr_service.set_gpus("2")
ocr_service.run_debugger_service()
elif sys.argv[1] == 'cpu':
ocr_service.run_debugger_service()
ocr_service.run_web_service()
......@@ -52,7 +52,7 @@ class OCRService(WebService):
imgs[i] = norm_img
feed = {"image": imgs.copy()}
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
return feed, fetch
return feed, fetch, True
def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True)
......
......@@ -250,7 +250,7 @@ class WebService(object):
from paddle_serving_app.local_predict import LocalPredictor
self.client = LocalPredictor()
self.client.load_model_config(
"{}".format(self.model_config), gpu=gpu, profile=False)
"{}".format(self.model_config), use_gpu=True, gpu_id=self.gpus[0])
def run_web_service(self):
print("This API will be deprecated later. Please do not use it")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册