提交 fe46b77e 编写于 作者: M MissPenguin

fix conflicts

### Quick Start
`Style-Text` is an improvement of the SRNet network proposed in Baidu's self-developed text editing algorithm "Editing Text in the Wild". It is different from the commonly used GAN methods. This tool decomposes the text synthesis task into three sub-modules to improve the effect of synthetic data: text style transfer module, background extraction module and fusion module.
The following figure shows some example results. In addition, the actual `nameplate text recognition` scene and `the Korean text recognition` scene verify the effectiveness of the synthesis tool, as follows.
#### Preparation
1. Please refer the [QUICK INSTALLATION](../doc/doc_en/installation_en.md) to install PaddlePaddle. Python3 environment is strongly recommended.
2. Download the pretrained models and unzip:
```bash
cd tools/style_text_rec
wget /path/to/style_text_models.zip
unzip style_text_models.zip
```
You can dowload models [here](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip). If you save the model files in other folders, please edit the three model paths in `configs/config.yml`:
```
bg_generator:
pretrain: style_text_rec/bg_generator
...
text_generator:
pretrain: style_text_models/text_generator
...
fusion_generator:
pretrain: style_text_models/fusion_generator
```
#### Demo
1. You can use the following commands to run a demo:
```bash
python -m tools.synth_image -c configs/config.yml
```
2. The results are `fake_bg.jpg`, `fake_text.jpg` and `fake_fusion.jpg` as shown in the figure above. Above them:
* `fake_text.jpg` is the generated image with the same font style as `Style Input`;
* `fake_bg.jpg` is the generated image of `Style Input` after removing foreground.
* `fake_fusion.jpg` is the final result, that is synthesised by `fake_text.jpg` and `fake_bg.jpg`.
3. If want to generate image by other `Style Input` or `Text Input`, you can modify the `tools/synth_image.py`:
* `img = cv2.imread("examples/style_images/1.jpg")`: the path of `Style Input`;
* `corpus = "PaddleOCR"`: the `Text Input`;
* Notice:modify the language option(`language = "en"`) to adapt `Text Input`, that support `en`, `ch`, `ko`.
4. We also provide `batch_synth_images` mothod, that can combine corpus and pictures in pairs to generate a batch of data.
### Advanced Usage
#### Components
`Style Text Rec` mainly contains the following components:
* `style_samplers`: It can sample `Style Input` from a dataset. Now, We only provide `DatasetSampler`.
* `corpus_generators`: It can generate corpus. Now, wo only provide two `corpus_generators`:
* `EnNumCorpus`: It can generate a random string according to a given length, including uppercase and lowercase English letters, numbers and spaces.
* `FileCorpus`: It can read a text file and randomly return the words in it.
* `text_drawers`: It can generate `Text Input`(text picture in standard font according to the input corpus). Note that when using, you have to modify the language information according to the corpus.
* `predictors`: It can call the deep learning model to generate new data based on the `Style Input` and `Text Input`.
* `writers`: It can write the generated pictures(`fake_bg.jpg`, `fake_text.jpg` and `fake_fusion.jpg`) and label information to the disk.
* `synthesisers`: It can call the all modules to complete the work.
### Generate Dataset
Before the start, you need to prepare some data as material.
First, you should have the style reference data for synthesis tasks, which are generally used as datasets for OCR recognition tasks.
1. The referenced dataset can be specifed in `configs/dataset_config.yml`:
* `StyleSampler`:
* `method`: The method of `StyleSampler`.
* `image_home`: The directory of pictures.
* `label_file`: The list of pictures path if `with_label` is `false`, otherwise, the label file path.
* `with_label`: The `label_file` is label file or not.
* `CorpusGenerator`:
* `method`: The mothod of `CorpusGenerator`. If `FileCorpus` used, you need modify `corpus_file` and `language` accordingly, if `EnNumCorpus`, other configurations is not needed.
* `language`: The language of the corpus. Needed if method is not `EnNumCorpus`.
* `corpus_file`: The corpus file path. Needed if method is not `EnNumCorpus`.
2. You can run the following command to start synthesis task:
``` bash
python -m tools.synth_dataset.py -c configs/dataset_config.yml
```
3. You can using the following command to start multiple synthesis tasks in a multi-threaded manner, which needed to specifying tags by `-t`:
```bash
python -m tools.synth_dataset.py -t 0 -c configs/dataset_config.yml
python -m tools.synth_dataset.py -t 1 -c configs/dataset_config.yml
```
### OCR Recognition Training
After completing the above operations, you can get the synthetic data set for OCR recognition. Next, please complete the training by refering to [OCR Recognition Document](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/recognition. md#%E5%90%AF%E5%8A%A8%E8%AE%AD%E7%BB%83).
\ No newline at end of file
## Style Text Rec
### 目录
- [工具简介](#工具简介)
- [环境配置](#环境配置)
- [快速上手](#快速上手)
- [高级使用](#高级使用)
- [应用示例](#应用示例)
### 工具简介
<div align="center">
<img src="doc/images/3.png" width="800">
</div>
<div align="center">
<img src="doc/images/1.png" width="600">
</div>
Style-Text数据合成工具是基于百度自研的文本编辑算法《Editing Text in the Wild》https://arxiv.org/abs/1908.03047
不同于常用的基于GAN的数据合成工具,Style-Text主要框架包括:1.文本前景风格迁移模块 2.背景抽取模块 3.融合模块。经过这样三步,就可以迅速实现图片文字风格迁移。下图是一些该数据合成工具效果图。
<div align="center">
<img src="doc/images/2.png" width="1000">
</div>
### 环境配置
1. 参考[快速安装](../doc/doc_ch/installation.md),安装PaddleOCR。
2. 进入`style_text_rec`目录,下载模型,并解压:
```bash
cd style_text_rec
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip
unzip style_text_models.zip
```
如果您将模型保存再其他位置,请在`configs/config.yml`中修改模型文件的地址,修改时需要同时修改这三个配置:
```
bg_generator:
pretrain: style_text_models/bg_generator
...
text_generator:
pretrain: style_text_models/text_generator
...
fusion_generator:
pretrain: style_text_models/fusion_generator
```
### 快速上手
1. 运行tools/synth_image,生成示例图片:
```python
python3 -m tools.synth_image -c configs/config.yml
```
1. 运行后,会生成`fake_busion.jpg`,即为最终结果。
<div align="center">
<img src="doc/images/4.jpg" width="300">
</div>
除此之外,程序还会生成并保存中间结果:
* `fake_bg.jpg`:为风格参考图去掉文字后的背景;
* `fake_text.jpg`:是用提供的字符串,仿照风格参考图中文字的风格,生成在灰色背景上的文字图片。
2. 如果您想尝试其他风格图像和文字的效果,可以添加style_image,text_corpus和language参数:
```python
python3 -m tools.synth_image -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en
```
* 注意:语言选项和语料相对应,目前我们支持英文、简体中文和韩语。
3.`tools/synth_image.py`中,我们还提供了一个`batch_synth_images`方法,可以两两组合语料和图片,批量生成一批数据。
### 高级使用
在开始合成数据集前,需要准备一些素材。
首先,需要风格图片作为合成图片的参考依据,这些数据可以是用作训练OCR识别模型的数据集。本例中使用带有标注文件的数据集作为风格图片.
1.`configs/dataset_config.yml`中配置输入数据路径。
* `StyleSampler`
* `method`:使用的风格图片采样方法;
* `image_home`:风格图片目录;
* `label_file`:风格图片路径列表文件,如果所用数据集有label,则label_file为label文件路径;
* `with_label`:标志`label_file`是否为label文件。
* `CorpusGenerator`
* `method`:语料生成方法,目前有`FileCorpus``EnNumCorpus`可选。如果使用`EnNumCorpus`,则不需要填写其他配置,否则需要修改`corpus_file``language`
* `language`:语料的语种;
* `corpus_file`: 语料文件路径。
我们提供了一批[样例图](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/chkoen_5w.tar)供您试用,下面给出了一些示例:
<div align="center">
<img src="doc/images/5.png" width="800">
</div>
2. 运行`tools/synth_dataset`合成数据:
``` bash
python -m tools.synth_dataset -c configs/dataset_config.yml
```
3. 如果您想使用并行方式来快速合成数据,可以通过启动多个进程,在启动时需要指定不同的`tag``-t`),如下所示:
```bash
python3 -m tools.synth_dataset -t 0 -c configs/dataset_config.yml
python3 -m tools.synth_dataset -t 1 -c configs/dataset_config.yml
```
### 应用示例
在完成上述操作后,即可得到用于OCR识别的合成数据集,下面给出了一些数据集生成的示例:
<div align="center">
<img src="doc/images/6.png" width="800">
</div>
请您参考[OCR识别文档](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/recognition.md#%E5%90%AF%E5%8A%A8%E8%AE%AD%E7%BB%83),完成训练。
下面展示了一些使用合成数据训练的效果:
| 场景 | 字符 | 原始数据 | 测试数据 | 只使用原始数据的识别准确率 | 新增合成数据 | 使用合成数据识别准确率 | 指标提升 |
| -------- | ---------- | -------- | -------- | -------------------------- | ------------ | ---------------------- | -------- |
| 金属表面 | 英文和数字 | 2203 | 650 | 0.5938 | 20000 | 0.7546 | 16% |
| 随机背景 | 韩语 | 5631 | 1230 | 0.3012 | 100000 | 0.5057 | 20% |
### 项目结构
```
style_text_rec
|-- arch
| |-- base_module.py
| |-- decoder.py
| |-- encoder.py
| |-- spectral_norm.py
| `-- style_text_rec.py
|-- configs
| |-- config.yml
| `-- dataset_config.yml
|-- engine
| |-- corpus_generators.py
| |-- predictors.py
| |-- style_samplers.py
| |-- synthesisers.py
| |-- text_drawers.py
| `-- writers.py
|-- examples
| |-- corpus
| | `-- example.txt
| |-- image_list.txt
| `-- style_images
| |-- 1.jpg
| `-- 2.jpg
|-- fonts
| |-- ch_standard.ttf
| |-- en_standard.ttf
| `-- ko_standard.ttf
|-- tools
| |-- __init__.py
| |-- synth_dataset.py
| `-- synth_image.py
`-- utils
|-- config.py
|-- load_params.py
|-- logging.py
|-- math_functions.py
`-- sys_funcs.py
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from arch.spectral_norm import spectral_norm
class CBN(nn.Layer):
def __init__(self,
name,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
use_bias=False,
norm_layer=None,
act=None,
act_attr=None):
super(CBN, self).__init__()
if use_bias:
bias_attr = paddle.ParamAttr(name=name + "_bias")
else:
bias_attr = None
self._conv = paddle.nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
weight_attr=paddle.ParamAttr(name=name + "_weights"),
bias_attr=bias_attr)
if norm_layer:
self._norm_layer = getattr(paddle.nn, norm_layer)(
num_features=out_channels, name=name + "_bn")
else:
self._norm_layer = None
if act:
if act_attr:
self._act = getattr(paddle.nn, act)(**act_attr,
name=name + "_" + act)
else:
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
else:
self._act = None
def forward(self, x):
out = self._conv(x)
if self._norm_layer:
out = self._norm_layer(out)
if self._act:
out = self._act(out)
return out
class SNConv(nn.Layer):
def __init__(self,
name,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
use_bias=False,
norm_layer=None,
act=None,
act_attr=None):
super(SNConv, self).__init__()
if use_bias:
bias_attr = paddle.ParamAttr(name=name + "_bias")
else:
bias_attr = None
self._sn_conv = spectral_norm(
paddle.nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
weight_attr=paddle.ParamAttr(name=name + "_weights"),
bias_attr=bias_attr))
if norm_layer:
self._norm_layer = getattr(paddle.nn, norm_layer)(
num_features=out_channels, name=name + "_bn")
else:
self._norm_layer = None
if act:
if act_attr:
self._act = getattr(paddle.nn, act)(**act_attr,
name=name + "_" + act)
else:
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
else:
self._act = None
def forward(self, x):
out = self._sn_conv(x)
if self._norm_layer:
out = self._norm_layer(out)
if self._act:
out = self._act(out)
return out
class SNConvTranspose(nn.Layer):
def __init__(self,
name,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
output_padding=0,
dilation=1,
groups=1,
use_bias=False,
norm_layer=None,
act=None,
act_attr=None):
super(SNConvTranspose, self).__init__()
if use_bias:
bias_attr = paddle.ParamAttr(name=name + "_bias")
else:
bias_attr = None
self._sn_conv_transpose = spectral_norm(
paddle.nn.Conv2DTranspose(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
dilation=dilation,
groups=groups,
weight_attr=paddle.ParamAttr(name=name + "_weights"),
bias_attr=bias_attr))
if norm_layer:
self._norm_layer = getattr(paddle.nn, norm_layer)(
num_features=out_channels, name=name + "_bn")
else:
self._norm_layer = None
if act:
if act_attr:
self._act = getattr(paddle.nn, act)(**act_attr,
name=name + "_" + act)
else:
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
else:
self._act = None
def forward(self, x):
out = self._sn_conv_transpose(x)
if self._norm_layer:
out = self._norm_layer(out)
if self._act:
out = self._act(out)
return out
class MiddleNet(nn.Layer):
def __init__(self, name, in_channels, mid_channels, out_channels,
use_bias):
super(MiddleNet, self).__init__()
self._sn_conv1 = SNConv(
name=name + "_sn_conv1",
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=1,
use_bias=use_bias,
norm_layer=None,
act=None)
self._pad2d = nn.Pad2D(padding=[1, 1, 1, 1], mode="replicate")
self._sn_conv2 = SNConv(
name=name + "_sn_conv2",
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=3,
use_bias=use_bias)
self._sn_conv3 = SNConv(
name=name + "_sn_conv3",
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=1,
use_bias=use_bias)
def forward(self, x):
sn_conv1 = self._sn_conv1.forward(x)
pad_2d = self._pad2d.forward(sn_conv1)
sn_conv2 = self._sn_conv2.forward(pad_2d)
sn_conv3 = self._sn_conv3.forward(sn_conv2)
return sn_conv3
class ResBlock(nn.Layer):
def __init__(self, name, channels, norm_layer, use_dropout, use_dilation,
use_bias):
super(ResBlock, self).__init__()
if use_dilation:
padding_mat = [1, 1, 1, 1]
else:
padding_mat = [0, 0, 0, 0]
self._pad1 = nn.Pad2D(padding_mat, mode="replicate")
self._sn_conv1 = SNConv(
name=name + "_sn_conv1",
in_channels=channels,
out_channels=channels,
kernel_size=3,
padding=0,
norm_layer=norm_layer,
use_bias=use_bias,
act="ReLU",
act_attr=None)
if use_dropout:
self._dropout = nn.Dropout(0.5)
else:
self._dropout = None
self._pad2 = nn.Pad2D([1, 1, 1, 1], mode="replicate")
self._sn_conv2 = SNConv(
name=name + "_sn_conv2",
in_channels=channels,
out_channels=channels,
kernel_size=3,
norm_layer=norm_layer,
use_bias=use_bias,
act="ReLU",
act_attr=None)
def forward(self, x):
pad1 = self._pad1.forward(x)
sn_conv1 = self._sn_conv1.forward(pad1)
pad2 = self._pad2.forward(sn_conv1)
sn_conv2 = self._sn_conv2.forward(pad2)
return sn_conv2 + x
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from arch.base_module import SNConv, SNConvTranspose, ResBlock
class Decoder(nn.Layer):
def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
act, act_attr, conv_block_dropout, conv_block_num,
conv_block_dilation, out_conv_act, out_conv_act_attr):
super(Decoder, self).__init__()
conv_blocks = []
for i in range(conv_block_num):
conv_blocks.append(
ResBlock(
name="{}_conv_block_{}".format(name, i),
channels=encode_dim * 8,
norm_layer=norm_layer,
use_dropout=conv_block_dropout,
use_dilation=conv_block_dilation,
use_bias=use_bias))
self.conv_blocks = nn.Sequential(*conv_blocks)
self._up1 = SNConvTranspose(
name=name + "_up1",
in_channels=encode_dim * 8,
out_channels=encode_dim * 4,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._up2 = SNConvTranspose(
name=name + "_up2",
in_channels=encode_dim * 4,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._up3 = SNConvTranspose(
name=name + "_up3",
in_channels=encode_dim * 2,
out_channels=encode_dim,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
self._out_conv = SNConv(
name=name + "_out_conv",
in_channels=encode_dim,
out_channels=out_channels,
kernel_size=3,
use_bias=use_bias,
norm_layer=None,
act=out_conv_act,
act_attr=out_conv_act_attr)
def forward(self, x):
if isinstance(x, (list, tuple)):
x = paddle.concat(x, axis=1)
output_dict = dict()
output_dict["conv_blocks"] = self.conv_blocks.forward(x)
output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
output_dict["up2"] = self._up2.forward(output_dict["up1"])
output_dict["up3"] = self._up3.forward(output_dict["up2"])
output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
return output_dict
class DecoderUnet(nn.Layer):
def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
act, act_attr, conv_block_dropout, conv_block_num,
conv_block_dilation, out_conv_act, out_conv_act_attr):
super(DecoderUnet, self).__init__()
conv_blocks = []
for i in range(conv_block_num):
conv_blocks.append(
ResBlock(
name="{}_conv_block_{}".format(name, i),
channels=encode_dim * 8,
norm_layer=norm_layer,
use_dropout=conv_block_dropout,
use_dilation=conv_block_dilation,
use_bias=use_bias))
self._conv_blocks = nn.Sequential(*conv_blocks)
self._up1 = SNConvTranspose(
name=name + "_up1",
in_channels=encode_dim * 8,
out_channels=encode_dim * 4,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._up2 = SNConvTranspose(
name=name + "_up2",
in_channels=encode_dim * 8,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._up3 = SNConvTranspose(
name=name + "_up3",
in_channels=encode_dim * 4,
out_channels=encode_dim,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
self._out_conv = SNConv(
name=name + "_out_conv",
in_channels=encode_dim,
out_channels=out_channels,
kernel_size=3,
use_bias=use_bias,
norm_layer=None,
act=out_conv_act,
act_attr=out_conv_act_attr)
def forward(self, x, y, feature2, feature1):
output_dict = dict()
output_dict["conv_blocks"] = self._conv_blocks(
paddle.concat(
(x, y), axis=1))
output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
output_dict["up2"] = self._up2.forward(
paddle.concat(
(output_dict["up1"], feature2), axis=1))
output_dict["up3"] = self._up3.forward(
paddle.concat(
(output_dict["up2"], feature1), axis=1))
output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
return output_dict
class SingleDecoder(nn.Layer):
def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
act, act_attr, conv_block_dropout, conv_block_num,
conv_block_dilation, out_conv_act, out_conv_act_attr):
super(SingleDecoder, self).__init__()
conv_blocks = []
for i in range(conv_block_num):
conv_blocks.append(
ResBlock(
name="{}_conv_block_{}".format(name, i),
channels=encode_dim * 4,
norm_layer=norm_layer,
use_dropout=conv_block_dropout,
use_dilation=conv_block_dilation,
use_bias=use_bias))
self._conv_blocks = nn.Sequential(*conv_blocks)
self._up1 = SNConvTranspose(
name=name + "_up1",
in_channels=encode_dim * 4,
out_channels=encode_dim * 4,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._up2 = SNConvTranspose(
name=name + "_up2",
in_channels=encode_dim * 8,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._up3 = SNConvTranspose(
name=name + "_up3",
in_channels=encode_dim * 4,
out_channels=encode_dim,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
self._out_conv = SNConv(
name=name + "_out_conv",
in_channels=encode_dim,
out_channels=out_channels,
kernel_size=3,
use_bias=use_bias,
norm_layer=None,
act=out_conv_act,
act_attr=out_conv_act_attr)
def forward(self, x, feature2, feature1):
output_dict = dict()
output_dict["conv_blocks"] = self._conv_blocks.forward(x)
output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
output_dict["up2"] = self._up2.forward(
paddle.concat(
(output_dict["up1"], feature2), axis=1))
output_dict["up3"] = self._up3.forward(
paddle.concat(
(output_dict["up2"], feature1), axis=1))
output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
return output_dict
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from arch.base_module import SNConv, SNConvTranspose, ResBlock
class Encoder(nn.Layer):
def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
act, act_attr, conv_block_dropout, conv_block_num,
conv_block_dilation):
super(Encoder, self).__init__()
self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
self._in_conv = SNConv(
name=name + "_in_conv",
in_channels=in_channels,
out_channels=encode_dim,
kernel_size=7,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._down1 = SNConv(
name=name + "_down1",
in_channels=encode_dim,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._down2 = SNConv(
name=name + "_down2",
in_channels=encode_dim * 2,
out_channels=encode_dim * 4,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._down3 = SNConv(
name=name + "_down3",
in_channels=encode_dim * 4,
out_channels=encode_dim * 4,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
conv_blocks = []
for i in range(conv_block_num):
conv_blocks.append(
ResBlock(
name="{}_conv_block_{}".format(name, i),
channels=encode_dim * 4,
norm_layer=norm_layer,
use_dropout=conv_block_dropout,
use_dilation=conv_block_dilation,
use_bias=use_bias))
self._conv_blocks = nn.Sequential(*conv_blocks)
def forward(self, x):
out_dict = dict()
x = self._pad2d(x)
out_dict["in_conv"] = self._in_conv.forward(x)
out_dict["down1"] = self._down1.forward(out_dict["in_conv"])
out_dict["down2"] = self._down2.forward(out_dict["down1"])
out_dict["down3"] = self._down3.forward(out_dict["down2"])
out_dict["res_blocks"] = self._conv_blocks.forward(out_dict["down3"])
return out_dict
class EncoderUnet(nn.Layer):
def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
act, act_attr):
super(EncoderUnet, self).__init__()
self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
self._in_conv = SNConv(
name=name + "_in_conv",
in_channels=in_channels,
out_channels=encode_dim,
kernel_size=7,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._down1 = SNConv(
name=name + "_down1",
in_channels=encode_dim,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._down2 = SNConv(
name=name + "_down2",
in_channels=encode_dim * 2,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._down3 = SNConv(
name=name + "_down3",
in_channels=encode_dim * 2,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._down4 = SNConv(
name=name + "_down4",
in_channels=encode_dim * 2,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._up1 = SNConvTranspose(
name=name + "_up1",
in_channels=encode_dim * 2,
out_channels=encode_dim * 2,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
self._up2 = SNConvTranspose(
name=name + "_up2",
in_channels=encode_dim * 4,
out_channels=encode_dim * 4,
kernel_size=3,
stride=2,
padding=1,
use_bias=use_bias,
norm_layer=norm_layer,
act=act,
act_attr=act_attr)
def forward(self, x):
output_dict = dict()
x = self._pad2d(x)
output_dict['in_conv'] = self._in_conv.forward(x)
output_dict['down1'] = self._down1.forward(output_dict['in_conv'])
output_dict['down2'] = self._down2.forward(output_dict['down1'])
output_dict['down3'] = self._down3.forward(output_dict['down2'])
output_dict['down4'] = self._down4.forward(output_dict['down3'])
output_dict['up1'] = self._up1.forward(output_dict['down4'])
output_dict['up2'] = self._up2.forward(
paddle.concat(
(output_dict['down3'], output_dict['up1']), axis=1))
output_dict['concat'] = paddle.concat(
(output_dict['down2'], output_dict['up2']), axis=1)
return output_dict
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
def normal_(x, mean=0., std=1.):
temp_value = paddle.normal(mean, std, shape=x.shape)
x.set_value(temp_value)
return x
class SpectralNorm(object):
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
self.name = name
self.dim = dim
if n_power_iterations <= 0:
raise ValueError('Expected n_power_iterations to be positive, but '
'got n_power_iterations={}'.format(
n_power_iterations))
self.n_power_iterations = n_power_iterations
self.eps = eps
def reshape_weight_to_matrix(self, weight):
weight_mat = weight
if self.dim != 0:
# transpose dim to front
weight_mat = weight_mat.transpose([
self.dim,
* [d for d in range(weight_mat.dim()) if d != self.dim]
])
height = weight_mat.shape[0]
return weight_mat.reshape([height, -1])
def compute_weight(self, module, do_power_iteration):
weight = getattr(module, self.name + '_orig')
u = getattr(module, self.name + '_u')
v = getattr(module, self.name + '_v')
weight_mat = self.reshape_weight_to_matrix(weight)
if do_power_iteration:
with paddle.no_grad():
for _ in range(self.n_power_iterations):
v.set_value(
F.normalize(
paddle.matmul(
weight_mat,
u,
transpose_x=True,
transpose_y=False),
axis=0,
epsilon=self.eps, ))
u.set_value(
F.normalize(
paddle.matmul(weight_mat, v),
axis=0,
epsilon=self.eps, ))
if self.n_power_iterations > 0:
u = u.clone()
v = v.clone()
sigma = paddle.dot(u, paddle.mv(weight_mat, v))
weight = weight / sigma
return weight
def remove(self, module):
with paddle.no_grad():
weight = self.compute_weight(module, do_power_iteration=False)
delattr(module, self.name)
delattr(module, self.name + '_u')
delattr(module, self.name + '_v')
delattr(module, self.name + '_orig')
module.add_parameter(self.name, weight.detach())
def __call__(self, module, inputs):
setattr(
module,
self.name,
self.compute_weight(
module, do_power_iteration=module.training))
@staticmethod
def apply(module, name, n_power_iterations, dim, eps):
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
raise RuntimeError(
"Cannot register two spectral_norm hooks on "
"the same parameter {}".format(name))
fn = SpectralNorm(name, n_power_iterations, dim, eps)
weight = module._parameters[name]
with paddle.no_grad():
weight_mat = fn.reshape_weight_to_matrix(weight)
h, w = weight_mat.shape
# randomly initialize u and v
u = module.create_parameter([h])
u = normal_(u, 0., 1.)
v = module.create_parameter([w])
v = normal_(v, 0., 1.)
u = F.normalize(u, axis=0, epsilon=fn.eps)
v = F.normalize(v, axis=0, epsilon=fn.eps)
# delete fn.name form parameters, otherwise you can not set attribute
del module._parameters[fn.name]
module.add_parameter(fn.name + "_orig", weight)
# still need to assign weight back as fn.name because all sorts of
# things may assume that it exists, e.g., when initializing weights.
# However, we can't directly assign as it could be an Parameter and
# gets added as a parameter. Instead, we register weight * 1.0 as a plain
# attribute.
setattr(module, fn.name, weight * 1.0)
module.register_buffer(fn.name + "_u", u)
module.register_buffer(fn.name + "_v", v)
module.register_forward_pre_hook(fn)
return fn
def spectral_norm(module,
name='weight',
n_power_iterations=1,
eps=1e-12,
dim=None):
if dim is None:
if isinstance(module, (nn.Conv1DTranspose, nn.Conv2DTranspose,
nn.Conv3DTranspose, nn.Linear)):
dim = 1
else:
dim = 0
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
return module
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from arch.base_module import MiddleNet, ResBlock
from arch.encoder import Encoder
from arch.decoder import Decoder, DecoderUnet, SingleDecoder
from utils.load_params import load_dygraph_pretrain
from utils.logging import get_logger
class StyleTextRec(nn.Layer):
def __init__(self, config):
super(StyleTextRec, self).__init__()
self.logger = get_logger()
self.text_generator = TextGenerator(config["Predictor"][
"text_generator"])
self.bg_generator = BgGeneratorWithMask(config["Predictor"][
"bg_generator"])
self.fusion_generator = FusionGeneratorSimple(config["Predictor"][
"fusion_generator"])
bg_generator_pretrain = config["Predictor"]["bg_generator"]["pretrain"]
text_generator_pretrain = config["Predictor"]["text_generator"][
"pretrain"]
fusion_generator_pretrain = config["Predictor"]["fusion_generator"][
"pretrain"]
load_dygraph_pretrain(
self.bg_generator,
self.logger,
path=bg_generator_pretrain,
load_static_weights=False)
load_dygraph_pretrain(
self.text_generator,
self.logger,
path=text_generator_pretrain,
load_static_weights=False)
load_dygraph_pretrain(
self.fusion_generator,
self.logger,
path=fusion_generator_pretrain,
load_static_weights=False)
def forward(self, style_input, text_input):
text_gen_output = self.text_generator.forward(style_input, text_input)
fake_text = text_gen_output["fake_text"]
fake_sk = text_gen_output["fake_sk"]
bg_gen_output = self.bg_generator.forward(style_input)
bg_encode_feature = bg_gen_output["bg_encode_feature"]
bg_decode_feature1 = bg_gen_output["bg_decode_feature1"]
bg_decode_feature2 = bg_gen_output["bg_decode_feature2"]
fake_bg = bg_gen_output["fake_bg"]
fusion_gen_output = self.fusion_generator.forward(fake_text, fake_bg)
fake_fusion = fusion_gen_output["fake_fusion"]
return {
"fake_fusion": fake_fusion,
"fake_text": fake_text,
"fake_sk": fake_sk,
"fake_bg": fake_bg,
}
class TextGenerator(nn.Layer):
def __init__(self, config):
super(TextGenerator, self).__init__()
name = config["module_name"]
encode_dim = config["encode_dim"]
norm_layer = config["norm_layer"]
conv_block_dropout = config["conv_block_dropout"]
conv_block_num = config["conv_block_num"]
conv_block_dilation = config["conv_block_dilation"]
if norm_layer == "InstanceNorm2D":
use_bias = True
else:
use_bias = False
self.encoder_text = Encoder(
name=name + "_encoder_text",
in_channels=3,
encode_dim=encode_dim,
use_bias=use_bias,
norm_layer=norm_layer,
act="ReLU",
act_attr=None,
conv_block_dropout=conv_block_dropout,
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation)
self.encoder_style = Encoder(
name=name + "_encoder_style",
in_channels=3,
encode_dim=encode_dim,
use_bias=use_bias,
norm_layer=norm_layer,
act="ReLU",
act_attr=None,
conv_block_dropout=conv_block_dropout,
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation)
self.decoder_text = Decoder(
name=name + "_decoder_text",
encode_dim=encode_dim,
out_channels=int(encode_dim / 2),
use_bias=use_bias,
norm_layer=norm_layer,
act="ReLU",
act_attr=None,
conv_block_dropout=conv_block_dropout,
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation,
out_conv_act="Tanh",
out_conv_act_attr=None)
self.decoder_sk = Decoder(
name=name + "_decoder_sk",
encode_dim=encode_dim,
out_channels=1,
use_bias=use_bias,
norm_layer=norm_layer,
act="ReLU",
act_attr=None,
conv_block_dropout=conv_block_dropout,
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation,
out_conv_act="Sigmoid",
out_conv_act_attr=None)
self.middle = MiddleNet(
name=name + "_middle_net",
in_channels=int(encode_dim / 2) + 1,
mid_channels=encode_dim,
out_channels=3,
use_bias=use_bias)
def forward(self, style_input, text_input):
style_feature = self.encoder_style.forward(style_input)["res_blocks"]
text_feature = self.encoder_text.forward(text_input)["res_blocks"]
fake_c_temp = self.decoder_text.forward([text_feature,
style_feature])["out_conv"]
fake_sk = self.decoder_sk.forward([text_feature,
style_feature])["out_conv"]
fake_text = self.middle(paddle.concat((fake_c_temp, fake_sk), axis=1))
return {"fake_sk": fake_sk, "fake_text": fake_text}
class BgGeneratorWithMask(nn.Layer):
def __init__(self, config):
super(BgGeneratorWithMask, self).__init__()
name = config["module_name"]
encode_dim = config["encode_dim"]
norm_layer = config["norm_layer"]
conv_block_dropout = config["conv_block_dropout"]
conv_block_num = config["conv_block_num"]
conv_block_dilation = config["conv_block_dilation"]
self.output_factor = config.get("output_factor", 1.0)
if norm_layer == "InstanceNorm2D":
use_bias = True
else:
use_bias = False
self.encoder_bg = Encoder(
name=name + "_encoder_bg",
in_channels=3,
encode_dim=encode_dim,
use_bias=use_bias,
norm_layer=norm_layer,
act="ReLU",
act_attr=None,
conv_block_dropout=conv_block_dropout,
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation)
self.decoder_bg = SingleDecoder(
name=name + "_decoder_bg",
encode_dim=encode_dim,
out_channels=3,
use_bias=use_bias,
norm_layer=norm_layer,
act="ReLU",
act_attr=None,
conv_block_dropout=conv_block_dropout,
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation,
out_conv_act="Tanh",
out_conv_act_attr=None)
self.decoder_mask = Decoder(
name=name + "_decoder_mask",
encode_dim=encode_dim // 2,
out_channels=1,
use_bias=use_bias,
norm_layer=norm_layer,
act="ReLU",
act_attr=None,
conv_block_dropout=conv_block_dropout,
conv_block_num=conv_block_num,
conv_block_dilation=conv_block_dilation,
out_conv_act="Sigmoid",
out_conv_act_attr=None)
self.middle = MiddleNet(
name=name + "_middle_net",
in_channels=3 + 1,
mid_channels=encode_dim,
out_channels=3,
use_bias=use_bias)
def forward(self, style_input):
encode_bg_output = self.encoder_bg(style_input)
decode_bg_output = self.decoder_bg(encode_bg_output["res_blocks"],
encode_bg_output["down2"],
encode_bg_output["down1"])
fake_c_temp = decode_bg_output["out_conv"]
fake_bg_mask = self.decoder_mask.forward(encode_bg_output[
"res_blocks"])["out_conv"]
fake_bg = self.middle(
paddle.concat(
(fake_c_temp, fake_bg_mask), axis=1))
return {
"bg_encode_feature": encode_bg_output["res_blocks"],
"bg_decode_feature1": decode_bg_output["up1"],
"bg_decode_feature2": decode_bg_output["up2"],
"fake_bg": fake_bg,
"fake_bg_mask": fake_bg_mask,
}
class FusionGeneratorSimple(nn.Layer):
def __init__(self, config):
super(FusionGeneratorSimple, self).__init__()
name = config["module_name"]
encode_dim = config["encode_dim"]
norm_layer = config["norm_layer"]
conv_block_dropout = config["conv_block_dropout"]
conv_block_dilation = config["conv_block_dilation"]
if norm_layer == "InstanceNorm2D":
use_bias = True
else:
use_bias = False
self._conv = nn.Conv2D(
in_channels=6,
out_channels=encode_dim,
kernel_size=3,
stride=1,
padding=1,
groups=1,
weight_attr=paddle.ParamAttr(name=name + "_conv_weights"),
bias_attr=False)
self._res_block = ResBlock(
name="{}_conv_block".format(name),
channels=encode_dim,
norm_layer=norm_layer,
use_dropout=conv_block_dropout,
use_dilation=conv_block_dilation,
use_bias=use_bias)
self._reduce_conv = nn.Conv2D(
in_channels=encode_dim,
out_channels=3,
kernel_size=3,
stride=1,
padding=1,
groups=1,
weight_attr=paddle.ParamAttr(name=name + "_reduce_conv_weights"),
bias_attr=False)
def forward(self, fake_text, fake_bg):
fake_concat = paddle.concat((fake_text, fake_bg), axis=1)
fake_concat_tmp = self._conv(fake_concat)
output_res = self._res_block(fake_concat_tmp)
fake_fusion = self._reduce_conv(output_res)
return {"fake_fusion": fake_fusion}
Global:
output_num: 10
output_dir: output_data
use_gpu: false
image_height: 32
image_width: 320
TextDrawer:
fonts:
en: fonts/en_standard.ttf
ch: fonts/ch_standard.ttf
ko: fonts/ko_standard.ttf
Predictor:
method: StyleTextRecPredictor
algorithm: StyleTextRec
scale: 0.00392156862745098
mean:
- 0.5
- 0.5
- 0.5
std:
- 0.5
- 0.5
- 0.5
expand_result: false
bg_generator:
pretrain: style_text_models/bg_generator
module_name: bg_generator
generator_type: BgGeneratorWithMask
encode_dim: 64
norm_layer: null
conv_block_num: 4
conv_block_dropout: false
conv_block_dilation: true
output_factor: 1.05
text_generator:
pretrain: style_text_models/text_generator
module_name: text_generator
generator_type: TextGenerator
encode_dim: 64
norm_layer: InstanceNorm2D
conv_block_num: 4
conv_block_dropout: false
conv_block_dilation: true
fusion_generator:
pretrain: style_text_models/fusion_generator
module_name: fusion_generator
generator_type: FusionGeneratorSimple
encode_dim: 64
norm_layer: null
conv_block_num: 4
conv_block_dropout: false
conv_block_dilation: true
Writer:
method: SimpleWriter
Global:
output_num: 10
output_dir: output_data
use_gpu: false
image_height: 32
image_width: 320
standard_font: fonts/en_standard.ttf
TextDrawer:
fonts:
en: fonts/en_standard.ttf
ch: fonts/ch_standard.ttf
ko: fonts/ko_standard.ttf
StyleSampler:
method: DatasetSampler
image_home: examples
label_file: examples/image_list.txt
with_label: true
CorpusGenerator:
method: FileCorpus
language: ch
corpus_file: examples/corpus/example.txt
Predictor:
method: StyleTextRecPredictor
algorithm: StyleTextRec
scale: 0.00392156862745098
mean:
- 0.5
- 0.5
- 0.5
std:
- 0.5
- 0.5
- 0.5
expand_result: false
bg_generator:
pretrain: models/style_text_rec/bg_generator
module_name: bg_generator
generator_type: BgGeneratorWithMask
encode_dim: 64
norm_layer: null
conv_block_num: 4
conv_block_dropout: false
conv_block_dilation: true
output_factor: 1.05
text_generator:
pretrain: models/style_text_rec/text_generator
module_name: text_generator
generator_type: TextGenerator
encode_dim: 64
norm_layer: InstanceNorm2D
conv_block_num: 4
conv_block_dropout: false
conv_block_dilation: true
fusion_generator:
pretrain: models/style_text_rec/fusion_generator
module_name: fusion_generator
generator_type: FusionGeneratorSimple
encode_dim: 64
norm_layer: null
conv_block_num: 4
conv_block_dropout: false
conv_block_dilation: true
Writer:
method: SimpleWriter
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from utils.logging import get_logger
class FileCorpus(object):
def __init__(self, config):
self.logger = get_logger()
self.logger.info("using FileCorpus")
self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
corpus_file = config["CorpusGenerator"]["corpus_file"]
self.language = config["CorpusGenerator"]["language"]
with open(corpus_file, 'r') as f:
corpus_raw = f.read()
self.corpus_list = corpus_raw.split("\n")[:-1]
assert len(self.corpus_list) > 0
random.shuffle(self.corpus_list)
self.index = 0
def generate(self, corpus_length=0):
if self.index >= len(self.corpus_list):
self.index = 0
random.shuffle(self.corpus_list)
corpus = self.corpus_list[self.index]
if corpus_length != 0:
corpus = corpus[0:corpus_length]
if corpus_length > len(corpus):
self.logger.warning("generated corpus is shorter than expected.")
self.index += 1
return self.language, corpus
class EnNumCorpus(object):
def __init__(self, config):
self.logger = get_logger()
self.logger.info("using NumberCorpus")
self.num_list = "0123456789"
self.en_char_list = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
self.height = config["Global"]["image_height"]
self.max_width = config["Global"]["image_width"]
def generate(self, corpus_length=0):
corpus = ""
if corpus_length == 0:
corpus_length = random.randint(5, 15)
for i in range(corpus_length):
if random.random() < 0.2:
corpus += "{}".format(random.choice(self.en_char_list))
else:
corpus += "{}".format(random.choice(self.num_list))
return "en", corpus
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
import math
import paddle
from arch import style_text_rec
from utils.sys_funcs import check_gpu
from utils.logging import get_logger
class StyleTextRecPredictor(object):
def __init__(self, config):
algorithm = config['Predictor']['algorithm']
assert algorithm in ["StyleTextRec"
], "Generator {} not supported.".format(algorithm)
use_gpu = config["Global"]['use_gpu']
check_gpu(use_gpu)
self.logger = get_logger()
self.generator = getattr(style_text_rec, algorithm)(config)
self.height = config["Global"]["image_height"]
self.width = config["Global"]["image_width"]
self.scale = config["Predictor"]["scale"]
self.mean = config["Predictor"]["mean"]
self.std = config["Predictor"]["std"]
self.expand_result = config["Predictor"]["expand_result"]
def predict(self, style_input, text_input):
style_input = self.rep_style_input(style_input, text_input)
tensor_style_input = self.preprocess(style_input)
tensor_text_input = self.preprocess(text_input)
style_text_result = self.generator.forward(tensor_style_input,
tensor_text_input)
fake_fusion = self.postprocess(style_text_result["fake_fusion"])
fake_text = self.postprocess(style_text_result["fake_text"])
fake_sk = self.postprocess(style_text_result["fake_sk"])
fake_bg = self.postprocess(style_text_result["fake_bg"])
bbox = self.get_text_boundary(fake_text)
if bbox:
left, right, top, bottom = bbox
fake_fusion = fake_fusion[top:bottom, left:right, :]
fake_text = fake_text[top:bottom, left:right, :]
fake_sk = fake_sk[top:bottom, left:right, :]
fake_bg = fake_bg[top:bottom, left:right, :]
# fake_fusion = self.crop_by_text(img_fake_fusion, img_fake_text)
return {
"fake_fusion": fake_fusion,
"fake_text": fake_text,
"fake_sk": fake_sk,
"fake_bg": fake_bg,
}
def preprocess(self, img):
img = (img.astype('float32') * self.scale - self.mean) / self.std
img_height, img_width, channel = img.shape
assert channel == 3, "Please use an rgb image."
ratio = img_width / float(img_height)
if math.ceil(self.height * ratio) > self.width:
resized_w = self.width
else:
resized_w = int(math.ceil(self.height * ratio))
img = cv2.resize(img, (resized_w, self.height))
new_img = np.zeros([self.height, self.width, 3]).astype('float32')
new_img[:, 0:resized_w, :] = img
img = new_img.transpose((2, 0, 1))
img = img[np.newaxis, :, :, :]
return paddle.to_tensor(img)
def postprocess(self, tensor):
img = tensor.numpy()[0]
img = img.transpose((1, 2, 0))
img = (img * self.std + self.mean) / self.scale
img = np.maximum(img, 0.0)
img = np.minimum(img, 255.0)
img = img.astype('uint8')
return img
def rep_style_input(self, style_input, text_input):
rep_num = int(1.2 * (text_input.shape[1] / text_input.shape[0]) /
(style_input.shape[1] / style_input.shape[0])) + 1
style_input = np.tile(style_input, reps=[1, rep_num, 1])
max_width = int(self.width / self.height * style_input.shape[0])
style_input = style_input[:, :max_width, :]
return style_input
def get_text_boundary(self, text_img):
img_height = text_img.shape[0]
img_width = text_img.shape[1]
bounder = 3
text_canny_img = cv2.Canny(text_img, 10, 20)
edge_num_h = text_canny_img.sum(axis=0)
no_zero_list_h = np.where(edge_num_h > 0)[0]
edge_num_w = text_canny_img.sum(axis=1)
no_zero_list_w = np.where(edge_num_w > 0)[0]
if len(no_zero_list_h) == 0 or len(no_zero_list_w) == 0:
return None
left = max(no_zero_list_h[0] - bounder, 0)
right = min(no_zero_list_h[-1] + bounder, img_width)
top = max(no_zero_list_w[0] - bounder, 0)
bottom = min(no_zero_list_w[-1] + bounder, img_height)
return [left, right, top, bottom]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import random
import cv2
class DatasetSampler(object):
def __init__(self, config):
self.image_home = config["StyleSampler"]["image_home"]
label_file = config["StyleSampler"]["label_file"]
self.dataset_with_label = config["StyleSampler"]["with_label"]
self.height = config["Global"]["image_height"]
self.index = 0
with open(label_file, "r") as f:
label_raw = f.read()
self.path_label_list = label_raw.split("\n")[:-1]
assert len(self.path_label_list) > 0
random.shuffle(self.path_label_list)
def sample(self):
if self.index >= len(self.path_label_list):
random.shuffle(self.path_label_list)
self.index = 0
if self.dataset_with_label:
path_label = self.path_label_list[self.index]
rel_image_path, label = path_label.split('\t')
else:
rel_image_path = self.path_label_list[self.index]
label = None
img_path = "{}/{}".format(self.image_home, rel_image_path)
image = cv2.imread(img_path)
origin_height = image.shape[0]
ratio = self.height / origin_height
width = int(image.shape[1] * ratio)
height = int(image.shape[0] * ratio)
image = cv2.resize(image, (width, height))
self.index += 1
if label:
return {"image": image, "label": label}
else:
return {"image": image}
def duplicate_image(image, width):
image_width = image.shape[1]
dup_num = width // image_width + 1
image = np.tile(image, reps=[1, dup_num, 1])
cropped_image = image[:, :width, :]
return cropped_image
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from utils.config import ArgsParser, load_config, override_config
from utils.logging import get_logger
from engine import style_samplers, corpus_generators, text_drawers, predictors, writers
class ImageSynthesiser(object):
def __init__(self):
self.FLAGS = ArgsParser().parse_args()
self.config = load_config(self.FLAGS.config)
self.config = override_config(self.config, options=self.FLAGS.override)
self.output_dir = self.config["Global"]["output_dir"]
if not os.path.exists(self.output_dir):
os.mkdir(self.output_dir)
self.logger = get_logger(
log_file='{}/predict.log'.format(self.output_dir))
self.text_drawer = text_drawers.StdTextDrawer(self.config)
predictor_method = self.config["Predictor"]["method"]
assert predictor_method is not None
self.predictor = getattr(predictors, predictor_method)(self.config)
def synth_image(self, corpus, style_input, language="en"):
corpus, text_input = self.text_drawer.draw_text(corpus, language)
synth_result = self.predictor.predict(style_input, text_input)
return synth_result
class DatasetSynthesiser(ImageSynthesiser):
def __init__(self):
super(DatasetSynthesiser, self).__init__()
self.tag = self.FLAGS.tag
self.output_num = self.config["Global"]["output_num"]
corpus_generator_method = self.config["CorpusGenerator"]["method"]
self.corpus_generator = getattr(corpus_generators,
corpus_generator_method)(self.config)
style_sampler_method = self.config["StyleSampler"]["method"]
assert style_sampler_method is not None
self.style_sampler = style_samplers.DatasetSampler(self.config)
self.writer = writers.SimpleWriter(self.config, self.tag)
def synth_dataset(self):
for i in range(self.output_num):
style_data = self.style_sampler.sample()
style_input = style_data["image"]
corpus_language, text_input_label = self.corpus_generator.generate(
)
text_input_label, text_input = self.text_drawer.draw_text(
text_input_label, corpus_language)
synth_result = self.predictor.predict(style_input, text_input)
fake_fusion = synth_result["fake_fusion"]
self.writer.save_image(fake_fusion, text_input_label)
self.writer.save_label()
self.writer.merge_label()
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from utils.logging import get_logger
class StdTextDrawer(object):
def __init__(self, config):
self.logger = get_logger()
self.max_width = config["Global"]["image_width"]
self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
self.height = config["Global"]["image_height"]
self.font_dict = {}
self.load_fonts(config["TextDrawer"]["fonts"])
self.support_languages = list(self.font_dict)
def load_fonts(self, fonts_config):
for language in fonts_config:
font_path = fonts_config[language]
font_height = self.get_valid_height(font_path)
font = ImageFont.truetype(font_path, font_height)
self.font_dict[language] = font
def get_valid_height(self, font_path):
font = ImageFont.truetype(font_path, self.height - 4)
_, font_height = font.getsize(self.char_list)
if font_height <= self.height - 4:
return self.height - 4
else:
return int((self.height - 4)**2 / font_height)
def draw_text(self, corpus, language="en", crop=True):
if language not in self.support_languages:
self.logger.warning(
"language {} not supported, use en instead.".format(language))
language = "en"
if crop:
width = min(self.max_width, len(corpus) * self.height) + 4
else:
width = len(corpus) * self.height + 4
bg = Image.new("RGB", (width, self.height), color=(127, 127, 127))
draw = ImageDraw.Draw(bg)
char_x = 2
font = self.font_dict[language]
for i, char_i in enumerate(corpus):
char_size = font.getsize(char_i)[0]
draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font)
char_x += char_size
if char_x >= width:
corpus = corpus[0:i + 1]
self.logger.warning("corpus length exceed limit: {}".format(
corpus))
break
text_input = np.array(bg).astype(np.uint8)
text_input = text_input[:, 0:char_x, :]
return corpus, text_input
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import glob
from utils.logging import get_logger
class SimpleWriter(object):
def __init__(self, config, tag):
self.logger = get_logger()
self.output_dir = config["Global"]["output_dir"]
self.counter = 0
self.label_dict = {}
self.tag = tag
self.label_file_index = 0
def save_image(self, image, text_input_label):
image_home = os.path.join(self.output_dir, "images", self.tag)
if not os.path.exists(image_home):
os.makedirs(image_home)
image_path = os.path.join(image_home, "{}.png".format(self.counter))
# todo support continue synth
cv2.imwrite(image_path, image)
self.logger.info("generate image: {}".format(image_path))
image_name = os.path.join(self.tag, "{}.png".format(self.counter))
self.label_dict[image_name] = text_input_label
self.counter += 1
if not self.counter % 100:
self.save_label()
def save_label(self):
label_raw = ""
label_home = os.path.join(self.output_dir, "label")
if not os.path.exists(label_home):
os.mkdir(label_home)
for image_path in self.label_dict:
label = self.label_dict[image_path]
label_raw += "{}\t{}\n".format(image_path, label)
label_file_path = os.path.join(label_home,
"{}_label.txt".format(self.tag))
with open(label_file_path, "w") as f:
f.write(label_raw)
self.label_file_index += 1
def merge_label(self):
label_raw = ""
label_file_regex = os.path.join(self.output_dir, "label",
"*_label.txt")
label_file_list = glob.glob(label_file_regex)
for label_file_i in label_file_list:
with open(label_file_i, "r") as f:
label_raw += f.read()
label_file_path = os.path.join(self.output_dir, "label.txt")
with open(label_file_path, "w") as f:
f.write(label_raw)
style_images/1.jpg NEATNESS
style_images/2.jpg 锁店君和宾馆
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from engine.synthesisers import DatasetSynthesiser
def synth_dataset():
dataset_synthesiser = DatasetSynthesiser()
dataset_synthesiser.synth_dataset()
if __name__ == '__main__':
synth_dataset()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import sys
import glob
from utils.config import ArgsParser
from engine.synthesisers import ImageSynthesiser
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
def synth_image():
args = ArgsParser().parse_args()
image_synthesiser = ImageSynthesiser()
style_image_path = args.style_image
img = cv2.imread(style_image_path)
text_corpus = args.text_corpus
language = args.language
synth_result = image_synthesiser.synth_image(text_corpus, img, language)
fake_fusion = synth_result["fake_fusion"]
fake_text = synth_result["fake_text"]
fake_bg = synth_result["fake_bg"]
cv2.imwrite("fake_fusion.jpg", fake_fusion)
cv2.imwrite("fake_text.jpg", fake_text)
cv2.imwrite("fake_bg.jpg", fake_bg)
def batch_synth_images():
image_synthesiser = ImageSynthesiser()
corpus_file = "../StyleTextRec_data/test_20201208/test_text_list.txt"
style_data_dir = "../StyleTextRec_data/test_20201208/style_images/"
save_path = "./output_data/"
corpus_list = []
with open(corpus_file, "rb") as fin:
lines = fin.readlines()
for line in lines:
substr = line.decode("utf-8").strip("\n").split("\t")
corpus_list.append(substr)
style_img_list = glob.glob("{}/*.jpg".format(style_data_dir))
corpus_num = len(corpus_list)
style_img_num = len(style_img_list)
for cno in range(corpus_num):
for sno in range(style_img_num):
corpus, lang = corpus_list[cno]
style_img_path = style_img_list[sno]
img = cv2.imread(style_img_path)
synth_result = image_synthesiser.synth_image(corpus, img, lang)
fake_fusion = synth_result["fake_fusion"]
fake_text = synth_result["fake_text"]
fake_bg = synth_result["fake_bg"]
for tp in range(2):
if tp == 0:
prefix = "%s/c%d_s%d_" % (save_path, cno, sno)
else:
prefix = "%s/s%d_c%d_" % (save_path, sno, cno)
cv2.imwrite("%s_fake_fusion.jpg" % prefix, fake_fusion)
cv2.imwrite("%s_fake_text.jpg" % prefix, fake_text)
cv2.imwrite("%s_fake_bg.jpg" % prefix, fake_bg)
cv2.imwrite("%s_input_style.jpg" % prefix, img)
print(cno, corpus_num, sno, style_img_num)
if __name__ == '__main__':
# batch_synth_images()
synth_image()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
import os
from argparse import ArgumentParser, RawDescriptionHelpFormatter
def override(dl, ks, v):
"""
Recursively replace dict of list
Args:
dl(dict or list): dict or list to be replaced
ks(list): list of keys
v(str): value to be replaced
"""
def str2num(v):
try:
return eval(v)
except Exception:
return v
assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
assert len(ks) > 0, ('lenght of keys should larger than 0')
if isinstance(dl, list):
k = str2num(ks[0])
if len(ks) == 1:
assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
dl[k] = str2num(v)
else:
override(dl[k], ks[1:], v)
else:
if len(ks) == 1:
#assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
if not ks[0] in dl:
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
dl[ks[0]] = str2num(v)
else:
assert ks[0] in dl, (
'({}) doesn\'t exist in {}, a new dict field is invalid'.
format(ks[0], dl))
override(dl[ks[0]], ks[1:], v)
def override_config(config, options=None):
"""
Recursively override the config
Args:
config(dict): dict to be replaced
options(list): list of pairs(key0.key1.idx.key2=value)
such as: [
'topk=2',
'VALID.transforms.1.ResizeImage.resize_short=300'
]
Returns:
config(dict): replaced config
"""
if options is not None:
for opt in options:
assert isinstance(opt, str), (
"option({}) should be a str".format(opt))
assert "=" in opt, (
"option({}) should contain a ="
"to distinguish between key and value".format(opt))
pair = opt.split('=')
assert len(pair) == 2, ("there can be only a = in the option")
key, value = pair
keys = key.split('.')
override(config, keys, value)
return config
class ArgsParser(ArgumentParser):
def __init__(self):
super(ArgsParser, self).__init__(
formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-c", "--config", help="configuration file to use")
self.add_argument(
"-t", "--tag", default="0", help="tag for marking worker")
self.add_argument(
'-o',
'--override',
action='append',
default=[],
help='config options to be overridden')
self.add_argument(
"--style_image", default="examples/style_images/1.jpg", help="tag for marking worker")
self.add_argument(
"--text_corpus", default="PaddleOCR", help="tag for marking worker")
self.add_argument(
"--language", default="en", help="tag for marking worker")
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
assert args.config is not None, \
"Please specify --config=configure_file_path."
return args
def load_config(file_path):
"""
Load config from yml/yaml file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: config
"""
ext = os.path.splitext(file_path)[1]
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
with open(file_path, 'rb') as f:
config = yaml.load(f, Loader=yaml.Loader)
return config
def gen_config():
base_config = {
"Global": {
"algorithm": "SRNet",
"use_gpu": True,
"start_epoch": 1,
"stage1_epoch_num": 100,
"stage2_epoch_num": 100,
"log_smooth_window": 20,
"print_batch_step": 2,
"save_model_dir": "./output/SRNet",
"use_visualdl": False,
"save_epoch_step": 10,
"vgg_pretrain": "./pretrained/VGG19_pretrained",
"vgg_load_static_pretrain": True
},
"Architecture": {
"model_type": "data_aug",
"algorithm": "SRNet",
"net_g": {
"name": "srnet_net_g",
"encode_dim": 64,
"norm": "batch",
"use_dropout": False,
"init_type": "xavier",
"init_gain": 0.02,
"use_dilation": 1
},
# input_nc, ndf, netD,
# n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
"bg_discriminator": {
"name": "srnet_bg_discriminator",
"input_nc": 6,
"ndf": 64,
"netD": "basic",
"norm": "none",
"init_type": "xavier",
},
"fusion_discriminator": {
"name": "srnet_fusion_discriminator",
"input_nc": 6,
"ndf": 64,
"netD": "basic",
"norm": "none",
"init_type": "xavier",
}
},
"Loss": {
"lamb": 10,
"perceptual_lamb": 1,
"muvar_lamb": 50,
"style_lamb": 500
},
"Optimizer": {
"name": "Adam",
"learning_rate": {
"name": "lambda",
"lr": 0.0002,
"lr_decay_iters": 50
},
"beta1": 0.5,
"beta2": 0.999,
},
"Train": {
"batch_size_per_card": 8,
"num_workers_per_card": 4,
"dataset": {
"delimiter": "\t",
"data_dir": "/",
"label_file": "tmp/label.txt",
"transforms": [{
"DecodeImage": {
"to_rgb": True,
"to_np": False,
"channel_first": False
}
}, {
"NormalizeImage": {
"scale": 1. / 255.,
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"order": None
}
}, {
"ToCHWImage": None
}]
}
}
}
with open("config.yml", "w") as f:
yaml.dump(base_config, f)
if __name__ == '__main__':
gen_config()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
__all__ = ['load_dygraph_pretrain']
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
if not os.path.exists(path + '.pdparams'):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
param_state_dict = paddle.load(path + '.pdparams')
model.set_state_dict(param_state_dict)
logger.info("load pretrained model from {}".format(path))
return
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import functools
import paddle.distributed as dist
logger_initialized = {}
@functools.lru_cache()
def get_logger(name='srnet', log_file=None, log_level=logging.INFO):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S")
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
file_handler = logging.FileHandler(log_file, 'a')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if dist.get_rank() == 0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
return logger
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def compute_mean_covariance(img):
batch_size = img.shape[0]
channel_num = img.shape[1]
height = img.shape[2]
width = img.shape[3]
num_pixels = height * width
# batch_size * channel_num * 1 * 1
mu = img.mean(2, keepdim=True).mean(3, keepdim=True)
# batch_size * channel_num * num_pixels
img_hat = img - mu.expand_as(img)
img_hat = img_hat.reshape([batch_size, channel_num, num_pixels])
# batch_size * num_pixels * channel_num
img_hat_transpose = img_hat.transpose([0, 2, 1])
# batch_size * channel_num * channel_num
covariance = paddle.bmm(img_hat, img_hat_transpose)
covariance = covariance / num_pixels
return mu, covariance
def dice_coefficient(y_true_cls, y_pred_cls, training_mask):
eps = 1e-5
intersection = paddle.sum(y_true_cls * y_pred_cls * training_mask)
union = paddle.sum(y_true_cls * training_mask) + paddle.sum(
y_pred_cls * training_mask) + eps
loss = 1. - (2 * intersection / union)
return loss
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import errno
import paddle
def get_check_global_params(mode):
check_params = [
'use_gpu', 'max_text_length', 'image_shape', 'image_shape',
'character_type', 'loss_type'
]
if mode == "train_eval":
check_params = check_params + [
'train_batch_size_per_card', 'test_batch_size_per_card'
]
elif mode == "test":
check_params = check_params + ['test_batch_size_per_card']
return check_params
def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
if use_gpu:
try:
if not paddle.is_compiled_with_cuda():
print(err)
sys.exit(1)
except:
print("Fail to check gpu state.")
sys.exit(1)
def _mkdir_if_not_exist(path, logger):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if not os.path.exists(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
logger.warning(
'be happy if some process has already created {}'.format(
path))
else:
raise OSError('Failed to mkdir {}'.format(path))
......@@ -36,12 +36,13 @@ Architecture:
algorithm: CRNN
Transform:
Backbone:
name: ResNet
layers: 34
name: MobileNetV3
scale: 0.5
model_name: large
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 256
hidden_size: 96
Head:
name: CTCHead
fc_decay: 0
......
......@@ -9,7 +9,7 @@
### 1.文本检测算法
PaddleOCR开源的文本检测算法列表:
- [x] DB([paper](https://arxiv.org/abs/1911.08947))(ppocr推荐)
- [x] DB([paper]( https://arxiv.org/abs/1911.08947) )(ppocr推荐)
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))
......@@ -38,9 +38,9 @@ PaddleOCR文本检测算法的训练和使用请参考文档教程中[模型训
### 2.文本识别算法
PaddleOCR基于动态图开源的文本识别算法列表:
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))(ppocr推荐)
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717) )(ppocr推荐)
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
- [ ] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1)) coming soon
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294)) coming soon
......
......@@ -62,9 +62,9 @@ PaddleOCR提供了训练脚本、评估脚本和预测脚本。
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
```
# GPU训练 支持单卡,多卡训练,通过selected_gpus指定卡号
# GPU训练 支持单卡,多卡训练,通过 '--gpus' 指定卡号,如果使用的paddle版本小于2.0rc1,请使用'--select_gpus'参数选择要使用的GPU
# 启动训练,下面的命令已经写入train.sh文件中,只需修改文件里的配置文件路径即可
python3 -m paddle.distributed.launch --selected_gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
```
- 数据增强
......@@ -74,7 +74,7 @@ PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入
默认的扰动方式有:颜色空间转换(cvtColor)、模糊(blur)、抖动(jitter)、噪声(Gasuss noise)、随机切割(random crop)、透视(perspective)、颜色反转(reverse),随机数据增强(RandAugment)。
训练过程中除随机数据增强外每种扰动方式以50%的概率被选择,具体代码实现请参考:
[rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py)
[rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py)
[randaugment.py](../../ppocr/data/imaug/randaugment.py)
*由于OpenCV的兼容性问题,扰动操作暂时只支持linux*
......
......@@ -107,17 +107,13 @@ PaddleOCR计算三个OCR检测相关的指标,分别是:Precision、Recall
运行如下代码,根据配置文件`det_db_mv3.yml``save_res_path`指定的测试集检测结果文件,计算评估指标。
评估时设置后处理参数`box_thresh=0.6``unclip_ratio=1.5`,使用不同数据集、不同模型训练,可调整这两个参数进行优化
```shell
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
评估时设置后处理参数`box_thresh=0.5``unclip_ratio=1.5`,使用不同数据集、不同模型训练,可调整这两个参数进行优化
训练中模型参数默认保存在`Global.save_model_dir`目录下。在评估指标时,需要设置`Global.checkpoints`指向保存的参数文件。
比如:
```shell
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.5 PostProcess.unclip_ratio=1.5
```
* 注:`box_thresh``unclip_ratio`是DB后处理所需要的参数,在评估EAST模型时不需要设置
## 测试检测效果
......
......@@ -22,9 +22,8 @@ inference 模型(`paddle.jit.save`保存的模型)
- [三、文本识别模型推理](#文本识别模型推理)
- [1. 超轻量中文识别模型推理](#超轻量中文识别模型推理)
- [2. 基于CTC损失的识别模型推理](#基于CTC损失的识别模型推理)
- [3. 基于Attention损失的识别模型推理](#基于Attention损失的识别模型推理)
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
- [5. 多语言模型的推理](#多语言模型的推理)
- [3. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
- [4. 多语言模型的推理](#多语言模型的推理)
- [四、方向分类模型推理](#方向识别模型推理)
- [1. 方向分类模型推理](#方向分类模型推理)
......@@ -129,24 +128,32 @@ python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.pretrained_mo
超轻量中文检测模型推理,可以执行如下命令:
```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/"
# 下载超轻量中文检测模型:
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar
tar xf ch_ppocr_mobile_v2.0_det_infer.tar
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./ch_ppocr_mobile_v2.0_det_infer/"
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
![](../imgs_results/det_res_2.jpg)
![](../imgs_results/det_res_22.jpg)
通过参数`limit_type``det_limit_side_len`来对图片的尺寸进行限制限,`limit_type=max`为限制长边长度<`det_limit_side_len`,`limit_type=min`为限制短边长度>`det_limit_side_len`,
图片不满足限制条件时(`limit_type=max`时长边长度>`det_limit_side_len``limit_type=min`时短边长度<`det_limit_side_len`),将对图片进行等比例缩放。
该参数默认设置为`limit_type='max',det_max_side_len=960`。 如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以执行如下命令:
通过参数`limit_type``det_limit_side_len`来对图片的尺寸进行限制
`litmit_type`可选参数为[`max`, `min`],
`det_limit_size_len` 为正整数,一般设置为32 的倍数,比如960。
参数默认设置为`limit_type='max', det_limit_side_len=960`。表示网络输入图像的最长边不能超过960,
如果超过这个值,会对图像做等宽比的resize操作,确保最长边为`det_limit_side_len`
设置为`limit_type='min', det_limit_side_len=960` 则表示限制图像的最短边为960。
如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以设置det_limit_side_len 为想要的值,比如1216:
```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1200
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
```
如果想使用CPU进行预测,执行命令如下
```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
```
<a name="DB文本检测模型推理"></a>
......@@ -268,16 +275,6 @@ CRNN 文本识别模型推理,可以执行如下命令:
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/rec_crnn/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
```
<a name="基于Attention损失的识别模型推理"></a>
### 3. 基于Attention损失的识别模型推理
基于Attention损失的识别模型与ctc不同,需要额外设置识别算法参数 --rec_algorithm="RARE"
RARE 文本识别模型推理,可以执行如下命令:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/rare/" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_algorithm="RARE"
```
![](../imgs_words_en/word_336.png)
执行命令后,上面图像的识别结果如下:
......@@ -297,7 +294,7 @@ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
```
### 4. 自定义文本识别字典的推理
### 3. 自定义文本识别字典的推理
如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径,并且设置 `rec_char_type=ch`
```
......@@ -305,7 +302,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
```
<a name="多语言模型的推理"></a>
### 5. 多语言模型的推理
### 4. 多语言模型的推理
如果您需要预测的是其他语言模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,
需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/` 路径下有默认提供的小语种字体,例如韩文识别:
......
......@@ -167,7 +167,7 @@ tar -xf rec_mv3_none_bilstm_ctc_v2.0_train.tar && rm -rf rec_mv3_none_bilstm_ctc
```
# GPU训练 支持单卡,多卡训练,通过--gpus参数指定卡号
# 训练icdar15英文数据 并将训练日志保存为 tain_rec.log
# 训练icdar15英文数据 训练日志会自动保存为 "{save_model_dir}" 下的train.log
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_icdar15_train.yml
```
<a name="数据增强"></a>
......@@ -200,11 +200,8 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
| rec_icdar15_train.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
| rec_mv3_none_bilstm_ctc.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
| rec_mv3_tps_bilstm_ctc.yml | STARNet | Mobilenet_v3 large 0.5 | tps | BiLSTM | ctc |
| rec_mv3_tps_bilstm_attn.yml | RARE | Mobilenet_v3 large 0.5 | tps | BiLSTM | attention |
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
| rec_r34_vd_tps_bilstm_ctc.yml | STARNet | Resnet34_vd | tps | BiLSTM | ctc |
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
......@@ -356,8 +353,7 @@ python3 tools/infer_rec.py -c configs/rec/rec_icdar15_train.yml -o Global.checkp
```
infer_img: doc/imgs_words/en/word_1.png
index: [19 24 18 23 29]
word : joint
result: ('joint', 0.9998967)
```
预测使用的配置文件必须与训练一致,如您通过 `python3 tools/train.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml` 完成了中文模型的训练,
......@@ -376,6 +372,5 @@ python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v
```
infer_img: doc/imgs_words/ch/word_1.jpg
index: [2092 177 312 2503]
word : 韩国小馆
result: ('韩国小馆', 0.997218)
```
......@@ -13,7 +13,7 @@ This tutorial lists the text detection algorithms and text recognition algorithm
PaddleOCR open source text detection algorithms list:
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))
- [x] DB([paper](https://arxiv.org/abs/1911.08947))
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))(Baidu Self-Research)
- [x] SAST([paper](https://arxiv.org/abs/1908.05498) )(Baidu Self-Research)
On the ICDAR2015 dataset, the text detection result is as follows:
......@@ -41,9 +41,9 @@ For the training guide and use of PaddleOCR text detection algorithms, please re
PaddleOCR open-source text recognition algorithms list:
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
- [ ] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1)) coming soon
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))(Baidu Self-Research) coming soon
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294) )(Baidu Self-Research) coming soon
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
......
......@@ -65,9 +65,9 @@ Start training:
```
# Set PYTHONPATH path
export PYTHONPATH=$PYTHONPATH:.
# GPU training Support single card and multi-card training, specify the card number through selected_gpus
# GPU training Support single card and multi-card training, specify the card number through --gpus. If your paddle version is less than 2.0rc1, please use '--selected_gpus'
# Start training, the following command has been written into the train.sh file, just modify the configuration file path in the file
python3 -m paddle.distributed.launch --selected_gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
```
- Data Augmentation
......@@ -77,7 +77,7 @@ PaddleOCR provides a variety of data augmentation methods. If you want to add di
The default perturbation methods are: cvtColor, blur, jitter, Gasuss noise, random crop, perspective, color reverse, RandAugment.
Except for RandAugment, each disturbance method is selected with a 50% probability during the training process. For specific code implementation, please refer to:
[rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py)
[rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py)
[randaugment.py](../../ppocr/data/imaug/randaugment.py)
......
文件模式从 100644 更改为 100755
......@@ -101,15 +101,11 @@ Run the following code to calculate the evaluation indicators. The result will b
When evaluating, set post-processing parameters `box_thresh=0.6`, `unclip_ratio=1.5`. If you use different datasets, different models for training, these two parameters should be adjusted for better result.
The model parameters during training are saved in the `Global.save_model_dir` directory by default. When evaluating indicators, you need to set `Global.checkpoints` to point to the saved parameter file.
```shell
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
The model parameters during training are saved in the `Global.save_model_dir` directory by default. When evaluating indicators, you need to set `Global.checkpoints` to point to the saved parameter file.
Such as:
```shell
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
* Note: `box_thresh` and `unclip_ratio` are parameters required for DB post-processing, and not need to be set when evaluating the EAST model.
......
......@@ -25,9 +25,8 @@ Next, we first introduce how to convert a trained model into an inference model,
- [TEXT RECOGNITION MODEL INFERENCE](#RECOGNITION_MODEL_INFERENCE)
- [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_RECOGNITION)
- [2. CTC-BASED TEXT RECOGNITION MODEL INFERENCE](#CTC-BASED_RECOGNITION)
- [3. ATTENTION-BASED TEXT RECOGNITION MODEL INFERENCE](#ATTENTION-BASED_RECOGNITION)
- [4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS)
- [5. MULTILINGUAL MODEL INFERENCE](MULTILINGUAL_MODEL_INFERENCE)
- [3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS)
- [4. MULTILINGUAL MODEL INFERENCE](MULTILINGUAL_MODEL_INFERENCE)
- [ANGLE CLASSIFICATION MODEL INFERENCE](#ANGLE_CLASS_MODEL_INFERENCE)
- [1. ANGLE CLASSIFICATION MODEL INFERENCE](#ANGLE_CLASS_MODEL_INFERENCE)
......@@ -135,24 +134,33 @@ Because EAST and DB algorithms are very different, when inference, it is necessa
For lightweight Chinese detection model inference, you can execute the following commands:
```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/"
# download DB text detection inference model
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar
tar xf ch_ppocr_mobile_v2.0_det_infer.tar
# predict
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/"
```
The visual text detection results are saved to the ./inference_results folder by default, and the name of the result file is prefixed with'det_res'. Examples of results are as follows:
![](../imgs_results/det_res_2.jpg)
![](../imgs_results/det_res_22.jpg)
The size of the image is limited by the parameters `limit_type` and `det_limit_side_len`, `limit_type=max` is to limit the length of the long side <`det_limit_side_len`, and `limit_type=min` is to limit the length of the short side>`det_limit_side_len`,
When the picture does not meet the restriction conditions (for `limit_type=max`and long side >`det_limit_side_len` or for `min` and short side <`det_limit_side_len`), the image will be scaled proportionally.
This parameter is set to `limit_type='max', det_max_side_len=960` by default. If the resolution of the input picture is relatively large, and you want to use a larger resolution prediction, you can execute the following command:
You can use the parameters `limit_type` and `det_limit_side_len` to limit the size of the input image,
The optional parameters of `litmit_type` are [`max`, `min`], and
`det_limit_size_len` is a positive integer, generally set to a multiple of 32, such as 960.
The default setting of the parameters is `limit_type='max', det_limit_side_len=960`. Indicates that the longest side of the network input image cannot exceed 960,
If this value is exceeded, the image will be resized with the same width ratio to ensure that the longest side is `det_limit_side_len`.
Set as `limit_type='min', det_limit_side_len=960`, it means that the shortest side of the image is limited to 960.
If the resolution of the input picture is relatively large and you want to use a larger resolution prediction, you can set det_limit_side_len to the desired value, such as 1216:
```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1200
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
```
If you want to use the CPU for prediction, execute the command as follows
```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
```
<a name="DB_DETECTION"></a>
......@@ -275,15 +283,6 @@ For CRNN text recognition model inference, execute the following commands:
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
```
<a name="ATTENTION-BASED_RECOGNITION"></a>
### 3. ATTENTION-BASED TEXT RECOGNITION MODEL INFERENCE
The recognition model based on Attention loss is different from ctc, and additional recognition algorithm parameters need to be set --rec_algorithm="RARE"
After executing the command, the recognition result of the above image is as follows:
```bash
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/rare/" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_algorithm="RARE"
```
![](../imgs_words_en/word_336.png)
After executing the command, the recognition result of the above image is as follows:
......@@ -303,7 +302,7 @@ dict_character = list(self.character_str)
```
<a name="USING_CUSTOM_CHARACTERS"></a>
### 4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY
### 3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY
If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`, and set `rec_char_type=ch`
```
......@@ -311,7 +310,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
```
<a name="MULTILINGUAL_MODEL_INFERENCE"></a>
### 5. MULTILINGAUL MODEL INFERENCE
### 4. MULTILINGAUL MODEL INFERENCE
If you need to predict other language models, when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results,
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/` path, such as Korean recognition:
......
......@@ -162,7 +162,7 @@ Start training:
```
# GPU training Support single card and multi-card training, specify the card number through --gpus
# Training icdar15 English data and saving the log as train_rec.log
# Training icdar15 English data and The training log will be automatically saved as train.log under "{save_model_dir}"
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_icdar15_train.yml
```
<a name="Data_Augmentation"></a>
......@@ -193,11 +193,8 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
| rec_icdar15_train.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
| rec_mv3_none_bilstm_ctc.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
| rec_mv3_tps_bilstm_ctc.yml | STARNet | Mobilenet_v3 large 0.5 | tps | BiLSTM | ctc |
| rec_mv3_tps_bilstm_attn.yml | RARE | Mobilenet_v3 large 0.5 | tps | BiLSTM | attention |
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
| rec_r34_vd_tps_bilstm_ctc.yml | STARNet | Resnet34_vd | tps | BiLSTM | ctc |
For training Chinese data, it is recommended to use
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
......@@ -350,8 +347,7 @@ Get the prediction result of the input image:
```
infer_img: doc/imgs_words/en/word_1.png
index: [19 24 18 23 29]
word : joint
result: ('joint', 0.9998967)
```
The configuration file used for prediction must be consistent with the training. For example, you completed the training of the Chinese model with `python3 tools/train.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml`, you can use the following command to predict the Chinese model:
......@@ -369,6 +365,5 @@ Get the prediction result of the input image:
```
infer_img: doc/imgs_words/ch/word_1.jpg
index: [2092 177 312 2503]
word : 韩国小馆
result: ('韩国小馆', 0.997218)
```
......@@ -180,7 +180,6 @@ class GridGenerator(nn.Layer):
P = self.build_P_paddle(I_r_size)
inv_delta_C_tensor = self.build_inv_delta_C_paddle(C).astype('float32')
# inv_delta_C_tensor = paddle.zeros((23,23)).astype('float32')
P_hat_tensor = self.build_P_hat_paddle(
C, paddle.to_tensor(P)).astype('float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册