提交 133ee7db 编写于 作者: 小湉湉's avatar 小湉湉

rename num_speakers

上级 a97c7b52
......@@ -3,7 +3,7 @@
set -e
source path.sh
gpus=4,5
gpus=0,1
stage=0
stop_stage=100
......
......@@ -46,14 +46,14 @@ def evaluate(args, fastspeech2_config, pwg_config):
print("vocab_size:", vocab_size)
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
num_speakers = len(spk_id)
print("num_speakers:", num_speakers)
spk_num = len(spk_id)
print("spk_num:", spk_num)
odim = fastspeech2_config.n_mels
model = FastSpeech2(
idim=vocab_size,
odim=odim,
num_speakers=num_speakers,
spk_num=spk_num,
**fastspeech2_config["model"])
model.set_state_dict(
......
......@@ -51,14 +51,14 @@ def evaluate(args, fastspeech2_config, pwg_config):
print("vocab_size:", vocab_size)
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
num_speakers = len(spk_id)
print("num_speakers:", num_speakers)
spk_num = len(spk_id)
print("spk_num:", spk_num)
odim = fastspeech2_config.n_mels
model = FastSpeech2(
idim=vocab_size,
odim=odim,
num_speakers=num_speakers,
spk_num=spk_num,
**fastspeech2_config["model"])
model.set_state_dict(
......
......@@ -40,19 +40,19 @@ def evaluate(args, fastspeech2_config, pwg_config):
fields = ["utt_id", "text"]
num_speakers = None
spk_num = None
if args.speaker_dict is not None:
print("multiple speaker fastspeech2!")
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
num_speakers = len(spk_id)
spk_num = len(spk_id)
fields += ["spk_id"]
elif args.voice_cloning:
print("voice cloning!")
fields += ["spk_emb"]
else:
print("single speaker fastspeech2!")
print("num_speakers:", num_speakers)
print("spk_num:", spk_num)
test_dataset = DataTable(data=test_metadata, fields=fields)
......@@ -65,7 +65,7 @@ def evaluate(args, fastspeech2_config, pwg_config):
model = FastSpeech2(
idim=vocab_size,
odim=odim,
num_speakers=num_speakers,
spk_num=spk_num,
**fastspeech2_config["model"])
model.set_state_dict(
......
......@@ -62,13 +62,13 @@ def train_sp(args, config):
"pitch", "energy"
]
converters = {"speech": np.load, "pitch": np.load, "energy": np.load}
num_speakers = None
spk_num = None
if args.speaker_dict is not None:
print("multiple speaker fastspeech2!")
collate_fn = fastspeech2_multi_spk_batch_fn
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
num_speakers = len(spk_id)
spk_num = len(spk_id)
fields += ["spk_id"]
elif args.voice_cloning:
print("Training voice cloning!")
......@@ -78,7 +78,7 @@ def train_sp(args, config):
else:
print("single speaker fastspeech2!")
collate_fn = fastspeech2_single_spk_batch_fn
print("num_speakers:", num_speakers)
print("spk_num:", spk_num)
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
......@@ -129,10 +129,7 @@ def train_sp(args, config):
odim = config.n_mels
model = FastSpeech2(
idim=vocab_size,
odim=odim,
num_speakers=num_speakers,
**config["model"])
idim=vocab_size, odim=odim, spk_num=spk_num, **config["model"])
if world_size > 1:
model = DataParallel(model)
print("model done!")
......
......@@ -96,7 +96,7 @@ class FastSpeech2(nn.Layer):
pitch_embed_dropout: float=0.5,
stop_gradient_from_pitch_predictor: bool=False,
# spk emb
num_speakers: int=None,
spk_num: int=None,
spk_embed_dim: int=None,
spk_embed_integration_type: str="add",
# tone emb
......@@ -146,9 +146,9 @@ class FastSpeech2(nn.Layer):
# initialize parameters
initialize(self, init_type)
if self.spk_embed_dim and num_speakers:
if spk_num and self.spk_embed_dim:
self.spk_embedding_table = nn.Embedding(
num_embeddings=num_speakers,
num_embeddings=spk_num,
embedding_dim=self.spk_embed_dim,
padding_idx=self.padding_idx)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册