diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index c8fc280d80056395bbc841a973004b06844b1214..19d7a69c7fb08a8e7fb36c3043aa211de19b9295 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -28,7 +28,9 @@ PaddleOCR开源的文本检测算法列表: | --- | --- | --- | --- | --- | --- | |SAST|ResNet50_vd|89.63%|78.44%|83.66%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)| -**说明:** SAST模型训练额外加入了icdar2013、icdar2017、COCO-Text、ArT等公开数据集进行调优。PaddleOCR用到的经过整理格式的英文公开数据集下载:[百度云地址](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (提取码: 2bpi) +**说明:** SAST模型训练额外加入了icdar2013、icdar2017、COCO-Text、ArT等公开数据集进行调优。PaddleOCR用到的经过整理格式的英文公开数据集下载: +* [百度云地址](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (提取码: 2bpi) +* [Google Drive下载地址](https://drive.google.com/drive/folders/1ll2-XEVyCQLpJjawLDiRlvo_i4BqHCJe?usp=sharing) PaddleOCR文本检测算法的训练和使用请参考文档教程中[模型训练/评估中的文本检测部分](./detection.md)。 diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index 77b9642e3b880547b1df6620d931689982db6d29..d70f99bb5c5b0bdcb7d39209dfc9a77c56918260 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -31,7 +31,9 @@ On Total-Text dataset, the text detection result is as follows: | --- | --- | --- | --- | --- | --- | |SAST|ResNet50_vd|89.63%|78.44%|83.66%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)| -**Note:** Additional data, like icdar2013, icdar2017, COCO-Text, ArT, was added to the model training of SAST. Download English public dataset in organized format used by PaddleOCR from [Baidu Drive](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (download code: 2bpi). +**Note:** Additional data, like icdar2013, icdar2017, COCO-Text, ArT, was added to the model training of SAST. Download English public dataset in organized format used by PaddleOCR from: +* [Baidu Drive](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (download code: 2bpi). +* [Google Drive](https://drive.google.com/drive/folders/1ll2-XEVyCQLpJjawLDiRlvo_i4BqHCJe?usp=sharing) For the training guide and use of PaddleOCR text detection algorithms, please refer to the document [Text detection model training/evaluation/prediction](./detection_en.md) diff --git a/paddleocr.py b/paddleocr.py index 7c126261eff1168a1888d72f71fb284e347f9ec9..c3741b264503534ef3e64531c2576273d8ccfd11 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -236,7 +236,9 @@ class PaddleOCR(predict_system.TextSystem): assert lang in model_urls[ 'rec'], 'param lang must in {}, but got {}'.format( model_urls['rec'].keys(), lang) + use_inner_dict = False if postprocess_params.rec_char_dict_path is None: + use_inner_dict = True postprocess_params.rec_char_dict_path = model_urls['rec'][lang][ 'dict_path'] @@ -263,9 +265,9 @@ class PaddleOCR(predict_system.TextSystem): if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL: logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL)) sys.exit(0) - - postprocess_params.rec_char_dict_path = str( - Path(__file__).parent / postprocess_params.rec_char_dict_path) + if use_inner_dict: + postprocess_params.rec_char_dict_path = str( + Path(__file__).parent / postprocess_params.rec_char_dict_path) # init det_model and rec_model super().__init__(postprocess_params) @@ -282,8 +284,13 @@ class PaddleOCR(predict_system.TextSystem): if isinstance(img, list) and det == True: logger.error('When input a list of images, det must be false') exit(0) + if cls == False: + self.use_angle_cls = False + elif cls == True and self.use_angle_cls == False: + logger.warning( + 'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process' + ) - self.use_angle_cls = cls if isinstance(img, str): # download net image if img.startswith('http'): diff --git a/ppocr/data/imaug/randaugment.py b/ppocr/data/imaug/randaugment.py index 0bfac353906535464eaa6637c3edbc7f0c938502..56f114d2f665f9b326e96819ac3d606c87a6e142 100644 --- a/ppocr/data/imaug/randaugment.py +++ b/ppocr/data/imaug/randaugment.py @@ -117,13 +117,16 @@ class RawRandAugment(object): class RandAugment(RawRandAugment): """ RandAugment wrapper to auto fit different img types """ - def __init__(self, *args, **kwargs): + def __init__(self, prob=0.5, *args, **kwargs): + self.prob = prob if six.PY2: super(RandAugment, self).__init__(*args, **kwargs) else: super().__init__(*args, **kwargs) def __call__(self, data): + if np.random.rand() > self.prob: + return data img = data['image'] if not isinstance(img, Image.Image): img = np.ascontiguousarray(img) diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py index 0d222714ff7edebfc717daa81d48ce7424dfbd03..4286d7691d1abcf80c283d1c1ab76f8cd1f4a634 100644 --- a/ppocr/modeling/heads/rec_att_head.py +++ b/ppocr/modeling/heads/rec_att_head.py @@ -38,7 +38,7 @@ class AttentionHead(nn.Layer): return input_ont_hot def forward(self, inputs, targets=None, batch_max_length=25): - batch_size = inputs.shape[0] + batch_size = paddle.shape(inputs)[0] num_steps = batch_max_length hidden = paddle.zeros((batch_size, self.hidden_size)) diff --git a/setup.py b/setup.py index 58f6de48548d494a7fde8528130b8e881bc7620d..70400df484128ba751da5f97503cc7f84e260d86 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ setup( package_dir={'paddleocr': ''}, include_package_data=True, entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]}, - version='2.0.2', + version='2.0.3', install_requires=requirements, license='Apache License 2.0', description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices', diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index 074172cc947cdc03b21392cf7b109971763f796a..d2592c6c95b0f466ea3ad5b45a35781282c9a492 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -98,10 +98,10 @@ class TextClassifier(object): norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = norm_img_batch.copy() starttime = time.time() - self.input_tensor.copy_from_cpu(norm_img_batch) self.predictor.run() prob_out = self.output_tensors[0].copy_to_cpu() + self.predictor.try_shrink_memory() cls_result = self.postprocess_op(prob_out) elapse += time.time() - starttime for rno in range(len(cls_result)): diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index b14825bdd8bad55b709d84bdf6df6575d90c7d95..f5ea0504f97f3e40853d431061f7086653f2628e 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -180,7 +180,7 @@ class TextDetector(object): preds['maps'] = outputs[0] else: raise NotImplementedError - + self.predictor.try_shrink_memory() post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] if self.det_algorithm == "SAST" and self.det_sast_polygon: diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index b24e57dd973bc0216f2875232bcec6e36ab47e29..1cb6e01b087ff98efb0a57be3cc58a79425fea57 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -237,7 +237,7 @@ class TextRecognizer(object): output = output_tensor.copy_to_cpu() outputs.append(output) preds = outputs[0] - + self.predictor.try_shrink_memory() rec_result = self.postprocess_op(preds) for rno in range(len(rec_result)): rec_res[indices[beg_img_no + rno]] = rec_result[rno] diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 6cb075e8be639fffbbc5376b2fd8c6ce3597e4aa..9019f003b44d9ecb69ed390fba8cc97d4d074cd5 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -143,7 +143,8 @@ def create_predictor(args, mode, logger): #config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'}) args.rec_batch_num = 1 - # config.enable_memory_optim() + # enable memory optim + config.enable_memory_optim() config.disable_glog_info() config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")