diff --git a/examples/aishell3/ernie_sat/local/train.sh b/examples/aishell3/ernie_sat/local/train.sh index f90db91505d7ff337824fc716212f566754cb5d8..30720e8f5b7ed289be91765e42a6e05bdf279c4a 100755 --- a/examples/aishell3/ernie_sat/local/train.sh +++ b/examples/aishell3/ernie_sat/local/train.sh @@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \ --dev-metadata=dump/dev/norm/metadata.jsonl \ --config=${config_path} \ --output-dir=${train_output_path} \ - --ngpu=1 \ + --ngpu=2 \ --phones-dict=dump/phone_id_map.txt \ No newline at end of file diff --git a/examples/aishell3_vctk/ernie_sat/local/train.sh b/examples/aishell3_vctk/ernie_sat/local/train.sh index f90db91505d7ff337824fc716212f566754cb5d8..30720e8f5b7ed289be91765e42a6e05bdf279c4a 100755 --- a/examples/aishell3_vctk/ernie_sat/local/train.sh +++ b/examples/aishell3_vctk/ernie_sat/local/train.sh @@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \ --dev-metadata=dump/dev/norm/metadata.jsonl \ --config=${config_path} \ --output-dir=${train_output_path} \ - --ngpu=1 \ + --ngpu=2 \ --phones-dict=dump/phone_id_map.txt \ No newline at end of file diff --git a/examples/vctk/ernie_sat/conf/default.yaml b/examples/vctk/ernie_sat/conf/default.yaml index 74c847a5ffcd31fb408875dc7b0fa03c92f0a78e..b61c81703f7a0009c55c0040df50d4db1865154a 100644 --- a/examples/vctk/ernie_sat/conf/default.yaml +++ b/examples/vctk/ernie_sat/conf/default.yaml @@ -79,7 +79,7 @@ grad_clip: 1.0 ########################################################### # TRAINING SETTING # ########################################################### -max_epoch: 200 +max_epoch: 600 num_snapshots: 5 ########################################################### @@ -160,4 +160,4 @@ token_list: - UH0 - AW0 - OY0 -- \ No newline at end of file +- diff --git a/examples/vctk/ernie_sat/local/train.sh b/examples/vctk/ernie_sat/local/train.sh index f90db91505d7ff337824fc716212f566754cb5d8..30720e8f5b7ed289be91765e42a6e05bdf279c4a 100755 --- a/examples/vctk/ernie_sat/local/train.sh +++ b/examples/vctk/ernie_sat/local/train.sh @@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \ --dev-metadata=dump/dev/norm/metadata.jsonl \ --config=${config_path} \ --output-dir=${train_output_path} \ - --ngpu=1 \ + --ngpu=2 \ --phones-dict=dump/phone_id_map.txt \ No newline at end of file diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index 9c964d8e95396249a72f5a2537847e09f53b4c70..05471167f05e7249ee310b55dcda788c38688119 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -29,20 +29,13 @@ from paddlespeech.t2s.modules.nets_utils import phones_text_masking # 因为要传参数,所以需要额外构建 -def build_erniesat_collate_fn( - mlm_prob: float=0.8, - mean_phn_span: int=8, - seg_emb: bool=False, - text_masking: bool=False, - epoch: int=-1, ): - - if epoch == -1: - mlm_prob_factor = 1 - else: - mlm_prob_factor = 0.8 +def build_erniesat_collate_fn(mlm_prob: float=0.8, + mean_phn_span: int=8, + seg_emb: bool=False, + text_masking: bool=False): return ErnieSATCollateFn( - mlm_prob=mlm_prob * mlm_prob_factor, + mlm_prob=mlm_prob, mean_phn_span=mean_phn_span, seg_emb=seg_emb, text_masking=text_masking) diff --git a/paddlespeech/t2s/exps/ernie_sat/synthesize.py b/paddlespeech/t2s/exps/ernie_sat/synthesize.py index 56f26a8bbd04f9b7e4023347ed82411d32ff7b07..2e358294889734c9920a754250ac2946506be62e 100644 --- a/paddlespeech/t2s/exps/ernie_sat/synthesize.py +++ b/paddlespeech/t2s/exps/ernie_sat/synthesize.py @@ -73,8 +73,7 @@ def evaluate(args): mlm_prob=erniesat_config.mlm_prob, mean_phn_span=erniesat_config.mean_phn_span, seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm', - text_masking=False, - epoch=-1) + text_masking=False) gen_raw = True erniesat_mu, erniesat_std = np.load(args.erniesat_stat) diff --git a/paddlespeech/t2s/exps/ernie_sat/train.py b/paddlespeech/t2s/exps/ernie_sat/train.py index 020b0d0fa06f34e280d269adb5624dd416352af7..5d8eadb684ea616bca9ed47d951c6c2a5bc2ee15 100644 --- a/paddlespeech/t2s/exps/ernie_sat/train.py +++ b/paddlespeech/t2s/exps/ernie_sat/train.py @@ -84,8 +84,7 @@ def train_sp(args, config): mlm_prob=config.mlm_prob, mean_phn_span=config.mean_phn_span, seg_emb=config.model['enc_input_layer'] == 'sega_mlm', - text_masking=config["model"]["text_masking"], - epoch=config["max_epoch"]) + text_masking=config["model"]["text_masking"]) train_sampler = DistributedBatchSampler( train_dataset,