Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
d8719969
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d8719969
编写于
2月 22, 2021
作者:
L
littletomatodonkey
提交者:
GitHub
2月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
improve style text infer process (#2055)
* improve style text * fix dead loop
上级
6a42745f
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
81 addition
and
25 deletion
+81
-25
StyleText/engine/predictors.py
StyleText/engine/predictors.py
+24
-1
StyleText/engine/synthesisers.py
StyleText/engine/synthesisers.py
+13
-7
StyleText/engine/text_drawers.py
StyleText/engine/text_drawers.py
+44
-17
未找到文件。
StyleText/engine/predictors.py
浏览文件 @
d8719969
...
@@ -38,7 +38,15 @@ class StyleTextRecPredictor(object):
...
@@ -38,7 +38,15 @@ class StyleTextRecPredictor(object):
self
.
std
=
config
[
"Predictor"
][
"std"
]
self
.
std
=
config
[
"Predictor"
][
"std"
]
self
.
expand_result
=
config
[
"Predictor"
][
"expand_result"
]
self
.
expand_result
=
config
[
"Predictor"
][
"expand_result"
]
def
predict
(
self
,
style_input
,
text_input
):
def
reshape_to_same_height
(
self
,
img_list
):
h
=
img_list
[
0
].
shape
[
0
]
for
idx
in
range
(
1
,
len
(
img_list
)):
new_w
=
round
(
1.0
*
img_list
[
idx
].
shape
[
1
]
/
img_list
[
idx
].
shape
[
0
]
*
h
)
img_list
[
idx
]
=
cv2
.
resize
(
img_list
[
idx
],
(
new_w
,
h
))
return
img_list
def
predict_single_image
(
self
,
style_input
,
text_input
):
style_input
=
self
.
rep_style_input
(
style_input
,
text_input
)
style_input
=
self
.
rep_style_input
(
style_input
,
text_input
)
tensor_style_input
=
self
.
preprocess
(
style_input
)
tensor_style_input
=
self
.
preprocess
(
style_input
)
tensor_text_input
=
self
.
preprocess
(
text_input
)
tensor_text_input
=
self
.
preprocess
(
text_input
)
...
@@ -64,6 +72,21 @@ class StyleTextRecPredictor(object):
...
@@ -64,6 +72,21 @@ class StyleTextRecPredictor(object):
"fake_bg"
:
fake_bg
,
"fake_bg"
:
fake_bg
,
}
}
def
predict
(
self
,
style_input
,
text_input_list
):
if
not
isinstance
(
text_input_list
,
(
tuple
,
list
)):
return
self
.
predict_single_image
(
style_input
,
text_input_list
)
synth_result_list
=
[]
for
text_input
in
text_input_list
:
synth_result
=
self
.
predict_single_image
(
style_input
,
text_input
)
synth_result_list
.
append
(
synth_result
)
for
key
in
synth_result
:
res
=
[
r
[
key
]
for
r
in
synth_result_list
]
res
=
self
.
reshape_to_same_height
(
res
)
synth_result
[
key
]
=
np
.
concatenate
(
res
,
axis
=
1
)
return
synth_result
def
preprocess
(
self
,
img
):
def
preprocess
(
self
,
img
):
img
=
(
img
.
astype
(
'float32'
)
*
self
.
scale
-
self
.
mean
)
/
self
.
std
img
=
(
img
.
astype
(
'float32'
)
*
self
.
scale
-
self
.
mean
)
/
self
.
std
img_height
,
img_width
,
channel
=
img
.
shape
img_height
,
img_width
,
channel
=
img
.
shape
...
...
StyleText/engine/synthesisers.py
浏览文件 @
d8719969
...
@@ -12,6 +12,8 @@
...
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
os
import
numpy
as
np
import
cv2
from
utils.config
import
ArgsParser
,
load_config
,
override_config
from
utils.config
import
ArgsParser
,
load_config
,
override_config
from
utils.logging
import
get_logger
from
utils.logging
import
get_logger
...
@@ -36,8 +38,9 @@ class ImageSynthesiser(object):
...
@@ -36,8 +38,9 @@ class ImageSynthesiser(object):
self
.
predictor
=
getattr
(
predictors
,
predictor_method
)(
self
.
config
)
self
.
predictor
=
getattr
(
predictors
,
predictor_method
)(
self
.
config
)
def
synth_image
(
self
,
corpus
,
style_input
,
language
=
"en"
):
def
synth_image
(
self
,
corpus
,
style_input
,
language
=
"en"
):
corpus
,
text_input
=
self
.
text_drawer
.
draw_text
(
corpus
,
language
)
corpus_list
,
text_input_list
=
self
.
text_drawer
.
draw_text
(
synth_result
=
self
.
predictor
.
predict
(
style_input
,
text_input
)
corpus
,
language
,
style_input_width
=
style_input
.
shape
[
1
])
synth_result
=
self
.
predictor
.
predict
(
style_input
,
text_input_list
)
return
synth_result
return
synth_result
...
@@ -59,12 +62,15 @@ class DatasetSynthesiser(ImageSynthesiser):
...
@@ -59,12 +62,15 @@ class DatasetSynthesiser(ImageSynthesiser):
for
i
in
range
(
self
.
output_num
):
for
i
in
range
(
self
.
output_num
):
style_data
=
self
.
style_sampler
.
sample
()
style_data
=
self
.
style_sampler
.
sample
()
style_input
=
style_data
[
"image"
]
style_input
=
style_data
[
"image"
]
corpus_language
,
text_input_label
=
self
.
corpus_generator
.
generate
(
corpus_language
,
text_input_label
=
self
.
corpus_generator
.
generate
()
)
text_input_label_list
,
text_input_list
=
self
.
text_drawer
.
draw_text
(
text_input_label
,
text_input
=
self
.
text_drawer
.
draw_text
(
text_input_label
,
text_input_label
,
corpus_language
)
corpus_language
,
style_input_width
=
style_input
.
shape
[
1
])
synth_result
=
self
.
predictor
.
predict
(
style_input
,
text_input
)
text_input_label
=
""
.
join
(
text_input_label_list
)
synth_result
=
self
.
predictor
.
predict
(
style_input
,
text_input_list
)
fake_fusion
=
synth_result
[
"fake_fusion"
]
fake_fusion
=
synth_result
[
"fake_fusion"
]
self
.
writer
.
save_image
(
fake_fusion
,
text_input_label
)
self
.
writer
.
save_image
(
fake_fusion
,
text_input_label
)
self
.
writer
.
save_label
()
self
.
writer
.
save_label
()
...
...
StyleText/engine/text_drawers.py
浏览文件 @
d8719969
from
PIL
import
Image
,
ImageDraw
,
ImageFont
from
PIL
import
Image
,
ImageDraw
,
ImageFont
import
numpy
as
np
import
numpy
as
np
import
cv2
from
utils.logging
import
get_logger
from
utils.logging
import
get_logger
...
@@ -28,7 +29,11 @@ class StdTextDrawer(object):
...
@@ -28,7 +29,11 @@ class StdTextDrawer(object):
else
:
else
:
return
int
((
self
.
height
-
4
)
**
2
/
font_height
)
return
int
((
self
.
height
-
4
)
**
2
/
font_height
)
def
draw_text
(
self
,
corpus
,
language
=
"en"
,
crop
=
True
):
def
draw_text
(
self
,
corpus
,
language
=
"en"
,
crop
=
True
,
style_input_width
=
None
):
if
language
not
in
self
.
support_languages
:
if
language
not
in
self
.
support_languages
:
self
.
logger
.
warning
(
self
.
logger
.
warning
(
"language {} not supported, use en instead."
.
format
(
language
))
"language {} not supported, use en instead."
.
format
(
language
))
...
@@ -37,21 +42,43 @@ class StdTextDrawer(object):
...
@@ -37,21 +42,43 @@ class StdTextDrawer(object):
width
=
min
(
self
.
max_width
,
len
(
corpus
)
*
self
.
height
)
+
4
width
=
min
(
self
.
max_width
,
len
(
corpus
)
*
self
.
height
)
+
4
else
:
else
:
width
=
len
(
corpus
)
*
self
.
height
+
4
width
=
len
(
corpus
)
*
self
.
height
+
4
if
style_input_width
is
not
None
:
width
=
min
(
width
,
style_input_width
)
corpus_list
=
[]
text_input_list
=
[]
while
len
(
corpus
)
!=
0
:
bg
=
Image
.
new
(
"RGB"
,
(
width
,
self
.
height
),
color
=
(
127
,
127
,
127
))
bg
=
Image
.
new
(
"RGB"
,
(
width
,
self
.
height
),
color
=
(
127
,
127
,
127
))
draw
=
ImageDraw
.
Draw
(
bg
)
draw
=
ImageDraw
.
Draw
(
bg
)
char_x
=
2
char_x
=
2
font
=
self
.
font_dict
[
language
]
font
=
self
.
font_dict
[
language
]
for
i
,
char_i
in
enumerate
(
corpus
):
i
=
0
while
i
<
len
(
corpus
):
char_i
=
corpus
[
i
]
char_size
=
font
.
getsize
(
char_i
)[
0
]
char_size
=
font
.
getsize
(
char_i
)[
0
]
# split when char_x exceeds char size and index is not 0 (at least 1 char should be wroten on the image)
if
char_x
+
char_size
>=
width
and
i
!=
0
:
text_input
=
np
.
array
(
bg
).
astype
(
np
.
uint8
)
text_input
=
text_input
[:,
0
:
char_x
,
:]
corpus_list
.
append
(
corpus
[
0
:
i
])
text_input_list
.
append
(
text_input
)
corpus
=
corpus
[
i
:]
break
draw
.
text
((
char_x
,
2
),
char_i
,
fill
=
(
0
,
0
,
0
),
font
=
font
)
draw
.
text
((
char_x
,
2
),
char_i
,
fill
=
(
0
,
0
,
0
),
font
=
font
)
char_x
+=
char_size
char_x
+=
char_size
if
char_x
>=
width
:
corpus
=
corpus
[
0
:
i
+
1
]
self
.
logger
.
warning
(
"corpus length exceed limit: {}"
.
format
(
corpus
))
break
i
+=
1
# the whole text is shorter than style input
if
i
==
len
(
corpus
):
text_input
=
np
.
array
(
bg
).
astype
(
np
.
uint8
)
text_input
=
np
.
array
(
bg
).
astype
(
np
.
uint8
)
text_input
=
text_input
[:,
0
:
char_x
,
:]
text_input
=
text_input
[:,
0
:
char_x
,
:]
return
corpus
,
text_input
corpus_list
.
append
(
corpus
[
0
:
i
])
text_input_list
.
append
(
text_input
)
corpus
=
corpus
[
i
:]
break
return
corpus_list
,
text_input_list
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录