diff --git a/README.md b/README.md
index 3970f79b55274beae95ccaf6a0a576a01b7ee4b1..50dac64cc1a5040b8b5636beb6f9b3c4e37a493f 100644
--- a/README.md
+++ b/README.md
@@ -424,6 +424,30 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
+**Punctuation Restoration**
+
+
+
+
+ Task |
+ Dataset |
+ Model Type |
+ Link |
+
+
+
+
+
+ Punctuation Restoration |
+ IWLST2012_zh |
+ Ernie Linear |
+
+ iwslt2012-punc0
+ |
+
+
+
+
## Documents
Normally, [Speech SoTA](https://paperswithcode.com/area/speech), [Audio SoTA](https://paperswithcode.com/area/audio) and [Music SoTA](https://paperswithcode.com/area/music) give you an overview of the hot academic topics in the related area. To focus on the tasks in PaddleSpeech, you will find the following guidelines are helpful to grasp the core ideas.
diff --git a/README_cn.md b/README_cn.md
index b47e9e61d08eeb933c635e21c6db27d6231038de..14167864999c3be05a9cbff9c2b6fb941356c814 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -415,6 +415,30 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
+**标点恢复**
+
+
+
+
+ 任务 |
+ 数据集 |
+ 模型种类 |
+ 链接 |
+
+
+
+
+
+ 标点恢复 |
+ IWLST2012_zh |
+ Ernie Linear |
+
+ iwslt2012-punc0
+ |
+
+
+
+
## 教程文档
对于 PaddleSpeech 的所关注的任务,以下指南有助于帮助开发者快速入门,了解语音相关核心思想。
diff --git a/docs/source/released_model.md b/docs/source/released_model.md
index 9db6a4c449a67a5ce6207a2f29ffa65ead6d527f..a10b2674f6b4a022d2fb6f4aa8b00d34681a3498 100644
--- a/docs/source/released_model.md
+++ b/docs/source/released_model.md
@@ -1,11 +1,10 @@
-
# Released Models
## Speech-to-Text Models
### Speech Recognition Model
Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | Hours of speech | Example Link
-:-------------:| :------------:| :-----: | -----: | :----------------- |:--------- | :---------- | :--------- | :-----------
+:-------------:| :------------:| :-----: | -----: | :-----: |:-----:| :-----: | :-----: | :-----:
[Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/aishell_ds2_online_cer8.00_release.tar.gz) | Aishell Dataset | Char-based | 345 MB | 2 Conv + 5 LSTM layers with only forward direction | 0.080 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0)
[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/ds2.model.tar.gz)| Aishell Dataset | Char-based | 306 MB | 2 Conv + 3 bidirectional GRU layers| 0.064 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0)
[Conformer Online Aishell ASR1 Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.chunk.release.tar.gz) | Aishell Dataset | Char-based | 283 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0594 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1)
@@ -17,22 +16,21 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER |
### Language Model based on NGram
Language Model | Training Data | Token-based | Size | Descriptions
-:-------------:| :------------:| :-----: | -----: | :-----------------
+:------------:| :------------:|:------------: | :------------: | :------------:
[English LM](https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm) | [CommonCrawl(en.00)](http://web-language-models.s3-website-us-east-1.amazonaws.com/ngrams/en/deduped/en.00.deduped.xz) | Word-based | 8.3 GB | Pruned with 0 1 1 1 1;
About 1.85 billion n-grams;
'trie' binary with '-a 22 -q 8 -b 8'
[Mandarin LM Small](https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm) | Baidu Internal Corpus | Char-based | 2.8 GB | Pruned with 0 1 2 4 4;
About 0.13 billion n-grams;
'probing' binary with default settings
[Mandarin LM Large](https://deepspeech.bj.bcebos.com/zh_lm/zhidao_giga.klm) | Baidu Internal Corpus | Char-based | 70.4 GB | No Pruning;
About 3.7 billion n-grams;
'probing' binary with default settings
### Speech Translation Models
-| Model | Training Data | Token-based | Size | Descriptions | BLEU | Example Link |
-| ------------------------------------------------------------ | ------------- | ----------- | ---- | ------------------------------------------------------------ | ----- | ------------------------------------------------------------ |
-| [Transformer FAT-ST MTL En-Zh](https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/fat_st_ted-en-zh.tar.gz) | Ted-En-Zh | Spm | | Encoder:Transformer, Decoder:Transformer,
Decoding method: Attention | 20.80 | [Transformer Ted-En-Zh ST1](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/examples/ted_en_zh/st1) |
-
+| Model | Training Data | Token-based | Size | Descriptions | BLEU | Example Link |
+| :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: |
+| [Transformer FAT-ST MTL En-Zh](https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/fat_st_ted-en-zh.tar.gz) | Ted-En-Zh| Spm| | Encoder:Transformer, Decoder:Transformer,
Decoding method: Attention | 20.80 | [Transformer Ted-En-Zh ST1](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/examples/ted_en_zh/st1) |
## Text-to-Speech Models
### Acoustic Models
-Model Type | Dataset| Example Link | Pretrained Models|Static Models|Siize(static)
+Model Type | Dataset| Example Link | Pretrained Models|Static Models|Size (static)
:-------------:| :------------:| :-----: | :-----:| :-----:| :-----:
Tacotron2|LJSpeech|[tacotron2-vctk](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts0)|[tacotron2_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.3.zip)|||
TransformerTTS| LJSpeech| [transformer-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts1)|[transformer_tts_ljspeech_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/transformer_tts/transformer_tts_ljspeech_ckpt_0.4.zip)|||
@@ -44,8 +42,8 @@ FastSpeech2| LJSpeech |[fastspeech2-ljspeech](https://github.com/PaddlePaddle/Pa
FastSpeech2| VCTK |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/tts3)|[fastspeech2_nosil_vctk_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip)|||
### Vocoders
-Model Type | Dataset| Example Link | Pretrained Models| Static Models|Size(static)
-:-------------:| :------------:| :-----: | :-----:| :-----:| :-----:
+Model Type | Dataset| Example Link | Pretrained Models| Static Models|Size (static)
+:-----:| :-----:| :-----: | :-----:| :-----:| :-----:
WaveFlow| LJSpeech |[waveflow-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc0)|[waveflow_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/waveflow/waveflow_ljspeech_ckpt_0.3.zip)|||
Parallel WaveGAN| CSMSC |[PWGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc1)|[pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip)|[pwg_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip)|5.1MB|
Parallel WaveGAN| LJSpeech |[PWGAN-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc1)|[pwg_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip)|||
@@ -69,10 +67,15 @@ Model Type | Dataset| Example Link | Pretrained Models
PANN | Audioset| [audioset_tagging_cnn](https://github.com/qiuqiangkong/audioset_tagging_cnn) | [panns_cnn6.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn6.pdparams), [panns_cnn10.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn10.pdparams), [panns_cnn14.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn14.pdparams)
PANN | ESC-50 |[pann-esc50]("./examples/esc50/cls0")|[esc50_cnn6.tar.gz](https://paddlespeech.bj.bcebos.com/cls/esc50/esc50_cnn6.tar.gz), [esc50_cnn10.tar.gz](https://paddlespeech.bj.bcebos.com/cls/esc50/esc50_cnn10.tar.gz), [esc50_cnn14.tar.gz](https://paddlespeech.bj.bcebos.com/cls/esc50/esc50_cnn14.tar.gz)
+## Punctuation Restoration Models
+Model Type | Dataset| Example Link | Pretrained Models
+:-------------:| :------------:| :-----: | :-----:
+Ernie Linear | IWLST2012_zh |[iwslt2012_punc0](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/iwslt2012/punc0)|[ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip)
+
## Speech Recognition Model from paddle 1.8
-| Acoustic Model |Training Data| Token-based | Size | Descriptions | CER | WER | Hours of speech |
-| :--------------: | :--------------: | :--------------: | :--------------: | :--------------: | :--------------: | :--------------: | :--------------: |
+| Acoustic Model |Training Data| Token-based | Size | Descriptions | CER | WER | Hours of speech |
+| :-----:| :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: |
| [Ds2 Offline Aishell model](https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz) | Aishell Dataset | Char-based | 234 MB | 2 Conv + 3 bidirectional GRU layers | 0.0804 | — | 151 h |
-| [Ds2 Offline Librispeech model](https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz) | Librispeech Dataset | Word-based | 307 MB | 2 Conv + 3 bidirectional sharing weight RNN layers | — | 0.0685 | 960 h |
-| [Ds2 Offline Baidu en8k model](https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz) | Baidu Internal English Dataset | Word-based | 273 MB | 2 Conv + 3 bidirectional GRU layers |— | 0.0541 | 8628 h |
+| [Ds2 Offline Librispeech model](https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz) | Librispeech Dataset | Word-based | 307 MB | 2 Conv + 3 bidirectional sharing weight RNN layers | — | 0.0685 | 960 h |
+| [Ds2 Offline Baidu en8k model](https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz) | Baidu Internal English Dataset | Word-based | 273 MB | 2 Conv + 3 bidirectional GRU layers |— | 0.0541 | 8628 h|
diff --git a/examples/iwslt2012/punc0/README.md b/examples/iwslt2012/punc0/README.md
index 38ef36fe07f3ed37f68f231dd6af19e3b7711029..15ccea85de86f6b9d83cb852f5f3eecc7346fe54 100644
--- a/examples/iwslt2012/punc0/README.md
+++ b/examples/iwslt2012/punc0/README.md
@@ -1,17 +1,28 @@
-# 中文实验例程
-## 测试数据:
-- IWLST2012中文:test2012
+# Punctuation Restoration with IWLST2012
+## Get Started
+### Data Preprocessing
+```bash
+./run.sh --stage 0 --stop-stage 0
+```
+### Model Training
+```bash
+./run.sh --stage 1 --stop-stage 1
+```
+### Testing
+```bash
+./run.sh --stage 2 --stop-stage 2
+```
+### Punctuation Restoration
+```bash
+./run.sh --stage 3 --stop-stage 3
+```
+## Pretrained Model
+The pretrained model can be downloaded here [ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip).
-## 运行代码
-- 运行 `./run.sh 0 0 conf/ernie_linear.yaml 1`
-
-## 实验结果:
-- ErnieLinear
- - 实验配置:conf/ernie_linear.yaml
- - 测试结果
-
- | | COMMA | PERIOD | QUESTION | OVERALL |
- |-----------|-----------|-----------|-----------|--------- |
- |Precision | 0.471831 | 0.497679 | 0.830189 | 0.599899 |
- |Recall | 0.583172 | 0.641148 | 0.846154 | 0.690158 |
- |F1 | 0.521626 | 0.560376 | 0.838095 | 0.640033 |
+### Test Result
+- Ernie Linear
+ | |COMMA | PERIOD | QUESTION | OVERALL|
+ |:-----:|:-----:|:-----:|:-----:|:-----:|
+ |Precision |0.510955 |0.526462 |0.820755 |0.619391|
+ |Recall |0.517433 |0.564179 |0.861386 |0.647666|
+ |F1 |0.514173 |0.544669 |0.840580 |0.633141|
diff --git a/examples/iwslt2012/punc0/conf/default.yaml b/examples/iwslt2012/punc0/conf/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..74ced99327b38dff4863718b75a7bff8096061c1
--- /dev/null
+++ b/examples/iwslt2012/punc0/conf/default.yaml
@@ -0,0 +1,44 @@
+###########################################################
+# DATA SETTING #
+###########################################################
+dataset_type: Ernie
+train_path: data/iwslt2012_zh/train.txt
+dev_path: data/iwslt2012_zh/dev.txt
+test_path: data/iwslt2012_zh/test.txt
+batch_size: 64
+num_workers: 2
+data_params:
+ pretrained_token: ernie-1.0
+ punc_path: data/iwslt2012_zh/punc_vocab
+ seq_len: 100
+
+
+###########################################################
+# MODEL SETTING #
+###########################################################
+model_type: ErnieLinear
+model:
+ pretrained_token: ernie-1.0
+ num_classes: 4
+
+###########################################################
+# OPTIMIZER SETTING #
+###########################################################
+optimizer_params:
+ weight_decay: 1.0e-6 # weight decay coefficient.
+
+scheduler_params:
+ learning_rate: 1.0e-5 # learning rate.
+ gamma: 1.0 # scheduler gamma.
+
+###########################################################
+# TRAINING SETTING #
+###########################################################
+max_epoch: 20
+num_snapshots: 5
+
+###########################################################
+# OTHER SETTING #
+###########################################################
+num_snapshots: 10 # max number of snapshots to keep while training
+seed: 42 # random seed for paddle, random, and np.random
diff --git a/examples/iwslt2012/punc0/conf/ernie_linear.yaml b/examples/iwslt2012/punc0/conf/ernie_linear.yaml
deleted file mode 100644
index bf8921107d6067136d41ae8e11089e5027e2c481..0000000000000000000000000000000000000000
--- a/examples/iwslt2012/punc0/conf/ernie_linear.yaml
+++ /dev/null
@@ -1,36 +0,0 @@
-data:
- dataset_type: Ernie
- train_path: data/iwslt2012_zh/train.txt
- dev_path: data/iwslt2012_zh/dev.txt
- test_path: data/iwslt2012_zh/test.txt
- data_params:
- pretrained_token: ernie-1.0
- punc_path: data/iwslt2012_zh/punc_vocab
- seq_len: 100
- batch_size: 64
- sortagrad: True
- shuffle_method: batch_shuffle
- num_workers: 0
-
-checkpoint:
- kbest_n: 5
- latest_n: 10
- metric_type: F1
-
-model_type: ErnieLinear
-
-model_params:
- pretrained_token: ernie-1.0
- num_classes: 4
-
-training:
- n_epoch: 20
- lr: !!float 1e-5
- lr_decay: 1.0
- weight_decay: !!float 1e-06
- global_grad_clip: 5.0
- log_interval: 10
- log_path: log/train_ernie_linear.log
-
-testing:
- log_path: log/test_ernie_linear.log
diff --git a/examples/iwslt2012/punc0/local/avg.sh b/examples/iwslt2012/punc0/local/avg.sh
deleted file mode 100644
index b8c14c6623cc30ca167cfb26fcc9ade349cad288..0000000000000000000000000000000000000000
--- a/examples/iwslt2012/punc0/local/avg.sh
+++ /dev/null
@@ -1,23 +0,0 @@
-#! /usr/bin/env bash
-
-if [ $# != 2 ]; then
- echo "usage: ${0} ckpt_dir avg_num"
- exit -1
-fi
-
-ckpt_dir=${1}
-average_num=${2}
-decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams
-
-python3 -u ${BIN_DIR}/avg_model.py \
---dst_model ${decode_checkpoint} \
---ckpt_dir ${ckpt_dir} \
---num ${average_num} \
---val_best
-
-if [ $? -ne 0 ]; then
- echo "Failed in avg ckpt!"
- exit 1
-fi
-
-exit 0
\ No newline at end of file
diff --git a/examples/iwslt2012/punc0/local/data.sh b/examples/iwslt2012/punc0/local/data.sh
old mode 100644
new mode 100755
diff --git a/examples/iwslt2012/punc0/local/punc_restore.sh b/examples/iwslt2012/punc0/local/punc_restore.sh
new file mode 100755
index 0000000000000000000000000000000000000000..30a4f12f8040cbb1617aa267ff68fe0248e01951
--- /dev/null
+++ b/examples/iwslt2012/punc0/local/punc_restore.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+
+config_path=$1
+train_output_path=$2
+ckpt_name=$3
+text=$4
+ckpt_prefix=${ckpt_name%.*}
+
+python3 ${BIN_DIR}/punc_restore.py \
+ --config=${config_path} \
+ --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --text=${text}
diff --git a/examples/iwslt2012/punc0/local/test.sh b/examples/iwslt2012/punc0/local/test.sh
old mode 100644
new mode 100755
index ee02246222b7ed6c7e780cde443abf537da7f14c..94e508b5bff9ae0d7b3253fafe3ebb943d44a97a
--- a/examples/iwslt2012/punc0/local/test.sh
+++ b/examples/iwslt2012/punc0/local/test.sh
@@ -1,26 +1,11 @@
-
#!/bin/bash
-if [ $# != 2 ];then
- echo "usage: ${0} config_path ckpt_path_prefix"
- exit -1
-fi
-
-ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
-echo "using $ngpu gpus..."
-
config_path=$1
-ckpt_prefix=$2
-
-python3 -u ${BIN_DIR}/test.py \
---ngpu 1 \
---config ${config_path} \
---result_file ${ckpt_prefix}.rsl \
---checkpoint_path ${ckpt_prefix}
+train_output_path=$2
+ckpt_name=$3
-if [ $? -ne 0 ]; then
- echo "Failed in evaluation!"
- exit 1
-fi
+ckpt_prefix=${ckpt_name%.*}
-exit 0
+python3 ${BIN_DIR}/test.py \
+ --config=${config_path} \
+ --checkpoint=${train_output_path}/checkpoints/${ckpt_name}
diff --git a/examples/iwslt2012/punc0/local/train.sh b/examples/iwslt2012/punc0/local/train.sh
old mode 100644
new mode 100755
index 9fabb8f75b2b00364bd6b878d9d18f058b8d91b9..85227eacbe518f59655a3b8013363c97d192b396
--- a/examples/iwslt2012/punc0/local/train.sh
+++ b/examples/iwslt2012/punc0/local/train.sh
@@ -1,28 +1,9 @@
#!/bin/bash
-if [ $# != 3 ];then
- echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name log_dir"
- exit -1
-fi
-
-ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
-echo "using $ngpu gpus..."
-
config_path=$1
-ckpt_name=$2
-log_dir=$3
-
-mkdir -p exp
-
-python3 -u ${BIN_DIR}/train.py \
---ngpu ${ngpu} \
---config ${config_path} \
---output_dir exp/${ckpt_name} \
---log_dir ${log_dir}
-
-if [ $? -ne 0 ]; then
- echo "Failed in training!"
- exit 1
-fi
+train_output_path=$2
-exit 0
+python3 ${BIN_DIR}/train.py \
+ --config=${config_path} \
+ --output-dir=${train_output_path} \
+ --ngpu=1
diff --git a/examples/iwslt2012/punc0/path.sh b/examples/iwslt2012/punc0/path.sh
old mode 100644
new mode 100755
index 8f67f9c938d36013bdb2f37f781ae3f5a9b43524..da790261f8da529f45cacf05c693c16dd20bd84b
--- a/examples/iwslt2012/punc0/path.sh
+++ b/examples/iwslt2012/punc0/path.sh
@@ -10,5 +10,5 @@ export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
-MODEL=$1
+MODEL=ernie_linear
export BIN_DIR=${MAIN_ROOT}/paddlespeech/text/exps/${MODEL}
diff --git a/examples/iwslt2012/punc0/run.sh b/examples/iwslt2012/punc0/run.sh
index 8d786a198c5671b98dc615f3eed7a3c44eed64ab..0c14eb7e23a7c02204adb933381ec9343e7041b6 100755
--- a/examples/iwslt2012/punc0/run.sh
+++ b/examples/iwslt2012/punc0/run.sh
@@ -1,40 +1,35 @@
#!/bin/bash
set -e
+source path.sh
-if [ $# -ne 4 ]; then
- echo "usage: bash ./run.sh stage gpu train_config avg_num"
- echo "eg: bash ./run.sh 1 0 train_config 1"
- exit -1
-fi
-
-stage=$1
+gpus=0,1
+stage=0
stop_stage=100
-gpus=$2
-conf_path=$3
-avg_num=$4
-avg_ckpt=avg_${avg_num}
-ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
-log_dir=log
-source path.sh ${ckpt}
+conf_path=conf/default.yaml
+train_output_path=exp/default
+ckpt_name=snapshot_iter_12840.pdz
+text=今天的天气真不错啊你下午有空吗我想约你一起去吃饭
+# with the following command, you can choose the stage range you want to run
+# such as `./run.sh --stage 0 --stop-stage 0`
+# this can not be mixed use with `$1`, `$2` ...
+source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
-if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
- bash ./local/data.sh
+ ./local/data.sh
fi
-if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
- # train model, all `ckpt` under `exp` dir
- CUDA_VISIBLE_DEVICES=${gpus} bash ./local/train.sh ${conf_path} ${ckpt} ${log_dir}
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ # train model, all `ckpt` under `train_output_path/checkpoints/` dir
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
fi
-if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- # avg n best model
- bash ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num}
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- # test ckpt avg_n
- CUDA_VISIBLE_DEVICES=${gpus} bash ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
-fi
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/punc_restore.sh ${conf_path} ${train_output_path} ${ckpt_name} ${text}|| exit -1
+fi
\ No newline at end of file
diff --git a/paddlespeech/t2s/exps/fastspeech2/gen_gta_mel.py b/paddlespeech/t2s/exps/fastspeech2/gen_gta_mel.py
index 8a9ef370c0f2916149b62c50d2425e969b49a5cb..fa46fd55f287173b25906b01a74a18aae52ba0a6 100644
--- a/paddlespeech/t2s/exps/fastspeech2/gen_gta_mel.py
+++ b/paddlespeech/t2s/exps/fastspeech2/gen_gta_mel.py
@@ -132,7 +132,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
def str2bool(str):
return True if str.lower() == 'true' else False
diff --git a/paddlespeech/t2s/exps/fastspeech2/train.py b/paddlespeech/t2s/exps/fastspeech2/train.py
index fafded6fca0cef090d9ba1c844cdc526fe5d49d6..1dfa575a1810075c5dd276b3d65551308def2b3c 100644
--- a/paddlespeech/t2s/exps/fastspeech2/train.py
+++ b/paddlespeech/t2s/exps/fastspeech2/train.py
@@ -174,7 +174,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
parser.add_argument(
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
diff --git a/paddlespeech/t2s/exps/gan_vocoder/hifigan/train.py b/paddlespeech/t2s/exps/gan_vocoder/hifigan/train.py
index 3bc11a603397b1d69b98abe303dddb32020ed35d..9ac6cbd34e55b5aa65fa2279480157a7064a42b5 100644
--- a/paddlespeech/t2s/exps/gan_vocoder/hifigan/train.py
+++ b/paddlespeech/t2s/exps/gan_vocoder/hifigan/train.py
@@ -250,7 +250,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
diff --git a/paddlespeech/t2s/exps/gan_vocoder/multi_band_melgan/train.py b/paddlespeech/t2s/exps/gan_vocoder/multi_band_melgan/train.py
index a44d2d3c263b5ca6f78b40d06a87cafa36bc7798..3d0ff7d35cfa1770221dd628ec1efed036645061 100644
--- a/paddlespeech/t2s/exps/gan_vocoder/multi_band_melgan/train.py
+++ b/paddlespeech/t2s/exps/gan_vocoder/multi_band_melgan/train.py
@@ -239,7 +239,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
diff --git a/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/synthesize_from_wav.py b/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/synthesize_from_wav.py
index ca2e3f5501a77a5d191b666b09684ea58daa396c..f5affb50b065139b0840966f9cfce14225b29753 100644
--- a/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/synthesize_from_wav.py
+++ b/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/synthesize_from_wav.py
@@ -93,7 +93,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
diff --git a/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/train.py b/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/train.py
index 98b0ed717c89d2d5d7ccf2c910420bd03722beaa..a7881d6bbb2d8dee9e3030124d2f04f97f598570 100644
--- a/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/train.py
+++ b/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/train.py
@@ -216,7 +216,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
benchmark_group = parser.add_argument_group(
'benchmark', 'arguments related to benchmark.')
diff --git a/paddlespeech/t2s/exps/gan_vocoder/style_melgan/train.py b/paddlespeech/t2s/exps/gan_vocoder/style_melgan/train.py
index bc746467806da598e1e346358fa24e4cac9871c8..36e4d645701bfe6ad99f43faa2d187b6c87eff57 100644
--- a/paddlespeech/t2s/exps/gan_vocoder/style_melgan/train.py
+++ b/paddlespeech/t2s/exps/gan_vocoder/style_melgan/train.py
@@ -232,7 +232,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
diff --git a/paddlespeech/t2s/exps/gan_vocoder/synthesize.py b/paddlespeech/t2s/exps/gan_vocoder/synthesize.py
index 6f4dc92dbc9a74f4735880e5ce297a898e3e1014..c60b9add2eb584029b53b1ae657f0a04edece028 100644
--- a/paddlespeech/t2s/exps/gan_vocoder/synthesize.py
+++ b/paddlespeech/t2s/exps/gan_vocoder/synthesize.py
@@ -42,7 +42,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
diff --git a/paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py b/paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py
index 2854d0555ad3041de9a6c4d35b6fd1c673a5042f..cb742c59587fa91f442d4ba5868c7b13a23fe085 100644
--- a/paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py
+++ b/paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py
@@ -173,7 +173,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir")
parser.add_argument(
"--inference-dir", type=str, help="dir to save inference models")
- parser.add_argument("--verbose", type=int, default=1, help="verbose")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
diff --git a/paddlespeech/t2s/exps/transformer_tts/synthesize.py b/paddlespeech/t2s/exps/transformer_tts/synthesize.py
index 666c3b7237f56894d804134e0674df0ee525a92b..7b6b1873fca65d4765df40ac2c8a233f0e33ae3a 100644
--- a/paddlespeech/t2s/exps/transformer_tts/synthesize.py
+++ b/paddlespeech/t2s/exps/transformer_tts/synthesize.py
@@ -118,7 +118,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
diff --git a/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py b/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py
index ba197f43cbda55733d6aa07236d1620fcc16dd15..0cd7d224e0b983d1461bfd42ffebd812f79380b4 100644
--- a/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py
+++ b/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py
@@ -137,7 +137,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
diff --git a/paddlespeech/t2s/exps/transformer_tts/train.py b/paddlespeech/t2s/exps/transformer_tts/train.py
index 163339f4a99948da4ad647ab693bc68c141d6811..8695c06a9706a12c7b48fd59acdb5b6c16e572a1 100644
--- a/paddlespeech/t2s/exps/transformer_tts/train.py
+++ b/paddlespeech/t2s/exps/transformer_tts/train.py
@@ -165,7 +165,6 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
- parser.add_argument("--verbose", type=int, default=1, help="verbose.")
parser.add_argument(
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
diff --git a/paddlespeech/text/utils/__init__.py b/paddlespeech/text/exps/__init__.py
similarity index 89%
rename from paddlespeech/text/utils/__init__.py
rename to paddlespeech/text/exps/__init__.py
index 185a92b8d94d3426d616c0624f0f2ee04339349e..abf198b97e6e818e1fbe59006f98492640bcee54 100644
--- a/paddlespeech/text/utils/__init__.py
+++ b/paddlespeech/text/exps/__init__.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/paddlespeech/text/training/__init__.py b/paddlespeech/text/exps/ernie_linear/__init__.py
similarity index 89%
rename from paddlespeech/text/training/__init__.py
rename to paddlespeech/text/exps/ernie_linear/__init__.py
index 185a92b8d94d3426d616c0624f0f2ee04339349e..abf198b97e6e818e1fbe59006f98492640bcee54 100644
--- a/paddlespeech/text/training/__init__.py
+++ b/paddlespeech/text/exps/ernie_linear/__init__.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/paddlespeech/text/exps/ernie_linear/punc_restore.py b/paddlespeech/text/exps/ernie_linear/punc_restore.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cb4d07199d790ba600834c836d383f6b6f19238
--- /dev/null
+++ b/paddlespeech/text/exps/ernie_linear/punc_restore.py
@@ -0,0 +1,110 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import re
+
+import paddle
+import yaml
+from paddlenlp.transformers import ErnieTokenizer
+from yacs.config import CfgNode
+
+from paddlespeech.text.models.ernie_linear import ErnieLinear
+
+DefinedClassifier = {
+ 'ErnieLinear': ErnieLinear,
+}
+
+tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
+
+
+def _clean_text(text, punc_list):
+ text = text.lower()
+ text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text)
+ text = re.sub(f'[{"".join([p for p in punc_list][1:])}]', '', text)
+ return text
+
+
+def preprocess(text, punc_list):
+ clean_text = _clean_text(text, punc_list)
+ assert len(clean_text) > 0, f'Invalid input string: {text}'
+ tokenized_input = tokenizer(
+ list(clean_text), return_length=True, is_split_into_words=True)
+ _inputs = dict()
+ _inputs['input_ids'] = tokenized_input['input_ids']
+ _inputs['seg_ids'] = tokenized_input['token_type_ids']
+ _inputs['seq_len'] = tokenized_input['seq_len']
+ return _inputs
+
+
+def test(args):
+ with open(args.config) as f:
+ config = CfgNode(yaml.safe_load(f))
+ print("========Args========")
+ print(yaml.safe_dump(vars(args)))
+ print("========Config========")
+ print(config)
+
+ punc_list = []
+ with open(config["data_params"]["punc_path"], 'r') as f:
+ for line in f:
+ punc_list.append(line.strip())
+
+ model = DefinedClassifier[config["model_type"]](**config["model"])
+ state_dict = paddle.load(args.checkpoint)
+ model.set_state_dict(state_dict["main_params"])
+ model.eval()
+ _inputs = preprocess(args.text, punc_list)
+ seq_len = _inputs['seq_len']
+ input_ids = paddle.to_tensor(_inputs['input_ids']).unsqueeze(0)
+ seg_ids = paddle.to_tensor(_inputs['seg_ids']).unsqueeze(0)
+ logits, _ = model(input_ids, seg_ids)
+ preds = paddle.argmax(logits, axis=-1).squeeze(0)
+ tokens = tokenizer.convert_ids_to_tokens(
+ _inputs['input_ids'][1:seq_len - 1])
+ labels = preds[1:seq_len - 1].tolist()
+ assert len(tokens) == len(labels)
+ # add 0 for non punc
+ punc_list = [0] + punc_list
+ text = ''
+ for t, l in zip(tokens, labels):
+ text += t
+ if l != 0: # Non punc.
+ text += punc_list[l]
+ print("Punctuation Restoration Result:", text)
+ return text
+
+
+def main():
+ # parse args and config and redirect to train_sp
+ parser = argparse.ArgumentParser(description="Run Punctuation Restoration.")
+ parser.add_argument("--config", type=str, help="ErnieLinear config file.")
+ parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
+ parser.add_argument("--text", type=str, help="raw text to be restored.")
+ parser.add_argument(
+ "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
+
+ args = parser.parse_args()
+
+ if args.ngpu == 0:
+ paddle.set_device("cpu")
+ elif args.ngpu > 0:
+ paddle.set_device("gpu")
+ else:
+ print("ngpu should >= 0 !")
+
+ test(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/paddlespeech/text/exps/ernie_linear/test.py b/paddlespeech/text/exps/ernie_linear/test.py
index 3cd507fbbe4fc6528a4cf1c1f6deb1ac33eed124..4302a1a3bfa80e8e8417fc7d5ec2786eda8417df 100644
--- a/paddlespeech/text/exps/ernie_linear/test.py
+++ b/paddlespeech/text/exps/ernie_linear/test.py
@@ -11,36 +11,110 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Evaluation for model."""
+import argparse
+
+import numpy as np
+import paddle
+import pandas as pd
import yaml
+from paddle import nn
+from paddle.io import DataLoader
+from sklearn.metrics import classification_report
+from sklearn.metrics import precision_recall_fscore_support
+from yacs.config import CfgNode
-from paddlespeech.s2t.utils.utility import print_arguments
-from paddlespeech.text.training.trainer import Tester
-from paddlespeech.text.utils.default_parser import default_argument_parser
+from paddlespeech.text.models.ernie_linear import ErnieLinear
+from paddlespeech.text.models.ernie_linear import PuncDataset
+from paddlespeech.text.models.ernie_linear import PuncDatasetFromErnieTokenizer
+DefinedClassifier = {
+ 'ErnieLinear': ErnieLinear,
+}
-def main_sp(config, args):
- exp = Tester(config, args)
- exp.setup()
- exp.run_test()
+DefinedLoss = {
+ "ce": nn.CrossEntropyLoss,
+}
+DefinedDataset = {
+ 'Punc': PuncDataset,
+ 'Ernie': PuncDatasetFromErnieTokenizer,
+}
-def main(config, args):
- main_sp(config, args)
+def evaluation(y_pred, y_test):
+ precision, recall, f1, _ = precision_recall_fscore_support(
+ y_test, y_pred, average=None, labels=[1, 2, 3])
+ overall = precision_recall_fscore_support(
+ y_test, y_pred, average='macro', labels=[1, 2, 3])
+ result = pd.DataFrame(
+ np.array([precision, recall, f1]),
+ columns=list(['O', 'COMMA', 'PERIOD', 'QUESTION'])[1:],
+ index=['Precision', 'Recall', 'F1'])
+ result['OVERALL'] = overall[:3]
+ return result
+
+
+def test(args):
+ with open(args.config) as f:
+ config = CfgNode(yaml.safe_load(f))
+ print("========Args========")
+ print(yaml.safe_dump(vars(args)))
+ print("========Config========")
+ print(config)
+
+ test_dataset = DefinedDataset[config["dataset_type"]](
+ train_path=config["test_path"], **config["data_params"])
+ test_loader = DataLoader(
+ test_dataset,
+ batch_size=config.batch_size,
+ shuffle=False,
+ drop_last=False)
+ model = DefinedClassifier[config["model_type"]](**config["model"])
+ state_dict = paddle.load(args.checkpoint)
+ model.set_state_dict(state_dict["main_params"])
+ model.eval()
+
+ punc_list = []
+ for i in range(len(test_loader.dataset.id2punc)):
+ punc_list.append(test_loader.dataset.id2punc[i])
+
+ test_total_label = []
+ test_total_predict = []
+
+ for i, batch in enumerate(test_loader):
+ input, label = batch
+ label = paddle.reshape(label, shape=[-1])
+ y, logit = model(input)
+ pred = paddle.argmax(logit, axis=1)
+ test_total_label.extend(label.numpy().tolist())
+ test_total_predict.extend(pred.numpy().tolist())
+ t = classification_report(
+ test_total_label, test_total_predict, target_names=punc_list)
+ print(t)
+ t2 = evaluation(test_total_label, test_total_predict)
+ print('=========================================================')
+ print(t2)
+
+
+def main():
+ # parse args and config and redirect to train_sp
+ parser = argparse.ArgumentParser(description="Test a ErnieLinear model.")
+ parser.add_argument("--config", type=str, help="ErnieLinear config file.")
+ parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
+ parser.add_argument(
+ "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
-if __name__ == "__main__":
- parser = default_argument_parser()
args = parser.parse_args()
- print_arguments(args, globals())
- # https://yaml.org/type/float.html
- with open(args.config, "r") as f:
- config = yaml.load(f, Loader=yaml.FullLoader)
+ if args.ngpu == 0:
+ paddle.set_device("cpu")
+ elif args.ngpu > 0:
+ paddle.set_device("gpu")
+ else:
+ print("ngpu should >= 0 !")
+
+ test(args)
- print(config)
- if args.dump_config:
- with open(args.dump_config, 'w') as f:
- print(config, file=f)
- main(config, args)
+if __name__ == "__main__":
+ main()
diff --git a/paddlespeech/text/exps/ernie_linear/train.py b/paddlespeech/text/exps/ernie_linear/train.py
index 09071438139bec882caaec6e212ceff57ba1095c..0d730d6666ae3a9ec8290e5c73ebcd45a0842302 100644
--- a/paddlespeech/text/exps/ernie_linear/train.py
+++ b/paddlespeech/text/exps/ernie_linear/train.py
@@ -11,40 +11,163 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Trainer for punctuation_restoration task."""
+import argparse
+import logging
+import os
+import shutil
+from pathlib import Path
+
+import paddle
import yaml
+from paddle import DataParallel
from paddle import distributed as dist
+from paddle import nn
+from paddle.io import DataLoader
+from paddle.optimizer import Adam
+from paddle.optimizer.lr import ExponentialDecay
+from yacs.config import CfgNode
-from paddlespeech.s2t.utils.utility import print_arguments
-from paddlespeech.text.training.trainer import Trainer
-from paddlespeech.text.utils.default_parser import default_argument_parser
+from paddlespeech.t2s.training.extensions.snapshot import Snapshot
+from paddlespeech.t2s.training.extensions.visualizer import VisualDL
+from paddlespeech.t2s.training.seeding import seed_everything
+from paddlespeech.t2s.training.trainer import Trainer
+from paddlespeech.text.models.ernie_linear import ErnieLinear
+from paddlespeech.text.models.ernie_linear import ErnieLinearEvaluator
+from paddlespeech.text.models.ernie_linear import ErnieLinearUpdater
+from paddlespeech.text.models.ernie_linear import PuncDataset
+from paddlespeech.text.models.ernie_linear import PuncDatasetFromErnieTokenizer
+DefinedClassifier = {
+ 'ErnieLinear': ErnieLinear,
+}
-def main_sp(config, args):
- exp = Trainer(config, args)
- exp.setup()
- exp.run()
+DefinedLoss = {
+ "ce": nn.CrossEntropyLoss,
+}
+DefinedDataset = {
+ 'Punc': PuncDataset,
+ 'Ernie': PuncDatasetFromErnieTokenizer,
+}
-def main(config, args):
- if args.ngpu > 1:
- dist.spawn(main_sp, args=(config, args), nprocs=args.ngpu)
+
+def train_sp(args, config):
+ # decides device type and whether to run in parallel
+ # setup running environment correctly
+ if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
+ paddle.set_device("cpu")
else:
- main_sp(config, args)
+ paddle.set_device("gpu")
+ world_size = paddle.distributed.get_world_size()
+ if world_size > 1:
+ paddle.distributed.init_parallel_env()
+ # set the random seed, it is a must for multiprocess training
+ seed_everything(config.seed)
+
+ print(
+ f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
+ )
+ # dataloader has been too verbose
+ logging.getLogger("DataLoader").disabled = True
+ train_dataset = DefinedDataset[config["dataset_type"]](
+ train_path=config["train_path"], **config["data_params"])
+ dev_dataset = DefinedDataset[config["dataset_type"]](
+ train_path=config["dev_path"], **config["data_params"])
+ train_dataloader = DataLoader(
+ train_dataset,
+ shuffle=True,
+ num_workers=config.num_workers,
+ batch_size=config.batch_size)
+
+ dev_dataloader = DataLoader(
+ dev_dataset,
+ batch_size=config.batch_size,
+ shuffle=False,
+ drop_last=False,
+ num_workers=config.num_workers)
+
+ print("dataloaders done!")
+
+ model = DefinedClassifier[config["model_type"]](**config["model"])
+
+ if world_size > 1:
+ model = DataParallel(model)
+ print("model done!")
+
+ criterion = DefinedLoss[config["loss_type"]](
+ **config["loss"]) if "loss_type" in config else DefinedLoss["ce"]()
+
+ print("criterions done!")
+
+ lr_schedule = ExponentialDecay(**config["scheduler_params"])
+ optimizer = Adam(
+ learning_rate=lr_schedule,
+ parameters=model.parameters(),
+ weight_decay=paddle.regularizer.L2Decay(
+ config["optimizer_params"]["weight_decay"]))
+
+ print("optimizer done!")
+
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+ if dist.get_rank() == 0:
+ config_name = args.config.split("/")[-1]
+ # copy conf to output_dir
+ shutil.copyfile(args.config, output_dir / config_name)
+
+ updater = ErnieLinearUpdater(
+ model=model,
+ criterion=criterion,
+ scheduler=lr_schedule,
+ optimizer=optimizer,
+ dataloader=train_dataloader,
+ output_dir=output_dir)
+
+ trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
+
+ evaluator = ErnieLinearEvaluator(
+ model=model,
+ criterion=criterion,
+ dataloader=dev_dataloader,
+ output_dir=output_dir)
+
+ if dist.get_rank() == 0:
+ trainer.extend(evaluator, trigger=(1, "epoch"))
+ trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
+ trainer.extend(
+ Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
+ # print(trainer.extensions)
+ trainer.run()
+
+
+def main():
+ # parse args and config and redirect to train_sp
+ parser = argparse.ArgumentParser(description="Train a ErnieLinear model.")
+ parser.add_argument("--config", type=str, help="ErnieLinear config file.")
+ parser.add_argument("--output-dir", type=str, help="output dir.")
+ parser.add_argument(
+ "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
-if __name__ == "__main__":
- parser = default_argument_parser()
args = parser.parse_args()
- print_arguments(args, globals())
- # https://yaml.org/type/float.html
- with open(args.config, "r") as f:
- config = yaml.load(f, Loader=yaml.FullLoader)
+ with open(args.config) as f:
+ config = CfgNode(yaml.safe_load(f))
+ print("========Args========")
+ print(yaml.safe_dump(vars(args)))
+ print("========Config========")
print(config)
- if args.dump_config:
- with open(args.dump_config, 'w') as f:
- print(config, file=f)
+ print(
+ f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
+ )
+
+ # dispatch
+ if args.ngpu > 1:
+ dist.spawn(train_sp, (args, config), nprocs=args.ngpu)
+ else:
+ train_sp(args, config)
+
- main(config, args)
+if __name__ == "__main__":
+ main()
diff --git a/paddlespeech/text/models/ernie_linear/__init__.py b/paddlespeech/text/models/ernie_linear/__init__.py
index 93453ce7473645aba18e516300a15d93bdc614ba..0a10a6eb2d3119b89a7e88d98254571735b29cfa 100644
--- a/paddlespeech/text/models/ernie_linear/__init__.py
+++ b/paddlespeech/text/models/ernie_linear/__init__.py
@@ -11,4 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from .model import ErnieLinear
+from .dataset import *
+from .ernie_linear import *
+from .ernie_linear_updater import *
diff --git a/paddlespeech/text/models/ernie_linear/dataset.py b/paddlespeech/text/models/ernie_linear/dataset.py
index 086e91bb890e82c49f9a00636c742bd44358203c..64c8d0bdfd34d32d016b539f147bdfa145b604f0 100644
--- a/paddlespeech/text/models/ernie_linear/dataset.py
+++ b/paddlespeech/text/models/ernie_linear/dataset.py
@@ -99,10 +99,8 @@ class PuncDatasetFromErnieTokenizer(Dataset):
self.tokenizer = ErnieTokenizer.from_pretrained(pretrained_token)
self.paddingID = self.tokenizer.pad_token_id
self.seq_len = seq_len
-
self.punc2id = self.load_vocab(punc_path, extra_word_list=[" "])
self.id2punc = {k: v for (v, k) in self.punc2id.items()}
-
tmp_seqs = open(train_path, encoding='utf-8').readlines()
self.txt_seqs = [i for seq in tmp_seqs for i in seq.split()]
self.preprocess(self.txt_seqs)
@@ -125,6 +123,7 @@ class PuncDatasetFromErnieTokenizer(Dataset):
input_data = []
label = []
count = 0
+ print("Preprocessing in PuncDatasetFromErnieTokenizer...")
for i in range(len(txt_seqs) - 1):
word = txt_seqs[i]
punc = txt_seqs[i + 1]
diff --git a/paddlespeech/text/models/ernie_linear/model.py b/paddlespeech/text/models/ernie_linear/ernie_linear.py
similarity index 100%
rename from paddlespeech/text/models/ernie_linear/model.py
rename to paddlespeech/text/models/ernie_linear/ernie_linear.py
diff --git a/paddlespeech/text/models/ernie_linear/ernie_linear_updater.py b/paddlespeech/text/models/ernie_linear/ernie_linear_updater.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b3d7410e04afff43ea16a821473fe05eb4f195d
--- /dev/null
+++ b/paddlespeech/text/models/ernie_linear/ernie_linear_updater.py
@@ -0,0 +1,123 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import paddle
+from paddle import distributed as dist
+from paddle.io import DataLoader
+from paddle.nn import Layer
+from paddle.optimizer import Optimizer
+from paddle.optimizer.lr import LRScheduler
+from sklearn.metrics import f1_score
+
+from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
+from paddlespeech.t2s.training.reporter import report
+from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
+logging.basicConfig(
+ format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
+ datefmt='[%Y-%m-%d %H:%M:%S]')
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+class ErnieLinearUpdater(StandardUpdater):
+ def __init__(self,
+ model: Layer,
+ criterion: Layer,
+ scheduler: LRScheduler,
+ optimizer: Optimizer,
+ dataloader: DataLoader,
+ output_dir=None):
+ super().__init__(model, optimizer, dataloader, init_state=None)
+ self.model = model
+ self.dataloader = dataloader
+
+ self.criterion = criterion
+ self.scheduler = scheduler
+ self.optimizer = optimizer
+
+ log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
+ self.filehandler = logging.FileHandler(str(log_file))
+ logger.addHandler(self.filehandler)
+ self.logger = logger
+ self.msg = ""
+
+ def update_core(self, batch):
+ self.msg = "Rank: {}, ".format(dist.get_rank())
+ losses_dict = {}
+
+ input, label = batch
+ label = paddle.reshape(label, shape=[-1])
+ y, logit = self.model(input)
+ pred = paddle.argmax(logit, axis=1)
+
+ loss = self.criterion(y, label)
+
+ self.optimizer.clear_grad()
+ loss.backward()
+
+ self.optimizer.step()
+ self.scheduler.step()
+
+ F1_score = f1_score(
+ label.numpy().tolist(), pred.numpy().tolist(), average="macro")
+
+ report("train/loss", float(loss))
+ losses_dict["loss"] = float(loss)
+ report("train/F1_score", float(F1_score))
+ losses_dict["F1_score"] = float(F1_score)
+
+ self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
+ for k, v in losses_dict.items())
+
+
+class ErnieLinearEvaluator(StandardEvaluator):
+ def __init__(self,
+ model: Layer,
+ criterion: Layer,
+ dataloader: DataLoader,
+ output_dir=None):
+ super().__init__(model, dataloader)
+ self.model = model
+ self.criterion = criterion
+ self.dataloader = dataloader
+
+ log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
+ self.filehandler = logging.FileHandler(str(log_file))
+ logger.addHandler(self.filehandler)
+ self.logger = logger
+ self.msg = ""
+
+ def evaluate_core(self, batch):
+ self.msg = "Evaluate: "
+ losses_dict = {}
+
+ input, label = batch
+ label = paddle.reshape(label, shape=[-1])
+ y, logit = self.model(input)
+ pred = paddle.argmax(logit, axis=1)
+
+ loss = self.criterion(y, label)
+
+ F1_score = f1_score(
+ label.numpy().tolist(), pred.numpy().tolist(), average="macro")
+
+ report("eval/loss", float(loss))
+ losses_dict["loss"] = float(loss)
+ report("eval/F1_score", float(F1_score))
+ losses_dict["F1_score"] = float(F1_score)
+
+ self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
+ for k, v in losses_dict.items())
+ self.logger.info(self.msg)
diff --git a/paddlespeech/text/training/trainer.py b/paddlespeech/text/training/trainer.py
deleted file mode 100644
index b5e6a563cc3d8f15ffad92f6fa3ff961104c4d54..0000000000000000000000000000000000000000
--- a/paddlespeech/text/training/trainer.py
+++ /dev/null
@@ -1,524 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import logging
-import time
-from collections import defaultdict
-from pathlib import Path
-
-import numpy as np
-import paddle
-import paddle.nn as nn
-import pandas as pd
-from paddle import distributed as dist
-from paddle.io import DataLoader
-from sklearn.metrics import classification_report
-from sklearn.metrics import f1_score
-from sklearn.metrics import precision_recall_fscore_support
-
-from ...s2t.utils import layer_tools
-from ...s2t.utils import mp_tools
-from ...s2t.utils.checkpoint import Checkpoint
-from ...text.models import ErnieLinear
-from ...text.models.ernie_linear.dataset import PuncDataset
-from ...text.models.ernie_linear.dataset import PuncDatasetFromErnieTokenizer
-
-__all__ = ["Trainer", "Tester"]
-
-DefinedClassifier = {
- 'ErnieLinear': ErnieLinear,
-}
-
-DefinedLoss = {
- "ce": nn.CrossEntropyLoss,
-}
-
-DefinedDataset = {
- 'Punc': PuncDataset,
- 'Ernie': PuncDatasetFromErnieTokenizer,
-}
-
-
-class Trainer():
- def __init__(self, config, args):
- self.config = config
- self.args = args
- self.optimizer = None
- self.output_dir = None
- self.log_dir = None
- self.checkpoint_dir = None
- self.iteration = 0
- self.epoch = 0
-
- def setup(self):
- """Setup the experiment.
- """
- self.setup_log_dir()
- self.setup_logger()
- if self.args.ngpu > 0:
- paddle.set_device('gpu')
- else:
- paddle.set_device('cpu')
- if self.parallel:
- self.init_parallel()
-
- self.setup_output_dir()
- self.dump_config()
- self.setup_checkpointer()
-
- self.setup_model()
-
- self.setup_dataloader()
-
- self.iteration = 0
- self.epoch = 1
-
- @property
- def parallel(self):
- """A flag indicating whether the experiment should run with
- multiprocessing.
- """
- return self.args.ngpu > 1
-
- def init_parallel(self):
- """Init environment for multiprocess training.
- """
- dist.init_parallel_env()
-
- @mp_tools.rank_zero_only
- def save(self, tag=None, infos: dict=None):
- """Save checkpoint (model parameters and optimizer states).
-
- Args:
- tag (int or str, optional): None for step, else using tag, e.g epoch. Defaults to None.
- infos (dict, optional): meta data to save. Defaults to None.
- """
-
- infos = infos if infos else dict()
- infos.update({
- "step": self.iteration,
- "epoch": self.epoch,
- "lr": self.optimizer.get_lr()
- })
- self.checkpointer.save_parameters(self.checkpoint_dir, self.iteration
- if tag is None else tag, self.model,
- self.optimizer, infos)
-
- def resume_or_scratch(self):
- """Resume from latest checkpoint at checkpoints in the output
- directory or load a specified checkpoint.
-
- If ``args.checkpoint_path`` is not None, load the checkpoint, else
- resume training.
- """
- scratch = None
- infos = self.checkpointer.load_parameters(
- self.model,
- self.optimizer,
- checkpoint_dir=self.checkpoint_dir,
- checkpoint_path=self.args.checkpoint_path)
- if infos:
- # restore from ckpt
- self.iteration = infos["step"]
- self.epoch = infos["epoch"]
- scratch = False
- else:
- self.iteration = 0
- self.epoch = 0
- scratch = True
-
- return scratch
-
- def new_epoch(self):
- """Reset the train loader seed and increment `epoch`.
- """
- self.epoch += 1
- if self.parallel:
- self.train_loader.batch_sampler.set_epoch(self.epoch)
-
- def train(self):
- """The training process control by epoch."""
- from_scratch = self.resume_or_scratch()
-
- if from_scratch:
- # save init model, i.e. 0 epoch
- self.save(tag="init")
-
- self.lr_scheduler.step(self.iteration)
- if self.parallel:
- self.train_loader.batch_sampler.set_epoch(self.epoch)
-
- self.logger.info(
- f"Train Total Examples: {len(self.train_loader.dataset)}")
- self.punc_list = []
- for i in range(len(self.train_loader.dataset.id2punc)):
- self.punc_list.append(self.train_loader.dataset.id2punc[i])
- while self.epoch < self.config["training"]["n_epoch"]:
- self.model.train()
- self.total_label_train = []
- self.total_predict_train = []
- try:
- data_start_time = time.time()
- for batch_index, batch in enumerate(self.train_loader):
- dataload_time = time.time() - data_start_time
- msg = "Train: Rank: {}, ".format(dist.get_rank())
- msg += "epoch: {}, ".format(self.epoch)
- msg += "step: {}, ".format(self.iteration)
- msg += "batch : {}/{}, ".format(batch_index + 1,
- len(self.train_loader))
- msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
- msg += "data time: {:>.3f}s, ".format(dataload_time)
- self.train_batch(batch_index, batch, msg)
- data_start_time = time.time()
- # t = classification_report(
- # self.total_label_train,
- # self.total_predict_train,
- # target_names=self.punc_list)
- # self.logger.info(t)
- except Exception as e:
- self.logger.error(e)
- raise e
-
- total_loss, F1_score = self.valid()
- self.logger.info("Epoch {} Val info val_loss {}, F1_score {}".
- format(self.epoch, total_loss, F1_score))
-
- self.save(
- tag=self.epoch, infos={"val_loss": total_loss,
- "F1": F1_score})
- # step lr every epoch
- self.lr_scheduler.step()
- self.new_epoch()
-
- def run(self):
- """The routine of the experiment after setup. This method is intended
- to be used by the user.
- """
- try:
- self.train()
- except KeyboardInterrupt:
- self.logger.info("Training was aborted by keybord interrupt.")
- self.save()
- exit(-1)
- finally:
- self.destory()
- self.logger.info("Training Done.")
-
- def setup_output_dir(self):
- """Create a directory used for output.
- """
- # output dir
- output_dir = Path(self.args.output_dir).expanduser()
- output_dir.mkdir(parents=True, exist_ok=True)
-
- self.output_dir = output_dir
-
- def setup_log_dir(self):
- """Create a directory used for logging.
- """
- # log dir
- log_dir = Path(self.args.log_dir).expanduser()
- log_dir.mkdir(parents=True, exist_ok=True)
-
- self.log_dir = log_dir
-
- def setup_checkpointer(self):
- """Create a directory used to save checkpoints into.
-
- It is "checkpoints" inside the output directory.
- """
- # checkpoint dir
- self.checkpointer = Checkpoint(self.config["checkpoint"]["kbest_n"],
- self.config["checkpoint"]["latest_n"])
-
- checkpoint_dir = self.output_dir / "checkpoints"
- checkpoint_dir.mkdir(exist_ok=True)
-
- self.checkpoint_dir = checkpoint_dir
-
- def setup_logger(self):
- LOG_FORMAT = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s"
- format_str = logging.Formatter(
- '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
- )
- logging.basicConfig(
- filename=self.config["training"]["log_path"],
- level=logging.INFO,
- format=LOG_FORMAT)
- self.logger = logging.getLogger(__name__)
-
- self.logger.setLevel(logging.INFO)
- sh = logging.StreamHandler()
- sh.setFormatter(format_str)
- self.logger.addHandler(sh)
-
- self.logger.info('info')
-
- @mp_tools.rank_zero_only
- def destory(self):
- pass
-
- @mp_tools.rank_zero_only
- def dump_config(self):
- """Save the configuration used for this experiment.
-
- It is saved in to ``config.yaml`` in the output directory at the
- beginning of the experiment.
- """
- with open(self.output_dir / "config.yaml", "wt") as f:
- print(self.config, file=f)
-
- def train_batch(self, batch_index, batch_data, msg):
- start = time.time()
-
- input, label = batch_data
- label = paddle.reshape(label, shape=[-1])
- y, logit = self.model(input)
- pred = paddle.argmax(logit, axis=1)
- self.total_label_train.extend(label.numpy().tolist())
- self.total_predict_train.extend(pred.numpy().tolist())
- loss = self.crit(y, label)
-
- loss.backward()
- layer_tools.print_grads(self.model, print_func=None)
- self.optimizer.step()
- self.optimizer.clear_grad()
- iteration_time = time.time() - start
-
- losses_np = {
- "train_loss": float(loss),
- }
- msg += "train time: {:>.3f}s, ".format(iteration_time)
- msg += "batch size: {}, ".format(self.config["data"]["batch_size"])
- msg += ", ".join("{}: {:>.6f}".format(k, v)
- for k, v in losses_np.items())
- self.logger.info(msg)
- self.iteration += 1
-
- @paddle.no_grad()
- def valid(self):
- self.logger.info(
- f"Valid Total Examples: {len(self.valid_loader.dataset)}")
- self.model.eval()
- valid_losses = defaultdict(list)
- num_seen_utts = 1
- total_loss = 0.0
- valid_total_label = []
- valid_total_predict = []
- for i, batch in enumerate(self.valid_loader):
- input, label = batch
- label = paddle.reshape(label, shape=[-1])
- y, logit = self.model(input)
- pred = paddle.argmax(logit, axis=1)
- valid_total_label.extend(label.numpy().tolist())
- valid_total_predict.extend(pred.numpy().tolist())
- loss = self.crit(y, label)
-
- if paddle.isfinite(loss):
- num_utts = batch[1].shape[0]
- num_seen_utts += num_utts
- total_loss += float(loss) * num_utts
- valid_losses["val_loss"].append(float(loss))
-
- if (i + 1) % self.config["training"]["log_interval"] == 0:
- valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
- valid_dump["val_history_loss"] = total_loss / num_seen_utts
-
- # logging
- msg = f"Valid: Rank: {dist.get_rank()}, "
- msg += "epoch: {}, ".format(self.epoch)
- msg += "step: {}, ".format(self.iteration)
- msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader))
- msg += ", ".join("{}: {:>.6f}".format(k, v)
- for k, v in valid_dump.items())
- self.logger.info(msg)
-
- self.logger.info("Rank {} Val info val_loss {}".format(
- dist.get_rank(), total_loss / num_seen_utts))
- F1_score = f1_score(
- valid_total_label, valid_total_predict, average="macro")
- return total_loss / num_seen_utts, F1_score
-
- def setup_model(self):
- config = self.config
-
- model = DefinedClassifier[self.config["model_type"]](
- **self.config["model_params"])
- self.crit = DefinedLoss[self.config["loss_type"]](**self.config[
- "loss"]) if "loss_type" in self.config else DefinedLoss["ce"]()
-
- if self.parallel:
- model = paddle.DataParallel(model)
-
- # self.logger.info(f"{model}")
- # layer_tools.print_params(model, self.logger.info)
-
- lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
- learning_rate=config["training"]["lr"],
- gamma=config["training"]["lr_decay"],
- verbose=True)
- optimizer = paddle.optimizer.Adam(
- learning_rate=lr_scheduler,
- parameters=model.parameters(),
- weight_decay=paddle.regularizer.L2Decay(
- config["training"]["weight_decay"]))
-
- self.model = model
- self.optimizer = optimizer
- self.lr_scheduler = lr_scheduler
- self.logger.info("Setup model/criterion/optimizer/lr_scheduler!")
-
- def setup_dataloader(self):
- config = self.config["data"].copy()
- train_dataset = DefinedDataset[config["dataset_type"]](
- train_path=config["train_path"], **config["data_params"])
- dev_dataset = DefinedDataset[config["dataset_type"]](
- train_path=config["dev_path"], **config["data_params"])
-
- self.train_loader = DataLoader(
- train_dataset,
- num_workers=config["num_workers"],
- batch_size=config["batch_size"])
- self.valid_loader = DataLoader(
- dev_dataset,
- batch_size=config["batch_size"],
- shuffle=False,
- drop_last=False,
- num_workers=config["num_workers"])
- self.logger.info("Setup train/valid Dataloader!")
-
-
-class Tester(Trainer):
- def __init__(self, config, args):
- super().__init__(config, args)
-
- @mp_tools.rank_zero_only
- @paddle.no_grad()
- def test(self):
- self.logger.info(
- f"Test Total Examples: {len(self.test_loader.dataset)}")
- self.punc_list = []
- for i in range(len(self.test_loader.dataset.id2punc)):
- self.punc_list.append(self.test_loader.dataset.id2punc[i])
- self.model.eval()
- test_total_label = []
- test_total_predict = []
- with open(self.args.result_file, 'w') as fout:
- for i, batch in enumerate(self.test_loader):
- input, label = batch
- label = paddle.reshape(label, shape=[-1])
- y, logit = self.model(input)
- pred = paddle.argmax(logit, axis=1)
- test_total_label.extend(label.numpy().tolist())
- test_total_predict.extend(pred.numpy().tolist())
-
- # logging
- msg = "Test: "
- msg += "epoch: {}, ".format(self.epoch)
- msg += "step: {}, ".format(self.iteration)
- self.logger.info(msg)
- t = classification_report(
- test_total_label, test_total_predict, target_names=self.punc_list)
- print(t)
- t2 = self.evaluation(test_total_label, test_total_predict)
- print(t2)
-
- def evaluation(self, y_pred, y_test):
- precision, recall, f1, _ = precision_recall_fscore_support(
- y_test, y_pred, average=None, labels=[1, 2, 3])
- overall = precision_recall_fscore_support(
- y_test, y_pred, average='macro', labels=[1, 2, 3])
- result = pd.DataFrame(
- np.array([precision, recall, f1]),
- columns=list(['O', 'COMMA', 'PERIOD', 'QUESTION'])[1:],
- index=['Precision', 'Recall', 'F1'])
- result['OVERALL'] = overall[:3]
- return result
-
- def run_test(self):
- self.resume_or_scratch()
- try:
- self.test()
- except KeyboardInterrupt:
- self.logger.info("Testing was aborted by keybord interrupt.")
- exit(-1)
-
- def setup(self):
- """Setup the experiment.
- """
- if self.args.ngpu > 0:
- paddle.set_device('gpu')
- else:
- paddle.set_device('cpu')
- self.setup_logger()
- self.setup_output_dir()
- self.setup_checkpointer()
-
- self.setup_dataloader()
- self.setup_model()
-
- self.iteration = 0
- self.epoch = 0
-
- def setup_model(self):
- config = self.config
- model = DefinedClassifier[self.config["model_type"]](
- **self.config["model_params"])
-
- self.model = model
- self.logger.info("Setup model!")
-
- def setup_dataloader(self):
- config = self.config["data"].copy()
-
- test_dataset = DefinedDataset[config["dataset_type"]](
- train_path=config["test_path"], **config["data_params"])
-
- self.test_loader = DataLoader(
- test_dataset,
- batch_size=config["batch_size"],
- shuffle=False,
- drop_last=False)
- self.logger.info("Setup test Dataloader!")
-
- def setup_output_dir(self):
- """Create a directory used for output.
- """
- # output dir
- if self.args.output_dir:
- output_dir = Path(self.args.output_dir).expanduser()
- output_dir.mkdir(parents=True, exist_ok=True)
- else:
- output_dir = Path(
- self.args.checkpoint_path).expanduser().parent.parent
- output_dir.mkdir(parents=True, exist_ok=True)
-
- self.output_dir = output_dir
-
- def setup_logger(self):
- LOG_FORMAT = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s"
- format_str = logging.Formatter(
- '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
- )
- logging.basicConfig(
- filename=self.config["testing"]["log_path"],
- level=logging.INFO,
- format=LOG_FORMAT)
- self.logger = logging.getLogger(__name__)
-
- self.logger.setLevel(logging.INFO)
- sh = logging.StreamHandler()
- sh.setFormatter(format_str)
- self.logger.addHandler(sh)
-
- self.logger.info('info')
diff --git a/paddlespeech/text/utils/default_parser.py b/paddlespeech/text/utils/default_parser.py
deleted file mode 100644
index 469157a69da6198932173dcbcbc38de845c9e97a..0000000000000000000000000000000000000000
--- a/paddlespeech/text/utils/default_parser.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import argparse
-
-
-def default_argument_parser():
- r"""A simple yet genral argument parser for experiments with t2s.
-
- This is used in examples with t2s. And it is intended to be used by
- other experiments with t2s. It requires a minimal set of command line
- arguments to start a training script.
-
- The ``--config`` and ``--opts`` are used for overwrite the deault
- configuration.
-
- The ``--data`` and ``--output`` specifies the data path and output path.
- Resuming training from existing progress at the output directory is the
- intended default behavior.
-
- The ``--checkpoint_path`` specifies the checkpoint to load from.
-
- The ``--ngpu`` specifies how to run the training.
-
-
- See Also
- --------
- paddlespeech.t2s.training.experiment
- Returns
- -------
- argparse.ArgumentParser
- the parser
- """
- parser = argparse.ArgumentParser()
-
- # yapf: disable
- # data and output
- parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.")
- parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.")
- # parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.")
- parser.add_argument("--output_dir", metavar="OUTPUT_DIR", help="path to save checkpoint.")
- parser.add_argument("--log_dir", metavar="LOG_DIR", help="path to save logs.")
-
- # load from saved checkpoint
- parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load")
-
- # save jit model to
- parser.add_argument("--export_path", type=str, help="path of the jit model to save")
-
- # save asr result to
- parser.add_argument("--result_file", type=str, help="path of save the asr result")
-
- # running
- parser.add_argument("--ngpu", type=int, default=1, help="number of parallel processes to use. if ngpu=0, using cpu.")
-
- # overwrite extra config and default config
- # parser.add_argument("--opts", nargs=argparse.REMAINDER,
- # help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
- parser.add_argument("--opts", type=str, default=[], nargs='+',
- help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
- # yapd: enable
-
- return parser