未验证 提交 7dccfe57 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

improve system prediction and remove some hard code (#4643)

* fix center yaml

* rm init_center param

* fix typo

* improve pred system
上级 1c2c2698
...@@ -62,8 +62,7 @@ Loss: ...@@ -62,8 +62,7 @@ Loss:
weight: 0.05 weight: 0.05
num_classes: 6625 num_classes: 6625
feat_dim: 96 feat_dim: 96
init_center: false center_file_path:
center_file_path: "./train_center.pkl"
# you can also try to add ace loss on your own dataset # you can also try to add ace loss on your own dataset
# - ACELoss: # - ACELoss:
# weight: 0.1 # weight: 0.1
......
...@@ -33,8 +33,8 @@ PaddleOCR提供的可下载模型包括`推理模型`、`训练模型`、`预训 ...@@ -33,8 +33,8 @@ PaddleOCR提供的可下载模型包括`推理模型`、`训练模型`、`预训
|模型名称|模型简介|配置文件|推理模型大小|下载地址| |模型名称|模型简介|配置文件|推理模型大小|下载地址|
| --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- |
|ch_PP-OCRv2_det_slim|【最新】slim量化+蒸馏版超轻量模型,支持中英文、多语种文本检测|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml)| 3M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)| |ch_PP-OCRv2_det_slim|【最新】slim量化+蒸馏版超轻量模型,支持中英文、多语种文本检测|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml)| 3M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)|
|ch_PP-OCRv2_det|【最新】原始超轻量模型,支持中英文、多语种文本检测|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml)|3M|[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar)| |ch_PP-OCRv2_det|【最新】原始超轻量模型,支持中英文、多语种文本检测|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml)|3M|[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar)|
|ch_ppocr_mobile_slim_v2.0_det|slim裁剪版超轻量模型,支持中英文、多语种文本检测|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)| 2.6M |[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar)| |ch_ppocr_mobile_slim_v2.0_det|slim裁剪版超轻量模型,支持中英文、多语种文本检测|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)| 2.6M |[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar)|
|ch_ppocr_mobile_v2.0_det|原始超轻量模型,支持中英文、多语种文本检测|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)|3M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar)| |ch_ppocr_mobile_v2.0_det|原始超轻量模型,支持中英文、多语种文本检测|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)|3M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar)|
|ch_ppocr_server_v2.0_det|通用模型,支持中英文、多语种文本检测,比超轻量模型更大,但效果更好|[ch_det_res18_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml)|47M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar)| |ch_ppocr_server_v2.0_det|通用模型,支持中英文、多语种文本检测,比超轻量模型更大,但效果更好|[ch_det_res18_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml)|47M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar)|
......
...@@ -29,8 +29,8 @@ Relationship of the above models is as follows. ...@@ -29,8 +29,8 @@ Relationship of the above models is as follows.
|model name|description|config|model size|download| |model name|description|config|model size|download|
| --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- |
|ch_PP-OCRv2_det_slim|[New] slim quantization with distillation lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml)| 3M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)| |ch_PP-OCRv2_det_slim|[New] slim quantization with distillation lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml)| 3M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)|
|ch_PP-OCRv2_det|[New] Original lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml)|3M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar)| |ch_PP-OCRv2_det|[New] Original lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml)|3M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar)|
|ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)|2.6M |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar)| |ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)|2.6M |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar)|
|ch_ppocr_mobile_v2.0_det|Original lightweight model, supporting Chinese, English, multilingual text detection|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)|3M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar)| |ch_ppocr_mobile_v2.0_det|Original lightweight model, supporting Chinese, English, multilingual text detection|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)|3M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar)|
|ch_ppocr_server_v2.0_det|General model, which is larger than the lightweight model, but achieved better performance|[ch_det_res18_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml)|47M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar)| |ch_ppocr_server_v2.0_det|General model, which is larger than the lightweight model, but achieved better performance|[ch_det_res18_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml)|47M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar)|
......
...@@ -30,21 +30,17 @@ class CenterLoss(nn.Layer): ...@@ -30,21 +30,17 @@ class CenterLoss(nn.Layer):
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
""" """
def __init__(self, def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None):
num_classes=6625,
feat_dim=96,
init_center=False,
center_file_path=None):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.feat_dim = feat_dim self.feat_dim = feat_dim
self.centers = paddle.randn( self.centers = paddle.randn(
shape=[self.num_classes, self.feat_dim]).astype("float64") shape=[self.num_classes, self.feat_dim]).astype("float64")
if init_center: if center_file_path is not None:
assert os.path.exists( assert os.path.exists(
center_file_path center_file_path
), f"center path({center_file_path}) must exist when init_center is set as True." ), f"center path({center_file_path}) must exist when it is not None."
with open(center_file_path, 'rb') as f: with open(center_file_path, 'rb') as f:
char_dict = pickle.load(f) char_dict = pickle.load(f)
for key in char_dict.keys(): for key in char_dict.keys():
......
...@@ -49,11 +49,19 @@ class TextSystem(object): ...@@ -49,11 +49,19 @@ class TextSystem(object):
if self.use_angle_cls: if self.use_angle_cls:
self.text_classifier = predict_cls.TextClassifier(args) self.text_classifier = predict_cls.TextClassifier(args)
def print_draw_crop_rec_res(self, img_crop_list, rec_res): self.args = args
self.crop_image_res_index = 0
def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
os.makedirs(output_dir, exist_ok=True)
bbox_num = len(img_crop_list) bbox_num = len(img_crop_list)
for bno in range(bbox_num): for bno in range(bbox_num):
cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno]) cv2.imwrite(
logger.info(bno, rec_res[bno]) os.path.join(output_dir,
f"mg_crop_{bno+self.crop_image_res_index}.jpg"),
img_crop_list[bno])
logger.debug(f"{bno}, {rec_res[bno]}")
self.crop_image_res_index += bbox_num
def __call__(self, img, cls=True): def __call__(self, img, cls=True):
ori_im = img.copy() ori_im = img.copy()
...@@ -80,7 +88,9 @@ class TextSystem(object): ...@@ -80,7 +88,9 @@ class TextSystem(object):
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.text_recognizer(img_crop_list)
logger.debug("rec_res num : {}, elapse : {}".format( logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse)) len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res) if self.args.save_crop_res:
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
rec_res)
filter_boxes, filter_rec_res = [], [] filter_boxes, filter_rec_res = [], []
for box, rec_reuslt in zip(dt_boxes, rec_res): for box, rec_reuslt in zip(dt_boxes, rec_res):
text, score = rec_reuslt text, score = rec_reuslt
...@@ -135,17 +145,17 @@ def main(args): ...@@ -135,17 +145,17 @@ def main(args):
if not flag: if not flag:
img = cv2.imread(image_file) img = cv2.imread(image_file)
if img is None: if img is None:
logger.info("error in loading image:{}".format(image_file)) logger.debug("error in loading image:{}".format(image_file))
continue continue
starttime = time.time() starttime = time.time()
dt_boxes, rec_res = text_sys(img) dt_boxes, rec_res = text_sys(img)
elapse = time.time() - starttime elapse = time.time() - starttime
total_time += elapse total_time += elapse
logger.info( logger.debug(
str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)) str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse))
for text, score in rec_res: for text, score in rec_res:
logger.info("{}, {:.3f}".format(text, score)) logger.debug("{}, {:.3f}".format(text, score))
if is_visualize: if is_visualize:
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
...@@ -160,19 +170,17 @@ def main(args): ...@@ -160,19 +170,17 @@ def main(args):
scores, scores,
drop_score=drop_score, drop_score=drop_score,
font_path=font_path) font_path=font_path)
draw_img_save = "./inference_results/" draw_img_save_dir = args.draw_img_save_dir
if not os.path.exists(draw_img_save): os.makedirs(draw_img_save_dir, exist_ok=True)
os.makedirs(draw_img_save)
if flag: if flag:
image_file = image_file[:-3] + "png" image_file = image_file[:-3] + "png"
cv2.imwrite( cv2.imwrite(
os.path.join(draw_img_save, os.path.basename(image_file)), os.path.join(draw_img_save_dir, os.path.basename(image_file)),
draw_img[:, :, ::-1]) draw_img[:, :, ::-1])
logger.info("The visualized image saved in {}".format( logger.debug("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file)))) os.path.join(draw_img_save_dir, os.path.basename(image_file))))
logger.info("The predict total time is {}".format(time.time() - _st)) logger.info("The predict total time is {}".format(time.time() - _st))
logger.info("\nThe predict total time is {}".format(total_time))
if args.benchmark: if args.benchmark:
text_sys.text_detector.autolog.report() text_sys.text_detector.autolog.report()
text_sys.text_recognizer.autolog.report() text_sys.text_recognizer.autolog.report()
......
...@@ -110,7 +110,13 @@ def init_args(): ...@@ -110,7 +110,13 @@ def init_args():
parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--cpu_threads", type=int, default=10) parser.add_argument("--cpu_threads", type=int, default=10)
parser.add_argument("--use_pdserving", type=str2bool, default=False) parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--warmup", type=str2bool, default=True) parser.add_argument("--warmup", type=str2bool, default=False)
#
parser.add_argument(
"--draw_img_save_dir", type=str, default="./inference_results")
parser.add_argument("--save_crop_res", type=str2bool, default=False)
parser.add_argument("--crop_res_save_dir", type=str, default="./output")
# multi-process # multi-process
parser.add_argument("--use_mp", type=str2bool, default=False) parser.add_argument("--use_mp", type=str2bool, default=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册