Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
9c813bb3
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9c813bb3
编写于
7月 01, 2021
作者:
M
MissPenguin
提交者:
GitHub
7月 01, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3218 from WenmuZhou/copy_paste
add copy paste
上级
5661c686
7d480546
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
238 addition
and
38 deletion
+238
-38
ppocr/data/imaug/__init__.py
ppocr/data/imaug/__init__.py
+1
-0
ppocr/data/imaug/copy_paste.py
ppocr/data/imaug/copy_paste.py
+164
-0
ppocr/data/simple_dataset.py
ppocr/data/simple_dataset.py
+35
-2
tools/infer/predict_system.py
tools/infer/predict_system.py
+2
-35
tools/infer/utility.py
tools/infer/utility.py
+36
-1
未找到文件。
ppocr/data/imaug/__init__.py
浏览文件 @
9c813bb3
...
@@ -23,6 +23,7 @@ from .random_crop_data import EastRandomCropData, PSERandomCrop
...
@@ -23,6 +23,7 @@ from .random_crop_data import EastRandomCropData, PSERandomCrop
from
.rec_img_aug
import
RecAug
,
RecResizeImg
,
ClsResizeImg
,
SRNRecResizeImg
from
.rec_img_aug
import
RecAug
,
RecResizeImg
,
ClsResizeImg
,
SRNRecResizeImg
from
.randaugment
import
RandAugment
from
.randaugment
import
RandAugment
from
.copy_paste
import
CopyPaste
from
.operators
import
*
from
.operators
import
*
from
.label_ops
import
*
from
.label_ops
import
*
...
...
ppocr/data/imaug/copy_paste.py
0 → 100644
浏览文件 @
9c813bb3
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
cv2
import
random
import
numpy
as
np
from
PIL
import
Image
from
shapely.geometry
import
Polygon
from
ppocr.data.imaug.iaa_augment
import
IaaAugment
from
ppocr.data.imaug.random_crop_data
import
is_poly_outside_rect
from
tools.infer.utility
import
get_rotate_crop_image
class
CopyPaste
(
object
):
def
__init__
(
self
,
objects_paste_ratio
=
0.2
,
limit_paste
=
True
,
**
kwargs
):
self
.
ext_data_num
=
1
self
.
objects_paste_ratio
=
objects_paste_ratio
self
.
limit_paste
=
limit_paste
augmenter_args
=
[{
'type'
:
'Resize'
,
'args'
:
{
'size'
:
[
0.5
,
3
]}}]
self
.
aug
=
IaaAugment
(
augmenter_args
)
def
__call__
(
self
,
data
):
src_img
=
data
[
'image'
]
src_polys
=
data
[
'polys'
].
tolist
()
src_ignores
=
data
[
'ignore_tags'
].
tolist
()
ext_data
=
data
[
'ext_data'
][
0
]
ext_image
=
ext_data
[
'image'
]
ext_polys
=
ext_data
[
'polys'
]
ext_ignores
=
ext_data
[
'ignore_tags'
]
indexs
=
[
i
for
i
in
range
(
len
(
ext_ignores
))
if
not
ext_ignores
[
i
]]
select_num
=
max
(
1
,
min
(
int
(
self
.
objects_paste_ratio
*
len
(
ext_polys
)),
30
))
random
.
shuffle
(
indexs
)
select_idxs
=
indexs
[:
select_num
]
select_polys
=
ext_polys
[
select_idxs
]
select_ignores
=
ext_ignores
[
select_idxs
]
src_img
=
cv2
.
cvtColor
(
src_img
,
cv2
.
COLOR_BGR2RGB
)
ext_image
=
cv2
.
cvtColor
(
ext_image
,
cv2
.
COLOR_BGR2RGB
)
src_img
=
Image
.
fromarray
(
src_img
).
convert
(
'RGBA'
)
for
poly
,
tag
in
zip
(
select_polys
,
select_ignores
):
box_img
=
get_rotate_crop_image
(
ext_image
,
poly
)
src_img
,
box
=
self
.
paste_img
(
src_img
,
box_img
,
src_polys
)
if
box
is
not
None
:
src_polys
.
append
(
box
)
src_ignores
.
append
(
tag
)
src_img
=
cv2
.
cvtColor
(
np
.
array
(
src_img
),
cv2
.
COLOR_RGB2BGR
)
h
,
w
=
src_img
.
shape
[:
2
]
src_polys
=
np
.
array
(
src_polys
)
src_polys
[:,
:,
0
]
=
np
.
clip
(
src_polys
[:,
:,
0
],
0
,
w
)
src_polys
[:,
:,
1
]
=
np
.
clip
(
src_polys
[:,
:,
1
],
0
,
h
)
data
[
'image'
]
=
src_img
data
[
'polys'
]
=
src_polys
data
[
'ignore_tags'
]
=
np
.
array
(
src_ignores
)
return
data
def
paste_img
(
self
,
src_img
,
box_img
,
src_polys
):
box_img_pil
=
Image
.
fromarray
(
box_img
).
convert
(
'RGBA'
)
src_w
,
src_h
=
src_img
.
size
box_w
,
box_h
=
box_img_pil
.
size
if
box_w
>
src_w
or
box_h
>
src_h
:
return
src_img
,
None
angle
=
np
.
random
.
randint
(
0
,
360
)
box
=
np
.
array
([[[
0
,
0
],
[
box_w
,
0
],
[
box_w
,
box_h
],
[
0
,
box_h
]]])
box
=
rotate_bbox
(
box_img
,
box
,
angle
)[
0
]
paste_x
,
paste_y
=
self
.
select_coord
(
src_polys
,
box
,
src_w
-
box_w
,
src_h
-
box_h
)
if
paste_x
is
None
:
return
src_img
,
None
box
[:,
0
]
+=
paste_x
box
[:,
1
]
+=
paste_y
box_img_pil
=
box_img_pil
.
rotate
(
angle
,
expand
=
1
)
r
,
g
,
b
,
A
=
box_img_pil
.
split
()
src_img
.
paste
(
box_img_pil
,
(
paste_x
,
paste_y
),
mask
=
A
)
return
src_img
,
box
def
select_coord
(
self
,
src_polys
,
box
,
endx
,
endy
):
if
self
.
limit_paste
:
xmin
,
ymin
,
xmax
,
ymax
=
box
[:,
0
].
min
(),
box
[:,
1
].
min
(
),
box
[:,
0
].
max
(),
box
[:,
1
].
max
()
for
_
in
range
(
50
):
paste_x
=
random
.
randint
(
0
,
endx
)
paste_y
=
random
.
randint
(
0
,
endy
)
xmin1
=
xmin
+
paste_x
xmax1
=
xmax
+
paste_x
ymin1
=
ymin
+
paste_y
ymax1
=
ymax
+
paste_y
num_poly_in_rect
=
0
for
poly
in
src_polys
:
if
not
is_poly_outside_rect
(
poly
,
xmax1
,
ymin1
,
xmax1
-
xmin1
,
ymax1
-
ymin1
):
num_poly_in_rect
+=
1
break
if
num_poly_in_rect
==
0
:
return
paste_x
,
paste_y
return
None
,
None
else
:
paste_x
=
random
.
randint
(
0
,
endx
)
paste_y
=
random
.
randint
(
0
,
endy
)
return
paste_x
,
paste_y
def
get_union
(
pD
,
pG
):
return
Polygon
(
pD
).
union
(
Polygon
(
pG
)).
area
def
get_intersection_over_union
(
pD
,
pG
):
return
get_intersection
(
pD
,
pG
)
/
get_union
(
pD
,
pG
)
def
get_intersection
(
pD
,
pG
):
return
Polygon
(
pD
).
intersection
(
Polygon
(
pG
)).
area
def
rotate_bbox
(
img
,
text_polys
,
angle
,
scale
=
1
):
"""
from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
Args:
img: np.ndarray
text_polys: np.ndarray N*4*2
angle: int
scale: int
Returns:
"""
w
=
img
.
shape
[
1
]
h
=
img
.
shape
[
0
]
rangle
=
np
.
deg2rad
(
angle
)
nw
=
(
abs
(
np
.
sin
(
rangle
)
*
h
)
+
abs
(
np
.
cos
(
rangle
)
*
w
))
nh
=
(
abs
(
np
.
cos
(
rangle
)
*
h
)
+
abs
(
np
.
sin
(
rangle
)
*
w
))
rot_mat
=
cv2
.
getRotationMatrix2D
((
nw
*
0.5
,
nh
*
0.5
),
angle
,
scale
)
rot_move
=
np
.
dot
(
rot_mat
,
np
.
array
([(
nw
-
w
)
*
0.5
,
(
nh
-
h
)
*
0.5
,
0
]))
rot_mat
[
0
,
2
]
+=
rot_move
[
0
]
rot_mat
[
1
,
2
]
+=
rot_move
[
1
]
# ---------------------- rotate box ----------------------
rot_text_polys
=
list
()
for
bbox
in
text_polys
:
point1
=
np
.
dot
(
rot_mat
,
np
.
array
([
bbox
[
0
,
0
],
bbox
[
0
,
1
],
1
]))
point2
=
np
.
dot
(
rot_mat
,
np
.
array
([
bbox
[
1
,
0
],
bbox
[
1
,
1
],
1
]))
point3
=
np
.
dot
(
rot_mat
,
np
.
array
([
bbox
[
2
,
0
],
bbox
[
2
,
1
],
1
]))
point4
=
np
.
dot
(
rot_mat
,
np
.
array
([
bbox
[
3
,
0
],
bbox
[
3
,
1
],
1
]))
rot_text_polys
.
append
([
point1
,
point2
,
point3
,
point4
])
return
np
.
array
(
rot_text_polys
,
dtype
=
np
.
float32
)
ppocr/data/simple_dataset.py
浏览文件 @
9c813bb3
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
import
numpy
as
np
import
numpy
as
np
import
os
import
os
import
random
import
random
import
traceback
from
paddle.io
import
Dataset
from
paddle.io
import
Dataset
from
.imaug
import
transform
,
create_operators
from
.imaug
import
transform
,
create_operators
...
@@ -69,6 +70,36 @@ class SimpleDataSet(Dataset):
...
@@ -69,6 +70,36 @@ class SimpleDataSet(Dataset):
random
.
shuffle
(
self
.
data_lines
)
random
.
shuffle
(
self
.
data_lines
)
return
return
def
get_ext_data
(
self
):
ext_data_num
=
0
for
op
in
self
.
ops
:
if
hasattr
(
op
,
'ext_data_num'
):
ext_data_num
=
getattr
(
op
,
'ext_data_num'
)
break
load_data_ops
=
self
.
ops
[:
2
]
ext_data
=
[]
while
len
(
ext_data
)
<
ext_data_num
:
file_idx
=
self
.
data_idx_order_list
[
np
.
random
.
randint
(
self
.
__len__
(
))]
data_line
=
self
.
data_lines
[
file_idx
]
data_line
=
data_line
.
decode
(
'utf-8'
)
substr
=
data_line
.
strip
(
"
\n
"
).
split
(
self
.
delimiter
)
file_name
=
substr
[
0
]
label
=
substr
[
1
]
img_path
=
os
.
path
.
join
(
self
.
data_dir
,
file_name
)
data
=
{
'img_path'
:
img_path
,
'label'
:
label
}
if
not
os
.
path
.
exists
(
img_path
):
continue
with
open
(
data
[
'img_path'
],
'rb'
)
as
f
:
img
=
f
.
read
()
data
[
'image'
]
=
img
data
=
transform
(
data
,
load_data_ops
)
if
data
is
None
:
continue
ext_data
.
append
(
data
)
return
ext_data
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
file_idx
=
self
.
data_idx_order_list
[
idx
]
file_idx
=
self
.
data_idx_order_list
[
idx
]
data_line
=
self
.
data_lines
[
file_idx
]
data_line
=
self
.
data_lines
[
file_idx
]
...
@@ -84,11 +115,13 @@ class SimpleDataSet(Dataset):
...
@@ -84,11 +115,13 @@ class SimpleDataSet(Dataset):
with
open
(
data
[
'img_path'
],
'rb'
)
as
f
:
with
open
(
data
[
'img_path'
],
'rb'
)
as
f
:
img
=
f
.
read
()
img
=
f
.
read
()
data
[
'image'
]
=
img
data
[
'image'
]
=
img
data
[
'ext_data'
]
=
self
.
get_ext_data
()
outs
=
transform
(
data
,
self
.
ops
)
outs
=
transform
(
data
,
self
.
ops
)
except
Exception
as
e
:
except
:
error_meg
=
traceback
.
format_exc
()
self
.
logger
.
error
(
self
.
logger
.
error
(
"When parsing line {}, error happened with msg: {}"
.
format
(
"When parsing line {}, error happened with msg: {}"
.
format
(
data_line
,
e
))
data_line
,
e
rror_meg
))
outs
=
None
outs
=
None
if
outs
is
None
:
if
outs
is
None
:
# during evaluation, we should fix the idx to get same results for many times of evaluation.
# during evaluation, we should fix the idx to get same results for many times of evaluation.
...
...
tools/infer/predict_system.py
浏览文件 @
9c813bb3
...
@@ -33,7 +33,7 @@ import tools.infer.predict_det as predict_det
...
@@ -33,7 +33,7 @@ import tools.infer.predict_det as predict_det
import
tools.infer.predict_cls
as
predict_cls
import
tools.infer.predict_cls
as
predict_cls
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.logging
import
get_logger
from
tools.infer.utility
import
draw_ocr_box_txt
from
tools.infer.utility
import
draw_ocr_box_txt
,
get_rotate_crop_image
logger
=
get_logger
()
logger
=
get_logger
()
...
@@ -49,39 +49,6 @@ class TextSystem(object):
...
@@ -49,39 +49,6 @@ class TextSystem(object):
if
self
.
use_angle_cls
:
if
self
.
use_angle_cls
:
self
.
text_classifier
=
predict_cls
.
TextClassifier
(
args
)
self
.
text_classifier
=
predict_cls
.
TextClassifier
(
args
)
def
get_rotate_crop_image
(
self
,
img
,
points
):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
img_crop_width
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
1
]),
np
.
linalg
.
norm
(
points
[
2
]
-
points
[
3
])))
img_crop_height
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
3
]),
np
.
linalg
.
norm
(
points
[
1
]
-
points
[
2
])))
pts_std
=
np
.
float32
([[
0
,
0
],
[
img_crop_width
,
0
],
[
img_crop_width
,
img_crop_height
],
[
0
,
img_crop_height
]])
M
=
cv2
.
getPerspectiveTransform
(
points
,
pts_std
)
dst_img
=
cv2
.
warpPerspective
(
img
,
M
,
(
img_crop_width
,
img_crop_height
),
borderMode
=
cv2
.
BORDER_REPLICATE
,
flags
=
cv2
.
INTER_CUBIC
)
dst_img_height
,
dst_img_width
=
dst_img
.
shape
[
0
:
2
]
if
dst_img_height
*
1.0
/
dst_img_width
>=
1.5
:
dst_img
=
np
.
rot90
(
dst_img
)
return
dst_img
def
print_draw_crop_rec_res
(
self
,
img_crop_list
,
rec_res
):
def
print_draw_crop_rec_res
(
self
,
img_crop_list
,
rec_res
):
bbox_num
=
len
(
img_crop_list
)
bbox_num
=
len
(
img_crop_list
)
for
bno
in
range
(
bbox_num
):
for
bno
in
range
(
bbox_num
):
...
@@ -102,7 +69,7 @@ class TextSystem(object):
...
@@ -102,7 +69,7 @@ class TextSystem(object):
for
bno
in
range
(
len
(
dt_boxes
)):
for
bno
in
range
(
len
(
dt_boxes
)):
tmp_box
=
copy
.
deepcopy
(
dt_boxes
[
bno
])
tmp_box
=
copy
.
deepcopy
(
dt_boxes
[
bno
])
img_crop
=
self
.
get_rotate_crop_image
(
ori_im
,
tmp_box
)
img_crop
=
get_rotate_crop_image
(
ori_im
,
tmp_box
)
img_crop_list
.
append
(
img_crop
)
img_crop_list
.
append
(
img_crop
)
if
self
.
use_angle_cls
and
cls
:
if
self
.
use_angle_cls
and
cls
:
img_crop_list
,
angle_list
,
elapse
=
self
.
text_classifier
(
img_crop_list
,
angle_list
,
elapse
=
self
.
text_classifier
(
...
...
tools/infer/utility.py
浏览文件 @
9c813bb3
...
@@ -241,7 +241,7 @@ def create_predictor(args, mode, logger):
...
@@ -241,7 +241,7 @@ def create_predictor(args, mode, logger):
config
.
delete_pass
(
"conv_transpose_eltwiseadd_bn_fuse_pass"
)
config
.
delete_pass
(
"conv_transpose_eltwiseadd_bn_fuse_pass"
)
if
mode
==
'table'
:
if
mode
==
'table'
:
config
.
delete_pass
(
"fc_fuse_pass"
)
# not supported for table
config
.
delete_pass
(
"fc_fuse_pass"
)
# not supported for table
config
.
switch_use_feed_fetch_ops
(
False
)
config
.
switch_use_feed_fetch_ops
(
False
)
config
.
switch_ir_optim
(
True
)
config
.
switch_ir_optim
(
True
)
...
@@ -506,5 +506,40 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5):
...
@@ -506,5 +506,40 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5):
return
image
return
image
def
get_rotate_crop_image
(
img
,
points
):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
assert
len
(
points
)
==
4
,
"shape of points must be 4*2"
img_crop_width
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
1
]),
np
.
linalg
.
norm
(
points
[
2
]
-
points
[
3
])))
img_crop_height
=
int
(
max
(
np
.
linalg
.
norm
(
points
[
0
]
-
points
[
3
]),
np
.
linalg
.
norm
(
points
[
1
]
-
points
[
2
])))
pts_std
=
np
.
float32
([[
0
,
0
],
[
img_crop_width
,
0
],
[
img_crop_width
,
img_crop_height
],
[
0
,
img_crop_height
]])
M
=
cv2
.
getPerspectiveTransform
(
points
,
pts_std
)
dst_img
=
cv2
.
warpPerspective
(
img
,
M
,
(
img_crop_width
,
img_crop_height
),
borderMode
=
cv2
.
BORDER_REPLICATE
,
flags
=
cv2
.
INTER_CUBIC
)
dst_img_height
,
dst_img_width
=
dst_img
.
shape
[
0
:
2
]
if
dst_img_height
*
1.0
/
dst_img_width
>=
1.5
:
dst_img
=
np
.
rot90
(
dst_img
)
return
dst_img
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
pass
pass
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录