未验证 提交 2e51e0da 编写于 作者: H HuangLiangJie 提交者: GitHub

[TTS]Fix attention bugs and sort VITS data with feats_lengths (#2770)

上级 6725bcd8
...@@ -187,7 +187,7 @@ def main(): ...@@ -187,7 +187,7 @@ def main():
record["spk_emb"] = str(item["spk_emb"]) record["spk_emb"] = str(item["spk_emb"])
output_metadata.append(record) output_metadata.append(record)
output_metadata.sort(key=itemgetter('utt_id')) output_metadata.sort(key=itemgetter('feats_lengths'))
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl" output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
with jsonlines.open(output_metadata_path, 'w') as writer: with jsonlines.open(output_metadata_path, 'w') as writer:
for item in output_metadata: for item in output_metadata:
......
...@@ -166,7 +166,7 @@ def process_sentences(config, ...@@ -166,7 +166,7 @@ def process_sentences(config,
if record: if record:
results.append(record) results.append(record)
results.sort(key=itemgetter("utt_id")) results.sort(key=itemgetter("feats_lengths"))
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer: with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
for item in results: for item in results:
writer.write(item) writer.write(item)
......
...@@ -24,13 +24,13 @@ import yaml ...@@ -24,13 +24,13 @@ import yaml
from paddle import DataParallel from paddle import DataParallel
from paddle import distributed as dist from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam from paddle.optimizer import Adam
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import vits_multi_spk_batch_fn from paddlespeech.t2s.datasets.am_batch_fn import vits_multi_spk_batch_fn
from paddlespeech.t2s.datasets.am_batch_fn import vits_single_spk_batch_fn from paddlespeech.t2s.datasets.am_batch_fn import vits_single_spk_batch_fn
from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.datasets.sampler import ErnieSATSampler
from paddlespeech.t2s.models.vits import VITS from paddlespeech.t2s.models.vits import VITS
from paddlespeech.t2s.models.vits import VITSEvaluator from paddlespeech.t2s.models.vits import VITSEvaluator
from paddlespeech.t2s.models.vits import VITSUpdater from paddlespeech.t2s.models.vits import VITSUpdater
...@@ -107,12 +107,12 @@ def train_sp(args, config): ...@@ -107,12 +107,12 @@ def train_sp(args, config):
converters=converters, ) converters=converters, )
# collate function and dataloader # collate function and dataloader
train_sampler = DistributedBatchSampler( train_sampler = ErnieSATSampler(
train_dataset, train_dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
shuffle=True, shuffle=True,
drop_last=True) drop_last=True)
dev_sampler = DistributedBatchSampler( dev_sampler = ErnieSATSampler(
dev_dataset, dev_dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
shuffle=False, shuffle=False,
......
...@@ -196,7 +196,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -196,7 +196,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
if self.zero_triu: if self.zero_triu:
ones = paddle.ones((t1, t2)) ones = paddle.ones((t1, t2))
x = x * paddle.tril(ones, t2 - 1)[None, None, :, :] x = x * paddle.tril(ones, t2 - t1)[None, None, :, :]
return x return x
...@@ -299,7 +299,7 @@ class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -299,7 +299,7 @@ class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
if self.zero_triu: if self.zero_triu:
ones = paddle.ones((t1, t2)) ones = paddle.ones((t1, t2))
x = x * paddle.tril(ones, t2 - 1)[None, None, :, :] x = x * paddle.tril(ones, t2 - t1)[None, None, :, :]
return x return x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册