提交 c4720557 编写于 作者: W wangjiawei04

fix all minor bugs

上级 7cacfc97
......@@ -117,7 +117,7 @@ class OCRService(WebService):
if __name__ == "__main__":
ocr_service = OCRService(name="ocr")
ocr_service.load_model_config("cls_server")
ocr_service.load_model_config(global_args.cls_model_dir)
ocr_service.init_rec()
if global_args.use_gpu:
ocr_service.prepare_server(
......
......@@ -37,4 +37,3 @@ for img_file in os.listdir(test_img_dir):
data = {"feed": [{"image": image}], "fetch": ["res"]}
r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r.json())
break
......@@ -96,7 +96,7 @@ class DetService(WebService):
if __name__ == "__main__":
ocr_service = DetService(name="ocr")
ocr_service.load_model_config("serving_server_dir")
ocr_service.load_model_config(global_args.det_model_dir)
ocr_service.init_det()
if global_args.use_gpu:
ocr_service.prepare_server(
......
......@@ -79,7 +79,6 @@ class TextDetectorHelper(TextDetector):
class DetService(WebService):
def init_det(self):
self.text_detector = TextDetectorHelper(global_args)
print("init finish")
def preprocess(self, feed=[], fetch=[]):
data = base64.b64decode(feed[0]["image"].encode('utf8'))
......@@ -96,7 +95,7 @@ class DetService(WebService):
if __name__ == "__main__":
ocr_service = DetService(name="ocr")
ocr_service.load_model_config("serving_server_dir")
ocr_service.load_model_config(global_args.det_model_dir)
ocr_service.init_det()
if global_args.use_gpu:
ocr_service.prepare_server(
......
......@@ -44,17 +44,16 @@ class TextSystemHelper(TextSystem):
if self.use_angle_cls:
self.clas_client = Debugger()
self.clas_client.load_model_config(
"ocr_clas_server", gpu=True, profile=False)
global_args.cls_model_dir, gpu=True, profile=False)
self.text_classifier = TextClassifierHelper(args)
self.det_client = Debugger()
self.det_client.load_model_config(
"serving_server_dir", gpu=True, profile=False)
global_args.det_model_dir, gpu=True, profile=False)
self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
def preprocess(self, img):
feed, fetch, self.tmp_args = self.text_detector.preprocess(img)
fetch_map = self.det_client.predict(feed, fetch)
print("det fetch_map", fetch_map)
outputs = [fetch_map[x] for x in fetch]
dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args)
if dt_boxes is None:
......@@ -90,12 +89,10 @@ class OCRService(WebService):
def preprocess(self, feed=[], fetch=[]):
# TODO: to handle batch rec images
print("start preprocess")
data = base64.b64decode(feed[0]["image"].encode('utf8'))
data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
feed, fetch, self.tmp_args = self.text_system.preprocess(im)
print("ocr preprocess done")
return feed, fetch
def postprocess(self, feed={}, fetch=[], fetch_map=None):
......
......@@ -25,7 +25,7 @@ from clas_rpc_server import TextClassifierHelper
from det_rpc_server import TextDetectorHelper
from rec_rpc_server import TextRecognizerHelper
import tools.infer.utility as utility
from tools.infer.predict_system import TextSystem
from tools.infer.predict_system import TextSystem, sorted_boxes
import copy
global_args = utility.parse_args()
......@@ -48,7 +48,7 @@ class TextSystemHelper(TextSystem):
self.text_classifier = TextClassifierHelper(args)
self.det_client = Client()
self.det_client.load_client_config(
"ocr_det_server/serving_client_conf.prototxt")
"det_db_client/serving_client_conf.prototxt")
self.det_client.connect(["127.0.0.1:9293"])
self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
......@@ -57,10 +57,10 @@ class TextSystemHelper(TextSystem):
fetch_map = self.det_client.predict(feed, fetch)
outputs = [fetch_map[x] for x in fetch]
dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args)
print(dt_boxes)
if dt_boxes is None:
return None, None
img_crop_list = []
sorted_boxes = SortedBoxes()
dt_boxes = sorted_boxes(dt_boxes)
for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno])
......@@ -70,6 +70,7 @@ class TextSystemHelper(TextSystem):
feed, fetch, self.tmp_args = self.text_classifier.preprocess(
img_crop_list)
fetch_map = self.clas_client.predict(feed, fetch)
print(fetch_map)
outputs = [fetch_map[x] for x in self.text_classifier.fetch]
for x in fetch_map.keys():
if ".lod" in x:
......
......@@ -36,8 +36,5 @@ for img_file in os.listdir(test_img_dir):
image = cv2_to_base64(image_data1)
data = {"feed": [{"image": image}], "fetch": ["res"]}
r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r)
rjson = r.json()
print(rjson)
#for x in rjson["result"]["pred_text"]:
# print(x)
......@@ -85,7 +85,6 @@ class TextRecognizerHelper(TextRecognizer):
rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"]
predict_lod = args["softmax_0.tmp_0.lod"]
indices = args["indices"]
print("indices", indices, rec_idx_lod)
rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1)
for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno]
......@@ -155,7 +154,6 @@ class OCRService(WebService):
if ".lod" in x:
self.tmp_args[x] = fetch_map[x]
rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args)
print("rec_res", rec_res)
res = {
"pred_text": [x[0] for x in rec_res],
"score": [str(x[1]) for x in rec_res]
......
......@@ -91,7 +91,6 @@ class TextRecognizerHelper(TextRecognizer):
rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"]
predict_lod = args["softmax_0.tmp_0.lod"]
indices = args["indices"]
print("indices", indices, rec_idx_lod)
rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1)
for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno]
......@@ -161,7 +160,6 @@ class OCRService(WebService):
if ".lod" in x:
self.tmp_args[x] = fetch_map[x]
rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args)
print("rec_res", rec_res)
res = {
"pred_text": [x[0] for x in rec_res],
"score": [str(x[1]) for x in rec_res]
......
......@@ -37,4 +37,3 @@ for img_file in os.listdir(test_img_dir):
data = {"feed": [{"image": image}], "fetch": ["res"]}
r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r.json())
break
......@@ -33,7 +33,7 @@ from paddle import fluid
class TextClassifier(object):
def __init__(self, args):
if args.use_serving is False:
if args.use_pdserving is False:
self.predictor, self.input_tensor, self.output_tensors = \
utility.create_predictor(args, mode="cls")
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
......
......@@ -75,7 +75,7 @@ class TextDetector(object):
else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0)
if args.use_gpu is False:
if args.use_pdserving is False:
self.predictor, self.input_tensor, self.output_tensors =\
utility.create_predictor(args, mode="det")
......
......@@ -34,7 +34,7 @@ from ppocr.utils.character import CharacterOps
class TextRecognizer(object):
def __init__(self, args):
if args.use_serving is False:
if args.use_pdserving is False:
self.predictor, self.input_tensor, self.output_tensors =\
utility.create_predictor(args, mode="rec")
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
......
......@@ -161,7 +161,12 @@ def main(args):
scores = [rec_res[i][1] for i in range(len(rec_res))]
draw_img = draw_ocr(
image, boxes, txts, scores, drop_score=drop_score, font_path=font_path)
image,
boxes,
txts,
scores,
drop_score=drop_score,
font_path=font_path)
draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
......
......@@ -37,7 +37,7 @@ def parse_args():
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000)
parser.add_argument("--use_serving", type=str2bool, default=False)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
# params for text detector
parser.add_argument("--image_dir", type=str)
......@@ -73,9 +73,7 @@ def parse_args():
default="./ppocr/utils/ppocr_keys_v1.txt")
parser.add_argument("--use_space_char", type=str2bool, default=True)
parser.add_argument(
"--vis_font_path",
type=str,
default="./doc/simfang.ttf")
"--vis_font_path", type=str, default="./doc/simfang.ttf")
# params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
......@@ -230,8 +228,7 @@ def draw_ocr_box_txt(image, boxes, txts, font_path="./doc/simfang.ttf"):
1])**2)
if box_height > 2 * box_width:
font_size = max(int(box_width * 0.9), 10)
font = ImageFont.truetype(
font_path, font_size, encoding="utf-8")
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
cur_y = box[0][1]
for c in txt:
char_size = font.getsize(c)
......@@ -240,8 +237,7 @@ def draw_ocr_box_txt(image, boxes, txts, font_path="./doc/simfang.ttf"):
cur_y += char_size[1]
else:
font_size = max(int(box_height * 0.8), 10)
font = ImageFont.truetype(
font_path, font_size, encoding="utf-8")
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
draw_right.text(
[box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
img_left = Image.blend(image, img_left, 0.5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册