Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
cf533b65
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cf533b65
编写于
7月 19, 2022
作者:
A
andyjpaddle
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add vl
上级
05a98305
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
1297 addition
and
27 deletion
+1297
-27
ppocr/data/imaug/__init__.py
ppocr/data/imaug/__init__.py
+2
-1
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+62
-1
ppocr/data/imaug/rec_img_aug.py
ppocr/data/imaug/rec_img_aug.py
+35
-0
ppocr/data/imaug/text_image_aug/__init__.py
ppocr/data/imaug/text_image_aug/__init__.py
+2
-1
ppocr/data/imaug/text_image_aug/vl_aug.py
ppocr/data/imaug/text_image_aug/vl_aug.py
+460
-0
ppocr/losses/__init__.py
ppocr/losses/__init__.py
+3
-1
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+2
-2
ppocr/modeling/backbones/rec_resnet_aster.py
ppocr/modeling/backbones/rec_resnet_aster.py
+111
-0
ppocr/modeling/heads/__init__.py
ppocr/modeling/heads/__init__.py
+2
-1
ppocr/modeling/heads/rec_visionlan_head.py
ppocr/modeling/heads/rec_visionlan_head.py
+498
-0
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+2
-2
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+69
-1
tools/eval.py
tools/eval.py
+1
-1
tools/export_model.py
tools/export_model.py
+1
-1
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+20
-0
tools/program.py
tools/program.py
+27
-15
未找到文件。
ppocr/data/imaug/__init__.py
浏览文件 @
cf533b65
...
@@ -23,7 +23,8 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
...
@@ -23,7 +23,8 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from
.make_pse_gt
import
MakePseGt
from
.make_pse_gt
import
MakePseGt
from
.rec_img_aug
import
BaseDataAugmentation
,
RecAug
,
RecConAug
,
RecResizeImg
,
ClsResizeImg
,
\
from
.rec_img_aug
import
BaseDataAugmentation
,
RecAug
,
RecConAug
,
RecResizeImg
,
ClsResizeImg
,
\
SRNRecResizeImg
,
NRTRRecResizeImg
,
SARRecResizeImg
,
PRENResizeImg
SRNRecResizeImg
,
NRTRRecResizeImg
,
SARRecResizeImg
,
PRENResizeImg
,
VLRecResizeImg
from
.text_image_aug
import
VLAug
from
.ssl_img_aug
import
SSLRotateResize
from
.ssl_img_aug
import
SSLRotateResize
from
.randaugment
import
RandAugment
from
.randaugment
import
RandAugment
from
.copy_paste
import
CopyPaste
from
.copy_paste
import
CopyPaste
...
...
ppocr/data/imaug/label_ops.py
浏览文件 @
cf533b65
...
@@ -23,6 +23,7 @@ import string
...
@@ -23,6 +23,7 @@ import string
from
shapely.geometry
import
LineString
,
Point
,
Polygon
from
shapely.geometry
import
LineString
,
Point
,
Polygon
import
json
import
json
import
copy
import
copy
from
random
import
sample
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.logging
import
get_logger
...
@@ -443,7 +444,9 @@ class KieLabelEncode(object):
...
@@ -443,7 +444,9 @@ class KieLabelEncode(object):
elif
'key_cls'
in
anno
.
keys
():
elif
'key_cls'
in
anno
.
keys
():
labels
.
append
(
anno
[
'key_cls'
])
labels
.
append
(
anno
[
'key_cls'
])
else
:
else
:
raise
ValueError
(
"Cannot found 'key_cls' in ann.keys(), please check your training annotation."
)
raise
ValueError
(
"Cannot found 'key_cls' in ann.keys(), please check your training annotation."
)
edges
.
append
(
ann
.
get
(
'edge'
,
0
))
edges
.
append
(
ann
.
get
(
'edge'
,
0
))
ann_infos
=
dict
(
ann_infos
=
dict
(
image
=
data
[
'image'
],
image
=
data
[
'image'
],
...
@@ -1044,3 +1047,61 @@ class MultiLabelEncode(BaseRecLabelEncode):
...
@@ -1044,3 +1047,61 @@ class MultiLabelEncode(BaseRecLabelEncode):
data_out
[
'label_sar'
]
=
sar
[
'label'
]
data_out
[
'label_sar'
]
=
sar
[
'label'
]
data_out
[
'length'
]
=
ctc
[
'length'
]
data_out
[
'length'
]
=
ctc
[
'length'
]
return
data_out
return
data_out
class
VLLabelEncode
(
BaseRecLabelEncode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
max_text_length
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
VLLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
use_space_char
)
def
__call__
(
self
,
data
):
text
=
data
[
'label'
]
# original string
# generate occluded text
len_str
=
len
(
text
)
if
len_str
<=
0
:
return
None
change_num
=
1
order
=
list
(
range
(
len_str
))
change_id
=
sample
(
order
,
change_num
)[
0
]
label_sub
=
text
[
change_id
]
if
change_id
==
(
len_str
-
1
):
label_res
=
text
[:
change_id
]
elif
change_id
==
0
:
label_res
=
text
[
1
:]
else
:
label_res
=
text
[:
change_id
]
+
text
[
change_id
+
1
:]
data
[
'label_res'
]
=
label_res
# remaining string
data
[
'label_sub'
]
=
label_sub
# occluded character
data
[
'label_id'
]
=
change_id
# character index
# encode label
text
=
self
.
encode
(
text
)
if
text
is
None
:
return
None
text
=
[
i
+
1
for
i
in
text
]
data
[
'length'
]
=
np
.
array
(
len
(
text
))
text
=
text
+
[
0
]
*
(
self
.
max_text_len
-
len
(
text
))
data
[
'label'
]
=
np
.
array
(
text
)
label_res
=
self
.
encode
(
label_res
)
label_sub
=
self
.
encode
(
label_sub
)
if
label_res
is
None
:
label_res
=
[]
else
:
label_res
=
[
i
+
1
for
i
in
label_res
]
if
label_sub
is
None
:
label_sub
=
[]
else
:
label_sub
=
[
i
+
1
for
i
in
label_sub
]
data
[
'length_res'
]
=
np
.
array
(
len
(
label_res
))
data
[
'length_sub'
]
=
np
.
array
(
len
(
label_sub
))
label_res
=
label_res
+
[
0
]
*
(
self
.
max_text_len
-
len
(
label_res
))
label_sub
=
label_sub
+
[
0
]
*
(
self
.
max_text_len
-
len
(
label_sub
))
data
[
'label_res'
]
=
np
.
array
(
label_res
)
data
[
'label_sub'
]
=
np
.
array
(
label_sub
)
return
data
ppocr/data/imaug/rec_img_aug.py
浏览文件 @
cf533b65
...
@@ -213,6 +213,41 @@ class RecResizeImg(object):
...
@@ -213,6 +213,41 @@ class RecResizeImg(object):
return
data
return
data
class
VLRecResizeImg
(
object
):
def
__init__
(
self
,
image_shape
,
infer_mode
=
False
,
character_dict_path
=
'./ppocr/utils/ppocr_keys_v1.txt'
,
padding
=
True
,
**
kwargs
):
self
.
image_shape
=
image_shape
self
.
infer_mode
=
infer_mode
self
.
character_dict_path
=
character_dict_path
self
.
padding
=
padding
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
if
self
.
infer_mode
and
self
.
character_dict_path
is
not
None
:
norm_img
,
valid_ratio
=
resize_norm_img_chinese
(
img
,
self
.
image_shape
)
else
:
imgC
,
imgH
,
imgW
=
self
.
image_shape
resized_image
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_w
=
imgW
resized_image
=
resized_image
.
astype
(
'float32'
)
if
self
.
image_shape
[
0
]
==
1
:
resized_image
=
resized_image
/
255
norm_img
=
resized_image
[
np
.
newaxis
,
:]
else
:
norm_img
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
valid_ratio
=
min
(
1.0
,
float
(
resized_w
/
imgW
))
data
[
'image'
]
=
norm_img
data
[
'valid_ratio'
]
=
valid_ratio
return
data
class
SRNRecResizeImg
(
object
):
class
SRNRecResizeImg
(
object
):
def
__init__
(
self
,
image_shape
,
num_heads
,
max_text_length
,
**
kwargs
):
def
__init__
(
self
,
image_shape
,
num_heads
,
max_text_length
,
**
kwargs
):
self
.
image_shape
=
image_shape
self
.
image_shape
=
image_shape
...
...
ppocr/data/imaug/text_image_aug/__init__.py
浏览文件 @
cf533b65
...
@@ -13,5 +13,6 @@
...
@@ -13,5 +13,6 @@
# limitations under the License.
# limitations under the License.
from
.augment
import
tia_perspective
,
tia_distort
,
tia_stretch
from
.augment
import
tia_perspective
,
tia_distort
,
tia_stretch
from
.vl_aug
import
VLAug
__all__
=
[
'tia_distort'
,
'tia_stretch'
,
'tia_perspective'
]
__all__
=
[
'tia_distort'
,
'tia_stretch'
,
'tia_perspective'
,
'VLAug'
]
ppocr/data/imaug/text_image_aug/vl_aug.py
0 → 100644
浏览文件 @
cf533b65
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
numbers
import
random
import
cv2
import
numpy
as
np
from
PIL
import
Image
from
paddle.vision
import
transforms
from
paddle.vision.transforms
import
Compose
def
sample_asym
(
magnitude
,
size
=
None
):
return
np
.
random
.
beta
(
1
,
4
,
size
)
*
magnitude
def
sample_sym
(
magnitude
,
size
=
None
):
return
(
np
.
random
.
beta
(
4
,
4
,
size
=
size
)
-
0.5
)
*
2
*
magnitude
def
sample_uniform
(
low
,
high
,
size
=
None
):
return
np
.
random
.
uniform
(
low
,
high
,
size
=
size
)
def
get_interpolation
(
type
=
'random'
):
if
type
==
'random'
:
choice
=
[
cv2
.
INTER_NEAREST
,
cv2
.
INTER_LINEAR
,
cv2
.
INTER_CUBIC
,
cv2
.
INTER_AREA
]
interpolation
=
choice
[
random
.
randint
(
0
,
len
(
choice
)
-
1
)]
elif
type
==
'nearest'
:
interpolation
=
cv2
.
INTER_NEAREST
elif
type
==
'linear'
:
interpolation
=
cv2
.
INTER_LINEAR
elif
type
==
'cubic'
:
interpolation
=
cv2
.
INTER_CUBIC
elif
type
==
'area'
:
interpolation
=
cv2
.
INTER_AREA
else
:
raise
TypeError
(
'Interpolation types only nearest, linear, cubic, area are supported!'
)
return
interpolation
class
CVRandomRotation
(
object
):
def
__init__
(
self
,
degrees
=
15
):
assert
isinstance
(
degrees
,
numbers
.
Number
),
"degree should be a single number."
assert
degrees
>=
0
,
"degree must be positive."
self
.
degrees
=
degrees
@
staticmethod
def
get_params
(
degrees
):
return
sample_sym
(
degrees
)
def
__call__
(
self
,
img
):
angle
=
self
.
get_params
(
self
.
degrees
)
src_h
,
src_w
=
img
.
shape
[:
2
]
M
=
cv2
.
getRotationMatrix2D
(
center
=
(
src_w
/
2
,
src_h
/
2
),
angle
=
angle
,
scale
=
1.0
)
abs_cos
,
abs_sin
=
abs
(
M
[
0
,
0
]),
abs
(
M
[
0
,
1
])
dst_w
=
int
(
src_h
*
abs_sin
+
src_w
*
abs_cos
)
dst_h
=
int
(
src_h
*
abs_cos
+
src_w
*
abs_sin
)
M
[
0
,
2
]
+=
(
dst_w
-
src_w
)
/
2
M
[
1
,
2
]
+=
(
dst_h
-
src_h
)
/
2
flags
=
get_interpolation
()
return
cv2
.
warpAffine
(
img
,
M
,
(
dst_w
,
dst_h
),
flags
=
flags
,
borderMode
=
cv2
.
BORDER_REPLICATE
)
class
CVRandomAffine
(
object
):
def
__init__
(
self
,
degrees
,
translate
=
None
,
scale
=
None
,
shear
=
None
):
assert
isinstance
(
degrees
,
numbers
.
Number
),
"degree should be a single number."
assert
degrees
>=
0
,
"degree must be positive."
self
.
degrees
=
degrees
if
translate
is
not
None
:
assert
isinstance
(
translate
,
(
tuple
,
list
))
and
len
(
translate
)
==
2
,
\
"translate should be a list or tuple and it must be of length 2."
for
t
in
translate
:
if
not
(
0.0
<=
t
<=
1.0
):
raise
ValueError
(
"translation values should be between 0 and 1"
)
self
.
translate
=
translate
if
scale
is
not
None
:
assert
isinstance
(
scale
,
(
tuple
,
list
))
and
len
(
scale
)
==
2
,
\
"scale should be a list or tuple and it must be of length 2."
for
s
in
scale
:
if
s
<=
0
:
raise
ValueError
(
"scale values should be positive"
)
self
.
scale
=
scale
if
shear
is
not
None
:
if
isinstance
(
shear
,
numbers
.
Number
):
if
shear
<
0
:
raise
ValueError
(
"If shear is a single number, it must be positive."
)
self
.
shear
=
[
shear
]
else
:
assert
isinstance
(
shear
,
(
tuple
,
list
))
and
(
len
(
shear
)
==
2
),
\
"shear should be a list or tuple and it must be of length 2."
self
.
shear
=
shear
else
:
self
.
shear
=
shear
def
_get_inverse_affine_matrix
(
self
,
center
,
angle
,
translate
,
scale
,
shear
):
from
numpy
import
sin
,
cos
,
tan
if
isinstance
(
shear
,
numbers
.
Number
):
shear
=
[
shear
,
0
]
if
not
isinstance
(
shear
,
(
tuple
,
list
))
and
len
(
shear
)
==
2
:
raise
ValueError
(
"Shear should be a single value or a tuple/list containing "
+
"two values. Got {}"
.
format
(
shear
))
rot
=
math
.
radians
(
angle
)
sx
,
sy
=
[
math
.
radians
(
s
)
for
s
in
shear
]
cx
,
cy
=
center
tx
,
ty
=
translate
# RSS without scaling
a
=
cos
(
rot
-
sy
)
/
cos
(
sy
)
b
=
-
cos
(
rot
-
sy
)
*
tan
(
sx
)
/
cos
(
sy
)
-
sin
(
rot
)
c
=
sin
(
rot
-
sy
)
/
cos
(
sy
)
d
=
-
sin
(
rot
-
sy
)
*
tan
(
sx
)
/
cos
(
sy
)
+
cos
(
rot
)
# Inverted rotation matrix with scale and shear
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
M
=
[
d
,
-
b
,
0
,
-
c
,
a
,
0
]
M
=
[
x
/
scale
for
x
in
M
]
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
M
[
2
]
+=
M
[
0
]
*
(
-
cx
-
tx
)
+
M
[
1
]
*
(
-
cy
-
ty
)
M
[
5
]
+=
M
[
3
]
*
(
-
cx
-
tx
)
+
M
[
4
]
*
(
-
cy
-
ty
)
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
M
[
2
]
+=
cx
M
[
5
]
+=
cy
return
M
@
staticmethod
def
get_params
(
degrees
,
translate
,
scale_ranges
,
shears
,
height
):
angle
=
sample_sym
(
degrees
)
if
translate
is
not
None
:
max_dx
=
translate
[
0
]
*
height
max_dy
=
translate
[
1
]
*
height
translations
=
(
np
.
round
(
sample_sym
(
max_dx
)),
np
.
round
(
sample_sym
(
max_dy
)))
else
:
translations
=
(
0
,
0
)
if
scale_ranges
is
not
None
:
scale
=
sample_uniform
(
scale_ranges
[
0
],
scale_ranges
[
1
])
else
:
scale
=
1.0
if
shears
is
not
None
:
if
len
(
shears
)
==
1
:
shear
=
[
sample_sym
(
shears
[
0
]),
0.
]
elif
len
(
shears
)
==
2
:
shear
=
[
sample_sym
(
shears
[
0
]),
sample_sym
(
shears
[
1
])]
else
:
shear
=
0.0
return
angle
,
translations
,
scale
,
shear
def
__call__
(
self
,
img
):
src_h
,
src_w
=
img
.
shape
[:
2
]
angle
,
translate
,
scale
,
shear
=
self
.
get_params
(
self
.
degrees
,
self
.
translate
,
self
.
scale
,
self
.
shear
,
src_h
)
M
=
self
.
_get_inverse_affine_matrix
((
src_w
/
2
,
src_h
/
2
),
angle
,
(
0
,
0
),
scale
,
shear
)
M
=
np
.
array
(
M
).
reshape
(
2
,
3
)
startpoints
=
[(
0
,
0
),
(
src_w
-
1
,
0
),
(
src_w
-
1
,
src_h
-
1
),
(
0
,
src_h
-
1
)]
project
=
lambda
x
,
y
,
a
,
b
,
c
:
int
(
a
*
x
+
b
*
y
+
c
)
endpoints
=
[(
project
(
x
,
y
,
*
M
[
0
]),
project
(
x
,
y
,
*
M
[
1
]))
for
x
,
y
in
startpoints
]
rect
=
cv2
.
minAreaRect
(
np
.
array
(
endpoints
))
bbox
=
cv2
.
boxPoints
(
rect
).
astype
(
dtype
=
np
.
int
)
max_x
,
max_y
=
bbox
[:,
0
].
max
(),
bbox
[:,
1
].
max
()
min_x
,
min_y
=
bbox
[:,
0
].
min
(),
bbox
[:,
1
].
min
()
dst_w
=
int
(
max_x
-
min_x
)
dst_h
=
int
(
max_y
-
min_y
)
M
[
0
,
2
]
+=
(
dst_w
-
src_w
)
/
2
M
[
1
,
2
]
+=
(
dst_h
-
src_h
)
/
2
# add translate
dst_w
+=
int
(
abs
(
translate
[
0
]))
dst_h
+=
int
(
abs
(
translate
[
1
]))
if
translate
[
0
]
<
0
:
M
[
0
,
2
]
+=
abs
(
translate
[
0
])
if
translate
[
1
]
<
0
:
M
[
1
,
2
]
+=
abs
(
translate
[
1
])
flags
=
get_interpolation
()
return
cv2
.
warpAffine
(
img
,
M
,
(
dst_w
,
dst_h
),
flags
=
flags
,
borderMode
=
cv2
.
BORDER_REPLICATE
)
class
CVRandomPerspective
(
object
):
def
__init__
(
self
,
distortion
=
0.5
):
self
.
distortion
=
distortion
def
get_params
(
self
,
width
,
height
,
distortion
):
offset_h
=
sample_asym
(
distortion
*
height
/
2
,
size
=
4
).
astype
(
dtype
=
np
.
int
)
offset_w
=
sample_asym
(
distortion
*
width
/
2
,
size
=
4
).
astype
(
dtype
=
np
.
int
)
topleft
=
(
offset_w
[
0
],
offset_h
[
0
])
topright
=
(
width
-
1
-
offset_w
[
1
],
offset_h
[
1
])
botright
=
(
width
-
1
-
offset_w
[
2
],
height
-
1
-
offset_h
[
2
])
botleft
=
(
offset_w
[
3
],
height
-
1
-
offset_h
[
3
])
startpoints
=
[(
0
,
0
),
(
width
-
1
,
0
),
(
width
-
1
,
height
-
1
),
(
0
,
height
-
1
)]
endpoints
=
[
topleft
,
topright
,
botright
,
botleft
]
return
np
.
array
(
startpoints
,
dtype
=
np
.
float32
),
np
.
array
(
endpoints
,
dtype
=
np
.
float32
)
def
__call__
(
self
,
img
):
height
,
width
=
img
.
shape
[:
2
]
startpoints
,
endpoints
=
self
.
get_params
(
width
,
height
,
self
.
distortion
)
M
=
cv2
.
getPerspectiveTransform
(
startpoints
,
endpoints
)
# TODO: more robust way to crop image
rect
=
cv2
.
minAreaRect
(
endpoints
)
bbox
=
cv2
.
boxPoints
(
rect
).
astype
(
dtype
=
np
.
int
)
max_x
,
max_y
=
bbox
[:,
0
].
max
(),
bbox
[:,
1
].
max
()
min_x
,
min_y
=
bbox
[:,
0
].
min
(),
bbox
[:,
1
].
min
()
min_x
,
min_y
=
max
(
min_x
,
0
),
max
(
min_y
,
0
)
flags
=
get_interpolation
()
img
=
cv2
.
warpPerspective
(
img
,
M
,
(
max_x
,
max_y
),
flags
=
flags
,
borderMode
=
cv2
.
BORDER_REPLICATE
)
img
=
img
[
min_y
:,
min_x
:]
return
img
class
CVRescale
(
object
):
def
__init__
(
self
,
factor
=
4
,
base_size
=
(
128
,
512
)):
""" Define image scales using gaussian pyramid and rescale image to target scale.
Args:
factor: the decayed factor from base size, factor=4 keeps target scale by default.
base_size: base size the build the bottom layer of pyramid
"""
if
isinstance
(
factor
,
numbers
.
Number
):
self
.
factor
=
round
(
sample_uniform
(
0
,
factor
))
elif
isinstance
(
factor
,
(
tuple
,
list
))
and
len
(
factor
)
==
2
:
self
.
factor
=
round
(
sample_uniform
(
factor
[
0
],
factor
[
1
]))
else
:
raise
Exception
(
'factor must be number or list with length 2'
)
# assert factor is valid
self
.
base_h
,
self
.
base_w
=
base_size
[:
2
]
def
__call__
(
self
,
img
):
if
self
.
factor
==
0
:
return
img
src_h
,
src_w
=
img
.
shape
[:
2
]
cur_w
,
cur_h
=
self
.
base_w
,
self
.
base_h
scale_img
=
cv2
.
resize
(
img
,
(
cur_w
,
cur_h
),
interpolation
=
get_interpolation
())
for
_
in
range
(
np
.
int
(
self
.
factor
)):
scale_img
=
cv2
.
pyrDown
(
scale_img
)
scale_img
=
cv2
.
resize
(
scale_img
,
(
src_w
,
src_h
),
interpolation
=
get_interpolation
())
return
scale_img
class
CVGaussianNoise
(
object
):
def
__init__
(
self
,
mean
=
0
,
var
=
20
):
self
.
mean
=
mean
if
isinstance
(
var
,
numbers
.
Number
):
self
.
var
=
max
(
int
(
sample_asym
(
var
)),
1
)
elif
isinstance
(
var
,
(
tuple
,
list
))
and
len
(
var
)
==
2
:
self
.
var
=
int
(
sample_uniform
(
var
[
0
],
var
[
1
]))
else
:
raise
Exception
(
'degree must be number or list with length 2'
)
def
__call__
(
self
,
img
):
noise
=
np
.
random
.
normal
(
self
.
mean
,
self
.
var
**
0.5
,
img
.
shape
)
img
=
np
.
clip
(
img
+
noise
,
0
,
255
).
astype
(
np
.
uint8
)
return
img
class
CVMotionBlur
(
object
):
def
__init__
(
self
,
degrees
=
12
,
angle
=
90
):
if
isinstance
(
degrees
,
numbers
.
Number
):
self
.
degree
=
max
(
int
(
sample_asym
(
degrees
)),
1
)
elif
isinstance
(
degrees
,
(
tuple
,
list
))
and
len
(
degrees
)
==
2
:
self
.
degree
=
int
(
sample_uniform
(
degrees
[
0
],
degrees
[
1
]))
else
:
raise
Exception
(
'degree must be number or list with length 2'
)
self
.
angle
=
sample_uniform
(
-
angle
,
angle
)
def
__call__
(
self
,
img
):
M
=
cv2
.
getRotationMatrix2D
((
self
.
degree
//
2
,
self
.
degree
//
2
),
self
.
angle
,
1
)
motion_blur_kernel
=
np
.
zeros
((
self
.
degree
,
self
.
degree
))
motion_blur_kernel
[
self
.
degree
//
2
,
:]
=
1
motion_blur_kernel
=
cv2
.
warpAffine
(
motion_blur_kernel
,
M
,
(
self
.
degree
,
self
.
degree
))
motion_blur_kernel
=
motion_blur_kernel
/
self
.
degree
img
=
cv2
.
filter2D
(
img
,
-
1
,
motion_blur_kernel
)
img
=
np
.
clip
(
img
,
0
,
255
).
astype
(
np
.
uint8
)
return
img
class
CVGeometry
(
object
):
def
__init__
(
self
,
degrees
=
15
,
translate
=
(
0.3
,
0.3
),
scale
=
(
0.5
,
2.
),
shear
=
(
45
,
15
),
distortion
=
0.5
,
p
=
0.5
):
self
.
p
=
p
type_p
=
random
.
random
()
if
type_p
<
0.33
:
self
.
transforms
=
CVRandomRotation
(
degrees
=
degrees
)
elif
type_p
<
0.66
:
self
.
transforms
=
CVRandomAffine
(
degrees
=
degrees
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
)
else
:
self
.
transforms
=
CVRandomPerspective
(
distortion
=
distortion
)
def
__call__
(
self
,
img
):
if
random
.
random
()
<
self
.
p
:
return
self
.
transforms
(
img
)
else
:
return
img
class
CVDeterioration
(
object
):
def
__init__
(
self
,
var
,
degrees
,
factor
,
p
=
0.5
):
self
.
p
=
p
transforms
=
[]
if
var
is
not
None
:
transforms
.
append
(
CVGaussianNoise
(
var
=
var
))
if
degrees
is
not
None
:
transforms
.
append
(
CVMotionBlur
(
degrees
=
degrees
))
if
factor
is
not
None
:
transforms
.
append
(
CVRescale
(
factor
=
factor
))
random
.
shuffle
(
transforms
)
transforms
=
Compose
(
transforms
)
self
.
transforms
=
transforms
def
__call__
(
self
,
img
):
if
random
.
random
()
<
self
.
p
:
return
self
.
transforms
(
img
)
else
:
return
img
class
CVColorJitter
(
object
):
def
__init__
(
self
,
brightness
=
0.5
,
contrast
=
0.5
,
saturation
=
0.5
,
hue
=
0.1
,
p
=
0.5
):
self
.
p
=
p
self
.
transforms
=
transforms
.
ColorJitter
(
brightness
=
brightness
,
contrast
=
contrast
,
saturation
=
saturation
,
hue
=
hue
)
def
__call__
(
self
,
img
):
if
random
.
random
()
<
self
.
p
:
return
self
.
transforms
(
img
)
else
:
return
img
class
VLAug
(
object
):
def
__init__
(
self
,
geometry_p
=
0.5
,
Deterioration_p
=
0.25
,
ColorJitter_p
=
0.25
,
**
kwargs
):
self
.
Geometry
=
CVGeometry
(
degrees
=
45
,
translate
=
(
0.0
,
0.0
),
scale
=
(
0.5
,
2.
),
shear
=
(
45
,
15
),
distortion
=
0.5
,
p
=
geometry_p
)
self
.
Deterioration
=
CVDeterioration
(
var
=
20
,
degrees
=
6
,
factor
=
4
,
p
=
Deterioration_p
)
self
.
ColorJitter
=
CVColorJitter
(
brightness
=
0.5
,
contrast
=
0.5
,
saturation
=
0.5
,
hue
=
0.1
,
p
=
ColorJitter_p
)
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
img
=
self
.
Geometry
(
img
)
img
=
self
.
Deterioration
(
img
)
img
=
self
.
ColorJitter
(
img
)
data
[
'image'
]
=
img
return
data
if
__name__
==
'__main__'
:
geo
=
CVGeometry
(
degrees
=
45
,
translate
=
(
0.0
,
0.0
),
scale
=
(
0.5
,
2.
),
shear
=
(
45
,
15
),
distortion
=
0.5
,
p
=
1
)
det
=
CVDeterioration
(
var
=
20
,
degrees
=
6
,
factor
=
4
,
p
=
1
)
color
=
CVColorJitter
(
brightness
=
0.5
,
contrast
=
0.5
,
saturation
=
0.5
,
hue
=
0.1
,
p
=
1
)
img
=
np
.
ones
((
64
,
256
,
3
))
img
=
geo
(
img
)
img
=
det
(
img
)
img
=
color
(
img
)
# import pdb
# pdb.set_trace()
# print()
ppocr/losses/__init__.py
浏览文件 @
cf533b65
...
@@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss
...
@@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss
from
.rec_aster_loss
import
AsterLoss
from
.rec_aster_loss
import
AsterLoss
from
.rec_pren_loss
import
PRENLoss
from
.rec_pren_loss
import
PRENLoss
from
.rec_multi_loss
import
MultiLoss
from
.rec_multi_loss
import
MultiLoss
from
.rec_vl_loss
import
VLLoss
# cls loss
# cls loss
from
.cls_loss
import
ClsLoss
from
.cls_loss
import
ClsLoss
...
@@ -61,7 +62,8 @@ def build_loss(config):
...
@@ -61,7 +62,8 @@ def build_loss(config):
'DBLoss'
,
'PSELoss'
,
'EASTLoss'
,
'SASTLoss'
,
'FCELoss'
,
'CTCLoss'
,
'DBLoss'
,
'PSELoss'
,
'EASTLoss'
,
'SASTLoss'
,
'FCELoss'
,
'CTCLoss'
,
'ClsLoss'
,
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'ClsLoss'
,
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'NRTRLoss'
,
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
,
'SDMGRLoss'
,
'NRTRLoss'
,
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
,
'SDMGRLoss'
,
'VQASerTokenLayoutLMLoss'
,
'LossFromOutput'
,
'PRENLoss'
,
'MultiLoss'
'VQASerTokenLayoutLMLoss'
,
'LossFromOutput'
,
'PRENLoss'
,
'MultiLoss'
,
'VLLoss'
]
]
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
module_name
=
config
.
pop
(
'name'
)
...
...
ppocr/modeling/backbones/__init__.py
浏览文件 @
cf533b65
...
@@ -28,14 +28,14 @@ def build_backbone(config, model_type):
...
@@ -28,14 +28,14 @@ def build_backbone(config, model_type):
from
.rec_mv1_enhance
import
MobileNetV1Enhance
from
.rec_mv1_enhance
import
MobileNetV1Enhance
from
.rec_nrtr_mtb
import
MTB
from
.rec_nrtr_mtb
import
MTB
from
.rec_resnet_31
import
ResNet31
from
.rec_resnet_31
import
ResNet31
from
.rec_resnet_aster
import
ResNet_ASTER
from
.rec_resnet_aster
import
ResNet_ASTER
,
ResNet45
from
.rec_micronet
import
MicroNet
from
.rec_micronet
import
MicroNet
from
.rec_efficientb3_pren
import
EfficientNetb3_PREN
from
.rec_efficientb3_pren
import
EfficientNetb3_PREN
from
.rec_svtrnet
import
SVTRNet
from
.rec_svtrnet
import
SVTRNet
support_dict
=
[
support_dict
=
[
'MobileNetV1Enhance'
,
'MobileNetV3'
,
'ResNet'
,
'ResNetFPN'
,
'MTB'
,
'MobileNetV1Enhance'
,
'MobileNetV3'
,
'ResNet'
,
'ResNetFPN'
,
'MTB'
,
"ResNet31"
,
"ResNet_ASTER"
,
'MicroNet'
,
'EfficientNetb3_PREN'
,
"ResNet31"
,
"ResNet_ASTER"
,
'MicroNet'
,
'EfficientNetb3_PREN'
,
'SVTRNet'
'SVTRNet'
,
'ResNet45'
]
]
elif
model_type
==
"e2e"
:
elif
model_type
==
"e2e"
:
from
.e2e_resnet_vd_pg
import
ResNet
from
.e2e_resnet_vd_pg
import
ResNet
...
...
ppocr/modeling/backbones/rec_resnet_aster.py
浏览文件 @
cf533b65
...
@@ -20,6 +20,10 @@ import paddle.nn as nn
...
@@ -20,6 +20,10 @@ import paddle.nn as nn
import
sys
import
sys
import
math
import
math
from
paddle.nn.initializer
import
KaimingNormal
,
Constant
zeros_
=
Constant
(
value
=
0.
)
ones_
=
Constant
(
value
=
1.
)
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
):
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
):
...
@@ -141,3 +145,110 @@ class ResNet_ASTER(nn.Layer):
...
@@ -141,3 +145,110 @@ class ResNet_ASTER(nn.Layer):
return
rnn_feat
return
rnn_feat
else
:
else
:
return
cnn_feat
return
cnn_feat
class
Block
(
nn
.
Layer
):
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
):
super
(
Block
,
self
).
__init__
()
self
.
conv1
=
conv1x1
(
inplanes
,
planes
)
self
.
bn1
=
nn
.
BatchNorm2D
(
planes
)
self
.
relu
=
nn
.
ReLU
()
self
.
conv2
=
conv3x3
(
planes
,
planes
,
stride
)
self
.
bn2
=
nn
.
BatchNorm2D
(
planes
)
self
.
downsample
=
downsample
self
.
stride
=
stride
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
out
+=
residual
out
=
self
.
relu
(
out
)
return
out
class
ResNet45
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
3
,
compress_layer
=
False
):
super
(
ResNet45
,
self
).
__init__
()
self
.
compress_layer
=
compress_layer
self
.
conv1_new
=
nn
.
Conv2D
(
in_channels
,
32
,
kernel_size
=
(
3
,
3
),
stride
=
1
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn1
=
nn
.
BatchNorm2D
(
32
)
self
.
relu
=
nn
.
ReLU
()
self
.
inplanes
=
32
self
.
layer1
=
self
.
_make_layer
(
32
,
3
,
[
2
,
2
])
# [32, 128]
self
.
layer2
=
self
.
_make_layer
(
64
,
4
,
[
2
,
2
])
# [16, 64]
self
.
layer3
=
self
.
_make_layer
(
128
,
6
,
[
2
,
2
])
# [8, 32]
self
.
layer4
=
self
.
_make_layer
(
256
,
6
,
[
1
,
1
])
# [8, 32]
self
.
layer5
=
self
.
_make_layer
(
512
,
3
,
[
1
,
1
])
# [8, 32]
if
self
.
compress_layer
:
self
.
layer6
=
nn
.
Sequential
(
nn
.
Conv2D
(
512
,
256
,
kernel_size
=
(
3
,
1
),
padding
=
(
0
,
0
),
stride
=
(
1
,
1
)),
nn
.
BatchNorm
(
256
),
nn
.
ReLU
())
self
.
out_channels
=
256
else
:
self
.
out_channels
=
512
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Conv2D
):
KaimingNormal
(
m
.
weight
)
elif
isinstance
(
m
,
nn
.
BatchNorm
):
ones_
(
m
.
weight
)
zeros_
(
m
.
bias
)
def
_make_layer
(
self
,
planes
,
blocks
,
stride
):
downsample
=
None
if
stride
!=
[
1
,
1
]
or
self
.
inplanes
!=
planes
:
downsample
=
nn
.
Sequential
(
conv1x1
(
self
.
inplanes
,
planes
,
stride
),
nn
.
BatchNorm2D
(
planes
))
layers
=
[]
layers
.
append
(
Block
(
self
.
inplanes
,
planes
,
stride
,
downsample
))
self
.
inplanes
=
planes
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
Block
(
self
.
inplanes
,
planes
))
return
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
x
=
self
.
conv1_new
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
x1
=
self
.
layer1
(
x
)
x2
=
self
.
layer2
(
x1
)
x3
=
self
.
layer3
(
x2
)
x4
=
self
.
layer4
(
x3
)
x5
=
self
.
layer5
(
x4
)
if
not
self
.
compress_layer
:
return
x5
else
:
x6
=
self
.
layer6
(
x5
)
return
x6
if
__name__
==
'__main__'
:
model
=
ResNet45
()
x
=
paddle
.
rand
([
1
,
3
,
64
,
256
])
x
=
paddle
.
to_tensor
(
x
)
print
(
x
.
shape
)
out
=
model
(
x
)
print
(
out
.
shape
)
ppocr/modeling/heads/__init__.py
浏览文件 @
cf533b65
...
@@ -33,6 +33,7 @@ def build_head(config):
...
@@ -33,6 +33,7 @@ def build_head(config):
from
.rec_aster_head
import
AsterHead
from
.rec_aster_head
import
AsterHead
from
.rec_pren_head
import
PRENHead
from
.rec_pren_head
import
PRENHead
from
.rec_multi_head
import
MultiHead
from
.rec_multi_head
import
MultiHead
from
.rec_visionlan_head
import
VLHead
# cls head
# cls head
from
.cls_head
import
ClsHead
from
.cls_head
import
ClsHead
...
@@ -46,7 +47,7 @@ def build_head(config):
...
@@ -46,7 +47,7 @@ def build_head(config):
'DBHead'
,
'PSEHead'
,
'FCEHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'DBHead'
,
'PSEHead'
,
'FCEHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'ClsHead'
,
'AttentionHead'
,
'SRNHead'
,
'PGHead'
,
'Transformer'
,
'ClsHead'
,
'AttentionHead'
,
'SRNHead'
,
'PGHead'
,
'Transformer'
,
'TableAttentionHead'
,
'SARHead'
,
'AsterHead'
,
'SDMGRHead'
,
'PRENHead'
,
'TableAttentionHead'
,
'SARHead'
,
'AsterHead'
,
'SDMGRHead'
,
'PRENHead'
,
'MultiHead'
'MultiHead'
,
'VLHead'
]
]
#table head
#table head
...
...
ppocr/modeling/heads/rec_visionlan_head.py
0 → 100644
浏览文件 @
cf533b65
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn.initializer
import
Normal
,
XavierNormal
import
numpy
as
np
from
ppocr.modeling.backbones.rec_resnet_aster
import
ResNet45
class
PositionalEncoding
(
nn
.
Layer
):
def
__init__
(
self
,
d_hid
,
n_position
=
200
):
super
(
PositionalEncoding
,
self
).
__init__
()
self
.
register_buffer
(
'pos_table'
,
self
.
_get_sinusoid_encoding_table
(
n_position
,
d_hid
))
def
_get_sinusoid_encoding_table
(
self
,
n_position
,
d_hid
):
''' Sinusoid position encoding table '''
def
get_position_angle_vec
(
position
):
return
[
position
/
np
.
power
(
10000
,
2
*
(
hid_j
//
2
)
/
d_hid
)
for
hid_j
in
range
(
d_hid
)
]
sinusoid_table
=
np
.
array
(
[
get_position_angle_vec
(
pos_i
)
for
pos_i
in
range
(
n_position
)])
sinusoid_table
[:,
0
::
2
]
=
np
.
sin
(
sinusoid_table
[:,
0
::
2
])
# dim 2i
sinusoid_table
[:,
1
::
2
]
=
np
.
cos
(
sinusoid_table
[:,
1
::
2
])
# dim 2i+1
sinusoid_table
=
paddle
.
to_tensor
(
sinusoid_table
,
dtype
=
'float32'
)
sinusoid_table
=
paddle
.
unsqueeze
(
sinusoid_table
,
axis
=
0
)
return
sinusoid_table
def
forward
(
self
,
x
):
return
x
+
self
.
pos_table
[:,
:
x
.
shape
[
1
]].
clone
().
detach
()
class
ScaledDotProductAttention
(
nn
.
Layer
):
"Scaled Dot-Product Attention"
def
__init__
(
self
,
temperature
,
attn_dropout
=
0.1
):
super
(
ScaledDotProductAttention
,
self
).
__init__
()
self
.
temperature
=
temperature
self
.
dropout
=
nn
.
Dropout
(
attn_dropout
)
self
.
softmax
=
nn
.
Softmax
(
axis
=
2
)
def
forward
(
self
,
q
,
k
,
v
,
mask
=
None
):
k
=
paddle
.
transpose
(
k
,
perm
=
[
0
,
2
,
1
])
attn
=
paddle
.
bmm
(
q
,
k
)
attn
=
attn
/
self
.
temperature
if
mask
is
not
None
:
attn
=
attn
.
masked_fill
(
mask
,
-
1e9
)
if
mask
.
dim
()
==
3
:
mask
=
paddle
.
unsqueeze
(
mask
,
axis
=
1
)
elif
mask
.
dim
()
==
2
:
mask
=
paddle
.
unsqueeze
(
mask
,
axis
=
1
)
mask
=
paddle
.
unsqueeze
(
mask
,
axis
=
1
)
repeat_times
=
[
attn
.
shape
[
1
]
//
mask
.
shape
[
1
],
attn
.
shape
[
2
]
//
mask
.
shape
[
2
]
]
mask
=
paddle
.
tile
(
mask
,
[
1
,
repeat_times
[
0
],
repeat_times
[
1
],
1
])
attn
[
mask
==
0
]
=
-
1e9
attn
=
self
.
softmax
(
attn
)
attn
=
self
.
dropout
(
attn
)
output
=
paddle
.
bmm
(
attn
,
v
)
return
output
class
MultiHeadAttention
(
nn
.
Layer
):
" Multi-Head Attention module"
def
__init__
(
self
,
n_head
,
d_model
,
d_k
,
d_v
,
dropout
=
0.1
):
super
(
MultiHeadAttention
,
self
).
__init__
()
self
.
n_head
=
n_head
self
.
d_k
=
d_k
self
.
d_v
=
d_v
self
.
w_qs
=
nn
.
Linear
(
d_model
,
n_head
*
d_k
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0
,
std
=
np
.
sqrt
(
2.0
/
(
d_model
+
d_k
)))))
self
.
w_ks
=
nn
.
Linear
(
d_model
,
n_head
*
d_k
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0
,
std
=
np
.
sqrt
(
2.0
/
(
d_model
+
d_k
)))))
self
.
w_vs
=
nn
.
Linear
(
d_model
,
n_head
*
d_v
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
mean
=
0
,
std
=
np
.
sqrt
(
2.0
/
(
d_model
+
d_v
)))))
self
.
attention
=
ScaledDotProductAttention
(
temperature
=
np
.
power
(
d_k
,
0.5
))
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
fc
=
nn
.
Linear
(
n_head
*
d_v
,
d_model
,
weight_attr
=
ParamAttr
(
initializer
=
XavierNormal
()))
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
q
,
k
,
v
,
mask
=
None
):
d_k
,
d_v
,
n_head
=
self
.
d_k
,
self
.
d_v
,
self
.
n_head
sz_b
,
len_q
,
_
=
q
.
shape
sz_b
,
len_k
,
_
=
k
.
shape
sz_b
,
len_v
,
_
=
v
.
shape
residual
=
q
q
=
self
.
w_qs
(
q
)
q
=
paddle
.
reshape
(
q
,
shape
=
[
-
1
,
len_q
,
n_head
,
d_k
])
# 4*21*512 ---- 4*21*8*64
k
=
self
.
w_ks
(
k
)
k
=
paddle
.
reshape
(
k
,
shape
=
[
-
1
,
len_k
,
n_head
,
d_k
])
v
=
self
.
w_vs
(
v
)
v
=
paddle
.
reshape
(
v
,
shape
=
[
-
1
,
len_v
,
n_head
,
d_v
])
q
=
paddle
.
transpose
(
q
,
perm
=
[
2
,
0
,
1
,
3
])
q
=
paddle
.
reshape
(
q
,
shape
=
[
-
1
,
len_q
,
d_k
])
# (n*b) x lq x dk
k
=
paddle
.
transpose
(
k
,
perm
=
[
2
,
0
,
1
,
3
])
k
=
paddle
.
reshape
(
k
,
shape
=
[
-
1
,
len_k
,
d_k
])
# (n*b) x lk x dk
v
=
paddle
.
transpose
(
v
,
perm
=
[
2
,
0
,
1
,
3
])
v
=
paddle
.
reshape
(
v
,
shape
=
[
-
1
,
len_v
,
d_v
])
# (n*b) x lv x dv
mask
=
paddle
.
tile
(
mask
,
[
n_head
,
1
,
1
])
if
mask
is
not
None
else
None
# (n*b) x .. x ..
output
=
self
.
attention
(
q
,
k
,
v
,
mask
=
mask
)
output
=
paddle
.
reshape
(
output
,
shape
=
[
n_head
,
-
1
,
len_q
,
d_v
])
output
=
paddle
.
transpose
(
output
,
perm
=
[
1
,
2
,
0
,
3
])
output
=
paddle
.
reshape
(
output
,
shape
=
[
-
1
,
len_q
,
n_head
*
d_v
])
# b x lq x (n*dv)
output
=
self
.
dropout
(
self
.
fc
(
output
))
output
=
self
.
layer_norm
(
output
+
residual
)
return
output
class
PositionwiseFeedForward
(
nn
.
Layer
):
def
__init__
(
self
,
d_in
,
d_hid
,
dropout
=
0.1
):
super
(
PositionwiseFeedForward
,
self
).
__init__
()
self
.
w_1
=
nn
.
Conv1D
(
d_in
,
d_hid
,
1
)
# position-wise
self
.
w_2
=
nn
.
Conv1D
(
d_hid
,
d_in
,
1
)
# position-wise
self
.
layer_norm
=
nn
.
LayerNorm
(
d_in
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
):
residual
=
x
x
=
paddle
.
transpose
(
x
,
perm
=
[
0
,
2
,
1
])
x
=
self
.
w_2
(
F
.
relu
(
self
.
w_1
(
x
)))
x
=
paddle
.
transpose
(
x
,
perm
=
[
0
,
2
,
1
])
x
=
self
.
dropout
(
x
)
x
=
self
.
layer_norm
(
x
+
residual
)
return
x
class
EncoderLayer
(
nn
.
Layer
):
''' Compose with two layers '''
def
__init__
(
self
,
d_model
,
d_inner
,
n_head
,
d_k
,
d_v
,
dropout
=
0.1
):
super
(
EncoderLayer
,
self
).
__init__
()
self
.
slf_attn
=
MultiHeadAttention
(
n_head
,
d_model
,
d_k
,
d_v
,
dropout
=
dropout
)
self
.
pos_ffn
=
PositionwiseFeedForward
(
d_model
,
d_inner
,
dropout
=
dropout
)
def
forward
(
self
,
enc_input
,
slf_attn_mask
=
None
):
enc_output
=
self
.
slf_attn
(
enc_input
,
enc_input
,
enc_input
,
mask
=
slf_attn_mask
)
enc_output
=
self
.
pos_ffn
(
enc_output
)
return
enc_output
class
Transformer_Encoder
(
nn
.
Layer
):
def
__init__
(
self
,
n_layers
=
2
,
n_head
=
8
,
d_word_vec
=
512
,
d_k
=
64
,
d_v
=
64
,
d_model
=
512
,
d_inner
=
2048
,
dropout
=
0.1
,
n_position
=
256
):
super
(
Transformer_Encoder
,
self
).
__init__
()
self
.
position_enc
=
PositionalEncoding
(
d_word_vec
,
n_position
=
n_position
)
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout
)
self
.
layer_stack
=
nn
.
LayerList
([
EncoderLayer
(
d_model
,
d_inner
,
n_head
,
d_k
,
d_v
,
dropout
=
dropout
)
for
_
in
range
(
n_layers
)
])
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
,
epsilon
=
1e-6
)
def
forward
(
self
,
enc_output
,
src_mask
,
return_attns
=
False
):
enc_output
=
self
.
dropout
(
self
.
position_enc
(
enc_output
))
# position embeding
for
enc_layer
in
self
.
layer_stack
:
enc_output
=
enc_layer
(
enc_output
,
slf_attn_mask
=
src_mask
)
enc_output
=
self
.
layer_norm
(
enc_output
)
return
enc_output
class
PP_layer
(
nn
.
Layer
):
def
__init__
(
self
,
n_dim
=
512
,
N_max_character
=
25
,
n_position
=
256
):
super
(
PP_layer
,
self
).
__init__
()
self
.
character_len
=
N_max_character
self
.
f0_embedding
=
nn
.
Embedding
(
N_max_character
,
n_dim
)
self
.
w0
=
nn
.
Linear
(
N_max_character
,
n_position
)
self
.
wv
=
nn
.
Linear
(
n_dim
,
n_dim
)
self
.
we
=
nn
.
Linear
(
n_dim
,
N_max_character
)
self
.
active
=
nn
.
Tanh
()
self
.
softmax
=
nn
.
Softmax
(
axis
=
2
)
def
forward
(
self
,
enc_output
):
# enc_output: b,256,512
reading_order
=
paddle
.
arange
(
self
.
character_len
,
dtype
=
'int64'
)
reading_order
=
reading_order
.
unsqueeze
(
0
).
expand
(
[
enc_output
.
shape
[
0
],
-
1
])
# (S,) -> (B, S)
reading_order
=
self
.
f0_embedding
(
reading_order
)
# b,25,512
# calculate attention
reading_order
=
paddle
.
transpose
(
reading_order
,
perm
=
[
0
,
2
,
1
])
t
=
self
.
w0
(
reading_order
)
# b,512,256
t
=
self
.
active
(
paddle
.
transpose
(
t
,
perm
=
[
0
,
2
,
1
])
+
self
.
wv
(
enc_output
))
# b,256,512
t
=
self
.
we
(
t
)
# b,256,25
t
=
self
.
softmax
(
paddle
.
transpose
(
t
,
perm
=
[
0
,
2
,
1
]))
# b,25,256
g_output
=
paddle
.
bmm
(
t
,
enc_output
)
# b,25,512
return
g_output
class
Prediction
(
nn
.
Layer
):
def
__init__
(
self
,
n_dim
=
512
,
n_position
=
256
,
N_max_character
=
25
,
n_class
=
37
):
super
(
Prediction
,
self
).
__init__
()
self
.
pp
=
PP_layer
(
n_dim
=
n_dim
,
N_max_character
=
N_max_character
,
n_position
=
n_position
)
self
.
pp_share
=
PP_layer
(
n_dim
=
n_dim
,
N_max_character
=
N_max_character
,
n_position
=
n_position
)
self
.
w_vrm
=
nn
.
Linear
(
n_dim
,
n_class
)
# output layer
self
.
w_share
=
nn
.
Linear
(
n_dim
,
n_class
)
# output layer
self
.
nclass
=
n_class
def
forward
(
self
,
cnn_feature
,
f_res
,
f_sub
,
train_mode
=
False
,
use_mlm
=
True
):
if
train_mode
:
if
not
use_mlm
:
g_output
=
self
.
pp
(
cnn_feature
)
# b,25,512
g_output
=
self
.
w_vrm
(
g_output
)
f_res
=
0
f_sub
=
0
return
g_output
,
f_res
,
f_sub
g_output
=
self
.
pp
(
cnn_feature
)
# b,25,512
f_res
=
self
.
pp_share
(
f_res
)
f_sub
=
self
.
pp_share
(
f_sub
)
g_output
=
self
.
w_vrm
(
g_output
)
f_res
=
self
.
w_share
(
f_res
)
f_sub
=
self
.
w_share
(
f_sub
)
return
g_output
,
f_res
,
f_sub
else
:
g_output
=
self
.
pp
(
cnn_feature
)
# b,25,512
g_output
=
self
.
w_vrm
(
g_output
)
return
g_output
class
MLM
(
nn
.
Layer
):
"Architecture of MLM"
def
__init__
(
self
,
n_dim
=
512
,
n_position
=
256
,
max_text_length
=
25
):
super
(
MLM
,
self
).
__init__
()
self
.
MLM_SequenceModeling_mask
=
Transformer_Encoder
(
n_layers
=
2
,
n_position
=
n_position
)
self
.
MLM_SequenceModeling_WCL
=
Transformer_Encoder
(
n_layers
=
1
,
n_position
=
n_position
)
self
.
pos_embedding
=
nn
.
Embedding
(
max_text_length
,
n_dim
)
self
.
w0_linear
=
nn
.
Linear
(
1
,
n_position
)
self
.
wv
=
nn
.
Linear
(
n_dim
,
n_dim
)
self
.
active
=
nn
.
Tanh
()
self
.
we
=
nn
.
Linear
(
n_dim
,
1
)
self
.
sigmoid
=
nn
.
Sigmoid
()
def
forward
(
self
,
x
,
label_pos
):
# transformer unit for generating mask_c
feature_v_seq
=
self
.
MLM_SequenceModeling_mask
(
x
,
src_mask
=
None
)
# position embedding layer
label_pos
=
paddle
.
to_tensor
(
label_pos
,
dtype
=
'int64'
)
pos_emb
=
self
.
pos_embedding
(
label_pos
)
pos_emb
=
self
.
w0_linear
(
paddle
.
unsqueeze
(
pos_emb
,
axis
=
2
))
pos_emb
=
paddle
.
transpose
(
pos_emb
,
perm
=
[
0
,
2
,
1
])
# fusion position embedding with features V & generate mask_c
att_map_sub
=
self
.
active
(
pos_emb
+
self
.
wv
(
feature_v_seq
))
att_map_sub
=
self
.
we
(
att_map_sub
)
# b,256,1
att_map_sub
=
paddle
.
transpose
(
att_map_sub
,
perm
=
[
0
,
2
,
1
])
att_map_sub
=
self
.
sigmoid
(
att_map_sub
)
# b,1,256
# WCL
## generate inputs for WCL
att_map_sub
=
paddle
.
transpose
(
att_map_sub
,
perm
=
[
0
,
2
,
1
])
f_res
=
x
*
(
1
-
att_map_sub
)
# second path with remaining string
f_sub
=
x
*
att_map_sub
# first path with occluded character
## transformer units in WCL
f_res
=
self
.
MLM_SequenceModeling_WCL
(
f_res
,
src_mask
=
None
)
f_sub
=
self
.
MLM_SequenceModeling_WCL
(
f_sub
,
src_mask
=
None
)
return
f_res
,
f_sub
,
att_map_sub
def
trans_1d_2d
(
x
):
b
,
w_h
,
c
=
x
.
shape
# b, 256, 512
x
=
paddle
.
transpose
(
x
,
perm
=
[
0
,
2
,
1
])
x
=
paddle
.
reshape
(
x
,
[
-
1
,
c
,
32
,
8
])
x
=
paddle
.
transpose
(
x
,
perm
=
[
0
,
1
,
3
,
2
])
# [b, c, 8, 32]
return
x
class
MLM_VRM
(
nn
.
Layer
):
"""
MLM+VRM, MLM is only used in training.
ratio controls the occluded number in a batch.
The pipeline of VisionLAN in testing is very concise with only a backbone + sequence modeling(transformer unit) + prediction layer(pp layer).
x: input image
label_pos: character index
training_step: LF or LA process
output
text_pre: prediction of VRM
test_rem: prediction of remaining string in MLM
text_mas: prediction of occluded character in MLM
mask_c_show: visualization of Mask_c
"""
def
__init__
(
self
,
n_layers
=
3
,
n_position
=
256
,
n_dim
=
512
,
max_text_length
=
25
,
nclass
=
37
):
super
(
MLM_VRM
,
self
).
__init__
()
self
.
MLM
=
MLM
(
n_dim
=
n_dim
,
n_position
=
n_position
,
max_text_length
=
max_text_length
)
self
.
SequenceModeling
=
Transformer_Encoder
(
n_layers
=
n_layers
,
n_position
=
n_position
)
self
.
Prediction
=
Prediction
(
n_dim
=
n_dim
,
n_position
=
n_position
,
N_max_character
=
max_text_length
+
1
,
# N_max_character = 1 eos + 25 characters
n_class
=
nclass
)
self
.
nclass
=
nclass
self
.
max_text_length
=
max_text_length
def
forward
(
self
,
x
,
label_pos
,
training_step
,
train_mode
=
False
):
b
,
c
,
h
,
w
=
x
.
shape
nT
=
self
.
max_text_length
x
=
paddle
.
transpose
(
x
,
perm
=
[
0
,
1
,
3
,
2
])
x
=
paddle
.
reshape
(
x
,
[
-
1
,
c
,
h
*
w
])
x
=
paddle
.
transpose
(
x
,
perm
=
[
0
,
2
,
1
])
if
train_mode
:
if
training_step
==
'LF_1'
:
f_res
=
0
f_sub
=
0
x
=
self
.
SequenceModeling
(
x
,
src_mask
=
None
)
text_pre
,
test_rem
,
text_mas
=
self
.
Prediction
(
x
,
f_res
,
f_sub
,
train_mode
=
True
,
use_mlm
=
False
)
return
text_pre
,
text_pre
,
text_pre
,
text_pre
elif
training_step
==
'LF_2'
:
# MLM
f_res
,
f_sub
,
mask_c
=
self
.
MLM
(
x
,
label_pos
)
x
=
self
.
SequenceModeling
(
x
,
src_mask
=
None
)
text_pre
,
test_rem
,
text_mas
=
self
.
Prediction
(
x
,
f_res
,
f_sub
,
train_mode
=
True
)
mask_c_show
=
trans_1d_2d
(
mask_c
)
return
text_pre
,
test_rem
,
text_mas
,
mask_c_show
elif
training_step
==
'LA'
:
# MLM
f_res
,
f_sub
,
mask_c
=
self
.
MLM
(
x
,
label_pos
)
## use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input
## ratio controls the occluded number in a batch
character_mask
=
paddle
.
zeros_like
(
mask_c
)
ratio
=
b
//
2
if
ratio
>=
1
:
with
paddle
.
no_grad
():
character_mask
[
0
:
ratio
,
:,
:]
=
mask_c
[
0
:
ratio
,
:,
:]
else
:
character_mask
=
mask_c
x
=
x
*
(
1
-
character_mask
)
# VRM
## transformer unit for VRM
x
=
self
.
SequenceModeling
(
x
,
src_mask
=
None
)
## prediction layer for MLM and VSR
text_pre
,
test_rem
,
text_mas
=
self
.
Prediction
(
x
,
f_res
,
f_sub
,
train_mode
=
True
)
mask_c_show
=
trans_1d_2d
(
mask_c
)
return
text_pre
,
test_rem
,
text_mas
,
mask_c_show
else
:
raise
NotImplementedError
else
:
# VRM is only used in the testing stage
f_res
=
0
f_sub
=
0
contextual_feature
=
self
.
SequenceModeling
(
x
,
src_mask
=
None
)
text_pre
=
self
.
Prediction
(
contextual_feature
,
f_res
,
f_sub
,
train_mode
=
False
,
use_mlm
=
False
)
text_pre
=
paddle
.
transpose
(
text_pre
,
perm
=
[
1
,
0
,
2
])
# (26, b, 37))
lenText
=
nT
nsteps
=
nT
out_res
=
paddle
.
zeros
(
shape
=
[
lenText
,
b
,
self
.
nclass
],
dtype
=
x
.
dtype
)
# (25, b, 37)
out_length
=
paddle
.
zeros
(
shape
=
[
b
],
dtype
=
x
.
dtype
)
now_step
=
0
for
_
in
range
(
nsteps
):
if
0
in
out_length
and
now_step
<
nsteps
:
tmp_result
=
text_pre
[
now_step
,
:,
:]
out_res
[
now_step
]
=
tmp_result
tmp_result
=
tmp_result
.
topk
(
1
)[
1
].
squeeze
(
axis
=
1
)
for
j
in
range
(
b
):
if
out_length
[
j
]
==
0
and
tmp_result
[
j
]
==
0
:
out_length
[
j
]
=
now_step
+
1
now_step
+=
1
# while 0 in out_length and now_step < nsteps:
# tmp_result = text_pre[now_step, :, :]
# out_res[now_step] = tmp_result
# tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
# for j in range(b):
# if out_length[j] == 0 and tmp_result[j] == 0:
# out_length[j] = now_step + 1
# now_step += 1
for
j
in
range
(
0
,
b
):
if
int
(
out_length
[
j
])
==
0
:
out_length
[
j
]
=
nsteps
start
=
0
output
=
paddle
.
zeros
(
shape
=
[
int
(
out_length
.
sum
()),
self
.
nclass
],
dtype
=
x
.
dtype
)
for
i
in
range
(
0
,
b
):
cur_length
=
int
(
out_length
[
i
])
output
[
start
:
start
+
cur_length
]
=
out_res
[
0
:
cur_length
,
i
,
:]
start
+=
cur_length
return
output
,
out_length
class
VLHead
(
nn
.
Layer
):
"""
Architecture of VisionLAN
"""
def
__init__
(
self
,
in_channels
,
out_channels
=
36
,
n_layers
=
3
,
n_position
=
256
,
n_dim
=
512
,
max_text_length
=
25
,
training_step
=
'LA'
):
super
(
VLHead
,
self
).
__init__
()
self
.
MLM_VRM
=
MLM_VRM
(
n_layers
=
n_layers
,
n_position
=
n_position
,
n_dim
=
n_dim
,
max_text_length
=
max_text_length
,
nclass
=
out_channels
+
1
)
self
.
training_step
=
training_step
def
forward
(
self
,
feat
,
targets
=
None
):
if
self
.
training
:
label_pos
=
targets
[
-
2
]
text_pre
,
test_rem
,
text_mas
,
mask_map
=
self
.
MLM_VRM
(
feat
,
label_pos
,
self
.
training_step
,
train_mode
=
True
)
return
text_pre
,
test_rem
,
text_mas
,
mask_map
else
:
output
,
out_length
=
self
.
MLM_VRM
(
feat
,
targets
,
self
.
training_step
,
train_mode
=
False
)
return
output
,
out_length
ppocr/postprocess/__init__.py
浏览文件 @
cf533b65
...
@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
...
@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
from
.fce_postprocess
import
FCEPostProcess
from
.fce_postprocess
import
FCEPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
\
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
\
DistillationCTCLabelDecode
,
TableLabelDecode
,
NRTRLabelDecode
,
SARLabelDecode
,
\
DistillationCTCLabelDecode
,
TableLabelDecode
,
NRTRLabelDecode
,
SARLabelDecode
,
\
SEEDLabelDecode
,
PRENLabelDecode
SEEDLabelDecode
,
PRENLabelDecode
,
VLLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
from
.pg_postprocess
import
PGPostProcess
from
.vqa_token_ser_layoutlm_postprocess
import
VQASerTokenLayoutLMPostProcess
from
.vqa_token_ser_layoutlm_postprocess
import
VQASerTokenLayoutLMPostProcess
...
@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
...
@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'DistillationDBPostProcess'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'SEEDLabelDecode'
,
'VQASerTokenLayoutLMPostProcess'
,
'SEEDLabelDecode'
,
'VQASerTokenLayoutLMPostProcess'
,
'VQAReTokenLayoutLMPostProcess'
,
'PRENLabelDecode'
,
'VQAReTokenLayoutLMPostProcess'
,
'PRENLabelDecode'
,
'DistillationSARLabelDecode'
'DistillationSARLabelDecode'
,
'VLLabelDecode'
]
]
if
config
[
'name'
]
==
'PSEPostProcess'
:
if
config
[
'name'
]
==
'PSEPostProcess'
:
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
cf533b65
...
@@ -27,7 +27,8 @@ class BaseRecLabelDecode(object):
...
@@ -27,7 +27,8 @@ class BaseRecLabelDecode(object):
self
.
character_str
=
[]
self
.
character_str
=
[]
if
character_dict_path
is
None
:
if
character_dict_path
is
None
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
# self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
self
.
character_str
=
"abcdefghijklmnopqrstuvwxyz1234567890"
dict_character
=
list
(
self
.
character_str
)
dict_character
=
list
(
self
.
character_str
)
else
:
else
:
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
...
@@ -752,3 +753,70 @@ class PRENLabelDecode(BaseRecLabelDecode):
...
@@ -752,3 +753,70 @@ class PRENLabelDecode(BaseRecLabelDecode):
return
text
return
text
label
=
self
.
decode
(
label
)
label
=
self
.
decode
(
label
)
return
text
,
label
return
text
,
label
class
VLLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
VLLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
ignored_tokens
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
selection
=
np
.
ones
(
len
(
text_index
[
batch_idx
]),
dtype
=
bool
)
if
is_remove_duplicate
:
selection
[
1
:]
=
text_index
[
batch_idx
][
1
:]
!=
text_index
[
batch_idx
][:
-
1
]
for
ignored_token
in
ignored_tokens
:
selection
&=
text_index
[
batch_idx
]
!=
ignored_token
char_list
=
[
self
.
character
[
text_id
-
1
]
for
text_id
in
text_index
[
batch_idx
][
selection
]
]
if
text_prob
is
not
None
:
conf_list
=
text_prob
[
batch_idx
][
selection
]
else
:
conf_list
=
[
1
]
*
len
(
selection
)
if
len
(
conf_list
)
==
0
:
conf_list
=
[
0
]
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
length
=
None
,
*
args
,
**
kwargs
):
if
len
(
preds
)
==
2
:
# eval mode
net_out
,
length
=
preds
else
:
# train mode
net_out
=
preds
[
0
]
length
=
length
net_out
=
paddle
.
concat
([
t
[:
l
]
for
t
,
l
in
zip
(
net_out
,
length
)])
text
=
[]
if
not
isinstance
(
net_out
,
paddle
.
Tensor
):
net_out
=
paddle
.
to_tensor
(
net_out
,
dtype
=
'float32'
)
# import pdb
# pdb.set_trace()
net_out
=
F
.
softmax
(
net_out
,
axis
=
1
)
for
i
in
range
(
0
,
length
.
shape
[
0
]):
preds_idx
=
net_out
[
int
(
length
[:
i
].
sum
()):
int
(
length
[:
i
].
sum
(
)
+
length
[
i
])].
topk
(
1
)[
1
][:,
0
].
tolist
()
preds_text
=
''
.
join
([
self
.
character
[
idx
-
1
]
if
idx
>
0
and
idx
<=
len
(
self
.
character
)
else
''
for
idx
in
preds_idx
])
preds_prob
=
net_out
[
int
(
length
[:
i
].
sum
()):
int
(
length
[:
i
].
sum
(
)
+
length
[
i
])].
topk
(
1
)[
0
][:,
0
]
preds_prob
=
paddle
.
exp
(
paddle
.
log
(
preds_prob
).
sum
()
/
(
preds_prob
.
shape
[
0
]
+
1e-6
))
text
.
append
((
preds_text
,
preds_prob
))
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
)
return
text
,
label
tools/eval.py
浏览文件 @
cf533b65
...
@@ -73,7 +73,7 @@ def main():
...
@@ -73,7 +73,7 @@ def main():
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
extra_input_models
=
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
,
"SVTR"
]
extra_input_models
=
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
,
"SVTR"
,
"VisionLAN"
]
extra_input
=
False
extra_input
=
False
if
config
[
'Architecture'
][
'algorithm'
]
==
'Distillation'
:
if
config
[
'Architecture'
][
'algorithm'
]
==
'Distillation'
:
for
key
in
config
[
'Architecture'
][
"Models"
]:
for
key
in
config
[
'Architecture'
][
"Models"
]:
...
...
tools/export_model.py
浏览文件 @
cf533b65
...
@@ -55,7 +55,7 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
...
@@ -55,7 +55,7 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
shape
=
[
None
,
3
,
48
,
160
],
dtype
=
"float32"
),
shape
=
[
None
,
3
,
48
,
160
],
dtype
=
"float32"
),
]
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
model
=
to_static
(
model
,
input_spec
=
other_shape
)
elif
arch_config
[
"algorithm"
]
==
"SVTR"
:
elif
arch_config
[
"algorithm"
]
in
[
"SVTR"
,
"VisionLAN"
]
:
if
arch_config
[
"Head"
][
"name"
]
==
'MultiHead'
:
if
arch_config
[
"Head"
][
"name"
]
==
'MultiHead'
:
other_shape
=
[
other_shape
=
[
paddle
.
static
.
InputSpec
(
paddle
.
static
.
InputSpec
(
...
...
tools/infer/predict_rec.py
浏览文件 @
cf533b65
...
@@ -69,6 +69,12 @@ class TextRecognizer(object):
...
@@ -69,6 +69,12 @@ class TextRecognizer(object):
"character_dict_path"
:
args
.
rec_char_dict_path
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
"use_space_char"
:
args
.
use_space_char
}
}
elif
self
.
rec_algorithm
==
"VisionLAN"
:
postprocess_params
=
{
'name'
:
'VLLabelDecode'
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
}
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
\
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
\
utility
.
create_predictor
(
args
,
'rec'
,
logger
)
utility
.
create_predictor
(
args
,
'rec'
,
logger
)
...
@@ -143,6 +149,15 @@ class TextRecognizer(object):
...
@@ -143,6 +149,15 @@ class TextRecognizer(object):
resized_image
/=
0.5
resized_image
/=
0.5
return
resized_image
return
resized_image
def
resize_norm_img_vl
(
self
,
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
resized_image
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
return
resized_image
def
resize_norm_img_srn
(
self
,
img
,
image_shape
):
def
resize_norm_img_srn
(
self
,
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
imgC
,
imgH
,
imgW
=
image_shape
...
@@ -300,6 +315,11 @@ class TextRecognizer(object):
...
@@ -300,6 +315,11 @@ class TextRecognizer(object):
self
.
rec_image_shape
)
self
.
rec_image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
norm_img_batch
.
append
(
norm_img
)
elif
self
.
rec_algorithm
==
"VisionLAN"
:
norm_img
=
self
.
resize_norm_img_vl
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
else
:
else
:
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]],
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]],
max_wh_ratio
)
max_wh_ratio
)
...
...
tools/program.py
浏览文件 @
cf533b65
...
@@ -207,7 +207,7 @@ def train(config,
...
@@ -207,7 +207,7 @@ def train(config,
model
.
train
()
model
.
train
()
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
extra_input_models
=
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
,
"SVTR"
]
extra_input_models
=
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
,
"SVTR"
,
"VisionLAN"
]
extra_input
=
False
extra_input
=
False
if
config
[
'Architecture'
][
'algorithm'
]
==
'Distillation'
:
if
config
[
'Architecture'
][
'algorithm'
]
==
'Distillation'
:
for
key
in
config
[
'Architecture'
][
"Models"
]:
for
key
in
config
[
'Architecture'
][
"Models"
]:
...
@@ -249,7 +249,6 @@ def train(config,
...
@@ -249,7 +249,6 @@ def train(config,
images
=
batch
[
0
]
images
=
batch
[
0
]
if
use_srn
:
if
use_srn
:
model_average
=
True
model_average
=
True
# use amp
# use amp
if
scaler
:
if
scaler
:
with
paddle
.
amp
.
auto_cast
():
with
paddle
.
amp
.
auto_cast
():
...
@@ -264,7 +263,6 @@ def train(config,
...
@@ -264,7 +263,6 @@ def train(config,
preds
=
model
(
batch
)
preds
=
model
(
batch
)
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
loss
=
loss_class
(
preds
,
batch
)
loss
=
loss_class
(
preds
,
batch
)
avg_loss
=
loss
[
'loss'
]
avg_loss
=
loss
[
'loss'
]
...
@@ -286,6 +284,9 @@ def train(config,
...
@@ -286,6 +284,9 @@ def train(config,
]:
# for multi head loss
]:
# for multi head loss
post_result
=
post_process_class
(
post_result
=
post_process_class
(
preds
[
'ctc'
],
batch
[
1
])
# for CTC head out
preds
[
'ctc'
],
batch
[
1
])
# for CTC head out
elif
config
[
'Loss'
][
'name'
]
in
[
'VLLoss'
]:
post_result
=
post_process_class
(
preds
,
batch
[
1
],
batch
[
-
1
])
else
:
else
:
post_result
=
post_process_class
(
preds
,
batch
[
1
])
post_result
=
post_process_class
(
preds
,
batch
[
1
])
eval_class
(
post_result
,
batch
)
eval_class
(
post_result
,
batch
)
...
@@ -307,7 +308,8 @@ def train(config,
...
@@ -307,7 +308,8 @@ def train(config,
train_stats
.
update
(
stats
)
train_stats
.
update
(
stats
)
if
log_writer
is
not
None
and
dist
.
get_rank
()
==
0
:
if
log_writer
is
not
None
and
dist
.
get_rank
()
==
0
:
log_writer
.
log_metrics
(
metrics
=
train_stats
.
get
(),
prefix
=
"TRAIN"
,
step
=
global_step
)
log_writer
.
log_metrics
(
metrics
=
train_stats
.
get
(),
prefix
=
"TRAIN"
,
step
=
global_step
)
if
dist
.
get_rank
()
==
0
and
(
if
dist
.
get_rank
()
==
0
and
(
(
global_step
>
0
and
global_step
%
print_batch_step
==
0
)
or
(
global_step
>
0
and
global_step
%
print_batch_step
==
0
)
or
...
@@ -354,7 +356,8 @@ def train(config,
...
@@ -354,7 +356,8 @@ def train(config,
# logger metric
# logger metric
if
log_writer
is
not
None
:
if
log_writer
is
not
None
:
log_writer
.
log_metrics
(
metrics
=
cur_metric
,
prefix
=
"EVAL"
,
step
=
global_step
)
log_writer
.
log_metrics
(
metrics
=
cur_metric
,
prefix
=
"EVAL"
,
step
=
global_step
)
if
cur_metric
[
main_indicator
]
>=
best_model_dict
[
if
cur_metric
[
main_indicator
]
>=
best_model_dict
[
main_indicator
]:
main_indicator
]:
...
@@ -377,11 +380,18 @@ def train(config,
...
@@ -377,11 +380,18 @@ def train(config,
logger
.
info
(
best_str
)
logger
.
info
(
best_str
)
# logger best metric
# logger best metric
if
log_writer
is
not
None
:
if
log_writer
is
not
None
:
log_writer
.
log_metrics
(
metrics
=
{
log_writer
.
log_metrics
(
"best_{}"
.
format
(
main_indicator
):
best_model_dict
[
main_indicator
]
metrics
=
{
},
prefix
=
"EVAL"
,
step
=
global_step
)
"best_{}"
.
format
(
main_indicator
):
best_model_dict
[
main_indicator
]
log_writer
.
log_model
(
is_best
=
True
,
prefix
=
"best_accuracy"
,
metadata
=
best_model_dict
)
},
prefix
=
"EVAL"
,
step
=
global_step
)
log_writer
.
log_model
(
is_best
=
True
,
prefix
=
"best_accuracy"
,
metadata
=
best_model_dict
)
reader_start
=
time
.
time
()
reader_start
=
time
.
time
()
if
dist
.
get_rank
()
==
0
:
if
dist
.
get_rank
()
==
0
:
...
@@ -413,7 +423,8 @@ def train(config,
...
@@ -413,7 +423,8 @@ def train(config,
epoch
=
epoch
,
epoch
=
epoch
,
global_step
=
global_step
)
global_step
=
global_step
)
if
log_writer
is
not
None
:
if
log_writer
is
not
None
:
log_writer
.
log_model
(
is_best
=
False
,
prefix
=
'iter_epoch_{}'
.
format
(
epoch
))
log_writer
.
log_model
(
is_best
=
False
,
prefix
=
'iter_epoch_{}'
.
format
(
epoch
))
best_str
=
'best metric, {}'
.
format
(
', '
.
join
(
best_str
=
'best metric, {}'
.
format
(
', '
.
join
(
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
best_model_dict
.
items
()]))
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
best_model_dict
.
items
()]))
...
@@ -451,7 +462,6 @@ def eval(model,
...
@@ -451,7 +462,6 @@ def eval(model,
preds
=
model
(
batch
)
preds
=
model
(
batch
)
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
batch_numpy
=
[]
batch_numpy
=
[]
for
item
in
batch
:
for
item
in
batch
:
if
isinstance
(
item
,
paddle
.
Tensor
):
if
isinstance
(
item
,
paddle
.
Tensor
):
...
@@ -564,7 +574,8 @@ def preprocess(is_train=False):
...
@@ -564,7 +574,8 @@ def preprocess(is_train=False):
assert
alg
in
[
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'SEED'
,
'SDMGR'
,
'LayoutXLM'
,
'LayoutLM'
,
'PREN'
,
'FCE'
,
'SVTR'
'SEED'
,
'SDMGR'
,
'LayoutXLM'
,
'LayoutLM'
,
'PREN'
,
'FCE'
,
'SVTR'
,
'VisionLAN'
]
]
if
use_xpu
:
if
use_xpu
:
...
@@ -583,9 +594,10 @@ def preprocess(is_train=False):
...
@@ -583,9 +594,10 @@ def preprocess(is_train=False):
if
'use_visualdl'
in
config
[
'Global'
]
and
config
[
'Global'
][
'use_visualdl'
]:
if
'use_visualdl'
in
config
[
'Global'
]
and
config
[
'Global'
][
'use_visualdl'
]:
save_model_dir
=
config
[
'Global'
][
'save_model_dir'
]
save_model_dir
=
config
[
'Global'
][
'save_model_dir'
]
vdl_writer_path
=
'{}/vdl/'
.
format
(
save_model_dir
)
vdl_writer_path
=
'{}/vdl/'
.
format
(
save_model_dir
)
log_writer
=
VDLLogger
(
save_model_dir
)
log_writer
=
VDLLogger
(
vdl_writer_path
)
loggers
.
append
(
log_writer
)
loggers
.
append
(
log_writer
)
if
(
'use_wandb'
in
config
[
'Global'
]
and
config
[
'Global'
][
'use_wandb'
])
or
'wandb'
in
config
:
if
(
'use_wandb'
in
config
[
'Global'
]
and
config
[
'Global'
][
'use_wandb'
])
or
'wandb'
in
config
:
save_dir
=
config
[
'Global'
][
'save_model_dir'
]
save_dir
=
config
[
'Global'
][
'save_model_dir'
]
wandb_writer_path
=
"{}/wandb"
.
format
(
save_dir
)
wandb_writer_path
=
"{}/wandb"
.
format
(
save_dir
)
if
"wandb"
in
config
:
if
"wandb"
in
config
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录