提交 133ee7db 编写于 作者: 小湉湉's avatar 小湉湉

rename num_speakers

上级 a97c7b52
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
set -e set -e
source path.sh source path.sh
gpus=4,5 gpus=0,1
stage=0 stage=0
stop_stage=100 stop_stage=100
......
...@@ -46,14 +46,14 @@ def evaluate(args, fastspeech2_config, pwg_config): ...@@ -46,14 +46,14 @@ def evaluate(args, fastspeech2_config, pwg_config):
print("vocab_size:", vocab_size) print("vocab_size:", vocab_size)
with open(args.speaker_dict, 'rt') as f: with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
num_speakers = len(spk_id) spk_num = len(spk_id)
print("num_speakers:", num_speakers) print("spk_num:", spk_num)
odim = fastspeech2_config.n_mels odim = fastspeech2_config.n_mels
model = FastSpeech2( model = FastSpeech2(
idim=vocab_size, idim=vocab_size,
odim=odim, odim=odim,
num_speakers=num_speakers, spk_num=spk_num,
**fastspeech2_config["model"]) **fastspeech2_config["model"])
model.set_state_dict( model.set_state_dict(
......
...@@ -51,14 +51,14 @@ def evaluate(args, fastspeech2_config, pwg_config): ...@@ -51,14 +51,14 @@ def evaluate(args, fastspeech2_config, pwg_config):
print("vocab_size:", vocab_size) print("vocab_size:", vocab_size)
with open(args.speaker_dict, 'rt') as f: with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
num_speakers = len(spk_id) spk_num = len(spk_id)
print("num_speakers:", num_speakers) print("spk_num:", spk_num)
odim = fastspeech2_config.n_mels odim = fastspeech2_config.n_mels
model = FastSpeech2( model = FastSpeech2(
idim=vocab_size, idim=vocab_size,
odim=odim, odim=odim,
num_speakers=num_speakers, spk_num=spk_num,
**fastspeech2_config["model"]) **fastspeech2_config["model"])
model.set_state_dict( model.set_state_dict(
......
...@@ -40,19 +40,19 @@ def evaluate(args, fastspeech2_config, pwg_config): ...@@ -40,19 +40,19 @@ def evaluate(args, fastspeech2_config, pwg_config):
fields = ["utt_id", "text"] fields = ["utt_id", "text"]
num_speakers = None spk_num = None
if args.speaker_dict is not None: if args.speaker_dict is not None:
print("multiple speaker fastspeech2!") print("multiple speaker fastspeech2!")
with open(args.speaker_dict, 'rt') as f: with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
num_speakers = len(spk_id) spk_num = len(spk_id)
fields += ["spk_id"] fields += ["spk_id"]
elif args.voice_cloning: elif args.voice_cloning:
print("voice cloning!") print("voice cloning!")
fields += ["spk_emb"] fields += ["spk_emb"]
else: else:
print("single speaker fastspeech2!") print("single speaker fastspeech2!")
print("num_speakers:", num_speakers) print("spk_num:", spk_num)
test_dataset = DataTable(data=test_metadata, fields=fields) test_dataset = DataTable(data=test_metadata, fields=fields)
...@@ -65,7 +65,7 @@ def evaluate(args, fastspeech2_config, pwg_config): ...@@ -65,7 +65,7 @@ def evaluate(args, fastspeech2_config, pwg_config):
model = FastSpeech2( model = FastSpeech2(
idim=vocab_size, idim=vocab_size,
odim=odim, odim=odim,
num_speakers=num_speakers, spk_num=spk_num,
**fastspeech2_config["model"]) **fastspeech2_config["model"])
model.set_state_dict( model.set_state_dict(
......
...@@ -62,13 +62,13 @@ def train_sp(args, config): ...@@ -62,13 +62,13 @@ def train_sp(args, config):
"pitch", "energy" "pitch", "energy"
] ]
converters = {"speech": np.load, "pitch": np.load, "energy": np.load} converters = {"speech": np.load, "pitch": np.load, "energy": np.load}
num_speakers = None spk_num = None
if args.speaker_dict is not None: if args.speaker_dict is not None:
print("multiple speaker fastspeech2!") print("multiple speaker fastspeech2!")
collate_fn = fastspeech2_multi_spk_batch_fn collate_fn = fastspeech2_multi_spk_batch_fn
with open(args.speaker_dict, 'rt') as f: with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
num_speakers = len(spk_id) spk_num = len(spk_id)
fields += ["spk_id"] fields += ["spk_id"]
elif args.voice_cloning: elif args.voice_cloning:
print("Training voice cloning!") print("Training voice cloning!")
...@@ -78,7 +78,7 @@ def train_sp(args, config): ...@@ -78,7 +78,7 @@ def train_sp(args, config):
else: else:
print("single speaker fastspeech2!") print("single speaker fastspeech2!")
collate_fn = fastspeech2_single_spk_batch_fn collate_fn = fastspeech2_single_spk_batch_fn
print("num_speakers:", num_speakers) print("spk_num:", spk_num)
# dataloader has been too verbose # dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True logging.getLogger("DataLoader").disabled = True
...@@ -129,10 +129,7 @@ def train_sp(args, config): ...@@ -129,10 +129,7 @@ def train_sp(args, config):
odim = config.n_mels odim = config.n_mels
model = FastSpeech2( model = FastSpeech2(
idim=vocab_size, idim=vocab_size, odim=odim, spk_num=spk_num, **config["model"])
odim=odim,
num_speakers=num_speakers,
**config["model"])
if world_size > 1: if world_size > 1:
model = DataParallel(model) model = DataParallel(model)
print("model done!") print("model done!")
......
...@@ -96,7 +96,7 @@ class FastSpeech2(nn.Layer): ...@@ -96,7 +96,7 @@ class FastSpeech2(nn.Layer):
pitch_embed_dropout: float=0.5, pitch_embed_dropout: float=0.5,
stop_gradient_from_pitch_predictor: bool=False, stop_gradient_from_pitch_predictor: bool=False,
# spk emb # spk emb
num_speakers: int=None, spk_num: int=None,
spk_embed_dim: int=None, spk_embed_dim: int=None,
spk_embed_integration_type: str="add", spk_embed_integration_type: str="add",
# tone emb # tone emb
...@@ -146,9 +146,9 @@ class FastSpeech2(nn.Layer): ...@@ -146,9 +146,9 @@ class FastSpeech2(nn.Layer):
# initialize parameters # initialize parameters
initialize(self, init_type) initialize(self, init_type)
if self.spk_embed_dim and num_speakers: if spk_num and self.spk_embed_dim:
self.spk_embedding_table = nn.Embedding( self.spk_embedding_table = nn.Embedding(
num_embeddings=num_speakers, num_embeddings=spk_num,
embedding_dim=self.spk_embed_dim, embedding_dim=self.spk_embed_dim,
padding_idx=self.padding_idx) padding_idx=self.padding_idx)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册