diff --git a/StyleTextRec/README.md b/StyleTextRec/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..36db8d44527c660bb0fcf985555896ae272babff
--- /dev/null
+++ b/StyleTextRec/README.md
@@ -0,0 +1,106 @@
+### Quick Start
+
+`Style-Text` is an improvement of the SRNet network proposed in Baidu's self-developed text editing algorithm "Editing Text in the Wild". It is different from the commonly used GAN methods. This tool decomposes the text synthesis task into three sub-modules to improve the effect of synthetic data: text style transfer module, background extraction module and fusion module.
+
+The following figure shows some example results. In addition, the actual `nameplate text recognition` scene and `the Korean text recognition` scene verify the effectiveness of the synthesis tool, as follows.
+
+
+#### Preparation
+
+1. Please refer the [QUICK INSTALLATION](../doc/doc_en/installation_en.md) to install PaddlePaddle. Python3 environment is strongly recommended.
+2. Download the pretrained models and unzip:
+
+```bash
+cd tools/style_text_rec
+wget /path/to/style_text_models.zip
+unzip style_text_models.zip
+```
+
+You can dowload models [here](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip). If you save the model files in other folders, please edit the three model paths in `configs/config.yml`:
+
+```
+bg_generator:
+ pretrain: style_text_rec/bg_generator
+...
+text_generator:
+ pretrain: style_text_models/text_generator
+...
+fusion_generator:
+ pretrain: style_text_models/fusion_generator
+```
+
+
+
+#### Demo
+
+1. You can use the following commands to run a demo:
+
+```bash
+python -m tools.synth_image -c configs/config.yml
+```
+
+2. The results are `fake_bg.jpg`, `fake_text.jpg` and `fake_fusion.jpg` as shown in the figure above. Above them:
+ * `fake_text.jpg` is the generated image with the same font style as `Style Input`;
+ * `fake_bg.jpg` is the generated image of `Style Input` after removing foreground.
+ * `fake_fusion.jpg` is the final result, that is synthesised by `fake_text.jpg` and `fake_bg.jpg`.
+
+3. If want to generate image by other `Style Input` or `Text Input`, you can modify the `tools/synth_image.py`:
+ * `img = cv2.imread("examples/style_images/1.jpg")`: the path of `Style Input`;
+ * `corpus = "PaddleOCR"`: the `Text Input`;
+ * Notice:modify the language option(`language = "en"`) to adapt `Text Input`, that support `en`, `ch`, `ko`.
+
+4. We also provide `batch_synth_images` mothod, that can combine corpus and pictures in pairs to generate a batch of data.
+
+### Advanced Usage
+
+#### Components
+
+`Style Text Rec` mainly contains the following components:
+
+* `style_samplers`: It can sample `Style Input` from a dataset. Now, We only provide `DatasetSampler`.
+
+* `corpus_generators`: It can generate corpus. Now, wo only provide two `corpus_generators`:
+ * `EnNumCorpus`: It can generate a random string according to a given length, including uppercase and lowercase English letters, numbers and spaces.
+ * `FileCorpus`: It can read a text file and randomly return the words in it.
+
+* `text_drawers`: It can generate `Text Input`(text picture in standard font according to the input corpus). Note that when using, you have to modify the language information according to the corpus.
+
+* `predictors`: It can call the deep learning model to generate new data based on the `Style Input` and `Text Input`.
+
+* `writers`: It can write the generated pictures(`fake_bg.jpg`, `fake_text.jpg` and `fake_fusion.jpg`) and label information to the disk.
+
+* `synthesisers`: It can call the all modules to complete the work.
+
+### Generate Dataset
+
+Before the start, you need to prepare some data as material.
+First, you should have the style reference data for synthesis tasks, which are generally used as datasets for OCR recognition tasks.
+
+1. The referenced dataset can be specifed in `configs/dataset_config.yml`:
+ * `StyleSampler`:
+ * `method`: The method of `StyleSampler`.
+ * `image_home`: The directory of pictures.
+ * `label_file`: The list of pictures path if `with_label` is `false`, otherwise, the label file path.
+ * `with_label`: The `label_file` is label file or not.
+
+ * `CorpusGenerator`:
+ * `method`: The mothod of `CorpusGenerator`. If `FileCorpus` used, you need modify `corpus_file` and `language` accordingly, if `EnNumCorpus`, other configurations is not needed.
+ * `language`: The language of the corpus. Needed if method is not `EnNumCorpus`.
+ * `corpus_file`: The corpus file path. Needed if method is not `EnNumCorpus`.
+
+2. You can run the following command to start synthesis task:
+
+ ``` bash
+ python -m tools.synth_dataset.py -c configs/dataset_config.yml
+ ```
+
+3. You can using the following command to start multiple synthesis tasks in a multi-threaded manner, which needed to specifying tags by `-t`:
+
+ ```bash
+ python -m tools.synth_dataset.py -t 0 -c configs/dataset_config.yml
+ python -m tools.synth_dataset.py -t 1 -c configs/dataset_config.yml
+ ```
+
+### OCR Recognition Training
+
+After completing the above operations, you can get the synthetic data set for OCR recognition. Next, please complete the training by refering to [OCR Recognition Document](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/recognition. md#%E5%90%AF%E5%8A%A8%E8%AE%AD%E7%BB%83).
\ No newline at end of file
diff --git a/StyleTextRec/README_ch.md b/StyleTextRec/README_ch.md
new file mode 100644
index 0000000000000000000000000000000000000000..b85615309f9a27187cd5378cc4b61fe2e4d6c50e
--- /dev/null
+++ b/StyleTextRec/README_ch.md
@@ -0,0 +1,164 @@
+## Style Text Rec
+
+### 目录
+- [工具简介](#工具简介)
+- [环境配置](#环境配置)
+- [快速上手](#快速上手)
+- [高级使用](#高级使用)
+- [应用示例](#应用示例)
+
+### 工具简介
+
+
+
+
+
+
+
+
+Style-Text数据合成工具是基于百度自研的文本编辑算法《Editing Text in the Wild》https://arxiv.org/abs/1908.03047
+不同于常用的基于GAN的数据合成工具,Style-Text主要框架包括:1.文本前景风格迁移模块 2.背景抽取模块 3.融合模块。经过这样三步,就可以迅速实现图片文字风格迁移。下图是一些该数据合成工具效果图。
+
+
+
+
+
+### 环境配置
+
+1. 参考[快速安装](../doc/doc_ch/installation.md),安装PaddleOCR。
+2. 进入`style_text_rec`目录,下载模型,并解压:
+
+```bash
+cd style_text_rec
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip
+unzip style_text_models.zip
+```
+
+如果您将模型保存再其他位置,请在`configs/config.yml`中修改模型文件的地址,修改时需要同时修改这三个配置:
+
+```
+bg_generator:
+ pretrain: style_text_models/bg_generator
+...
+text_generator:
+ pretrain: style_text_models/text_generator
+...
+fusion_generator:
+ pretrain: style_text_models/fusion_generator
+```
+
+### 快速上手
+
+1. 运行tools/synth_image,生成示例图片:
+
+```python
+python3 -m tools.synth_image -c configs/config.yml
+```
+
+1. 运行后,会生成`fake_busion.jpg`,即为最终结果。
+
+
+
+除此之外,程序还会生成并保存中间结果:
+ * `fake_bg.jpg`:为风格参考图去掉文字后的背景;
+ * `fake_text.jpg`:是用提供的字符串,仿照风格参考图中文字的风格,生成在灰色背景上的文字图片。
+
+2. 如果您想尝试其他风格图像和文字的效果,可以添加style_image,text_corpus和language参数:
+```python
+python3 -m tools.synth_image -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en
+```
+ * 注意:语言选项和语料相对应,目前我们支持英文、简体中文和韩语。
+
+3. 在`tools/synth_image.py`中,我们还提供了一个`batch_synth_images`方法,可以两两组合语料和图片,批量生成一批数据。
+
+### 高级使用
+
+在开始合成数据集前,需要准备一些素材。
+
+首先,需要风格图片作为合成图片的参考依据,这些数据可以是用作训练OCR识别模型的数据集。本例中使用带有标注文件的数据集作为风格图片.
+
+1. 在`configs/dataset_config.yml`中配置输入数据路径。
+ * `StyleSampler`:
+ * `method`:使用的风格图片采样方法;
+ * `image_home`:风格图片目录;
+ * `label_file`:风格图片路径列表文件,如果所用数据集有label,则label_file为label文件路径;
+ * `with_label`:标志`label_file`是否为label文件。
+ * `CorpusGenerator`:
+ * `method`:语料生成方法,目前有`FileCorpus`和`EnNumCorpus`可选。如果使用`EnNumCorpus`,则不需要填写其他配置,否则需要修改`corpus_file`和`language`;
+ * `language`:语料的语种;
+ * `corpus_file`: 语料文件路径。
+
+ 我们提供了一批[样例图](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/chkoen_5w.tar)供您试用,下面给出了一些示例:
+
+
+
+2. 运行`tools/synth_dataset`合成数据:
+
+ ``` bash
+ python -m tools.synth_dataset -c configs/dataset_config.yml
+ ```
+
+3. 如果您想使用并行方式来快速合成数据,可以通过启动多个进程,在启动时需要指定不同的`tag`(`-t`),如下所示:
+
+ ```bash
+ python3 -m tools.synth_dataset -t 0 -c configs/dataset_config.yml
+ python3 -m tools.synth_dataset -t 1 -c configs/dataset_config.yml
+ ```
+
+
+### 应用示例
+
+在完成上述操作后,即可得到用于OCR识别的合成数据集,下面给出了一些数据集生成的示例:
+
+
+
+请您参考[OCR识别文档](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/recognition.md#%E5%90%AF%E5%8A%A8%E8%AE%AD%E7%BB%83),完成训练。
+
+下面展示了一些使用合成数据训练的效果:
+
+| 场景 | 字符 | 原始数据 | 测试数据 | 只使用原始数据的识别准确率 | 新增合成数据 | 使用合成数据识别准确率 | 指标提升 |
+| -------- | ---------- | -------- | -------- | -------------------------- | ------------ | ---------------------- | -------- |
+| 金属表面 | 英文和数字 | 2203 | 650 | 0.5938 | 20000 | 0.7546 | 16% |
+| 随机背景 | 韩语 | 5631 | 1230 | 0.3012 | 100000 | 0.5057 | 20% |
+
+### 项目结构
+```
+style_text_rec
+|-- arch
+| |-- base_module.py
+| |-- decoder.py
+| |-- encoder.py
+| |-- spectral_norm.py
+| `-- style_text_rec.py
+|-- configs
+| |-- config.yml
+| `-- dataset_config.yml
+|-- engine
+| |-- corpus_generators.py
+| |-- predictors.py
+| |-- style_samplers.py
+| |-- synthesisers.py
+| |-- text_drawers.py
+| `-- writers.py
+|-- examples
+| |-- corpus
+| | `-- example.txt
+| |-- image_list.txt
+| `-- style_images
+| |-- 1.jpg
+| `-- 2.jpg
+|-- fonts
+| |-- ch_standard.ttf
+| |-- en_standard.ttf
+| `-- ko_standard.ttf
+|-- tools
+| |-- __init__.py
+| |-- synth_dataset.py
+| `-- synth_image.py
+`-- utils
+ |-- config.py
+ |-- load_params.py
+ |-- logging.py
+ |-- math_functions.py
+ `-- sys_funcs.py
+```
diff --git a/StyleTextRec/__init__.py b/StyleTextRec/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/StyleTextRec/arch/__init__.py b/StyleTextRec/arch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/StyleTextRec/arch/base_module.py b/StyleTextRec/arch/base_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..da2b6b834c6a86b1c3efeb5cef4cb9d02e44e405
--- /dev/null
+++ b/StyleTextRec/arch/base_module.py
@@ -0,0 +1,255 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+import paddle.nn as nn
+
+from arch.spectral_norm import spectral_norm
+
+
+class CBN(nn.Layer):
+ def __init__(self,
+ name,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ use_bias=False,
+ norm_layer=None,
+ act=None,
+ act_attr=None):
+ super(CBN, self).__init__()
+ if use_bias:
+ bias_attr = paddle.ParamAttr(name=name + "_bias")
+ else:
+ bias_attr = None
+ self._conv = paddle.nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ weight_attr=paddle.ParamAttr(name=name + "_weights"),
+ bias_attr=bias_attr)
+ if norm_layer:
+ self._norm_layer = getattr(paddle.nn, norm_layer)(
+ num_features=out_channels, name=name + "_bn")
+ else:
+ self._norm_layer = None
+ if act:
+ if act_attr:
+ self._act = getattr(paddle.nn, act)(**act_attr,
+ name=name + "_" + act)
+ else:
+ self._act = getattr(paddle.nn, act)(name=name + "_" + act)
+ else:
+ self._act = None
+
+ def forward(self, x):
+ out = self._conv(x)
+ if self._norm_layer:
+ out = self._norm_layer(out)
+ if self._act:
+ out = self._act(out)
+ return out
+
+
+class SNConv(nn.Layer):
+ def __init__(self,
+ name,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ use_bias=False,
+ norm_layer=None,
+ act=None,
+ act_attr=None):
+ super(SNConv, self).__init__()
+ if use_bias:
+ bias_attr = paddle.ParamAttr(name=name + "_bias")
+ else:
+ bias_attr = None
+ self._sn_conv = spectral_norm(
+ paddle.nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ weight_attr=paddle.ParamAttr(name=name + "_weights"),
+ bias_attr=bias_attr))
+ if norm_layer:
+ self._norm_layer = getattr(paddle.nn, norm_layer)(
+ num_features=out_channels, name=name + "_bn")
+ else:
+ self._norm_layer = None
+ if act:
+ if act_attr:
+ self._act = getattr(paddle.nn, act)(**act_attr,
+ name=name + "_" + act)
+ else:
+ self._act = getattr(paddle.nn, act)(name=name + "_" + act)
+ else:
+ self._act = None
+
+ def forward(self, x):
+ out = self._sn_conv(x)
+ if self._norm_layer:
+ out = self._norm_layer(out)
+ if self._act:
+ out = self._act(out)
+ return out
+
+
+class SNConvTranspose(nn.Layer):
+ def __init__(self,
+ name,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ output_padding=0,
+ dilation=1,
+ groups=1,
+ use_bias=False,
+ norm_layer=None,
+ act=None,
+ act_attr=None):
+ super(SNConvTranspose, self).__init__()
+ if use_bias:
+ bias_attr = paddle.ParamAttr(name=name + "_bias")
+ else:
+ bias_attr = None
+ self._sn_conv_transpose = spectral_norm(
+ paddle.nn.Conv2DTranspose(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ dilation=dilation,
+ groups=groups,
+ weight_attr=paddle.ParamAttr(name=name + "_weights"),
+ bias_attr=bias_attr))
+ if norm_layer:
+ self._norm_layer = getattr(paddle.nn, norm_layer)(
+ num_features=out_channels, name=name + "_bn")
+ else:
+ self._norm_layer = None
+ if act:
+ if act_attr:
+ self._act = getattr(paddle.nn, act)(**act_attr,
+ name=name + "_" + act)
+ else:
+ self._act = getattr(paddle.nn, act)(name=name + "_" + act)
+ else:
+ self._act = None
+
+ def forward(self, x):
+ out = self._sn_conv_transpose(x)
+ if self._norm_layer:
+ out = self._norm_layer(out)
+ if self._act:
+ out = self._act(out)
+ return out
+
+
+class MiddleNet(nn.Layer):
+ def __init__(self, name, in_channels, mid_channels, out_channels,
+ use_bias):
+ super(MiddleNet, self).__init__()
+ self._sn_conv1 = SNConv(
+ name=name + "_sn_conv1",
+ in_channels=in_channels,
+ out_channels=mid_channels,
+ kernel_size=1,
+ use_bias=use_bias,
+ norm_layer=None,
+ act=None)
+ self._pad2d = nn.Pad2D(padding=[1, 1, 1, 1], mode="replicate")
+ self._sn_conv2 = SNConv(
+ name=name + "_sn_conv2",
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ kernel_size=3,
+ use_bias=use_bias)
+ self._sn_conv3 = SNConv(
+ name=name + "_sn_conv3",
+ in_channels=mid_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ use_bias=use_bias)
+
+ def forward(self, x):
+
+ sn_conv1 = self._sn_conv1.forward(x)
+ pad_2d = self._pad2d.forward(sn_conv1)
+ sn_conv2 = self._sn_conv2.forward(pad_2d)
+ sn_conv3 = self._sn_conv3.forward(sn_conv2)
+ return sn_conv3
+
+
+class ResBlock(nn.Layer):
+ def __init__(self, name, channels, norm_layer, use_dropout, use_dilation,
+ use_bias):
+ super(ResBlock, self).__init__()
+ if use_dilation:
+ padding_mat = [1, 1, 1, 1]
+ else:
+ padding_mat = [0, 0, 0, 0]
+ self._pad1 = nn.Pad2D(padding_mat, mode="replicate")
+
+ self._sn_conv1 = SNConv(
+ name=name + "_sn_conv1",
+ in_channels=channels,
+ out_channels=channels,
+ kernel_size=3,
+ padding=0,
+ norm_layer=norm_layer,
+ use_bias=use_bias,
+ act="ReLU",
+ act_attr=None)
+ if use_dropout:
+ self._dropout = nn.Dropout(0.5)
+ else:
+ self._dropout = None
+ self._pad2 = nn.Pad2D([1, 1, 1, 1], mode="replicate")
+ self._sn_conv2 = SNConv(
+ name=name + "_sn_conv2",
+ in_channels=channels,
+ out_channels=channels,
+ kernel_size=3,
+ norm_layer=norm_layer,
+ use_bias=use_bias,
+ act="ReLU",
+ act_attr=None)
+
+ def forward(self, x):
+ pad1 = self._pad1.forward(x)
+ sn_conv1 = self._sn_conv1.forward(pad1)
+ pad2 = self._pad2.forward(sn_conv1)
+ sn_conv2 = self._sn_conv2.forward(pad2)
+ return sn_conv2 + x
diff --git a/StyleTextRec/arch/decoder.py b/StyleTextRec/arch/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..36f07c5998a8f6b400997eacae0b44860312f432
--- /dev/null
+++ b/StyleTextRec/arch/decoder.py
@@ -0,0 +1,251 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+import paddle.nn as nn
+
+from arch.base_module import SNConv, SNConvTranspose, ResBlock
+
+
+class Decoder(nn.Layer):
+ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
+ act, act_attr, conv_block_dropout, conv_block_num,
+ conv_block_dilation, out_conv_act, out_conv_act_attr):
+ super(Decoder, self).__init__()
+ conv_blocks = []
+ for i in range(conv_block_num):
+ conv_blocks.append(
+ ResBlock(
+ name="{}_conv_block_{}".format(name, i),
+ channels=encode_dim * 8,
+ norm_layer=norm_layer,
+ use_dropout=conv_block_dropout,
+ use_dilation=conv_block_dilation,
+ use_bias=use_bias))
+ self.conv_blocks = nn.Sequential(*conv_blocks)
+ self._up1 = SNConvTranspose(
+ name=name + "_up1",
+ in_channels=encode_dim * 8,
+ out_channels=encode_dim * 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._up2 = SNConvTranspose(
+ name=name + "_up2",
+ in_channels=encode_dim * 4,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._up3 = SNConvTranspose(
+ name=name + "_up3",
+ in_channels=encode_dim * 2,
+ out_channels=encode_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
+ self._out_conv = SNConv(
+ name=name + "_out_conv",
+ in_channels=encode_dim,
+ out_channels=out_channels,
+ kernel_size=3,
+ use_bias=use_bias,
+ norm_layer=None,
+ act=out_conv_act,
+ act_attr=out_conv_act_attr)
+
+ def forward(self, x):
+ if isinstance(x, (list, tuple)):
+ x = paddle.concat(x, axis=1)
+ output_dict = dict()
+ output_dict["conv_blocks"] = self.conv_blocks.forward(x)
+ output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
+ output_dict["up2"] = self._up2.forward(output_dict["up1"])
+ output_dict["up3"] = self._up3.forward(output_dict["up2"])
+ output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
+ output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
+ return output_dict
+
+
+class DecoderUnet(nn.Layer):
+ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
+ act, act_attr, conv_block_dropout, conv_block_num,
+ conv_block_dilation, out_conv_act, out_conv_act_attr):
+ super(DecoderUnet, self).__init__()
+ conv_blocks = []
+ for i in range(conv_block_num):
+ conv_blocks.append(
+ ResBlock(
+ name="{}_conv_block_{}".format(name, i),
+ channels=encode_dim * 8,
+ norm_layer=norm_layer,
+ use_dropout=conv_block_dropout,
+ use_dilation=conv_block_dilation,
+ use_bias=use_bias))
+ self._conv_blocks = nn.Sequential(*conv_blocks)
+ self._up1 = SNConvTranspose(
+ name=name + "_up1",
+ in_channels=encode_dim * 8,
+ out_channels=encode_dim * 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._up2 = SNConvTranspose(
+ name=name + "_up2",
+ in_channels=encode_dim * 8,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._up3 = SNConvTranspose(
+ name=name + "_up3",
+ in_channels=encode_dim * 4,
+ out_channels=encode_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
+ self._out_conv = SNConv(
+ name=name + "_out_conv",
+ in_channels=encode_dim,
+ out_channels=out_channels,
+ kernel_size=3,
+ use_bias=use_bias,
+ norm_layer=None,
+ act=out_conv_act,
+ act_attr=out_conv_act_attr)
+
+ def forward(self, x, y, feature2, feature1):
+ output_dict = dict()
+ output_dict["conv_blocks"] = self._conv_blocks(
+ paddle.concat(
+ (x, y), axis=1))
+ output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
+ output_dict["up2"] = self._up2.forward(
+ paddle.concat(
+ (output_dict["up1"], feature2), axis=1))
+ output_dict["up3"] = self._up3.forward(
+ paddle.concat(
+ (output_dict["up2"], feature1), axis=1))
+ output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
+ output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
+ return output_dict
+
+
+class SingleDecoder(nn.Layer):
+ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
+ act, act_attr, conv_block_dropout, conv_block_num,
+ conv_block_dilation, out_conv_act, out_conv_act_attr):
+ super(SingleDecoder, self).__init__()
+ conv_blocks = []
+ for i in range(conv_block_num):
+ conv_blocks.append(
+ ResBlock(
+ name="{}_conv_block_{}".format(name, i),
+ channels=encode_dim * 4,
+ norm_layer=norm_layer,
+ use_dropout=conv_block_dropout,
+ use_dilation=conv_block_dilation,
+ use_bias=use_bias))
+ self._conv_blocks = nn.Sequential(*conv_blocks)
+ self._up1 = SNConvTranspose(
+ name=name + "_up1",
+ in_channels=encode_dim * 4,
+ out_channels=encode_dim * 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._up2 = SNConvTranspose(
+ name=name + "_up2",
+ in_channels=encode_dim * 8,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._up3 = SNConvTranspose(
+ name=name + "_up3",
+ in_channels=encode_dim * 4,
+ out_channels=encode_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
+ self._out_conv = SNConv(
+ name=name + "_out_conv",
+ in_channels=encode_dim,
+ out_channels=out_channels,
+ kernel_size=3,
+ use_bias=use_bias,
+ norm_layer=None,
+ act=out_conv_act,
+ act_attr=out_conv_act_attr)
+
+ def forward(self, x, feature2, feature1):
+ output_dict = dict()
+ output_dict["conv_blocks"] = self._conv_blocks.forward(x)
+ output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
+ output_dict["up2"] = self._up2.forward(
+ paddle.concat(
+ (output_dict["up1"], feature2), axis=1))
+ output_dict["up3"] = self._up3.forward(
+ paddle.concat(
+ (output_dict["up2"], feature1), axis=1))
+ output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
+ output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
+ return output_dict
diff --git a/StyleTextRec/arch/encoder.py b/StyleTextRec/arch/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..b884cda2934477082a1ed98c94e33b736d1f96b4
--- /dev/null
+++ b/StyleTextRec/arch/encoder.py
@@ -0,0 +1,186 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+import paddle.nn as nn
+
+from arch.base_module import SNConv, SNConvTranspose, ResBlock
+
+
+class Encoder(nn.Layer):
+ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
+ act, act_attr, conv_block_dropout, conv_block_num,
+ conv_block_dilation):
+ super(Encoder, self).__init__()
+ self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
+ self._in_conv = SNConv(
+ name=name + "_in_conv",
+ in_channels=in_channels,
+ out_channels=encode_dim,
+ kernel_size=7,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._down1 = SNConv(
+ name=name + "_down1",
+ in_channels=encode_dim,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._down2 = SNConv(
+ name=name + "_down2",
+ in_channels=encode_dim * 2,
+ out_channels=encode_dim * 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._down3 = SNConv(
+ name=name + "_down3",
+ in_channels=encode_dim * 4,
+ out_channels=encode_dim * 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ conv_blocks = []
+ for i in range(conv_block_num):
+ conv_blocks.append(
+ ResBlock(
+ name="{}_conv_block_{}".format(name, i),
+ channels=encode_dim * 4,
+ norm_layer=norm_layer,
+ use_dropout=conv_block_dropout,
+ use_dilation=conv_block_dilation,
+ use_bias=use_bias))
+ self._conv_blocks = nn.Sequential(*conv_blocks)
+
+ def forward(self, x):
+ out_dict = dict()
+ x = self._pad2d(x)
+ out_dict["in_conv"] = self._in_conv.forward(x)
+ out_dict["down1"] = self._down1.forward(out_dict["in_conv"])
+ out_dict["down2"] = self._down2.forward(out_dict["down1"])
+ out_dict["down3"] = self._down3.forward(out_dict["down2"])
+ out_dict["res_blocks"] = self._conv_blocks.forward(out_dict["down3"])
+ return out_dict
+
+
+class EncoderUnet(nn.Layer):
+ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
+ act, act_attr):
+ super(EncoderUnet, self).__init__()
+ self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
+ self._in_conv = SNConv(
+ name=name + "_in_conv",
+ in_channels=in_channels,
+ out_channels=encode_dim,
+ kernel_size=7,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._down1 = SNConv(
+ name=name + "_down1",
+ in_channels=encode_dim,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._down2 = SNConv(
+ name=name + "_down2",
+ in_channels=encode_dim * 2,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._down3 = SNConv(
+ name=name + "_down3",
+ in_channels=encode_dim * 2,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._down4 = SNConv(
+ name=name + "_down4",
+ in_channels=encode_dim * 2,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._up1 = SNConvTranspose(
+ name=name + "_up1",
+ in_channels=encode_dim * 2,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._up2 = SNConvTranspose(
+ name=name + "_up2",
+ in_channels=encode_dim * 4,
+ out_channels=encode_dim * 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+
+ def forward(self, x):
+ output_dict = dict()
+ x = self._pad2d(x)
+ output_dict['in_conv'] = self._in_conv.forward(x)
+ output_dict['down1'] = self._down1.forward(output_dict['in_conv'])
+ output_dict['down2'] = self._down2.forward(output_dict['down1'])
+ output_dict['down3'] = self._down3.forward(output_dict['down2'])
+ output_dict['down4'] = self._down4.forward(output_dict['down3'])
+ output_dict['up1'] = self._up1.forward(output_dict['down4'])
+ output_dict['up2'] = self._up2.forward(
+ paddle.concat(
+ (output_dict['down3'], output_dict['up1']), axis=1))
+ output_dict['concat'] = paddle.concat(
+ (output_dict['down2'], output_dict['up2']), axis=1)
+ return output_dict
diff --git a/StyleTextRec/arch/spectral_norm.py b/StyleTextRec/arch/spectral_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..21d0afc8d4a8fd4e2262db5c8461d6ffc3dadd45
--- /dev/null
+++ b/StyleTextRec/arch/spectral_norm.py
@@ -0,0 +1,150 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+def normal_(x, mean=0., std=1.):
+ temp_value = paddle.normal(mean, std, shape=x.shape)
+ x.set_value(temp_value)
+ return x
+
+
+class SpectralNorm(object):
+ def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
+ self.name = name
+ self.dim = dim
+ if n_power_iterations <= 0:
+ raise ValueError('Expected n_power_iterations to be positive, but '
+ 'got n_power_iterations={}'.format(
+ n_power_iterations))
+ self.n_power_iterations = n_power_iterations
+ self.eps = eps
+
+ def reshape_weight_to_matrix(self, weight):
+ weight_mat = weight
+ if self.dim != 0:
+ # transpose dim to front
+ weight_mat = weight_mat.transpose([
+ self.dim,
+ * [d for d in range(weight_mat.dim()) if d != self.dim]
+ ])
+
+ height = weight_mat.shape[0]
+
+ return weight_mat.reshape([height, -1])
+
+ def compute_weight(self, module, do_power_iteration):
+ weight = getattr(module, self.name + '_orig')
+ u = getattr(module, self.name + '_u')
+ v = getattr(module, self.name + '_v')
+ weight_mat = self.reshape_weight_to_matrix(weight)
+
+ if do_power_iteration:
+ with paddle.no_grad():
+ for _ in range(self.n_power_iterations):
+ v.set_value(
+ F.normalize(
+ paddle.matmul(
+ weight_mat,
+ u,
+ transpose_x=True,
+ transpose_y=False),
+ axis=0,
+ epsilon=self.eps, ))
+
+ u.set_value(
+ F.normalize(
+ paddle.matmul(weight_mat, v),
+ axis=0,
+ epsilon=self.eps, ))
+ if self.n_power_iterations > 0:
+ u = u.clone()
+ v = v.clone()
+
+ sigma = paddle.dot(u, paddle.mv(weight_mat, v))
+ weight = weight / sigma
+ return weight
+
+ def remove(self, module):
+ with paddle.no_grad():
+ weight = self.compute_weight(module, do_power_iteration=False)
+ delattr(module, self.name)
+ delattr(module, self.name + '_u')
+ delattr(module, self.name + '_v')
+ delattr(module, self.name + '_orig')
+
+ module.add_parameter(self.name, weight.detach())
+
+ def __call__(self, module, inputs):
+ setattr(
+ module,
+ self.name,
+ self.compute_weight(
+ module, do_power_iteration=module.training))
+
+ @staticmethod
+ def apply(module, name, n_power_iterations, dim, eps):
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, SpectralNorm) and hook.name == name:
+ raise RuntimeError(
+ "Cannot register two spectral_norm hooks on "
+ "the same parameter {}".format(name))
+
+ fn = SpectralNorm(name, n_power_iterations, dim, eps)
+ weight = module._parameters[name]
+
+ with paddle.no_grad():
+ weight_mat = fn.reshape_weight_to_matrix(weight)
+ h, w = weight_mat.shape
+
+ # randomly initialize u and v
+ u = module.create_parameter([h])
+ u = normal_(u, 0., 1.)
+ v = module.create_parameter([w])
+ v = normal_(v, 0., 1.)
+ u = F.normalize(u, axis=0, epsilon=fn.eps)
+ v = F.normalize(v, axis=0, epsilon=fn.eps)
+
+ # delete fn.name form parameters, otherwise you can not set attribute
+ del module._parameters[fn.name]
+ module.add_parameter(fn.name + "_orig", weight)
+ # still need to assign weight back as fn.name because all sorts of
+ # things may assume that it exists, e.g., when initializing weights.
+ # However, we can't directly assign as it could be an Parameter and
+ # gets added as a parameter. Instead, we register weight * 1.0 as a plain
+ # attribute.
+ setattr(module, fn.name, weight * 1.0)
+ module.register_buffer(fn.name + "_u", u)
+ module.register_buffer(fn.name + "_v", v)
+
+ module.register_forward_pre_hook(fn)
+ return fn
+
+
+def spectral_norm(module,
+ name='weight',
+ n_power_iterations=1,
+ eps=1e-12,
+ dim=None):
+
+ if dim is None:
+ if isinstance(module, (nn.Conv1DTranspose, nn.Conv2DTranspose,
+ nn.Conv3DTranspose, nn.Linear)):
+ dim = 1
+ else:
+ dim = 0
+ SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
+ return module
diff --git a/StyleTextRec/arch/style_text_rec.py b/StyleTextRec/arch/style_text_rec.py
new file mode 100644
index 0000000000000000000000000000000000000000..599927ce3edefc90f14191ef3d29b1221355867e
--- /dev/null
+++ b/StyleTextRec/arch/style_text_rec.py
@@ -0,0 +1,285 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+import paddle.nn as nn
+
+from arch.base_module import MiddleNet, ResBlock
+from arch.encoder import Encoder
+from arch.decoder import Decoder, DecoderUnet, SingleDecoder
+from utils.load_params import load_dygraph_pretrain
+from utils.logging import get_logger
+
+
+class StyleTextRec(nn.Layer):
+ def __init__(self, config):
+ super(StyleTextRec, self).__init__()
+ self.logger = get_logger()
+ self.text_generator = TextGenerator(config["Predictor"][
+ "text_generator"])
+ self.bg_generator = BgGeneratorWithMask(config["Predictor"][
+ "bg_generator"])
+ self.fusion_generator = FusionGeneratorSimple(config["Predictor"][
+ "fusion_generator"])
+ bg_generator_pretrain = config["Predictor"]["bg_generator"]["pretrain"]
+ text_generator_pretrain = config["Predictor"]["text_generator"][
+ "pretrain"]
+ fusion_generator_pretrain = config["Predictor"]["fusion_generator"][
+ "pretrain"]
+ load_dygraph_pretrain(
+ self.bg_generator,
+ self.logger,
+ path=bg_generator_pretrain,
+ load_static_weights=False)
+ load_dygraph_pretrain(
+ self.text_generator,
+ self.logger,
+ path=text_generator_pretrain,
+ load_static_weights=False)
+ load_dygraph_pretrain(
+ self.fusion_generator,
+ self.logger,
+ path=fusion_generator_pretrain,
+ load_static_weights=False)
+
+ def forward(self, style_input, text_input):
+ text_gen_output = self.text_generator.forward(style_input, text_input)
+ fake_text = text_gen_output["fake_text"]
+ fake_sk = text_gen_output["fake_sk"]
+ bg_gen_output = self.bg_generator.forward(style_input)
+ bg_encode_feature = bg_gen_output["bg_encode_feature"]
+ bg_decode_feature1 = bg_gen_output["bg_decode_feature1"]
+ bg_decode_feature2 = bg_gen_output["bg_decode_feature2"]
+ fake_bg = bg_gen_output["fake_bg"]
+
+ fusion_gen_output = self.fusion_generator.forward(fake_text, fake_bg)
+ fake_fusion = fusion_gen_output["fake_fusion"]
+ return {
+ "fake_fusion": fake_fusion,
+ "fake_text": fake_text,
+ "fake_sk": fake_sk,
+ "fake_bg": fake_bg,
+ }
+
+
+class TextGenerator(nn.Layer):
+ def __init__(self, config):
+ super(TextGenerator, self).__init__()
+ name = config["module_name"]
+ encode_dim = config["encode_dim"]
+ norm_layer = config["norm_layer"]
+ conv_block_dropout = config["conv_block_dropout"]
+ conv_block_num = config["conv_block_num"]
+ conv_block_dilation = config["conv_block_dilation"]
+ if norm_layer == "InstanceNorm2D":
+ use_bias = True
+ else:
+ use_bias = False
+ self.encoder_text = Encoder(
+ name=name + "_encoder_text",
+ in_channels=3,
+ encode_dim=encode_dim,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act="ReLU",
+ act_attr=None,
+ conv_block_dropout=conv_block_dropout,
+ conv_block_num=conv_block_num,
+ conv_block_dilation=conv_block_dilation)
+ self.encoder_style = Encoder(
+ name=name + "_encoder_style",
+ in_channels=3,
+ encode_dim=encode_dim,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act="ReLU",
+ act_attr=None,
+ conv_block_dropout=conv_block_dropout,
+ conv_block_num=conv_block_num,
+ conv_block_dilation=conv_block_dilation)
+ self.decoder_text = Decoder(
+ name=name + "_decoder_text",
+ encode_dim=encode_dim,
+ out_channels=int(encode_dim / 2),
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act="ReLU",
+ act_attr=None,
+ conv_block_dropout=conv_block_dropout,
+ conv_block_num=conv_block_num,
+ conv_block_dilation=conv_block_dilation,
+ out_conv_act="Tanh",
+ out_conv_act_attr=None)
+ self.decoder_sk = Decoder(
+ name=name + "_decoder_sk",
+ encode_dim=encode_dim,
+ out_channels=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act="ReLU",
+ act_attr=None,
+ conv_block_dropout=conv_block_dropout,
+ conv_block_num=conv_block_num,
+ conv_block_dilation=conv_block_dilation,
+ out_conv_act="Sigmoid",
+ out_conv_act_attr=None)
+
+ self.middle = MiddleNet(
+ name=name + "_middle_net",
+ in_channels=int(encode_dim / 2) + 1,
+ mid_channels=encode_dim,
+ out_channels=3,
+ use_bias=use_bias)
+
+ def forward(self, style_input, text_input):
+ style_feature = self.encoder_style.forward(style_input)["res_blocks"]
+ text_feature = self.encoder_text.forward(text_input)["res_blocks"]
+ fake_c_temp = self.decoder_text.forward([text_feature,
+ style_feature])["out_conv"]
+ fake_sk = self.decoder_sk.forward([text_feature,
+ style_feature])["out_conv"]
+ fake_text = self.middle(paddle.concat((fake_c_temp, fake_sk), axis=1))
+ return {"fake_sk": fake_sk, "fake_text": fake_text}
+
+
+class BgGeneratorWithMask(nn.Layer):
+ def __init__(self, config):
+ super(BgGeneratorWithMask, self).__init__()
+ name = config["module_name"]
+ encode_dim = config["encode_dim"]
+ norm_layer = config["norm_layer"]
+ conv_block_dropout = config["conv_block_dropout"]
+ conv_block_num = config["conv_block_num"]
+ conv_block_dilation = config["conv_block_dilation"]
+ self.output_factor = config.get("output_factor", 1.0)
+
+ if norm_layer == "InstanceNorm2D":
+ use_bias = True
+ else:
+ use_bias = False
+
+ self.encoder_bg = Encoder(
+ name=name + "_encoder_bg",
+ in_channels=3,
+ encode_dim=encode_dim,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act="ReLU",
+ act_attr=None,
+ conv_block_dropout=conv_block_dropout,
+ conv_block_num=conv_block_num,
+ conv_block_dilation=conv_block_dilation)
+
+ self.decoder_bg = SingleDecoder(
+ name=name + "_decoder_bg",
+ encode_dim=encode_dim,
+ out_channels=3,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act="ReLU",
+ act_attr=None,
+ conv_block_dropout=conv_block_dropout,
+ conv_block_num=conv_block_num,
+ conv_block_dilation=conv_block_dilation,
+ out_conv_act="Tanh",
+ out_conv_act_attr=None)
+
+ self.decoder_mask = Decoder(
+ name=name + "_decoder_mask",
+ encode_dim=encode_dim // 2,
+ out_channels=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act="ReLU",
+ act_attr=None,
+ conv_block_dropout=conv_block_dropout,
+ conv_block_num=conv_block_num,
+ conv_block_dilation=conv_block_dilation,
+ out_conv_act="Sigmoid",
+ out_conv_act_attr=None)
+
+ self.middle = MiddleNet(
+ name=name + "_middle_net",
+ in_channels=3 + 1,
+ mid_channels=encode_dim,
+ out_channels=3,
+ use_bias=use_bias)
+
+ def forward(self, style_input):
+ encode_bg_output = self.encoder_bg(style_input)
+ decode_bg_output = self.decoder_bg(encode_bg_output["res_blocks"],
+ encode_bg_output["down2"],
+ encode_bg_output["down1"])
+
+ fake_c_temp = decode_bg_output["out_conv"]
+ fake_bg_mask = self.decoder_mask.forward(encode_bg_output[
+ "res_blocks"])["out_conv"]
+ fake_bg = self.middle(
+ paddle.concat(
+ (fake_c_temp, fake_bg_mask), axis=1))
+ return {
+ "bg_encode_feature": encode_bg_output["res_blocks"],
+ "bg_decode_feature1": decode_bg_output["up1"],
+ "bg_decode_feature2": decode_bg_output["up2"],
+ "fake_bg": fake_bg,
+ "fake_bg_mask": fake_bg_mask,
+ }
+
+
+class FusionGeneratorSimple(nn.Layer):
+ def __init__(self, config):
+ super(FusionGeneratorSimple, self).__init__()
+ name = config["module_name"]
+ encode_dim = config["encode_dim"]
+ norm_layer = config["norm_layer"]
+ conv_block_dropout = config["conv_block_dropout"]
+ conv_block_dilation = config["conv_block_dilation"]
+ if norm_layer == "InstanceNorm2D":
+ use_bias = True
+ else:
+ use_bias = False
+
+ self._conv = nn.Conv2D(
+ in_channels=6,
+ out_channels=encode_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=1,
+ weight_attr=paddle.ParamAttr(name=name + "_conv_weights"),
+ bias_attr=False)
+
+ self._res_block = ResBlock(
+ name="{}_conv_block".format(name),
+ channels=encode_dim,
+ norm_layer=norm_layer,
+ use_dropout=conv_block_dropout,
+ use_dilation=conv_block_dilation,
+ use_bias=use_bias)
+
+ self._reduce_conv = nn.Conv2D(
+ in_channels=encode_dim,
+ out_channels=3,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=1,
+ weight_attr=paddle.ParamAttr(name=name + "_reduce_conv_weights"),
+ bias_attr=False)
+
+ def forward(self, fake_text, fake_bg):
+ fake_concat = paddle.concat((fake_text, fake_bg), axis=1)
+ fake_concat_tmp = self._conv(fake_concat)
+ output_res = self._res_block(fake_concat_tmp)
+ fake_fusion = self._reduce_conv(output_res)
+ return {"fake_fusion": fake_fusion}
diff --git a/StyleTextRec/configs/config.yml b/StyleTextRec/configs/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..3b10b3d2761a4aa40c28abe10134a2f276e1af9d
--- /dev/null
+++ b/StyleTextRec/configs/config.yml
@@ -0,0 +1,54 @@
+Global:
+ output_num: 10
+ output_dir: output_data
+ use_gpu: false
+ image_height: 32
+ image_width: 320
+TextDrawer:
+ fonts:
+ en: fonts/en_standard.ttf
+ ch: fonts/ch_standard.ttf
+ ko: fonts/ko_standard.ttf
+Predictor:
+ method: StyleTextRecPredictor
+ algorithm: StyleTextRec
+ scale: 0.00392156862745098
+ mean:
+ - 0.5
+ - 0.5
+ - 0.5
+ std:
+ - 0.5
+ - 0.5
+ - 0.5
+ expand_result: false
+ bg_generator:
+ pretrain: style_text_models/bg_generator
+ module_name: bg_generator
+ generator_type: BgGeneratorWithMask
+ encode_dim: 64
+ norm_layer: null
+ conv_block_num: 4
+ conv_block_dropout: false
+ conv_block_dilation: true
+ output_factor: 1.05
+ text_generator:
+ pretrain: style_text_models/text_generator
+ module_name: text_generator
+ generator_type: TextGenerator
+ encode_dim: 64
+ norm_layer: InstanceNorm2D
+ conv_block_num: 4
+ conv_block_dropout: false
+ conv_block_dilation: true
+ fusion_generator:
+ pretrain: style_text_models/fusion_generator
+ module_name: fusion_generator
+ generator_type: FusionGeneratorSimple
+ encode_dim: 64
+ norm_layer: null
+ conv_block_num: 4
+ conv_block_dropout: false
+ conv_block_dilation: true
+Writer:
+ method: SimpleWriter
diff --git a/StyleTextRec/configs/dataset_config.yml b/StyleTextRec/configs/dataset_config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e047489e5d82e4c561a835ccf4de1b385e4f5c08
--- /dev/null
+++ b/StyleTextRec/configs/dataset_config.yml
@@ -0,0 +1,64 @@
+Global:
+ output_num: 10
+ output_dir: output_data
+ use_gpu: false
+ image_height: 32
+ image_width: 320
+ standard_font: fonts/en_standard.ttf
+TextDrawer:
+ fonts:
+ en: fonts/en_standard.ttf
+ ch: fonts/ch_standard.ttf
+ ko: fonts/ko_standard.ttf
+StyleSampler:
+ method: DatasetSampler
+ image_home: examples
+ label_file: examples/image_list.txt
+ with_label: true
+CorpusGenerator:
+ method: FileCorpus
+ language: ch
+ corpus_file: examples/corpus/example.txt
+Predictor:
+ method: StyleTextRecPredictor
+ algorithm: StyleTextRec
+ scale: 0.00392156862745098
+ mean:
+ - 0.5
+ - 0.5
+ - 0.5
+ std:
+ - 0.5
+ - 0.5
+ - 0.5
+ expand_result: false
+ bg_generator:
+ pretrain: models/style_text_rec/bg_generator
+ module_name: bg_generator
+ generator_type: BgGeneratorWithMask
+ encode_dim: 64
+ norm_layer: null
+ conv_block_num: 4
+ conv_block_dropout: false
+ conv_block_dilation: true
+ output_factor: 1.05
+ text_generator:
+ pretrain: models/style_text_rec/text_generator
+ module_name: text_generator
+ generator_type: TextGenerator
+ encode_dim: 64
+ norm_layer: InstanceNorm2D
+ conv_block_num: 4
+ conv_block_dropout: false
+ conv_block_dilation: true
+ fusion_generator:
+ pretrain: models/style_text_rec/fusion_generator
+ module_name: fusion_generator
+ generator_type: FusionGeneratorSimple
+ encode_dim: 64
+ norm_layer: null
+ conv_block_num: 4
+ conv_block_dropout: false
+ conv_block_dilation: true
+Writer:
+ method: SimpleWriter
diff --git a/StyleTextRec/doc/images/1.png b/StyleTextRec/doc/images/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..8f7574ba2f723ac82241fec6dc52828713a5d293
Binary files /dev/null and b/StyleTextRec/doc/images/1.png differ
diff --git a/StyleTextRec/doc/images/2.png b/StyleTextRec/doc/images/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..ce9bf4712a551b9d9d27eae00f9c7b9b5845d8b3
Binary files /dev/null and b/StyleTextRec/doc/images/2.png differ
diff --git a/StyleTextRec/doc/images/3.png b/StyleTextRec/doc/images/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..0fb73a31f58c1c476cf84f3c507f0af6523385f4
Binary files /dev/null and b/StyleTextRec/doc/images/3.png differ
diff --git a/StyleTextRec/doc/images/4.jpg b/StyleTextRec/doc/images/4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5fda9548632b63e55b42315dca4a5b9cec2a353c
Binary files /dev/null and b/StyleTextRec/doc/images/4.jpg differ
diff --git a/StyleTextRec/doc/images/5.png b/StyleTextRec/doc/images/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..ea0b89034940cc70e6ec8f77471f3af1c2b54219
Binary files /dev/null and b/StyleTextRec/doc/images/5.png differ
diff --git a/StyleTextRec/doc/images/6.png b/StyleTextRec/doc/images/6.png
new file mode 100644
index 0000000000000000000000000000000000000000..75af7275a009ec01c4bc0903a57d559daf93101b
Binary files /dev/null and b/StyleTextRec/doc/images/6.png differ
diff --git a/StyleTextRec/engine/__init__.py b/StyleTextRec/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/StyleTextRec/engine/corpus_generators.py b/StyleTextRec/engine/corpus_generators.py
new file mode 100644
index 0000000000000000000000000000000000000000..186d15f36d16971d9e7700535b50b1f724a80fe7
--- /dev/null
+++ b/StyleTextRec/engine/corpus_generators.py
@@ -0,0 +1,66 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import random
+
+from utils.logging import get_logger
+
+
+class FileCorpus(object):
+ def __init__(self, config):
+ self.logger = get_logger()
+ self.logger.info("using FileCorpus")
+
+ self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
+
+ corpus_file = config["CorpusGenerator"]["corpus_file"]
+ self.language = config["CorpusGenerator"]["language"]
+ with open(corpus_file, 'r') as f:
+ corpus_raw = f.read()
+ self.corpus_list = corpus_raw.split("\n")[:-1]
+ assert len(self.corpus_list) > 0
+ random.shuffle(self.corpus_list)
+ self.index = 0
+
+ def generate(self, corpus_length=0):
+ if self.index >= len(self.corpus_list):
+ self.index = 0
+ random.shuffle(self.corpus_list)
+ corpus = self.corpus_list[self.index]
+ if corpus_length != 0:
+ corpus = corpus[0:corpus_length]
+ if corpus_length > len(corpus):
+ self.logger.warning("generated corpus is shorter than expected.")
+ self.index += 1
+ return self.language, corpus
+
+
+class EnNumCorpus(object):
+ def __init__(self, config):
+ self.logger = get_logger()
+ self.logger.info("using NumberCorpus")
+ self.num_list = "0123456789"
+ self.en_char_list = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
+ self.height = config["Global"]["image_height"]
+ self.max_width = config["Global"]["image_width"]
+
+ def generate(self, corpus_length=0):
+ corpus = ""
+ if corpus_length == 0:
+ corpus_length = random.randint(5, 15)
+ for i in range(corpus_length):
+ if random.random() < 0.2:
+ corpus += "{}".format(random.choice(self.en_char_list))
+ else:
+ corpus += "{}".format(random.choice(self.num_list))
+ return "en", corpus
diff --git a/StyleTextRec/engine/predictors.py b/StyleTextRec/engine/predictors.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9f4afe4a18bd1e0a96ac37aa0359f26434ddb3d
--- /dev/null
+++ b/StyleTextRec/engine/predictors.py
@@ -0,0 +1,115 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import cv2
+import math
+import paddle
+
+from arch import style_text_rec
+from utils.sys_funcs import check_gpu
+from utils.logging import get_logger
+
+
+class StyleTextRecPredictor(object):
+ def __init__(self, config):
+ algorithm = config['Predictor']['algorithm']
+ assert algorithm in ["StyleTextRec"
+ ], "Generator {} not supported.".format(algorithm)
+ use_gpu = config["Global"]['use_gpu']
+ check_gpu(use_gpu)
+ self.logger = get_logger()
+ self.generator = getattr(style_text_rec, algorithm)(config)
+ self.height = config["Global"]["image_height"]
+ self.width = config["Global"]["image_width"]
+ self.scale = config["Predictor"]["scale"]
+ self.mean = config["Predictor"]["mean"]
+ self.std = config["Predictor"]["std"]
+ self.expand_result = config["Predictor"]["expand_result"]
+
+ def predict(self, style_input, text_input):
+ style_input = self.rep_style_input(style_input, text_input)
+ tensor_style_input = self.preprocess(style_input)
+ tensor_text_input = self.preprocess(text_input)
+ style_text_result = self.generator.forward(tensor_style_input,
+ tensor_text_input)
+ fake_fusion = self.postprocess(style_text_result["fake_fusion"])
+ fake_text = self.postprocess(style_text_result["fake_text"])
+ fake_sk = self.postprocess(style_text_result["fake_sk"])
+ fake_bg = self.postprocess(style_text_result["fake_bg"])
+ bbox = self.get_text_boundary(fake_text)
+ if bbox:
+ left, right, top, bottom = bbox
+ fake_fusion = fake_fusion[top:bottom, left:right, :]
+ fake_text = fake_text[top:bottom, left:right, :]
+ fake_sk = fake_sk[top:bottom, left:right, :]
+ fake_bg = fake_bg[top:bottom, left:right, :]
+
+ # fake_fusion = self.crop_by_text(img_fake_fusion, img_fake_text)
+ return {
+ "fake_fusion": fake_fusion,
+ "fake_text": fake_text,
+ "fake_sk": fake_sk,
+ "fake_bg": fake_bg,
+ }
+
+ def preprocess(self, img):
+ img = (img.astype('float32') * self.scale - self.mean) / self.std
+ img_height, img_width, channel = img.shape
+ assert channel == 3, "Please use an rgb image."
+ ratio = img_width / float(img_height)
+ if math.ceil(self.height * ratio) > self.width:
+ resized_w = self.width
+ else:
+ resized_w = int(math.ceil(self.height * ratio))
+ img = cv2.resize(img, (resized_w, self.height))
+
+ new_img = np.zeros([self.height, self.width, 3]).astype('float32')
+ new_img[:, 0:resized_w, :] = img
+ img = new_img.transpose((2, 0, 1))
+ img = img[np.newaxis, :, :, :]
+ return paddle.to_tensor(img)
+
+ def postprocess(self, tensor):
+ img = tensor.numpy()[0]
+ img = img.transpose((1, 2, 0))
+ img = (img * self.std + self.mean) / self.scale
+ img = np.maximum(img, 0.0)
+ img = np.minimum(img, 255.0)
+ img = img.astype('uint8')
+ return img
+
+ def rep_style_input(self, style_input, text_input):
+ rep_num = int(1.2 * (text_input.shape[1] / text_input.shape[0]) /
+ (style_input.shape[1] / style_input.shape[0])) + 1
+ style_input = np.tile(style_input, reps=[1, rep_num, 1])
+ max_width = int(self.width / self.height * style_input.shape[0])
+ style_input = style_input[:, :max_width, :]
+ return style_input
+
+ def get_text_boundary(self, text_img):
+ img_height = text_img.shape[0]
+ img_width = text_img.shape[1]
+ bounder = 3
+ text_canny_img = cv2.Canny(text_img, 10, 20)
+ edge_num_h = text_canny_img.sum(axis=0)
+ no_zero_list_h = np.where(edge_num_h > 0)[0]
+ edge_num_w = text_canny_img.sum(axis=1)
+ no_zero_list_w = np.where(edge_num_w > 0)[0]
+ if len(no_zero_list_h) == 0 or len(no_zero_list_w) == 0:
+ return None
+ left = max(no_zero_list_h[0] - bounder, 0)
+ right = min(no_zero_list_h[-1] + bounder, img_width)
+ top = max(no_zero_list_w[0] - bounder, 0)
+ bottom = min(no_zero_list_w[-1] + bounder, img_height)
+ return [left, right, top, bottom]
diff --git a/StyleTextRec/engine/style_samplers.py b/StyleTextRec/engine/style_samplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..e171d58db7527ffb37972524991e58ac59c6bb0a
--- /dev/null
+++ b/StyleTextRec/engine/style_samplers.py
@@ -0,0 +1,62 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import random
+import cv2
+
+
+class DatasetSampler(object):
+ def __init__(self, config):
+ self.image_home = config["StyleSampler"]["image_home"]
+ label_file = config["StyleSampler"]["label_file"]
+ self.dataset_with_label = config["StyleSampler"]["with_label"]
+ self.height = config["Global"]["image_height"]
+ self.index = 0
+ with open(label_file, "r") as f:
+ label_raw = f.read()
+ self.path_label_list = label_raw.split("\n")[:-1]
+ assert len(self.path_label_list) > 0
+ random.shuffle(self.path_label_list)
+
+ def sample(self):
+ if self.index >= len(self.path_label_list):
+ random.shuffle(self.path_label_list)
+ self.index = 0
+ if self.dataset_with_label:
+ path_label = self.path_label_list[self.index]
+ rel_image_path, label = path_label.split('\t')
+ else:
+ rel_image_path = self.path_label_list[self.index]
+ label = None
+ img_path = "{}/{}".format(self.image_home, rel_image_path)
+ image = cv2.imread(img_path)
+ origin_height = image.shape[0]
+ ratio = self.height / origin_height
+ width = int(image.shape[1] * ratio)
+ height = int(image.shape[0] * ratio)
+ image = cv2.resize(image, (width, height))
+
+ self.index += 1
+ if label:
+ return {"image": image, "label": label}
+ else:
+ return {"image": image}
+
+
+def duplicate_image(image, width):
+ image_width = image.shape[1]
+ dup_num = width // image_width + 1
+ image = np.tile(image, reps=[1, dup_num, 1])
+ cropped_image = image[:, :width, :]
+ return cropped_image
diff --git a/StyleTextRec/engine/synthesisers.py b/StyleTextRec/engine/synthesisers.py
new file mode 100644
index 0000000000000000000000000000000000000000..177e3e049a695ecd06f5d2271f21336dd4eff997
--- /dev/null
+++ b/StyleTextRec/engine/synthesisers.py
@@ -0,0 +1,71 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+from utils.config import ArgsParser, load_config, override_config
+from utils.logging import get_logger
+from engine import style_samplers, corpus_generators, text_drawers, predictors, writers
+
+
+class ImageSynthesiser(object):
+ def __init__(self):
+ self.FLAGS = ArgsParser().parse_args()
+ self.config = load_config(self.FLAGS.config)
+ self.config = override_config(self.config, options=self.FLAGS.override)
+ self.output_dir = self.config["Global"]["output_dir"]
+ if not os.path.exists(self.output_dir):
+ os.mkdir(self.output_dir)
+ self.logger = get_logger(
+ log_file='{}/predict.log'.format(self.output_dir))
+
+ self.text_drawer = text_drawers.StdTextDrawer(self.config)
+
+ predictor_method = self.config["Predictor"]["method"]
+ assert predictor_method is not None
+ self.predictor = getattr(predictors, predictor_method)(self.config)
+
+ def synth_image(self, corpus, style_input, language="en"):
+ corpus, text_input = self.text_drawer.draw_text(corpus, language)
+ synth_result = self.predictor.predict(style_input, text_input)
+ return synth_result
+
+
+class DatasetSynthesiser(ImageSynthesiser):
+ def __init__(self):
+ super(DatasetSynthesiser, self).__init__()
+ self.tag = self.FLAGS.tag
+ self.output_num = self.config["Global"]["output_num"]
+ corpus_generator_method = self.config["CorpusGenerator"]["method"]
+ self.corpus_generator = getattr(corpus_generators,
+ corpus_generator_method)(self.config)
+
+ style_sampler_method = self.config["StyleSampler"]["method"]
+ assert style_sampler_method is not None
+ self.style_sampler = style_samplers.DatasetSampler(self.config)
+ self.writer = writers.SimpleWriter(self.config, self.tag)
+
+ def synth_dataset(self):
+ for i in range(self.output_num):
+ style_data = self.style_sampler.sample()
+ style_input = style_data["image"]
+ corpus_language, text_input_label = self.corpus_generator.generate(
+ )
+ text_input_label, text_input = self.text_drawer.draw_text(
+ text_input_label, corpus_language)
+
+ synth_result = self.predictor.predict(style_input, text_input)
+ fake_fusion = synth_result["fake_fusion"]
+ self.writer.save_image(fake_fusion, text_input_label)
+ self.writer.save_label()
+ self.writer.merge_label()
diff --git a/StyleTextRec/engine/text_drawers.py b/StyleTextRec/engine/text_drawers.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aaac06ec50816bb6e2774972644c0a7dfb908c6
--- /dev/null
+++ b/StyleTextRec/engine/text_drawers.py
@@ -0,0 +1,57 @@
+from PIL import Image, ImageDraw, ImageFont
+import numpy as np
+from utils.logging import get_logger
+
+
+class StdTextDrawer(object):
+ def __init__(self, config):
+ self.logger = get_logger()
+ self.max_width = config["Global"]["image_width"]
+ self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
+ self.height = config["Global"]["image_height"]
+ self.font_dict = {}
+ self.load_fonts(config["TextDrawer"]["fonts"])
+ self.support_languages = list(self.font_dict)
+
+ def load_fonts(self, fonts_config):
+ for language in fonts_config:
+ font_path = fonts_config[language]
+ font_height = self.get_valid_height(font_path)
+ font = ImageFont.truetype(font_path, font_height)
+ self.font_dict[language] = font
+
+ def get_valid_height(self, font_path):
+ font = ImageFont.truetype(font_path, self.height - 4)
+ _, font_height = font.getsize(self.char_list)
+ if font_height <= self.height - 4:
+ return self.height - 4
+ else:
+ return int((self.height - 4)**2 / font_height)
+
+ def draw_text(self, corpus, language="en", crop=True):
+ if language not in self.support_languages:
+ self.logger.warning(
+ "language {} not supported, use en instead.".format(language))
+ language = "en"
+ if crop:
+ width = min(self.max_width, len(corpus) * self.height) + 4
+ else:
+ width = len(corpus) * self.height + 4
+ bg = Image.new("RGB", (width, self.height), color=(127, 127, 127))
+ draw = ImageDraw.Draw(bg)
+
+ char_x = 2
+ font = self.font_dict[language]
+ for i, char_i in enumerate(corpus):
+ char_size = font.getsize(char_i)[0]
+ draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font)
+ char_x += char_size
+ if char_x >= width:
+ corpus = corpus[0:i + 1]
+ self.logger.warning("corpus length exceed limit: {}".format(
+ corpus))
+ break
+
+ text_input = np.array(bg).astype(np.uint8)
+ text_input = text_input[:, 0:char_x, :]
+ return corpus, text_input
diff --git a/StyleTextRec/engine/writers.py b/StyleTextRec/engine/writers.py
new file mode 100644
index 0000000000000000000000000000000000000000..0df75e7234812c3fbab69ceed50040aa16cd83bc
--- /dev/null
+++ b/StyleTextRec/engine/writers.py
@@ -0,0 +1,71 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import cv2
+import glob
+
+from utils.logging import get_logger
+
+
+class SimpleWriter(object):
+ def __init__(self, config, tag):
+ self.logger = get_logger()
+ self.output_dir = config["Global"]["output_dir"]
+ self.counter = 0
+ self.label_dict = {}
+ self.tag = tag
+ self.label_file_index = 0
+
+ def save_image(self, image, text_input_label):
+ image_home = os.path.join(self.output_dir, "images", self.tag)
+ if not os.path.exists(image_home):
+ os.makedirs(image_home)
+
+ image_path = os.path.join(image_home, "{}.png".format(self.counter))
+ # todo support continue synth
+ cv2.imwrite(image_path, image)
+ self.logger.info("generate image: {}".format(image_path))
+
+ image_name = os.path.join(self.tag, "{}.png".format(self.counter))
+ self.label_dict[image_name] = text_input_label
+
+ self.counter += 1
+ if not self.counter % 100:
+ self.save_label()
+
+ def save_label(self):
+ label_raw = ""
+ label_home = os.path.join(self.output_dir, "label")
+ if not os.path.exists(label_home):
+ os.mkdir(label_home)
+ for image_path in self.label_dict:
+ label = self.label_dict[image_path]
+ label_raw += "{}\t{}\n".format(image_path, label)
+ label_file_path = os.path.join(label_home,
+ "{}_label.txt".format(self.tag))
+ with open(label_file_path, "w") as f:
+ f.write(label_raw)
+ self.label_file_index += 1
+
+ def merge_label(self):
+ label_raw = ""
+ label_file_regex = os.path.join(self.output_dir, "label",
+ "*_label.txt")
+ label_file_list = glob.glob(label_file_regex)
+ for label_file_i in label_file_list:
+ with open(label_file_i, "r") as f:
+ label_raw += f.read()
+ label_file_path = os.path.join(self.output_dir, "label.txt")
+ with open(label_file_path, "w") as f:
+ f.write(label_raw)
diff --git a/StyleTextRec/examples/corpus/example.txt b/StyleTextRec/examples/corpus/example.txt
new file mode 100644
index 0000000000000000000000000000000000000000..78451cc3d92a3353f5de0c74c2cb0a06e6197653
--- /dev/null
+++ b/StyleTextRec/examples/corpus/example.txt
@@ -0,0 +1,2 @@
+PaddleOCR
+飞桨文字识别
diff --git a/StyleTextRec/examples/image_list.txt b/StyleTextRec/examples/image_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b07be0353516f7822e4994d5dcddcd85766035dc
--- /dev/null
+++ b/StyleTextRec/examples/image_list.txt
@@ -0,0 +1,2 @@
+style_images/1.jpg NEATNESS
+style_images/2.jpg 锁店君和宾馆
diff --git a/StyleTextRec/examples/style_images/1.jpg b/StyleTextRec/examples/style_images/1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4da7838e5d3c711cdeab60df63ae4c7af7b475ae
Binary files /dev/null and b/StyleTextRec/examples/style_images/1.jpg differ
diff --git a/StyleTextRec/examples/style_images/2.jpg b/StyleTextRec/examples/style_images/2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f68ce49aa5558124d36ae6eaa801be5b0e79e152
Binary files /dev/null and b/StyleTextRec/examples/style_images/2.jpg differ
diff --git a/StyleTextRec/fonts/ch_standard.ttf b/StyleTextRec/fonts/ch_standard.ttf
new file mode 100755
index 0000000000000000000000000000000000000000..cdb7fa5907587b8dbe0ad1da7442d3e4f8bd7488
Binary files /dev/null and b/StyleTextRec/fonts/ch_standard.ttf differ
diff --git a/StyleTextRec/fonts/en_standard.ttf b/StyleTextRec/fonts/en_standard.ttf
new file mode 100755
index 0000000000000000000000000000000000000000..2e31d02424ed50b9e05c19b5d82500699a6edbb0
Binary files /dev/null and b/StyleTextRec/fonts/en_standard.ttf differ
diff --git a/StyleTextRec/fonts/ko_standard.ttf b/StyleTextRec/fonts/ko_standard.ttf
new file mode 100755
index 0000000000000000000000000000000000000000..982bd879c27c731d2601ea8da988784e06f4b5b3
Binary files /dev/null and b/StyleTextRec/fonts/ko_standard.ttf differ
diff --git a/StyleTextRec/tools/__init__.py b/StyleTextRec/tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/StyleTextRec/tools/synth_dataset.py b/StyleTextRec/tools/synth_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a0e6d5e1f701c49558cfe1ea1df61e9b4180a89
--- /dev/null
+++ b/StyleTextRec/tools/synth_dataset.py
@@ -0,0 +1,23 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from engine.synthesisers import DatasetSynthesiser
+
+
+def synth_dataset():
+ dataset_synthesiser = DatasetSynthesiser()
+ dataset_synthesiser.synth_dataset()
+
+
+if __name__ == '__main__':
+ synth_dataset()
diff --git a/StyleTextRec/tools/synth_image.py b/StyleTextRec/tools/synth_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b4827b825e4a28dd1fb2eba722d23e64e8ce0be
--- /dev/null
+++ b/StyleTextRec/tools/synth_image.py
@@ -0,0 +1,82 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import cv2
+import sys
+import glob
+
+from utils.config import ArgsParser
+from engine.synthesisers import ImageSynthesiser
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+
+
+def synth_image():
+ args = ArgsParser().parse_args()
+ image_synthesiser = ImageSynthesiser()
+ style_image_path = args.style_image
+ img = cv2.imread(style_image_path)
+ text_corpus = args.text_corpus
+ language = args.language
+
+ synth_result = image_synthesiser.synth_image(text_corpus, img, language)
+ fake_fusion = synth_result["fake_fusion"]
+ fake_text = synth_result["fake_text"]
+ fake_bg = synth_result["fake_bg"]
+ cv2.imwrite("fake_fusion.jpg", fake_fusion)
+ cv2.imwrite("fake_text.jpg", fake_text)
+ cv2.imwrite("fake_bg.jpg", fake_bg)
+
+
+def batch_synth_images():
+ image_synthesiser = ImageSynthesiser()
+
+ corpus_file = "../StyleTextRec_data/test_20201208/test_text_list.txt"
+ style_data_dir = "../StyleTextRec_data/test_20201208/style_images/"
+ save_path = "./output_data/"
+ corpus_list = []
+ with open(corpus_file, "rb") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ substr = line.decode("utf-8").strip("\n").split("\t")
+ corpus_list.append(substr)
+ style_img_list = glob.glob("{}/*.jpg".format(style_data_dir))
+ corpus_num = len(corpus_list)
+ style_img_num = len(style_img_list)
+ for cno in range(corpus_num):
+ for sno in range(style_img_num):
+ corpus, lang = corpus_list[cno]
+ style_img_path = style_img_list[sno]
+ img = cv2.imread(style_img_path)
+ synth_result = image_synthesiser.synth_image(corpus, img, lang)
+ fake_fusion = synth_result["fake_fusion"]
+ fake_text = synth_result["fake_text"]
+ fake_bg = synth_result["fake_bg"]
+ for tp in range(2):
+ if tp == 0:
+ prefix = "%s/c%d_s%d_" % (save_path, cno, sno)
+ else:
+ prefix = "%s/s%d_c%d_" % (save_path, sno, cno)
+ cv2.imwrite("%s_fake_fusion.jpg" % prefix, fake_fusion)
+ cv2.imwrite("%s_fake_text.jpg" % prefix, fake_text)
+ cv2.imwrite("%s_fake_bg.jpg" % prefix, fake_bg)
+ cv2.imwrite("%s_input_style.jpg" % prefix, img)
+ print(cno, corpus_num, sno, style_img_num)
+
+
+if __name__ == '__main__':
+ # batch_synth_images()
+ synth_image()
diff --git a/StyleTextRec/utils/__init__.py b/StyleTextRec/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/StyleTextRec/utils/config.py b/StyleTextRec/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2f8a618a838db361da4867e00df8dcd619f9f3d
--- /dev/null
+++ b/StyleTextRec/utils/config.py
@@ -0,0 +1,224 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import yaml
+import os
+from argparse import ArgumentParser, RawDescriptionHelpFormatter
+
+
+def override(dl, ks, v):
+ """
+ Recursively replace dict of list
+
+ Args:
+ dl(dict or list): dict or list to be replaced
+ ks(list): list of keys
+ v(str): value to be replaced
+ """
+
+ def str2num(v):
+ try:
+ return eval(v)
+ except Exception:
+ return v
+
+ assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
+ assert len(ks) > 0, ('lenght of keys should larger than 0')
+ if isinstance(dl, list):
+ k = str2num(ks[0])
+ if len(ks) == 1:
+ assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
+ dl[k] = str2num(v)
+ else:
+ override(dl[k], ks[1:], v)
+ else:
+ if len(ks) == 1:
+ #assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
+ if not ks[0] in dl:
+ logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
+ dl[ks[0]] = str2num(v)
+ else:
+ assert ks[0] in dl, (
+ '({}) doesn\'t exist in {}, a new dict field is invalid'.
+ format(ks[0], dl))
+ override(dl[ks[0]], ks[1:], v)
+
+
+def override_config(config, options=None):
+ """
+ Recursively override the config
+
+ Args:
+ config(dict): dict to be replaced
+ options(list): list of pairs(key0.key1.idx.key2=value)
+ such as: [
+ 'topk=2',
+ 'VALID.transforms.1.ResizeImage.resize_short=300'
+ ]
+
+ Returns:
+ config(dict): replaced config
+ """
+ if options is not None:
+ for opt in options:
+ assert isinstance(opt, str), (
+ "option({}) should be a str".format(opt))
+ assert "=" in opt, (
+ "option({}) should contain a ="
+ "to distinguish between key and value".format(opt))
+ pair = opt.split('=')
+ assert len(pair) == 2, ("there can be only a = in the option")
+ key, value = pair
+ keys = key.split('.')
+ override(config, keys, value)
+
+ return config
+
+
+class ArgsParser(ArgumentParser):
+ def __init__(self):
+ super(ArgsParser, self).__init__(
+ formatter_class=RawDescriptionHelpFormatter)
+ self.add_argument("-c", "--config", help="configuration file to use")
+ self.add_argument(
+ "-t", "--tag", default="0", help="tag for marking worker")
+ self.add_argument(
+ '-o',
+ '--override',
+ action='append',
+ default=[],
+ help='config options to be overridden')
+ self.add_argument(
+ "--style_image", default="examples/style_images/1.jpg", help="tag for marking worker")
+ self.add_argument(
+ "--text_corpus", default="PaddleOCR", help="tag for marking worker")
+ self.add_argument(
+ "--language", default="en", help="tag for marking worker")
+
+ def parse_args(self, argv=None):
+ args = super(ArgsParser, self).parse_args(argv)
+ assert args.config is not None, \
+ "Please specify --config=configure_file_path."
+ return args
+
+
+def load_config(file_path):
+ """
+ Load config from yml/yaml file.
+ Args:
+ file_path (str): Path of the config file to be loaded.
+ Returns: config
+ """
+ ext = os.path.splitext(file_path)[1]
+ assert ext in ['.yml', '.yaml'], "only support yaml files for now"
+ with open(file_path, 'rb') as f:
+ config = yaml.load(f, Loader=yaml.Loader)
+
+ return config
+
+
+def gen_config():
+ base_config = {
+ "Global": {
+ "algorithm": "SRNet",
+ "use_gpu": True,
+ "start_epoch": 1,
+ "stage1_epoch_num": 100,
+ "stage2_epoch_num": 100,
+ "log_smooth_window": 20,
+ "print_batch_step": 2,
+ "save_model_dir": "./output/SRNet",
+ "use_visualdl": False,
+ "save_epoch_step": 10,
+ "vgg_pretrain": "./pretrained/VGG19_pretrained",
+ "vgg_load_static_pretrain": True
+ },
+ "Architecture": {
+ "model_type": "data_aug",
+ "algorithm": "SRNet",
+ "net_g": {
+ "name": "srnet_net_g",
+ "encode_dim": 64,
+ "norm": "batch",
+ "use_dropout": False,
+ "init_type": "xavier",
+ "init_gain": 0.02,
+ "use_dilation": 1
+ },
+ # input_nc, ndf, netD,
+ # n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
+ "bg_discriminator": {
+ "name": "srnet_bg_discriminator",
+ "input_nc": 6,
+ "ndf": 64,
+ "netD": "basic",
+ "norm": "none",
+ "init_type": "xavier",
+ },
+ "fusion_discriminator": {
+ "name": "srnet_fusion_discriminator",
+ "input_nc": 6,
+ "ndf": 64,
+ "netD": "basic",
+ "norm": "none",
+ "init_type": "xavier",
+ }
+ },
+ "Loss": {
+ "lamb": 10,
+ "perceptual_lamb": 1,
+ "muvar_lamb": 50,
+ "style_lamb": 500
+ },
+ "Optimizer": {
+ "name": "Adam",
+ "learning_rate": {
+ "name": "lambda",
+ "lr": 0.0002,
+ "lr_decay_iters": 50
+ },
+ "beta1": 0.5,
+ "beta2": 0.999,
+ },
+ "Train": {
+ "batch_size_per_card": 8,
+ "num_workers_per_card": 4,
+ "dataset": {
+ "delimiter": "\t",
+ "data_dir": "/",
+ "label_file": "tmp/label.txt",
+ "transforms": [{
+ "DecodeImage": {
+ "to_rgb": True,
+ "to_np": False,
+ "channel_first": False
+ }
+ }, {
+ "NormalizeImage": {
+ "scale": 1. / 255.,
+ "mean": [0.485, 0.456, 0.406],
+ "std": [0.229, 0.224, 0.225],
+ "order": None
+ }
+ }, {
+ "ToCHWImage": None
+ }]
+ }
+ }
+ }
+ with open("config.yml", "w") as f:
+ yaml.dump(base_config, f)
+
+
+if __name__ == '__main__':
+ gen_config()
diff --git a/StyleTextRec/utils/load_params.py b/StyleTextRec/utils/load_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..be0561363eb21483d267ff6557c1d453d330c5f8
--- /dev/null
+++ b/StyleTextRec/utils/load_params.py
@@ -0,0 +1,27 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import paddle
+
+__all__ = ['load_dygraph_pretrain']
+
+
+def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
+ if not os.path.exists(path + '.pdparams'):
+ raise ValueError("Model pretrain path {} does not "
+ "exists.".format(path))
+ param_state_dict = paddle.load(path + '.pdparams')
+ model.set_state_dict(param_state_dict)
+ logger.info("load pretrained model from {}".format(path))
+ return
diff --git a/StyleTextRec/utils/logging.py b/StyleTextRec/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..f700fe26bc9bfda21d39a0bddd89180f5de442ab
--- /dev/null
+++ b/StyleTextRec/utils/logging.py
@@ -0,0 +1,65 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+import logging
+import functools
+import paddle.distributed as dist
+
+logger_initialized = {}
+
+
+@functools.lru_cache()
+def get_logger(name='srnet', log_file=None, log_level=logging.INFO):
+ """Initialize and get a logger by name.
+ If the logger has not been initialized, this method will initialize the
+ logger by adding one or two handlers, otherwise the initialized logger will
+ be directly returned. During initialization, a StreamHandler will always be
+ added. If `log_file` is specified a FileHandler will also be added.
+ Args:
+ name (str): Logger name.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the logger.
+ log_level (int): The logger level. Note that only the process of
+ rank 0 is affected, and other processes will set the level to
+ "Error" thus be silent most of the time.
+ Returns:
+ logging.Logger: The expected logger.
+ """
+ logger = logging.getLogger(name)
+ if name in logger_initialized:
+ return logger
+ for logger_name in logger_initialized:
+ if name.startswith(logger_name):
+ return logger
+
+ formatter = logging.Formatter(
+ '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
+ datefmt="%Y/%m/%d %H:%M:%S")
+
+ stream_handler = logging.StreamHandler(stream=sys.stdout)
+ stream_handler.setFormatter(formatter)
+ logger.addHandler(stream_handler)
+ if log_file is not None and dist.get_rank() == 0:
+ log_file_folder = os.path.split(log_file)[0]
+ os.makedirs(log_file_folder, exist_ok=True)
+ file_handler = logging.FileHandler(log_file, 'a')
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+ if dist.get_rank() == 0:
+ logger.setLevel(log_level)
+ else:
+ logger.setLevel(logging.ERROR)
+ logger_initialized[name] = True
+ return logger
diff --git a/StyleTextRec/utils/math_functions.py b/StyleTextRec/utils/math_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dc8d9160f8941f825d7aade79afc99035577bca
--- /dev/null
+++ b/StyleTextRec/utils/math_functions.py
@@ -0,0 +1,45 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+
+
+def compute_mean_covariance(img):
+ batch_size = img.shape[0]
+ channel_num = img.shape[1]
+ height = img.shape[2]
+ width = img.shape[3]
+ num_pixels = height * width
+
+ # batch_size * channel_num * 1 * 1
+ mu = img.mean(2, keepdim=True).mean(3, keepdim=True)
+
+ # batch_size * channel_num * num_pixels
+ img_hat = img - mu.expand_as(img)
+ img_hat = img_hat.reshape([batch_size, channel_num, num_pixels])
+ # batch_size * num_pixels * channel_num
+ img_hat_transpose = img_hat.transpose([0, 2, 1])
+ # batch_size * channel_num * channel_num
+ covariance = paddle.bmm(img_hat, img_hat_transpose)
+ covariance = covariance / num_pixels
+
+ return mu, covariance
+
+
+def dice_coefficient(y_true_cls, y_pred_cls, training_mask):
+ eps = 1e-5
+ intersection = paddle.sum(y_true_cls * y_pred_cls * training_mask)
+ union = paddle.sum(y_true_cls * training_mask) + paddle.sum(
+ y_pred_cls * training_mask) + eps
+ loss = 1. - (2 * intersection / union)
+ return loss
diff --git a/StyleTextRec/utils/sys_funcs.py b/StyleTextRec/utils/sys_funcs.py
new file mode 100644
index 0000000000000000000000000000000000000000..203d91d83630e41fbe931a055e81e65cf0fb2e7d
--- /dev/null
+++ b/StyleTextRec/utils/sys_funcs.py
@@ -0,0 +1,67 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+import os
+import errno
+import paddle
+
+
+def get_check_global_params(mode):
+ check_params = [
+ 'use_gpu', 'max_text_length', 'image_shape', 'image_shape',
+ 'character_type', 'loss_type'
+ ]
+ if mode == "train_eval":
+ check_params = check_params + [
+ 'train_batch_size_per_card', 'test_batch_size_per_card'
+ ]
+ elif mode == "test":
+ check_params = check_params + ['test_batch_size_per_card']
+ return check_params
+
+
+def check_gpu(use_gpu):
+ """
+ Log error and exit when set use_gpu=true in paddlepaddle
+ cpu version.
+ """
+ err = "Config use_gpu cannot be set as true while you are " \
+ "using paddlepaddle cpu version ! \nPlease try: \n" \
+ "\t1. Install paddlepaddle-gpu to run model on GPU \n" \
+ "\t2. Set use_gpu as false in config file to run " \
+ "model on CPU"
+ if use_gpu:
+ try:
+ if not paddle.is_compiled_with_cuda():
+ print(err)
+ sys.exit(1)
+ except:
+ print("Fail to check gpu state.")
+ sys.exit(1)
+
+
+def _mkdir_if_not_exist(path, logger):
+ """
+ mkdir if not exists, ignore the exception when multiprocess mkdir together
+ """
+ if not os.path.exists(path):
+ try:
+ os.makedirs(path)
+ except OSError as e:
+ if e.errno == errno.EEXIST and os.path.isdir(path):
+ logger.warning(
+ 'be happy if some process has already created {}'.format(
+ path))
+ else:
+ raise OSError('Failed to mkdir {}'.format(path))
diff --git a/configs/rec/rec_icdar15_train.yml b/configs/rec/rec_icdar15_train.yml
index 7efbd5cf0d963229a94aa43558589b828d17cbd0..3de0ce7741cc8086b41cd1f5b98f6a8bbced90fa 100644
--- a/configs/rec/rec_icdar15_train.yml
+++ b/configs/rec/rec_icdar15_train.yml
@@ -36,12 +36,13 @@ Architecture:
algorithm: CRNN
Transform:
Backbone:
- name: ResNet
- layers: 34
+ name: MobileNetV3
+ scale: 0.5
+ model_name: large
Neck:
name: SequenceEncoder
encoder_type: rnn
- hidden_size: 256
+ hidden_size: 96
Head:
name: CTCHead
fc_decay: 0
diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
index 83a0b3cdc3a59e92e3bca019870806fc5dc8852d..baf0ece8b91f05ca105aba019cb87964eedac0d3 100755
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -9,7 +9,7 @@
### 1.文本检测算法
PaddleOCR开源的文本检测算法列表:
-- [x] DB([paper](https://arxiv.org/abs/1911.08947))(ppocr推荐)
+- [x] DB([paper]( https://arxiv.org/abs/1911.08947) )(ppocr推荐)
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))
@@ -38,9 +38,9 @@ PaddleOCR文本检测算法的训练和使用请参考文档教程中[模型训
### 2.文本识别算法
PaddleOCR基于动态图开源的文本识别算法列表:
-- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))(ppocr推荐)
+- [x] CRNN([paper](https://arxiv.org/abs/1507.05717) )(ppocr推荐)
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
-- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
+- [ ] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1)) coming soon
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294)) coming soon
diff --git a/doc/doc_ch/angle_class.md b/doc/doc_ch/angle_class.md
index d6a36b86b476f15b7b34f67e888ceb781b2ed7a0..3f2027b9ddff331b3259ed62c7c7b43e686efcce 100644
--- a/doc/doc_ch/angle_class.md
+++ b/doc/doc_ch/angle_class.md
@@ -62,9 +62,9 @@ PaddleOCR提供了训练脚本、评估脚本和预测脚本。
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
```
-# GPU训练 支持单卡,多卡训练,通过selected_gpus指定卡号
+# GPU训练 支持单卡,多卡训练,通过 '--gpus' 指定卡号,如果使用的paddle版本小于2.0rc1,请使用'--select_gpus'参数选择要使用的GPU
# 启动训练,下面的命令已经写入train.sh文件中,只需修改文件里的配置文件路径即可
-python3 -m paddle.distributed.launch --selected_gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
+python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
```
- 数据增强
@@ -74,7 +74,7 @@ PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入
默认的扰动方式有:颜色空间转换(cvtColor)、模糊(blur)、抖动(jitter)、噪声(Gasuss noise)、随机切割(random crop)、透视(perspective)、颜色反转(reverse),随机数据增强(RandAugment)。
训练过程中除随机数据增强外每种扰动方式以50%的概率被选择,具体代码实现请参考:
-[rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py)
+[rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py)
[randaugment.py](../../ppocr/data/imaug/randaugment.py)
*由于OpenCV的兼容性问题,扰动操作暂时只支持linux*
diff --git a/doc/doc_ch/detection.md b/doc/doc_ch/detection.md
index ec3cb2766071d4c1ff6927de3e79c3e3c0c51131..08b94a9c838cb265a1e6145e29db676bf52c7de7 100644
--- a/doc/doc_ch/detection.md
+++ b/doc/doc_ch/detection.md
@@ -107,17 +107,13 @@ PaddleOCR计算三个OCR检测相关的指标,分别是:Precision、Recall
运行如下代码,根据配置文件`det_db_mv3.yml`中`save_res_path`指定的测试集检测结果文件,计算评估指标。
-评估时设置后处理参数`box_thresh=0.6`,`unclip_ratio=1.5`,使用不同数据集、不同模型训练,可调整这两个参数进行优化
-```shell
-python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
-```
+评估时设置后处理参数`box_thresh=0.5`,`unclip_ratio=1.5`,使用不同数据集、不同模型训练,可调整这两个参数进行优化
训练中模型参数默认保存在`Global.save_model_dir`目录下。在评估指标时,需要设置`Global.checkpoints`指向保存的参数文件。
-
-比如:
```shell
-python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
+python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.5 PostProcess.unclip_ratio=1.5
```
+
* 注:`box_thresh`、`unclip_ratio`是DB后处理所需要的参数,在评估EAST模型时不需要设置
## 测试检测效果
diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md
index 10c01666404c1a66a14485ed30195954ed881b6f..aea7ff1de242dec75cae26a2bf3d6838d7559882 100755
--- a/doc/doc_ch/inference.md
+++ b/doc/doc_ch/inference.md
@@ -22,9 +22,8 @@ inference 模型(`paddle.jit.save`保存的模型)
- [三、文本识别模型推理](#文本识别模型推理)
- [1. 超轻量中文识别模型推理](#超轻量中文识别模型推理)
- [2. 基于CTC损失的识别模型推理](#基于CTC损失的识别模型推理)
- - [3. 基于Attention损失的识别模型推理](#基于Attention损失的识别模型推理)
- - [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
- - [5. 多语言模型的推理](#多语言模型的推理)
+ - [3. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
+ - [4. 多语言模型的推理](#多语言模型的推理)
- [四、方向分类模型推理](#方向识别模型推理)
- [1. 方向分类模型推理](#方向分类模型推理)
@@ -129,24 +128,32 @@ python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.pretrained_mo
超轻量中文检测模型推理,可以执行如下命令:
```
-python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/"
+# 下载超轻量中文检测模型:
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar
+tar xf ch_ppocr_mobile_v2.0_det_infer.tar
+python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./ch_ppocr_mobile_v2.0_det_infer/"
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
-![](../imgs_results/det_res_2.jpg)
+![](../imgs_results/det_res_22.jpg)
-通过参数`limit_type`和`det_limit_side_len`来对图片的尺寸进行限制限,`limit_type=max`为限制长边长度<`det_limit_side_len`,`limit_type=min`为限制短边长度>`det_limit_side_len`,
-图片不满足限制条件时(`limit_type=max`时长边长度>`det_limit_side_len`或`limit_type=min`时短边长度<`det_limit_side_len`),将对图片进行等比例缩放。
-该参数默认设置为`limit_type='max',det_max_side_len=960`。 如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以执行如下命令:
+通过参数`limit_type`和`det_limit_side_len`来对图片的尺寸进行限制,
+`litmit_type`可选参数为[`max`, `min`],
+`det_limit_size_len` 为正整数,一般设置为32 的倍数,比如960。
+参数默认设置为`limit_type='max', det_limit_side_len=960`。表示网络输入图像的最长边不能超过960,
+如果超过这个值,会对图像做等宽比的resize操作,确保最长边为`det_limit_side_len`。
+设置为`limit_type='min', det_limit_side_len=960` 则表示限制图像的最短边为960。
+
+如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以设置det_limit_side_len 为想要的值,比如1216:
```
-python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1200
+python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
```
如果想使用CPU进行预测,执行命令如下
```
-python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
+python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
```
@@ -268,16 +275,6 @@ CRNN 文本识别模型推理,可以执行如下命令:
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/rec_crnn/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
```
-
-### 3. 基于Attention损失的识别模型推理
-
-基于Attention损失的识别模型与ctc不同,需要额外设置识别算法参数 --rec_algorithm="RARE"
-RARE 文本识别模型推理,可以执行如下命令:
-```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/rare/" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_algorithm="RARE"
-
-```
-
![](../imgs_words_en/word_336.png)
执行命令后,上面图像的识别结果如下:
@@ -297,7 +294,7 @@ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
```
-### 4. 自定义文本识别字典的推理
+### 3. 自定义文本识别字典的推理
如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径,并且设置 `rec_char_type=ch`
```
@@ -305,7 +302,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
```
-### 5. 多语言模型的推理
+### 4. 多语言模型的推理
如果您需要预测的是其他语言模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,
需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/` 路径下有默认提供的小语种字体,例如韩文识别:
diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md
index d67179e026f6dc5b0f2baaea482f6b8cee337dc5..dc06365c6ef66fe5539887a19042dfdbfb45efa3 100644
--- a/doc/doc_ch/recognition.md
+++ b/doc/doc_ch/recognition.md
@@ -167,7 +167,7 @@ tar -xf rec_mv3_none_bilstm_ctc_v2.0_train.tar && rm -rf rec_mv3_none_bilstm_ctc
```
# GPU训练 支持单卡,多卡训练,通过--gpus参数指定卡号
-# 训练icdar15英文数据 并将训练日志保存为 tain_rec.log
+# 训练icdar15英文数据 训练日志会自动保存为 "{save_model_dir}" 下的train.log
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_icdar15_train.yml
```
@@ -200,11 +200,8 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
| rec_icdar15_train.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
| rec_mv3_none_bilstm_ctc.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
-| rec_mv3_tps_bilstm_ctc.yml | STARNet | Mobilenet_v3 large 0.5 | tps | BiLSTM | ctc |
-| rec_mv3_tps_bilstm_attn.yml | RARE | Mobilenet_v3 large 0.5 | tps | BiLSTM | attention |
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
-| rec_r34_vd_tps_bilstm_ctc.yml | STARNet | Resnet34_vd | tps | BiLSTM | ctc |
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
@@ -356,8 +353,7 @@ python3 tools/infer_rec.py -c configs/rec/rec_icdar15_train.yml -o Global.checkp
```
infer_img: doc/imgs_words/en/word_1.png
- index: [19 24 18 23 29]
- word : joint
+ result: ('joint', 0.9998967)
```
预测使用的配置文件必须与训练一致,如您通过 `python3 tools/train.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml` 完成了中文模型的训练,
@@ -376,6 +372,5 @@ python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v
```
infer_img: doc/imgs_words/ch/word_1.jpg
- index: [2092 177 312 2503]
- word : 韩国小馆
+ result: ('韩国小馆', 0.997218)
```
diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
index 46438d5d59db3edbc296ce25d1b8b06c3d79265a..c4c522ced1cdcd187e27b5099aecd074975e90d3 100755
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -13,7 +13,7 @@ This tutorial lists the text detection algorithms and text recognition algorithm
PaddleOCR open source text detection algorithms list:
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))
- [x] DB([paper](https://arxiv.org/abs/1911.08947))
-- [x] SAST([paper](https://arxiv.org/abs/1908.05498))(Baidu Self-Research)
+- [x] SAST([paper](https://arxiv.org/abs/1908.05498) )(Baidu Self-Research)
On the ICDAR2015 dataset, the text detection result is as follows:
@@ -41,9 +41,9 @@ For the training guide and use of PaddleOCR text detection algorithms, please re
PaddleOCR open-source text recognition algorithms list:
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
-- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
+- [ ] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1)) coming soon
-- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))(Baidu Self-Research) coming soon
+- [ ] SRN([paper](https://arxiv.org/abs/2003.12294) )(Baidu Self-Research) coming soon
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
diff --git a/doc/doc_en/angle_class_en.md b/doc/doc_en/angle_class_en.md
index defdff3ccbbad9d0201305529073bdc80abd5d29..4c479e7b22e7caea6bf5f864d32b57197b925dd9 100644
--- a/doc/doc_en/angle_class_en.md
+++ b/doc/doc_en/angle_class_en.md
@@ -65,9 +65,9 @@ Start training:
```
# Set PYTHONPATH path
export PYTHONPATH=$PYTHONPATH:.
-# GPU training Support single card and multi-card training, specify the card number through selected_gpus
+# GPU training Support single card and multi-card training, specify the card number through --gpus. If your paddle version is less than 2.0rc1, please use '--selected_gpus'
# Start training, the following command has been written into the train.sh file, just modify the configuration file path in the file
-python3 -m paddle.distributed.launch --selected_gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
+python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
```
- Data Augmentation
@@ -77,7 +77,7 @@ PaddleOCR provides a variety of data augmentation methods. If you want to add di
The default perturbation methods are: cvtColor, blur, jitter, Gasuss noise, random crop, perspective, color reverse, RandAugment.
Except for RandAugment, each disturbance method is selected with a 50% probability during the training process. For specific code implementation, please refer to:
-[rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py)
+[rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py)
[randaugment.py](../../ppocr/data/imaug/randaugment.py)
diff --git a/doc/doc_en/benchmark_en.md b/doc/doc_en/benchmark_en.md
old mode 100644
new mode 100755
diff --git a/doc/doc_en/detection_en.md b/doc/doc_en/detection_en.md
index 6a2bda6b497df5d9a6ccb914976f53b2e27ce9b0..7638315ae9991c909d7079c904d646a656173dca 100644
--- a/doc/doc_en/detection_en.md
+++ b/doc/doc_en/detection_en.md
@@ -101,15 +101,11 @@ Run the following code to calculate the evaluation indicators. The result will b
When evaluating, set post-processing parameters `box_thresh=0.6`, `unclip_ratio=1.5`. If you use different datasets, different models for training, these two parameters should be adjusted for better result.
+The model parameters during training are saved in the `Global.save_model_dir` directory by default. When evaluating indicators, you need to set `Global.checkpoints` to point to the saved parameter file.
```shell
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
-The model parameters during training are saved in the `Global.save_model_dir` directory by default. When evaluating indicators, you need to set `Global.checkpoints` to point to the saved parameter file.
-Such as:
-```shell
-python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
-```
* Note: `box_thresh` and `unclip_ratio` are parameters required for DB post-processing, and not need to be set when evaluating the EAST model.
diff --git a/doc/doc_en/inference_en.md b/doc/doc_en/inference_en.md
index 606565275ba243969bf919bfb91a0a2067a7a8cd..db86b109d1a13d00aab833aa31d0279622e7c7f8 100755
--- a/doc/doc_en/inference_en.md
+++ b/doc/doc_en/inference_en.md
@@ -25,9 +25,8 @@ Next, we first introduce how to convert a trained model into an inference model,
- [TEXT RECOGNITION MODEL INFERENCE](#RECOGNITION_MODEL_INFERENCE)
- [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_RECOGNITION)
- [2. CTC-BASED TEXT RECOGNITION MODEL INFERENCE](#CTC-BASED_RECOGNITION)
- - [3. ATTENTION-BASED TEXT RECOGNITION MODEL INFERENCE](#ATTENTION-BASED_RECOGNITION)
- - [4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS)
- - [5. MULTILINGUAL MODEL INFERENCE](MULTILINGUAL_MODEL_INFERENCE)
+ - [3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS)
+ - [4. MULTILINGUAL MODEL INFERENCE](MULTILINGUAL_MODEL_INFERENCE)
- [ANGLE CLASSIFICATION MODEL INFERENCE](#ANGLE_CLASS_MODEL_INFERENCE)
- [1. ANGLE CLASSIFICATION MODEL INFERENCE](#ANGLE_CLASS_MODEL_INFERENCE)
@@ -135,24 +134,33 @@ Because EAST and DB algorithms are very different, when inference, it is necessa
For lightweight Chinese detection model inference, you can execute the following commands:
```
-python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/"
+# download DB text detection inference model
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar
+tar xf ch_ppocr_mobile_v2.0_det_infer.tar
+# predict
+python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/"
```
The visual text detection results are saved to the ./inference_results folder by default, and the name of the result file is prefixed with'det_res'. Examples of results are as follows:
-![](../imgs_results/det_res_2.jpg)
+![](../imgs_results/det_res_22.jpg)
-The size of the image is limited by the parameters `limit_type` and `det_limit_side_len`, `limit_type=max` is to limit the length of the long side <`det_limit_side_len`, and `limit_type=min` is to limit the length of the short side>`det_limit_side_len`,
-When the picture does not meet the restriction conditions (for `limit_type=max`and long side >`det_limit_side_len` or for `min` and short side <`det_limit_side_len`), the image will be scaled proportionally.
-This parameter is set to `limit_type='max', det_max_side_len=960` by default. If the resolution of the input picture is relatively large, and you want to use a larger resolution prediction, you can execute the following command:
+You can use the parameters `limit_type` and `det_limit_side_len` to limit the size of the input image,
+The optional parameters of `litmit_type` are [`max`, `min`], and
+`det_limit_size_len` is a positive integer, generally set to a multiple of 32, such as 960.
+The default setting of the parameters is `limit_type='max', det_limit_side_len=960`. Indicates that the longest side of the network input image cannot exceed 960,
+If this value is exceeded, the image will be resized with the same width ratio to ensure that the longest side is `det_limit_side_len`.
+Set as `limit_type='min', det_limit_side_len=960`, it means that the shortest side of the image is limited to 960.
+
+If the resolution of the input picture is relatively large and you want to use a larger resolution prediction, you can set det_limit_side_len to the desired value, such as 1216:
```
-python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1200
+python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
```
If you want to use the CPU for prediction, execute the command as follows
```
-python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
+python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
```
@@ -275,15 +283,6 @@ For CRNN text recognition model inference, execute the following commands:
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
```
-
-### 3. ATTENTION-BASED TEXT RECOGNITION MODEL INFERENCE
-
-The recognition model based on Attention loss is different from ctc, and additional recognition algorithm parameters need to be set --rec_algorithm="RARE"
-After executing the command, the recognition result of the above image is as follows:
-```bash
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/rare/" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_algorithm="RARE"
-```
-
![](../imgs_words_en/word_336.png)
After executing the command, the recognition result of the above image is as follows:
@@ -303,7 +302,7 @@ dict_character = list(self.character_str)
```
-### 4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY
+### 3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY
If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`, and set `rec_char_type=ch`
```
@@ -311,7 +310,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
```
-### 5. MULTILINGAUL MODEL INFERENCE
+### 4. MULTILINGAUL MODEL INFERENCE
If you need to predict other language models, when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results,
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/` path, such as Korean recognition:
diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md
index 1539b288da2518bf5441adea7983135f3c46619f..bc8faa0fc3df936855ead965f1e22107b576bc7a 100644
--- a/doc/doc_en/recognition_en.md
+++ b/doc/doc_en/recognition_en.md
@@ -162,7 +162,7 @@ Start training:
```
# GPU training Support single card and multi-card training, specify the card number through --gpus
-# Training icdar15 English data and saving the log as train_rec.log
+# Training icdar15 English data and The training log will be automatically saved as train.log under "{save_model_dir}"
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_icdar15_train.yml
```
@@ -193,11 +193,8 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
| rec_icdar15_train.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
| rec_mv3_none_bilstm_ctc.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
-| rec_mv3_tps_bilstm_ctc.yml | STARNet | Mobilenet_v3 large 0.5 | tps | BiLSTM | ctc |
-| rec_mv3_tps_bilstm_attn.yml | RARE | Mobilenet_v3 large 0.5 | tps | BiLSTM | attention |
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
-| rec_r34_vd_tps_bilstm_ctc.yml | STARNet | Resnet34_vd | tps | BiLSTM | ctc |
For training Chinese data, it is recommended to use
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
@@ -350,8 +347,7 @@ Get the prediction result of the input image:
```
infer_img: doc/imgs_words/en/word_1.png
- index: [19 24 18 23 29]
- word : joint
+ result: ('joint', 0.9998967)
```
The configuration file used for prediction must be consistent with the training. For example, you completed the training of the Chinese model with `python3 tools/train.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml`, you can use the following command to predict the Chinese model:
@@ -369,6 +365,5 @@ Get the prediction result of the input image:
```
infer_img: doc/imgs_words/ch/word_1.jpg
- index: [2092 177 312 2503]
- word : 韩国小馆
+ result: ('韩国小馆', 0.997218)
```
diff --git a/doc/imgs_results/det_res_22.jpg b/doc/imgs_results/det_res_22.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d1255f49d9d371d4b91d98c6750a10a01f56b629
Binary files /dev/null and b/doc/imgs_results/det_res_22.jpg differ
diff --git a/doc/imgs_words_en/.DS_Store b/doc/imgs_words_en/.DS_Store
deleted file mode 100644
index 5008ddfcf53c02e82d7eee2e57c38e5672ef89f6..0000000000000000000000000000000000000000
Binary files a/doc/imgs_words_en/.DS_Store and /dev/null differ
diff --git a/ppocr/modeling/transforms/tps.py b/ppocr/modeling/transforms/tps.py
index 74bec7416bb1fd970ad00aecfdafc4173827a145..86665bedfff726c174e676cb544000a37ada0dad 100644
--- a/ppocr/modeling/transforms/tps.py
+++ b/ppocr/modeling/transforms/tps.py
@@ -180,7 +180,6 @@ class GridGenerator(nn.Layer):
P = self.build_P_paddle(I_r_size)
inv_delta_C_tensor = self.build_inv_delta_C_paddle(C).astype('float32')
- # inv_delta_C_tensor = paddle.zeros((23,23)).astype('float32')
P_hat_tensor = self.build_P_hat_paddle(
C, paddle.to_tensor(P)).astype('float32')