torch_pwgan.py 1.8 KB
Newer Older
O
oyjxer 已提交
1 2 3 4 5 6 7 8
"""Wrapper class for the vocoder model trained with parallel_wavegan repo."""
import logging
import os
from pathlib import Path
from typing import Optional
from typing import Union

import torch
小湉湉's avatar
小湉湉 已提交
9
import yaml
O
oyjxer 已提交
10 11


小湉湉's avatar
小湉湉 已提交
12
class TorchPWGAN(torch.nn.Module):
O
oyjxer 已提交
13 14 15
    """Wrapper class to load the vocoder trained with parallel_wavegan repo."""

    def __init__(
小湉湉's avatar
小湉湉 已提交
16 17 18
            self,
            model_file: Union[Path, str],
            config_file: Optional[Union[Path, str]]=None, ):
O
oyjxer 已提交
19 20 21 22 23 24 25
        """Initialize ParallelWaveGANPretrainedVocoder module."""
        super().__init__()
        try:
            from parallel_wavegan.utils import load_model
        except ImportError:
            logging.error(
                "`parallel_wavegan` is not installed. "
小湉湉's avatar
小湉湉 已提交
26
                "Please install via `pip install -U parallel_wavegan`.")
O
oyjxer 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
            raise
        if config_file is None:
            dirname = os.path.dirname(str(model_file))
            config_file = os.path.join(dirname, "config.yml")
        with open(config_file) as f:
            config = yaml.load(f, Loader=yaml.Loader)
        self.fs = config["sampling_rate"]
        self.vocoder = load_model(model_file, config)
        if hasattr(self.vocoder, "remove_weight_norm"):
            self.vocoder.remove_weight_norm()
        self.normalize_before = False
        if hasattr(self.vocoder, "mean"):
            self.normalize_before = True

    @torch.no_grad()
    def forward(self, feats: torch.Tensor) -> torch.Tensor:
        """Generate waveform with pretrained vocoder.

        Args:
            feats (Tensor): Feature tensor (T_feats, #mels).

        Returns:
            Tensor: Generated waveform tensor (T_wav).

        """
        return self.vocoder.inference(
            feats,
小湉湉's avatar
小湉湉 已提交
54
            normalize_before=self.normalize_before, ).view(-1)