提交 19180d35 编写于 作者: T tianhao zhang

format wav2vec2 demo

上级 6e429f05
...@@ -33,7 +33,7 @@ filename = ...@@ -33,7 +33,7 @@ filename =
# Specify a list of codes to ignore. # Specify a list of codes to ignore.
ignore = ignore =
W503 W503
E252,E262,E127,E265,E126,E266,E241,E261,E128,E125 E252,E262,E127,E265,E126,E266,E241,E261,E128,E125,E129
W291,W293,W605 W291,W293,W605
E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
# shebang has extra meaning in fbcode lints, so I think it's not worth trying # shebang has extra meaning in fbcode lints, so I think it's not worth trying
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
* asr0 - deepspeech2 Streaming/Non-Streaming * asr0 - deepspeech2 Streaming/Non-Streaming
* asr1 - transformer/conformer Streaming/Non-Streaming * asr1 - transformer/conformer Streaming/Non-Streaming
* asr2 - transformer/conformer Streaming/Non-Streaming with Kaldi feature * asr2 - transformer/conformer Streaming/Non-Streaming with Kaldi feature
* asr3 - wav2vecASR, ASR model with pre-trained wav2vec2 and CTC
## Data ## Data
| Data Subset | Duration in Seconds | | Data Subset | Duration in Seconds |
......
...@@ -382,6 +382,36 @@ class LogMelSpectrogramKaldi(): ...@@ -382,6 +382,36 @@ class LogMelSpectrogramKaldi():
return mat return mat
class WavProcess():
def __init__(self, dither=0.1):
"""
Args:
dither (float): Dithering constant
Returns:
"""
self.dither = dither
def __call__(self, x, train):
"""
Args:
x (np.ndarray): shape (Ti,)
train (bool): True, train mode.
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
dither = self.dither if train else 0.0
if x.ndim != 1:
raise ValueError("Not support x: [Time, Channel]")
waveform = np.expand_dims(x, -1)
return waveform
class LogMelSpectrogramKaldi_decay(): class LogMelSpectrogramKaldi_decay():
def __init__( def __init__(
self, self,
......
...@@ -41,6 +41,7 @@ import_alias = dict( ...@@ -41,6 +41,7 @@ import_alias = dict(
utterance_cmvn="paddlespeech.audio.transform.cmvn:UtteranceCMVN", utterance_cmvn="paddlespeech.audio.transform.cmvn:UtteranceCMVN",
fbank="paddlespeech.audio.transform.spectrogram:LogMelSpectrogram", fbank="paddlespeech.audio.transform.spectrogram:LogMelSpectrogram",
spectrogram="paddlespeech.audio.transform.spectrogram:Spectrogram", spectrogram="paddlespeech.audio.transform.spectrogram:Spectrogram",
wav_process="paddlespeech.audio.transform.spectrogram:WavProcess",
stft="paddlespeech.audio.transform.spectrogram:Stft", stft="paddlespeech.audio.transform.spectrogram:Stft",
istft="paddlespeech.audio.transform.spectrogram:IStft", istft="paddlespeech.audio.transform.spectrogram:IStft",
stft2fbank="paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram", stft2fbank="paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram",
......
...@@ -27,6 +27,7 @@ from paddlespeech.s2t.utils.log import Log ...@@ -27,6 +27,7 @@ from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
class Wav2vec2Infer(): class Wav2vec2Infer():
def __init__(self, config, args): def __init__(self, config, args):
self.args = args self.args = args
...@@ -34,8 +35,7 @@ class Wav2vec2Infer(): ...@@ -34,8 +35,7 @@ class Wav2vec2Infer():
self.audio_file = args.audio_file self.audio_file = args.audio_file
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=config.unit_type, unit_type=config.unit_type, vocab=config.vocab_filepath)
vocab=config.vocab_filepath)
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
# model # model
......
...@@ -18,38 +18,38 @@ import time ...@@ -18,38 +18,38 @@ import time
from collections import defaultdict from collections import defaultdict
from collections import OrderedDict from collections import OrderedDict
from contextlib import nullcontext from contextlib import nullcontext
from paddlespeech.s2t.utils import mp_tools
import jsonlines import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment
from paddlespeech.s2t.utils import error_rate from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR
from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope from paddlespeech.s2t.training.reporter import ObsScope
from paddlespeech.s2t.training.reporter import report from paddlespeech.s2t.training.reporter import report
from paddlespeech.s2t.training.scheduler import LRSchedulerFactory from paddlespeech.s2t.training.scheduler import LRSchedulerFactory
from paddlespeech.s2t.training.timer import Timer from paddlespeech.s2t.training.timer import Timer
from paddlespeech.s2t.training.trainer import Trainer from paddlespeech.s2t.training.trainer import Trainer
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils import error_rate
from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils import mp_tools
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
class Wav2Vec2ASRTrainer(Trainer): class Wav2Vec2ASRTrainer(Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
self.avg_train_loss = 0 self.avg_train_loss = 0
def train_batch(self, batch_index, batch, msg): def train_batch(self, batch_index, batch, msg):
train_conf = self.config train_conf = self.config
start = time.time() start = time.time()
...@@ -58,7 +58,7 @@ class Wav2Vec2ASRTrainer(Trainer): ...@@ -58,7 +58,7 @@ class Wav2Vec2ASRTrainer(Trainer):
utt, wav, wavs_lens, target, target_lens = batch utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1] wavs_lens_rate = wavs_lens / wav.shape[1]
target_lens_rate = target_lens / target.shape[1] target_lens_rate = target_lens / target.shape[1]
wav = wav[:,:,0] wav = wav[:, :, 0]
wav = self.speech_augmentation(wav, wavs_lens_rate) wav = self.speech_augmentation(wav, wavs_lens_rate)
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate) loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
# pring(wav, wavs_lens_rate, target, target_lens_rate) # pring(wav, wavs_lens_rate, target, target_lens_rate)
...@@ -108,7 +108,8 @@ class Wav2Vec2ASRTrainer(Trainer): ...@@ -108,7 +108,8 @@ class Wav2Vec2ASRTrainer(Trainer):
def valid(self): def valid(self):
self.model.eval() self.model.eval()
if not self.use_streamdata: if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") logger.info(
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
num_seen_utts = 1 num_seen_utts = 1
total_loss = 0.0 total_loss = 0.0
...@@ -116,7 +117,7 @@ class Wav2Vec2ASRTrainer(Trainer): ...@@ -116,7 +117,7 @@ class Wav2Vec2ASRTrainer(Trainer):
utt, wav, wavs_lens, target, target_lens = batch utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1] wavs_lens_rate = wavs_lens / wav.shape[1]
target_lens_rate = target_lens / target.shape[1] target_lens_rate = target_lens / target.shape[1]
wav = wav[:,:,0] wav = wav[:, :, 0]
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate) loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
if paddle.isfinite(loss): if paddle.isfinite(loss):
...@@ -134,7 +135,8 @@ class Wav2Vec2ASRTrainer(Trainer): ...@@ -134,7 +135,8 @@ class Wav2Vec2ASRTrainer(Trainer):
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata: if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += "batch: {}/{}, ".format(i + 1,
len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items()) for k, v in valid_dump.items())
logger.info(msg) logger.info(msg)
...@@ -155,7 +157,8 @@ class Wav2Vec2ASRTrainer(Trainer): ...@@ -155,7 +157,8 @@ class Wav2Vec2ASRTrainer(Trainer):
self.before_train() self.before_train()
if not self.use_streamdata: if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch: while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"): with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
...@@ -223,14 +226,18 @@ class Wav2Vec2ASRTrainer(Trainer): ...@@ -223,14 +226,18 @@ class Wav2Vec2ASRTrainer(Trainer):
config = self.config.clone() config = self.config.clone()
self.use_streamdata = config.get("use_stream_data", False) self.use_streamdata = config.get("use_stream_data", False)
if self.train: if self.train:
self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) self.train_loader = DataLoaderFactory.get_dataloader(
self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) 'train', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader(
'valid', config, self.args)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
else: else:
decode_batch_size = config.get('decode', dict()).get( decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1) 'decode_batch_size', 1)
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) self.test_loader = DataLoaderFactory.get_dataloader('test', config,
self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args) self.args)
self.align_loader = DataLoaderFactory.get_dataloader(
'align', config, self.args)
logger.info("Setup test/align Dataloader!") logger.info("Setup test/align Dataloader!")
def setup_model(self): def setup_model(self):
...@@ -312,14 +319,14 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer): ...@@ -312,14 +319,14 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer):
self.text_featurizer = TextFeaturizer( self.text_featurizer = TextFeaturizer(
unit_type=config.unit_type, vocab=config.vocab_filepath) unit_type=config.unit_type, vocab=config.vocab_filepath)
self.vocab_list = self.text_featurizer.vocab_list self.vocab_list = self.text_featurizer.vocab_list
def id2token(self, texts, texts_len): def id2token(self, texts, texts_len):
""" ord() id to chr() chr """ """ ord() id to chr() chr """
trans = [] trans = []
for text, n in zip(texts, texts_len): for text, n in zip(texts, texts_len):
n = n.numpy().item() n = n.numpy().item()
ids = text[:n] ids = text[:n]
trans.append( trans.append(self.text_featurizer.defeaturize(ids.numpy().tolist()))
self.text_featurizer.defeaturize(ids.numpy().tolist()))
return trans return trans
def compute_metrics(self, def compute_metrics(self,
......
...@@ -3,6 +3,7 @@ Authors ...@@ -3,6 +3,7 @@ Authors
* Elena Rastorgueva 2020 * Elena Rastorgueva 2020
""" """
import paddle import paddle
from paddlespeech.s2t.models.wav2vec2.modules import containers from paddlespeech.s2t.models.wav2vec2.modules import containers
from paddlespeech.s2t.models.wav2vec2.modules import linear from paddlespeech.s2t.models.wav2vec2.modules import linear
...@@ -31,8 +32,7 @@ class VanillaNN(containers.Sequential): ...@@ -31,8 +32,7 @@ class VanillaNN(containers.Sequential):
input_shape, input_shape,
activation=paddle.nn.LeakyReLU, activation=paddle.nn.LeakyReLU,
dnn_blocks=2, dnn_blocks=2,
dnn_neurons=512, dnn_neurons=512, ):
):
super().__init__(input_shape=input_shape) super().__init__(input_shape=input_shape)
for block_index in range(dnn_blocks): for block_index in range(dnn_blocks):
...@@ -40,6 +40,5 @@ class VanillaNN(containers.Sequential): ...@@ -40,6 +40,5 @@ class VanillaNN(containers.Sequential):
linear.Linear, linear.Linear,
n_neurons=dnn_neurons, n_neurons=dnn_neurons,
bias=True, bias=True,
layer_name="linear", layer_name="linear", )
)
self.append(activation(), layer_name="act") self.append(activation(), layer_name="act")
...@@ -11,12 +11,10 @@ ...@@ -11,12 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from packaging import version from paddle import nn
from paddle import Tensor, nn from paddle import Tensor
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
...@@ -29,7 +27,9 @@ class NewGELUActivation(nn.Layer): ...@@ -29,7 +27,9 @@ class NewGELUActivation(nn.Layer):
""" """
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + paddle.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * paddle.pow(input, 3.0)))) return 0.5 * input * (1.0 + paddle.tanh(
math.sqrt(2.0 / math.pi) *
(input + 0.044715 * paddle.pow(input, 3.0))))
class GELUActivation(nn.Layer): class GELUActivation(nn.Layer):
...@@ -40,7 +40,7 @@ class GELUActivation(nn.Layer): ...@@ -40,7 +40,7 @@ class GELUActivation(nn.Layer):
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
""" """
def __init__(self, use_gelu_python: bool = False): def __init__(self, use_gelu_python: bool=False):
super().__init__() super().__init__()
self.act = nn.functional.gelu self.act = nn.functional.gelu
...@@ -57,7 +57,9 @@ class FastGELUActivation(nn.Layer): ...@@ -57,7 +57,9 @@ class FastGELUActivation(nn.Layer):
""" """
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + paddle.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) return 0.5 * input * (
1.0 + paddle.tanh(input * 0.7978845608 *
(1.0 + 0.044715 * input * input)))
class QuickGELUActivation(nn.Layer): class QuickGELUActivation(nn.Layer):
...@@ -84,7 +86,8 @@ class ClippedGELUActivation(nn.Layer): ...@@ -84,7 +86,8 @@ class ClippedGELUActivation(nn.Layer):
def __init__(self, min: float, max: float): def __init__(self, min: float, max: float):
if min > max: if min > max:
raise ValueError(f"min should be < max (got min: {min}, max: {max})") raise ValueError(
f"min should be < max (got min: {min}, max: {max})")
super().__init__() super().__init__()
self.min = min self.min = min
...@@ -161,7 +164,9 @@ def get_activation(activation_string): ...@@ -161,7 +164,9 @@ def get_activation(activation_string):
if activation_string in ACT2FN: if activation_string in ACT2FN:
return ACT2FN[activation_string] return ACT2FN[activation_string]
else: else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") raise KeyError(
f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}"
)
# For backwards compatibility with: from activations import gelu_python # For backwards compatibility with: from activations import gelu_python
......
import paddle
import inspect import inspect
import logging
import operator import paddle
import functools
class Sequential(paddle.nn.LayerDict): class Sequential(paddle.nn.LayerDict):
"""A sequence of modules with potentially inferring shape on construction. """A sequence of modules with potentially inferring shape on construction.
...@@ -103,8 +102,7 @@ class Sequential(paddle.nn.LayerDict): ...@@ -103,8 +102,7 @@ class Sequential(paddle.nn.LayerDict):
raise ValueError( raise ValueError(
"Must pass `input_shape` at initialization and use " "Must pass `input_shape` at initialization and use "
"modules that take `input_shape` to infer shape when " "modules that take `input_shape` to infer shape when "
"using `append()`." "using `append()`.")
)
def get_output_shape(self): def get_output_shape(self):
"""Returns expected shape of the output. """Returns expected shape of the output.
......
...@@ -3,10 +3,10 @@ Authors ...@@ -3,10 +3,10 @@ Authors
* Mirco Ravanelli 2020 * Mirco Ravanelli 2020
* Davide Borra 2021 * Davide Borra 2021
""" """
import logging import logging
import paddle import paddle
import paddle.nn as nn
from paddlespeech.s2t.modules import align from paddlespeech.s2t.modules import align
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -42,8 +42,7 @@ class Linear(paddle.nn.Layer): ...@@ -42,8 +42,7 @@ class Linear(paddle.nn.Layer):
input_shape=None, input_shape=None,
input_size=None, input_size=None,
bias=True, bias=True,
combine_dims=False, combine_dims=False, ):
):
super().__init__() super().__init__()
self.combine_dims = combine_dims self.combine_dims = combine_dims
......
...@@ -11,12 +11,12 @@ ...@@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass
from dataclasses import fields from dataclasses import fields
from typing import Optional
from typing import Tuple
import paddle import paddle
...@@ -41,10 +41,13 @@ class ModelOutput(OrderedDict): ...@@ -41,10 +41,13 @@ class ModelOutput(OrderedDict):
if not len(class_fields): if not len(class_fields):
raise ValueError(f"{self.__class__.__name__} has no fields.") raise ValueError(f"{self.__class__.__name__} has no fields.")
if not all(field.default is None for field in class_fields[1:]): if not all(field.default is None for field in class_fields[1:]):
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.") raise ValueError(
f"{self.__class__.__name__} should not have more than one required field."
)
first_field = getattr(self, class_fields[0].name) first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) other_fields_are_none = all(
getattr(self, field.name) is None for field in class_fields[1:])
if other_fields_are_none and not paddle.is_tensor(first_field): if other_fields_are_none and not paddle.is_tensor(first_field):
if isinstance(first_field, dict): if isinstance(first_field, dict):
...@@ -61,11 +64,9 @@ class ModelOutput(OrderedDict): ...@@ -61,11 +64,9 @@ class ModelOutput(OrderedDict):
# set the associated fields # set the associated fields
if first_field_iterator: if first_field_iterator:
for element in iterator: for element in iterator:
if ( if (not isinstance(element, (list, tuple)) or
not isinstance(element, (list, tuple)) not len(element) == 2 or
or not len(element) == 2 not isinstance(element[0], str)):
or not isinstance(element[0], str)
):
break break
setattr(self, element[0], element[1]) setattr(self, element[0], element[1])
if element[1] is not None: if element[1] is not None:
...@@ -79,16 +80,23 @@ class ModelOutput(OrderedDict): ...@@ -79,16 +80,23 @@ class ModelOutput(OrderedDict):
self[field.name] = v self[field.name] = v
def __delitem__(self, *args, **kwargs): def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") raise Exception(
f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
)
def setdefault(self, *args, **kwargs): def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") raise Exception(
f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
)
def pop(self, *args, **kwargs): def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") raise Exception(
f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs): def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") raise Exception(
f"You cannot use ``update`` on a {self.__class__.__name__} instance."
)
def __getitem__(self, k): def __getitem__(self, k):
if isinstance(k, str): if isinstance(k, str):
......
...@@ -7,10 +7,8 @@ Authors ...@@ -7,10 +7,8 @@ Authors
* Samuele Cornell 2020 * Samuele Cornell 2020
* Sarthak Yadav 2022 * Sarthak Yadav 2022
""" """
import paddle
import math
from packaging import version
import numpy as np import numpy as np
import paddle
def blackman_window(window_length, periodic=True): def blackman_window(window_length, periodic=True):
...@@ -97,8 +95,7 @@ def convolve1d( ...@@ -97,8 +95,7 @@ def convolve1d(
stride=1, stride=1,
groups=1, groups=1,
use_fft=False, use_fft=False,
rotation_index=0, rotation_index=0, ):
):
"""Use paddle.nn.functional to perform 1d padding and conv. """Use paddle.nn.functional to perform 1d padding and conv.
Arguments Arguments
--------- ---------
...@@ -150,8 +147,7 @@ def convolve1d( ...@@ -150,8 +147,7 @@ def convolve1d(
# Padding can be a tuple (left_pad, right_pad) or an int # Padding can be a tuple (left_pad, right_pad) or an int
if isinstance(padding, tuple): if isinstance(padding, tuple):
waveform = paddle.nn.functional.pad( waveform = paddle.nn.functional.pad(
x=waveform, pad=padding, mode=pad_type, data_format='NCL' x=waveform, pad=padding, mode=pad_type, data_format='NCL')
)
# This approach uses FFT, which is more efficient if the kernel is large # This approach uses FFT, which is more efficient if the kernel is large
if use_fft: if use_fft:
...@@ -165,9 +161,7 @@ def convolve1d( ...@@ -165,9 +161,7 @@ def convolve1d(
# Perform rotation to ensure alignment # Perform rotation to ensure alignment
zeros = paddle.zeros( zeros = paddle.zeros(
[kernel.shape[0], kernel.shape[1], zero_length], [kernel.shape[0], kernel.shape[1], zero_length], dtype=kernel.dtype)
dtype=kernel.dtype
)
after_index = kernel[..., rotation_index:] after_index = kernel[..., rotation_index:]
before_index = kernel[..., :rotation_index] before_index = kernel[..., :rotation_index]
kernel = paddle.concat((after_index, zeros, before_index), axis=-1) kernel = paddle.concat((after_index, zeros, before_index), axis=-1)
...@@ -185,12 +179,12 @@ def convolve1d( ...@@ -185,12 +179,12 @@ def convolve1d(
weight=kernel, weight=kernel,
stride=stride, stride=stride,
groups=groups, groups=groups,
padding=padding if not isinstance(padding, tuple) else 0, padding=padding if not isinstance(padding, tuple) else 0, )
)
# Return time dimension to the second dimension. # Return time dimension to the second dimension.
return convolved.transpose([0, 2, 1]) return convolved.transpose([0, 2, 1])
def notch_filter(notch_freq, filter_width=101, notch_width=0.05): def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
"""Returns a notch filter constructed from a high-pass and low-pass filter. """Returns a notch filter constructed from a high-pass and low-pass filter.
(from https://tomroelandts.com/articles/ (from https://tomroelandts.com/articles/
...@@ -224,7 +218,8 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05): ...@@ -224,7 +218,8 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
return paddle.sin(x) / x return paddle.sin(x) / x
# The zero is at the middle index # The zero is at the middle index
return paddle.concat([_sinc(x[:pad]), paddle.ones([1]), _sinc(x[pad + 1 :])]) return paddle.concat(
[_sinc(x[:pad]), paddle.ones([1]), _sinc(x[pad + 1:])])
# Compute a low-pass filter with cutoff frequency notch_freq. # Compute a low-pass filter with cutoff frequency notch_freq.
hlpf = sinc(3 * (notch_freq - notch_width) * inputs) hlpf = sinc(3 * (notch_freq - notch_width) * inputs)
...@@ -239,4 +234,3 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05): ...@@ -239,4 +234,3 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
# Adding filters creates notch filter # Adding filters creates notch filter
return (hlpf + hhpf).view(1, -1, 1) return (hlpf + hhpf).view(1, -1, 1)
import math import math
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F
from paddlespeech.s2t.models.wav2vec2.processing.signal_processing import ( from paddlespeech.s2t.models.wav2vec2.processing.signal_processing import compute_amplitude
compute_amplitude, from paddlespeech.s2t.models.wav2vec2.processing.signal_processing import convolve1d
convolve1d, from paddlespeech.s2t.models.wav2vec2.processing.signal_processing import notch_filter
notch_filter)
class SpeedPerturb(nn.Layer): class SpeedPerturb(nn.Layer):
"""Slightly speed up or slow down an audio signal. """Slightly speed up or slow down an audio signal.
...@@ -36,8 +37,10 @@ class SpeedPerturb(nn.Layer): ...@@ -36,8 +37,10 @@ class SpeedPerturb(nn.Layer):
""" """
def __init__( def __init__(
self, orig_freq, speeds=[90, 100, 110], perturb_prob=1.0, self,
): orig_freq,
speeds=[90, 100, 110],
perturb_prob=1.0, ):
super().__init__() super().__init__()
self.orig_freq = orig_freq self.orig_freq = orig_freq
self.speeds = speeds self.speeds = speeds
...@@ -73,11 +76,12 @@ class SpeedPerturb(nn.Layer): ...@@ -73,11 +76,12 @@ class SpeedPerturb(nn.Layer):
return waveform.clone() return waveform.clone()
# Perform a random perturbation # Perform a random perturbation
self.samp_index = paddle.randint(len(self.speeds), shape=(1,))[0] self.samp_index = paddle.randint(len(self.speeds), shape=(1, ))[0]
perturbed_waveform = self.resamplers[self.samp_index](waveform) perturbed_waveform = self.resamplers[self.samp_index](waveform)
return perturbed_waveform return perturbed_waveform
class Resample(nn.Layer): class Resample(nn.Layer):
"""This class resamples an audio signal using sinc-based interpolation. """This class resamples an audio signal using sinc-based interpolation.
...@@ -94,9 +98,12 @@ class Resample(nn.Layer): ...@@ -94,9 +98,12 @@ class Resample(nn.Layer):
Controls the sharpness of the filter, larger numbers result in a Controls the sharpness of the filter, larger numbers result in a
sharper filter, but they are less efficient. Values from 4 to 10 are allowed. sharper filter, but they are less efficient. Values from 4 to 10 are allowed.
""" """
def __init__( def __init__(
self, orig_freq=16000, new_freq=16000, lowpass_filter_width=6, self,
): orig_freq=16000,
new_freq=16000,
lowpass_filter_width=6, ):
super().__init__() super().__init__()
self.orig_freq = orig_freq self.orig_freq = orig_freq
self.new_freq = new_freq self.new_freq = new_freq
...@@ -193,8 +200,7 @@ class Resample(nn.Layer): ...@@ -193,8 +200,7 @@ class Resample(nn.Layer):
window_size = self.weights.shape[1] window_size = self.weights.shape[1]
tot_output_samp = self._output_samples(wave_len) tot_output_samp = self._output_samples(wave_len)
resampled_waveform = paddle.zeros( resampled_waveform = paddle.zeros(
(batch_size, num_channels, tot_output_samp) (batch_size, num_channels, tot_output_samp))
)
# self.weights = self.weights.to(waveforms.device) # self.weights = self.weights.to(waveforms.device)
# Check weights are on correct device # Check weights are on correct device
...@@ -222,28 +228,25 @@ class Resample(nn.Layer): ...@@ -222,28 +228,25 @@ class Resample(nn.Layer):
right_padding = max(0, end_index + 1 - current_wave_len) right_padding = max(0, end_index + 1 - current_wave_len)
left_padding = max(0, -first_index) left_padding = max(0, -first_index)
wave_to_conv = paddle.nn.functional.pad( wave_to_conv = paddle.nn.functional.pad(
wave_to_conv, (left_padding, right_padding), data_format='NCL' wave_to_conv, (left_padding, right_padding), data_format='NCL')
)
conv_wave = paddle.nn.functional.conv1d( conv_wave = paddle.nn.functional.conv1d(
x=wave_to_conv, x=wave_to_conv,
weight=self.weights[i].repeat(num_channels, 1, 1), weight=self.weights[i].repeat(num_channels, 1, 1),
stride=self.conv_stride, stride=self.conv_stride,
groups=num_channels, groups=num_channels, )
)
# we want conv_wave[:, i] to be at # we want conv_wave[:, i] to be at
# output[:, i + n*conv_transpose_stride] # output[:, i + n*conv_transpose_stride]
dilated_conv_wave = paddle.nn.functional.conv1d_transpose( dilated_conv_wave = paddle.nn.functional.conv1d_transpose(
conv_wave, eye, stride=self.conv_transpose_stride conv_wave, eye, stride=self.conv_transpose_stride)
)
# pad dilated_conv_wave so it reaches the output length if needed. # pad dilated_conv_wave so it reaches the output length if needed.
left_padding = i left_padding = i
previous_padding = left_padding + dilated_conv_wave.shape[-1] previous_padding = left_padding + dilated_conv_wave.shape[-1]
right_padding = max(0, tot_output_samp - previous_padding) right_padding = max(0, tot_output_samp - previous_padding)
dilated_conv_wave = paddle.nn.functional.pad( dilated_conv_wave = paddle.nn.functional.pad(
dilated_conv_wave, (left_padding, right_padding), data_format='NCL' dilated_conv_wave, (left_padding, right_padding),
) data_format='NCL')
dilated_conv_wave = dilated_conv_wave[..., :tot_output_samp] dilated_conv_wave = dilated_conv_wave[..., :tot_output_samp]
resampled_waveform += dilated_conv_wave resampled_waveform += dilated_conv_wave
...@@ -326,9 +329,7 @@ class Resample(nn.Layer): ...@@ -326,9 +329,7 @@ class Resample(nn.Layer):
window_width = self.lowpass_filter_width / (2.0 * lowpass_cutoff) window_width = self.lowpass_filter_width / (2.0 * lowpass_cutoff)
assert lowpass_cutoff < min(self.orig_freq, self.new_freq) / 2 assert lowpass_cutoff < min(self.orig_freq, self.new_freq) / 2
output_t = paddle.arange( output_t = paddle.arange(start=0.0, end=self.output_samples)
start=0.0, end=self.output_samples
)
output_t /= self.new_freq output_t /= self.new_freq
min_t = output_t - window_width min_t = output_t - window_width
max_t = output_t + window_width max_t = output_t + window_width
...@@ -346,23 +347,16 @@ class Resample(nn.Layer): ...@@ -346,23 +347,16 @@ class Resample(nn.Layer):
inside_window_indices = delta_t.abs() < (window_width) inside_window_indices = delta_t.abs() < (window_width)
# raised-cosine (Hanning) window with width `window_width` # raised-cosine (Hanning) window with width `window_width`
weights[inside_window_indices] = 0.5 * ( weights[inside_window_indices] = 0.5 * (1 + paddle.cos(
1 2 * math.pi * lowpass_cutoff / self.lowpass_filter_width *
+ paddle.cos( delta_t[inside_window_indices]))
2
* math.pi
* lowpass_cutoff
/ self.lowpass_filter_width
* delta_t[inside_window_indices]
)
)
t_eq_zero_indices = delta_t == 0.0 t_eq_zero_indices = delta_t == 0.0
t_not_eq_zero_indices = ~t_eq_zero_indices t_not_eq_zero_indices = ~t_eq_zero_indices
# sinc filter function # sinc filter function
weights[t_not_eq_zero_indices] *= paddle.sin( weights[t_not_eq_zero_indices] *= paddle.sin(
2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices] 2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices]) / (
) / (math.pi * delta_t[t_not_eq_zero_indices]) math.pi * delta_t[t_not_eq_zero_indices])
# limit of the function at t = 0 # limit of the function at t = 0
weights[t_eq_zero_indices] *= 2 * lowpass_cutoff weights[t_eq_zero_indices] *= 2 * lowpass_cutoff
...@@ -411,8 +405,7 @@ class DropFreq(nn.Layer): ...@@ -411,8 +405,7 @@ class DropFreq(nn.Layer):
drop_count_low=1, drop_count_low=1,
drop_count_high=2, drop_count_high=2,
drop_width=0.05, drop_width=0.05,
drop_prob=1, drop_prob=1, ):
):
super().__init__() super().__init__()
self.drop_freq_low = drop_freq_low self.drop_freq_low = drop_freq_low
self.drop_freq_high = drop_freq_high self.drop_freq_high = drop_freq_high
...@@ -443,14 +436,14 @@ class DropFreq(nn.Layer): ...@@ -443,14 +436,14 @@ class DropFreq(nn.Layer):
# Pick number of frequencies to drop # Pick number of frequencies to drop
drop_count = paddle.randint( drop_count = paddle.randint(
low=self.drop_count_low, high=self.drop_count_high + 1, shape=(1,), low=self.drop_count_low,
) high=self.drop_count_high + 1,
shape=(1, ), )
# Pick a frequency to drop # Pick a frequency to drop
drop_range = self.drop_freq_high - self.drop_freq_low drop_range = self.drop_freq_high - self.drop_freq_low
drop_frequency = ( drop_frequency = (
paddle.rand(drop_count) * drop_range + self.drop_freq_low paddle.rand(drop_count) * drop_range + self.drop_freq_low)
)
# Filter parameters # Filter parameters
filter_length = 101 filter_length = 101
pad = filter_length // 2 pad = filter_length // 2
...@@ -461,8 +454,9 @@ class DropFreq(nn.Layer): ...@@ -461,8 +454,9 @@ class DropFreq(nn.Layer):
# Subtract each frequency # Subtract each frequency
for frequency in drop_frequency: for frequency in drop_frequency:
notch_kernel = notch_filter( notch_kernel = notch_filter(
frequency, filter_length, self.drop_width, frequency,
) filter_length,
self.drop_width, )
drop_filter = convolve1d(drop_filter, notch_kernel, pad) drop_filter = convolve1d(drop_filter, notch_kernel, pad)
# Apply filter # Apply filter
...@@ -471,6 +465,7 @@ class DropFreq(nn.Layer): ...@@ -471,6 +465,7 @@ class DropFreq(nn.Layer):
# Remove channels dimension if added # Remove channels dimension if added
return dropped_waveform.squeeze(-1) return dropped_waveform.squeeze(-1)
class DropChunk(nn.Layer): class DropChunk(nn.Layer):
"""This class drops portions of the input signal. """This class drops portions of the input signal.
Using `DropChunk` as an augmentation strategy helps a models learn to rely Using `DropChunk` as an augmentation strategy helps a models learn to rely
...@@ -523,8 +518,7 @@ class DropChunk(nn.Layer): ...@@ -523,8 +518,7 @@ class DropChunk(nn.Layer):
drop_start=0, drop_start=0,
drop_end=None, drop_end=None,
drop_prob=1, drop_prob=1,
noise_factor=0.0, noise_factor=0.0, ):
):
super().__init__() super().__init__()
self.drop_length_low = drop_length_low self.drop_length_low = drop_length_low
self.drop_length_high = drop_length_high self.drop_length_high = drop_length_high
...@@ -580,8 +574,7 @@ class DropChunk(nn.Layer): ...@@ -580,8 +574,7 @@ class DropChunk(nn.Layer):
drop_times = paddle.randint( drop_times = paddle.randint(
low=self.drop_count_low, low=self.drop_count_low,
high=self.drop_count_high + 1, high=self.drop_count_high + 1,
shape=(batch_size,), shape=(batch_size, ), )
)
# Iterate batch to set mask # Iterate batch to set mask
for i in range(batch_size): for i in range(batch_size):
...@@ -592,8 +585,7 @@ class DropChunk(nn.Layer): ...@@ -592,8 +585,7 @@ class DropChunk(nn.Layer):
length = paddle.randint( length = paddle.randint(
low=self.drop_length_low, low=self.drop_length_low,
high=self.drop_length_high + 1, high=self.drop_length_high + 1,
shape=(drop_times[i],), shape=(drop_times[i], ), )
)
# Compute range of starting locations # Compute range of starting locations
start_min = self.drop_start start_min = self.drop_start
...@@ -608,15 +600,16 @@ class DropChunk(nn.Layer): ...@@ -608,15 +600,16 @@ class DropChunk(nn.Layer):
# Pick starting locations # Pick starting locations
start = paddle.randint( start = paddle.randint(
low=start_min, high=start_max + 1, shape=(drop_times[i],), low=start_min,
) high=start_max + 1,
shape=(drop_times[i], ), )
end = start + length end = start + length
# Update waveform # Update waveform
if not self.noise_factor: if not self.noise_factor:
for j in range(drop_times[i]): for j in range(drop_times[i]):
dropped_waveform[i, start[j] : end[j]] = 0.0 dropped_waveform[i, start[j]:end[j]] = 0.0
else: else:
# Uniform distribution of -2 to +2 * avg amplitude should # Uniform distribution of -2 to +2 * avg amplitude should
# preserve the average for normalization # preserve the average for normalization
...@@ -625,7 +618,7 @@ class DropChunk(nn.Layer): ...@@ -625,7 +618,7 @@ class DropChunk(nn.Layer):
# zero-center the noise distribution # zero-center the noise distribution
noise_vec = paddle.rand([length[j]]) noise_vec = paddle.rand([length[j]])
noise_vec = 2 * noise_max * noise_vec - noise_max noise_vec = 2 * noise_max * noise_vec - noise_max
dropped_waveform[i, start[j] : end[j]] = noise_vec dropped_waveform[i, start[j]:end[j]] = noise_vec
return dropped_waveform return dropped_waveform
...@@ -691,25 +684,21 @@ class TimeDomainSpecAugment(nn.Layer): ...@@ -691,25 +684,21 @@ class TimeDomainSpecAugment(nn.Layer):
drop_chunk_count_high=5, drop_chunk_count_high=5,
drop_chunk_length_low=1000, drop_chunk_length_low=1000,
drop_chunk_length_high=2000, drop_chunk_length_high=2000,
drop_chunk_noise_factor=0, drop_chunk_noise_factor=0, ):
):
super().__init__() super().__init__()
self.speed_perturb = SpeedPerturb( self.speed_perturb = SpeedPerturb(
perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds)
)
self.drop_freq = DropFreq( self.drop_freq = DropFreq(
drop_prob=drop_freq_prob, drop_prob=drop_freq_prob,
drop_count_low=drop_freq_count_low, drop_count_low=drop_freq_count_low,
drop_count_high=drop_freq_count_high, drop_count_high=drop_freq_count_high, )
)
self.drop_chunk = DropChunk( self.drop_chunk = DropChunk(
drop_prob=drop_chunk_prob, drop_prob=drop_chunk_prob,
drop_count_low=drop_chunk_count_low, drop_count_low=drop_chunk_count_low,
drop_count_high=drop_chunk_count_high, drop_count_high=drop_chunk_count_high,
drop_length_low=drop_chunk_length_low, drop_length_low=drop_chunk_length_low,
drop_length_high=drop_chunk_length_high, drop_length_high=drop_chunk_length_high,
noise_factor=drop_chunk_noise_factor, noise_factor=drop_chunk_noise_factor, )
)
def forward(self, waveforms, lengths): def forward(self, waveforms, lengths):
"""Returns the distorted waveforms. """Returns the distorted waveforms.
......
import numpy as np from collections import defaultdict
import os
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Optional
from typing import Tuple from typing import Tuple
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2ConfigPure from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2ConfigPure
from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2Model from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2Model
from paddlespeech.s2t.modules.mask import make_pad_mask
from paddlespeech.s2t.utils.utility import log_add
from collections import defaultdict
from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN
from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
from yacs.config import CfgNode from paddlespeech.s2t.utils.utility import log_add
class Wav2vec2ASR(nn.Layer): class Wav2vec2ASR(nn.Layer):
def __init__(self, config: dict): def __init__(self, config: dict):
...@@ -36,8 +30,16 @@ class Wav2vec2ASR(nn.Layer): ...@@ -36,8 +30,16 @@ class Wav2vec2ASR(nn.Layer):
for parm in wav2vec2.parameters(): for parm in wav2vec2.parameters():
parm.trainable = False parm.trainable = False
self.wav2vec2 = wav2vec2 self.wav2vec2 = wav2vec2
self.enc = VanillaNN(input_shape=[None,None,wav2vec2_config.hidden_size], activation=nn.LeakyReLU, dnn_blocks=config.dnn_blocks, dnn_neurons=config.dnn_neurons) self.enc = VanillaNN(
self.ctc = CTC(odim=config.output_dim, enc_n_units=config.dnn_neurons, blank_id=config.blank_id, dropout_rate=config.ctc_dropout_rate, reduction=True) input_shape=[None, None, wav2vec2_config.hidden_size],
activation=nn.LeakyReLU,
dnn_blocks=config.dnn_blocks,
dnn_neurons=config.dnn_neurons)
self.ctc = CTC(odim=config.output_dim,
enc_n_units=config.dnn_neurons,
blank_id=config.blank_id,
dropout_rate=config.ctc_dropout_rate,
reduction=True)
def forward(self, wav, wavs_lens_rate, target, target_lens_rate): def forward(self, wav, wavs_lens_rate, target, target_lens_rate):
if self.normalize_wav: if self.normalize_wav:
...@@ -51,7 +53,8 @@ class Wav2vec2ASR(nn.Layer): ...@@ -51,7 +53,8 @@ class Wav2vec2ASR(nn.Layer):
x = self.enc(feats) x = self.enc(feats)
x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64) x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64)
target_lens = (target_lens_rate * target.shape[1]).round().astype(paddle.int64) target_lens = (target_lens_rate *
target.shape[1]).round().astype(paddle.int64)
ctc_loss = self.ctc(x, x_lens, target, target_lens) ctc_loss = self.ctc(x, x_lens, target, target_lens)
return ctc_loss return ctc_loss
...@@ -63,7 +66,8 @@ class Wav2vec2ASR(nn.Layer): ...@@ -63,7 +66,8 @@ class Wav2vec2ASR(nn.Layer):
decoding_method: str, decoding_method: str,
beam_size: int): beam_size: int):
batch_size = feats.shape[0] batch_size = feats.shape[0]
if decoding_method is 'ctc_prefix_beam_search' and batch_size > 1:
if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1:
logger.error( logger.error(
f'decoding mode {decoding_method} must be running with batch_size == 1' f'decoding mode {decoding_method} must be running with batch_size == 1'
) )
...@@ -79,13 +83,12 @@ class Wav2vec2ASR(nn.Layer): ...@@ -79,13 +83,12 @@ class Wav2vec2ASR(nn.Layer):
# with other batch decoding mode # with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search': elif decoding_method == 'ctc_prefix_beam_search':
assert feats.shape[0] == 1 assert feats.shape[0] == 1
hyp = self.ctc_prefix_beam_search( hyp = self.ctc_prefix_beam_search(feats, beam_size)
feats,
beam_size)
res = [text_feature.defeaturize(hyp)] res = [text_feature.defeaturize(hyp)]
res_tokenids = [hyp] res_tokenids = [hyp]
else: else:
raise ValueError(f"wav2vec2 not support decoding method: {decoding_method}") raise ValueError(
f"wav2vec2 not support decoding method: {decoding_method}")
return res, res_tokenids return res, res_tokenids
...@@ -94,8 +97,7 @@ class Wav2vec2ASR(nn.Layer): ...@@ -94,8 +97,7 @@ class Wav2vec2ASR(nn.Layer):
model = cls(config) model = cls(config)
return model return model
def ctc_greedy_search( def ctc_greedy_search(self, wav) -> List[List[int]]:
self, wav) -> List[List[int]]:
""" Apply CTC greedy search """ Apply CTC greedy search
Args: Args:
speech (paddle.Tensor): (batch, max_len) speech (paddle.Tensor): (batch, max_len)
...@@ -104,7 +106,7 @@ class Wav2vec2ASR(nn.Layer): ...@@ -104,7 +106,7 @@ class Wav2vec2ASR(nn.Layer):
List[List[int]]: best path result List[List[int]]: best path result
""" """
batch_size = wav.shape[0] batch_size = wav.shape[0]
wav = wav[:,:,0] wav = wav[:, :, 0]
if self.normalize_wav: if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[1:]) wav = F.layer_norm(wav, wav.shape[1:])
# Extract wav2vec output # Extract wav2vec output
...@@ -124,7 +126,10 @@ class Wav2vec2ASR(nn.Layer): ...@@ -124,7 +126,10 @@ class Wav2vec2ASR(nn.Layer):
return hyps return hyps
def _ctc_prefix_beam_search( def _ctc_prefix_beam_search(
self, wav, beam_size, blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]: self,
wav,
beam_size,
blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]:
""" CTC prefix beam search inner implementation """ CTC prefix beam search inner implementation
Args: Args:
speech (paddle.Tensor): (batch, max_len, feat_dim) speech (paddle.Tensor): (batch, max_len, feat_dim)
...@@ -142,7 +147,7 @@ class Wav2vec2ASR(nn.Layer): ...@@ -142,7 +147,7 @@ class Wav2vec2ASR(nn.Layer):
paddle.Tensor: encoder output, (1, max_len, encoder_dim), paddle.Tensor: encoder output, (1, max_len, encoder_dim),
it will be used for rescoring in attention rescoring mode it will be used for rescoring in attention rescoring mode
""" """
wav = wav[:,:,0] wav = wav[:, :, 0]
if self.normalize_wav: if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[1:]) wav = F.layer_norm(wav, wav.shape[1:])
...@@ -219,29 +224,5 @@ class Wav2vec2ASR(nn.Layer): ...@@ -219,29 +224,5 @@ class Wav2vec2ASR(nn.Layer):
Returns: Returns:
List[int]: CTC prefix beam search nbest results List[int]: CTC prefix beam search nbest results
""" """
hyps = self._ctc_prefix_beam_search( hyps = self._ctc_prefix_beam_search(wav, beam_size)
wav, beam_size)
return hyps[0][0] return hyps[0][0]
# @jit.to_static
# def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
# """ Export interface for c++ call, apply linear transform and log
# softmax before ctc
# Args:
# xs (paddle.Tensor): encoder output, (B, T, D)
# Returns:
# paddle.Tensor: activation before ctc
# """
# return self.ctc.log_softmax(xs)
# def _get_data(self):
# data_dir = "data"
# wavs = np.load(os.path.join(data_dir, "wavs.npy"))
# wavs_lens = np.load(os.path.join(data_dir, "wavs_lens.npy"))
# tokens = np.load(os.path.join(data_dir, "tokens.npy"))
# tokens_lens = np.load(os.path.join(data_dir, "tokens_lens.npy"))
# batch = (paddle.to_tensor(wavs), paddle.to_tensor(wavs_lens, dtype='float32'),
# paddle.to_tensor(tokens, dtype='int32'), paddle.to_tensor(tokens_lens, dtype='float32'))
# return batch
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册