Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
f2d98c5e
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
f2d98c5e
编写于
12月 15, 2020
作者:
W
weishengyu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add style_text_rec
上级
b1623d69
变更
27
隐藏空白更改
内联
并排
Showing
27 changed file
with
2190 addition
and
0 deletion
+2190
-0
tools/style_text_rec/arch/base_module.py
tools/style_text_rec/arch/base_module.py
+255
-0
tools/style_text_rec/arch/decoder.py
tools/style_text_rec/arch/decoder.py
+250
-0
tools/style_text_rec/arch/encoder.py
tools/style_text_rec/arch/encoder.py
+185
-0
tools/style_text_rec/arch/spectral_norm.py
tools/style_text_rec/arch/spectral_norm.py
+154
-0
tools/style_text_rec/arch/style_text_rec.py
tools/style_text_rec/arch/style_text_rec.py
+288
-0
tools/style_text_rec/configs/config.yml
tools/style_text_rec/configs/config.yml
+54
-0
tools/style_text_rec/configs/dataset_config.yml
tools/style_text_rec/configs/dataset_config.yml
+64
-0
tools/style_text_rec/engine/corpus_generators.py
tools/style_text_rec/engine/corpus_generators.py
+54
-0
tools/style_text_rec/engine/predictors.py
tools/style_text_rec/engine/predictors.py
+115
-0
tools/style_text_rec/engine/style_samplers.py
tools/style_text_rec/engine/style_samplers.py
+62
-0
tools/style_text_rec/engine/synthesisers.py
tools/style_text_rec/engine/synthesisers.py
+58
-0
tools/style_text_rec/engine/text_drawers.py
tools/style_text_rec/engine/text_drawers.py
+58
-0
tools/style_text_rec/engine/writers.py
tools/style_text_rec/engine/writers.py
+71
-0
tools/style_text_rec/examples/corpus/example.txt
tools/style_text_rec/examples/corpus/example.txt
+2
-0
tools/style_text_rec/examples/image_list.txt
tools/style_text_rec/examples/image_list.txt
+2
-0
tools/style_text_rec/examples/style_images/1.jpg
tools/style_text_rec/examples/style_images/1.jpg
+0
-0
tools/style_text_rec/examples/style_images/2.jpg
tools/style_text_rec/examples/style_images/2.jpg
+0
-0
tools/style_text_rec/fonts/ch_standard.ttf
tools/style_text_rec/fonts/ch_standard.ttf
+0
-0
tools/style_text_rec/fonts/en_standard.ttf
tools/style_text_rec/fonts/en_standard.ttf
+0
-0
tools/style_text_rec/fonts/ko_standard.ttf
tools/style_text_rec/fonts/ko_standard.ttf
+0
-0
tools/style_text_rec/tools/synth_dataset.py
tools/style_text_rec/tools/synth_dataset.py
+10
-0
tools/style_text_rec/tools/synth_image.py
tools/style_text_rec/tools/synth_image.py
+78
-0
tools/style_text_rec/utils/config.py
tools/style_text_rec/utils/config.py
+219
-0
tools/style_text_rec/utils/load_params.py
tools/style_text_rec/utils/load_params.py
+33
-0
tools/style_text_rec/utils/logging.py
tools/style_text_rec/utils/logging.py
+66
-0
tools/style_text_rec/utils/math_functions.py
tools/style_text_rec/utils/math_functions.py
+45
-0
tools/style_text_rec/utils/sys_funcs.py
tools/style_text_rec/utils/sys_funcs.py
+67
-0
未找到文件。
tools/style_text_rec/arch/base_module.py
0 → 100644
浏览文件 @
f2d98c5e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
functools
import
paddle
import
paddle.nn
as
nn
from
arch.spectral_norm
import
spectral_norm
class
CBN
(
nn
.
Layer
):
def
__init__
(
self
,
name
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
use_bias
=
False
,
norm_layer
=
None
,
act
=
None
,
act_attr
=
None
):
super
(
CBN
,
self
).
__init__
()
if
use_bias
:
bias_attr
=
paddle
.
ParamAttr
(
name
=
name
+
"_bias"
)
else
:
bias_attr
=
None
self
.
_conv
=
paddle
.
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
groups
,
weight_attr
=
paddle
.
ParamAttr
(
name
=
name
+
"_weights"
),
bias_attr
=
bias_attr
)
if
norm_layer
:
self
.
_norm_layer
=
getattr
(
paddle
.
nn
,
norm_layer
)(
num_features
=
out_channels
,
name
=
name
+
"_bn"
)
else
:
self
.
_norm_layer
=
None
if
act
:
if
act_attr
:
self
.
_act
=
getattr
(
paddle
.
nn
,
act
)(
**
act_attr
,
name
=
name
+
"_"
+
act
)
else
:
self
.
_act
=
getattr
(
paddle
.
nn
,
act
)(
name
=
name
+
"_"
+
act
)
else
:
self
.
_act
=
None
def
forward
(
self
,
x
):
out
=
self
.
_conv
(
x
)
if
self
.
_norm_layer
:
out
=
self
.
_norm_layer
(
out
)
if
self
.
_act
:
out
=
self
.
_act
(
out
)
return
out
class
SNConv
(
nn
.
Layer
):
def
__init__
(
self
,
name
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
use_bias
=
False
,
norm_layer
=
None
,
act
=
None
,
act_attr
=
None
):
super
(
SNConv
,
self
).
__init__
()
if
use_bias
:
bias_attr
=
paddle
.
ParamAttr
(
name
=
name
+
"_bias"
)
else
:
bias_attr
=
None
self
.
_sn_conv
=
spectral_norm
(
paddle
.
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
groups
,
weight_attr
=
paddle
.
ParamAttr
(
name
=
name
+
"_weights"
),
bias_attr
=
bias_attr
))
if
norm_layer
:
self
.
_norm_layer
=
getattr
(
paddle
.
nn
,
norm_layer
)(
num_features
=
out_channels
,
name
=
name
+
"_bn"
)
else
:
self
.
_norm_layer
=
None
if
act
:
if
act_attr
:
self
.
_act
=
getattr
(
paddle
.
nn
,
act
)(
**
act_attr
,
name
=
name
+
"_"
+
act
)
else
:
self
.
_act
=
getattr
(
paddle
.
nn
,
act
)(
name
=
name
+
"_"
+
act
)
else
:
self
.
_act
=
None
def
forward
(
self
,
x
):
out
=
self
.
_sn_conv
(
x
)
if
self
.
_norm_layer
:
out
=
self
.
_norm_layer
(
out
)
if
self
.
_act
:
out
=
self
.
_act
(
out
)
return
out
class
SNConvTranspose
(
nn
.
Layer
):
def
__init__
(
self
,
name
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
output_padding
=
0
,
dilation
=
1
,
groups
=
1
,
use_bias
=
False
,
norm_layer
=
None
,
act
=
None
,
act_attr
=
None
):
super
(
SNConvTranspose
,
self
).
__init__
()
if
use_bias
:
bias_attr
=
paddle
.
ParamAttr
(
name
=
name
+
"_bias"
)
else
:
bias_attr
=
None
self
.
_sn_conv_transpose
=
spectral_norm
(
paddle
.
nn
.
Conv2DTranspose
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
output_padding
=
output_padding
,
dilation
=
dilation
,
groups
=
groups
,
weight_attr
=
paddle
.
ParamAttr
(
name
=
name
+
"_weights"
),
bias_attr
=
bias_attr
))
if
norm_layer
:
self
.
_norm_layer
=
getattr
(
paddle
.
nn
,
norm_layer
)(
num_features
=
out_channels
,
name
=
name
+
"_bn"
)
else
:
self
.
_norm_layer
=
None
if
act
:
if
act_attr
:
self
.
_act
=
getattr
(
paddle
.
nn
,
act
)(
**
act_attr
,
name
=
name
+
"_"
+
act
)
else
:
self
.
_act
=
getattr
(
paddle
.
nn
,
act
)(
name
=
name
+
"_"
+
act
)
else
:
self
.
_act
=
None
def
forward
(
self
,
x
):
out
=
self
.
_sn_conv_transpose
(
x
)
if
self
.
_norm_layer
:
out
=
self
.
_norm_layer
(
out
)
if
self
.
_act
:
out
=
self
.
_act
(
out
)
return
out
class
MiddleNet
(
nn
.
Layer
):
def
__init__
(
self
,
name
,
in_channels
,
mid_channels
,
out_channels
,
use_bias
):
super
(
MiddleNet
,
self
).
__init__
()
self
.
_sn_conv1
=
SNConv
(
name
=
name
+
"_sn_conv1"
,
in_channels
=
in_channels
,
out_channels
=
mid_channels
,
kernel_size
=
1
,
use_bias
=
use_bias
,
norm_layer
=
None
,
act
=
None
)
self
.
_pad2d
=
nn
.
Pad2D
(
padding
=
[
1
,
1
,
1
,
1
],
mode
=
"replicate"
)
self
.
_sn_conv2
=
SNConv
(
name
=
name
+
"_sn_conv2"
,
in_channels
=
mid_channels
,
out_channels
=
mid_channels
,
kernel_size
=
3
,
use_bias
=
use_bias
)
self
.
_sn_conv3
=
SNConv
(
name
=
name
+
"_sn_conv3"
,
in_channels
=
mid_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
use_bias
=
use_bias
)
def
forward
(
self
,
x
):
sn_conv1
=
self
.
_sn_conv1
.
forward
(
x
)
pad_2d
=
self
.
_pad2d
.
forward
(
sn_conv1
)
sn_conv2
=
self
.
_sn_conv2
.
forward
(
pad_2d
)
sn_conv3
=
self
.
_sn_conv3
.
forward
(
sn_conv2
)
return
sn_conv3
class
ResBlock
(
nn
.
Layer
):
def
__init__
(
self
,
name
,
channels
,
norm_layer
,
use_dropout
,
use_dilation
,
use_bias
):
super
(
ResBlock
,
self
).
__init__
()
if
use_dilation
:
padding_mat
=
[
1
,
1
,
1
,
1
]
else
:
padding_mat
=
[
0
,
0
,
0
,
0
]
self
.
_pad1
=
nn
.
Pad2D
(
padding_mat
,
mode
=
"replicate"
)
self
.
_sn_conv1
=
SNConv
(
name
=
name
+
"_sn_conv1"
,
in_channels
=
channels
,
out_channels
=
channels
,
kernel_size
=
3
,
padding
=
0
,
norm_layer
=
norm_layer
,
use_bias
=
use_bias
,
act
=
"ReLU"
,
act_attr
=
None
)
if
use_dropout
:
self
.
_dropout
=
nn
.
Dropout
(
0.5
)
else
:
self
.
_dropout
=
None
self
.
_pad2
=
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
"replicate"
)
self
.
_sn_conv2
=
SNConv
(
name
=
name
+
"_sn_conv2"
,
in_channels
=
channels
,
out_channels
=
channels
,
kernel_size
=
3
,
norm_layer
=
norm_layer
,
use_bias
=
use_bias
,
act
=
"ReLU"
,
act_attr
=
None
)
def
forward
(
self
,
x
):
pad1
=
self
.
_pad1
.
forward
(
x
)
sn_conv1
=
self
.
_sn_conv1
.
forward
(
pad1
)
pad2
=
self
.
_pad2
.
forward
(
sn_conv1
)
sn_conv2
=
self
.
_sn_conv2
.
forward
(
pad2
)
return
sn_conv2
+
x
tools/style_text_rec/arch/decoder.py
0 → 100644
浏览文件 @
f2d98c5e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
from
arch.base_module
import
SNConv
,
SNConvTranspose
,
ResBlock
class
Decoder
(
nn
.
Layer
):
def
__init__
(
self
,
name
,
encode_dim
,
out_channels
,
use_bias
,
norm_layer
,
act
,
act_attr
,
conv_block_dropout
,
conv_block_num
,
conv_block_dilation
,
out_conv_act
,
out_conv_act_attr
):
super
(
Decoder
,
self
).
__init__
()
conv_blocks
=
[]
for
i
in
range
(
conv_block_num
):
conv_blocks
.
append
(
ResBlock
(
name
=
"{}_conv_block_{}"
.
format
(
name
,
i
),
channels
=
encode_dim
*
8
,
norm_layer
=
norm_layer
,
use_dropout
=
conv_block_dropout
,
use_dilation
=
conv_block_dilation
,
use_bias
=
use_bias
))
self
.
conv_blocks
=
nn
.
Sequential
(
*
conv_blocks
)
self
.
_up1
=
SNConvTranspose
(
name
=
name
+
"_up1"
,
in_channels
=
encode_dim
*
8
,
out_channels
=
encode_dim
*
4
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
output_padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_up2
=
SNConvTranspose
(
name
=
name
+
"_up2"
,
in_channels
=
encode_dim
*
4
,
out_channels
=
encode_dim
*
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
output_padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_up3
=
SNConvTranspose
(
name
=
name
+
"_up3"
,
in_channels
=
encode_dim
*
2
,
out_channels
=
encode_dim
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
output_padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_pad2d
=
paddle
.
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
"replicate"
)
self
.
_out_conv
=
SNConv
(
name
=
name
+
"_out_conv"
,
in_channels
=
encode_dim
,
out_channels
=
out_channels
,
kernel_size
=
3
,
use_bias
=
use_bias
,
norm_layer
=
None
,
act
=
out_conv_act
,
act_attr
=
out_conv_act_attr
)
def
forward
(
self
,
x
):
if
isinstance
(
x
,
(
list
,
tuple
)):
x
=
paddle
.
concat
(
x
,
axis
=
1
)
output_dict
=
dict
()
output_dict
[
"conv_blocks"
]
=
self
.
conv_blocks
.
forward
(
x
)
output_dict
[
"up1"
]
=
self
.
_up1
.
forward
(
output_dict
[
"conv_blocks"
])
output_dict
[
"up2"
]
=
self
.
_up2
.
forward
(
output_dict
[
"up1"
])
output_dict
[
"up3"
]
=
self
.
_up3
.
forward
(
output_dict
[
"up2"
])
output_dict
[
"pad2d"
]
=
self
.
_pad2d
.
forward
(
output_dict
[
"up3"
])
output_dict
[
"out_conv"
]
=
self
.
_out_conv
.
forward
(
output_dict
[
"pad2d"
])
return
output_dict
class
DecoderUnet
(
nn
.
Layer
):
def
__init__
(
self
,
name
,
encode_dim
,
out_channels
,
use_bias
,
norm_layer
,
act
,
act_attr
,
conv_block_dropout
,
conv_block_num
,
conv_block_dilation
,
out_conv_act
,
out_conv_act_attr
):
super
(
DecoderUnet
,
self
).
__init__
()
conv_blocks
=
[]
for
i
in
range
(
conv_block_num
):
conv_blocks
.
append
(
ResBlock
(
name
=
"{}_conv_block_{}"
.
format
(
name
,
i
),
channels
=
encode_dim
*
8
,
norm_layer
=
norm_layer
,
use_dropout
=
conv_block_dropout
,
use_dilation
=
conv_block_dilation
,
use_bias
=
use_bias
))
self
.
_conv_blocks
=
nn
.
Sequential
(
*
conv_blocks
)
self
.
_up1
=
SNConvTranspose
(
name
=
name
+
"_up1"
,
in_channels
=
encode_dim
*
8
,
out_channels
=
encode_dim
*
4
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
output_padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_up2
=
SNConvTranspose
(
name
=
name
+
"_up2"
,
in_channels
=
encode_dim
*
8
,
out_channels
=
encode_dim
*
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
output_padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_up3
=
SNConvTranspose
(
name
=
name
+
"_up3"
,
in_channels
=
encode_dim
*
4
,
out_channels
=
encode_dim
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
output_padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_pad2d
=
paddle
.
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
"replicate"
)
self
.
_out_conv
=
SNConv
(
name
=
name
+
"_out_conv"
,
in_channels
=
encode_dim
,
out_channels
=
out_channels
,
kernel_size
=
3
,
use_bias
=
use_bias
,
norm_layer
=
None
,
act
=
out_conv_act
,
act_attr
=
out_conv_act_attr
)
def
forward
(
self
,
x
,
y
,
feature2
,
feature1
):
output_dict
=
dict
()
output_dict
[
"conv_blocks"
]
=
self
.
_conv_blocks
(
paddle
.
concat
(
(
x
,
y
),
axis
=
1
))
output_dict
[
"up1"
]
=
self
.
_up1
.
forward
(
output_dict
[
"conv_blocks"
])
output_dict
[
"up2"
]
=
self
.
_up2
.
forward
(
paddle
.
concat
(
(
output_dict
[
"up1"
],
feature2
),
axis
=
1
))
output_dict
[
"up3"
]
=
self
.
_up3
.
forward
(
paddle
.
concat
(
(
output_dict
[
"up2"
],
feature1
),
axis
=
1
))
output_dict
[
"pad2d"
]
=
self
.
_pad2d
.
forward
(
output_dict
[
"up3"
])
output_dict
[
"out_conv"
]
=
self
.
_out_conv
.
forward
(
output_dict
[
"pad2d"
])
return
output_dict
class
SingleDecoder
(
nn
.
Layer
):
def
__init__
(
self
,
name
,
encode_dim
,
out_channels
,
use_bias
,
norm_layer
,
act
,
act_attr
,
conv_block_dropout
,
conv_block_num
,
conv_block_dilation
,
out_conv_act
,
out_conv_act_attr
):
super
(
SingleDecoder
,
self
).
__init__
()
conv_blocks
=
[]
for
i
in
range
(
conv_block_num
):
conv_blocks
.
append
(
ResBlock
(
name
=
"{}_conv_block_{}"
.
format
(
name
,
i
),
channels
=
encode_dim
*
4
,
norm_layer
=
norm_layer
,
use_dropout
=
conv_block_dropout
,
use_dilation
=
conv_block_dilation
,
use_bias
=
use_bias
))
self
.
_conv_blocks
=
nn
.
Sequential
(
*
conv_blocks
)
self
.
_up1
=
SNConvTranspose
(
name
=
name
+
"_up1"
,
in_channels
=
encode_dim
*
4
,
out_channels
=
encode_dim
*
4
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
output_padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_up2
=
SNConvTranspose
(
name
=
name
+
"_up2"
,
in_channels
=
encode_dim
*
8
,
out_channels
=
encode_dim
*
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
output_padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_up3
=
SNConvTranspose
(
name
=
name
+
"_up3"
,
in_channels
=
encode_dim
*
4
,
out_channels
=
encode_dim
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
output_padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_pad2d
=
paddle
.
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
"replicate"
)
self
.
_out_conv
=
SNConv
(
name
=
name
+
"_out_conv"
,
in_channels
=
encode_dim
,
out_channels
=
out_channels
,
kernel_size
=
3
,
use_bias
=
use_bias
,
norm_layer
=
None
,
act
=
out_conv_act
,
act_attr
=
out_conv_act_attr
)
def
forward
(
self
,
x
,
feature2
,
feature1
):
output_dict
=
dict
()
output_dict
[
"conv_blocks"
]
=
self
.
_conv_blocks
.
forward
(
x
)
output_dict
[
"up1"
]
=
self
.
_up1
.
forward
(
output_dict
[
"conv_blocks"
])
output_dict
[
"up2"
]
=
self
.
_up2
.
forward
(
paddle
.
concat
(
(
output_dict
[
"up1"
],
feature2
),
axis
=
1
))
output_dict
[
"up3"
]
=
self
.
_up3
.
forward
(
paddle
.
concat
(
(
output_dict
[
"up2"
],
feature1
),
axis
=
1
))
output_dict
[
"pad2d"
]
=
self
.
_pad2d
.
forward
(
output_dict
[
"up3"
])
output_dict
[
"out_conv"
]
=
self
.
_out_conv
.
forward
(
output_dict
[
"pad2d"
])
return
output_dict
tools/style_text_rec/arch/encoder.py
0 → 100644
浏览文件 @
f2d98c5e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
from
arch.base_module
import
SNConv
,
SNConvTranspose
,
ResBlock
class
Encoder
(
nn
.
Layer
):
def
__init__
(
self
,
name
,
in_channels
,
encode_dim
,
use_bias
,
norm_layer
,
act
,
act_attr
,
conv_block_dropout
,
conv_block_num
,
conv_block_dilation
):
super
(
Encoder
,
self
).
__init__
()
self
.
_pad2d
=
paddle
.
nn
.
Pad2D
([
3
,
3
,
3
,
3
],
mode
=
"replicate"
)
self
.
_in_conv
=
SNConv
(
name
=
name
+
"_in_conv"
,
in_channels
=
in_channels
,
out_channels
=
encode_dim
,
kernel_size
=
7
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_down1
=
SNConv
(
name
=
name
+
"_down1"
,
in_channels
=
encode_dim
,
out_channels
=
encode_dim
*
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_down2
=
SNConv
(
name
=
name
+
"_down2"
,
in_channels
=
encode_dim
*
2
,
out_channels
=
encode_dim
*
4
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_down3
=
SNConv
(
name
=
name
+
"_down3"
,
in_channels
=
encode_dim
*
4
,
out_channels
=
encode_dim
*
4
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
conv_blocks
=
[]
for
i
in
range
(
conv_block_num
):
conv_blocks
.
append
(
ResBlock
(
name
=
"{}_conv_block_{}"
.
format
(
name
,
i
),
channels
=
encode_dim
*
4
,
norm_layer
=
norm_layer
,
use_dropout
=
conv_block_dropout
,
use_dilation
=
conv_block_dilation
,
use_bias
=
use_bias
))
self
.
_conv_blocks
=
nn
.
Sequential
(
*
conv_blocks
)
def
forward
(
self
,
x
):
out_dict
=
dict
()
x
=
self
.
_pad2d
(
x
)
out_dict
[
"in_conv"
]
=
self
.
_in_conv
.
forward
(
x
)
out_dict
[
"down1"
]
=
self
.
_down1
.
forward
(
out_dict
[
"in_conv"
])
out_dict
[
"down2"
]
=
self
.
_down2
.
forward
(
out_dict
[
"down1"
])
out_dict
[
"down3"
]
=
self
.
_down3
.
forward
(
out_dict
[
"down2"
])
out_dict
[
"res_blocks"
]
=
self
.
_conv_blocks
.
forward
(
out_dict
[
"down3"
])
return
out_dict
class
EncoderUnet
(
nn
.
Layer
):
def
__init__
(
self
,
name
,
in_channels
,
encode_dim
,
use_bias
,
norm_layer
,
act
,
act_attr
):
super
(
EncoderUnet
,
self
).
__init__
()
self
.
_pad2d
=
paddle
.
nn
.
Pad2D
([
3
,
3
,
3
,
3
],
mode
=
"replicate"
)
self
.
_in_conv
=
SNConv
(
name
=
name
+
"_in_conv"
,
in_channels
=
in_channels
,
out_channels
=
encode_dim
,
kernel_size
=
7
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_down1
=
SNConv
(
name
=
name
+
"_down1"
,
in_channels
=
encode_dim
,
out_channels
=
encode_dim
*
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_down2
=
SNConv
(
name
=
name
+
"_down2"
,
in_channels
=
encode_dim
*
2
,
out_channels
=
encode_dim
*
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_down3
=
SNConv
(
name
=
name
+
"_down3"
,
in_channels
=
encode_dim
*
2
,
out_channels
=
encode_dim
*
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_down4
=
SNConv
(
name
=
name
+
"_down4"
,
in_channels
=
encode_dim
*
2
,
out_channels
=
encode_dim
*
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_up1
=
SNConvTranspose
(
name
=
name
+
"_up1"
,
in_channels
=
encode_dim
*
2
,
out_channels
=
encode_dim
*
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
self
.
_up2
=
SNConvTranspose
(
name
=
name
+
"_up2"
,
in_channels
=
encode_dim
*
4
,
out_channels
=
encode_dim
*
4
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
act
,
act_attr
=
act_attr
)
def
forward
(
self
,
x
):
output_dict
=
dict
()
x
=
self
.
_pad2d
(
x
)
output_dict
[
'in_conv'
]
=
self
.
_in_conv
.
forward
(
x
)
output_dict
[
'down1'
]
=
self
.
_down1
.
forward
(
output_dict
[
'in_conv'
])
output_dict
[
'down2'
]
=
self
.
_down2
.
forward
(
output_dict
[
'down1'
])
output_dict
[
'down3'
]
=
self
.
_down3
.
forward
(
output_dict
[
'down2'
])
output_dict
[
'down4'
]
=
self
.
_down4
.
forward
(
output_dict
[
'down3'
])
output_dict
[
'up1'
]
=
self
.
_up1
.
forward
(
output_dict
[
'down4'
])
output_dict
[
'up2'
]
=
self
.
_up2
.
forward
(
paddle
.
concat
(
(
output_dict
[
'down3'
],
output_dict
[
'up1'
]),
axis
=
1
))
output_dict
[
'concat'
]
=
paddle
.
concat
(
(
output_dict
[
'down2'
],
output_dict
[
'up2'
]),
axis
=
1
)
return
output_dict
tools/style_text_rec/arch/spectral_norm.py
0 → 100644
浏览文件 @
f2d98c5e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
def
normal_
(
x
,
mean
=
0.
,
std
=
1.
):
temp_value
=
paddle
.
normal
(
mean
,
std
,
shape
=
x
.
shape
)
x
.
set_value
(
temp_value
)
return
x
class
SpectralNorm
(
object
):
def
__init__
(
self
,
name
=
'weight'
,
n_power_iterations
=
1
,
dim
=
0
,
eps
=
1e-12
):
self
.
name
=
name
self
.
dim
=
dim
if
n_power_iterations
<=
0
:
raise
ValueError
(
'Expected n_power_iterations to be positive, but '
'got n_power_iterations={}'
.
format
(
n_power_iterations
))
self
.
n_power_iterations
=
n_power_iterations
self
.
eps
=
eps
def
reshape_weight_to_matrix
(
self
,
weight
):
weight_mat
=
weight
if
self
.
dim
!=
0
:
# transpose dim to front
weight_mat
=
weight_mat
.
transpose
([
self
.
dim
,
*
[
d
for
d
in
range
(
weight_mat
.
dim
())
if
d
!=
self
.
dim
]
])
height
=
weight_mat
.
shape
[
0
]
return
weight_mat
.
reshape
([
height
,
-
1
])
def
compute_weight
(
self
,
module
,
do_power_iteration
):
weight
=
getattr
(
module
,
self
.
name
+
'_orig'
)
u
=
getattr
(
module
,
self
.
name
+
'_u'
)
v
=
getattr
(
module
,
self
.
name
+
'_v'
)
weight_mat
=
self
.
reshape_weight_to_matrix
(
weight
)
if
do_power_iteration
:
with
paddle
.
no_grad
():
for
_
in
range
(
self
.
n_power_iterations
):
v
.
set_value
(
F
.
normalize
(
paddle
.
matmul
(
weight_mat
,
u
,
transpose_x
=
True
,
transpose_y
=
False
),
axis
=
0
,
epsilon
=
self
.
eps
,
))
u
.
set_value
(
F
.
normalize
(
paddle
.
matmul
(
weight_mat
,
v
),
axis
=
0
,
epsilon
=
self
.
eps
,
))
if
self
.
n_power_iterations
>
0
:
u
=
u
.
clone
()
v
=
v
.
clone
()
sigma
=
paddle
.
dot
(
u
,
paddle
.
mv
(
weight_mat
,
v
))
weight
=
weight
/
sigma
return
weight
def
remove
(
self
,
module
):
with
paddle
.
no_grad
():
weight
=
self
.
compute_weight
(
module
,
do_power_iteration
=
False
)
delattr
(
module
,
self
.
name
)
delattr
(
module
,
self
.
name
+
'_u'
)
delattr
(
module
,
self
.
name
+
'_v'
)
delattr
(
module
,
self
.
name
+
'_orig'
)
module
.
add_parameter
(
self
.
name
,
weight
.
detach
())
def
__call__
(
self
,
module
,
inputs
):
setattr
(
module
,
self
.
name
,
self
.
compute_weight
(
module
,
do_power_iteration
=
module
.
training
))
@
staticmethod
def
apply
(
module
,
name
,
n_power_iterations
,
dim
,
eps
):
for
k
,
hook
in
module
.
_forward_pre_hooks
.
items
():
if
isinstance
(
hook
,
SpectralNorm
)
and
hook
.
name
==
name
:
raise
RuntimeError
(
"Cannot register two spectral_norm hooks on "
"the same parameter {}"
.
format
(
name
))
fn
=
SpectralNorm
(
name
,
n_power_iterations
,
dim
,
eps
)
weight
=
module
.
_parameters
[
name
]
with
paddle
.
no_grad
():
weight_mat
=
fn
.
reshape_weight_to_matrix
(
weight
)
h
,
w
=
weight_mat
.
shape
# randomly initialize u and v
u
=
module
.
create_parameter
([
h
])
u
=
normal_
(
u
,
0.
,
1.
)
v
=
module
.
create_parameter
([
w
])
v
=
normal_
(
v
,
0.
,
1.
)
u
=
F
.
normalize
(
u
,
axis
=
0
,
epsilon
=
fn
.
eps
)
v
=
F
.
normalize
(
v
,
axis
=
0
,
epsilon
=
fn
.
eps
)
# delete fn.name form parameters, otherwise you can not set attribute
del
module
.
_parameters
[
fn
.
name
]
module
.
add_parameter
(
fn
.
name
+
"_orig"
,
weight
)
# still need to assign weight back as fn.name because all sorts of
# things may assume that it exists, e.g., when initializing weights.
# However, we can't directly assign as it could be an Parameter and
# gets added as a parameter. Instead, we register weight * 1.0 as a plain
# attribute.
setattr
(
module
,
fn
.
name
,
weight
*
1.0
)
module
.
register_buffer
(
fn
.
name
+
"_u"
,
u
)
module
.
register_buffer
(
fn
.
name
+
"_v"
,
v
)
module
.
register_forward_pre_hook
(
fn
)
return
fn
def
spectral_norm
(
module
,
name
=
'weight'
,
n_power_iterations
=
1
,
eps
=
1e-12
,
dim
=
None
):
if
dim
is
None
:
if
isinstance
(
module
,
(
nn
.
Conv1DTranspose
,
nn
.
Conv2DTranspose
,
nn
.
Conv3DTranspose
,
nn
.
Linear
)):
dim
=
1
else
:
dim
=
0
SpectralNorm
.
apply
(
module
,
name
,
n_power_iterations
,
dim
,
eps
)
return
module
tools/style_text_rec/arch/style_text_rec.py
0 → 100644
浏览文件 @
f2d98c5e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
cv2
import
math
import
paddle
import
paddle.nn
as
nn
from
arch.base_module
import
MiddleNet
,
ResBlock
from
arch.encoder
import
Encoder
from
arch.decoder
import
Decoder
,
DecoderUnet
,
SingleDecoder
from
utils.load_params
import
load_dygraph_pretrain
from
utils.logging
import
get_logger
class
StyleTextRec
(
nn
.
Layer
):
def
__init__
(
self
,
config
):
super
(
StyleTextRec
,
self
).
__init__
()
self
.
logger
=
get_logger
()
self
.
text_generator
=
TextGenerator
(
config
[
"Predictor"
][
"text_generator"
])
self
.
bg_generator
=
BgGeneratorWithMask
(
config
[
"Predictor"
][
"bg_generator"
])
self
.
fusion_generator
=
FusionGeneratorSimple
(
config
[
"Predictor"
][
"fusion_generator"
])
bg_generator_pretrain
=
config
[
"Predictor"
][
"bg_generator"
][
"pretrain"
]
text_generator_pretrain
=
config
[
"Predictor"
][
"text_generator"
][
"pretrain"
]
fusion_generator_pretrain
=
config
[
"Predictor"
][
"fusion_generator"
][
"pretrain"
]
load_dygraph_pretrain
(
self
.
bg_generator
,
self
.
logger
,
path
=
bg_generator_pretrain
,
load_static_weights
=
False
)
load_dygraph_pretrain
(
self
.
text_generator
,
self
.
logger
,
path
=
text_generator_pretrain
,
load_static_weights
=
False
)
load_dygraph_pretrain
(
self
.
fusion_generator
,
self
.
logger
,
path
=
fusion_generator_pretrain
,
load_static_weights
=
False
)
def
forward
(
self
,
style_input
,
text_input
):
text_gen_output
=
self
.
text_generator
.
forward
(
style_input
,
text_input
)
fake_text
=
text_gen_output
[
"fake_text"
]
fake_sk
=
text_gen_output
[
"fake_sk"
]
bg_gen_output
=
self
.
bg_generator
.
forward
(
style_input
)
bg_encode_feature
=
bg_gen_output
[
"bg_encode_feature"
]
bg_decode_feature1
=
bg_gen_output
[
"bg_decode_feature1"
]
bg_decode_feature2
=
bg_gen_output
[
"bg_decode_feature2"
]
fake_bg
=
bg_gen_output
[
"fake_bg"
]
fusion_gen_output
=
self
.
fusion_generator
.
forward
(
fake_text
,
fake_bg
)
fake_fusion
=
fusion_gen_output
[
"fake_fusion"
]
return
{
"fake_fusion"
:
fake_fusion
,
"fake_text"
:
fake_text
,
"fake_sk"
:
fake_sk
,
"fake_bg"
:
fake_bg
,
}
class
TextGenerator
(
nn
.
Layer
):
def
__init__
(
self
,
config
):
super
(
TextGenerator
,
self
).
__init__
()
name
=
config
[
"module_name"
]
encode_dim
=
config
[
"encode_dim"
]
norm_layer
=
config
[
"norm_layer"
]
conv_block_dropout
=
config
[
"conv_block_dropout"
]
conv_block_num
=
config
[
"conv_block_num"
]
conv_block_dilation
=
config
[
"conv_block_dilation"
]
if
norm_layer
==
"InstanceNorm2D"
:
use_bias
=
True
else
:
use_bias
=
False
self
.
encoder_text
=
Encoder
(
name
=
name
+
"_encoder_text"
,
in_channels
=
3
,
encode_dim
=
encode_dim
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
"ReLU"
,
act_attr
=
None
,
conv_block_dropout
=
conv_block_dropout
,
conv_block_num
=
conv_block_num
,
conv_block_dilation
=
conv_block_dilation
)
self
.
encoder_style
=
Encoder
(
name
=
name
+
"_encoder_style"
,
in_channels
=
3
,
encode_dim
=
encode_dim
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
"ReLU"
,
act_attr
=
None
,
conv_block_dropout
=
conv_block_dropout
,
conv_block_num
=
conv_block_num
,
conv_block_dilation
=
conv_block_dilation
)
self
.
decoder_text
=
Decoder
(
name
=
name
+
"_decoder_text"
,
encode_dim
=
encode_dim
,
out_channels
=
int
(
encode_dim
/
2
),
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
"ReLU"
,
act_attr
=
None
,
conv_block_dropout
=
conv_block_dropout
,
conv_block_num
=
conv_block_num
,
conv_block_dilation
=
conv_block_dilation
,
out_conv_act
=
"Tanh"
,
out_conv_act_attr
=
None
)
self
.
decoder_sk
=
Decoder
(
name
=
name
+
"_decoder_sk"
,
encode_dim
=
encode_dim
,
out_channels
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
"ReLU"
,
act_attr
=
None
,
conv_block_dropout
=
conv_block_dropout
,
conv_block_num
=
conv_block_num
,
conv_block_dilation
=
conv_block_dilation
,
out_conv_act
=
"Sigmoid"
,
out_conv_act_attr
=
None
)
self
.
middle
=
MiddleNet
(
name
=
name
+
"_middle_net"
,
in_channels
=
int
(
encode_dim
/
2
)
+
1
,
mid_channels
=
encode_dim
,
out_channels
=
3
,
use_bias
=
use_bias
)
def
forward
(
self
,
style_input
,
text_input
):
style_feature
=
self
.
encoder_style
.
forward
(
style_input
)[
"res_blocks"
]
text_feature
=
self
.
encoder_text
.
forward
(
text_input
)[
"res_blocks"
]
fake_c_temp
=
self
.
decoder_text
.
forward
([
text_feature
,
style_feature
])[
"out_conv"
]
fake_sk
=
self
.
decoder_sk
.
forward
([
text_feature
,
style_feature
])[
"out_conv"
]
fake_text
=
self
.
middle
(
paddle
.
concat
((
fake_c_temp
,
fake_sk
),
axis
=
1
))
return
{
"fake_sk"
:
fake_sk
,
"fake_text"
:
fake_text
}
class
BgGeneratorWithMask
(
nn
.
Layer
):
def
__init__
(
self
,
config
):
super
(
BgGeneratorWithMask
,
self
).
__init__
()
name
=
config
[
"module_name"
]
encode_dim
=
config
[
"encode_dim"
]
norm_layer
=
config
[
"norm_layer"
]
conv_block_dropout
=
config
[
"conv_block_dropout"
]
conv_block_num
=
config
[
"conv_block_num"
]
conv_block_dilation
=
config
[
"conv_block_dilation"
]
self
.
output_factor
=
config
.
get
(
"output_factor"
,
1.0
)
if
norm_layer
==
"InstanceNorm2D"
:
use_bias
=
True
else
:
use_bias
=
False
self
.
encoder_bg
=
Encoder
(
name
=
name
+
"_encoder_bg"
,
in_channels
=
3
,
encode_dim
=
encode_dim
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
"ReLU"
,
act_attr
=
None
,
conv_block_dropout
=
conv_block_dropout
,
conv_block_num
=
conv_block_num
,
conv_block_dilation
=
conv_block_dilation
)
self
.
decoder_bg
=
SingleDecoder
(
name
=
name
+
"_decoder_bg"
,
encode_dim
=
encode_dim
,
out_channels
=
3
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
"ReLU"
,
act_attr
=
None
,
conv_block_dropout
=
conv_block_dropout
,
conv_block_num
=
conv_block_num
,
conv_block_dilation
=
conv_block_dilation
,
out_conv_act
=
"Tanh"
,
out_conv_act_attr
=
None
)
self
.
decoder_mask
=
Decoder
(
name
=
name
+
"_decoder_mask"
,
encode_dim
=
encode_dim
//
2
,
out_channels
=
1
,
use_bias
=
use_bias
,
norm_layer
=
norm_layer
,
act
=
"ReLU"
,
act_attr
=
None
,
conv_block_dropout
=
conv_block_dropout
,
conv_block_num
=
conv_block_num
,
conv_block_dilation
=
conv_block_dilation
,
out_conv_act
=
"Sigmoid"
,
out_conv_act_attr
=
None
)
self
.
middle
=
MiddleNet
(
name
=
name
+
"_middle_net"
,
in_channels
=
3
+
1
,
mid_channels
=
encode_dim
,
out_channels
=
3
,
use_bias
=
use_bias
)
def
forward
(
self
,
style_input
):
encode_bg_output
=
self
.
encoder_bg
(
style_input
)
decode_bg_output
=
self
.
decoder_bg
(
encode_bg_output
[
"res_blocks"
],
encode_bg_output
[
"down2"
],
encode_bg_output
[
"down1"
])
fake_c_temp
=
decode_bg_output
[
"out_conv"
]
fake_bg_mask
=
self
.
decoder_mask
.
forward
(
encode_bg_output
[
"res_blocks"
])[
"out_conv"
]
fake_bg
=
self
.
middle
(
paddle
.
concat
(
(
fake_c_temp
,
fake_bg_mask
),
axis
=
1
))
return
{
"bg_encode_feature"
:
encode_bg_output
[
"res_blocks"
],
"bg_decode_feature1"
:
decode_bg_output
[
"up1"
],
"bg_decode_feature2"
:
decode_bg_output
[
"up2"
],
"fake_bg"
:
fake_bg
,
"fake_bg_mask"
:
fake_bg_mask
,
}
class
FusionGeneratorSimple
(
nn
.
Layer
):
def
__init__
(
self
,
config
):
super
(
FusionGeneratorSimple
,
self
).
__init__
()
name
=
config
[
"module_name"
]
encode_dim
=
config
[
"encode_dim"
]
norm_layer
=
config
[
"norm_layer"
]
conv_block_dropout
=
config
[
"conv_block_dropout"
]
conv_block_dilation
=
config
[
"conv_block_dilation"
]
if
norm_layer
==
"InstanceNorm2D"
:
use_bias
=
True
else
:
use_bias
=
False
self
.
_conv
=
nn
.
Conv2D
(
in_channels
=
6
,
out_channels
=
encode_dim
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
groups
=
1
,
weight_attr
=
paddle
.
ParamAttr
(
name
=
name
+
"_conv_weights"
),
bias_attr
=
False
)
self
.
_res_block
=
ResBlock
(
name
=
"{}_conv_block"
.
format
(
name
),
channels
=
encode_dim
,
norm_layer
=
norm_layer
,
use_dropout
=
conv_block_dropout
,
use_dilation
=
conv_block_dilation
,
use_bias
=
use_bias
)
self
.
_reduce_conv
=
nn
.
Conv2D
(
in_channels
=
encode_dim
,
out_channels
=
3
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
groups
=
1
,
weight_attr
=
paddle
.
ParamAttr
(
name
=
name
+
"_reduce_conv_weights"
),
bias_attr
=
False
)
def
forward
(
self
,
fake_text
,
fake_bg
):
fake_concat
=
paddle
.
concat
((
fake_text
,
fake_bg
),
axis
=
1
)
fake_concat_tmp
=
self
.
_conv
(
fake_concat
)
output_res
=
self
.
_res_block
(
fake_concat_tmp
)
fake_fusion
=
self
.
_reduce_conv
(
output_res
)
return
{
"fake_fusion"
:
fake_fusion
}
tools/style_text_rec/configs/config.yml
0 → 100644
浏览文件 @
f2d98c5e
Global
:
output_num
:
10
output_dir
:
output_data
use_gpu
:
false
image_height
:
32
image_width
:
320
TextDrawer
:
fonts
:
en
:
fonts/en_standard.ttf
ch
:
fonts/ch_standard.ttf
ko
:
fonts/ko_standard.ttf
Predictor
:
method
:
StyleTextRecPredictor
algorithm
:
StyleTextRec
scale
:
0.00392156862745098
mean
:
-
0.5
-
0.5
-
0.5
std
:
-
0.5
-
0.5
-
0.5
expand_result
:
false
bg_generator
:
pretrain
:
style_text_models/bg_generator
module_name
:
bg_generator
generator_type
:
BgGeneratorWithMask
encode_dim
:
64
norm_layer
:
null
conv_block_num
:
4
conv_block_dropout
:
false
conv_block_dilation
:
true
output_factor
:
1.05
text_generator
:
pretrain
:
style_text_models/text_generator
module_name
:
text_generator
generator_type
:
TextGenerator
encode_dim
:
64
norm_layer
:
InstanceNorm2D
conv_block_num
:
4
conv_block_dropout
:
false
conv_block_dilation
:
true
fusion_generator
:
pretrain
:
style_text_models/fusion_generator
module_name
:
fusion_generator
generator_type
:
FusionGeneratorSimple
encode_dim
:
64
norm_layer
:
null
conv_block_num
:
4
conv_block_dropout
:
false
conv_block_dilation
:
true
Writer
:
method
:
SimpleWriter
tools/style_text_rec/configs/dataset_config.yml
0 → 100644
浏览文件 @
f2d98c5e
Global
:
output_num
:
10
output_dir
:
output_data
use_gpu
:
false
image_height
:
32
image_width
:
320
standard_font
:
fonts/en_standard.ttf
TextDrawer
:
fonts
:
en
:
fonts/en_standard.ttf
ch
:
fonts/ch_standard.ttf
ko
:
fonts/ko_standard.ttf
StyleSampler
:
method
:
DatasetSampler
image_home
:
examples
label_file
:
examples/image_list.txt
with_label
:
true
CorpusGenerator
:
method
:
FileCorpus
language
:
ch
corpus_file
:
examples/corpus/example.txt
Predictor
:
method
:
StyleTextRecPredictor
algorithm
:
StyleTextRec
scale
:
0.00392156862745098
mean
:
-
0.5
-
0.5
-
0.5
std
:
-
0.5
-
0.5
-
0.5
expand_result
:
false
bg_generator
:
pretrain
:
style_text_models/bg_generator
module_name
:
bg_generator
generator_type
:
BgGeneratorWithMask
encode_dim
:
64
norm_layer
:
null
conv_block_num
:
4
conv_block_dropout
:
false
conv_block_dilation
:
true
output_factor
:
1.05
text_generator
:
pretrain
:
style_text_models/text_generator
module_name
:
text_generator
generator_type
:
TextGenerator
encode_dim
:
64
norm_layer
:
InstanceNorm2D
conv_block_num
:
4
conv_block_dropout
:
false
conv_block_dilation
:
true
fusion_generator
:
pretrain
:
style_text_models/fusion_generator
module_name
:
fusion_generator
generator_type
:
FusionGeneratorSimple
encode_dim
:
64
norm_layer
:
null
conv_block_num
:
4
conv_block_dropout
:
false
conv_block_dilation
:
true
Writer
:
method
:
SimpleWriter
tools/style_text_rec/engine/corpus_generators.py
0 → 100644
浏览文件 @
f2d98c5e
import
random
from
PIL
import
Image
,
ImageDraw
,
ImageFont
import
numpy
as
np
from
utils.logging
import
get_logger
class
FileCorpus
(
object
):
def
__init__
(
self
,
config
):
self
.
logger
=
get_logger
()
self
.
logger
.
info
(
"using FileCorpus"
)
self
.
char_list
=
" 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
corpus_file
=
config
[
"CorpusGenerator"
][
"corpus_file"
]
self
.
language
=
config
[
"CorpusGenerator"
][
"language"
]
with
open
(
corpus_file
,
'r'
)
as
f
:
corpus_raw
=
f
.
read
()
self
.
corpus_list
=
corpus_raw
.
split
(
"
\n
"
)[:
-
1
]
assert
len
(
self
.
corpus_list
)
>
0
random
.
shuffle
(
self
.
corpus_list
)
self
.
index
=
0
def
generate
(
self
,
corpus_length
=
0
):
if
self
.
index
>=
len
(
self
.
corpus_list
):
self
.
index
=
0
random
.
shuffle
(
self
.
corpus_list
)
corpus
=
self
.
corpus_list
[
self
.
index
]
if
corpus_length
!=
0
:
corpus
=
corpus
[
0
:
corpus_length
]
if
corpus_length
>
len
(
corpus
):
self
.
logger
.
warning
(
"generated corpus is shorter than expected."
)
self
.
index
+=
1
return
self
.
language
,
corpus
class
EnNumCorpus
(
object
):
def
__init__
(
self
,
config
):
self
.
logger
=
get_logger
()
self
.
logger
.
info
(
"using NumberCorpus"
)
self
.
num_list
=
"0123456789"
self
.
en_char_list
=
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
self
.
height
=
config
[
"Global"
][
"image_height"
]
self
.
max_width
=
config
[
"Global"
][
"image_width"
]
def
generate
(
self
,
corpus_length
=
0
):
corpus
=
""
if
corpus_length
==
0
:
corpus_length
=
random
.
randint
(
5
,
15
)
for
i
in
range
(
corpus_length
):
if
random
.
random
()
<
0.2
:
corpus
+=
"{}"
.
format
(
random
.
choice
(
self
.
en_char_list
))
else
:
corpus
+=
"{}"
.
format
(
random
.
choice
(
self
.
num_list
))
return
"en"
,
corpus
tools/style_text_rec/engine/predictors.py
0 → 100644
浏览文件 @
f2d98c5e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
cv2
import
math
import
paddle
from
arch
import
style_text_rec
from
utils.sys_funcs
import
check_gpu
from
utils.logging
import
get_logger
class
StyleTextRecPredictor
(
object
):
def
__init__
(
self
,
config
):
algorithm
=
config
[
'Predictor'
][
'algorithm'
]
assert
algorithm
in
[
"StyleTextRec"
],
"Generator {} not supported."
.
format
(
algorithm
)
use_gpu
=
config
[
"Global"
][
'use_gpu'
]
check_gpu
(
use_gpu
)
self
.
logger
=
get_logger
()
self
.
generator
=
getattr
(
style_text_rec
,
algorithm
)(
config
)
self
.
height
=
config
[
"Global"
][
"image_height"
]
self
.
width
=
config
[
"Global"
][
"image_width"
]
self
.
scale
=
config
[
"Predictor"
][
"scale"
]
self
.
mean
=
config
[
"Predictor"
][
"mean"
]
self
.
std
=
config
[
"Predictor"
][
"std"
]
self
.
expand_result
=
config
[
"Predictor"
][
"expand_result"
]
def
predict
(
self
,
style_input
,
text_input
):
style_input
=
self
.
rep_style_input
(
style_input
,
text_input
)
tensor_style_input
=
self
.
preprocess
(
style_input
)
tensor_text_input
=
self
.
preprocess
(
text_input
)
style_text_result
=
self
.
generator
.
forward
(
tensor_style_input
,
tensor_text_input
)
fake_fusion
=
self
.
postprocess
(
style_text_result
[
"fake_fusion"
])
fake_text
=
self
.
postprocess
(
style_text_result
[
"fake_text"
])
fake_sk
=
self
.
postprocess
(
style_text_result
[
"fake_sk"
])
fake_bg
=
self
.
postprocess
(
style_text_result
[
"fake_bg"
])
bbox
=
self
.
get_text_boundary
(
fake_text
)
if
bbox
:
left
,
right
,
top
,
bottom
=
bbox
fake_fusion
=
fake_fusion
[
top
:
bottom
,
left
:
right
,
:]
fake_text
=
fake_text
[
top
:
bottom
,
left
:
right
,
:]
fake_sk
=
fake_sk
[
top
:
bottom
,
left
:
right
,
:]
fake_bg
=
fake_bg
[
top
:
bottom
,
left
:
right
,
:]
# fake_fusion = self.crop_by_text(img_fake_fusion, img_fake_text)
return
{
"fake_fusion"
:
fake_fusion
,
"fake_text"
:
fake_text
,
"fake_sk"
:
fake_sk
,
"fake_bg"
:
fake_bg
,
}
def
preprocess
(
self
,
img
):
img
=
(
img
.
astype
(
'float32'
)
*
self
.
scale
-
self
.
mean
)
/
self
.
std
img_height
,
img_width
,
channel
=
img
.
shape
assert
channel
==
3
,
"Please use an rgb image."
ratio
=
img_width
/
float
(
img_height
)
if
math
.
ceil
(
self
.
height
*
ratio
)
>
self
.
width
:
resized_w
=
self
.
width
else
:
resized_w
=
int
(
math
.
ceil
(
self
.
height
*
ratio
))
img
=
cv2
.
resize
(
img
,
(
resized_w
,
self
.
height
))
new_img
=
np
.
zeros
([
self
.
height
,
self
.
width
,
3
]).
astype
(
'float32'
)
new_img
[:,
0
:
resized_w
,
:]
=
img
img
=
new_img
.
transpose
((
2
,
0
,
1
))
img
=
img
[
np
.
newaxis
,
:,
:,
:]
return
paddle
.
to_tensor
(
img
)
def
postprocess
(
self
,
tensor
):
img
=
tensor
.
numpy
()[
0
]
img
=
img
.
transpose
((
1
,
2
,
0
))
img
=
(
img
*
self
.
std
+
self
.
mean
)
/
self
.
scale
img
=
np
.
maximum
(
img
,
0.0
)
img
=
np
.
minimum
(
img
,
255.0
)
img
=
img
.
astype
(
'uint8'
)
return
img
def
rep_style_input
(
self
,
style_input
,
text_input
):
rep_num
=
int
(
1.2
*
(
text_input
.
shape
[
1
]
/
text_input
.
shape
[
0
])
/
(
style_input
.
shape
[
1
]
/
style_input
.
shape
[
0
]))
+
1
style_input
=
np
.
tile
(
style_input
,
reps
=
[
1
,
rep_num
,
1
])
max_width
=
int
(
self
.
width
/
self
.
height
*
style_input
.
shape
[
0
])
style_input
=
style_input
[:,
:
max_width
,
:]
return
style_input
def
get_text_boundary
(
self
,
text_img
):
img_height
=
text_img
.
shape
[
0
]
img_width
=
text_img
.
shape
[
1
]
bounder
=
3
text_canny_img
=
cv2
.
Canny
(
text_img
,
10
,
20
)
edge_num_h
=
text_canny_img
.
sum
(
axis
=
0
)
no_zero_list_h
=
np
.
where
(
edge_num_h
>
0
)[
0
]
edge_num_w
=
text_canny_img
.
sum
(
axis
=
1
)
no_zero_list_w
=
np
.
where
(
edge_num_w
>
0
)[
0
]
if
len
(
no_zero_list_h
)
==
0
or
len
(
no_zero_list_w
)
==
0
:
return
None
left
=
max
(
no_zero_list_h
[
0
]
-
bounder
,
0
)
right
=
min
(
no_zero_list_h
[
-
1
]
+
bounder
,
img_width
)
top
=
max
(
no_zero_list_w
[
0
]
-
bounder
,
0
)
bottom
=
min
(
no_zero_list_w
[
-
1
]
+
bounder
,
img_height
)
return
[
left
,
right
,
top
,
bottom
]
tools/style_text_rec/engine/style_samplers.py
0 → 100644
浏览文件 @
f2d98c5e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
random
import
cv2
class
DatasetSampler
(
object
):
def
__init__
(
self
,
config
):
self
.
image_home
=
config
[
"StyleSampler"
][
"image_home"
]
label_file
=
config
[
"StyleSampler"
][
"label_file"
]
self
.
dataset_with_label
=
config
[
"StyleSampler"
][
"with_label"
]
self
.
height
=
config
[
"Global"
][
"image_height"
]
self
.
index
=
0
with
open
(
label_file
,
"r"
)
as
f
:
label_raw
=
f
.
read
()
self
.
path_label_list
=
label_raw
.
split
(
"
\n
"
)[:
-
1
]
assert
len
(
self
.
path_label_list
)
>
0
random
.
shuffle
(
self
.
path_label_list
)
def
sample
(
self
):
if
self
.
index
>=
len
(
self
.
path_label_list
):
random
.
shuffle
(
self
.
path_label_list
)
self
.
index
=
0
if
self
.
dataset_with_label
:
path_label
=
self
.
path_label_list
[
self
.
index
]
rel_image_path
,
label
=
path_label
.
split
(
'
\t
'
)
else
:
rel_image_path
=
self
.
path_label_list
[
self
.
index
]
label
=
None
img_path
=
"{}/{}"
.
format
(
self
.
image_home
,
rel_image_path
)
image
=
cv2
.
imread
(
img_path
)
origin_height
=
image
.
shape
[
0
]
ratio
=
self
.
height
/
origin_height
width
=
int
(
image
.
shape
[
1
]
*
ratio
)
height
=
int
(
image
.
shape
[
0
]
*
ratio
)
image
=
cv2
.
resize
(
image
,
(
width
,
height
))
self
.
index
+=
1
if
label
:
return
{
"image"
:
image
,
"label"
:
label
}
else
:
return
{
"image"
:
image
}
def
duplicate_image
(
image
,
width
):
image_width
=
image
.
shape
[
1
]
dup_num
=
width
//
image_width
+
1
image
=
np
.
tile
(
image
,
reps
=
[
1
,
dup_num
,
1
])
cropped_image
=
image
[:,
:
width
,
:]
return
cropped_image
tools/style_text_rec/engine/synthesisers.py
0 → 100644
浏览文件 @
f2d98c5e
import
os
from
utils.config
import
ArgsParser
,
load_config
,
override_config
from
utils.logging
import
get_logger
from
engine
import
style_samplers
,
corpus_generators
,
text_drawers
,
predictors
,
writers
class
ImageSynthesiser
(
object
):
def
__init__
(
self
):
self
.
FLAGS
=
ArgsParser
().
parse_args
()
self
.
config
=
load_config
(
self
.
FLAGS
.
config
)
self
.
config
=
override_config
(
self
.
config
,
options
=
self
.
FLAGS
.
override
)
self
.
output_dir
=
self
.
config
[
"Global"
][
"output_dir"
]
if
not
os
.
path
.
exists
(
self
.
output_dir
):
os
.
mkdir
(
self
.
output_dir
)
self
.
logger
=
get_logger
(
log_file
=
'{}/predict.log'
.
format
(
self
.
output_dir
))
self
.
text_drawer
=
text_drawers
.
StdTextDrawer
(
self
.
config
)
predictor_method
=
self
.
config
[
"Predictor"
][
"method"
]
assert
predictor_method
is
not
None
self
.
predictor
=
getattr
(
predictors
,
predictor_method
)(
self
.
config
)
def
synth_image
(
self
,
corpus
,
style_input
,
language
=
"en"
):
corpus
,
text_input
=
self
.
text_drawer
.
draw_text
(
corpus
,
language
)
synth_result
=
self
.
predictor
.
predict
(
style_input
,
text_input
)
return
synth_result
class
DatasetSynthesiser
(
ImageSynthesiser
):
def
__init__
(
self
):
super
(
DatasetSynthesiser
,
self
).
__init__
()
self
.
tag
=
self
.
FLAGS
.
tag
self
.
output_num
=
self
.
config
[
"Global"
][
"output_num"
]
corpus_generator_method
=
self
.
config
[
"CorpusGenerator"
][
"method"
]
self
.
corpus_generator
=
getattr
(
corpus_generators
,
corpus_generator_method
)(
self
.
config
)
style_sampler_method
=
self
.
config
[
"StyleSampler"
][
"method"
]
assert
style_sampler_method
is
not
None
self
.
style_sampler
=
style_samplers
.
DatasetSampler
(
self
.
config
)
self
.
writer
=
writers
.
SimpleWriter
(
self
.
config
,
self
.
tag
)
def
synth_dataset
(
self
):
for
i
in
range
(
self
.
output_num
):
style_data
=
self
.
style_sampler
.
sample
()
style_input
=
style_data
[
"image"
]
corpus_language
,
text_input_label
=
self
.
corpus_generator
.
generate
(
)
text_input_label
,
text_input
=
self
.
text_drawer
.
draw_text
(
text_input_label
,
corpus_language
)
synth_result
=
self
.
predictor
.
predict
(
style_input
,
text_input
)
fake_fusion
=
synth_result
[
"fake_fusion"
]
self
.
writer
.
save_image
(
fake_fusion
,
text_input_label
)
self
.
writer
.
save_label
()
self
.
writer
.
merge_label
()
tools/style_text_rec/engine/text_drawers.py
0 → 100644
浏览文件 @
f2d98c5e
import
random
from
PIL
import
Image
,
ImageDraw
,
ImageFont
import
numpy
as
np
from
utils.logging
import
get_logger
class
StdTextDrawer
(
object
):
def
__init__
(
self
,
config
):
self
.
logger
=
get_logger
()
self
.
max_width
=
config
[
"Global"
][
"image_width"
]
self
.
char_list
=
" 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
self
.
height
=
config
[
"Global"
][
"image_height"
]
self
.
font_dict
=
{}
self
.
load_fonts
(
config
[
"TextDrawer"
][
"fonts"
])
self
.
support_languages
=
list
(
self
.
font_dict
)
def
load_fonts
(
self
,
fonts_config
):
for
language
in
fonts_config
:
font_path
=
fonts_config
[
language
]
font_height
=
self
.
get_valid_height
(
font_path
)
font
=
ImageFont
.
truetype
(
font_path
,
font_height
)
self
.
font_dict
[
language
]
=
font
def
get_valid_height
(
self
,
font_path
):
font
=
ImageFont
.
truetype
(
font_path
,
self
.
height
-
4
)
_
,
font_height
=
font
.
getsize
(
self
.
char_list
)
if
font_height
<=
self
.
height
-
4
:
return
self
.
height
-
4
else
:
return
int
((
self
.
height
-
4
)
**
2
/
font_height
)
def
draw_text
(
self
,
corpus
,
language
=
"en"
,
crop
=
True
):
if
language
not
in
self
.
support_languages
:
self
.
logger
.
warning
(
"language {} not supported, use en instead."
.
format
(
language
))
language
=
"en"
if
crop
:
width
=
min
(
self
.
max_width
,
len
(
corpus
)
*
self
.
height
)
+
4
else
:
width
=
len
(
corpus
)
*
self
.
height
+
4
bg
=
Image
.
new
(
"RGB"
,
(
width
,
self
.
height
),
color
=
(
127
,
127
,
127
))
draw
=
ImageDraw
.
Draw
(
bg
)
char_x
=
2
font
=
self
.
font_dict
[
language
]
for
i
,
char_i
in
enumerate
(
corpus
):
char_size
=
font
.
getsize
(
char_i
)[
0
]
draw
.
text
((
char_x
,
2
),
char_i
,
fill
=
(
0
,
0
,
0
),
font
=
font
)
char_x
+=
char_size
if
char_x
>=
width
:
corpus
=
corpus
[
0
:
i
+
1
]
self
.
logger
.
warning
(
"corpus length exceed limit: {}"
.
format
(
corpus
))
break
text_input
=
np
.
array
(
bg
).
astype
(
np
.
uint8
)
text_input
=
text_input
[:,
0
:
char_x
,
:]
return
corpus
,
text_input
tools/style_text_rec/engine/writers.py
0 → 100644
浏览文件 @
f2d98c5e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
glob
from
utils.logging
import
get_logger
class
SimpleWriter
(
object
):
def
__init__
(
self
,
config
,
tag
):
self
.
logger
=
get_logger
()
self
.
output_dir
=
config
[
"Global"
][
"output_dir"
]
self
.
counter
=
0
self
.
label_dict
=
{}
self
.
tag
=
tag
self
.
label_file_index
=
0
def
save_image
(
self
,
image
,
text_input_label
):
image_home
=
os
.
path
.
join
(
self
.
output_dir
,
"images"
,
self
.
tag
)
if
not
os
.
path
.
exists
(
image_home
):
os
.
makedirs
(
image_home
)
image_path
=
os
.
path
.
join
(
image_home
,
"{}.png"
.
format
(
self
.
counter
))
# todo support continue synth
cv2
.
imwrite
(
image_path
,
image
)
self
.
logger
.
info
(
"generate image: {}"
.
format
(
image_path
))
image_name
=
os
.
path
.
join
(
self
.
tag
,
"{}.png"
.
format
(
self
.
counter
))
self
.
label_dict
[
image_name
]
=
text_input_label
self
.
counter
+=
1
if
not
self
.
counter
%
100
:
self
.
save_label
()
def
save_label
(
self
):
label_raw
=
""
label_home
=
os
.
path
.
join
(
self
.
output_dir
,
"label"
)
if
not
os
.
path
.
exists
(
label_home
):
os
.
mkdir
(
label_home
)
for
image_path
in
self
.
label_dict
:
label
=
self
.
label_dict
[
image_path
]
label_raw
+=
"{}
\t
{}
\n
"
.
format
(
image_path
,
label
)
label_file_path
=
os
.
path
.
join
(
label_home
,
"{}_label.txt"
.
format
(
self
.
tag
))
with
open
(
label_file_path
,
"w"
)
as
f
:
f
.
write
(
label_raw
)
self
.
label_file_index
+=
1
def
merge_label
(
self
):
label_raw
=
""
label_file_regex
=
os
.
path
.
join
(
self
.
output_dir
,
"label"
,
"*_label.txt"
)
label_file_list
=
glob
.
glob
(
label_file_regex
)
for
label_file_i
in
label_file_list
:
with
open
(
label_file_i
,
"r"
)
as
f
:
label_raw
+=
f
.
read
()
label_file_path
=
os
.
path
.
join
(
self
.
output_dir
,
"label.txt"
)
with
open
(
label_file_path
,
"w"
)
as
f
:
f
.
write
(
label_raw
)
tools/style_text_rec/examples/corpus/example.txt
0 → 100644
浏览文件 @
f2d98c5e
PaddleOCR
飞桨文字识别
tools/style_text_rec/examples/image_list.txt
0 → 100644
浏览文件 @
f2d98c5e
style_images/1.jpg NEATNESS
style_images/2.jpg 锁店君和宾馆
tools/style_text_rec/examples/style_images/1.jpg
0 → 100644
浏览文件 @
f2d98c5e
2.5 KB
tools/style_text_rec/examples/style_images/2.jpg
0 → 100644
浏览文件 @
f2d98c5e
3.8 KB
tools/style_text_rec/fonts/ch_standard.ttf
0 → 100755
浏览文件 @
f2d98c5e
文件已添加
tools/style_text_rec/fonts/en_standard.ttf
0 → 100755
浏览文件 @
f2d98c5e
文件已添加
tools/style_text_rec/fonts/ko_standard.ttf
0 → 100755
浏览文件 @
f2d98c5e
文件已添加
tools/style_text_rec/tools/synth_dataset.py
0 → 100644
浏览文件 @
f2d98c5e
from
engine.synthesisers
import
DatasetSynthesiser
def
synth_dataset
():
dataset_synthesiser
=
DatasetSynthesiser
()
dataset_synthesiser
.
synth_dataset
()
if
__name__
==
'__main__'
:
synth_dataset
()
tools/style_text_rec/tools/synth_image.py
0 → 100644
浏览文件 @
f2d98c5e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
sys
import
glob
from
engine.synthesisers
import
ImageSynthesiser
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
def
synth_image
():
image_synthesiser
=
ImageSynthesiser
()
img
=
cv2
.
imread
(
"examples/style_images/1.jpg"
)
corpus
=
"PaddleOCR"
language
=
"en"
synth_result
=
image_synthesiser
.
synth_image
(
corpus
,
img
,
language
)
fake_fusion
=
synth_result
[
"fake_fusion"
]
fake_text
=
synth_result
[
"fake_text"
]
fake_bg
=
synth_result
[
"fake_bg"
]
cv2
.
imwrite
(
"fake_fusion.jpg"
,
fake_fusion
)
cv2
.
imwrite
(
"fake_text.jpg"
,
fake_text
)
cv2
.
imwrite
(
"fake_bg.jpg"
,
fake_bg
)
def
batch_synth_images
():
image_synthesiser
=
ImageSynthesiser
()
corpus_file
=
"../StyleTextRec_data/test_20201208/test_text_list.txt"
style_data_dir
=
"../StyleTextRec_data/test_20201208/style_images/"
save_path
=
"./output_data/"
corpus_list
=
[]
with
open
(
corpus_file
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
for
line
in
lines
:
substr
=
line
.
decode
(
"utf-8"
).
strip
(
"
\n
"
).
split
(
"
\t
"
)
corpus_list
.
append
(
substr
)
style_img_list
=
glob
.
glob
(
"{}/*.jpg"
.
format
(
style_data_dir
))
corpus_num
=
len
(
corpus_list
)
style_img_num
=
len
(
style_img_list
)
for
cno
in
range
(
corpus_num
):
for
sno
in
range
(
style_img_num
):
corpus
,
lang
=
corpus_list
[
cno
]
style_img_path
=
style_img_list
[
sno
]
img
=
cv2
.
imread
(
style_img_path
)
synth_result
=
image_synthesiser
.
synth_image
(
corpus
,
img
,
lang
)
fake_fusion
=
synth_result
[
"fake_fusion"
]
fake_text
=
synth_result
[
"fake_text"
]
fake_bg
=
synth_result
[
"fake_bg"
]
for
tp
in
range
(
2
):
if
tp
==
0
:
prefix
=
"%s/c%d_s%d_"
%
(
save_path
,
cno
,
sno
)
else
:
prefix
=
"%s/s%d_c%d_"
%
(
save_path
,
sno
,
cno
)
cv2
.
imwrite
(
"%s_fake_fusion.jpg"
%
prefix
,
fake_fusion
)
cv2
.
imwrite
(
"%s_fake_text.jpg"
%
prefix
,
fake_text
)
cv2
.
imwrite
(
"%s_fake_bg.jpg"
%
prefix
,
fake_bg
)
cv2
.
imwrite
(
"%s_input_style.jpg"
%
prefix
,
img
)
print
(
cno
,
corpus_num
,
sno
,
style_img_num
)
if
__name__
==
'__main__'
:
# batch_synth_images()
synth_image
()
tools/style_text_rec/utils/config.py
0 → 100644
浏览文件 @
f2d98c5e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
yaml
import
os
from
collections
import
OrderedDict
from
argparse
import
ArgumentParser
,
RawDescriptionHelpFormatter
def
override
(
dl
,
ks
,
v
):
"""
Recursively replace dict of list
Args:
dl(dict or list): dict or list to be replaced
ks(list): list of keys
v(str): value to be replaced
"""
def
str2num
(
v
):
try
:
return
eval
(
v
)
except
Exception
:
return
v
assert
isinstance
(
dl
,
(
list
,
dict
)),
(
"{} should be a list or a dict"
)
assert
len
(
ks
)
>
0
,
(
'lenght of keys should larger than 0'
)
if
isinstance
(
dl
,
list
):
k
=
str2num
(
ks
[
0
])
if
len
(
ks
)
==
1
:
assert
k
<
len
(
dl
),
(
'index({}) out of range({})'
.
format
(
k
,
dl
))
dl
[
k
]
=
str2num
(
v
)
else
:
override
(
dl
[
k
],
ks
[
1
:],
v
)
else
:
if
len
(
ks
)
==
1
:
#assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
if
not
ks
[
0
]
in
dl
:
logger
.
warning
(
'A new filed ({}) detected!'
.
format
(
ks
[
0
],
dl
))
dl
[
ks
[
0
]]
=
str2num
(
v
)
else
:
assert
ks
[
0
]
in
dl
,
(
'({}) doesn
\'
t exist in {}, a new dict field is invalid'
.
format
(
ks
[
0
],
dl
))
override
(
dl
[
ks
[
0
]],
ks
[
1
:],
v
)
def
override_config
(
config
,
options
=
None
):
"""
Recursively override the config
Args:
config(dict): dict to be replaced
options(list): list of pairs(key0.key1.idx.key2=value)
such as: [
'topk=2',
'VALID.transforms.1.ResizeImage.resize_short=300'
]
Returns:
config(dict): replaced config
"""
if
options
is
not
None
:
for
opt
in
options
:
assert
isinstance
(
opt
,
str
),
(
"option({}) should be a str"
.
format
(
opt
))
assert
"="
in
opt
,
(
"option({}) should contain a ="
"to distinguish between key and value"
.
format
(
opt
))
pair
=
opt
.
split
(
'='
)
assert
len
(
pair
)
==
2
,
(
"there can be only a = in the option"
)
key
,
value
=
pair
keys
=
key
.
split
(
'.'
)
override
(
config
,
keys
,
value
)
return
config
class
ArgsParser
(
ArgumentParser
):
def
__init__
(
self
):
super
(
ArgsParser
,
self
).
__init__
(
formatter_class
=
RawDescriptionHelpFormatter
)
self
.
add_argument
(
"-c"
,
"--config"
,
help
=
"configuration file to use"
)
self
.
add_argument
(
"-t"
,
"--tag"
,
default
=
"0"
,
help
=
"tag for marking worker"
)
self
.
add_argument
(
'-o'
,
'--override'
,
action
=
'append'
,
default
=
[],
help
=
'config options to be overridden'
)
def
parse_args
(
self
,
argv
=
None
):
args
=
super
(
ArgsParser
,
self
).
parse_args
(
argv
)
assert
args
.
config
is
not
None
,
\
"Please specify --config=configure_file_path."
return
args
def
load_config
(
file_path
):
"""
Load config from yml/yaml file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: config
"""
ext
=
os
.
path
.
splitext
(
file_path
)[
1
]
assert
ext
in
[
'.yml'
,
'.yaml'
],
"only support yaml files for now"
with
open
(
file_path
,
'rb'
)
as
f
:
config
=
yaml
.
load
(
f
,
Loader
=
yaml
.
Loader
)
return
config
def
gen_config
():
base_config
=
{
"Global"
:
{
"algorithm"
:
"SRNet"
,
"use_gpu"
:
True
,
"start_epoch"
:
1
,
"stage1_epoch_num"
:
100
,
"stage2_epoch_num"
:
100
,
"log_smooth_window"
:
20
,
"print_batch_step"
:
2
,
"save_model_dir"
:
"./output/SRNet"
,
"use_visualdl"
:
False
,
"save_epoch_step"
:
10
,
"vgg_pretrain"
:
"./pretrained/VGG19_pretrained"
,
"vgg_load_static_pretrain"
:
True
},
"Architecture"
:
{
"model_type"
:
"data_aug"
,
"algorithm"
:
"SRNet"
,
"net_g"
:
{
"name"
:
"srnet_net_g"
,
"encode_dim"
:
64
,
"norm"
:
"batch"
,
"use_dropout"
:
False
,
"init_type"
:
"xavier"
,
"init_gain"
:
0.02
,
"use_dilation"
:
1
},
# input_nc, ndf, netD,
# n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
"bg_discriminator"
:
{
"name"
:
"srnet_bg_discriminator"
,
"input_nc"
:
6
,
"ndf"
:
64
,
"netD"
:
"basic"
,
"norm"
:
"none"
,
"init_type"
:
"xavier"
,
},
"fusion_discriminator"
:
{
"name"
:
"srnet_fusion_discriminator"
,
"input_nc"
:
6
,
"ndf"
:
64
,
"netD"
:
"basic"
,
"norm"
:
"none"
,
"init_type"
:
"xavier"
,
}
},
"Loss"
:
{
"lamb"
:
10
,
"perceptual_lamb"
:
1
,
"muvar_lamb"
:
50
,
"style_lamb"
:
500
},
"Optimizer"
:
{
"name"
:
"Adam"
,
"learning_rate"
:
{
"name"
:
"lambda"
,
"lr"
:
0.0002
,
"lr_decay_iters"
:
50
},
"beta1"
:
0.5
,
"beta2"
:
0.999
,
},
"Train"
:
{
"batch_size_per_card"
:
8
,
"num_workers_per_card"
:
4
,
"dataset"
:
{
"delimiter"
:
"
\t
"
,
"data_dir"
:
"/"
,
"label_file"
:
"tmp/label.txt"
,
"transforms"
:
[{
"DecodeImage"
:
{
"to_rgb"
:
True
,
"to_np"
:
False
,
"channel_first"
:
False
}
},
{
"NormalizeImage"
:
{
"scale"
:
1.
/
255.
,
"mean"
:
[
0.485
,
0.456
,
0.406
],
"std"
:
[
0.229
,
0.224
,
0.225
],
"order"
:
None
}
},
{
"ToCHWImage"
:
None
}]
}
}
}
with
open
(
"config.yml"
,
"w"
)
as
f
:
yaml
.
dump
(
base_config
,
f
)
if
__name__
==
'__main__'
:
gen_config
()
tools/style_text_rec/utils/load_params.py
0 → 100644
浏览文件 @
f2d98c5e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
paddle
__all__
=
[
'load_dygraph_pretrain'
]
def
load_dygraph_pretrain
(
model
,
logger
,
path
=
None
,
load_static_weights
=
False
):
if
not
os
.
path
.
exists
(
path
+
'.pdparams'
):
raise
ValueError
(
"Model pretrain path {} does not "
"exists."
.
format
(
path
))
param_state_dict
=
paddle
.
load
(
path
+
'.pdparams'
)
model
.
set_state_dict
(
param_state_dict
)
logger
.
info
(
"load pretrained model from {}"
.
format
(
path
))
return
tools/style_text_rec/utils/logging.py
0 → 100644
浏览文件 @
f2d98c5e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
logging
import
functools
import
paddle.distributed
as
dist
logger_initialized
=
{}
@
functools
.
lru_cache
()
def
get_logger
(
name
=
'srnet'
,
log_file
=
None
,
log_level
=
logging
.
INFO
):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
logger
=
logging
.
getLogger
(
name
)
if
name
in
logger_initialized
:
return
logger
for
logger_name
in
logger_initialized
:
if
name
.
startswith
(
logger_name
):
return
logger
formatter
=
logging
.
Formatter
(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s'
,
datefmt
=
"%Y/%m/%d %H:%M:%S"
)
stream_handler
=
logging
.
StreamHandler
(
stream
=
sys
.
stdout
)
stream_handler
.
setFormatter
(
formatter
)
logger
.
addHandler
(
stream_handler
)
if
log_file
is
not
None
and
dist
.
get_rank
()
==
0
:
log_file_folder
=
os
.
path
.
split
(
log_file
)[
0
]
os
.
makedirs
(
log_file_folder
,
exist_ok
=
True
)
file_handler
=
logging
.
FileHandler
(
log_file
,
'a'
)
file_handler
.
setFormatter
(
formatter
)
logger
.
addHandler
(
file_handler
)
if
dist
.
get_rank
()
==
0
:
logger
.
setLevel
(
log_level
)
else
:
logger
.
setLevel
(
logging
.
ERROR
)
logger_initialized
[
name
]
=
True
return
logger
tools/style_text_rec/utils/math_functions.py
0 → 100644
浏览文件 @
f2d98c5e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
def
compute_mean_covariance
(
img
):
batch_size
=
img
.
shape
[
0
]
channel_num
=
img
.
shape
[
1
]
height
=
img
.
shape
[
2
]
width
=
img
.
shape
[
3
]
num_pixels
=
height
*
width
# batch_size * channel_num * 1 * 1
mu
=
img
.
mean
(
2
,
keepdim
=
True
).
mean
(
3
,
keepdim
=
True
)
# batch_size * channel_num * num_pixels
img_hat
=
img
-
mu
.
expand_as
(
img
)
img_hat
=
img_hat
.
reshape
([
batch_size
,
channel_num
,
num_pixels
])
# batch_size * num_pixels * channel_num
img_hat_transpose
=
img_hat
.
transpose
([
0
,
2
,
1
])
# batch_size * channel_num * channel_num
covariance
=
paddle
.
bmm
(
img_hat
,
img_hat_transpose
)
covariance
=
covariance
/
num_pixels
return
mu
,
covariance
def
dice_coefficient
(
y_true_cls
,
y_pred_cls
,
training_mask
):
eps
=
1e-5
intersection
=
paddle
.
sum
(
y_true_cls
*
y_pred_cls
*
training_mask
)
union
=
paddle
.
sum
(
y_true_cls
*
training_mask
)
+
paddle
.
sum
(
y_pred_cls
*
training_mask
)
+
eps
loss
=
1.
-
(
2
*
intersection
/
union
)
return
loss
tools/style_text_rec/utils/sys_funcs.py
0 → 100644
浏览文件 @
f2d98c5e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
os
import
errno
import
paddle
def
get_check_global_params
(
mode
):
check_params
=
[
'use_gpu'
,
'max_text_length'
,
'image_shape'
,
'image_shape'
,
'character_type'
,
'loss_type'
]
if
mode
==
"train_eval"
:
check_params
=
check_params
+
[
'train_batch_size_per_card'
,
'test_batch_size_per_card'
]
elif
mode
==
"test"
:
check_params
=
check_params
+
[
'test_batch_size_per_card'
]
return
check_params
def
check_gpu
(
use_gpu
):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err
=
"Config use_gpu cannot be set as true while you are "
\
"using paddlepaddle cpu version !
\n
Please try:
\n
"
\
"
\t
1. Install paddlepaddle-gpu to run model on GPU
\n
"
\
"
\t
2. Set use_gpu as false in config file to run "
\
"model on CPU"
if
use_gpu
:
try
:
if
not
paddle
.
is_compiled_with_cuda
():
print
(
err
)
sys
.
exit
(
1
)
except
:
print
(
"Fail to check gpu state."
)
sys
.
exit
(
1
)
def
_mkdir_if_not_exist
(
path
,
logger
):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if
not
os
.
path
.
exists
(
path
):
try
:
os
.
makedirs
(
path
)
except
OSError
as
e
:
if
e
.
errno
==
errno
.
EEXIST
and
os
.
path
.
isdir
(
path
):
logger
.
warning
(
'be happy if some process has already created {}'
.
format
(
path
))
else
:
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录