提交 c4720557 编写于 作者: W wangjiawei04

fix all minor bugs

上级 7cacfc97
...@@ -117,7 +117,7 @@ class OCRService(WebService): ...@@ -117,7 +117,7 @@ class OCRService(WebService):
if __name__ == "__main__": if __name__ == "__main__":
ocr_service = OCRService(name="ocr") ocr_service = OCRService(name="ocr")
ocr_service.load_model_config("cls_server") ocr_service.load_model_config(global_args.cls_model_dir)
ocr_service.init_rec() ocr_service.init_rec()
if global_args.use_gpu: if global_args.use_gpu:
ocr_service.prepare_server( ocr_service.prepare_server(
......
...@@ -37,4 +37,3 @@ for img_file in os.listdir(test_img_dir): ...@@ -37,4 +37,3 @@ for img_file in os.listdir(test_img_dir):
data = {"feed": [{"image": image}], "fetch": ["res"]} data = {"feed": [{"image": image}], "fetch": ["res"]}
r = requests.post(url=url, headers=headers, data=json.dumps(data)) r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r.json()) print(r.json())
break
...@@ -96,7 +96,7 @@ class DetService(WebService): ...@@ -96,7 +96,7 @@ class DetService(WebService):
if __name__ == "__main__": if __name__ == "__main__":
ocr_service = DetService(name="ocr") ocr_service = DetService(name="ocr")
ocr_service.load_model_config("serving_server_dir") ocr_service.load_model_config(global_args.det_model_dir)
ocr_service.init_det() ocr_service.init_det()
if global_args.use_gpu: if global_args.use_gpu:
ocr_service.prepare_server( ocr_service.prepare_server(
......
...@@ -79,7 +79,6 @@ class TextDetectorHelper(TextDetector): ...@@ -79,7 +79,6 @@ class TextDetectorHelper(TextDetector):
class DetService(WebService): class DetService(WebService):
def init_det(self): def init_det(self):
self.text_detector = TextDetectorHelper(global_args) self.text_detector = TextDetectorHelper(global_args)
print("init finish")
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
data = base64.b64decode(feed[0]["image"].encode('utf8')) data = base64.b64decode(feed[0]["image"].encode('utf8'))
...@@ -96,7 +95,7 @@ class DetService(WebService): ...@@ -96,7 +95,7 @@ class DetService(WebService):
if __name__ == "__main__": if __name__ == "__main__":
ocr_service = DetService(name="ocr") ocr_service = DetService(name="ocr")
ocr_service.load_model_config("serving_server_dir") ocr_service.load_model_config(global_args.det_model_dir)
ocr_service.init_det() ocr_service.init_det()
if global_args.use_gpu: if global_args.use_gpu:
ocr_service.prepare_server( ocr_service.prepare_server(
......
...@@ -44,17 +44,16 @@ class TextSystemHelper(TextSystem): ...@@ -44,17 +44,16 @@ class TextSystemHelper(TextSystem):
if self.use_angle_cls: if self.use_angle_cls:
self.clas_client = Debugger() self.clas_client = Debugger()
self.clas_client.load_model_config( self.clas_client.load_model_config(
"ocr_clas_server", gpu=True, profile=False) global_args.cls_model_dir, gpu=True, profile=False)
self.text_classifier = TextClassifierHelper(args) self.text_classifier = TextClassifierHelper(args)
self.det_client = Debugger() self.det_client = Debugger()
self.det_client.load_model_config( self.det_client.load_model_config(
"serving_server_dir", gpu=True, profile=False) global_args.det_model_dir, gpu=True, profile=False)
self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
def preprocess(self, img): def preprocess(self, img):
feed, fetch, self.tmp_args = self.text_detector.preprocess(img) feed, fetch, self.tmp_args = self.text_detector.preprocess(img)
fetch_map = self.det_client.predict(feed, fetch) fetch_map = self.det_client.predict(feed, fetch)
print("det fetch_map", fetch_map)
outputs = [fetch_map[x] for x in fetch] outputs = [fetch_map[x] for x in fetch]
dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args) dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args)
if dt_boxes is None: if dt_boxes is None:
...@@ -90,12 +89,10 @@ class OCRService(WebService): ...@@ -90,12 +89,10 @@ class OCRService(WebService):
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
# TODO: to handle batch rec images # TODO: to handle batch rec images
print("start preprocess")
data = base64.b64decode(feed[0]["image"].encode('utf8')) data = base64.b64decode(feed[0]["image"].encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR) im = cv2.imdecode(data, cv2.IMREAD_COLOR)
feed, fetch, self.tmp_args = self.text_system.preprocess(im) feed, fetch, self.tmp_args = self.text_system.preprocess(im)
print("ocr preprocess done")
return feed, fetch return feed, fetch
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
......
...@@ -25,7 +25,7 @@ from clas_rpc_server import TextClassifierHelper ...@@ -25,7 +25,7 @@ from clas_rpc_server import TextClassifierHelper
from det_rpc_server import TextDetectorHelper from det_rpc_server import TextDetectorHelper
from rec_rpc_server import TextRecognizerHelper from rec_rpc_server import TextRecognizerHelper
import tools.infer.utility as utility import tools.infer.utility as utility
from tools.infer.predict_system import TextSystem from tools.infer.predict_system import TextSystem, sorted_boxes
import copy import copy
global_args = utility.parse_args() global_args = utility.parse_args()
...@@ -48,7 +48,7 @@ class TextSystemHelper(TextSystem): ...@@ -48,7 +48,7 @@ class TextSystemHelper(TextSystem):
self.text_classifier = TextClassifierHelper(args) self.text_classifier = TextClassifierHelper(args)
self.det_client = Client() self.det_client = Client()
self.det_client.load_client_config( self.det_client.load_client_config(
"ocr_det_server/serving_client_conf.prototxt") "det_db_client/serving_client_conf.prototxt")
self.det_client.connect(["127.0.0.1:9293"]) self.det_client.connect(["127.0.0.1:9293"])
self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
...@@ -57,10 +57,10 @@ class TextSystemHelper(TextSystem): ...@@ -57,10 +57,10 @@ class TextSystemHelper(TextSystem):
fetch_map = self.det_client.predict(feed, fetch) fetch_map = self.det_client.predict(feed, fetch)
outputs = [fetch_map[x] for x in fetch] outputs = [fetch_map[x] for x in fetch]
dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args) dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args)
print(dt_boxes)
if dt_boxes is None: if dt_boxes is None:
return None, None return None, None
img_crop_list = [] img_crop_list = []
sorted_boxes = SortedBoxes()
dt_boxes = sorted_boxes(dt_boxes) dt_boxes = sorted_boxes(dt_boxes)
for bno in range(len(dt_boxes)): for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno]) tmp_box = copy.deepcopy(dt_boxes[bno])
...@@ -70,6 +70,7 @@ class TextSystemHelper(TextSystem): ...@@ -70,6 +70,7 @@ class TextSystemHelper(TextSystem):
feed, fetch, self.tmp_args = self.text_classifier.preprocess( feed, fetch, self.tmp_args = self.text_classifier.preprocess(
img_crop_list) img_crop_list)
fetch_map = self.clas_client.predict(feed, fetch) fetch_map = self.clas_client.predict(feed, fetch)
print(fetch_map)
outputs = [fetch_map[x] for x in self.text_classifier.fetch] outputs = [fetch_map[x] for x in self.text_classifier.fetch]
for x in fetch_map.keys(): for x in fetch_map.keys():
if ".lod" in x: if ".lod" in x:
......
...@@ -36,8 +36,5 @@ for img_file in os.listdir(test_img_dir): ...@@ -36,8 +36,5 @@ for img_file in os.listdir(test_img_dir):
image = cv2_to_base64(image_data1) image = cv2_to_base64(image_data1)
data = {"feed": [{"image": image}], "fetch": ["res"]} data = {"feed": [{"image": image}], "fetch": ["res"]}
r = requests.post(url=url, headers=headers, data=json.dumps(data)) r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r)
rjson = r.json() rjson = r.json()
print(rjson) print(rjson)
#for x in rjson["result"]["pred_text"]:
# print(x)
...@@ -85,7 +85,6 @@ class TextRecognizerHelper(TextRecognizer): ...@@ -85,7 +85,6 @@ class TextRecognizerHelper(TextRecognizer):
rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"] rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"]
predict_lod = args["softmax_0.tmp_0.lod"] predict_lod = args["softmax_0.tmp_0.lod"]
indices = args["indices"] indices = args["indices"]
print("indices", indices, rec_idx_lod)
rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1) rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1)
for rno in range(len(rec_idx_lod) - 1): for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno] beg = rec_idx_lod[rno]
...@@ -155,7 +154,6 @@ class OCRService(WebService): ...@@ -155,7 +154,6 @@ class OCRService(WebService):
if ".lod" in x: if ".lod" in x:
self.tmp_args[x] = fetch_map[x] self.tmp_args[x] = fetch_map[x]
rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args) rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args)
print("rec_res", rec_res)
res = { res = {
"pred_text": [x[0] for x in rec_res], "pred_text": [x[0] for x in rec_res],
"score": [str(x[1]) for x in rec_res] "score": [str(x[1]) for x in rec_res]
......
...@@ -91,7 +91,6 @@ class TextRecognizerHelper(TextRecognizer): ...@@ -91,7 +91,6 @@ class TextRecognizerHelper(TextRecognizer):
rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"] rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"]
predict_lod = args["softmax_0.tmp_0.lod"] predict_lod = args["softmax_0.tmp_0.lod"]
indices = args["indices"] indices = args["indices"]
print("indices", indices, rec_idx_lod)
rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1) rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1)
for rno in range(len(rec_idx_lod) - 1): for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno] beg = rec_idx_lod[rno]
...@@ -161,7 +160,6 @@ class OCRService(WebService): ...@@ -161,7 +160,6 @@ class OCRService(WebService):
if ".lod" in x: if ".lod" in x:
self.tmp_args[x] = fetch_map[x] self.tmp_args[x] = fetch_map[x]
rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args) rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args)
print("rec_res", rec_res)
res = { res = {
"pred_text": [x[0] for x in rec_res], "pred_text": [x[0] for x in rec_res],
"score": [str(x[1]) for x in rec_res] "score": [str(x[1]) for x in rec_res]
......
...@@ -37,4 +37,3 @@ for img_file in os.listdir(test_img_dir): ...@@ -37,4 +37,3 @@ for img_file in os.listdir(test_img_dir):
data = {"feed": [{"image": image}], "fetch": ["res"]} data = {"feed": [{"image": image}], "fetch": ["res"]}
r = requests.post(url=url, headers=headers, data=json.dumps(data)) r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r.json()) print(r.json())
break
...@@ -33,7 +33,7 @@ from paddle import fluid ...@@ -33,7 +33,7 @@ from paddle import fluid
class TextClassifier(object): class TextClassifier(object):
def __init__(self, args): def __init__(self, args):
if args.use_serving is False: if args.use_pdserving is False:
self.predictor, self.input_tensor, self.output_tensors = \ self.predictor, self.input_tensor, self.output_tensors = \
utility.create_predictor(args, mode="cls") utility.create_predictor(args, mode="cls")
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")] self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
......
...@@ -75,7 +75,7 @@ class TextDetector(object): ...@@ -75,7 +75,7 @@ class TextDetector(object):
else: else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0) sys.exit(0)
if args.use_gpu is False: if args.use_pdserving is False:
self.predictor, self.input_tensor, self.output_tensors =\ self.predictor, self.input_tensor, self.output_tensors =\
utility.create_predictor(args, mode="det") utility.create_predictor(args, mode="det")
......
...@@ -34,7 +34,7 @@ from ppocr.utils.character import CharacterOps ...@@ -34,7 +34,7 @@ from ppocr.utils.character import CharacterOps
class TextRecognizer(object): class TextRecognizer(object):
def __init__(self, args): def __init__(self, args):
if args.use_serving is False: if args.use_pdserving is False:
self.predictor, self.input_tensor, self.output_tensors =\ self.predictor, self.input_tensor, self.output_tensors =\
utility.create_predictor(args, mode="rec") utility.create_predictor(args, mode="rec")
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")] self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
......
...@@ -161,7 +161,12 @@ def main(args): ...@@ -161,7 +161,12 @@ def main(args):
scores = [rec_res[i][1] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))]
draw_img = draw_ocr( draw_img = draw_ocr(
image, boxes, txts, scores, drop_score=drop_score, font_path=font_path) image,
boxes,
txts,
scores,
drop_score=drop_score,
font_path=font_path)
draw_img_save = "./inference_results/" draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save): if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save) os.makedirs(draw_img_save)
......
...@@ -37,7 +37,7 @@ def parse_args(): ...@@ -37,7 +37,7 @@ def parse_args():
parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000) parser.add_argument("--gpu_mem", type=int, default=8000)
parser.add_argument("--use_serving", type=str2bool, default=False) parser.add_argument("--use_pdserving", type=str2bool, default=False)
# params for text detector # params for text detector
parser.add_argument("--image_dir", type=str) parser.add_argument("--image_dir", type=str)
...@@ -73,9 +73,7 @@ def parse_args(): ...@@ -73,9 +73,7 @@ def parse_args():
default="./ppocr/utils/ppocr_keys_v1.txt") default="./ppocr/utils/ppocr_keys_v1.txt")
parser.add_argument("--use_space_char", type=str2bool, default=True) parser.add_argument("--use_space_char", type=str2bool, default=True)
parser.add_argument( parser.add_argument(
"--vis_font_path", "--vis_font_path", type=str, default="./doc/simfang.ttf")
type=str,
default="./doc/simfang.ttf")
# params for text classifier # params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False) parser.add_argument("--use_angle_cls", type=str2bool, default=False)
...@@ -230,8 +228,7 @@ def draw_ocr_box_txt(image, boxes, txts, font_path="./doc/simfang.ttf"): ...@@ -230,8 +228,7 @@ def draw_ocr_box_txt(image, boxes, txts, font_path="./doc/simfang.ttf"):
1])**2) 1])**2)
if box_height > 2 * box_width: if box_height > 2 * box_width:
font_size = max(int(box_width * 0.9), 10) font_size = max(int(box_width * 0.9), 10)
font = ImageFont.truetype( font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
font_path, font_size, encoding="utf-8")
cur_y = box[0][1] cur_y = box[0][1]
for c in txt: for c in txt:
char_size = font.getsize(c) char_size = font.getsize(c)
...@@ -240,8 +237,7 @@ def draw_ocr_box_txt(image, boxes, txts, font_path="./doc/simfang.ttf"): ...@@ -240,8 +237,7 @@ def draw_ocr_box_txt(image, boxes, txts, font_path="./doc/simfang.ttf"):
cur_y += char_size[1] cur_y += char_size[1]
else: else:
font_size = max(int(box_height * 0.8), 10) font_size = max(int(box_height * 0.8), 10)
font = ImageFont.truetype( font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
font_path, font_size, encoding="utf-8")
draw_right.text( draw_right.text(
[box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) [box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
img_left = Image.blend(image, img_left, 0.5) img_left = Image.blend(image, img_left, 0.5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册