diff --git a/doc/doc_ch/style_text_rec.md b/doc/doc_ch/style_text_rec.md new file mode 100644 index 0000000000000000000000000000000000000000..3a6d8b8bc020b3af84140bb07b6bcc4e6e7c5de3 --- /dev/null +++ b/doc/doc_ch/style_text_rec.md @@ -0,0 +1,150 @@ +## Style Text Rec + +### 目录 +[工具简介](#工具简介) +[环境配置](#环境配置) +[快速上手](#快速上手) +[高级使用](#高级使用) +[应用示例](#应用示例) + +### 工具简介 +
+ +
+ +Style-Text是对百度自研文本编辑算法《Editing Text in the Wild》中提出的SRNet网络的改进,不同于常用的GAN的方法只选择一个分支,该工具将文本合成任务分解为三个子模块,文本风格迁移模块、背景抽取模块和前背景融合模块,来提升合成数据的效果。下图显示了一些示例结果。 + +
+ + +
+ +此外,在实际铭牌文本识别场景和韩语文本识别场景,验证了该合成工具的有效性。 + +### 环境配置 + +1. 参考[快速安装](./installation.md),安装PaddleOCR。强烈建议您使用python3环境。 +2. 进入`style_text_rec`目录,下载模型,并解压: + +```bash +cd style_text_rec +wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip +unzip style_text_models.zip +``` + +如果您将模型保存再其他位置,请在`configs/config.yml`中修改模型文件的地址,修改时需要同时修改这三个配置: + +``` +bg_generator: + pretrain: style_text_models/bg_generator +... +text_generator: + pretrain: style_text_models/text_generator +... +fusion_generator: + pretrain: style_text_models/fusion_generator +``` + +### 快速上手 + +1. 运行tools/synth_image,生成示例图片: + +```python +python3 -m tools.synth_image -c configs/config.yml +``` + +1. 运行后,会生成`fake_busion.jpg`,即为最终结果。 +
+ +
+除此之外,程序还会生成并保存中间结果: + * `fake_bg.jpg`:为风格参考图去掉文字后的背景; + * `fake_text.jpg`:是用提供的字符串,仿照风格参考图中文字的风格,生成在灰色背景上的文字图片。 + +2. 如果您想尝试其他风格图像和文字的效果,可以添加style_image,text_corpus和language参数: +```python +python3 -m tools.synth_image -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en +``` + * 注意:语言选项和语料相对应,目前我们支持英文、简体中文和韩语。 + +3. 在`tools/synth_image.py`中,我们还提供了一个`batch_synth_images`方法,可以两两组合语料和图片,批量生成一批数据。 + +### 高级使用 + +在开始合成数据集前,需要准备一些素材。 + +首先,需要风格图片作为合成图片的参考依据,这些数据可以是用作训练OCR识别模型的数据集。本例中使用带有标注文件的数据集作为风格图片. + +1. 在`configs/dataset_config.yml`中配置输入数据路径。 + * `StyleSampler`: + * `method`:使用的风格图片采样方法; + * `image_home`:风格图片目录; + * `label_file`:风格图片路径列表文件,如果所用数据集有label,则label_file为label文件路径; + * `with_label`:标志`label_file`是否为label文件。 + 我们提供了一批[样例图](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/chkoen_5w.tar)供您试用。 + * `CorpusGenerator`: + * `method`:语料生成方法,目前有`FileCorpus`和`EnNumCorpus`可选。如果使用`EnNumCorpus`,则不需要填写其他配置,否则需要修改`corpus_file`和`language`; + * `language`:语料的语种; + * `corpus_file`: 语料文件路径。 + +2. 运行`tools/synth_dataset`合成数据: + + ``` bash + python -m tools.synth_dataset -c configs/dataset_config.yml + ``` + +3. 如果您想使用并行方式来快速合成数据,可以通过启动多个进程,在启动时需要指定不同的`tag`(`-t`),如下所示: + + ```bash + python3 -m tools.synth_dataset -t 0 -c configs/dataset_config.yml + python3 -m tools.synth_dataset -t 1 -c configs/dataset_config.yml + ``` + + +### 应用示例 + +在完成上述操作后,即可得到用于OCR识别的合成数据集,下面给出了一些数据集生成的示例: + +接下来请参考[OCR识别文档](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/recognition.md#%E5%90%AF%E5%8A%A8%E8%AE%AD%E7%BB%83),完成训练。 + +### 项目结构 +``` +style_text_rec +|-- arch +| |-- base_module.py +| |-- decoder.py +| |-- encoder.py +| |-- spectral_norm.py +| `-- style_text_rec.py +|-- configs +| |-- config.yml +| `-- dataset_config.yml +|-- engine +| |-- corpus_generators.py +| |-- predictors.py +| |-- style_samplers.py +| |-- synthesisers.py +| |-- text_drawers.py +| `-- writers.py +|-- examples +| |-- corpus +| | `-- example.txt +| |-- image_list.txt +| `-- style_images +| |-- 1.jpg +| `-- 2.jpg +|-- fonts +| |-- ch_standard.ttf +| |-- en_standard.ttf +| `-- ko_standard.ttf +|-- tools +| |-- __init__.py +| |-- synth_dataset.py +| `-- synth_image.py +`-- utils + |-- config.py + |-- load_params.py + |-- logging.py + |-- math_functions.py + `-- sys_funcs.py +``` \ No newline at end of file diff --git a/doc/doc_en/style_text_rec_en.md b/doc/doc_en/style_text_rec_en.md new file mode 100644 index 0000000000000000000000000000000000000000..7e7d29c93606032fd39ab686af232cec395bac9c --- /dev/null +++ b/doc/doc_en/style_text_rec_en.md @@ -0,0 +1,106 @@ +### Quick Start + +`Style-Text` is an improvement of the SRNet network proposed in Baidu's self-developed text editing algorithm "Editing Text in the Wild". It is different from the commonly used GAN methods. This tool decomposes the text synthesis task into three sub-modules to improve the effect of synthetic data: text style transfer module, background extraction module and fusion module. + +The following figure shows some example results. In addition, the actual `nameplate text recognition` scene and `the Korean text recognition` scene verify the effectiveness of the synthesis tool, as follows. + + +#### Preparation + +1. Please refer the [QUICK INSTALLATION](./installation_en.md) to install PaddlePaddle. Python3 environment is strongly recommended. +2. Download the pretrained models and unzip: + +```bash +cd tools/style_text_rec +wget /path/to/style_text_models.zip +unzip style_text_models.zip +``` + +You can dowload models [here](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip). If you save the model files in other folders, please edit the three model paths in `configs/config.yml`: + +``` +bg_generator: + pretrain: style_text_rec/bg_generator +... +text_generator: + pretrain: style_text_models/text_generator +... +fusion_generator: + pretrain: style_text_models/fusion_generator +``` + + + +#### Demo + +1. You can use the following commands to run a demo: + +```bash +python -m tools.synth_image -c configs/config.yml +``` + +2. The results are `fake_bg.jpg`, `fake_text.jpg` and `fake_fusion.jpg` as shown in the figure above. Above them: + * `fake_text.jpg` is the generated image with the same font style as `Style Input`; + * `fake_bg.jpg` is the generated image of `Style Input` after removing foreground. + * `fake_fusion.jpg` is the final result, that is synthesised by `fake_text.jpg` and `fake_bg.jpg`. + +3. If want to generate image by other `Style Input` or `Text Input`, you can modify the `tools/synth_image.py`: + * `img = cv2.imread("examples/style_images/1.jpg")`: the path of `Style Input`; + * `corpus = "PaddleOCR"`: the `Text Input`; + * Notice:modify the language option(`language = "en"`) to adapt `Text Input`, that support `en`, `ch`, `ko`. + +4. We also provide `batch_synth_images` mothod, that can combine corpus and pictures in pairs to generate a batch of data. + +### Advanced Usage + +#### Components + +`Style Text Rec` mainly contains the following components: + +* `style_samplers`: It can sample `Style Input` from a dataset. Now, We only provide `DatasetSampler`. + +* `corpus_generators`: It can generate corpus. Now, wo only provide two `corpus_generators`: + * `EnNumCorpus`: It can generate a random string according to a given length, including uppercase and lowercase English letters, numbers and spaces. + * `FileCorpus`: It can read a text file and randomly return the words in it. + +* `text_drawers`: It can generate `Text Input`(text picture in standard font according to the input corpus). Note that when using, you have to modify the language information according to the corpus. + +* `predictors`: It can call the deep learning model to generate new data based on the `Style Input` and `Text Input`. + +* `writers`: It can write the generated pictures(`fake_bg.jpg`, `fake_text.jpg` and `fake_fusion.jpg`) and label information to the disk. + +* `synthesisers`: It can call the all modules to complete the work. + +### Generate Dataset + +Before the start, you need to prepare some data as material. +First, you should have the style reference data for synthesis tasks, which are generally used as datasets for OCR recognition tasks. + +1. The referenced dataset can be specifed in `configs/dataset_config.yml`: + * `StyleSampler`: + * `method`: The method of `StyleSampler`. + * `image_home`: The directory of pictures. + * `label_file`: The list of pictures path if `with_label` is `false`, otherwise, the label file path. + * `with_label`: The `label_file` is label file or not. + + * `CorpusGenerator`: + * `method`: The mothod of `CorpusGenerator`. If `FileCorpus` used, you need modify `corpus_file` and `language` accordingly, if `EnNumCorpus`, other configurations is not needed. + * `language`: The language of the corpus. Needed if method is not `EnNumCorpus`. + * `corpus_file`: The corpus file path. Needed if method is not `EnNumCorpus`. + +2. You can run the following command to start synthesis task: + + ``` bash + python -m tools.synth_dataset.py -c configs/dataset_config.yml + ``` + +3. You can using the following command to start multiple synthesis tasks in a multi-threaded manner, which needed to specifying tags by `-t`: + + ```bash + python -m tools.synth_dataset.py -t 0 -c configs/dataset_config.yml + python -m tools.synth_dataset.py -t 1 -c configs/dataset_config.yml + ``` + +### OCR Recognition Training + +After completing the above operations, you can get the synthetic data set for OCR recognition. Next, please complete the training by refering to [OCR Recognition Document](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/recognition. md#%E5%90%AF%E5%8A%A8%E8%AE%AD%E7%BB%83). \ No newline at end of file diff --git a/doc/imgs_style_text/1.png b/doc/imgs_style_text/1.png new file mode 100644 index 0000000000000000000000000000000000000000..8f7574ba2f723ac82241fec6dc52828713a5d293 Binary files /dev/null and b/doc/imgs_style_text/1.png differ diff --git a/doc/imgs_style_text/2.png b/doc/imgs_style_text/2.png new file mode 100644 index 0000000000000000000000000000000000000000..ce9bf4712a551b9d9d27eae00f9c7b9b5845d8b3 Binary files /dev/null and b/doc/imgs_style_text/2.png differ diff --git a/doc/imgs_style_text/3.png b/doc/imgs_style_text/3.png new file mode 100644 index 0000000000000000000000000000000000000000..0fb73a31f58c1c476cf84f3c507f0af6523385f4 Binary files /dev/null and b/doc/imgs_style_text/3.png differ diff --git a/doc/imgs_style_text/4.jpg b/doc/imgs_style_text/4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5fda9548632b63e55b42315dca4a5b9cec2a353c Binary files /dev/null and b/doc/imgs_style_text/4.jpg differ diff --git a/doc/imgs_style_text/5.png b/doc/imgs_style_text/5.png new file mode 100644 index 0000000000000000000000000000000000000000..ea0b89034940cc70e6ec8f77471f3af1c2b54219 Binary files /dev/null and b/doc/imgs_style_text/5.png differ diff --git a/style_text_rec/__init__.py b/style_text_rec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/style_text_rec/arch/__init__.py b/style_text_rec/arch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/style_text_rec/arch/base_module.py b/style_text_rec/arch/base_module.py new file mode 100644 index 0000000000000000000000000000000000000000..da2b6b834c6a86b1c3efeb5cef4cb9d02e44e405 --- /dev/null +++ b/style_text_rec/arch/base_module.py @@ -0,0 +1,255 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn + +from arch.spectral_norm import spectral_norm + + +class CBN(nn.Layer): + def __init__(self, + name, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + use_bias=False, + norm_layer=None, + act=None, + act_attr=None): + super(CBN, self).__init__() + if use_bias: + bias_attr = paddle.ParamAttr(name=name + "_bias") + else: + bias_attr = None + self._conv = paddle.nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + weight_attr=paddle.ParamAttr(name=name + "_weights"), + bias_attr=bias_attr) + if norm_layer: + self._norm_layer = getattr(paddle.nn, norm_layer)( + num_features=out_channels, name=name + "_bn") + else: + self._norm_layer = None + if act: + if act_attr: + self._act = getattr(paddle.nn, act)(**act_attr, + name=name + "_" + act) + else: + self._act = getattr(paddle.nn, act)(name=name + "_" + act) + else: + self._act = None + + def forward(self, x): + out = self._conv(x) + if self._norm_layer: + out = self._norm_layer(out) + if self._act: + out = self._act(out) + return out + + +class SNConv(nn.Layer): + def __init__(self, + name, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + use_bias=False, + norm_layer=None, + act=None, + act_attr=None): + super(SNConv, self).__init__() + if use_bias: + bias_attr = paddle.ParamAttr(name=name + "_bias") + else: + bias_attr = None + self._sn_conv = spectral_norm( + paddle.nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + weight_attr=paddle.ParamAttr(name=name + "_weights"), + bias_attr=bias_attr)) + if norm_layer: + self._norm_layer = getattr(paddle.nn, norm_layer)( + num_features=out_channels, name=name + "_bn") + else: + self._norm_layer = None + if act: + if act_attr: + self._act = getattr(paddle.nn, act)(**act_attr, + name=name + "_" + act) + else: + self._act = getattr(paddle.nn, act)(name=name + "_" + act) + else: + self._act = None + + def forward(self, x): + out = self._sn_conv(x) + if self._norm_layer: + out = self._norm_layer(out) + if self._act: + out = self._act(out) + return out + + +class SNConvTranspose(nn.Layer): + def __init__(self, + name, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + dilation=1, + groups=1, + use_bias=False, + norm_layer=None, + act=None, + act_attr=None): + super(SNConvTranspose, self).__init__() + if use_bias: + bias_attr = paddle.ParamAttr(name=name + "_bias") + else: + bias_attr = None + self._sn_conv_transpose = spectral_norm( + paddle.nn.Conv2DTranspose( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + dilation=dilation, + groups=groups, + weight_attr=paddle.ParamAttr(name=name + "_weights"), + bias_attr=bias_attr)) + if norm_layer: + self._norm_layer = getattr(paddle.nn, norm_layer)( + num_features=out_channels, name=name + "_bn") + else: + self._norm_layer = None + if act: + if act_attr: + self._act = getattr(paddle.nn, act)(**act_attr, + name=name + "_" + act) + else: + self._act = getattr(paddle.nn, act)(name=name + "_" + act) + else: + self._act = None + + def forward(self, x): + out = self._sn_conv_transpose(x) + if self._norm_layer: + out = self._norm_layer(out) + if self._act: + out = self._act(out) + return out + + +class MiddleNet(nn.Layer): + def __init__(self, name, in_channels, mid_channels, out_channels, + use_bias): + super(MiddleNet, self).__init__() + self._sn_conv1 = SNConv( + name=name + "_sn_conv1", + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + use_bias=use_bias, + norm_layer=None, + act=None) + self._pad2d = nn.Pad2D(padding=[1, 1, 1, 1], mode="replicate") + self._sn_conv2 = SNConv( + name=name + "_sn_conv2", + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=3, + use_bias=use_bias) + self._sn_conv3 = SNConv( + name=name + "_sn_conv3", + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + use_bias=use_bias) + + def forward(self, x): + + sn_conv1 = self._sn_conv1.forward(x) + pad_2d = self._pad2d.forward(sn_conv1) + sn_conv2 = self._sn_conv2.forward(pad_2d) + sn_conv3 = self._sn_conv3.forward(sn_conv2) + return sn_conv3 + + +class ResBlock(nn.Layer): + def __init__(self, name, channels, norm_layer, use_dropout, use_dilation, + use_bias): + super(ResBlock, self).__init__() + if use_dilation: + padding_mat = [1, 1, 1, 1] + else: + padding_mat = [0, 0, 0, 0] + self._pad1 = nn.Pad2D(padding_mat, mode="replicate") + + self._sn_conv1 = SNConv( + name=name + "_sn_conv1", + in_channels=channels, + out_channels=channels, + kernel_size=3, + padding=0, + norm_layer=norm_layer, + use_bias=use_bias, + act="ReLU", + act_attr=None) + if use_dropout: + self._dropout = nn.Dropout(0.5) + else: + self._dropout = None + self._pad2 = nn.Pad2D([1, 1, 1, 1], mode="replicate") + self._sn_conv2 = SNConv( + name=name + "_sn_conv2", + in_channels=channels, + out_channels=channels, + kernel_size=3, + norm_layer=norm_layer, + use_bias=use_bias, + act="ReLU", + act_attr=None) + + def forward(self, x): + pad1 = self._pad1.forward(x) + sn_conv1 = self._sn_conv1.forward(pad1) + pad2 = self._pad2.forward(sn_conv1) + sn_conv2 = self._sn_conv2.forward(pad2) + return sn_conv2 + x diff --git a/style_text_rec/arch/decoder.py b/style_text_rec/arch/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..36f07c5998a8f6b400997eacae0b44860312f432 --- /dev/null +++ b/style_text_rec/arch/decoder.py @@ -0,0 +1,251 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn + +from arch.base_module import SNConv, SNConvTranspose, ResBlock + + +class Decoder(nn.Layer): + def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer, + act, act_attr, conv_block_dropout, conv_block_num, + conv_block_dilation, out_conv_act, out_conv_act_attr): + super(Decoder, self).__init__() + conv_blocks = [] + for i in range(conv_block_num): + conv_blocks.append( + ResBlock( + name="{}_conv_block_{}".format(name, i), + channels=encode_dim * 8, + norm_layer=norm_layer, + use_dropout=conv_block_dropout, + use_dilation=conv_block_dilation, + use_bias=use_bias)) + self.conv_blocks = nn.Sequential(*conv_blocks) + self._up1 = SNConvTranspose( + name=name + "_up1", + in_channels=encode_dim * 8, + out_channels=encode_dim * 4, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._up2 = SNConvTranspose( + name=name + "_up2", + in_channels=encode_dim * 4, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._up3 = SNConvTranspose( + name=name + "_up3", + in_channels=encode_dim * 2, + out_channels=encode_dim, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate") + self._out_conv = SNConv( + name=name + "_out_conv", + in_channels=encode_dim, + out_channels=out_channels, + kernel_size=3, + use_bias=use_bias, + norm_layer=None, + act=out_conv_act, + act_attr=out_conv_act_attr) + + def forward(self, x): + if isinstance(x, (list, tuple)): + x = paddle.concat(x, axis=1) + output_dict = dict() + output_dict["conv_blocks"] = self.conv_blocks.forward(x) + output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"]) + output_dict["up2"] = self._up2.forward(output_dict["up1"]) + output_dict["up3"] = self._up3.forward(output_dict["up2"]) + output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"]) + output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"]) + return output_dict + + +class DecoderUnet(nn.Layer): + def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer, + act, act_attr, conv_block_dropout, conv_block_num, + conv_block_dilation, out_conv_act, out_conv_act_attr): + super(DecoderUnet, self).__init__() + conv_blocks = [] + for i in range(conv_block_num): + conv_blocks.append( + ResBlock( + name="{}_conv_block_{}".format(name, i), + channels=encode_dim * 8, + norm_layer=norm_layer, + use_dropout=conv_block_dropout, + use_dilation=conv_block_dilation, + use_bias=use_bias)) + self._conv_blocks = nn.Sequential(*conv_blocks) + self._up1 = SNConvTranspose( + name=name + "_up1", + in_channels=encode_dim * 8, + out_channels=encode_dim * 4, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._up2 = SNConvTranspose( + name=name + "_up2", + in_channels=encode_dim * 8, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._up3 = SNConvTranspose( + name=name + "_up3", + in_channels=encode_dim * 4, + out_channels=encode_dim, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate") + self._out_conv = SNConv( + name=name + "_out_conv", + in_channels=encode_dim, + out_channels=out_channels, + kernel_size=3, + use_bias=use_bias, + norm_layer=None, + act=out_conv_act, + act_attr=out_conv_act_attr) + + def forward(self, x, y, feature2, feature1): + output_dict = dict() + output_dict["conv_blocks"] = self._conv_blocks( + paddle.concat( + (x, y), axis=1)) + output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"]) + output_dict["up2"] = self._up2.forward( + paddle.concat( + (output_dict["up1"], feature2), axis=1)) + output_dict["up3"] = self._up3.forward( + paddle.concat( + (output_dict["up2"], feature1), axis=1)) + output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"]) + output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"]) + return output_dict + + +class SingleDecoder(nn.Layer): + def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer, + act, act_attr, conv_block_dropout, conv_block_num, + conv_block_dilation, out_conv_act, out_conv_act_attr): + super(SingleDecoder, self).__init__() + conv_blocks = [] + for i in range(conv_block_num): + conv_blocks.append( + ResBlock( + name="{}_conv_block_{}".format(name, i), + channels=encode_dim * 4, + norm_layer=norm_layer, + use_dropout=conv_block_dropout, + use_dilation=conv_block_dilation, + use_bias=use_bias)) + self._conv_blocks = nn.Sequential(*conv_blocks) + self._up1 = SNConvTranspose( + name=name + "_up1", + in_channels=encode_dim * 4, + out_channels=encode_dim * 4, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._up2 = SNConvTranspose( + name=name + "_up2", + in_channels=encode_dim * 8, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._up3 = SNConvTranspose( + name=name + "_up3", + in_channels=encode_dim * 4, + out_channels=encode_dim, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate") + self._out_conv = SNConv( + name=name + "_out_conv", + in_channels=encode_dim, + out_channels=out_channels, + kernel_size=3, + use_bias=use_bias, + norm_layer=None, + act=out_conv_act, + act_attr=out_conv_act_attr) + + def forward(self, x, feature2, feature1): + output_dict = dict() + output_dict["conv_blocks"] = self._conv_blocks.forward(x) + output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"]) + output_dict["up2"] = self._up2.forward( + paddle.concat( + (output_dict["up1"], feature2), axis=1)) + output_dict["up3"] = self._up3.forward( + paddle.concat( + (output_dict["up2"], feature1), axis=1)) + output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"]) + output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"]) + return output_dict diff --git a/style_text_rec/arch/encoder.py b/style_text_rec/arch/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b884cda2934477082a1ed98c94e33b736d1f96b4 --- /dev/null +++ b/style_text_rec/arch/encoder.py @@ -0,0 +1,186 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn + +from arch.base_module import SNConv, SNConvTranspose, ResBlock + + +class Encoder(nn.Layer): + def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer, + act, act_attr, conv_block_dropout, conv_block_num, + conv_block_dilation): + super(Encoder, self).__init__() + self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate") + self._in_conv = SNConv( + name=name + "_in_conv", + in_channels=in_channels, + out_channels=encode_dim, + kernel_size=7, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._down1 = SNConv( + name=name + "_down1", + in_channels=encode_dim, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._down2 = SNConv( + name=name + "_down2", + in_channels=encode_dim * 2, + out_channels=encode_dim * 4, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._down3 = SNConv( + name=name + "_down3", + in_channels=encode_dim * 4, + out_channels=encode_dim * 4, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + conv_blocks = [] + for i in range(conv_block_num): + conv_blocks.append( + ResBlock( + name="{}_conv_block_{}".format(name, i), + channels=encode_dim * 4, + norm_layer=norm_layer, + use_dropout=conv_block_dropout, + use_dilation=conv_block_dilation, + use_bias=use_bias)) + self._conv_blocks = nn.Sequential(*conv_blocks) + + def forward(self, x): + out_dict = dict() + x = self._pad2d(x) + out_dict["in_conv"] = self._in_conv.forward(x) + out_dict["down1"] = self._down1.forward(out_dict["in_conv"]) + out_dict["down2"] = self._down2.forward(out_dict["down1"]) + out_dict["down3"] = self._down3.forward(out_dict["down2"]) + out_dict["res_blocks"] = self._conv_blocks.forward(out_dict["down3"]) + return out_dict + + +class EncoderUnet(nn.Layer): + def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer, + act, act_attr): + super(EncoderUnet, self).__init__() + self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate") + self._in_conv = SNConv( + name=name + "_in_conv", + in_channels=in_channels, + out_channels=encode_dim, + kernel_size=7, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._down1 = SNConv( + name=name + "_down1", + in_channels=encode_dim, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._down2 = SNConv( + name=name + "_down2", + in_channels=encode_dim * 2, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._down3 = SNConv( + name=name + "_down3", + in_channels=encode_dim * 2, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._down4 = SNConv( + name=name + "_down4", + in_channels=encode_dim * 2, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._up1 = SNConvTranspose( + name=name + "_up1", + in_channels=encode_dim * 2, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._up2 = SNConvTranspose( + name=name + "_up2", + in_channels=encode_dim * 4, + out_channels=encode_dim * 4, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + + def forward(self, x): + output_dict = dict() + x = self._pad2d(x) + output_dict['in_conv'] = self._in_conv.forward(x) + output_dict['down1'] = self._down1.forward(output_dict['in_conv']) + output_dict['down2'] = self._down2.forward(output_dict['down1']) + output_dict['down3'] = self._down3.forward(output_dict['down2']) + output_dict['down4'] = self._down4.forward(output_dict['down3']) + output_dict['up1'] = self._up1.forward(output_dict['down4']) + output_dict['up2'] = self._up2.forward( + paddle.concat( + (output_dict['down3'], output_dict['up1']), axis=1)) + output_dict['concat'] = paddle.concat( + (output_dict['down2'], output_dict['up2']), axis=1) + return output_dict diff --git a/style_text_rec/arch/spectral_norm.py b/style_text_rec/arch/spectral_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..21d0afc8d4a8fd4e2262db5c8461d6ffc3dadd45 --- /dev/null +++ b/style_text_rec/arch/spectral_norm.py @@ -0,0 +1,150 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +def normal_(x, mean=0., std=1.): + temp_value = paddle.normal(mean, std, shape=x.shape) + x.set_value(temp_value) + return x + + +class SpectralNorm(object): + def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): + self.name = name + self.dim = dim + if n_power_iterations <= 0: + raise ValueError('Expected n_power_iterations to be positive, but ' + 'got n_power_iterations={}'.format( + n_power_iterations)) + self.n_power_iterations = n_power_iterations + self.eps = eps + + def reshape_weight_to_matrix(self, weight): + weight_mat = weight + if self.dim != 0: + # transpose dim to front + weight_mat = weight_mat.transpose([ + self.dim, + * [d for d in range(weight_mat.dim()) if d != self.dim] + ]) + + height = weight_mat.shape[0] + + return weight_mat.reshape([height, -1]) + + def compute_weight(self, module, do_power_iteration): + weight = getattr(module, self.name + '_orig') + u = getattr(module, self.name + '_u') + v = getattr(module, self.name + '_v') + weight_mat = self.reshape_weight_to_matrix(weight) + + if do_power_iteration: + with paddle.no_grad(): + for _ in range(self.n_power_iterations): + v.set_value( + F.normalize( + paddle.matmul( + weight_mat, + u, + transpose_x=True, + transpose_y=False), + axis=0, + epsilon=self.eps, )) + + u.set_value( + F.normalize( + paddle.matmul(weight_mat, v), + axis=0, + epsilon=self.eps, )) + if self.n_power_iterations > 0: + u = u.clone() + v = v.clone() + + sigma = paddle.dot(u, paddle.mv(weight_mat, v)) + weight = weight / sigma + return weight + + def remove(self, module): + with paddle.no_grad(): + weight = self.compute_weight(module, do_power_iteration=False) + delattr(module, self.name) + delattr(module, self.name + '_u') + delattr(module, self.name + '_v') + delattr(module, self.name + '_orig') + + module.add_parameter(self.name, weight.detach()) + + def __call__(self, module, inputs): + setattr( + module, + self.name, + self.compute_weight( + module, do_power_iteration=module.training)) + + @staticmethod + def apply(module, name, n_power_iterations, dim, eps): + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + raise RuntimeError( + "Cannot register two spectral_norm hooks on " + "the same parameter {}".format(name)) + + fn = SpectralNorm(name, n_power_iterations, dim, eps) + weight = module._parameters[name] + + with paddle.no_grad(): + weight_mat = fn.reshape_weight_to_matrix(weight) + h, w = weight_mat.shape + + # randomly initialize u and v + u = module.create_parameter([h]) + u = normal_(u, 0., 1.) + v = module.create_parameter([w]) + v = normal_(v, 0., 1.) + u = F.normalize(u, axis=0, epsilon=fn.eps) + v = F.normalize(v, axis=0, epsilon=fn.eps) + + # delete fn.name form parameters, otherwise you can not set attribute + del module._parameters[fn.name] + module.add_parameter(fn.name + "_orig", weight) + # still need to assign weight back as fn.name because all sorts of + # things may assume that it exists, e.g., when initializing weights. + # However, we can't directly assign as it could be an Parameter and + # gets added as a parameter. Instead, we register weight * 1.0 as a plain + # attribute. + setattr(module, fn.name, weight * 1.0) + module.register_buffer(fn.name + "_u", u) + module.register_buffer(fn.name + "_v", v) + + module.register_forward_pre_hook(fn) + return fn + + +def spectral_norm(module, + name='weight', + n_power_iterations=1, + eps=1e-12, + dim=None): + + if dim is None: + if isinstance(module, (nn.Conv1DTranspose, nn.Conv2DTranspose, + nn.Conv3DTranspose, nn.Linear)): + dim = 1 + else: + dim = 0 + SpectralNorm.apply(module, name, n_power_iterations, dim, eps) + return module diff --git a/style_text_rec/arch/style_text_rec.py b/style_text_rec/arch/style_text_rec.py new file mode 100644 index 0000000000000000000000000000000000000000..599927ce3edefc90f14191ef3d29b1221355867e --- /dev/null +++ b/style_text_rec/arch/style_text_rec.py @@ -0,0 +1,285 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn + +from arch.base_module import MiddleNet, ResBlock +from arch.encoder import Encoder +from arch.decoder import Decoder, DecoderUnet, SingleDecoder +from utils.load_params import load_dygraph_pretrain +from utils.logging import get_logger + + +class StyleTextRec(nn.Layer): + def __init__(self, config): + super(StyleTextRec, self).__init__() + self.logger = get_logger() + self.text_generator = TextGenerator(config["Predictor"][ + "text_generator"]) + self.bg_generator = BgGeneratorWithMask(config["Predictor"][ + "bg_generator"]) + self.fusion_generator = FusionGeneratorSimple(config["Predictor"][ + "fusion_generator"]) + bg_generator_pretrain = config["Predictor"]["bg_generator"]["pretrain"] + text_generator_pretrain = config["Predictor"]["text_generator"][ + "pretrain"] + fusion_generator_pretrain = config["Predictor"]["fusion_generator"][ + "pretrain"] + load_dygraph_pretrain( + self.bg_generator, + self.logger, + path=bg_generator_pretrain, + load_static_weights=False) + load_dygraph_pretrain( + self.text_generator, + self.logger, + path=text_generator_pretrain, + load_static_weights=False) + load_dygraph_pretrain( + self.fusion_generator, + self.logger, + path=fusion_generator_pretrain, + load_static_weights=False) + + def forward(self, style_input, text_input): + text_gen_output = self.text_generator.forward(style_input, text_input) + fake_text = text_gen_output["fake_text"] + fake_sk = text_gen_output["fake_sk"] + bg_gen_output = self.bg_generator.forward(style_input) + bg_encode_feature = bg_gen_output["bg_encode_feature"] + bg_decode_feature1 = bg_gen_output["bg_decode_feature1"] + bg_decode_feature2 = bg_gen_output["bg_decode_feature2"] + fake_bg = bg_gen_output["fake_bg"] + + fusion_gen_output = self.fusion_generator.forward(fake_text, fake_bg) + fake_fusion = fusion_gen_output["fake_fusion"] + return { + "fake_fusion": fake_fusion, + "fake_text": fake_text, + "fake_sk": fake_sk, + "fake_bg": fake_bg, + } + + +class TextGenerator(nn.Layer): + def __init__(self, config): + super(TextGenerator, self).__init__() + name = config["module_name"] + encode_dim = config["encode_dim"] + norm_layer = config["norm_layer"] + conv_block_dropout = config["conv_block_dropout"] + conv_block_num = config["conv_block_num"] + conv_block_dilation = config["conv_block_dilation"] + if norm_layer == "InstanceNorm2D": + use_bias = True + else: + use_bias = False + self.encoder_text = Encoder( + name=name + "_encoder_text", + in_channels=3, + encode_dim=encode_dim, + use_bias=use_bias, + norm_layer=norm_layer, + act="ReLU", + act_attr=None, + conv_block_dropout=conv_block_dropout, + conv_block_num=conv_block_num, + conv_block_dilation=conv_block_dilation) + self.encoder_style = Encoder( + name=name + "_encoder_style", + in_channels=3, + encode_dim=encode_dim, + use_bias=use_bias, + norm_layer=norm_layer, + act="ReLU", + act_attr=None, + conv_block_dropout=conv_block_dropout, + conv_block_num=conv_block_num, + conv_block_dilation=conv_block_dilation) + self.decoder_text = Decoder( + name=name + "_decoder_text", + encode_dim=encode_dim, + out_channels=int(encode_dim / 2), + use_bias=use_bias, + norm_layer=norm_layer, + act="ReLU", + act_attr=None, + conv_block_dropout=conv_block_dropout, + conv_block_num=conv_block_num, + conv_block_dilation=conv_block_dilation, + out_conv_act="Tanh", + out_conv_act_attr=None) + self.decoder_sk = Decoder( + name=name + "_decoder_sk", + encode_dim=encode_dim, + out_channels=1, + use_bias=use_bias, + norm_layer=norm_layer, + act="ReLU", + act_attr=None, + conv_block_dropout=conv_block_dropout, + conv_block_num=conv_block_num, + conv_block_dilation=conv_block_dilation, + out_conv_act="Sigmoid", + out_conv_act_attr=None) + + self.middle = MiddleNet( + name=name + "_middle_net", + in_channels=int(encode_dim / 2) + 1, + mid_channels=encode_dim, + out_channels=3, + use_bias=use_bias) + + def forward(self, style_input, text_input): + style_feature = self.encoder_style.forward(style_input)["res_blocks"] + text_feature = self.encoder_text.forward(text_input)["res_blocks"] + fake_c_temp = self.decoder_text.forward([text_feature, + style_feature])["out_conv"] + fake_sk = self.decoder_sk.forward([text_feature, + style_feature])["out_conv"] + fake_text = self.middle(paddle.concat((fake_c_temp, fake_sk), axis=1)) + return {"fake_sk": fake_sk, "fake_text": fake_text} + + +class BgGeneratorWithMask(nn.Layer): + def __init__(self, config): + super(BgGeneratorWithMask, self).__init__() + name = config["module_name"] + encode_dim = config["encode_dim"] + norm_layer = config["norm_layer"] + conv_block_dropout = config["conv_block_dropout"] + conv_block_num = config["conv_block_num"] + conv_block_dilation = config["conv_block_dilation"] + self.output_factor = config.get("output_factor", 1.0) + + if norm_layer == "InstanceNorm2D": + use_bias = True + else: + use_bias = False + + self.encoder_bg = Encoder( + name=name + "_encoder_bg", + in_channels=3, + encode_dim=encode_dim, + use_bias=use_bias, + norm_layer=norm_layer, + act="ReLU", + act_attr=None, + conv_block_dropout=conv_block_dropout, + conv_block_num=conv_block_num, + conv_block_dilation=conv_block_dilation) + + self.decoder_bg = SingleDecoder( + name=name + "_decoder_bg", + encode_dim=encode_dim, + out_channels=3, + use_bias=use_bias, + norm_layer=norm_layer, + act="ReLU", + act_attr=None, + conv_block_dropout=conv_block_dropout, + conv_block_num=conv_block_num, + conv_block_dilation=conv_block_dilation, + out_conv_act="Tanh", + out_conv_act_attr=None) + + self.decoder_mask = Decoder( + name=name + "_decoder_mask", + encode_dim=encode_dim // 2, + out_channels=1, + use_bias=use_bias, + norm_layer=norm_layer, + act="ReLU", + act_attr=None, + conv_block_dropout=conv_block_dropout, + conv_block_num=conv_block_num, + conv_block_dilation=conv_block_dilation, + out_conv_act="Sigmoid", + out_conv_act_attr=None) + + self.middle = MiddleNet( + name=name + "_middle_net", + in_channels=3 + 1, + mid_channels=encode_dim, + out_channels=3, + use_bias=use_bias) + + def forward(self, style_input): + encode_bg_output = self.encoder_bg(style_input) + decode_bg_output = self.decoder_bg(encode_bg_output["res_blocks"], + encode_bg_output["down2"], + encode_bg_output["down1"]) + + fake_c_temp = decode_bg_output["out_conv"] + fake_bg_mask = self.decoder_mask.forward(encode_bg_output[ + "res_blocks"])["out_conv"] + fake_bg = self.middle( + paddle.concat( + (fake_c_temp, fake_bg_mask), axis=1)) + return { + "bg_encode_feature": encode_bg_output["res_blocks"], + "bg_decode_feature1": decode_bg_output["up1"], + "bg_decode_feature2": decode_bg_output["up2"], + "fake_bg": fake_bg, + "fake_bg_mask": fake_bg_mask, + } + + +class FusionGeneratorSimple(nn.Layer): + def __init__(self, config): + super(FusionGeneratorSimple, self).__init__() + name = config["module_name"] + encode_dim = config["encode_dim"] + norm_layer = config["norm_layer"] + conv_block_dropout = config["conv_block_dropout"] + conv_block_dilation = config["conv_block_dilation"] + if norm_layer == "InstanceNorm2D": + use_bias = True + else: + use_bias = False + + self._conv = nn.Conv2D( + in_channels=6, + out_channels=encode_dim, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=paddle.ParamAttr(name=name + "_conv_weights"), + bias_attr=False) + + self._res_block = ResBlock( + name="{}_conv_block".format(name), + channels=encode_dim, + norm_layer=norm_layer, + use_dropout=conv_block_dropout, + use_dilation=conv_block_dilation, + use_bias=use_bias) + + self._reduce_conv = nn.Conv2D( + in_channels=encode_dim, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=paddle.ParamAttr(name=name + "_reduce_conv_weights"), + bias_attr=False) + + def forward(self, fake_text, fake_bg): + fake_concat = paddle.concat((fake_text, fake_bg), axis=1) + fake_concat_tmp = self._conv(fake_concat) + output_res = self._res_block(fake_concat_tmp) + fake_fusion = self._reduce_conv(output_res) + return {"fake_fusion": fake_fusion} diff --git a/style_text_rec/configs/config.yml b/style_text_rec/configs/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..3b10b3d2761a4aa40c28abe10134a2f276e1af9d --- /dev/null +++ b/style_text_rec/configs/config.yml @@ -0,0 +1,54 @@ +Global: + output_num: 10 + output_dir: output_data + use_gpu: false + image_height: 32 + image_width: 320 +TextDrawer: + fonts: + en: fonts/en_standard.ttf + ch: fonts/ch_standard.ttf + ko: fonts/ko_standard.ttf +Predictor: + method: StyleTextRecPredictor + algorithm: StyleTextRec + scale: 0.00392156862745098 + mean: + - 0.5 + - 0.5 + - 0.5 + std: + - 0.5 + - 0.5 + - 0.5 + expand_result: false + bg_generator: + pretrain: style_text_models/bg_generator + module_name: bg_generator + generator_type: BgGeneratorWithMask + encode_dim: 64 + norm_layer: null + conv_block_num: 4 + conv_block_dropout: false + conv_block_dilation: true + output_factor: 1.05 + text_generator: + pretrain: style_text_models/text_generator + module_name: text_generator + generator_type: TextGenerator + encode_dim: 64 + norm_layer: InstanceNorm2D + conv_block_num: 4 + conv_block_dropout: false + conv_block_dilation: true + fusion_generator: + pretrain: style_text_models/fusion_generator + module_name: fusion_generator + generator_type: FusionGeneratorSimple + encode_dim: 64 + norm_layer: null + conv_block_num: 4 + conv_block_dropout: false + conv_block_dilation: true +Writer: + method: SimpleWriter diff --git a/style_text_rec/configs/dataset_config.yml b/style_text_rec/configs/dataset_config.yml new file mode 100644 index 0000000000000000000000000000000000000000..e047489e5d82e4c561a835ccf4de1b385e4f5c08 --- /dev/null +++ b/style_text_rec/configs/dataset_config.yml @@ -0,0 +1,64 @@ +Global: + output_num: 10 + output_dir: output_data + use_gpu: false + image_height: 32 + image_width: 320 + standard_font: fonts/en_standard.ttf +TextDrawer: + fonts: + en: fonts/en_standard.ttf + ch: fonts/ch_standard.ttf + ko: fonts/ko_standard.ttf +StyleSampler: + method: DatasetSampler + image_home: examples + label_file: examples/image_list.txt + with_label: true +CorpusGenerator: + method: FileCorpus + language: ch + corpus_file: examples/corpus/example.txt +Predictor: + method: StyleTextRecPredictor + algorithm: StyleTextRec + scale: 0.00392156862745098 + mean: + - 0.5 + - 0.5 + - 0.5 + std: + - 0.5 + - 0.5 + - 0.5 + expand_result: false + bg_generator: + pretrain: models/style_text_rec/bg_generator + module_name: bg_generator + generator_type: BgGeneratorWithMask + encode_dim: 64 + norm_layer: null + conv_block_num: 4 + conv_block_dropout: false + conv_block_dilation: true + output_factor: 1.05 + text_generator: + pretrain: models/style_text_rec/text_generator + module_name: text_generator + generator_type: TextGenerator + encode_dim: 64 + norm_layer: InstanceNorm2D + conv_block_num: 4 + conv_block_dropout: false + conv_block_dilation: true + fusion_generator: + pretrain: models/style_text_rec/fusion_generator + module_name: fusion_generator + generator_type: FusionGeneratorSimple + encode_dim: 64 + norm_layer: null + conv_block_num: 4 + conv_block_dropout: false + conv_block_dilation: true +Writer: + method: SimpleWriter diff --git a/style_text_rec/engine/__init__.py b/style_text_rec/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/style_text_rec/engine/corpus_generators.py b/style_text_rec/engine/corpus_generators.py new file mode 100644 index 0000000000000000000000000000000000000000..186d15f36d16971d9e7700535b50b1f724a80fe7 --- /dev/null +++ b/style_text_rec/engine/corpus_generators.py @@ -0,0 +1,66 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random + +from utils.logging import get_logger + + +class FileCorpus(object): + def __init__(self, config): + self.logger = get_logger() + self.logger.info("using FileCorpus") + + self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + + corpus_file = config["CorpusGenerator"]["corpus_file"] + self.language = config["CorpusGenerator"]["language"] + with open(corpus_file, 'r') as f: + corpus_raw = f.read() + self.corpus_list = corpus_raw.split("\n")[:-1] + assert len(self.corpus_list) > 0 + random.shuffle(self.corpus_list) + self.index = 0 + + def generate(self, corpus_length=0): + if self.index >= len(self.corpus_list): + self.index = 0 + random.shuffle(self.corpus_list) + corpus = self.corpus_list[self.index] + if corpus_length != 0: + corpus = corpus[0:corpus_length] + if corpus_length > len(corpus): + self.logger.warning("generated corpus is shorter than expected.") + self.index += 1 + return self.language, corpus + + +class EnNumCorpus(object): + def __init__(self, config): + self.logger = get_logger() + self.logger.info("using NumberCorpus") + self.num_list = "0123456789" + self.en_char_list = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + self.height = config["Global"]["image_height"] + self.max_width = config["Global"]["image_width"] + + def generate(self, corpus_length=0): + corpus = "" + if corpus_length == 0: + corpus_length = random.randint(5, 15) + for i in range(corpus_length): + if random.random() < 0.2: + corpus += "{}".format(random.choice(self.en_char_list)) + else: + corpus += "{}".format(random.choice(self.num_list)) + return "en", corpus diff --git a/style_text_rec/engine/predictors.py b/style_text_rec/engine/predictors.py new file mode 100644 index 0000000000000000000000000000000000000000..d9f4afe4a18bd1e0a96ac37aa0359f26434ddb3d --- /dev/null +++ b/style_text_rec/engine/predictors.py @@ -0,0 +1,115 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import cv2 +import math +import paddle + +from arch import style_text_rec +from utils.sys_funcs import check_gpu +from utils.logging import get_logger + + +class StyleTextRecPredictor(object): + def __init__(self, config): + algorithm = config['Predictor']['algorithm'] + assert algorithm in ["StyleTextRec" + ], "Generator {} not supported.".format(algorithm) + use_gpu = config["Global"]['use_gpu'] + check_gpu(use_gpu) + self.logger = get_logger() + self.generator = getattr(style_text_rec, algorithm)(config) + self.height = config["Global"]["image_height"] + self.width = config["Global"]["image_width"] + self.scale = config["Predictor"]["scale"] + self.mean = config["Predictor"]["mean"] + self.std = config["Predictor"]["std"] + self.expand_result = config["Predictor"]["expand_result"] + + def predict(self, style_input, text_input): + style_input = self.rep_style_input(style_input, text_input) + tensor_style_input = self.preprocess(style_input) + tensor_text_input = self.preprocess(text_input) + style_text_result = self.generator.forward(tensor_style_input, + tensor_text_input) + fake_fusion = self.postprocess(style_text_result["fake_fusion"]) + fake_text = self.postprocess(style_text_result["fake_text"]) + fake_sk = self.postprocess(style_text_result["fake_sk"]) + fake_bg = self.postprocess(style_text_result["fake_bg"]) + bbox = self.get_text_boundary(fake_text) + if bbox: + left, right, top, bottom = bbox + fake_fusion = fake_fusion[top:bottom, left:right, :] + fake_text = fake_text[top:bottom, left:right, :] + fake_sk = fake_sk[top:bottom, left:right, :] + fake_bg = fake_bg[top:bottom, left:right, :] + + # fake_fusion = self.crop_by_text(img_fake_fusion, img_fake_text) + return { + "fake_fusion": fake_fusion, + "fake_text": fake_text, + "fake_sk": fake_sk, + "fake_bg": fake_bg, + } + + def preprocess(self, img): + img = (img.astype('float32') * self.scale - self.mean) / self.std + img_height, img_width, channel = img.shape + assert channel == 3, "Please use an rgb image." + ratio = img_width / float(img_height) + if math.ceil(self.height * ratio) > self.width: + resized_w = self.width + else: + resized_w = int(math.ceil(self.height * ratio)) + img = cv2.resize(img, (resized_w, self.height)) + + new_img = np.zeros([self.height, self.width, 3]).astype('float32') + new_img[:, 0:resized_w, :] = img + img = new_img.transpose((2, 0, 1)) + img = img[np.newaxis, :, :, :] + return paddle.to_tensor(img) + + def postprocess(self, tensor): + img = tensor.numpy()[0] + img = img.transpose((1, 2, 0)) + img = (img * self.std + self.mean) / self.scale + img = np.maximum(img, 0.0) + img = np.minimum(img, 255.0) + img = img.astype('uint8') + return img + + def rep_style_input(self, style_input, text_input): + rep_num = int(1.2 * (text_input.shape[1] / text_input.shape[0]) / + (style_input.shape[1] / style_input.shape[0])) + 1 + style_input = np.tile(style_input, reps=[1, rep_num, 1]) + max_width = int(self.width / self.height * style_input.shape[0]) + style_input = style_input[:, :max_width, :] + return style_input + + def get_text_boundary(self, text_img): + img_height = text_img.shape[0] + img_width = text_img.shape[1] + bounder = 3 + text_canny_img = cv2.Canny(text_img, 10, 20) + edge_num_h = text_canny_img.sum(axis=0) + no_zero_list_h = np.where(edge_num_h > 0)[0] + edge_num_w = text_canny_img.sum(axis=1) + no_zero_list_w = np.where(edge_num_w > 0)[0] + if len(no_zero_list_h) == 0 or len(no_zero_list_w) == 0: + return None + left = max(no_zero_list_h[0] - bounder, 0) + right = min(no_zero_list_h[-1] + bounder, img_width) + top = max(no_zero_list_w[0] - bounder, 0) + bottom = min(no_zero_list_w[-1] + bounder, img_height) + return [left, right, top, bottom] diff --git a/style_text_rec/engine/style_samplers.py b/style_text_rec/engine/style_samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..e171d58db7527ffb37972524991e58ac59c6bb0a --- /dev/null +++ b/style_text_rec/engine/style_samplers.py @@ -0,0 +1,62 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import random +import cv2 + + +class DatasetSampler(object): + def __init__(self, config): + self.image_home = config["StyleSampler"]["image_home"] + label_file = config["StyleSampler"]["label_file"] + self.dataset_with_label = config["StyleSampler"]["with_label"] + self.height = config["Global"]["image_height"] + self.index = 0 + with open(label_file, "r") as f: + label_raw = f.read() + self.path_label_list = label_raw.split("\n")[:-1] + assert len(self.path_label_list) > 0 + random.shuffle(self.path_label_list) + + def sample(self): + if self.index >= len(self.path_label_list): + random.shuffle(self.path_label_list) + self.index = 0 + if self.dataset_with_label: + path_label = self.path_label_list[self.index] + rel_image_path, label = path_label.split('\t') + else: + rel_image_path = self.path_label_list[self.index] + label = None + img_path = "{}/{}".format(self.image_home, rel_image_path) + image = cv2.imread(img_path) + origin_height = image.shape[0] + ratio = self.height / origin_height + width = int(image.shape[1] * ratio) + height = int(image.shape[0] * ratio) + image = cv2.resize(image, (width, height)) + + self.index += 1 + if label: + return {"image": image, "label": label} + else: + return {"image": image} + + +def duplicate_image(image, width): + image_width = image.shape[1] + dup_num = width // image_width + 1 + image = np.tile(image, reps=[1, dup_num, 1]) + cropped_image = image[:, :width, :] + return cropped_image diff --git a/style_text_rec/engine/synthesisers.py b/style_text_rec/engine/synthesisers.py new file mode 100644 index 0000000000000000000000000000000000000000..177e3e049a695ecd06f5d2271f21336dd4eff997 --- /dev/null +++ b/style_text_rec/engine/synthesisers.py @@ -0,0 +1,71 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from utils.config import ArgsParser, load_config, override_config +from utils.logging import get_logger +from engine import style_samplers, corpus_generators, text_drawers, predictors, writers + + +class ImageSynthesiser(object): + def __init__(self): + self.FLAGS = ArgsParser().parse_args() + self.config = load_config(self.FLAGS.config) + self.config = override_config(self.config, options=self.FLAGS.override) + self.output_dir = self.config["Global"]["output_dir"] + if not os.path.exists(self.output_dir): + os.mkdir(self.output_dir) + self.logger = get_logger( + log_file='{}/predict.log'.format(self.output_dir)) + + self.text_drawer = text_drawers.StdTextDrawer(self.config) + + predictor_method = self.config["Predictor"]["method"] + assert predictor_method is not None + self.predictor = getattr(predictors, predictor_method)(self.config) + + def synth_image(self, corpus, style_input, language="en"): + corpus, text_input = self.text_drawer.draw_text(corpus, language) + synth_result = self.predictor.predict(style_input, text_input) + return synth_result + + +class DatasetSynthesiser(ImageSynthesiser): + def __init__(self): + super(DatasetSynthesiser, self).__init__() + self.tag = self.FLAGS.tag + self.output_num = self.config["Global"]["output_num"] + corpus_generator_method = self.config["CorpusGenerator"]["method"] + self.corpus_generator = getattr(corpus_generators, + corpus_generator_method)(self.config) + + style_sampler_method = self.config["StyleSampler"]["method"] + assert style_sampler_method is not None + self.style_sampler = style_samplers.DatasetSampler(self.config) + self.writer = writers.SimpleWriter(self.config, self.tag) + + def synth_dataset(self): + for i in range(self.output_num): + style_data = self.style_sampler.sample() + style_input = style_data["image"] + corpus_language, text_input_label = self.corpus_generator.generate( + ) + text_input_label, text_input = self.text_drawer.draw_text( + text_input_label, corpus_language) + + synth_result = self.predictor.predict(style_input, text_input) + fake_fusion = synth_result["fake_fusion"] + self.writer.save_image(fake_fusion, text_input_label) + self.writer.save_label() + self.writer.merge_label() diff --git a/style_text_rec/engine/text_drawers.py b/style_text_rec/engine/text_drawers.py new file mode 100644 index 0000000000000000000000000000000000000000..8aaac06ec50816bb6e2774972644c0a7dfb908c6 --- /dev/null +++ b/style_text_rec/engine/text_drawers.py @@ -0,0 +1,57 @@ +from PIL import Image, ImageDraw, ImageFont +import numpy as np +from utils.logging import get_logger + + +class StdTextDrawer(object): + def __init__(self, config): + self.logger = get_logger() + self.max_width = config["Global"]["image_width"] + self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + self.height = config["Global"]["image_height"] + self.font_dict = {} + self.load_fonts(config["TextDrawer"]["fonts"]) + self.support_languages = list(self.font_dict) + + def load_fonts(self, fonts_config): + for language in fonts_config: + font_path = fonts_config[language] + font_height = self.get_valid_height(font_path) + font = ImageFont.truetype(font_path, font_height) + self.font_dict[language] = font + + def get_valid_height(self, font_path): + font = ImageFont.truetype(font_path, self.height - 4) + _, font_height = font.getsize(self.char_list) + if font_height <= self.height - 4: + return self.height - 4 + else: + return int((self.height - 4)**2 / font_height) + + def draw_text(self, corpus, language="en", crop=True): + if language not in self.support_languages: + self.logger.warning( + "language {} not supported, use en instead.".format(language)) + language = "en" + if crop: + width = min(self.max_width, len(corpus) * self.height) + 4 + else: + width = len(corpus) * self.height + 4 + bg = Image.new("RGB", (width, self.height), color=(127, 127, 127)) + draw = ImageDraw.Draw(bg) + + char_x = 2 + font = self.font_dict[language] + for i, char_i in enumerate(corpus): + char_size = font.getsize(char_i)[0] + draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font) + char_x += char_size + if char_x >= width: + corpus = corpus[0:i + 1] + self.logger.warning("corpus length exceed limit: {}".format( + corpus)) + break + + text_input = np.array(bg).astype(np.uint8) + text_input = text_input[:, 0:char_x, :] + return corpus, text_input diff --git a/style_text_rec/engine/writers.py b/style_text_rec/engine/writers.py new file mode 100644 index 0000000000000000000000000000000000000000..0df75e7234812c3fbab69ceed50040aa16cd83bc --- /dev/null +++ b/style_text_rec/engine/writers.py @@ -0,0 +1,71 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import cv2 +import glob + +from utils.logging import get_logger + + +class SimpleWriter(object): + def __init__(self, config, tag): + self.logger = get_logger() + self.output_dir = config["Global"]["output_dir"] + self.counter = 0 + self.label_dict = {} + self.tag = tag + self.label_file_index = 0 + + def save_image(self, image, text_input_label): + image_home = os.path.join(self.output_dir, "images", self.tag) + if not os.path.exists(image_home): + os.makedirs(image_home) + + image_path = os.path.join(image_home, "{}.png".format(self.counter)) + # todo support continue synth + cv2.imwrite(image_path, image) + self.logger.info("generate image: {}".format(image_path)) + + image_name = os.path.join(self.tag, "{}.png".format(self.counter)) + self.label_dict[image_name] = text_input_label + + self.counter += 1 + if not self.counter % 100: + self.save_label() + + def save_label(self): + label_raw = "" + label_home = os.path.join(self.output_dir, "label") + if not os.path.exists(label_home): + os.mkdir(label_home) + for image_path in self.label_dict: + label = self.label_dict[image_path] + label_raw += "{}\t{}\n".format(image_path, label) + label_file_path = os.path.join(label_home, + "{}_label.txt".format(self.tag)) + with open(label_file_path, "w") as f: + f.write(label_raw) + self.label_file_index += 1 + + def merge_label(self): + label_raw = "" + label_file_regex = os.path.join(self.output_dir, "label", + "*_label.txt") + label_file_list = glob.glob(label_file_regex) + for label_file_i in label_file_list: + with open(label_file_i, "r") as f: + label_raw += f.read() + label_file_path = os.path.join(self.output_dir, "label.txt") + with open(label_file_path, "w") as f: + f.write(label_raw) diff --git a/style_text_rec/examples/corpus/example.txt b/style_text_rec/examples/corpus/example.txt new file mode 100644 index 0000000000000000000000000000000000000000..78451cc3d92a3353f5de0c74c2cb0a06e6197653 --- /dev/null +++ b/style_text_rec/examples/corpus/example.txt @@ -0,0 +1,2 @@ +PaddleOCR +飞桨文字识别 diff --git a/style_text_rec/examples/image_list.txt b/style_text_rec/examples/image_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..b07be0353516f7822e4994d5dcddcd85766035dc --- /dev/null +++ b/style_text_rec/examples/image_list.txt @@ -0,0 +1,2 @@ +style_images/1.jpg NEATNESS +style_images/2.jpg 锁店君和宾馆 diff --git a/style_text_rec/examples/style_images/1.jpg b/style_text_rec/examples/style_images/1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4da7838e5d3c711cdeab60df63ae4c7af7b475ae Binary files /dev/null and b/style_text_rec/examples/style_images/1.jpg differ diff --git a/style_text_rec/examples/style_images/2.jpg b/style_text_rec/examples/style_images/2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f68ce49aa5558124d36ae6eaa801be5b0e79e152 Binary files /dev/null and b/style_text_rec/examples/style_images/2.jpg differ diff --git a/style_text_rec/fonts/ch_standard.ttf b/style_text_rec/fonts/ch_standard.ttf new file mode 100755 index 0000000000000000000000000000000000000000..cdb7fa5907587b8dbe0ad1da7442d3e4f8bd7488 Binary files /dev/null and b/style_text_rec/fonts/ch_standard.ttf differ diff --git a/style_text_rec/fonts/en_standard.ttf b/style_text_rec/fonts/en_standard.ttf new file mode 100755 index 0000000000000000000000000000000000000000..2e31d02424ed50b9e05c19b5d82500699a6edbb0 Binary files /dev/null and b/style_text_rec/fonts/en_standard.ttf differ diff --git a/style_text_rec/fonts/ko_standard.ttf b/style_text_rec/fonts/ko_standard.ttf new file mode 100755 index 0000000000000000000000000000000000000000..982bd879c27c731d2601ea8da988784e06f4b5b3 Binary files /dev/null and b/style_text_rec/fonts/ko_standard.ttf differ diff --git a/style_text_rec/tools/__init__.py b/style_text_rec/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/style_text_rec/tools/synth_dataset.py b/style_text_rec/tools/synth_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4a0e6d5e1f701c49558cfe1ea1df61e9b4180a89 --- /dev/null +++ b/style_text_rec/tools/synth_dataset.py @@ -0,0 +1,23 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from engine.synthesisers import DatasetSynthesiser + + +def synth_dataset(): + dataset_synthesiser = DatasetSynthesiser() + dataset_synthesiser.synth_dataset() + + +if __name__ == '__main__': + synth_dataset() diff --git a/style_text_rec/tools/synth_image.py b/style_text_rec/tools/synth_image.py new file mode 100644 index 0000000000000000000000000000000000000000..7b4827b825e4a28dd1fb2eba722d23e64e8ce0be --- /dev/null +++ b/style_text_rec/tools/synth_image.py @@ -0,0 +1,82 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import cv2 +import sys +import glob + +from utils.config import ArgsParser +from engine.synthesisers import ImageSynthesiser + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + + +def synth_image(): + args = ArgsParser().parse_args() + image_synthesiser = ImageSynthesiser() + style_image_path = args.style_image + img = cv2.imread(style_image_path) + text_corpus = args.text_corpus + language = args.language + + synth_result = image_synthesiser.synth_image(text_corpus, img, language) + fake_fusion = synth_result["fake_fusion"] + fake_text = synth_result["fake_text"] + fake_bg = synth_result["fake_bg"] + cv2.imwrite("fake_fusion.jpg", fake_fusion) + cv2.imwrite("fake_text.jpg", fake_text) + cv2.imwrite("fake_bg.jpg", fake_bg) + + +def batch_synth_images(): + image_synthesiser = ImageSynthesiser() + + corpus_file = "../StyleTextRec_data/test_20201208/test_text_list.txt" + style_data_dir = "../StyleTextRec_data/test_20201208/style_images/" + save_path = "./output_data/" + corpus_list = [] + with open(corpus_file, "rb") as fin: + lines = fin.readlines() + for line in lines: + substr = line.decode("utf-8").strip("\n").split("\t") + corpus_list.append(substr) + style_img_list = glob.glob("{}/*.jpg".format(style_data_dir)) + corpus_num = len(corpus_list) + style_img_num = len(style_img_list) + for cno in range(corpus_num): + for sno in range(style_img_num): + corpus, lang = corpus_list[cno] + style_img_path = style_img_list[sno] + img = cv2.imread(style_img_path) + synth_result = image_synthesiser.synth_image(corpus, img, lang) + fake_fusion = synth_result["fake_fusion"] + fake_text = synth_result["fake_text"] + fake_bg = synth_result["fake_bg"] + for tp in range(2): + if tp == 0: + prefix = "%s/c%d_s%d_" % (save_path, cno, sno) + else: + prefix = "%s/s%d_c%d_" % (save_path, sno, cno) + cv2.imwrite("%s_fake_fusion.jpg" % prefix, fake_fusion) + cv2.imwrite("%s_fake_text.jpg" % prefix, fake_text) + cv2.imwrite("%s_fake_bg.jpg" % prefix, fake_bg) + cv2.imwrite("%s_input_style.jpg" % prefix, img) + print(cno, corpus_num, sno, style_img_num) + + +if __name__ == '__main__': + # batch_synth_images() + synth_image() diff --git a/style_text_rec/utils/__init__.py b/style_text_rec/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/style_text_rec/utils/config.py b/style_text_rec/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f8a618a838db361da4867e00df8dcd619f9f3d --- /dev/null +++ b/style_text_rec/utils/config.py @@ -0,0 +1,224 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import yaml +import os +from argparse import ArgumentParser, RawDescriptionHelpFormatter + + +def override(dl, ks, v): + """ + Recursively replace dict of list + + Args: + dl(dict or list): dict or list to be replaced + ks(list): list of keys + v(str): value to be replaced + """ + + def str2num(v): + try: + return eval(v) + except Exception: + return v + + assert isinstance(dl, (list, dict)), ("{} should be a list or a dict") + assert len(ks) > 0, ('lenght of keys should larger than 0') + if isinstance(dl, list): + k = str2num(ks[0]) + if len(ks) == 1: + assert k < len(dl), ('index({}) out of range({})'.format(k, dl)) + dl[k] = str2num(v) + else: + override(dl[k], ks[1:], v) + else: + if len(ks) == 1: + #assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl)) + if not ks[0] in dl: + logger.warning('A new filed ({}) detected!'.format(ks[0], dl)) + dl[ks[0]] = str2num(v) + else: + assert ks[0] in dl, ( + '({}) doesn\'t exist in {}, a new dict field is invalid'. + format(ks[0], dl)) + override(dl[ks[0]], ks[1:], v) + + +def override_config(config, options=None): + """ + Recursively override the config + + Args: + config(dict): dict to be replaced + options(list): list of pairs(key0.key1.idx.key2=value) + such as: [ + 'topk=2', + 'VALID.transforms.1.ResizeImage.resize_short=300' + ] + + Returns: + config(dict): replaced config + """ + if options is not None: + for opt in options: + assert isinstance(opt, str), ( + "option({}) should be a str".format(opt)) + assert "=" in opt, ( + "option({}) should contain a =" + "to distinguish between key and value".format(opt)) + pair = opt.split('=') + assert len(pair) == 2, ("there can be only a = in the option") + key, value = pair + keys = key.split('.') + override(config, keys, value) + + return config + + +class ArgsParser(ArgumentParser): + def __init__(self): + super(ArgsParser, self).__init__( + formatter_class=RawDescriptionHelpFormatter) + self.add_argument("-c", "--config", help="configuration file to use") + self.add_argument( + "-t", "--tag", default="0", help="tag for marking worker") + self.add_argument( + '-o', + '--override', + action='append', + default=[], + help='config options to be overridden') + self.add_argument( + "--style_image", default="examples/style_images/1.jpg", help="tag for marking worker") + self.add_argument( + "--text_corpus", default="PaddleOCR", help="tag for marking worker") + self.add_argument( + "--language", default="en", help="tag for marking worker") + + def parse_args(self, argv=None): + args = super(ArgsParser, self).parse_args(argv) + assert args.config is not None, \ + "Please specify --config=configure_file_path." + return args + + +def load_config(file_path): + """ + Load config from yml/yaml file. + Args: + file_path (str): Path of the config file to be loaded. + Returns: config + """ + ext = os.path.splitext(file_path)[1] + assert ext in ['.yml', '.yaml'], "only support yaml files for now" + with open(file_path, 'rb') as f: + config = yaml.load(f, Loader=yaml.Loader) + + return config + + +def gen_config(): + base_config = { + "Global": { + "algorithm": "SRNet", + "use_gpu": True, + "start_epoch": 1, + "stage1_epoch_num": 100, + "stage2_epoch_num": 100, + "log_smooth_window": 20, + "print_batch_step": 2, + "save_model_dir": "./output/SRNet", + "use_visualdl": False, + "save_epoch_step": 10, + "vgg_pretrain": "./pretrained/VGG19_pretrained", + "vgg_load_static_pretrain": True + }, + "Architecture": { + "model_type": "data_aug", + "algorithm": "SRNet", + "net_g": { + "name": "srnet_net_g", + "encode_dim": 64, + "norm": "batch", + "use_dropout": False, + "init_type": "xavier", + "init_gain": 0.02, + "use_dilation": 1 + }, + # input_nc, ndf, netD, + # n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0' + "bg_discriminator": { + "name": "srnet_bg_discriminator", + "input_nc": 6, + "ndf": 64, + "netD": "basic", + "norm": "none", + "init_type": "xavier", + }, + "fusion_discriminator": { + "name": "srnet_fusion_discriminator", + "input_nc": 6, + "ndf": 64, + "netD": "basic", + "norm": "none", + "init_type": "xavier", + } + }, + "Loss": { + "lamb": 10, + "perceptual_lamb": 1, + "muvar_lamb": 50, + "style_lamb": 500 + }, + "Optimizer": { + "name": "Adam", + "learning_rate": { + "name": "lambda", + "lr": 0.0002, + "lr_decay_iters": 50 + }, + "beta1": 0.5, + "beta2": 0.999, + }, + "Train": { + "batch_size_per_card": 8, + "num_workers_per_card": 4, + "dataset": { + "delimiter": "\t", + "data_dir": "/", + "label_file": "tmp/label.txt", + "transforms": [{ + "DecodeImage": { + "to_rgb": True, + "to_np": False, + "channel_first": False + } + }, { + "NormalizeImage": { + "scale": 1. / 255., + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "order": None + } + }, { + "ToCHWImage": None + }] + } + } + } + with open("config.yml", "w") as f: + yaml.dump(base_config, f) + + +if __name__ == '__main__': + gen_config() diff --git a/style_text_rec/utils/load_params.py b/style_text_rec/utils/load_params.py new file mode 100644 index 0000000000000000000000000000000000000000..be0561363eb21483d267ff6557c1d453d330c5f8 --- /dev/null +++ b/style_text_rec/utils/load_params.py @@ -0,0 +1,27 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import paddle + +__all__ = ['load_dygraph_pretrain'] + + +def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False): + if not os.path.exists(path + '.pdparams'): + raise ValueError("Model pretrain path {} does not " + "exists.".format(path)) + param_state_dict = paddle.load(path + '.pdparams') + model.set_state_dict(param_state_dict) + logger.info("load pretrained model from {}".format(path)) + return diff --git a/style_text_rec/utils/logging.py b/style_text_rec/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..f700fe26bc9bfda21d39a0bddd89180f5de442ab --- /dev/null +++ b/style_text_rec/utils/logging.py @@ -0,0 +1,65 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import logging +import functools +import paddle.distributed as dist + +logger_initialized = {} + + +@functools.lru_cache() +def get_logger(name='srnet', log_file=None, log_level=logging.INFO): + """Initialize and get a logger by name. + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. If `log_file` is specified a FileHandler will also be added. + Args: + name (str): Logger name. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the logger. + log_level (int): The logger level. Note that only the process of + rank 0 is affected, and other processes will set the level to + "Error" thus be silent most of the time. + Returns: + logging.Logger: The expected logger. + """ + logger = logging.getLogger(name) + if name in logger_initialized: + return logger + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + formatter = logging.Formatter( + '[%(asctime)s] %(name)s %(levelname)s: %(message)s', + datefmt="%Y/%m/%d %H:%M:%S") + + stream_handler = logging.StreamHandler(stream=sys.stdout) + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + if log_file is not None and dist.get_rank() == 0: + log_file_folder = os.path.split(log_file)[0] + os.makedirs(log_file_folder, exist_ok=True) + file_handler = logging.FileHandler(log_file, 'a') + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + if dist.get_rank() == 0: + logger.setLevel(log_level) + else: + logger.setLevel(logging.ERROR) + logger_initialized[name] = True + return logger diff --git a/style_text_rec/utils/math_functions.py b/style_text_rec/utils/math_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc8d9160f8941f825d7aade79afc99035577bca --- /dev/null +++ b/style_text_rec/utils/math_functions.py @@ -0,0 +1,45 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle + + +def compute_mean_covariance(img): + batch_size = img.shape[0] + channel_num = img.shape[1] + height = img.shape[2] + width = img.shape[3] + num_pixels = height * width + + # batch_size * channel_num * 1 * 1 + mu = img.mean(2, keepdim=True).mean(3, keepdim=True) + + # batch_size * channel_num * num_pixels + img_hat = img - mu.expand_as(img) + img_hat = img_hat.reshape([batch_size, channel_num, num_pixels]) + # batch_size * num_pixels * channel_num + img_hat_transpose = img_hat.transpose([0, 2, 1]) + # batch_size * channel_num * channel_num + covariance = paddle.bmm(img_hat, img_hat_transpose) + covariance = covariance / num_pixels + + return mu, covariance + + +def dice_coefficient(y_true_cls, y_pred_cls, training_mask): + eps = 1e-5 + intersection = paddle.sum(y_true_cls * y_pred_cls * training_mask) + union = paddle.sum(y_true_cls * training_mask) + paddle.sum( + y_pred_cls * training_mask) + eps + loss = 1. - (2 * intersection / union) + return loss diff --git a/style_text_rec/utils/sys_funcs.py b/style_text_rec/utils/sys_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..203d91d83630e41fbe931a055e81e65cf0fb2e7d --- /dev/null +++ b/style_text_rec/utils/sys_funcs.py @@ -0,0 +1,67 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import os +import errno +import paddle + + +def get_check_global_params(mode): + check_params = [ + 'use_gpu', 'max_text_length', 'image_shape', 'image_shape', + 'character_type', 'loss_type' + ] + if mode == "train_eval": + check_params = check_params + [ + 'train_batch_size_per_card', 'test_batch_size_per_card' + ] + elif mode == "test": + check_params = check_params + ['test_batch_size_per_card'] + return check_params + + +def check_gpu(use_gpu): + """ + Log error and exit when set use_gpu=true in paddlepaddle + cpu version. + """ + err = "Config use_gpu cannot be set as true while you are " \ + "using paddlepaddle cpu version ! \nPlease try: \n" \ + "\t1. Install paddlepaddle-gpu to run model on GPU \n" \ + "\t2. Set use_gpu as false in config file to run " \ + "model on CPU" + if use_gpu: + try: + if not paddle.is_compiled_with_cuda(): + print(err) + sys.exit(1) + except: + print("Fail to check gpu state.") + sys.exit(1) + + +def _mkdir_if_not_exist(path, logger): + """ + mkdir if not exists, ignore the exception when multiprocess mkdir together + """ + if not os.path.exists(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno == errno.EEXIST and os.path.isdir(path): + logger.warning( + 'be happy if some process has already created {}'.format( + path)) + else: + raise OSError('Failed to mkdir {}'.format(path))