Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
8308f332
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8308f332
编写于
9月 27, 2021
作者:
T
tink2123
浏览文件
操作
浏览文件
下载
差异文件
fix conflicts
上级
9a44e279
1b2ca6e6
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
1261 addition
and
18 deletion
+1261
-18
configs/rec/rec_resnet_stn_bilstm_att.yml
configs/rec/rec_resnet_stn_bilstm_att.yml
+106
-0
ppocr/data/imaug/__init__.py
ppocr/data/imaug/__init__.py
+1
-1
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+39
-3
ppocr/data/imaug/operators.py
ppocr/data/imaug/operators.py
+14
-1
ppocr/data/imaug/rec_img_aug.py
ppocr/data/imaug/rec_img_aug.py
+25
-1
ppocr/losses/__init__.py
ppocr/losses/__init__.py
+5
-1
ppocr/losses/rec_aster_loss.py
ppocr/losses/rec_aster_loss.py
+99
-0
ppocr/metrics/rec_metric.py
ppocr/metrics/rec_metric.py
+11
-2
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+5
-1
ppocr/modeling/backbones/rec_resnet_aster.py
ppocr/modeling/backbones/rec_resnet_aster.py
+147
-0
ppocr/modeling/heads/__init__.py
ppocr/modeling/heads/__init__.py
+3
-1
ppocr/modeling/heads/rec_aster_head.py
ppocr/modeling/heads/rec_aster_head.py
+390
-0
ppocr/modeling/transforms/__init__.py
ppocr/modeling/transforms/__init__.py
+2
-1
ppocr/modeling/transforms/stn.py
ppocr/modeling/transforms/stn.py
+108
-0
ppocr/modeling/transforms/tps.py
ppocr/modeling/transforms/tps.py
+27
-1
ppocr/modeling/transforms/tps_spatial_transformer.py
ppocr/modeling/transforms/tps_spatial_transformer.py
+157
-0
ppocr/optimizer/optimizer.py
ppocr/optimizer/optimizer.py
+31
-0
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+5
-3
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+82
-0
tools/program.py
tools/program.py
+4
-2
未找到文件。
configs/rec/rec_resnet_stn_bilstm_att.yml
0 → 100644
浏览文件 @
8308f332
Global
:
use_gpu
:
True
epoch_num
:
400
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec/seed
save_epoch_step
:
3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
0
,
2000
]
cal_metric_during_train
:
True
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path
:
character_type
:
EN_symbol
max_text_length
:
100
infer_mode
:
False
use_space_char
:
False
eval_filter
:
True
save_res_path
:
./output/rec/predicts_seed.txt
Optimizer
:
name
:
Adadelta
weight_deacy
:
0.0
momentum
:
0.9
lr
:
name
:
Piecewise
decay_epochs
:
[
4
,
5
,
8
]
values
:
[
1.0
,
0.1
,
0.01
]
regularizer
:
name
:
'
L2'
factor
:
2.0e-05
Architecture
:
model_type
:
seed
algorithm
:
ASTER
Transform
:
name
:
STN_ON
tps_inputsize
:
[
32
,
64
]
tps_outputsize
:
[
32
,
100
]
num_control_points
:
20
tps_margins
:
[
0.05
,
0.05
]
stn_activation
:
none
Backbone
:
name
:
ResNet_ASTER
Head
:
name
:
AsterHead
# AttentionHead
sDim
:
512
attDim
:
512
max_len_labels
:
100
Loss
:
name
:
AsterLoss
PostProcess
:
name
:
SEEDLabelDecode
Metric
:
name
:
RecMetric
main_indicator
:
acc
is_filter
:
True
Train
:
dataset
:
name
:
LMDBDataSet
data_dir
:
./train_data/data_lmdb_release/training/
transforms
:
-
Fasttext
:
path
:
"
./cc.en.300.bin"
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
SEEDLabelEncode
:
# Class handling label
-
SEEDResize
:
image_shape
:
[
3
,
64
,
256
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
,
'
fast_label'
]
# dataloader will return list in this order
loader
:
shuffle
:
True
batch_size_per_card
:
256
drop_last
:
True
num_workers
:
6
Eval
:
dataset
:
name
:
LMDBDataSet
data_dir
:
./train_data/data_lmdb_release/evaluation/
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
SEEDLabelEncode
:
# Class handling label
-
SEEDResize
:
image_shape
:
[
3
,
64
,
256
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
loader
:
shuffle
:
False
drop_last
:
True
batch_size_per_card
:
256
num_workers
:
4
ppocr/data/imaug/__init__.py
浏览文件 @
8308f332
...
...
@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
from
.make_shrink_map
import
MakeShrinkMap
from
.random_crop_data
import
EastRandomCropData
,
PSERandomCrop
from
.rec_img_aug
import
RecAug
,
RecResizeImg
,
ClsResizeImg
,
SRNRecResizeImg
,
NRTRRecResizeImg
,
SARRecResizeImg
from
.rec_img_aug
import
RecAug
,
RecResizeImg
,
ClsResizeImg
,
SRNRecResizeImg
,
NRTRRecResizeImg
,
SARRecResizeImg
,
SEEDResize
from
.randaugment
import
RandAugment
from
.copy_paste
import
CopyPaste
from
.operators
import
*
...
...
ppocr/data/imaug/label_ops.py
浏览文件 @
8308f332
...
...
@@ -106,6 +106,7 @@ class BaseRecLabelEncode(object):
self
.
max_text_len
=
max_text_length
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
self
.
unknown
=
"UNKNOWN"
if
character_type
==
"en"
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
dict_character
=
list
(
self
.
character_str
)
...
...
@@ -174,6 +175,7 @@ class NRTRLabelEncode(BaseRecLabelEncode):
super
(
NRTRLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
character_type
,
use_space_char
)
def
__call__
(
self
,
data
):
text
=
data
[
'label'
]
text
=
self
.
encode
(
text
)
...
...
@@ -185,10 +187,12 @@ class NRTRLabelEncode(BaseRecLabelEncode):
text
=
text
+
[
0
]
*
(
self
.
max_text_len
-
len
(
text
))
data
[
'label'
]
=
np
.
array
(
text
)
return
data
def
add_special_char
(
self
,
dict_character
):
dict_character
=
[
'blank'
,
'<unk>'
,
'<s>'
,
'</s>'
]
+
dict_character
dict_character
=
[
'blank'
,
'<unk>'
,
'<s>'
,
'</s>'
]
+
dict_character
return
dict_character
class
CTCLabelEncode
(
BaseRecLabelEncode
):
""" Convert between text-label and text-index """
...
...
@@ -337,6 +341,39 @@ class AttnLabelEncode(BaseRecLabelEncode):
return
idx
class
SEEDLabelEncode
(
BaseRecLabelEncode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
max_text_length
,
character_dict_path
=
None
,
character_type
=
'ch'
,
use_space_char
=
False
,
**
kwargs
):
super
(
SEEDLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
character_type
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
dict_character
=
dict_character
+
[
self
.
end_str
]
return
dict_character
def
__call__
(
self
,
data
):
text
=
data
[
'label'
]
text
=
self
.
encode
(
text
)
if
text
is
None
:
return
None
if
len
(
text
)
>=
self
.
max_text_len
:
return
None
data
[
'length'
]
=
np
.
array
(
len
(
text
))
+
1
# conclue eos
text
=
text
+
[
len
(
self
.
character
)
-
1
]
*
(
self
.
max_text_len
-
len
(
text
)
)
data
[
'label'
]
=
np
.
array
(
text
)
return
data
class
SRNLabelEncode
(
BaseRecLabelEncode
):
""" Convert between text-label and text-index """
...
...
@@ -416,7 +453,6 @@ class TableLabelEncode(object):
substr
=
lines
[
0
].
decode
(
'utf-8'
).
strip
(
"
\r\n
"
).
split
(
"
\t
"
)
character_num
=
int
(
substr
[
0
])
elem_num
=
int
(
substr
[
1
])
for
cno
in
range
(
1
,
1
+
character_num
):
character
=
lines
[
cno
].
decode
(
'utf-8'
).
strip
(
"
\r\n
"
)
list_character
.
append
(
character
)
...
...
@@ -588,7 +624,7 @@ class SARLabelEncode(BaseRecLabelEncode):
data
[
'length'
]
=
np
.
array
(
len
(
text
))
target
=
[
self
.
start_idx
]
+
text
+
[
self
.
end_idx
]
padded_text
=
[
self
.
padding_idx
for
_
in
range
(
self
.
max_text_len
)]
padded_text
[:
len
(
target
)]
=
target
data
[
'label'
]
=
np
.
array
(
padded_text
)
return
data
...
...
ppocr/data/imaug/operators.py
浏览文件 @
8308f332
...
...
@@ -23,6 +23,7 @@ import sys
import
six
import
cv2
import
numpy
as
np
import
fasttext
class
DecodeImage
(
object
):
...
...
@@ -83,12 +84,13 @@ class NRTRDecodeImage(object):
elif
self
.
img_mode
==
'RGB'
:
assert
img
.
shape
[
2
]
==
3
,
'invalid shape of image[%s]'
%
(
img
.
shape
)
img
=
img
[:,
:,
::
-
1
]
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
if
self
.
channel_first
:
img
=
img
.
transpose
((
2
,
0
,
1
))
data
[
'image'
]
=
img
return
data
class
NormalizeImage
(
object
):
""" normalize image such as substract mean, divide std
"""
...
...
@@ -133,6 +135,17 @@ class ToCHWImage(object):
return
data
class
Fasttext
(
object
):
def
__init__
(
self
,
path
=
"None"
,
**
kwargs
):
self
.
fast_model
=
fasttext
.
load_model
(
path
)
def
__call__
(
self
,
data
):
label
=
data
[
'label'
]
fast_label
=
self
.
fast_model
[
label
]
data
[
'fast_label'
]
=
fast_label
return
data
class
KeepKeys
(
object
):
def
__init__
(
self
,
keep_keys
,
**
kwargs
):
self
.
keep_keys
=
keep_keys
...
...
ppocr/data/imaug/rec_img_aug.py
浏览文件 @
8308f332
...
...
@@ -82,6 +82,18 @@ class RecResizeImg(object):
return
data
class
SEEDResize
(
object
):
def
__init__
(
self
,
image_shape
,
infer_mode
=
False
,
**
kwargs
):
self
.
image_shape
=
image_shape
self
.
infer_mode
=
infer_mode
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
norm_img
=
resize_no_padding_img
(
img
,
self
.
image_shape
)
data
[
'image'
]
=
norm_img
return
data
class
SRNRecResizeImg
(
object
):
def
__init__
(
self
,
image_shape
,
num_heads
,
max_text_length
,
**
kwargs
):
self
.
image_shape
=
image_shape
...
...
@@ -109,7 +121,8 @@ class SARRecResizeImg(object):
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
norm_img
,
resize_shape
,
pad_shape
,
valid_ratio
=
resize_norm_img_sar
(
img
,
self
.
image_shape
,
self
.
width_downsample_ratio
)
norm_img
,
resize_shape
,
pad_shape
,
valid_ratio
=
resize_norm_img_sar
(
img
,
self
.
image_shape
,
self
.
width_downsample_ratio
)
data
[
'image'
]
=
norm_img
data
[
'resized_shape'
]
=
resize_shape
data
[
'pad_shape'
]
=
pad_shape
...
...
@@ -175,6 +188,17 @@ def resize_norm_img(img, image_shape):
return
padding_im
def
resize_no_padding_img
(
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
resized_image
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
return
resized_image
def
resize_norm_img_chinese
(
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
# todo: change to 0 and modified image shape
...
...
ppocr/losses/__init__.py
浏览文件 @
8308f332
...
...
@@ -42,10 +42,14 @@ from .combined_loss import CombinedLoss
# table loss
from
.table_att_loss
import
TableAttentionLoss
from
.rec_aster_loss
import
AsterLoss
def
build_loss
(
config
):
support_dict
=
[
'DBLoss'
,
'EASTLoss'
,
'SASTLoss'
,
'CTCLoss'
,
'ClsLoss'
,
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'NRTRLoss'
,
'TableAttentionLoss'
,
'SARLoss'
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'NRTRLoss'
,
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
]
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/losses/rec_aster_loss.py
0 → 100644
浏览文件 @
8308f332
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
nn
class
CosineEmbeddingLoss
(
nn
.
Layer
):
def
__init__
(
self
,
margin
=
0.
):
super
(
CosineEmbeddingLoss
,
self
).
__init__
()
self
.
margin
=
margin
self
.
epsilon
=
1e-12
def
forward
(
self
,
x1
,
x2
,
target
):
similarity
=
paddle
.
fluid
.
layers
.
reduce_sum
(
x1
*
x2
,
dim
=-
1
)
/
(
paddle
.
norm
(
x1
,
axis
=-
1
)
*
paddle
.
norm
(
x2
,
axis
=-
1
)
+
self
.
epsilon
)
one_list
=
paddle
.
full_like
(
target
,
fill_value
=
1
)
out
=
paddle
.
fluid
.
layers
.
reduce_mean
(
paddle
.
where
(
paddle
.
equal
(
target
,
one_list
),
1.
-
similarity
,
paddle
.
maximum
(
paddle
.
zeros_like
(
similarity
),
similarity
-
self
.
margin
)))
return
out
class
AsterLoss
(
nn
.
Layer
):
def
__init__
(
self
,
weight
=
None
,
size_average
=
True
,
ignore_index
=-
100
,
sequence_normalize
=
False
,
sample_normalize
=
True
,
**
kwargs
):
super
(
AsterLoss
,
self
).
__init__
()
self
.
weight
=
weight
self
.
size_average
=
size_average
self
.
ignore_index
=
ignore_index
self
.
sequence_normalize
=
sequence_normalize
self
.
sample_normalize
=
sample_normalize
self
.
loss_sem
=
CosineEmbeddingLoss
()
self
.
is_cosin_loss
=
True
self
.
loss_func_rec
=
nn
.
CrossEntropyLoss
(
weight
=
None
,
reduction
=
'none'
)
def
forward
(
self
,
predicts
,
batch
):
targets
=
batch
[
1
].
astype
(
"int64"
)
label_lengths
=
batch
[
2
].
astype
(
'int64'
)
sem_target
=
batch
[
3
].
astype
(
'float32'
)
embedding_vectors
=
predicts
[
'embedding_vectors'
]
rec_pred
=
predicts
[
'rec_pred'
]
if
not
self
.
is_cosin_loss
:
sem_loss
=
paddle
.
sum
(
self
.
loss_sem
(
embedding_vectors
,
sem_target
))
else
:
label_target
=
paddle
.
ones
([
embedding_vectors
.
shape
[
0
]])
sem_loss
=
paddle
.
sum
(
self
.
loss_sem
(
embedding_vectors
,
sem_target
,
label_target
))
# rec loss
batch_size
,
def_max_length
=
targets
.
shape
[
0
],
targets
.
shape
[
1
]
mask
=
paddle
.
zeros
([
batch_size
,
def_max_length
])
for
i
in
range
(
batch_size
):
mask
[
i
,
:
label_lengths
[
i
]]
=
1
mask
=
paddle
.
cast
(
mask
,
"float32"
)
max_length
=
max
(
label_lengths
)
assert
max_length
==
rec_pred
.
shape
[
1
]
targets
=
targets
[:,
:
max_length
]
mask
=
mask
[:,
:
max_length
]
rec_pred
=
paddle
.
reshape
(
rec_pred
,
[
-
1
,
rec_pred
.
shape
[
2
]])
input
=
nn
.
functional
.
log_softmax
(
rec_pred
,
axis
=
1
)
targets
=
paddle
.
reshape
(
targets
,
[
-
1
,
1
])
mask
=
paddle
.
reshape
(
mask
,
[
-
1
,
1
])
output
=
-
paddle
.
index_sample
(
input
,
index
=
targets
)
*
mask
output
=
paddle
.
sum
(
output
)
if
self
.
sequence_normalize
:
output
=
output
/
paddle
.
sum
(
mask
)
if
self
.
sample_normalize
:
output
=
output
/
batch_size
loss
=
output
+
sem_loss
*
0.1
return
{
'loss'
:
loss
}
ppocr/metrics/rec_metric.py
浏览文件 @
8308f332
...
...
@@ -13,13 +13,20 @@
# limitations under the License.
import
Levenshtein
import
string
class
RecMetric
(
object
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
**
kwargs
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
is_filter
=
False
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
is_filter
=
is_filter
self
.
reset
()
def
_normalize_text
(
self
,
text
):
text
=
''
.
join
(
filter
(
lambda
x
:
x
in
(
string
.
digits
+
string
.
ascii_letters
),
text
))
return
text
.
lower
()
def
__call__
(
self
,
pred_label
,
*
args
,
**
kwargs
):
preds
,
labels
=
pred_label
correct_num
=
0
...
...
@@ -28,6 +35,9 @@ class RecMetric(object):
for
(
pred
,
pred_conf
),
(
target
,
_
)
in
zip
(
preds
,
labels
):
pred
=
pred
.
replace
(
" "
,
""
)
target
=
target
.
replace
(
" "
,
""
)
if
self
.
is_filter
:
pred
=
self
.
_normalize_text
(
pred
)
target
=
self
.
_normalize_text
(
target
)
norm_edit_dis
+=
Levenshtein
.
distance
(
pred
,
target
)
/
max
(
len
(
pred
),
len
(
target
),
1
)
if
pred
==
target
:
...
...
@@ -57,4 +67,3 @@ class RecMetric(object):
self
.
correct_num
=
0
self
.
all_num
=
0
self
.
norm_edit_dis
=
0
ppocr/modeling/backbones/__init__.py
浏览文件 @
8308f332
...
...
@@ -29,7 +29,8 @@ def build_backbone(config, model_type):
from
.rec_nrtr_mtb
import
MTB
from
.rec_resnet_31
import
ResNet31
support_dict
=
[
'MobileNetV1Enhance'
,
'MobileNetV3'
,
'ResNet'
,
'ResNetFPN'
,
'MTB'
,
"ResNet31"
'MobileNetV1Enhance'
,
'MobileNetV3'
,
'ResNet'
,
'ResNetFPN'
,
'MTB'
,
"ResNet31"
]
elif
model_type
==
"e2e"
:
from
.e2e_resnet_vd_pg
import
ResNet
...
...
@@ -38,6 +39,9 @@ def build_backbone(config, model_type):
from
.table_resnet_vd
import
ResNet
from
.table_mobilenet_v3
import
MobileNetV3
support_dict
=
[
"ResNet"
,
"MobileNetV3"
]
elif
model_type
==
"seed"
:
from
.rec_resnet_aster
import
ResNet_ASTER
support_dict
=
[
"ResNet_ASTER"
]
else
:
raise
NotImplementedError
...
...
ppocr/modeling/backbones/rec_resnet_aster.py
0 → 100644
浏览文件 @
8308f332
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
sys
import
math
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
):
"""3x3 convolution with padding"""
return
nn
.
Conv2D
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias_attr
=
False
)
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
):
"""1x1 convolution"""
return
nn
.
Conv2D
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
bias_attr
=
False
)
def
get_sinusoid_encoding
(
n_position
,
feat_dim
,
wave_length
=
10000
):
# [n_position]
positions
=
paddle
.
arange
(
0
,
n_position
)
# [feat_dim]
dim_range
=
paddle
.
arange
(
0
,
feat_dim
)
dim_range
=
paddle
.
pow
(
wave_length
,
2
*
(
dim_range
//
2
)
/
feat_dim
)
# [n_position, feat_dim]
angles
=
paddle
.
unsqueeze
(
positions
,
axis
=
1
)
/
paddle
.
unsqueeze
(
dim_range
,
axis
=
0
)
angles
=
paddle
.
cast
(
angles
,
"float32"
)
angles
[:,
0
::
2
]
=
paddle
.
sin
(
angles
[:,
0
::
2
])
angles
[:,
1
::
2
]
=
paddle
.
cos
(
angles
[:,
1
::
2
])
return
angles
class
AsterBlock
(
nn
.
Layer
):
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
):
super
(
AsterBlock
,
self
).
__init__
()
self
.
conv1
=
conv1x1
(
inplanes
,
planes
,
stride
)
self
.
bn1
=
nn
.
BatchNorm2D
(
planes
)
self
.
relu
=
nn
.
ReLU
()
self
.
conv2
=
conv3x3
(
planes
,
planes
)
self
.
bn2
=
nn
.
BatchNorm2D
(
planes
)
self
.
downsample
=
downsample
self
.
stride
=
stride
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
out
+=
residual
out
=
self
.
relu
(
out
)
return
out
class
ResNet_ASTER
(
nn
.
Layer
):
"""For aster or crnn"""
def
__init__
(
self
,
with_lstm
=
True
,
n_group
=
1
,
in_channels
=
3
):
super
(
ResNet_ASTER
,
self
).
__init__
()
self
.
with_lstm
=
with_lstm
self
.
n_group
=
n_group
self
.
layer0
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
,
32
,
kernel_size
=
(
3
,
3
),
stride
=
1
,
padding
=
1
,
bias_attr
=
False
),
nn
.
BatchNorm2D
(
32
),
nn
.
ReLU
())
self
.
inplanes
=
32
self
.
layer1
=
self
.
_make_layer
(
32
,
3
,
[
2
,
2
])
# [16, 50]
self
.
layer2
=
self
.
_make_layer
(
64
,
4
,
[
2
,
2
])
# [8, 25]
self
.
layer3
=
self
.
_make_layer
(
128
,
6
,
[
2
,
1
])
# [4, 25]
self
.
layer4
=
self
.
_make_layer
(
256
,
6
,
[
2
,
1
])
# [2, 25]
self
.
layer5
=
self
.
_make_layer
(
512
,
3
,
[
2
,
1
])
# [1, 25]
if
with_lstm
:
self
.
rnn
=
nn
.
LSTM
(
512
,
256
,
direction
=
"bidirect"
,
num_layers
=
2
)
self
.
out_channels
=
2
*
256
else
:
self
.
out_channels
=
512
def
_make_layer
(
self
,
planes
,
blocks
,
stride
):
downsample
=
None
if
stride
!=
[
1
,
1
]
or
self
.
inplanes
!=
planes
:
downsample
=
nn
.
Sequential
(
conv1x1
(
self
.
inplanes
,
planes
,
stride
),
nn
.
BatchNorm2D
(
planes
))
layers
=
[]
layers
.
append
(
AsterBlock
(
self
.
inplanes
,
planes
,
stride
,
downsample
))
self
.
inplanes
=
planes
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
AsterBlock
(
self
.
inplanes
,
planes
))
return
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
x0
=
self
.
layer0
(
x
)
x1
=
self
.
layer1
(
x0
)
x2
=
self
.
layer2
(
x1
)
x3
=
self
.
layer3
(
x2
)
x4
=
self
.
layer4
(
x3
)
x5
=
self
.
layer5
(
x4
)
cnn_feat
=
x5
.
squeeze
(
2
)
# [N, c, w]
cnn_feat
=
paddle
.
transpose
(
cnn_feat
,
perm
=
[
0
,
2
,
1
])
if
self
.
with_lstm
:
rnn_feat
,
_
=
self
.
rnn
(
cnn_feat
)
return
rnn_feat
else
:
return
cnn_feat
if
__name__
==
"__main__"
:
x
=
paddle
.
randn
([
3
,
3
,
32
,
100
])
net
=
ResNet_ASTER
()
encoder_feat
=
net
(
x
)
print
(
encoder_feat
.
shape
)
ppocr/modeling/heads/__init__.py
浏览文件 @
8308f332
...
...
@@ -28,12 +28,14 @@ def build_head(config):
from
.rec_srn_head
import
SRNHead
from
.rec_nrtr_head
import
Transformer
from
.rec_sar_head
import
SARHead
from
.rec_aster_head
import
AsterHead
# cls head
from
.cls_head
import
ClsHead
support_dict
=
[
'DBHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'ClsHead'
,
'AttentionHead'
,
'SRNHead'
,
'PGHead'
,
'Transformer'
,
'TableAttentionHead'
,
'SARHead'
'SRNHead'
,
'PGHead'
,
'TableAttentionHead'
,
'SARHead'
,
'Transformer'
,
'AsterHead'
,
'SARHead'
]
#table head
...
...
ppocr/modeling/heads/rec_aster_head.py
0 → 100644
浏览文件 @
8308f332
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
sys
import
paddle
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
class
AsterHead
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
sDim
,
attDim
,
max_len_labels
,
time_step
=
25
,
beam_width
=
5
,
**
kwargs
):
super
(
AsterHead
,
self
).
__init__
()
self
.
num_classes
=
out_channels
self
.
in_planes
=
in_channels
self
.
sDim
=
sDim
self
.
attDim
=
attDim
self
.
max_len_labels
=
max_len_labels
self
.
decoder
=
AttentionRecognitionHead
(
in_channels
,
out_channels
,
sDim
,
attDim
,
max_len_labels
)
self
.
time_step
=
time_step
self
.
embeder
=
Embedding
(
self
.
time_step
,
in_channels
)
self
.
beam_width
=
beam_width
self
.
eos
=
self
.
num_classes
-
1
def
forward
(
self
,
x
,
targets
=
None
,
embed
=
None
):
return_dict
=
{}
embedding_vectors
=
self
.
embeder
(
x
)
if
self
.
training
:
rec_targets
,
rec_lengths
,
_
=
targets
rec_pred
=
self
.
decoder
([
x
,
rec_targets
,
rec_lengths
],
embedding_vectors
)
return_dict
[
'rec_pred'
]
=
rec_pred
return_dict
[
'embedding_vectors'
]
=
embedding_vectors
else
:
rec_pred
,
rec_pred_scores
=
self
.
decoder
.
beam_search
(
x
,
self
.
beam_width
,
self
.
eos
,
embedding_vectors
)
return_dict
[
'rec_pred'
]
=
rec_pred
return_dict
[
'rec_pred_scores'
]
=
rec_pred_scores
return_dict
[
'embedding_vectors'
]
=
embedding_vectors
return
return_dict
class
Embedding
(
nn
.
Layer
):
def
__init__
(
self
,
in_timestep
,
in_planes
,
mid_dim
=
4096
,
embed_dim
=
300
):
super
(
Embedding
,
self
).
__init__
()
self
.
in_timestep
=
in_timestep
self
.
in_planes
=
in_planes
self
.
embed_dim
=
embed_dim
self
.
mid_dim
=
mid_dim
self
.
eEmbed
=
nn
.
Linear
(
in_timestep
*
in_planes
,
self
.
embed_dim
)
# Embed encoder output to a word-embedding like
def
forward
(
self
,
x
):
x
=
paddle
.
reshape
(
x
,
[
paddle
.
shape
(
x
)[
0
],
-
1
])
x
=
self
.
eEmbed
(
x
)
return
x
class
AttentionRecognitionHead
(
nn
.
Layer
):
"""
input: [b x 16 x 64 x in_planes]
output: probability sequence: [b x T x num_classes]
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
sDim
,
attDim
,
max_len_labels
):
super
(
AttentionRecognitionHead
,
self
).
__init__
()
self
.
num_classes
=
out_channels
# this is the output classes. So it includes the <EOS>.
self
.
in_planes
=
in_channels
self
.
sDim
=
sDim
self
.
attDim
=
attDim
self
.
max_len_labels
=
max_len_labels
self
.
decoder
=
DecoderUnit
(
sDim
=
sDim
,
xDim
=
in_channels
,
yDim
=
self
.
num_classes
,
attDim
=
attDim
)
def
forward
(
self
,
x
,
embed
):
x
,
targets
,
lengths
=
x
batch_size
=
paddle
.
shape
(
x
)[
0
]
# Decoder
state
=
self
.
decoder
.
get_initial_state
(
embed
)
outputs
=
[]
for
i
in
range
(
max
(
lengths
)):
if
i
==
0
:
y_prev
=
paddle
.
full
(
shape
=
[
batch_size
],
fill_value
=
self
.
num_classes
)
else
:
y_prev
=
targets
[:,
i
-
1
]
output
,
state
=
self
.
decoder
(
x
,
state
,
y_prev
)
outputs
.
append
(
output
)
outputs
=
paddle
.
concat
([
_
.
unsqueeze
(
1
)
for
_
in
outputs
],
1
)
return
outputs
# inference stage.
def
sample
(
self
,
x
):
x
,
_
,
_
=
x
batch_size
=
x
.
size
(
0
)
# Decoder
state
=
paddle
.
zeros
([
1
,
batch_size
,
self
.
sDim
])
predicted_ids
,
predicted_scores
=
[],
[]
for
i
in
range
(
self
.
max_len_labels
):
if
i
==
0
:
y_prev
=
paddle
.
full
(
shape
=
[
batch_size
],
fill_value
=
self
.
num_classes
)
else
:
y_prev
=
predicted
output
,
state
=
self
.
decoder
(
x
,
state
,
y_prev
)
output
=
F
.
softmax
(
output
,
axis
=
1
)
score
,
predicted
=
output
.
max
(
1
)
predicted_ids
.
append
(
predicted
.
unsqueeze
(
1
))
predicted_scores
.
append
(
score
.
unsqueeze
(
1
))
predicted_ids
=
paddle
.
concat
([
predicted_ids
,
1
])
predicted_scores
=
paddle
.
concat
([
predicted_scores
,
1
])
# return predicted_ids.squeeze(), predicted_scores.squeeze()
return
predicted_ids
,
predicted_scores
def
beam_search
(
self
,
x
,
beam_width
,
eos
,
embed
):
def
_inflate
(
tensor
,
times
,
dim
):
repeat_dims
=
[
1
]
*
tensor
.
dim
()
repeat_dims
[
dim
]
=
times
output
=
paddle
.
tile
(
tensor
,
repeat_dims
)
return
output
# https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py
batch_size
,
l
,
d
=
x
.
shape
# inflated_encoder_feats = _inflate(encoder_feats, beam_width, 0) # ABC --> AABBCC -/-> ABCABC
x
=
paddle
.
tile
(
paddle
.
transpose
(
x
.
unsqueeze
(
1
),
perm
=
[
1
,
0
,
2
,
3
]),
[
beam_width
,
1
,
1
,
1
])
inflated_encoder_feats
=
paddle
.
reshape
(
paddle
.
transpose
(
x
,
perm
=
[
1
,
0
,
2
,
3
]),
[
-
1
,
l
,
d
])
# Initialize the decoder
state
=
self
.
decoder
.
get_initial_state
(
embed
,
tile_times
=
beam_width
)
pos_index
=
paddle
.
reshape
(
paddle
.
arange
(
batch_size
)
*
beam_width
,
shape
=
[
-
1
,
1
])
# Initialize the scores
sequence_scores
=
paddle
.
full
(
shape
=
[
batch_size
*
beam_width
,
1
],
fill_value
=-
float
(
'Inf'
))
index
=
[
i
*
beam_width
for
i
in
range
(
0
,
batch_size
)]
sequence_scores
[
index
]
=
0.0
# Initialize the input vector
y_prev
=
paddle
.
full
(
shape
=
[
batch_size
*
beam_width
],
fill_value
=
self
.
num_classes
)
# Store decisions for backtracking
stored_scores
=
list
()
stored_predecessors
=
list
()
stored_emitted_symbols
=
list
()
for
i
in
range
(
self
.
max_len_labels
):
output
,
state
=
self
.
decoder
(
inflated_encoder_feats
,
state
,
y_prev
)
state
=
paddle
.
unsqueeze
(
state
,
axis
=
0
)
log_softmax_output
=
paddle
.
nn
.
functional
.
log_softmax
(
output
,
axis
=
1
)
sequence_scores
=
_inflate
(
sequence_scores
,
self
.
num_classes
,
1
)
sequence_scores
+=
log_softmax_output
scores
,
candidates
=
paddle
.
topk
(
paddle
.
reshape
(
sequence_scores
,
[
batch_size
,
-
1
]),
beam_width
,
axis
=
1
)
# Reshape input = (bk, 1) and sequence_scores = (bk, 1)
y_prev
=
paddle
.
reshape
(
candidates
%
self
.
num_classes
,
shape
=
[
batch_size
*
beam_width
])
sequence_scores
=
paddle
.
reshape
(
scores
,
shape
=
[
batch_size
*
beam_width
,
1
])
# Update fields for next timestep
pos_index
=
paddle
.
expand_as
(
pos_index
,
candidates
)
predecessors
=
paddle
.
cast
(
candidates
/
self
.
num_classes
+
pos_index
,
dtype
=
'int64'
)
predecessors
=
paddle
.
reshape
(
predecessors
,
shape
=
[
batch_size
*
beam_width
,
1
])
state
=
paddle
.
index_select
(
state
,
index
=
predecessors
.
squeeze
(),
axis
=
1
)
# Update sequence socres and erase scores for <eos> symbol so that they aren't expanded
stored_scores
.
append
(
sequence_scores
.
clone
())
y_prev
=
paddle
.
reshape
(
y_prev
,
shape
=
[
-
1
,
1
])
eos_prev
=
paddle
.
full_like
(
y_prev
,
fill_value
=
eos
)
mask
=
eos_prev
==
y_prev
mask
=
paddle
.
nonzero
(
mask
)
if
mask
.
dim
()
>
0
:
sequence_scores
=
sequence_scores
.
numpy
()
mask
=
mask
.
numpy
()
sequence_scores
[
mask
]
=
-
float
(
'inf'
)
sequence_scores
=
paddle
.
to_tensor
(
sequence_scores
)
# Cache results for backtracking
stored_predecessors
.
append
(
predecessors
)
y_prev
=
paddle
.
squeeze
(
y_prev
)
stored_emitted_symbols
.
append
(
y_prev
)
# Do backtracking to return the optimal values
#====== backtrak ======#
# Initialize return variables given different types
p
=
list
()
l
=
[[
self
.
max_len_labels
]
*
beam_width
for
_
in
range
(
batch_size
)
]
# Placeholder for lengths of top-k sequences
# the last step output of the beams are not sorted
# thus they are sorted here
sorted_score
,
sorted_idx
=
paddle
.
topk
(
paddle
.
reshape
(
stored_scores
[
-
1
],
shape
=
[
batch_size
,
beam_width
]),
beam_width
)
# initialize the sequence scores with the sorted last step beam scores
s
=
sorted_score
.
clone
()
batch_eos_found
=
[
0
]
*
batch_size
# the number of EOS found
# in the backward loop below for each batch
t
=
self
.
max_len_labels
-
1
# initialize the back pointer with the sorted order of the last step beams.
# add pos_index for indexing variable with b*k as the first dimension.
t_predecessors
=
paddle
.
reshape
(
sorted_idx
+
pos_index
.
expand_as
(
sorted_idx
),
shape
=
[
batch_size
*
beam_width
])
while
t
>=
0
:
# Re-order the variables with the back pointer
current_symbol
=
paddle
.
index_select
(
stored_emitted_symbols
[
t
],
index
=
t_predecessors
,
axis
=
0
)
t_predecessors
=
paddle
.
index_select
(
stored_predecessors
[
t
].
squeeze
(),
index
=
t_predecessors
,
axis
=
0
)
eos_indices
=
stored_emitted_symbols
[
t
]
==
eos
eos_indices
=
paddle
.
nonzero
(
eos_indices
)
if
eos_indices
.
dim
()
>
0
:
for
i
in
range
(
eos_indices
.
shape
[
0
]
-
1
,
-
1
,
-
1
):
# Indices of the EOS symbol for both variables
# with b*k as the first dimension, and b, k for
# the first two dimensions
idx
=
eos_indices
[
i
]
b_idx
=
int
(
idx
[
0
]
/
beam_width
)
# The indices of the replacing position
# according to the replacement strategy noted above
res_k_idx
=
beam_width
-
(
batch_eos_found
[
b_idx
]
%
beam_width
)
-
1
batch_eos_found
[
b_idx
]
+=
1
res_idx
=
b_idx
*
beam_width
+
res_k_idx
# Replace the old information in return variables
# with the new ended sequence information
t_predecessors
[
res_idx
]
=
stored_predecessors
[
t
][
idx
[
0
]]
current_symbol
[
res_idx
]
=
stored_emitted_symbols
[
t
][
idx
[
0
]]
s
[
b_idx
,
res_k_idx
]
=
stored_scores
[
t
][
idx
[
0
],
0
]
l
[
b_idx
][
res_k_idx
]
=
t
+
1
# record the back tracked results
p
.
append
(
current_symbol
)
t
-=
1
# Sort and re-order again as the added ended sequences may change
# the order (very unlikely)
s
,
re_sorted_idx
=
s
.
topk
(
beam_width
)
for
b_idx
in
range
(
batch_size
):
l
[
b_idx
]
=
[
l
[
b_idx
][
k_idx
.
item
()]
for
k_idx
in
re_sorted_idx
[
b_idx
,
:]
]
re_sorted_idx
=
paddle
.
reshape
(
re_sorted_idx
+
pos_index
.
expand_as
(
re_sorted_idx
),
[
batch_size
*
beam_width
])
# Reverse the sequences and re-order at the same time
# It is reversed because the backtracking happens in reverse time order
p
=
[
paddle
.
reshape
(
paddle
.
index_select
(
step
,
re_sorted_idx
,
0
),
shape
=
[
batch_size
,
beam_width
,
-
1
])
for
step
in
reversed
(
p
)
]
p
=
paddle
.
concat
(
p
,
-
1
)[:,
0
,
:]
return
p
,
paddle
.
ones_like
(
p
)
class
AttentionUnit
(
nn
.
Layer
):
def
__init__
(
self
,
sDim
,
xDim
,
attDim
):
super
(
AttentionUnit
,
self
).
__init__
()
self
.
sDim
=
sDim
self
.
xDim
=
xDim
self
.
attDim
=
attDim
self
.
sEmbed
=
nn
.
Linear
(
sDim
,
attDim
)
self
.
xEmbed
=
nn
.
Linear
(
xDim
,
attDim
)
self
.
wEmbed
=
nn
.
Linear
(
attDim
,
1
)
def
forward
(
self
,
x
,
sPrev
):
batch_size
,
T
,
_
=
x
.
shape
# [b x T x xDim]
x
=
paddle
.
reshape
(
x
,
[
-
1
,
self
.
xDim
])
# [(b x T) x xDim]
xProj
=
self
.
xEmbed
(
x
)
# [(b x T) x attDim]
xProj
=
paddle
.
reshape
(
xProj
,
[
batch_size
,
T
,
-
1
])
# [b x T x attDim]
sPrev
=
sPrev
.
squeeze
(
0
)
sProj
=
self
.
sEmbed
(
sPrev
)
# [b x attDim]
sProj
=
paddle
.
unsqueeze
(
sProj
,
1
)
# [b x 1 x attDim]
sProj
=
paddle
.
expand
(
sProj
,
[
batch_size
,
T
,
self
.
attDim
])
# [b x T x attDim]
sumTanh
=
paddle
.
tanh
(
sProj
+
xProj
)
sumTanh
=
paddle
.
reshape
(
sumTanh
,
[
-
1
,
self
.
attDim
])
vProj
=
self
.
wEmbed
(
sumTanh
)
# [(b x T) x 1]
vProj
=
paddle
.
reshape
(
vProj
,
[
batch_size
,
T
])
alpha
=
F
.
softmax
(
vProj
,
axis
=
1
)
# attention weights for each sample in the minibatch
return
alpha
class
DecoderUnit
(
nn
.
Layer
):
def
__init__
(
self
,
sDim
,
xDim
,
yDim
,
attDim
):
super
(
DecoderUnit
,
self
).
__init__
()
self
.
sDim
=
sDim
self
.
xDim
=
xDim
self
.
yDim
=
yDim
self
.
attDim
=
attDim
self
.
emdDim
=
attDim
self
.
attention_unit
=
AttentionUnit
(
sDim
,
xDim
,
attDim
)
self
.
tgt_embedding
=
nn
.
Embedding
(
yDim
+
1
,
self
.
emdDim
,
weight_attr
=
nn
.
initializer
.
Normal
(
std
=
0.01
))
# the last is used for <BOS>
self
.
gru
=
nn
.
GRUCell
(
input_size
=
xDim
+
self
.
emdDim
,
hidden_size
=
sDim
)
self
.
fc
=
nn
.
Linear
(
sDim
,
yDim
,
weight_attr
=
nn
.
initializer
.
Normal
(
std
=
0.01
),
bias_attr
=
nn
.
initializer
.
Constant
(
value
=
0
))
self
.
embed_fc
=
nn
.
Linear
(
300
,
self
.
sDim
)
def
get_initial_state
(
self
,
embed
,
tile_times
=
1
):
assert
embed
.
shape
[
1
]
==
300
state
=
self
.
embed_fc
(
embed
)
# N * sDim
if
tile_times
!=
1
:
state
=
state
.
unsqueeze
(
1
)
trans_state
=
paddle
.
transpose
(
state
,
perm
=
[
1
,
0
,
2
])
state
=
paddle
.
tile
(
trans_state
,
repeat_times
=
[
tile_times
,
1
,
1
])
trans_state
=
paddle
.
transpose
(
state
,
perm
=
[
1
,
0
,
2
])
state
=
paddle
.
reshape
(
trans_state
,
shape
=
[
-
1
,
self
.
sDim
])
state
=
state
.
unsqueeze
(
0
)
# 1 * N * sDim
return
state
def
forward
(
self
,
x
,
sPrev
,
yPrev
):
# x: feature sequence from the image decoder.
batch_size
,
T
,
_
=
x
.
shape
alpha
=
self
.
attention_unit
(
x
,
sPrev
)
context
=
paddle
.
squeeze
(
paddle
.
matmul
(
alpha
.
unsqueeze
(
1
),
x
),
axis
=
1
)
yPrev
=
paddle
.
cast
(
yPrev
,
dtype
=
"int64"
)
yProj
=
self
.
tgt_embedding
(
yPrev
)
concat_context
=
paddle
.
concat
([
yProj
,
context
],
1
)
concat_context
=
paddle
.
squeeze
(
concat_context
,
1
)
sPrev
=
paddle
.
squeeze
(
sPrev
,
0
)
output
,
state
=
self
.
gru
(
concat_context
,
sPrev
)
output
=
paddle
.
squeeze
(
output
,
axis
=
1
)
output
=
self
.
fc
(
output
)
return
output
,
state
\ No newline at end of file
ppocr/modeling/transforms/__init__.py
浏览文件 @
8308f332
...
...
@@ -17,8 +17,9 @@ __all__ = ['build_transform']
def
build_transform
(
config
):
from
.tps
import
TPS
from
.tps
import
STN_ON
support_dict
=
[
'TPS'
]
support_dict
=
[
'TPS'
,
'STN_ON'
]
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
...
...
ppocr/modeling/transforms/stn.py
0 → 100644
浏览文件 @
8308f332
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
from
paddle
import
nn
,
ParamAttr
from
paddle.nn
import
functional
as
F
import
numpy
as
np
def
conv3x3_block
(
in_channels
,
out_channels
,
stride
=
1
):
n
=
3
*
3
*
out_channels
w
=
math
.
sqrt
(
2.
/
n
)
conv_layer
=
nn
.
Conv2D
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
weight_attr
=
nn
.
initializer
.
Normal
(
mean
=
0.0
,
std
=
w
),
bias_attr
=
nn
.
initializer
.
Constant
(
0
))
block
=
nn
.
Sequential
(
conv_layer
,
nn
.
BatchNorm2D
(
out_channels
),
nn
.
ReLU
())
return
block
class
STN
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
num_ctrlpoints
,
activation
=
'none'
):
super
(
STN
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
num_ctrlpoints
=
num_ctrlpoints
self
.
activation
=
activation
self
.
stn_convnet
=
nn
.
Sequential
(
conv3x3_block
(
in_channels
,
32
),
#32x64
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
),
conv3x3_block
(
32
,
64
),
#16x32
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
),
conv3x3_block
(
64
,
128
),
# 8*16
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
),
conv3x3_block
(
128
,
256
),
# 4*8
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
),
conv3x3_block
(
256
,
256
),
# 2*4,
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
),
conv3x3_block
(
256
,
256
))
# 1*2
self
.
stn_fc1
=
nn
.
Sequential
(
nn
.
Linear
(
2
*
256
,
512
,
weight_attr
=
nn
.
initializer
.
Normal
(
0
,
0.001
),
bias_attr
=
nn
.
initializer
.
Constant
(
0
)),
nn
.
BatchNorm1D
(
512
),
nn
.
ReLU
())
fc2_bias
=
self
.
init_stn
()
self
.
stn_fc2
=
nn
.
Linear
(
512
,
num_ctrlpoints
*
2
,
weight_attr
=
nn
.
initializer
.
Constant
(
0.0
),
bias_attr
=
nn
.
initializer
.
Assign
(
fc2_bias
))
def
init_stn
(
self
):
margin
=
0.01
sampling_num_per_side
=
int
(
self
.
num_ctrlpoints
/
2
)
ctrl_pts_x
=
np
.
linspace
(
margin
,
1.
-
margin
,
sampling_num_per_side
)
ctrl_pts_y_top
=
np
.
ones
(
sampling_num_per_side
)
*
margin
ctrl_pts_y_bottom
=
np
.
ones
(
sampling_num_per_side
)
*
(
1
-
margin
)
ctrl_pts_top
=
np
.
stack
([
ctrl_pts_x
,
ctrl_pts_y_top
],
axis
=
1
)
ctrl_pts_bottom
=
np
.
stack
([
ctrl_pts_x
,
ctrl_pts_y_bottom
],
axis
=
1
)
ctrl_points
=
np
.
concatenate
(
[
ctrl_pts_top
,
ctrl_pts_bottom
],
axis
=
0
).
astype
(
np
.
float32
)
if
self
.
activation
==
'none'
:
pass
elif
self
.
activation
==
'sigmoid'
:
ctrl_points
=
-
np
.
log
(
1.
/
ctrl_points
-
1.
)
ctrl_points
=
paddle
.
to_tensor
(
ctrl_points
)
fc2_bias
=
paddle
.
reshape
(
ctrl_points
,
shape
=
[
ctrl_points
.
shape
[
0
]
*
ctrl_points
.
shape
[
1
]])
return
fc2_bias
def
forward
(
self
,
x
):
x
=
self
.
stn_convnet
(
x
)
batch_size
,
_
,
h
,
w
=
x
.
shape
x
=
paddle
.
reshape
(
x
,
shape
=
(
batch_size
,
-
1
))
img_feat
=
self
.
stn_fc1
(
x
)
x
=
self
.
stn_fc2
(
0.1
*
img_feat
)
if
self
.
activation
==
'sigmoid'
:
x
=
F
.
sigmoid
(
x
)
x
=
paddle
.
reshape
(
x
,
shape
=
[
-
1
,
self
.
num_ctrlpoints
,
2
])
return
img_feat
,
x
ppocr/modeling/transforms/tps.py
浏览文件 @
8308f332
...
...
@@ -22,6 +22,9 @@ from paddle import nn, ParamAttr
from
paddle.nn
import
functional
as
F
import
numpy
as
np
from
.tps_spatial_transformer
import
TPSSpatialTransformer
from
.stn
import
STN
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
...
...
@@ -231,7 +234,8 @@ class GridGenerator(nn.Layer):
""" Return inv_delta_C which is needed to calculate T """
F
=
self
.
F
hat_eye
=
paddle
.
eye
(
F
,
dtype
=
'float64'
)
# F x F
hat_C
=
paddle
.
norm
(
C
.
reshape
([
1
,
F
,
2
])
-
C
.
reshape
([
F
,
1
,
2
]),
axis
=
2
)
+
hat_eye
hat_C
=
paddle
.
norm
(
C
.
reshape
([
1
,
F
,
2
])
-
C
.
reshape
([
F
,
1
,
2
]),
axis
=
2
)
+
hat_eye
hat_C
=
(
hat_C
**
2
)
*
paddle
.
log
(
hat_C
)
delta_C
=
paddle
.
concat
(
# F+3 x F+3
[
...
...
@@ -301,3 +305,25 @@ class TPS(nn.Layer):
[
-
1
,
image
.
shape
[
2
],
image
.
shape
[
3
],
2
])
batch_I_r
=
F
.
grid_sample
(
x
=
image
,
grid
=
batch_P_prime
)
return
batch_I_r
class
STN_ON
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
tps_inputsize
,
tps_outputsize
,
num_control_points
,
tps_margins
,
stn_activation
):
super
(
STN_ON
,
self
).
__init__
()
self
.
tps
=
TPSSpatialTransformer
(
output_image_size
=
tuple
(
tps_outputsize
),
num_control_points
=
num_control_points
,
margins
=
tuple
(
tps_margins
))
self
.
stn_head
=
STN
(
in_channels
=
in_channels
,
num_ctrlpoints
=
num_control_points
,
activation
=
stn_activation
)
self
.
tps_inputsize
=
tps_inputsize
self
.
out_channels
=
in_channels
def
forward
(
self
,
image
):
stn_input
=
paddle
.
nn
.
functional
.
interpolate
(
image
,
self
.
tps_inputsize
,
mode
=
"bilinear"
,
align_corners
=
True
)
stn_img_feat
,
ctrl_points
=
self
.
stn_head
(
stn_input
)
x
,
_
=
self
.
tps
(
image
,
ctrl_points
)
return
x
ppocr/modeling/transforms/tps_spatial_transformer.py
0 → 100644
浏览文件 @
8308f332
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
from
paddle
import
nn
,
ParamAttr
from
paddle.nn
import
functional
as
F
import
numpy
as
np
import
itertools
def
grid_sample
(
input
,
grid
,
canvas
=
None
):
input
.
stop_gradient
=
False
output
=
F
.
grid_sample
(
input
,
grid
)
if
canvas
is
None
:
return
output
else
:
input_mask
=
paddle
.
ones
(
shape
=
input
.
shape
)
output_mask
=
F
.
grid_sample
(
input_mask
,
grid
)
padded_output
=
output
*
output_mask
+
canvas
*
(
1
-
output_mask
)
return
padded_output
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
def
compute_partial_repr
(
input_points
,
control_points
):
N
=
input_points
.
shape
[
0
]
M
=
control_points
.
shape
[
0
]
pairwise_diff
=
paddle
.
reshape
(
input_points
,
shape
=
[
N
,
1
,
2
])
-
paddle
.
reshape
(
control_points
,
shape
=
[
1
,
M
,
2
])
# original implementation, very slow
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
pairwise_diff_square
=
pairwise_diff
*
pairwise_diff
pairwise_dist
=
pairwise_diff_square
[:,
:,
0
]
+
pairwise_diff_square
[:,
:,
1
]
repr_matrix
=
0.5
*
pairwise_dist
*
paddle
.
log
(
pairwise_dist
)
# fix numerical error for 0 * log(0), substitute all nan with 0
mask
=
repr_matrix
!=
repr_matrix
repr_matrix
[
mask
]
=
0
return
repr_matrix
# output_ctrl_pts are specified, according to our task.
def
build_output_control_points
(
num_control_points
,
margins
):
margin_x
,
margin_y
=
margins
num_ctrl_pts_per_side
=
num_control_points
//
2
ctrl_pts_x
=
np
.
linspace
(
margin_x
,
1.0
-
margin_x
,
num_ctrl_pts_per_side
)
ctrl_pts_y_top
=
np
.
ones
(
num_ctrl_pts_per_side
)
*
margin_y
ctrl_pts_y_bottom
=
np
.
ones
(
num_ctrl_pts_per_side
)
*
(
1.0
-
margin_y
)
ctrl_pts_top
=
np
.
stack
([
ctrl_pts_x
,
ctrl_pts_y_top
],
axis
=
1
)
ctrl_pts_bottom
=
np
.
stack
([
ctrl_pts_x
,
ctrl_pts_y_bottom
],
axis
=
1
)
# ctrl_pts_top = ctrl_pts_top[1:-1,:]
# ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
output_ctrl_pts_arr
=
np
.
concatenate
(
[
ctrl_pts_top
,
ctrl_pts_bottom
],
axis
=
0
)
output_ctrl_pts
=
paddle
.
to_tensor
(
output_ctrl_pts_arr
)
return
output_ctrl_pts
class
TPSSpatialTransformer
(
nn
.
Layer
):
def
__init__
(
self
,
output_image_size
=
None
,
num_control_points
=
None
,
margins
=
None
):
super
(
TPSSpatialTransformer
,
self
).
__init__
()
self
.
output_image_size
=
output_image_size
self
.
num_control_points
=
num_control_points
self
.
margins
=
margins
self
.
target_height
,
self
.
target_width
=
output_image_size
target_control_points
=
build_output_control_points
(
num_control_points
,
margins
)
N
=
num_control_points
# N = N - 4
# create padded kernel matrix
forward_kernel
=
paddle
.
zeros
(
shape
=
[
N
+
3
,
N
+
3
])
target_control_partial_repr
=
compute_partial_repr
(
target_control_points
,
target_control_points
)
target_control_partial_repr
=
paddle
.
cast
(
target_control_partial_repr
,
forward_kernel
.
dtype
)
forward_kernel
[:
N
,
:
N
]
=
target_control_partial_repr
forward_kernel
[:
N
,
-
3
]
=
1
forward_kernel
[
-
3
,
:
N
]
=
1
target_control_points
=
paddle
.
cast
(
target_control_points
,
forward_kernel
.
dtype
)
forward_kernel
[:
N
,
-
2
:]
=
target_control_points
forward_kernel
[
-
2
:,
:
N
]
=
paddle
.
transpose
(
target_control_points
,
perm
=
[
1
,
0
])
# compute inverse matrix
inverse_kernel
=
paddle
.
inverse
(
forward_kernel
)
# create target cordinate matrix
HW
=
self
.
target_height
*
self
.
target_width
target_coordinate
=
list
(
itertools
.
product
(
range
(
self
.
target_height
),
range
(
self
.
target_width
)))
target_coordinate
=
paddle
.
to_tensor
(
target_coordinate
)
# HW x 2
Y
,
X
=
paddle
.
split
(
target_coordinate
,
target_coordinate
.
shape
[
1
],
axis
=
1
)
#Y, X = target_coordinate.split(1, dim = 1)
Y
=
Y
/
(
self
.
target_height
-
1
)
X
=
X
/
(
self
.
target_width
-
1
)
target_coordinate
=
paddle
.
concat
(
[
X
,
Y
],
axis
=
1
)
# convert from (y, x) to (x, y)
target_coordinate_partial_repr
=
compute_partial_repr
(
target_coordinate
,
target_control_points
)
target_coordinate_repr
=
paddle
.
concat
(
[
target_coordinate_partial_repr
,
paddle
.
ones
(
shape
=
[
HW
,
1
]),
target_coordinate
],
axis
=
1
)
# register precomputed matrices
self
.
inverse_kernel
=
inverse_kernel
self
.
padding_matrix
=
paddle
.
zeros
(
shape
=
[
3
,
2
])
self
.
target_coordinate_repr
=
target_coordinate_repr
self
.
target_control_points
=
target_control_points
def
forward
(
self
,
input
,
source_control_points
):
assert
source_control_points
.
ndimension
()
==
3
assert
source_control_points
.
shape
[
1
]
==
self
.
num_control_points
assert
source_control_points
.
shape
[
2
]
==
2
#batch_size = source_control_points.shape[0]
batch_size
=
paddle
.
shape
(
source_control_points
)[
0
]
self
.
padding_matrix
=
paddle
.
expand
(
self
.
padding_matrix
,
shape
=
[
batch_size
,
3
,
2
])
Y
=
paddle
.
concat
([
source_control_points
,
self
.
padding_matrix
],
1
)
mapping_matrix
=
paddle
.
matmul
(
self
.
inverse_kernel
,
Y
)
source_coordinate
=
paddle
.
matmul
(
self
.
target_coordinate_repr
,
mapping_matrix
)
grid
=
paddle
.
reshape
(
source_coordinate
,
shape
=
[
-
1
,
self
.
target_height
,
self
.
target_width
,
2
])
grid
=
paddle
.
clip
(
grid
,
0
,
1
)
# the source_control_points may be out of [0, 1].
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
grid
=
2.0
*
grid
-
1.0
output_maps
=
grid_sample
(
input
,
grid
,
canvas
=
None
)
return
output_maps
,
source_coordinate
ppocr/optimizer/optimizer.py
浏览文件 @
8308f332
...
...
@@ -127,3 +127,34 @@ class RMSProp(object):
grad_clip
=
self
.
grad_clip
,
parameters
=
parameters
)
return
opt
class
Adadelta
(
object
):
def
__init__
(
self
,
learning_rate
=
0.001
,
epsilon
=
1e-08
,
rho
=
0.95
,
parameter_list
=
None
,
weight_decay
=
None
,
grad_clip
=
None
,
name
=
None
,
**
kwargs
):
self
.
learning_rate
=
learning_rate
self
.
epsilon
=
epsilon
self
.
rho
=
rho
self
.
parameter_list
=
parameter_list
self
.
learning_rate
=
learning_rate
self
.
weight_decay
=
weight_decay
self
.
grad_clip
=
grad_clip
self
.
name
=
name
def
__call__
(
self
,
parameters
):
opt
=
optim
.
Adadelta
(
learning_rate
=
self
.
learning_rate
,
epsilon
=
self
.
epsilon
,
rho
=
self
.
rho
,
weight_decay
=
self
.
weight_decay
,
grad_clip
=
self
.
grad_clip
,
name
=
self
.
name
,
parameters
=
parameters
)
return
opt
ppocr/postprocess/__init__.py
浏览文件 @
8308f332
...
...
@@ -24,17 +24,19 @@ __all__ = ['build_post_process']
from
.db_postprocess
import
DBPostProcess
,
DistillationDBPostProcess
from
.east_postprocess
import
EASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
DistillationCTCLabelDecode
,
NRTRLabelDecode
,
\
TableLabelDecode
,
SAR
LabelDecode
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
DistillationCTCLabelDecode
,
\
TableLabelDecode
,
NRTRLabelDecode
,
SARLabelDecode
,
SEED
LabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
def
build_post_process
(
config
,
global_config
=
None
):
support_dict
=
[
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
,
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'DistillationDBPostProcess'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
'DistillationDBPostProcess'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'SEEDLabelDecode'
]
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
8308f332
...
...
@@ -303,6 +303,88 @@ class AttnLabelDecode(BaseRecLabelDecode):
return
idx
class
SEEDLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
character_type
=
'ch'
,
use_space_char
=
False
,
**
kwargs
):
super
(
SEEDLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
dict_character
=
dict_character
dict_character
=
dict_character
+
[
self
.
end_str
]
return
dict_character
def
get_ignored_tokens
(
self
):
end_idx
=
self
.
get_beg_end_flag_idx
(
"eos"
)
return
[
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"sos"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"eos"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
%
beg_or_end
return
idx
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
[
end_idx
]
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
int
(
text_index
[
batch_idx
][
idx
])
==
int
(
end_idx
):
break
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
)))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
preds_idx
=
preds
[
"rec_pred"
]
if
isinstance
(
preds_idx
,
paddle
.
Tensor
):
preds_idx
=
preds_idx
.
numpy
()
if
"rec_pred_scores"
in
preds
:
preds_idx
=
preds
[
"rec_pred"
]
preds_prob
=
preds
[
"rec_pred_scores"
]
else
:
preds_idx
=
preds
[
"rec_pred"
].
argmax
(
axis
=
2
)
preds_prob
=
preds
[
"rec_pred"
].
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
return
text
,
label
class
SRNLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
...
...
tools/program.py
浏览文件 @
8308f332
...
...
@@ -188,10 +188,12 @@ def train(config,
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
use_nrtr
=
config
[
'Architecture'
][
'algorithm'
]
==
"NRTR"
use_sar
=
config
[
'Architecture'
][
'algorithm'
]
==
'SAR'
use_seed
=
config
[
'Architecture'
][
'algorithm'
]
==
'SEED'
try
:
model_type
=
config
[
'Architecture'
][
'model_type'
]
except
:
model_type
=
None
algorithm
=
config
[
'Architecture'
][
'algorithm'
]
if
'start_epoch'
in
best_model_dict
:
start_epoch
=
best_model_dict
[
'start_epoch'
]
...
...
@@ -215,7 +217,7 @@ def train(config,
images
=
batch
[
0
]
if
use_srn
:
model_average
=
True
if
use_srn
or
model_type
==
'table'
or
use_nrtr
or
use_sar
:
if
use_srn
or
model_type
==
'table'
or
use_nrtr
or
use_sar
or
use_seed
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
preds
=
model
(
images
)
...
...
@@ -402,7 +404,7 @@ def preprocess(is_train=False):
alg
=
config
[
'Architecture'
][
'algorithm'
]
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'ASTER'
]
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录