未验证 提交 a26968a4 编写于 作者: D dyning 提交者: GitHub

Merge pull request #1428 from weisy11/dygraph

Add Style Text Rec
## Style Text Rec
### 目录
[工具简介](#工具简介)
[环境配置](#环境配置)
[快速上手](#快速上手)
[高级使用](#高级使用)
[应用示例](#应用示例)
### 工具简介
<div align="center">
<img src="../imgs_style_text/3.png" width="800">
</div>
Style-Text是对百度自研文本编辑算法《Editing Text in the Wild》中提出的SRNet网络的改进,不同于常用的GAN的方法只选择一个分支,该工具将文本合成任务分解为三个子模块,文本风格迁移模块、背景抽取模块和前背景融合模块,来提升合成数据的效果。下图显示了一些示例结果。
<div align="center">
<img src="../imgs_style_text/1.png" width="800">
<img src="../imgs_style_text/2.png" width="800">
</div>
此外,在实际铭牌文本识别场景和韩语文本识别场景,验证了该合成工具的有效性。
### 环境配置
1. 参考[快速安装](./installation.md),安装PaddleOCR。强烈建议您使用python3环境。
2. 进入`style_text_rec`目录,下载模型,并解压:
```bash
cd style_text_rec
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip
unzip style_text_models.zip
```
如果您将模型保存再其他位置,请在`configs/config.yml`中修改模型文件的地址,修改时需要同时修改这三个配置:
```
bg_generator:
pretrain: style_text_models/bg_generator
...
text_generator:
pretrain: style_text_models/text_generator
...
fusion_generator:
pretrain: style_text_models/fusion_generator
```
### 快速上手
1. 运行tools/synth_image,生成示例图片:
```python
python3 -m tools.synth_image -c configs/config.yml
```
1. 运行后,会生成`fake_busion.jpg`,即为最终结果。
<div align="center">
<img src="../imgs_style_text/4.jpg" width="300">
</div>
除此之外,程序还会生成并保存中间结果:
* `fake_bg.jpg`:为风格参考图去掉文字后的背景;
* `fake_text.jpg`:是用提供的字符串,仿照风格参考图中文字的风格,生成在灰色背景上的文字图片。
2. 如果您想尝试其他风格图像和文字的效果,可以添加style_image,text_corpus和language参数:
```python
python3 -m tools.synth_image -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en
```
* 注意:语言选项和语料相对应,目前我们支持英文、简体中文和韩语。
3.`tools/synth_image.py`中,我们还提供了一个`batch_synth_images`方法,可以两两组合语料和图片,批量生成一批数据。
### 高级使用
在开始合成数据集前,需要准备一些素材。
首先,需要风格图片作为合成图片的参考依据,这些数据可以是用作训练OCR识别模型的数据集。本例中使用带有标注文件的数据集作为风格图片.
1.`configs/dataset_config.yml`中配置输入数据路径。
* `StyleSampler`
* `method`:使用的风格图片采样方法;
* `image_home`:风格图片目录;
* `label_file`:风格图片路径列表文件,如果所用数据集有label,则label_file为label文件路径;
* `with_label`:标志`label_file`是否为label文件。
我们提供了一批[样例图](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/chkoen_5w.tar)供您试用。
* `CorpusGenerator`
* `method`:语料生成方法,目前有`FileCorpus``EnNumCorpus`可选。如果使用`EnNumCorpus`,则不需要填写其他配置,否则需要修改`corpus_file``language`
* `language`:语料的语种;
* `corpus_file`: 语料文件路径。
2. 运行`tools/synth_dataset`合成数据:
``` bash
python -m tools.synth_dataset -c configs/dataset_config.yml
```
3. 如果您想使用并行方式来快速合成数据,可以通过启动多个进程,在启动时需要指定不同的`tag``-t`),如下所示:
```bash
python3 -m tools.synth_dataset -t 0 -c configs/dataset_config.yml
python3 -m tools.synth_dataset -t 1 -c configs/dataset_config.yml
```
### 应用示例
在完成上述操作后,即可得到用于OCR识别的合成数据集,下面给出了一些数据集生成的示例:
接下来请参考[OCR识别文档](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/recognition.md#%E5%90%AF%E5%8A%A8%E8%AE%AD%E7%BB%83),完成训练。
### 项目结构
```
style_text_rec
|-- arch
| |-- base_module.py
| |-- decoder.py
| |-- encoder.py
| |-- spectral_norm.py
| `-- style_text_rec.py
|-- configs
| |-- config.yml
| `-- dataset_config.yml
|-- engine
| |-- corpus_generators.py
| |-- predictors.py
| |-- style_samplers.py
| |-- synthesisers.py
| |-- text_drawers.py
| `-- writers.py
|-- examples
| |-- corpus
| | `-- example.txt
| |-- image_list.txt
| `-- style_images
| |-- 1.jpg
| `-- 2.jpg
|-- fonts
| |-- ch_standard.ttf
| |-- en_standard.ttf
| `-- ko_standard.ttf
|-- tools
| |-- __init__.py
| |-- synth_dataset.py
| `-- synth_image.py
`-- utils
|-- config.py
|-- load_params.py
|-- logging.py
|-- math_functions.py
`-- sys_funcs.py
```
\ No newline at end of file
### Quick Start
`Style-Text` is an improvement of the SRNet network proposed in Baidu's self-developed text editing algorithm "Editing Text in the Wild". It is different from the commonly used GAN methods. This tool decomposes the text synthesis task into three sub-modules to improve the effect of synthetic data: text style transfer module, background extraction module and fusion module.
The following figure shows some example results. In addition, the actual `nameplate text recognition` scene and `the Korean text recognition` scene verify the effectiveness of the synthesis tool, as follows.
#### Preparation
1. Please refer the [QUICK INSTALLATION](./installation_en.md) to install PaddlePaddle. Python3 environment is strongly recommended.
2. Download the pretrained models and unzip:
```bash
cd tools/style_text_rec
wget /path/to/style_text_models.zip
unzip style_text_models.zip
```
You can dowload models [here](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip). If you save the model files in other folders, please edit the three model paths in `configs/config.yml`:
```
bg_generator:
pretrain: style_text_rec/bg_generator
...
text_generator:
pretrain: style_text_models/text_generator
...
fusion_generator:
pretrain: style_text_models/fusion_generator
```
#### Demo
1. You can use the following commands to run a demo:
```bash
python -m tools.synth_image -c configs/config.yml
```
2. The results are `fake_bg.jpg`, `fake_text.jpg` and `fake_fusion.jpg` as shown in the figure above. Above them:
* `fake_text.jpg` is the generated image with the same font style as `Style Input`;
* `fake_bg.jpg` is the generated image of `Style Input` after removing foreground.
* `fake_fusion.jpg` is the final result, that is synthesised by `fake_text.jpg` and `fake_bg.jpg`.
3. If want to generate image by other `Style Input` or `Text Input`, you can modify the `tools/synth_image.py`:
* `img = cv2.imread("examples/style_images/1.jpg")`: the path of `Style Input`;
* `corpus = "PaddleOCR"`: the `Text Input`;
* Notice:modify the language option(`language = "en"`) to adapt `Text Input`, that support `en`, `ch`, `ko`.
4. We also provide `batch_synth_images` mothod, that can combine corpus and pictures in pairs to generate a batch of data.
### Advanced Usage
#### Components
`Style Text Rec` mainly contains the following components:
* `style_samplers`: It can sample `Style Input` from a dataset. Now, We only provide `DatasetSampler`.
* `corpus_generators`: It can generate corpus. Now, wo only provide two `corpus_generators`:
* `EnNumCorpus`: It can generate a random string according to a given length, including uppercase and lowercase English letters, numbers and spaces.
* `FileCorpus`: It can read a text file and randomly return the words in it.
* `text_drawers`: It can generate `Text Input`(text picture in standard font according to the input corpus). Note that when using, you have to modify the language information according to the corpus.
* `predictors`: It can call the deep learning model to generate new data based on the `Style Input` and `Text Input`.
* `writers`: It can write the generated pictures(`fake_bg.jpg`, `fake_text.jpg` and `fake_fusion.jpg`) and label information to the disk.
* `synthesisers`: It can call the all modules to complete the work.
### Generate Dataset
Before the start, you need to prepare some data as material.
First, you should have the style reference data for synthesis tasks, which are generally used as datasets for OCR recognition tasks.
1. The referenced dataset can be specifed in `configs/dataset_config.yml`:
* `StyleSampler`:
* `method`: The method of `StyleSampler`.
* `image_home`: The directory of pictures.
* `label_file`: The list of pictures path if `with_label` is `false`, otherwise, the label file path.
* `with_label`: The `label_file` is label file or not.
* `CorpusGenerator`:
* `method`: The mothod of `CorpusGenerator`. If `FileCorpus` used, you need modify `corpus_file` and `language` accordingly, if `EnNumCorpus`, other configurations is not needed.
* `language`: The language of the corpus. Needed if method is not `EnNumCorpus`.
* `corpus_file`: The corpus file path. Needed if method is not `EnNumCorpus`.
2. You can run the following command to start synthesis task:
``` bash
python -m tools.synth_dataset.py -c configs/dataset_config.yml
```
3. You can using the following command to start multiple synthesis tasks in a multi-threaded manner, which needed to specifying tags by `-t`:
```bash
python -m tools.synth_dataset.py -t 0 -c configs/dataset_config.yml
python -m tools.synth_dataset.py -t 1 -c configs/dataset_config.yml
```
### OCR Recognition Training
After completing the above operations, you can get the synthetic data set for OCR recognition. Next, please complete the training by refering to [OCR Recognition Document](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/recognition. md#%E5%90%AF%E5%8A%A8%E8%AE%AD%E7%BB%83).
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from arch.spectral_norm import spectral_norm
class CBN(nn.Layer):
def __init__(self,
name,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
use_bias=False,
norm_layer=None,
act=None,
act_attr=None):
super(CBN, self).__init__()
if use_bias:
bias_attr = paddle.ParamAttr(name=name + "_bias")
else:
bias_attr = None
self._conv = paddle.nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
weight_attr=paddle.ParamAttr(name=name + "_weights"),
bias_attr=bias_attr)
if norm_layer:
self._norm_layer = getattr(paddle.nn, norm_layer)(
num_features=out_channels, name=name + "_bn")
else:
self._norm_layer = None
if act:
if act_attr:
self._act = getattr(paddle.nn, act)(**act_attr,
name=name + "_" + act)
else:
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
else:
self._act = None
def forward(self, x):
out = self._conv(x)
if self._norm_layer:
out = self._norm_layer(out)
if self._act:
out = self._act(out)
return out
class SNConv(nn.Layer):
def __init__(self,
name,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
use_bias=False,
norm_layer=None,
act=None,
act_attr=None):
super(SNConv, self).__init__()
if use_bias:
bias_attr = paddle.ParamAttr(name=name + "_bias")
else:
bias_attr = None
self._sn_conv = spectral_norm(
paddle.nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
weight_attr=paddle.ParamAttr(name=name + "_weights"),
bias_attr=bias_attr))
if norm_layer:
self._norm_layer = getattr(paddle.nn, norm_layer)(
num_features=out_channels, name=name + "_bn")
else:
self._norm_layer = None
if act:
if act_attr:
self._act = getattr(paddle.nn, act)(**act_attr,
name=name + "_" + act)
else:
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
else:
self._act = None
def forward(self, x):
out = self._sn_conv(x)
if self._norm_layer:
out = self._norm_layer(out)
if self._act:
out = self._act(out)
return out
class SNConvTranspose(nn.Layer):
def __init__(self,
name,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
output_padding=0,
dilation=1,
groups=1,
use_bias=False,
norm_layer=None,
act=None,
act_attr=None):
super(SNConvTranspose, self).__init__()
if use_bias:
bias_attr = paddle.ParamAttr(name=name + "_bias")
else:
bias_attr = None
self._sn_conv_transpose = spectral_norm(
paddle.nn.Conv2DTranspose(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
dilation=dilation,
groups=groups,
weight_attr=paddle.ParamAttr(name=name + "_weights"),
bias_attr=bias_attr))
if norm_layer:
self._norm_layer = getattr(paddle.nn, norm_layer)(
num_features=out_channels, name=name + "_bn")
else:
self._norm_layer = None
if act:
if act_attr:
self._act = getattr(paddle.nn, act)(**act_attr,
name=name + "_" + act)
else:
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
else:
self._act = None
def forward(self, x):
out = self._sn_conv_transpose(x)
if self._norm_layer:
out = self._norm_layer(out)
if self._act:
out = self._act(out)
return out
class MiddleNet(nn.Layer):
def __init__(self, name, in_channels, mid_channels, out_channels,
use_bias):
super(MiddleNet, self).__init__()
self._sn_conv1 = SNConv(
name=name + "_sn_conv1",
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=1,
use_bias=use_bias,
norm_layer=None,
act=None)
self._pad2d = nn.Pad2D(padding=[1, 1, 1, 1], mode="replicate")
self._sn_conv2 = SNConv(
name=name + "_sn_conv2",
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=3,
use_bias=use_bias)
self._sn_conv3 = SNConv(
name=name + "_sn_conv3",
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=1,
use_bias=use_bias)
def forward(self, x):
sn_conv1 = self._sn_conv1.forward(x)
pad_2d = self._pad2d.forward(sn_conv1)
sn_conv2 = self._sn_conv2.forward(pad_2d)
sn_conv3 = self._sn_conv3.forward(sn_conv2)
return sn_conv3
class ResBlock(nn.Layer):
def __init__(self, name, channels, norm_layer, use_dropout, use_dilation,
use_bias):
super(ResBlock, self).__init__()
if use_dilation:
padding_mat = [1, 1, 1, 1]
else:
padding_mat = [0, 0, 0, 0]
self._pad1 = nn.Pad2D(padding_mat, mode="replicate")
self._sn_conv1 = SNConv(
name=name + "_sn_conv1",
in_channels=channels,
out_channels=channels,
kernel_size=3,
padding=0,
norm_layer=norm_layer,
use_bias=use_bias,
act="ReLU",
act_attr=None)
if use_dropout:
self._dropout = nn.Dropout(0.5)
else:
self._dropout = None
self._pad2 = nn.Pad2D([1, 1, 1, 1], mode="replicate")
self._sn_conv2 = SNConv(
name=name + "_sn_conv2",
in_channels=channels,
out_channels=channels,
kernel_size=3,
norm_layer=norm_layer,
use_bias=use_bias,
act="ReLU",
act_attr=None)
def forward(self, x):
pad1 = self._pad1.forward(x)
sn_conv1 = self._sn_conv1.forward(pad1)
pad2 = self._pad2.forward(sn_conv1)
sn_conv2 = self._sn_conv2.forward(pad2)
return sn_conv2 + x
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from arch.base_module import SNConv, SNConvTranspose, ResBlock
class Decoder(nn.Layer):
def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
act, act_attr, conv_block_dropout, conv_block_num,
conv_block_dilation, out_conv_act, out_conv_act_attr):
super(Decoder, self).__init__()
conv_blocks = []
for i in range(conv_block_num):
conv_blocks.append(
ResBlock(
name="{}_conv_block_{}".format(name, i),
channels=encode_dim * 8,
norm_layer=norm_layer,
use_dropout=conv_block_dropout,
use_dilation=conv_block_dilation,
use_bias=use_bias))
self.conv_blocks = nn.Sequential(*conv_blocks)
self._up1 = SNConvTranspose(
name=name + "_up1",
in_channels=encode_dim * 8,
out_channels=encode_dim * 4,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._up2 = SNConvTranspose(
name=name + "_up2",
in_channels=encode_dim * 4,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._up3 = SNConvTranspose(
name=name + "_up3",
in_channels=encode_dim * 2,
out_channels=encode_dim,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
self._out_conv = SNConv(
name=name + "_out_conv",
in_channels=encode_dim,
out_channels=out_channels,
kernel_size=3,
use_bias=use_bias,
norm_layer=None,
act=out_conv_act,
act_attr=out_conv_act_attr)
def forward(self, x):
if isinstance(x, (list, tuple)):
x = paddle.concat(x, axis=1)
output_dict = dict()
output_dict["conv_blocks"] = self.conv_blocks.forward(x)
output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
output_dict["up2"] = self._up2.forward(output_dict["up1"])
output_dict["up3"] = self._up3.forward(output_dict["up2"])
output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
return output_dict
class DecoderUnet(nn.Layer):
def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
act, act_attr, conv_block_dropout, conv_block_num,
conv_block_dilation, out_conv_act, out_conv_act_attr):
super(DecoderUnet, self).__init__()
conv_blocks = []
for i in range(conv_block_num):
conv_blocks.append(
ResBlock(
name="{}_conv_block_{}".format(name, i),
channels=encode_dim * 8,
norm_layer=norm_layer,
use_dropout=conv_block_dropout,
use_dilation=conv_block_dilation,
use_bias=use_bias))
self._conv_blocks = nn.Sequential(*conv_blocks)
self._up1 = SNConvTranspose(
name=name + "_up1",
in_channels=encode_dim * 8,
out_channels=encode_dim * 4,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._up2 = SNConvTranspose(
name=name + "_up2",
in_channels=encode_dim * 8,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._up3 = SNConvTranspose(
name=name + "_up3",
in_channels=encode_dim * 4,
out_channels=encode_dim,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
self._out_conv = SNConv(
name=name + "_out_conv",
in_channels=encode_dim,
out_channels=out_channels,
kernel_size=3,
use_bias=use_bias,
norm_layer=None,
act=out_conv_act,
act_attr=out_conv_act_attr)
def forward(self, x, y, feature2, feature1):
output_dict = dict()
output_dict["conv_blocks"] = self._conv_blocks(
paddle.concat(
(x, y), axis=1))
output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
output_dict["up2"] = self._up2.forward(
paddle.concat(
(output_dict["up1"], feature2), axis=1))
output_dict["up3"] = self._up3.forward(
paddle.concat(
(output_dict["up2"], feature1), axis=1))
output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
return output_dict
class SingleDecoder(nn.Layer):
def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
act, act_attr, conv_block_dropout, conv_block_num,
conv_block_dilation, out_conv_act, out_conv_act_attr):
super(SingleDecoder, self).__init__()
conv_blocks = []
for i in range(conv_block_num):
conv_blocks.append(
ResBlock(
name="{}_conv_block_{}".format(name, i),
channels=encode_dim * 4,
norm_layer=norm_layer,
use_dropout=conv_block_dropout,
use_dilation=conv_block_dilation,
use_bias=use_bias))
self._conv_blocks = nn.Sequential(*conv_blocks)
self._up1 = SNConvTranspose(
name=name + "_up1",
in_channels=encode_dim * 4,
out_channels=encode_dim * 4,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._up2 = SNConvTranspose(
name=name + "_up2",
in_channels=encode_dim * 8,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._up3 = SNConvTranspose(
name=name + "_up3",
in_channels=encode_dim * 4,
out_channels=encode_dim,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
self._out_conv = SNConv(
name=name + "_out_conv",
in_channels=encode_dim,
out_channels=out_channels,
kernel_size=3,
use_bias=use_bias,
norm_layer=None,
act=out_conv_act,
act_attr=out_conv_act_attr)
def forward(self, x, feature2, feature1):
output_dict = dict()
output_dict["conv_blocks"] = self._conv_blocks.forward(x)
output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
output_dict["up2"] = self._up2.forward(
paddle.concat(
(output_dict["up1"], feature2), axis=1))
output_dict["up3"] = self._up3.forward(
paddle.concat(
(output_dict["up2"], feature1), axis=1))
output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
return output_dict
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from arch.base_module import SNConv, SNConvTranspose, ResBlock
class Encoder(nn.Layer):
def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
act, act_attr, conv_block_dropout, conv_block_num,
conv_block_dilation):
super(Encoder, self).__init__()
self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
self._in_conv = SNConv(
name=name + "_in_conv",
in_channels=in_channels,
out_channels=encode_dim,
kernel_size=7,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._down1 = SNConv(
name=name + "_down1",
in_channels=encode_dim,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._down2 = SNConv(
name=name + "_down2",
in_channels=encode_dim * 2,
out_channels=encode_dim * 4,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._down3 = SNConv(
name=name + "_down3",
in_channels=encode_dim * 4,
out_channels=encode_dim * 4,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
conv_blocks = []
for i in range(conv_block_num):
conv_blocks.append(
ResBlock(
name="{}_conv_block_{}".format(name, i),
channels=encode_dim * 4,
norm_layer=norm_layer,
use_dropout=conv_block_dropout,
use_dilation=conv_block_dilation,
use_bias=use_bias))
self._conv_blocks = nn.Sequential(*conv_blocks)
def forward(self, x):
out_dict = dict()
x = self._pad2d(x)
out_dict["in_conv"] = self._in_conv.forward(x)
out_dict["down1"] = self._down1.forward(out_dict["in_conv"])
out_dict["down2"] = self._down2.forward(out_dict["down1"])
out_dict["down3"] = self._down3.forward(out_dict["down2"])
out_dict["res_blocks"] = self._conv_blocks.forward(out_dict["down3"])
return out_dict
class EncoderUnet(nn.Layer):
def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
act, act_attr):
super(EncoderUnet, self).__init__()
self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
self._in_conv = SNConv(
name=name + "_in_conv",
in_channels=in_channels,
out_channels=encode_dim,
kernel_size=7,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._down1 = SNConv(
name=name + "_down1",
in_channels=encode_dim,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._down2 = SNConv(
name=name + "_down2",
in_channels=encode_dim * 2,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._down3 = SNConv(
name=name + "_down3",
in_channels=encode_dim * 2,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._down4 = SNConv(
name=name + "_down4",
in_channels=encode_dim * 2,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._up1 = SNConvTranspose(
name=name + "_up1",
in_channels=encode_dim * 2,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._up2 = SNConvTranspose(
name=name + "_up2",
in_channels=encode_dim * 4,
out_channels=encode_dim * 4,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
def forward(self, x):
output_dict = dict()
x = self._pad2d(x)
output_dict['in_conv'] = self._in_conv.forward(x)
output_dict['down1'] = self._down1.forward(output_dict['in_conv'])
output_dict['down2'] = self._down2.forward(output_dict['down1'])
output_dict['down3'] = self._down3.forward(output_dict['down2'])
output_dict['down4'] = self._down4.forward(output_dict['down3'])
output_dict['up1'] = self._up1.forward(output_dict['down4'])
output_dict['up2'] = self._up2.forward(
paddle.concat(
(output_dict['down3'], output_dict['up1']), axis=1))
output_dict['concat'] = paddle.concat(
(output_dict['down2'], output_dict['up2']), axis=1)
return output_dict
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
def normal_(x, mean=0., std=1.):
temp_value = paddle.normal(mean, std, shape=x.shape)
x.set_value(temp_value)
return x
class SpectralNorm(object):
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
self.name = name
self.dim = dim
if n_power_iterations <= 0:
raise ValueError('Expected n_power_iterations to be positive, but '
'got n_power_iterations={}'.format(
n_power_iterations))
self.n_power_iterations = n_power_iterations
self.eps = eps
def reshape_weight_to_matrix(self, weight):
weight_mat = weight
if self.dim != 0:
# transpose dim to front
weight_mat = weight_mat.transpose([
self.dim,
* [d for d in range(weight_mat.dim()) if d != self.dim]
])
height = weight_mat.shape[0]
return weight_mat.reshape([height, -1])
def compute_weight(self, module, do_power_iteration):
weight = getattr(module, self.name + '_orig')
u = getattr(module, self.name + '_u')
v = getattr(module, self.name + '_v')
weight_mat = self.reshape_weight_to_matrix(weight)
if do_power_iteration:
with paddle.no_grad():
for _ in range(self.n_power_iterations):
v.set_value(
F.normalize(
paddle.matmul(
weight_mat,
u,
transpose_x=True,
transpose_y=False),
axis=0,
epsilon=self.eps, ))
u.set_value(
F.normalize(
paddle.matmul(weight_mat, v),
axis=0,
epsilon=self.eps, ))
if self.n_power_iterations > 0:
u = u.clone()
v = v.clone()
sigma = paddle.dot(u, paddle.mv(weight_mat, v))
weight = weight / sigma
return weight
def remove(self, module):
with paddle.no_grad():
weight = self.compute_weight(module, do_power_iteration=False)
delattr(module, self.name)
delattr(module, self.name + '_u')
delattr(module, self.name + '_v')
delattr(module, self.name + '_orig')
module.add_parameter(self.name, weight.detach())
def __call__(self, module, inputs):
setattr(
module,
self.name,
self.compute_weight(
module, do_power_iteration=module.training))
@staticmethod
def apply(module, name, n_power_iterations, dim, eps):
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
raise RuntimeError(
"Cannot register two spectral_norm hooks on "
"the same parameter {}".format(name))
fn = SpectralNorm(name, n_power_iterations, dim, eps)
weight = module._parameters[name]
with paddle.no_grad():
weight_mat = fn.reshape_weight_to_matrix(weight)
h, w = weight_mat.shape
# randomly initialize u and v
u = module.create_parameter([h])
u = normal_(u, 0., 1.)
v = module.create_parameter([w])
v = normal_(v, 0., 1.)
u = F.normalize(u, axis=0, epsilon=fn.eps)
v = F.normalize(v, axis=0, epsilon=fn.eps)
# delete fn.name form parameters, otherwise you can not set attribute
del module._parameters[fn.name]
module.add_parameter(fn.name + "_orig", weight)
# still need to assign weight back as fn.name because all sorts of
# things may assume that it exists, e.g., when initializing weights.
# However, we can't directly assign as it could be an Parameter and
# gets added as a parameter. Instead, we register weight * 1.0 as a plain
# attribute.
setattr(module, fn.name, weight * 1.0)
module.register_buffer(fn.name + "_u", u)
module.register_buffer(fn.name + "_v", v)
module.register_forward_pre_hook(fn)
return fn
def spectral_norm(module,
name='weight',
n_power_iterations=1,
eps=1e-12,
dim=None):
if dim is None:
if isinstance(module, (nn.Conv1DTranspose, nn.Conv2DTranspose,
nn.Conv3DTranspose, nn.Linear)):
dim = 1
else:
dim = 0
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
return module
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from arch.base_module import MiddleNet, ResBlock
from arch.encoder import Encoder
from arch.decoder import Decoder, DecoderUnet, SingleDecoder
from utils.load_params import load_dygraph_pretrain
from utils.logging import get_logger
class StyleTextRec(nn.Layer):
def __init__(self, config):
super(StyleTextRec, self).__init__()
self.logger = get_logger()
self.text_generator = TextGenerator(config["Predictor"][
"text_generator"])
self.bg_generator = BgGeneratorWithMask(config["Predictor"][
"bg_generator"])
self.fusion_generator = FusionGeneratorSimple(config["Predictor"][
"fusion_generator"])
bg_generator_pretrain = config["Predictor"]["bg_generator"]["pretrain"]
text_generator_pretrain = config["Predictor"]["text_generator"][
"pretrain"]
fusion_generator_pretrain = config["Predictor"]["fusion_generator"][
"pretrain"]
load_dygraph_pretrain(
self.bg_generator,
self.logger,
path=bg_generator_pretrain,
load_static_weights=False)
load_dygraph_pretrain(
self.text_generator,
self.logger,
path=text_generator_pretrain,
load_static_weights=False)
load_dygraph_pretrain(
self.fusion_generator,
self.logger,
path=fusion_generator_pretrain,
load_static_weights=False)
def forward(self, style_input, text_input):
text_gen_output = self.text_generator.forward(style_input, text_input)
fake_text = text_gen_output["fake_text"]
fake_sk = text_gen_output["fake_sk"]
bg_gen_output = self.bg_generator.forward(style_input)
bg_encode_feature = bg_gen_output["bg_encode_feature"]
bg_decode_feature1 = bg_gen_output["bg_decode_feature1"]
bg_decode_feature2 = bg_gen_output["bg_decode_feature2"]
fake_bg = bg_gen_output["fake_bg"]
fusion_gen_output = self.fusion_generator.forward(fake_text, fake_bg)
fake_fusion = fusion_gen_output["fake_fusion"]
return {
"fake_fusion": fake_fusion,
"fake_text": fake_text,
"fake_sk": fake_sk,
"fake_bg": fake_bg,
}
class TextGenerator(nn.Layer):
def __init__(self, config):
super(TextGenerator, self).__init__()
name = config["module_name"]
encode_dim = config["encode_dim"]
norm_layer = config["norm_layer"]
conv_block_dropout = config["conv_block_dropout"]
conv_block_num = config["conv_block_num"]
conv_block_dilation = config["conv_block_dilation"]
if norm_layer == "InstanceNorm2D":
use_bias = True
else:
use_bias = False
self.encoder_text = Encoder(
name=name + "_encoder_text",
in_channels=3,
encode_dim=encode_dim,
use_bias=use_bias,
norm_layer=norm_layer,
act="ReLU",
act_attr=None,
conv_block_dropout=conv_block_dropout,
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation)
self.encoder_style = Encoder(
name=name + "_encoder_style",
in_channels=3,
encode_dim=encode_dim,
use_bias=use_bias,
norm_layer=norm_layer,
act="ReLU",
act_attr=None,
conv_block_dropout=conv_block_dropout,
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation)
self.decoder_text = Decoder(
name=name + "_decoder_text",
encode_dim=encode_dim,
out_channels=int(encode_dim / 2),
use_bias=use_bias,
norm_layer=norm_layer,
act="ReLU",
act_attr=None,
conv_block_dropout=conv_block_dropout,
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation,
out_conv_act="Tanh",
out_conv_act_attr=None)
self.decoder_sk = Decoder(
name=name + "_decoder_sk",
encode_dim=encode_dim,
out_channels=1,
use_bias=use_bias,
norm_layer=norm_layer,
act="ReLU",
act_attr=None,
conv_block_dropout=conv_block_dropout,
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation,
out_conv_act="Sigmoid",
out_conv_act_attr=None)
self.middle = MiddleNet(
name=name + "_middle_net",
in_channels=int(encode_dim / 2) + 1,
mid_channels=encode_dim,
out_channels=3,
use_bias=use_bias)
def forward(self, style_input, text_input):
style_feature = self.encoder_style.forward(style_input)["res_blocks"]
text_feature = self.encoder_text.forward(text_input)["res_blocks"]
fake_c_temp = self.decoder_text.forward([text_feature,
style_feature])["out_conv"]
fake_sk = self.decoder_sk.forward([text_feature,
style_feature])["out_conv"]
fake_text = self.middle(paddle.concat((fake_c_temp, fake_sk), axis=1))
return {"fake_sk": fake_sk, "fake_text": fake_text}
class BgGeneratorWithMask(nn.Layer):
def __init__(self, config):
super(BgGeneratorWithMask, self).__init__()
name = config["module_name"]
encode_dim = config["encode_dim"]
norm_layer = config["norm_layer"]
conv_block_dropout = config["conv_block_dropout"]
conv_block_num = config["conv_block_num"]
conv_block_dilation = config["conv_block_dilation"]
self.output_factor = config.get("output_factor", 1.0)
if norm_layer == "InstanceNorm2D":
use_bias = True
else:
use_bias = False
self.encoder_bg = Encoder(
name=name + "_encoder_bg",
in_channels=3,
encode_dim=encode_dim,
use_bias=use_bias,
norm_layer=norm_layer,
act="ReLU",
act_attr=None,
conv_block_dropout=conv_block_dropout,
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation)
self.decoder_bg = SingleDecoder(
name=name + "_decoder_bg",
encode_dim=encode_dim,
out_channels=3,
use_bias=use_bias,
norm_layer=norm_layer,
act="ReLU",
act_attr=None,
conv_block_dropout=conv_block_dropout,
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation,
out_conv_act="Tanh",
out_conv_act_attr=None)
self.decoder_mask = Decoder(
name=name + "_decoder_mask",
encode_dim=encode_dim // 2,
out_channels=1,
use_bias=use_bias,
norm_layer=norm_layer,
act="ReLU",
act_attr=None,
conv_block_dropout=conv_block_dropout,
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation,
out_conv_act="Sigmoid",
out_conv_act_attr=None)
self.middle = MiddleNet(
name=name + "_middle_net",
in_channels=3 + 1,
mid_channels=encode_dim,
out_channels=3,
use_bias=use_bias)
def forward(self, style_input):
encode_bg_output = self.encoder_bg(style_input)
decode_bg_output = self.decoder_bg(encode_bg_output["res_blocks"],
encode_bg_output["down2"],
encode_bg_output["down1"])
fake_c_temp = decode_bg_output["out_conv"]
fake_bg_mask = self.decoder_mask.forward(encode_bg_output[
"res_blocks"])["out_conv"]
fake_bg = self.middle(
paddle.concat(
(fake_c_temp, fake_bg_mask), axis=1))
return {
"bg_encode_feature": encode_bg_output["res_blocks"],
"bg_decode_feature1": decode_bg_output["up1"],
"bg_decode_feature2": decode_bg_output["up2"],
"fake_bg": fake_bg,
"fake_bg_mask": fake_bg_mask,
}
class FusionGeneratorSimple(nn.Layer):
def __init__(self, config):
super(FusionGeneratorSimple, self).__init__()
name = config["module_name"]
encode_dim = config["encode_dim"]
norm_layer = config["norm_layer"]
conv_block_dropout = config["conv_block_dropout"]
conv_block_dilation = config["conv_block_dilation"]
if norm_layer == "InstanceNorm2D":
use_bias = True
else:
use_bias = False
self._conv = nn.Conv2D(
in_channels=6,
out_channels=encode_dim,
kernel_size=3,
stride=1,
padding=1,
groups=1,
weight_attr=paddle.ParamAttr(name=name + "_conv_weights"),
bias_attr=False)
self._res_block = ResBlock(
name="{}_conv_block".format(name),
channels=encode_dim,
norm_layer=norm_layer,
use_dropout=conv_block_dropout,
use_dilation=conv_block_dilation,
use_bias=use_bias)
self._reduce_conv = nn.Conv2D(
in_channels=encode_dim,
out_channels=3,
kernel_size=3,
stride=1,
padding=1,
groups=1,
weight_attr=paddle.ParamAttr(name=name + "_reduce_conv_weights"),
bias_attr=False)
def forward(self, fake_text, fake_bg):
fake_concat = paddle.concat((fake_text, fake_bg), axis=1)
fake_concat_tmp = self._conv(fake_concat)
output_res = self._res_block(fake_concat_tmp)
fake_fusion = self._reduce_conv(output_res)
return {"fake_fusion": fake_fusion}
Global:
output_num: 10
output_dir: output_data
use_gpu: false
image_height: 32
image_width: 320
TextDrawer:
fonts:
en: fonts/en_standard.ttf
ch: fonts/ch_standard.ttf
ko: fonts/ko_standard.ttf
Predictor:
method: StyleTextRecPredictor
algorithm: StyleTextRec
scale: 0.00392156862745098
mean:
- 0.5
- 0.5
- 0.5
std:
- 0.5
- 0.5
- 0.5
expand_result: false
bg_generator:
pretrain: style_text_models/bg_generator
module_name: bg_generator
generator_type: BgGeneratorWithMask
encode_dim: 64
norm_layer: null
conv_block_num: 4
conv_block_dropout: false
conv_block_dilation: true
output_factor: 1.05
text_generator:
pretrain: style_text_models/text_generator
module_name: text_generator
generator_type: TextGenerator
encode_dim: 64
norm_layer: InstanceNorm2D
conv_block_num: 4
conv_block_dropout: false
conv_block_dilation: true
fusion_generator:
pretrain: style_text_models/fusion_generator
module_name: fusion_generator
generator_type: FusionGeneratorSimple
encode_dim: 64
norm_layer: null
conv_block_num: 4
conv_block_dropout: false
conv_block_dilation: true
Writer:
method: SimpleWriter
Global:
output_num: 10
output_dir: output_data
use_gpu: false
image_height: 32
image_width: 320
standard_font: fonts/en_standard.ttf
TextDrawer:
fonts:
en: fonts/en_standard.ttf
ch: fonts/ch_standard.ttf
ko: fonts/ko_standard.ttf
StyleSampler:
method: DatasetSampler
image_home: examples
label_file: examples/image_list.txt
with_label: true
CorpusGenerator:
method: FileCorpus
language: ch
corpus_file: examples/corpus/example.txt
Predictor:
method: StyleTextRecPredictor
algorithm: StyleTextRec
scale: 0.00392156862745098
mean:
- 0.5
- 0.5
- 0.5
std:
- 0.5
- 0.5
- 0.5
expand_result: false
bg_generator:
pretrain: models/style_text_rec/bg_generator
module_name: bg_generator
generator_type: BgGeneratorWithMask
encode_dim: 64
norm_layer: null
conv_block_num: 4
conv_block_dropout: false
conv_block_dilation: true
output_factor: 1.05
text_generator:
pretrain: models/style_text_rec/text_generator
module_name: text_generator
generator_type: TextGenerator
encode_dim: 64
norm_layer: InstanceNorm2D
conv_block_num: 4
conv_block_dropout: false
conv_block_dilation: true
fusion_generator:
pretrain: models/style_text_rec/fusion_generator
module_name: fusion_generator
generator_type: FusionGeneratorSimple
encode_dim: 64
norm_layer: null
conv_block_num: 4
conv_block_dropout: false
conv_block_dilation: true
Writer:
method: SimpleWriter
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from utils.logging import get_logger
class FileCorpus(object):
def __init__(self, config):
self.logger = get_logger()
self.logger.info("using FileCorpus")
self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
corpus_file = config["CorpusGenerator"]["corpus_file"]
self.language = config["CorpusGenerator"]["language"]
with open(corpus_file, 'r') as f:
corpus_raw = f.read()
self.corpus_list = corpus_raw.split("\n")[:-1]
assert len(self.corpus_list) > 0
random.shuffle(self.corpus_list)
self.index = 0
def generate(self, corpus_length=0):
if self.index >= len(self.corpus_list):
self.index = 0
random.shuffle(self.corpus_list)
corpus = self.corpus_list[self.index]
if corpus_length != 0:
corpus = corpus[0:corpus_length]
if corpus_length > len(corpus):
self.logger.warning("generated corpus is shorter than expected.")
self.index += 1
return self.language, corpus
class EnNumCorpus(object):
def __init__(self, config):
self.logger = get_logger()
self.logger.info("using NumberCorpus")
self.num_list = "0123456789"
self.en_char_list = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
self.height = config["Global"]["image_height"]
self.max_width = config["Global"]["image_width"]
def generate(self, corpus_length=0):
corpus = ""
if corpus_length == 0:
corpus_length = random.randint(5, 15)
for i in range(corpus_length):
if random.random() < 0.2:
corpus += "{}".format(random.choice(self.en_char_list))
else:
corpus += "{}".format(random.choice(self.num_list))
return "en", corpus
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
import math
import paddle
from arch import style_text_rec
from utils.sys_funcs import check_gpu
from utils.logging import get_logger
class StyleTextRecPredictor(object):
def __init__(self, config):
algorithm = config['Predictor']['algorithm']
assert algorithm in ["StyleTextRec"
], "Generator {} not supported.".format(algorithm)
use_gpu = config["Global"]['use_gpu']
check_gpu(use_gpu)
self.logger = get_logger()
self.generator = getattr(style_text_rec, algorithm)(config)
self.height = config["Global"]["image_height"]
self.width = config["Global"]["image_width"]
self.scale = config["Predictor"]["scale"]
self.mean = config["Predictor"]["mean"]
self.std = config["Predictor"]["std"]
self.expand_result = config["Predictor"]["expand_result"]
def predict(self, style_input, text_input):
style_input = self.rep_style_input(style_input, text_input)
tensor_style_input = self.preprocess(style_input)
tensor_text_input = self.preprocess(text_input)
style_text_result = self.generator.forward(tensor_style_input,
tensor_text_input)
fake_fusion = self.postprocess(style_text_result["fake_fusion"])
fake_text = self.postprocess(style_text_result["fake_text"])
fake_sk = self.postprocess(style_text_result["fake_sk"])
fake_bg = self.postprocess(style_text_result["fake_bg"])
bbox = self.get_text_boundary(fake_text)
if bbox:
left, right, top, bottom = bbox
fake_fusion = fake_fusion[top:bottom, left:right, :]
fake_text = fake_text[top:bottom, left:right, :]
fake_sk = fake_sk[top:bottom, left:right, :]
fake_bg = fake_bg[top:bottom, left:right, :]
# fake_fusion = self.crop_by_text(img_fake_fusion, img_fake_text)
return {
"fake_fusion": fake_fusion,
"fake_text": fake_text,
"fake_sk": fake_sk,
"fake_bg": fake_bg,
}
def preprocess(self, img):
img = (img.astype('float32') * self.scale - self.mean) / self.std
img_height, img_width, channel = img.shape
assert channel == 3, "Please use an rgb image."
ratio = img_width / float(img_height)
if math.ceil(self.height * ratio) > self.width:
resized_w = self.width
else:
resized_w = int(math.ceil(self.height * ratio))
img = cv2.resize(img, (resized_w, self.height))
new_img = np.zeros([self.height, self.width, 3]).astype('float32')
new_img[:, 0:resized_w, :] = img
img = new_img.transpose((2, 0, 1))
img = img[np.newaxis, :, :, :]
return paddle.to_tensor(img)
def postprocess(self, tensor):
img = tensor.numpy()[0]
img = img.transpose((1, 2, 0))
img = (img * self.std + self.mean) / self.scale
img = np.maximum(img, 0.0)
img = np.minimum(img, 255.0)
img = img.astype('uint8')
return img
def rep_style_input(self, style_input, text_input):
rep_num = int(1.2 * (text_input.shape[1] / text_input.shape[0]) /
(style_input.shape[1] / style_input.shape[0])) + 1
style_input = np.tile(style_input, reps=[1, rep_num, 1])
max_width = int(self.width / self.height * style_input.shape[0])
style_input = style_input[:, :max_width, :]
return style_input
def get_text_boundary(self, text_img):
img_height = text_img.shape[0]
img_width = text_img.shape[1]
bounder = 3
text_canny_img = cv2.Canny(text_img, 10, 20)
edge_num_h = text_canny_img.sum(axis=0)
no_zero_list_h = np.where(edge_num_h > 0)[0]
edge_num_w = text_canny_img.sum(axis=1)
no_zero_list_w = np.where(edge_num_w > 0)[0]
if len(no_zero_list_h) == 0 or len(no_zero_list_w) == 0:
return None
left = max(no_zero_list_h[0] - bounder, 0)
right = min(no_zero_list_h[-1] + bounder, img_width)
top = max(no_zero_list_w[0] - bounder, 0)
bottom = min(no_zero_list_w[-1] + bounder, img_height)
return [left, right, top, bottom]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import random
import cv2
class DatasetSampler(object):
def __init__(self, config):
self.image_home = config["StyleSampler"]["image_home"]
label_file = config["StyleSampler"]["label_file"]
self.dataset_with_label = config["StyleSampler"]["with_label"]
self.height = config["Global"]["image_height"]
self.index = 0
with open(label_file, "r") as f:
label_raw = f.read()
self.path_label_list = label_raw.split("\n")[:-1]
assert len(self.path_label_list) > 0
random.shuffle(self.path_label_list)
def sample(self):
if self.index >= len(self.path_label_list):
random.shuffle(self.path_label_list)
self.index = 0
if self.dataset_with_label:
path_label = self.path_label_list[self.index]
rel_image_path, label = path_label.split('\t')
else:
rel_image_path = self.path_label_list[self.index]
label = None
img_path = "{}/{}".format(self.image_home, rel_image_path)
image = cv2.imread(img_path)
origin_height = image.shape[0]
ratio = self.height / origin_height
width = int(image.shape[1] * ratio)
height = int(image.shape[0] * ratio)
image = cv2.resize(image, (width, height))
self.index += 1
if label:
return {"image": image, "label": label}
else:
return {"image": image}
def duplicate_image(image, width):
image_width = image.shape[1]
dup_num = width // image_width + 1
image = np.tile(image, reps=[1, dup_num, 1])
cropped_image = image[:, :width, :]
return cropped_image
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from utils.config import ArgsParser, load_config, override_config
from utils.logging import get_logger
from engine import style_samplers, corpus_generators, text_drawers, predictors, writers
class ImageSynthesiser(object):
def __init__(self):
self.FLAGS = ArgsParser().parse_args()
self.config = load_config(self.FLAGS.config)
self.config = override_config(self.config, options=self.FLAGS.override)
self.output_dir = self.config["Global"]["output_dir"]
if not os.path.exists(self.output_dir):
os.mkdir(self.output_dir)
self.logger = get_logger(
log_file='{}/predict.log'.format(self.output_dir))
self.text_drawer = text_drawers.StdTextDrawer(self.config)
predictor_method = self.config["Predictor"]["method"]
assert predictor_method is not None
self.predictor = getattr(predictors, predictor_method)(self.config)
def synth_image(self, corpus, style_input, language="en"):
corpus, text_input = self.text_drawer.draw_text(corpus, language)
synth_result = self.predictor.predict(style_input, text_input)
return synth_result
class DatasetSynthesiser(ImageSynthesiser):
def __init__(self):
super(DatasetSynthesiser, self).__init__()
self.tag = self.FLAGS.tag
self.output_num = self.config["Global"]["output_num"]
corpus_generator_method = self.config["CorpusGenerator"]["method"]
self.corpus_generator = getattr(corpus_generators,
corpus_generator_method)(self.config)
style_sampler_method = self.config["StyleSampler"]["method"]
assert style_sampler_method is not None
self.style_sampler = style_samplers.DatasetSampler(self.config)
self.writer = writers.SimpleWriter(self.config, self.tag)
def synth_dataset(self):
for i in range(self.output_num):
style_data = self.style_sampler.sample()
style_input = style_data["image"]
corpus_language, text_input_label = self.corpus_generator.generate(
)
text_input_label, text_input = self.text_drawer.draw_text(
text_input_label, corpus_language)
synth_result = self.predictor.predict(style_input, text_input)
fake_fusion = synth_result["fake_fusion"]
self.writer.save_image(fake_fusion, text_input_label)
self.writer.save_label()
self.writer.merge_label()
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from utils.logging import get_logger
class StdTextDrawer(object):
def __init__(self, config):
self.logger = get_logger()
self.max_width = config["Global"]["image_width"]
self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
self.height = config["Global"]["image_height"]
self.font_dict = {}
self.load_fonts(config["TextDrawer"]["fonts"])
self.support_languages = list(self.font_dict)
def load_fonts(self, fonts_config):
for language in fonts_config:
font_path = fonts_config[language]
font_height = self.get_valid_height(font_path)
font = ImageFont.truetype(font_path, font_height)
self.font_dict[language] = font
def get_valid_height(self, font_path):
font = ImageFont.truetype(font_path, self.height - 4)
_, font_height = font.getsize(self.char_list)
if font_height <= self.height - 4:
return self.height - 4
else:
return int((self.height - 4)**2 / font_height)
def draw_text(self, corpus, language="en", crop=True):
if language not in self.support_languages:
self.logger.warning(
"language {} not supported, use en instead.".format(language))
language = "en"
if crop:
width = min(self.max_width, len(corpus) * self.height) + 4
else:
width = len(corpus) * self.height + 4
bg = Image.new("RGB", (width, self.height), color=(127, 127, 127))
draw = ImageDraw.Draw(bg)
char_x = 2
font = self.font_dict[language]
for i, char_i in enumerate(corpus):
char_size = font.getsize(char_i)[0]
draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font)
char_x += char_size
if char_x >= width:
corpus = corpus[0:i + 1]
self.logger.warning("corpus length exceed limit: {}".format(
corpus))
break
text_input = np.array(bg).astype(np.uint8)
text_input = text_input[:, 0:char_x, :]
return corpus, text_input
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import glob
from utils.logging import get_logger
class SimpleWriter(object):
def __init__(self, config, tag):
self.logger = get_logger()
self.output_dir = config["Global"]["output_dir"]
self.counter = 0
self.label_dict = {}
self.tag = tag
self.label_file_index = 0
def save_image(self, image, text_input_label):
image_home = os.path.join(self.output_dir, "images", self.tag)
if not os.path.exists(image_home):
os.makedirs(image_home)
image_path = os.path.join(image_home, "{}.png".format(self.counter))
# todo support continue synth
cv2.imwrite(image_path, image)
self.logger.info("generate image: {}".format(image_path))
image_name = os.path.join(self.tag, "{}.png".format(self.counter))
self.label_dict[image_name] = text_input_label
self.counter += 1
if not self.counter % 100:
self.save_label()
def save_label(self):
label_raw = ""
label_home = os.path.join(self.output_dir, "label")
if not os.path.exists(label_home):
os.mkdir(label_home)
for image_path in self.label_dict:
label = self.label_dict[image_path]
label_raw += "{}\t{}\n".format(image_path, label)
label_file_path = os.path.join(label_home,
"{}_label.txt".format(self.tag))
with open(label_file_path, "w") as f:
f.write(label_raw)
self.label_file_index += 1
def merge_label(self):
label_raw = ""
label_file_regex = os.path.join(self.output_dir, "label",
"*_label.txt")
label_file_list = glob.glob(label_file_regex)
for label_file_i in label_file_list:
with open(label_file_i, "r") as f:
label_raw += f.read()
label_file_path = os.path.join(self.output_dir, "label.txt")
with open(label_file_path, "w") as f:
f.write(label_raw)
style_images/1.jpg NEATNESS
style_images/2.jpg 锁店君和宾馆
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from engine.synthesisers import DatasetSynthesiser
def synth_dataset():
dataset_synthesiser = DatasetSynthesiser()
dataset_synthesiser.synth_dataset()
if __name__ == '__main__':
synth_dataset()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import sys
import glob
from utils.config import ArgsParser
from engine.synthesisers import ImageSynthesiser
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
def synth_image():
args = ArgsParser().parse_args()
image_synthesiser = ImageSynthesiser()
style_image_path = args.style_image
img = cv2.imread(style_image_path)
text_corpus = args.text_corpus
language = args.language
synth_result = image_synthesiser.synth_image(text_corpus, img, language)
fake_fusion = synth_result["fake_fusion"]
fake_text = synth_result["fake_text"]
fake_bg = synth_result["fake_bg"]
cv2.imwrite("fake_fusion.jpg", fake_fusion)
cv2.imwrite("fake_text.jpg", fake_text)
cv2.imwrite("fake_bg.jpg", fake_bg)
def batch_synth_images():
image_synthesiser = ImageSynthesiser()
corpus_file = "../StyleTextRec_data/test_20201208/test_text_list.txt"
style_data_dir = "../StyleTextRec_data/test_20201208/style_images/"
save_path = "./output_data/"
corpus_list = []
with open(corpus_file, "rb") as fin:
lines = fin.readlines()
for line in lines:
substr = line.decode("utf-8").strip("\n").split("\t")
corpus_list.append(substr)
style_img_list = glob.glob("{}/*.jpg".format(style_data_dir))
corpus_num = len(corpus_list)
style_img_num = len(style_img_list)
for cno in range(corpus_num):
for sno in range(style_img_num):
corpus, lang = corpus_list[cno]
style_img_path = style_img_list[sno]
img = cv2.imread(style_img_path)
synth_result = image_synthesiser.synth_image(corpus, img, lang)
fake_fusion = synth_result["fake_fusion"]
fake_text = synth_result["fake_text"]
fake_bg = synth_result["fake_bg"]
for tp in range(2):
if tp == 0:
prefix = "%s/c%d_s%d_" % (save_path, cno, sno)
else:
prefix = "%s/s%d_c%d_" % (save_path, sno, cno)
cv2.imwrite("%s_fake_fusion.jpg" % prefix, fake_fusion)
cv2.imwrite("%s_fake_text.jpg" % prefix, fake_text)
cv2.imwrite("%s_fake_bg.jpg" % prefix, fake_bg)
cv2.imwrite("%s_input_style.jpg" % prefix, img)
print(cno, corpus_num, sno, style_img_num)
if __name__ == '__main__':
# batch_synth_images()
synth_image()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
import os
from argparse import ArgumentParser, RawDescriptionHelpFormatter
def override(dl, ks, v):
"""
Recursively replace dict of list
Args:
dl(dict or list): dict or list to be replaced
ks(list): list of keys
v(str): value to be replaced
"""
def str2num(v):
try:
return eval(v)
except Exception:
return v
assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
assert len(ks) > 0, ('lenght of keys should larger than 0')
if isinstance(dl, list):
k = str2num(ks[0])
if len(ks) == 1:
assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
dl[k] = str2num(v)
else:
override(dl[k], ks[1:], v)
else:
if len(ks) == 1:
#assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
if not ks[0] in dl:
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
dl[ks[0]] = str2num(v)
else:
assert ks[0] in dl, (
'({}) doesn\'t exist in {}, a new dict field is invalid'.
format(ks[0], dl))
override(dl[ks[0]], ks[1:], v)
def override_config(config, options=None):
"""
Recursively override the config
Args:
config(dict): dict to be replaced
options(list): list of pairs(key0.key1.idx.key2=value)
such as: [
'topk=2',
'VALID.transforms.1.ResizeImage.resize_short=300'
]
Returns:
config(dict): replaced config
"""
if options is not None:
for opt in options:
assert isinstance(opt, str), (
"option({}) should be a str".format(opt))
assert "=" in opt, (
"option({}) should contain a ="
"to distinguish between key and value".format(opt))
pair = opt.split('=')
assert len(pair) == 2, ("there can be only a = in the option")
key, value = pair
keys = key.split('.')
override(config, keys, value)
return config
class ArgsParser(ArgumentParser):
def __init__(self):
super(ArgsParser, self).__init__(
formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-c", "--config", help="configuration file to use")
self.add_argument(
"-t", "--tag", default="0", help="tag for marking worker")
self.add_argument(
'-o',
'--override',
action='append',
default=[],
help='config options to be overridden')
self.add_argument(
"--style_image", default="examples/style_images/1.jpg", help="tag for marking worker")
self.add_argument(
"--text_corpus", default="PaddleOCR", help="tag for marking worker")
self.add_argument(
"--language", default="en", help="tag for marking worker")
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
assert args.config is not None, \
"Please specify --config=configure_file_path."
return args
def load_config(file_path):
"""
Load config from yml/yaml file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: config
"""
ext = os.path.splitext(file_path)[1]
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
with open(file_path, 'rb') as f:
config = yaml.load(f, Loader=yaml.Loader)
return config
def gen_config():
base_config = {
"Global": {
"algorithm": "SRNet",
"use_gpu": True,
"start_epoch": 1,
"stage1_epoch_num": 100,
"stage2_epoch_num": 100,
"log_smooth_window": 20,
"print_batch_step": 2,
"save_model_dir": "./output/SRNet",
"use_visualdl": False,
"save_epoch_step": 10,
"vgg_pretrain": "./pretrained/VGG19_pretrained",
"vgg_load_static_pretrain": True
},
"Architecture": {
"model_type": "data_aug",
"algorithm": "SRNet",
"net_g": {
"name": "srnet_net_g",
"encode_dim": 64,
"norm": "batch",
"use_dropout": False,
"init_type": "xavier",
"init_gain": 0.02,
"use_dilation": 1
},
# input_nc, ndf, netD,
# n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
"bg_discriminator": {
"name": "srnet_bg_discriminator",
"input_nc": 6,
"ndf": 64,
"netD": "basic",
"norm": "none",
"init_type": "xavier",
},
"fusion_discriminator": {
"name": "srnet_fusion_discriminator",
"input_nc": 6,
"ndf": 64,
"netD": "basic",
"norm": "none",
"init_type": "xavier",
}
},
"Loss": {
"lamb": 10,
"perceptual_lamb": 1,
"muvar_lamb": 50,
"style_lamb": 500
},
"Optimizer": {
"name": "Adam",
"learning_rate": {
"name": "lambda",
"lr": 0.0002,
"lr_decay_iters": 50
},
"beta1": 0.5,
"beta2": 0.999,
},
"Train": {
"batch_size_per_card": 8,
"num_workers_per_card": 4,
"dataset": {
"delimiter": "\t",
"data_dir": "/",
"label_file": "tmp/label.txt",
"transforms": [{
"DecodeImage": {
"to_rgb": True,
"to_np": False,
"channel_first": False
}
}, {
"NormalizeImage": {
"scale": 1. / 255.,
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"order": None
}
}, {
"ToCHWImage": None
}]
}
}
}
with open("config.yml", "w") as f:
yaml.dump(base_config, f)
if __name__ == '__main__':
gen_config()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
__all__ = ['load_dygraph_pretrain']
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
if not os.path.exists(path + '.pdparams'):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
param_state_dict = paddle.load(path + '.pdparams')
model.set_state_dict(param_state_dict)
logger.info("load pretrained model from {}".format(path))
return
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import functools
import paddle.distributed as dist
logger_initialized = {}
@functools.lru_cache()
def get_logger(name='srnet', log_file=None, log_level=logging.INFO):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S")
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
file_handler = logging.FileHandler(log_file, 'a')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if dist.get_rank() == 0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
return logger
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def compute_mean_covariance(img):
batch_size = img.shape[0]
channel_num = img.shape[1]
height = img.shape[2]
width = img.shape[3]
num_pixels = height * width
# batch_size * channel_num * 1 * 1
mu = img.mean(2, keepdim=True).mean(3, keepdim=True)
# batch_size * channel_num * num_pixels
img_hat = img - mu.expand_as(img)
img_hat = img_hat.reshape([batch_size, channel_num, num_pixels])
# batch_size * num_pixels * channel_num
img_hat_transpose = img_hat.transpose([0, 2, 1])
# batch_size * channel_num * channel_num
covariance = paddle.bmm(img_hat, img_hat_transpose)
covariance = covariance / num_pixels
return mu, covariance
def dice_coefficient(y_true_cls, y_pred_cls, training_mask):
eps = 1e-5
intersection = paddle.sum(y_true_cls * y_pred_cls * training_mask)
union = paddle.sum(y_true_cls * training_mask) + paddle.sum(
y_pred_cls * training_mask) + eps
loss = 1. - (2 * intersection / union)
return loss
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import errno
import paddle
def get_check_global_params(mode):
check_params = [
'use_gpu', 'max_text_length', 'image_shape', 'image_shape',
'character_type', 'loss_type'
]
if mode == "train_eval":
check_params = check_params + [
'train_batch_size_per_card', 'test_batch_size_per_card'
]
elif mode == "test":
check_params = check_params + ['test_batch_size_per_card']
return check_params
def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
if use_gpu:
try:
if not paddle.is_compiled_with_cuda():
print(err)
sys.exit(1)
except:
print("Fail to check gpu state.")
sys.exit(1)
def _mkdir_if_not_exist(path, logger):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if not os.path.exists(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
logger.warning(
'be happy if some process has already created {}'.format(
path))
else:
raise OSError('Failed to mkdir {}'.format(path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册