提交 7cacfc97 编写于 作者: W wangjiawei04

first commit

上级 30648fbc
...@@ -18,60 +18,111 @@ import cv2 ...@@ -18,60 +18,111 @@ import cv2
import sys import sys
import numpy as np import numpy as np
import os import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
if sys.argv[1] == 'gpu':
from paddle_serving_server_gpu.web_service import WebService
elif sys.argv[1] == 'cpu':
from paddle_serving_server.web_service import WebService
import time import time
import re import re
import base64 import base64
from tools.infer.predict_cls import TextClassifier
import tools.infer.utility as utility
global_args = utility.parse_args()
if global_args.use_gpu:
from paddle_serving_server_gpu.web_service import WebService
else:
from paddle_serving_server.web_service import WebService
class TextClassifierHelper(TextClassifier):
def __init__(self, args):
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
self.cls_batch_num = args.rec_batch_num
self.label_list = args.label_list
self.cls_thresh = args.cls_thresh
self.fetch = [
"save_infer_model/scale_0.tmp_0", "save_infer_model/scale_1.tmp_0"
]
def preprocess(self, img_list):
args = {}
img_num = len(img_list)
args["img_list"] = img_list
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the cls process
indices = np.argsort(np.array(width_list))
args["indices"] = indices
cls_res = [['', 0.0]] * img_num
batch_num = self.cls_batch_num
predict_time = 0
beg_img_no, end_img_no = 0, img_num
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
feed = {"image": norm_img_batch.copy()}
return feed, self.fetch, args
def postprocess(self, outputs, args):
prob_out = outputs[0]
label_out = outputs[1]
indices = args["indices"]
cls_res = [['', 0.0]] * len(label_out)
if len(label_out.shape) != 1:
prob_out, label_out = label_out, prob_out
for rno in range(len(label_out)):
label_idx = label_out[rno]
score = prob_out[rno][label_idx]
label = self.label_list[label_idx]
cls_res[indices[rno]] = [label, score]
if '180' in label and score > self.cls_thresh:
img_list[indices[rno]] = cv2.rotate(img_list[indices[rno]], 1)
return args["img_list"], cls_res
class OCRService(WebService): class OCRService(WebService):
def init_rec(self): def init_rec(self):
self.ocr_reader = OCRReader() self.ocr_reader = OCRReader()
self.text_classifier = TextClassifierHelper(global_args)
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
# TODO: to handle batch rec images
img_list = [] img_list = []
for feed_data in feed: for feed_data in feed:
data = base64.b64decode(feed_data["image"].encode('utf8')) data = base64.b64decode(feed_data["image"].encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR) im = cv2.imdecode(data, cv2.IMREAD_COLOR)
img_list.append(im) img_list.append(im)
feed_list = [] feed, fetch, self.tmp_args = self.text_classifier.preprocess(img_list)
max_wh_ratio = 0 return feed, fetch
for i, boximg in enumerate(img_list):
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for img in img_list:
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
feed = {"image": norm_img}
feed_list.append(feed)
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
return feed_list, fetch
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) outputs = [fetch_map[x] for x in self.text_classifier.fetch]
res_lst = [] for x in fetch_map.keys():
for res in rec_res: if ".lod" in x:
res_lst.append(res[0]) self.tmp_args[x] = fetch_map[x]
res = {"res": res_lst} _, rec_res = self.text_classifier.postprocess(outputs, self.tmp_args)
res = {
"pred_text": [x[0] for x in rec_res],
"score": [str(x[1]) for x in rec_res]
}
return res return res
ocr_service = OCRService(name="ocr") if __name__ == "__main__":
ocr_service.load_model_config("ocr_rec_model") ocr_service = OCRService(name="ocr")
ocr_service.init_rec() ocr_service.load_model_config("cls_server")
if sys.argv[1] == 'gpu': ocr_service.init_rec()
ocr_service.set_gpus("0") if global_args.use_gpu:
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0) ocr_service.prepare_server(
elif sys.argv[1] == 'cpu': workdir="workdir", port=9292, device="gpu", gpuid=0)
ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu") else:
ocr_service.run_rpc_service() ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu")
ocr_service.run_web_service() ocr_service.run_debugger_service()
ocr_service.run_web_service()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import OCRReader
import cv2
import sys
import numpy as np
import os
import time
import re
import base64
from tools.infer.predict_cls import TextClassifier
import tools.infer.utility as utility
global_args = utility.parse_args()
if global_args.use_gpu:
from paddle_serving_server_gpu.web_service import WebService
else:
from paddle_serving_server.web_service import WebService
class TextClassifierHelper(TextClassifier):
def __init__(self, args):
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
self.cls_batch_num = args.rec_batch_num
self.label_list = args.label_list
self.cls_thresh = args.cls_thresh
self.fetch = [
"save_infer_model/scale_0.tmp_0", "save_infer_model/scale_1.tmp_0"
]
def preprocess(self, img_list):
args = {}
img_num = len(img_list)
args["img_list"] = img_list
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the cls process
indices = np.argsort(np.array(width_list))
args["indices"] = indices
cls_res = [['', 0.0]] * img_num
batch_num = self.cls_batch_num
predict_time = 0
beg_img_no, end_img_no = 0, img_num
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
if img_num > 1:
feed = [{
"image": norm_img_batch[x]
} for x in range(norm_img_batch.shape[0])]
else:
feed = {"image": norm_img_batch[0]}
return feed, self.fetch, args
def postprocess(self, outputs, args):
prob_out = outputs[0]
label_out = outputs[1]
indices = args["indices"]
cls_res = [['', 0.0]] * len(label_out)
if len(label_out.shape) != 1:
prob_out, label_out = label_out, prob_out
for rno in range(len(label_out)):
label_idx = label_out[rno]
score = prob_out[rno][label_idx]
label = self.label_list[label_idx]
cls_res[indices[rno]] = [label, score]
if '180' in label and score > self.cls_thresh:
img_list[indices[rno]] = cv2.rotate(img_list[indices[rno]], 1)
return args["img_list"], cls_res
class OCRService(WebService):
def init_rec(self):
self.ocr_reader = OCRReader()
self.text_classifier = TextClassifierHelper(global_args)
def preprocess(self, feed=[], fetch=[]):
# TODO: to handle batch rec images
img_list = []
for feed_data in feed:
data = base64.b64decode(feed_data["image"].encode('utf8'))
data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
img_list.append(im)
feed, fetch, self.tmp_args = self.text_classifier.preprocess(img_list)
return feed, fetch
def postprocess(self, feed={}, fetch=[], fetch_map=None):
outputs = [fetch_map[x] for x in self.text_classifier.fetch]
for x in fetch_map.keys():
if ".lod" in x:
self.tmp_args[x] = fetch_map[x]
_, rec_res = self.text_classifier.postprocess(outputs, self.tmp_args)
res = {
"direction": [x[0] for x in rec_res],
"score": [str(x[1]) for x in rec_res]
}
return res
if __name__ == "__main__":
ocr_service = OCRService(name="ocr")
ocr_service.load_model_config(global_args.cls_model_dir)
ocr_service.init_rec()
if global_args.use_gpu:
ocr_service.prepare_server(
workdir="workdir", port=9292, device="gpu", gpuid=0)
else:
ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu")
ocr_service.run_rpc_service()
ocr_service.run_web_service()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
import requests
import json
import cv2
import base64
import os, sys
import time
def cv2_to_base64(image):
#data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(image).decode(
'utf8') #data.tostring()).decode('utf8')
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:9292/ocr/prediction"
test_img_dir = "../../doc/imgs_words/ch/"
for img_file in os.listdir(test_img_dir):
with open(os.path.join(test_img_dir, img_file), 'rb') as file:
image_data1 = file.read()
image = cv2_to_base64(image_data1)
data = {"feed": [{"image": image}], "fetch": ["res"]}
r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r.json())
break
...@@ -17,63 +17,91 @@ import cv2 ...@@ -17,63 +17,91 @@ import cv2
import sys import sys
import numpy as np import numpy as np
import os import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes
if sys.argv[1] == 'gpu':
from paddle_serving_server_gpu.web_service import WebService
elif sys.argv[1] == 'cpu':
from paddle_serving_server.web_service import WebService
import time import time
import re import re
import base64 import base64
from tools.infer.predict_det import TextDetector
import tools.infer.utility as utility
global_args = utility.parse_args()
if global_args.use_gpu:
from paddle_serving_server_gpu.web_service import WebService
else:
from paddle_serving_server.web_service import WebService
class TextDetectorHelper(TextDetector):
def __init__(self, args):
super(TextDetectorHelper, self).__init__(args)
if self.det_algorithm == "SAST":
self.fetch = [
"bn_f_border4.output.tmp_2", "bn_f_tco4.output.tmp_2",
"bn_f_tvo4.output.tmp_2", "sigmoid_0.tmp_0"
]
elif self.det_algorithm == "EAST":
self.fetch = ["sigmoid_0.tmp_0", "tmp_2"]
elif self.det_algorithm == "DB":
self.fetch = ["sigmoid_0.tmp_0"]
def preprocess(self, img):
img = img.copy()
im, ratio_list = self.preprocess_op(img)
if im is None:
return None, 0
return {
"image": im.copy()
}, self.fetch, {
"ratio_list": [ratio_list],
"ori_im": img
}
def postprocess(self, outputs, args):
outs_dict = {}
if self.det_algorithm == "EAST":
outs_dict['f_geo'] = outputs[0]
outs_dict['f_score'] = outputs[1]
elif self.det_algorithm == 'SAST':
outs_dict['f_border'] = outputs[0]
outs_dict['f_score'] = outputs[1]
outs_dict['f_tco'] = outputs[2]
outs_dict['f_tvo'] = outputs[3]
else:
outs_dict['maps'] = outputs[0]
dt_boxes_list = self.postprocess_op(outs_dict, args["ratio_list"])
dt_boxes = dt_boxes_list[0]
if self.det_algorithm == "SAST" and self.det_sast_polygon:
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes,
args["ori_im"].shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, args["ori_im"].shape)
return dt_boxes
class OCRService(WebService): class DetService(WebService):
def init_det(self): def init_det(self):
self.det_preprocess = Sequential([ self.text_detector = TextDetectorHelper(global_args)
ResizeByFactor(32, 960), Div(255),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
(2, 0, 1))
])
self.filter_func = FilterBoxes(10, 10)
self.post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
})
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
data = base64.b64decode(feed[0]["image"].encode('utf8')) data = base64.b64decode(feed[0]["image"].encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR) im = cv2.imdecode(data, cv2.IMREAD_COLOR)
self.ori_h, self.ori_w, _ = im.shape feed, fetch, self.tmp_args = self.text_detector.preprocess(im)
det_img = self.det_preprocess(im) return feed, fetch
_, self.new_h, self.new_w = det_img.shape
return {"image": det_img[np.newaxis, :].copy()}, ["concat_1.tmp_0"]
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
det_out = fetch_map["concat_1.tmp_0"] outputs = [fetch_map[x] for x in fetch]
ratio_list = [ res = self.text_detector.postprocess(outputs, self.tmp_args)
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w return {"boxes": res.tolist()}
]
dt_boxes_list = self.post_func(det_out, [ratio_list])
dt_boxes = self.filter_func(dt_boxes_list[0], [self.ori_h, self.ori_w])
return {"dt_boxes": dt_boxes.tolist()}
ocr_service = OCRService(name="ocr") if __name__ == "__main__":
ocr_service.load_model_config("ocr_det_model") ocr_service = DetService(name="ocr")
ocr_service.init_det() ocr_service.load_model_config("serving_server_dir")
if sys.argv[1] == 'gpu': ocr_service.init_det()
ocr_service.set_gpus("0") if global_args.use_gpu:
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0) ocr_service.prepare_server(
ocr_service.run_debugger_service(gpu=True) workdir="workdir", port=9292, device="gpu", gpuid=0)
elif sys.argv[1] == 'cpu': else:
ocr_service.prepare_server(workdir="workdir", port=9292) ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu")
ocr_service.run_debugger_service() ocr_service.run_debugger_service()
ocr_service.init_det() ocr_service.run_web_service()
ocr_service.run_web_service()
...@@ -17,62 +17,91 @@ import cv2 ...@@ -17,62 +17,91 @@ import cv2
import sys import sys
import numpy as np import numpy as np
import os import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes
if sys.argv[1] == 'gpu':
from paddle_serving_server_gpu.web_service import WebService
elif sys.argv[1] == 'cpu':
from paddle_serving_server.web_service import WebService
import time import time
import re import re
import base64 import base64
from tools.infer.predict_det import TextDetector
import tools.infer.utility as utility
global_args = utility.parse_args()
if global_args.use_gpu:
from paddle_serving_server_gpu.web_service import WebService
else:
from paddle_serving_server.web_service import WebService
class TextDetectorHelper(TextDetector):
def __init__(self, args):
super(TextDetectorHelper, self).__init__(args)
if self.det_algorithm == "SAST":
self.fetch = [
"bn_f_border4.output.tmp_2", "bn_f_tco4.output.tmp_2",
"bn_f_tvo4.output.tmp_2", "sigmoid_0.tmp_0"
]
elif self.det_algorithm == "EAST":
self.fetch = ["sigmoid_0.tmp_0", "tmp_2"]
elif self.det_algorithm == "DB":
self.fetch = ["sigmoid_0.tmp_0"]
def preprocess(self, img):
im, ratio_list = self.preprocess_op(img)
if im is None:
return None, 0
return {
"image": im[0]
}, self.fetch, {
"ratio_list": [ratio_list],
"ori_im": img
}
def postprocess(self, outputs, args):
outs_dict = {}
if self.det_algorithm == "EAST":
outs_dict['f_geo'] = outputs[0]
outs_dict['f_score'] = outputs[1]
elif self.det_algorithm == 'SAST':
outs_dict['f_border'] = outputs[0]
outs_dict['f_score'] = outputs[1]
outs_dict['f_tco'] = outputs[2]
outs_dict['f_tvo'] = outputs[3]
else:
outs_dict['maps'] = outputs[0]
dt_boxes_list = self.postprocess_op(outs_dict, args["ratio_list"])
dt_boxes = dt_boxes_list[0]
if self.det_algorithm == "SAST" and self.det_sast_polygon:
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes,
args["ori_im"].shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, args["ori_im"].shape)
return dt_boxes
class OCRService(WebService): class DetService(WebService):
def init_det(self): def init_det(self):
self.det_preprocess = Sequential([ self.text_detector = TextDetectorHelper(global_args)
ResizeByFactor(32, 960), Div(255), print("init finish")
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
(2, 0, 1))
])
self.filter_func = FilterBoxes(10, 10)
self.post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
})
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
data = base64.b64decode(feed[0]["image"].encode('utf8')) data = base64.b64decode(feed[0]["image"].encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR) im = cv2.imdecode(data, cv2.IMREAD_COLOR)
self.ori_h, self.ori_w, _ = im.shape feed, fetch, self.tmp_args = self.text_detector.preprocess(im)
det_img = self.det_preprocess(im) return feed, fetch
_, self.new_h, self.new_w = det_img.shape
print(det_img)
return {"image": det_img}, ["concat_1.tmp_0"]
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
det_out = fetch_map["concat_1.tmp_0"] outputs = [fetch_map[x] for x in fetch]
ratio_list = [ res = self.text_detector.postprocess(outputs, self.tmp_args)
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w return {"boxes": res.tolist()}
]
dt_boxes_list = self.post_func(det_out, [ratio_list])
dt_boxes = self.filter_func(dt_boxes_list[0], [self.ori_h, self.ori_w])
return {"dt_boxes": dt_boxes.tolist()}
ocr_service = OCRService(name="ocr") if __name__ == "__main__":
ocr_service.load_model_config("ocr_det_model") ocr_service = DetService(name="ocr")
if sys.argv[1] == 'gpu': ocr_service.load_model_config("serving_server_dir")
ocr_service.set_gpus("0") ocr_service.init_det()
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0) if global_args.use_gpu:
elif sys.argv[1] == 'cpu': ocr_service.prepare_server(
ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu") workdir="workdir", port=9292, device="gpu", gpuid=0)
ocr_service.init_det() else:
ocr_service.run_rpc_service() ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu")
ocr_service.run_web_service() ocr_service.run_rpc_service()
ocr_service.run_web_service()
...@@ -18,97 +18,107 @@ import cv2 ...@@ -18,97 +18,107 @@ import cv2
import sys import sys
import numpy as np import numpy as np
import os import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
if sys.argv[1] == 'gpu':
from paddle_serving_server_gpu.web_service import WebService
elif sys.argv[1] == 'cpu':
from paddle_serving_server.web_service import WebService
from paddle_serving_app.local_predict import Debugger
import time import time
import re import re
import base64 import base64
from clas_local_server import TextClassifierHelper
from det_local_server import TextDetectorHelper
from rec_local_server import TextRecognizerHelper
import tools.infer.utility as utility
from tools.infer.predict_system import TextSystem, sorted_boxes
from paddle_serving_app.local_predict import Debugger
import copy
global_args = utility.parse_args()
if global_args.use_gpu:
from paddle_serving_server_gpu.web_service import WebService
else:
from paddle_serving_server.web_service import WebService
class OCRService(WebService):
def init_det_debugger(self, det_model_config): class TextSystemHelper(TextSystem):
self.det_preprocess = Sequential([ def __init__(self, args):
ResizeByFactor(32, 960), Div(255), self.text_detector = TextDetectorHelper(args)
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose( self.text_recognizer = TextRecognizerHelper(args)
(2, 0, 1)) self.use_angle_cls = args.use_angle_cls
]) if self.use_angle_cls:
self.clas_client = Debugger()
self.clas_client.load_model_config(
"ocr_clas_server", gpu=True, profile=False)
self.text_classifier = TextClassifierHelper(args)
self.det_client = Debugger() self.det_client = Debugger()
if sys.argv[1] == 'gpu': self.det_client.load_model_config(
self.det_client.load_model_config( "serving_server_dir", gpu=True, profile=False)
det_model_config, gpu=True, profile=False) self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
elif sys.argv[1] == 'cpu':
self.det_client.load_model_config( def preprocess(self, img):
det_model_config, gpu=False, profile=False) feed, fetch, self.tmp_args = self.text_detector.preprocess(img)
self.ocr_reader = OCRReader() fetch_map = self.det_client.predict(feed, fetch)
print("det fetch_map", fetch_map)
outputs = [fetch_map[x] for x in fetch]
dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args)
if dt_boxes is None:
return None, None
img_crop_list = []
dt_boxes = sorted_boxes(dt_boxes)
for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = self.get_rotate_crop_image(img, tmp_box)
img_crop_list.append(img_crop)
if self.use_angle_cls:
feed, fetch, self.tmp_args = self.text_classifier.preprocess(
img_crop_list)
fetch_map = self.clas_client.predict(feed, fetch)
outputs = [fetch_map[x] for x in self.text_classifier.fetch]
for x in fetch_map.keys():
if ".lod" in x:
self.tmp_args[x] = fetch_map[x]
img_crop_list, _ = self.text_classifier.postprocess(outputs,
self.tmp_args)
feed, fetch, self.tmp_args = self.text_recognizer.preprocess(
img_crop_list)
return feed, self.fetch, self.tmp_args
def postprocess(self, outputs, args):
return self.text_recognizer.postprocess(outputs, args)
class OCRService(WebService):
def init_rec(self):
args = utility.parse_args()
self.text_system = TextSystemHelper(args)
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
# TODO: to handle batch rec images
print("start preprocess")
data = base64.b64decode(feed[0]["image"].encode('utf8')) data = base64.b64decode(feed[0]["image"].encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR) im = cv2.imdecode(data, cv2.IMREAD_COLOR)
ori_h, ori_w, _ = im.shape feed, fetch, self.tmp_args = self.text_system.preprocess(im)
det_img = self.det_preprocess(im) print("ocr preprocess done")
_, new_h, new_w = det_img.shape
det_img = det_img[np.newaxis, :]
det_img = det_img.copy()
det_out = self.det_client.predict(
feed={"image": det_img}, fetch=["concat_1.tmp_0"])
filter_func = FilterBoxes(10, 10)
post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
})
sorted_boxes = SortedBoxes()
ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w]
dt_boxes_list = post_func(det_out["concat_1.tmp_0"], [ratio_list])
dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w])
dt_boxes = sorted_boxes(dt_boxes)
get_rotate_crop_image = GetRotateCropImage()
img_list = []
max_wh_ratio = 0
for i, dtbox in enumerate(dt_boxes):
boximg = get_rotate_crop_image(im, dt_boxes[i])
img_list.append(boximg)
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
if len(img_list) == 0:
return [], []
_, w, h = self.ocr_reader.resize_norm_img(img_list[0],
max_wh_ratio).shape
imgs = np.zeros((len(img_list), 3, w, h)).astype('float32')
for id, img in enumerate(img_list):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[id] = norm_img
feed = {"image": imgs.copy()}
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
return feed, fetch return feed, fetch
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) outputs = [fetch_map[x] for x in self.text_system.fetch]
res_lst = [] for x in fetch_map.keys():
for res in rec_res: if ".lod" in x:
res_lst.append(res[0]) self.tmp_args[x] = fetch_map[x]
res = {"res": res_lst} rec_res = self.text_system.postprocess(outputs, self.tmp_args)
res = {
"pred_text": [x[0] for x in rec_res],
"score": [str(x[1]) for x in rec_res]
}
return res return res
ocr_service = OCRService(name="ocr") if __name__ == "__main__":
ocr_service.load_model_config("ocr_rec_model") ocr_service = OCRService(name="ocr")
ocr_service.init_det_debugger(det_model_config="ocr_det_model") ocr_service.load_model_config("ocr_rec_model")
if sys.argv[1] == 'gpu': ocr_service.init_rec()
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0) if global_args.use_gpu:
ocr_service.run_debugger_service(gpu=True) ocr_service.prepare_server(
elif sys.argv[1] == 'cpu': workdir="workdir", port=9292, device="gpu", gpuid=0)
ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu") else:
ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu")
ocr_service.run_debugger_service() ocr_service.run_debugger_service()
ocr_service.run_web_service() ocr_service.run_web_service()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import OCRReader
import cv2
import sys
import numpy as np
import os
import time
import re
import base64
from clas_rpc_server import TextClassifierHelper
from det_rpc_server import TextDetectorHelper
from rec_rpc_server import TextRecognizerHelper
import tools.infer.utility as utility
from tools.infer.predict_system import TextSystem
import copy
global_args = utility.parse_args()
if global_args.use_gpu:
from paddle_serving_server_gpu.web_service import WebService
else:
from paddle_serving_server.web_service import WebService
class TextSystemHelper(TextSystem):
def __init__(self, args):
self.text_detector = TextDetectorHelper(args)
self.text_recognizer = TextRecognizerHelper(args)
self.use_angle_cls = args.use_angle_cls
if self.use_angle_cls:
self.clas_client = Client()
self.clas_client.load_client_config(
"ocr_clas_client/serving_client_conf.prototxt")
self.clas_client.connect(["127.0.0.1:9294"])
self.text_classifier = TextClassifierHelper(args)
self.det_client = Client()
self.det_client.load_client_config(
"ocr_det_server/serving_client_conf.prototxt")
self.det_client.connect(["127.0.0.1:9293"])
self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
def preprocess(self, img):
feed, fetch, self.tmp_args = self.text_detector.preprocess(img)
fetch_map = self.det_client.predict(feed, fetch)
outputs = [fetch_map[x] for x in fetch]
dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args)
if dt_boxes is None:
return None, None
img_crop_list = []
sorted_boxes = SortedBoxes()
dt_boxes = sorted_boxes(dt_boxes)
for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = self.get_rotate_crop_image(img, tmp_box)
img_crop_list.append(img_crop)
if self.use_angle_cls:
feed, fetch, self.tmp_args = self.text_classifier.preprocess(
img_crop_list)
fetch_map = self.clas_client.predict(feed, fetch)
outputs = [fetch_map[x] for x in self.text_classifier.fetch]
for x in fetch_map.keys():
if ".lod" in x:
self.tmp_args[x] = fetch_map[x]
img_crop_list, _ = self.text_classifier.postprocess(outputs,
self.tmp_args)
feed, fetch, self.tmp_args = self.text_recognizer.preprocess(
img_crop_list)
return feed, self.fetch, self.tmp_args
def postprocess(self, outputs, args):
return self.text_recognizer.postprocess(outputs, args)
class OCRService(WebService):
def init_rec(self):
args = utility.parse_args()
self.text_system = TextSystemHelper(args)
def preprocess(self, feed=[], fetch=[]):
# TODO: to handle batch rec images
data = base64.b64decode(feed[0]["image"].encode('utf8'))
data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
feed, fetch, self.tmp_args = self.text_system.preprocess(im)
return feed, fetch
def postprocess(self, feed={}, fetch=[], fetch_map=None):
outputs = [fetch_map[x] for x in self.text_system.fetch]
for x in fetch_map.keys():
if ".lod" in x:
self.tmp_args[x] = fetch_map[x]
rec_res = self.text_system.postprocess(outputs, self.tmp_args)
res = {
"pred_text": [x[0] for x in rec_res],
"score": [str(x[1]) for x in rec_res]
}
return res
if __name__ == "__main__":
ocr_service = OCRService(name="ocr")
ocr_service.load_model_config(global_args.rec_model_dir)
ocr_service.init_rec()
if global_args.use_gpu:
ocr_service.prepare_server(
workdir="workdir", port=9292, device="gpu", gpuid=0)
else:
ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu")
ocr_service.run_rpc_service()
ocr_service.run_web_service()
...@@ -20,11 +20,13 @@ import base64 ...@@ -20,11 +20,13 @@ import base64
import os, sys import os, sys
import time import time
def cv2_to_base64(image): def cv2_to_base64(image):
#data = cv2.imencode('.jpg', image)[1] #data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(image).decode( return base64.b64encode(image).decode(
'utf8') #data.tostring()).decode('utf8') 'utf8') #data.tostring()).decode('utf8')
headers = {"Content-type": "application/json"} headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:9292/ocr/prediction" url = "http://127.0.0.1:9292/ocr/prediction"
test_img_dir = "../../doc/imgs/" test_img_dir = "../../doc/imgs/"
...@@ -34,4 +36,8 @@ for img_file in os.listdir(test_img_dir): ...@@ -34,4 +36,8 @@ for img_file in os.listdir(test_img_dir):
image = cv2_to_base64(image_data1) image = cv2_to_base64(image_data1)
data = {"feed": [{"image": image}], "fetch": ["res"]} data = {"feed": [{"image": image}], "fetch": ["res"]}
r = requests.post(url=url, headers=headers, data=json.dumps(data)) r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r.json()) print(r)
rjson = r.json()
print(rjson)
#for x in rjson["result"]["pred_text"]:
# print(x)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import OCRReader
import cv2
import sys
import numpy as np
import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
if sys.argv[1] == 'gpu':
from paddle_serving_server_gpu.web_service import WebService
elif sys.argv[1] == 'cpu':
from paddle_serving_server.web_service import WebService
import time
import re
import base64
class OCRService(WebService):
def init_det_client(self, det_port, det_client_config):
self.det_preprocess = Sequential([
ResizeByFactor(32, 960), Div(255),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
(2, 0, 1))
])
self.det_client = Client()
self.det_client.load_client_config(det_client_config)
self.det_client.connect(["127.0.0.1:{}".format(det_port)])
self.ocr_reader = OCRReader()
def preprocess(self, feed=[], fetch=[]):
data = base64.b64decode(feed[0]["image"].encode('utf8'))
data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
ori_h, ori_w, _ = im.shape
det_img = self.det_preprocess(im)
det_out = self.det_client.predict(
feed={"image": det_img}, fetch=["concat_1.tmp_0"])
_, new_h, new_w = det_img.shape
filter_func = FilterBoxes(10, 10)
post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
})
sorted_boxes = SortedBoxes()
ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w]
dt_boxes_list = post_func(det_out["concat_1.tmp_0"], [ratio_list])
dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w])
dt_boxes = sorted_boxes(dt_boxes)
get_rotate_crop_image = GetRotateCropImage()
feed_list = []
img_list = []
max_wh_ratio = 0
for i, dtbox in enumerate(dt_boxes):
boximg = get_rotate_crop_image(im, dt_boxes[i])
img_list.append(boximg)
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for img in img_list:
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
feed = {"image": norm_img}
feed_list.append(feed)
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
return feed_list, fetch
def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True)
res_lst = []
for res in rec_res:
res_lst.append(res[0])
res = {"res": res_lst}
return res
ocr_service = OCRService(name="ocr")
ocr_service.load_model_config("ocr_rec_model")
if sys.argv[1] == 'gpu':
ocr_service.set_gpus("0")
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
elif sys.argv[1] == 'cpu':
ocr_service.prepare_server(workdir="workdir", port=9292)
ocr_service.init_det_client(
det_port=9293,
det_client_config="ocr_det_client/serving_client_conf.prototxt")
ocr_service.run_rpc_service()
ocr_service.run_web_service()
[English](readme_en.md) | 简体中文 [English](readme_en.md) | 简体中文
PaddleOCR提供2种服务部署方式: PaddleOCR提供2种服务部署方式:
- 基于PaddleHub Serving的部署:代码路径为"`./deploy/hubserving`",使用方法参考[文档](../hubserving/readme.md) - 基于PaddleHub Serving的部署:代码路径为"`./deploy/hubserving`",使用方法参考[文档](../hubserving/readme.md)
- 基于PaddleServing的部署:代码路径为"`./deploy/pdserving`",按照本教程使用。 - 基于PaddleServing的部署:代码路径为"`./deploy/pdserving`",按照本教程使用。
# Paddle Serving 服务部署 # Paddle Serving 服务部署
...@@ -11,7 +11,7 @@ PaddleOCR提供2种服务部署方式: ...@@ -11,7 +11,7 @@ PaddleOCR提供2种服务部署方式:
### 1. 准备环境 ### 1. 准备环境
我们先安装Paddle Serving相关组件 我们先安装Paddle Serving相关组件
我们推荐用户使用GPU来做Paddle Serving的OCR服务部署 我们推荐用户使用GPU来做Paddle Serving的OCR服务部署
**CUDA版本:9.0** **CUDA版本:9.0**
...@@ -39,7 +39,7 @@ python -m pip install paddle_serving_app paddle_serving_client ...@@ -39,7 +39,7 @@ python -m pip install paddle_serving_app paddle_serving_client
python -m paddle_serving_app.package --get_model ocr_rec python -m paddle_serving_app.package --get_model ocr_rec
tar -xzvf ocr_rec.tar.gz tar -xzvf ocr_rec.tar.gz
python -m paddle_serving_app.package --get_model ocr_det python -m paddle_serving_app.package --get_model ocr_det
tar -xzvf ocr_det.tar.gz tar -xzvf ocr_det.tar.gz
``` ```
执行上述命令会下载`db_crnn_mobile`的模型,如果想要下载规模更大的`db_crnn_server`模型,可以在下载预测模型并解压之后。参考[如何从Paddle保存的预测模型转为Paddle Serving格式可部署的模型](https://github.com/PaddlePaddle/Serving/blob/develop/doc/INFERENCE_TO_SERVING_CN.md) 执行上述命令会下载`db_crnn_mobile`的模型,如果想要下载规模更大的`db_crnn_server`模型,可以在下载预测模型并解压之后。参考[如何从Paddle保存的预测模型转为Paddle Serving格式可部署的模型](https://github.com/PaddlePaddle/Serving/blob/develop/doc/INFERENCE_TO_SERVING_CN.md)
...@@ -72,7 +72,7 @@ feed_var_names, fetch_var_names = inference_model_to_serving( ...@@ -72,7 +72,7 @@ feed_var_names, fetch_var_names = inference_model_to_serving(
``` ```
# cpu,gpu启动二选一,以下是cpu启动 # cpu,gpu启动二选一,以下是cpu启动
python -m paddle_serving_server.serve --model ocr_det_model --port 9293 python -m paddle_serving_server.serve --model ocr_det_model --port 9293
python ocr_web_server.py cpu python ocr_web_server.py cpu
# gpu启动 # gpu启动
python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0 python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0
......
English | [简体中文](readme.md) English | [简体中文](readme.md)
PaddleOCR provides 2 service deployment methods: PaddleOCR provides 2 service deployment methods:
- Based on **PaddleHub Serving**: Code path is "`./deploy/hubserving`". Please refer to the [tutorial](../hubserving/readme_en.md) for usage. - Based on **PaddleHub Serving**: Code path is "`./deploy/hubserving`". Please refer to the [tutorial](../hubserving/readme_en.md) for usage.
- Based on **PaddleServing**: Code path is "`./deploy/pdserving`". Please follow this tutorial. - Based on **PaddleServing**: Code path is "`./deploy/pdserving`". Please follow this tutorial.
...@@ -37,7 +37,7 @@ You can directly use converted model provided by `paddle_serving_app` for conven ...@@ -37,7 +37,7 @@ You can directly use converted model provided by `paddle_serving_app` for conven
python -m paddle_serving_app.package --get_model ocr_rec python -m paddle_serving_app.package --get_model ocr_rec
tar -xzvf ocr_rec.tar.gz tar -xzvf ocr_rec.tar.gz
python -m paddle_serving_app.package --get_model ocr_det python -m paddle_serving_app.package --get_model ocr_det
tar -xzvf ocr_det.tar.gz tar -xzvf ocr_det.tar.gz
``` ```
Executing the above command will download the `db_crnn_mobile` model, which is in different format with inference model. If you want to use other models for deployment, you can refer to the [tutorial](https://github.com/PaddlePaddle/Serving/blob/develop/doc/INFERENCE_TO_SERVING_CN.md) to convert your inference model to a model which is deployable for Paddle Serving. Executing the above command will download the `db_crnn_mobile` model, which is in different format with inference model. If you want to use other models for deployment, you can refer to the [tutorial](https://github.com/PaddlePaddle/Serving/blob/develop/doc/INFERENCE_TO_SERVING_CN.md) to convert your inference model to a model which is deployable for Paddle Serving.
...@@ -71,7 +71,7 @@ Start the standard version or the fast version service according to your actual ...@@ -71,7 +71,7 @@ Start the standard version or the fast version service according to your actual
``` ```
# start with CPU # start with CPU
python -m paddle_serving_server.serve --model ocr_det_model --port 9293 python -m paddle_serving_server.serve --model ocr_det_model --port 9293
python ocr_web_server.py cpu python ocr_web_server.py cpu
# or, with GPU # or, with GPU
......
...@@ -18,62 +18,159 @@ import cv2 ...@@ -18,62 +18,159 @@ import cv2
import sys import sys
import numpy as np import numpy as np
import os import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
if sys.argv[1] == 'gpu':
from paddle_serving_server_gpu.web_service import WebService
elif sys.argv[1] == 'cpu':
from paddle_serving_server.web_service import WebService
import time import time
import re import re
import base64 import base64
from tools.infer.predict_rec import TextRecognizer
import tools.infer.utility as utility
global_args = utility.parse_args()
if global_args.use_gpu:
from paddle_serving_server_gpu.web_service import WebService
else:
from paddle_serving_server.web_service import WebService
class TextRecognizerHelper(TextRecognizer):
def __init__(self, args):
super(TextRecognizerHelper, self).__init__(args)
if self.loss_type == "ctc":
self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
def preprocess(self, img_list):
img_num = len(img_list)
args = {}
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
indices = np.argsort(np.array(width_list))
args["indices"] = indices
predict_time = 0
beg_img_no = 0
end_img_no = img_num
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
if self.loss_type != "srn":
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
else:
norm_img = self.process_image_srn(img_list[indices[ino]],
self.rec_image_shape, 8, 25,
self.char_ops)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
encoder_word_pos_list.append(norm_img[1])
gsrm_word_pos_list.append(norm_img[2])
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch, axis=0).copy()
feed = {"image": norm_img_batch.copy()}
return feed, self.fetch, args
def postprocess(self, outputs, args):
if self.loss_type == "ctc":
rec_idx_batch = outputs[0]
predict_batch = outputs[1]
rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"]
predict_lod = args["softmax_0.tmp_0.lod"]
indices = args["indices"]
print("indices", indices, rec_idx_lod)
rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1)
for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno]
end = rec_idx_lod[rno + 1]
rec_idx_tmp = rec_idx_batch[beg:end, 0]
preds_text = self.char_ops.decode(rec_idx_tmp)
beg = predict_lod[rno]
end = predict_lod[rno + 1]
probs = predict_batch[beg:end, :]
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
if len(valid_ind) == 0:
continue
score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res[indices[rno]] = [preds_text, score]
elif self.loss_type == 'srn':
char_num = self.char_ops.get_char_num()
preds = rec_idx_batch.reshape(-1)
elapse = time.time() - starttime
predict_time += elapse
total_preds = preds.copy()
for ino in range(int(len(rec_idx_batch) / self.text_len)):
preds = total_preds[ino * self.text_len:(ino + 1) *
self.text_len]
ind = np.argmax(probs, axis=1)
valid_ind = np.where(preds != int(char_num - 1))[0]
if len(valid_ind) == 0:
continue
score = np.mean(probs[valid_ind, ind[valid_ind]])
preds = preds[:valid_ind[-1] + 1]
preds_text = self.char_ops.decode(preds)
rec_res[indices[ino]] = [preds_text, score]
else:
for rno in range(len(rec_idx_batch)):
end_pos = np.where(rec_idx_batch[rno, :] == 1)[0]
if len(end_pos) <= 1:
preds = rec_idx_batch[rno, 1:]
score = np.mean(predict_batch[rno, 1:])
else:
preds = rec_idx_batch[rno, 1:end_pos[1]]
score = np.mean(predict_batch[rno, 1:end_pos[1]])
preds_text = self.char_ops.decode(preds)
rec_res[indices[rno]] = [preds_text, score]
return rec_res
class OCRService(WebService): class OCRService(WebService):
def init_rec(self): def init_rec(self):
self.ocr_reader = OCRReader() self.ocr_reader = OCRReader()
self.text_recognizer = TextRecognizerHelper(global_args)
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
# TODO: to handle batch rec images
img_list = [] img_list = []
for feed_data in feed: for feed_data in feed:
data = base64.b64decode(feed_data["image"].encode('utf8')) data = base64.b64decode(feed_data["image"].encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR) im = cv2.imdecode(data, cv2.IMREAD_COLOR)
img_list.append(im) img_list.append(im)
max_wh_ratio = 0 feed, fetch, self.tmp_args = self.text_recognizer.preprocess(img_list)
for i, boximg in enumerate(img_list):
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
_, w, h = self.ocr_reader.resize_norm_img(img_list[0],
max_wh_ratio).shape
imgs = np.zeros((len(img_list), 3, w, h)).astype('float32')
for i, img in enumerate(img_list):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[i] = norm_img
feed = {"image": imgs.copy()}
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
return feed, fetch return feed, fetch
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) outputs = [fetch_map[x] for x in self.text_recognizer.fetch]
res_lst = [] for x in fetch_map.keys():
for res in rec_res: if ".lod" in x:
res_lst.append(res[0]) self.tmp_args[x] = fetch_map[x]
res = {"res": res_lst} rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args)
print("rec_res", rec_res)
res = {
"pred_text": [x[0] for x in rec_res],
"score": [str(x[1]) for x in rec_res]
}
return res return res
ocr_service = OCRService(name="ocr") if __name__ == "__main__":
ocr_service.load_model_config("ocr_rec_model") ocr_service = OCRService(name="ocr")
ocr_service.init_rec() ocr_service.load_model_config("ocr_rec_model")
if sys.argv[1] == 'gpu': ocr_service.init_rec()
ocr_service.set_gpus("0") if global_args.use_gpu:
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0) ocr_service.prepare_server(
ocr_service.run_debugger_service(gpu=True) workdir="workdir", port=9292, device="gpu", gpuid=0)
elif sys.argv[1] == 'cpu': else:
ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu") ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu")
ocr_service.run_debugger_service() ocr_service.run_debugger_service()
ocr_service.run_web_service() ocr_service.run_web_service()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import OCRReader
import cv2
import sys
import numpy as np
import os
import time
import re
import base64
from tools.infer.predict_rec import TextRecognizer
import tools.infer.utility as utility
global_args = utility.parse_args()
if global_args.use_gpu:
from paddle_serving_server_gpu.web_service import WebService
else:
from paddle_serving_server.web_service import WebService
class TextRecognizerHelper(TextRecognizer):
def __init__(self, args):
super(TextRecognizerHelper, self).__init__(args)
if self.loss_type == "ctc":
self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
def preprocess(self, img_list):
img_num = len(img_list)
args = {}
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
indices = np.argsort(np.array(width_list))
args["indices"] = indices
predict_time = 0
beg_img_no = 0
end_img_no = img_num
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
if self.loss_type != "srn":
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
else:
norm_img = self.process_image_srn(img_list[indices[ino]],
self.rec_image_shape, 8, 25,
self.char_ops)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
encoder_word_pos_list.append(norm_img[1])
gsrm_word_pos_list.append(norm_img[2])
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch, axis=0)
if img_num > 1:
feed = [{
"image": norm_img_batch[x]
} for x in range(norm_img_batch.shape[0])]
else:
feed = {"image": norm_img_batch[0]}
return feed, self.fetch, args
def postprocess(self, outputs, args):
if self.loss_type == "ctc":
rec_idx_batch = outputs[0]
predict_batch = outputs[1]
rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"]
predict_lod = args["softmax_0.tmp_0.lod"]
indices = args["indices"]
print("indices", indices, rec_idx_lod)
rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1)
for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno]
end = rec_idx_lod[rno + 1]
rec_idx_tmp = rec_idx_batch[beg:end, 0]
preds_text = self.char_ops.decode(rec_idx_tmp)
beg = predict_lod[rno]
end = predict_lod[rno + 1]
probs = predict_batch[beg:end, :]
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
if len(valid_ind) == 0:
continue
score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res[indices[rno]] = [preds_text, score]
elif self.loss_type == 'srn':
char_num = self.char_ops.get_char_num()
preds = rec_idx_batch.reshape(-1)
elapse = time.time() - starttime
predict_time += elapse
total_preds = preds.copy()
for ino in range(int(len(rec_idx_batch) / self.text_len)):
preds = total_preds[ino * self.text_len:(ino + 1) *
self.text_len]
ind = np.argmax(probs, axis=1)
valid_ind = np.where(preds != int(char_num - 1))[0]
if len(valid_ind) == 0:
continue
score = np.mean(probs[valid_ind, ind[valid_ind]])
preds = preds[:valid_ind[-1] + 1]
preds_text = self.char_ops.decode(preds)
rec_res[indices[ino]] = [preds_text, score]
else:
for rno in range(len(rec_idx_batch)):
end_pos = np.where(rec_idx_batch[rno, :] == 1)[0]
if len(end_pos) <= 1:
preds = rec_idx_batch[rno, 1:]
score = np.mean(predict_batch[rno, 1:])
else:
preds = rec_idx_batch[rno, 1:end_pos[1]]
score = np.mean(predict_batch[rno, 1:end_pos[1]])
preds_text = self.char_ops.decode(preds)
rec_res[indices[rno]] = [preds_text, score]
return rec_res
class OCRService(WebService):
def init_rec(self):
self.ocr_reader = OCRReader()
self.text_recognizer = TextRecognizerHelper(global_args)
def preprocess(self, feed=[], fetch=[]):
# TODO: to handle batch rec images
img_list = []
for feed_data in feed:
data = base64.b64decode(feed_data["image"].encode('utf8'))
data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
img_list.append(im)
feed, fetch, self.tmp_args = self.text_recognizer.preprocess(img_list)
return feed, fetch
def postprocess(self, feed={}, fetch=[], fetch_map=None):
outputs = [fetch_map[x] for x in self.text_recognizer.fetch]
for x in fetch_map.keys():
if ".lod" in x:
self.tmp_args[x] = fetch_map[x]
rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args)
print("rec_res", rec_res)
res = {
"pred_text": [x[0] for x in rec_res],
"score": [str(x[1]) for x in rec_res]
}
return res
if __name__ == "__main__":
ocr_service = OCRService(name="ocr")
ocr_service.load_model_config(global_args.rec_model_dir)
ocr_service.init_rec()
if global_args.use_gpu:
ocr_service.prepare_server(
workdir="workdir", port=9292, device="gpu", gpuid=0)
else:
ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu")
ocr_service.run_rpc_service()
ocr_service.run_web_service()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
import requests
import json
import cv2
import base64
import os, sys
import time
def cv2_to_base64(image):
#data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(image).decode(
'utf8') #data.tostring()).decode('utf8')
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:9292/ocr/prediction"
test_img_dir = "../../doc/imgs_words/ch/"
for img_file in os.listdir(test_img_dir):
with open(os.path.join(test_img_dir, img_file), 'rb') as file:
image_data1 = file.read()
image = cv2_to_base64(image_data1)
data = {"feed": [{"image": image}], "fetch": ["res"]}
r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r.json())
break
# 使用Paddle Serving预测推理
阅读本文档之前,请先阅读文档 [基于Python预测引擎推理](./inference.md)
同本地执行预测一样,我们需要保存一份可以用于Paddle Serving的模型。
接下来首先介绍如何将训练的模型转换成Paddle Serving模型,然后将依次介绍文本检测、文本识别以及两者串联基于预测引擎推理。
## 一、训练模型转Serving模型
### 检测模型转Serving模型
下载超轻量级中文检测模型:
```
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar && tar xf ./ch_lite/ch_det_mv3_db.tar -C ./ch_lite/
```
上述模型是以MobileNetV3为backbone训练的DB算法,将训练好的模型转换成Serving模型只需要运行如下命令:
```
# -c后面设置训练算法的yml配置文件
# -o配置可选参数
# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
python tools/export_serving_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./ch_lite/det_mv3_db/best_accuracy Global.save_inference_dir=./inference/det_db/
```
转Serving模型时,使用的配置文件和训练时使用的配置文件相同。另外,还需要设置配置文件中的`Global.checkpoints``Global.save_inference_dir`参数。 其中`Global.checkpoints`指向训练中保存的模型参数文件,`Global.save_inference_dir`是生成的inference模型要保存的目录。 转换成功后,在`save_inference_dir`目录下有两个文件:
```
inference/det_db/
├── serving_client_dir # 客户端配置文件夹
└── serving_server_dir # 服务端配置文件夹
```
### 识别模型转Serving模型
下载超轻量中文识别模型:
```
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar && tar xf ./ch_lite/ch_rec_mv3_crnn.tar -C ./ch_lite/
```
识别模型转inference模型与检测的方式相同,如下:
```
# -c后面设置训练算法的yml配置文件
# -o配置可选参数
# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
python3 tools/export_serving_model.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints=./ch_lite/rec_mv3_crnn/best_accuracy \
Global.save_inference_dir=./inference/rec_crnn/
```
**注意:**如果您是在自己的数据集上训练的模型,并且调整了中文字符的字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
转换成功后,在目录下有两个文件:
```
/inference/rec_crnn/
├── serving_client_dir # 客户端配置文件夹
└── serving_server_dir # 服务端配置文件夹
```
### 方向分类模型转Serving模型
下载方向分类模型:
```
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile-v1.1.cls_pre.tar && tar xf ./ch_lite/ch_ppocr_mobile-v1.1.cls_pre.tar -C ./ch_lite/
```
方向分类模型转inference模型与检测的方式相同,如下:
```
# -c后面设置训练算法的yml配置文件
# -o配置可选参数
# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
python3 tools/export_serving_model.py -c configs/cls/cls_mv3.yml -o Global.checkpoints=./ch_lite/cls_model/best_accuracy \
Global.save_inference_dir=./inference/cls/
```
转换成功后,在目录下有两个文件:
```
/inference/cls/
├── serving_client_dir # 客户端配置文件夹
└── serving_server_dir # 服务端配置文件夹
```
在接下来的教程中,我们将给出推理的demo模型下载链接。
```
wget --no-check-certificate ocr_serving_model_zoo.tar.gz
tar zxf ocr_serving_model_zoo.tar.gz
```
## 二、文本检测模型Serving推理
文本检测模型推理,默认使用DB模型的配置参数。当不使用DB模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。
与本地预测不同的是,Serving预测需要一个客户端和一个服务端,因此接下来的教程都是两行代码。所有的
### 1. 超轻量中文检测模型推理
超轻量中文检测模型推理,可以执行如下命令启动服务端:
```
#根据环境只需要启动其中一个就可以
python det_rpc_server.py --use_serving True #标准版,Linux用户
python det_local_server.py --use_serving True #快速版,Windows/Linux用户
```
客户端
```
python det_web_client.py
```
Serving的推测和本地预测不同点在于,客户端发送请求到服务端,服务端需要检测到文字框之后返回框的坐标,此处没有后处理的图片,只能看到坐标值。
### 2. DB文本检测模型推理
首先将DB文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)),可以使用如下命令进行转换:
```
# -c后面设置训练算法的yml配置文件
# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
python3 tools/export_serving_model.py -c configs/det/det_r50_vd_db.yml -o Global.checkpoints="./models/det_r50_vd_db/best_accuracy" Global.save_inference_dir="./inference/det_db"
```
经过转换之后,会在`./inference/det_db` 目录下出现`serving_server_dir``serving_client_dir`,然后指定`det_model_dir`
## 三、文本识别模型Serving推理
下面将介绍超轻量中文识别模型推理、基于CTC损失的识别模型推理和基于Attention损失的识别模型推理。对于中文文本识别,建议优先选择基于CTC损失的识别模型,实践中也发现基于Attention损失的效果不如基于CTC损失的识别模型。此外,如果训练时修改了文本的字典,请参考下面的自定义文本识别字典的推理。
### 1. 超轻量中文识别模型推理
超轻量中文识别模型推理,可以执行如下命令启动服务端:
```
#根据环境只需要启动其中一个就可以
python rec_rpc_server.py --use_serving True #标准版,Linux用户
python rec_local_server.py --use_serving True #快速版,Windows/Linux用户
```
客户端
```
python rec_web_client.py
```
执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
```
{u'result': {u'score': [u'0.89547354'], u'pred_text': ['实力活力']}}
```
## 四、方向分类模型推理
下面将介绍方向分类模型推理。
### 1. 方向分类模型推理
方向分类模型推理, 可以执行如下命令启动服务端:
```
#根据环境只需要启动其中一个就可以
python clas_rpc_server.py --use_serving True #标准版,Linux用户
python clas_local_server.py --use_serving True #快速版,Windows/Linux用户
```
客户端
```
python rec_web_client.py
```
![](../imgs_words/ch/word_4.jpg)
执行命令后,上面图像的预测结果(分类的方向和得分)会打印到屏幕上,示例如下:
```
{u'result': {u'direction': [u'0'], u'score': [u'0.9999963']}}
```
## 五、文本检测、方向分类和文字识别串联Serving推理
### 1. 超轻量中文OCR模型推理
在执行预测时,需要通过参数`image_dir`指定单张图像或者图像集合的路径、参数`det_model_dir`,`cls_model_dir``rec_model_dir`分别指定检测,方向分类和识别的inference模型路径。参数`use_angle_cls`用于控制是否启用方向分类模型。与本地预测不同的是,为了减少网络传输耗时,可视化识别结果目前不做处理,用户收到的是推理得到的文字字段。
执行如下命令启动服务端:
```
#标准版,Linux用户
#GPU用户
python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0
python -m paddle_serving_server_gpu.serve --model ocr_cls_model --port 9294 --gpu_id 0
python ocr_rpc_server.py --use_serving True --use_gpu True
#CPU用户
python -m paddle_serving_server.serve --model ocr_det_model --port 9293
python -m paddle_serving_server.serve --model ocr_cls_model --port 9294
python ocr_rpc_server.py --use_serving True --use_gpu False
#快速版,Windows/Linux用户
python ocr_local_server.py --use_serving True
```
客户端
```
python rec_web_client.py
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect.
set_paddle_flags(
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
)
import program
from paddle import fluid
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from ppocr.utils.save_load import init_model
from paddle_serving_client.io import save_model
def main():
startup_prog, eval_program, place, config, _ = program.preprocess()
feeded_var_names, target_vars, fetches_var_name = program.build_export(
config, eval_program, startup_prog)
eval_program = eval_program.clone(for_test=True)
exe = fluid.Executor(place)
exe.run(startup_prog)
init_model(config, eval_program, exe)
save_inference_dir = config['Global']['save_inference_dir']
if not os.path.exists(save_inference_dir):
os.makedirs(save_inference_dir)
serving_client_dir = "{}/serving_client_dir".format(save_inference_dir)
serving_server_dir = "{}/serving_server_dir".format(save_inference_dir)
feed_dict = {
x: eval_program.global_block().var(x)
for x in feeded_var_names
}
fetch_dict = {x.name: x for x in target_vars}
save_model(serving_server_dir, serving_client_dir, feed_dict, fetch_dict,
eval_program)
print(
"paddle serving model saved in {}/serving_server_dir and {}/serving_client_dir".
format(save_inference_dir, save_inference_dir))
print("save success, output_name_list:", fetches_var_name)
if __name__ == '__main__':
main()
...@@ -33,8 +33,9 @@ from paddle import fluid ...@@ -33,8 +33,9 @@ from paddle import fluid
class TextClassifier(object): class TextClassifier(object):
def __init__(self, args): def __init__(self, args):
self.predictor, self.input_tensor, self.output_tensors = \ if args.use_serving is False:
utility.create_predictor(args, mode="cls") self.predictor, self.input_tensor, self.output_tensors = \
utility.create_predictor(args, mode="cls")
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")] self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
self.cls_batch_num = args.rec_batch_num self.cls_batch_num = args.rec_batch_num
self.label_list = args.label_list self.label_list = args.label_list
...@@ -103,7 +104,6 @@ class TextClassifier(object): ...@@ -103,7 +104,6 @@ class TextClassifier(object):
label_out = self.output_tensors[1].copy_to_cpu() label_out = self.output_tensors[1].copy_to_cpu()
if len(label_out.shape) != 1: if len(label_out.shape) != 1:
prob_out, label_out = label_out, prob_out prob_out, label_out = label_out, prob_out
elapse = time.time() - starttime elapse = time.time() - starttime
predict_time += elapse predict_time += elapse
for rno in range(len(label_out)): for rno in range(len(label_out)):
......
...@@ -75,9 +75,9 @@ class TextDetector(object): ...@@ -75,9 +75,9 @@ class TextDetector(object):
else: else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0) sys.exit(0)
if args.use_gpu is False:
self.predictor, self.input_tensor, self.output_tensors =\ self.predictor, self.input_tensor, self.output_tensors =\
utility.create_predictor(args, mode="det") utility.create_predictor(args, mode="det")
def order_points_clockwise(self, pts): def order_points_clockwise(self, pts):
""" """
......
...@@ -34,8 +34,9 @@ from ppocr.utils.character import CharacterOps ...@@ -34,8 +34,9 @@ from ppocr.utils.character import CharacterOps
class TextRecognizer(object): class TextRecognizer(object):
def __init__(self, args): def __init__(self, args):
self.predictor, self.input_tensor, self.output_tensors =\ if args.use_serving is False:
utility.create_predictor(args, mode="rec") self.predictor, self.input_tensor, self.output_tensors =\
utility.create_predictor(args, mode="rec")
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")] self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
self.character_type = args.rec_char_type self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num self.rec_batch_num = args.rec_batch_num
...@@ -320,7 +321,7 @@ def main(args): ...@@ -320,7 +321,7 @@ def main(args):
print(e) print(e)
logger.info( logger.info(
"ERROR!!!! \n" "ERROR!!!! \n"
"Please read the FAQhttps://github.com/PaddlePaddle/PaddleOCR#faq \n" "Please read the FAQ: https://github.com/PaddlePaddle/PaddleOCR#faq \n"
"If your model has tps module: " "If your model has tps module: "
"TPS does not support variable shape.\n" "TPS does not support variable shape.\n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
......
...@@ -37,6 +37,7 @@ def parse_args(): ...@@ -37,6 +37,7 @@ def parse_args():
parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000) parser.add_argument("--gpu_mem", type=int, default=8000)
parser.add_argument("--use_serving", type=str2bool, default=False)
# params for text detector # params for text detector
parser.add_argument("--image_dir", type=str) parser.add_argument("--image_dir", type=str)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册