Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
0a2e367f
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0a2e367f
编写于
3月 21, 2023
作者:
小湉湉
提交者:
GitHub
3月 21, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[TTS]clean starganv2 vc model code and add docstring (#2987)
* clean code * add docstring
上级
880c172d
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
176 addition
and
433 deletion
+176
-433
paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/layers.py
paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/layers.py
+5
-223
paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py
paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py
+3
-4
paddlespeech/t2s/models/starganv2_vc/JDCNet/model.py
paddlespeech/t2s/models/starganv2_vc/JDCNet/model.py
+51
-75
paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py
paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py
+117
-131
未找到文件。
paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/layers.py
浏览文件 @
0a2e367f
...
...
@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
random
import
paddle
import
paddle.nn.functional
as
F
import
paddleaudio.functional
as
audio_F
...
...
@@ -46,7 +44,8 @@ class LinearNorm(nn.Layer):
self
.
linear_layer
.
weight
,
gain
=
_calculate_gain
(
w_init_gain
))
def
forward
(
self
,
x
:
paddle
.
Tensor
):
return
self
.
linear_layer
(
x
)
out
=
self
.
linear_layer
(
x
)
return
out
class
ConvNorm
(
nn
.
Layer
):
...
...
@@ -82,85 +81,6 @@ class ConvNorm(nn.Layer):
return
conv_signal
class
CausualConv
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
=
1
,
stride
:
int
=
1
,
padding
:
int
=
1
,
dilation
:
int
=
1
,
bias
:
bool
=
True
,
w_init_gain
:
str
=
'linear'
,
param
=
None
):
super
().
__init__
()
if
padding
is
None
:
assert
(
kernel_size
%
2
==
1
)
padding
=
int
(
dilation
*
(
kernel_size
-
1
)
/
2
)
*
2
else
:
self
.
padding
=
padding
*
2
self
.
conv
=
nn
.
Conv1D
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
self
.
padding
,
dilation
=
dilation
,
bias_attr
=
bias
)
xavier_uniform_
(
self
.
conv
.
weight
,
gain
=
_calculate_gain
(
w_init_gain
,
param
=
param
))
def
forward
(
self
,
x
:
paddle
.
Tensor
):
x
=
self
.
conv
(
x
)
x
=
x
[:,
:,
:
-
self
.
padding
]
return
x
class
CausualBlock
(
nn
.
Layer
):
def
__init__
(
self
,
hidden_dim
:
int
,
n_conv
:
int
=
3
,
dropout_p
:
float
=
0.2
,
activ
:
str
=
'lrelu'
):
super
().
__init__
()
self
.
blocks
=
nn
.
LayerList
([
self
.
_get_conv
(
hidden_dim
=
hidden_dim
,
dilation
=
3
**
i
,
activ
=
activ
,
dropout_p
=
dropout_p
)
for
i
in
range
(
n_conv
)
])
def
forward
(
self
,
x
):
for
block
in
self
.
blocks
:
res
=
x
x
=
block
(
x
)
x
+=
res
return
x
def
_get_conv
(
self
,
hidden_dim
:
int
,
dilation
:
int
,
activ
:
str
=
'lrelu'
,
dropout_p
:
float
=
0.2
):
layers
=
[
CausualConv
(
in_channels
=
hidden_dim
,
out_channels
=
hidden_dim
,
kernel_size
=
3
,
padding
=
dilation
,
dilation
=
dilation
),
_get_activation_fn
(
activ
),
nn
.
BatchNorm1D
(
hidden_dim
),
nn
.
Dropout
(
p
=
dropout_p
),
CausualConv
(
in_channels
=
hidden_dim
,
out_channels
=
hidden_dim
,
kernel_size
=
3
,
padding
=
1
,
dilation
=
1
),
_get_activation_fn
(
activ
),
nn
.
Dropout
(
p
=
dropout_p
)
]
return
nn
.
Sequential
(
*
layers
)
class
ConvBlock
(
nn
.
Layer
):
def
__init__
(
self
,
hidden_dim
:
int
,
...
...
@@ -264,13 +184,14 @@ class Attention(nn.Layer):
"""
Args:
query:
decoder output (
batch
, n_mel_channels * n_frames_per_step)
decoder output (
B
, n_mel_channels * n_frames_per_step)
processed_memory:
processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat:
cumulative and prev. att weights (B, 2, max_time)
Returns:
Tensor: alignment (batch, max_time)
Tensor:
alignment (B, max_time)
"""
processed_query
=
self
.
query_layer
(
query
.
unsqueeze
(
1
))
...
...
@@ -316,144 +237,6 @@ class Attention(nn.Layer):
return
attention_context
,
attention_weights
class
ForwardAttentionV2
(
nn
.
Layer
):
def
__init__
(
self
,
attention_rnn_dim
:
int
,
embedding_dim
:
int
,
attention_dim
:
int
,
attention_location_n_filters
:
int
,
attention_location_kernel_size
:
int
):
super
().
__init__
()
self
.
query_layer
=
LinearNorm
(
in_dim
=
attention_rnn_dim
,
out_dim
=
attention_dim
,
bias
=
False
,
w_init_gain
=
'tanh'
)
self
.
memory_layer
=
LinearNorm
(
in_dim
=
embedding_dim
,
out_dim
=
attention_dim
,
bias
=
False
,
w_init_gain
=
'tanh'
)
self
.
v
=
LinearNorm
(
in_dim
=
attention_dim
,
out_dim
=
1
,
bias
=
False
)
self
.
location_layer
=
LocationLayer
(
attention_n_filters
=
attention_location_n_filters
,
attention_kernel_size
=
attention_location_kernel_size
,
attention_dim
=
attention_dim
)
self
.
score_mask_value
=
-
float
(
1e20
)
def
get_alignment_energies
(
self
,
query
:
paddle
.
Tensor
,
processed_memory
:
paddle
.
Tensor
,
attention_weights_cat
:
paddle
.
Tensor
):
"""
Args:
query:
decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory:
processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat:
prev. and cumulative att weights (B, 2, max_time)
Returns:
Tensor: alignment (batch, max_time)
"""
processed_query
=
self
.
query_layer
(
query
.
unsqueeze
(
1
))
processed_attention_weights
=
self
.
location_layer
(
attention_weights_cat
)
energies
=
self
.
v
(
paddle
.
tanh
(
processed_query
+
processed_attention_weights
+
processed_memory
))
energies
=
energies
.
squeeze
(
-
1
)
return
energies
def
forward
(
self
,
attention_hidden_state
:
paddle
.
Tensor
,
memory
:
paddle
.
Tensor
,
processed_memory
:
paddle
.
Tensor
,
attention_weights_cat
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
log_alpha
:
paddle
.
Tensor
):
"""
Args:
attention_hidden_state:
attention rnn last output
memory:
encoder outputs
processed_memory:
processed encoder outputs
attention_weights_cat:
previous and cummulative attention weights
mask:
binary mask for padded data
"""
log_energy
=
self
.
get_alignment_energies
(
query
=
attention_hidden_state
,
processed_memory
=
processed_memory
,
attention_weights_cat
=
attention_weights_cat
)
if
mask
is
not
None
:
log_energy
[:]
=
paddle
.
where
(
mask
,
paddle
.
full
(
log_energy
.
shape
,
self
.
score_mask_value
,
log_energy
.
dtype
),
log_energy
)
log_alpha_shift_padded
=
[]
max_time
=
log_energy
.
shape
[
1
]
for
sft
in
range
(
2
):
shifted
=
log_alpha
[:,
:
max_time
-
sft
]
shift_padded
=
F
.
pad
(
shifted
,
(
sft
,
0
),
'constant'
,
self
.
score_mask_value
)
log_alpha_shift_padded
.
append
(
shift_padded
.
unsqueeze
(
2
))
biased
=
paddle
.
logsumexp
(
paddle
.
conat
(
log_alpha_shift_padded
,
2
),
2
)
log_alpha_new
=
biased
+
log_energy
attention_weights
=
F
.
softmax
(
log_alpha_new
,
axis
=
1
)
attention_context
=
paddle
.
bmm
(
attention_weights
.
unsqueeze
(
1
),
memory
)
attention_context
=
attention_context
.
squeeze
(
1
)
return
attention_context
,
attention_weights
,
log_alpha_new
class
PhaseShuffle2D
(
nn
.
Layer
):
def
__init__
(
self
,
n
:
int
=
2
):
super
().
__init__
()
self
.
n
=
n
self
.
random
=
random
.
Random
(
1
)
def
forward
(
self
,
x
:
paddle
.
Tensor
,
move
:
int
=
None
):
# x.size = (B, C, M, L)
if
move
is
None
:
move
=
self
.
random
.
randint
(
-
self
.
n
,
self
.
n
)
if
move
==
0
:
return
x
else
:
left
=
x
[:,
:,
:,
:
move
]
right
=
x
[:,
:,
:,
move
:]
shuffled
=
paddle
.
concat
([
right
,
left
],
axis
=
3
)
return
shuffled
class
PhaseShuffle1D
(
nn
.
Layer
):
def
__init__
(
self
,
n
:
int
=
2
):
super
().
__init__
()
self
.
n
=
n
self
.
random
=
random
.
Random
(
1
)
def
forward
(
self
,
x
:
paddle
.
Tensor
,
move
:
int
=
None
):
# x.size = (B, C, M, L)
if
move
is
None
:
move
=
self
.
random
.
randint
(
-
self
.
n
,
self
.
n
)
if
move
==
0
:
return
x
else
:
left
=
x
[:,
:,
:
move
]
right
=
x
[:,
:,
move
:]
shuffled
=
paddle
.
concat
([
right
,
left
],
axis
=
2
)
return
shuffled
class
MFCC
(
nn
.
Layer
):
def
__init__
(
self
,
n_mfcc
:
int
=
40
,
n_mels
:
int
=
80
):
super
().
__init__
()
...
...
@@ -473,7 +256,6 @@ class MFCC(nn.Layer):
# -> (channel, time, n_mfcc).tranpose(...)
mfcc
=
paddle
.
matmul
(
mel_specgram
.
transpose
([
0
,
2
,
1
]),
self
.
dct_mat
).
transpose
([
0
,
2
,
1
])
# unpack batch
if
unsqueezed
:
mfcc
=
mfcc
.
squeeze
(
0
)
...
...
paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py
浏览文件 @
0a2e367f
...
...
@@ -99,7 +99,7 @@ class ASRCNN(nn.Layer):
unmask_futre_steps (int):
unmasking future step size.
Return:
mask (paddle.BoolTensor
):
Tensor (paddle.Tensor(bool)
):
mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False
"""
index_tensor
=
paddle
.
arange
(
out_length
).
unsqueeze
(
0
).
expand
(
...
...
@@ -194,9 +194,8 @@ class ASRS2S(nn.Layer):
logit_outputs
+=
[
logit
]
alignments
+=
[
attention_weights
]
hidden_outputs
,
logit_outputs
,
alignments
=
\
self
.
parse_decoder_outputs
(
hidden_outputs
,
logit_outputs
,
alignments
)
hidden_outputs
,
logit_outputs
,
alignments
=
self
.
parse_decoder_outputs
(
hidden_outputs
,
logit_outputs
,
alignments
)
return
hidden_outputs
,
logit_outputs
,
alignments
...
...
paddlespeech/t2s/models/starganv2_vc/JDCNet/model.py
浏览文件 @
0a2e367f
...
...
@@ -33,10 +33,9 @@ class JDCNet(nn.Layer):
super
().
__init__
()
self
.
seq_len
=
seq_len
self
.
num_class
=
num_class
# input = (b, 1, 31, 513), b = batch size
# input: (B, num_class, T, n_mels)
self
.
conv_block
=
nn
.
Sequential
(
# out
: (b, 64, 31, 513
)
# out
put: (B, out_channels, T, n_mels
)
nn
.
Conv2D
(
in_channels
=
1
,
out_channels
=
64
,
...
...
@@ -45,127 +44,99 @@ class JDCNet(nn.Layer):
bias_attr
=
False
),
nn
.
BatchNorm2D
(
num_features
=
64
),
nn
.
LeakyReLU
(
leaky_relu_slope
),
#
(b, 64, 31, 513
)
#
out: (B, out_channels, T, n_mels
)
nn
.
Conv2D
(
64
,
64
,
3
,
padding
=
1
,
bias_attr
=
False
),
)
# res blocks
# (b, 128, 31, 128)
# output: (B, out_channels, T, n_mels // 2)
self
.
res_block1
=
ResBlock
(
in_channels
=
64
,
out_channels
=
128
)
#
(b, 192, 31, 32
)
#
output: (B, out_channels, T, n_mels // 4
)
self
.
res_block2
=
ResBlock
(
in_channels
=
128
,
out_channels
=
192
)
#
(b, 256, 31,
8)
#
output: (B, out_channels, T, n_mels //
8)
self
.
res_block3
=
ResBlock
(
in_channels
=
192
,
out_channels
=
256
)
# pool block
self
.
pool_block
=
nn
.
Sequential
(
nn
.
BatchNorm2D
(
num_features
=
256
),
nn
.
LeakyReLU
(
leaky_relu_slope
),
# (
b, 256, 31
, 2)
# (
B, num_features, T
, 2)
nn
.
MaxPool2D
(
kernel_size
=
(
1
,
4
)),
nn
.
Dropout
(
p
=
0.5
),
)
# maxpool layers (for auxiliary network inputs)
# in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
self
.
maxpool1
=
nn
.
MaxPool2D
(
kernel_size
=
(
1
,
40
))
# in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
self
.
maxpool2
=
nn
.
MaxPool2D
(
kernel_size
=
(
1
,
20
))
# in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
self
.
maxpool3
=
nn
.
MaxPool2D
(
kernel_size
=
(
1
,
10
))
# in = (b, 640, 31, 2), out = (b, 256, 31, 2)
self
.
detector_conv
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
640
,
out_channels
=
256
,
kernel_size
=
1
,
bias_attr
=
False
),
nn
.
BatchNorm2D
(
256
),
nn
.
LeakyReLU
(
leaky_relu_slope
),
nn
.
Dropout
(
p
=
0.5
),
)
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
# output: (b, 31, 512)
# input: (B, T, input_size), resized from (B, input_size // 2, T, 2)
# output: (B, T, input_size)
self
.
bilstm_classifier
=
nn
.
LSTM
(
input_size
=
512
,
hidden_size
=
256
,
time_major
=
False
,
direction
=
'bidirectional'
)
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
# output: (b, 31, 512)
self
.
bilstm_detector
=
nn
.
LSTM
(
input_size
=
512
,
hidden_size
=
256
,
time_major
=
False
,
direction
=
'bidirectional'
)
# input: (b * 31, 512)
# output: (b * 31, num_class)
# input: (B * T, in_features)
# output: (B * T, num_class)
self
.
classifier
=
nn
.
Linear
(
in_features
=
512
,
out_features
=
self
.
num_class
)
# input: (b * 31, 512)
# output: (b * 31, 2) - binary classifier
self
.
detector
=
nn
.
Linear
(
in_features
=
512
,
out_features
=
2
)
# initialize weights
self
.
apply
(
self
.
init_weights
)
def
get_feature_GAN
(
self
,
x
:
paddle
.
Tensor
):
seq_len
=
x
.
shape
[
-
2
]
x
=
x
.
astype
(
paddle
.
float32
).
transpose
([
0
,
1
,
3
,
2
]
if
len
(
x
.
shape
)
==
4
else
[
0
,
2
,
1
])
"""Calculate feature_GAN.
Args:
x(Tensor(float32)):
Shape (B, num_class, n_mels, T).
Returns:
Tensor:
Shape (B, num_features, n_mels // 8, T).
"""
x
=
x
.
astype
(
paddle
.
float32
)
x
=
x
.
transpose
([
0
,
1
,
3
,
2
]
if
len
(
x
.
shape
)
==
4
else
[
0
,
2
,
1
])
convblock_out
=
self
.
conv_block
(
x
)
resblock1_out
=
self
.
res_block1
(
convblock_out
)
resblock2_out
=
self
.
res_block2
(
resblock1_out
)
resblock3_out
=
self
.
res_block3
(
resblock2_out
)
poolblock_out
=
self
.
pool_block
[
0
](
resblock3_out
)
poolblock_out
=
self
.
pool_block
[
1
](
poolblock_out
)
return
poolblock_out
.
transpose
([
0
,
1
,
3
,
2
]
if
len
(
poolblock_out
.
shape
)
==
4
else
[
0
,
2
,
1
])
GAN_feature
=
poolblock_out
.
transpose
([
0
,
1
,
3
,
2
]
if
len
(
poolblock_out
.
shape
)
==
4
else
[
0
,
2
,
1
]
)
return
GAN_feature
def
forward
(
self
,
x
:
paddle
.
Tensor
):
"""
"""Calculate forward propagation.
Args:
x(Tensor(float32)):
Shape (B, num_class, n_mels, seq_len).
Returns:
classification_prediction, detection_prediction
sizes: (b, 31, 722), (b, 31, 2)
Tensor:
classifier output consists of predicted pitch classes per frame.
Shape: (B, seq_len, num_class).
Tensor:
GAN_feature. Shape: (B, num_features, n_mels // 8, seq_len)
Tensor:
poolblock_out. Shape (B, seq_len, 512)
"""
###############################
# forward pass for classifier #
###############################
# (B, num_class, n_mels, T) -> (B, num_class, T, n_mels)
x
=
x
.
transpose
([
0
,
1
,
3
,
2
]
if
len
(
x
.
shape
)
==
4
else
[
0
,
2
,
1
]).
astype
(
paddle
.
float32
)
convblock_out
=
self
.
conv_block
(
x
)
resblock1_out
=
self
.
res_block1
(
convblock_out
)
resblock2_out
=
self
.
res_block2
(
resblock1_out
)
resblock3_out
=
self
.
res_block3
(
resblock2_out
)
poolblock_out
=
self
.
pool_block
[
0
](
resblock3_out
)
poolblock_out
=
self
.
pool_block
[
1
](
poolblock_out
)
GAN_feature
=
poolblock_out
.
transpose
([
0
,
1
,
3
,
2
]
if
len
(
poolblock_out
.
shape
)
==
4
else
[
0
,
2
,
1
])
poolblock_out
=
self
.
pool_block
[
2
](
poolblock_out
)
# (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
# (B, 256, seq_len, 2) => (B, seq_len, 256, 2) => (B, seq_len, 512)
classifier_out
=
poolblock_out
.
transpose
([
0
,
2
,
1
,
3
]).
reshape
(
(
-
1
,
self
.
seq_len
,
512
))
self
.
bilstm_classifier
.
flatten_parameters
()
classifier_out
,
_
=
self
.
bilstm_classifier
(
classifier_out
)
# ignore the hidden states
classifier_out
=
classifier_out
.
reshape
((
-
1
,
512
))
# (b * 31, 512)
# ignore the hidden states
classifier_out
,
_
=
self
.
bilstm_classifier
(
classifier_out
)
# (B * seq_len, 512)
classifier_out
=
classifier_out
.
reshape
((
-
1
,
512
))
classifier_out
=
self
.
classifier
(
classifier_out
)
# (B, seq_len, num_class)
classifier_out
=
classifier_out
.
reshape
(
(
-
1
,
self
.
seq_len
,
self
.
num_class
))
# (b, 31, num_class)
# sizes: (b, 31, 722), (b, 31, 2)
# classifier output consists of predicted pitch classes per frame
# detector output consists of: (isvoice, notvoice) estimates per frame
(
-
1
,
self
.
seq_len
,
self
.
num_class
))
return
paddle
.
abs
(
classifier_out
.
squeeze
()),
GAN_feature
,
poolblock_out
@
staticmethod
...
...
@@ -188,10 +159,9 @@ class ResBlock(nn.Layer):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
leaky_relu_slope
=
0.01
):
leaky_relu_slope
:
float
=
0.01
):
super
().
__init__
()
self
.
downsample
=
in_channels
!=
out_channels
# BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
self
.
pre_conv
=
nn
.
Sequential
(
nn
.
BatchNorm2D
(
num_features
=
in_channels
),
...
...
@@ -215,7 +185,6 @@ class ResBlock(nn.Layer):
kernel_size
=
3
,
padding
=
1
,
bias_attr
=
False
),
)
# 1 x 1 convolution layer to match the feature dimensions
self
.
conv1by1
=
None
if
self
.
downsample
:
...
...
@@ -226,6 +195,13 @@ class ResBlock(nn.Layer):
bias_attr
=
False
)
def
forward
(
self
,
x
:
paddle
.
Tensor
):
"""Calculate forward propagation.
Args:
x(Tensor(float32)): Shape (B, in_channels, T, n_mels).
Returns:
Tensor:
The residual output, Shape (B, out_channels, T, n_mels // 2).
"""
x
=
self
.
pre_conv
(
x
)
if
self
.
downsample
:
x
=
self
.
conv
(
x
)
+
self
.
conv1by1
(
x
)
...
...
paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py
浏览文件 @
0a2e367f
...
...
@@ -19,31 +19,36 @@ This work is licensed under the Creative Commons Attribution-NonCommercial
http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""
# import copy
import
math
import
paddle
import
paddle.nn.functional
as
F
from
paddle
import
nn
from
paddlespeech.utils.initialize
import
_calculate_gain
from
paddlespeech.utils.initialize
import
xavier_uniform_
# from munch import Munch
class
DownSample
(
nn
.
Layer
):
def
__init__
(
self
,
layer_type
:
str
):
super
().
__init__
()
self
.
layer_type
=
layer_type
def
forward
(
self
,
x
):
def
forward
(
self
,
x
:
paddle
.
Tensor
):
"""Calculate forward propagation.
Args:
x(Tensor(float32)): Shape (B, dim_in, n_mels, T).
Returns:
Tensor:
layer_type == 'none': Shape (B, dim_in, n_mels, T)
layer_type == 'timepreserve': Shape (B, dim_in, n_mels // 2, T)
layer_type == 'half': Shape (B, dim_in, n_mels // 2, T // 2)
"""
if
self
.
layer_type
==
'none'
:
return
x
elif
self
.
layer_type
==
'timepreserve'
:
return
F
.
avg_pool2d
(
x
,
(
2
,
1
))
out
=
F
.
avg_pool2d
(
x
,
(
2
,
1
))
return
out
elif
self
.
layer_type
==
'half'
:
return
F
.
avg_pool2d
(
x
,
2
)
out
=
F
.
avg_pool2d
(
x
,
2
)
return
out
else
:
raise
RuntimeError
(
'Got unexpected donwsampletype %s, expected is [none, timepreserve, half]'
...
...
@@ -55,13 +60,24 @@ class UpSample(nn.Layer):
super
().
__init__
()
self
.
layer_type
=
layer_type
def
forward
(
self
,
x
):
def
forward
(
self
,
x
:
paddle
.
Tensor
):
"""Calculate forward propagation.
Args:
x(Tensor(float32)): Shape (B, dim_in, n_mels, T).
Returns:
Tensor:
layer_type == 'none': Shape (B, dim_in, n_mels, T)
layer_type == 'timepreserve': Shape (B, dim_in, n_mels * 2, T)
layer_type == 'half': Shape (B, dim_in, n_mels * 2, T * 2)
"""
if
self
.
layer_type
==
'none'
:
return
x
elif
self
.
layer_type
==
'timepreserve'
:
return
F
.
interpolate
(
x
,
scale_factor
=
(
2
,
1
),
mode
=
'nearest'
)
out
=
F
.
interpolate
(
x
,
scale_factor
=
(
2
,
1
),
mode
=
'nearest'
)
return
out
elif
self
.
layer_type
==
'half'
:
return
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
'nearest'
)
out
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
'nearest'
)
return
out
else
:
raise
RuntimeError
(
'Got unexpected upsampletype %s, expected is [none, timepreserve, half]'
...
...
@@ -127,9 +143,19 @@ class ResBlk(nn.Layer):
return
x
def
forward
(
self
,
x
:
paddle
.
Tensor
):
"""Calculate forward propagation.
Args:
x(Tensor(float32)): Shape (B, dim_in, n_mels, T).
Returns:
Tensor:
downsample == 'none': Shape (B, dim_in, n_mels, T).
downsample == 'timepreserve': Shape (B, dim_out, T, n_mels // 2, T).
downsample == 'half': Shape (B, dim_out, T, n_mels // 2, T // 2).
"""
x
=
self
.
_shortcut
(
x
)
+
self
.
_residual
(
x
)
# unit variance
return
x
/
math
.
sqrt
(
2
)
out
=
x
/
math
.
sqrt
(
2
)
return
out
class
AdaIN
(
nn
.
Layer
):
...
...
@@ -140,12 +166,21 @@ class AdaIN(nn.Layer):
self
.
fc
=
nn
.
Linear
(
style_dim
,
num_features
*
2
)
def
forward
(
self
,
x
:
paddle
.
Tensor
,
s
:
paddle
.
Tensor
):
"""Calculate forward propagation.
Args:
x(Tensor(float32)): Shape (B, style_dim, n_mels, T).
s(Tensor(float32)): Shape (style_dim, ).
Returns:
Tensor:
Shape (B, style_dim, T, n_mels, T).
"""
if
len
(
s
.
shape
)
==
1
:
s
=
s
[
None
]
h
=
self
.
fc
(
s
)
h
=
h
.
reshape
((
h
.
shape
[
0
],
h
.
shape
[
1
],
1
,
1
))
gamma
,
beta
=
paddle
.
split
(
h
,
2
,
axis
=
1
)
return
(
1
+
gamma
)
*
self
.
norm
(
x
)
+
beta
out
=
(
1
+
gamma
)
*
self
.
norm
(
x
)
+
beta
return
out
class
AdainResBlk
(
nn
.
Layer
):
...
...
@@ -162,6 +197,7 @@ class AdainResBlk(nn.Layer):
self
.
upsample
=
UpSample
(
layer_type
=
upsample
)
self
.
learned_sc
=
dim_in
!=
dim_out
self
.
_build_weights
(
dim_in
,
dim_out
,
style_dim
)
self
.
layer_type
=
upsample
def
_build_weights
(
self
,
dim_in
:
int
,
dim_out
:
int
,
style_dim
:
int
=
64
):
self
.
conv1
=
nn
.
Conv2D
(
...
...
@@ -204,6 +240,18 @@ class AdainResBlk(nn.Layer):
return
x
def
forward
(
self
,
x
:
paddle
.
Tensor
,
s
:
paddle
.
Tensor
):
"""Calculate forward propagation.
Args:
x(Tensor(float32)):
Shape (B, dim_in, n_mels, T).
s(Tensor(float32)):
Shape (64,).
Returns:
Tensor:
upsample == 'none': Shape (B, dim_out, T, n_mels, T).
upsample == 'timepreserve': Shape (B, dim_out, T, n_mels * 2, T).
upsample == 'half': Shape (B, dim_out, T, n_mels * 2, T * 2).
"""
out
=
self
.
_residual
(
x
,
s
)
if
self
.
w_hpf
==
0
:
out
=
(
out
+
self
.
_shortcut
(
x
))
/
math
.
sqrt
(
2
)
...
...
@@ -219,7 +267,8 @@ class HighPass(nn.Layer):
def
forward
(
self
,
x
:
paddle
.
Tensor
):
filter
=
self
.
filter
.
unsqueeze
(
0
).
unsqueeze
(
1
).
tile
(
[
x
.
shape
[
1
],
1
,
1
,
1
])
return
F
.
conv2d
(
x
,
filter
,
padding
=
1
,
groups
=
x
.
shape
[
1
])
out
=
F
.
conv2d
(
x
,
filter
,
padding
=
1
,
groups
=
x
.
shape
[
1
])
return
out
class
Generator
(
nn
.
Layer
):
...
...
@@ -276,12 +325,10 @@ class Generator(nn.Layer):
w_hpf
=
w_hpf
,
upsample
=
_downtype
))
# stack-like
dim_in
=
dim_out
# bottleneck blocks (encoder)
for
_
in
range
(
2
):
self
.
encode
.
append
(
ResBlk
(
dim_in
=
dim_out
,
dim_out
=
dim_out
,
normalize
=
True
))
# F0 blocks
if
F0_channel
!=
0
:
self
.
decode
.
insert
(
0
,
...
...
@@ -290,7 +337,6 @@ class Generator(nn.Layer):
dim_out
=
dim_out
,
style_dim
=
style_dim
,
w_hpf
=
w_hpf
))
# bottleneck blocks (decoder)
for
_
in
range
(
2
):
self
.
decode
.
insert
(
0
,
...
...
@@ -299,7 +345,6 @@ class Generator(nn.Layer):
dim_out
=
dim_out
+
int
(
F0_channel
/
2
),
style_dim
=
style_dim
,
w_hpf
=
w_hpf
))
if
F0_channel
!=
0
:
self
.
F0_conv
=
nn
.
Sequential
(
ResBlk
(
...
...
@@ -307,7 +352,6 @@ class Generator(nn.Layer):
dim_out
=
int
(
F0_channel
/
2
),
normalize
=
True
,
downsample
=
"half"
),
)
if
w_hpf
>
0
:
self
.
hpf
=
HighPass
(
w_hpf
)
...
...
@@ -316,26 +360,44 @@ class Generator(nn.Layer):
s
:
paddle
.
Tensor
,
masks
:
paddle
.
Tensor
=
None
,
F0
:
paddle
.
Tensor
=
None
):
"""Calculate forward propagation.
Args:
x(Tensor(float32)):
Shape (B, 1, n_mels, T).
s(Tensor(float32)):
Shape (64,).
masks:
None.
F0:
Shape (B, num_features(256), n_mels // 8, T).
Returns:
Tensor:
output of generator. Shape (B, 1, n_mels, T // 4 * 4)
"""
x
=
self
.
stem
(
x
)
cache
=
{}
# output: (B, max_conv_dim, n_mels // 16, T // 4)
for
block
in
self
.
encode
:
if
(
masks
is
not
None
)
and
(
x
.
shape
[
2
]
in
[
32
,
64
,
128
]):
cache
[
x
.
shape
[
2
]]
=
x
x
=
block
(
x
)
if
F0
is
not
None
:
# input: (B, num_features(256), n_mels // 8, T)
# output: (B, num_features(256) // 2, n_mels // 16, T // 2)
F0
=
self
.
F0_conv
(
F0
)
# output: (B, num_features(256) // 2, n_mels // 16, T // 4)
F0
=
F
.
adaptive_avg_pool2d
(
F0
,
[
x
.
shape
[
-
2
],
x
.
shape
[
-
1
]])
x
=
paddle
.
concat
([
x
,
F0
],
axis
=
1
)
# input: (B, max_conv_dim+num_features(256) // 2, n_mels // 16, T // 4 * 4)
# output: (B, dim_in, n_mels, T // 4 * 4)
for
block
in
self
.
decode
:
x
=
block
(
x
,
s
)
if
(
masks
is
not
None
)
and
(
x
.
shape
[
2
]
in
[
32
,
64
,
128
]):
mask
=
masks
[
0
]
if
x
.
shape
[
2
]
in
[
32
]
else
masks
[
1
]
mask
=
F
.
interpolate
(
mask
,
size
=
x
.
shape
[
2
],
mode
=
'bilinear'
)
x
=
x
+
self
.
hpf
(
mask
*
cache
[
x
.
shape
[
2
]])
return
self
.
to_out
(
x
)
out
=
self
.
to_out
(
x
)
return
out
class
MappingNetwork
(
nn
.
Layer
):
...
...
@@ -366,14 +428,25 @@ class MappingNetwork(nn.Layer):
])
def
forward
(
self
,
z
:
paddle
.
Tensor
,
y
:
paddle
.
Tensor
):
"""Calculate forward propagation.
Args:
z(Tensor(float32)):
Shape (B, 1, n_mels, T).
y(Tensor(float32)):
speaker label. Shape (B, ).
Returns:
Tensor:
Shape (style_dim, )
"""
h
=
self
.
shared
(
z
)
out
=
[]
for
layer
in
self
.
unshared
:
out
+=
[
layer
(
h
)]
# (
batch
, num_domains, style_dim)
# (
B
, num_domains, style_dim)
out
=
paddle
.
stack
(
out
,
axis
=
1
)
idx
=
paddle
.
arange
(
y
.
shape
[
0
])
# (
batch, style_dim
)
# (
style_dim,
)
s
=
out
[
idx
,
y
]
return
s
...
...
@@ -419,15 +492,25 @@ class StyleEncoder(nn.Layer):
self
.
unshared
.
append
(
nn
.
Linear
(
dim_out
,
style_dim
))
def
forward
(
self
,
x
:
paddle
.
Tensor
,
y
:
paddle
.
Tensor
):
"""Calculate forward propagation.
Args:
x(Tensor(float32)):
Shape (B, 1, n_mels, T).
y(Tensor(float32)):
speaker label. Shape (B, ).
Returns:
Tensor:
Shape (style_dim, )
"""
h
=
self
.
shared
(
x
)
h
=
h
.
reshape
((
h
.
shape
[
0
],
-
1
))
out
=
[]
for
layer
in
self
.
unshared
:
out
+=
[
layer
(
h
)]
# (
batch
, num_domains, style_dim)
# (
B
, num_domains, style_dim)
out
=
paddle
.
stack
(
out
,
axis
=
1
)
idx
=
paddle
.
arange
(
y
.
shape
[
0
])
# (
batch, style_dim
)
# (
style_dim,
)
s
=
out
[
idx
,
y
]
return
s
...
...
@@ -454,25 +537,12 @@ class Discriminator(nn.Layer):
self
.
num_domains
=
num_domains
def
forward
(
self
,
x
:
paddle
.
Tensor
,
y
:
paddle
.
Tensor
):
return
self
.
dis
(
x
,
y
)
out
=
self
.
dis
(
x
,
y
)
return
out
def
classifier
(
self
,
x
:
paddle
.
Tensor
):
return
self
.
cls
.
get_feature
(
x
)
class
LinearNorm
(
nn
.
Layer
):
def
__init__
(
self
,
in_dim
:
int
,
out_dim
:
int
,
bias
:
bool
=
True
,
w_init_gain
:
str
=
'linear'
):
super
().
__init__
()
self
.
linear_layer
=
nn
.
Linear
(
in_dim
,
out_dim
,
bias_attr
=
bias
)
xavier_uniform_
(
self
.
linear_layer
.
weight
,
gain
=
_calculate_gain
(
w_init_gain
))
def
forward
(
self
,
x
):
return
self
.
linear_layer
(
x
)
out
=
self
.
cls
.
get_feature
(
x
)
return
out
class
Discriminator2D
(
nn
.
Layer
):
...
...
@@ -520,97 +590,13 @@ class Discriminator2D(nn.Layer):
def
get_feature
(
self
,
x
:
paddle
.
Tensor
):
out
=
self
.
main
(
x
)
# (
batch
, num_domains)
# (
B
, num_domains)
out
=
out
.
reshape
((
out
.
shape
[
0
],
-
1
))
return
out
def
forward
(
self
,
x
:
paddle
.
Tensor
,
y
:
paddle
.
Tensor
):
out
=
self
.
get_feature
(
x
)
idx
=
paddle
.
arange
(
y
.
shape
[
0
])
# (
batch)
# (
B,) ?
out
=
out
[
idx
,
y
]
return
out
'''
def build_model(args, F0_model: nn.Layer, ASR_model: nn.Layer):
generator = Generator(
dim_in=args.dim_in,
style_dim=args.style_dim,
max_conv_dim=args.max_conv_dim,
w_hpf=args.w_hpf,
F0_channel=args.F0_channel)
mapping_network = MappingNetwork(
latent_dim=args.latent_dim,
style_dim=args.style_dim,
num_domains=args.num_domains,
hidden_dim=args.max_conv_dim)
style_encoder = StyleEncoder(
dim_in=args.dim_in,
style_dim=args.style_dim,
num_domains=args.num_domains,
max_conv_dim=args.max_conv_dim)
discriminator = Discriminator(
dim_in=args.dim_in,
num_domains=args.num_domains,
max_conv_dim=args.max_conv_dim,
n_repeat=args.n_repeat)
generator_ema = copy.deepcopy(generator)
mapping_network_ema = copy.deepcopy(mapping_network)
style_encoder_ema = copy.deepcopy(style_encoder)
nets = Munch(
generator=generator,
mapping_network=mapping_network,
style_encoder=style_encoder,
discriminator=discriminator,
f0_model=F0_model,
asr_model=ASR_model)
nets_ema = Munch(
generator=generator_ema,
mapping_network=mapping_network_ema,
style_encoder=style_encoder_ema)
return nets, nets_ema
class StarGANv2VC(nn.Layer):
def __init__(
self,
# spk_num
num_domains: int=20,
dim_in: int=64,
style_dim: int=64,
latent_dim: int=16,
max_conv_dim: int=512,
n_repeat: int=4,
w_hpf: int=0,
F0_channel: int=256):
super().__init__()
self.generator = Generator(
dim_in=dim_in,
style_dim=style_dim,
max_conv_dim=max_conv_dim,
w_hpf=w_hpf,
F0_channel=F0_channel)
# MappingNetwork and StyleEncoder are used to generate reference_embeddings
self.mapping_network = MappingNetwork(
latent_dim=latent_dim,
style_dim=style_dim,
num_domains=num_domains,
hidden_dim=max_conv_dim)
self.style_encoder = StyleEncoder(
dim_in=dim_in,
style_dim=style_dim,
num_domains=num_domains,
max_conv_dim=max_conv_dim)
self.discriminator = Discriminator(
dim_in=dim_in,
num_domains=num_domains,
max_conv_dim=max_conv_dim,
repeat_num=n_repeat)
'''
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录