Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
59cc4efd
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
59cc4efd
编写于
7月 22, 2021
作者:
T
tink2123
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add for SEED
上级
38801c7f
变更
21
展开全部
隐藏空白更改
内联
并排
Showing
21 changed file
with
1868 addition
and
54 deletion
+1868
-54
configs/rec/rec_resnet_stn_bilstm_att.yml
configs/rec/rec_resnet_stn_bilstm_att.yml
+101
-0
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+38
-26
ppocr/data/simple_dataset.py
ppocr/data/simple_dataset.py
+1
-0
ppocr/losses/__init__.py
ppocr/losses/__init__.py
+4
-1
ppocr/losses/rec_aster_loss.py
ppocr/losses/rec_aster_loss.py
+79
-0
ppocr/losses/rec_att_loss.py
ppocr/losses/rec_att_loss.py
+2
-0
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+3
-1
ppocr/modeling/backbones/levit.py
ppocr/modeling/backbones/levit.py
+707
-0
ppocr/modeling/backbones/rec_resnet_aster.py
ppocr/modeling/backbones/rec_resnet_aster.py
+147
-0
ppocr/modeling/heads/__init__.py
ppocr/modeling/heads/__init__.py
+5
-1
ppocr/modeling/heads/rec_aster_head.py
ppocr/modeling/heads/rec_aster_head.py
+258
-0
ppocr/modeling/heads/rec_att_head.py
ppocr/modeling/heads/rec_att_head.py
+5
-0
ppocr/modeling/transforms/__init__.py
ppocr/modeling/transforms/__init__.py
+2
-1
ppocr/modeling/transforms/stn.py
ppocr/modeling/transforms/stn.py
+121
-0
ppocr/modeling/transforms/tps.py
ppocr/modeling/transforms/tps.py
+28
-1
ppocr/modeling/transforms/tps_spatial_transformer.py
ppocr/modeling/transforms/tps_spatial_transformer.py
+178
-0
ppocr/modeling/transforms/tps_torch.py
ppocr/modeling/transforms/tps_torch.py
+149
-0
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+18
-11
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+10
-7
tools/program.py
tools/program.py
+10
-5
tools/train.py
tools/train.py
+2
-0
未找到文件。
configs/rec/rec_resnet_stn_bilstm_att.yml
0 → 100644
浏览文件 @
59cc4efd
Global
:
use_gpu
:
False
epoch_num
:
400
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec/b3_rare_r34_none_gru/
save_epoch_step
:
3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
0
,
2000
]
cal_metric_during_train
:
True
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path
:
character_type
:
EN_symbol
max_text_length
:
25
infer_mode
:
False
use_space_char
:
False
save_res_path
:
./output/rec/predicts_b3_rare_r34_none_gru.txt
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
learning_rate
:
0.0005
regularizer
:
name
:
'
L2'
factor
:
0.00000
Architecture
:
model_type
:
rec
algorithm
:
ASTER
Transform
:
name
:
STN_ON
tps_inputsize
:
[
32
,
64
]
tps_outputsize
:
[
32
,
100
]
num_control_points
:
20
tps_margins
:
[
0.05
,
0.05
]
stn_activation
:
none
Backbone
:
name
:
ResNet_ASTER
Head
:
name
:
AsterHead
# AttentionHead
sDim
:
512
attDim
:
512
max_len_labels
:
100
Loss
:
name
:
AsterLoss
PostProcess
:
name
:
AttnLabelDecode
Metric
:
name
:
RecMetric
main_indicator
:
acc
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/ic15_data/
label_file_list
:
[
"
./train_data/ic15_data/1.txt"
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
AttnLabelEncode
:
# Class handling label
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
100
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
loader
:
shuffle
:
True
batch_size_per_card
:
2
drop_last
:
True
num_workers
:
8
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/ic15_data/
label_file_list
:
[
"
./train_data/ic15_data/1.txt"
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
AttnLabelEncode
:
# Class handling label
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
100
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
2
num_workers
:
8
ppocr/data/imaug/label_ops.py
浏览文件 @
59cc4efd
...
...
@@ -104,6 +104,7 @@ class BaseRecLabelEncode(object):
self
.
max_text_len
=
max_text_length
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
self
.
unknown
=
"UNKNOWN"
if
character_type
==
"en"
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
dict_character
=
list
(
self
.
character_str
)
...
...
@@ -275,7 +276,9 @@ class AttnLabelEncode(BaseRecLabelEncode):
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
dict_character
=
[
self
.
beg_str
]
+
dict_character
+
[
self
.
end_str
]
self
.
unknown
=
"UNKNOWN"
dict_character
=
[
self
.
beg_str
]
+
dict_character
+
[
self
.
end_str
]
+
[
self
.
unknown
]
return
dict_character
def
__call__
(
self
,
data
):
...
...
@@ -288,6 +291,7 @@ class AttnLabelEncode(BaseRecLabelEncode):
data
[
'length'
]
=
np
.
array
(
len
(
text
))
text
=
[
0
]
+
text
+
[
len
(
self
.
character
)
-
1
]
+
[
0
]
*
(
self
.
max_text_len
-
len
(
text
)
-
2
)
data
[
'label'
]
=
np
.
array
(
text
)
return
data
...
...
@@ -352,19 +356,22 @@ class SRNLabelEncode(BaseRecLabelEncode):
%
beg_or_end
return
idx
class
TableLabelEncode
(
object
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
max_text_length
,
max_elem_length
,
max_cell_num
,
character_dict_path
,
span_weight
=
1.0
,
**
kwargs
):
def
__init__
(
self
,
max_text_length
,
max_elem_length
,
max_cell_num
,
character_dict_path
,
span_weight
=
1.0
,
**
kwargs
):
self
.
max_text_length
=
max_text_length
self
.
max_elem_length
=
max_elem_length
self
.
max_cell_num
=
max_cell_num
list_character
,
list_elem
=
self
.
load_char_elem_dict
(
character_dict_path
)
list_character
,
list_elem
=
self
.
load_char_elem_dict
(
character_dict_path
)
list_character
=
self
.
add_special_char
(
list_character
)
list_elem
=
self
.
add_special_char
(
list_elem
)
self
.
dict_character
=
{}
...
...
@@ -374,7 +381,7 @@ class TableLabelEncode(object):
for
i
,
elem
in
enumerate
(
list_elem
):
self
.
dict_elem
[
elem
]
=
i
self
.
span_weight
=
span_weight
def
load_char_elem_dict
(
self
,
character_dict_path
):
list_character
=
[]
list_elem
=
[]
...
...
@@ -383,27 +390,27 @@ class TableLabelEncode(object):
substr
=
lines
[
0
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
split
(
"
\t
"
)
character_num
=
int
(
substr
[
0
])
elem_num
=
int
(
substr
[
1
])
for
cno
in
range
(
1
,
1
+
character_num
):
for
cno
in
range
(
1
,
1
+
character_num
):
character
=
lines
[
cno
].
decode
(
'utf-8'
).
strip
(
"
\n
"
)
list_character
.
append
(
character
)
for
eno
in
range
(
1
+
character_num
,
1
+
character_num
+
elem_num
):
for
eno
in
range
(
1
+
character_num
,
1
+
character_num
+
elem_num
):
elem
=
lines
[
eno
].
decode
(
'utf-8'
).
strip
(
"
\n
"
)
list_elem
.
append
(
elem
)
return
list_character
,
list_elem
def
add_special_char
(
self
,
list_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
list_character
=
[
self
.
beg_str
]
+
list_character
+
[
self
.
end_str
]
return
list_character
def
get_span_idx_list
(
self
):
span_idx_list
=
[]
for
elem
in
self
.
dict_elem
:
if
'span'
in
elem
:
span_idx_list
.
append
(
self
.
dict_elem
[
elem
])
return
span_idx_list
def
__call__
(
self
,
data
):
cells
=
data
[
'cells'
]
structure
=
data
[
'structure'
][
'tokens'
]
...
...
@@ -412,18 +419,22 @@ class TableLabelEncode(object):
return
None
elem_num
=
len
(
structure
)
structure
=
[
0
]
+
structure
+
[
len
(
self
.
dict_elem
)
-
1
]
structure
=
structure
+
[
0
]
*
(
self
.
max_elem_length
+
2
-
len
(
structure
))
structure
=
structure
+
[
0
]
*
(
self
.
max_elem_length
+
2
-
len
(
structure
)
)
structure
=
np
.
array
(
structure
)
data
[
'structure'
]
=
structure
elem_char_idx1
=
self
.
dict_elem
[
'<td>'
]
elem_char_idx2
=
self
.
dict_elem
[
'<td'
]
span_idx_list
=
self
.
get_span_idx_list
()
td_idx_list
=
np
.
logical_or
(
structure
==
elem_char_idx1
,
structure
==
elem_char_idx2
)
td_idx_list
=
np
.
logical_or
(
structure
==
elem_char_idx1
,
structure
==
elem_char_idx2
)
td_idx_list
=
np
.
where
(
td_idx_list
)[
0
]
structure_mask
=
np
.
ones
((
self
.
max_elem_length
+
2
,
1
),
dtype
=
np
.
float32
)
structure_mask
=
np
.
ones
(
(
self
.
max_elem_length
+
2
,
1
),
dtype
=
np
.
float32
)
bbox_list
=
np
.
zeros
((
self
.
max_elem_length
+
2
,
4
),
dtype
=
np
.
float32
)
bbox_list_mask
=
np
.
zeros
((
self
.
max_elem_length
+
2
,
1
),
dtype
=
np
.
float32
)
bbox_list_mask
=
np
.
zeros
(
(
self
.
max_elem_length
+
2
,
1
),
dtype
=
np
.
float32
)
img_height
,
img_width
,
img_ch
=
data
[
'image'
].
shape
if
len
(
span_idx_list
)
>
0
:
span_weight
=
len
(
td_idx_list
)
*
1.0
/
len
(
span_idx_list
)
...
...
@@ -450,9 +461,11 @@ class TableLabelEncode(object):
char_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'char'
)
elem_beg_idx
=
self
.
get_beg_end_flag_idx
(
'beg'
,
'elem'
)
elem_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'elem'
)
data
[
'sp_tokens'
]
=
np
.
array
([
char_beg_idx
,
char_end_idx
,
elem_beg_idx
,
elem_end_idx
,
elem_char_idx1
,
elem_char_idx2
,
self
.
max_text_length
,
self
.
max_elem_length
,
self
.
max_cell_num
,
elem_num
])
data
[
'sp_tokens'
]
=
np
.
array
([
char_beg_idx
,
char_end_idx
,
elem_beg_idx
,
elem_end_idx
,
elem_char_idx1
,
elem_char_idx2
,
self
.
max_text_length
,
self
.
max_elem_length
,
self
.
max_cell_num
,
elem_num
])
return
data
def
encode
(
self
,
text
,
char_or_elem
):
...
...
@@ -504,9 +517,8 @@ class TableLabelEncode(object):
idx
=
np
.
array
(
self
.
dict_elem
[
self
.
end_str
])
else
:
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx of elem"
\
%
beg_or_end
%
beg_or_end
else
:
assert
False
,
"Unsupport type %s in char_or_elem"
\
%
char_or_elem
%
char_or_elem
return
idx
\ No newline at end of file
ppocr/data/simple_dataset.py
浏览文件 @
59cc4efd
...
...
@@ -22,6 +22,7 @@ from .imaug import transform, create_operators
class
SimpleDataSet
(
Dataset
):
def
__init__
(
self
,
config
,
mode
,
logger
,
seed
=
None
):
print
(
"===== simpledataset ========"
)
super
(
SimpleDataSet
,
self
).
__init__
()
self
.
logger
=
logger
self
.
mode
=
mode
.
lower
()
...
...
ppocr/losses/__init__.py
浏览文件 @
59cc4efd
...
...
@@ -41,10 +41,13 @@ from .combined_loss import CombinedLoss
# table loss
from
.table_att_loss
import
TableAttentionLoss
from
.rec_aster_loss
import
AsterLoss
def
build_loss
(
config
):
support_dict
=
[
'DBLoss'
,
'EASTLoss'
,
'SASTLoss'
,
'CTCLoss'
,
'ClsLoss'
,
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'TableAttentionLoss'
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'TableAttentionLoss'
,
'AsterLoss'
]
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
...
...
ppocr/losses/rec_aster_loss.py
0 → 100644
浏览文件 @
59cc4efd
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
nn
import
fasttext
class
AsterLoss
(
nn
.
Layer
):
def
__init__
(
self
,
weight
=
None
,
size_average
=
True
,
ignore_index
=-
100
,
sequence_normalize
=
False
,
sample_normalize
=
True
,
**
kwargs
):
super
(
AsterLoss
,
self
).
__init__
()
self
.
weight
=
weight
self
.
size_average
=
size_average
self
.
ignore_index
=
ignore_index
self
.
sequence_normalize
=
sequence_normalize
self
.
sample_normalize
=
sample_normalize
self
.
loss_func
=
paddle
.
nn
.
CosineSimilarity
()
def
forward
(
self
,
predicts
,
batch
):
targets
=
batch
[
1
].
astype
(
"int64"
)
label_lengths
=
batch
[
2
].
astype
(
'int64'
)
# sem_target = batch[3].astype('float32')
embedding_vectors
=
predicts
[
'embedding_vectors'
]
rec_pred
=
predicts
[
'rec_pred'
]
# semantic loss
# print(embedding_vectors)
# print(embedding_vectors.shape)
# targets = fasttext[targets]
# sem_loss = 1 - self.loss_func(embedding_vectors, targets)
# rec loss
batch_size
,
num_steps
,
num_classes
=
rec_pred
.
shape
[
0
],
rec_pred
.
shape
[
1
],
rec_pred
.
shape
[
2
]
assert
len
(
targets
.
shape
)
==
len
(
list
(
rec_pred
.
shape
))
-
1
,
\
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
mask
=
paddle
.
zeros
([
batch_size
,
num_steps
])
for
i
in
range
(
batch_size
):
mask
[
i
,
:
label_lengths
[
i
]]
=
1
mask
=
paddle
.
cast
(
mask
,
"float32"
)
max_length
=
max
(
label_lengths
)
assert
max_length
==
rec_pred
.
shape
[
1
]
targets
=
targets
[:,
:
max_length
]
mask
=
mask
[:,
:
max_length
]
rec_pred
=
paddle
.
reshape
(
rec_pred
,
[
-
1
,
rec_pred
.
shape
[
-
1
]])
input
=
nn
.
functional
.
log_softmax
(
rec_pred
,
axis
=
1
)
targets
=
paddle
.
reshape
(
targets
,
[
-
1
,
1
])
mask
=
paddle
.
reshape
(
mask
,
[
-
1
,
1
])
# print("input:", input)
output
=
-
paddle
.
gather
(
input
,
index
=
targets
,
axis
=
1
)
*
mask
output
=
paddle
.
sum
(
output
)
if
self
.
sequence_normalize
:
output
=
output
/
paddle
.
sum
(
mask
)
if
self
.
sample_normalize
:
output
=
output
/
batch_size
loss
=
output
return
{
'loss'
:
loss
}
# , 'sem_loss':sem_loss}
ppocr/losses/rec_att_loss.py
浏览文件 @
59cc4efd
...
...
@@ -35,5 +35,7 @@ class AttentionLoss(nn.Layer):
inputs
=
paddle
.
reshape
(
predicts
,
[
-
1
,
predicts
.
shape
[
-
1
]])
targets
=
paddle
.
reshape
(
targets
,
[
-
1
])
print
(
"input:"
,
paddle
.
argmax
(
inputs
,
axis
=
1
))
print
(
"targets:"
,
targets
)
return
{
'loss'
:
paddle
.
sum
(
self
.
loss_func
(
inputs
,
targets
))}
ppocr/modeling/backbones/__init__.py
浏览文件 @
59cc4efd
...
...
@@ -26,8 +26,10 @@ def build_backbone(config, model_type):
from
.rec_resnet_vd
import
ResNet
from
.rec_resnet_fpn
import
ResNetFPN
from
.rec_mv1_enhance
import
MobileNetV1Enhance
from
.rec_resnet_aster
import
ResNet_ASTER
support_dict
=
[
"MobileNetV1Enhance"
,
"MobileNetV3"
,
"ResNet"
,
"ResNetFPN"
"MobileNetV1Enhance"
,
"MobileNetV3"
,
"ResNet"
,
"ResNetFPN"
,
"ResNet_ASTER"
]
elif
model_type
==
"e2e"
:
from
.e2e_resnet_vd_pg
import
ResNet
...
...
ppocr/modeling/backbones/levit.py
0 → 100644
浏览文件 @
59cc4efd
此差异已折叠。
点击以展开。
ppocr/modeling/backbones/rec_resnet_aster.py
0 → 100644
浏览文件 @
59cc4efd
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
sys
import
math
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
):
"""3x3 convolution with padding"""
return
nn
.
Conv2D
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias_attr
=
False
)
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
):
"""1x1 convolution"""
return
nn
.
Conv2D
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
bias_attr
=
False
)
def
get_sinusoid_encoding
(
n_position
,
feat_dim
,
wave_length
=
10000
):
# [n_position]
positions
=
paddle
.
arange
(
0
,
n_position
)
# [feat_dim]
dim_range
=
paddle
.
arange
(
0
,
feat_dim
)
dim_range
=
paddle
.
pow
(
wave_length
,
2
*
(
dim_range
//
2
)
/
feat_dim
)
# [n_position, feat_dim]
angles
=
paddle
.
unsqueeze
(
positions
,
axis
=
1
)
/
paddle
.
unsqueeze
(
dim_range
,
axis
=
0
)
angles
=
paddle
.
cast
(
angles
,
"float32"
)
angles
[:,
0
::
2
]
=
paddle
.
sin
(
angles
[:,
0
::
2
])
angles
[:,
1
::
2
]
=
paddle
.
cos
(
angles
[:,
1
::
2
])
return
angles
class
AsterBlock
(
nn
.
Layer
):
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
):
super
(
AsterBlock
,
self
).
__init__
()
self
.
conv1
=
conv1x1
(
inplanes
,
planes
,
stride
)
self
.
bn1
=
nn
.
BatchNorm2D
(
planes
)
self
.
relu
=
nn
.
ReLU
()
self
.
conv2
=
conv3x3
(
planes
,
planes
)
self
.
bn2
=
nn
.
BatchNorm2D
(
planes
)
self
.
downsample
=
downsample
self
.
stride
=
stride
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
out
+=
residual
out
=
self
.
relu
(
out
)
return
out
class
ResNet_ASTER
(
nn
.
Layer
):
"""For aster or crnn"""
def
__init__
(
self
,
with_lstm
=
True
,
n_group
=
1
,
in_channels
=
3
):
super
(
ResNet_ASTER
,
self
).
__init__
()
self
.
with_lstm
=
with_lstm
self
.
n_group
=
n_group
self
.
layer0
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
,
32
,
kernel_size
=
(
3
,
3
),
stride
=
1
,
padding
=
1
,
bias_attr
=
False
),
nn
.
BatchNorm2D
(
32
),
nn
.
ReLU
())
self
.
inplanes
=
32
self
.
layer1
=
self
.
_make_layer
(
32
,
3
,
[
2
,
2
])
# [16, 50]
self
.
layer2
=
self
.
_make_layer
(
64
,
4
,
[
2
,
2
])
# [8, 25]
self
.
layer3
=
self
.
_make_layer
(
128
,
6
,
[
2
,
1
])
# [4, 25]
self
.
layer4
=
self
.
_make_layer
(
256
,
6
,
[
2
,
1
])
# [2, 25]
self
.
layer5
=
self
.
_make_layer
(
512
,
3
,
[
2
,
1
])
# [1, 25]
if
with_lstm
:
self
.
rnn
=
nn
.
LSTM
(
512
,
256
,
direction
=
"bidirect"
,
num_layers
=
2
)
self
.
out_channels
=
2
*
256
else
:
self
.
out_channels
=
512
def
_make_layer
(
self
,
planes
,
blocks
,
stride
):
downsample
=
None
if
stride
!=
[
1
,
1
]
or
self
.
inplanes
!=
planes
:
downsample
=
nn
.
Sequential
(
conv1x1
(
self
.
inplanes
,
planes
,
stride
),
nn
.
BatchNorm2D
(
planes
))
layers
=
[]
layers
.
append
(
AsterBlock
(
self
.
inplanes
,
planes
,
stride
,
downsample
))
self
.
inplanes
=
planes
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
AsterBlock
(
self
.
inplanes
,
planes
))
return
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
x0
=
self
.
layer0
(
x
)
x1
=
self
.
layer1
(
x0
)
x2
=
self
.
layer2
(
x1
)
x3
=
self
.
layer3
(
x2
)
x4
=
self
.
layer4
(
x3
)
x5
=
self
.
layer5
(
x4
)
cnn_feat
=
x5
.
squeeze
(
2
)
# [N, c, w]
cnn_feat
=
paddle
.
transpose
(
cnn_feat
,
perm
=
[
0
,
2
,
1
])
if
self
.
with_lstm
:
rnn_feat
,
_
=
self
.
rnn
(
cnn_feat
)
return
rnn_feat
else
:
return
cnn_feat
if
__name__
==
"__main__"
:
x
=
paddle
.
randn
([
3
,
3
,
32
,
100
])
net
=
ResNet_ASTER
()
encoder_feat
=
net
(
x
)
print
(
encoder_feat
.
shape
)
ppocr/modeling/heads/__init__.py
浏览文件 @
59cc4efd
...
...
@@ -26,12 +26,15 @@ def build_head(config):
from
.rec_ctc_head
import
CTCHead
from
.rec_att_head
import
AttentionHead
from
.rec_srn_head
import
SRNHead
from
.rec_aster_head
import
AttentionRecognitionHead
,
AsterHead
# cls head
from
.cls_head
import
ClsHead
support_dict
=
[
'DBHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'ClsHead'
,
'AttentionHead'
,
'SRNHead'
,
'PGHead'
,
'TableAttentionHead'
]
'SRNHead'
,
'PGHead'
,
'TableAttentionHead'
,
'AttentionRecognitionHead'
,
'AsterHead'
]
#table head
from
.table_att_head
import
TableAttentionHead
...
...
@@ -39,5 +42,6 @@ def build_head(config):
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
'head only support {}'
.
format
(
support_dict
))
print
(
config
)
module_class
=
eval
(
module_name
)(
**
config
)
return
module_class
ppocr/modeling/heads/rec_aster_head.py
0 → 100644
浏览文件 @
59cc4efd
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
sys
import
paddle
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
class
AsterHead
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
sDim
,
attDim
,
max_len_labels
,
time_step
=
25
,
beam_width
=
5
,
**
kwargs
):
super
(
AsterHead
,
self
).
__init__
()
self
.
num_classes
=
out_channels
self
.
in_planes
=
in_channels
self
.
sDim
=
sDim
self
.
attDim
=
attDim
self
.
max_len_labels
=
max_len_labels
self
.
decoder
=
AttentionRecognitionHead
(
in_channels
,
out_channels
,
sDim
,
attDim
,
max_len_labels
)
self
.
time_step
=
time_step
self
.
embeder
=
Embedding
(
self
.
time_step
,
in_channels
)
self
.
beam_width
=
beam_width
def
forward
(
self
,
x
,
targets
=
None
,
embed
=
None
):
return_dict
=
{}
embedding_vectors
=
self
.
embeder
(
x
)
rec_targets
,
rec_lengths
=
targets
if
self
.
training
:
rec_pred
=
self
.
decoder
([
x
,
rec_targets
,
rec_lengths
],
embedding_vectors
)
return_dict
[
'rec_pred'
]
=
rec_pred
return_dict
[
'embedding_vectors'
]
=
embedding_vectors
else
:
rec_pred
,
rec_pred_scores
=
self
.
decoder
.
beam_search
(
x
,
self
.
beam_width
,
self
.
eos
,
embedding_vectors
)
return_dict
[
'rec_pred'
]
=
rec_pred
return_dict
[
'rec_pred_scores'
]
=
rec_pred_scores
return_dict
[
'embedding_vectors'
]
=
embedding_vectors
return
return_dict
class
Embedding
(
nn
.
Layer
):
def
__init__
(
self
,
in_timestep
,
in_planes
,
mid_dim
=
4096
,
embed_dim
=
300
):
super
(
Embedding
,
self
).
__init__
()
self
.
in_timestep
=
in_timestep
self
.
in_planes
=
in_planes
self
.
embed_dim
=
embed_dim
self
.
mid_dim
=
mid_dim
self
.
eEmbed
=
nn
.
Linear
(
in_timestep
*
in_planes
,
self
.
embed_dim
)
# Embed encoder output to a word-embedding like
def
forward
(
self
,
x
):
x
=
paddle
.
reshape
(
x
,
[
paddle
.
shape
(
x
)[
0
],
-
1
])
x
=
self
.
eEmbed
(
x
)
return
x
class
AttentionRecognitionHead
(
nn
.
Layer
):
"""
input: [b x 16 x 64 x in_planes]
output: probability sequence: [b x T x num_classes]
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
sDim
,
attDim
,
max_len_labels
):
super
(
AttentionRecognitionHead
,
self
).
__init__
()
self
.
num_classes
=
out_channels
# this is the output classes. So it includes the <EOS>.
self
.
in_planes
=
in_channels
self
.
sDim
=
sDim
self
.
attDim
=
attDim
self
.
max_len_labels
=
max_len_labels
self
.
decoder
=
DecoderUnit
(
sDim
=
sDim
,
xDim
=
in_channels
,
yDim
=
self
.
num_classes
,
attDim
=
attDim
)
def
forward
(
self
,
x
,
embed
):
x
,
targets
,
lengths
=
x
batch_size
=
paddle
.
shape
(
x
)[
0
]
# Decoder
state
=
self
.
decoder
.
get_initial_state
(
embed
)
outputs
=
[]
for
i
in
range
(
max
(
lengths
)):
if
i
==
0
:
y_prev
=
paddle
.
full
(
shape
=
[
batch_size
],
fill_value
=
self
.
num_classes
)
else
:
y_prev
=
targets
[:,
i
-
1
]
output
,
state
=
self
.
decoder
(
x
,
state
,
y_prev
)
outputs
.
append
(
output
)
outputs
=
paddle
.
concat
([
_
.
unsqueeze
(
1
)
for
_
in
outputs
],
1
)
return
outputs
# inference stage.
def
sample
(
self
,
x
):
x
,
_
,
_
=
x
batch_size
=
x
.
size
(
0
)
# Decoder
state
=
paddle
.
zeros
([
1
,
batch_size
,
self
.
sDim
])
predicted_ids
,
predicted_scores
=
[],
[]
for
i
in
range
(
self
.
max_len_labels
):
if
i
==
0
:
y_prev
=
paddle
.
full
(
shape
=
[
batch_size
],
fill_value
=
self
.
num_classes
)
else
:
y_prev
=
predicted
output
,
state
=
self
.
decoder
(
x
,
state
,
y_prev
)
output
=
F
.
softmax
(
output
,
axis
=
1
)
score
,
predicted
=
output
.
max
(
1
)
predicted_ids
.
append
(
predicted
.
unsqueeze
(
1
))
predicted_scores
.
append
(
score
.
unsqueeze
(
1
))
predicted_ids
=
paddle
.
concat
([
predicted_ids
,
1
])
predicted_scores
=
paddle
.
concat
([
predicted_scores
,
1
])
# return predicted_ids.squeeze(), predicted_scores.squeeze()
return
predicted_ids
,
predicted_scores
class
AttentionUnit
(
nn
.
Layer
):
def
__init__
(
self
,
sDim
,
xDim
,
attDim
):
super
(
AttentionUnit
,
self
).
__init__
()
self
.
sDim
=
sDim
self
.
xDim
=
xDim
self
.
attDim
=
attDim
self
.
sEmbed
=
nn
.
Linear
(
sDim
,
attDim
,
weight_attr
=
paddle
.
nn
.
initializer
.
Normal
(
std
=
0.01
),
bias_attr
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
))
self
.
xEmbed
=
nn
.
Linear
(
xDim
,
attDim
,
weight_attr
=
paddle
.
nn
.
initializer
.
Normal
(
std
=
0.01
),
bias_attr
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
))
self
.
wEmbed
=
nn
.
Linear
(
attDim
,
1
,
weight_attr
=
paddle
.
nn
.
initializer
.
Normal
(
std
=
0.01
),
bias_attr
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
))
def
forward
(
self
,
x
,
sPrev
):
batch_size
,
T
,
_
=
x
.
shape
# [b x T x xDim]
x
=
paddle
.
reshape
(
x
,
[
-
1
,
self
.
xDim
])
# [(b x T) x xDim]
xProj
=
self
.
xEmbed
(
x
)
# [(b x T) x attDim]
xProj
=
paddle
.
reshape
(
xProj
,
[
batch_size
,
T
,
-
1
])
# [b x T x attDim]
sPrev
=
sPrev
.
squeeze
(
0
)
sProj
=
self
.
sEmbed
(
sPrev
)
# [b x attDim]
sProj
=
paddle
.
unsqueeze
(
sProj
,
1
)
# [b x 1 x attDim]
sProj
=
paddle
.
expand
(
sProj
,
[
batch_size
,
T
,
self
.
attDim
])
# [b x T x attDim]
sumTanh
=
paddle
.
tanh
(
sProj
+
xProj
)
sumTanh
=
paddle
.
reshape
(
sumTanh
,
[
-
1
,
self
.
attDim
])
vProj
=
self
.
wEmbed
(
sumTanh
)
# [(b x T) x 1]
vProj
=
paddle
.
reshape
(
vProj
,
[
batch_size
,
T
])
alpha
=
F
.
softmax
(
vProj
,
axis
=
1
)
# attention weights for each sample in the minibatch
return
alpha
class
DecoderUnit
(
nn
.
Layer
):
def
__init__
(
self
,
sDim
,
xDim
,
yDim
,
attDim
):
super
(
DecoderUnit
,
self
).
__init__
()
self
.
sDim
=
sDim
self
.
xDim
=
xDim
self
.
yDim
=
yDim
self
.
attDim
=
attDim
self
.
emdDim
=
attDim
self
.
attention_unit
=
AttentionUnit
(
sDim
,
xDim
,
attDim
)
self
.
tgt_embedding
=
nn
.
Embedding
(
yDim
+
1
,
self
.
emdDim
,
weight_attr
=
nn
.
initializer
.
Normal
(
std
=
0.01
))
# the last is used for <BOS>
self
.
gru
=
nn
.
GRUCell
(
input_size
=
xDim
+
self
.
emdDim
,
hidden_size
=
sDim
)
self
.
fc
=
nn
.
Linear
(
sDim
,
yDim
,
weight_attr
=
nn
.
initializer
.
Normal
(
std
=
0.01
),
bias_attr
=
nn
.
initializer
.
Constant
(
value
=
0
))
self
.
embed_fc
=
nn
.
Linear
(
300
,
self
.
sDim
)
def
get_initial_state
(
self
,
embed
,
tile_times
=
1
):
assert
embed
.
shape
[
1
]
==
300
state
=
self
.
embed_fc
(
embed
)
# N * sDim
if
tile_times
!=
1
:
state
=
state
.
unsqueeze
(
1
)
trans_state
=
paddle
.
transpose
(
state
,
perm
=
[
1
,
0
,
2
])
state
=
paddle
.
tile
(
trans_state
,
repeat_times
=
[
tile_times
,
1
,
1
])
trans_state
=
paddle
.
transpose
(
state
,
perm
=
[
1
,
0
,
2
])
state
=
paddle
.
reshape
(
trans_state
,
shape
=
[
-
1
,
self
.
sDim
])
state
=
state
.
unsqueeze
(
0
)
# 1 * N * sDim
return
state
def
forward
(
self
,
x
,
sPrev
,
yPrev
):
# x: feature sequence from the image decoder.
batch_size
,
T
,
_
=
x
.
shape
alpha
=
self
.
attention_unit
(
x
,
sPrev
)
context
=
paddle
.
squeeze
(
paddle
.
matmul
(
alpha
.
unsqueeze
(
1
),
x
),
axis
=
1
)
yPrev
=
paddle
.
cast
(
yPrev
,
dtype
=
"int64"
)
yProj
=
self
.
tgt_embedding
(
yPrev
)
concat_context
=
paddle
.
concat
([
yProj
,
context
],
1
)
concat_context
=
paddle
.
squeeze
(
concat_context
,
1
)
sPrev
=
paddle
.
squeeze
(
sPrev
,
0
)
output
,
state
=
self
.
gru
(
concat_context
,
sPrev
)
output
=
paddle
.
squeeze
(
output
,
axis
=
1
)
output
=
self
.
fc
(
output
)
return
output
,
state
if
__name__
==
"__main__"
:
model
=
AttentionRecognitionHead
(
num_classes
=
20
,
in_channels
=
30
,
sDim
=
512
,
attDim
=
512
,
max_len_labels
=
25
,
out_channels
=
38
)
data
=
paddle
.
ones
([
16
,
64
,
3
])
targets
=
paddle
.
ones
([
16
,
25
])
length
=
paddle
.
to_tensor
(
20
)
x
=
[
data
,
targets
,
length
]
output
=
model
(
x
)
print
(
output
.
shape
)
ppocr/modeling/heads/rec_att_head.py
浏览文件 @
59cc4efd
...
...
@@ -44,10 +44,13 @@ class AttentionHead(nn.Layer):
hidden
=
paddle
.
zeros
((
batch_size
,
self
.
hidden_size
))
output_hiddens
=
[]
targets
=
targets
[
0
]
print
(
targets
)
if
targets
is
not
None
:
for
i
in
range
(
num_steps
):
char_onehots
=
self
.
_char_to_onehot
(
targets
[:,
i
],
onehot_dim
=
self
.
num_classes
)
# print("char_onehots:", char_onehots)
(
outputs
,
hidden
),
alpha
=
self
.
attention_cell
(
hidden
,
inputs
,
char_onehots
)
output_hiddens
.
append
(
paddle
.
unsqueeze
(
outputs
,
axis
=
1
))
...
...
@@ -104,6 +107,8 @@ class AttentionGRUCell(nn.Layer):
alpha
=
paddle
.
transpose
(
alpha
,
[
0
,
2
,
1
])
context
=
paddle
.
squeeze
(
paddle
.
mm
(
alpha
,
batch_H
),
axis
=
1
)
concat_context
=
paddle
.
concat
([
context
,
char_onehots
],
1
)
# print("concat_context:", concat_context.shape)
# print("prev_hidden:", prev_hidden.shape)
cur_hidden
=
self
.
rnn
(
concat_context
,
prev_hidden
)
...
...
ppocr/modeling/transforms/__init__.py
浏览文件 @
59cc4efd
...
...
@@ -17,8 +17,9 @@ __all__ = ['build_transform']
def
build_transform
(
config
):
from
.tps
import
TPS
from
.tps
import
STN_ON
support_dict
=
[
'TPS'
]
support_dict
=
[
'TPS'
,
'STN_ON'
]
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
...
...
ppocr/modeling/transforms/stn.py
0 → 100644
浏览文件 @
59cc4efd
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
from
paddle
import
nn
,
ParamAttr
from
paddle.nn
import
functional
as
F
import
numpy
as
np
def
conv3x3_block
(
in_channels
,
out_channels
,
stride
=
1
):
n
=
3
*
3
*
out_channels
w
=
math
.
sqrt
(
2.
/
n
)
conv_layer
=
nn
.
Conv2D
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
weight_attr
=
nn
.
initializer
.
Normal
(
mean
=
0.0
,
std
=
w
),
bias_attr
=
nn
.
initializer
.
Constant
(
0
))
block
=
nn
.
Sequential
(
conv_layer
,
nn
.
BatchNorm2D
(
out_channels
),
nn
.
ReLU
())
return
block
class
STN
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
num_ctrlpoints
,
activation
=
'none'
):
super
(
STN
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
num_ctrlpoints
=
num_ctrlpoints
self
.
activation
=
activation
self
.
stn_convnet
=
nn
.
Sequential
(
conv3x3_block
(
in_channels
,
32
),
#32x64
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
),
conv3x3_block
(
32
,
64
),
#16x32
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
),
conv3x3_block
(
64
,
128
),
# 8*16
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
),
conv3x3_block
(
128
,
256
),
# 4*8
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
),
conv3x3_block
(
256
,
256
),
# 2*4,
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
),
conv3x3_block
(
256
,
256
))
# 1*2
self
.
stn_fc1
=
nn
.
Sequential
(
nn
.
Linear
(
2
*
256
,
512
,
weight_attr
=
nn
.
initializer
.
Normal
(
0
,
0.001
),
bias_attr
=
nn
.
initializer
.
Constant
(
0
)),
nn
.
BatchNorm1D
(
512
),
nn
.
ReLU
())
fc2_bias
=
self
.
init_stn
()
self
.
stn_fc2
=
nn
.
Linear
(
512
,
num_ctrlpoints
*
2
,
weight_attr
=
nn
.
initializer
.
Constant
(
0.0
),
bias_attr
=
nn
.
initializer
.
Assign
(
fc2_bias
))
def
init_stn
(
self
):
margin
=
0.01
sampling_num_per_side
=
int
(
self
.
num_ctrlpoints
/
2
)
ctrl_pts_x
=
np
.
linspace
(
margin
,
1.
-
margin
,
sampling_num_per_side
)
ctrl_pts_y_top
=
np
.
ones
(
sampling_num_per_side
)
*
margin
ctrl_pts_y_bottom
=
np
.
ones
(
sampling_num_per_side
)
*
(
1
-
margin
)
ctrl_pts_top
=
np
.
stack
([
ctrl_pts_x
,
ctrl_pts_y_top
],
axis
=
1
)
ctrl_pts_bottom
=
np
.
stack
([
ctrl_pts_x
,
ctrl_pts_y_bottom
],
axis
=
1
)
ctrl_points
=
np
.
concatenate
(
[
ctrl_pts_top
,
ctrl_pts_bottom
],
axis
=
0
).
astype
(
np
.
float32
)
if
self
.
activation
==
'none'
:
pass
elif
self
.
activation
==
'sigmoid'
:
ctrl_points
=
-
np
.
log
(
1.
/
ctrl_points
-
1.
)
ctrl_points
=
paddle
.
to_tensor
(
ctrl_points
)
fc2_bias
=
paddle
.
reshape
(
ctrl_points
,
shape
=
[
ctrl_points
.
shape
[
0
]
*
ctrl_points
.
shape
[
1
]])
return
fc2_bias
def
forward
(
self
,
x
):
x
=
self
.
stn_convnet
(
x
)
batch_size
,
_
,
h
,
w
=
x
.
shape
x
=
paddle
.
reshape
(
x
,
shape
=
(
batch_size
,
-
1
))
img_feat
=
self
.
stn_fc1
(
x
)
x
=
self
.
stn_fc2
(
0.1
*
img_feat
)
if
self
.
activation
==
'sigmoid'
:
x
=
F
.
sigmoid
(
x
)
x
=
paddle
.
reshape
(
x
,
shape
=
[
-
1
,
self
.
num_ctrlpoints
,
2
])
return
img_feat
,
x
if
__name__
==
"__main__"
:
in_planes
=
3
num_ctrlpoints
=
20
np
.
random
.
seed
(
100
)
activation
=
'none'
# 'sigmoid'
stn_head
=
STN
(
in_planes
,
num_ctrlpoints
,
activation
)
data
=
np
.
random
.
randn
(
10
,
3
,
32
,
64
).
astype
(
"float32"
)
print
(
"data:"
,
np
.
sum
(
data
))
input
=
paddle
.
to_tensor
(
data
)
#input = paddle.randn([10, 3, 32, 64])
control_points
=
stn_head
(
input
)
ppocr/modeling/transforms/tps.py
浏览文件 @
59cc4efd
...
...
@@ -22,6 +22,9 @@ from paddle import nn, ParamAttr
from
paddle.nn
import
functional
as
F
import
numpy
as
np
from
.tps_spatial_transformer
import
TPSSpatialTransformer
from
.stn
import
STN
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
...
...
@@ -231,7 +234,8 @@ class GridGenerator(nn.Layer):
""" Return inv_delta_C which is needed to calculate T """
F
=
self
.
F
hat_eye
=
paddle
.
eye
(
F
,
dtype
=
'float64'
)
# F x F
hat_C
=
paddle
.
norm
(
C
.
reshape
([
1
,
F
,
2
])
-
C
.
reshape
([
F
,
1
,
2
]),
axis
=
2
)
+
hat_eye
hat_C
=
paddle
.
norm
(
C
.
reshape
([
1
,
F
,
2
])
-
C
.
reshape
([
F
,
1
,
2
]),
axis
=
2
)
+
hat_eye
hat_C
=
(
hat_C
**
2
)
*
paddle
.
log
(
hat_C
)
delta_C
=
paddle
.
concat
(
# F+3 x F+3
[
...
...
@@ -301,3 +305,26 @@ class TPS(nn.Layer):
[
-
1
,
image
.
shape
[
2
],
image
.
shape
[
3
],
2
])
batch_I_r
=
F
.
grid_sample
(
x
=
image
,
grid
=
batch_P_prime
)
return
batch_I_r
class
STN_ON
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
tps_inputsize
,
tps_outputsize
,
num_control_points
,
tps_margins
,
stn_activation
):
super
(
STN_ON
,
self
).
__init__
()
self
.
tps
=
TPSSpatialTransformer
(
output_image_size
=
tuple
(
tps_outputsize
),
num_control_points
=
num_control_points
,
margins
=
tuple
(
tps_margins
))
self
.
stn_head
=
STN
(
in_channels
=
in_channels
,
num_ctrlpoints
=
num_control_points
,
activation
=
stn_activation
)
self
.
tps_inputsize
=
tps_inputsize
self
.
out_channels
=
in_channels
def
forward
(
self
,
image
):
stn_input
=
paddle
.
nn
.
functional
.
interpolate
(
image
,
self
.
tps_inputsize
,
mode
=
"bilinear"
,
align_corners
=
True
)
stn_img_feat
,
ctrl_points
=
self
.
stn_head
(
stn_input
)
x
,
_
=
self
.
tps
(
image
,
ctrl_points
)
# print(x.shape)
return
x
ppocr/modeling/transforms/tps_spatial_transformer.py
0 → 100644
浏览文件 @
59cc4efd
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
from
paddle
import
nn
,
ParamAttr
from
paddle.nn
import
functional
as
F
import
numpy
as
np
import
itertools
def
grid_sample
(
input
,
grid
,
canvas
=
None
):
input
.
stop_gradient
=
False
output
=
F
.
grid_sample
(
input
,
grid
)
if
canvas
is
None
:
return
output
else
:
input_mask
=
paddle
.
ones
(
shape
=
input
.
shape
)
output_mask
=
F
.
grid_sample
(
input_mask
,
grid
)
padded_output
=
output
*
output_mask
+
canvas
*
(
1
-
output_mask
)
return
padded_output
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
def
compute_partial_repr
(
input_points
,
control_points
):
N
=
input_points
.
shape
[
0
]
M
=
control_points
.
shape
[
0
]
pairwise_diff
=
paddle
.
reshape
(
input_points
,
shape
=
[
N
,
1
,
2
])
-
paddle
.
reshape
(
control_points
,
shape
=
[
1
,
M
,
2
])
# original implementation, very slow
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
pairwise_diff_square
=
pairwise_diff
*
pairwise_diff
pairwise_dist
=
pairwise_diff_square
[:,
:,
0
]
+
pairwise_diff_square
[:,
:,
1
]
repr_matrix
=
0.5
*
pairwise_dist
*
paddle
.
log
(
pairwise_dist
)
# fix numerical error for 0 * log(0), substitute all nan with 0
mask
=
repr_matrix
!=
repr_matrix
repr_matrix
[
mask
]
=
0
return
repr_matrix
# output_ctrl_pts are specified, according to our task.
def
build_output_control_points
(
num_control_points
,
margins
):
margin_x
,
margin_y
=
margins
num_ctrl_pts_per_side
=
num_control_points
//
2
ctrl_pts_x
=
np
.
linspace
(
margin_x
,
1.0
-
margin_x
,
num_ctrl_pts_per_side
)
ctrl_pts_y_top
=
np
.
ones
(
num_ctrl_pts_per_side
)
*
margin_y
ctrl_pts_y_bottom
=
np
.
ones
(
num_ctrl_pts_per_side
)
*
(
1.0
-
margin_y
)
ctrl_pts_top
=
np
.
stack
([
ctrl_pts_x
,
ctrl_pts_y_top
],
axis
=
1
)
ctrl_pts_bottom
=
np
.
stack
([
ctrl_pts_x
,
ctrl_pts_y_bottom
],
axis
=
1
)
# ctrl_pts_top = ctrl_pts_top[1:-1,:]
# ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
output_ctrl_pts_arr
=
np
.
concatenate
(
[
ctrl_pts_top
,
ctrl_pts_bottom
],
axis
=
0
)
output_ctrl_pts
=
paddle
.
to_tensor
(
output_ctrl_pts_arr
)
return
output_ctrl_pts
class
TPSSpatialTransformer
(
nn
.
Layer
):
def
__init__
(
self
,
output_image_size
=
None
,
num_control_points
=
None
,
margins
=
None
):
super
(
TPSSpatialTransformer
,
self
).
__init__
()
self
.
output_image_size
=
output_image_size
self
.
num_control_points
=
num_control_points
self
.
margins
=
margins
self
.
target_height
,
self
.
target_width
=
output_image_size
target_control_points
=
build_output_control_points
(
num_control_points
,
margins
)
N
=
num_control_points
# N = N - 4
# create padded kernel matrix
forward_kernel
=
paddle
.
zeros
(
shape
=
[
N
+
3
,
N
+
3
])
target_control_partial_repr
=
compute_partial_repr
(
target_control_points
,
target_control_points
)
target_control_partial_repr
=
paddle
.
cast
(
target_control_partial_repr
,
forward_kernel
.
dtype
)
forward_kernel
[:
N
,
:
N
]
=
target_control_partial_repr
forward_kernel
[:
N
,
-
3
]
=
1
forward_kernel
[
-
3
,
:
N
]
=
1
target_control_points
=
paddle
.
cast
(
target_control_points
,
forward_kernel
.
dtype
)
forward_kernel
[:
N
,
-
2
:]
=
target_control_points
forward_kernel
[
-
2
:,
:
N
]
=
paddle
.
transpose
(
target_control_points
,
perm
=
[
1
,
0
])
# compute inverse matrix
inverse_kernel
=
paddle
.
inverse
(
forward_kernel
)
# create target cordinate matrix
HW
=
self
.
target_height
*
self
.
target_width
target_coordinate
=
list
(
itertools
.
product
(
range
(
self
.
target_height
),
range
(
self
.
target_width
)))
target_coordinate
=
paddle
.
to_tensor
(
target_coordinate
)
# HW x 2
Y
,
X
=
paddle
.
split
(
target_coordinate
,
target_coordinate
.
shape
[
1
],
axis
=
1
)
#Y, X = target_coordinate.split(1, dim = 1)
Y
=
Y
/
(
self
.
target_height
-
1
)
X
=
X
/
(
self
.
target_width
-
1
)
target_coordinate
=
paddle
.
concat
(
[
X
,
Y
],
axis
=
1
)
# convert from (y, x) to (x, y)
target_coordinate_partial_repr
=
compute_partial_repr
(
target_coordinate
,
target_control_points
)
target_coordinate_repr
=
paddle
.
concat
(
[
target_coordinate_partial_repr
,
paddle
.
ones
(
shape
=
[
HW
,
1
]),
target_coordinate
],
axis
=
1
)
# register precomputed matrices
self
.
inverse_kernel
=
inverse_kernel
self
.
padding_matrix
=
paddle
.
zeros
(
shape
=
[
3
,
2
])
self
.
target_coordinate_repr
=
target_coordinate_repr
self
.
target_control_points
=
target_control_points
def
forward
(
self
,
input
,
source_control_points
):
assert
source_control_points
.
ndimension
()
==
3
assert
source_control_points
.
shape
[
1
]
==
self
.
num_control_points
assert
source_control_points
.
shape
[
2
]
==
2
batch_size
=
source_control_points
.
shape
[
0
]
self
.
padding_matrix
=
paddle
.
expand
(
self
.
padding_matrix
,
shape
=
[
batch_size
,
3
,
2
])
Y
=
paddle
.
concat
([
source_control_points
,
self
.
padding_matrix
],
1
)
mapping_matrix
=
paddle
.
matmul
(
self
.
inverse_kernel
,
Y
)
source_coordinate
=
paddle
.
matmul
(
self
.
target_coordinate_repr
,
mapping_matrix
)
grid
=
paddle
.
reshape
(
source_coordinate
,
shape
=
[
-
1
,
self
.
target_height
,
self
.
target_width
,
2
])
grid
=
paddle
.
clip
(
grid
,
0
,
1
)
# the source_control_points may be out of [0, 1].
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
# grid = 2.0 * grid - 1.0
output_maps
=
grid_sample
(
input
,
grid
,
canvas
=
None
)
return
output_maps
,
source_coordinate
if
__name__
==
"__main__"
:
from
stn
import
STN
in_planes
=
3
num_ctrlpoints
=
20
np
.
random
.
seed
(
100
)
activation
=
'none'
# 'sigmoid'
stn_head
=
STN
(
in_planes
,
num_ctrlpoints
,
activation
)
data
=
np
.
random
.
randn
(
10
,
3
,
32
,
64
).
astype
(
"float32"
)
input
=
paddle
.
to_tensor
(
data
)
#input = paddle.randn([10, 3, 32, 64])
control_points
=
stn_head
(
input
)
#print("control points:", control_points)
#input = paddle.randn(shape=[10,3,32,100])
tps
=
TPSSpatialTransformer
(
output_image_size
=
[
32
,
320
],
num_control_points
=
20
,
margins
=
[
0.05
,
0.05
])
out
=
tps
(
input
,
control_points
[
1
])
print
(
"out 0 :"
,
out
[
0
].
shape
)
print
(
"out 1:"
,
out
[
1
].
shape
)
ppocr/modeling/transforms/tps_torch.py
0 → 100644
浏览文件 @
59cc4efd
from
__future__
import
absolute_import
import
numpy
as
np
import
itertools
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
def
grid_sample
(
input
,
grid
,
canvas
=
None
):
output
=
F
.
grid_sample
(
input
,
grid
)
if
canvas
is
None
:
return
output
else
:
input_mask
=
input
.
data
.
new
(
input
.
size
()).
fill_
(
1
)
output_mask
=
F
.
grid_sample
(
input_mask
,
grid
)
padded_output
=
output
*
output_mask
+
canvas
*
(
1
-
output_mask
)
return
padded_output
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
def
compute_partial_repr
(
input_points
,
control_points
):
N
=
input_points
.
size
(
0
)
M
=
control_points
.
size
(
0
)
pairwise_diff
=
input_points
.
view
(
N
,
1
,
2
)
-
control_points
.
view
(
1
,
M
,
2
)
# original implementation, very slow
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
pairwise_diff_square
=
pairwise_diff
*
pairwise_diff
pairwise_dist
=
pairwise_diff_square
[:,
:,
0
]
+
pairwise_diff_square
[:,
:,
1
]
repr_matrix
=
0.5
*
pairwise_dist
*
torch
.
log
(
pairwise_dist
)
# fix numerical error for 0 * log(0), substitute all nan with 0
mask
=
repr_matrix
!=
repr_matrix
repr_matrix
.
masked_fill_
(
mask
,
0
)
return
repr_matrix
# output_ctrl_pts are specified, according to our task.
def
build_output_control_points
(
num_control_points
,
margins
):
margin_x
,
margin_y
=
margins
num_ctrl_pts_per_side
=
num_control_points
//
2
ctrl_pts_x
=
np
.
linspace
(
margin_x
,
1.0
-
margin_x
,
num_ctrl_pts_per_side
)
ctrl_pts_y_top
=
np
.
ones
(
num_ctrl_pts_per_side
)
*
margin_y
ctrl_pts_y_bottom
=
np
.
ones
(
num_ctrl_pts_per_side
)
*
(
1.0
-
margin_y
)
ctrl_pts_top
=
np
.
stack
([
ctrl_pts_x
,
ctrl_pts_y_top
],
axis
=
1
)
ctrl_pts_bottom
=
np
.
stack
([
ctrl_pts_x
,
ctrl_pts_y_bottom
],
axis
=
1
)
# ctrl_pts_top = ctrl_pts_top[1:-1,:]
# ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
output_ctrl_pts_arr
=
np
.
concatenate
(
[
ctrl_pts_top
,
ctrl_pts_bottom
],
axis
=
0
)
output_ctrl_pts
=
torch
.
Tensor
(
output_ctrl_pts_arr
)
return
output_ctrl_pts
# demo: ~/test/models/test_tps_transformation.py
class
TPSSpatialTransformer
(
nn
.
Module
):
def
__init__
(
self
,
output_image_size
=
None
,
num_control_points
=
None
,
margins
=
None
):
super
(
TPSSpatialTransformer
,
self
).
__init__
()
self
.
output_image_size
=
output_image_size
self
.
num_control_points
=
num_control_points
self
.
margins
=
margins
self
.
target_height
,
self
.
target_width
=
output_image_size
target_control_points
=
build_output_control_points
(
num_control_points
,
margins
)
N
=
num_control_points
# N = N - 4
# create padded kernel matrix
forward_kernel
=
torch
.
zeros
(
N
+
3
,
N
+
3
)
target_control_partial_repr
=
compute_partial_repr
(
target_control_points
,
target_control_points
)
forward_kernel
[:
N
,
:
N
].
copy_
(
target_control_partial_repr
)
forward_kernel
[:
N
,
-
3
].
fill_
(
1
)
forward_kernel
[
-
3
,
:
N
].
fill_
(
1
)
forward_kernel
[:
N
,
-
2
:].
copy_
(
target_control_points
)
forward_kernel
[
-
2
:,
:
N
].
copy_
(
target_control_points
.
transpose
(
0
,
1
))
# compute inverse matrix
inverse_kernel
=
torch
.
inverse
(
forward_kernel
)
# create target cordinate matrix
HW
=
self
.
target_height
*
self
.
target_width
target_coordinate
=
list
(
itertools
.
product
(
range
(
self
.
target_height
),
range
(
self
.
target_width
)))
target_coordinate
=
torch
.
Tensor
(
target_coordinate
)
# HW x 2
Y
,
X
=
target_coordinate
.
split
(
1
,
dim
=
1
)
Y
=
Y
/
(
self
.
target_height
-
1
)
X
=
X
/
(
self
.
target_width
-
1
)
target_coordinate
=
torch
.
cat
([
X
,
Y
],
dim
=
1
)
# convert from (y, x) to (x, y)
target_coordinate_partial_repr
=
compute_partial_repr
(
target_coordinate
,
target_control_points
)
target_coordinate_repr
=
torch
.
cat
([
target_coordinate_partial_repr
,
torch
.
ones
(
HW
,
1
),
target_coordinate
],
dim
=
1
)
# register precomputed matrices
self
.
register_buffer
(
'inverse_kernel'
,
inverse_kernel
)
self
.
register_buffer
(
'padding_matrix'
,
torch
.
zeros
(
3
,
2
))
self
.
register_buffer
(
'target_coordinate_repr'
,
target_coordinate_repr
)
self
.
register_buffer
(
'target_control_points'
,
target_control_points
)
def
forward
(
self
,
input
,
source_control_points
):
assert
source_control_points
.
ndimension
()
==
3
assert
source_control_points
.
size
(
1
)
==
self
.
num_control_points
assert
source_control_points
.
size
(
2
)
==
2
batch_size
=
source_control_points
.
size
(
0
)
Y
=
torch
.
cat
([
source_control_points
,
self
.
padding_matrix
.
expand
(
batch_size
,
3
,
2
)
],
1
)
mapping_matrix
=
torch
.
matmul
(
self
.
inverse_kernel
,
Y
)
source_coordinate
=
torch
.
matmul
(
self
.
target_coordinate_repr
,
mapping_matrix
)
grid
=
source_coordinate
.
view
(
-
1
,
self
.
target_height
,
self
.
target_width
,
2
)
grid
=
torch
.
clamp
(
grid
,
0
,
1
)
# the source_control_points may be out of [0, 1].
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
grid
=
2.0
*
grid
-
1.0
output_maps
=
grid_sample
(
input
,
grid
,
canvas
=
None
)
return
output_maps
,
source_coordinate
if
__name__
==
"__main__"
:
from
stn_torch
import
STNHead
in_planes
=
3
num_ctrlpoints
=
20
torch
.
manual_seed
(
10
)
activation
=
'none'
# 'sigmoid'
stn_head
=
STNHead
(
in_planes
,
num_ctrlpoints
,
activation
)
np
.
random
.
seed
(
100
)
data
=
np
.
random
.
randn
(
10
,
3
,
32
,
64
).
astype
(
"float32"
)
input
=
torch
.
tensor
(
data
)
control_points
=
stn_head
(
input
)
tps
=
TPSSpatialTransformer
(
output_image_size
=
[
32
,
320
],
num_control_points
=
20
,
margins
=
[
0.05
,
0.05
])
out
=
tps
(
input
,
control_points
[
1
])
print
(
"out 0 :"
,
out
[
0
].
shape
)
print
(
"out 1:"
,
out
[
1
].
shape
)
ppocr/postprocess/rec_postprocess.py
浏览文件 @
59cc4efd
...
...
@@ -170,8 +170,10 @@ class AttnLabelDecode(BaseRecLabelDecode):
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
self
.
unkonwn
=
"UNKNOWN"
dict_character
=
dict_character
dict_character
=
[
self
.
beg_str
]
+
dict_character
+
[
self
.
end_str
]
dict_character
=
[
self
.
beg_str
]
+
dict_character
+
[
self
.
end_str
]
+
[
self
.
unkonwn
]
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
...
...
@@ -212,6 +214,7 @@ class AttnLabelDecode(BaseRecLabelDecode):
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
preds
=
preds
[
"rec_pred"
]
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
...
...
@@ -324,10 +327,9 @@ class SRNLabelDecode(BaseRecLabelDecode):
class
TableLabelDecode
(
object
):
""" """
def
__init__
(
self
,
character_dict_path
,
**
kwargs
):
list_character
,
list_elem
=
self
.
load_char_elem_dict
(
character_dict_path
)
def
__init__
(
self
,
character_dict_path
,
**
kwargs
):
list_character
,
list_elem
=
self
.
load_char_elem_dict
(
character_dict_path
)
list_character
=
self
.
add_special_char
(
list_character
)
list_elem
=
self
.
add_special_char
(
list_elem
)
self
.
dict_character
=
{}
...
...
@@ -366,14 +368,14 @@ class TableLabelDecode(object):
def
__call__
(
self
,
preds
):
structure_probs
=
preds
[
'structure_probs'
]
loc_preds
=
preds
[
'loc_preds'
]
if
isinstance
(
structure_probs
,
paddle
.
Tensor
):
if
isinstance
(
structure_probs
,
paddle
.
Tensor
):
structure_probs
=
structure_probs
.
numpy
()
if
isinstance
(
loc_preds
,
paddle
.
Tensor
):
if
isinstance
(
loc_preds
,
paddle
.
Tensor
):
loc_preds
=
loc_preds
.
numpy
()
structure_idx
=
structure_probs
.
argmax
(
axis
=
2
)
structure_probs
=
structure_probs
.
max
(
axis
=
2
)
structure_str
,
structure_pos
,
result_score_list
,
result_elem_idx_list
=
self
.
decode
(
structure_idx
,
structure_probs
,
'elem'
)
structure_str
,
structure_pos
,
result_score_list
,
result_elem_idx_list
=
self
.
decode
(
structure_idx
,
structure_probs
,
'elem'
)
res_html_code_list
=
[]
res_loc_list
=
[]
batch_num
=
len
(
structure_str
)
...
...
@@ -388,8 +390,13 @@ class TableLabelDecode(object):
res_loc
=
np
.
array
(
res_loc
)
res_html_code_list
.
append
(
res_html_code
)
res_loc_list
.
append
(
res_loc
)
return
{
'res_html_code'
:
res_html_code_list
,
'res_loc'
:
res_loc_list
,
'res_score_list'
:
result_score_list
,
'res_elem_idx_list'
:
result_elem_idx_list
,
'structure_str_list'
:
structure_str
}
return
{
'res_html_code'
:
res_html_code_list
,
'res_loc'
:
res_loc_list
,
'res_score_list'
:
result_score_list
,
'res_elem_idx_list'
:
result_elem_idx_list
,
'structure_str_list'
:
structure_str
}
def
decode
(
self
,
text_index
,
structure_probs
,
char_or_elem
):
"""convert text-label into text-index.
...
...
ppocr/utils/save_load.py
浏览文件 @
59cc4efd
...
...
@@ -105,13 +105,16 @@ def load_dygraph_params(config, model, logger, optimizer):
params
=
paddle
.
load
(
pm
)
state_dict
=
model
.
state_dict
()
new_state_dict
=
{}
for
k1
,
k2
in
zip
(
state_dict
.
keys
(),
params
.
keys
()):
if
list
(
state_dict
[
k1
].
shape
)
==
list
(
params
[
k2
].
shape
):
new_state_dict
[
k1
]
=
params
[
k2
]
else
:
logger
.
info
(
f
"The shape of model params
{
k1
}
{
state_dict
[
k1
].
shape
}
not matched with loaded params
{
k2
}
{
params
[
k2
].
shape
}
!"
)
# for k1, k2 in zip(state_dict.keys(), params.keys()):
for
k1
in
state_dict
.
keys
():
if
k1
not
in
params
:
continue
if
list
(
state_dict
[
k1
].
shape
)
==
list
(
params
[
k1
].
shape
):
new_state_dict
[
k1
]
=
params
[
k1
]
else
:
logger
.
info
(
f
"The shape of model params
{
k1
}
{
state_dict
[
k1
].
shape
}
not matched with loaded params
{
k1
}
{
params
[
k1
].
shape
}
!"
)
model
.
set_state_dict
(
new_state_dict
)
logger
.
info
(
f
"loaded pretrained_model successful from
{
pm
}
"
)
return
{}
...
...
tools/program.py
浏览文件 @
59cc4efd
...
...
@@ -187,6 +187,7 @@ def train(config,
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
model_type
=
config
[
'Architecture'
][
'model_type'
]
algorithm
=
config
[
'Architecture'
][
'algorithm'
]
if
'start_epoch'
in
best_model_dict
:
start_epoch
=
best_model_dict
[
'start_epoch'
]
...
...
@@ -210,10 +211,14 @@ def train(config,
images
=
batch
[
0
]
if
use_srn
:
model_average
=
True
if
use_srn
or
model_type
==
'table'
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
preds
=
model
(
images
)
# if use_srn or model_type == 'table' or algorithm == "ASTER":
# preds = model(images, data=batch[1:])
# else:
# preds = model(images)
preds
=
model
(
images
,
data
=
batch
[
1
:])
state_dict
=
model
.
state_dict
()
# for key in state_dict:
# print(key)
loss
=
loss_class
(
preds
,
batch
)
avg_loss
=
loss
[
'loss'
]
avg_loss
.
backward
()
...
...
@@ -395,7 +400,7 @@ def preprocess(is_train=False):
alg
=
config
[
'Architecture'
][
'algorithm'
]
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'TableAttn'
'CLS'
,
'PGNet'
,
'Distillation'
,
'TableAttn'
,
'ASTER'
]
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
...
...
tools/train.py
浏览文件 @
59cc4efd
...
...
@@ -72,6 +72,8 @@ def main(config, device, logger, vdl_writer):
# for rec algorithm
if
hasattr
(
post_process_class
,
'character'
):
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
character
=
getattr
(
post_process_class
,
'character'
)
print
(
"getattr character:"
,
character
)
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
for
key
in
config
[
'Architecture'
][
"Models"
]:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录