提交 97965f4c 编写于 作者: 小湉湉's avatar 小湉湉

fix mlm_prob, test=tts

上级 c1395e3a
...@@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \ ...@@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \ --dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \ --config=${config_path} \
--output-dir=${train_output_path} \ --output-dir=${train_output_path} \
--ngpu=1 \ --ngpu=2 \
--phones-dict=dump/phone_id_map.txt --phones-dict=dump/phone_id_map.txt
\ No newline at end of file
...@@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \ ...@@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \ --dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \ --config=${config_path} \
--output-dir=${train_output_path} \ --output-dir=${train_output_path} \
--ngpu=1 \ --ngpu=2 \
--phones-dict=dump/phone_id_map.txt --phones-dict=dump/phone_id_map.txt
\ No newline at end of file
...@@ -79,7 +79,7 @@ grad_clip: 1.0 ...@@ -79,7 +79,7 @@ grad_clip: 1.0
########################################################### ###########################################################
# TRAINING SETTING # # TRAINING SETTING #
########################################################### ###########################################################
max_epoch: 200 max_epoch: 600
num_snapshots: 5 num_snapshots: 5
########################################################### ###########################################################
......
...@@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \ ...@@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \ --dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \ --config=${config_path} \
--output-dir=${train_output_path} \ --output-dir=${train_output_path} \
--ngpu=1 \ --ngpu=2 \
--phones-dict=dump/phone_id_map.txt --phones-dict=dump/phone_id_map.txt
\ No newline at end of file
...@@ -29,20 +29,13 @@ from paddlespeech.t2s.modules.nets_utils import phones_text_masking ...@@ -29,20 +29,13 @@ from paddlespeech.t2s.modules.nets_utils import phones_text_masking
# 因为要传参数,所以需要额外构建 # 因为要传参数,所以需要额外构建
def build_erniesat_collate_fn( def build_erniesat_collate_fn(mlm_prob: float=0.8,
mlm_prob: float=0.8,
mean_phn_span: int=8, mean_phn_span: int=8,
seg_emb: bool=False, seg_emb: bool=False,
text_masking: bool=False, text_masking: bool=False):
epoch: int=-1, ):
if epoch == -1:
mlm_prob_factor = 1
else:
mlm_prob_factor = 0.8
return ErnieSATCollateFn( return ErnieSATCollateFn(
mlm_prob=mlm_prob * mlm_prob_factor, mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span, mean_phn_span=mean_phn_span,
seg_emb=seg_emb, seg_emb=seg_emb,
text_masking=text_masking) text_masking=text_masking)
......
...@@ -73,8 +73,7 @@ def evaluate(args): ...@@ -73,8 +73,7 @@ def evaluate(args):
mlm_prob=erniesat_config.mlm_prob, mlm_prob=erniesat_config.mlm_prob,
mean_phn_span=erniesat_config.mean_phn_span, mean_phn_span=erniesat_config.mean_phn_span,
seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm', seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm',
text_masking=False, text_masking=False)
epoch=-1)
gen_raw = True gen_raw = True
erniesat_mu, erniesat_std = np.load(args.erniesat_stat) erniesat_mu, erniesat_std = np.load(args.erniesat_stat)
......
...@@ -84,8 +84,7 @@ def train_sp(args, config): ...@@ -84,8 +84,7 @@ def train_sp(args, config):
mlm_prob=config.mlm_prob, mlm_prob=config.mlm_prob,
mean_phn_span=config.mean_phn_span, mean_phn_span=config.mean_phn_span,
seg_emb=config.model['enc_input_layer'] == 'sega_mlm', seg_emb=config.model['enc_input_layer'] == 'sega_mlm',
text_masking=config["model"]["text_masking"], text_masking=config["model"]["text_masking"])
epoch=config["max_epoch"])
train_sampler = DistributedBatchSampler( train_sampler = DistributedBatchSampler(
train_dataset, train_dataset,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册