Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
c9e1077d
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c9e1077d
编写于
8月 30, 2021
作者:
T
tink2123
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish code
上级
59cc4efd
变更
23
展开全部
隐藏空白更改
内联
并排
Showing
23 changed file
with
461 addition
and
1020 deletion
+461
-1020
configs/rec/rec_resnet_stn_bilstm_att.yml
configs/rec/rec_resnet_stn_bilstm_att.yml
+35
-30
ppocr/data/imaug/__init__.py
ppocr/data/imaug/__init__.py
+1
-1
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+34
-4
ppocr/data/imaug/operators.py
ppocr/data/imaug/operators.py
+14
-2
ppocr/data/imaug/rec_img_aug.py
ppocr/data/imaug/rec_img_aug.py
+23
-0
ppocr/data/simple_dataset.py
ppocr/data/simple_dataset.py
+0
-1
ppocr/losses/rec_aster_loss.py
ppocr/losses/rec_aster_loss.py
+37
-18
ppocr/losses/rec_att_loss.py
ppocr/losses/rec_att_loss.py
+0
-2
ppocr/metrics/rec_metric.py
ppocr/metrics/rec_metric.py
+11
-1
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+4
-3
ppocr/modeling/backbones/levit.py
ppocr/modeling/backbones/levit.py
+0
-707
ppocr/modeling/heads/__init__.py
ppocr/modeling/heads/__init__.py
+0
-1
ppocr/modeling/heads/rec_aster_head.py
ppocr/modeling/heads/rec_aster_head.py
+170
-38
ppocr/modeling/heads/rec_att_head.py
ppocr/modeling/heads/rec_att_head.py
+0
-5
ppocr/modeling/transforms/stn.py
ppocr/modeling/transforms/stn.py
+0
-13
ppocr/modeling/transforms/tps.py
ppocr/modeling/transforms/tps.py
+1
-0
ppocr/modeling/transforms/tps_spatial_transformer.py
ppocr/modeling/transforms/tps_spatial_transformer.py
+3
-24
ppocr/modeling/transforms/tps_torch.py
ppocr/modeling/transforms/tps_torch.py
+0
-149
ppocr/optimizer/optimizer.py
ppocr/optimizer/optimizer.py
+31
-0
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+2
-2
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+83
-4
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+7
-10
tools/program.py
tools/program.py
+5
-5
未找到文件。
configs/rec/rec_resnet_stn_bilstm_att.yml
浏览文件 @
c9e1077d
Global
:
use_gpu
:
Fals
e
use_gpu
:
Tru
e
epoch_num
:
400
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec/
b3_rare_r34_none_gru/
save_model_dir
:
./output/rec/
seed
save_epoch_step
:
3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
0
,
2000
]
...
...
@@ -12,28 +12,32 @@ Global:
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/imgs_words
/ch/word_1.jp
g
infer_img
:
doc/imgs_words
_en/word_10.pn
g
# for data or label process
character_dict_path
:
character_type
:
EN_symbol
max_text_length
:
25
max_text_length
:
100
infer_mode
:
False
use_space_char
:
False
save_res_path
:
./output/rec/predicts_b3_rare_r34_none_gru.txt
eval_filter
:
True
save_res_path
:
./output/rec/predicts_seed.txt
Optimizer
:
name
:
Ada
m
beta1
:
0.9
beta2
:
0.99
9
name
:
Ada
delta
weight_deacy
:
0.0
momentum
:
0.
9
lr
:
learning_rate
:
0.0005
name
:
Piecewise
decay_epochs
:
[
4
,
5
,
8
]
values
:
[
1.0
,
0.1
,
0.01
]
regularizer
:
name
:
'
L2'
factor
:
0.00000
factor
:
2.0e-05
Architecture
:
model_type
:
rec
model_type
:
seed
algorithm
:
ASTER
Transform
:
name
:
STN_ON
...
...
@@ -54,48 +58,49 @@ Loss:
name
:
AsterLoss
PostProcess
:
name
:
Attn
LabelDecode
name
:
SEED
LabelDecode
Metric
:
name
:
RecMetric
main_indicator
:
acc
is_filter
:
True
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/ic15_data/
label_file_list
:
[
"
./train_data/ic15_data/1.txt"
]
name
:
LMDBDataSet
data_dir
:
./train_data/data_lmdb_release/training/
transforms
:
-
Fasttext
:
path
:
"
./cc.en.300.bin"
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
Attn
LabelEncode
:
# Class handling label
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
100
]
-
SEED
LabelEncode
:
# Class handling label
-
SEEDResize
:
image_shape
:
[
3
,
64
,
256
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
,
'
fast_label'
]
# dataloader will return list in this order
loader
:
shuffle
:
True
batch_size_per_card
:
2
batch_size_per_card
:
2
56
drop_last
:
True
num_workers
:
8
num_workers
:
6
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/ic15_data/
label_file_list
:
[
"
./train_data/ic15_data/1.txt"
]
name
:
LMDBDataSet
data_dir
:
./train_data/data_lmdb_release/evaluation/
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
Attn
LabelEncode
:
# Class handling label
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
100
]
-
SEED
LabelEncode
:
# Class handling label
-
SEEDResize
:
image_shape
:
[
3
,
64
,
256
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
loader
:
shuffle
:
False
drop_last
:
Fals
e
batch_size_per_card
:
2
num_workers
:
8
drop_last
:
Tru
e
batch_size_per_card
:
2
56
num_workers
:
4
ppocr/data/imaug/__init__.py
浏览文件 @
c9e1077d
...
...
@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
from
.make_shrink_map
import
MakeShrinkMap
from
.random_crop_data
import
EastRandomCropData
,
PSERandomCrop
from
.rec_img_aug
import
RecAug
,
RecResizeImg
,
ClsResizeImg
,
SRNRecResizeImg
from
.rec_img_aug
import
RecAug
,
RecResizeImg
,
ClsResizeImg
,
SRNRecResizeImg
,
SEEDResize
from
.randaugment
import
RandAugment
from
.copy_paste
import
CopyPaste
from
.operators
import
*
...
...
ppocr/data/imaug/label_ops.py
浏览文件 @
c9e1077d
...
...
@@ -276,9 +276,7 @@ class AttnLabelEncode(BaseRecLabelEncode):
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
self
.
unknown
=
"UNKNOWN"
dict_character
=
[
self
.
beg_str
]
+
dict_character
+
[
self
.
end_str
]
+
[
self
.
unknown
]
dict_character
=
[
self
.
beg_str
]
+
dict_character
+
[
self
.
end_str
]
return
dict_character
def
__call__
(
self
,
data
):
...
...
@@ -291,7 +289,6 @@ class AttnLabelEncode(BaseRecLabelEncode):
data
[
'length'
]
=
np
.
array
(
len
(
text
))
text
=
[
0
]
+
text
+
[
len
(
self
.
character
)
-
1
]
+
[
0
]
*
(
self
.
max_text_len
-
len
(
text
)
-
2
)
data
[
'label'
]
=
np
.
array
(
text
)
return
data
...
...
@@ -311,6 +308,39 @@ class AttnLabelEncode(BaseRecLabelEncode):
return
idx
class
SEEDLabelEncode
(
BaseRecLabelEncode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
max_text_length
,
character_dict_path
=
None
,
character_type
=
'ch'
,
use_space_char
=
False
,
**
kwargs
):
super
(
SEEDLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
character_type
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
dict_character
=
dict_character
+
[
self
.
end_str
]
return
dict_character
def
__call__
(
self
,
data
):
text
=
data
[
'label'
]
text
=
self
.
encode
(
text
)
if
text
is
None
:
return
None
if
len
(
text
)
>=
self
.
max_text_len
:
return
None
data
[
'length'
]
=
np
.
array
(
len
(
text
))
+
1
# conclue eos
text
=
text
+
[
len
(
self
.
character
)
-
1
]
*
(
self
.
max_text_len
-
len
(
text
)
)
data
[
'label'
]
=
np
.
array
(
text
)
return
data
class
SRNLabelEncode
(
BaseRecLabelEncode
):
""" Convert between text-label and text-index """
...
...
ppocr/data/imaug/operators.py
浏览文件 @
c9e1077d
...
...
@@ -23,6 +23,7 @@ import sys
import
six
import
cv2
import
numpy
as
np
import
fasttext
class
DecodeImage
(
object
):
...
...
@@ -81,7 +82,7 @@ class NormalizeImage(object):
assert
isinstance
(
img
,
np
.
ndarray
),
"invalid input 'img' in NormalizeImage"
data
[
'image'
]
=
(
img
.
astype
(
'float32'
)
*
self
.
scale
-
self
.
mean
)
/
self
.
std
img
.
astype
(
'float32'
)
*
self
.
scale
-
self
.
mean
)
/
self
.
std
return
data
...
...
@@ -101,6 +102,17 @@ class ToCHWImage(object):
return
data
class
Fasttext
(
object
):
def
__init__
(
self
,
path
=
"None"
,
**
kwargs
):
self
.
fast_model
=
fasttext
.
load_model
(
path
)
def
__call__
(
self
,
data
):
label
=
data
[
'label'
]
fast_label
=
self
.
fast_model
[
label
]
data
[
'fast_label'
]
=
fast_label
return
data
class
KeepKeys
(
object
):
def
__init__
(
self
,
keep_keys
,
**
kwargs
):
self
.
keep_keys
=
keep_keys
...
...
@@ -183,7 +195,7 @@ class DetResizeForTest(object):
else
:
ratio
=
1.
elif
self
.
limit_type
==
'resize_long'
:
ratio
=
float
(
limit_side_len
)
/
max
(
h
,
w
)
ratio
=
float
(
limit_side_len
)
/
max
(
h
,
w
)
else
:
raise
Exception
(
'not support limit type, image '
)
resize_h
=
int
(
h
*
ratio
)
...
...
ppocr/data/imaug/rec_img_aug.py
浏览文件 @
c9e1077d
...
...
@@ -63,6 +63,18 @@ class RecResizeImg(object):
return
data
class
SEEDResize
(
object
):
def
__init__
(
self
,
image_shape
,
infer_mode
=
False
,
**
kwargs
):
self
.
image_shape
=
image_shape
self
.
infer_mode
=
infer_mode
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
norm_img
=
resize_no_padding_img
(
img
,
self
.
image_shape
)
data
[
'image'
]
=
norm_img
return
data
class
SRNRecResizeImg
(
object
):
def
__init__
(
self
,
image_shape
,
num_heads
,
max_text_length
,
**
kwargs
):
self
.
image_shape
=
image_shape
...
...
@@ -106,6 +118,17 @@ def resize_norm_img(img, image_shape):
return
padding_im
def
resize_no_padding_img
(
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
resized_image
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
return
resized_image
def
resize_norm_img_chinese
(
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
# todo: change to 0 and modified image shape
...
...
ppocr/data/simple_dataset.py
浏览文件 @
c9e1077d
...
...
@@ -22,7 +22,6 @@ from .imaug import transform, create_operators
class
SimpleDataSet
(
Dataset
):
def
__init__
(
self
,
config
,
mode
,
logger
,
seed
=
None
):
print
(
"===== simpledataset ========"
)
super
(
SimpleDataSet
,
self
).
__init__
()
self
.
logger
=
logger
self
.
mode
=
mode
.
lower
()
...
...
ppocr/losses/rec_aster_loss.py
浏览文件 @
c9e1077d
...
...
@@ -18,7 +18,26 @@ from __future__ import print_function
import
paddle
from
paddle
import
nn
import
fasttext
class
CosineEmbeddingLoss
(
nn
.
Layer
):
def
__init__
(
self
,
margin
=
0.
):
super
(
CosineEmbeddingLoss
,
self
).
__init__
()
self
.
margin
=
margin
def
forward
(
self
,
x1
,
x2
,
target
):
similarity
=
paddle
.
fluid
.
layers
.
reduce_sum
(
x1
*
x2
,
dim
=-
1
)
/
(
paddle
.
norm
(
x1
,
axis
=-
1
)
*
paddle
.
norm
(
x2
,
axis
=-
1
))
one_list
=
paddle
.
full_like
(
target
,
fill_value
=
1
)
out
=
paddle
.
fluid
.
layers
.
reduce_mean
(
paddle
.
where
(
paddle
.
equal
(
target
,
one_list
),
1.
-
similarity
,
paddle
.
maximum
(
paddle
.
zeros_like
(
similarity
),
similarity
-
self
.
margin
)))
return
out
class
AsterLoss
(
nn
.
Layer
):
...
...
@@ -35,28 +54,28 @@ class AsterLoss(nn.Layer):
self
.
ignore_index
=
ignore_index
self
.
sequence_normalize
=
sequence_normalize
self
.
sample_normalize
=
sample_normalize
self
.
loss_func
=
paddle
.
nn
.
CosineSimilarity
()
self
.
loss_sem
=
CosineEmbeddingLoss
()
self
.
is_cosin_loss
=
True
self
.
loss_func_rec
=
nn
.
CrossEntropyLoss
(
weight
=
None
,
reduction
=
'none'
)
def
forward
(
self
,
predicts
,
batch
):
targets
=
batch
[
1
].
astype
(
"int64"
)
label_lengths
=
batch
[
2
].
astype
(
'int64'
)
#
sem_target = batch[3].astype('float32')
sem_target
=
batch
[
3
].
astype
(
'float32'
)
embedding_vectors
=
predicts
[
'embedding_vectors'
]
rec_pred
=
predicts
[
'rec_pred'
]
# semantic loss
# print(embedding_vectors)
# print(embedding_vectors.shape)
# targets = fasttext[targets]
# sem_loss = 1 - self.loss_func(embedding_vectors, targets)
if
not
self
.
is_cosin_loss
:
sem_loss
=
paddle
.
sum
(
self
.
loss_sem
(
embedding_vectors
,
sem_target
))
else
:
label_target
=
paddle
.
ones
([
embedding_vectors
.
shape
[
0
]])
sem_loss
=
paddle
.
sum
(
self
.
loss_sem
(
embedding_vectors
,
sem_target
,
label_target
))
# rec loss
batch_size
,
num_steps
,
num_classes
=
rec_pred
.
shape
[
0
],
rec_pred
.
shape
[
1
],
rec_pred
.
shape
[
2
]
assert
len
(
targets
.
shape
)
==
len
(
list
(
rec_pred
.
shape
))
-
1
,
\
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
batch_size
,
def_max_length
=
targets
.
shape
[
0
],
targets
.
shape
[
1
]
mask
=
paddle
.
zeros
([
batch_size
,
num_steps
])
mask
=
paddle
.
zeros
([
batch_size
,
def_max_length
])
for
i
in
range
(
batch_size
):
mask
[
i
,
:
label_lengths
[
i
]]
=
1
mask
=
paddle
.
cast
(
mask
,
"float32"
)
...
...
@@ -64,16 +83,16 @@ class AsterLoss(nn.Layer):
assert
max_length
==
rec_pred
.
shape
[
1
]
targets
=
targets
[:,
:
max_length
]
mask
=
mask
[:,
:
max_length
]
rec_pred
=
paddle
.
reshape
(
rec_pred
,
[
-
1
,
rec_pred
.
shape
[
-
1
]])
rec_pred
=
paddle
.
reshape
(
rec_pred
,
[
-
1
,
rec_pred
.
shape
[
2
]])
input
=
nn
.
functional
.
log_softmax
(
rec_pred
,
axis
=
1
)
targets
=
paddle
.
reshape
(
targets
,
[
-
1
,
1
])
mask
=
paddle
.
reshape
(
mask
,
[
-
1
,
1
])
# print("input:", input)
output
=
-
paddle
.
gather
(
input
,
index
=
targets
,
axis
=
1
)
*
mask
output
=
-
paddle
.
index_sample
(
input
,
index
=
targets
)
*
mask
output
=
paddle
.
sum
(
output
)
if
self
.
sequence_normalize
:
output
=
output
/
paddle
.
sum
(
mask
)
if
self
.
sample_normalize
:
output
=
output
/
batch_size
loss
=
output
return
{
'loss'
:
loss
}
# , 'sem_loss':sem_loss}
loss
=
output
+
sem_loss
*
0.1
return
{
'loss'
:
loss
}
ppocr/losses/rec_att_loss.py
浏览文件 @
c9e1077d
...
...
@@ -35,7 +35,5 @@ class AttentionLoss(nn.Layer):
inputs
=
paddle
.
reshape
(
predicts
,
[
-
1
,
predicts
.
shape
[
-
1
]])
targets
=
paddle
.
reshape
(
targets
,
[
-
1
])
print
(
"input:"
,
paddle
.
argmax
(
inputs
,
axis
=
1
))
print
(
"targets:"
,
targets
)
return
{
'loss'
:
paddle
.
sum
(
self
.
loss_func
(
inputs
,
targets
))}
ppocr/metrics/rec_metric.py
浏览文件 @
c9e1077d
...
...
@@ -13,13 +13,20 @@
# limitations under the License.
import
Levenshtein
import
string
class
RecMetric
(
object
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
**
kwargs
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
is_filter
=
False
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
is_filter
=
is_filter
self
.
reset
()
def
_normalize_text
(
self
,
text
):
text
=
''
.
join
(
filter
(
lambda
x
:
x
in
(
string
.
digits
+
string
.
ascii_letters
),
text
))
return
text
.
lower
()
def
__call__
(
self
,
pred_label
,
*
args
,
**
kwargs
):
preds
,
labels
=
pred_label
correct_num
=
0
...
...
@@ -28,6 +35,9 @@ class RecMetric(object):
for
(
pred
,
pred_conf
),
(
target
,
_
)
in
zip
(
preds
,
labels
):
pred
=
pred
.
replace
(
" "
,
""
)
target
=
target
.
replace
(
" "
,
""
)
if
self
.
is_filter
:
pred
=
self
.
_normalize_text
(
pred
)
target
=
self
.
_normalize_text
(
target
)
norm_edit_dis
+=
Levenshtein
.
distance
(
pred
,
target
)
/
max
(
len
(
pred
),
len
(
target
),
1
)
if
pred
==
target
:
...
...
ppocr/modeling/backbones/__init__.py
浏览文件 @
c9e1077d
...
...
@@ -26,10 +26,8 @@ def build_backbone(config, model_type):
from
.rec_resnet_vd
import
ResNet
from
.rec_resnet_fpn
import
ResNetFPN
from
.rec_mv1_enhance
import
MobileNetV1Enhance
from
.rec_resnet_aster
import
ResNet_ASTER
support_dict
=
[
"MobileNetV1Enhance"
,
"MobileNetV3"
,
"ResNet"
,
"ResNetFPN"
,
"ResNet_ASTER"
"MobileNetV1Enhance"
,
"MobileNetV3"
,
"ResNet"
,
"ResNetFPN"
]
elif
model_type
==
"e2e"
:
from
.e2e_resnet_vd_pg
import
ResNet
...
...
@@ -38,6 +36,9 @@ def build_backbone(config, model_type):
from
.table_resnet_vd
import
ResNet
from
.table_mobilenet_v3
import
MobileNetV3
support_dict
=
[
"ResNet"
,
"MobileNetV3"
]
elif
model_type
==
"seed"
:
from
.rec_resnet_aster
import
ResNet_ASTER
support_dict
=
[
"ResNet_ASTER"
]
else
:
raise
NotImplementedError
...
...
ppocr/modeling/backbones/levit.py
已删除
100644 → 0
浏览文件 @
59cc4efd
此差异已折叠。
点击以展开。
ppocr/modeling/heads/__init__.py
浏览文件 @
c9e1077d
...
...
@@ -42,6 +42,5 @@ def build_head(config):
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
'head only support {}'
.
format
(
support_dict
))
print
(
config
)
module_class
=
eval
(
module_name
)(
**
config
)
return
module_class
ppocr/modeling/heads/rec_aster_head.py
浏览文件 @
c9e1077d
...
...
@@ -43,13 +43,14 @@ class AsterHead(nn.Layer):
self
.
time_step
=
time_step
self
.
embeder
=
Embedding
(
self
.
time_step
,
in_channels
)
self
.
beam_width
=
beam_width
self
.
eos
=
self
.
num_classes
-
1
def
forward
(
self
,
x
,
targets
=
None
,
embed
=
None
):
return_dict
=
{}
embedding_vectors
=
self
.
embeder
(
x
)
rec_targets
,
rec_lengths
=
targets
if
self
.
training
:
rec_targets
,
rec_lengths
,
_
=
targets
rec_pred
=
self
.
decoder
([
x
,
rec_targets
,
rec_lengths
],
embedding_vectors
)
return_dict
[
'rec_pred'
]
=
rec_pred
...
...
@@ -104,14 +105,12 @@ class AttentionRecognitionHead(nn.Layer):
# Decoder
state
=
self
.
decoder
.
get_initial_state
(
embed
)
outputs
=
[]
for
i
in
range
(
max
(
lengths
)):
if
i
==
0
:
y_prev
=
paddle
.
full
(
shape
=
[
batch_size
],
fill_value
=
self
.
num_classes
)
else
:
y_prev
=
targets
[:,
i
-
1
]
output
,
state
=
self
.
decoder
(
x
,
state
,
y_prev
)
outputs
.
append
(
output
)
outputs
=
paddle
.
concat
([
_
.
unsqueeze
(
1
)
for
_
in
outputs
],
1
)
...
...
@@ -142,6 +141,170 @@ class AttentionRecognitionHead(nn.Layer):
# return predicted_ids.squeeze(), predicted_scores.squeeze()
return
predicted_ids
,
predicted_scores
def
beam_search
(
self
,
x
,
beam_width
,
eos
,
embed
):
def
_inflate
(
tensor
,
times
,
dim
):
repeat_dims
=
[
1
]
*
tensor
.
dim
()
repeat_dims
[
dim
]
=
times
output
=
paddle
.
tile
(
tensor
,
repeat_dims
)
return
output
# https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py
batch_size
,
l
,
d
=
x
.
shape
# inflated_encoder_feats = _inflate(encoder_feats, beam_width, 0) # ABC --> AABBCC -/-> ABCABC
x
=
paddle
.
tile
(
paddle
.
transpose
(
x
.
unsqueeze
(
1
),
perm
=
[
1
,
0
,
2
,
3
]),
[
beam_width
,
1
,
1
,
1
])
inflated_encoder_feats
=
paddle
.
reshape
(
paddle
.
transpose
(
x
,
perm
=
[
1
,
0
,
2
,
3
]),
[
-
1
,
l
,
d
])
# Initialize the decoder
state
=
self
.
decoder
.
get_initial_state
(
embed
,
tile_times
=
beam_width
)
pos_index
=
paddle
.
reshape
(
paddle
.
arange
(
batch_size
)
*
beam_width
,
shape
=
[
-
1
,
1
])
# Initialize the scores
sequence_scores
=
paddle
.
full
(
shape
=
[
batch_size
*
beam_width
,
1
],
fill_value
=-
float
(
'Inf'
))
index
=
[
i
*
beam_width
for
i
in
range
(
0
,
batch_size
)]
sequence_scores
[
index
]
=
0.0
# Initialize the input vector
y_prev
=
paddle
.
full
(
shape
=
[
batch_size
*
beam_width
],
fill_value
=
self
.
num_classes
)
# Store decisions for backtracking
stored_scores
=
list
()
stored_predecessors
=
list
()
stored_emitted_symbols
=
list
()
for
i
in
range
(
self
.
max_len_labels
):
output
,
state
=
self
.
decoder
(
inflated_encoder_feats
,
state
,
y_prev
)
state
=
paddle
.
unsqueeze
(
state
,
axis
=
0
)
log_softmax_output
=
paddle
.
nn
.
functional
.
log_softmax
(
output
,
axis
=
1
)
sequence_scores
=
_inflate
(
sequence_scores
,
self
.
num_classes
,
1
)
sequence_scores
+=
log_softmax_output
scores
,
candidates
=
paddle
.
topk
(
paddle
.
reshape
(
sequence_scores
,
[
batch_size
,
-
1
]),
beam_width
,
axis
=
1
)
# Reshape input = (bk, 1) and sequence_scores = (bk, 1)
y_prev
=
paddle
.
reshape
(
candidates
%
self
.
num_classes
,
shape
=
[
batch_size
*
beam_width
])
sequence_scores
=
paddle
.
reshape
(
scores
,
shape
=
[
batch_size
*
beam_width
,
1
])
# Update fields for next timestep
pos_index
=
paddle
.
expand_as
(
pos_index
,
candidates
)
predecessors
=
paddle
.
cast
(
candidates
/
self
.
num_classes
+
pos_index
,
dtype
=
'int64'
)
predecessors
=
paddle
.
reshape
(
predecessors
,
shape
=
[
batch_size
*
beam_width
,
1
])
state
=
paddle
.
index_select
(
state
,
index
=
predecessors
.
squeeze
(),
axis
=
1
)
# Update sequence socres and erase scores for <eos> symbol so that they aren't expanded
stored_scores
.
append
(
sequence_scores
.
clone
())
y_prev
=
paddle
.
reshape
(
y_prev
,
shape
=
[
-
1
,
1
])
eos_prev
=
paddle
.
full_like
(
y_prev
,
fill_value
=
eos
)
mask
=
eos_prev
==
y_prev
mask
=
paddle
.
nonzero
(
mask
)
if
mask
.
dim
()
>
0
:
sequence_scores
=
sequence_scores
.
numpy
()
mask
=
mask
.
numpy
()
sequence_scores
[
mask
]
=
-
float
(
'inf'
)
sequence_scores
=
paddle
.
to_tensor
(
sequence_scores
)
# Cache results for backtracking
stored_predecessors
.
append
(
predecessors
)
y_prev
=
paddle
.
squeeze
(
y_prev
)
stored_emitted_symbols
.
append
(
y_prev
)
# Do backtracking to return the optimal values
#====== backtrak ======#
# Initialize return variables given different types
p
=
list
()
l
=
[[
self
.
max_len_labels
]
*
beam_width
for
_
in
range
(
batch_size
)
]
# Placeholder for lengths of top-k sequences
# the last step output of the beams are not sorted
# thus they are sorted here
sorted_score
,
sorted_idx
=
paddle
.
topk
(
paddle
.
reshape
(
stored_scores
[
-
1
],
shape
=
[
batch_size
,
beam_width
]),
beam_width
)
# initialize the sequence scores with the sorted last step beam scores
s
=
sorted_score
.
clone
()
batch_eos_found
=
[
0
]
*
batch_size
# the number of EOS found
# in the backward loop below for each batch
t
=
self
.
max_len_labels
-
1
# initialize the back pointer with the sorted order of the last step beams.
# add pos_index for indexing variable with b*k as the first dimension.
t_predecessors
=
paddle
.
reshape
(
sorted_idx
+
pos_index
.
expand_as
(
sorted_idx
),
shape
=
[
batch_size
*
beam_width
])
while
t
>=
0
:
# Re-order the variables with the back pointer
current_symbol
=
paddle
.
index_select
(
stored_emitted_symbols
[
t
],
index
=
t_predecessors
,
axis
=
0
)
t_predecessors
=
paddle
.
index_select
(
stored_predecessors
[
t
].
squeeze
(),
index
=
t_predecessors
,
axis
=
0
)
eos_indices
=
stored_emitted_symbols
[
t
]
==
eos
eos_indices
=
paddle
.
nonzero
(
eos_indices
)
if
eos_indices
.
dim
()
>
0
:
for
i
in
range
(
eos_indices
.
shape
[
0
]
-
1
,
-
1
,
-
1
):
# Indices of the EOS symbol for both variables
# with b*k as the first dimension, and b, k for
# the first two dimensions
idx
=
eos_indices
[
i
]
b_idx
=
int
(
idx
[
0
]
/
beam_width
)
# The indices of the replacing position
# according to the replacement strategy noted above
res_k_idx
=
beam_width
-
(
batch_eos_found
[
b_idx
]
%
beam_width
)
-
1
batch_eos_found
[
b_idx
]
+=
1
res_idx
=
b_idx
*
beam_width
+
res_k_idx
# Replace the old information in return variables
# with the new ended sequence information
t_predecessors
[
res_idx
]
=
stored_predecessors
[
t
][
idx
[
0
]]
current_symbol
[
res_idx
]
=
stored_emitted_symbols
[
t
][
idx
[
0
]]
s
[
b_idx
,
res_k_idx
]
=
stored_scores
[
t
][
idx
[
0
],
0
]
l
[
b_idx
][
res_k_idx
]
=
t
+
1
# record the back tracked results
p
.
append
(
current_symbol
)
t
-=
1
# Sort and re-order again as the added ended sequences may change
# the order (very unlikely)
s
,
re_sorted_idx
=
s
.
topk
(
beam_width
)
for
b_idx
in
range
(
batch_size
):
l
[
b_idx
]
=
[
l
[
b_idx
][
k_idx
.
item
()]
for
k_idx
in
re_sorted_idx
[
b_idx
,
:]
]
re_sorted_idx
=
paddle
.
reshape
(
re_sorted_idx
+
pos_index
.
expand_as
(
re_sorted_idx
),
[
batch_size
*
beam_width
])
# Reverse the sequences and re-order at the same time
# It is reversed because the backtracking happens in reverse time order
p
=
[
paddle
.
reshape
(
paddle
.
index_select
(
step
,
re_sorted_idx
,
0
),
shape
=
[
batch_size
,
beam_width
,
-
1
])
for
step
in
reversed
(
p
)
]
p
=
paddle
.
concat
(
p
,
-
1
)[:,
0
,
:]
return
p
,
paddle
.
ones_like
(
p
)
class
AttentionUnit
(
nn
.
Layer
):
def
__init__
(
self
,
sDim
,
xDim
,
attDim
):
...
...
@@ -151,21 +314,9 @@ class AttentionUnit(nn.Layer):
self
.
xDim
=
xDim
self
.
attDim
=
attDim
self
.
sEmbed
=
nn
.
Linear
(
sDim
,
attDim
,
weight_attr
=
paddle
.
nn
.
initializer
.
Normal
(
std
=
0.01
),
bias_attr
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
))
self
.
xEmbed
=
nn
.
Linear
(
xDim
,
attDim
,
weight_attr
=
paddle
.
nn
.
initializer
.
Normal
(
std
=
0.01
),
bias_attr
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
))
self
.
wEmbed
=
nn
.
Linear
(
attDim
,
1
,
weight_attr
=
paddle
.
nn
.
initializer
.
Normal
(
std
=
0.01
),
bias_attr
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
))
self
.
sEmbed
=
nn
.
Linear
(
sDim
,
attDim
)
self
.
xEmbed
=
nn
.
Linear
(
xDim
,
attDim
)
self
.
wEmbed
=
nn
.
Linear
(
attDim
,
1
)
def
forward
(
self
,
x
,
sPrev
):
batch_size
,
T
,
_
=
x
.
shape
# [b x T x xDim]
...
...
@@ -184,10 +335,8 @@ class AttentionUnit(nn.Layer):
vProj
=
self
.
wEmbed
(
sumTanh
)
# [(b x T) x 1]
vProj
=
paddle
.
reshape
(
vProj
,
[
batch_size
,
T
])
alpha
=
F
.
softmax
(
vProj
,
axis
=
1
)
# attention weights for each sample in the minibatch
return
alpha
...
...
@@ -238,21 +387,4 @@ class DecoderUnit(nn.Layer):
output
,
state
=
self
.
gru
(
concat_context
,
sPrev
)
output
=
paddle
.
squeeze
(
output
,
axis
=
1
)
output
=
self
.
fc
(
output
)
return
output
,
state
if
__name__
==
"__main__"
:
model
=
AttentionRecognitionHead
(
num_classes
=
20
,
in_channels
=
30
,
sDim
=
512
,
attDim
=
512
,
max_len_labels
=
25
,
out_channels
=
38
)
data
=
paddle
.
ones
([
16
,
64
,
3
])
targets
=
paddle
.
ones
([
16
,
25
])
length
=
paddle
.
to_tensor
(
20
)
x
=
[
data
,
targets
,
length
]
output
=
model
(
x
)
print
(
output
.
shape
)
return
output
,
state
\ No newline at end of file
ppocr/modeling/heads/rec_att_head.py
浏览文件 @
c9e1077d
...
...
@@ -44,13 +44,10 @@ class AttentionHead(nn.Layer):
hidden
=
paddle
.
zeros
((
batch_size
,
self
.
hidden_size
))
output_hiddens
=
[]
targets
=
targets
[
0
]
print
(
targets
)
if
targets
is
not
None
:
for
i
in
range
(
num_steps
):
char_onehots
=
self
.
_char_to_onehot
(
targets
[:,
i
],
onehot_dim
=
self
.
num_classes
)
# print("char_onehots:", char_onehots)
(
outputs
,
hidden
),
alpha
=
self
.
attention_cell
(
hidden
,
inputs
,
char_onehots
)
output_hiddens
.
append
(
paddle
.
unsqueeze
(
outputs
,
axis
=
1
))
...
...
@@ -107,8 +104,6 @@ class AttentionGRUCell(nn.Layer):
alpha
=
paddle
.
transpose
(
alpha
,
[
0
,
2
,
1
])
context
=
paddle
.
squeeze
(
paddle
.
mm
(
alpha
,
batch_H
),
axis
=
1
)
concat_context
=
paddle
.
concat
([
context
,
char_onehots
],
1
)
# print("concat_context:", concat_context.shape)
# print("prev_hidden:", prev_hidden.shape)
cur_hidden
=
self
.
rnn
(
concat_context
,
prev_hidden
)
...
...
ppocr/modeling/transforms/stn.py
浏览文件 @
c9e1077d
...
...
@@ -106,16 +106,3 @@ class STN(nn.Layer):
x
=
F
.
sigmoid
(
x
)
x
=
paddle
.
reshape
(
x
,
shape
=
[
-
1
,
self
.
num_ctrlpoints
,
2
])
return
img_feat
,
x
if
__name__
==
"__main__"
:
in_planes
=
3
num_ctrlpoints
=
20
np
.
random
.
seed
(
100
)
activation
=
'none'
# 'sigmoid'
stn_head
=
STN
(
in_planes
,
num_ctrlpoints
,
activation
)
data
=
np
.
random
.
randn
(
10
,
3
,
32
,
64
).
astype
(
"float32"
)
print
(
"data:"
,
np
.
sum
(
data
))
input
=
paddle
.
to_tensor
(
data
)
#input = paddle.randn([10, 3, 32, 64])
control_points
=
stn_head
(
input
)
ppocr/modeling/transforms/tps.py
浏览文件 @
c9e1077d
...
...
@@ -326,5 +326,6 @@ class STN_ON(nn.Layer):
image
,
self
.
tps_inputsize
,
mode
=
"bilinear"
,
align_corners
=
True
)
stn_img_feat
,
ctrl_points
=
self
.
stn_head
(
stn_input
)
x
,
_
=
self
.
tps
(
image
,
ctrl_points
)
#print("x:", np.sum(x.numpy()))
# print(x.shape)
return
x
ppocr/modeling/transforms/tps_spatial_transformer.py
浏览文件 @
c9e1077d
...
...
@@ -136,7 +136,8 @@ class TPSSpatialTransformer(nn.Layer):
assert
source_control_points
.
ndimension
()
==
3
assert
source_control_points
.
shape
[
1
]
==
self
.
num_control_points
assert
source_control_points
.
shape
[
2
]
==
2
batch_size
=
source_control_points
.
shape
[
0
]
#batch_size = source_control_points.shape[0]
batch_size
=
paddle
.
shape
(
source_control_points
)[
0
]
self
.
padding_matrix
=
paddle
.
expand
(
self
.
padding_matrix
,
shape
=
[
batch_size
,
3
,
2
])
...
...
@@ -151,28 +152,6 @@ class TPSSpatialTransformer(nn.Layer):
grid
=
paddle
.
clip
(
grid
,
0
,
1
)
# the source_control_points may be out of [0, 1].
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
#
grid = 2.0 * grid - 1.0
grid
=
2.0
*
grid
-
1.0
output_maps
=
grid_sample
(
input
,
grid
,
canvas
=
None
)
return
output_maps
,
source_coordinate
if
__name__
==
"__main__"
:
from
stn
import
STN
in_planes
=
3
num_ctrlpoints
=
20
np
.
random
.
seed
(
100
)
activation
=
'none'
# 'sigmoid'
stn_head
=
STN
(
in_planes
,
num_ctrlpoints
,
activation
)
data
=
np
.
random
.
randn
(
10
,
3
,
32
,
64
).
astype
(
"float32"
)
input
=
paddle
.
to_tensor
(
data
)
#input = paddle.randn([10, 3, 32, 64])
control_points
=
stn_head
(
input
)
#print("control points:", control_points)
#input = paddle.randn(shape=[10,3,32,100])
tps
=
TPSSpatialTransformer
(
output_image_size
=
[
32
,
320
],
num_control_points
=
20
,
margins
=
[
0.05
,
0.05
])
out
=
tps
(
input
,
control_points
[
1
])
print
(
"out 0 :"
,
out
[
0
].
shape
)
print
(
"out 1:"
,
out
[
1
].
shape
)
ppocr/modeling/transforms/tps_torch.py
已删除
100644 → 0
浏览文件 @
59cc4efd
from
__future__
import
absolute_import
import
numpy
as
np
import
itertools
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
def
grid_sample
(
input
,
grid
,
canvas
=
None
):
output
=
F
.
grid_sample
(
input
,
grid
)
if
canvas
is
None
:
return
output
else
:
input_mask
=
input
.
data
.
new
(
input
.
size
()).
fill_
(
1
)
output_mask
=
F
.
grid_sample
(
input_mask
,
grid
)
padded_output
=
output
*
output_mask
+
canvas
*
(
1
-
output_mask
)
return
padded_output
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
def
compute_partial_repr
(
input_points
,
control_points
):
N
=
input_points
.
size
(
0
)
M
=
control_points
.
size
(
0
)
pairwise_diff
=
input_points
.
view
(
N
,
1
,
2
)
-
control_points
.
view
(
1
,
M
,
2
)
# original implementation, very slow
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
pairwise_diff_square
=
pairwise_diff
*
pairwise_diff
pairwise_dist
=
pairwise_diff_square
[:,
:,
0
]
+
pairwise_diff_square
[:,
:,
1
]
repr_matrix
=
0.5
*
pairwise_dist
*
torch
.
log
(
pairwise_dist
)
# fix numerical error for 0 * log(0), substitute all nan with 0
mask
=
repr_matrix
!=
repr_matrix
repr_matrix
.
masked_fill_
(
mask
,
0
)
return
repr_matrix
# output_ctrl_pts are specified, according to our task.
def
build_output_control_points
(
num_control_points
,
margins
):
margin_x
,
margin_y
=
margins
num_ctrl_pts_per_side
=
num_control_points
//
2
ctrl_pts_x
=
np
.
linspace
(
margin_x
,
1.0
-
margin_x
,
num_ctrl_pts_per_side
)
ctrl_pts_y_top
=
np
.
ones
(
num_ctrl_pts_per_side
)
*
margin_y
ctrl_pts_y_bottom
=
np
.
ones
(
num_ctrl_pts_per_side
)
*
(
1.0
-
margin_y
)
ctrl_pts_top
=
np
.
stack
([
ctrl_pts_x
,
ctrl_pts_y_top
],
axis
=
1
)
ctrl_pts_bottom
=
np
.
stack
([
ctrl_pts_x
,
ctrl_pts_y_bottom
],
axis
=
1
)
# ctrl_pts_top = ctrl_pts_top[1:-1,:]
# ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
output_ctrl_pts_arr
=
np
.
concatenate
(
[
ctrl_pts_top
,
ctrl_pts_bottom
],
axis
=
0
)
output_ctrl_pts
=
torch
.
Tensor
(
output_ctrl_pts_arr
)
return
output_ctrl_pts
# demo: ~/test/models/test_tps_transformation.py
class
TPSSpatialTransformer
(
nn
.
Module
):
def
__init__
(
self
,
output_image_size
=
None
,
num_control_points
=
None
,
margins
=
None
):
super
(
TPSSpatialTransformer
,
self
).
__init__
()
self
.
output_image_size
=
output_image_size
self
.
num_control_points
=
num_control_points
self
.
margins
=
margins
self
.
target_height
,
self
.
target_width
=
output_image_size
target_control_points
=
build_output_control_points
(
num_control_points
,
margins
)
N
=
num_control_points
# N = N - 4
# create padded kernel matrix
forward_kernel
=
torch
.
zeros
(
N
+
3
,
N
+
3
)
target_control_partial_repr
=
compute_partial_repr
(
target_control_points
,
target_control_points
)
forward_kernel
[:
N
,
:
N
].
copy_
(
target_control_partial_repr
)
forward_kernel
[:
N
,
-
3
].
fill_
(
1
)
forward_kernel
[
-
3
,
:
N
].
fill_
(
1
)
forward_kernel
[:
N
,
-
2
:].
copy_
(
target_control_points
)
forward_kernel
[
-
2
:,
:
N
].
copy_
(
target_control_points
.
transpose
(
0
,
1
))
# compute inverse matrix
inverse_kernel
=
torch
.
inverse
(
forward_kernel
)
# create target cordinate matrix
HW
=
self
.
target_height
*
self
.
target_width
target_coordinate
=
list
(
itertools
.
product
(
range
(
self
.
target_height
),
range
(
self
.
target_width
)))
target_coordinate
=
torch
.
Tensor
(
target_coordinate
)
# HW x 2
Y
,
X
=
target_coordinate
.
split
(
1
,
dim
=
1
)
Y
=
Y
/
(
self
.
target_height
-
1
)
X
=
X
/
(
self
.
target_width
-
1
)
target_coordinate
=
torch
.
cat
([
X
,
Y
],
dim
=
1
)
# convert from (y, x) to (x, y)
target_coordinate_partial_repr
=
compute_partial_repr
(
target_coordinate
,
target_control_points
)
target_coordinate_repr
=
torch
.
cat
([
target_coordinate_partial_repr
,
torch
.
ones
(
HW
,
1
),
target_coordinate
],
dim
=
1
)
# register precomputed matrices
self
.
register_buffer
(
'inverse_kernel'
,
inverse_kernel
)
self
.
register_buffer
(
'padding_matrix'
,
torch
.
zeros
(
3
,
2
))
self
.
register_buffer
(
'target_coordinate_repr'
,
target_coordinate_repr
)
self
.
register_buffer
(
'target_control_points'
,
target_control_points
)
def
forward
(
self
,
input
,
source_control_points
):
assert
source_control_points
.
ndimension
()
==
3
assert
source_control_points
.
size
(
1
)
==
self
.
num_control_points
assert
source_control_points
.
size
(
2
)
==
2
batch_size
=
source_control_points
.
size
(
0
)
Y
=
torch
.
cat
([
source_control_points
,
self
.
padding_matrix
.
expand
(
batch_size
,
3
,
2
)
],
1
)
mapping_matrix
=
torch
.
matmul
(
self
.
inverse_kernel
,
Y
)
source_coordinate
=
torch
.
matmul
(
self
.
target_coordinate_repr
,
mapping_matrix
)
grid
=
source_coordinate
.
view
(
-
1
,
self
.
target_height
,
self
.
target_width
,
2
)
grid
=
torch
.
clamp
(
grid
,
0
,
1
)
# the source_control_points may be out of [0, 1].
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
grid
=
2.0
*
grid
-
1.0
output_maps
=
grid_sample
(
input
,
grid
,
canvas
=
None
)
return
output_maps
,
source_coordinate
if
__name__
==
"__main__"
:
from
stn_torch
import
STNHead
in_planes
=
3
num_ctrlpoints
=
20
torch
.
manual_seed
(
10
)
activation
=
'none'
# 'sigmoid'
stn_head
=
STNHead
(
in_planes
,
num_ctrlpoints
,
activation
)
np
.
random
.
seed
(
100
)
data
=
np
.
random
.
randn
(
10
,
3
,
32
,
64
).
astype
(
"float32"
)
input
=
torch
.
tensor
(
data
)
control_points
=
stn_head
(
input
)
tps
=
TPSSpatialTransformer
(
output_image_size
=
[
32
,
320
],
num_control_points
=
20
,
margins
=
[
0.05
,
0.05
])
out
=
tps
(
input
,
control_points
[
1
])
print
(
"out 0 :"
,
out
[
0
].
shape
)
print
(
"out 1:"
,
out
[
1
].
shape
)
ppocr/optimizer/optimizer.py
浏览文件 @
c9e1077d
...
...
@@ -127,3 +127,34 @@ class RMSProp(object):
grad_clip
=
self
.
grad_clip
,
parameters
=
parameters
)
return
opt
class
Adadelta
(
object
):
def
__init__
(
self
,
learning_rate
=
0.001
,
epsilon
=
1e-08
,
rho
=
0.95
,
parameter_list
=
None
,
weight_decay
=
None
,
grad_clip
=
None
,
name
=
None
,
**
kwargs
):
self
.
learning_rate
=
learning_rate
self
.
epsilon
=
epsilon
self
.
rho
=
rho
self
.
parameter_list
=
parameter_list
self
.
learning_rate
=
learning_rate
self
.
weight_decay
=
weight_decay
self
.
grad_clip
=
grad_clip
self
.
name
=
name
def
__call__
(
self
,
parameters
):
opt
=
optim
.
Adadelta
(
learning_rate
=
self
.
learning_rate
,
epsilon
=
self
.
epsilon
,
rho
=
self
.
rho
,
weight_decay
=
self
.
weight_decay
,
grad_clip
=
self
.
grad_clip
,
name
=
self
.
name
,
parameters
=
parameters
)
return
opt
ppocr/postprocess/__init__.py
浏览文件 @
c9e1077d
...
...
@@ -25,7 +25,7 @@ from .db_postprocess import DBPostProcess
from
.east_postprocess
import
EASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
DistillationCTCLabelDecode
,
\
TableLabelDecode
TableLabelDecode
,
SEEDLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
...
...
@@ -34,7 +34,7 @@ def build_post_process(config, global_config=None):
support_dict
=
[
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
,
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'SEEDLabelDecode'
]
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
c9e1077d
...
...
@@ -170,10 +170,8 @@ class AttnLabelDecode(BaseRecLabelDecode):
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
self
.
unkonwn
=
"UNKNOWN"
dict_character
=
dict_character
dict_character
=
[
self
.
beg_str
]
+
dict_character
+
[
self
.
end_str
]
+
[
self
.
unkonwn
]
dict_character
=
[
self
.
beg_str
]
+
dict_character
+
[
self
.
end_str
]
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
...
...
@@ -214,7 +212,6 @@ class AttnLabelDecode(BaseRecLabelDecode):
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
preds
=
preds
[
"rec_pred"
]
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
...
...
@@ -242,6 +239,88 @@ class AttnLabelDecode(BaseRecLabelDecode):
return
idx
class
SEEDLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
character_type
=
'ch'
,
use_space_char
=
False
,
**
kwargs
):
super
(
SEEDLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
dict_character
=
dict_character
dict_character
=
dict_character
+
[
self
.
end_str
]
return
dict_character
def
get_ignored_tokens
(
self
):
end_idx
=
self
.
get_beg_end_flag_idx
(
"eos"
)
return
[
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"sos"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"eos"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
%
beg_or_end
return
idx
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
[
end_idx
]
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
int
(
text_index
[
batch_idx
][
idx
])
==
int
(
end_idx
):
break
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
)))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
preds_idx
=
preds
[
"rec_pred"
]
if
isinstance
(
preds_idx
,
paddle
.
Tensor
):
preds_idx
=
preds_idx
.
numpy
()
if
"rec_pred_scores"
in
preds
:
preds_idx
=
preds
[
"rec_pred"
]
preds_prob
=
preds
[
"rec_pred_scores"
]
else
:
preds_idx
=
preds
[
"rec_pred"
].
argmax
(
axis
=
2
)
preds_prob
=
preds
[
"rec_pred"
].
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
return
text
,
label
class
SRNLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
...
...
ppocr/utils/save_load.py
浏览文件 @
c9e1077d
...
...
@@ -105,16 +105,13 @@ def load_dygraph_params(config, model, logger, optimizer):
params
=
paddle
.
load
(
pm
)
state_dict
=
model
.
state_dict
()
new_state_dict
=
{}
# for k1, k2 in zip(state_dict.keys(), params.keys()):
for
k1
in
state_dict
.
keys
():
if
k1
not
in
params
:
continue
if
list
(
state_dict
[
k1
].
shape
)
==
list
(
params
[
k1
].
shape
):
new_state_dict
[
k1
]
=
params
[
k1
]
else
:
logger
.
info
(
f
"The shape of model params
{
k1
}
{
state_dict
[
k1
].
shape
}
not matched with loaded params
{
k1
}
{
params
[
k1
].
shape
}
!"
)
for
k1
,
k2
in
zip
(
state_dict
.
keys
(),
params
.
keys
()):
if
list
(
state_dict
[
k1
].
shape
)
==
list
(
params
[
k2
].
shape
):
new_state_dict
[
k1
]
=
params
[
k2
]
else
:
logger
.
info
(
f
"The shape of model params
{
k1
}
{
state_dict
[
k1
].
shape
}
not matched with loaded params
{
k2
}
{
params
[
k2
].
shape
}
!"
)
model
.
set_state_dict
(
new_state_dict
)
logger
.
info
(
f
"loaded pretrained_model successful from
{
pm
}
"
)
return
{}
...
...
tools/program.py
浏览文件 @
c9e1077d
...
...
@@ -211,11 +211,10 @@ def train(config,
images
=
batch
[
0
]
if
use_srn
:
model_average
=
True
# if use_srn or model_type == 'table' or algorithm == "ASTER":
# preds = model(images, data=batch[1:])
# else:
# preds = model(images)
preds
=
model
(
images
,
data
=
batch
[
1
:])
if
use_srn
or
model_type
==
'table'
or
model_type
==
"seed"
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
preds
=
model
(
images
)
state_dict
=
model
.
state_dict
()
# for key in state_dict:
# print(key)
...
...
@@ -415,6 +414,7 @@ def preprocess(is_train=False):
yaml
.
dump
(
dict
(
config
),
f
,
default_flow_style
=
False
,
sort_keys
=
False
)
log_file
=
'{}/train.log'
.
format
(
save_model_dir
)
print
(
"log has save in {}/train.log"
.
format
(
save_model_dir
))
else
:
log_file
=
None
logger
=
get_logger
(
name
=
'root'
,
log_file
=
log_file
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录