Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
ffa94415
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ffa94415
编写于
8月 24, 2021
作者:
A
andyjpaddle
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add_rec_sar, test=dygraph
上级
d49699fb
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
768 addition
and
0 deletion
+768
-0
configs/rec/rec_r31_sar.yml
configs/rec/rec_r31_sar.yml
+99
-0
ppocr/losses/rec_sar_loss.py
ppocr/losses/rec_sar_loss.py
+25
-0
ppocr/modeling/backbones/rec_resnet_31.py
ppocr/modeling/backbones/rec_resnet_31.py
+176
-0
ppocr/modeling/heads/rec_sar_head.py
ppocr/modeling/heads/rec_sar_head.py
+378
-0
ppocr/utils/dict90.txt
ppocr/utils/dict90.txt
+90
-0
未找到文件。
configs/rec/rec_r31_sar.yml
0 → 100644
浏览文件 @
ffa94415
Global
:
use_gpu
:
true
epoch_num
:
5
log_smooth_window
:
20
print_batch_step
:
20
save_model_dir
:
/paddle/backup/sar_rec/sar_train_v3
save_epoch_step
:
1
# evaluation is run every 2000 iterations
eval_batch_step
:
[
0
,
2000
]
cal_metric_during_train
:
True
pretrained_model
:
#/paddle/backup/sar_rec/sar_train_v2/best_accuracy
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
demo_text_recog.jpg
# for data or label process
character_dict_path
:
ppocr/utils/dict90.txt
character_type
:
ch
max_text_length
:
30
infer_mode
:
False
use_space_char
:
False
save_res_path
:
./output/rec/predicts_sar.txt
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
Piecewise
decay_epochs
:
[
3
,
4
]
values
:
[
0.001
,
0.0001
,
0.00001
]
regularizer
:
name
:
'
L2'
factor
:
0
Architecture
:
model_type
:
rec
algorithm
:
SAR
Transform
:
Backbone
:
name
:
ResNet31
Head
:
name
:
SARHead
Loss
:
name
:
SARLoss
PostProcess
:
name
:
SARLabelDecode
Metric
:
name
:
RecMetric
Train
:
dataset
:
name
:
LMDBDataSet
#SimpleDataSet
# delimiter: ' '
# label_file_list: ['/paddle/data/concat_data/icdar_2013_train20.txt', '/paddle/data/concat_data/icdar_2015_train20.txt', '/paddle/data/concat_data/coco_text_train20.txt', '/paddle/data/concat_data/IIIt5k_train20.txt', '/paddle/data/concat_data/SynthAdd_train.txt', '/paddle/data/concat_data/SynthText_train.txt', '/paddle/data/concat_data/Syn90k_train.txt']
data_dir
:
/paddle/data/ocr_data/training/
#/paddle/data/concat_data/
# ratio_list: 1.0
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
SARLabelEncode
:
# Class handling label
-
SARRecResizeImg
:
image_shape
:
[
3
,
48
,
48
,
160
]
# h:48 w:[48,160]
width_downsample_ratio
:
0.25
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
valid_ratio'
]
# dataloader will return list in this order
loader
:
shuffle
:
True
batch_size_per_card
:
64
# 32
drop_last
:
True
num_workers
:
8
use_shared_memory
:
False
Eval
:
dataset
:
name
:
LMDBDataSet
data_dir
:
/paddle/data/ocr_data/evaluation/
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
SARLabelEncode
:
# Class handling label
-
SARRecResizeImg
:
image_shape
:
[
3
,
48
,
48
,
160
]
width_downsample_ratio
:
0.25
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
valid_ratio'
]
# dataloader will return list in this order
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
64
num_workers
:
4
use_shared_memory
:
False
\ No newline at end of file
ppocr/losses/rec_sar_loss.py
0 → 100644
浏览文件 @
ffa94415
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
nn
class
SARLoss
(
nn
.
Layer
):
def
__init__
(
self
,
**
kwargs
):
super
(
SARLoss
,
self
).
__init__
()
self
.
loss_func
=
paddle
.
nn
.
loss
.
CrossEntropyLoss
(
reduction
=
"mean"
,
ignore_index
=
92
)
def
forward
(
self
,
predicts
,
batch
):
predict
=
predicts
[:,
:
-
1
,
:]
# ignore last index of outputs to be in same seq_len with targets
label
=
batch
[
1
].
astype
(
"int64"
)[:,
1
:]
# ignore first index of target in loss calculation
batch_size
,
num_steps
,
num_classes
=
predict
.
shape
[
0
],
predict
.
shape
[
1
],
predict
.
shape
[
2
]
assert
len
(
label
.
shape
)
==
len
(
list
(
predict
.
shape
))
-
1
,
\
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
inputs
=
paddle
.
reshape
(
predict
,
[
-
1
,
num_classes
])
targets
=
paddle
.
reshape
(
label
,
[
-
1
])
loss
=
self
.
loss_func
(
inputs
,
targets
)
return
{
'loss'
:
loss
}
ppocr/modeling/backbones/rec_resnet_31.py
0 → 100644
浏览文件 @
ffa94415
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
import
numpy
as
np
__all__
=
[
"ResNet31"
]
def
conv3x3
(
in_channel
,
out_channel
,
stride
=
1
):
return
nn
.
Conv2D
(
in_channel
,
out_channel
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias_attr
=
False
)
class
BasicBlock
(
nn
.
Layer
):
expansion
=
1
def
__init__
(
self
,
in_channels
,
channels
,
stride
=
1
,
downsample
=
False
):
super
().
__init__
()
self
.
conv1
=
conv3x3
(
in_channels
,
channels
,
stride
)
self
.
bn1
=
nn
.
BatchNorm2D
(
channels
)
self
.
relu
=
nn
.
ReLU
()
self
.
conv2
=
conv3x3
(
channels
,
channels
)
self
.
bn2
=
nn
.
BatchNorm2D
(
channels
)
self
.
downsample
=
downsample
if
downsample
:
self
.
downsample
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
,
channels
*
self
.
expansion
,
1
,
stride
,
bias_attr
=
False
),
nn
.
BatchNorm2D
(
channels
*
self
.
expansion
),
)
else
:
self
.
downsample
=
nn
.
Sequential
()
self
.
stride
=
stride
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
if
self
.
downsample
:
residual
=
self
.
downsample
(
x
)
out
+=
residual
out
=
self
.
relu
(
out
)
return
out
class
ResNet31
(
nn
.
Layer
):
'''
Args:
in_channels (int): Number of channels of input image tensor.
layers (list[int]): List of BasicBlock number for each stage.
channels (list[int]): List of out_channels of Conv2d layer.
out_indices (None | Sequence[int]): Indices of output stages.
last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
'''
def
__init__
(
self
,
in_channels
=
3
,
layers
=
[
1
,
2
,
5
,
3
],
channels
=
[
64
,
128
,
256
,
256
,
512
,
512
,
512
],
out_indices
=
None
,
last_stage_pool
=
False
):
super
(
ResNet31
,
self
).
__init__
()
assert
isinstance
(
in_channels
,
int
)
assert
isinstance
(
last_stage_pool
,
bool
)
self
.
out_indices
=
out_indices
self
.
last_stage_pool
=
last_stage_pool
# conv 1 (Conv Conv)
self
.
conv1_1
=
nn
.
Conv2D
(
in_channels
,
channels
[
0
],
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
bn1_1
=
nn
.
BatchNorm2D
(
channels
[
0
])
self
.
relu1_1
=
nn
.
ReLU
()
self
.
conv1_2
=
nn
.
Conv2D
(
channels
[
0
],
channels
[
1
],
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
bn1_2
=
nn
.
BatchNorm2D
(
channels
[
1
])
self
.
relu1_2
=
nn
.
ReLU
()
# conv 2 (Max-pooling, Residual block, Conv)
self
.
pool2
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
ceil_mode
=
True
)
self
.
block2
=
self
.
_make_layer
(
channels
[
1
],
channels
[
2
],
layers
[
0
])
self
.
conv2
=
nn
.
Conv2D
(
channels
[
2
],
channels
[
2
],
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
bn2
=
nn
.
BatchNorm2D
(
channels
[
2
])
self
.
relu2
=
nn
.
ReLU
()
# conv 3 (Max-pooling, Residual block, Conv)
self
.
pool3
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
ceil_mode
=
True
)
self
.
block3
=
self
.
_make_layer
(
channels
[
2
],
channels
[
3
],
layers
[
1
])
self
.
conv3
=
nn
.
Conv2D
(
channels
[
3
],
channels
[
3
],
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
bn3
=
nn
.
BatchNorm2D
(
channels
[
3
])
self
.
relu3
=
nn
.
ReLU
()
# conv 4 (Max-pooling, Residual block, Conv)
self
.
pool4
=
nn
.
MaxPool2D
(
kernel_size
=
(
2
,
1
),
stride
=
(
2
,
1
),
padding
=
0
,
ceil_mode
=
True
)
self
.
block4
=
self
.
_make_layer
(
channels
[
3
],
channels
[
4
],
layers
[
2
])
self
.
conv4
=
nn
.
Conv2D
(
channels
[
4
],
channels
[
4
],
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
bn4
=
nn
.
BatchNorm2D
(
channels
[
4
])
self
.
relu4
=
nn
.
ReLU
()
# conv 5 ((Max-pooling), Residual block, Conv)
self
.
pool5
=
None
if
self
.
last_stage_pool
:
self
.
pool5
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
ceil_mode
=
True
)
self
.
block5
=
self
.
_make_layer
(
channels
[
4
],
channels
[
5
],
layers
[
3
])
self
.
conv5
=
nn
.
Conv2D
(
channels
[
5
],
channels
[
5
],
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
bn5
=
nn
.
BatchNorm2D
(
channels
[
5
])
self
.
relu5
=
nn
.
ReLU
()
self
.
out_channels
=
channels
[
-
1
]
def
_make_layer
(
self
,
input_channels
,
output_channels
,
blocks
):
layers
=
[]
for
_
in
range
(
blocks
):
downsample
=
None
if
input_channels
!=
output_channels
:
downsample
=
nn
.
Sequential
(
nn
.
Conv2D
(
input_channels
,
output_channels
,
kernel_size
=
1
,
stride
=
1
,
bias_attr
=
False
),
nn
.
BatchNorm2D
(
output_channels
),
)
layers
.
append
(
BasicBlock
(
input_channels
,
output_channels
,
downsample
=
downsample
))
input_channels
=
output_channels
return
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
x
=
self
.
conv1_1
(
x
)
x
=
self
.
bn1_1
(
x
)
x
=
self
.
relu1_1
(
x
)
x
=
self
.
conv1_2
(
x
)
x
=
self
.
bn1_2
(
x
)
x
=
self
.
relu1_2
(
x
)
outs
=
[]
for
i
in
range
(
4
):
layer_index
=
i
+
2
pool_layer
=
getattr
(
self
,
f
'pool
{
layer_index
}
'
)
block_layer
=
getattr
(
self
,
f
'block
{
layer_index
}
'
)
conv_layer
=
getattr
(
self
,
f
'conv
{
layer_index
}
'
)
bn_layer
=
getattr
(
self
,
f
'bn
{
layer_index
}
'
)
relu_layer
=
getattr
(
self
,
f
'relu
{
layer_index
}
'
)
if
pool_layer
is
not
None
:
x
=
pool_layer
(
x
)
x
=
block_layer
(
x
)
x
=
conv_layer
(
x
)
x
=
bn_layer
(
x
)
x
=
relu_layer
(
x
)
outs
.
append
(
x
)
if
self
.
out_indices
is
not
None
:
return
tuple
([
outs
[
i
]
for
i
in
self
.
out_indices
])
return
x
ppocr/modeling/heads/rec_sar_head.py
0 → 100644
浏览文件 @
ffa94415
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
class
SAREncoder
(
nn
.
Layer
):
"""
Args:
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
enc_gru (bool): If True, use GRU, else LSTM in encoder.
d_model (int): Dim of channels from backbone.
d_enc (int): Dim of encoder RNN layer.
mask (bool): If True, mask padding in RNN sequence.
"""
def
__init__
(
self
,
enc_bi_rnn
=
False
,
enc_drop_rnn
=
0.1
,
enc_gru
=
False
,
d_model
=
512
,
d_enc
=
512
,
mask
=
True
,
**
kwargs
):
super
().
__init__
()
assert
isinstance
(
enc_bi_rnn
,
bool
)
assert
isinstance
(
enc_drop_rnn
,
(
int
,
float
))
assert
0
<=
enc_drop_rnn
<
1.0
assert
isinstance
(
enc_gru
,
bool
)
assert
isinstance
(
d_model
,
int
)
assert
isinstance
(
d_enc
,
int
)
assert
isinstance
(
mask
,
bool
)
self
.
enc_bi_rnn
=
enc_bi_rnn
self
.
enc_drop_rnn
=
enc_drop_rnn
self
.
mask
=
mask
# LSTM Encoder
if
enc_bi_rnn
:
direction
=
'bidirectional'
else
:
direction
=
'forward'
kwargs
=
dict
(
input_size
=
d_model
,
hidden_size
=
d_enc
,
num_layers
=
2
,
time_major
=
False
,
dropout
=
enc_drop_rnn
,
direction
=
direction
)
if
enc_gru
:
self
.
rnn_encoder
=
nn
.
GRU
(
**
kwargs
)
else
:
self
.
rnn_encoder
=
nn
.
LSTM
(
**
kwargs
)
# global feature transformation
encoder_rnn_out_size
=
d_enc
*
(
int
(
enc_bi_rnn
)
+
1
)
self
.
linear
=
nn
.
Linear
(
encoder_rnn_out_size
,
encoder_rnn_out_size
)
def
forward
(
self
,
feat
,
img_metas
=
None
):
if
img_metas
is
not
None
:
assert
len
(
img_metas
[
0
])
==
feat
.
shape
[
0
]
valid_ratios
=
None
if
img_metas
is
not
None
and
self
.
mask
:
valid_ratios
=
img_metas
[
-
1
]
h_feat
=
feat
.
shape
[
2
]
# bsz c h w
feat_v
=
F
.
max_pool2d
(
feat
,
kernel_size
=
(
h_feat
,
1
),
stride
=
1
,
padding
=
0
)
feat_v
=
feat_v
.
squeeze
(
2
)
# bsz * C * W
feat_v
=
paddle
.
transpose
(
feat_v
,
perm
=
[
0
,
2
,
1
])
# bsz * W * C
holistic_feat
=
self
.
rnn_encoder
(
feat_v
)[
0
]
# bsz * T * C
if
valid_ratios
is
not
None
:
valid_hf
=
[]
T
=
holistic_feat
.
shape
[
1
]
for
i
,
valid_ratio
in
enumerate
(
valid_ratios
):
valid_step
=
min
(
T
,
math
.
ceil
(
T
*
valid_ratio
))
-
1
valid_hf
.
append
(
holistic_feat
[
i
,
valid_step
,
:])
valid_hf
=
paddle
.
stack
(
valid_hf
,
axis
=
0
)
else
:
valid_hf
=
holistic_feat
[:,
-
1
,
:]
# bsz * C
holistic_feat
=
self
.
linear
(
valid_hf
)
# bsz * C
return
holistic_feat
class
BaseDecoder
(
nn
.
Layer
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
()
def
forward_train
(
self
,
feat
,
out_enc
,
targets
,
img_metas
):
raise
NotImplementedError
def
forward_test
(
self
,
feat
,
out_enc
,
img_metas
):
raise
NotImplementedError
def
forward
(
self
,
feat
,
out_enc
,
label
=
None
,
img_metas
=
None
,
train_mode
=
True
):
self
.
train_mode
=
train_mode
if
train_mode
:
return
self
.
forward_train
(
feat
,
out_enc
,
label
,
img_metas
)
return
self
.
forward_test
(
feat
,
out_enc
,
img_metas
)
class
ParallelSARDecoder
(
BaseDecoder
):
"""
Args:
num_classes (int): Output class number.
channels (list[int]): Network layer channels.
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
dec_drop_rnn (float): Dropout of RNN layer in decoder.
dec_gru (bool): If True, use GRU, else LSTM in decoder.
d_model (int): Dim of channels from backbone.
d_enc (int): Dim of encoder RNN layer.
d_k (int): Dim of channels of attention module.
pred_dropout (float): Dropout probability of prediction layer.
max_seq_len (int): Maximum sequence length for decoding.
mask (bool): If True, mask padding in feature map.
start_idx (int): Index of start token.
padding_idx (int): Index of padding token.
pred_concat (bool): If True, concat glimpse feature from
attention with holistic feature and hidden state.
"""
def
__init__
(
self
,
num_classes
=
93
,
# 90 + unknown + start + padding
enc_bi_rnn
=
False
,
dec_bi_rnn
=
False
,
dec_drop_rnn
=
0.0
,
dec_gru
=
False
,
d_model
=
512
,
d_enc
=
512
,
d_k
=
64
,
pred_dropout
=
0.1
,
max_text_length
=
30
,
mask
=
True
,
start_idx
=
91
,
padding_idx
=
92
,
# 92
pred_concat
=
True
,
**
kwargs
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
enc_bi_rnn
=
enc_bi_rnn
self
.
d_k
=
d_k
self
.
start_idx
=
start_idx
self
.
max_seq_len
=
max_text_length
self
.
mask
=
mask
self
.
pred_concat
=
pred_concat
encoder_rnn_out_size
=
d_enc
*
(
int
(
enc_bi_rnn
)
+
1
)
decoder_rnn_out_size
=
encoder_rnn_out_size
*
(
int
(
dec_bi_rnn
)
+
1
)
# 2D attention layer
self
.
conv1x1_1
=
nn
.
Linear
(
decoder_rnn_out_size
,
d_k
)
self
.
conv3x3_1
=
nn
.
Conv2D
(
d_model
,
d_k
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv1x1_2
=
nn
.
Linear
(
d_k
,
1
)
# Decoder RNN layer
if
dec_bi_rnn
:
direction
=
'bidirectional'
else
:
direction
=
'forward'
kwargs
=
dict
(
input_size
=
encoder_rnn_out_size
,
hidden_size
=
encoder_rnn_out_size
,
num_layers
=
2
,
time_major
=
False
,
dropout
=
dec_drop_rnn
,
direction
=
direction
)
if
dec_gru
:
self
.
rnn_decoder
=
nn
.
GRU
(
**
kwargs
)
else
:
self
.
rnn_decoder
=
nn
.
LSTM
(
**
kwargs
)
# Decoder input embedding
self
.
embedding
=
nn
.
Embedding
(
self
.
num_classes
,
encoder_rnn_out_size
,
padding_idx
=
padding_idx
)
# Prediction layer
self
.
pred_dropout
=
nn
.
Dropout
(
pred_dropout
)
pred_num_classes
=
num_classes
-
1
if
pred_concat
:
fc_in_channel
=
decoder_rnn_out_size
+
d_model
+
d_enc
else
:
fc_in_channel
=
d_model
self
.
prediction
=
nn
.
Linear
(
fc_in_channel
,
pred_num_classes
)
def
_2d_attention
(
self
,
decoder_input
,
feat
,
holistic_feat
,
valid_ratios
=
None
):
y
=
self
.
rnn_decoder
(
decoder_input
)[
0
]
# y: bsz * (seq_len + 1) * hidden_size
attn_query
=
self
.
conv1x1_1
(
y
)
# bsz * (seq_len + 1) * attn_size
bsz
,
seq_len
,
attn_size
=
attn_query
.
shape
attn_query
=
paddle
.
unsqueeze
(
attn_query
,
axis
=
[
3
,
4
])
# (bsz, seq_len + 1, attn_size, 1, 1)
attn_key
=
self
.
conv3x3_1
(
feat
)
# bsz * attn_size * h * w
attn_key
=
attn_key
.
unsqueeze
(
1
)
# bsz * 1 * attn_size * h * w
attn_weight
=
paddle
.
tanh
(
paddle
.
add
(
attn_key
,
attn_query
))
# bsz * (seq_len + 1) * attn_size * h * w
attn_weight
=
paddle
.
transpose
(
attn_weight
,
perm
=
[
0
,
1
,
3
,
4
,
2
])
# bsz * (seq_len + 1) * h * w * attn_size
attn_weight
=
self
.
conv1x1_2
(
attn_weight
)
# bsz * (seq_len + 1) * h * w * 1
bsz
,
T
,
h
,
w
,
c
=
attn_weight
.
shape
assert
c
==
1
if
valid_ratios
is
not
None
:
# cal mask of attention weight
for
i
,
valid_ratio
in
enumerate
(
valid_ratios
):
valid_width
=
min
(
w
,
math
.
ceil
(
w
*
valid_ratio
))
attn_weight
[
i
,
:,
:,
valid_width
:,
:]
=
float
(
'-inf'
)
attn_weight
=
paddle
.
reshape
(
attn_weight
,
[
bsz
,
T
,
-
1
])
attn_weight
=
F
.
softmax
(
attn_weight
,
axis
=-
1
)
attn_weight
=
paddle
.
reshape
(
attn_weight
,
[
bsz
,
T
,
h
,
w
,
c
])
attn_weight
=
paddle
.
transpose
(
attn_weight
,
perm
=
[
0
,
1
,
4
,
2
,
3
])
# attn_weight: bsz * T * c * h * w
# feat: bsz * c * h * w
attn_feat
=
paddle
.
sum
(
paddle
.
multiply
(
feat
.
unsqueeze
(
1
),
attn_weight
),
(
3
,
4
),
keepdim
=
False
)
# bsz * (seq_len + 1) * C
# Linear transformation
if
self
.
pred_concat
:
hf_c
=
holistic_feat
.
shape
[
-
1
]
holistic_feat
=
paddle
.
expand
(
holistic_feat
,
shape
=
[
bsz
,
seq_len
,
hf_c
])
y
=
self
.
prediction
(
paddle
.
concat
((
y
,
attn_feat
,
holistic_feat
),
2
))
else
:
y
=
self
.
prediction
(
attn_feat
)
# bsz * (seq_len + 1) * num_classes
if
self
.
train_mode
:
y
=
self
.
pred_dropout
(
y
)
return
y
def
forward_train
(
self
,
feat
,
out_enc
,
label
,
img_metas
):
'''
img_metas: [label, valid_ratio]
'''
if
img_metas
is
not
None
:
assert
len
(
img_metas
[
0
])
==
feat
.
shape
[
0
]
valid_ratios
=
None
if
img_metas
is
not
None
and
self
.
mask
:
valid_ratios
=
img_metas
[
-
1
]
label
=
label
.
cuda
()
lab_embedding
=
self
.
embedding
(
label
)
# bsz * seq_len * emb_dim
out_enc
=
out_enc
.
unsqueeze
(
1
)
# bsz * 1 * emb_dim
in_dec
=
paddle
.
concat
((
out_enc
,
lab_embedding
),
axis
=
1
)
# bsz * (seq_len + 1) * C
out_dec
=
self
.
_2d_attention
(
in_dec
,
feat
,
out_enc
,
valid_ratios
=
valid_ratios
)
# bsz * (seq_len + 1) * num_classes
return
out_dec
[:,
1
:,
:]
# bsz * seq_len * num_classes
def
forward_test
(
self
,
feat
,
out_enc
,
img_metas
):
if
img_metas
is
not
None
:
assert
len
(
img_metas
[
0
])
==
feat
.
shape
[
0
]
valid_ratios
=
None
if
img_metas
is
not
None
and
self
.
mask
:
valid_ratios
=
img_metas
[
-
1
]
seq_len
=
self
.
max_seq_len
bsz
=
feat
.
shape
[
0
]
start_token
=
paddle
.
full
((
bsz
,
),
fill_value
=
self
.
start_idx
,
dtype
=
'int64'
)
# bsz
start_token
=
self
.
embedding
(
start_token
)
# bsz * emb_dim
emb_dim
=
start_token
.
shape
[
1
]
start_token
=
start_token
.
unsqueeze
(
1
)
start_token
=
paddle
.
expand
(
start_token
,
shape
=
[
bsz
,
seq_len
,
emb_dim
])
# bsz * seq_len * emb_dim
out_enc
=
out_enc
.
unsqueeze
(
1
)
# bsz * 1 * emb_dim
decoder_input
=
paddle
.
concat
((
out_enc
,
start_token
),
axis
=
1
)
# bsz * (seq_len + 1) * emb_dim
outputs
=
[]
for
i
in
range
(
1
,
seq_len
+
1
):
decoder_output
=
self
.
_2d_attention
(
decoder_input
,
feat
,
out_enc
,
valid_ratios
=
valid_ratios
)
char_output
=
decoder_output
[:,
i
,
:]
# bsz * num_classes
char_output
=
F
.
softmax
(
char_output
,
-
1
)
outputs
.
append
(
char_output
)
max_idx
=
paddle
.
argmax
(
char_output
,
axis
=
1
,
keepdim
=
False
)
char_embedding
=
self
.
embedding
(
max_idx
)
# bsz * emb_dim
if
i
<
seq_len
:
decoder_input
[:,
i
+
1
,
:]
=
char_embedding
outputs
=
paddle
.
stack
(
outputs
,
1
)
# bsz * seq_len * num_classes
return
outputs
class
SARHead
(
nn
.
Layer
):
def
__init__
(
self
,
enc_bi_rnn
=
False
,
enc_drop_rnn
=
0.1
,
enc_gru
=
False
,
dec_bi_rnn
=
False
,
dec_drop_rnn
=
0.0
,
dec_gru
=
False
,
d_k
=
512
,
pred_dropout
=
0.1
,
max_text_length
=
30
,
pred_concat
=
True
,
**
kwargs
):
super
(
SARHead
,
self
).
__init__
()
# encoder module
self
.
encoder
=
SAREncoder
(
enc_bi_rnn
=
enc_bi_rnn
,
enc_drop_rnn
=
enc_drop_rnn
,
enc_gru
=
enc_gru
)
# decoder module
self
.
decoder
=
ParallelSARDecoder
(
enc_bi_rnn
=
enc_bi_rnn
,
dec_bi_rnn
=
dec_bi_rnn
,
dec_drop_rnn
=
dec_drop_rnn
,
dec_gru
=
dec_gru
,
d_k
=
d_k
,
pred_dropout
=
pred_dropout
,
max_text_length
=
max_text_length
,
pred_concat
=
pred_concat
)
def
forward
(
self
,
feat
,
targets
=
None
):
'''
img_metas: [label, valid_ratio]
'''
holistic_feat
=
self
.
encoder
(
feat
,
targets
)
# bsz c
if
self
.
training
:
label
=
targets
[
0
]
# label
label
=
paddle
.
to_tensor
(
label
,
dtype
=
'int64'
)
final_out
=
self
.
decoder
(
feat
,
holistic_feat
,
label
,
img_metas
=
targets
)
if
not
self
.
training
:
final_out
=
self
.
decoder
(
feat
,
holistic_feat
,
label
=
None
,
img_metas
=
targets
,
train_mode
=
False
)
# (bsz, seq_len, num_classes)
return
final_out
\ No newline at end of file
ppocr/utils/dict90.txt
0 → 100644
浏览文件 @
ffa94415
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
:
;
<
=
>
?
@
[
\
]
_
`
~
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录