未验证 提交 afe8ed19 编写于 作者: 天涯古巷's avatar 天涯古巷 提交者: GitHub

Merge branch 'PaddlePaddle:release/2.3' into release/2.3

...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# 下载超轻量中文检测模型: # 下载超轻量中文检测模型:
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar
tar xf ch_PP-OCRv2_det_infer.tar tar xf ch_PP-OCRv2_det_infer.tar
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./ch_PP-OCRv2_det_infer.tar/" python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./ch_PP-OCRv2_det_infer/"
``` ```
......
...@@ -33,7 +33,7 @@ ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/dataset ...@@ -33,7 +33,7 @@ ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/dataset
mklink /d <path/to/paddle_ocr>/train_data/dataset <path/to/dataset> mklink /d <path/to/paddle_ocr>/train_data/dataset <path/to/dataset>
``` ```
<a name="准备数据集"></a> <a name="自定义数据集"></a>
### 1.1 自定义数据集 ### 1.1 自定义数据集
下面以通用数据集为例, 介绍如何准备数据集: 下面以通用数据集为例, 介绍如何准备数据集:
......
...@@ -22,7 +22,7 @@ class NRTRLoss(nn.Layer): ...@@ -22,7 +22,7 @@ class NRTRLoss(nn.Layer):
log_prb = F.log_softmax(pred, axis=1) log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal( non_pad_mask = paddle.not_equal(
tgt, paddle.zeros( tgt, paddle.zeros(
tgt.shape, dtype='int64')) tgt.shape, dtype=tgt.dtype))
loss = -(one_hot * log_prb).sum(axis=1) loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean() loss = loss.masked_select(non_pad_mask).mean()
else: else:
......
...@@ -168,14 +168,6 @@ class NRTRLabelDecode(BaseRecLabelDecode): ...@@ -168,14 +168,6 @@ class NRTRLabelDecode(BaseRecLabelDecode):
character_type, use_space_char) character_type, use_space_char)
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
if preds.dtype == paddle.int64:
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
if preds[0][0]==2:
preds_idx = preds[:,1:]
else:
preds_idx = preds
if len(preds) == 2: if len(preds) == 2:
preds_id = preds[0] preds_id = preds[0]
preds_prob = preds[1] preds_prob = preds[1]
......
...@@ -51,7 +51,7 @@ def get_image_file_list(img_file): ...@@ -51,7 +51,7 @@ def get_image_file_list(img_file):
if img_file is None or not os.path.exists(img_file): if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file)) raise Exception("not found any img file in {}".format(img_file))
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'} img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF','webp','ppm'}
if os.path.isfile(img_file) and imghdr.what(img_file) in img_end: if os.path.isfile(img_file) and imghdr.what(img_file) in img_end:
imgs_lists.append(img_file) imgs_lists.append(img_file)
elif os.path.isdir(img_file): elif os.path.isdir(img_file):
...@@ -77,4 +77,4 @@ def check_and_read_gif(img_path): ...@@ -77,4 +77,4 @@ def check_and_read_gif(img_path):
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
imgvalue = frame[:, :, ::-1] imgvalue = frame[:, :, ::-1]
return imgvalue, True return imgvalue, True
return None, False return None, False
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册