Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
d5e6df05
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d5e6df05
编写于
9月 29, 2021
作者:
T
tink2123
浏览文件
操作
浏览文件
下载
差异文件
fix seed typo
上级
93118497
6f7e07e6
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
398 addition
and
20 deletion
+398
-20
configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml
...igs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml
+126
-0
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+5
-0
ppocr/losses/__init__.py
ppocr/losses/__init__.py
+0
-1
ppocr/losses/ace_loss.py
ppocr/losses/ace_loss.py
+50
-0
ppocr/losses/center_loss.py
ppocr/losses/center_loss.py
+89
-0
ppocr/losses/combined_loss.py
ppocr/losses/combined_loss.py
+4
-0
ppocr/losses/distillation_loss.py
ppocr/losses/distillation_loss.py
+1
-1
ppocr/losses/rec_ctc_loss.py
ppocr/losses/rec_ctc_loss.py
+9
-1
ppocr/modeling/heads/rec_ctc_head.py
ppocr/modeling/heads/rec_ctc_head.py
+13
-4
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+5
-1
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+2
-0
tools/export_model.py
tools/export_model.py
+6
-0
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+70
-1
tools/program.py
tools/program.py
+18
-11
未找到文件。
configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml
0 → 100644
浏览文件 @
d5e6df05
Global
:
debug
:
false
use_gpu
:
true
epoch_num
:
800
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec_mobile_pp-OCRv2_enhanced_ctc_loss
save_epoch_step
:
3
eval_batch_step
:
[
0
,
2000
]
cal_metric_during_train
:
true
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
false
infer_img
:
doc/imgs_words/ch/word_1.jpg
character_dict_path
:
ppocr/utils/ppocr_keys_v1.txt
character_type
:
ch
max_text_length
:
25
infer_mode
:
false
use_space_char
:
true
distributed
:
true
save_res_path
:
./output/rec/predicts_mobile_pp-OCRv2_enhanced_ctc_loss.txt
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
Piecewise
decay_epochs
:
[
700
,
800
]
values
:
[
0.001
,
0.0001
]
warmup_epoch
:
5
regularizer
:
name
:
L2
factor
:
2.0e-05
Architecture
:
model_type
:
rec
algorithm
:
CRNN
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
Neck
:
name
:
SequenceEncoder
encoder_type
:
rnn
hidden_size
:
64
Head
:
name
:
CTCHead
mid_channels
:
96
fc_decay
:
0.00002
return_feats
:
true
Loss
:
name
:
CombinedLoss
loss_config_list
:
-
CTCLoss
:
use_focal_loss
:
false
weight
:
1.0
-
CenterLoss
:
weight
:
0.05
num_classes
:
6625
feat_dim
:
96
init_center
:
false
center_file_path
:
"
./train_center.pkl"
# you can also try to add ace loss on your own dataset
# - ACELoss:
# weight: 0.1
PostProcess
:
name
:
CTCLabelDecode
Metric
:
name
:
RecMetric
main_indicator
:
acc
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/
label_file_list
:
-
./train_data/train_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
RecAug
:
-
CTCLabelEncode
:
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
320
]
-
KeepKeys
:
keep_keys
:
-
image
-
label
-
length
-
label_ace
loader
:
shuffle
:
true
batch_size_per_card
:
128
drop_last
:
true
num_workers
:
8
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data
label_file_list
:
-
./train_data/val_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
CTCLabelEncode
:
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
320
]
-
KeepKeys
:
keep_keys
:
-
image
-
label
-
length
loader
:
shuffle
:
false
drop_last
:
false
batch_size_per_card
:
128
num_workers
:
8
ppocr/data/imaug/label_ops.py
浏览文件 @
d5e6df05
...
@@ -215,6 +215,11 @@ class CTCLabelEncode(BaseRecLabelEncode):
...
@@ -215,6 +215,11 @@ class CTCLabelEncode(BaseRecLabelEncode):
data
[
'length'
]
=
np
.
array
(
len
(
text
))
data
[
'length'
]
=
np
.
array
(
len
(
text
))
text
=
text
+
[
0
]
*
(
self
.
max_text_len
-
len
(
text
))
text
=
text
+
[
0
]
*
(
self
.
max_text_len
-
len
(
text
))
data
[
'label'
]
=
np
.
array
(
text
)
data
[
'label'
]
=
np
.
array
(
text
)
label
=
[
0
]
*
len
(
self
.
character
)
for
x
in
text
:
label
[
x
]
+=
1
data
[
'label_ace'
]
=
np
.
array
(
label
)
return
data
return
data
def
add_special_char
(
self
,
dict_character
):
def
add_special_char
(
self
,
dict_character
):
...
...
ppocr/losses/__init__.py
浏览文件 @
d5e6df05
...
@@ -52,7 +52,6 @@ def build_loss(config):
...
@@ -52,7 +52,6 @@ def build_loss(config):
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'NRTRLoss'
,
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'NRTRLoss'
,
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
]
]
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
'loss only support {}'
.
format
(
assert
module_name
in
support_dict
,
Exception
(
'loss only support {}'
.
format
(
...
...
ppocr/losses/ace_loss.py
0 → 100644
浏览文件 @
d5e6df05
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.nn
as
nn
class
ACELoss
(
nn
.
Layer
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
()
self
.
loss_func
=
nn
.
CrossEntropyLoss
(
weight
=
None
,
ignore_index
=
0
,
reduction
=
'none'
,
soft_label
=
True
,
axis
=-
1
)
def
__call__
(
self
,
predicts
,
batch
):
if
isinstance
(
predicts
,
(
list
,
tuple
)):
predicts
=
predicts
[
-
1
]
B
,
N
=
predicts
.
shape
[:
2
]
div
=
paddle
.
to_tensor
([
N
]).
astype
(
'float32'
)
predicts
=
nn
.
functional
.
softmax
(
predicts
,
axis
=-
1
)
aggregation_preds
=
paddle
.
sum
(
predicts
,
axis
=
1
)
aggregation_preds
=
paddle
.
divide
(
aggregation_preds
,
div
)
length
=
batch
[
2
].
astype
(
"float32"
)
batch
=
batch
[
3
].
astype
(
"float32"
)
batch
[:,
0
]
=
paddle
.
subtract
(
div
,
length
)
batch
=
paddle
.
divide
(
batch
,
div
)
loss
=
self
.
loss_func
(
aggregation_preds
,
batch
)
return
{
"loss_ace"
:
loss
}
ppocr/losses/center_loss.py
0 → 100644
浏览文件 @
d5e6df05
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
pickle
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
class
CenterLoss
(
nn
.
Layer
):
"""
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
"""
def
__init__
(
self
,
num_classes
=
6625
,
feat_dim
=
96
,
init_center
=
False
,
center_file_path
=
None
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
feat_dim
=
feat_dim
self
.
centers
=
paddle
.
randn
(
shape
=
[
self
.
num_classes
,
self
.
feat_dim
]).
astype
(
"float64"
)
#random center
if
init_center
:
assert
os
.
path
.
exists
(
center_file_path
),
f
"center path(
{
center_file_path
}
) must exist when init_center is set as True."
with
open
(
center_file_path
,
'rb'
)
as
f
:
char_dict
=
pickle
.
load
(
f
)
for
key
in
char_dict
.
keys
():
self
.
centers
[
key
]
=
paddle
.
to_tensor
(
char_dict
[
key
])
def
__call__
(
self
,
predicts
,
batch
):
assert
isinstance
(
predicts
,
(
list
,
tuple
))
features
,
predicts
=
predicts
feats_reshape
=
paddle
.
reshape
(
features
,
[
-
1
,
features
.
shape
[
-
1
]]).
astype
(
"float64"
)
label
=
paddle
.
argmax
(
predicts
,
axis
=
2
)
label
=
paddle
.
reshape
(
label
,
[
label
.
shape
[
0
]
*
label
.
shape
[
1
]])
batch_size
=
feats_reshape
.
shape
[
0
]
#calc feat * feat
dist1
=
paddle
.
sum
(
paddle
.
square
(
feats_reshape
),
axis
=
1
,
keepdim
=
True
)
dist1
=
paddle
.
expand
(
dist1
,
[
batch_size
,
self
.
num_classes
])
#dist2 of centers
dist2
=
paddle
.
sum
(
paddle
.
square
(
self
.
centers
),
axis
=
1
,
keepdim
=
True
)
#num_classes
dist2
=
paddle
.
expand
(
dist2
,
[
self
.
num_classes
,
batch_size
]).
astype
(
"float64"
)
dist2
=
paddle
.
transpose
(
dist2
,
[
1
,
0
])
#first x * x + y * y
distmat
=
paddle
.
add
(
dist1
,
dist2
)
tmp
=
paddle
.
matmul
(
feats_reshape
,
paddle
.
transpose
(
self
.
centers
,
[
1
,
0
]))
distmat
=
distmat
-
2.0
*
tmp
#generate the mask
classes
=
paddle
.
arange
(
self
.
num_classes
).
astype
(
"int64"
)
label
=
paddle
.
expand
(
paddle
.
unsqueeze
(
label
,
1
),
(
batch_size
,
self
.
num_classes
))
mask
=
paddle
.
equal
(
paddle
.
expand
(
classes
,
[
batch_size
,
self
.
num_classes
]),
label
).
astype
(
"float64"
)
#get mask
dist
=
paddle
.
multiply
(
distmat
,
mask
)
loss
=
paddle
.
sum
(
paddle
.
clip
(
dist
,
min
=
1e-12
,
max
=
1e+12
))
/
batch_size
return
{
'loss_center'
:
loss
}
ppocr/losses/combined_loss.py
浏览文件 @
d5e6df05
...
@@ -15,6 +15,10 @@
...
@@ -15,6 +15,10 @@
import
paddle
import
paddle
import
paddle.nn
as
nn
import
paddle.nn
as
nn
from
.rec_ctc_loss
import
CTCLoss
from
.center_loss
import
CenterLoss
from
.ace_loss
import
ACELoss
from
.distillation_loss
import
DistillationCTCLoss
from
.distillation_loss
import
DistillationCTCLoss
from
.distillation_loss
import
DistillationDMLLoss
from
.distillation_loss
import
DistillationDMLLoss
from
.distillation_loss
import
DistillationDistanceLoss
,
DistillationDBLoss
,
DistillationDilaDBLoss
from
.distillation_loss
import
DistillationDistanceLoss
,
DistillationDBLoss
,
DistillationDilaDBLoss
...
...
ppocr/losses/distillation_loss.py
浏览文件 @
d5e6df05
...
@@ -112,7 +112,7 @@ class DistillationDMLLoss(DMLLoss):
...
@@ -112,7 +112,7 @@ class DistillationDMLLoss(DMLLoss):
if
isinstance
(
loss
,
dict
):
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
for
key
in
loss
:
loss_dict
[
"{}_{}_{}_{}_{}"
.
format
(
key
,
pair
[
loss_dict
[
"{}_{}_{}_{}_{}"
.
format
(
key
,
pair
[
0
],
pair
[
1
],
map
_name
,
idx
)]
=
loss
[
key
]
0
],
pair
[
1
],
self
.
maps
_name
,
idx
)]
=
loss
[
key
]
else
:
else
:
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
self
.
maps_name
[
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
self
.
maps_name
[
_c
],
idx
)]
=
loss
_c
],
idx
)]
=
loss
...
...
ppocr/losses/rec_ctc_loss.py
浏览文件 @
d5e6df05
...
@@ -21,16 +21,24 @@ from paddle import nn
...
@@ -21,16 +21,24 @@ from paddle import nn
class
CTCLoss
(
nn
.
Layer
):
class
CTCLoss
(
nn
.
Layer
):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
use_focal_loss
=
False
,
**
kwargs
):
super
(
CTCLoss
,
self
).
__init__
()
super
(
CTCLoss
,
self
).
__init__
()
self
.
loss_func
=
nn
.
CTCLoss
(
blank
=
0
,
reduction
=
'none'
)
self
.
loss_func
=
nn
.
CTCLoss
(
blank
=
0
,
reduction
=
'none'
)
self
.
use_focal_loss
=
use_focal_loss
def
forward
(
self
,
predicts
,
batch
):
def
forward
(
self
,
predicts
,
batch
):
if
isinstance
(
predicts
,
(
list
,
tuple
)):
predicts
=
predicts
[
-
1
]
predicts
=
predicts
.
transpose
((
1
,
0
,
2
))
predicts
=
predicts
.
transpose
((
1
,
0
,
2
))
N
,
B
,
_
=
predicts
.
shape
N
,
B
,
_
=
predicts
.
shape
preds_lengths
=
paddle
.
to_tensor
([
N
]
*
B
,
dtype
=
'int64'
)
preds_lengths
=
paddle
.
to_tensor
([
N
]
*
B
,
dtype
=
'int64'
)
labels
=
batch
[
1
].
astype
(
"int32"
)
labels
=
batch
[
1
].
astype
(
"int32"
)
label_lengths
=
batch
[
2
].
astype
(
'int64'
)
label_lengths
=
batch
[
2
].
astype
(
'int64'
)
loss
=
self
.
loss_func
(
predicts
,
labels
,
preds_lengths
,
label_lengths
)
loss
=
self
.
loss_func
(
predicts
,
labels
,
preds_lengths
,
label_lengths
)
if
self
.
use_focal_loss
:
weight
=
paddle
.
exp
(
-
loss
)
weight
=
paddle
.
subtract
(
paddle
.
to_tensor
([
1.0
]),
weight
)
weight
=
paddle
.
square
(
weight
)
*
self
.
focal_loss_alpha
loss
=
paddle
.
multiply
(
loss
,
weight
)
loss
=
loss
.
mean
()
# sum
loss
=
loss
.
mean
()
# sum
return
{
'loss'
:
loss
}
return
{
'loss'
:
loss
}
ppocr/modeling/heads/rec_ctc_head.py
浏览文件 @
d5e6df05
...
@@ -38,6 +38,7 @@ class CTCHead(nn.Layer):
...
@@ -38,6 +38,7 @@ class CTCHead(nn.Layer):
out_channels
,
out_channels
,
fc_decay
=
0.0004
,
fc_decay
=
0.0004
,
mid_channels
=
None
,
mid_channels
=
None
,
return_feats
=
False
,
**
kwargs
):
**
kwargs
):
super
(
CTCHead
,
self
).
__init__
()
super
(
CTCHead
,
self
).
__init__
()
if
mid_channels
is
None
:
if
mid_channels
is
None
:
...
@@ -66,14 +67,22 @@ class CTCHead(nn.Layer):
...
@@ -66,14 +67,22 @@ class CTCHead(nn.Layer):
bias_attr
=
bias_attr2
)
bias_attr
=
bias_attr2
)
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
self
.
mid_channels
=
mid_channels
self
.
mid_channels
=
mid_channels
self
.
return_feats
=
return_feats
def
forward
(
self
,
x
,
targets
=
None
):
def
forward
(
self
,
x
,
targets
=
None
):
if
self
.
mid_channels
is
None
:
if
self
.
mid_channels
is
None
:
predicts
=
self
.
fc
(
x
)
predicts
=
self
.
fc
(
x
)
else
:
else
:
predicts
=
self
.
fc1
(
x
)
x
=
self
.
fc1
(
x
)
predicts
=
self
.
fc2
(
predicts
)
predicts
=
self
.
fc2
(
x
)
if
self
.
return_feats
:
result
=
(
x
,
predicts
)
else
:
result
=
predicts
if
not
self
.
training
:
if
not
self
.
training
:
predicts
=
F
.
softmax
(
predicts
,
axis
=
2
)
predicts
=
F
.
softmax
(
predicts
,
axis
=
2
)
return
predicts
result
=
predicts
return
result
ppocr/postprocess/__init__.py
浏览文件 @
d5e6df05
...
@@ -18,6 +18,7 @@ from __future__ import print_function
...
@@ -18,6 +18,7 @@ from __future__ import print_function
from
__future__
import
unicode_literals
from
__future__
import
unicode_literals
import
copy
import
copy
import
platform
__all__
=
[
'build_post_process'
]
__all__
=
[
'build_post_process'
]
...
@@ -28,7 +29,10 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
...
@@ -28,7 +29,10 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
TableLabelDecode
,
NRTRLabelDecode
,
SARLabelDecode
,
SEEDLabelDecode
TableLabelDecode
,
NRTRLabelDecode
,
SARLabelDecode
,
SEEDLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
from
.pg_postprocess
import
PGPostProcess
from
.pse_postprocess
import
PSEPostProcess
if
platform
.
system
()
!=
"Windows"
:
# pse is not support in Windows
from
.pse_postprocess
import
PSEPostProcess
def
build_post_process
(
config
,
global_config
=
None
):
def
build_post_process
(
config
,
global_config
=
None
):
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
d5e6df05
...
@@ -111,6 +111,8 @@ class CTCLabelDecode(BaseRecLabelDecode):
...
@@ -111,6 +111,8 @@ class CTCLabelDecode(BaseRecLabelDecode):
character_type
,
use_space_char
)
character_type
,
use_space_char
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
isinstance
(
preds
,
tuple
):
preds
=
preds
[
-
1
]
if
isinstance
(
preds
,
paddle
.
Tensor
):
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_idx
=
preds
.
argmax
(
axis
=
2
)
...
...
tools/export_model.py
浏览文件 @
d5e6df05
...
@@ -49,6 +49,12 @@ def export_single_model(model, arch_config, save_path, logger):
...
@@ -49,6 +49,12 @@ def export_single_model(model, arch_config, save_path, logger):
]
]
]
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
model
=
to_static
(
model
,
input_spec
=
other_shape
)
elif
arch_config
[
"algorithm"
]
==
"SAR"
:
other_shape
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
48
,
160
],
dtype
=
"float32"
),
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
else
:
else
:
infer_shape
=
[
3
,
-
1
,
-
1
]
infer_shape
=
[
3
,
-
1
,
-
1
]
if
arch_config
[
"model_type"
]
==
"rec"
:
if
arch_config
[
"model_type"
]
==
"rec"
:
...
...
tools/infer/predict_rec.py
浏览文件 @
d5e6df05
...
@@ -68,6 +68,13 @@ class TextRecognizer(object):
...
@@ -68,6 +68,13 @@ class TextRecognizer(object):
"character_dict_path"
:
args
.
rec_char_dict_path
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
"use_space_char"
:
args
.
use_space_char
}
}
elif
self
.
rec_algorithm
==
"SAR"
:
postprocess_params
=
{
'name'
:
'SARLabelDecode'
,
"character_type"
:
args
.
rec_char_type
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
}
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
\
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
\
utility
.
create_predictor
(
args
,
'rec'
,
logger
)
utility
.
create_predictor
(
args
,
'rec'
,
logger
)
...
@@ -194,6 +201,41 @@ class TextRecognizer(object):
...
@@ -194,6 +201,41 @@ class TextRecognizer(object):
return
(
norm_img
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
return
(
norm_img
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
)
gsrm_slf_attn_bias2
)
def
resize_norm_img_sar
(
self
,
img
,
image_shape
,
width_downsample_ratio
=
0.25
):
imgC
,
imgH
,
imgW_min
,
imgW_max
=
image_shape
h
=
img
.
shape
[
0
]
w
=
img
.
shape
[
1
]
valid_ratio
=
1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor
=
int
(
1
/
width_downsample_ratio
)
# resize
ratio
=
w
/
float
(
h
)
resize_w
=
math
.
ceil
(
imgH
*
ratio
)
if
resize_w
%
width_divisor
!=
0
:
resize_w
=
round
(
resize_w
/
width_divisor
)
*
width_divisor
if
imgW_min
is
not
None
:
resize_w
=
max
(
imgW_min
,
resize_w
)
if
imgW_max
is
not
None
:
valid_ratio
=
min
(
1.0
,
1.0
*
resize_w
/
imgW_max
)
resize_w
=
min
(
imgW_max
,
resize_w
)
resized_image
=
cv2
.
resize
(
img
,
(
resize_w
,
imgH
))
resized_image
=
resized_image
.
astype
(
'float32'
)
# norm
if
image_shape
[
0
]
==
1
:
resized_image
=
resized_image
/
255
resized_image
=
resized_image
[
np
.
newaxis
,
:]
else
:
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
resize_shape
=
resized_image
.
shape
padding_im
=
-
1.0
*
np
.
ones
((
imgC
,
imgH
,
imgW_max
),
dtype
=
np
.
float32
)
padding_im
[:,
:,
0
:
resize_w
]
=
resized_image
pad_shape
=
padding_im
.
shape
return
padding_im
,
resize_shape
,
pad_shape
,
valid_ratio
def
__call__
(
self
,
img_list
):
def
__call__
(
self
,
img_list
):
img_num
=
len
(
img_list
)
img_num
=
len
(
img_list
)
# Calculate the aspect ratio of all text bars
# Calculate the aspect ratio of all text bars
...
@@ -216,11 +258,19 @@ class TextRecognizer(object):
...
@@ -216,11 +258,19 @@ class TextRecognizer(object):
wh_ratio
=
w
*
1.0
/
h
wh_ratio
=
w
*
1.0
/
h
max_wh_ratio
=
max
(
max_wh_ratio
,
wh_ratio
)
max_wh_ratio
=
max
(
max_wh_ratio
,
wh_ratio
)
for
ino
in
range
(
beg_img_no
,
end_img_no
):
for
ino
in
range
(
beg_img_no
,
end_img_no
):
if
self
.
rec_algorithm
!=
"SRN"
:
if
self
.
rec_algorithm
!=
"SRN"
and
self
.
rec_algorithm
!=
"SAR"
:
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]],
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]],
max_wh_ratio
)
max_wh_ratio
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
norm_img_batch
.
append
(
norm_img
)
elif
self
.
rec_algorithm
==
"SAR"
:
norm_img
,
_
,
_
,
valid_ratio
=
self
.
resize_norm_img_sar
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
valid_ratio
=
np
.
expand_dims
(
valid_ratio
,
axis
=
0
)
valid_ratios
=
[]
valid_ratios
.
append
(
valid_ratio
)
norm_img_batch
.
append
(
norm_img
)
else
:
else
:
norm_img
=
self
.
process_image_srn
(
norm_img
=
self
.
process_image_srn
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
,
8
,
25
)
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
,
8
,
25
)
...
@@ -266,6 +316,25 @@ class TextRecognizer(object):
...
@@ -266,6 +316,25 @@ class TextRecognizer(object):
if
self
.
benchmark
:
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
self
.
autolog
.
times
.
stamp
()
preds
=
{
"predict"
:
outputs
[
2
]}
preds
=
{
"predict"
:
outputs
[
2
]}
elif
self
.
rec_algorithm
==
"SAR"
:
valid_ratios
=
np
.
concatenate
(
valid_ratios
)
inputs
=
[
norm_img_batch
,
valid_ratios
,
]
input_names
=
self
.
predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
input_tensor
=
self
.
predictor
.
get_input_handle
(
input_names
[
i
])
input_tensor
.
copy_from_cpu
(
inputs
[
i
])
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
preds
=
outputs
[
0
]
else
:
else
:
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
self
.
predictor
.
run
()
self
.
predictor
.
run
()
...
...
tools/program.py
浏览文件 @
d5e6df05
...
@@ -394,6 +394,18 @@ def preprocess(is_train=False):
...
@@ -394,6 +394,18 @@ def preprocess(is_train=False):
config
=
load_config
(
FLAGS
.
config
)
config
=
load_config
(
FLAGS
.
config
)
merge_config
(
FLAGS
.
opt
)
merge_config
(
FLAGS
.
opt
)
if
is_train
:
# save_config
save_model_dir
=
config
[
'Global'
][
'save_model_dir'
]
os
.
makedirs
(
save_model_dir
,
exist_ok
=
True
)
with
open
(
os
.
path
.
join
(
save_model_dir
,
'config.yml'
),
'w'
)
as
f
:
yaml
.
dump
(
dict
(
config
),
f
,
default_flow_style
=
False
,
sort_keys
=
False
)
log_file
=
'{}/train.log'
.
format
(
save_model_dir
)
else
:
log_file
=
None
logger
=
get_logger
(
name
=
'root'
,
log_file
=
log_file
)
# check if set use_gpu=True in paddlepaddle cpu version
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu
=
config
[
'Global'
][
'use_gpu'
]
use_gpu
=
config
[
'Global'
][
'use_gpu'
]
check_gpu
(
use_gpu
)
check_gpu
(
use_gpu
)
...
@@ -403,22 +415,17 @@ def preprocess(is_train=False):
...
@@ -403,22 +415,17 @@ def preprocess(is_train=False):
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'SEED'
]
'SEED'
]
windows_not_support_list
=
[
'PSE'
]
if
platform
.
system
()
==
"Windows"
and
alg
in
windows_not_support_list
:
logger
.
warning
(
'{} is not support in Windows now'
.
format
(
windows_not_support_list
))
sys
.
exit
()
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
paddle
.
set_device
(
device
)
device
=
paddle
.
set_device
(
device
)
config
[
'Global'
][
'distributed'
]
=
dist
.
get_world_size
()
!=
1
config
[
'Global'
][
'distributed'
]
=
dist
.
get_world_size
()
!=
1
if
is_train
:
# save_config
save_model_dir
=
config
[
'Global'
][
'save_model_dir'
]
os
.
makedirs
(
save_model_dir
,
exist_ok
=
True
)
with
open
(
os
.
path
.
join
(
save_model_dir
,
'config.yml'
),
'w'
)
as
f
:
yaml
.
dump
(
dict
(
config
),
f
,
default_flow_style
=
False
,
sort_keys
=
False
)
log_file
=
'{}/train.log'
.
format
(
save_model_dir
)
else
:
log_file
=
None
logger
=
get_logger
(
name
=
'root'
,
log_file
=
log_file
)
if
config
[
'Global'
][
'use_visualdl'
]:
if
config
[
'Global'
][
'use_visualdl'
]:
from
visualdl
import
LogWriter
from
visualdl
import
LogWriter
save_model_dir
=
config
[
'Global'
][
'save_model_dir'
]
save_model_dir
=
config
[
'Global'
][
'save_model_dir'
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录