提交 0af4e4a5 编写于 作者: W wangjiawei04

fix codestyle

上级 3e395027
...@@ -27,6 +27,7 @@ import time ...@@ -27,6 +27,7 @@ import time
import re import re
import base64 import base64
class OCRService(WebService): class OCRService(WebService):
def init_det_client(self, det_port, det_client_config): def init_det_client(self, det_port, det_client_config):
self.det_preprocess = Sequential([ self.det_preprocess = Sequential([
...@@ -46,7 +47,7 @@ class OCRService(WebService): ...@@ -46,7 +47,7 @@ class OCRService(WebService):
ori_h, ori_w, _ = im.shape ori_h, ori_w, _ = im.shape
det_img = self.det_preprocess(im) det_img = self.det_preprocess(im)
det_out = self.det_client.predict( det_out = self.det_client.predict(
feed={"image": det_img}, fetch=["concat_1.tmp_0"]) feed={"image": det_img}, fetch=["concat_1.tmp_0"])
_, new_h, new_w = det_img.shape _, new_h, new_w = det_img.shape
filter_func = FilterBoxes(10, 10) filter_func = FilterBoxes(10, 10)
post_func = DBPostProcess({ post_func = DBPostProcess({
......
...@@ -130,5 +130,6 @@ class Debugger(object): ...@@ -130,5 +130,6 @@ class Debugger(object):
fetch_map[name] = outputs[self.fetch_names_to_idx_[ fetch_map[name] = outputs[self.fetch_names_to_idx_[
name]].as_ndarray() name]].as_ndarray()
if len(outputs[self.fetch_names_to_idx_[name]].lod) > 0: if len(outputs[self.fetch_names_to_idx_[name]].lod) > 0:
fetch_map[name+".lod"] = outputs[self.fetch_names_to_idx_[name]].lod[0] fetch_map[name + ".lod"] = outputs[self.fetch_names_to_idx_[
name]].lod[0]
return fetch_map return fetch_map
...@@ -781,10 +781,12 @@ class Transpose(object): ...@@ -781,10 +781,12 @@ class Transpose(object):
"({})".format(self.transpose_target) "({})".format(self.transpose_target)
return format_string return format_string
class SortedBoxes(object): class SortedBoxes(object):
""" """
Sorted bounding boxes from Detection Sorted bounding boxes from Detection
""" """
def __init__(self): def __init__(self):
pass pass
...@@ -798,12 +800,14 @@ class SortedBoxes(object): ...@@ -798,12 +800,14 @@ class SortedBoxes(object):
tmp = _boxes[i] tmp = _boxes[i]
_boxes[i] = _boxes[i + 1] _boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp _boxes[i + 1] = tmp
return _boxes return _boxes
class GetRotateCropImage(object): class GetRotateCropImage(object):
""" """
Rotate and Crop image from OCR Det output Rotate and Crop image from OCR Det output
""" """
def __init__(self): def __init__(self):
pass pass
......
...@@ -120,7 +120,12 @@ class CharacterOps(object): ...@@ -120,7 +120,12 @@ class CharacterOps(object):
class OCRReader(object): class OCRReader(object):
def __init__(self, algorithm="CRNN", image_shape=[3,32,320], char_type="ch", batch_num=1, char_dict_path="./ppocr_keys_v1.txt"): def __init__(self,
algorithm="CRNN",
image_shape=[3, 32, 320],
char_type="ch",
batch_num=1,
char_dict_path="./ppocr_keys_v1.txt"):
self.rec_image_shape = image_shape self.rec_image_shape = image_shape
self.character_type = char_type self.character_type = char_type
self.rec_batch_num = batch_num self.rec_batch_num = batch_num
...@@ -129,7 +134,7 @@ class OCRReader(object): ...@@ -129,7 +134,7 @@ class OCRReader(object):
char_ops_params["character_dict_path"] = char_dict_path char_ops_params["character_dict_path"] = char_dict_path
char_ops_params['loss_type'] = 'ctc' char_ops_params['loss_type'] = 'ctc'
self.char_ops = CharacterOps(char_ops_params) self.char_ops = CharacterOps(char_ops_params)
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape imgC, imgH, imgW = self.rec_image_shape
if self.character_type == "ch": if self.character_type == "ch":
...@@ -151,7 +156,6 @@ class OCRReader(object): ...@@ -151,7 +156,6 @@ class OCRReader(object):
padding_im[:, :, 0:resized_w] = resized_image padding_im[:, :, 0:resized_w] = resized_image
return padding_im return padding_im
def preprocess(self, img_list): def preprocess(self, img_list):
img_num = len(img_list) img_num = len(img_list)
norm_img_batch = [] norm_img_batch = []
...@@ -180,14 +184,15 @@ class OCRReader(object): ...@@ -180,14 +184,15 @@ class OCRReader(object):
end = rec_idx_lod[rno + 1] end = rec_idx_lod[rno + 1]
if isinstance(rec_idx_batch, list): if isinstance(rec_idx_batch, list):
rec_idx_tmp = [x[0] for x in rec_idx_batch[beg:end]] rec_idx_tmp = [x[0] for x in rec_idx_batch[beg:end]]
else: #nd array else: #nd array
rec_idx_tmp = rec_idx_batch[beg:end, 0] rec_idx_tmp = rec_idx_batch[beg:end, 0]
preds_text = self.char_ops.decode(rec_idx_tmp) preds_text = self.char_ops.decode(rec_idx_tmp)
if with_score: if with_score:
beg = predict_lod[rno] beg = predict_lod[rno]
end = predict_lod[rno + 1] end = predict_lod[rno + 1]
if isinstance(outputs["softmax_0.tmp_0"], list): if isinstance(outputs["softmax_0.tmp_0"], list):
outputs["softmax_0.tmp_0"] = np.array(outputs["softmax_0.tmp_0"]).astype(np.float32) outputs["softmax_0.tmp_0"] = np.array(outputs[
"softmax_0.tmp_0"]).astype(np.float32)
probs = outputs["softmax_0.tmp_0"][beg:end, :] probs = outputs["softmax_0.tmp_0"][beg:end, :]
ind = np.argmax(probs, axis=1) ind = np.argmax(probs, axis=1)
blank = probs.shape[1] blank = probs.shape[1]
......
...@@ -189,7 +189,8 @@ class WebService(object): ...@@ -189,7 +189,8 @@ class WebService(object):
def _launch_local_predictor(self, gpu): def _launch_local_predictor(self, gpu):
from paddle_serving_app.local_predict import Debugger from paddle_serving_app.local_predict import Debugger
self.client = Debugger() self.client = Debugger()
self.client.load_model_config("{}".format(self.model_config), gpu=gpu, profile=False) self.client.load_model_config(
"{}".format(self.model_config), gpu=gpu, profile=False)
def run_web_service(self): def run_web_service(self):
self.app_instance.run(host="0.0.0.0", self.app_instance.run(host="0.0.0.0",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册