提交 4b7786f2 编写于 作者: 小湉湉's avatar 小湉湉

add vits network scripts, test=tts

上级 93964960
#!/bin/bash
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
MODEL=vits
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}
\ No newline at end of file
......@@ -20,15 +20,14 @@ from scipy.interpolate import interp1d
class LogMelFBank():
def __init__(self,
sr=24000,
n_fft=2048,
hop_length=300,
win_length=None,
window="hann",
n_mels=80,
fmin=80,
fmax=7600,
eps=1e-10):
sr: int=24000,
n_fft: int=2048,
hop_length: int=300,
win_length: int=None,
window: str="hann",
n_mels: int=80,
fmin: int=80,
fmax: int=7600):
self.sr = sr
# stft
self.n_fft = n_fft
......@@ -54,7 +53,7 @@ class LogMelFBank():
fmax=self.fmax)
return mel_filter
def _stft(self, wav):
def _stft(self, wav: np.ndarray):
D = librosa.core.stft(
wav,
n_fft=self.n_fft,
......@@ -65,11 +64,11 @@ class LogMelFBank():
pad_mode=self.pad_mode)
return D
def _spectrogram(self, wav):
def _spectrogram(self, wav: np.ndarray):
D = self._stft(wav)
return np.abs(D)
def _mel_spectrogram(self, wav):
def _mel_spectrogram(self, wav: np.ndarray):
S = self._spectrogram(wav)
mel = np.dot(self.mel_filter, S)
return mel
......@@ -90,14 +89,18 @@ class LogMelFBank():
class Pitch():
def __init__(self, sr=24000, hop_length=300, f0min=80, f0max=7600):
def __init__(self,
sr: int=24000,
hop_length: int=300,
f0min: int=80,
f0max: int=7600):
self.sr = sr
self.hop_length = hop_length
self.f0min = f0min
self.f0max = f0max
def _convert_to_continuous_f0(self, f0: np.array) -> np.array:
def _convert_to_continuous_f0(self, f0: np.ndarray) -> np.ndarray:
if (f0 == 0).all():
print("All frames seems to be unvoiced.")
return f0
......@@ -120,9 +123,9 @@ class Pitch():
return f0
def _calculate_f0(self,
input: np.array,
use_continuous_f0=True,
use_log_f0=True) -> np.array:
input: np.ndarray,
use_continuous_f0: bool=True,
use_log_f0: bool=True) -> np.ndarray:
input = input.astype(np.float)
frame_period = 1000 * self.hop_length / self.sr
f0, timeaxis = pyworld.dio(
......@@ -139,7 +142,8 @@ class Pitch():
f0[nonzero_idxs] = np.log(f0[nonzero_idxs])
return f0.reshape(-1)
def _average_by_duration(self, input: np.array, d: np.array) -> np.array:
def _average_by_duration(self, input: np.ndarray,
d: np.ndarray) -> np.ndarray:
d_cumsum = np.pad(d.cumsum(0), (1, 0), 'constant')
arr_list = []
for start, end in zip(d_cumsum[:-1], d_cumsum[1:]):
......@@ -154,11 +158,11 @@ class Pitch():
return arr_list
def get_pitch(self,
wav,
use_continuous_f0=True,
use_log_f0=True,
use_token_averaged_f0=True,
duration=None):
wav: np.ndarray,
use_continuous_f0: bool=True,
use_log_f0: bool=True,
use_token_averaged_f0: bool=True,
duration: np.ndarray=None):
f0 = self._calculate_f0(wav, use_continuous_f0, use_log_f0)
if use_token_averaged_f0 and duration is not None:
f0 = self._average_by_duration(f0, duration)
......@@ -167,13 +171,13 @@ class Pitch():
class Energy():
def __init__(self,
sr=24000,
n_fft=2048,
hop_length=300,
win_length=None,
window="hann",
center=True,
pad_mode="reflect"):
sr: int=24000,
n_fft: int=2048,
hop_length: int=300,
win_length: int=None,
window: str="hann",
center: bool=True,
pad_mode: str="reflect"):
self.sr = sr
self.n_fft = n_fft
......@@ -183,7 +187,7 @@ class Energy():
self.center = center
self.pad_mode = pad_mode
def _stft(self, wav):
def _stft(self, wav: np.ndarray):
D = librosa.core.stft(
wav,
n_fft=self.n_fft,
......@@ -194,7 +198,7 @@ class Energy():
pad_mode=self.pad_mode)
return D
def _calculate_energy(self, input):
def _calculate_energy(self, input: np.ndarray):
input = input.astype(np.float32)
input_stft = self._stft(input)
input_power = np.abs(input_stft)**2
......@@ -203,7 +207,8 @@ class Energy():
np.sum(input_power, axis=0), a_min=1.0e-10, a_max=float('inf')))
return energy
def _average_by_duration(self, input: np.array, d: np.array) -> np.array:
def _average_by_duration(self, input: np.ndarray,
d: np.ndarray) -> np.ndarray:
d_cumsum = np.pad(d.cumsum(0), (1, 0), 'constant')
arr_list = []
for start, end in zip(d_cumsum[:-1], d_cumsum[1:]):
......@@ -214,8 +219,49 @@ class Energy():
arr_list = np.expand_dims(np.array(arr_list), 0).T
return arr_list
def get_energy(self, wav, use_token_averaged_energy=True, duration=None):
def get_energy(self,
wav: np.ndarray,
use_token_averaged_energy: bool=True,
duration: np.ndarray=None):
energy = self._calculate_energy(wav)
if use_token_averaged_energy and duration is not None:
energy = self._average_by_duration(energy, duration)
return energy
class LinearSpectrogram():
def __init__(
self,
n_fft: int=1024,
win_length: int=None,
hop_length: int=256,
window: str="hann",
center: bool=True, ):
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.center = center
self.n_fft = n_fft
self.pad_mode = "reflect"
def _stft(self, wav: np.ndarray):
D = librosa.core.stft(
wav,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode=self.pad_mode)
return D
def _spectrogram(self, wav: np.ndarray):
D = self._stft(wav)
return np.abs(D)
def get_linear_spectrogram(self, wav: np.ndarray):
linear_spectrogram = self._spectrogram(wav)
linear_spectrogram = np.clip(
linear_spectrogram, a_min=1e-10, a_max=float("inf"))
return linear_spectrogram.T
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -195,7 +195,7 @@ class Frontend():
new_initials.append(initials[i])
return new_initials, new_finals
def _p2id(self, phonemes: List[str]) -> np.array:
def _p2id(self, phonemes: List[str]) -> np.ndarray:
# replace unk phone with sp
phonemes = [
phn if phn in self.vocab_phones else "sp" for phn in phonemes
......@@ -203,7 +203,7 @@ class Frontend():
phone_ids = [self.vocab_phones[item] for item in phonemes]
return np.array(phone_ids, np.int64)
def _t2id(self, tones: List[str]) -> np.array:
def _t2id(self, tones: List[str]) -> np.ndarray:
# replace unk phone with sp
tones = [tone if tone in self.vocab_tones else "0" for tone in tones]
tone_ids = [self.vocab_tones[item] for item in tones]
......
......@@ -16,6 +16,7 @@ import copy
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
import paddle
import paddle.nn.functional as F
......@@ -34,6 +35,7 @@ class HiFiGANGenerator(nn.Layer):
in_channels: int=80,
out_channels: int=1,
channels: int=512,
global_channels: int=-1,
kernel_size: int=7,
upsample_scales: List[int]=(8, 8, 2, 2),
upsample_kernel_sizes: List[int]=(16, 16, 4, 4),
......@@ -51,6 +53,7 @@ class HiFiGANGenerator(nn.Layer):
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
channels (int): Number of hidden representation channels.
global_channels (int): Number of global conditioning channels.
kernel_size (int): Kernel size of initial and final conv layer.
upsample_scales (list): List of upsampling scales.
upsample_kernel_sizes (list): List of kernel sizes for upsampling layers.
......@@ -119,6 +122,9 @@ class HiFiGANGenerator(nn.Layer):
padding=(kernel_size - 1) // 2, ),
nn.Tanh(), )
if global_channels > 0:
self.global_conv = nn.Conv1D(global_channels, channels, 1)
nn.initializer.set_global_initializer(None)
# apply weight norm
......@@ -128,15 +134,18 @@ class HiFiGANGenerator(nn.Layer):
# reset parameters
self.reset_parameters()
def forward(self, c):
def forward(self, c, g: Optional[paddle.Tensor]=None):
"""Calculate forward propagation.
Args:
c (Tensor): Input tensor (B, in_channels, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, out_channels, T).
"""
c = self.input_conv(c)
if g is not None:
c = c + self.global_conv(g)
for i in range(self.num_upsamples):
c = self.upsamples[i](c)
# initialize
......@@ -187,16 +196,19 @@ class HiFiGANGenerator(nn.Layer):
self.apply(_remove_weight_norm)
def inference(self, c):
def inference(self, c, g: Optional[paddle.Tensor]=None):
"""Perform inference.
Args:
c (Tensor): Input tensor (T, in_channels).
normalize_before (bool): Whether to perform normalization.
g (Optional[Tensor]): Global conditioning tensor (global_channels, 1).
Returns:
Tensor:
Output tensor (T ** prod(upsample_scales), out_channels).
"""
c = self.forward(c.transpose([1, 0]).unsqueeze(0))
if g is not None:
g = g.unsqueeze(0)
c = self.forward(c.transpose([1, 0]).unsqueeze(0), g=g)
return c.squeeze(0).transpose([1, 0])
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stochastic duration predictor modules in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import Optional
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddlespeech.t2s.models.vits.flow import ConvFlow
from paddlespeech.t2s.models.vits.flow import DilatedDepthSeparableConv
from paddlespeech.t2s.models.vits.flow import ElementwiseAffineFlow
from paddlespeech.t2s.models.vits.flow import FlipFlow
from paddlespeech.t2s.models.vits.flow import LogFlow
class StochasticDurationPredictor(nn.Layer):
"""Stochastic duration predictor module.
This is a module of stochastic duration predictor described in `Conditional
Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2106.06103
"""
def __init__(
self,
channels: int=192,
kernel_size: int=3,
dropout_rate: float=0.5,
flows: int=4,
dds_conv_layers: int=3,
global_channels: int=-1, ):
"""Initialize StochasticDurationPredictor module.
Args:
channels (int): Number of channels.
kernel_size (int): Kernel size.
dropout_rate (float): Dropout rate.
flows (int): Number of flows.
dds_conv_layers (int): Number of conv layers in DDS conv.
global_channels (int): Number of global conditioning channels.
"""
super().__init__()
self.pre = nn.Conv1D(channels, channels, 1)
self.dds = DilatedDepthSeparableConv(
channels,
kernel_size,
layers=dds_conv_layers,
dropout_rate=dropout_rate, )
self.proj = nn.Conv1D(channels, channels, 1)
self.log_flow = LogFlow()
self.flows = nn.LayerList()
self.flows.append(ElementwiseAffineFlow(2))
for i in range(flows):
self.flows.append(
ConvFlow(
2,
channels,
kernel_size,
layers=dds_conv_layers, ))
self.flows.append(FlipFlow())
self.post_pre = nn.Conv1D(1, channels, 1)
self.post_dds = DilatedDepthSeparableConv(
channels,
kernel_size,
layers=dds_conv_layers,
dropout_rate=dropout_rate, )
self.post_proj = nn.Conv1D(channels, channels, 1)
self.post_flows = nn.LayerList()
self.post_flows.append(ElementwiseAffineFlow(2))
for i in range(flows):
self.post_flows.append(
ConvFlow(
2,
channels,
kernel_size,
layers=dds_conv_layers, ))
self.post_flows.append(FlipFlow())
if global_channels > 0:
self.global_conv = nn.Conv1D(global_channels, channels, 1)
def forward(
self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
w: Optional[paddle.Tensor]=None,
g: Optional[paddle.Tensor]=None,
inverse: bool=False,
noise_scale: float=1.0, ) -> paddle.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T_text).
x_mask (Tensor): Mask tensor (B, 1, T_text).
w (Optional[Tensor]): Duration tensor (B, 1, T_text).
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1)
inverse (bool): Whether to inverse the flow.
noise_scale (float): Noise scale value.
Returns:
Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,).
If inverse, log-duration tensor (B, 1, T_text).
"""
# stop gradient
# x = x.detach()
x = self.pre(x)
if g is not None:
# stop gradient
x = x + self.global_conv(g.detach())
x = self.dds(x, x_mask)
x = self.proj(x) * x_mask
if not inverse:
assert w is not None, "w must be provided."
h_w = self.post_pre(w)
h_w = self.post_dds(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = (paddle.randn([paddle.shape(w)[0], 2, paddle.shape(w)[2]]) *
x_mask)
z_q = e_q
logdet_tot_q = 0.0
for i, flow in enumerate(self.post_flows):
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q
z_u, z1 = paddle.split(z_q, [1, 1], 1)
u = F.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += paddle.sum(
(F.log_sigmoid(z_u) + F.log_sigmoid(-z_u)) * x_mask, [1, 2])
logq = (paddle.sum(-0.5 *
(math.log(2 * math.pi) +
(e_q**2)) * x_mask, [1, 2]) - logdet_tot_q)
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = paddle.concat([z0, z1], 1)
for flow in self.flows:
z, logdet = flow(z, x_mask, g=x, inverse=inverse)
logdet_tot = logdet_tot + logdet
nll = (paddle.sum(0.5 * (math.log(2 * math.pi) +
(z**2)) * x_mask, [1, 2]) - logdet_tot)
# (B,)
return nll + logq
else:
flows = list(reversed(self.flows))
# remove a useless vflow
flows = flows[:-2] + [flows[-1]]
z = (paddle.randn([paddle.shape(x)[0], 2, paddle.shape(x)[2]]) *
noise_scale)
for flow in flows:
z = flow(z, x_mask, g=x, inverse=inverse)
z0, z1 = paddle.split(z, 2, axis=1)
logw = z0
return logw
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Basic Flow modules used in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import Optional
from typing import Tuple
from typing import Union
import paddle
from paddle import nn
from paddlespeech.t2s.models.vits.transform import piecewise_rational_quadratic_transform
class FlipFlow(nn.Layer):
"""Flip flow module."""
def forward(self, x: paddle.Tensor, *args, inverse: bool=False, **kwargs
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Flipped tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
x = paddle.flip(x, [1])
if not inverse:
logdet = paddle.zeros(paddle.shape(x)[0], dtype=x.dtype)
return x, logdet
else:
return x
class LogFlow(nn.Layer):
"""Log flow module."""
def forward(self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
inverse: bool=False,
eps: float=1e-5,
**kwargs
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
inverse (bool): Whether to inverse the flow.
eps (float): Epsilon for log.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
if not inverse:
y = paddle.log(paddle.clip(x, min=eps)) * x_mask
logdet = paddle.sum(-y, [1, 2])
return y, logdet
else:
x = paddle.exp(x) * x_mask
return x
class ElementwiseAffineFlow(nn.Layer):
"""Elementwise affine flow module."""
def __init__(self, channels: int):
"""Initialize ElementwiseAffineFlow module.
Args:
channels (int): Number of channels.
"""
super().__init__()
self.channels = channels
m = paddle.zeros([channels, 1])
self.m = paddle.create_parameter(
shape=m.shape,
dtype=str(m.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(m))
logs = paddle.zeros([channels, 1])
self.logs = paddle.create_parameter(
shape=logs.shape,
dtype=str(logs.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(logs))
def forward(self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
inverse: bool=False,
**kwargs
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
if not inverse:
y = self.m + paddle.exp(self.logs) * x
y = y * x_mask
logdet = paddle.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * paddle.exp(-self.logs) * x_mask
return x
class Transpose(nn.Layer):
"""Transpose module for paddle.nn.Sequential()."""
def __init__(self, dim1: int, dim2: int):
"""Initialize Transpose module."""
super().__init__()
self.dim1 = dim1
self.dim2 = dim2
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
"""Transpose."""
len_dim = len(x.shape)
orig_perm = list(range(len_dim))
new_perm = orig_perm[:]
temp = new_perm[self.dim1]
new_perm[self.dim1] = new_perm[self.dim2]
new_perm[self.dim2] = temp
return paddle.transpose(x, new_perm)
class DilatedDepthSeparableConv(nn.Layer):
"""Dilated depth-separable conv module."""
def __init__(
self,
channels: int,
kernel_size: int,
layers: int,
dropout_rate: float=0.0,
eps: float=1e-5, ):
"""Initialize DilatedDepthSeparableConv module.
Args:
channels (int): Number of channels.
kernel_size (int): Kernel size.
layers (int): Number of layers.
dropout_rate (float): Dropout rate.
eps (float): Epsilon for layer norm.
"""
super().__init__()
self.convs = nn.LayerList()
for i in range(layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs.append(
nn.Sequential(
nn.Conv1D(
channels,
channels,
kernel_size,
groups=channels,
dilation=dilation,
padding=padding, ),
Transpose(1, 2),
nn.LayerNorm(channels, epsilon=eps),
Transpose(1, 2),
nn.GELU(),
nn.Conv1D(
channels,
channels,
1, ),
Transpose(1, 2),
nn.LayerNorm(channels, epsilon=eps),
Transpose(1, 2),
nn.GELU(),
nn.Dropout(dropout_rate), ))
def forward(self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
g: Optional[paddle.Tensor]=None) -> paddle.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, channels, T).
"""
if g is not None:
x = x + g
for f in self.convs:
y = f(x * x_mask)
x = x + y
return x * x_mask
class ConvFlow(nn.Layer):
"""Convolutional flow module."""
def __init__(
self,
in_channels: int,
hidden_channels: int,
kernel_size: int,
layers: int,
bins: int=10,
tail_bound: float=5.0, ):
"""Initialize ConvFlow module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size.
layers (int): Number of layers.
bins (int): Number of bins.
tail_bound (float): Tail bound value.
"""
super().__init__()
self.half_channels = in_channels // 2
self.hidden_channels = hidden_channels
self.bins = bins
self.tail_bound = tail_bound
self.input_conv = nn.Conv1D(
self.half_channels,
hidden_channels,
1, )
self.dds_conv = DilatedDepthSeparableConv(
hidden_channels,
kernel_size,
layers,
dropout_rate=0.0, )
self.proj = nn.Conv1D(
hidden_channels,
self.half_channels * (bins * 3 - 1),
1, )
# self.proj.weight.data.zero_()
# self.proj.bias.data.zero_()
weight = paddle.zeros(paddle.shape(self.proj.weight))
self.proj.weight = paddle.create_parameter(
shape=weight.shape,
dtype=str(weight.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(weight))
bias = paddle.zeros(paddle.shape(self.proj.bias))
self.proj.bias = paddle.create_parameter(
shape=bias.shape,
dtype=str(bias.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(bias))
def forward(
self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
g: Optional[paddle.Tensor]=None,
inverse: bool=False,
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
xa, xb = x.split(2, 1)
h = self.input_conv(xa)
h = self.dds_conv(h, x_mask, g=g)
# (B, half_channels * (bins * 3 - 1), T)
h = self.proj(h) * x_mask
b, c, t = xa.shape
# (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1)
h = h.reshape([b, c, -1, t]).transpose([0, 1, 3, 2])
denom = math.sqrt(self.hidden_channels)
unnorm_widths = h[..., :self.bins] / denom
unnorm_heights = h[..., self.bins:2 * self.bins] / denom
unnorm_derivatives = h[..., 2 * self.bins:]
xb, logdet_abs = piecewise_rational_quadratic_transform(
xb,
unnorm_widths,
unnorm_heights,
unnorm_derivatives,
inverse=inverse,
tails="linear",
tail_bound=self.tail_bound, )
x = paddle.concat([xa, xb], 1) * x_mask
logdet = paddle.sum(logdet_abs * x_mask, [1, 2])
if not inverse:
return x, logdet
else:
return x
此差异已折叠。
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Maximum path calculation module.
This code is based on https://github.com/jaywalnut310/vits.
"""
import warnings
import numpy as np
import paddle
from numba import njit
from numba import prange
try:
from .core import maximum_path_c
is_cython_avalable = True
except ImportError:
is_cython_avalable = False
warnings.warn(
"Cython version is not available. Fallback to 'EXPERIMETAL' numba version. "
"If you want to use the cython version, please build it as follows: "
"`cd paddlespeech/t2s/models/vits/monotonic_align; python setup.py build_ext --inplace`"
)
def maximum_path(neg_x_ent: paddle.Tensor,
attn_mask: paddle.Tensor) -> paddle.Tensor:
"""Calculate maximum path.
Args:
neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text).
attn_mask (Tensor): Attention mask (B, T_feats, T_text).
Returns:
Tensor: Maximum path tensor (B, T_feats, T_text).
"""
dtype = neg_x_ent.dtype
neg_x_ent = neg_x_ent.numpy().astype(np.float32)
path = np.zeros(neg_x_ent.shape, dtype=np.int32)
t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32)
t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32)
if is_cython_avalable:
maximum_path_c(path, neg_x_ent, t_t_max, t_s_max)
else:
maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max)
return paddle.cast(paddle.to_tensor(path), dtype=dtype)
@njit
def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf):
"""Calculate a single maximum path with numba."""
index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[y - 1, x]
if x == 0:
if y == 0:
v_prev = 0.0
else:
v_prev = max_neg_val
else:
v_prev = value[y - 1, x - 1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[y, index] = 1
if index != 0 and (index == y or
value[y - 1, index] < value[y - 1, index - 1]):
index = index - 1
@njit(parallel=True)
def maximum_path_numba(paths, values, t_ys, t_xs):
"""Calculate batch maximum path with numba."""
for i in prange(paths.shape[0]):
maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i])
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Maximum path calculation module with cython optimization.
This code is copied from https://github.com/jaywalnut310/vits and modifed code format.
"""
cimport cython
from cython.parallel import prange
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
cdef int x
cdef int y
cdef float v_prev
cdef float v_cur
cdef float tmp
cdef int index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[y - 1, x]
if x == 0:
if y == 0:
v_prev = 0.0
else:
v_prev = max_neg_val
else:
v_prev = value[y - 1, x - 1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[y, index] = 1
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
index = index - 1
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
cdef int b = paths.shape[0]
cdef int i
for i in prange(b, nogil=True):
maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Setup cython code."""
from Cython.Build import cythonize
from setuptools import Extension
from setuptools import setup
from setuptools.command.build_ext import build_ext as _build_ext
class build_ext(_build_ext):
"""Overwrite build_ext."""
def finalize_options(self):
"""Prevent numpy from thinking it is still in its setup process."""
_build_ext.finalize_options(self)
__builtins__.__NUMPY_SETUP__ = False
import numpy
self.include_dirs.append(numpy.get_include())
exts = [Extension(
name="core",
sources=["core.pyx"], )]
setup(
name="monotonic_align",
ext_modules=cythonize(exts, language_level=3),
cmdclass={"build_ext": build_ext}, )
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text encoder module in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
from typing import Optional
from typing import Tuple
import paddle
from paddle import nn
from paddlespeech.t2s.models.vits.wavenet.wavenet import WaveNet
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
class PosteriorEncoder(nn.Layer):
"""Posterior encoder module in VITS.
This is a module of posterior encoder described in `Conditional Variational
Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
in_channels: int=513,
out_channels: int=192,
hidden_channels: int=192,
kernel_size: int=5,
layers: int=16,
stacks: int=1,
base_dilation: int=1,
global_channels: int=-1,
dropout_rate: float=0.0,
bias: bool=True,
use_weight_norm: bool=True, ):
"""Initilialize PosteriorEncoder module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size in WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of repeat stacking of WaveNet.
base_dilation (int): Base dilation factor.
global_channels (int): Number of global conditioning channels.
dropout_rate (float): Dropout rate.
bias (bool): Whether to use bias parameters in conv.
use_weight_norm (bool): Whether to apply weight norm.
"""
super().__init__()
# define modules
self.input_conv = nn.Conv1D(in_channels, hidden_channels, 1)
self.encoder = WaveNet(
in_channels=-1,
out_channels=-1,
kernel_size=kernel_size,
layers=layers,
stacks=stacks,
base_dilation=base_dilation,
residual_channels=hidden_channels,
aux_channels=-1,
gate_channels=hidden_channels * 2,
skip_channels=hidden_channels,
global_channels=global_channels,
dropout_rate=dropout_rate,
bias=bias,
use_weight_norm=use_weight_norm,
use_first_conv=False,
use_last_conv=False,
scale_residual=False,
scale_skip_connect=True, )
self.proj = nn.Conv1D(hidden_channels, out_channels * 2, 1)
def forward(
self,
x: paddle.Tensor,
x_lengths: paddle.Tensor,
g: Optional[paddle.Tensor]=None
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T_feats).
x_lengths (Tensor): Length tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Encoded hidden representation tensor (B, out_channels, T_feats).
Tensor: Projected mean tensor (B, out_channels, T_feats).
Tensor: Projected scale tensor (B, out_channels, T_feats).
Tensor: Mask tensor for input tensor (B, 1, T_feats).
"""
x_mask = make_non_pad_mask(x_lengths).unsqueeze(1)
x = self.input_conv(x) * x_mask
x = self.encoder(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = paddle.split(stats, 2, axis=1)
z = (m + paddle.randn(paddle.shape(m)) * paddle.exp(logs)) * x_mask
return z, m, logs, x_mask
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Residual affine coupling modules in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
from typing import Optional
from typing import Tuple
from typing import Union
import paddle
from paddle import nn
from paddlespeech.t2s.models.vits.flow import FlipFlow
from paddlespeech.t2s.models.vits.wavenet.wavenet import WaveNet
class ResidualAffineCouplingBlock(nn.Layer):
"""Residual affine coupling block module.
This is a module of residual affine coupling block, which used as "Flow" in
`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
in_channels: int=192,
hidden_channels: int=192,
flows: int=4,
kernel_size: int=5,
base_dilation: int=1,
layers: int=4,
global_channels: int=-1,
dropout_rate: float=0.0,
use_weight_norm: bool=True,
bias: bool=True,
use_only_mean: bool=True, ):
"""Initilize ResidualAffineCouplingBlock module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
flows (int): Number of flows.
kernel_size (int): Kernel size for WaveNet.
base_dilation (int): Base dilation factor for WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of stacks of WaveNet.
global_channels (int): Number of global channels.
dropout_rate (float): Dropout rate.
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
bias (bool): Whether to use bias paramters in WaveNet.
use_only_mean (bool): Whether to estimate only mean.
"""
super().__init__()
self.flows = nn.LayerList()
for i in range(flows):
self.flows.append(
ResidualAffineCouplingLayer(
in_channels=in_channels,
hidden_channels=hidden_channels,
kernel_size=kernel_size,
base_dilation=base_dilation,
layers=layers,
stacks=1,
global_channels=global_channels,
dropout_rate=dropout_rate,
use_weight_norm=use_weight_norm,
bias=bias,
use_only_mean=use_only_mean, ))
self.flows.append(FlipFlow())
def forward(
self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
g: Optional[paddle.Tensor]=None,
inverse: bool=False, ) -> paddle.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_mask (Tensor): Length tensor (B, 1, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, in_channels, T).
"""
if not inverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, inverse=inverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, inverse=inverse)
return x
class ResidualAffineCouplingLayer(nn.Layer):
"""Residual affine coupling layer."""
def __init__(
self,
in_channels: int=192,
hidden_channels: int=192,
kernel_size: int=5,
base_dilation: int=1,
layers: int=5,
stacks: int=1,
global_channels: int=-1,
dropout_rate: float=0.0,
use_weight_norm: bool=True,
bias: bool=True,
use_only_mean: bool=True, ):
"""Initialzie ResidualAffineCouplingLayer module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size for WaveNet.
base_dilation (int): Base dilation factor for WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of stacks of WaveNet.
global_channels (int): Number of global channels.
dropout_rate (float): Dropout rate.
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
bias (bool): Whether to use bias paramters in WaveNet.
use_only_mean (bool): Whether to estimate only mean.
"""
assert in_channels % 2 == 0, "in_channels should be divisible by 2"
super().__init__()
self.half_channels = in_channels // 2
self.use_only_mean = use_only_mean
# define modules
self.input_conv = nn.Conv1D(
self.half_channels,
hidden_channels,
1, )
self.encoder = WaveNet(
in_channels=-1,
out_channels=-1,
kernel_size=kernel_size,
layers=layers,
stacks=stacks,
base_dilation=base_dilation,
residual_channels=hidden_channels,
aux_channels=-1,
gate_channels=hidden_channels * 2,
skip_channels=hidden_channels,
global_channels=global_channels,
dropout_rate=dropout_rate,
bias=bias,
use_weight_norm=use_weight_norm,
use_first_conv=False,
use_last_conv=False,
scale_residual=False,
scale_skip_connect=True, )
if use_only_mean:
self.proj = nn.Conv1D(
hidden_channels,
self.half_channels,
1, )
else:
self.proj = nn.Conv1D(
hidden_channels,
self.half_channels * 2,
1, )
# self.proj.weight.data.zero_()
# self.proj.bias.data.zero_()
weight = paddle.zeros(paddle.shape(self.proj.weight))
self.proj.weight = paddle.create_parameter(
shape=weight.shape,
dtype=str(weight.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(weight))
bias = paddle.zeros(paddle.shape(self.proj.bias))
self.proj.bias = paddle.create_parameter(
shape=bias.shape,
dtype=str(bias.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(bias))
def forward(
self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
g: Optional[paddle.Tensor]=None,
inverse: bool=False,
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_lengths (Tensor): Length tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, in_channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
xa, xb = paddle.split(x, 2, axis=1)
h = self.input_conv(xa) * x_mask
h = self.encoder(h, x_mask, g=g)
stats = self.proj(h) * x_mask
if not self.use_only_mean:
m, logs = paddle.split(stats, 2, axis=1)
else:
m = stats
logs = paddle.zeros(paddle.shape(m))
if not inverse:
xb = m + xb * paddle.exp(logs) * x_mask
x = paddle.concat([xa, xb], 1)
logdet = paddle.sum(logs, [1, 2])
return x, logdet
else:
xb = (xb - m) * paddle.exp(-logs) * x_mask
x = paddle.concat([xa, xb], 1)
return x
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text encoder module in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import Tuple
import paddle
from paddle import nn
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder as Encoder
class TextEncoder(nn.Layer):
"""Text encoder module in VITS.
This is a module of text encoder described in `Conditional Variational Autoencoder
with Adversarial Learning for End-to-End Text-to-Speech`_.
Instead of the relative positional Transformer, we use conformer architecture as
the encoder module, which contains additional convolution layers.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
vocabs: int,
attention_dim: int=192,
attention_heads: int=2,
linear_units: int=768,
blocks: int=6,
positionwise_layer_type: str="conv1d",
positionwise_conv_kernel_size: int=3,
positional_encoding_layer_type: str="rel_pos",
self_attention_layer_type: str="rel_selfattn",
activation_type: str="swish",
normalize_before: bool=True,
use_macaron_style: bool=False,
use_conformer_conv: bool=False,
conformer_kernel_size: int=7,
dropout_rate: float=0.1,
positional_dropout_rate: float=0.0,
attention_dropout_rate: float=0.0, ):
"""Initialize TextEncoder module.
Args:
vocabs (int): Vocabulary size.
attention_dim (int): Attention dimension.
attention_heads (int): Number of attention heads.
linear_units (int): Number of linear units of positionwise layers.
blocks (int): Number of encoder blocks.
positionwise_layer_type (str): Positionwise layer type.
positionwise_conv_kernel_size (int): Positionwise layer's kernel size.
positional_encoding_layer_type (str): Positional encoding layer type.
self_attention_layer_type (str): Self-attention layer type.
activation_type (str): Activation function type.
normalize_before (bool): Whether to apply LayerNorm before attention.
use_macaron_style (bool): Whether to use macaron style components.
use_conformer_conv (bool): Whether to use conformer conv layers.
conformer_kernel_size (int): Conformer's conv kernel size.
dropout_rate (float): Dropout rate.
positional_dropout_rate (float): Dropout rate for positional encoding.
attention_dropout_rate (float): Dropout rate for attention.
"""
super().__init__()
# store for forward
self.attention_dim = attention_dim
# define modules
self.emb = nn.Embedding(vocabs, attention_dim)
dist = paddle.distribution.Normal(loc=0.0, scale=attention_dim**-0.5)
w = dist.sample(self.emb.weight.shape)
self.emb.weight.set_value(w)
self.encoder = Encoder(
idim=-1,
input_layer=None,
attention_dim=attention_dim,
attention_heads=attention_heads,
linear_units=linear_units,
num_blocks=blocks,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
normalize_before=normalize_before,
positionwise_layer_type=positionwise_layer_type,
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
macaron_style=use_macaron_style,
pos_enc_layer_type=positional_encoding_layer_type,
selfattention_layer_type=self_attention_layer_type,
activation_type=activation_type,
use_cnn_module=use_conformer_conv,
cnn_module_kernel=conformer_kernel_size, )
self.proj = nn.Conv1D(attention_dim, attention_dim * 2, 1)
def forward(
self,
x: paddle.Tensor,
x_lengths: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input index tensor (B, T_text).
x_lengths (Tensor): Length tensor (B,).
Returns:
Tensor: Encoded hidden representation (B, attention_dim, T_text).
Tensor: Projected mean tensor (B, attention_dim, T_text).
Tensor: Projected scale tensor (B, attention_dim, T_text).
Tensor: Mask tensor for input tensor (B, 1, T_text).
"""
x = self.emb(x) * math.sqrt(self.attention_dim)
x_mask = make_non_pad_mask(x_lengths).unsqueeze(1)
# encoder assume the channel last (B, T_text, attention_dim)
# but mask shape shoud be (B, 1, T_text)
x, _ = self.encoder(x, x_mask)
# convert the channel first (B, attention_dim, T_text)
x = paddle.transpose(x, [0, 2, 1])
stats = self.proj(x) * x_mask
m, logs = paddle.split(stats, 2, axis=1)
return x, m, logs, x_mask
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flow-related transformation.
This code is based on https://github.com/bayesiains/nflows.
"""
import numpy as np
import paddle
from paddle.nn import functional as F
from paddlespeech.t2s.modules.nets_utils import paddle_gather
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs)
return outputs, logabsdet
def mask_preprocess(x, mask):
B, C, T, bins = paddle.shape(x)
new_x = paddle.zeros([mask.sum(), bins])
for i in range(bins):
new_x[:, i] = x[:, :, :, i][mask]
return new_x
def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = paddle.zeros(paddle.shape(inputs))
logabsdet = paddle.zeros(paddle.shape(inputs))
if tails == "linear":
unnormalized_derivatives = F.pad(
unnormalized_derivatives,
pad=[0] * (len(unnormalized_derivatives.shape) - 1) * 2 + [1, 1])
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
unnormalized_widths = mask_preprocess(unnormalized_widths,
inside_interval_mask)
unnormalized_heights = mask_preprocess(unnormalized_heights,
inside_interval_mask)
unnormalized_derivatives = mask_preprocess(unnormalized_derivatives,
inside_interval_mask)
(outputs[inside_interval_mask],
logabsdet[inside_interval_mask], ) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative, )
return outputs, logabsdet
def rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
if paddle.min(inputs) < left or paddle.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, axis=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = paddle.cumsum(widths, axis=-1)
cumwidths = F.pad(
cumwidths,
pad=[0] * (len(cumwidths.shape) - 1) * 2 + [1, 0],
mode="constant",
value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, axis=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = paddle.cumsum(heights, axis=-1)
cumheights = F.pad(
cumheights,
pad=[0] * (len(cumheights.shape) - 1) * 2 + [1, 0],
mode="constant",
value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = _searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = _searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = paddle_gather(cumwidths, -1, bin_idx)[..., 0]
input_bin_widths = paddle_gather(widths, -1, bin_idx)[..., 0]
input_cumheights = paddle_gather(cumheights, -1, bin_idx)[..., 0]
delta = heights / widths
input_delta = paddle_gather(delta, -1, bin_idx)[..., 0]
input_derivatives = paddle_gather(derivatives, -1, bin_idx)[..., 0]
input_derivatives_plus_one = paddle_gather(derivatives[..., 1:], -1,
bin_idx)[..., 0]
input_heights = paddle_gather(heights, -1, bin_idx)[..., 0]
if inverse:
a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta)
c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - paddle.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta
) * theta_one_minus_theta)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2) + 2 * input_delta *
theta_one_minus_theta + input_derivatives * (1 - root).pow(2))
logabsdet = paddle.log(derivative_numerator) - 2 * paddle.log(
denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (input_delta * theta.pow(2) +
input_derivatives * theta_one_minus_theta)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta
) * theta_one_minus_theta)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2) + 2 * input_delta *
theta_one_minus_theta + input_derivatives * (1 - theta).pow(2))
logabsdet = paddle.log(derivative_numerator) - 2 * paddle.log(
denominator)
return outputs, logabsdet
def _searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return paddle.sum(inputs[..., None] >= bin_locations, axis=-1) - 1
此差异已折叠。
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import math
from typing import Optional
from typing import Tuple
import paddle
import paddle.nn.functional as F
from paddle import nn
class ResidualBlock(nn.Layer):
"""Residual block module in WaveNet."""
def __init__(
self,
kernel_size: int=3,
residual_channels: int=64,
gate_channels: int=128,
skip_channels: int=64,
aux_channels: int=80,
global_channels: int=-1,
dropout_rate: float=0.0,
dilation: int=1,
bias: bool=True,
scale_residual: bool=False, ):
"""Initialize ResidualBlock module.
Args:
kernel_size (int): Kernel size of dilation convolution layer.
residual_channels (int): Number of channels for residual connection.
skip_channels (int): Number of channels for skip connection.
aux_channels (int): Number of local conditioning channels.
dropout (float): Dropout probability.
dilation (int): Dilation factor.
bias (bool): Whether to add bias parameter in convolution layers.
scale_residual (bool): Whether to scale the residual outputs.
"""
super().__init__()
self.dropout_rate = dropout_rate
self.residual_channels = residual_channels
self.skip_channels = skip_channels
self.scale_residual = scale_residual
# check
assert (
kernel_size - 1) % 2 == 0, "Not support even number kernel size."
assert gate_channels % 2 == 0
# dilation conv
padding = (kernel_size - 1) // 2 * dilation
self.conv = nn.Conv1D(
residual_channels,
gate_channels,
kernel_size,
padding=padding,
dilation=dilation,
bias_attr=bias, )
# local conditioning
if aux_channels > 0:
self.conv1x1_aux = nn.Conv1D(
aux_channels, gate_channels, kernel_size=1, bias_attr=False)
else:
self.conv1x1_aux = None
# global conditioning
if global_channels > 0:
self.conv1x1_glo = nn.Conv1D(
global_channels, gate_channels, kernel_size=1, bias_attr=False)
else:
self.conv1x1_glo = None
# conv output is split into two groups
gate_out_channels = gate_channels // 2
# NOTE: concat two convs into a single conv for the efficiency
# (integrate res 1x1 + skip 1x1 convs)
self.conv1x1_out = nn.Conv1D(
gate_out_channels,
residual_channels + skip_channels,
kernel_size=1,
bias_attr=bias)
def forward(
self,
x: paddle.Tensor,
x_mask: Optional[paddle.Tensor]=None,
c: Optional[paddle.Tensor]=None,
g: Optional[paddle.Tensor]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, residual_channels, T).
x_mask Optional[paddle.Tensor]: Mask tensor (B, 1, T).
c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor for residual connection (B, residual_channels, T).
Tensor: Output tensor for skip connection (B, skip_channels, T).
"""
residual = x
x = F.dropout(x, p=self.dropout_rate, training=self.training)
x = self.conv(x)
# split into two part for gated activation
splitdim = 1
xa, xb = paddle.split(x, 2, axis=splitdim)
# local conditioning
if c is not None:
c = self.conv1x1_aux(c)
ca, cb = paddle.split(c, 2, axis=splitdim)
xa, xb = xa + ca, xb + cb
# global conditioning
if g is not None:
g = self.conv1x1_glo(g)
ga, gb = paddle.split(g, 2, axis=splitdim)
xa, xb = xa + ga, xb + gb
x = paddle.tanh(xa) * F.sigmoid(xb)
# residual + skip 1x1 conv
x = self.conv1x1_out(x)
if x_mask is not None:
x = x * x_mask
# split integrated conv results
x, s = paddle.split(
x, [self.residual_channels, self.skip_channels], axis=1)
# for residual connection
x = x + residual
if self.scale_residual:
x = x * math.sqrt(0.5)
return x, s
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import math
from typing import Optional
import paddle
from paddle import nn
from paddlespeech.t2s.models.vits.wavenet.residual_block import ResidualBlock
class WaveNet(nn.Layer):
"""WaveNet with global conditioning."""
def __init__(
self,
in_channels: int=1,
out_channels: int=1,
kernel_size: int=3,
layers: int=30,
stacks: int=3,
base_dilation: int=2,
residual_channels: int=64,
aux_channels: int=-1,
gate_channels: int=128,
skip_channels: int=64,
global_channels: int=-1,
dropout_rate: float=0.0,
bias: bool=True,
use_weight_norm: bool=True,
use_first_conv: bool=False,
use_last_conv: bool=False,
scale_residual: bool=False,
scale_skip_connect: bool=False, ):
"""Initialize WaveNet module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Kernel size of dilated convolution.
layers (int): Number of residual block layers.
stacks (int): Number of stacks i.e., dilation cycles.
base_dilation (int): Base dilation factor.
residual_channels (int): Number of channels in residual conv.
gate_channels (int): Number of channels in gated conv.
skip_channels (int): Number of channels in skip conv.
aux_channels (int): Number of channels for local conditioning feature.
global_channels (int): Number of channels for global conditioning feature.
dropout_rate (float): Dropout rate. 0.0 means no dropout applied.
bias (bool): Whether to use bias parameter in conv layer.
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
be applied to all of the conv layers.
use_first_conv (bool): Whether to use the first conv layers.
use_last_conv (bool): Whether to use the last conv layers.
scale_residual (bool): Whether to scale the residual outputs.
scale_skip_connect (bool): Whether to scale the skip connection outputs.
"""
super().__init__()
self.layers = layers
self.stacks = stacks
self.kernel_size = kernel_size
self.base_dilation = base_dilation
self.use_first_conv = use_first_conv
self.use_last_conv = use_last_conv
self.scale_skip_connect = scale_skip_connect
# check the number of layers and stacks
assert layers % stacks == 0
layers_per_stack = layers // stacks
# define first convolution
if self.use_first_conv:
self.first_conv = nn.Conv1D(
in_channels, residual_channels, kernel_size=1, bias_attr=True)
# define residual blocks
self.conv_layers = nn.LayerList()
for layer in range(layers):
dilation = base_dilation**(layer % layers_per_stack)
conv = ResidualBlock(
kernel_size=kernel_size,
residual_channels=residual_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
global_channels=global_channels,
dilation=dilation,
dropout_rate=dropout_rate,
bias=bias,
scale_residual=scale_residual, )
self.conv_layers.append(conv)
# define output layers
if self.use_last_conv:
self.last_conv = nn.Sequential(
nn.ReLU(),
nn.Conv1D(
skip_channels, skip_channels, kernel_size=1,
bias_attr=True),
nn.ReLU(),
nn.Conv1D(
skip_channels, out_channels, kernel_size=1, bias_attr=True),
)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
def forward(
self,
x: paddle.Tensor,
x_mask: Optional[paddle.Tensor]=None,
c: Optional[paddle.Tensor]=None,
g: Optional[paddle.Tensor]=None, ) -> paddle.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T) if use_first_conv else
(B, residual_channels, T).
x_mask (Optional[Tensor]): Mask tensor (B, 1, T).
c (Optional[Tensor]): Local conditioning features (B, aux_channels, T).
g (Optional[Tensor]): Global conditioning features (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, out_channels, T) if use_last_conv else
(B, residual_channels, T).
"""
# encode to hidden representation
if self.use_first_conv:
x = self.first_conv(x)
# residual block
skips = 0.0
for f in self.conv_layers:
x, h = f(x, x_mask=x_mask, c=c, g=g)
skips = skips + h
x = skips
if self.scale_skip_connect:
x = x * math.sqrt(1.0 / len(self.conv_layers))
# apply final layers
if self.use_last_conv:
x = self.last_conv(x)
return x
def apply_weight_norm(self):
def _apply_weight_norm(layer):
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
nn.utils.weight_norm(layer)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
def _remove_weight_norm(layer):
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
pass
self.apply(_remove_weight_norm)
......@@ -1006,3 +1006,40 @@ class FeatureMatchLoss(nn.Layer):
feat_match_loss /= i + 1
return feat_match_loss
# loss for VITS
class KLDivergenceLoss(nn.Layer):
"""KL divergence loss."""
def forward(
self,
z_p: paddle.Tensor,
logs_q: paddle.Tensor,
m_p: paddle.Tensor,
logs_p: paddle.Tensor,
z_mask: paddle.Tensor,
) -> paddle.Tensor:
"""Calculate KL divergence loss.
Args:
z_p (Tensor): Flow hidden representation (B, H, T_feats).
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
z_mask (Tensor): Mask tensor (B, 1, T_feats).
Returns:
Tensor: KL divergence loss.
"""
z_p = paddle.cast(z_p, 'float32')
logs_q = paddle.cast(logs_q, 'float32')
m_p = paddle.cast(m_p, 'float32')
logs_p = paddle.cast(logs_p, 'float32')
z_mask = paddle.cast(z_mask, 'float32')
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * paddle.exp(-2.0 * logs_p)
kl = paddle.sum(kl * z_mask)
loss = kl / paddle.sum(z_mask)
return loss
\ No newline at end of file
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
from typing import Tuple
import paddle
from paddle import nn
from typeguard import check_argument_types
......@@ -129,3 +131,66 @@ def initialize(model: nn.Layer, init: str):
nn.initializer.Constant())
else:
raise ValueError("Unknown initialization: " + init)
# for VITS
def get_random_segments(
x: paddle.paddle,
x_lengths: paddle.Tensor,
segment_size: int, ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Get random segments.
Args:
x (Tensor): Input tensor (B, C, T).
x_lengths (Tensor): Length tensor (B,).
segment_size (int): Segment size.
Returns:
Tensor: Segmented tensor (B, C, segment_size).
Tensor: Start index tensor (B,).
"""
b, c, t = paddle.shape(x)
max_start_idx = x_lengths - segment_size
start_idxs = paddle.cast(paddle.rand([b]) * max_start_idx, 'int64')
segments = get_segments(x, start_idxs, segment_size)
return segments, start_idxs
def get_segments(
x: paddle.Tensor,
start_idxs: paddle.Tensor,
segment_size: int, ) -> paddle.Tensor:
"""Get segments.
Args:
x (Tensor): Input tensor (B, C, T).
start_idxs (Tensor): Start index tensor (B,).
segment_size (int): Segment size.
Returns:
Tensor: Segmented tensor (B, C, segment_size).
"""
b, c, t = paddle.shape(x)
segments = paddle.zeros([b, c, segment_size], dtype=x.dtype)
for i, start_idx in enumerate(start_idxs):
segments[i] = x[i, :, start_idx:start_idx + segment_size]
return segments
# see https://github.com/PaddlePaddle/X2Paddle/blob/develop/docs/pytorch_project_convertor/API_docs/ops/torch.gather.md
def paddle_gather(x, dim, index):
index_shape = index.shape
index_flatten = index.flatten()
if dim < 0:
dim = len(x.shape) + dim
nd_index = []
for k in range(len(x.shape)):
if k == dim:
nd_index.append(index_flatten)
else:
reshape_shape = [1] * len(x.shape)
reshape_shape[k] = x.shape[k]
x_arange = paddle.arange(x.shape[k], dtype=index.dtype)
x_arange = x_arange.reshape(reshape_shape)
dim_index = paddle.expand(x_arange, index_shape).flatten()
nd_index.append(dim_index)
ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64")
paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape)
return paddle_out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册