提交 0ccab047 编写于 作者: F Feiyu Chan 提交者: Yibing Liu

init commit for deepvoice3 (#3458)

* ini commit for deepvoice, add tensorboard to requirements

* fix urls for code we adapted from

* fix makedirs for python2, fix README

* fix open with encoding for python2 compatability

* fix python2's str(), use encode for unicode, and str() for int

* fix python2 encoding issue, add model architecture and project structure for README

* add model structure, add explanation for hyperparameter priority order.
上级 70ccf385
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
files: (?!.*paddle)^.*$
- id: end-of-file-fixer
files: \.md$
- id: trailing-whitespace
files: \.md$
- repo: https://github.com/Lucas-C/pre-commit-hooks
sha: v1.0.1
hooks:
- id: forbid-crlf
files: \.md$
- id: remove-crlf
files: \.md$
- id: forbid-tabs
files: \.md$
- id: remove-tabs
files: \.md$
Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Part of code was copied or adpated from https://github.com/r9y9/deepvoice3_pytorch/
Copyright (c) 2017: Ryuichi Yamamoto, whose applies.
# Deep Voice 3 with Paddle Fluid
Paddle fluid implementation of DeepVoice 3, a convolutional network based text-to-speech synthesis model. The implementation is based on [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654).
We implement Deepvoice3 model in paddle fluid with dynamic graph, which is convenient for flexible network architectures.
## Installation
### Install paddlepaddle
For faster training speed and better support, it is recommended that you install the lasted develop version of paddlepaddle. You can either download the lasted dev wheel or build paddle from source.
1. Download lasted wheel. See [**Multi-version whl package list - dev**](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/install/Tables_en.html#multi-version-whl-package-list-dev) for more details.
2. Build paddlepaddle from source. See [**Compile From Source Code**](https://www.paddlepaddle.org.cn/documentation/docs/en/1.5/beginners_guide/install/compile/fromsource_en.html) for more details. Note that if you want to enable data parallel training for multiple GPUs, you should set `-DWITH_DISTRIBUTE=ON` with cmake.
### Other Requirements
Install other requirements with pip.
```bash
pip install -r requirements.txt
```
You also need to download punkt and cmudict for nltk, because we tokenize text with `punkt` and convert text into phonemes with `cmudict`.
```python
import nltk
nltk.download("punkt")
nltk.download("cmudict")
```
## Model Architecture
![DeepVoice3 model architecture](./_images/model_architecture.png)
The model consists of an encoder, a decoder and a converter (and a speaker embedding for multispeaker models). The encoder, together with the decoder forms the seq2seq part of the model, and the converter forms the postnet part.
## Project Structure
```text
├── audio.py # audio processing
├── compute_timestamp_ratio.py # script to compute position rate
├── conversion # parameter conversion from pytorch model
├── requirements.txt # requirements
├── hparams.py # HParam class for deepvoice3
├── hparam_tf # hyper parameter related stuffs
├── ljspeech.py # functions for ljspeech preprocessing
├── preprocess.py # preprocrssing script
├── presets # preset hyperparameters
├── deepvoice3_paddle # DeepVoice3 model implementation
├── eval_model.py # functions for model evaluation
├── synthesis.py # script for speech synthesis
├── train_model.py # functions for model training
└── train.py # script for model training
```
## Usage
There are many hyperparameters to be tuned depending on the specification of model and dataset you are working on. Hyperparameters that are known to work good are provided in the repository. See `presets` directory for details. Now we only provide preset with LJSpeech dataset (`deepvoice3_ljspeech.json`). Support for more models and datasets is pending.
Note that `preprocess.py`, `train.py` and `synthesis.py` all accept a `--preset` parameter. To ensure consistency, you should use the same preset for preprocessing, training and synthesizing.
Note that you can overwrite preset hyperparameters with command line argument `--hparams`, just pass several key-value pair in `${key}=${value}` format seperated by comma (`,`). For example `--hparams="batch_size=8, nepochs=500"` can overwrite default values in the preset json file. For more details about hyperparameters, see `hparams.py`, which contains the definition of `hparams`. Priority order of hyperparameters is command line option `--hparams` > `--preset` json configuration file > definition of hparams in `hparams.py`.
Some hyperparameters are only related to training, like `batch_size`, `checkpoint_interval` and you can use different values for preprocessing and training. But hyperparameters related to data preprocessing, like `num_mels` and `ref_level_db`, should be kept the same for preprocessing and training.
### Dataset
Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/).
```bash
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
tar xjvf LJSpeech-1.1.tar.bz2
```
Preprocessing with `preprocess.py`.
```bash
python preprocess.py \
--preset=${preset_json_path} \
--hparams="hyper parameters you want to overwrite" \
${name} ${in_dir} ${out_dir}
```
Now `${dataset_name}$` only supports `ljspeech`. Support for other datasets is pending.
Assuming that you use `presers/deepvoice3_ljspeech.json` for LJSpeech and the path of the unziped dataset is `~/data/LJSpeech-1.1`, then you can preprocess data with the following command.
```bash
python preprocess.py \
--preset=presets/deepvoice3_ljspeech.json \
ljspeech ~/data/LJSpeech-1.1/ ./data/ljspeech
```
When this is done, you will see extracted features in `./data/ljspeech` including:
1. text and corresponding file names for the extracted features in `train.txt`.
2. mel-spectrogram in `ljspeech-mel-*.npy` .
3. linear-spectrogram in `ljspeech-spec-*.npy`.
### Train on single GPU
Training the whole model on one single GPU:
```bash
export CUDA_VISIBLE_DEVICES=0
python train.py --data-root=${data-root} --use-gpu \
--preset=${preset_json_path} \
--hparams="parameters you may want to override"
```
For more details about `train.py`, see `python train.py --help`.
#### load checkpoints
You can load saved checkpoint and resume training with `--checkpoint`, if you wan to reset optimizer states, pass `--reset-optimizer` in addition.
#### train a part of the model
You can also train parts of the model while freezing other parts, by passing `--train-seq2seq-only` or `--train-postnet-only`. When training only parts of the model, other parts should be loaded from saved checkpoints.
To train only the `seq2seq` or `postnet`, you should load from a whole model with `--checkpoint`and keep the same configurations. Note that when training only the `postnet`, you should set `use_decoder_state_for_postnet_input=false`, because when train only the postnet, the postnet takes the ground truth mel-spectrogram as input.
example:
```bash
export CUDA_VISIBLE_DEVICES=0
python train.py --data-root=${data-root} --use-gpu \
--preset=${preset_json_path} \
--hparams="parameters you may want to override" \
--train-seq2seq-only \
--checkpoint=${path_of_the_saved_model}
```
### Training on multiple GPUs
Training on multiple GPUs with data parallel is enabled. You can run `train.py` with `paddle.distributed.launch` module. Here is the command line usage.
```bash
python -m paddle.distributed.launch \
--started_port ${port_of_the_first_worker} \
--selected_gpus ${logical_gpu_ids_to_choose} \
--log_dir ${path_of_write_log} \
training_script ...
```
`paddle.distributed.launch` parallelizes training in multiprocessing mode.`--selected_gpus` means the logical ids of the selected GPUs, and `started_port` means the port used by the first worker. Outputs of each worker are saved in `--log_dir.` Then follows the command for training on a single GPU, except that you should pass `--use-data-paralle` in addition.
```bash
export CUDA_VISIBLE_DEVICES=2,3,4,5 # The IDs of visible physical devices
python -m paddle.distributed.launch \
--selected_gpus=0,1,2,3 --log_dir ${multi_gpu_log_dir} \
train.py --data-root=${data-root} \
--use-gpu --use-data-parallel \
--preset=${preset_json_path} \
--hparams="parameters you may want to override"
```
In the example above, we set only GPU `2, 3, 4, 5` to be visible. Then `--selected_gpus="0, 1, 2, 3"` means the logical ids of the selected gpus, which correpond to GPU `2, 3, 4, 5`.
Model checkpoints (directory ending with `.model`) are saved in `./checkpoints` per 10000 steps by default. Layer-wise averaged attention alignments (.png) are saved in `.checkpointys/alignment_ave`. And alignments for each attention layer are saved in `.checkpointys/alignment_layer{attention_layer_num}` per 10000 steps for inspection.
Synthesis results of 6 sentences (hardcoded in `eval_model.py`) are saved in `checkpoints/eval`, including `step{step_num}_text{text_id}_single_alignment.png` for averaged alignments and `step{step_num}_text{text_id}_single_predicted.wav` for the predicted waveforms.
### Monitor with Tensorboard
Logs with tensorboard are saved in `./log/${datetime}` directory by default. You can monitor logs by tensorboard.
```bash
tensorboard --logdir=${log_dir} --host=$HOSTNAME --port=8888
```
### Synthesize from a checkpoint
Given a list of text, `synthesis.py` synthesize audio signals from a trained model.
```bash
python infer.py --use-gpu --preset=${preset_json_path} \
--hparams="parameters you may want to override" \
${checkpoint} ${text_list_file} ${dst_dir}}
```
Example test_list.txt:
```text
Generative adversarial network or variational auto-encoder.
Once upon a time there was a dear little girl who was loved by every one who looked at her, but most of all by her grandmother, and there was nothing that she would not have given to the child.
A text-to-speech synthesis system typically consists of multiple stages, such as a text analysis frontend, an acoustic model and an audio synthesis module.
```
generated waveform files and alignment files are saved in `${dst_dir}`.
### Compute position ratio
According to [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654), the position rate is different for different datasets. There are 2 position rates, one for the query and the other for the key, which are referred to as $\omega_1$ and $\omega_2$ in th paper, and the corresponding names in preset json are `query_position_rate` and `key_position_rate`.
For example, the `query_position_rate` and `key_position_rate` for LJSpeech are `1.0` and `1.385`, respectively. These values are computed with `compute_timestamp_ratio.py`. Run the command below.
```bash
python compute_timestamp_ratio.py --preset=${preset_json_path} \
--hparams="parameters you may want to override" ${data_root}
```
You will get outputs like this.
```text
100%|██████████████████████████████████████████████████████████| 13047/13047 [00:12<00:00, 1058.19it/s]
1345587 1863884.0 1.3851828235558161
```
Then set the `key_position_rate=1.385` and `query_position_rate=1.0` in the preset.
## Acknowledgement
We thankfully included and adapted some files r9y9's from [deepvoice3_pytorch](https://github.com/r9y9/deepvoice3_pytorch).
# Deep Voice 3 with Paddle Fluid
Paddle 实现的 Deepvoice3,一个基于卷积神经网络的语音合成 (Text to Speech) 模型。本实现基于 [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654)
本 Deepvoice3 实现使用 Paddle 动态图模式,这对于灵活的网络结构更为方便。
## 安装
### 安装 paddlepaddle 框架
为了更快的训练速度和更好的支持,我们推荐使用最新的开发版 paddle。用户可以最新编译的开发版 whl 包,也可以选择从源码编译 Paddle。
1. 下载最新编译的开发版 whl 包。可以从 [**多版本 wheel 包列表-dev**](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/Tables.html#whl-dev) 页面中选择合适的版本。
2. 从源码编译 Paddle. 参考[**从源码编译**](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/compile/fromsource.html) 页面。注意,如果你需要使用多卡训练,那么编译前需要设置选项 `-DWITH_DISTRIBUTE=ON`
### 其他依赖
使用 pip 安装其他依赖。
```bash
pip install -r requirements.txt
```
另外需要下载 nltk 的两个库,因为使用了 `punkt` 对文本进行 tokenization,并且使用了 `cmudict` 来将文本转为音位。
```python
import nltk
nltk.download("punkt")
nltk.download("cmudict")
```
## 模型结构
![DeepVoice3 模型结构](./_images/model_architecture.png)
模型包含 encoder, decoder, converter 几个部分,对于 multispeaker 数据集,还有一个 speaker embedding。其中 encoder 和 decoder 构成 seq2seq 部分,converter 构成 postnet 部分。
## 项目结构
```text
├── audio.py # 用于处理处理音频的函数
├── compute_timestamp_ratio.py # 计算 position rate 的脚本
├── conversion # 用于转换 pytorch 实现的参数
├── requirements.txt # 项目依赖
├── hparams.py # DeepVoice3 运行超参数配置类的定义
├── hparam_tf # 超参数相关
├── ljspeech.py # ljspeech 数据集预处理
├── preprocess.py # 通用预处理脚本
├── presets # 预设超参数配置
├── deepvoice3_paddle # DeepVoice3 模型实现的主要文件
├── eval_model.py # 模型测评相关函数
├── synthesis.py # 用于语音合成的脚本
├── train_model.py # 模型训练相关函数
└── train.py # 用于模型训练的脚本
```
## 使用方法
根据所使用的模型配置和数据集的不同,有不少超参数需要进行调节。我们提供已知结果较好的超参数设置,详见 `presets` 文件夹。目前我们只提供 LJSpeech 的预设配置 (`deepvoice3_ljspeech.json`)。后续将提供更多模型和数据集的预设配置。
`preprocess.py``train.py``synthesis.py` 都接受 `--preset` 参数。为了保持一致性,最好在数据预处理,模型训练和语音合成时使用相同的预设配置。
可以通过 `--hparams` 参数来覆盖预设的超参数配置,参数格式是逗号分隔的键值对 `${key}=${value}`,例如 `--hparams="batch_size=8, nepochs=500"`。关于超参数设置更多细节可以参考 `hparams.py` ,其中定义了 hparams。超参数的优先级序列是:通过命令行参数 `--hparams` 传入的参数优先级高于通过 `--preset` 参数传入的 json 配置文件,高于 `hparams.py` 中的定义。
部分参数可以只和训练有关,如 `batch_size`, `checkpoint_interval`, 用户在训练时可以使用不同的值。但部分参数和数据预处理相关,如 `num_mels``ref_level_db`, 这些参数在数据预处理和训练时候应该保持一致。
关于超参数设置更多细节可以参考 `hparams.py` ,其中定义了 hparams。超参数的优先级序列是:通过命令行参数 `--hparams` 传入的参数优先级高于通过 `--preset` 参数传入的 json 配置文件,高于 `hparams.py` 中的定义。
### 数据集
下载并解压 [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) 数据集。
```bash
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
tar xjvf LJSpeech-1.1.tar.bz2
```
使用 `preprocess.py`进行预处理。
```bash
python preprocess.py \
--preset=${preset_json_path} \
--hparams="hyper parameters you want to overwrite" \
${name} ${in_dir} ${out_dir}
```
目前 `${dataset_name}$` 只支持 `ljspeech`。未来将会支持更多数据集。
假设你使用 `presers/deepvoice3_ljspeech.json` 作为处理 LJSpeech 的预设配置文件,并且解压后的数据集位于 `~/data/LJSpeech-1.1`, 那么使用如下的命令进行数据预处理。
```bash
python preprocess.py \
--preset=presets/deepvoice3_ljspeech.json \
ljspeech ~/data/LJSpeech-1.1/ ./data/ljspeech
```
数据处理完成后,你会在 `./data/ljspeech` 看到提取的特征,包含如下文件。
1. `train.txt`,包含文本和对应的音频特征的文件名。
2. `ljspeech-mel-*.npy`,包含 mel 频谱。
3. `ljspeech-spec-*.npy`,包含线性频谱。
### 使用 GPU 单卡训练
在单个 GPU 上训练整个模型的使用方法如下。
```bash
export CUDA_VISIBLE_DEVICES=0
python train.py --data-root=${data-root} --use-gpu \
--preset=${preset_json_path} \
--hparams="parameters you may want to override"
```
用于可以通过 `python train.py --help` 查看 `train.py` 的详细使用方法。
#### 加载保存的模型
用户可以通过 `--checkpoint` 参数加载保存的模型并恢复训练。如果你想要重置优化器的状态,在训练脚本加入 `--reset-optimizer` 参数。
#### 训练模型的一部分
用户可以通过 `--train-seq2seq-only` 或者 `--train-postnet-only` 来实现固定模型的其他部分,只训练需要训练的部分。但当只训练模型的一部分时,其他的部分需要从保存的模型中加载。
当只训练模型的 `seq2seq` 部分或者 `postnet` 部分时,需要使用 `--checkpoint` 加载整个模型并保持相同的配置。注意,当只训练 `postnet` 的时候,需要保证配置中的`use_decoder_state_for_postnet_input=false`,因为在这种情况下,postnet 使用真实的 mel 频谱作为输入。
示例:
```bash
export CUDA_VISIBLE_DEVICES=0
python train.py --data-root=${data-root} --use-gpu \
--preset=${preset_json_path} \
--hparams="parameters you may want to override" \
--train-seq2seq-only \
--checkpoint=${path_of_the_saved_model}
```
### 使用 GPU 多卡训练
本模型支持使用多个 GPU 通过数据并行的方式 训练。方法是使用 `paddle.distributed.launch` 模块来启动 `train.py`
```bash
python -m paddle.distributed.launch \
--started_port ${port_of_the_first_worker} \
--selected_gpus ${logical_gpu_ids_to_choose} \
--log_dir ${path_of_write_log} \
training_script ...
```
paddle.distributed.launch 通过多进程的方式进行并行训练。`--selected_gpus` 指的是选择的 GPU 的逻辑序号,`started_port` 指的是 0 号显卡的使用的端口号,`--log_dir` 是日志保存的目录,每个进程的输出会在这个文件夹中保存为单独的文件。再在后面接上需要启动的脚本文件及其参数即可。这部分和单卡训练的脚本一致,但是需要传入 `--use-data-paralle` 以使用数据并行训练。示例命令如下。
```bash
export CUDA_VISIBLE_DEVICES=2,3,4,5 # The IDs of visible physical devices
python -m paddle.distributed.launch \
--selected_gpus=0,1,2,3 --log_dir ${multi_gpu_log_dir} \
train.py --data-root=${data-root} \
--use-gpu --use-data-parallel \
--preset=${preset_json_path} \
--hparams="parameters you may want to override"
```
上述的示例中,设置了 `2, 3, 4, 5` 号显卡为可见的 GPU。然后 `--selected_gpus=0,1,2,3` 选择的是 GPU 的逻辑序号,分别对应于 `2, 3, 4, 5` 号卡。
模型默认被保存为后缀为 `.model`的文件夹,保存在 `./checkpoints` 文件夹中。多层平均的注意力机制对齐结果被保存为 `.png` 图片,默认保存在 `.checkpointys/alignment_ave` 中。每一层的注意力机制对齐结果默认被保存在 `.checkpointys/alignment_layer{attention_layer_num}`文件夹中。默认每 10000 步保存一次用于查看。
对 6 个给定的句子的语音合成结果保存在 `checkpoints/eval` 中,包含多层平均平均的注意力机制对齐结果,这被保存为名为 `step{step_num}_text{text_id}_single_alignment.png` 的图片;以及合成的音频文件,保存为名为 `step{step_num}_text{text_id}_single_predicted.wav` 的音频。
### 使用 Tensorboard 查看训练
Tensorboard 训练日志默认被保存在 `./log/${datetime}` 文件夹,可以通过 tensorboard 查看。使用方法如下。
```bash
tensorboard --logdir=${log_dir} --host=$HOSTNAME --port=8888
```
### 从保存的模型合成语音
给定一组文本,使用 `synthesis.py` 从一个训练好的模型来合成语音,使用方法如下。
```bash
python infer.py --use-gpu --preset=${preset_json_path} \
--hparams="parameters you may want to override" \
${checkpoint} ${text_list_file} ${dst_dir}}
```
示例文本文件如下:
```text
Generative adversarial network or variational auto-encoder.
Once upon a time there was a dear little girl who was loved by every one who looked at her, but most of all by her grandmother, and there was nothing that she would not have given to the child.
A text-to-speech synthesis system typically consists of multiple stages, such as a text analysis frontend, an acoustic model and an audio synthesis module.
```
合成的结果包含注意力机制对齐结果和音频文件,保存于 `${dst_dir}`
### 计算 position rate
根据 [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654), 对于不同的数据集,会有不同的 position rate. 有两个不同的 position rate,一个用于 query 一个用于 key, 这在论文中称为 $\omega_1$ 和 $\omega_2$ ,在预设配置文件中的名字分别为 `query_position_rate``key_position_rate`
比如 LJSpeech 数据集的 `query_position_rate``key_position_rate` 分别为 `1.0``1.385`。这些值可以 `compute_timestamp_ratio.py`。使用如下命令计算。
```bash
python compute_timestamp_ratio.py --preset=${preset_json_path} \
--hparams="parameters you may want to override" ${data_root}
```
可以得到如下的结果。
```text
100%|██████████████████████████████████████████████████████████| 13047/13047 [00:12<00:00, 1058.19it/s]
1345587 1863884.0 1.3851828235558161
```
然后在预设配置文件中设置 `key_position_rate=1.385` 以及 `query_position_rate=1.0`
## 致谢
本实现包含及改写了 r9y9's 的 [deepvoice3_pytorch](https://github.com/r9y9/deepvoice3_pytorch) 中的部分文件,在此表示感谢。
# This file was copied from https://github.com/r9y9/deepvoice3_pytorch/tree/master/audio.py
# Copyright (c) 2017: Ryuichi Yamamoto.
import librosa
import librosa.filters
import math
import numpy as np
from scipy import signal
from hparams import hparams
from scipy.io import wavfile
import lws
def load_wav(path):
return librosa.core.load(path, sr=hparams.sample_rate)[0]
def save_wav(wav, path):
wav = wav * 32767 / max(0.01, np.max(np.abs(wav)))
wavfile.write(path, hparams.sample_rate, wav.astype(np.int16))
def preemphasis(x):
from nnmnkwii.preprocessing import preemphasis
return preemphasis(x, hparams.preemphasis)
def inv_preemphasis(x):
from nnmnkwii.preprocessing import inv_preemphasis
return inv_preemphasis(x, hparams.preemphasis)
def spectrogram(y):
D = _lws_processor().stft(preemphasis(y)).T
S = _amp_to_db(np.abs(D)) - hparams.ref_level_db
return _normalize(S)
def inv_spectrogram(spectrogram):
'''Converts spectrogram to waveform using librosa'''
S = _db_to_amp(_denormalize(spectrogram) +
hparams.ref_level_db) # Convert back to linear
processor = _lws_processor()
D = processor.run_lws(S.astype(np.float64).T**hparams.power)
y = processor.istft(D).astype(np.float32)
return inv_preemphasis(y)
def melspectrogram(y):
D = _lws_processor().stft(preemphasis(y)).T
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hparams.ref_level_db
if not hparams.allow_clipping_in_normalization:
assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
return _normalize(S)
def _lws_processor():
return lws.lws(hparams.fft_size, hparams.hop_size, mode="speech")
# Conversions:
_mel_basis = None
def _linear_to_mel(spectrogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = _build_mel_basis()
return np.dot(_mel_basis, spectrogram)
def _build_mel_basis():
if hparams.fmax is not None:
assert hparams.fmax <= hparams.sample_rate // 2
return librosa.filters.mel(hparams.sample_rate,
hparams.fft_size,
fmin=hparams.fmin,
fmax=hparams.fmax,
n_mels=hparams.num_mels)
def _amp_to_db(x):
min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(min_level, x))
def _db_to_amp(x):
return np.power(10.0, x * 0.05)
def _normalize(S):
return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1)
def _denormalize(S):
return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db
# Part of code was adpated from https://github.com/r9y9/deepvoice3_pytorch/tree/master/compute_timestamp_ratio.py
# Copyright (c) 2017: Ryuichi Yamamoto.
import argparse
import sys
import numpy as np
from hparams import hparams, hparams_debug_string
from deepvoice3_paddle.data import TextDataSource, MelSpecDataSource
from nnmnkwii.datasets import FileSourceDataset
from tqdm import trange
from deepvoice3_paddle import frontend
def build_parser():
parser = argparse.ArgumentParser(
description="Compute output/input timestamp ratio.")
parser.add_argument(
"--hparams", type=str, default="", help="Hyper parameters.")
parser.add_argument(
"--preset",
type=str,
required=True,
help="Path of preset parameters (json).")
parser.add_argument("data_root", type=str, help="path of the dataset.")
return parser
if __name__ == "__main__":
parser = build_parser()
args, _ = parser.parse_known_args()
data_root = args.data_root
preset = args.preset
# Load preset if specified
if preset is not None:
with open(preset) as f:
hparams.parse_json(f.read())
# Override hyper parameters
hparams.parse(args.hparams)
assert hparams.name == "deepvoice3"
# Code below
X = FileSourceDataset(TextDataSource(data_root))
Mel = FileSourceDataset(MelSpecDataSource(data_root))
in_sizes = []
out_sizes = []
for i in trange(len(X)):
x, m = X[i], Mel[i]
if X.file_data_source.multi_speaker:
x = x[0]
in_sizes.append(x.shape[0])
out_sizes.append(m.shape[0])
in_sizes = np.array(in_sizes)
out_sizes = np.array(out_sizes)
input_timestamps = np.sum(in_sizes)
output_timestamps = np.sum(
out_sizes) / hparams.outputs_per_step / hparams.downsample_step
print(input_timestamps, output_timestamps,
output_timestamps / input_timestamps)
sys.exit(0)
# Parameter conversion
## generate name map
To convert a model trained with `https://github.com/r9y9/deepvoice3_pytorch`, we provide a script to generate name map between pytorch model and paddle model for `deepvoice3`. You can provide `--preset` and `--hparams` to specify the model's configuration.
```bash
python generate_name_map.py --preset=${preset_to_use} --haprams="hyper parameters to overwrite"
```
It would print a name map. The format of the name map file looks like this. Each line consists of 3 fields, the first is the name of a parameter in the saved state dict of pytorch model, the second and third is the name and shape of the corresponding parameter in the saved state dict of paddle.
```
seq2seq.encoder.embed_tokens.weight encoder/Encoder_0/Embedding_0.w_0 [149, 256]
seq2seq.encoder.convolutions.0.bias encoder/Encoder_0/ConvProj1D_1/Conv2D_0.b_0 [512]
seq2seq.encoder.convolutions.0.weight_g encoder/Encoder_0/ConvProj1D_1/Conv2D_0.w_1 [512]
```
Redirect the output to a file to save it.
```bash
python generate_name_map.py --preset=${preset_to_use} --haprams="hyper parameters to overwrite" > name_map.txt
```
## convert saved pytorch model to paddle model
Given a name map and a saved pytorch model, you can convert it to a paddle model.
```bash
python convert.py \
--pytorch-model ${pytorch_model.pth} \
--paddle-model ${path_to_save_paddle_model} \
--name-map ${name_map_path}
```
Note that the user should provide the name map file, and ensure the models are equivalent to each other and corresponding parameters have the right shapes.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
import paddle
from paddle import fluid
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--pytorch-model",
dest='pytorch_model',
type=str,
help="The source pytorch mode.")
parser.add_argument(
"--paddle-model",
dest='paddle_model',
type=str,
help="The directory to save paddle model, now saves model as a folder.")
parser.add_argument(
"--name-map",
dest="name_map",
type=str,
help="name mapping for the source model and the target model.")
def read_name_map(fname):
"""
There should be a 3-column file.
The first comuln is the name of parameter in pytorch model's state dict;
The second column is the name of parameter in paddle model's state dict;
The third column is the shape of the repective parameter in paddle model.
"""
name_map = {}
with open(fname, 'rt') as f:
for line in f:
src_key, tgt_key, tgt_shape = line.strip().split('\t')
tgt_shape = eval(tgt_shape)
name_map[src_key] = (tgt_key, tgt_shape)
return name_map
def torch2paddle(state_dict, name_map, dirname):
"""
state_dict: pytorch model's state dict.
name_map: a text file for name mapping from pytorch model to paddle model.
dirname: path of the paddle model to save.
"""
program = fluid.Program()
global_block = program.global_block()
for k in state_dict.keys():
global_block.create_parameter(
name=name_map[k][0],
shape=[1],
dtype='float32',
initializer=fluid.initializer.Constant(value=0.0))
place = fluid.core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
exe.run(program)
# NOTE: transpose the pytorch model's parameter if neccessary
# we do not transpose here because we used conv instead of FC layer to replace Linear in pytorch,
# which does not need us to transpose the paramerters.
# but when you use a FC layer corresponding a torch Linear module, be sure to transpose the weight.
# Other transformations are not concerned, but users should check the data shape to ensure that
# the transformations are what's expected.
for k, v in state_dict.items():
fluid.global_scope().find_var(name_map[k][0]).get_tensor().set(
v.cpu().numpy().reshape(name_map[k][1]), place)
fluid.io.save_params(exe, dirname, main_program=program)
if __name__ == "__main__":
args, _ = parser.parse_known_args()
result = torch.load(args.pytorch_model)
state_dict = result["state_dict"]
name_map = read_name_map(args.name_map)
torch2paddle(state_dict, name_map, args.paddle_model)
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from deepvoice3_paddle.deepvoice3 import DeepVoiceTTS, ConvSpec, WindowRange
def deepvoice3(n_vocab,
embed_dim=256,
mel_dim=80,
linear_dim=513,
r=4,
downsample_step=1,
n_speakers=1,
speaker_dim=16,
padding_idx=0,
dropout=(1 - 0.96),
filter_size=5,
encoder_channels=128,
decoder_channels=256,
converter_channels=256,
query_position_rate=1.0,
key_position_rate=1.29,
use_memory_mask=False,
trainable_positional_encodings=False,
force_monotonic_attention=True,
use_decoder_state_for_postnet_input=True,
max_positions=512,
embedding_weight_std=0.1,
speaker_embedding_weight_std=0.01,
freeze_embedding=False,
window_range=WindowRange(-1, 3),
key_projection=False,
value_projection=False):
time_upsampling = max(downsample_step, 1)
h = encoder_channels
k = filter_size
encoder_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1), ConvSpec(h, k, 3))
h = decoder_channels
prenet_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3))
attentive_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1))
attention = [True, False, False, False, True]
h = converter_channels
postnet_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(2 * h, k, 1), ConvSpec(2 * h, k, 3))
model = DeepVoiceTTS(
"dv3", n_speakers, speaker_dim, speaker_embedding_weight_std, n_vocab,
embed_dim, padding_idx, embedding_weight_std, freeze_embedding,
encoder_convolutions, max_positions, padding_idx,
trainable_positional_encodings, mel_dim, r, prenet_convolutions,
attentive_convolutions, attention, use_memory_mask,
force_monotonic_attention, query_position_rate, key_position_rate,
window_range, key_projection, value_projection, linear_dim,
postnet_convolutions, time_upsampling, dropout,
use_decoder_state_for_postnet_input, "float32")
return model
def deepvoice3_multispeaker(n_vocab,
embed_dim=256,
mel_dim=80,
linear_dim=513,
r=4,
downsample_step=1,
n_speakers=1,
speaker_dim=16,
padding_idx=0,
dropout=(1 - 0.96),
filter_size=5,
encoder_channels=128,
decoder_channels=256,
converter_channels=256,
query_position_rate=1.0,
key_position_rate=1.29,
use_memory_mask=False,
trainable_positional_encodings=False,
force_monotonic_attention=True,
use_decoder_state_for_postnet_input=True,
max_positions=512,
embedding_weight_std=0.1,
speaker_embedding_weight_std=0.01,
freeze_embedding=False,
window_range=WindowRange(-1, 3),
key_projection=False,
value_projection=False):
time_upsampling = max(downsample_step, 1)
h = encoder_channels
k = filter_size
encoder_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1), ConvSpec(h, k, 3))
h = decoder_channels
prenet_convolutions = (ConvSpec(h, k, 1))
attentive_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1))
attention = [True, False, False, False, False]
h = converter_channels
postnet_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(2 * h, k, 1), ConvSpec(2 * h, k, 3))
model = DeepVoiceTTS(
"dv3", n_speakers, speaker_dim, speaker_embedding_weight_std, n_vocab,
embed_dim, padding_idx, embedding_weight_std, freeze_embedding,
encoder_convolutions, max_positions, padding_idx,
trainable_positional_encodings, mel_dim, r, prenet_convolutions,
attentive_convolutions, attention, use_memory_mask,
force_monotonic_attention, query_position_rate, key_position_rate,
window_range, key_projection, value_projection, linear_dim,
postnet_convolutions, time_upsampling, dropout,
use_decoder_state_for_postnet_input, "float32")
return model
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle
from paddle import fluid
import paddle.fluid.dygraph as dg
from deepvoice3_paddle.weight_norm import Conv2D, Conv2DTranspose
class Conv1D(dg.Layer):
"""
A convolution 1D block implemented with Conv2D. Form simplicity and
ensuring the output has the same length as the input, it does not allow
stride > 1.
"""
def __init__(self,
name_scope,
in_cahnnels,
num_filters,
filter_size=3,
dilation=1,
groups=None,
causal=False,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
dtype="float32"):
super(Conv1D, self).__init__(name_scope, dtype=dtype)
if causal:
padding = dilation * (filter_size - 1)
else:
padding = (dilation * (filter_size - 1)) // 2
self.in_channels = in_cahnnels
self.num_filters = num_filters
self.filter_size = filter_size
self.dilation = dilation
self.causal = causal
self.padding = padding
self.act = act
self.conv = Conv2D(
self.full_name(),
num_filters=num_filters,
filter_size=(1, filter_size),
stride=(1, 1),
dilation=(1, dilation),
padding=(0, padding),
groups=groups,
param_attr=param_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
def forward(self, x):
"""
Args:
x (Variable): Shape(B, C_in, 1, T), the input, where C_in means
input channels.
Returns:
x (Variable): Shape(B, C_out, 1, T), the outputs, where C_out means
output channels (num_filters).
"""
x = self.conv(x)
if self.filter_size > 1:
if self.causal:
x = fluid.layers.slice(
x, axes=[3], starts=[0], ends=[-self.padding])
elif self.filter_size % 2 == 0:
x = fluid.layers.slice(x, axes=[3], starts=[0], ends=[-1])
return x
def start_new_sequence(self):
self.temp_weight = None
self.input_buffer = None
def add_input(self, x):
"""
Adding input for a time step and compute an output for a time step.
Args:
x (Variable): Shape(B, C_in, 1, T), the input, where C_in means
input channels, and T = 1.
Returns:
out (Variable): Shape(B, C_out, 1, T), the outputs, where C_out
means output channels (num_filters), and T = 1.
"""
if self.temp_weight is None:
self.temp_weight = self._reshaped_weight()
window_size = 1 + (self.filter_size - 1) * self.dilation
batch_size = x.shape[0]
in_channels = x.shape[1]
if self.filter_size > 1:
if self.input_buffer is None:
self.input_buffer = fluid.layers.fill_constant(
[batch_size, in_channels, 1, window_size - 1],
dtype=x.dtype,
value=0.0)
else:
self.input_buffer = self.input_buffer[:, :, :, 1:]
self.input_buffer = fluid.layers.concat(
[self.input_buffer, x], axis=3)
x = self.input_buffer
if self.dilation > 1:
if not hasattr(self, "indices"):
self.indices = dg.to_variable(
np.arange(0, window_size, self.dilation))
tmp = fluid.layers.transpose(
self.input_buffer, perm=[3, 1, 2, 0])
tmp = fluid.layers.gather(tmp, index=self.indices)
tmp = fluid.layers.transpose(tmp, perm=[3, 1, 2, 0])
x = tmp
inputs = fluid.layers.reshape(
x, shape=[batch_size, in_channels * 1 * self.filter_size])
out = fluid.layers.matmul(inputs, self.temp_weight, transpose_y=True)
out = fluid.layers.elementwise_add(out, self.conv._bias_param, axis=-1)
out = fluid.layers.reshape(out, out.shape + [1, 1])
out = self._helper.append_activation(out, act=self.act)
return out
def _reshaped_weight(self):
"""
Get the linearized weight of convolution filter, cause it is by nature
a matmul weight. And because the model uses weight norm, compute the
weight by weight_v * weight_g to make it faster.
Returns:
weight_matrix (Variable): Shape(C_out, C_in * 1 * kernel_size)
"""
shape = self.conv._filter_param_v.shape
matrix_shape = [shape[0], np.prod(shape[1:])]
weight_matrix = fluid.layers.reshape(
self.conv._filter_param_v, shape=matrix_shape)
weight_matrix = fluid.layers.elementwise_mul(
fluid.layers.l2_normalize(
weight_matrix, axis=1),
self.conv._filter_param_g,
axis=0)
return weight_matrix
class Conv1DTranspose(dg.Layer):
"""
A convolutional transpose 1D block implemented with convolutional transpose
2D. It does not ensure that the output is exactly expanded stride times in
time dimension.
"""
def __init__(self,
name_scope,
in_channels,
num_filters,
filter_size,
padding=0,
stride=1,
dilation=1,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
dtype="float32"):
super(Conv1DTranspose, self).__init__(name_scope, dtype=dtype)
self.in_channels = in_channels
self.num_filters = num_filters
self.filter_size = filter_size
self.padding = padding
self.stride = stride
self.dilation = dilation
self.groups = groups
self.conv_transpose = Conv2DTranspose(
self.full_name(),
num_filters,
filter_size=(1, filter_size),
padding=(0, padding),
stride=(1, stride),
dilation=(1, dilation),
groups=groups,
param_attr=param_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
def forward(self, x):
"""
Argss:
x (Variable): Shape(B, C_in, 1, T_in), where C_in means the input
channels and T_in means the number of time steps of input.
Returns:
out (Variable): shape(B, C_out, 1, T_out), where C_out means the
output channels and T_out means the number of time steps of
input.
"""
return self.conv_transpose(x)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import random
import platform
from os.path import dirname, join
from nnmnkwii.datasets import FileSourceDataset, FileDataSource
from os.path import join, expanduser
import random
# import global hyper parameters
from hparams import hparams
from deepvoice3_paddle import frontend, builder
_frontend = getattr(frontend, hparams.frontend)
def _pad(seq, max_len, constant_values=0):
return np.pad(seq, (0, max_len - len(seq)),
mode="constant",
constant_values=constant_values)
def _pad_2d(x, max_len, b_pad=0):
x = np.pad(x, [(b_pad, max_len - len(x) - b_pad), (0, 0)],
mode="constant",
constant_values=0)
return x
class TextDataSource(FileDataSource):
def __init__(self, data_root, speaker_id=None):
self.data_root = data_root
self.speaker_ids = None
self.multi_speaker = False
# If not None, filter by speaker_id
self.speaker_id = speaker_id
def collect_files(self):
meta = join(self.data_root, "train.txt")
with open(meta, "rb") as f:
lines = f.readlines()
l = lines[0].decode("utf-8").split("|")
assert len(l) == 4 or len(l) == 5
self.multi_speaker = len(l) == 5
texts = list(map(lambda l: l.decode("utf-8").split("|")[3], lines))
if self.multi_speaker:
speaker_ids = list(
map(lambda l: int(l.decode("utf-8").split("|")[-1]), lines))
# Filter by speaker_id
# using multi-speaker dataset as a single speaker dataset
if self.speaker_id is not None:
indices = np.array(speaker_ids) == self.speaker_id
texts = list(np.array(texts)[indices])
self.multi_speaker = False
return texts
return texts, speaker_ids
else:
return texts
def collect_features(self, *args):
if self.multi_speaker:
text, speaker_id = args
else:
text = args[0]
global _frontend
if _frontend is None:
_frontend = getattr(frontend, hparams.frontend)
seq = _frontend.text_to_sequence(
text, p=hparams.replace_pronunciation_prob)
if platform.system() == "Windows":
if hasattr(hparams, "gc_probability"):
_frontend = None # memory leaking prevention in Windows
if np.random.rand() < hparams.gc_probability:
gc.collect() # garbage collection enforced
print("GC done")
if self.multi_speaker:
return np.asarray(seq, dtype=np.int32), int(speaker_id)
else:
return np.asarray(seq, dtype=np.int32)
class _NPYDataSource(FileDataSource):
def __init__(self, data_root, col, speaker_id=None):
self.data_root = data_root
self.col = col
self.frame_lengths = []
self.speaker_id = speaker_id
def collect_files(self):
meta = join(self.data_root, "train.txt")
with open(meta, "rb") as f:
lines = f.readlines()
l = lines[0].decode("utf-8").split("|")
assert len(l) == 4 or len(l) == 5
multi_speaker = len(l) == 5
self.frame_lengths = list(
map(lambda l: int(l.decode("utf-8").split("|")[2]), lines))
paths = list(
map(lambda l: l.decode("utf-8").split("|")[self.col], lines))
paths = list(map(lambda f: join(self.data_root, f), paths))
if multi_speaker and self.speaker_id is not None:
speaker_ids = list(
map(lambda l: int(l.decode("utf-8").split("|")[-1]), lines))
# Filter by speaker_id
# using multi-speaker dataset as a single speaker dataset
indices = np.array(speaker_ids) == self.speaker_id
paths = list(np.array(paths)[indices])
self.frame_lengths = list(np.array(self.frame_lengths)[indices])
# aha, need to cast numpy.int64 to int
self.frame_lengths = list(map(int, self.frame_lengths))
return paths
def collect_features(self, path):
return np.load(path)
class MelSpecDataSource(_NPYDataSource):
def __init__(self, data_root, speaker_id=None):
super(MelSpecDataSource, self).__init__(data_root, 1, speaker_id)
class LinearSpecDataSource(_NPYDataSource):
def __init__(self, data_root, speaker_id=None):
super(LinearSpecDataSource, self).__init__(data_root, 0, speaker_id)
class PartialyRandomizedSimilarTimeLengthSampler(object):
"""Partially randmoized sampler
1. Sort by lengths
2. Pick a small patch and randomize it
3. Permutate mini-batchs
"""
def __init__(self,
lengths,
batch_size=16,
batch_group_size=None,
permutate=True):
self.sorted_indices = np.argsort(lengths)
self.lengths = np.array(lengths)[self.sorted_indices]
self.batch_size = batch_size
if batch_group_size is None:
batch_group_size = min(batch_size * 32, len(self.lengths))
if batch_group_size % batch_size != 0:
batch_group_size -= batch_group_size % batch_size
self.batch_group_size = batch_group_size
assert batch_group_size % batch_size == 0
self.permutate = permutate
def __iter__(self):
indices = self.sorted_indices.copy()
batch_group_size = self.batch_group_size
s, e = 0, 0
for i in range(len(indices) // batch_group_size):
s = i * batch_group_size
e = s + batch_group_size
random.shuffle(indices[s:e])
# Permutate batches
if self.permutate:
perm = np.arange(len(indices[:e]) // self.batch_size)
random.shuffle(perm)
indices[:e] = indices[:e].reshape(
-1, self.batch_size)[perm, :].reshape(-1)
# Handle last elements
s += batch_group_size
if s < len(indices):
random.shuffle(indices[s:])
return iter(indices)
def __len__(self):
return len(self.sorted_indices)
class Dataset(object):
def __init__(self, X, Mel, Y):
self.X = X
self.Mel = Mel
self.Y = Y
# alias
self.multi_speaker = X.file_data_source.multi_speaker
def __getitem__(self, idx):
if self.multi_speaker:
text, speaker_id = self.X[idx]
return text, self.Mel[idx], self.Y[idx], speaker_id
else:
return self.X[idx], self.Mel[idx], self.Y[idx]
def __len__(self):
return len(self.X)
def make_loader(dataset, batch_size, shuffle, sampler, create_batch_fn,
trainer_count, local_rank):
assert not (
shuffle and
sampler), "shuffle and sampler should not be valid in the same time."
num_samples = len(dataset)
def wrapper():
if sampler is None:
ids = range(num_samples)
if shuffle:
random.shuffle(ids)
else:
ids = sampler
batch, batches = [], []
for idx in ids:
batch.append(dataset[idx])
if len(batch) >= batch_size:
batches.append(batch)
batch = []
if len(batches) >= trainer_count:
yield create_batch_fn(batches[local_rank])
batches = []
if len(batch) > 0:
batches.append(batch)
if len(batches) >= trainer_count:
yield create_batch_fn(batches[local_rank])
return wrapper
def create_batch(batch):
"""Create batch"""
r = hparams.outputs_per_step
downsample_step = hparams.downsample_step
multi_speaker = len(batch[0]) == 4
# Lengths
input_lengths = [len(x[0]) for x in batch]
max_input_len = max(input_lengths)
input_lengths = np.array(input_lengths, dtype=np.int64)
target_lengths = [len(x[1]) for x in batch]
max_target_len = max(target_lengths)
target_lengths = np.array(target_lengths, dtype=np.int64)
if max_target_len % (r * downsample_step) != 0:
max_target_len += (r * downsample_step) - max_target_len % (
r * downsample_step)
assert max_target_len % (r * downsample_step) == 0
# Set 0 for zero beginning padding
# imitates initial decoder states
b_pad = r
max_target_len += b_pad * downsample_step
x_batch = np.array(
[_pad(x[0], max_input_len) for x in batch], dtype=np.int64)
x_batch = np.expand_dims(x_batch, axis=-1)
mel_batch = np.array(
[_pad_2d(
x[1], max_target_len, b_pad=b_pad) for x in batch],
dtype=np.float32)
# down sampling is done here
if downsample_step > 1:
mel_batch = mel_batch[:, 0::downsample_step, :]
mel_batch = np.expand_dims(np.transpose(mel_batch, axes=[0, 2, 1]), axis=2)
y_batch = np.array(
[_pad_2d(
x[2], max_target_len, b_pad=b_pad) for x in batch],
dtype=np.float32)
y_batch = np.expand_dims(np.transpose(y_batch, axes=[0, 2, 1]), axis=2)
# text positions
text_positions = np.array(
[_pad(np.arange(1, len(x[0]) + 1), max_input_len) for x in batch],
dtype=np.int)
text_positions = np.expand_dims(text_positions, axis=-1)
max_decoder_target_len = max_target_len // r // downsample_step
# frame positions
s, e = 1, max_decoder_target_len + 1
frame_positions = np.tile(
np.expand_dims(
np.arange(s, e), axis=0), (len(batch), 1))
frame_positions = np.expand_dims(frame_positions, axis=-1)
# done flags
done = np.array([
_pad(
np.zeros(
len(x[1]) // r // downsample_step - 1, dtype=np.float32),
max_decoder_target_len,
constant_values=1) for x in batch
])
done = np.expand_dims(np.expand_dims(done, axis=1), axis=1)
if multi_speaker:
speaker_ids = np.expand_dims(np.array([x[3] for x in batch]), axis=-1)
return (x_batch, input_lengths, mel_batch, y_batch, text_positions,
frame_positions, done, target_lengths, speaker_ids)
else:
speaker_ids = None
return (x_batch, input_lengths, mel_batch, y_batch, text_positions,
frame_positions, done, target_lengths)
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddle import fluid
import paddle.fluid.dygraph as dg
from hparams import hparams, hparams_debug_string
from deepvoice3_paddle import frontend
from deepvoice3_paddle.deepvoice3 import DeepVoiceTTS
def dry_run(model):
"""
Run the model once, just to get it initialized.
"""
model.train()
_frontend = getattr(frontend, hparams.frontend)
batch_size = 4
enc_length = 157
snd_sample_length = 500
r = hparams.outputs_per_step
downsample_step = hparams.downsample_step
n_speakers = hparams.n_speakers
# make sure snd_sample_length can be divided by r * downsample_step
linear_shift = r * downsample_step
snd_sample_length += linear_shift - snd_sample_length % linear_shift
decoder_length = snd_sample_length // downsample_step // r
mel_length = snd_sample_length // downsample_step
n_vocab = _frontend.n_vocab
max_pos = hparams.max_positions
spker_embed = hparams.speaker_embed_dim
linear_dim = model.linear_dim
mel_dim = hparams.num_mels
x = np.random.randint(
low=0, high=n_vocab, size=(batch_size, enc_length, 1), dtype="int64")
input_lengths = np.arange(
enc_length - batch_size + 1, enc_length + 1, dtype="int64")
mel = np.random.randn(batch_size, mel_dim, 1, mel_length).astype("float32")
y = np.random.randn(batch_size, linear_dim, 1,
snd_sample_length).astype("float32")
text_positions = np.tile(
np.arange(
0, enc_length, dtype="int64"), (batch_size, 1))
text_mask = text_positions > np.expand_dims(input_lengths, 1)
text_positions[text_mask] = 0
text_positions = np.expand_dims(text_positions, axis=-1)
frame_positions = np.tile(
np.arange(
1, decoder_length + 1, dtype="int64"), (batch_size, 1))
frame_positions = np.expand_dims(frame_positions, axis=-1)
done = np.zeros(shape=(batch_size, 1, 1, decoder_length), dtype="float32")
target_lengths = np.array([snd_sample_length] * batch_size).astype("int64")
speaker_ids = np.random.randint(
low=0, high=n_speakers, size=(batch_size, 1),
dtype="int64") if n_speakers > 1 else None
ismultispeaker = speaker_ids is not None
x = dg.to_variable(x)
input_lengths = dg.to_variable(input_lengths)
mel = dg.to_variable(mel)
y = dg.to_variable(y)
text_positions = dg.to_variable(text_positions)
frame_positions = dg.to_variable(frame_positions)
done = dg.to_variable(done)
target_lengths = dg.to_variable(target_lengths)
speaker_ids = dg.to_variable(
speaker_ids) if speaker_ids is not None else None
# these two fields are used as numpy ndarray
text_lengths = input_lengths.numpy()
decoder_lengths = target_lengths.numpy() // r // downsample_step
max_seq_len = max(text_lengths.max(), decoder_lengths.max())
if max_seq_len >= hparams.max_positions:
raise RuntimeError(
"max_seq_len ({}) >= max_posision ({})\n"
"Input text or decoder targget length exceeded the maximum length.\n"
"Please set a larger value for ``max_position`` in hyper parameters."
.format(max_seq_len, hparams.max_positions))
# cause paddle's embedding layer expect shape[-1] == 1
# first dry run runs the whole model
mel_outputs, linear_outputs, attn, done_hat = model(
x, input_lengths, mel, speaker_ids, text_positions, frame_positions)
num_parameters = 0
for k, v in model.state_dict().items():
print("{}|{}|{}".format(k, v.shape, np.prod(v.shape)))
num_parameters += np.prod(v.shape)
print("now model has {} parameters".format(len(model.state_dict())))
print("now model has {} elements".format(num_parameters))
This package is adapted from https://github.com/r9y9/deepvoice3_pytorch/tree/master/deepvoice3_pytorch/frontend, Copyright (c) 2017: Ryuichi Yamamoto, whose license applies.
# coding: utf-8
"""Text processing frontend
All frontend module should have the following functions:
- text_to_sequence(text, p)
- sequence_to_text(sequence)
and the property:
- n_vocab
"""
from deepvoice3_paddle.frontend import en
# optinoal Japanese frontend
try:
from deepvoice3_paddle.frontend import jp
except ImportError:
jp = None
try:
from deepvoice3_paddle.frontend import ko
except ImportError:
ko = None
# if you are going to use the frontend, you need to modify _characters in
# symbol.py:
# _characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? ' + '¡¿ñáéíóúÁÉÍÓÚÑ'
try:
from deepvoice3_paddle.frontend import es
except ImportError:
es = None
# coding: utf-8
from deepvoice3_paddle.frontend.text.symbols import symbols
import nltk
from random import random
n_vocab = len(symbols)
_arpabet = nltk.corpus.cmudict.dict()
def _maybe_get_arpabet(word, p):
try:
phonemes = _arpabet[word][0]
phonemes = " ".join(phonemes)
except KeyError:
return word
return '{%s}' % phonemes if random() < p else word
def mix_pronunciation(text, p):
text = ' '.join(_maybe_get_arpabet(word, p) for word in text.split(' '))
return text
def text_to_sequence(text, p=0.0):
if p >= 0:
text = mix_pronunciation(text, p)
from deepvoice3_paddle.frontend.text import text_to_sequence
text = text_to_sequence(text, ["english_cleaners"])
return text
from deepvoice3_paddle.frontend.text import sequence_to_text
# coding: utf-8
from deepvoice3_paddle.frontend.text.symbols import symbols
import nltk
from random import random
n_vocab = len(symbols)
def text_to_sequence(text, p=0.0):
from deepvoice3_paddle.frontend.text import text_to_sequence
text = text_to_sequence(text, ["basic_cleaners"])
return text
from deepvoice3_paddle.frontend.text import sequence_to_text
# coding: utf-8
import MeCab
import jaconv
from random import random
n_vocab = 0xffff
_eos = 1
_pad = 0
_tagger = None
def _yomi(mecab_result):
tokens = []
yomis = []
for line in mecab_result.split("\n")[:-1]:
s = line.split("\t")
if len(s) == 1:
break
token, rest = s
rest = rest.split(",")
tokens.append(token)
yomi = rest[7] if len(rest) > 7 else None
yomi = None if yomi == "*" else yomi
yomis.append(yomi)
return tokens, yomis
def _mix_pronunciation(tokens, yomis, p):
return "".join(yomis[idx]
if yomis[idx] is not None and random() < p else tokens[idx]
for idx in range(len(tokens)))
def mix_pronunciation(text, p):
global _tagger
if _tagger is None:
_tagger = MeCab.Tagger("")
tokens, yomis = _yomi(_tagger.parse(text))
return _mix_pronunciation(tokens, yomis, p)
def add_punctuation(text):
last = text[-1]
if last not in [".", ",", "、", "。", "!", "?", "!", "?"]:
text = text + "。"
return text
def normalize_delimitor(text):
text = text.replace(",", "、")
text = text.replace(".", "。")
text = text.replace(",", "、")
text = text.replace(".", "。")
return text
def text_to_sequence(text, p=0.0):
for c in [" ", " ", "「", "」", "『", "』", "・", "【", "】", "(", ")", "(", ")"]:
text = text.replace(c, "")
text = text.replace("!", "!")
text = text.replace("?", "?")
text = normalize_delimitor(text)
text = jaconv.normalize(text)
if p > 0:
text = mix_pronunciation(text, p)
text = jaconv.hira2kata(text)
text = add_punctuation(text)
return [ord(c) for c in text] + [_eos] # EOS
def sequence_to_text(seq):
return "".join(chr(n) for n in seq)
# coding: utf-8
from random import random
n_vocab = 0xffff
_eos = 1
_pad = 0
_tagger = None
def text_to_sequence(text, p=0.0):
return [ord(c) for c in text] + [_eos] # EOS
def sequence_to_text(seq):
return "".join(chr(n) for n in seq)
import re
from deepvoice3_paddle.frontend.text import cleaners
from deepvoice3_paddle.frontend.text.symbols import symbols
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
# Regular expression matching text enclosed in curly braces:
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
def text_to_sequence(text, cleaner_names):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
'''
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
if not m:
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
break
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
# Append EOS token
sequence.append(_symbol_to_id['~'])
return sequence
def sequence_to_text(sequence):
'''Converts a sequence of IDs back to a string'''
result = ''
for symbol_id in sequence:
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == '@':
s = '{%s}' % s[1:]
result += s
return result.replace('}{', ' ')
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception('Unknown cleaner: %s' % name)
text = cleaner(text)
return text
def _symbols_to_sequence(symbols):
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
def _arpabet_to_sequence(text):
return _symbols_to_sequence(['@' + s for s in text.split()])
def _should_keep_symbol(s):
return s in _symbol_to_id and s is not '_' and s is not '~'
'''
Cleaners are transformations that run over the input text at both training and
eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as
the "cleaners" hyperparameter. Some cleaners are English-specific. You'll
typically want to use:
1. "english_cleaners" for English text
2. "transliteration_cleaners" for non-English text that can be transliterated
to ASCII using the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you
should also update the symbols in symbols.py to match your data).
'''
import re
from unidecode import unidecode
from .numbers import normalize_numbers
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def add_punctuation(text):
if len(text) == 0:
return text
if text[-1] not in '!,.:;?':
text = text + '.' # without this decoder is confused when to output EOS
return text
def basic_cleaners(text):
'''
Basic pipeline that lowercases and collapses whitespace without
transliteration.
'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
'''Pipeline for non-English text that transliterates to ASCII.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def english_cleaners(text):
'''
Pipeline for English text, including number and abbreviation expansion.
'''
text = convert_to_ascii(text)
text = add_punctuation(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
return text
import re
valid_symbols = [
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1',
'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0',
'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0',
'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1',
'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW',
'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T',
'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y',
'Z', 'ZH'
]
_valid_symbol_set = set(valid_symbols)
class CMUDict:
'''
Thin wrapper around CMUDict data.
http://www.speech.cs.cmu.edu/cgi-bin/cmudict
'''
def __init__(self, file_or_path, keep_ambiguous=True):
if isinstance(file_or_path, str):
with open(file_or_path, encoding='latin-1') as f:
entries = _parse_cmudict(f)
else:
entries = _parse_cmudict(file_or_path)
if not keep_ambiguous:
entries = {
word: pron
for word, pron in entries.items() if len(pron) == 1
}
self._entries = entries
def __len__(self):
return len(self._entries)
def lookup(self, word):
'''Returns list of ARPAbet pronunciations of the given word.'''
return self._entries.get(word.upper())
_alt_re = re.compile(r'\([0-9]+\)')
def _parse_cmudict(file):
cmudict = {}
for line in file:
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
parts = line.split(' ')
word = re.sub(_alt_re, '', parts[0])
pronunciation = _get_pronunciation(parts[1])
if pronunciation:
if word in cmudict:
cmudict[word].append(pronunciation)
else:
cmudict[word] = [pronunciation]
return cmudict
def _get_pronunciation(s):
parts = s.strip().split(' ')
for part in parts:
if part not in _valid_symbol_set:
return None
return ' '.join(parts)
# -*- coding: utf-8 -*-
import inflect
import re
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+')
def _remove_commas(m):
return m.group(1).replace(',', '')
def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ')
def _expand_dollars(m):
match = m.group(1)
parts = match.split('.')
if len(parts) > 2:
return match + ' dollars' # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
else:
return 'zero dollars'
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
else:
return _inflect.number_to_words(
num, andword='', zero='oh', group=2).replace(', ', ' ')
else:
return _inflect.number_to_words(num, andword='')
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
'''
Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English or text
that has been run through Unidecode. For other data, you can modify _characters.
See TRAINING_DATA.md for details.
'''
from .cmudict import valid_symbols
_pad = '_'
_eos = '~'
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
_arpabet = ['@' + s for s in valid_symbols]
# Export all symbols:
symbols = [_pad, _eos] + list(_characters) + _arpabet
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from numba import jit
from paddle import fluid
import paddle.fluid.dygraph as dg
def masked_mean(inputs, mask):
"""
Args:
inputs (Variable): Shape(B, C, 1, T), the input, where B means
batch size, C means channels of input, T means timesteps of
the input.
mask (Variable): Shape(B, T), a mask.
Returns:
loss (Variable): Shape(1, ), masked mean.
"""
channels = inputs.shape[1]
reshaped_mask = fluid.layers.reshape(
mask, shape=[mask.shape[0], 1, 1, mask.shape[-1]])
expanded_mask = fluid.layers.expand(
reshaped_mask, expand_times=[1, channels, 1, 1])
expanded_mask.stop_gradient = True
valid_cnt = fluid.layers.reduce_sum(expanded_mask)
valid_cnt.stop_gradient = True
masked_inputs = inputs * expanded_mask
loss = fluid.layers.reduce_sum(masked_inputs) / valid_cnt
return loss
@jit(nopython=True)
def guided_attention(N, max_N, T, max_T, g):
W = np.zeros((max_N, max_T), dtype=np.float32)
for n in range(N):
for t in range(T):
W[n, t] = 1 - np.exp(-(n / N - t / T)**2 / (2 * g * g))
return W
def guided_attentions(input_lengths, target_lengths, max_target_len, g=0.2):
B = len(input_lengths)
max_input_len = input_lengths.max()
W = np.zeros((B, max_target_len, max_input_len), dtype=np.float32)
for b in range(B):
W[b] = guided_attention(input_lengths[b], max_input_len,
target_lengths[b], max_target_len, g).T
return W
class TTSLoss(object):
def __init__(self,
masked_weight=0.0,
priority_weight=0.0,
binary_divergence_weight=0.0,
guided_attention_sigma=0.2):
self.masked_weight = masked_weight
self.priority_weight = priority_weight
self.binary_divergence_weight = binary_divergence_weight
self.guided_attention_sigma = guided_attention_sigma
def l1_loss(self, prediction, target, mask, priority_bin=None):
abs_diff = fluid.layers.abs(prediction - target)
# basic mask-weighted l1 loss
w = self.masked_weight
if w > 0 and mask is not None:
base_l1_loss = w * masked_mean(abs_diff, mask) + (
1 - w) * fluid.layers.reduce_mean(abs_diff)
else:
base_l1_loss = fluid.layers.reduce_mean(abs_diff)
if self.priority_weight > 0 and priority_bin is not None:
# mask-weighted priority channels' l1-loss
priority_abs_diff = fluid.layers.slice(
abs_diff, axes=[1], starts=[0], ends=[priority_bin])
if w > 0 and mask is not None:
priority_loss = w * masked_mean(priority_abs_diff, mask) + (
1 - w) * fluid.layers.reduce_mean(priority_abs_diff)
else:
priority_loss = fluid.layers.reduce_mean(priority_abs_diff)
# priority weighted sum
p = self.priority_weight
loss = p * priority_loss + (1 - p) * base_l1_loss
else:
loss = base_l1_loss
return loss
def binary_divergence(self, prediction, target, mask):
flattened_prediction = fluid.layers.reshape(prediction, [-1, 1])
flattened_target = fluid.layers.reshape(target, [-1, 1])
flattened_loss = fluid.layers.log_loss(
flattened_prediction, flattened_target, epsilon=1e-8)
bin_div = fluid.layers.reshape(flattened_loss, prediction.shape)
w = self.masked_weight
if w > 0 and mask is not None:
loss = w * masked_mean(bin_div, mask) + (
1 - w) * fluid.layers.reduce_mean(bin_div)
else:
loss = fluid.layers.reduce_mean(bin_div)
return loss
@staticmethod
def done_loss(done_hat, done):
flat_done_hat = fluid.layers.reshape(done_hat, [-1, 1])
flat_done = fluid.layers.reshape(done, [-1, 1])
loss = fluid.layers.log_loss(flat_done_hat, flat_done, epsilon=1e-8)
loss = fluid.layers.reduce_mean(loss)
return loss
def attention_loss(self, predicted_attention, input_lengths,
target_lengths):
"""
Given valid encoder_lengths and decoder_lengths, compute a diagonal
guide, and compute loss from the predicted attention and the guide.
Args:
predicted_attention (Variable): Shape(*, B, T_dec, T_enc), the
alignment tensor, where B means batch size, T_dec means number
of time steps of the decoder, T_enc means the number of time
steps of the encoder, * means other possible dimensions.
input_lengths (numpy.ndarray): Shape(B,), dtype:int64, valid lengths
(time steps) of encoder outputs.
target_lengths (numpy.ndarray): Shape(batch_size,), dtype:int64,
valid lengths (time steps) of decoder outputs.
Returns:
loss (Variable): Shape(1, ) attention loss.
"""
n_attention, batch_size, max_target_len, max_input_len = (
predicted_attention.shape)
soft_mask = guided_attentions(input_lengths, target_lengths,
max_target_len,
self.guided_attention_sigma)
soft_mask_ = dg.to_variable(soft_mask)
loss = fluid.layers.reduce_mean(predicted_attention * soft_mask_)
return loss
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import fluid
import paddle.fluid.dygraph as dg
import numpy as np
import deepvoice3_paddle.conv as conv
import deepvoice3_paddle.weight_norm as weight_norm
def FC(name_scope,
in_features,
size,
num_flatten_dims=1,
dropout=0.0,
epsilon=1e-30,
act=None,
is_test=False,
dtype="float32"):
"""
A special Linear Layer, when it is used with dropout, the weight is
initialized as normal(0, std=np.sqrt((1-dropout) / in_features))
"""
# stds
if isinstance(in_features, int):
in_features = [in_features]
stds = [np.sqrt((1 - dropout) / in_feature) for in_feature in in_features]
weight_inits = [
fluid.initializer.NormalInitializer(scale=std) for std in stds
]
bias_init = fluid.initializer.ConstantInitializer(0.0)
# param attrs
weight_attrs = [fluid.ParamAttr(initializer=init) for init in weight_inits]
bias_attr = fluid.ParamAttr(initializer=bias_init)
layer = weight_norm.FC(name_scope,
size,
num_flatten_dims=num_flatten_dims,
param_attr=weight_attrs,
bias_attr=bias_attr,
act=act,
dtype=dtype)
return layer
def Conv1D(name_scope,
in_channels,
num_filters,
filter_size=3,
dilation=1,
groups=None,
causal=False,
std_mul=1.0,
dropout=0.0,
use_cudnn=True,
act=None,
dtype="float32"):
"""
A special Conv1D Layer, when it is used with dropout, the weight is
initialized as
normal(0, std=np.sqrt(std_mul * (1-dropout) / (filter_size * in_features)))
"""
# std
std = np.sqrt((std_mul * (1 - dropout)) / (filter_size * in_channels))
weight_init = fluid.initializer.NormalInitializer(loc=0.0, scale=std)
bias_init = fluid.initializer.ConstantInitializer(0.0)
# param attrs
weight_attr = fluid.ParamAttr(initializer=weight_init)
bias_attr = fluid.ParamAttr(initializer=bias_init)
layer = conv.Conv1D(
name_scope,
in_channels,
num_filters,
filter_size,
dilation,
groups=groups,
causal=causal,
param_attr=weight_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
return layer
def Embedding(name_scope,
num_embeddings,
embed_dim,
is_sparse=False,
is_distributed=False,
padding_idx=None,
std=0.01,
dtype="float32"):
# param attrs
weight_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=std))
layer = dg.Embedding(
name_scope, (num_embeddings, embed_dim),
padding_idx=padding_idx,
param_attr=weight_attr,
dtype=dtype)
return layer
class Conv1DGLU(dg.Layer):
"""
A Convolution 1D block with GLU activation. It also applys dropout for the
input x. It fuses speaker embeddings through a FC activated by softsign. It
has residual connection from the input x, and scale the output by
np.sqrt(0.5).
"""
def __init__(self,
name_scope,
n_speakers,
speaker_dim,
in_channels,
num_filters,
filter_size,
dilation,
std_mul=4.0,
dropout=0.0,
causal=False,
residual=True,
dtype="float32"):
super(Conv1DGLU, self).__init__(name_scope, dtype=dtype)
# conv spec
self.in_channels = in_channels
self.n_speakers = n_speakers
self.speaker_dim = speaker_dim
self.num_filters = num_filters
self.filter_size = filter_size
self.dilation = dilation
self.causal = causal
self.residual = residual
# weight init and dropout
self.std_mul = std_mul
self.dropout = dropout
if residual:
assert (
in_channels == num_filters
), "this block uses residual connection"\
"the input_channes should equals num_filters"
self.conv = Conv1D(
self.full_name(),
in_channels,
2 * num_filters,
filter_size,
dilation,
causal=causal,
std_mul=std_mul,
dropout=dropout,
dtype=dtype)
if n_speakers > 1:
assert (speaker_dim is not None
), "speaker embed should not be null in multi-speaker case"
self.fc = Conv1D(
self.full_name(),
speaker_dim,
num_filters,
filter_size=1,
dilation=1,
causal=False,
act="softsign",
dtype=dtype)
def forward(self, x, speaker_embed_bc1t=None):
"""
Args:
x (Variable): Shape(B, C_in, 1, T), the input of Conv1DGLU
layer, where B means batch_size, C_in means the input channels
T means input time steps.
speaker_embed_bct1 (Variable): Shape(B, C_sp, 1, T), expanded
speaker embed, where C_sp means speaker embedding size. Note
that when using residual connection, the Conv1DGLU does not
change the number of channels, so out channels equals input
channels.
Returns:
x (Variable): Shape(B, C_out, 1, T), the output of Conv1DGLU, where
C_out means the output channels of Conv1DGLU.
"""
residual = x
x = fluid.layers.dropout(
x, self.dropout, dropout_implementation="upscale_in_train")
x = self.conv(x)
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
if speaker_embed_bc1t is not None:
sp = self.fc(speaker_embed_bc1t)
content = content + sp
# glu
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), content)
if self.residual:
x = fluid.layers.scale(x + residual, np.sqrt(0.5))
return x
def add_input(self, x, speaker_embed_bc11=None):
"""
Inputs:
x: shape(B, num_filters, 1, time_steps)
speaker_embed_bc11: shape(B, speaker_dim, 1, time_steps)
Outputs:
out: shape(B, num_filters, 1, time_steps), where time_steps = 1
"""
residual = x
# add step input and produce step output
x = fluid.layers.dropout(
x, self.dropout, dropout_implementation="upscale_in_train")
x = self.conv.add_input(x)
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
if speaker_embed_bc11 is not None:
sp = self.fc(speaker_embed_bc11)
content = content + sp
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), content)
if self.residual:
x = fluid.layers.scale(x + residual, np.sqrt(0.5))
return x
def Conv1DTranspose(name_scope,
in_channels,
num_filters,
filter_size,
padding=0,
stride=1,
dilation=1,
groups=None,
std_mul=1.0,
dropout=0.0,
use_cudnn=True,
act=None,
dtype="float32"):
std = np.sqrt(std_mul * (1 - dropout) / (in_channels * filter_size))
weight_init = fluid.initializer.NormalInitializer(scale=std)
weight_attr = fluid.ParamAttr(initializer=weight_init)
bias_init = fluid.initializer.ConstantInitializer(0.0)
bias_attr = fluid.ParamAttr(initializer=bias_init)
layer = conv.Conv1DTranspose(
name_scope,
in_channels,
num_filters,
filter_size,
padding=padding,
stride=stride,
dilation=dilation,
groups=groups,
param_attr=weight_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
return layer
def compute_position_embedding(rad):
# rad is a transposed radius, shape(embed_dim, n_vocab)
embed_dim, n_vocab = rad.shape
even_dims = dg.to_variable(np.arange(0, embed_dim, 2).astype("int32"))
odd_dims = dg.to_variable(np.arange(1, embed_dim, 2).astype("int32"))
even_rads = fluid.layers.gather(rad, even_dims)
odd_rads = fluid.layers.gather(rad, odd_dims)
sines = fluid.layers.sin(even_rads)
cosines = fluid.layers.cos(odd_rads)
temp = fluid.layers.scatter(rad, even_dims, sines)
out = fluid.layers.scatter(temp, odd_dims, cosines)
out = fluid.layers.transpose(out, perm=[1, 0])
return out
def position_encoding_init(n_position,
d_pos_vec,
position_rate=1.0,
sinusoidal=True):
""" Init the sinusoid position encoding table """
# keep idx 0 for padding token position encoding zero vector
position_enc = np.array([[
position_rate * pos / np.power(10000, 2 * (i // 2) / d_pos_vec)
for i in range(d_pos_vec)
] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
if sinusoidal:
position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i
position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1
return position_enc
class PositionEmbedding(dg.Layer):
def __init__(self,
name_scope,
n_position,
d_pos_vec,
position_rate=1.0,
is_sparse=False,
is_distributed=False,
param_attr=None,
max_norm=None,
padding_idx=None,
dtype="float32"):
super(PositionEmbedding, self).__init__(name_scope, dtype=dtype)
self.embed = dg.Embedding(
self.full_name(),
size=(n_position, d_pos_vec),
is_sparse=is_sparse,
is_distributed=is_distributed,
padding_idx=None,
param_attr=param_attr,
dtype=dtype)
self.set_weight(
position_encoding_init(
n_position,
d_pos_vec,
position_rate=position_rate,
sinusoidal=False).astype(dtype))
self._is_sparse = is_sparse
self._is_distributed = is_distributed
self._remote_prefetch = self._is_sparse and (not self._is_distributed)
if self._remote_prefetch:
assert self._is_sparse is True and self._is_distributed is False
self._padding_idx = (-1 if padding_idx is None else padding_idx if
padding_idx >= 0 else (n_position + padding_idx))
self._position_rate = position_rate
self._max_norm = max_norm
self._dtype = dtype
def set_weight(self, array):
assert self.embed._w.shape == list(array.shape), "shape does not match"
self.embed._w._ivar.value().get_tensor().set(
array, fluid.framework._current_expected_place())
def forward(self, indices, speaker_position_rate=None):
"""
Args:
indices (Variable): Shape (B, T, 1), dtype: int64, position
indices, where B means the batch size, T means the time steps.
speaker_position_rate (Variable | float, optional), position
rate. It can be a float point number or a Variable with
shape (1,), then this speaker_position_rate is used for every
example. It can also be a Variable with shape (B, 1), which
contains a speaker position rate for each speaker.
Returns:
out (Variable): Shape(B, C_pos), position embedding, where C_pos
means position embedding size.
"""
rad = fluid.layers.transpose(self.embed._w, perm=[1, 0])
batch_size = indices.shape[0]
if speaker_position_rate is None:
weight = compute_position_embedding(rad)
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="lookup_table",
inputs={"Ids": indices,
"W": weight},
outputs={"Out": out},
attrs={
"is_sparse": self._is_sparse,
"is_distributed": self._is_distributed,
"remote_prefetch": self._remote_prefetch,
"padding_idx":
self._padding_idx, # special value for lookup table op
})
return out
elif (np.isscalar(speaker_position_rate) or
isinstance(speaker_position_rate, fluid.framework.Variable) and
speaker_position_rate.shape == [1, 1]):
# # make a weight
# scale the weight (the operand for sin & cos)
if np.isscalar(speaker_position_rate):
scaled_rad = fluid.layers.scale(rad, speaker_position_rate)
else:
scaled_rad = fluid.layers.elementwise_mul(
rad, speaker_position_rate[0])
weight = compute_position_embedding(scaled_rad)
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="lookup_table",
inputs={"Ids": indices,
"W": weight},
outputs={"Out": out},
attrs={
"is_sparse": self._is_sparse,
"is_distributed": self._is_distributed,
"remote_prefetch": self._remote_prefetch,
"padding_idx":
self._padding_idx, # special value for lookup table op
})
return out
elif np.prod(speaker_position_rate.shape) > 1:
assert speaker_position_rate.shape == [batch_size, 1]
outputs = []
for i in range(batch_size):
rate = speaker_position_rate[i] # rate has shape [1]
scaled_rad = fluid.layers.elementwise_mul(rad, rate)
weight = compute_position_embedding(scaled_rad)
out = self._helper.create_variable_for_type_inference(
self._dtype)
sequence = indices[i]
self._helper.append_op(
type="lookup_table",
inputs={"Ids": sequence,
"W": weight},
outputs={"Out": out},
attrs={
"is_sparse": self._is_sparse,
"is_distributed": self._is_distributed,
"remote_prefetch": self._remote_prefetch,
"padding_idx": -1,
})
outputs.append(out)
out = fluid.layers.stack(outputs)
return out
else:
raise Exception("Then you can just use position rate at init")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from os.path import dirname, join
import paddle
from paddle import fluid
import paddle.fluid.dygraph as dg
def _load(checkpoint_path):
"""
Load saved state dict and optimizer state(optional).
"""
state_dict, optimizer_state = dg.load_persistables(dirname=checkpoint_path)
return state_dict, optimizer_state
def load_checkpoint(path, model, optimizer=None, reset_optimizer=True):
"""
layers like FC, Conv*, ... the Layer does not initialize their parameters
before first run.
1. if you want to load only a part of a saved whole model, to part of an
existing model, just pass the part as the target model , and path of the
saved whole model as source path.
2. if you want to load exactly from what is saved, just passed the model
and path as expected.
The rule of thumb is:
1. loading to a model works with name, a unique global name.
2. loading from a directory works with file structure, each parameter is
saved in a file. Loading a file from directory A/ would `create` a
corresponding Variable for each saved parameter, whose name is the file's
relative path from directory A/.
"""
print("Load checkpoint from: {}".format(path))
state_dict, optimizer_state = _load(path)
model.load_dict(state_dict)
if not reset_optimizer and optimizer is not None:
if optimizer_state is not None:
print("[loading] Load optimizer state from {}".format(path))
optimizer.load(optimizer_state)
return model
def _load_embedding(path, model):
print("[loading] Loading embedding from {}".format(path))
state_dict, optimizer_state = _load(path)
key = os.path.join(model.full_name(), "ConvS2S_0/Encoder_0/Embedding_0.w_0")
tensor = model.state_dict()[key]._ivar.value().get_tensor()
tensor.set(state_dict[key], fluid.framework._current_expected_place())
def save_checkpoint(model, optimizer, checkpoint_dir, global_step):
checkpoint_path = join(checkpoint_dir,
"checkpoint_step{:09d}.model".format(global_step))
dg.save_persistables(
model.state_dict(), dirname=checkpoint_path, optimizers=optimizer)
print("[checkpoint] Saved checkpoint:", checkpoint_path)
此差异已折叠。
此差异已折叠。
Source: hparam.py copied from tensorflow v1.12.0.
https://github.com/tensorflow/tensorflow/blob/v1.12.0/tensorflow/contrib/training/python/training/hparam.py
with the following:
wget https://github.com/tensorflow/tensorflow/raw/v1.12.0/tensorflow/contrib/training/python/training/hparam.py
Once all other tensorflow dependencies of these file are removed, the class keeps its goal. Functions not available due to this process are not used in this project.
此差异已折叠。
此差异已折叠。
此差异已折叠。
{
"name": "deepvoice3",
"frontend": "en",
"replace_pronunciation_prob": 0.5,
"builder": "deepvoice3",
"n_speakers": 1,
"speaker_embed_dim": 16,
"num_mels": 80,
"fmin": 125,
"fmax": 7600,
"fft_size": 1024,
"hop_size": 256,
"sample_rate": 22050,
"preemphasis": 0.97,
"min_level_db": -100,
"ref_level_db": 20,
"rescaling": false,
"rescaling_max": 0.999,
"allow_clipping_in_normalization": true,
"downsample_step": 4,
"outputs_per_step": 1,
"embedding_weight_std": 0.1,
"speaker_embedding_weight_std": 0.01,
"padding_idx": 0,
"max_positions": 512,
"dropout": 0.050000000000000044,
"kernel_size": 3,
"text_embed_dim": 256,
"encoder_channels": 512,
"decoder_channels": 256,
"converter_channels": 256,
"query_position_rate": 1.0,
"key_position_rate": 1.385,
"key_projection": true,
"value_projection": true,
"use_memory_mask": true,
"trainable_positional_encodings": false,
"freeze_embedding": false,
"use_decoder_state_for_postnet_input": true,
"pin_memory": true,
"num_workers": 2,
"masked_loss_weight": 0.5,
"priority_freq": 3000,
"priority_freq_weight": 0.0,
"binary_divergence_weight": 0.1,
"use_guided_attention": true,
"guided_attention_sigma": 0.2,
"batch_size": 16,
"adam_beta1": 0.5,
"adam_beta2": 0.9,
"adam_eps": 1e-06,
"initial_learning_rate": 0.0005,
"lr_schedule": "noam_learning_rate_decay",
"lr_schedule_kwargs": {},
"nepochs": 2000,
"weight_decay": 0.0,
"clip_thresh": 0.1,
"checkpoint_interval": 10000,
"eval_interval": 10000,
"save_optimizer_state": true,
"force_monotonic_attention": true,
"window_ahead": 3,
"window_backward": 1,
"power": 1.4
}
numba==0.45.1
numpy==1.16.4
nltk==3.4.4
scipy
unidecode==1.1.1
inflect==2.1.0
librosa==0.7.0
tqdm==4.35.0
tensorboardX==1.8
matplotlib
requests==2.22.0
lws
nnmnkwii
tensorboard
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册