diff --git a/doc/doc_ch/inference_ppocr.md b/doc/doc_ch/inference_ppocr.md index 2b447701c3f10641d66f6bb65488a7b21c2d6450..3e46f17d3a781839dfe5e632f85aabcd03d0fd17 100644 --- a/doc/doc_ch/inference_ppocr.md +++ b/doc/doc_ch/inference_ppocr.md @@ -20,7 +20,7 @@ # 下载超轻量中文检测模型: wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar tar xf ch_PP-OCRv2_det_infer.tar -python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./ch_PP-OCRv2_det_infer.tar/" +python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./ch_PP-OCRv2_det_infer/" ``` diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md index d57d6bb152fb3de80909cb04f1b9b10a70ebd0b7..b2d33346fe4f401cc8a02f777605851ceb7329d5 100644 --- a/doc/doc_ch/recognition.md +++ b/doc/doc_ch/recognition.md @@ -33,7 +33,7 @@ ln -sf /train_data/dataset mklink /d /train_data/dataset ``` - + ### 1.1 自定义数据集 下面以通用数据集为例, 介绍如何准备数据集: diff --git a/ppocr/losses/rec_nrtr_loss.py b/ppocr/losses/rec_nrtr_loss.py index 41714dd2a3ae15eeedc62521d97935f68271c598..200a6d0486dbf6f76dd674eb58f641b31a70f31c 100644 --- a/ppocr/losses/rec_nrtr_loss.py +++ b/ppocr/losses/rec_nrtr_loss.py @@ -22,7 +22,7 @@ class NRTRLoss(nn.Layer): log_prb = F.log_softmax(pred, axis=1) non_pad_mask = paddle.not_equal( tgt, paddle.zeros( - tgt.shape, dtype='int64')) + tgt.shape, dtype=tgt.dtype)) loss = -(one_hot * log_prb).sum(axis=1) loss = loss.masked_select(non_pad_mask).mean() else: diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 07efd972008bd37e7fd46549b58c1ce58a48cbc7..c0d8bab5fa689800b4b2b235c94501bc15284346 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -168,14 +168,6 @@ class NRTRLabelDecode(BaseRecLabelDecode): character_type, use_space_char) def __call__(self, preds, label=None, *args, **kwargs): - if preds.dtype == paddle.int64: - if isinstance(preds, paddle.Tensor): - preds = preds.numpy() - if preds[0][0]==2: - preds_idx = preds[:,1:] - else: - preds_idx = preds - if len(preds) == 2: preds_id = preds[0] preds_prob = preds[1] diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py index 7bb4c906d298af54ed56e2805f487a2c22d1894b..7d1faa3bd6f881bed60a4d43131b45d7f3a3b9fb 100755 --- a/ppocr/utils/utility.py +++ b/ppocr/utils/utility.py @@ -51,7 +51,7 @@ def get_image_file_list(img_file): if img_file is None or not os.path.exists(img_file): raise Exception("not found any img file in {}".format(img_file)) - img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'} + img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF','webp','ppm'} if os.path.isfile(img_file) and imghdr.what(img_file) in img_end: imgs_lists.append(img_file) elif os.path.isdir(img_file): @@ -77,4 +77,4 @@ def check_and_read_gif(img_path): frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) imgvalue = frame[:, :, ::-1] return imgvalue, True - return None, False \ No newline at end of file + return None, False