Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
115b5175
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
115b5175
编写于
9月 23, 2020
作者:
D
Double_V
提交者:
GitHub
9月 23, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #814 from littletomatodonkey/add_tia
add tia aug
上级
5d202e44
3a18b08f
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
315 addition
and
38 deletion
+315
-38
ppocr/data/rec/img_tools.py
ppocr/data/rec/img_tools.py
+54
-38
ppocr/data/rec/text_image_aug/augment.py
ppocr/data/rec/text_image_aug/augment.py
+107
-0
ppocr/data/rec/text_image_aug/warp_mls.py
ppocr/data/rec/text_image_aug/warp_mls.py
+154
-0
未找到文件。
ppocr/data/rec/img_tools.py
浏览文件 @
115b5175
...
@@ -19,6 +19,8 @@ import random
...
@@ -19,6 +19,8 @@ import random
from
ppocr.utils.utility
import
initial_logger
from
ppocr.utils.utility
import
initial_logger
logger
=
initial_logger
()
logger
=
initial_logger
()
from
.text_image_aug.augment
import
tia_distort
,
tia_stretch
,
tia_perspective
def
get_bounding_box_rect
(
pos
):
def
get_bounding_box_rect
(
pos
):
left
=
min
(
pos
[
0
])
left
=
min
(
pos
[
0
])
...
@@ -196,6 +198,9 @@ class Config:
...
@@ -196,6 +198,9 @@ class Config:
self
.
h
=
h
self
.
h
=
h
self
.
perspective
=
True
self
.
perspective
=
True
self
.
stretch
=
True
self
.
distort
=
True
self
.
crop
=
True
self
.
crop
=
True
self
.
affine
=
False
self
.
affine
=
False
self
.
reverse
=
True
self
.
reverse
=
True
...
@@ -299,41 +304,40 @@ def warp(img, ang):
...
@@ -299,41 +304,40 @@ def warp(img, ang):
config
.
make
(
w
,
h
,
ang
)
config
.
make
(
w
,
h
,
ang
)
new_img
=
img
new_img
=
img
prob
=
0.4
if
config
.
distort
:
img_height
,
img_width
=
img
.
shape
[
0
:
2
]
if
random
.
random
()
<=
prob
and
img_height
>=
20
and
img_width
>=
20
:
new_img
=
tia_distort
(
new_img
,
random
.
randint
(
3
,
6
))
if
config
.
stretch
:
img_height
,
img_width
=
img
.
shape
[
0
:
2
]
if
random
.
random
()
<=
prob
and
img_height
>=
20
and
img_width
>=
20
:
new_img
=
tia_stretch
(
new_img
,
random
.
randint
(
3
,
6
))
if
config
.
perspective
:
if
config
.
perspective
:
tp
=
random
.
randint
(
1
,
100
)
if
random
.
random
()
<=
prob
:
if
tp
>=
50
:
new_img
=
tia_perspective
(
new_img
)
warpR
,
(
r1
,
c1
),
ratio
,
dst
=
get_warpR
(
config
)
new_w
=
int
(
np
.
max
(
dst
[:,
0
]))
-
int
(
np
.
min
(
dst
[:,
0
]))
new_img
=
cv2
.
warpPerspective
(
new_img
,
warpR
,
(
int
(
new_w
*
ratio
),
h
),
borderMode
=
config
.
borderMode
)
if
config
.
crop
:
if
config
.
crop
:
img_height
,
img_width
=
img
.
shape
[
0
:
2
]
img_height
,
img_width
=
img
.
shape
[
0
:
2
]
tp
=
random
.
randint
(
1
,
100
)
if
random
.
random
()
<=
prob
and
img_height
>=
20
and
img_width
>=
20
:
if
tp
>=
50
and
img_height
>=
20
and
img_width
>=
20
:
new_img
=
get_crop
(
new_img
)
new_img
=
get_crop
(
new_img
)
if
config
.
affine
:
warpT
=
get_warpAffine
(
config
)
new_img
=
cv2
.
warpAffine
(
new_img
,
warpT
,
(
w
,
h
),
borderMode
=
config
.
borderMode
)
if
config
.
blur
:
if
config
.
blur
:
tp
=
random
.
randint
(
1
,
100
)
if
random
.
random
()
<=
prob
:
if
tp
>=
50
:
new_img
=
blur
(
new_img
)
new_img
=
blur
(
new_img
)
if
config
.
color
:
if
config
.
color
:
tp
=
random
.
randint
(
1
,
100
)
if
random
.
random
()
<=
prob
:
if
tp
>=
50
:
new_img
=
cvtColor
(
new_img
)
new_img
=
cvtColor
(
new_img
)
if
config
.
jitter
:
if
config
.
jitter
:
new_img
=
jitter
(
new_img
)
new_img
=
jitter
(
new_img
)
if
config
.
noise
:
if
config
.
noise
:
tp
=
random
.
randint
(
1
,
100
)
if
random
.
random
()
<=
prob
:
if
tp
>=
50
:
new_img
=
add_gasuss_noise
(
new_img
)
new_img
=
add_gasuss_noise
(
new_img
)
if
config
.
reverse
:
if
config
.
reverse
:
tp
=
random
.
randint
(
1
,
100
)
if
random
.
random
()
<=
prob
:
if
tp
>=
50
:
new_img
=
255
-
new_img
new_img
=
255
-
new_img
return
new_img
return
new_img
...
@@ -382,6 +386,7 @@ def process_image(img,
...
@@ -382,6 +386,7 @@ def process_image(img,
%
loss_type
%
loss_type
return
(
norm_img
)
return
(
norm_img
)
def
resize_norm_img_srn
(
img
,
image_shape
):
def
resize_norm_img_srn
(
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
imgC
,
imgH
,
imgW
=
image_shape
...
@@ -408,30 +413,39 @@ def resize_norm_img_srn(img, image_shape):
...
@@ -408,30 +413,39 @@ def resize_norm_img_srn(img, image_shape):
return
np
.
reshape
(
img_black
,
(
c
,
row
,
col
)).
astype
(
np
.
float32
)
return
np
.
reshape
(
img_black
,
(
c
,
row
,
col
)).
astype
(
np
.
float32
)
def
srn_other_inputs
(
image_shape
,
num_heads
,
def
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
,
char_num
):
max_text_length
,
char_num
):
imgC
,
imgH
,
imgW
=
image_shape
imgC
,
imgH
,
imgW
=
image_shape
feature_dim
=
int
((
imgH
/
8
)
*
(
imgW
/
8
))
feature_dim
=
int
((
imgH
/
8
)
*
(
imgW
/
8
))
encoder_word_pos
=
np
.
array
(
range
(
0
,
feature_dim
)).
reshape
((
feature_dim
,
1
)).
astype
(
'int64'
)
encoder_word_pos
=
np
.
array
(
range
(
0
,
feature_dim
)).
reshape
(
gsrm_word_pos
=
np
.
array
(
range
(
0
,
max_text_length
)).
reshape
((
max_text_length
,
1
)).
astype
(
'int64'
)
(
feature_dim
,
1
)).
astype
(
'int64'
)
gsrm_word_pos
=
np
.
array
(
range
(
0
,
max_text_length
)).
reshape
(
(
max_text_length
,
1
)).
astype
(
'int64'
)
lbl_weight
=
np
.
array
([
int
(
char_num
-
1
)]
*
max_text_length
).
reshape
((
-
1
,
1
)).
astype
(
'int64'
)
lbl_weight
=
np
.
array
([
int
(
char_num
-
1
)]
*
max_text_length
).
reshape
(
(
-
1
,
1
)).
astype
(
'int64'
)
gsrm_attn_bias_data
=
np
.
ones
((
1
,
max_text_length
,
max_text_length
))
gsrm_attn_bias_data
=
np
.
ones
((
1
,
max_text_length
,
max_text_length
))
gsrm_slf_attn_bias1
=
np
.
triu
(
gsrm_attn_bias_data
,
1
).
reshape
([
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias1
=
np
.
triu
(
gsrm_attn_bias_data
,
1
).
reshape
(
gsrm_slf_attn_bias1
=
np
.
tile
(
gsrm_slf_attn_bias1
,
[
1
,
num_heads
,
1
,
1
])
*
[
-
1e9
]
[
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias1
=
np
.
tile
(
gsrm_slf_attn_bias1
,
[
1
,
num_heads
,
1
,
1
])
*
[
-
1e9
]
gsrm_slf_attn_bias2
=
np
.
tril
(
gsrm_attn_bias_data
,
-
1
).
reshape
([
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias2
=
np
.
tril
(
gsrm_attn_bias_data
,
-
1
).
reshape
(
gsrm_slf_attn_bias2
=
np
.
tile
(
gsrm_slf_attn_bias2
,
[
1
,
num_heads
,
1
,
1
])
*
[
-
1e9
]
[
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias2
=
np
.
tile
(
gsrm_slf_attn_bias2
,
[
1
,
num_heads
,
1
,
1
])
*
[
-
1e9
]
encoder_word_pos
=
encoder_word_pos
[
np
.
newaxis
,
:]
encoder_word_pos
=
encoder_word_pos
[
np
.
newaxis
,
:]
gsrm_word_pos
=
gsrm_word_pos
[
np
.
newaxis
,
:]
gsrm_word_pos
=
gsrm_word_pos
[
np
.
newaxis
,
:]
return
[
lbl_weight
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
return
[
lbl_weight
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
def
process_image_srn
(
img
,
def
process_image_srn
(
img
,
image_shape
,
image_shape
,
...
@@ -453,14 +467,16 @@ def process_image_srn(img,
...
@@ -453,14 +467,16 @@ def process_image_srn(img,
return
None
return
None
else
:
else
:
if
loss_type
==
"srn"
:
if
loss_type
==
"srn"
:
text_padded
=
[
int
(
char_num
-
1
)]
*
max_text_length
text_padded
=
[
int
(
char_num
-
1
)]
*
max_text_length
for
i
in
range
(
len
(
text
)):
for
i
in
range
(
len
(
text
)):
text_padded
[
i
]
=
text
[
i
]
text_padded
[
i
]
=
text
[
i
]
lbl_weight
[
i
]
=
[
1.0
]
lbl_weight
[
i
]
=
[
1.0
]
text_padded
=
np
.
array
(
text_padded
)
text_padded
=
np
.
array
(
text_padded
)
text
=
text_padded
.
reshape
(
-
1
,
1
)
text
=
text_padded
.
reshape
(
-
1
,
1
)
return
(
norm_img
,
text
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
,
lbl_weight
)
return
(
norm_img
,
text
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
,
lbl_weight
)
else
:
else
:
assert
False
,
"Unsupport loss_type %s in process_image"
\
assert
False
,
"Unsupport loss_type %s in process_image"
\
%
loss_type
%
loss_type
return
(
norm_img
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
)
return
(
norm_img
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
)
ppocr/data/rec/text_image_aug/augment.py
0 → 100644
浏览文件 @
115b5175
# -*- coding:utf-8 -*-
# Author: RubanSeven
# Reference: https://github.com/RubanSeven/Text-Image-Augmentation-python
# import cv2
import
numpy
as
np
from
.warp_mls
import
WarpMLS
def
tia_distort
(
src
,
segment
=
4
):
img_h
,
img_w
=
src
.
shape
[:
2
]
cut
=
img_w
//
segment
thresh
=
cut
//
3
src_pts
=
list
()
dst_pts
=
list
()
src_pts
.
append
([
0
,
0
])
src_pts
.
append
([
img_w
,
0
])
src_pts
.
append
([
img_w
,
img_h
])
src_pts
.
append
([
0
,
img_h
])
dst_pts
.
append
([
np
.
random
.
randint
(
thresh
),
np
.
random
.
randint
(
thresh
)])
dst_pts
.
append
(
[
img_w
-
np
.
random
.
randint
(
thresh
),
np
.
random
.
randint
(
thresh
)])
dst_pts
.
append
(
[
img_w
-
np
.
random
.
randint
(
thresh
),
img_h
-
np
.
random
.
randint
(
thresh
)])
dst_pts
.
append
(
[
np
.
random
.
randint
(
thresh
),
img_h
-
np
.
random
.
randint
(
thresh
)])
half_thresh
=
thresh
*
0.5
for
cut_idx
in
np
.
arange
(
1
,
segment
,
1
):
src_pts
.
append
([
cut
*
cut_idx
,
0
])
src_pts
.
append
([
cut
*
cut_idx
,
img_h
])
dst_pts
.
append
([
cut
*
cut_idx
+
np
.
random
.
randint
(
thresh
)
-
half_thresh
,
np
.
random
.
randint
(
thresh
)
-
half_thresh
])
dst_pts
.
append
([
cut
*
cut_idx
+
np
.
random
.
randint
(
thresh
)
-
half_thresh
,
img_h
+
np
.
random
.
randint
(
thresh
)
-
half_thresh
])
trans
=
WarpMLS
(
src
,
src_pts
,
dst_pts
,
img_w
,
img_h
)
dst
=
trans
.
generate
()
return
dst
def
tia_stretch
(
src
,
segment
=
4
):
img_h
,
img_w
=
src
.
shape
[:
2
]
cut
=
img_w
//
segment
thresh
=
cut
*
4
//
5
src_pts
=
list
()
dst_pts
=
list
()
src_pts
.
append
([
0
,
0
])
src_pts
.
append
([
img_w
,
0
])
src_pts
.
append
([
img_w
,
img_h
])
src_pts
.
append
([
0
,
img_h
])
dst_pts
.
append
([
0
,
0
])
dst_pts
.
append
([
img_w
,
0
])
dst_pts
.
append
([
img_w
,
img_h
])
dst_pts
.
append
([
0
,
img_h
])
half_thresh
=
thresh
*
0.5
for
cut_idx
in
np
.
arange
(
1
,
segment
,
1
):
move
=
np
.
random
.
randint
(
thresh
)
-
half_thresh
src_pts
.
append
([
cut
*
cut_idx
,
0
])
src_pts
.
append
([
cut
*
cut_idx
,
img_h
])
dst_pts
.
append
([
cut
*
cut_idx
+
move
,
0
])
dst_pts
.
append
([
cut
*
cut_idx
+
move
,
img_h
])
trans
=
WarpMLS
(
src
,
src_pts
,
dst_pts
,
img_w
,
img_h
)
dst
=
trans
.
generate
()
return
dst
def
tia_perspective
(
src
):
img_h
,
img_w
=
src
.
shape
[:
2
]
thresh
=
img_h
//
2
src_pts
=
list
()
dst_pts
=
list
()
src_pts
.
append
([
0
,
0
])
src_pts
.
append
([
img_w
,
0
])
src_pts
.
append
([
img_w
,
img_h
])
src_pts
.
append
([
0
,
img_h
])
dst_pts
.
append
([
0
,
np
.
random
.
randint
(
thresh
)])
dst_pts
.
append
([
img_w
,
np
.
random
.
randint
(
thresh
)])
dst_pts
.
append
([
img_w
,
img_h
-
np
.
random
.
randint
(
thresh
)])
dst_pts
.
append
([
0
,
img_h
-
np
.
random
.
randint
(
thresh
)])
trans
=
WarpMLS
(
src
,
src_pts
,
dst_pts
,
img_w
,
img_h
)
dst
=
trans
.
generate
()
return
dst
ppocr/data/rec/text_image_aug/warp_mls.py
0 → 100644
浏览文件 @
115b5175
# -*- coding:utf-8 -*-
# Author: RubanSeven
# Reference: https://github.com/RubanSeven/Text-Image-Augmentation-python
import
math
import
numpy
as
np
class
WarpMLS
:
def
__init__
(
self
,
src
,
src_pts
,
dst_pts
,
dst_w
,
dst_h
,
trans_ratio
=
1.
):
self
.
src
=
src
self
.
src_pts
=
src_pts
self
.
dst_pts
=
dst_pts
self
.
pt_count
=
len
(
self
.
dst_pts
)
self
.
dst_w
=
dst_w
self
.
dst_h
=
dst_h
self
.
trans_ratio
=
trans_ratio
self
.
grid_size
=
100
self
.
rdx
=
np
.
zeros
((
self
.
dst_h
,
self
.
dst_w
))
self
.
rdy
=
np
.
zeros
((
self
.
dst_h
,
self
.
dst_w
))
@
staticmethod
def
__bilinear_interp
(
x
,
y
,
v11
,
v12
,
v21
,
v22
):
return
(
v11
*
(
1
-
y
)
+
v12
*
y
)
*
(
1
-
x
)
+
(
v21
*
(
1
-
y
)
+
v22
*
y
)
*
x
def
generate
(
self
):
self
.
calc_delta
()
return
self
.
gen_img
()
def
calc_delta
(
self
):
w
=
np
.
zeros
(
self
.
pt_count
,
dtype
=
np
.
float32
)
if
self
.
pt_count
<
2
:
return
i
=
0
while
1
:
if
self
.
dst_w
<=
i
<
self
.
dst_w
+
self
.
grid_size
-
1
:
i
=
self
.
dst_w
-
1
elif
i
>=
self
.
dst_w
:
break
j
=
0
while
1
:
if
self
.
dst_h
<=
j
<
self
.
dst_h
+
self
.
grid_size
-
1
:
j
=
self
.
dst_h
-
1
elif
j
>=
self
.
dst_h
:
break
sw
=
0
swp
=
np
.
zeros
(
2
,
dtype
=
np
.
float32
)
swq
=
np
.
zeros
(
2
,
dtype
=
np
.
float32
)
new_pt
=
np
.
zeros
(
2
,
dtype
=
np
.
float32
)
cur_pt
=
np
.
array
([
i
,
j
],
dtype
=
np
.
float32
)
k
=
0
for
k
in
range
(
self
.
pt_count
):
if
i
==
self
.
dst_pts
[
k
][
0
]
and
j
==
self
.
dst_pts
[
k
][
1
]:
break
w
[
k
]
=
1.
/
(
(
i
-
self
.
dst_pts
[
k
][
0
])
*
(
i
-
self
.
dst_pts
[
k
][
0
])
+
(
j
-
self
.
dst_pts
[
k
][
1
])
*
(
j
-
self
.
dst_pts
[
k
][
1
]))
sw
+=
w
[
k
]
swp
=
swp
+
w
[
k
]
*
np
.
array
(
self
.
dst_pts
[
k
])
swq
=
swq
+
w
[
k
]
*
np
.
array
(
self
.
src_pts
[
k
])
if
k
==
self
.
pt_count
-
1
:
pstar
=
1
/
sw
*
swp
qstar
=
1
/
sw
*
swq
miu_s
=
0
for
k
in
range
(
self
.
pt_count
):
if
i
==
self
.
dst_pts
[
k
][
0
]
and
j
==
self
.
dst_pts
[
k
][
1
]:
continue
pt_i
=
self
.
dst_pts
[
k
]
-
pstar
miu_s
+=
w
[
k
]
*
np
.
sum
(
pt_i
*
pt_i
)
cur_pt
-=
pstar
cur_pt_j
=
np
.
array
([
-
cur_pt
[
1
],
cur_pt
[
0
]])
for
k
in
range
(
self
.
pt_count
):
if
i
==
self
.
dst_pts
[
k
][
0
]
and
j
==
self
.
dst_pts
[
k
][
1
]:
continue
pt_i
=
self
.
dst_pts
[
k
]
-
pstar
pt_j
=
np
.
array
([
-
pt_i
[
1
],
pt_i
[
0
]])
tmp_pt
=
np
.
zeros
(
2
,
dtype
=
np
.
float32
)
tmp_pt
[
0
]
=
np
.
sum
(
pt_i
*
cur_pt
)
*
self
.
src_pts
[
k
][
0
]
-
\
np
.
sum
(
pt_j
*
cur_pt
)
*
self
.
src_pts
[
k
][
1
]
tmp_pt
[
1
]
=
-
np
.
sum
(
pt_i
*
cur_pt_j
)
*
self
.
src_pts
[
k
][
0
]
+
\
np
.
sum
(
pt_j
*
cur_pt_j
)
*
self
.
src_pts
[
k
][
1
]
tmp_pt
*=
(
w
[
k
]
/
miu_s
)
new_pt
+=
tmp_pt
new_pt
+=
qstar
else
:
new_pt
=
self
.
src_pts
[
k
]
self
.
rdx
[
j
,
i
]
=
new_pt
[
0
]
-
i
self
.
rdy
[
j
,
i
]
=
new_pt
[
1
]
-
j
j
+=
self
.
grid_size
i
+=
self
.
grid_size
def
gen_img
(
self
):
src_h
,
src_w
=
self
.
src
.
shape
[:
2
]
dst
=
np
.
zeros_like
(
self
.
src
,
dtype
=
np
.
float32
)
for
i
in
np
.
arange
(
0
,
self
.
dst_h
,
self
.
grid_size
):
for
j
in
np
.
arange
(
0
,
self
.
dst_w
,
self
.
grid_size
):
ni
=
i
+
self
.
grid_size
nj
=
j
+
self
.
grid_size
w
=
h
=
self
.
grid_size
if
ni
>=
self
.
dst_h
:
ni
=
self
.
dst_h
-
1
h
=
ni
-
i
+
1
if
nj
>=
self
.
dst_w
:
nj
=
self
.
dst_w
-
1
w
=
nj
-
j
+
1
di
=
np
.
reshape
(
np
.
arange
(
h
),
(
-
1
,
1
))
dj
=
np
.
reshape
(
np
.
arange
(
w
),
(
1
,
-
1
))
delta_x
=
self
.
__bilinear_interp
(
di
/
h
,
dj
/
w
,
self
.
rdx
[
i
,
j
],
self
.
rdx
[
i
,
nj
],
self
.
rdx
[
ni
,
j
],
self
.
rdx
[
ni
,
nj
])
delta_y
=
self
.
__bilinear_interp
(
di
/
h
,
dj
/
w
,
self
.
rdy
[
i
,
j
],
self
.
rdy
[
i
,
nj
],
self
.
rdy
[
ni
,
j
],
self
.
rdy
[
ni
,
nj
])
nx
=
j
+
dj
+
delta_x
*
self
.
trans_ratio
ny
=
i
+
di
+
delta_y
*
self
.
trans_ratio
nx
=
np
.
clip
(
nx
,
0
,
src_w
-
1
)
ny
=
np
.
clip
(
ny
,
0
,
src_h
-
1
)
nxi
=
np
.
array
(
np
.
floor
(
nx
),
dtype
=
np
.
int32
)
nyi
=
np
.
array
(
np
.
floor
(
ny
),
dtype
=
np
.
int32
)
nxi1
=
np
.
array
(
np
.
ceil
(
nx
),
dtype
=
np
.
int32
)
nyi1
=
np
.
array
(
np
.
ceil
(
ny
),
dtype
=
np
.
int32
)
if
len
(
self
.
src
.
shape
)
==
3
:
x
=
np
.
tile
(
np
.
expand_dims
(
ny
-
nyi
,
axis
=-
1
),
(
1
,
1
,
3
))
y
=
np
.
tile
(
np
.
expand_dims
(
nx
-
nxi
,
axis
=-
1
),
(
1
,
1
,
3
))
else
:
x
=
ny
-
nyi
y
=
nx
-
nxi
dst
[
i
:
i
+
h
,
j
:
j
+
w
]
=
self
.
__bilinear_interp
(
x
,
y
,
self
.
src
[
nyi
,
nxi
],
self
.
src
[
nyi
,
nxi1
],
self
.
src
[
nyi1
,
nxi
],
self
.
src
[
nyi1
,
nxi1
])
dst
=
np
.
clip
(
dst
,
0
,
255
)
dst
=
np
.
array
(
dst
,
dtype
=
np
.
uint8
)
return
dst
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录