未验证 提交 edd18b67 编写于 作者: u010070587's avatar u010070587 提交者: GitHub

Merge pull request #7206 from andyjpaddle/fix_vl

Fix visionlan predict format
...@@ -101,7 +101,7 @@ python3 tools/export_model.py -c configs/rec/rec_r45_visionlan.yml -o Global.pre ...@@ -101,7 +101,7 @@ python3 tools/export_model.py -c configs/rec/rec_r45_visionlan.yml -o Global.pre
执行如下命令进行模型推理: 执行如下命令进行模型推理:
```shell ```shell
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' --rec_model_dir='./inference/rec_r45_visionlan/' --rec_algorithm='VisionLAN' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/dict36.txt' python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' --rec_model_dir='./inference/rec_r45_visionlan/' --rec_algorithm='VisionLAN' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/ic15_dict.txt' --use_space_char=False
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。 # 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。
``` ```
...@@ -110,7 +110,7 @@ python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' ...@@ -110,7 +110,7 @@ python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png'
执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下: 执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
结果如下: 结果如下:
```shell ```shell
Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.97076982) Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.9999493)
``` ```
**注意** **注意**
......
...@@ -90,7 +90,7 @@ After the conversion is successful, there are three files in the directory: ...@@ -90,7 +90,7 @@ After the conversion is successful, there are three files in the directory:
For VisionLAN text recognition model inference, the following commands can be executed: For VisionLAN text recognition model inference, the following commands can be executed:
``` ```
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' --rec_model_dir='./inference/rec_r45_visionlan/' --rec_algorithm='VisionLAN' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/dict36.txt' python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' --rec_model_dir='./inference/rec_r45_visionlan/' --rec_algorithm='VisionLAN' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/ic15_dict.txt' --use_space_char=False
``` ```
![](../imgs_words/en/word_2.png) ![](../imgs_words/en/word_2.png)
...@@ -98,7 +98,7 @@ python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' ...@@ -98,7 +98,7 @@ python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png'
After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows: After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows:
The result is as follows: The result is as follows:
```shell ```shell
Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.97076982) Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.9999493)
``` ```
<a name="4-2"></a> <a name="4-2"></a>
......
...@@ -67,7 +67,7 @@ def build_loss(config): ...@@ -67,7 +67,7 @@ def build_loss(config):
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss','StrokeFocusLoss' 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
...@@ -780,7 +780,7 @@ class VLLabelDecode(BaseRecLabelDecode): ...@@ -780,7 +780,7 @@ class VLLabelDecode(BaseRecLabelDecode):
) + length[i])].topk(1)[0][:, 0] ) + length[i])].topk(1)[0][:, 0]
preds_prob = paddle.exp( preds_prob = paddle.exp(
paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6)) paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
text.append((preds_text, preds_prob)) text.append((preds_text, preds_prob.numpy()[0]))
if label is None: if label is None:
return text return text
label = self.decode(label) label = self.decode(label)
......
...@@ -490,7 +490,7 @@ def eval(model, ...@@ -490,7 +490,7 @@ def eval(model,
break break
images = batch[0] images = batch[0]
start = time.time() start = time.time()
# use amp # use amp
if scaler: if scaler:
with paddle.amp.auto_cast(level='O2'): with paddle.amp.auto_cast(level='O2'):
...@@ -508,10 +508,10 @@ def eval(model, ...@@ -508,10 +508,10 @@ def eval(model,
1, 2, 0).astype(np.uint8) 1, 2, 0).astype(np.uint8)
fm_lr = (lr_img[i].numpy() * 255).transpose( fm_lr = (lr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8) 1, 2, 0).astype(np.uint8)
cv2.imwrite("output/images/{}_{}_sr.jpg".format(sum_images, cv2.imwrite("output/images/{}_{}_sr.jpg".format(
i), fm_sr) sum_images, i), fm_sr)
cv2.imwrite("output/images/{}_{}_lr.jpg".format(sum_images, cv2.imwrite("output/images/{}_{}_lr.jpg".format(
i), fm_lr) sum_images, i), fm_lr)
else: else:
preds = model(images) preds = model(images)
else: else:
...@@ -529,10 +529,10 @@ def eval(model, ...@@ -529,10 +529,10 @@ def eval(model,
1, 2, 0).astype(np.uint8) 1, 2, 0).astype(np.uint8)
fm_lr = (lr_img[i].numpy() * 255).transpose( fm_lr = (lr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8) 1, 2, 0).astype(np.uint8)
cv2.imwrite("output/images/{}_{}_sr.jpg".format(sum_images, cv2.imwrite("output/images/{}_{}_sr.jpg".format(
i), fm_sr) sum_images, i), fm_sr)
cv2.imwrite("output/images/{}_{}_lr.jpg".format(sum_images, cv2.imwrite("output/images/{}_{}_lr.jpg".format(
i), fm_lr) sum_images, i), fm_lr)
else: else:
preds = model(images) preds = model(images)
...@@ -652,7 +652,7 @@ def preprocess(is_train=False): ...@@ -652,7 +652,7 @@ def preprocess(is_train=False):
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
'Gestalt' 'Gestalt'
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册