提交 c67d04ad 编写于 作者: W wangjiawei04

adapt to mobile model

上级 e11b8aff
......@@ -41,7 +41,7 @@ class TextDetectorHelper(TextDetector):
elif self.det_algorithm == "EAST":
self.fetch = ["sigmoid_0.tmp_0", "tmp_2"]
elif self.det_algorithm == "DB":
self.fetch = ["sigmoid_0.tmp_0"]
self.fetch = ["save_infer_model/scale_0.tmp_0"]
def preprocess(self, img):
img = img.copy()
......
......@@ -41,7 +41,7 @@ class TextDetectorHelper(TextDetector):
elif self.det_algorithm == "EAST":
self.fetch = ["sigmoid_0.tmp_0", "tmp_2"]
elif self.det_algorithm == "DB":
self.fetch = ["sigmoid_0.tmp_0"]
self.fetch = ["save_infer_model/scale_0.tmp_0"]
def preprocess(self, img):
im, ratio_list = self.preprocess_op(img)
......
......@@ -14,7 +14,7 @@ def read_params():
#params for text detector
cfg.det_algorithm = "DB"
cfg.det_model_dir = "./det_mv_server/"
cfg.det_model_dir = "./det_infer_server/"
cfg.det_max_side_len = 960
#DB parmas
......@@ -29,7 +29,7 @@ def read_params():
#params for text recognizer
cfg.rec_algorithm = "CRNN"
cfg.rec_model_dir = "./ocr_rec_server/"
cfg.rec_model_dir = "./rec_infer_server/"
cfg.rec_image_shape = "3, 32, 320"
cfg.rec_char_type = 'ch'
......@@ -41,7 +41,7 @@ def read_params():
#params for text classifier
cfg.use_angle_cls = True
cfg.cls_model_dir = "./ocr_clas_server/"
cfg.cls_model_dir = "./cls_infer_server/"
cfg.cls_image_shape = "3, 48, 192"
cfg.label_list = ['0', '180']
cfg.cls_batch_num = 30
......
......@@ -36,7 +36,7 @@ class TextRecognizerHelper(TextRecognizer):
def __init__(self, args):
super(TextRecognizerHelper, self).__init__(args)
if self.loss_type == "ctc":
self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
self.fetch = ["save_infer_model/scale_0.tmp_0", "save_infer_model/scale_1.tmp_0"]
def preprocess(self, img_list):
img_num = len(img_list)
......@@ -83,8 +83,8 @@ class TextRecognizerHelper(TextRecognizer):
if self.loss_type == "ctc":
rec_idx_batch = outputs[0]
predict_batch = outputs[1]
rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"]
predict_lod = args["softmax_0.tmp_0.lod"]
rec_idx_lod = args["save_infer_model/scale_0.tmp_0.lod"]
predict_lod = args["save_infer_model/scale_1.tmp_0.lod"]
indices = args["indices"]
rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1)
for rno in range(len(rec_idx_lod) - 1):
......
......@@ -35,7 +35,7 @@ class TextRecognizerHelper(TextRecognizer):
def __init__(self, args):
super(TextRecognizerHelper, self).__init__(args)
if self.loss_type == "ctc":
self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
self.fetch = ["save_infer_model/scale_0.tmp_0", "save_infer_model/scale_1.tmp_0"]
def preprocess(self, img_list):
img_num = len(img_list)
......@@ -88,8 +88,8 @@ class TextRecognizerHelper(TextRecognizer):
if self.loss_type == "ctc":
rec_idx_batch = outputs[0]
predict_batch = outputs[1]
rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"]
predict_lod = args["softmax_0.tmp_0.lod"]
rec_idx_lod = args["save_infer_model/scale_0.tmp_0.lod"]
predict_lod = args["save_infer_model/scale_1.tmp_0.lod"]
indices = args["indices"]
rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1)
for rno in range(len(rec_idx_lod) - 1):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册