未验证 提交 2e51e0da 编写于 作者: H HuangLiangJie 提交者: GitHub

[TTS]Fix attention bugs and sort VITS data with feats_lengths (#2770)

上级 6725bcd8
......@@ -187,7 +187,7 @@ def main():
record["spk_emb"] = str(item["spk_emb"])
output_metadata.append(record)
output_metadata.sort(key=itemgetter('utt_id'))
output_metadata.sort(key=itemgetter('feats_lengths'))
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
with jsonlines.open(output_metadata_path, 'w') as writer:
for item in output_metadata:
......
......@@ -166,7 +166,7 @@ def process_sentences(config,
if record:
results.append(record)
results.sort(key=itemgetter("utt_id"))
results.sort(key=itemgetter("feats_lengths"))
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
for item in results:
writer.write(item)
......
......@@ -24,13 +24,13 @@ import yaml
from paddle import DataParallel
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import vits_multi_spk_batch_fn
from paddlespeech.t2s.datasets.am_batch_fn import vits_single_spk_batch_fn
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.datasets.sampler import ErnieSATSampler
from paddlespeech.t2s.models.vits import VITS
from paddlespeech.t2s.models.vits import VITSEvaluator
from paddlespeech.t2s.models.vits import VITSUpdater
......@@ -107,12 +107,12 @@ def train_sp(args, config):
converters=converters, )
# collate function and dataloader
train_sampler = DistributedBatchSampler(
train_sampler = ErnieSATSampler(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
drop_last=True)
dev_sampler = DistributedBatchSampler(
dev_sampler = ErnieSATSampler(
dev_dataset,
batch_size=config.batch_size,
shuffle=False,
......
......@@ -196,7 +196,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
if self.zero_triu:
ones = paddle.ones((t1, t2))
x = x * paddle.tril(ones, t2 - 1)[None, None, :, :]
x = x * paddle.tril(ones, t2 - t1)[None, None, :, :]
return x
......@@ -299,7 +299,7 @@ class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
if self.zero_triu:
ones = paddle.ones((t1, t2))
x = x * paddle.tril(ones, t2 - 1)[None, None, :, :]
x = x * paddle.tril(ones, t2 - t1)[None, None, :, :]
return x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册