提交 0c287c41 编写于 作者: W WenmuZhou

python端预测完成

上级 903b102f
...@@ -31,6 +31,8 @@ from ppocr.postprocess import build_post_process ...@@ -31,6 +31,8 @@ from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
logger = get_logger()
class TextClassifier(object): class TextClassifier(object):
def __init__(self, args): def __init__(self, args):
...@@ -147,5 +149,4 @@ def main(args): ...@@ -147,5 +149,4 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
logger = get_logger()
main(utility.parse_args()) main(utility.parse_args())
...@@ -30,6 +30,8 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif ...@@ -30,6 +30,8 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
logger = get_logger()
class TextDetector(object): class TextDetector(object):
def __init__(self, args): def __init__(self, args):
...@@ -158,9 +160,7 @@ class TextDetector(object): ...@@ -158,9 +160,7 @@ class TextDetector(object):
if __name__ == "__main__": if __name__ == "__main__":
args = utility.parse_args() args = utility.parse_args()
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
logger = get_logger()
text_detector = TextDetector(args) text_detector = TextDetector(args)
count = 0 count = 0
total_time = 0 total_time = 0
......
...@@ -13,12 +13,12 @@ ...@@ -13,12 +13,12 @@
# limitations under the License. # limitations under the License.
import os import os
import sys import sys
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import cv2 import cv2
import copy
import numpy as np import numpy as np
import math import math
import time import time
...@@ -30,6 +30,8 @@ from ppocr.postprocess import build_post_process ...@@ -30,6 +30,8 @@ from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
logger = get_logger()
class TextRecognizer(object): class TextRecognizer(object):
def __init__(self, args): def __init__(self, args):
...@@ -80,7 +82,7 @@ class TextRecognizer(object): ...@@ -80,7 +82,7 @@ class TextRecognizer(object):
# rec_res = [] # rec_res = []
rec_res = [['', 0.0]] * img_num rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num batch_num = self.rec_batch_num
predict_time = 0 elapse = 0
for beg_img_no in range(0, img_num, batch_num): for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num) end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = [] norm_img_batch = []
...@@ -110,7 +112,9 @@ class TextRecognizer(object): ...@@ -110,7 +112,9 @@ class TextRecognizer(object):
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
outputs.append(output) outputs.append(output)
preds = outputs[0] preds = outputs[0]
rec_res = self.postprocess_op(preds) rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
elapse = time.time() - starttime elapse = time.time() - starttime
return rec_res, elapse return rec_res, elapse
...@@ -147,5 +151,4 @@ def main(args): ...@@ -147,5 +151,4 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
logger = get_logger()
main(utility.parse_args()) main(utility.parse_args())
...@@ -17,20 +17,17 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) ...@@ -17,20 +17,17 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import tools.infer.utility as utility
from ppocr.utils.utility import initial_logger
logger = initial_logger()
import cv2 import cv2
import tools.infer.predict_det as predict_det
import tools.infer.predict_rec as predict_rec
import copy import copy
import numpy as np import numpy as np
import math
import time import time
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from PIL import Image from PIL import Image
import tools.infer.utility as utility
from tools.infer.utility import draw_ocr from tools.infer.utility import draw_ocr
from tools.infer.utility import draw_ocr_box_txt import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
class TextSystem(object): class TextSystem(object):
...@@ -153,11 +150,7 @@ def main(args): ...@@ -153,11 +150,7 @@ def main(args):
scores = [rec_res[i][1] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))]
draw_img = draw_ocr( draw_img = draw_ocr(
image, image, boxes, txts, scores, drop_score=drop_score)
boxes,
txts,
scores,
drop_score=drop_score)
draw_img_save = "./inference_results/" draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save): if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save) os.makedirs(draw_img_save)
...@@ -169,4 +162,5 @@ def main(args): ...@@ -169,4 +162,5 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
logger = get_logger()
main(utility.parse_args()) main(utility.parse_args())
...@@ -39,7 +39,8 @@ def parse_args(): ...@@ -39,7 +39,8 @@ def parse_args():
parser.add_argument("--image_dir", type=str) parser.add_argument("--image_dir", type=str)
parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str) parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_max_side_len", type=float, default=960) parser.add_argument("--det_limit_side_len", type=float, default=960)
parser.add_argument("--det_limit_type", type=str, default='max')
# DB parmas # DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_thresh", type=float, default=0.3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册