diff --git a/docs/source/released_model.md b/docs/source/released_model.md index a7c6a036b455410cc7d88947ef6c99d7b867924c..78f5c92f0ed0aae46c5ec3a82d59c10bf4360064 100644 --- a/docs/source/released_model.md +++ b/docs/source/released_model.md @@ -3,14 +3,15 @@ ## Speech-to-Text Models ### Acoustic Model Released in paddle 2.X -Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | Hours of speech -:-------------:| :------------:| :-----: | -----: | :----------------- |:--------- | :---------- | :--------- -[Ds2 Online Aishell Model](https://deepspeech.bj.bcebos.com/release2.2/aishell/s0/ds2_online_aishll_CER8.02_release.tar.gz) | Aishell Dataset | Char-based | 345 MB | 2 Conv + 5 LSTM layers with only forward direction | 0.080218 |-| 151 h -[Ds2 Offline Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s0/aishell.s0.ds2.offline.cer6p65.release.tar.gz)| Aishell Dataset | Char-based | 306 MB | 2 Conv + 3 bidirectional GRU layers| 0.065 |-| 151 h -[Conformer Online Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.chunk.release.tar.gz) | Aishell Dataset | Char-based | 283 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention + CTC | 0.0594 |-| 151 h -[Conformer Offline Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.release.tar.gz) | Aishell Dataset | Char-based | 284 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention | 0.0547 |-| 151 h -[Conformer Librispeech Model](https://deepspeech.bj.bcebos.com/release2.1/librispeech/s1/conformer.release.tar.gz) | Librispeech Dataset | Word-based | 287 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention |-| 0.0325 | 960 h -[Transformer Librispeech Model](https://deepspeech.bj.bcebos.com/release2.1/librispeech/s1/transformer.release.tar.gz) | Librispeech Dataset | Word-based | 195 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention |-| 0.0544 | 960 h +Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | Hours of speech | example link +:-------------:| :------------:| :-----: | -----: | :----------------- |:--------- | :---------- | :--------- | :----------- +[Ds2 Online Aishell S0 Model](https://deepspeech.bj.bcebos.com/release2.2/aishell/s0/ds2_online_aishll_CER8.02_release.tar.gz) | Aishell Dataset | Char-based | 345 MB | 2 Conv + 5 LSTM layers with only forward direction | 0.080218 |-| 151 h | [D2 Online Aishell S0 Example](../../examples/aishell/s0) +[Ds2 Offline Aishell S0 Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s0/aishell.s0.ds2.offline.cer6p65.release.tar.gz)| Aishell Dataset | Char-based | 306 MB | 2 Conv + 3 bidirectional GRU layers| 0.065 |-| 151 h | [Ds2 Offline Aishell S0 Example](../../examples/aishell/s0) +[Conformer Online Aishell S1 Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.chunk.release.tar.gz) | Aishell Dataset | Char-based | 283 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0594 |-| 151 h | [Conformer Online Aishell S1 Example](../../examples/aishell/s1) +[Conformer Offline Aishell S1 Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.release.tar.gz) | Aishell Dataset | Char-based | 284 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0547 |-| 151 h | [Conformer Offline Aishell S1 Example](../../examples/aishell/s1) +[Conformer Librispeech S1 Model](https://deepspeech.bj.bcebos.com/release2.1/librispeech/s1/conformer.release.tar.gz) | Librispeech Dataset | subword-based | 287 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0325 | 960 h | [Conformer Librispeech S1 example](../../example/librispeech/s1) +[Transformer Librispeech S1 Model](https://deepspeech.bj.bcebos.com/release2.2/librispeech/s1/librispeech.s1.transformer.all.wer5p62.release.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0456 | 960 h | [Transformer Librispeech S1 example](../../example/librispeech/s1) +[Transformer Librispeech S2 Model](https://deepspeech.bj.bcebos.com/release2.2/librispeech/s2/libri_transformer_espnet_wer3p84.release.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention |-| 0.0384 | 960 h | [Transformer Librispeech S2 example](../../example/librispeech/s2) ### Acoustic Model Transformed from paddle 1.8 Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | Hours of speech diff --git a/examples/aishell3/vc1/README.md b/examples/aishell3/vc1/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8c0aec3af0e46f3829cb999c0a177ad8213ae4f7 --- /dev/null +++ b/examples/aishell3/vc1/README.md @@ -0,0 +1,89 @@ +# FastSpeech2 + AISHELL-3 Voice Cloning +This example contains code used to train a [Tacotron2 ](https://arxiv.org/abs/1712.05884) model with [AISHELL-3](http://www.aishelltech.com/aishell_3). The trained model can be used in Voice Cloning Task, We refer to the model structure of [Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis](https://arxiv.org/pdf/1806.04558.pdf) . The general steps are as follows: +1. Speaker Encoder: We use a Speaker Verification to train a speaker encoder. Datasets used in this task are different from those used in Tacotron2, because the transcriptions are not needed, we use more datasets, refer to [ge2e](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/ge2e). +2. Synthesizer: Then, we use the trained speaker encoder to generate utterance embedding for each sentence in AISHELL-3. This embedding is a extra input of Tacotron2 which will be concated with encoder outputs. +3. Vocoder: We use WaveFlow as the neural Vocoder, refer to [waveflow](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc0). + +## Get Started +Assume the path to the dataset is `~/datasets/data_aishell3`. +Assume the path to the MFA result of AISHELL-3 is `./alignment`. +Assume the path to the pretrained ge2e model is `ge2e_ckpt_path=./ge2e_ckpt_0.3/step-3000000` +Run the command below to +1. **source path**. +2. preprocess the dataset, +3. train the model. +4. start a voice cloning inference. +```bash +./run.sh +``` +### Preprocess the dataset +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/preprocess.sh ${input} ${preprocess_path} ${alignment} ${ge2e_ckpt_path} +``` +#### generate utterance embedding + Use pretrained GE2E (speaker encoder) to generate utterance embedding for each sentence in AISHELL-3, which has the same file structure with wav files and the format is `.npy`. + +```bash +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${BIN_DIR}/../ge2e/inference.py \ + --input=${input} \ + --output=${preprocess_path}/embed \ + --ngpu=1 \ + --checkpoint_path=${ge2e_ckpt_path} +fi +``` + +The computing time of utterance embedding can be x hours. +#### process wav +There are silence in the edge of AISHELL-3's wavs, and the audio amplitude is very small, so, we need to remove the silence and normalize the audio. You can the silence remove method based on volume or energy, but the effect is not very good, We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get the alignment of text and speech, then utilize the alignment results to remove the silence. + +We use Montreal Force Aligner 1.0. The label in aishell3 include pinyin,so the lexicon we provided to MFA is pinyin rather than Chinese characters. And the prosody marks(`$` and `%`) need to be removed. You shoud preprocess the dataset into the format which MFA needs, the texts have the same name with wavs and have the suffix `.lab`. + +We use [lexicon.txt](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/paddlespeech/t2s/exps/voice_cloning/tacotron2_ge2e/lexicon.txt) as the lexicon. + +You can download the alignment results from here [alignment_aishell3.tar.gz](https://paddlespeech.bj.bcebos.com/Parakeet/alignment_aishell3.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/use_mfa) (use MFA1.x now) of our repo. + +```bash +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "Process wav ..." + python3 ${BIN_DIR}/process_wav.py \ + --input=${input}/wav \ + --output=${preprocess_path}/normalized_wav \ + --alignment=${alignment} +fi +``` + +#### preprocess transcription +We revert the transcription into `phones` and `tones`. It is worth noting that our processing here is different from that used for MFA, we separated the tones. This is a processing method, of course, you can only segment initials and vowels. + +```bash +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + python3 ${BIN_DIR}/preprocess_transcription.py \ + --input=${input} \ + --output=${preprocess_path} +fi +``` +The default input is `~/datasets/data_aishell3/train`,which contains `label_train-set.txt`, the processed results are `metadata.yaml` and `metadata.pickle`. the former is a text format for easy viewing, and the latter is a binary format for direct reading. +#### extract mel +```python +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + python3 ${BIN_DIR}/extract_mel.py \ + --input=${preprocess_path}/normalized_wav \ + --output=${preprocess_path}/mel +fi +``` + +### Train the model +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${preprocess_path} ${train_output_path} +``` + +Our model remve stop token prediction in Tacotron2, because of the problem of extremely unbalanced proportion of positive and negative samples of stop token prediction, and it's very sensitive to the clip of audio silence. We use the last symbol from the highest point of attention to the encoder side as the termination condition. + +In addition, in order to accelerate the convergence of the model, we add `guided attention loss` to induce the alignment between encoder and decoder to show diagonal lines faster. +### Infernece +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${ge2e_params_path} ${tacotron2_params_path} ${waveflow_params_path} ${vc_input} ${vc_output} +``` +## Pretrained Model +[tacotron2_aishell3_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/tacotron2_aishell3_ckpt_0.3.zip). diff --git a/examples/aishell3/vc1/conf/default.yaml b/examples/aishell3/vc1/conf/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bdd2a765e1dee40324a05226be7bc590448c7c49 --- /dev/null +++ b/examples/aishell3/vc1/conf/default.yaml @@ -0,0 +1,105 @@ +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### + +fs: 24000 # sr +n_fft: 2048 # FFT size. +n_shift: 300 # Hop size. +win_length: 1200 # Window length. + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. + +# Only used for feats_type != raw + +fmin: 80 # Minimum frequency of Mel basis. +fmax: 7600 # Maximum frequency of Mel basis. +n_mels: 80 # The number of mel basis. + +# Only used for the model using pitch features (e.g. FastSpeech2) +f0min: 80 # Maximum f0 for pitch extraction. +f0max: 400 # Minimum f0 for pitch extraction. + + +########################################################### +# DATA SETTING # +########################################################### +batch_size: 64 +num_workers: 2 + + +########################################################### +# MODEL SETTING # +########################################################### +model: + adim: 384 # attention dimension + aheads: 2 # number of attention heads + elayers: 4 # number of encoder layers + eunits: 1536 # number of encoder ff units + dlayers: 4 # number of decoder layers + dunits: 1536 # number of decoder ff units + positionwise_layer_type: conv1d # type of position-wise layer + positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer + duration_predictor_layers: 2 # number of layers of duration predictor + duration_predictor_chans: 256 # number of channels of duration predictor + duration_predictor_kernel_size: 3 # filter size of duration predictor + postnet_layers: 5 # number of layers of postnset + postnet_filts: 5 # filter size of conv layers in postnet + postnet_chans: 256 # number of channels of conv layers in postnet + use_masking: True # whether to apply masking for padded part in loss calculation + use_scaled_pos_enc: True # whether to use scaled positional encoding + encoder_normalize_before: True # whether to perform layer normalization before the input + decoder_normalize_before: True # whether to perform layer normalization before the input + reduction_factor: 1 # reduction factor + init_type: xavier_uniform # initialization type + init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding + init_dec_alpha: 1.0 # initial value of alpha of decoder scaled position encoding + transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer + transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding + transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer + transformer_dec_dropout_rate: 0.2 # dropout rate for transformer decoder layer + transformer_dec_positional_dropout_rate: 0.2 # dropout rate for transformer decoder positional encoding + transformer_dec_attn_dropout_rate: 0.2 # dropout rate for transformer decoder attention layer + pitch_predictor_layers: 5 # number of conv layers in pitch predictor + pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor + pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor + pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor + pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch + pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch + stop_gradient_from_pitch_predictor: true # whether to stop the gradient from pitch predictor to encoder + energy_predictor_layers: 2 # number of conv layers in energy predictor + energy_predictor_chans: 256 # number of channels of conv layers in energy predictor + energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor + energy_predictor_dropout: 0.5 # dropout rate in energy predictor + energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy + energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy + stop_gradient_from_energy_predictor: false # whether to stop the gradient from energy predictor to encoder + spk_embed_dim: 256 # speaker embedding dimension + spk_embed_integration_type: concat # speaker embedding integration type + + + +########################################################### +# UPDATER SETTING # +########################################################### +updater: + use_masking: True # whether to apply masking for padded part in loss calculation + + +########################################################### +# OPTIMIZER SETTING # +########################################################### +optimizer: + optim: adam # optimizer type + learning_rate: 0.001 # learning rate + +########################################################### +# TRAINING SETTING # +########################################################### +max_epoch: 200 +num_snapshots: 5 + + +########################################################### +# OTHER SETTING # +########################################################### +seed: 10086 diff --git a/examples/aishell3/vc1/local/preprocess.sh b/examples/aishell3/vc1/local/preprocess.sh new file mode 100755 index 0000000000000000000000000000000000000000..5f939a1a8203c51945f4dfa748ba5d92aba34337 --- /dev/null +++ b/examples/aishell3/vc1/local/preprocess.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +stage=0 +stop_stage=100 + +config_path=$1 +ge2e_ckpt_path=$2 + +# gen speaker embedding +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${MAIN_ROOT}/paddlespeech/vector/exps/ge2e/inference.py \ + --input=~/datasets/data_aishell3/train/wav/ \ + --output=dump/embed \ + --checkpoint_path=${ge2e_ckpt_path} +fi + +# copy from tts3/preprocess +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./aishell3_alignment_tone \ + --output durations.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/preprocess.py \ + --dataset=aishell3 \ + --rootdir=~/datasets/data_aishell3/ \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --num-cpu=20 \ + --cut-sil=True \ + --spk_emb_dir=dump/embed +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="speech" + + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="pitch" + + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="energy" +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # normalize and covert phone/speaker to id, dev and test should use train's stats + echo "Normalize ..." + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --pitch-stats=dump/train/pitch_stats.npy \ + --energy-stats=dump/train/energy_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --pitch-stats=dump/train/pitch_stats.npy \ + --energy-stats=dump/train/energy_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --pitch-stats=dump/train/pitch_stats.npy \ + --energy-stats=dump/train/energy_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt +fi diff --git a/examples/aishell3/vc1/local/synthesize.sh b/examples/aishell3/vc1/local/synthesize.sh new file mode 100755 index 0000000000000000000000000000000000000000..35478c784b76b76a5aeb8c57a9b3a7de6aa583c5 --- /dev/null +++ b/examples/aishell3/vc1/local/synthesize.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/synthesize.py \ + --fastspeech2-config=${config_path} \ + --fastspeech2-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --fastspeech2-stat=dump/train/speech_stats.npy \ + --pwg-config=pwg_aishell3_ckpt_0.5/default.yaml \ + --pwg-checkpoint=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --pwg-stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --test-metadata=dump/test/norm/metadata.jsonl \ + --output-dir=${train_output_path}/test \ + --phones-dict=dump/phone_id_map.txt \ + --voice-cloning=True diff --git a/examples/aishell3/vc1/local/train.sh b/examples/aishell3/vc1/local/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..c775fcadcceef12e05225c46aa53812e22aa2ee4 --- /dev/null +++ b/examples/aishell3/vc1/local/train.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +python3 ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --ngpu=2 \ + --phones-dict=dump/phone_id_map.txt \ + --voice-cloning=True \ No newline at end of file diff --git a/examples/aishell3/vc1/local/voice_cloning.sh b/examples/aishell3/vc1/local/voice_cloning.sh new file mode 100755 index 0000000000000000000000000000000000000000..55bdd761ef845d7dc8084ce0cb80947f7b050656 --- /dev/null +++ b/examples/aishell3/vc1/local/voice_cloning.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 +ge2e_params_path=$4 +ref_audio_dir=$5 + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/voice_cloning.py \ + --fastspeech2-config=${config_path} \ + --fastspeech2-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --fastspeech2-stat=dump/train/speech_stats.npy \ + --pwg-config=pwg_aishell3_ckpt_0.5/default.yaml \ + --pwg-checkpoint=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --pwg-stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --ge2e_params_path=${ge2e_params_path} \ + --text="凯莫瑞安联合体的经济崩溃迫在眉睫。" \ + --input-dir=${ref_audio_dir} \ + --output-dir=${train_output_path}/vc_syn \ + --phones-dict=dump/phone_id_map.txt diff --git a/examples/aishell3/vc1/path.sh b/examples/aishell3/vc1/path.sh new file mode 100755 index 0000000000000000000000000000000000000000..fb7e8411c80cc8cbf1c65dffaaf771bda961e10e --- /dev/null +++ b/examples/aishell3/vc1/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=fastspeech2 +export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL} diff --git a/examples/aishell3/vc1/run.sh b/examples/aishell3/vc1/run.sh new file mode 100755 index 0000000000000000000000000000000000000000..4eae1bdd87fadfe1663202cbd9700620e64cea37 --- /dev/null +++ b/examples/aishell3/vc1/run.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0,1 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_482.pdz +ref_audio_dir=ref_audio + +# not include ".pdparams" here +ge2e_ckpt_path=./ge2e_ckpt_0.3/step-3000000 + +# include ".pdparams" here +ge2e_params_path=${ge2e_ckpt_path}.pdparams + +# with the following command, you can choice the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + CUDA_VISIBLE_DEVICES=${gpus} ./local/preprocess.sh ${conf_path} ${ge2e_ckpt_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize, vocoder is pwgan + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # synthesize, vocoder is pwgan + CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${conf_path} ${train_output_path} ${ckpt_name} ${ge2e_params_path} ${ref_audio_dir} || exit -1 +fi diff --git a/paddlespeech/t2s/exps/fastspeech2/voice_cloning.py b/paddlespeech/t2s/exps/fastspeech2/voice_cloning.py new file mode 100644 index 0000000000000000000000000000000000000000..9fbd496418199d05b6319ab335f1a5437bb961d2 --- /dev/null +++ b/paddlespeech/t2s/exps/fastspeech2/voice_cloning.py @@ -0,0 +1,208 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from pathlib import Path + +import numpy as np +import paddle +import soundfile as sf +import yaml +from yacs.config import CfgNode + +from paddlespeech.t2s.frontend.zh_frontend import Frontend +from paddlespeech.t2s.models.fastspeech2 import FastSpeech2 +from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Inference +from paddlespeech.t2s.models.parallel_wavegan import PWGGenerator +from paddlespeech.t2s.models.parallel_wavegan import PWGInference +from paddlespeech.t2s.modules.normalizer import ZScore +from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor +from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder + + +def voice_cloning(args, fastspeech2_config, pwg_config): + # speaker encoder + p = SpeakerVerificationPreprocessor( + sampling_rate=16000, + audio_norm_target_dBFS=-30, + vad_window_length=30, + vad_moving_average_width=8, + vad_max_silence_length=6, + mel_window_length=25, + mel_window_step=10, + n_mels=40, + partial_n_frames=160, + min_pad_coverage=0.75, + partial_overlap_ratio=0.5) + print("Audio Processor Done!") + + speaker_encoder = LSTMSpeakerEncoder( + n_mels=40, num_layers=3, hidden_size=256, output_size=256) + speaker_encoder.set_state_dict(paddle.load(args.ge2e_params_path)) + speaker_encoder.eval() + print("GE2E Done!") + + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + odim = fastspeech2_config.n_mels + model = FastSpeech2( + idim=vocab_size, odim=odim, **fastspeech2_config["model"]) + + model.set_state_dict( + paddle.load(args.fastspeech2_checkpoint)["main_params"]) + model.eval() + + vocoder = PWGGenerator(**pwg_config["generator_params"]) + vocoder.set_state_dict(paddle.load(args.pwg_checkpoint)["generator_params"]) + vocoder.remove_weight_norm() + vocoder.eval() + print("model done!") + + frontend = Frontend(phone_vocab_path=args.phones_dict) + print("frontend done!") + + stat = np.load(args.fastspeech2_stat) + mu, std = stat + mu = paddle.to_tensor(mu) + std = paddle.to_tensor(std) + fastspeech2_normalizer = ZScore(mu, std) + + stat = np.load(args.pwg_stat) + mu, std = stat + mu = paddle.to_tensor(mu) + std = paddle.to_tensor(std) + pwg_normalizer = ZScore(mu, std) + + fastspeech2_inference = FastSpeech2Inference(fastspeech2_normalizer, model) + fastspeech2_inference.eval() + pwg_inference = PWGInference(pwg_normalizer, vocoder) + pwg_inference.eval() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + input_dir = Path(args.input_dir) + + sentence = args.text + + input_ids = frontend.get_input_ids(sentence, merge_sentences=True) + phone_ids = input_ids["phone_ids"][0] + + for name in os.listdir(input_dir): + utt_id = name.split(".")[0] + ref_audio_path = input_dir / name + mel_sequences = p.extract_mel_partials(p.preprocess_wav(ref_audio_path)) + # print("mel_sequences: ", mel_sequences.shape) + with paddle.no_grad(): + spk_emb = speaker_encoder.embed_utterance( + paddle.to_tensor(mel_sequences)) + # print("spk_emb shape: ", spk_emb.shape) + + with paddle.no_grad(): + wav = pwg_inference( + fastspeech2_inference(phone_ids, spk_emb=spk_emb)) + + sf.write( + str(output_dir / (utt_id + ".wav")), + wav.numpy(), + samplerate=fastspeech2_config.fs) + print(f"{utt_id} done!") + # Randomly generate numbers of 0 ~ 0.2, 256 is the dim of spk_emb + random_spk_emb = np.random.rand(256) * 0.2 + random_spk_emb = paddle.to_tensor(random_spk_emb) + utt_id = "random_spk_emb" + with paddle.no_grad(): + wav = pwg_inference(fastspeech2_inference(phone_ids, spk_emb=spk_emb)) + sf.write( + str(output_dir / (utt_id + ".wav")), + wav.numpy(), + samplerate=fastspeech2_config.fs) + print(f"{utt_id} done!") + + +def main(): + # parse args and config and redirect to train_sp + parser = argparse.ArgumentParser(description="") + parser.add_argument( + "--fastspeech2-config", type=str, help="fastspeech2 config file.") + parser.add_argument( + "--fastspeech2-checkpoint", + type=str, + help="fastspeech2 checkpoint to load.") + parser.add_argument( + "--fastspeech2-stat", + type=str, + help="mean and standard deviation used to normalize spectrogram when training fastspeech2." + ) + parser.add_argument( + "--pwg-config", type=str, help="parallel wavegan config file.") + parser.add_argument( + "--pwg-checkpoint", + type=str, + help="parallel wavegan generator parameters to load.") + parser.add_argument( + "--pwg-stat", + type=str, + help="mean and standard deviation used to normalize spectrogram when training parallel wavegan." + ) + parser.add_argument( + "--phones-dict", + type=str, + default="phone_id_map.txt", + help="phone vocabulary file.") + parser.add_argument( + "--text", + type=str, + default="每当你觉得,想要批评什么人的时候,你切要记着,这个世界上的人,并非都具备你禀有的条件。", + help="text to synthesize, a line") + + parser.add_argument( + "--ge2e_params_path", type=str, help="ge2e params path.") + + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.") + + parser.add_argument( + "--input-dir", + type=str, + help="input dir of *.wav, the sample rate will be resample to 16k.") + parser.add_argument("--output-dir", type=str, help="output dir.") + + args = parser.parse_args() + + if args.ngpu == 0: + paddle.set_device("cpu") + elif args.ngpu > 0: + paddle.set_device("gpu") + else: + print("ngpu should >= 0 !") + + with open(args.fastspeech2_config) as f: + fastspeech2_config = CfgNode(yaml.safe_load(f)) + with open(args.pwg_config) as f: + pwg_config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(fastspeech2_config) + print(pwg_config) + + voice_cloning(args, fastspeech2_config, pwg_config) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/modules/attention.py b/paddlespeech/t2s/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..154625cc3c1d9426ed2d21edc3064798b73ccd3a --- /dev/null +++ b/paddlespeech/t2s/modules/attention.py @@ -0,0 +1,348 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import numpy as np +import paddle +from paddle import nn +from paddle.nn import functional as F + + +def scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0, + training=True): + r"""Scaled dot product attention with masking. + + Assume that q, k, v all have the same leading dimensions (denoted as * in + descriptions below). Dropout is applied to attention weights before + weighted sum of values. + + Parameters + ----------- + q : Tensor [shape=(\*, T_q, d)] + the query tensor. + k : Tensor [shape=(\*, T_k, d)] + the key tensor. + v : Tensor [shape=(\*, T_k, d_v)] + the value tensor. + mask : Tensor, [shape=(\*, T_q, T_k) or broadcastable shape], optional + the mask tensor, zeros correspond to paddings. Defaults to None. + + Returns + ---------- + out : Tensor [shape=(\*, T_q, d_v)] + the context vector. + attn_weights : Tensor [shape=(\*, T_q, T_k)] + the attention weights. + """ + d = q.shape[-1] # we only support imperative execution + qk = paddle.matmul(q, k, transpose_y=True) + scaled_logit = paddle.scale(qk, 1.0 / math.sqrt(d)) + + if mask is not None: + scaled_logit += paddle.scale((1.0 - mask), -1e9) # hard coded here + + attn_weights = F.softmax(scaled_logit, axis=-1) + attn_weights = F.dropout(attn_weights, dropout, training=training) + out = paddle.matmul(attn_weights, v) + return out, attn_weights + + +def drop_head(x, drop_n_heads, training=True): + """Drop n context vectors from multiple ones. + + Parameters + ---------- + x : Tensor [shape=(batch_size, num_heads, time_steps, channels)] + The input, multiple context vectors. + drop_n_heads : int [0<= drop_n_heads <= num_heads] + Number of vectors to drop. + training : bool + A flag indicating whether it is in training. If `False`, no dropout is + applied. + + Returns + ------- + Tensor + The output. + """ + if not training or (drop_n_heads == 0): + return x + + batch_size, num_heads, _, _ = x.shape + # drop all heads + if num_heads == drop_n_heads: + return paddle.zeros_like(x) + + mask = np.ones([batch_size, num_heads]) + mask[:, :drop_n_heads] = 0 + for subarray in mask: + np.random.shuffle(subarray) + scale = float(num_heads) / (num_heads - drop_n_heads) + mask = scale * np.reshape(mask, [batch_size, num_heads, 1, 1]) + out = x * paddle.to_tensor(mask) + return out + + +def _split_heads(x, num_heads): + batch_size, time_steps, _ = x.shape + x = paddle.reshape(x, [batch_size, time_steps, num_heads, -1]) + x = paddle.transpose(x, [0, 2, 1, 3]) + return x + + +def _concat_heads(x): + batch_size, _, time_steps, _ = x.shape + x = paddle.transpose(x, [0, 2, 1, 3]) + x = paddle.reshape(x, [batch_size, time_steps, -1]) + return x + + +# Standard implementations of Monohead Attention & Multihead Attention +class MonoheadAttention(nn.Layer): + """Monohead Attention module. + + Parameters + ---------- + model_dim : int + Feature size of the query. + dropout : float, optional + Dropout probability of scaled dot product attention and final context + vector. Defaults to 0.0. + k_dim : int, optional + Feature size of the key of each scaled dot product attention. If not + provided, it is set to `model_dim / num_heads`. Defaults to None. + v_dim : int, optional + Feature size of the key of each scaled dot product attention. If not + provided, it is set to `model_dim / num_heads`. Defaults to None. + """ + + def __init__(self, + model_dim: int, + dropout: float=0.0, + k_dim: int=None, + v_dim: int=None): + super(MonoheadAttention, self).__init__() + k_dim = k_dim or model_dim + v_dim = v_dim or model_dim + self.affine_q = nn.Linear(model_dim, k_dim) + self.affine_k = nn.Linear(model_dim, k_dim) + self.affine_v = nn.Linear(model_dim, v_dim) + self.affine_o = nn.Linear(v_dim, model_dim) + + self.model_dim = model_dim + self.dropout = dropout + + def forward(self, q, k, v, mask): + """Compute context vector and attention weights. + + Parameters + ----------- + q : Tensor [shape=(batch_size, time_steps_q, model_dim)] + The queries. + k : Tensor [shape=(batch_size, time_steps_k, model_dim)] + The keys. + v : Tensor [shape=(batch_size, time_steps_k, model_dim)] + The values. + mask : Tensor [shape=(batch_size, times_steps_q, time_steps_k] or broadcastable shape + The mask. + + Returns + ---------- + out : Tensor [shape=(batch_size, time_steps_q, model_dim)] + The context vector. + attention_weights : Tensor [shape=(batch_size, times_steps_q, time_steps_k)] + The attention weights. + """ + q = self.affine_q(q) # (B, T, C) + k = self.affine_k(k) + v = self.affine_v(v) + + context_vectors, attention_weights = scaled_dot_product_attention( + q, k, v, mask, self.dropout, self.training) + + out = self.affine_o(context_vectors) + return out, attention_weights + + +class MultiheadAttention(nn.Layer): + """Multihead Attention module. + + Parameters + ----------- + model_dim: int + The feature size of query. + num_heads : int + The number of attention heads. + dropout : float, optional + Dropout probability of scaled dot product attention and final context + vector. Defaults to 0.0. + k_dim : int, optional + Feature size of the key of each scaled dot product attention. If not + provided, it is set to ``model_dim / num_heads``. Defaults to None. + v_dim : int, optional + Feature size of the key of each scaled dot product attention. If not + provided, it is set to ``model_dim / num_heads``. Defaults to None. + + Raises + --------- + ValueError + If ``model_dim`` is not divisible by ``num_heads``. + """ + + def __init__(self, + model_dim: int, + num_heads: int, + dropout: float=0.0, + k_dim: int=None, + v_dim: int=None): + super(MultiheadAttention, self).__init__() + if model_dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + depth = model_dim // num_heads + k_dim = k_dim or depth + v_dim = v_dim or depth + self.affine_q = nn.Linear(model_dim, num_heads * k_dim) + self.affine_k = nn.Linear(model_dim, num_heads * k_dim) + self.affine_v = nn.Linear(model_dim, num_heads * v_dim) + self.affine_o = nn.Linear(num_heads * v_dim, model_dim) + + self.num_heads = num_heads + self.model_dim = model_dim + self.dropout = dropout + + def forward(self, q, k, v, mask): + """Compute context vector and attention weights. + + Parameters + ----------- + q : Tensor [shape=(batch_size, time_steps_q, model_dim)] + The queries. + k : Tensor [shape=(batch_size, time_steps_k, model_dim)] + The keys. + v : Tensor [shape=(batch_size, time_steps_k, model_dim)] + The values. + mask : Tensor [shape=(batch_size, times_steps_q, time_steps_k] or broadcastable shape + The mask. + + Returns + ---------- + out : Tensor [shape=(batch_size, time_steps_q, model_dim)] + The context vector. + attention_weights : Tensor [shape=(batch_size, times_steps_q, time_steps_k)] + The attention weights. + """ + q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C) + k = _split_heads(self.affine_k(k), self.num_heads) + v = _split_heads(self.affine_v(v), self.num_heads) + mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim + + context_vectors, attention_weights = scaled_dot_product_attention( + q, k, v, mask, self.dropout, self.training) + # NOTE: there is more sophisticated implementation: Scheduled DropHead + context_vectors = _concat_heads(context_vectors) # (B, T, h*C) + out = self.affine_o(context_vectors) + return out, attention_weights + + +class LocationSensitiveAttention(nn.Layer): + """Location Sensitive Attention module. + + Reference: `Attention-Based Models for Speech Recognition `_ + + Parameters + ----------- + d_query: int + The feature size of query. + d_key : int + The feature size of key. + d_attention : int + The feature size of dimension. + location_filters : int + Filter size of attention convolution. + location_kernel_size : int + Kernel size of attention convolution. + """ + + def __init__(self, + d_query: int, + d_key: int, + d_attention: int, + location_filters: int, + location_kernel_size: int): + super().__init__() + + self.query_layer = nn.Linear(d_query, d_attention, bias_attr=False) + self.key_layer = nn.Linear(d_key, d_attention, bias_attr=False) + self.value = nn.Linear(d_attention, 1, bias_attr=False) + + # Location Layer + self.location_conv = nn.Conv1D( + 2, + location_filters, + kernel_size=location_kernel_size, + padding=int((location_kernel_size - 1) / 2), + bias_attr=False, + data_format='NLC') + self.location_layer = nn.Linear( + location_filters, d_attention, bias_attr=False) + + def forward(self, + query, + processed_key, + value, + attention_weights_cat, + mask=None): + """Compute context vector and attention weights. + + Parameters + ----------- + query : Tensor [shape=(batch_size, d_query)] + The queries. + processed_key : Tensor [shape=(batch_size, time_steps_k, d_attention)] + The keys after linear layer. + value : Tensor [shape=(batch_size, time_steps_k, d_key)] + The values. + attention_weights_cat : Tensor [shape=(batch_size, time_step_k, 2)] + Attention weights concat. + mask : Tensor, optional + The mask. Shape should be (batch_size, times_steps_k, 1). + Defaults to None. + + Returns + ---------- + attention_context : Tensor [shape=(batch_size, d_attention)] + The context vector. + attention_weights : Tensor [shape=(batch_size, time_steps_k)] + The attention weights. + """ + + processed_query = self.query_layer(paddle.unsqueeze(query, axis=[1])) + processed_attention_weights = self.location_layer( + self.location_conv(attention_weights_cat)) + # (B, T_enc, 1) + alignment = self.value( + paddle.tanh(processed_attention_weights + processed_key + + processed_query)) + + if mask is not None: + alignment = alignment + (1.0 - mask) * -1e9 + + attention_weights = F.softmax(alignment, axis=1) + attention_context = paddle.matmul( + attention_weights, value, transpose_x=True) + + attention_weights = paddle.squeeze(attention_weights, axis=-1) + attention_context = paddle.squeeze(attention_context, axis=1) + + return attention_context, attention_weights diff --git a/paddlespeech/t2s/modules/nets_utils.py b/paddlespeech/t2s/modules/nets_utils.py index fbb3a9a3d65f83fd43b19902c9e97137691f2a2d..879cdba63e87d6c898a1fd417a9f9e3958a9bcb2 100644 --- a/paddlespeech/t2s/modules/nets_utils.py +++ b/paddlespeech/t2s/modules/nets_utils.py @@ -17,14 +17,6 @@ from paddle import nn from typeguard import check_argument_types -class Swish(paddle.nn.Layer): - """Construct an Swish object.""" - - def forward(self, x): - """Return Swich activation function.""" - return x * paddle.nn.Sigmoid(x) - - def pad_list(xs, pad_value): """Perform padding for the list of tensors. @@ -168,7 +160,7 @@ def get_activation(act): "tanh": paddle.nn.Tanh, "relu": paddle.nn.ReLU, "selu": paddle.nn.SELU, - "swish": Swish, + "swish": paddle.nn.Swish, } return activation_funcs[act]() diff --git a/paddlespeech/t2s/modules/transformer/subsampling.py b/paddlespeech/t2s/modules/transformer/subsampling.py index 300b35beda72dda735629b525a0f00bb25129e94..506cfde68f265c6aab848c2cac45b3ab2fcdca90 100644 --- a/paddlespeech/t2s/modules/transformer/subsampling.py +++ b/paddlespeech/t2s/modules/transformer/subsampling.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from espnet(https://github.com/espnet/espnet) -# Conv2dSubsampling 测试通过 """Subsampling layer definition.""" import paddle @@ -98,8 +97,7 @@ class Conv2dSubsampling(paddle.nn.Layer): # (b, c, t, f) x = x.unsqueeze(1) x = self.conv(x) - b, c, t, f = x.shape - # x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + b, c, t, f = paddle.shape(x) x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f])) if x_mask is None: return x, None @@ -163,7 +161,7 @@ class Conv2dSubsampling2(paddle.nn.Layer): # (b, c, t, f) x = x.unsqueeze(1) x = self.conv(x) - b, c, t, f = x.shape + b, c, t, f = paddle.shape(x) x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f])) if x_mask is None: return x, None @@ -227,7 +225,7 @@ class Conv2dSubsampling6(paddle.nn.Layer): # (b, c, t, f) x = x.unsqueeze(1) x = self.conv(x) - b, c, t, f = x.shape + b, c, t, f = paddle.shape(x) x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f])) if x_mask is None: return x, None @@ -259,8 +257,8 @@ class Conv2dSubsampling8(paddle.nn.Layer): paddle.nn.Conv2D(odim, odim, 3, 2), paddle.nn.ReLU(), ) self.out = paddle.nn.Sequential( - paddle.nn.Linear(odim * ((( - (idim - 1) // 2 - 1) // 2 - 1) // 2), odim), + paddle.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), + odim), pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), ) @@ -284,7 +282,7 @@ class Conv2dSubsampling8(paddle.nn.Layer): # (b, c, t, f) x = x.unsqueeze(1) x = self.conv(x) - b, c, t, f = x.shape + b, c, t, f = paddle.shape(x) x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f])) if x_mask is None: return x, None diff --git a/paddlespeech/xvector/__init__.py b/paddlespeech/vector/__init__.py similarity index 100% rename from paddlespeech/xvector/__init__.py rename to paddlespeech/vector/__init__.py diff --git a/paddlespeech/t2s/exps/ge2e/__init__.py b/paddlespeech/vector/exps/__init__.py similarity index 100% rename from paddlespeech/t2s/exps/ge2e/__init__.py rename to paddlespeech/vector/exps/__init__.py diff --git a/paddlespeech/vector/exps/ge2e/__init__.py b/paddlespeech/vector/exps/ge2e/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abf198b97e6e818e1fbe59006f98492640bcee54 --- /dev/null +++ b/paddlespeech/vector/exps/ge2e/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/t2s/exps/ge2e/audio_processor.py b/paddlespeech/vector/exps/ge2e/audio_processor.py similarity index 100% rename from paddlespeech/t2s/exps/ge2e/audio_processor.py rename to paddlespeech/vector/exps/ge2e/audio_processor.py diff --git a/paddlespeech/t2s/exps/ge2e/config.py b/paddlespeech/vector/exps/ge2e/config.py similarity index 100% rename from paddlespeech/t2s/exps/ge2e/config.py rename to paddlespeech/vector/exps/ge2e/config.py diff --git a/paddlespeech/t2s/exps/ge2e/dataset_processors.py b/paddlespeech/vector/exps/ge2e/dataset_processors.py similarity index 98% rename from paddlespeech/t2s/exps/ge2e/dataset_processors.py rename to paddlespeech/vector/exps/ge2e/dataset_processors.py index a9320d9859067154333c3464a495fc4557379dfc..908c852b2ec8121838f249ed04310f714776cffb 100644 --- a/paddlespeech/t2s/exps/ge2e/dataset_processors.py +++ b/paddlespeech/vector/exps/ge2e/dataset_processors.py @@ -19,7 +19,7 @@ from typing import List import numpy as np from tqdm import tqdm -from paddlespeech.t2s.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor +from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor def _process_utterance(path_pair, processor: SpeakerVerificationPreprocessor): diff --git a/paddlespeech/t2s/exps/ge2e/inference.py b/paddlespeech/vector/exps/ge2e/inference.py similarity index 95% rename from paddlespeech/t2s/exps/ge2e/inference.py rename to paddlespeech/vector/exps/ge2e/inference.py index eed3b7947d6bf9c4f561ecedbe458281cd61ab97..7660de5e876529448b0e8f0e2a3f6185d15e9322 100644 --- a/paddlespeech/t2s/exps/ge2e/inference.py +++ b/paddlespeech/vector/exps/ge2e/inference.py @@ -18,9 +18,9 @@ import numpy as np import paddle import tqdm -from paddlespeech.t2s.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor -from paddlespeech.t2s.exps.ge2e.config import get_cfg_defaults -from paddlespeech.t2s.models.lstm_speaker_encoder import LSTMSpeakerEncoder +from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor +from paddlespeech.vector.exps.ge2e.config import get_cfg_defaults +from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder def embed_utterance(processor, model, fpath_or_wav): diff --git a/paddlespeech/t2s/exps/ge2e/preprocess.py b/paddlespeech/vector/exps/ge2e/preprocess.py similarity index 87% rename from paddlespeech/t2s/exps/ge2e/preprocess.py rename to paddlespeech/vector/exps/ge2e/preprocess.py index 604ff0c6735f378cfda7052147823b3dd63a1780..dabe0ce7694547ed197a4d570bcec0399e9ac54e 100644 --- a/paddlespeech/t2s/exps/ge2e/preprocess.py +++ b/paddlespeech/vector/exps/ge2e/preprocess.py @@ -14,14 +14,13 @@ import argparse from pathlib import Path -from audio_processor import SpeakerVerificationPreprocessor - -from paddlespeech.t2s.exps.ge2e.config import get_cfg_defaults -from paddlespeech.t2s.exps.ge2e.dataset_processors import process_aidatatang_200zh -from paddlespeech.t2s.exps.ge2e.dataset_processors import process_librispeech -from paddlespeech.t2s.exps.ge2e.dataset_processors import process_magicdata -from paddlespeech.t2s.exps.ge2e.dataset_processors import process_voxceleb1 -from paddlespeech.t2s.exps.ge2e.dataset_processors import process_voxceleb2 +from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor +from paddlespeech.vector.exps.ge2e.config import get_cfg_defaults +from paddlespeech.vector.exps.ge2e.dataset_processors import process_aidatatang_200zh +from paddlespeech.vector.exps.ge2e.dataset_processors import process_librispeech +from paddlespeech.vector.exps.ge2e.dataset_processors import process_magicdata +from paddlespeech.vector.exps.ge2e.dataset_processors import process_voxceleb1 +from paddlespeech.vector.exps.ge2e.dataset_processors import process_voxceleb2 if __name__ == "__main__": parser = argparse.ArgumentParser( diff --git a/paddlespeech/vector/exps/ge2e/random_cycle.py b/paddlespeech/vector/exps/ge2e/random_cycle.py new file mode 100644 index 0000000000000000000000000000000000000000..290fd2fa274b66f7802cb0ab529d04099118f624 --- /dev/null +++ b/paddlespeech/vector/exps/ge2e/random_cycle.py @@ -0,0 +1,38 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random + + +def cycle(iterable): + # cycle('ABCD') --> A B C D A B C D A B C D ... + saved = [] + for element in iterable: + yield element + saved.append(element) + while saved: + for element in saved: + yield element + + +def random_cycle(iterable): + # cycle('ABCD') --> A B C D B C D A A D B C ... + saved = [] + for element in iterable: + yield element + saved.append(element) + random.shuffle(saved) + while saved: + for element in saved: + yield element + random.shuffle(saved) diff --git a/paddlespeech/vector/exps/ge2e/speaker_verification_dataset.py b/paddlespeech/vector/exps/ge2e/speaker_verification_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..194eb7f28fb485e8fc61ba25fb9c9fcb61bf1802 --- /dev/null +++ b/paddlespeech/vector/exps/ge2e/speaker_verification_dataset.py @@ -0,0 +1,131 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +from pathlib import Path + +import numpy as np +from paddle.io import BatchSampler +from paddle.io import Dataset + +from paddlespeech.vector.exps.ge2e.random_cycle import random_cycle + + +class MultiSpeakerMelDataset(Dataset): + """A 2 layer directory thatn contains mel spectrograms in *.npy format. + An Example file structure tree is shown below. We prefer to preprocess + raw datasets and organized them like this. + + dataset_root/ + speaker1/ + utterance1.npy + utterance2.npy + utterance3.npy + speaker2/ + utterance1.npy + utterance2.npy + utterance3.npy + """ + + def __init__(self, dataset_root: Path): + self.root = Path(dataset_root).expanduser() + speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()] + + speaker_utterances = { + speaker_dir: list(speaker_dir.glob("*.npy")) + for speaker_dir in speaker_dirs + } + + self.speaker_dirs = speaker_dirs + self.speaker_to_utterances = speaker_utterances + + # meta data + self.num_speakers = len(self.speaker_dirs) + self.num_utterances = np.sum( + len(utterances) + for speaker, utterances in self.speaker_to_utterances.items()) + + def get_example_by_index(self, speaker_index, utterance_index): + speaker_dir = self.speaker_dirs[speaker_index] + fpath = self.speaker_to_utterances[speaker_dir][utterance_index] + return self[fpath] + + def __getitem__(self, fpath): + return np.load(fpath) + + def __len__(self): + return int(self.num_utterances) + + +class MultiSpeakerSampler(BatchSampler): + """A multi-stratal sampler designed for speaker verification task. + First, N speakers from all speakers are sampled randomly. Then, for each + speaker, randomly sample M utterances from their corresponding utterances. + """ + + def __init__(self, + dataset: MultiSpeakerMelDataset, + speakers_per_batch: int, + utterances_per_speaker: int): + self._speakers = list(dataset.speaker_dirs) + self._speaker_to_utterances = dataset.speaker_to_utterances + + self.speakers_per_batch = speakers_per_batch + self.utterances_per_speaker = utterances_per_speaker + + def __iter__(self): + # yield list of Paths + speaker_generator = iter(random_cycle(self._speakers)) + speaker_utterances_generator = { + s: iter(random_cycle(us)) + for s, us in self._speaker_to_utterances.items() + } + + while True: + speakers = [] + for _ in range(self.speakers_per_batch): + speakers.append(next(speaker_generator)) + + utterances = [] + for s in speakers: + us = speaker_utterances_generator[s] + for _ in range(self.utterances_per_speaker): + utterances.append(next(us)) + yield utterances + + +class RandomClip(object): + def __init__(self, frames): + self.frames = frames + + def __call__(self, spec): + # spec [T, C] + T = spec.shape[0] + start = random.randint(0, T - self.frames) + return spec[start:start + self.frames, :] + + +class Collate(object): + def __init__(self, num_frames): + self.random_crop = RandomClip(num_frames) + + def __call__(self, examples): + frame_clips = [self.random_crop(mel) for mel in examples] + batced_clips = np.stack(frame_clips) + return batced_clips + + +if __name__ == "__main__": + mydataset = MultiSpeakerMelDataset( + Path("/home/chenfeiyu/datasets/SV2TTS/encoder")) + print(mydataset.get_example_by_index(0, 10)) diff --git a/paddlespeech/vector/exps/ge2e/train.py b/paddlespeech/vector/exps/ge2e/train.py new file mode 100644 index 0000000000000000000000000000000000000000..bf1cf1074b5dec41f2287f5113b6facef9909283 --- /dev/null +++ b/paddlespeech/vector/exps/ge2e/train.py @@ -0,0 +1,123 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +from paddle import DataParallel +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.nn.clip import ClipGradByGlobalNorm +from paddle.optimizer import Adam + +from paddlespeech.t2s.training import default_argument_parser +from paddlespeech.t2s.training import ExperimentBase +from paddlespeech.vector.exps.ge2e.config import get_cfg_defaults +from paddlespeech.vector.exps.ge2e.speaker_verification_dataset import Collate +from paddlespeech.vector.exps.ge2e.speaker_verification_dataset import MultiSpeakerMelDataset +from paddlespeech.vector.exps.ge2e.speaker_verification_dataset import MultiSpeakerSampler +from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder + + +class Ge2eExperiment(ExperimentBase): + def setup_model(self): + config = self.config + model = LSTMSpeakerEncoder(config.data.n_mels, config.model.num_layers, + config.model.hidden_size, + config.model.embedding_size) + optimizer = Adam( + config.training.learning_rate_init, + parameters=model.parameters(), + grad_clip=ClipGradByGlobalNorm(3)) + self.model = DataParallel(model) if self.parallel else model + self.model_core = model + self.optimizer = optimizer + + def setup_dataloader(self): + config = self.config + train_dataset = MultiSpeakerMelDataset(self.args.data) + sampler = MultiSpeakerSampler(train_dataset, + config.training.speakers_per_batch, + config.training.utterances_per_speaker) + train_loader = DataLoader( + train_dataset, + batch_sampler=sampler, + collate_fn=Collate(config.data.partial_n_frames), + num_workers=16) + + self.train_dataset = train_dataset + self.train_loader = train_loader + + def train_batch(self): + start = time.time() + batch = self.read_batch() + data_loader_time = time.time() - start + + self.optimizer.clear_grad() + self.model.train() + specs = batch + loss, eer = self.model(specs, self.config.training.speakers_per_batch) + loss.backward() + self.model_core.do_gradient_ops() + self.optimizer.step() + iteration_time = time.time() - start + + # logging + loss_value = float(loss) + msg = "Rank: {}, ".format(dist.get_rank()) + msg += "step: {}, ".format(self.iteration) + msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, + iteration_time) + msg += 'loss: {:>.6f} err: {:>.6f}'.format(loss_value, eer) + self.logger.info(msg) + + if dist.get_rank() == 0: + self.visualizer.add_scalar("train/loss", loss_value, self.iteration) + self.visualizer.add_scalar("train/eer", eer, self.iteration) + self.visualizer.add_scalar("param/w", + float(self.model_core.similarity_weight), + self.iteration) + self.visualizer.add_scalar("param/b", + float(self.model_core.similarity_bias), + self.iteration) + + def valid(self): + pass + + +def main_sp(config, args): + exp = Ge2eExperiment(config, args) + exp.setup() + exp.resume_or_load() + exp.run() + + +def main(config, args): + if args.ngpu > 1: + dist.spawn(main_sp, args=(config, args), nprocs=args.ngpu) + else: + main_sp(config, args) + + +if __name__ == "__main__": + config = get_cfg_defaults() + parser = default_argument_parser() + args = parser.parse_args() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + print(args) + + main(config, args) diff --git a/paddlespeech/vector/models/__init__.py b/paddlespeech/vector/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..185a92b8d94d3426d616c0624f0f2ee04339349e --- /dev/null +++ b/paddlespeech/vector/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/t2s/models/lstm_speaker_encoder.py b/paddlespeech/vector/models/lstm_speaker_encoder.py similarity index 100% rename from paddlespeech/t2s/models/lstm_speaker_encoder.py rename to paddlespeech/vector/models/lstm_speaker_encoder.py diff --git a/third_party/python_kaldi_features/setup.py b/third_party/python_kaldi_features/setup.py index 47c77718636963b1f2852118cba49957159d92d0..c76f23b5146b1b48ca683a4a5050a7b3c82613c3 100644 --- a/third_party/python_kaldi_features/setup.py +++ b/third_party/python_kaldi_features/setup.py @@ -3,12 +3,16 @@ try: except ImportError: from distutils.core import setup -setup(name='python_speech_features', - version='0.6', - description='Python Speech Feature extraction', - author='James Lyons', - author_email='james.lyons0@gmail.com', +with open("requirements.txt", encoding="utf-8-sig") as f: + requirements = f.readlines() + +setup(name='paddlespeech_feat', + version='0.0.1a', + description='python speech feature extraction in paddlespeech', + install_requires=requirements, + author="PaddlePaddle Speech and Language Team", + author_email="paddlesl@baidu.com", license='MIT', - url='https://github.com/jameslyons/python_speech_features', + url='https://github.com/PaddlePaddle/PaddleSpeech', packages=['python_speech_features'], )