Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
3ea09ad3
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3ea09ad3
编写于
1月 21, 2021
作者:
D
Double_V
提交者:
GitHub
1月 21, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1785 from LDOUBLEV/trt_cpp
fix sast process
上级
18669cc3
16bd2dd0
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
188 addition
and
103 deletion
+188
-103
ppocr/data/imaug/sast_process.py
ppocr/data/imaug/sast_process.py
+188
-103
未找到文件。
ppocr/data/imaug/sast_process.py
浏览文件 @
3ea09ad3
...
@@ -24,11 +24,11 @@ __all__ = ['SASTProcessTrain']
...
@@ -24,11 +24,11 @@ __all__ = ['SASTProcessTrain']
class
SASTProcessTrain
(
object
):
class
SASTProcessTrain
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
image_shape
=
[
512
,
512
],
image_shape
=
[
512
,
512
],
min_crop_size
=
24
,
min_crop_size
=
24
,
min_crop_side_ratio
=
0.3
,
min_crop_side_ratio
=
0.3
,
min_text_size
=
10
,
min_text_size
=
10
,
max_text_size
=
512
,
max_text_size
=
512
,
**
kwargs
):
**
kwargs
):
self
.
input_size
=
image_shape
[
1
]
self
.
input_size
=
image_shape
[
1
]
self
.
min_crop_size
=
min_crop_size
self
.
min_crop_size
=
min_crop_size
...
@@ -42,12 +42,10 @@ class SASTProcessTrain(object):
...
@@ -42,12 +42,10 @@ class SASTProcessTrain(object):
:param poly:
:param poly:
:return:
:return:
"""
"""
edge
=
[
edge
=
[(
poly
[
1
][
0
]
-
poly
[
0
][
0
])
*
(
poly
[
1
][
1
]
+
poly
[
0
][
1
]),
(
poly
[
1
][
0
]
-
poly
[
0
][
0
])
*
(
poly
[
1
][
1
]
+
poly
[
0
][
1
]),
(
poly
[
2
][
0
]
-
poly
[
1
][
0
])
*
(
poly
[
2
][
1
]
+
poly
[
1
][
1
]),
(
poly
[
2
][
0
]
-
poly
[
1
][
0
])
*
(
poly
[
2
][
1
]
+
poly
[
1
][
1
]),
(
poly
[
3
][
0
]
-
poly
[
2
][
0
])
*
(
poly
[
3
][
1
]
+
poly
[
2
][
1
]),
(
poly
[
3
][
0
]
-
poly
[
2
][
0
])
*
(
poly
[
3
][
1
]
+
poly
[
2
][
1
]),
(
poly
[
0
][
0
]
-
poly
[
3
][
0
])
*
(
poly
[
0
][
1
]
+
poly
[
3
][
1
])]
(
poly
[
0
][
0
]
-
poly
[
3
][
0
])
*
(
poly
[
0
][
1
]
+
poly
[
3
][
1
])
]
return
np
.
sum
(
edge
)
/
2.
return
np
.
sum
(
edge
)
/
2.
def
gen_quad_from_poly
(
self
,
poly
):
def
gen_quad_from_poly
(
self
,
poly
):
...
@@ -57,7 +55,8 @@ class SASTProcessTrain(object):
...
@@ -57,7 +55,8 @@ class SASTProcessTrain(object):
point_num
=
poly
.
shape
[
0
]
point_num
=
poly
.
shape
[
0
]
min_area_quad
=
np
.
zeros
((
4
,
2
),
dtype
=
np
.
float32
)
min_area_quad
=
np
.
zeros
((
4
,
2
),
dtype
=
np
.
float32
)
if
True
:
if
True
:
rect
=
cv2
.
minAreaRect
(
poly
.
astype
(
np
.
int32
))
# (center (x,y), (width, height), angle of rotation)
rect
=
cv2
.
minAreaRect
(
poly
.
astype
(
np
.
int32
))
# (center (x,y), (width, height), angle of rotation)
center_point
=
rect
[
0
]
center_point
=
rect
[
0
]
box
=
np
.
array
(
cv2
.
boxPoints
(
rect
))
box
=
np
.
array
(
cv2
.
boxPoints
(
rect
))
...
@@ -102,23 +101,33 @@ class SASTProcessTrain(object):
...
@@ -102,23 +101,33 @@ class SASTProcessTrain(object):
if
p_area
>
0
:
if
p_area
>
0
:
if
tag
==
False
:
if
tag
==
False
:
print
(
'poly in wrong direction'
)
print
(
'poly in wrong direction'
)
tag
=
True
# reversed cases should be ignore
tag
=
True
# reversed cases should be ignore
poly
=
poly
[(
0
,
15
,
14
,
13
,
12
,
11
,
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
),
:]
poly
=
poly
[(
0
,
15
,
14
,
13
,
12
,
11
,
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
),
:]
quad
=
quad
[(
0
,
3
,
2
,
1
),
:]
quad
=
quad
[(
0
,
3
,
2
,
1
),
:]
len_w
=
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
])
+
np
.
linalg
.
norm
(
quad
[
3
]
-
quad
[
2
])
len_w
=
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
])
+
np
.
linalg
.
norm
(
quad
[
3
]
-
len_h
=
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
3
])
+
np
.
linalg
.
norm
(
quad
[
1
]
-
quad
[
2
])
quad
[
2
])
len_h
=
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
3
])
+
np
.
linalg
.
norm
(
quad
[
1
]
-
quad
[
2
])
hv_tag
=
1
hv_tag
=
1
if
len_w
*
2.0
<
len_h
:
if
len_w
*
2.0
<
len_h
:
hv_tag
=
0
hv_tag
=
0
validated_polys
.
append
(
poly
)
validated_polys
.
append
(
poly
)
validated_tags
.
append
(
tag
)
validated_tags
.
append
(
tag
)
hv_tags
.
append
(
hv_tag
)
hv_tags
.
append
(
hv_tag
)
return
np
.
array
(
validated_polys
),
np
.
array
(
validated_tags
),
np
.
array
(
hv_tags
)
return
np
.
array
(
validated_polys
),
np
.
array
(
validated_tags
),
np
.
array
(
hv_tags
)
def
crop_area
(
self
,
im
,
polys
,
tags
,
hv_tags
,
crop_background
=
False
,
max_tries
=
25
):
def
crop_area
(
self
,
im
,
polys
,
tags
,
hv_tags
,
crop_background
=
False
,
max_tries
=
25
):
"""
"""
make random crop from the input image
make random crop from the input image
:param im:
:param im:
...
@@ -137,10 +146,10 @@ class SASTProcessTrain(object):
...
@@ -137,10 +146,10 @@ class SASTProcessTrain(object):
poly
=
np
.
round
(
poly
,
decimals
=
0
).
astype
(
np
.
int32
)
poly
=
np
.
round
(
poly
,
decimals
=
0
).
astype
(
np
.
int32
)
minx
=
np
.
min
(
poly
[:,
0
])
minx
=
np
.
min
(
poly
[:,
0
])
maxx
=
np
.
max
(
poly
[:,
0
])
maxx
=
np
.
max
(
poly
[:,
0
])
w_array
[
minx
+
pad_w
:
maxx
+
pad_w
]
=
1
w_array
[
minx
+
pad_w
:
maxx
+
pad_w
]
=
1
miny
=
np
.
min
(
poly
[:,
1
])
miny
=
np
.
min
(
poly
[:,
1
])
maxy
=
np
.
max
(
poly
[:,
1
])
maxy
=
np
.
max
(
poly
[:,
1
])
h_array
[
miny
+
pad_h
:
maxy
+
pad_h
]
=
1
h_array
[
miny
+
pad_h
:
maxy
+
pad_h
]
=
1
# ensure the cropped area not across a text
# ensure the cropped area not across a text
h_axis
=
np
.
where
(
h_array
==
0
)[
0
]
h_axis
=
np
.
where
(
h_array
==
0
)[
0
]
w_axis
=
np
.
where
(
w_array
==
0
)[
0
]
w_axis
=
np
.
where
(
w_array
==
0
)[
0
]
...
@@ -166,17 +175,18 @@ class SASTProcessTrain(object):
...
@@ -166,17 +175,18 @@ class SASTProcessTrain(object):
if
polys
.
shape
[
0
]
!=
0
:
if
polys
.
shape
[
0
]
!=
0
:
poly_axis_in_area
=
(
polys
[:,
:,
0
]
>=
xmin
)
&
(
polys
[:,
:,
0
]
<=
xmax
)
\
poly_axis_in_area
=
(
polys
[:,
:,
0
]
>=
xmin
)
&
(
polys
[:,
:,
0
]
<=
xmax
)
\
&
(
polys
[:,
:,
1
]
>=
ymin
)
&
(
polys
[:,
:,
1
]
<=
ymax
)
&
(
polys
[:,
:,
1
]
>=
ymin
)
&
(
polys
[:,
:,
1
]
<=
ymax
)
selected_polys
=
np
.
where
(
np
.
sum
(
poly_axis_in_area
,
axis
=
1
)
==
4
)[
0
]
selected_polys
=
np
.
where
(
np
.
sum
(
poly_axis_in_area
,
axis
=
1
)
==
4
)[
0
]
else
:
else
:
selected_polys
=
[]
selected_polys
=
[]
if
len
(
selected_polys
)
==
0
:
if
len
(
selected_polys
)
==
0
:
# no text in this area
# no text in this area
if
crop_background
:
if
crop_background
:
return
im
[
ymin
:
ymax
+
1
,
xmin
:
xmax
+
1
,
:],
\
return
im
[
ymin
:
ymax
+
1
,
xmin
:
xmax
+
1
,
:],
\
polys
[
selected_polys
],
tags
[
selected_polys
],
hv_tags
[
selected_polys
]
,
txts
polys
[
selected_polys
],
tags
[
selected_polys
],
hv_tags
[
selected_polys
]
else
:
else
:
continue
continue
im
=
im
[
ymin
:
ymax
+
1
,
xmin
:
xmax
+
1
,
:]
im
=
im
[
ymin
:
ymax
+
1
,
xmin
:
xmax
+
1
,
:]
polys
=
polys
[
selected_polys
]
polys
=
polys
[
selected_polys
]
tags
=
tags
[
selected_polys
]
tags
=
tags
[
selected_polys
]
hv_tags
=
hv_tags
[
selected_polys
]
hv_tags
=
hv_tags
[
selected_polys
]
...
@@ -192,18 +202,28 @@ class SASTProcessTrain(object):
...
@@ -192,18 +202,28 @@ class SASTProcessTrain(object):
width_list
=
[]
width_list
=
[]
height_list
=
[]
height_list
=
[]
for
quad
in
poly_quads
:
for
quad
in
poly_quads
:
quad_w
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
])
+
np
.
linalg
.
norm
(
quad
[
2
]
-
quad
[
3
]))
/
2.0
quad_w
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
])
+
quad_h
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
3
])
+
np
.
linalg
.
norm
(
quad
[
2
]
-
quad
[
1
]))
/
2.0
np
.
linalg
.
norm
(
quad
[
2
]
-
quad
[
3
]))
/
2.0
quad_h
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
3
])
+
np
.
linalg
.
norm
(
quad
[
2
]
-
quad
[
1
]))
/
2.0
width_list
.
append
(
quad_w
)
width_list
.
append
(
quad_w
)
height_list
.
append
(
quad_h
)
height_list
.
append
(
quad_h
)
norm_width
=
max
(
sum
(
width_list
)
/
(
len
(
width_list
)
+
1e-6
),
1.0
)
norm_width
=
max
(
sum
(
width_list
)
/
(
len
(
width_list
)
+
1e-6
),
1.0
)
average_height
=
max
(
sum
(
height_list
)
/
(
len
(
height_list
)
+
1e-6
),
1.0
)
average_height
=
max
(
sum
(
height_list
)
/
(
len
(
height_list
)
+
1e-6
),
1.0
)
for
quad
in
poly_quads
:
for
quad
in
poly_quads
:
direct_vector_full
=
((
quad
[
1
]
+
quad
[
2
])
-
(
quad
[
0
]
+
quad
[
3
]))
/
2.0
direct_vector_full
=
(
direct_vector
=
direct_vector_full
/
(
np
.
linalg
.
norm
(
direct_vector_full
)
+
1e-6
)
*
norm_width
(
quad
[
1
]
+
quad
[
2
])
-
(
quad
[
0
]
+
quad
[
3
]))
/
2.0
direction_label
=
tuple
(
map
(
float
,
[
direct_vector
[
0
],
direct_vector
[
1
],
1.0
/
(
average_height
+
1e-6
)]))
direct_vector
=
direct_vector_full
/
(
cv2
.
fillPoly
(
direction_map
,
quad
.
round
().
astype
(
np
.
int32
)[
np
.
newaxis
,
:,
:],
direction_label
)
np
.
linalg
.
norm
(
direct_vector_full
)
+
1e-6
)
*
norm_width
direction_label
=
tuple
(
map
(
float
,
[
direct_vector
[
0
],
direct_vector
[
1
],
1.0
/
(
average_height
+
1e-6
)
]))
cv2
.
fillPoly
(
direction_map
,
quad
.
round
().
astype
(
np
.
int32
)[
np
.
newaxis
,
:,
:],
direction_label
)
return
direction_map
return
direction_map
def
calculate_average_height
(
self
,
poly_quads
):
def
calculate_average_height
(
self
,
poly_quads
):
...
@@ -211,13 +231,19 @@ class SASTProcessTrain(object):
...
@@ -211,13 +231,19 @@ class SASTProcessTrain(object):
"""
"""
height_list
=
[]
height_list
=
[]
for
quad
in
poly_quads
:
for
quad
in
poly_quads
:
quad_h
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
3
])
+
np
.
linalg
.
norm
(
quad
[
2
]
-
quad
[
1
]))
/
2.0
quad_h
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
3
])
+
np
.
linalg
.
norm
(
quad
[
2
]
-
quad
[
1
]))
/
2.0
height_list
.
append
(
quad_h
)
height_list
.
append
(
quad_h
)
average_height
=
max
(
sum
(
height_list
)
/
len
(
height_list
),
1.0
)
average_height
=
max
(
sum
(
height_list
)
/
len
(
height_list
),
1.0
)
return
average_height
return
average_height
def
generate_tcl_label
(
self
,
hw
,
polys
,
tags
,
ds_ratio
,
def
generate_tcl_label
(
self
,
tcl_ratio
=
0.3
,
shrink_ratio_of_width
=
0.15
):
hw
,
polys
,
tags
,
ds_ratio
,
tcl_ratio
=
0.3
,
shrink_ratio_of_width
=
0.15
):
"""
"""
Generate polygon.
Generate polygon.
"""
"""
...
@@ -225,21 +251,30 @@ class SASTProcessTrain(object):
...
@@ -225,21 +251,30 @@ class SASTProcessTrain(object):
h
,
w
=
int
(
h
*
ds_ratio
),
int
(
w
*
ds_ratio
)
h
,
w
=
int
(
h
*
ds_ratio
),
int
(
w
*
ds_ratio
)
polys
=
polys
*
ds_ratio
polys
=
polys
*
ds_ratio
score_map
=
np
.
zeros
((
h
,
w
,),
dtype
=
np
.
float32
)
score_map
=
np
.
zeros
(
(
h
,
w
,
),
dtype
=
np
.
float32
)
tbo_map
=
np
.
zeros
((
h
,
w
,
5
),
dtype
=
np
.
float32
)
tbo_map
=
np
.
zeros
((
h
,
w
,
5
),
dtype
=
np
.
float32
)
training_mask
=
np
.
ones
((
h
,
w
,),
dtype
=
np
.
float32
)
training_mask
=
np
.
ones
(
direction_map
=
np
.
ones
((
h
,
w
,
3
))
*
np
.
array
([
0
,
0
,
1
]).
reshape
([
1
,
1
,
3
]).
astype
(
np
.
float32
)
(
h
,
w
,
),
dtype
=
np
.
float32
)
direction_map
=
np
.
ones
((
h
,
w
,
3
))
*
np
.
array
([
0
,
0
,
1
]).
reshape
(
[
1
,
1
,
3
]).
astype
(
np
.
float32
)
for
poly_idx
,
poly_tag
in
enumerate
(
zip
(
polys
,
tags
)):
for
poly_idx
,
poly_tag
in
enumerate
(
zip
(
polys
,
tags
)):
poly
=
poly_tag
[
0
]
poly
=
poly_tag
[
0
]
tag
=
poly_tag
[
1
]
tag
=
poly_tag
[
1
]
# generate min_area_quad
# generate min_area_quad
min_area_quad
,
center_point
=
self
.
gen_min_area_quad_from_poly
(
poly
)
min_area_quad
,
center_point
=
self
.
gen_min_area_quad_from_poly
(
poly
)
min_area_quad_h
=
0.5
*
(
np
.
linalg
.
norm
(
min_area_quad
[
0
]
-
min_area_quad
[
3
])
+
min_area_quad_h
=
0.5
*
(
np
.
linalg
.
norm
(
min_area_quad
[
1
]
-
min_area_quad
[
2
]))
np
.
linalg
.
norm
(
min_area_quad
[
0
]
-
min_area_quad
[
3
])
+
min_area_quad_w
=
0.5
*
(
np
.
linalg
.
norm
(
min_area_quad
[
0
]
-
min_area_quad
[
1
])
+
np
.
linalg
.
norm
(
min_area_quad
[
1
]
-
min_area_quad
[
2
]))
np
.
linalg
.
norm
(
min_area_quad
[
2
]
-
min_area_quad
[
3
]))
min_area_quad_w
=
0.5
*
(
np
.
linalg
.
norm
(
min_area_quad
[
0
]
-
min_area_quad
[
1
])
+
np
.
linalg
.
norm
(
min_area_quad
[
2
]
-
min_area_quad
[
3
]))
if
min
(
min_area_quad_h
,
min_area_quad_w
)
<
self
.
min_text_size
*
ds_ratio
\
if
min
(
min_area_quad_h
,
min_area_quad_w
)
<
self
.
min_text_size
*
ds_ratio
\
or
min
(
min_area_quad_h
,
min_area_quad_w
)
>
self
.
max_text_size
*
ds_ratio
:
or
min
(
min_area_quad_h
,
min_area_quad_w
)
>
self
.
max_text_size
*
ds_ratio
:
...
@@ -247,25 +282,37 @@ class SASTProcessTrain(object):
...
@@ -247,25 +282,37 @@ class SASTProcessTrain(object):
if
tag
:
if
tag
:
# continue
# continue
cv2
.
fillPoly
(
training_mask
,
poly
.
astype
(
np
.
int32
)[
np
.
newaxis
,
:,
:],
0.15
)
cv2
.
fillPoly
(
training_mask
,
poly
.
astype
(
np
.
int32
)[
np
.
newaxis
,
:,
:],
0.15
)
else
:
else
:
tcl_poly
=
self
.
poly2tcl
(
poly
,
tcl_ratio
)
tcl_poly
=
self
.
poly2tcl
(
poly
,
tcl_ratio
)
tcl_quads
=
self
.
poly2quads
(
tcl_poly
)
tcl_quads
=
self
.
poly2quads
(
tcl_poly
)
poly_quads
=
self
.
poly2quads
(
poly
)
poly_quads
=
self
.
poly2quads
(
poly
)
# stcl map
# stcl map
stcl_quads
,
quad_index
=
self
.
shrink_poly_along_width
(
tcl_quads
,
shrink_ratio_of_width
=
shrink_ratio_of_width
,
stcl_quads
,
quad_index
=
self
.
shrink_poly_along_width
(
expand_height_ratio
=
1.0
/
tcl_ratio
)
tcl_quads
,
shrink_ratio_of_width
=
shrink_ratio_of_width
,
expand_height_ratio
=
1.0
/
tcl_ratio
)
# generate tcl map
# generate tcl map
cv2
.
fillPoly
(
score_map
,
np
.
round
(
stcl_quads
).
astype
(
np
.
int32
),
1.0
)
cv2
.
fillPoly
(
score_map
,
np
.
round
(
stcl_quads
).
astype
(
np
.
int32
),
1.0
)
# generate tbo map
# generate tbo map
for
idx
,
quad
in
enumerate
(
stcl_quads
):
for
idx
,
quad
in
enumerate
(
stcl_quads
):
quad_mask
=
np
.
zeros
((
h
,
w
),
dtype
=
np
.
float32
)
quad_mask
=
np
.
zeros
((
h
,
w
),
dtype
=
np
.
float32
)
quad_mask
=
cv2
.
fillPoly
(
quad_mask
,
np
.
round
(
quad
[
np
.
newaxis
,
:,
:]).
astype
(
np
.
int32
),
1.0
)
quad_mask
=
cv2
.
fillPoly
(
tbo_map
=
self
.
gen_quad_tbo
(
poly_quads
[
quad_index
[
idx
]],
quad_mask
,
tbo_map
)
quad_mask
,
np
.
round
(
quad
[
np
.
newaxis
,
:,
:]).
astype
(
np
.
int32
),
1.0
)
tbo_map
=
self
.
gen_quad_tbo
(
poly_quads
[
quad_index
[
idx
]],
quad_mask
,
tbo_map
)
return
score_map
,
tbo_map
,
training_mask
return
score_map
,
tbo_map
,
training_mask
def
generate_tvo_and_tco
(
self
,
hw
,
polys
,
tags
,
tcl_ratio
=
0.3
,
ds_ratio
=
0.25
):
def
generate_tvo_and_tco
(
self
,
hw
,
polys
,
tags
,
tcl_ratio
=
0.3
,
ds_ratio
=
0.25
):
"""
"""
Generate tcl map, tvo map and tbo map.
Generate tcl map, tvo map and tbo map.
"""
"""
...
@@ -297,35 +344,44 @@ class SASTProcessTrain(object):
...
@@ -297,35 +344,44 @@ class SASTProcessTrain(object):
# generate min_area_quad
# generate min_area_quad
min_area_quad
,
center_point
=
self
.
gen_min_area_quad_from_poly
(
poly
)
min_area_quad
,
center_point
=
self
.
gen_min_area_quad_from_poly
(
poly
)
min_area_quad_h
=
0.5
*
(
np
.
linalg
.
norm
(
min_area_quad
[
0
]
-
min_area_quad
[
3
])
+
min_area_quad_h
=
0.5
*
(
np
.
linalg
.
norm
(
min_area_quad
[
1
]
-
min_area_quad
[
2
]))
np
.
linalg
.
norm
(
min_area_quad
[
0
]
-
min_area_quad
[
3
])
+
min_area_quad_w
=
0.5
*
(
np
.
linalg
.
norm
(
min_area_quad
[
0
]
-
min_area_quad
[
1
])
+
np
.
linalg
.
norm
(
min_area_quad
[
1
]
-
min_area_quad
[
2
]))
np
.
linalg
.
norm
(
min_area_quad
[
2
]
-
min_area_quad
[
3
]))
min_area_quad_w
=
0.5
*
(
np
.
linalg
.
norm
(
min_area_quad
[
0
]
-
min_area_quad
[
1
])
+
np
.
linalg
.
norm
(
min_area_quad
[
2
]
-
min_area_quad
[
3
]))
# generate tcl map and text, 128 * 128
# generate tcl map and text, 128 * 128
tcl_poly
=
self
.
poly2tcl
(
poly
,
tcl_ratio
)
tcl_poly
=
self
.
poly2tcl
(
poly
,
tcl_ratio
)
# generate poly_tv_xy_map
# generate poly_tv_xy_map
for
idx
in
range
(
4
):
for
idx
in
range
(
4
):
cv2
.
fillPoly
(
poly_tv_xy_map
[
2
*
idx
],
cv2
.
fillPoly
(
np
.
round
(
tcl_poly
[
np
.
newaxis
,
:,
:]).
astype
(
np
.
int32
),
poly_tv_xy_map
[
2
*
idx
],
float
(
min
(
max
(
min_area_quad
[
idx
,
0
],
0
),
w
)))
np
.
round
(
tcl_poly
[
np
.
newaxis
,
:,
:]).
astype
(
np
.
int32
),
cv2
.
fillPoly
(
poly_tv_xy_map
[
2
*
idx
+
1
],
float
(
min
(
max
(
min_area_quad
[
idx
,
0
],
0
),
w
)))
np
.
round
(
tcl_poly
[
np
.
newaxis
,
:,
:]).
astype
(
np
.
int32
),
cv2
.
fillPoly
(
float
(
min
(
max
(
min_area_quad
[
idx
,
1
],
0
),
h
)))
poly_tv_xy_map
[
2
*
idx
+
1
],
np
.
round
(
tcl_poly
[
np
.
newaxis
,
:,
:]).
astype
(
np
.
int32
),
float
(
min
(
max
(
min_area_quad
[
idx
,
1
],
0
),
h
)))
# generate poly_tc_xy_map
# generate poly_tc_xy_map
for
idx
in
range
(
2
):
for
idx
in
range
(
2
):
cv2
.
fillPoly
(
poly_tc_xy_map
[
idx
],
cv2
.
fillPoly
(
np
.
round
(
tcl_poly
[
np
.
newaxis
,
:,
:]).
astype
(
np
.
int32
),
float
(
center_point
[
idx
]))
poly_tc_xy_map
[
idx
],
np
.
round
(
tcl_poly
[
np
.
newaxis
,
:,
:]).
astype
(
np
.
int32
),
float
(
center_point
[
idx
]))
# generate poly_short_edge_map
# generate poly_short_edge_map
cv2
.
fillPoly
(
poly_short_edge_map
,
cv2
.
fillPoly
(
np
.
round
(
tcl_poly
[
np
.
newaxis
,
:,
:]).
astype
(
np
.
int32
),
poly_short_edge_map
,
float
(
max
(
min
(
min_area_quad_h
,
min_area_quad_w
),
1.0
)))
np
.
round
(
tcl_poly
[
np
.
newaxis
,
:,
:]).
astype
(
np
.
int32
),
float
(
max
(
min
(
min_area_quad_h
,
min_area_quad_w
),
1.0
)))
# generate poly_mask and training_mask
# generate poly_mask and training_mask
cv2
.
fillPoly
(
poly_mask
,
np
.
round
(
tcl_poly
[
np
.
newaxis
,
:,
:]).
astype
(
np
.
int32
),
1
)
cv2
.
fillPoly
(
poly_mask
,
np
.
round
(
tcl_poly
[
np
.
newaxis
,
:,
:]).
astype
(
np
.
int32
),
1
)
tvo_map
*=
poly_mask
tvo_map
*=
poly_mask
tvo_map
[:
8
]
-=
poly_tv_xy_map
tvo_map
[:
8
]
-=
poly_tv_xy_map
...
@@ -356,7 +412,8 @@ class SASTProcessTrain(object):
...
@@ -356,7 +412,8 @@ class SASTProcessTrain(object):
elif
point_num
>
4
:
elif
point_num
>
4
:
vector_1
=
poly
[
0
]
-
poly
[
1
]
vector_1
=
poly
[
0
]
-
poly
[
1
]
vector_2
=
poly
[
1
]
-
poly
[
2
]
vector_2
=
poly
[
1
]
-
poly
[
2
]
cos_theta
=
np
.
dot
(
vector_1
,
vector_2
)
/
(
np
.
linalg
.
norm
(
vector_1
)
*
np
.
linalg
.
norm
(
vector_2
)
+
1e-6
)
cos_theta
=
np
.
dot
(
vector_1
,
vector_2
)
/
(
np
.
linalg
.
norm
(
vector_1
)
*
np
.
linalg
.
norm
(
vector_2
)
+
1e-6
)
theta
=
np
.
arccos
(
np
.
round
(
cos_theta
,
decimals
=
4
))
theta
=
np
.
arccos
(
np
.
round
(
cos_theta
,
decimals
=
4
))
if
abs
(
theta
)
>
(
70
/
180
*
math
.
pi
):
if
abs
(
theta
)
>
(
70
/
180
*
math
.
pi
):
...
@@ -374,7 +431,8 @@ class SASTProcessTrain(object):
...
@@ -374,7 +431,8 @@ class SASTProcessTrain(object):
min_area_quad
=
poly
min_area_quad
=
poly
center_point
=
np
.
sum
(
poly
,
axis
=
0
)
/
4
center_point
=
np
.
sum
(
poly
,
axis
=
0
)
/
4
else
:
else
:
rect
=
cv2
.
minAreaRect
(
poly
.
astype
(
np
.
int32
))
# (center (x,y), (width, height), angle of rotation)
rect
=
cv2
.
minAreaRect
(
poly
.
astype
(
np
.
int32
))
# (center (x,y), (width, height), angle of rotation)
center_point
=
rect
[
0
]
center_point
=
rect
[
0
]
box
=
np
.
array
(
cv2
.
boxPoints
(
rect
))
box
=
np
.
array
(
cv2
.
boxPoints
(
rect
))
...
@@ -394,16 +452,23 @@ class SASTProcessTrain(object):
...
@@ -394,16 +452,23 @@ class SASTProcessTrain(object):
return
min_area_quad
,
center_point
return
min_area_quad
,
center_point
def
shrink_quad_along_width
(
self
,
quad
,
begin_width_ratio
=
0.
,
end_width_ratio
=
1.
):
def
shrink_quad_along_width
(
self
,
quad
,
begin_width_ratio
=
0.
,
end_width_ratio
=
1.
):
"""
"""
Generate shrink_quad_along_width.
Generate shrink_quad_along_width.
"""
"""
ratio_pair
=
np
.
array
([[
begin_width_ratio
],
[
end_width_ratio
]],
dtype
=
np
.
float32
)
ratio_pair
=
np
.
array
(
[[
begin_width_ratio
],
[
end_width_ratio
]],
dtype
=
np
.
float32
)
p0_1
=
quad
[
0
]
+
(
quad
[
1
]
-
quad
[
0
])
*
ratio_pair
p0_1
=
quad
[
0
]
+
(
quad
[
1
]
-
quad
[
0
])
*
ratio_pair
p3_2
=
quad
[
3
]
+
(
quad
[
2
]
-
quad
[
3
])
*
ratio_pair
p3_2
=
quad
[
3
]
+
(
quad
[
2
]
-
quad
[
3
])
*
ratio_pair
return
np
.
array
([
p0_1
[
0
],
p0_1
[
1
],
p3_2
[
1
],
p3_2
[
0
]])
return
np
.
array
([
p0_1
[
0
],
p0_1
[
1
],
p3_2
[
1
],
p3_2
[
0
]])
def
shrink_poly_along_width
(
self
,
quads
,
shrink_ratio_of_width
,
expand_height_ratio
=
1.0
):
def
shrink_poly_along_width
(
self
,
quads
,
shrink_ratio_of_width
,
expand_height_ratio
=
1.0
):
"""
"""
shrink poly with given length.
shrink poly with given length.
"""
"""
...
@@ -421,22 +486,28 @@ class SASTProcessTrain(object):
...
@@ -421,22 +486,28 @@ class SASTProcessTrain(object):
upper_edge_list
.
append
(
upper_edge_len
)
upper_edge_list
.
append
(
upper_edge_len
)
# length of left edge and right edge.
# length of left edge and right edge.
left_length
=
np
.
linalg
.
norm
(
quads
[
0
][
0
]
-
quads
[
0
][
3
])
*
expand_height_ratio
left_length
=
np
.
linalg
.
norm
(
quads
[
0
][
0
]
-
quads
[
0
][
right_length
=
np
.
linalg
.
norm
(
quads
[
-
1
][
1
]
-
quads
[
-
1
][
2
])
*
expand_height_ratio
3
])
*
expand_height_ratio
right_length
=
np
.
linalg
.
norm
(
quads
[
-
1
][
1
]
-
quads
[
-
1
][
2
])
*
expand_height_ratio
shrink_length
=
min
(
left_length
,
right_length
,
sum
(
upper_edge_list
))
*
shrink_ratio_of_width
shrink_length
=
min
(
left_length
,
right_length
,
sum
(
upper_edge_list
))
*
shrink_ratio_of_width
# shrinking length
# shrinking length
upper_len_left
=
shrink_length
upper_len_left
=
shrink_length
upper_len_right
=
sum
(
upper_edge_list
)
-
shrink_length
upper_len_right
=
sum
(
upper_edge_list
)
-
shrink_length
left_idx
,
left_ratio
=
get_cut_info
(
upper_edge_list
,
upper_len_left
)
left_idx
,
left_ratio
=
get_cut_info
(
upper_edge_list
,
upper_len_left
)
left_quad
=
self
.
shrink_quad_along_width
(
quads
[
left_idx
],
begin_width_ratio
=
left_ratio
,
end_width_ratio
=
1
)
left_quad
=
self
.
shrink_quad_along_width
(
quads
[
left_idx
],
begin_width_ratio
=
left_ratio
,
end_width_ratio
=
1
)
right_idx
,
right_ratio
=
get_cut_info
(
upper_edge_list
,
upper_len_right
)
right_idx
,
right_ratio
=
get_cut_info
(
upper_edge_list
,
upper_len_right
)
right_quad
=
self
.
shrink_quad_along_width
(
quads
[
right_idx
],
begin_width_ratio
=
0
,
end_width_ratio
=
right_ratio
)
right_quad
=
self
.
shrink_quad_along_width
(
quads
[
right_idx
],
begin_width_ratio
=
0
,
end_width_ratio
=
right_ratio
)
out_quad_list
=
[]
out_quad_list
=
[]
if
left_idx
==
right_idx
:
if
left_idx
==
right_idx
:
out_quad_list
.
append
([
left_quad
[
0
],
right_quad
[
1
],
right_quad
[
2
],
left_quad
[
3
]])
out_quad_list
.
append
(
[
left_quad
[
0
],
right_quad
[
1
],
right_quad
[
2
],
left_quad
[
3
]])
else
:
else
:
out_quad_list
.
append
(
left_quad
)
out_quad_list
.
append
(
left_quad
)
for
idx
in
range
(
left_idx
+
1
,
right_idx
):
for
idx
in
range
(
left_idx
+
1
,
right_idx
):
...
@@ -500,7 +571,8 @@ class SASTProcessTrain(object):
...
@@ -500,7 +571,8 @@ class SASTProcessTrain(object):
"""
"""
Generate center line by poly clock-wise point. (4, 2)
Generate center line by poly clock-wise point. (4, 2)
"""
"""
ratio_pair
=
np
.
array
([[
0.5
-
ratio
/
2
],
[
0.5
+
ratio
/
2
]],
dtype
=
np
.
float32
)
ratio_pair
=
np
.
array
(
[[
0.5
-
ratio
/
2
],
[
0.5
+
ratio
/
2
]],
dtype
=
np
.
float32
)
p0_3
=
poly
[
0
]
+
(
poly
[
3
]
-
poly
[
0
])
*
ratio_pair
p0_3
=
poly
[
0
]
+
(
poly
[
3
]
-
poly
[
0
])
*
ratio_pair
p1_2
=
poly
[
1
]
+
(
poly
[
2
]
-
poly
[
1
])
*
ratio_pair
p1_2
=
poly
[
1
]
+
(
poly
[
2
]
-
poly
[
1
])
*
ratio_pair
return
np
.
array
([
p0_3
[
0
],
p1_2
[
0
],
p1_2
[
1
],
p0_3
[
1
]])
return
np
.
array
([
p0_3
[
0
],
p1_2
[
0
],
p1_2
[
1
],
p0_3
[
1
]])
...
@@ -509,12 +581,14 @@ class SASTProcessTrain(object):
...
@@ -509,12 +581,14 @@ class SASTProcessTrain(object):
"""
"""
Generate center line by poly clock-wise point.
Generate center line by poly clock-wise point.
"""
"""
ratio_pair
=
np
.
array
([[
0.5
-
ratio
/
2
],
[
0.5
+
ratio
/
2
]],
dtype
=
np
.
float32
)
ratio_pair
=
np
.
array
(
[[
0.5
-
ratio
/
2
],
[
0.5
+
ratio
/
2
]],
dtype
=
np
.
float32
)
tcl_poly
=
np
.
zeros_like
(
poly
)
tcl_poly
=
np
.
zeros_like
(
poly
)
point_num
=
poly
.
shape
[
0
]
point_num
=
poly
.
shape
[
0
]
for
idx
in
range
(
point_num
//
2
):
for
idx
in
range
(
point_num
//
2
):
point_pair
=
poly
[
idx
]
+
(
poly
[
point_num
-
1
-
idx
]
-
poly
[
idx
])
*
ratio_pair
point_pair
=
poly
[
idx
]
+
(
poly
[
point_num
-
1
-
idx
]
-
poly
[
idx
]
)
*
ratio_pair
tcl_poly
[
idx
]
=
point_pair
[
0
]
tcl_poly
[
idx
]
=
point_pair
[
0
]
tcl_poly
[
point_num
-
1
-
idx
]
=
point_pair
[
1
]
tcl_poly
[
point_num
-
1
-
idx
]
=
point_pair
[
1
]
return
tcl_poly
return
tcl_poly
...
@@ -527,8 +601,10 @@ class SASTProcessTrain(object):
...
@@ -527,8 +601,10 @@ class SASTProcessTrain(object):
up_line
=
self
.
line_cross_two_point
(
quad
[
0
],
quad
[
1
])
up_line
=
self
.
line_cross_two_point
(
quad
[
0
],
quad
[
1
])
lower_line
=
self
.
line_cross_two_point
(
quad
[
3
],
quad
[
2
])
lower_line
=
self
.
line_cross_two_point
(
quad
[
3
],
quad
[
2
])
quad_h
=
0.5
*
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
3
])
+
np
.
linalg
.
norm
(
quad
[
1
]
-
quad
[
2
]))
quad_h
=
0.5
*
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
3
])
+
quad_w
=
0.5
*
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
])
+
np
.
linalg
.
norm
(
quad
[
2
]
-
quad
[
3
]))
np
.
linalg
.
norm
(
quad
[
1
]
-
quad
[
2
]))
quad_w
=
0.5
*
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
])
+
np
.
linalg
.
norm
(
quad
[
2
]
-
quad
[
3
]))
# average angle of left and right line.
# average angle of left and right line.
angle
=
self
.
average_angle
(
quad
)
angle
=
self
.
average_angle
(
quad
)
...
@@ -565,7 +641,8 @@ class SASTProcessTrain(object):
...
@@ -565,7 +641,8 @@ class SASTProcessTrain(object):
quad_num
=
point_num
//
2
-
1
quad_num
=
point_num
//
2
-
1
for
idx
in
range
(
quad_num
):
for
idx
in
range
(
quad_num
):
# reshape and adjust to clock-wise
# reshape and adjust to clock-wise
quad_list
.
append
((
np
.
array
(
point_pair_list
)[[
idx
,
idx
+
1
]]).
reshape
(
4
,
2
)[[
0
,
2
,
3
,
1
]])
quad_list
.
append
((
np
.
array
(
point_pair_list
)[[
idx
,
idx
+
1
]]
).
reshape
(
4
,
2
)[[
0
,
2
,
3
,
1
]])
return
np
.
array
(
quad_list
)
return
np
.
array
(
quad_list
)
...
@@ -579,7 +656,8 @@ class SASTProcessTrain(object):
...
@@ -579,7 +656,8 @@ class SASTProcessTrain(object):
return
None
return
None
h
,
w
,
_
=
im
.
shape
h
,
w
,
_
=
im
.
shape
text_polys
,
text_tags
,
hv_tags
=
self
.
check_and_validate_polys
(
text_polys
,
text_tags
,
(
h
,
w
))
text_polys
,
text_tags
,
hv_tags
=
self
.
check_and_validate_polys
(
text_polys
,
text_tags
,
(
h
,
w
))
if
text_polys
.
shape
[
0
]
==
0
:
if
text_polys
.
shape
[
0
]
==
0
:
return
None
return
None
...
@@ -591,7 +669,7 @@ class SASTProcessTrain(object):
...
@@ -591,7 +669,7 @@ class SASTProcessTrain(object):
if
np
.
random
.
rand
()
<
0.5
:
if
np
.
random
.
rand
()
<
0.5
:
asp_scale
=
1.0
/
asp_scale
asp_scale
=
1.0
/
asp_scale
asp_scale
=
math
.
sqrt
(
asp_scale
)
asp_scale
=
math
.
sqrt
(
asp_scale
)
asp_wx
=
asp_scale
asp_wx
=
asp_scale
asp_hy
=
1.0
/
asp_scale
asp_hy
=
1.0
/
asp_scale
im
=
cv2
.
resize
(
im
,
dsize
=
None
,
fx
=
asp_wx
,
fy
=
asp_hy
)
im
=
cv2
.
resize
(
im
,
dsize
=
None
,
fx
=
asp_wx
,
fy
=
asp_hy
)
...
@@ -610,7 +688,7 @@ class SASTProcessTrain(object):
...
@@ -610,7 +688,7 @@ class SASTProcessTrain(object):
#no background
#no background
im
,
text_polys
,
text_tags
,
hv_tags
=
self
.
crop_area
(
im
,
\
im
,
text_polys
,
text_tags
,
hv_tags
=
self
.
crop_area
(
im
,
\
text_polys
,
text_tags
,
hv_tags
,
crop_background
=
False
)
text_polys
,
text_tags
,
hv_tags
,
crop_background
=
False
)
if
text_polys
.
shape
[
0
]
==
0
:
if
text_polys
.
shape
[
0
]
==
0
:
return
None
return
None
#continue for all ignore case
#continue for all ignore case
...
@@ -621,17 +699,18 @@ class SASTProcessTrain(object):
...
@@ -621,17 +699,18 @@ class SASTProcessTrain(object):
return
None
return
None
#resize image
#resize image
std_ratio
=
float
(
self
.
input_size
)
/
max
(
new_w
,
new_h
)
std_ratio
=
float
(
self
.
input_size
)
/
max
(
new_w
,
new_h
)
rand_scales
=
np
.
array
([
0.25
,
0.375
,
0.5
,
0.625
,
0.75
,
0.875
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
])
rand_scales
=
np
.
array
(
[
0.25
,
0.375
,
0.5
,
0.625
,
0.75
,
0.875
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
])
rz_scale
=
std_ratio
*
np
.
random
.
choice
(
rand_scales
)
rz_scale
=
std_ratio
*
np
.
random
.
choice
(
rand_scales
)
im
=
cv2
.
resize
(
im
,
dsize
=
None
,
fx
=
rz_scale
,
fy
=
rz_scale
)
im
=
cv2
.
resize
(
im
,
dsize
=
None
,
fx
=
rz_scale
,
fy
=
rz_scale
)
text_polys
[:,
:,
0
]
*=
rz_scale
text_polys
[:,
:,
0
]
*=
rz_scale
text_polys
[:,
:,
1
]
*=
rz_scale
text_polys
[:,
:,
1
]
*=
rz_scale
#add gaussian blur
#add gaussian blur
if
np
.
random
.
rand
()
<
0.1
*
0.5
:
if
np
.
random
.
rand
()
<
0.1
*
0.5
:
ks
=
np
.
random
.
permutation
(
5
)[
0
]
+
1
ks
=
np
.
random
.
permutation
(
5
)[
0
]
+
1
ks
=
int
(
ks
/
2
)
*
2
+
1
ks
=
int
(
ks
/
2
)
*
2
+
1
im
=
cv2
.
GaussianBlur
(
im
,
ksize
=
(
ks
,
ks
),
sigmaX
=
0
,
sigmaY
=
0
)
im
=
cv2
.
GaussianBlur
(
im
,
ksize
=
(
ks
,
ks
),
sigmaX
=
0
,
sigmaY
=
0
)
#add brighter
#add brighter
if
np
.
random
.
rand
()
<
0.1
*
0.5
:
if
np
.
random
.
rand
()
<
0.1
*
0.5
:
im
=
im
*
(
1.0
+
np
.
random
.
rand
()
*
0.5
)
im
=
im
*
(
1.0
+
np
.
random
.
rand
()
*
0.5
)
...
@@ -640,13 +719,14 @@ class SASTProcessTrain(object):
...
@@ -640,13 +719,14 @@ class SASTProcessTrain(object):
if
np
.
random
.
rand
()
<
0.1
*
0.5
:
if
np
.
random
.
rand
()
<
0.1
*
0.5
:
im
=
im
*
(
1.0
-
np
.
random
.
rand
()
*
0.5
)
im
=
im
*
(
1.0
-
np
.
random
.
rand
()
*
0.5
)
im
=
np
.
clip
(
im
,
0.0
,
255.0
)
im
=
np
.
clip
(
im
,
0.0
,
255.0
)
# Padding the im to [input_size, input_size]
# Padding the im to [input_size, input_size]
new_h
,
new_w
,
_
=
im
.
shape
new_h
,
new_w
,
_
=
im
.
shape
if
min
(
new_w
,
new_h
)
<
self
.
input_size
*
0.5
:
if
min
(
new_w
,
new_h
)
<
self
.
input_size
*
0.5
:
return
None
return
None
im_padded
=
np
.
ones
((
self
.
input_size
,
self
.
input_size
,
3
),
dtype
=
np
.
float32
)
im_padded
=
np
.
ones
(
(
self
.
input_size
,
self
.
input_size
,
3
),
dtype
=
np
.
float32
)
im_padded
[:,
:,
2
]
=
0.485
*
255
im_padded
[:,
:,
2
]
=
0.485
*
255
im_padded
[:,
:,
1
]
=
0.456
*
255
im_padded
[:,
:,
1
]
=
0.456
*
255
im_padded
[:,
:,
0
]
=
0.406
*
255
im_padded
[:,
:,
0
]
=
0.406
*
255
...
@@ -661,24 +741,29 @@ class SASTProcessTrain(object):
...
@@ -661,24 +741,29 @@ class SASTProcessTrain(object):
sw
=
int
(
np
.
random
.
rand
()
*
del_w
)
sw
=
int
(
np
.
random
.
rand
()
*
del_w
)
# Padding
# Padding
im_padded
[
sh
:
sh
+
new_h
,
sw
:
sw
+
new_w
,
:]
=
im
.
copy
()
im_padded
[
sh
:
sh
+
new_h
,
sw
:
sw
+
new_w
,
:]
=
im
.
copy
()
text_polys
[:,
:,
0
]
+=
sw
text_polys
[:,
:,
0
]
+=
sw
text_polys
[:,
:,
1
]
+=
sh
text_polys
[:,
:,
1
]
+=
sh
score_map
,
border_map
,
training_mask
=
self
.
generate_tcl_label
(
(
self
.
input_size
,
self
.
input_size
),
score_map
,
border_map
,
training_mask
=
self
.
generate_tcl_label
(
text_polys
,
text_tags
,
0.25
)
(
self
.
input_size
,
self
.
input_size
),
text_polys
,
text_tags
,
0.25
)
# SAST head
# SAST head
tvo_map
,
tco_map
=
self
.
generate_tvo_and_tco
((
self
.
input_size
,
self
.
input_size
),
text_polys
,
text_tags
,
tcl_ratio
=
0.3
,
ds_ratio
=
0.25
)
tvo_map
,
tco_map
=
self
.
generate_tvo_and_tco
(
(
self
.
input_size
,
self
.
input_size
),
text_polys
,
text_tags
,
tcl_ratio
=
0.3
,
ds_ratio
=
0.25
)
# print("test--------tvo_map shape:", tvo_map.shape)
# print("test--------tvo_map shape:", tvo_map.shape)
im_padded
[:,
:,
2
]
-=
0.485
*
255
im_padded
[:,
:,
2
]
-=
0.485
*
255
im_padded
[:,
:,
1
]
-=
0.456
*
255
im_padded
[:,
:,
1
]
-=
0.456
*
255
im_padded
[:,
:,
0
]
-=
0.406
*
255
im_padded
[:,
:,
0
]
-=
0.406
*
255
im_padded
[:,
:,
2
]
/=
(
255.0
*
0.229
)
im_padded
[:,
:,
2
]
/=
(
255.0
*
0.229
)
im_padded
[:,
:,
1
]
/=
(
255.0
*
0.224
)
im_padded
[:,
:,
1
]
/=
(
255.0
*
0.224
)
im_padded
[:,
:,
0
]
/=
(
255.0
*
0.225
)
im_padded
[:,
:,
0
]
/=
(
255.0
*
0.225
)
im_padded
=
im_padded
.
transpose
((
2
,
0
,
1
))
im_padded
=
im_padded
.
transpose
((
2
,
0
,
1
))
data
[
'image'
]
=
im_padded
[::
-
1
,
:,
:]
data
[
'image'
]
=
im_padded
[::
-
1
,
:,
:]
data
[
'score_map'
]
=
score_map
[
np
.
newaxis
,
:,
:]
data
[
'score_map'
]
=
score_map
[
np
.
newaxis
,
:,
:]
...
@@ -686,4 +771,4 @@ class SASTProcessTrain(object):
...
@@ -686,4 +771,4 @@ class SASTProcessTrain(object):
data
[
'training_mask'
]
=
training_mask
[
np
.
newaxis
,
:,
:]
data
[
'training_mask'
]
=
training_mask
[
np
.
newaxis
,
:,
:]
data
[
'tvo_map'
]
=
tvo_map
.
transpose
((
2
,
0
,
1
))
data
[
'tvo_map'
]
=
tvo_map
.
transpose
((
2
,
0
,
1
))
data
[
'tco_map'
]
=
tco_map
.
transpose
((
2
,
0
,
1
))
data
[
'tco_map'
]
=
tco_map
.
transpose
((
2
,
0
,
1
))
return
data
return
data
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录