提交 c67d04ad 编写于 作者: W wangjiawei04

adapt to mobile model

上级 e11b8aff
...@@ -41,7 +41,7 @@ class TextDetectorHelper(TextDetector): ...@@ -41,7 +41,7 @@ class TextDetectorHelper(TextDetector):
elif self.det_algorithm == "EAST": elif self.det_algorithm == "EAST":
self.fetch = ["sigmoid_0.tmp_0", "tmp_2"] self.fetch = ["sigmoid_0.tmp_0", "tmp_2"]
elif self.det_algorithm == "DB": elif self.det_algorithm == "DB":
self.fetch = ["sigmoid_0.tmp_0"] self.fetch = ["save_infer_model/scale_0.tmp_0"]
def preprocess(self, img): def preprocess(self, img):
img = img.copy() img = img.copy()
......
...@@ -41,7 +41,7 @@ class TextDetectorHelper(TextDetector): ...@@ -41,7 +41,7 @@ class TextDetectorHelper(TextDetector):
elif self.det_algorithm == "EAST": elif self.det_algorithm == "EAST":
self.fetch = ["sigmoid_0.tmp_0", "tmp_2"] self.fetch = ["sigmoid_0.tmp_0", "tmp_2"]
elif self.det_algorithm == "DB": elif self.det_algorithm == "DB":
self.fetch = ["sigmoid_0.tmp_0"] self.fetch = ["save_infer_model/scale_0.tmp_0"]
def preprocess(self, img): def preprocess(self, img):
im, ratio_list = self.preprocess_op(img) im, ratio_list = self.preprocess_op(img)
......
...@@ -14,7 +14,7 @@ def read_params(): ...@@ -14,7 +14,7 @@ def read_params():
#params for text detector #params for text detector
cfg.det_algorithm = "DB" cfg.det_algorithm = "DB"
cfg.det_model_dir = "./det_mv_server/" cfg.det_model_dir = "./det_infer_server/"
cfg.det_max_side_len = 960 cfg.det_max_side_len = 960
#DB parmas #DB parmas
...@@ -29,7 +29,7 @@ def read_params(): ...@@ -29,7 +29,7 @@ def read_params():
#params for text recognizer #params for text recognizer
cfg.rec_algorithm = "CRNN" cfg.rec_algorithm = "CRNN"
cfg.rec_model_dir = "./ocr_rec_server/" cfg.rec_model_dir = "./rec_infer_server/"
cfg.rec_image_shape = "3, 32, 320" cfg.rec_image_shape = "3, 32, 320"
cfg.rec_char_type = 'ch' cfg.rec_char_type = 'ch'
...@@ -41,7 +41,7 @@ def read_params(): ...@@ -41,7 +41,7 @@ def read_params():
#params for text classifier #params for text classifier
cfg.use_angle_cls = True cfg.use_angle_cls = True
cfg.cls_model_dir = "./ocr_clas_server/" cfg.cls_model_dir = "./cls_infer_server/"
cfg.cls_image_shape = "3, 48, 192" cfg.cls_image_shape = "3, 48, 192"
cfg.label_list = ['0', '180'] cfg.label_list = ['0', '180']
cfg.cls_batch_num = 30 cfg.cls_batch_num = 30
......
...@@ -36,7 +36,7 @@ class TextRecognizerHelper(TextRecognizer): ...@@ -36,7 +36,7 @@ class TextRecognizerHelper(TextRecognizer):
def __init__(self, args): def __init__(self, args):
super(TextRecognizerHelper, self).__init__(args) super(TextRecognizerHelper, self).__init__(args)
if self.loss_type == "ctc": if self.loss_type == "ctc":
self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] self.fetch = ["save_infer_model/scale_0.tmp_0", "save_infer_model/scale_1.tmp_0"]
def preprocess(self, img_list): def preprocess(self, img_list):
img_num = len(img_list) img_num = len(img_list)
...@@ -83,8 +83,8 @@ class TextRecognizerHelper(TextRecognizer): ...@@ -83,8 +83,8 @@ class TextRecognizerHelper(TextRecognizer):
if self.loss_type == "ctc": if self.loss_type == "ctc":
rec_idx_batch = outputs[0] rec_idx_batch = outputs[0]
predict_batch = outputs[1] predict_batch = outputs[1]
rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"] rec_idx_lod = args["save_infer_model/scale_0.tmp_0.lod"]
predict_lod = args["softmax_0.tmp_0.lod"] predict_lod = args["save_infer_model/scale_1.tmp_0.lod"]
indices = args["indices"] indices = args["indices"]
rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1) rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1)
for rno in range(len(rec_idx_lod) - 1): for rno in range(len(rec_idx_lod) - 1):
......
...@@ -35,7 +35,7 @@ class TextRecognizerHelper(TextRecognizer): ...@@ -35,7 +35,7 @@ class TextRecognizerHelper(TextRecognizer):
def __init__(self, args): def __init__(self, args):
super(TextRecognizerHelper, self).__init__(args) super(TextRecognizerHelper, self).__init__(args)
if self.loss_type == "ctc": if self.loss_type == "ctc":
self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] self.fetch = ["save_infer_model/scale_0.tmp_0", "save_infer_model/scale_1.tmp_0"]
def preprocess(self, img_list): def preprocess(self, img_list):
img_num = len(img_list) img_num = len(img_list)
...@@ -88,8 +88,8 @@ class TextRecognizerHelper(TextRecognizer): ...@@ -88,8 +88,8 @@ class TextRecognizerHelper(TextRecognizer):
if self.loss_type == "ctc": if self.loss_type == "ctc":
rec_idx_batch = outputs[0] rec_idx_batch = outputs[0]
predict_batch = outputs[1] predict_batch = outputs[1]
rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"] rec_idx_lod = args["save_infer_model/scale_0.tmp_0.lod"]
predict_lod = args["softmax_0.tmp_0.lod"] predict_lod = args["save_infer_model/scale_1.tmp_0.lod"]
indices = args["indices"] indices = args["indices"]
rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1) rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1)
for rno in range(len(rec_idx_lod) - 1): for rno in range(len(rec_idx_lod) - 1):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册