Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
a3a09515
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a3a09515
编写于
7月 20, 2022
作者:
A
andyjpaddle
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add vl
上级
0401e520
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
293 addition
and
608 deletion
+293
-608
configs/rec/rec_r45_visionlan.yml
configs/rec/rec_r45_visionlan.yml
+106
-0
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+6
-4
ppocr/data/imaug/text_image_aug/__init__.py
ppocr/data/imaug/text_image_aug/__init__.py
+1
-2
ppocr/data/imaug/text_image_aug/vl_aug.py
ppocr/data/imaug/text_image_aug/vl_aug.py
+0
-460
ppocr/losses/rec_vl_loss.py
ppocr/losses/rec_vl_loss.py
+66
-0
ppocr/modeling/backbones/rec_resnet_45.py
ppocr/modeling/backbones/rec_resnet_45.py
+11
-14
ppocr/modeling/backbones/rec_resnet_aster.py
ppocr/modeling/backbones/rec_resnet_aster.py
+1
-112
ppocr/modeling/heads/rec_visionlan_head.py
ppocr/modeling/heads/rec_visionlan_head.py
+1
-9
ppocr/optimizer/optimizer.py
ppocr/optimizer/optimizer.py
+54
-3
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+1
-2
ppocr/utils/dict36.txt
ppocr/utils/dict36.txt
+36
-0
tools/export_model.py
tools/export_model.py
+8
-2
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+2
-0
未找到文件。
configs/rec/rec_r45_visionlan.yml
0 → 100644
浏览文件 @
a3a09515
Global
:
use_gpu
:
true
epoch_num
:
8
log_smooth_window
:
200
print_batch_step
:
200
save_model_dir
:
/paddle/backup/visionlan/LA_v2
save_epoch_step
:
1
# evaluation is run every 2000 iterations
eval_batch_step
:
[
0
,
2000
]
cal_metric_during_train
:
True
pretrained_model
:
./pretrained_model/LF_2_ocr
checkpoints
:
save_inference_dir
:
use_visualdl
:
True
infer_img
:
doc/imgs_words/en/word_2.png
# for data or label process
character_dict_path
:
ppocr/utils/dict36.txt
max_text_length
:
&max_text_length
25
training_step
:
&training_step
LA
infer_mode
:
False
use_space_char
:
False
save_res_path
:
./output/rec/predicts_visionlan.txt
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
clip_norm
:
20.0
group_lr
:
true
training_step
:
*training_step
lr
:
name
:
Piecewise
decay_epochs
:
[
6
]
values
:
[
0.0001
,
0.00001
]
regularizer
:
name
:
'
L2'
factor
:
0
Architecture
:
model_type
:
rec
algorithm
:
VisionLAN
Transform
:
Backbone
:
name
:
ResNet45
strides
:
[
2
,
2
,
2
,
1
,
1
]
Head
:
name
:
VLHead
n_layers
:
3
n_position
:
256
n_dim
:
512
max_text_length
:
*max_text_length
training_step
:
*training_step
Loss
:
name
:
VLLoss
mode
:
*training_step
weight_res
:
0.5
weight_mas
:
0.5
PostProcess
:
name
:
VLLabelDecode
Metric
:
name
:
RecMetric
is_filter
:
true
Train
:
dataset
:
name
:
LMDBDataSet
data_dir
:
./train_data/data_lmdb_release/training/
transforms
:
-
DecodeImage
:
# load image
img_mode
:
RGB
channel_first
:
False
-
ABINetRecAug
:
-
VLLabelEncode
:
# Class handling label
-
VLRecResizeImg
:
image_shape
:
[
3
,
64
,
256
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
label_res'
,
'
label_sub'
,
'
label_id'
,
'
length'
]
# dataloader will return list in this order
loader
:
shuffle
:
True
batch_size_per_card
:
220
drop_last
:
True
num_workers
:
4
Eval
:
dataset
:
name
:
LMDBDataSet
data_dir
:
./train_data/data_lmdb_release/validation/
transforms
:
-
DecodeImage
:
# load image
img_mode
:
RGB
channel_first
:
False
-
VLLabelEncode
:
# Class handling label
-
VLRecResizeImg
:
image_shape
:
[
3
,
64
,
256
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
label_res'
,
'
label_sub'
,
'
label_id'
,
'
length'
]
# dataloader will return list in this order
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
64
num_workers
:
4
ppocr/data/imaug/label_ops.py
浏览文件 @
a3a09515
...
...
@@ -99,12 +99,13 @@ class BaseRecLabelEncode(object):
def
__init__
(
self
,
max_text_length
,
character_dict_path
=
None
,
use_space_char
=
False
):
use_space_char
=
False
,
lower
=
False
):
self
.
max_text_len
=
max_text_length
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
self
.
lower
=
False
self
.
lower
=
lower
if
character_dict_path
is
None
:
logger
=
get_logger
()
...
...
@@ -1227,9 +1228,10 @@ class VLLabelEncode(BaseRecLabelEncode):
max_text_length
,
character_dict_path
=
None
,
use_space_char
=
False
,
lower
=
True
,
**
kwargs
):
super
(
VLLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
use_space_cha
r
)
super
(
VLLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
use_space_char
,
lowe
r
)
def
__call__
(
self
,
data
):
text
=
data
[
'label'
]
# original string
...
...
ppocr/data/imaug/text_image_aug/__init__.py
浏览文件 @
a3a09515
...
...
@@ -13,6 +13,5 @@
# limitations under the License.
from
.augment
import
tia_perspective
,
tia_distort
,
tia_stretch
from
.vl_aug
import
VLAug
__all__
=
[
'tia_distort'
,
'tia_stretch'
,
'tia_perspective'
,
'VLAug'
]
__all__
=
[
'tia_distort'
,
'tia_stretch'
,
'tia_perspective'
]
ppocr/data/imaug/text_image_aug/vl_aug.py
已删除
100644 → 0
浏览文件 @
0401e520
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
numbers
import
random
import
cv2
import
numpy
as
np
from
PIL
import
Image
from
paddle.vision
import
transforms
from
paddle.vision.transforms
import
Compose
def
sample_asym
(
magnitude
,
size
=
None
):
return
np
.
random
.
beta
(
1
,
4
,
size
)
*
magnitude
def
sample_sym
(
magnitude
,
size
=
None
):
return
(
np
.
random
.
beta
(
4
,
4
,
size
=
size
)
-
0.5
)
*
2
*
magnitude
def
sample_uniform
(
low
,
high
,
size
=
None
):
return
np
.
random
.
uniform
(
low
,
high
,
size
=
size
)
def
get_interpolation
(
type
=
'random'
):
if
type
==
'random'
:
choice
=
[
cv2
.
INTER_NEAREST
,
cv2
.
INTER_LINEAR
,
cv2
.
INTER_CUBIC
,
cv2
.
INTER_AREA
]
interpolation
=
choice
[
random
.
randint
(
0
,
len
(
choice
)
-
1
)]
elif
type
==
'nearest'
:
interpolation
=
cv2
.
INTER_NEAREST
elif
type
==
'linear'
:
interpolation
=
cv2
.
INTER_LINEAR
elif
type
==
'cubic'
:
interpolation
=
cv2
.
INTER_CUBIC
elif
type
==
'area'
:
interpolation
=
cv2
.
INTER_AREA
else
:
raise
TypeError
(
'Interpolation types only nearest, linear, cubic, area are supported!'
)
return
interpolation
class
CVRandomRotation
(
object
):
def
__init__
(
self
,
degrees
=
15
):
assert
isinstance
(
degrees
,
numbers
.
Number
),
"degree should be a single number."
assert
degrees
>=
0
,
"degree must be positive."
self
.
degrees
=
degrees
@
staticmethod
def
get_params
(
degrees
):
return
sample_sym
(
degrees
)
def
__call__
(
self
,
img
):
angle
=
self
.
get_params
(
self
.
degrees
)
src_h
,
src_w
=
img
.
shape
[:
2
]
M
=
cv2
.
getRotationMatrix2D
(
center
=
(
src_w
/
2
,
src_h
/
2
),
angle
=
angle
,
scale
=
1.0
)
abs_cos
,
abs_sin
=
abs
(
M
[
0
,
0
]),
abs
(
M
[
0
,
1
])
dst_w
=
int
(
src_h
*
abs_sin
+
src_w
*
abs_cos
)
dst_h
=
int
(
src_h
*
abs_cos
+
src_w
*
abs_sin
)
M
[
0
,
2
]
+=
(
dst_w
-
src_w
)
/
2
M
[
1
,
2
]
+=
(
dst_h
-
src_h
)
/
2
flags
=
get_interpolation
()
return
cv2
.
warpAffine
(
img
,
M
,
(
dst_w
,
dst_h
),
flags
=
flags
,
borderMode
=
cv2
.
BORDER_REPLICATE
)
class
CVRandomAffine
(
object
):
def
__init__
(
self
,
degrees
,
translate
=
None
,
scale
=
None
,
shear
=
None
):
assert
isinstance
(
degrees
,
numbers
.
Number
),
"degree should be a single number."
assert
degrees
>=
0
,
"degree must be positive."
self
.
degrees
=
degrees
if
translate
is
not
None
:
assert
isinstance
(
translate
,
(
tuple
,
list
))
and
len
(
translate
)
==
2
,
\
"translate should be a list or tuple and it must be of length 2."
for
t
in
translate
:
if
not
(
0.0
<=
t
<=
1.0
):
raise
ValueError
(
"translation values should be between 0 and 1"
)
self
.
translate
=
translate
if
scale
is
not
None
:
assert
isinstance
(
scale
,
(
tuple
,
list
))
and
len
(
scale
)
==
2
,
\
"scale should be a list or tuple and it must be of length 2."
for
s
in
scale
:
if
s
<=
0
:
raise
ValueError
(
"scale values should be positive"
)
self
.
scale
=
scale
if
shear
is
not
None
:
if
isinstance
(
shear
,
numbers
.
Number
):
if
shear
<
0
:
raise
ValueError
(
"If shear is a single number, it must be positive."
)
self
.
shear
=
[
shear
]
else
:
assert
isinstance
(
shear
,
(
tuple
,
list
))
and
(
len
(
shear
)
==
2
),
\
"shear should be a list or tuple and it must be of length 2."
self
.
shear
=
shear
else
:
self
.
shear
=
shear
def
_get_inverse_affine_matrix
(
self
,
center
,
angle
,
translate
,
scale
,
shear
):
from
numpy
import
sin
,
cos
,
tan
if
isinstance
(
shear
,
numbers
.
Number
):
shear
=
[
shear
,
0
]
if
not
isinstance
(
shear
,
(
tuple
,
list
))
and
len
(
shear
)
==
2
:
raise
ValueError
(
"Shear should be a single value or a tuple/list containing "
+
"two values. Got {}"
.
format
(
shear
))
rot
=
math
.
radians
(
angle
)
sx
,
sy
=
[
math
.
radians
(
s
)
for
s
in
shear
]
cx
,
cy
=
center
tx
,
ty
=
translate
# RSS without scaling
a
=
cos
(
rot
-
sy
)
/
cos
(
sy
)
b
=
-
cos
(
rot
-
sy
)
*
tan
(
sx
)
/
cos
(
sy
)
-
sin
(
rot
)
c
=
sin
(
rot
-
sy
)
/
cos
(
sy
)
d
=
-
sin
(
rot
-
sy
)
*
tan
(
sx
)
/
cos
(
sy
)
+
cos
(
rot
)
# Inverted rotation matrix with scale and shear
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
M
=
[
d
,
-
b
,
0
,
-
c
,
a
,
0
]
M
=
[
x
/
scale
for
x
in
M
]
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
M
[
2
]
+=
M
[
0
]
*
(
-
cx
-
tx
)
+
M
[
1
]
*
(
-
cy
-
ty
)
M
[
5
]
+=
M
[
3
]
*
(
-
cx
-
tx
)
+
M
[
4
]
*
(
-
cy
-
ty
)
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
M
[
2
]
+=
cx
M
[
5
]
+=
cy
return
M
@
staticmethod
def
get_params
(
degrees
,
translate
,
scale_ranges
,
shears
,
height
):
angle
=
sample_sym
(
degrees
)
if
translate
is
not
None
:
max_dx
=
translate
[
0
]
*
height
max_dy
=
translate
[
1
]
*
height
translations
=
(
np
.
round
(
sample_sym
(
max_dx
)),
np
.
round
(
sample_sym
(
max_dy
)))
else
:
translations
=
(
0
,
0
)
if
scale_ranges
is
not
None
:
scale
=
sample_uniform
(
scale_ranges
[
0
],
scale_ranges
[
1
])
else
:
scale
=
1.0
if
shears
is
not
None
:
if
len
(
shears
)
==
1
:
shear
=
[
sample_sym
(
shears
[
0
]),
0.
]
elif
len
(
shears
)
==
2
:
shear
=
[
sample_sym
(
shears
[
0
]),
sample_sym
(
shears
[
1
])]
else
:
shear
=
0.0
return
angle
,
translations
,
scale
,
shear
def
__call__
(
self
,
img
):
src_h
,
src_w
=
img
.
shape
[:
2
]
angle
,
translate
,
scale
,
shear
=
self
.
get_params
(
self
.
degrees
,
self
.
translate
,
self
.
scale
,
self
.
shear
,
src_h
)
M
=
self
.
_get_inverse_affine_matrix
((
src_w
/
2
,
src_h
/
2
),
angle
,
(
0
,
0
),
scale
,
shear
)
M
=
np
.
array
(
M
).
reshape
(
2
,
3
)
startpoints
=
[(
0
,
0
),
(
src_w
-
1
,
0
),
(
src_w
-
1
,
src_h
-
1
),
(
0
,
src_h
-
1
)]
project
=
lambda
x
,
y
,
a
,
b
,
c
:
int
(
a
*
x
+
b
*
y
+
c
)
endpoints
=
[(
project
(
x
,
y
,
*
M
[
0
]),
project
(
x
,
y
,
*
M
[
1
]))
for
x
,
y
in
startpoints
]
rect
=
cv2
.
minAreaRect
(
np
.
array
(
endpoints
))
bbox
=
cv2
.
boxPoints
(
rect
).
astype
(
dtype
=
np
.
int
)
max_x
,
max_y
=
bbox
[:,
0
].
max
(),
bbox
[:,
1
].
max
()
min_x
,
min_y
=
bbox
[:,
0
].
min
(),
bbox
[:,
1
].
min
()
dst_w
=
int
(
max_x
-
min_x
)
dst_h
=
int
(
max_y
-
min_y
)
M
[
0
,
2
]
+=
(
dst_w
-
src_w
)
/
2
M
[
1
,
2
]
+=
(
dst_h
-
src_h
)
/
2
# add translate
dst_w
+=
int
(
abs
(
translate
[
0
]))
dst_h
+=
int
(
abs
(
translate
[
1
]))
if
translate
[
0
]
<
0
:
M
[
0
,
2
]
+=
abs
(
translate
[
0
])
if
translate
[
1
]
<
0
:
M
[
1
,
2
]
+=
abs
(
translate
[
1
])
flags
=
get_interpolation
()
return
cv2
.
warpAffine
(
img
,
M
,
(
dst_w
,
dst_h
),
flags
=
flags
,
borderMode
=
cv2
.
BORDER_REPLICATE
)
class
CVRandomPerspective
(
object
):
def
__init__
(
self
,
distortion
=
0.5
):
self
.
distortion
=
distortion
def
get_params
(
self
,
width
,
height
,
distortion
):
offset_h
=
sample_asym
(
distortion
*
height
/
2
,
size
=
4
).
astype
(
dtype
=
np
.
int
)
offset_w
=
sample_asym
(
distortion
*
width
/
2
,
size
=
4
).
astype
(
dtype
=
np
.
int
)
topleft
=
(
offset_w
[
0
],
offset_h
[
0
])
topright
=
(
width
-
1
-
offset_w
[
1
],
offset_h
[
1
])
botright
=
(
width
-
1
-
offset_w
[
2
],
height
-
1
-
offset_h
[
2
])
botleft
=
(
offset_w
[
3
],
height
-
1
-
offset_h
[
3
])
startpoints
=
[(
0
,
0
),
(
width
-
1
,
0
),
(
width
-
1
,
height
-
1
),
(
0
,
height
-
1
)]
endpoints
=
[
topleft
,
topright
,
botright
,
botleft
]
return
np
.
array
(
startpoints
,
dtype
=
np
.
float32
),
np
.
array
(
endpoints
,
dtype
=
np
.
float32
)
def
__call__
(
self
,
img
):
height
,
width
=
img
.
shape
[:
2
]
startpoints
,
endpoints
=
self
.
get_params
(
width
,
height
,
self
.
distortion
)
M
=
cv2
.
getPerspectiveTransform
(
startpoints
,
endpoints
)
# TODO: more robust way to crop image
rect
=
cv2
.
minAreaRect
(
endpoints
)
bbox
=
cv2
.
boxPoints
(
rect
).
astype
(
dtype
=
np
.
int
)
max_x
,
max_y
=
bbox
[:,
0
].
max
(),
bbox
[:,
1
].
max
()
min_x
,
min_y
=
bbox
[:,
0
].
min
(),
bbox
[:,
1
].
min
()
min_x
,
min_y
=
max
(
min_x
,
0
),
max
(
min_y
,
0
)
flags
=
get_interpolation
()
img
=
cv2
.
warpPerspective
(
img
,
M
,
(
max_x
,
max_y
),
flags
=
flags
,
borderMode
=
cv2
.
BORDER_REPLICATE
)
img
=
img
[
min_y
:,
min_x
:]
return
img
class
CVRescale
(
object
):
def
__init__
(
self
,
factor
=
4
,
base_size
=
(
128
,
512
)):
""" Define image scales using gaussian pyramid and rescale image to target scale.
Args:
factor: the decayed factor from base size, factor=4 keeps target scale by default.
base_size: base size the build the bottom layer of pyramid
"""
if
isinstance
(
factor
,
numbers
.
Number
):
self
.
factor
=
round
(
sample_uniform
(
0
,
factor
))
elif
isinstance
(
factor
,
(
tuple
,
list
))
and
len
(
factor
)
==
2
:
self
.
factor
=
round
(
sample_uniform
(
factor
[
0
],
factor
[
1
]))
else
:
raise
Exception
(
'factor must be number or list with length 2'
)
# assert factor is valid
self
.
base_h
,
self
.
base_w
=
base_size
[:
2
]
def
__call__
(
self
,
img
):
if
self
.
factor
==
0
:
return
img
src_h
,
src_w
=
img
.
shape
[:
2
]
cur_w
,
cur_h
=
self
.
base_w
,
self
.
base_h
scale_img
=
cv2
.
resize
(
img
,
(
cur_w
,
cur_h
),
interpolation
=
get_interpolation
())
for
_
in
range
(
np
.
int
(
self
.
factor
)):
scale_img
=
cv2
.
pyrDown
(
scale_img
)
scale_img
=
cv2
.
resize
(
scale_img
,
(
src_w
,
src_h
),
interpolation
=
get_interpolation
())
return
scale_img
class
CVGaussianNoise
(
object
):
def
__init__
(
self
,
mean
=
0
,
var
=
20
):
self
.
mean
=
mean
if
isinstance
(
var
,
numbers
.
Number
):
self
.
var
=
max
(
int
(
sample_asym
(
var
)),
1
)
elif
isinstance
(
var
,
(
tuple
,
list
))
and
len
(
var
)
==
2
:
self
.
var
=
int
(
sample_uniform
(
var
[
0
],
var
[
1
]))
else
:
raise
Exception
(
'degree must be number or list with length 2'
)
def
__call__
(
self
,
img
):
noise
=
np
.
random
.
normal
(
self
.
mean
,
self
.
var
**
0.5
,
img
.
shape
)
img
=
np
.
clip
(
img
+
noise
,
0
,
255
).
astype
(
np
.
uint8
)
return
img
class
CVMotionBlur
(
object
):
def
__init__
(
self
,
degrees
=
12
,
angle
=
90
):
if
isinstance
(
degrees
,
numbers
.
Number
):
self
.
degree
=
max
(
int
(
sample_asym
(
degrees
)),
1
)
elif
isinstance
(
degrees
,
(
tuple
,
list
))
and
len
(
degrees
)
==
2
:
self
.
degree
=
int
(
sample_uniform
(
degrees
[
0
],
degrees
[
1
]))
else
:
raise
Exception
(
'degree must be number or list with length 2'
)
self
.
angle
=
sample_uniform
(
-
angle
,
angle
)
def
__call__
(
self
,
img
):
M
=
cv2
.
getRotationMatrix2D
((
self
.
degree
//
2
,
self
.
degree
//
2
),
self
.
angle
,
1
)
motion_blur_kernel
=
np
.
zeros
((
self
.
degree
,
self
.
degree
))
motion_blur_kernel
[
self
.
degree
//
2
,
:]
=
1
motion_blur_kernel
=
cv2
.
warpAffine
(
motion_blur_kernel
,
M
,
(
self
.
degree
,
self
.
degree
))
motion_blur_kernel
=
motion_blur_kernel
/
self
.
degree
img
=
cv2
.
filter2D
(
img
,
-
1
,
motion_blur_kernel
)
img
=
np
.
clip
(
img
,
0
,
255
).
astype
(
np
.
uint8
)
return
img
class
CVGeometry
(
object
):
def
__init__
(
self
,
degrees
=
15
,
translate
=
(
0.3
,
0.3
),
scale
=
(
0.5
,
2.
),
shear
=
(
45
,
15
),
distortion
=
0.5
,
p
=
0.5
):
self
.
p
=
p
type_p
=
random
.
random
()
if
type_p
<
0.33
:
self
.
transforms
=
CVRandomRotation
(
degrees
=
degrees
)
elif
type_p
<
0.66
:
self
.
transforms
=
CVRandomAffine
(
degrees
=
degrees
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
)
else
:
self
.
transforms
=
CVRandomPerspective
(
distortion
=
distortion
)
def
__call__
(
self
,
img
):
if
random
.
random
()
<
self
.
p
:
return
self
.
transforms
(
img
)
else
:
return
img
class
CVDeterioration
(
object
):
def
__init__
(
self
,
var
,
degrees
,
factor
,
p
=
0.5
):
self
.
p
=
p
transforms
=
[]
if
var
is
not
None
:
transforms
.
append
(
CVGaussianNoise
(
var
=
var
))
if
degrees
is
not
None
:
transforms
.
append
(
CVMotionBlur
(
degrees
=
degrees
))
if
factor
is
not
None
:
transforms
.
append
(
CVRescale
(
factor
=
factor
))
random
.
shuffle
(
transforms
)
transforms
=
Compose
(
transforms
)
self
.
transforms
=
transforms
def
__call__
(
self
,
img
):
if
random
.
random
()
<
self
.
p
:
return
self
.
transforms
(
img
)
else
:
return
img
class
CVColorJitter
(
object
):
def
__init__
(
self
,
brightness
=
0.5
,
contrast
=
0.5
,
saturation
=
0.5
,
hue
=
0.1
,
p
=
0.5
):
self
.
p
=
p
self
.
transforms
=
transforms
.
ColorJitter
(
brightness
=
brightness
,
contrast
=
contrast
,
saturation
=
saturation
,
hue
=
hue
)
def
__call__
(
self
,
img
):
if
random
.
random
()
<
self
.
p
:
return
self
.
transforms
(
img
)
else
:
return
img
class
VLAug
(
object
):
def
__init__
(
self
,
geometry_p
=
0.5
,
Deterioration_p
=
0.25
,
ColorJitter_p
=
0.25
,
**
kwargs
):
self
.
Geometry
=
CVGeometry
(
degrees
=
45
,
translate
=
(
0.0
,
0.0
),
scale
=
(
0.5
,
2.
),
shear
=
(
45
,
15
),
distortion
=
0.5
,
p
=
geometry_p
)
self
.
Deterioration
=
CVDeterioration
(
var
=
20
,
degrees
=
6
,
factor
=
4
,
p
=
Deterioration_p
)
self
.
ColorJitter
=
CVColorJitter
(
brightness
=
0.5
,
contrast
=
0.5
,
saturation
=
0.5
,
hue
=
0.1
,
p
=
ColorJitter_p
)
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
img
=
self
.
Geometry
(
img
)
img
=
self
.
Deterioration
(
img
)
img
=
self
.
ColorJitter
(
img
)
data
[
'image'
]
=
img
return
data
if
__name__
==
'__main__'
:
geo
=
CVGeometry
(
degrees
=
45
,
translate
=
(
0.0
,
0.0
),
scale
=
(
0.5
,
2.
),
shear
=
(
45
,
15
),
distortion
=
0.5
,
p
=
1
)
det
=
CVDeterioration
(
var
=
20
,
degrees
=
6
,
factor
=
4
,
p
=
1
)
color
=
CVColorJitter
(
brightness
=
0.5
,
contrast
=
0.5
,
saturation
=
0.5
,
hue
=
0.1
,
p
=
1
)
img
=
np
.
ones
((
64
,
256
,
3
))
img
=
geo
(
img
)
img
=
det
(
img
)
img
=
color
(
img
)
# import pdb
# pdb.set_trace()
# print()
ppocr/losses/rec_vl_loss.py
0 → 100644
浏览文件 @
a3a09515
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
nn
class
VLLoss
(
nn
.
Layer
):
def
__init__
(
self
,
mode
=
'LF_1'
,
weight_res
=
0.5
,
weight_mas
=
0.5
,
**
kwargs
):
super
(
VLLoss
,
self
).
__init__
()
self
.
loss_func
=
paddle
.
nn
.
loss
.
CrossEntropyLoss
(
reduction
=
"mean"
)
assert
mode
in
[
'LF_1'
,
'LF_2'
,
'LA'
]
self
.
mode
=
mode
self
.
weight_res
=
weight_res
self
.
weight_mas
=
weight_mas
def
flatten_label
(
self
,
target
):
label_flatten
=
[]
label_length
=
[]
for
i
in
range
(
0
,
target
.
shape
[
0
]):
cur_label
=
target
[
i
].
tolist
()
label_flatten
+=
cur_label
[:
cur_label
.
index
(
0
)
+
1
]
label_length
.
append
(
cur_label
.
index
(
0
)
+
1
)
label_flatten
=
paddle
.
to_tensor
(
label_flatten
,
dtype
=
'int64'
)
label_length
=
paddle
.
to_tensor
(
label_length
,
dtype
=
'int32'
)
return
(
label_flatten
,
label_length
)
def
_flatten
(
self
,
sources
,
lengths
):
return
paddle
.
concat
([
t
[:
l
]
for
t
,
l
in
zip
(
sources
,
lengths
)])
def
forward
(
self
,
predicts
,
batch
):
text_pre
=
predicts
[
0
]
target
=
batch
[
1
].
astype
(
'int64'
)
label_flatten
,
length
=
self
.
flatten_label
(
target
)
text_pre
=
self
.
_flatten
(
text_pre
,
length
)
if
self
.
mode
==
'LF_1'
:
loss
=
self
.
loss_func
(
text_pre
,
label_flatten
)
else
:
text_rem
=
predicts
[
1
]
text_mas
=
predicts
[
2
]
target_res
=
batch
[
2
].
astype
(
'int64'
)
target_sub
=
batch
[
3
].
astype
(
'int64'
)
label_flatten_res
,
length_res
=
self
.
flatten_label
(
target_res
)
label_flatten_sub
,
length_sub
=
self
.
flatten_label
(
target_sub
)
text_rem
=
self
.
_flatten
(
text_rem
,
length_res
)
text_mas
=
self
.
_flatten
(
text_mas
,
length_sub
)
loss_ori
=
self
.
loss_func
(
text_pre
,
label_flatten
)
loss_res
=
self
.
loss_func
(
text_rem
,
label_flatten_res
)
loss_mas
=
self
.
loss_func
(
text_mas
,
label_flatten_sub
)
loss
=
loss_ori
+
loss_res
*
self
.
weight_res
+
loss_mas
*
self
.
weight_mas
return
{
'loss'
:
loss
}
ppocr/modeling/backbones/rec_resnet_45.py
浏览文件 @
a3a09515
...
...
@@ -84,11 +84,15 @@ class BasicBlock(nn.Layer):
class
ResNet45
(
nn
.
Layer
):
def
__init__
(
self
,
block
=
BasicBlock
,
layers
=
[
3
,
4
,
6
,
6
,
3
],
in_channels
=
3
):
def
__init__
(
self
,
in_channels
=
3
,
block
=
BasicBlock
,
layers
=
[
3
,
4
,
6
,
6
,
3
],
strides
=
[
2
,
1
,
2
,
1
,
1
]):
self
.
inplanes
=
32
super
(
ResNet45
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2D
(
3
,
in_channels
,
32
,
kernel_size
=
3
,
stride
=
1
,
...
...
@@ -98,18 +102,13 @@ class ResNet45(nn.Layer):
self
.
bn1
=
nn
.
BatchNorm2D
(
32
)
self
.
relu
=
nn
.
ReLU
()
self
.
layer1
=
self
.
_make_layer
(
block
,
32
,
layers
[
0
],
stride
=
2
)
self
.
layer2
=
self
.
_make_layer
(
block
,
64
,
layers
[
1
],
stride
=
1
)
self
.
layer3
=
self
.
_make_layer
(
block
,
128
,
layers
[
2
],
stride
=
2
)
self
.
layer4
=
self
.
_make_layer
(
block
,
256
,
layers
[
3
],
stride
=
1
)
self
.
layer5
=
self
.
_make_layer
(
block
,
512
,
layers
[
4
],
stride
=
1
)
self
.
layer1
=
self
.
_make_layer
(
block
,
32
,
layers
[
0
],
stride
=
strides
[
0
]
)
self
.
layer2
=
self
.
_make_layer
(
block
,
64
,
layers
[
1
],
stride
=
strides
[
1
]
)
self
.
layer3
=
self
.
_make_layer
(
block
,
128
,
layers
[
2
],
stride
=
strides
[
2
]
)
self
.
layer4
=
self
.
_make_layer
(
block
,
256
,
layers
[
3
],
stride
=
strides
[
3
]
)
self
.
layer5
=
self
.
_make_layer
(
block
,
512
,
layers
[
4
],
stride
=
strides
[
4
]
)
self
.
out_channels
=
512
# for m in self.modules():
# if isinstance(m, nn.Conv2D):
# n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
def
_make_layer
(
self
,
block
,
planes
,
blocks
,
stride
=
1
):
downsample
=
None
if
stride
!=
1
or
self
.
inplanes
!=
planes
*
block
.
expansion
:
...
...
@@ -137,11 +136,9 @@ class ResNet45(nn.Layer):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
# print(x)
x
=
self
.
layer1
(
x
)
x
=
self
.
layer2
(
x
)
x
=
self
.
layer3
(
x
)
# print(x)
x
=
self
.
layer4
(
x
)
x
=
self
.
layer5
(
x
)
return
x
ppocr/modeling/backbones/rec_resnet_aster.py
浏览文件 @
a3a09515
...
...
@@ -20,10 +20,6 @@ import paddle.nn as nn
import
sys
import
math
from
paddle.nn.initializer
import
KaimingNormal
,
Constant
zeros_
=
Constant
(
value
=
0.
)
ones_
=
Constant
(
value
=
1.
)
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
):
...
...
@@ -144,111 +140,4 @@ class ResNet_ASTER(nn.Layer):
rnn_feat
,
_
=
self
.
rnn
(
cnn_feat
)
return
rnn_feat
else
:
return
cnn_feat
class
Block
(
nn
.
Layer
):
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
):
super
(
Block
,
self
).
__init__
()
self
.
conv1
=
conv1x1
(
inplanes
,
planes
)
self
.
bn1
=
nn
.
BatchNorm2D
(
planes
)
self
.
relu
=
nn
.
ReLU
()
self
.
conv2
=
conv3x3
(
planes
,
planes
,
stride
)
self
.
bn2
=
nn
.
BatchNorm2D
(
planes
)
self
.
downsample
=
downsample
self
.
stride
=
stride
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
out
+=
residual
out
=
self
.
relu
(
out
)
return
out
class
ResNet45
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
3
,
compress_layer
=
False
):
super
(
ResNet45
,
self
).
__init__
()
self
.
compress_layer
=
compress_layer
self
.
conv1_new
=
nn
.
Conv2D
(
in_channels
,
32
,
kernel_size
=
(
3
,
3
),
stride
=
1
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn1
=
nn
.
BatchNorm2D
(
32
)
self
.
relu
=
nn
.
ReLU
()
self
.
inplanes
=
32
self
.
layer1
=
self
.
_make_layer
(
32
,
3
,
[
2
,
2
])
# [32, 128]
self
.
layer2
=
self
.
_make_layer
(
64
,
4
,
[
2
,
2
])
# [16, 64]
self
.
layer3
=
self
.
_make_layer
(
128
,
6
,
[
2
,
2
])
# [8, 32]
self
.
layer4
=
self
.
_make_layer
(
256
,
6
,
[
1
,
1
])
# [8, 32]
self
.
layer5
=
self
.
_make_layer
(
512
,
3
,
[
1
,
1
])
# [8, 32]
if
self
.
compress_layer
:
self
.
layer6
=
nn
.
Sequential
(
nn
.
Conv2D
(
512
,
256
,
kernel_size
=
(
3
,
1
),
padding
=
(
0
,
0
),
stride
=
(
1
,
1
)),
nn
.
BatchNorm
(
256
),
nn
.
ReLU
())
self
.
out_channels
=
256
else
:
self
.
out_channels
=
512
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Conv2D
):
KaimingNormal
(
m
.
weight
)
elif
isinstance
(
m
,
nn
.
BatchNorm
):
ones_
(
m
.
weight
)
zeros_
(
m
.
bias
)
def
_make_layer
(
self
,
planes
,
blocks
,
stride
):
downsample
=
None
if
stride
!=
[
1
,
1
]
or
self
.
inplanes
!=
planes
:
downsample
=
nn
.
Sequential
(
conv1x1
(
self
.
inplanes
,
planes
,
stride
),
nn
.
BatchNorm2D
(
planes
))
layers
=
[]
layers
.
append
(
Block
(
self
.
inplanes
,
planes
,
stride
,
downsample
))
self
.
inplanes
=
planes
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
Block
(
self
.
inplanes
,
planes
))
return
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
x
=
self
.
conv1_new
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
x1
=
self
.
layer1
(
x
)
x2
=
self
.
layer2
(
x1
)
x3
=
self
.
layer3
(
x2
)
x4
=
self
.
layer4
(
x3
)
x5
=
self
.
layer5
(
x4
)
if
not
self
.
compress_layer
:
return
x5
else
:
x6
=
self
.
layer6
(
x5
)
return
x6
if
__name__
==
'__main__'
:
model
=
ResNet45
()
x
=
paddle
.
rand
([
1
,
3
,
64
,
256
])
x
=
paddle
.
to_tensor
(
x
)
print
(
x
.
shape
)
out
=
model
(
x
)
print
(
out
.
shape
)
return
cnn_feat
\ No newline at end of file
ppocr/modeling/heads/rec_visionlan_head.py
浏览文件 @
a3a09515
...
...
@@ -22,7 +22,7 @@ import paddle.nn as nn
import
paddle.nn.functional
as
F
from
paddle.nn.initializer
import
Normal
,
XavierNormal
import
numpy
as
np
from
ppocr.modeling.backbones.rec_resnet_
aster
import
ResNet45
from
ppocr.modeling.backbones.rec_resnet_
45
import
ResNet45
class
PositionalEncoding
(
nn
.
Layer
):
...
...
@@ -442,14 +442,6 @@ class MLM_VRM(nn.Layer):
if
out_length
[
j
]
==
0
and
tmp_result
[
j
]
==
0
:
out_length
[
j
]
=
now_step
+
1
now_step
+=
1
# while 0 in out_length and now_step < nsteps:
# tmp_result = text_pre[now_step, :, :]
# out_res[now_step] = tmp_result
# tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
# for j in range(b):
# if out_length[j] == 0 and tmp_result[j] == 0:
# out_length[j] = now_step + 1
# now_step += 1
for
j
in
range
(
0
,
b
):
if
int
(
out_length
[
j
])
==
0
:
out_length
[
j
]
=
nsteps
...
...
ppocr/optimizer/optimizer.py
浏览文件 @
a3a09515
...
...
@@ -77,11 +77,62 @@ class Adam(object):
self
.
grad_clip
=
grad_clip
self
.
name
=
name
self
.
lazy_mode
=
lazy_mode
self
.
group_lr
=
kwargs
.
get
(
'group_lr'
,
False
)
self
.
training_step
=
kwargs
.
get
(
'training_step'
,
None
)
def
__call__
(
self
,
model
):
train_params
=
[
param
for
param
in
model
.
parameters
()
if
param
.
trainable
is
True
]
if
self
.
group_lr
:
if
self
.
training_step
==
'LF_2'
:
import
paddle
if
isinstance
(
model
,
paddle
.
fluid
.
dygraph
.
parallel
.
DataParallel
):
# multi gpu
mlm
=
model
.
_layers
.
head
.
MLM_VRM
.
MLM
.
parameters
()
pre_mlm_pp
=
model
.
_layers
.
head
.
MLM_VRM
.
Prediction
.
pp_share
.
parameters
(
)
pre_mlm_w
=
model
.
_layers
.
head
.
MLM_VRM
.
Prediction
.
w_share
.
parameters
(
)
else
:
# single gpu
mlm
=
model
.
head
.
MLM_VRM
.
MLM
.
parameters
()
pre_mlm_pp
=
model
.
head
.
MLM_VRM
.
Prediction
.
pp_share
.
parameters
(
)
pre_mlm_w
=
model
.
head
.
MLM_VRM
.
Prediction
.
w_share
.
parameters
(
)
total
=
[]
for
param
in
mlm
:
total
.
append
(
id
(
param
))
for
param
in
pre_mlm_pp
:
total
.
append
(
id
(
param
))
for
param
in
pre_mlm_w
:
total
.
append
(
id
(
param
))
group_base_params
=
[
param
for
param
in
model
.
parameters
()
if
id
(
param
)
in
total
]
group_small_params
=
[
param
for
param
in
model
.
parameters
()
if
id
(
param
)
not
in
total
]
train_params
=
[{
'params'
:
group_base_params
},
{
'params'
:
group_small_params
,
'learning_rate'
:
self
.
learning_rate
.
values
[
0
]
*
0.1
}]
else
:
print
(
'group lr currently only support VisionLAN in LF_2 training step'
)
train_params
=
[
param
for
param
in
model
.
parameters
()
if
param
.
trainable
is
True
]
else
:
train_params
=
[
param
for
param
in
model
.
parameters
()
if
param
.
trainable
is
True
]
opt
=
optim
.
Adam
(
learning_rate
=
self
.
learning_rate
,
beta1
=
self
.
beta1
,
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
a3a09515
...
...
@@ -27,8 +27,7 @@ class BaseRecLabelDecode(object):
self
.
character_str
=
[]
if
character_dict_path
is
None
:
# self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
self
.
character_str
=
"abcdefghijklmnopqrstuvwxyz1234567890"
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
dict_character
=
list
(
self
.
character_str
)
else
:
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
...
...
ppocr/utils/dict36.txt
0 → 100644
浏览文件 @
a3a09515
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
1
2
3
4
5
6
7
8
9
0
\ No newline at end of file
tools/export_model.py
浏览文件 @
a3a09515
...
...
@@ -60,7 +60,7 @@ def export_single_model(model,
shape
=
[
None
,
3
,
48
,
160
],
dtype
=
"float32"
),
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
elif
arch_config
[
"algorithm"
]
in
[
"SVTR"
,
"VisionLAN"
]
:
elif
arch_config
[
"algorithm"
]
==
"SVTR"
:
if
arch_config
[
"Head"
][
"name"
]
==
'MultiHead'
:
other_shape
=
[
paddle
.
static
.
InputSpec
(
...
...
@@ -97,6 +97,12 @@ def export_single_model(model,
shape
=
[
None
,
1
,
32
,
100
],
dtype
=
"float32"
),
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
elif
arch_config
[
"algorithm"
]
==
"VisionLAN"
:
other_shape
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
64
,
256
],
dtype
=
"float32"
),
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
elif
arch_config
[
"algorithm"
]
in
[
"LayoutLM"
,
"LayoutLMv2"
,
"LayoutXLM"
]:
input_spec
=
[
paddle
.
static
.
InputSpec
(
...
...
@@ -217,4 +223,4 @@ def main():
if
__name__
==
"__main__"
:
main
()
main
()
\ No newline at end of file
tools/infer/predict_rec.py
浏览文件 @
a3a09515
...
...
@@ -366,6 +366,8 @@ class TextRecognizer(object):
elif
self
.
rec_algorithm
==
"VisionLAN"
:
norm_img
=
self
.
resize_norm_img_vl
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
elif
self
.
rec_algorithm
==
"ABINet"
:
norm_img
=
self
.
resize_norm_img_abinet
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录