Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
d2c11969
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d2c11969
编写于
4月 28, 2022
作者:
T
Topdu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rm svtrlabeldecode resize
上级
ac703e56
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
14 addition
and
125 deletion
+14
-125
doc/doc_ch/algorithm_rec_nrtr.md
doc/doc_ch/algorithm_rec_nrtr.md
+3
-3
doc/doc_ch/algorithm_rec_svtr.md
doc/doc_ch/algorithm_rec_svtr.md
+8
-9
ppocr/data/imaug/__init__.py
ppocr/data/imaug/__init__.py
+1
-1
ppocr/data/imaug/rec_img_aug.py
ppocr/data/imaug/rec_img_aug.py
+0
-71
ppocr/modeling/transforms/stn.py
ppocr/modeling/transforms/stn.py
+0
-2
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+2
-2
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+0
-37
未找到文件。
doc/doc_ch/algorithm_rec_nrtr.md
浏览文件 @
d2c11969
...
@@ -62,7 +62,7 @@ python3 tools/eval.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model
...
@@ -62,7 +62,7 @@ python3 tools/eval.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model
使用如下命令进行单张图片预测:
使用如下命令进行单张图片预测:
```
shell
```
shell
# 注意将pretrained_model的路径设置为本地路径。
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/infer_rec.py
-c
configs/rec/rec_mtb_nrtr.yml
-o
Global.infer_img
=
'./doc/imgs_words_en/word_10.png'
Global.pretrained_model
=
./rec_mtb_nrtr_train/best_accuracy
Global.load_static_weights
=
false
python3 tools/infer_rec.py
-c
configs/rec/rec_mtb_nrtr.yml
-o
Global.infer_img
=
'./doc/imgs_words_en/word_10.png'
Global.pretrained_model
=
./rec_mtb_nrtr_train/best_accuracy
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
```
```
...
@@ -72,11 +72,11 @@ python3 tools/infer_rec.py -c configs/rec/rec_mtb_nrtr.yml -o Global.infer_img='
...
@@ -72,11 +72,11 @@ python3 tools/infer_rec.py -c configs/rec/rec_mtb_nrtr.yml -o Global.infer_img='
<a
name=
"4-1"
></a>
<a
name=
"4-1"
></a>
### 4.1 Python推理
### 4.1 Python推理
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例(
[
模型下载地址
](
#model
)
),可以使用如下命令进行转换:
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例(
[
模型下载地址
](
https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar
)
),可以使用如下命令进行转换:
```
shell
```
shell
# 注意将pretrained_model的路径设置为本地路径。
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py
-c
configs/rec/rec_mtb_nrtr.yml
-o
Global.pretrained_model
=
./rec_mtb_nrtr_train/best_accuracy Global.save_inference_dir
=
./inference/rec_mtb_nrtr/
Global.load_static_weights
=
False
python3 tools/export_model.py
-c
configs/rec/rec_mtb_nrtr.yml
-o
Global.pretrained_model
=
./rec_mtb_nrtr_train/best_accuracy Global.save_inference_dir
=
./inference/rec_mtb_nrtr/
```
```
执行如下命令进行模型推理:
执行如下命令进行模型推理:
...
...
doc/doc_ch/algorithm_rec_svtr.md
浏览文件 @
d2c11969
...
@@ -26,10 +26,13 @@
...
@@ -26,10 +26,13 @@
1.
首次发现单视觉模型可以达到与视觉语言模型相媲美甚至更高的准确率,并且其具有效率高和适应多语言的优点,在实际应用中很有前景。
1.
首次发现单视觉模型可以达到与视觉语言模型相媲美甚至更高的准确率,并且其具有效率高和适应多语言的优点,在实际应用中很有前景。
2.
SVTR从字符组件的角度出发,逐渐的合并字符组件,自下而上地完成字符的识别。
2.
SVTR从字符组件的角度出发,逐渐的合并字符组件,自下而上地完成字符的识别。
3.
SVTR引入了局部和全局Mixing,分别用于提取字符组件特征和字符间依赖关系,与多尺度的特征一起,形成多粒度特征描述。
3.
SVTR引入了局部和全局Mixing,分别用于提取字符组件特征和字符间依赖关系,与多尺度的特征一起,形成多粒度特征描述。
4.
SVTR-L在识别英文和中文场景文本方面实现了最先进的性能。SVTR-T平衡精确度和效率,在一个NVIDIA 1080Ti GPU中,每个英文图像文本平均消耗4.5ms。
<a
name=
"model"
></a>
<a
name=
"model"
></a>
`SVTR`
在场景文本识别公开数据集上的精度(%)和模型文件如下:
SVTR在场景文本识别公开数据集上的精度(%)和模型文件如下:
*
中文数据集来自于
[
Chinese Benckmark
](
https://arxiv.org/abs/2112.15093
)
,SVTR的中文训练评估策略遵循该论文。
| SVTR |IC13
<br/>
857 | SVT |IIIT5k
<br/>
3000 |IC15
<br/>
1811| SVTP |CUTE80 | Avg_6 |IC15
<br/>
2077 |IC13
<br/>
1015 |IC03
<br/>
867|IC03
<br/>
860|Avg_10 |Chinese| 英文
<br/>
链接 | 中文
<br/>
链接 |
| SVTR |IC13
<br/>
857 | SVT |IIIT5k
<br/>
3000 |IC15
<br/>
1811| SVTP |CUTE80 | Avg_6 |IC15
<br/>
2077 |IC13
<br/>
1015 |IC03
<br/>
867|IC03
<br/>
860|Avg_10 |Chinese| 英文
<br/>
链接 | 中文
<br/>
链接 |
|:-----:|:------:|:-----:|:---------:|:------:|:-----:|:-----:|:-----:|:-------:|:-------:|:-----:|:-----:|:-----:|:-----:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|
|:-----:|:------:|:-----:|:---------:|:------:|:-----:|:-----:|:-----:|:-------:|:-------:|:-----:|:-----:|:-----:|:-----:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|
...
@@ -56,10 +59,6 @@
...
@@ -56,10 +59,6 @@
[
英文数据集下载
](
https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here
)
[
英文数据集下载
](
https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here
)
[
中文数据集下载
](
https://github.com/fudanvi/benchmarking-chinese-text-recognition#download
)
[
中文数据集下载
](
https://github.com/fudanvi/benchmarking-chinese-text-recognition#download
)
**注意:**
1.
训练
`SVTR`
时,需将将配置文件中的测试数据集路径设置为本地的评估数据集路径,例如将中文的
`scene_test`
数据集修改为
`scene_val`
。
2.
训练
`SVTR`
时,需将配置文件中的
`SVTRLableDecode`
修改为
`CTCLabelDecode`
,将
`SVTRRecResizeImg`
修改为
`RecResizeImg`
。
#### 启动训练
#### 启动训练
请参考
[
文本识别训练教程
](
./recognition.md
)
。PaddleOCR对代码进行了模块化,训练
`SVTR`
识别模型时需要
**更换配置文件**
为
`SVTR`
的
[
配置文件
](
../../configs/rec/rec_svtrnet.yml
)
。
请参考
[
文本识别训练教程
](
./recognition.md
)
。PaddleOCR对代码进行了模块化,训练
`SVTR`
识别模型时需要
**更换配置文件**
为
`SVTR`
的
[
配置文件
](
../../configs/rec/rec_svtrnet.yml
)
。
...
@@ -67,7 +66,7 @@
...
@@ -67,7 +66,7 @@
<a
name=
"3-2"
></a>
<a
name=
"3-2"
></a>
### 3.2 评估
### 3.2 评估
可下载
`SVTR`
提供模型文件和配置文件
[
模型下载
](
#model
)
,以
`SVTR-T`
为例,使用如下命令进行评估:
可下载
`SVTR`
提供模型文件和配置文件
:
[
下载地址
](
https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar
)
,以
`SVTR-T`
为例,使用如下命令进行评估:
```
shell
```
shell
# 注意将pretrained_model的路径设置为本地路径。
# 注意将pretrained_model的路径设置为本地路径。
...
@@ -81,7 +80,7 @@ python3 tools/eval.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_s
...
@@ -81,7 +80,7 @@ python3 tools/eval.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_s
使用如下命令进行单张图片预测:
使用如下命令进行单张图片预测:
```
shell
```
shell
# 注意将pretrained_model的路径设置为本地路径。
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/infer_rec.py
-c
./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml
-o
Global.infer_img
=
'./doc/imgs_words_en/word_10.png'
Global.pretrained_model
=
./rec_svtr_tiny_none_ctc_en_train/best_accuracy
Global.load_static_weights
=
false
python3 tools/infer_rec.py
-c
./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml
-o
Global.infer_img
=
'./doc/imgs_words_en/word_10.png'
Global.pretrained_model
=
./rec_svtr_tiny_none_ctc_en_train/best_accuracy
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
```
```
...
@@ -91,11 +90,11 @@ python3 tools/infer_rec.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6glo
...
@@ -91,11 +90,11 @@ python3 tools/infer_rec.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6glo
<a
name=
"4-1"
></a>
<a
name=
"4-1"
></a>
### 4.1 Python推理
### 4.1 Python推理
首先将训练得到best模型,转换成inference model。下面以基于
`SVTR-T`
,在英文数据集训练的模型为例(
[
模型下载地址
](
#model
)
),可以使用如下命令进行转换:
首先将训练得到best模型,转换成inference model。下面以基于
`SVTR-T`
,在英文数据集训练的模型为例(
[
模型下载地址
](
https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar
)
),可以使用如下命令进行转换:
```
shell
```
shell
# 注意将pretrained_model的路径设置为本地路径。
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py
-c
./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml
-o
Global.pretrained_model
=
./rec_svtr_tiny_none_ctc_en_train/best_accuracy Global.save_inference_dir
=
./inference/rec_svtr_tiny_stn_en
/ Global.load_static_weights
=
False
python3 tools/export_model.py
-c
./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml
-o
Global.pretrained_model
=
./rec_svtr_tiny_none_ctc_en_train/best_accuracy Global.save_inference_dir
=
./inference/rec_svtr_tiny_stn_en
```
```
执行如下命令进行模型推理:
执行如下命令进行模型推理:
...
...
ppocr/data/imaug/__init__.py
浏览文件 @
d2c11969
...
@@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
...
@@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from
.make_pse_gt
import
MakePseGt
from
.make_pse_gt
import
MakePseGt
from
.rec_img_aug
import
RecAug
,
RecConAug
,
RecResizeImg
,
ClsResizeImg
,
\
from
.rec_img_aug
import
RecAug
,
RecConAug
,
RecResizeImg
,
ClsResizeImg
,
\
SRNRecResizeImg
,
NRTRRecResizeImg
,
SARRecResizeImg
,
PRENResizeImg
,
SVTRRecResizeImg
SRNRecResizeImg
,
NRTRRecResizeImg
,
SARRecResizeImg
,
PRENResizeImg
from
.ssl_img_aug
import
SSLRotateResize
from
.ssl_img_aug
import
SSLRotateResize
from
.randaugment
import
RandAugment
from
.randaugment
import
RandAugment
from
.copy_paste
import
CopyPaste
from
.copy_paste
import
CopyPaste
...
...
ppocr/data/imaug/rec_img_aug.py
浏览文件 @
d2c11969
...
@@ -207,25 +207,6 @@ class PRENResizeImg(object):
...
@@ -207,25 +207,6 @@ class PRENResizeImg(object):
return
data
return
data
class
SVTRRecResizeImg
(
object
):
def
__init__
(
self
,
image_shape
,
infer_mode
=
False
,
character_dict_path
=
'./ppocr/utils/ppocr_keys_v1.txt'
,
padding
=
True
,
**
kwargs
):
self
.
image_shape
=
image_shape
self
.
infer_mode
=
infer_mode
self
.
character_dict_path
=
character_dict_path
self
.
padding
=
padding
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
norm_img
=
resize_norm_img_svtr
(
img
,
self
.
image_shape
,
self
.
padding
)
data
[
'image'
]
=
norm_img
return
data
def
resize_norm_img_sar
(
img
,
image_shape
,
width_downsample_ratio
=
0.25
):
def
resize_norm_img_sar
(
img
,
image_shape
,
width_downsample_ratio
=
0.25
):
imgC
,
imgH
,
imgW_min
,
imgW_max
=
image_shape
imgC
,
imgH
,
imgW_min
,
imgW_max
=
image_shape
h
=
img
.
shape
[
0
]
h
=
img
.
shape
[
0
]
...
@@ -344,58 +325,6 @@ def resize_norm_img_srn(img, image_shape):
...
@@ -344,58 +325,6 @@ def resize_norm_img_srn(img, image_shape):
return
np
.
reshape
(
img_black
,
(
c
,
row
,
col
)).
astype
(
np
.
float32
)
return
np
.
reshape
(
img_black
,
(
c
,
row
,
col
)).
astype
(
np
.
float32
)
def
resize_norm_img_svtr
(
img
,
image_shape
,
padding
=
False
):
imgC
,
imgH
,
imgW
=
image_shape
h
=
img
.
shape
[
0
]
w
=
img
.
shape
[
1
]
if
not
padding
:
if
h
>
2.0
*
w
:
image
=
Image
.
fromarray
(
img
)
image1
=
image
.
rotate
(
90
,
expand
=
True
)
image2
=
image
.
rotate
(
-
90
,
expand
=
True
)
img1
=
np
.
array
(
image1
)
img2
=
np
.
array
(
image2
)
else
:
img1
=
copy
.
deepcopy
(
img
)
img2
=
copy
.
deepcopy
(
img
)
resized_image
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_image1
=
cv2
.
resize
(
img1
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_image2
=
cv2
.
resize
(
img2
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_w
=
imgW
else
:
ratio
=
w
/
float
(
h
)
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
resized_w
=
imgW
else
:
resized_w
=
int
(
math
.
ceil
(
imgH
*
ratio
))
resized_image
=
cv2
.
resize
(
img
,
(
resized_w
,
imgH
))
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image1
=
resized_image1
.
astype
(
'float32'
)
resized_image2
=
resized_image2
.
astype
(
'float32'
)
if
image_shape
[
0
]
==
1
:
resized_image
=
resized_image
/
255
resized_image
=
resized_image
[
np
.
newaxis
,
:]
else
:
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image1
=
resized_image1
.
transpose
((
2
,
0
,
1
))
/
255
resized_image2
=
resized_image2
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
resized_image1
-=
0.5
resized_image1
/=
0.5
resized_image2
-=
0.5
resized_image2
/=
0.5
padding_im
=
np
.
zeros
((
3
,
imgC
,
imgH
,
imgW
),
dtype
=
np
.
float32
)
padding_im
[
0
,
:,
:,
0
:
resized_w
]
=
resized_image
padding_im
[
1
,
:,
:,
0
:
resized_w
]
=
resized_image1
padding_im
[
2
,
:,
:,
0
:
resized_w
]
=
resized_image2
return
padding_im
def
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
):
def
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
):
imgC
,
imgH
,
imgW
=
image_shape
imgC
,
imgH
,
imgW
=
image_shape
...
...
ppocr/modeling/transforms/stn.py
浏览文件 @
d2c11969
...
@@ -128,8 +128,6 @@ class STN_ON(nn.Layer):
...
@@ -128,8 +128,6 @@ class STN_ON(nn.Layer):
self
.
out_channels
=
in_channels
self
.
out_channels
=
in_channels
def
forward
(
self
,
image
):
def
forward
(
self
,
image
):
if
len
(
image
.
shape
)
==
5
:
image
=
image
.
reshape
([
0
,
image
.
shape
[
-
3
],
image
.
shape
[
-
2
],
image
.
shape
[
-
1
]])
stn_input
=
paddle
.
nn
.
functional
.
interpolate
(
stn_input
=
paddle
.
nn
.
functional
.
interpolate
(
image
,
self
.
tps_inputsize
,
mode
=
"bilinear"
,
align_corners
=
True
)
image
,
self
.
tps_inputsize
,
mode
=
"bilinear"
,
align_corners
=
True
)
stn_img_feat
,
ctrl_points
=
self
.
stn_head
(
stn_input
)
stn_img_feat
,
ctrl_points
=
self
.
stn_head
(
stn_input
)
...
...
ppocr/postprocess/__init__.py
浏览文件 @
d2c11969
...
@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
...
@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
from
.fce_postprocess
import
FCEPostProcess
from
.fce_postprocess
import
FCEPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
\
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
\
DistillationCTCLabelDecode
,
TableLabelDecode
,
NRTRLabelDecode
,
SARLabelDecode
,
\
DistillationCTCLabelDecode
,
TableLabelDecode
,
NRTRLabelDecode
,
SARLabelDecode
,
\
SEEDLabelDecode
,
PRENLabelDecode
,
SVTRLabelDecode
SEEDLabelDecode
,
PRENLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
from
.pg_postprocess
import
PGPostProcess
from
.vqa_token_ser_layoutlm_postprocess
import
VQASerTokenLayoutLMPostProcess
from
.vqa_token_ser_layoutlm_postprocess
import
VQASerTokenLayoutLMPostProcess
...
@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
...
@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'DistillationDBPostProcess'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'SEEDLabelDecode'
,
'VQASerTokenLayoutLMPostProcess'
,
'SEEDLabelDecode'
,
'VQASerTokenLayoutLMPostProcess'
,
'VQAReTokenLayoutLMPostProcess'
,
'PRENLabelDecode'
,
'VQAReTokenLayoutLMPostProcess'
,
'PRENLabelDecode'
,
'DistillationSARLabelDecode'
,
'SVTRLabelDecode'
'DistillationSARLabelDecode'
]
]
if
config
[
'name'
]
==
'PSEPostProcess'
:
if
config
[
'name'
]
==
'PSEPostProcess'
:
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
d2c11969
...
@@ -752,40 +752,3 @@ class PRENLabelDecode(BaseRecLabelDecode):
...
@@ -752,40 +752,3 @@ class PRENLabelDecode(BaseRecLabelDecode):
return
text
return
text
label
=
self
.
decode
(
label
)
label
=
self
.
decode
(
label
)
return
text
,
label
return
text
,
label
class
SVTRLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
SVTRLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
isinstance
(
preds
,
tuple
):
preds
=
preds
[
-
1
]
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=-
1
)
preds_prob
=
preds
.
max
(
axis
=-
1
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
True
)
return_text
=
[]
for
i
in
range
(
0
,
len
(
text
),
3
):
text0
=
text
[
i
]
text1
=
text
[
i
+
1
]
text2
=
text
[
i
+
2
]
text_pred
=
[
text0
[
0
],
text1
[
0
],
text2
[
0
]]
text_prob
=
[
text0
[
1
],
text1
[
1
],
text2
[
1
]]
id_max
=
text_prob
.
index
(
max
(
text_prob
))
return_text
.
append
((
text_pred
[
id_max
],
text_prob
[
id_max
]))
if
label
is
None
:
return
return_text
label
=
self
.
decode
(
label
)
return
return_text
,
label
def
add_special_char
(
self
,
dict_character
):
dict_character
=
[
'blank'
]
+
dict_character
return
dict_character
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录