Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
19180d35
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
19180d35
编写于
10月 10, 2022
作者:
T
tianhao zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format wav2vec2 demo
上级
6e429f05
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
558 addition
and
485 deletion
+558
-485
.flake8
.flake8
+1
-1
examples/librispeech/README.md
examples/librispeech/README.md
+1
-1
paddlespeech/audio/transform/spectrogram.py
paddlespeech/audio/transform/spectrogram.py
+30
-0
paddlespeech/audio/transform/transformation.py
paddlespeech/audio/transform/transformation.py
+1
-0
paddlespeech/s2t/exps/wav2vec2/bin/test_wav.py
paddlespeech/s2t/exps/wav2vec2/bin/test_wav.py
+6
-6
paddlespeech/s2t/exps/wav2vec2/model.py
paddlespeech/s2t/exps/wav2vec2/model.py
+35
-28
paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py
paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py
+7
-8
paddlespeech/s2t/models/wav2vec2/modules/activations.py
paddlespeech/s2t/models/wav2vec2/modules/activations.py
+14
-9
paddlespeech/s2t/models/wav2vec2/modules/containers.py
paddlespeech/s2t/models/wav2vec2/modules/containers.py
+5
-7
paddlespeech/s2t/models/wav2vec2/modules/linear.py
paddlespeech/s2t/models/wav2vec2/modules/linear.py
+8
-9
paddlespeech/s2t/models/wav2vec2/modules/modeling_outputs.py
paddlespeech/s2t/models/wav2vec2/modules/modeling_outputs.py
+23
-15
paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py
...lespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py
+301
-239
paddlespeech/s2t/models/wav2vec2/processing/signal_processing.py
...peech/s2t/models/wav2vec2/processing/signal_processing.py
+15
-21
paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py
...ech/s2t/models/wav2vec2/processing/speech_augmentation.py
+78
-89
paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py
paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py
+33
-52
未找到文件。
.flake8
浏览文件 @
19180d35
...
@@ -33,7 +33,7 @@ filename =
...
@@ -33,7 +33,7 @@ filename =
# Specify a list of codes to ignore.
# Specify a list of codes to ignore.
ignore =
ignore =
W503
W503
E252,E262,E127,E265,E126,E266,E241,E261,E128,E125
E252,E262,E127,E265,E126,E266,E241,E261,E128,E125
,E129
W291,W293,W605
W291,W293,W605
E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
...
...
examples/librispeech/README.md
浏览文件 @
19180d35
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
*
asr0 - deepspeech2 Streaming/Non-Streaming
*
asr0 - deepspeech2 Streaming/Non-Streaming
*
asr1 - transformer/conformer Streaming/Non-Streaming
*
asr1 - transformer/conformer Streaming/Non-Streaming
*
asr2 - transformer/conformer Streaming/Non-Streaming with Kaldi feature
*
asr2 - transformer/conformer Streaming/Non-Streaming with Kaldi feature
*
asr3 - wav2vecASR, ASR model with pre-trained wav2vec2 and CTC
## Data
## Data
| Data Subset | Duration in Seconds |
| Data Subset | Duration in Seconds |
...
...
paddlespeech/audio/transform/spectrogram.py
浏览文件 @
19180d35
...
@@ -382,6 +382,36 @@ class LogMelSpectrogramKaldi():
...
@@ -382,6 +382,36 @@ class LogMelSpectrogramKaldi():
return
mat
return
mat
class
WavProcess
():
def
__init__
(
self
,
dither
=
0.1
):
"""
Args:
dither (float): Dithering constant
Returns:
"""
self
.
dither
=
dither
def
__call__
(
self
,
x
,
train
):
"""
Args:
x (np.ndarray): shape (Ti,)
train (bool): True, train mode.
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
dither
=
self
.
dither
if
train
else
0.0
if
x
.
ndim
!=
1
:
raise
ValueError
(
"Not support x: [Time, Channel]"
)
waveform
=
np
.
expand_dims
(
x
,
-
1
)
return
waveform
class
LogMelSpectrogramKaldi_decay
():
class
LogMelSpectrogramKaldi_decay
():
def
__init__
(
def
__init__
(
self
,
self
,
...
...
paddlespeech/audio/transform/transformation.py
浏览文件 @
19180d35
...
@@ -41,6 +41,7 @@ import_alias = dict(
...
@@ -41,6 +41,7 @@ import_alias = dict(
utterance_cmvn
=
"paddlespeech.audio.transform.cmvn:UtteranceCMVN"
,
utterance_cmvn
=
"paddlespeech.audio.transform.cmvn:UtteranceCMVN"
,
fbank
=
"paddlespeech.audio.transform.spectrogram:LogMelSpectrogram"
,
fbank
=
"paddlespeech.audio.transform.spectrogram:LogMelSpectrogram"
,
spectrogram
=
"paddlespeech.audio.transform.spectrogram:Spectrogram"
,
spectrogram
=
"paddlespeech.audio.transform.spectrogram:Spectrogram"
,
wav_process
=
"paddlespeech.audio.transform.spectrogram:WavProcess"
,
stft
=
"paddlespeech.audio.transform.spectrogram:Stft"
,
stft
=
"paddlespeech.audio.transform.spectrogram:Stft"
,
istft
=
"paddlespeech.audio.transform.spectrogram:IStft"
,
istft
=
"paddlespeech.audio.transform.spectrogram:IStft"
,
stft2fbank
=
"paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram"
,
stft2fbank
=
"paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram"
,
...
...
paddlespeech/s2t/exps/wav2vec2/bin/test_wav.py
浏览文件 @
19180d35
...
@@ -27,6 +27,7 @@ from paddlespeech.s2t.utils.log import Log
...
@@ -27,6 +27,7 @@ from paddlespeech.s2t.utils.log import Log
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
class
Wav2vec2Infer
():
class
Wav2vec2Infer
():
def
__init__
(
self
,
config
,
args
):
def
__init__
(
self
,
config
,
args
):
self
.
args
=
args
self
.
args
=
args
...
@@ -34,8 +35,7 @@ class Wav2vec2Infer():
...
@@ -34,8 +35,7 @@ class Wav2vec2Infer():
self
.
audio_file
=
args
.
audio_file
self
.
audio_file
=
args
.
audio_file
self
.
text_feature
=
TextFeaturizer
(
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
config
.
unit_type
,
unit_type
=
config
.
unit_type
,
vocab
=
config
.
vocab_filepath
)
vocab
=
config
.
vocab_filepath
)
paddle
.
set_device
(
'gpu'
if
self
.
args
.
ngpu
>
0
else
'cpu'
)
paddle
.
set_device
(
'gpu'
if
self
.
args
.
ngpu
>
0
else
'cpu'
)
# model
# model
...
@@ -63,10 +63,10 @@ class Wav2vec2Infer():
...
@@ -63,10 +63,10 @@ class Wav2vec2Infer():
xs
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
).
unsqueeze
(
axis
=
0
)
xs
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
).
unsqueeze
(
axis
=
0
)
decode_config
=
self
.
config
.
decode
decode_config
=
self
.
config
.
decode
result_transcripts
,
result_tokenids
=
self
.
model
.
decode
(
result_transcripts
,
result_tokenids
=
self
.
model
.
decode
(
xs
,
xs
,
text_feature
=
self
.
text_feature
,
text_feature
=
self
.
text_feature
,
decoding_method
=
decode_config
.
decoding_method
,
decoding_method
=
decode_config
.
decoding_method
,
beam_size
=
decode_config
.
beam_size
)
beam_size
=
decode_config
.
beam_size
)
rsl
=
result_transcripts
[
0
]
rsl
=
result_transcripts
[
0
]
utt
=
Path
(
self
.
audio_file
).
name
utt
=
Path
(
self
.
audio_file
).
name
logger
.
info
(
f
"hyp:
{
utt
}
{
rsl
}
"
)
logger
.
info
(
f
"hyp:
{
utt
}
{
rsl
}
"
)
...
...
paddlespeech/s2t/exps/wav2vec2/model.py
浏览文件 @
19180d35
...
@@ -18,53 +18,53 @@ import time
...
@@ -18,53 +18,53 @@ import time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
paddlespeech.s2t.utils
import
mp_tools
import
jsonlines
import
jsonlines
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.io.dataloader
import
StreamDataLoader
from
paddlespeech.s2t.io.dataloader
import
DataLoaderFactory
from
paddlespeech.s2t.io.dataloader
import
DataLoaderFactory
from
paddlespeech.s2t.
models.wav2vec2.wav2vec2_ASR
import
Wav2vec2ASR
from
paddlespeech.s2t.
io.dataloader
import
StreamDataLoader
from
paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation
import
TimeDomainSpecAugment
from
paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation
import
TimeDomainSpecAugment
from
paddlespeech.s2t.utils
import
error_rate
from
paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR
import
Wav2vec2ASR
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
from
paddlespeech.s2t.training.reporter
import
ObsScope
from
paddlespeech.s2t.training.reporter
import
ObsScope
from
paddlespeech.s2t.training.reporter
import
report
from
paddlespeech.s2t.training.reporter
import
report
from
paddlespeech.s2t.training.scheduler
import
LRSchedulerFactory
from
paddlespeech.s2t.training.scheduler
import
LRSchedulerFactory
from
paddlespeech.s2t.training.timer
import
Timer
from
paddlespeech.s2t.training.timer
import
Timer
from
paddlespeech.s2t.training.trainer
import
Trainer
from
paddlespeech.s2t.training.trainer
import
Trainer
from
paddlespeech.s2t.utils
.utility
import
UpdateConfig
from
paddlespeech.s2t.utils
import
error_rate
from
paddlespeech.s2t.utils
import
layer_tools
from
paddlespeech.s2t.utils
import
layer_tools
from
paddlespeech.s2t.utils
import
mp_tools
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
class
Wav2Vec2ASRTrainer
(
Trainer
):
class
Wav2Vec2ASRTrainer
(
Trainer
):
def
__init__
(
self
,
config
,
args
):
def
__init__
(
self
,
config
,
args
):
super
().
__init__
(
config
,
args
)
super
().
__init__
(
config
,
args
)
self
.
avg_train_loss
=
0
self
.
avg_train_loss
=
0
def
train_batch
(
self
,
batch_index
,
batch
,
msg
):
def
train_batch
(
self
,
batch_index
,
batch
,
msg
):
train_conf
=
self
.
config
train_conf
=
self
.
config
start
=
time
.
time
()
start
=
time
.
time
()
# forward
# forward
utt
,
wav
,
wavs_lens
,
target
,
target_lens
=
batch
utt
,
wav
,
wavs_lens
,
target
,
target_lens
=
batch
wavs_lens_rate
=
wavs_lens
/
wav
.
shape
[
1
]
wavs_lens_rate
=
wavs_lens
/
wav
.
shape
[
1
]
target_lens_rate
=
target_lens
/
target
.
shape
[
1
]
target_lens_rate
=
target_lens
/
target
.
shape
[
1
]
wav
=
wav
[:,
:,
0
]
wav
=
wav
[:,
:,
0
]
wav
=
self
.
speech_augmentation
(
wav
,
wavs_lens_rate
)
wav
=
self
.
speech_augmentation
(
wav
,
wavs_lens_rate
)
loss
=
self
.
model
(
wav
,
wavs_lens_rate
,
target
,
target_lens_rate
)
loss
=
self
.
model
(
wav
,
wavs_lens_rate
,
target
,
target_lens_rate
)
# pring(wav, wavs_lens_rate, target, target_lens_rate)
# pring(wav, wavs_lens_rate, target, target_lens_rate)
# loss div by `batch_size * accum_grad`
# loss div by `batch_size * accum_grad`
loss
/=
train_conf
.
accum_grad
loss
/=
train_conf
.
accum_grad
losses_np
=
{
'loss'
:
float
(
loss
)
*
train_conf
.
accum_grad
}
losses_np
=
{
'loss'
:
float
(
loss
)
*
train_conf
.
accum_grad
}
# loss backward
# loss backward
...
@@ -108,15 +108,16 @@ class Wav2Vec2ASRTrainer(Trainer):
...
@@ -108,15 +108,16 @@ class Wav2Vec2ASRTrainer(Trainer):
def
valid
(
self
):
def
valid
(
self
):
self
.
model
.
eval
()
self
.
model
.
eval
()
if
not
self
.
use_streamdata
:
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
num_seen_utts
=
1
total_loss
=
0.0
total_loss
=
0.0
for
i
,
batch
in
enumerate
(
self
.
valid_loader
):
for
i
,
batch
in
enumerate
(
self
.
valid_loader
):
utt
,
wav
,
wavs_lens
,
target
,
target_lens
=
batch
utt
,
wav
,
wavs_lens
,
target
,
target_lens
=
batch
wavs_lens_rate
=
wavs_lens
/
wav
.
shape
[
1
]
wavs_lens_rate
=
wavs_lens
/
wav
.
shape
[
1
]
target_lens_rate
=
target_lens
/
target
.
shape
[
1
]
target_lens_rate
=
target_lens
/
target
.
shape
[
1
]
wav
=
wav
[:,
:,
0
]
wav
=
wav
[:,
:,
0
]
loss
=
self
.
model
(
wav
,
wavs_lens_rate
,
target
,
target_lens_rate
)
loss
=
self
.
model
(
wav
,
wavs_lens_rate
,
target
,
target_lens_rate
)
if
paddle
.
isfinite
(
loss
):
if
paddle
.
isfinite
(
loss
):
...
@@ -134,7 +135,8 @@ class Wav2Vec2ASRTrainer(Trainer):
...
@@ -134,7 +135,8 @@ class Wav2Vec2ASRTrainer(Trainer):
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
if
not
self
.
use_streamdata
:
if
not
self
.
use_streamdata
:
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_dump
.
items
())
for
k
,
v
in
valid_dump
.
items
())
logger
.
info
(
msg
)
logger
.
info
(
msg
)
...
@@ -155,7 +157,8 @@ class Wav2Vec2ASRTrainer(Trainer):
...
@@ -155,7 +157,8 @@ class Wav2Vec2ASRTrainer(Trainer):
self
.
before_train
()
self
.
before_train
()
if
not
self
.
use_streamdata
:
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
n_epoch
:
with
Timer
(
"Epoch-Train Time Cost: {}"
):
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
self
.
model
.
train
()
...
@@ -223,14 +226,18 @@ class Wav2Vec2ASRTrainer(Trainer):
...
@@ -223,14 +226,18 @@ class Wav2Vec2ASRTrainer(Trainer):
config
=
self
.
config
.
clone
()
config
=
self
.
config
.
clone
()
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
if
self
.
train
:
if
self
.
train
:
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
'train'
,
config
,
self
.
args
)
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
else
:
decode_batch_size
=
config
.
get
(
'decode'
,
dict
()).
get
(
decode_batch_size
=
config
.
get
(
'decode'
,
dict
()).
get
(
'decode_batch_size'
,
1
)
'decode_batch_size'
,
1
)
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
self
.
args
)
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
logger
.
info
(
"Setup test/align Dataloader!"
)
logger
.
info
(
"Setup test/align Dataloader!"
)
def
setup_model
(
self
):
def
setup_model
(
self
):
...
@@ -248,7 +255,7 @@ class Wav2Vec2ASRTrainer(Trainer):
...
@@ -248,7 +255,7 @@ class Wav2Vec2ASRTrainer(Trainer):
model
=
Wav2vec2ASR
.
from_config
(
model_conf
)
model
=
Wav2vec2ASR
.
from_config
(
model_conf
)
if
self
.
parallel
:
if
self
.
parallel
:
model
=
paddle
.
DataParallel
(
model
,
find_unused_parameters
=
True
)
model
=
paddle
.
DataParallel
(
model
,
find_unused_parameters
=
True
)
logger
.
info
(
f
"
{
model
}
"
)
logger
.
info
(
f
"
{
model
}
"
)
layer_tools
.
print_params
(
model
,
logger
.
info
)
layer_tools
.
print_params
(
model
,
logger
.
info
)
...
@@ -312,14 +319,14 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer):
...
@@ -312,14 +319,14 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer):
self
.
text_featurizer
=
TextFeaturizer
(
self
.
text_featurizer
=
TextFeaturizer
(
unit_type
=
config
.
unit_type
,
vocab
=
config
.
vocab_filepath
)
unit_type
=
config
.
unit_type
,
vocab
=
config
.
vocab_filepath
)
self
.
vocab_list
=
self
.
text_featurizer
.
vocab_list
self
.
vocab_list
=
self
.
text_featurizer
.
vocab_list
def
id2token
(
self
,
texts
,
texts_len
):
def
id2token
(
self
,
texts
,
texts_len
):
""" ord() id to chr() chr """
""" ord() id to chr() chr """
trans
=
[]
trans
=
[]
for
text
,
n
in
zip
(
texts
,
texts_len
):
for
text
,
n
in
zip
(
texts
,
texts_len
):
n
=
n
.
numpy
().
item
()
n
=
n
.
numpy
().
item
()
ids
=
text
[:
n
]
ids
=
text
[:
n
]
trans
.
append
(
trans
.
append
(
self
.
text_featurizer
.
defeaturize
(
ids
.
numpy
().
tolist
()))
self
.
text_featurizer
.
defeaturize
(
ids
.
numpy
().
tolist
()))
return
trans
return
trans
def
compute_metrics
(
self
,
def
compute_metrics
(
self
,
...
@@ -337,10 +344,10 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer):
...
@@ -337,10 +344,10 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer):
start_time
=
time
.
time
()
start_time
=
time
.
time
()
target_transcripts
=
self
.
id2token
(
texts
,
texts_len
)
target_transcripts
=
self
.
id2token
(
texts
,
texts_len
)
result_transcripts
,
result_tokenids
=
self
.
model
.
decode
(
result_transcripts
,
result_tokenids
=
self
.
model
.
decode
(
audio
,
audio
,
text_feature
=
self
.
text_featurizer
,
text_feature
=
self
.
text_featurizer
,
decoding_method
=
decode_cfg
.
decoding_method
,
decoding_method
=
decode_cfg
.
decoding_method
,
beam_size
=
decode_cfg
.
beam_size
)
beam_size
=
decode_cfg
.
beam_size
)
decode_time
=
time
.
time
()
-
start_time
decode_time
=
time
.
time
()
-
start_time
for
utt
,
target
,
result
,
rec_tids
in
zip
(
for
utt
,
target
,
result
,
rec_tids
in
zip
(
...
@@ -432,4 +439,4 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer):
...
@@ -432,4 +439,4 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer):
"decode_method"
:
"decode_method"
:
self
.
config
.
decode
.
decoding_method
,
self
.
config
.
decode
.
decoding_method
,
})
})
f
.
write
(
data
+
'
\n
'
)
f
.
write
(
data
+
'
\n
'
)
\ No newline at end of file
paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py
浏览文件 @
19180d35
...
@@ -3,6 +3,7 @@ Authors
...
@@ -3,6 +3,7 @@ Authors
* Elena Rastorgueva 2020
* Elena Rastorgueva 2020
"""
"""
import
paddle
import
paddle
from
paddlespeech.s2t.models.wav2vec2.modules
import
containers
from
paddlespeech.s2t.models.wav2vec2.modules
import
containers
from
paddlespeech.s2t.models.wav2vec2.modules
import
linear
from
paddlespeech.s2t.models.wav2vec2.modules
import
linear
...
@@ -27,12 +28,11 @@ class VanillaNN(containers.Sequential):
...
@@ -27,12 +28,11 @@ class VanillaNN(containers.Sequential):
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
input_shape
,
input_shape
,
activation
=
paddle
.
nn
.
LeakyReLU
,
activation
=
paddle
.
nn
.
LeakyReLU
,
dnn_blocks
=
2
,
dnn_blocks
=
2
,
dnn_neurons
=
512
,
dnn_neurons
=
512
,
):
):
super
().
__init__
(
input_shape
=
input_shape
)
super
().
__init__
(
input_shape
=
input_shape
)
for
block_index
in
range
(
dnn_blocks
):
for
block_index
in
range
(
dnn_blocks
):
...
@@ -40,6 +40,5 @@ class VanillaNN(containers.Sequential):
...
@@ -40,6 +40,5 @@ class VanillaNN(containers.Sequential):
linear
.
Linear
,
linear
.
Linear
,
n_neurons
=
dnn_neurons
,
n_neurons
=
dnn_neurons
,
bias
=
True
,
bias
=
True
,
layer_name
=
"linear"
,
layer_name
=
"linear"
,
)
)
self
.
append
(
activation
(),
layer_name
=
"act"
)
self
.
append
(
activation
(),
layer_name
=
"act"
)
paddlespeech/s2t/models/wav2vec2/modules/activations.py
浏览文件 @
19180d35
...
@@ -11,12 +11,10 @@
...
@@ -11,12 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
math
import
math
from
packaging
import
version
from
paddle
import
nn
from
paddle
import
Tensor
,
nn
from
paddle
import
Tensor
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
@@ -29,7 +27,9 @@ class NewGELUActivation(nn.Layer):
...
@@ -29,7 +27,9 @@ class NewGELUActivation(nn.Layer):
"""
"""
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
return
0.5
*
input
*
(
1.0
+
paddle
.
tanh
(
math
.
sqrt
(
2.0
/
math
.
pi
)
*
(
input
+
0.044715
*
paddle
.
pow
(
input
,
3.0
))))
return
0.5
*
input
*
(
1.0
+
paddle
.
tanh
(
math
.
sqrt
(
2.0
/
math
.
pi
)
*
(
input
+
0.044715
*
paddle
.
pow
(
input
,
3.0
))))
class
GELUActivation
(
nn
.
Layer
):
class
GELUActivation
(
nn
.
Layer
):
...
@@ -40,7 +40,7 @@ class GELUActivation(nn.Layer):
...
@@ -40,7 +40,7 @@ class GELUActivation(nn.Layer):
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
"""
def
__init__
(
self
,
use_gelu_python
:
bool
=
False
):
def
__init__
(
self
,
use_gelu_python
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
act
=
nn
.
functional
.
gelu
self
.
act
=
nn
.
functional
.
gelu
...
@@ -57,7 +57,9 @@ class FastGELUActivation(nn.Layer):
...
@@ -57,7 +57,9 @@ class FastGELUActivation(nn.Layer):
"""
"""
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
return
0.5
*
input
*
(
1.0
+
paddle
.
tanh
(
input
*
0.7978845608
*
(
1.0
+
0.044715
*
input
*
input
)))
return
0.5
*
input
*
(
1.0
+
paddle
.
tanh
(
input
*
0.7978845608
*
(
1.0
+
0.044715
*
input
*
input
)))
class
QuickGELUActivation
(
nn
.
Layer
):
class
QuickGELUActivation
(
nn
.
Layer
):
...
@@ -84,7 +86,8 @@ class ClippedGELUActivation(nn.Layer):
...
@@ -84,7 +86,8 @@ class ClippedGELUActivation(nn.Layer):
def
__init__
(
self
,
min
:
float
,
max
:
float
):
def
__init__
(
self
,
min
:
float
,
max
:
float
):
if
min
>
max
:
if
min
>
max
:
raise
ValueError
(
f
"min should be < max (got min:
{
min
}
, max:
{
max
}
)"
)
raise
ValueError
(
f
"min should be < max (got min:
{
min
}
, max:
{
max
}
)"
)
super
().
__init__
()
super
().
__init__
()
self
.
min
=
min
self
.
min
=
min
...
@@ -161,7 +164,9 @@ def get_activation(activation_string):
...
@@ -161,7 +164,9 @@ def get_activation(activation_string):
if
activation_string
in
ACT2FN
:
if
activation_string
in
ACT2FN
:
return
ACT2FN
[
activation_string
]
return
ACT2FN
[
activation_string
]
else
:
else
:
raise
KeyError
(
f
"function
{
activation_string
}
not found in ACT2FN mapping
{
list
(
ACT2FN
.
keys
())
}
"
)
raise
KeyError
(
f
"function
{
activation_string
}
not found in ACT2FN mapping
{
list
(
ACT2FN
.
keys
())
}
"
)
# For backwards compatibility with: from activations import gelu_python
# For backwards compatibility with: from activations import gelu_python
...
...
paddlespeech/s2t/models/wav2vec2/modules/containers.py
浏览文件 @
19180d35
import
paddle
import
inspect
import
inspect
import
logging
import
operator
import
paddle
import
functools
class
Sequential
(
paddle
.
nn
.
LayerDict
):
class
Sequential
(
paddle
.
nn
.
LayerDict
):
"""A sequence of modules with potentially inferring shape on construction.
"""A sequence of modules with potentially inferring shape on construction.
...
@@ -98,13 +97,12 @@ class Sequential(paddle.nn.LayerDict):
...
@@ -98,13 +97,12 @@ class Sequential(paddle.nn.LayerDict):
# Finally, append the layer.
# Finally, append the layer.
try
:
try
:
self
[
layer_name
]
=
layer
self
[
layer_name
]
=
layer
# self.add_module(layer_name, layer)
# self.add_module(layer_name, layer)
except
TypeError
:
except
TypeError
:
raise
ValueError
(
raise
ValueError
(
"Must pass `input_shape` at initialization and use "
"Must pass `input_shape` at initialization and use "
"modules that take `input_shape` to infer shape when "
"modules that take `input_shape` to infer shape when "
"using `append()`."
"using `append()`."
)
)
def
get_output_shape
(
self
):
def
get_output_shape
(
self
):
"""Returns expected shape of the output.
"""Returns expected shape of the output.
...
...
paddlespeech/s2t/models/wav2vec2/modules/linear.py
浏览文件 @
19180d35
...
@@ -3,10 +3,10 @@ Authors
...
@@ -3,10 +3,10 @@ Authors
* Mirco Ravanelli 2020
* Mirco Ravanelli 2020
* Davide Borra 2021
* Davide Borra 2021
"""
"""
import
logging
import
logging
import
paddle
import
paddle
import
paddle.nn
as
nn
from
paddlespeech.s2t.modules
import
align
from
paddlespeech.s2t.modules
import
align
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -37,13 +37,12 @@ class Linear(paddle.nn.Layer):
...
@@ -37,13 +37,12 @@ class Linear(paddle.nn.Layer):
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
n_neurons
,
n_neurons
,
input_shape
=
None
,
input_shape
=
None
,
input_size
=
None
,
input_size
=
None
,
bias
=
True
,
bias
=
True
,
combine_dims
=
False
,
combine_dims
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
combine_dims
=
combine_dims
self
.
combine_dims
=
combine_dims
...
...
paddlespeech/s2t/models/wav2vec2/modules/modeling_outputs.py
浏览文件 @
19180d35
...
@@ -11,12 +11,12 @@
...
@@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
dataclasses
import
dataclass
from
dataclasses
import
fields
from
dataclasses
import
fields
from
typing
import
Optional
from
typing
import
Tuple
import
paddle
import
paddle
...
@@ -41,10 +41,13 @@ class ModelOutput(OrderedDict):
...
@@ -41,10 +41,13 @@ class ModelOutput(OrderedDict):
if
not
len
(
class_fields
):
if
not
len
(
class_fields
):
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
has no fields."
)
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
has no fields."
)
if
not
all
(
field
.
default
is
None
for
field
in
class_fields
[
1
:]):
if
not
all
(
field
.
default
is
None
for
field
in
class_fields
[
1
:]):
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
should not have more than one required field."
)
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
should not have more than one required field."
)
first_field
=
getattr
(
self
,
class_fields
[
0
].
name
)
first_field
=
getattr
(
self
,
class_fields
[
0
].
name
)
other_fields_are_none
=
all
(
getattr
(
self
,
field
.
name
)
is
None
for
field
in
class_fields
[
1
:])
other_fields_are_none
=
all
(
getattr
(
self
,
field
.
name
)
is
None
for
field
in
class_fields
[
1
:])
if
other_fields_are_none
and
not
paddle
.
is_tensor
(
first_field
):
if
other_fields_are_none
and
not
paddle
.
is_tensor
(
first_field
):
if
isinstance
(
first_field
,
dict
):
if
isinstance
(
first_field
,
dict
):
...
@@ -61,11 +64,9 @@ class ModelOutput(OrderedDict):
...
@@ -61,11 +64,9 @@ class ModelOutput(OrderedDict):
# set the associated fields
# set the associated fields
if
first_field_iterator
:
if
first_field_iterator
:
for
element
in
iterator
:
for
element
in
iterator
:
if
(
if
(
not
isinstance
(
element
,
(
list
,
tuple
))
or
not
isinstance
(
element
,
(
list
,
tuple
))
not
len
(
element
)
==
2
or
or
not
len
(
element
)
==
2
not
isinstance
(
element
[
0
],
str
)):
or
not
isinstance
(
element
[
0
],
str
)
):
break
break
setattr
(
self
,
element
[
0
],
element
[
1
])
setattr
(
self
,
element
[
0
],
element
[
1
])
if
element
[
1
]
is
not
None
:
if
element
[
1
]
is
not
None
:
...
@@ -79,16 +80,23 @@ class ModelOutput(OrderedDict):
...
@@ -79,16 +80,23 @@ class ModelOutput(OrderedDict):
self
[
field
.
name
]
=
v
self
[
field
.
name
]
=
v
def
__delitem__
(
self
,
*
args
,
**
kwargs
):
def
__delitem__
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``__delitem__`` on a
{
self
.
__class__
.
__name__
}
instance."
)
raise
Exception
(
f
"You cannot use ``__delitem__`` on a
{
self
.
__class__
.
__name__
}
instance."
)
def
setdefault
(
self
,
*
args
,
**
kwargs
):
def
setdefault
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``setdefault`` on a
{
self
.
__class__
.
__name__
}
instance."
)
raise
Exception
(
f
"You cannot use ``setdefault`` on a
{
self
.
__class__
.
__name__
}
instance."
)
def
pop
(
self
,
*
args
,
**
kwargs
):
def
pop
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``pop`` on a
{
self
.
__class__
.
__name__
}
instance."
)
raise
Exception
(
f
"You cannot use ``pop`` on a
{
self
.
__class__
.
__name__
}
instance."
)
def
update
(
self
,
*
args
,
**
kwargs
):
def
update
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``update`` on a
{
self
.
__class__
.
__name__
}
instance."
)
raise
Exception
(
f
"You cannot use ``update`` on a
{
self
.
__class__
.
__name__
}
instance."
)
def
__getitem__
(
self
,
k
):
def
__getitem__
(
self
,
k
):
if
isinstance
(
k
,
str
):
if
isinstance
(
k
,
str
):
...
...
paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py
浏览文件 @
19180d35
...
@@ -13,24 +13,19 @@
...
@@ -13,24 +13,19 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
""" Paddle Wav2Vec2 model."""
""" Paddle Wav2Vec2 model."""
import
math
import
warnings
import
paddle
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
from
typing
import
Tuple
from
typing
import
Union
import
numpy
as
np
import
numpy
as
np
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddlespeech.s2t.models.wav2vec2.modules.activations
import
ACT2FN
from
paddlespeech.s2t.models.wav2vec2.modules.activations
import
ACT2FN
from
paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs
import
(
from
paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs
import
BaseModelOutput
BaseModelOutput
,
from
paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs
import
ModelOutput
Wav2Vec2BaseModelOutput
,
from
paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs
import
Wav2Vec2BaseModelOutput
ModelOutput
)
import
pdb
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
@@ -78,12 +73,11 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
...
@@ -78,12 +73,11 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
def
_compute_mask_indices
(
def
_compute_mask_indices
(
shape
:
Tuple
[
int
,
int
],
shape
:
Tuple
[
int
,
int
],
mask_prob
:
float
,
mask_prob
:
float
,
mask_length
:
int
,
mask_length
:
int
,
attention_mask
:
Optional
[
paddle
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
paddle
.
Tensor
]
=
None
,
min_masks
:
int
=
0
,
min_masks
:
int
=
0
,
)
->
np
.
ndarray
:
)
->
np
.
ndarray
:
"""
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
...
@@ -109,8 +103,7 @@ def _compute_mask_indices(
...
@@ -109,8 +103,7 @@ def _compute_mask_indices(
if
mask_length
>
sequence_length
:
if
mask_length
>
sequence_length
:
raise
ValueError
(
raise
ValueError
(
f
"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`:
{
mask_length
}
"
f
"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`:
{
mask_length
}
"
f
" and `sequence_length`:
{
sequence_length
}
`"
f
" and `sequence_length`:
{
sequence_length
}
`"
)
)
# epsilon is used for probabilistic rounding
# epsilon is used for probabilistic rounding
epsilon
=
np
.
random
.
rand
(
1
).
item
()
epsilon
=
np
.
random
.
rand
(
1
).
item
()
...
@@ -131,11 +124,9 @@ def _compute_mask_indices(
...
@@ -131,11 +124,9 @@ def _compute_mask_indices(
return
num_masked_span
return
num_masked_span
# compute number of masked spans in batch
# compute number of masked spans in batch
input_lengths
=
(
input_lengths
=
(
attention_mask
.
sum
(
-
1
).
detach
().
tolist
()
attention_mask
.
sum
(
-
1
).
detach
().
tolist
()
if
attention_mask
is
not
None
else
if
attention_mask
is
not
None
[
sequence_length
for
_
in
range
(
batch_size
)])
else
[
sequence_length
for
_
in
range
(
batch_size
)]
)
# SpecAugment mask to fill
# SpecAugment mask to fill
spec_aug_mask
=
np
.
zeros
((
batch_size
,
sequence_length
),
dtype
=
np
.
bool
)
spec_aug_mask
=
np
.
zeros
((
batch_size
,
sequence_length
),
dtype
=
np
.
bool
)
...
@@ -152,8 +143,9 @@ def _compute_mask_indices(
...
@@ -152,8 +143,9 @@ def _compute_mask_indices(
# get random indices to mask
# get random indices to mask
spec_aug_mask_idx
=
np
.
random
.
choice
(
spec_aug_mask_idx
=
np
.
random
.
choice
(
np
.
arange
(
input_length
-
(
mask_length
-
1
)),
num_masked_span
,
replace
=
False
np
.
arange
(
input_length
-
(
mask_length
-
1
)),
)
num_masked_span
,
replace
=
False
)
# pick first sampled index that will serve as a dummy index to pad vector
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# to ensure same dimension for all batches due to probabilistic rounding
...
@@ -166,29 +158,33 @@ def _compute_mask_indices(
...
@@ -166,29 +158,33 @@ def _compute_mask_indices(
else
:
else
:
dummy_mask_idx
=
spec_aug_mask_idx
[
0
]
dummy_mask_idx
=
spec_aug_mask_idx
[
0
]
spec_aug_mask_idx
=
np
.
concatenate
(
spec_aug_mask_idx
=
np
.
concatenate
([
[
spec_aug_mask_idx
,
np
.
ones
(
max_num_masked_span
-
num_masked_span
,
dtype
=
np
.
int32
)
*
dummy_mask_idx
]
spec_aug_mask_idx
,
)
np
.
ones
(
max_num_masked_span
-
num_masked_span
,
dtype
=
np
.
int32
)
*
dummy_mask_idx
])
spec_aug_mask_idxs
.
append
(
spec_aug_mask_idx
)
spec_aug_mask_idxs
.
append
(
spec_aug_mask_idx
)
spec_aug_mask_idxs
=
np
.
array
(
spec_aug_mask_idxs
)
spec_aug_mask_idxs
=
np
.
array
(
spec_aug_mask_idxs
)
# expand masked indices to masked spans
# expand masked indices to masked spans
spec_aug_mask_idxs
=
np
.
broadcast_to
(
spec_aug_mask_idxs
=
np
.
broadcast_to
(
spec_aug_mask_idxs
[:,
:,
None
],
(
batch_size
,
max_num_masked_span
,
mask_length
)
spec_aug_mask_idxs
[:,
:,
None
],
)
(
batch_size
,
max_num_masked_span
,
mask_length
))
spec_aug_mask_idxs
=
spec_aug_mask_idxs
.
reshape
((
batch_size
,
max_num_masked_span
*
mask_length
))
spec_aug_mask_idxs
=
spec_aug_mask_idxs
.
reshape
(
(
batch_size
,
max_num_masked_span
*
mask_length
))
# add offset to the starting indexes so that indexes now create a span
# add offset to the starting indexes so that indexes now create a span
offsets
=
np
.
arange
(
mask_length
)[
None
,
None
,
:]
offsets
=
np
.
arange
(
mask_length
)[
None
,
None
,
:]
offsets
=
np
.
broadcast_to
(
offsets
,
(
batch_size
,
max_num_masked_span
,
mask_length
)).
reshape
(
offsets
=
np
.
broadcast_to
(
offsets
,
(
(
batch_size
,
max_num_masked_span
*
mask_length
)
batch_size
,
max_num_masked_span
,
mask_length
)).
reshape
(
)
(
batch_size
,
max_num_masked_span
*
mask_length
)
)
spec_aug_mask_idxs
=
spec_aug_mask_idxs
+
offsets
spec_aug_mask_idxs
=
spec_aug_mask_idxs
+
offsets
# ensure that we cannot have indices larger than sequence_length
# ensure that we cannot have indices larger than sequence_length
if
spec_aug_mask_idxs
.
max
()
>
sequence_length
-
1
:
if
spec_aug_mask_idxs
.
max
()
>
sequence_length
-
1
:
spec_aug_mask_idxs
[
spec_aug_mask_idxs
>
sequence_length
-
1
]
=
sequence_length
-
1
spec_aug_mask_idxs
[
spec_aug_mask_idxs
>
sequence_length
-
1
]
=
sequence_length
-
1
# scatter indices to mask
# scatter indices to mask
np
.
put_along_axis
(
spec_aug_mask
,
spec_aug_mask_idxs
,
1
,
-
1
)
np
.
put_along_axis
(
spec_aug_mask
,
spec_aug_mask_idxs
,
1
,
-
1
)
...
@@ -196,9 +192,9 @@ def _compute_mask_indices(
...
@@ -196,9 +192,9 @@ def _compute_mask_indices(
return
spec_aug_mask
return
spec_aug_mask
def
_sample_negative_indices
(
def
_sample_negative_indices
(
features_shape
:
Tuple
,
features_shape
:
Tuple
,
num_negatives
:
int
,
mask_time_indices
:
Optional
[
np
.
ndarray
]
=
None
num_negatives
:
int
,
):
mask_time_indices
:
Optional
[
np
.
ndarray
]
=
None
):
"""
"""
Sample `num_negatives` vectors from feature vectors.
Sample `num_negatives` vectors from feature vectors.
"""
"""
...
@@ -208,23 +204,28 @@ def _sample_negative_indices(
...
@@ -208,23 +204,28 @@ def _sample_negative_indices(
sequence_length_range
=
np
.
arange
(
sequence_length
)
sequence_length_range
=
np
.
arange
(
sequence_length
)
# get `num_negatives` random vector indices from the same utterance
# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices
=
np
.
zeros
(
shape
=
(
batch_size
,
sequence_length
,
num_negatives
),
dtype
=
np
.
int32
)
sampled_negative_indices
=
np
.
zeros
(
shape
=
(
batch_size
,
sequence_length
,
num_negatives
),
dtype
=
np
.
int32
)
mask_time_indices
=
(
mask_time_indices
=
(
mask_time_indices
.
astype
(
np
.
bool
)
mask_time_indices
.
astype
(
np
.
bool
)
if
mask_time_indices
is
not
None
else
np
.
ones
(
features_shape
,
dtype
=
np
.
bool
)
if
mask_time_indices
is
not
None
else
)
np
.
ones
(
features_shape
,
dtype
=
np
.
bool
)
)
for
batch_idx
in
range
(
batch_size
):
for
batch_idx
in
range
(
batch_size
):
high
=
mask_time_indices
[
batch_idx
].
sum
()
-
1
high
=
mask_time_indices
[
batch_idx
].
sum
()
-
1
mapped_masked_indices
=
sequence_length_range
[
mask_time_indices
[
batch_idx
]]
mapped_masked_indices
=
sequence_length_range
[
mask_time_indices
[
batch_idx
]]
feature_indices
=
np
.
broadcast_to
(
np
.
arange
(
high
+
1
)[:,
None
],
(
high
+
1
,
num_negatives
))
feature_indices
=
np
.
broadcast_to
(
sampled_indices
=
np
.
random
.
randint
(
0
,
high
,
size
=
(
high
+
1
,
num_negatives
))
np
.
arange
(
high
+
1
)[:,
None
],
(
high
+
1
,
num_negatives
))
sampled_indices
=
np
.
random
.
randint
(
0
,
high
,
size
=
(
high
+
1
,
num_negatives
))
# avoid sampling the same positive vector, but keep the distribution uniform
# avoid sampling the same positive vector, but keep the distribution uniform
sampled_indices
[
sampled_indices
>=
feature_indices
]
+=
1
sampled_indices
[
sampled_indices
>=
feature_indices
]
+=
1
# remap to actual indices
# remap to actual indices
sampled_negative_indices
[
batch_idx
][
mask_time_indices
[
batch_idx
]]
=
mapped_masked_indices
[
sampled_indices
]
sampled_negative_indices
[
batch_idx
][
mask_time_indices
[
batch_idx
]]
=
mapped_masked_indices
[
sampled_indices
]
# correct for batch size
# correct for batch size
sampled_negative_indices
[
batch_idx
]
+=
batch_idx
*
sequence_length
sampled_negative_indices
[
batch_idx
]
+=
batch_idx
*
sequence_length
...
@@ -243,8 +244,7 @@ class Wav2Vec2NoLayerNormConvLayer(nn.Layer):
...
@@ -243,8 +244,7 @@ class Wav2Vec2NoLayerNormConvLayer(nn.Layer):
self
.
out_conv_dim
,
self
.
out_conv_dim
,
kernel_size
=
config
.
conv_kernel
[
layer_id
],
kernel_size
=
config
.
conv_kernel
[
layer_id
],
stride
=
config
.
conv_stride
[
layer_id
],
stride
=
config
.
conv_stride
[
layer_id
],
bias_attr
=
config
.
conv_bias
,
bias_attr
=
config
.
conv_bias
,
)
)
self
.
activation
=
ACT2FN
[
config
.
feat_extract_activation
]
self
.
activation
=
ACT2FN
[
config
.
feat_extract_activation
]
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
...
@@ -264,8 +264,7 @@ class Wav2Vec2LayerNormConvLayer(nn.Layer):
...
@@ -264,8 +264,7 @@ class Wav2Vec2LayerNormConvLayer(nn.Layer):
self
.
out_conv_dim
,
self
.
out_conv_dim
,
kernel_size
=
config
.
conv_kernel
[
layer_id
],
kernel_size
=
config
.
conv_kernel
[
layer_id
],
stride
=
config
.
conv_stride
[
layer_id
],
stride
=
config
.
conv_stride
[
layer_id
],
bias_attr
=
config
.
conv_bias
,
bias_attr
=
config
.
conv_bias
,
)
)
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
out_conv_dim
)
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
out_conv_dim
)
self
.
activation
=
ACT2FN
[
config
.
feat_extract_activation
]
self
.
activation
=
ACT2FN
[
config
.
feat_extract_activation
]
...
@@ -290,11 +289,11 @@ class Wav2Vec2GroupNormConvLayer(nn.Layer):
...
@@ -290,11 +289,11 @@ class Wav2Vec2GroupNormConvLayer(nn.Layer):
self
.
out_conv_dim
,
self
.
out_conv_dim
,
kernel_size
=
config
.
conv_kernel
[
layer_id
],
kernel_size
=
config
.
conv_kernel
[
layer_id
],
stride
=
config
.
conv_stride
[
layer_id
],
stride
=
config
.
conv_stride
[
layer_id
],
bias_attr
=
config
.
conv_bias
,
bias_attr
=
config
.
conv_bias
,
)
)
self
.
activation
=
ACT2FN
[
config
.
feat_extract_activation
]
self
.
activation
=
ACT2FN
[
config
.
feat_extract_activation
]
self
.
layer_norm
=
nn
.
GroupNorm
(
num_groups
=
self
.
out_conv_dim
,
num_channels
=
self
.
out_conv_dim
)
self
.
layer_norm
=
nn
.
GroupNorm
(
num_groups
=
self
.
out_conv_dim
,
num_channels
=
self
.
out_conv_dim
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
conv
(
hidden_states
)
hidden_states
=
self
.
conv
(
hidden_states
)
...
@@ -311,8 +310,7 @@ class Wav2Vec2PositionalConvEmbedding(nn.Layer):
...
@@ -311,8 +310,7 @@ class Wav2Vec2PositionalConvEmbedding(nn.Layer):
config
.
hidden_size
,
config
.
hidden_size
,
kernel_size
=
config
.
num_conv_pos_embeddings
,
kernel_size
=
config
.
num_conv_pos_embeddings
,
padding
=
config
.
num_conv_pos_embeddings
//
2
,
padding
=
config
.
num_conv_pos_embeddings
//
2
,
groups
=
config
.
num_conv_pos_embedding_groups
,
groups
=
config
.
num_conv_pos_embedding_groups
,
)
)
self
.
conv
=
nn
.
utils
.
weight_norm
(
self
.
conv
,
name
=
"weight"
,
dim
=
2
)
self
.
conv
=
nn
.
utils
.
weight_norm
(
self
.
conv
,
name
=
"weight"
,
dim
=
2
)
...
@@ -337,7 +335,7 @@ class Wav2Vec2SamePadLayer(nn.Layer):
...
@@ -337,7 +335,7 @@ class Wav2Vec2SamePadLayer(nn.Layer):
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
if
self
.
num_pad_remove
>
0
:
if
self
.
num_pad_remove
>
0
:
hidden_states
=
hidden_states
[:,
:,
:
-
self
.
num_pad_remove
]
hidden_states
=
hidden_states
[:,
:,
:
-
self
.
num_pad_remove
]
return
hidden_states
return
hidden_states
...
@@ -349,11 +347,13 @@ class Wav2Vec2FeatureEncoder(nn.Layer):
...
@@ -349,11 +347,13 @@ class Wav2Vec2FeatureEncoder(nn.Layer):
if
config
.
feat_extract_norm
==
"group"
:
if
config
.
feat_extract_norm
==
"group"
:
conv_layers
=
[
Wav2Vec2GroupNormConvLayer
(
config
,
layer_id
=
0
)]
+
[
conv_layers
=
[
Wav2Vec2GroupNormConvLayer
(
config
,
layer_id
=
0
)]
+
[
Wav2Vec2NoLayerNormConvLayer
(
config
,
layer_id
=
i
+
1
)
for
i
in
range
(
config
.
num_feat_extract_layers
-
1
)
Wav2Vec2NoLayerNormConvLayer
(
config
,
layer_id
=
i
+
1
)
for
i
in
range
(
config
.
num_feat_extract_layers
-
1
)
]
]
elif
config
.
feat_extract_norm
==
"layer"
:
elif
config
.
feat_extract_norm
==
"layer"
:
conv_layers
=
[
conv_layers
=
[
Wav2Vec2LayerNormConvLayer
(
config
,
layer_id
=
i
)
for
i
in
range
(
config
.
num_feat_extract_layers
)
Wav2Vec2LayerNormConvLayer
(
config
,
layer_id
=
i
)
for
i
in
range
(
config
.
num_feat_extract_layers
)
]
]
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
@@ -373,10 +373,12 @@ class Wav2Vec2FeatureEncoder(nn.Layer):
...
@@ -373,10 +373,12 @@ class Wav2Vec2FeatureEncoder(nn.Layer):
return
hidden_states
return
hidden_states
class
Wav2Vec2FeatureProjection
(
nn
.
Layer
):
class
Wav2Vec2FeatureProjection
(
nn
.
Layer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
().
__init__
()
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
conv_dim
[
-
1
],
epsilon
=
config
.
layer_norm_eps
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
conv_dim
[
-
1
],
epsilon
=
config
.
layer_norm_eps
)
self
.
projection
=
nn
.
Linear
(
config
.
conv_dim
[
-
1
],
config
.
hidden_size
)
self
.
projection
=
nn
.
Linear
(
config
.
conv_dim
[
-
1
],
config
.
hidden_size
)
self
.
dropout
=
nn
.
Dropout
(
config
.
feat_proj_dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
feat_proj_dropout
)
...
@@ -393,13 +395,12 @@ class Wav2Vec2Attention(nn.Layer):
...
@@ -393,13 +395,12 @@ class Wav2Vec2Attention(nn.Layer):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
def
__init__
(
self
,
self
,
embed_dim
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
num_heads
:
int
,
dropout
:
float
=
0.0
,
dropout
:
float
=
0.0
,
is_decoder
:
bool
=
False
,
is_decoder
:
bool
=
False
,
bias
:
bool
=
True
,
bias
:
bool
=
True
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
...
@@ -409,8 +410,7 @@ class Wav2Vec2Attention(nn.Layer):
...
@@ -409,8 +410,7 @@ class Wav2Vec2Attention(nn.Layer):
if
(
self
.
head_dim
*
num_heads
)
!=
self
.
embed_dim
:
if
(
self
.
head_dim
*
num_heads
)
!=
self
.
embed_dim
:
raise
ValueError
(
raise
ValueError
(
f
"embed_dim must be divisible by num_heads (got `embed_dim`:
{
self
.
embed_dim
}
"
f
"embed_dim must be divisible by num_heads (got `embed_dim`:
{
self
.
embed_dim
}
"
f
" and `num_heads`:
{
num_heads
}
)."
f
" and `num_heads`:
{
num_heads
}
)."
)
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
is_decoder
=
is_decoder
self
.
is_decoder
=
is_decoder
...
@@ -420,17 +420,18 @@ class Wav2Vec2Attention(nn.Layer):
...
@@ -420,17 +420,18 @@ class Wav2Vec2Attention(nn.Layer):
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias_attr
=
bias
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias_attr
=
bias
)
def
_shape
(
self
,
tensor
:
paddle
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
def
_shape
(
self
,
tensor
:
paddle
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
paddle
.
reshape
(
tensor
,
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
)).
transpose
([
0
,
2
,
1
,
3
])
return
paddle
.
reshape
(
tensor
,
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
)).
transpose
([
0
,
2
,
1
,
3
])
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
paddle
.
Tensor
,
hidden_states
:
paddle
.
Tensor
,
key_value_states
:
Optional
[
paddle
.
Tensor
]
=
None
,
key_value_states
:
Optional
[
paddle
.
Tensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
paddle
.
Tensor
]]
=
None
,
past_key_value
:
Optional
[
Tuple
[
paddle
.
Tensor
]]
=
None
,
attention_mask
:
Optional
[
paddle
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
paddle
.
Tensor
]
=
None
,
layer_head_mask
:
Optional
[
paddle
.
Tensor
]
=
None
,
layer_head_mask
:
Optional
[
paddle
.
Tensor
]
=
None
,
output_attentions
:
bool
=
False
,
output_attentions
:
bool
=
False
,
)
->
Tuple
[
paddle
.
Tensor
,
Optional
[
)
->
Tuple
[
paddle
.
Tensor
,
Optional
[
paddle
.
Tensor
],
Optional
[
Tuple
[
paddle
.
Tensor
]]]:
paddle
.
Tensor
],
Optional
[
Tuple
[
paddle
.
Tensor
]]]:
"""Input shape: Batch x Time x Channel"""
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# if key_value_states are provided this layer is used as a cross-attention layer
...
@@ -455,7 +456,8 @@ class Wav2Vec2Attention(nn.Layer):
...
@@ -455,7 +456,8 @@ class Wav2Vec2Attention(nn.Layer):
key_states
=
self
.
_shape
(
self
.
k_proj
(
hidden_states
),
-
1
,
bsz
)
key_states
=
self
.
_shape
(
self
.
k_proj
(
hidden_states
),
-
1
,
bsz
)
value_states
=
self
.
_shape
(
self
.
v_proj
(
hidden_states
),
-
1
,
bsz
)
value_states
=
self
.
_shape
(
self
.
v_proj
(
hidden_states
),
-
1
,
bsz
)
key_states
=
paddle
.
concat
([
past_key_value
[
0
],
key_states
],
axis
=
2
)
key_states
=
paddle
.
concat
([
past_key_value
[
0
],
key_states
],
axis
=
2
)
value_states
=
paddle
.
concat
([
past_key_value
[
1
],
value_states
],
axis
=
2
)
value_states
=
paddle
.
concat
(
[
past_key_value
[
1
],
value_states
],
axis
=
2
)
else
:
else
:
# self_attention
# self_attention
key_states
=
self
.
_shape
(
self
.
k_proj
(
hidden_states
),
-
1
,
bsz
)
key_states
=
self
.
_shape
(
self
.
k_proj
(
hidden_states
),
-
1
,
bsz
)
...
@@ -472,60 +474,68 @@ class Wav2Vec2Attention(nn.Layer):
...
@@ -472,60 +474,68 @@ class Wav2Vec2Attention(nn.Layer):
past_key_value
=
(
key_states
,
value_states
)
past_key_value
=
(
key_states
,
value_states
)
proj_shape
=
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
proj_shape
=
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
query_states
=
self
.
_shape
(
query_states
,
tgt_len
,
bsz
).
reshape
(
proj_shape
)
query_states
=
self
.
_shape
(
query_states
,
tgt_len
,
bsz
).
reshape
(
proj_shape
)
key_states
=
key_states
.
reshape
(
proj_shape
)
key_states
=
key_states
.
reshape
(
proj_shape
)
value_states
=
value_states
.
reshape
(
proj_shape
)
value_states
=
value_states
.
reshape
(
proj_shape
)
src_len
=
key_states
.
shape
[
1
]
src_len
=
key_states
.
shape
[
1
]
attn_weights
=
paddle
.
bmm
(
query_states
,
key_states
.
transpose
([
0
,
2
,
1
]))
attn_weights
=
paddle
.
bmm
(
query_states
,
key_states
.
transpose
([
0
,
2
,
1
]))
if
attn_weights
.
shape
!=
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]:
if
attn_weights
.
shape
!=
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]:
raise
ValueError
(
raise
ValueError
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
attn_weights
.
shape
}
"
f
"
{
attn_weights
.
shape
}
"
)
)
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
if
attention_mask
.
shape
!=
[
bsz
,
1
,
tgt_len
,
src_len
]:
if
attention_mask
.
shape
!=
[
bsz
,
1
,
tgt_len
,
src_len
]:
raise
ValueError
(
raise
ValueError
(
f
"Attention mask should be of size
{
[
bsz
,
1
,
tgt_len
,
src_len
]
}
, but is
{
attention_mask
.
shape
}
"
f
"Attention mask should be of size
{
[
bsz
,
1
,
tgt_len
,
src_len
]
}
, but is
{
attention_mask
.
shape
}
"
)
)
attn_weights
=
attn_weights
.
reshape
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
+
attention_mask
attn_weights
=
attn_weights
.
reshape
(
bsz
,
self
.
num_heads
,
tgt_len
,
attn_weights
=
attn_weights
.
reshape
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
src_len
)
+
attention_mask
attn_weights
=
attn_weights
.
reshape
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
axis
=-
1
)
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
axis
=-
1
)
if
layer_head_mask
is
not
None
:
if
layer_head_mask
is
not
None
:
if
layer_head_mask
.
shape
!=
[
self
.
num_heads
,]:
if
layer_head_mask
.
shape
!=
[
self
.
num_heads
,
]:
raise
ValueError
(
raise
ValueError
(
f
"Head mask for a single layer should be of size
{
[
self
.
num_heads
,]
}
, but is"
f
"Head mask for a single layer should be of size
{
[
self
.
num_heads
,]
}
, but is"
f
"
{
layer_head_mask
.
shape
}
"
f
"
{
layer_head_mask
.
shape
}
"
)
)
attn_weights
=
layer_head_mask
.
reshape
(
attn_weights
=
layer_head_mask
.
reshape
((
1
,
-
1
,
1
,
1
))
*
attn_weights
.
reshape
((
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
(
1
,
-
1
,
1
,
1
))
*
attn_weights
.
reshape
(
attn_weights
=
attn_weights
.
reshape
((
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
attn_weights
.
reshape
(
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
if
output_attentions
:
if
output_attentions
:
# this operation is a bit awkward, but it's required to
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
# twice and have to be reused in the following
attn_weights_reshaped
=
attn_weights
.
reshape
((
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights_reshaped
=
attn_weights
.
reshape
(
attn_weights
=
attn_weights_reshaped
.
reshape
((
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
))
attn_weights
=
attn_weights_reshaped
.
reshape
(
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
))
else
:
else
:
attn_weights_reshaped
=
None
attn_weights_reshaped
=
None
attn_probs
=
nn
.
functional
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn_probs
=
nn
.
functional
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn_output
=
paddle
.
bmm
(
attn_probs
,
value_states
)
attn_output
=
paddle
.
bmm
(
attn_probs
,
value_states
)
if
attn_output
.
shape
!=
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]:
if
attn_output
.
shape
!=
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]:
raise
ValueError
(
raise
ValueError
(
f
"`attn_output` should be of size
{
[
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
}
, but is"
f
"`attn_output` should be of size
{
[
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
}
, but is"
f
"
{
attn_output
.
shape
}
"
f
"
{
attn_output
.
shape
}
"
)
)
attn_output
=
attn_output
.
reshape
((
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
))
attn_output
=
attn_output
.
reshape
(
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
))
attn_output
=
attn_output
.
transpose
([
0
,
2
,
1
,
3
])
attn_output
=
attn_output
.
transpose
([
0
,
2
,
1
,
3
])
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
...
@@ -542,13 +552,15 @@ class Wav2Vec2FeedForward(nn.Layer):
...
@@ -542,13 +552,15 @@ class Wav2Vec2FeedForward(nn.Layer):
super
().
__init__
()
super
().
__init__
()
self
.
intermediate_dropout
=
nn
.
Dropout
(
config
.
activation_dropout
)
self
.
intermediate_dropout
=
nn
.
Dropout
(
config
.
activation_dropout
)
self
.
intermediate_dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
intermediate_dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
if
isinstance
(
config
.
hidden_act
,
str
):
if
isinstance
(
config
.
hidden_act
,
str
):
self
.
intermediate_act_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
intermediate_act_fn
=
ACT2FN
[
config
.
hidden_act
]
else
:
else
:
self
.
intermediate_act_fn
=
config
.
hidden_act
self
.
intermediate_act_fn
=
config
.
hidden_act
self
.
output_dense
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
self
.
output_dense
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
self
.
output_dropout
=
nn
.
Dropout
(
config
.
hidden_dropout
)
self
.
output_dropout
=
nn
.
Dropout
(
config
.
hidden_dropout
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
...
@@ -568,18 +580,23 @@ class Wav2Vec2EncoderLayer(nn.Layer):
...
@@ -568,18 +580,23 @@ class Wav2Vec2EncoderLayer(nn.Layer):
embed_dim
=
config
.
hidden_size
,
embed_dim
=
config
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
dropout
=
config
.
attention_dropout
,
dropout
=
config
.
attention_dropout
,
is_decoder
=
False
,
is_decoder
=
False
,
)
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
epsilon
=
config
.
layer_norm_eps
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
epsilon
=
config
.
layer_norm_eps
)
self
.
feed_forward
=
Wav2Vec2FeedForward
(
config
)
self
.
feed_forward
=
Wav2Vec2FeedForward
(
config
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
epsilon
=
config
.
layer_norm_eps
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
epsilon
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
output_attentions
=
False
):
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
output_attentions
=
False
):
attn_residual
=
hidden_states
attn_residual
=
hidden_states
hidden_states
,
attn_weights
,
_
=
self
.
attention
(
hidden_states
,
attn_weights
,
_
=
self
.
attention
(
hidden_states
,
attention_mask
=
attention_mask
,
output_attentions
=
output_attentions
hidden_states
,
)
attention_mask
=
attention_mask
,
output_attentions
=
output_attentions
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
attn_residual
+
hidden_states
hidden_states
=
attn_residual
+
hidden_states
...
@@ -587,10 +604,10 @@ class Wav2Vec2EncoderLayer(nn.Layer):
...
@@ -587,10 +604,10 @@ class Wav2Vec2EncoderLayer(nn.Layer):
hidden_states
=
hidden_states
+
self
.
feed_forward
(
hidden_states
)
hidden_states
=
hidden_states
+
self
.
feed_forward
(
hidden_states
)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
outputs
=
(
hidden_states
,)
outputs
=
(
hidden_states
,
)
if
output_attentions
:
if
output_attentions
:
outputs
+=
(
attn_weights
,)
outputs
+=
(
attn_weights
,
)
return
outputs
return
outputs
...
@@ -602,27 +619,33 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Layer):
...
@@ -602,27 +619,33 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Layer):
embed_dim
=
config
.
hidden_size
,
embed_dim
=
config
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
dropout
=
config
.
attention_dropout
,
dropout
=
config
.
attention_dropout
,
is_decoder
=
False
,
is_decoder
=
False
,
)
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
epsilon
=
config
.
layer_norm_eps
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
epsilon
=
config
.
layer_norm_eps
)
self
.
feed_forward
=
Wav2Vec2FeedForward
(
config
)
self
.
feed_forward
=
Wav2Vec2FeedForward
(
config
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
epsilon
=
config
.
layer_norm_eps
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
epsilon
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
output_attentions
=
False
):
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
output_attentions
=
False
):
attn_residual
=
hidden_states
attn_residual
=
hidden_states
hidden_states
=
self
.
layer_norm
(
hidden_states
)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
hidden_states
,
attn_weights
,
_
=
self
.
attention
(
hidden_states
,
attn_weights
,
_
=
self
.
attention
(
hidden_states
,
attention_mask
=
attention_mask
,
output_attentions
=
output_attentions
hidden_states
,
)
attention_mask
=
attention_mask
,
output_attentions
=
output_attentions
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
attn_residual
+
hidden_states
hidden_states
=
attn_residual
+
hidden_states
hidden_states
=
hidden_states
+
self
.
feed_forward
(
self
.
final_layer_norm
(
hidden_states
))
hidden_states
=
hidden_states
+
self
.
feed_forward
(
self
.
final_layer_norm
(
hidden_states
))
outputs
=
(
hidden_states
,)
outputs
=
(
hidden_states
,
)
if
output_attentions
:
if
output_attentions
:
outputs
+=
(
attn_weights
,)
outputs
+=
(
attn_weights
,
)
return
outputs
return
outputs
...
@@ -632,33 +655,38 @@ class Wav2Vec2Encoder(nn.Layer):
...
@@ -632,33 +655,38 @@ class Wav2Vec2Encoder(nn.Layer):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
pos_conv_embed
=
Wav2Vec2PositionalConvEmbedding
(
config
)
self
.
pos_conv_embed
=
Wav2Vec2PositionalConvEmbedding
(
config
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
epsilon
=
config
.
layer_norm_eps
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
epsilon
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout
)
self
.
layers
=
nn
.
LayerList
([
Wav2Vec2EncoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
layers
=
nn
.
LayerList
([
Wav2Vec2EncoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
gradient_checkpointing
=
False
self
.
gradient_checkpointing
=
False
def
forward
(
def
forward
(
self
,
self
,
hidden_states
,
hidden_states
,
attention_mask
=
None
,
attention_mask
=
None
,
output_attentions
=
False
,
output_attentions
=
False
,
output_hidden_states
=
False
,
output_hidden_states
=
False
,
return_dict
=
True
,
return_dict
=
True
,
):
):
all_hidden_states
=
()
if
output_hidden_states
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attentions
=
()
if
output_attentions
else
None
all_self_attentions
=
()
if
output_attentions
else
None
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
# make sure padded tokens output 0
# make sure padded tokens output 0
expand_attention_mask
=
attention_mask
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
hidden_states
.
shape
[
2
])
expand_attention_mask
=
attention_mask
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
hidden_states
.
shape
[
2
])
hidden_states
[
~
expand_attention_mask
]
=
0
hidden_states
[
~
expand_attention_mask
]
=
0
# extend attention_mask
# extend attention_mask
attention_mask
=
1.0
-
attention_mask
[:,
None
,
None
,
:].
to
(
dtype
=
hidden_states
.
dtype
)
attention_mask
=
1.0
-
attention_mask
[:,
None
,
None
,
:].
to
(
dtype
=
hidden_states
.
dtype
)
attention_mask
=
attention_mask
*
np
.
iinfo
(
np
.
float32
).
min
attention_mask
=
attention_mask
*
np
.
iinfo
(
np
.
float32
).
min
attention_mask
=
attention_mask
.
expand
(
attention_mask
=
attention_mask
.
expand
(
attention_mask
.
shape
[
0
],
1
,
attention_mask
.
shape
[
0
],
1
,
attention_mask
.
shape
[
-
1
],
attention_mask
.
shape
[
-
1
]
attention_mask
.
shape
[
-
1
],
)
attention_mask
.
shape
[
-
1
]
)
position_embeddings
=
self
.
pos_conv_embed
(
hidden_states
)
position_embeddings
=
self
.
pos_conv_embed
(
hidden_states
)
hidden_states
=
hidden_states
+
position_embeddings
hidden_states
=
hidden_states
+
position_embeddings
...
@@ -669,13 +697,14 @@ class Wav2Vec2Encoder(nn.Layer):
...
@@ -669,13 +697,14 @@ class Wav2Vec2Encoder(nn.Layer):
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
if
output_hidden_states
:
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,
)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability
=
np
.
random
.
uniform
(
0
,
1
)
dropout_probability
=
np
.
random
.
uniform
(
0
,
1
)
skip_the_layer
=
True
if
self
.
training
and
(
dropout_probability
<
self
.
config
.
layerdrop
)
else
False
skip_the_layer
=
True
if
self
.
training
and
(
if
not
skip_the_layer
:
# or deepspeed_zero3_is_enabled:
dropout_probability
<
self
.
config
.
layerdrop
)
else
False
if
not
skip_the_layer
:
# or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
# under deepspeed zero3 all gpus must run in sync
if
self
.
gradient_checkpointing
and
self
.
training
:
if
self
.
gradient_checkpointing
and
self
.
training
:
# create gradient checkpointing function
# create gradient checkpointing function
...
@@ -686,26 +715,30 @@ class Wav2Vec2Encoder(nn.Layer):
...
@@ -686,26 +715,30 @@ class Wav2Vec2Encoder(nn.Layer):
return
custom_forward
return
custom_forward
else
:
else
:
layer_outputs
=
layer
(
layer_outputs
=
layer
(
hidden_states
,
attention_mask
=
attention_mask
,
output_attentions
=
output_attentions
hidden_states
,
)
attention_mask
=
attention_mask
,
output_attentions
=
output_attentions
)
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
layer_outputs
[
0
]
if
skip_the_layer
:
if
skip_the_layer
:
layer_outputs
=
(
None
,
None
)
layer_outputs
=
(
None
,
None
)
if
output_attentions
:
if
output_attentions
:
all_self_attentions
=
all_self_attentions
+
(
layer_outputs
[
1
],)
all_self_attentions
=
all_self_attentions
+
(
layer_outputs
[
1
],
)
if
output_hidden_states
:
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,
)
if
not
return_dict
:
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
all_hidden_states
,
all_self_attentions
]
if
v
is
not
None
)
return
tuple
(
v
for
v
in
[
hidden_states
,
all_hidden_states
,
all_self_attentions
]
if
v
is
not
None
)
return
BaseModelOutput
(
return
BaseModelOutput
(
last_hidden_state
=
hidden_states
,
last_hidden_state
=
hidden_states
,
hidden_states
=
all_hidden_states
,
hidden_states
=
all_hidden_states
,
attentions
=
all_self_attentions
,
attentions
=
all_self_attentions
,
)
)
class
Wav2Vec2EncoderStableLayerNorm
(
nn
.
Layer
):
class
Wav2Vec2EncoderStableLayerNorm
(
nn
.
Layer
):
...
@@ -713,35 +746,39 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Layer):
...
@@ -713,35 +746,39 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Layer):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
pos_conv_embed
=
Wav2Vec2PositionalConvEmbedding
(
config
)
self
.
pos_conv_embed
=
Wav2Vec2PositionalConvEmbedding
(
config
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
epsilon
=
config
.
layer_norm_eps
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
epsilon
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout
)
self
.
layers
=
nn
.
LayerList
(
self
.
layers
=
nn
.
LayerList
([
[
Wav2Vec2EncoderLayerStableLayerNorm
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)]
Wav2Vec2EncoderLayerStableLayerNorm
(
config
)
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
gradient_checkpointing
=
False
self
.
gradient_checkpointing
=
False
def
forward
(
def
forward
(
self
,
self
,
hidden_states
,
hidden_states
,
attention_mask
=
None
,
attention_mask
=
None
,
output_attentions
=
False
,
output_attentions
=
False
,
output_hidden_states
=
False
,
output_hidden_states
=
False
,
return_dict
=
True
,
return_dict
=
True
,
):
):
all_hidden_states
=
()
if
output_hidden_states
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attentions
=
()
if
output_attentions
else
None
all_self_attentions
=
()
if
output_attentions
else
None
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
# make sure padded tokens are not attended to
# make sure padded tokens are not attended to
expand_attention_mask
=
attention_mask
.
unsqueeze
(
-
1
).
repeat_interleave
(
hidden_states
.
shape
[
2
],
axis
=
2
)
expand_attention_mask
=
attention_mask
.
unsqueeze
(
-
1
).
repeat_interleave
(
hidden_states
.
shape
[
2
],
axis
=
2
)
hidden_states
[
~
expand_attention_mask
]
=
0
hidden_states
[
~
expand_attention_mask
]
=
0
# extend attention_mask
# extend attention_mask
attention_mask
=
1.0
-
attention_mask
[:,
None
,
None
,
:].
to
(
dtype
=
hidden_states
.
dtype
)
attention_mask
=
1.0
-
attention_mask
[:,
None
,
None
,
:].
to
(
dtype
=
hidden_states
.
dtype
)
attention_mask
=
attention_mask
*
np
.
iinfo
(
np
.
float32
).
min
attention_mask
=
attention_mask
*
np
.
iinfo
(
np
.
float32
).
min
attention_mask
=
attention_mask
.
expand
(
attention_mask
=
attention_mask
.
expand
(
attention_mask
.
shape
[
0
],
1
,
attention_mask
.
shape
[
0
],
1
,
attention_mask
.
shape
[
-
1
],
attention_mask
.
shape
[
-
1
]
attention_mask
.
shape
[
-
1
],
)
attention_mask
.
shape
[
-
1
]
)
position_embeddings
=
self
.
pos_conv_embed
(
hidden_states
)
position_embeddings
=
self
.
pos_conv_embed
(
hidden_states
)
hidden_states
=
hidden_states
+
position_embeddings
hidden_states
=
hidden_states
+
position_embeddings
...
@@ -749,13 +786,14 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Layer):
...
@@ -749,13 +786,14 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Layer):
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
if
output_hidden_states
:
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,
)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability
=
np
.
random
.
uniform
(
0
,
1
)
dropout_probability
=
np
.
random
.
uniform
(
0
,
1
)
skip_the_layer
=
True
if
self
.
training
and
(
dropout_probability
<
self
.
config
.
layerdrop
)
else
False
skip_the_layer
=
True
if
self
.
training
and
(
if
not
skip_the_layer
:
# or deepspeed_zero3_is_enabled:
dropout_probability
<
self
.
config
.
layerdrop
)
else
False
if
not
skip_the_layer
:
# or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
# under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if
self
.
gradient_checkpointing
and
self
.
training
:
if
self
.
gradient_checkpointing
and
self
.
training
:
...
@@ -767,28 +805,32 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Layer):
...
@@ -767,28 +805,32 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Layer):
return
custom_forward
return
custom_forward
else
:
else
:
layer_outputs
=
layer
(
layer_outputs
=
layer
(
hidden_states
,
attention_mask
=
attention_mask
,
output_attentions
=
output_attentions
hidden_states
,
)
attention_mask
=
attention_mask
,
output_attentions
=
output_attentions
)
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
layer_outputs
[
0
]
if
skip_the_layer
:
if
skip_the_layer
:
layer_outputs
=
(
None
,
None
)
layer_outputs
=
(
None
,
None
)
if
output_attentions
:
if
output_attentions
:
all_self_attentions
=
all_self_attentions
+
(
layer_outputs
[
1
],)
all_self_attentions
=
all_self_attentions
+
(
layer_outputs
[
1
],
)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
if
output_hidden_states
:
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,
)
if
not
return_dict
:
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
all_hidden_states
,
all_self_attentions
]
if
v
is
not
None
)
return
tuple
(
v
for
v
in
[
hidden_states
,
all_hidden_states
,
all_self_attentions
]
if
v
is
not
None
)
return
BaseModelOutput
(
return
BaseModelOutput
(
last_hidden_state
=
hidden_states
,
last_hidden_state
=
hidden_states
,
hidden_states
=
all_hidden_states
,
hidden_states
=
all_hidden_states
,
attentions
=
all_self_attentions
,
attentions
=
all_self_attentions
,
)
)
class
Wav2Vec2GumbelVectorQuantizer
(
nn
.
Layer
):
class
Wav2Vec2GumbelVectorQuantizer
(
nn
.
Layer
):
...
@@ -810,9 +852,13 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Layer):
...
@@ -810,9 +852,13 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Layer):
# storage for codebook variables (codewords)
# storage for codebook variables (codewords)
self
.
codevectors
=
paddle
.
static
.
create_parameter
(
self
.
codevectors
=
paddle
.
static
.
create_parameter
(
shape
=
[
1
,
self
.
num_groups
*
self
.
num_vars
,
config
.
codevector_dim
//
self
.
num_groups
],
dtype
=
'float32'
shape
=
[
)
1
,
self
.
num_groups
*
self
.
num_vars
,
self
.
weight_proj
=
nn
.
Linear
(
config
.
conv_dim
[
-
1
],
self
.
num_groups
*
self
.
num_vars
)
config
.
codevector_dim
//
self
.
num_groups
],
dtype
=
'float32'
)
self
.
weight_proj
=
nn
.
Linear
(
config
.
conv_dim
[
-
1
],
self
.
num_groups
*
self
.
num_vars
)
# can be decayed for training
# can be decayed for training
self
.
temperature
=
2
self
.
temperature
=
2
...
@@ -826,7 +872,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Layer):
...
@@ -826,7 +872,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Layer):
else
:
else
:
marginal_probs
=
probs
.
mean
(
dim
=
0
)
marginal_probs
=
probs
.
mean
(
dim
=
0
)
perplexity
=
paddle
.
exp
(
-
paddle
.
sum
(
marginal_probs
*
paddle
.
log
(
marginal_probs
+
1e-7
),
dim
=-
1
)).
sum
()
perplexity
=
paddle
.
exp
(
-
paddle
.
sum
(
marginal_probs
*
paddle
.
log
(
marginal_probs
+
1e-7
),
dim
=-
1
)).
sum
()
return
perplexity
return
perplexity
def
forward
(
self
,
hidden_states
,
mask_time_indices
=
None
):
def
forward
(
self
,
hidden_states
,
mask_time_indices
=
None
):
...
@@ -834,35 +881,45 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Layer):
...
@@ -834,35 +881,45 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Layer):
# project to codevector dim
# project to codevector dim
hidden_states
=
self
.
weight_proj
(
hidden_states
)
hidden_states
=
self
.
weight_proj
(
hidden_states
)
hidden_states
=
hidden_states
.
reshape
((
batch_size
*
sequence_length
*
self
.
num_groups
,
-
1
))
hidden_states
=
hidden_states
.
reshape
(
(
batch_size
*
sequence_length
*
self
.
num_groups
,
-
1
))
if
self
.
training
:
if
self
.
training
:
# sample code vector probs via gumbel in differentiateable way
# sample code vector probs via gumbel in differentiateable way
codevector_probs
=
nn
.
functional
.
gumbel_softmax
(
codevector_probs
=
nn
.
functional
.
gumbel_softmax
(
hidden_states
.
float
(),
tau
=
self
.
temperature
,
hard
=
True
hidden_states
.
float
(),
tau
=
self
.
temperature
,
).
type_as
(
hidden_states
)
hard
=
True
).
type_as
(
hidden_states
)
# compute perplexity
# compute perplexity
codevector_soft_dist
=
paddle
.
softmax
(
codevector_soft_dist
=
paddle
.
softmax
(
hidden_states
.
reshape
((
batch_size
*
sequence_length
,
self
.
num_groups
,
-
1
)).
float
(),
axis
=-
1
hidden_states
.
reshape
((
batch_size
*
sequence_length
,
)
self
.
num_groups
,
-
1
)).
float
(),
perplexity
=
self
.
_compute_perplexity
(
codevector_soft_dist
,
mask_time_indices
)
axis
=-
1
)
perplexity
=
self
.
_compute_perplexity
(
codevector_soft_dist
,
mask_time_indices
)
else
:
else
:
# take argmax in non-differentiable way
# take argmax in non-differentiable way
# comptute hard codevector distribution (one hot)
# comptute hard codevector distribution (one hot)
codevector_idx
=
hidden_states
.
argmax
(
dim
=-
1
)
codevector_idx
=
hidden_states
.
argmax
(
dim
=-
1
)
codevector_probs
=
hidden_states
.
new_zeros
(
*
hidden_states
.
shape
).
scatter_
(
codevector_probs
=
hidden_states
.
new_zeros
(
-
1
,
codevector_idx
.
reshape
((
-
1
,
1
)),
1.0
*
hidden_states
.
shape
).
scatter_
(
-
1
,
)
codevector_idx
.
reshape
((
-
1
,
1
)),
codevector_probs
=
codevector_probs
.
reshape
((
batch_size
*
sequence_length
,
self
.
num_groups
,
-
1
))
1.0
)
codevector_probs
=
codevector_probs
.
reshape
(
perplexity
=
self
.
_compute_perplexity
(
codevector_probs
,
mask_time_indices
)
(
batch_size
*
sequence_length
,
self
.
num_groups
,
-
1
))
codevector_probs
=
codevector_probs
.
reshape
((
batch_size
*
sequence_length
,
-
1
))
perplexity
=
self
.
_compute_perplexity
(
codevector_probs
,
mask_time_indices
)
codevector_probs
=
codevector_probs
.
reshape
(
(
batch_size
*
sequence_length
,
-
1
))
# use probs to retrieve codevectors
# use probs to retrieve codevectors
codevectors_per_group
=
codevector_probs
.
unsqueeze
(
-
1
)
*
self
.
codevectors
codevectors_per_group
=
codevector_probs
.
unsqueeze
(
codevectors
=
codevectors_per_group
.
reshape
((
batch_size
*
sequence_length
,
self
.
num_groups
,
self
.
num_vars
,
-
1
))
-
1
)
*
self
.
codevectors
codevectors
=
codevectors
.
sum
(
-
2
).
reshape
((
batch_size
,
sequence_length
,
-
1
))
codevectors
=
codevectors_per_group
.
reshape
(
(
batch_size
*
sequence_length
,
self
.
num_groups
,
self
.
num_vars
,
-
1
))
codevectors
=
codevectors
.
sum
(
-
2
).
reshape
(
(
batch_size
,
sequence_length
,
-
1
))
return
codevectors
,
perplexity
return
codevectors
,
perplexity
...
@@ -878,7 +935,9 @@ class Wav2Vec2Adapter(nn.Layer):
...
@@ -878,7 +935,9 @@ class Wav2Vec2Adapter(nn.Layer):
else
:
else
:
self
.
proj
=
self
.
proj_layer_norm
=
None
self
.
proj
=
self
.
proj_layer_norm
=
None
self
.
layers
=
nn
.
LayerList
(
Wav2Vec2AdapterLayer
(
config
)
for
_
in
range
(
config
.
num_adapter_layers
))
self
.
layers
=
nn
.
LayerList
(
Wav2Vec2AdapterLayer
(
config
)
for
_
in
range
(
config
.
num_adapter_layers
))
self
.
layerdrop
=
config
.
layerdrop
self
.
layerdrop
=
config
.
layerdrop
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
...
@@ -906,8 +965,7 @@ class Wav2Vec2AdapterLayer(nn.Layer):
...
@@ -906,8 +965,7 @@ class Wav2Vec2AdapterLayer(nn.Layer):
2
*
config
.
output_hidden_size
,
2
*
config
.
output_hidden_size
,
config
.
adapter_kernel_size
,
config
.
adapter_kernel_size
,
stride
=
config
.
adapter_stride
,
stride
=
config
.
adapter_stride
,
padding
=
1
,
padding
=
1
,
)
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
conv
(
hidden_states
)
hidden_states
=
self
.
conv
(
hidden_states
)
...
@@ -916,7 +974,7 @@ class Wav2Vec2AdapterLayer(nn.Layer):
...
@@ -916,7 +974,7 @@ class Wav2Vec2AdapterLayer(nn.Layer):
return
hidden_states
return
hidden_states
class
Wav2Vec2Model
(
nn
.
Layer
):
class
Wav2Vec2Model
(
nn
.
Layer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -925,9 +983,13 @@ class Wav2Vec2Model(nn.Layer):
...
@@ -925,9 +983,13 @@ class Wav2Vec2Model(nn.Layer):
# model only needs masking vector if mask prob is > 0.0
# model only needs masking vector if mask prob is > 0.0
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
# self.masked_spec_embed = nn.Parameter(paddle.Tensor(config.hidden_size).uniform_())
# self.masked_spec_embed = nn.Parameter(paddle.Tensor(config.hidden_size).uniform_())
#self.masked_spec_embed = paddle.uniform([config.hidden_size])
#self.masked_spec_embed = paddle.uniform([config.hidden_size])
self
.
masked_spec_embed
=
paddle
.
static
.
create_parameter
(
shape
=
[
config
.
hidden_size
],
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
Uniform
(
low
=
0
,
high
=
1.0
))
self
.
masked_spec_embed
=
paddle
.
static
.
create_parameter
(
shape
=
[
config
.
hidden_size
],
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
Uniform
(
low
=
0
,
high
=
1.0
))
if
config
.
do_stable_layer_norm
:
if
config
.
do_stable_layer_norm
:
self
.
encoder
=
Wav2Vec2EncoderStableLayerNorm
(
config
)
self
.
encoder
=
Wav2Vec2EncoderStableLayerNorm
(
config
)
else
:
else
:
...
@@ -946,11 +1008,10 @@ class Wav2Vec2Model(nn.Layer):
...
@@ -946,11 +1008,10 @@ class Wav2Vec2Model(nn.Layer):
self
.
feature_extractor
.
_freeze_parameters
()
self
.
feature_extractor
.
_freeze_parameters
()
def
_mask_hidden_states
(
def
_mask_hidden_states
(
self
,
self
,
hidden_states
:
paddle
.
Tensor
,
hidden_states
:
paddle
.
Tensor
,
mask_time_indices
:
Optional
[
paddle
.
Tensor
]
=
None
,
mask_time_indices
:
Optional
[
paddle
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
paddle
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
paddle
.
Tensor
]
=
None
,
):
):
"""
"""
Masks extracted features along time axis and/or along feature axis according to
Masks extracted features along time axis and/or along feature axis according to
[SpecAugment](https://arxiv.org/abs/1904.08779).
[SpecAugment](https://arxiv.org/abs/1904.08779).
...
@@ -963,17 +1024,19 @@ class Wav2Vec2Model(nn.Layer):
...
@@ -963,17 +1024,19 @@ class Wav2Vec2Model(nn.Layer):
batch_size
,
sequence_length
,
hidden_size
=
hidden_states
.
shape
batch_size
,
sequence_length
,
hidden_size
=
hidden_states
.
shape
if
mask_time_indices
is
not
None
:
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
(
batch_size
,
sequence_length
),
mask_prob
=
self
.
config
.
mask_time_prob
,
mask_prob
=
self
.
config
.
mask_time_prob
,
mask_length
=
self
.
config
.
mask_time_length
,
mask_length
=
self
.
config
.
mask_time_length
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
min_masks
=
self
.
config
.
mask_time_min_masks
,
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
)
mask_time_indices
=
paddle
.
to_tensor
(
mask_time_indices
=
paddle
.
to_tensor
(
mask_time_indices
,
dtype
=
paddle
.
bool
)
mask_time_indices
,
dtype
=
paddle
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
# generate indices & apply SpecAugment along feature axis
...
@@ -981,27 +1044,28 @@ class Wav2Vec2Model(nn.Layer):
...
@@ -981,27 +1044,28 @@ class Wav2Vec2Model(nn.Layer):
(
batch_size
,
hidden_size
),
(
batch_size
,
hidden_size
),
mask_prob
=
self
.
config
.
mask_feature_prob
,
mask_prob
=
self
.
config
.
mask_feature_prob
,
mask_length
=
self
.
config
.
mask_feature_length
,
mask_length
=
self
.
config
.
mask_feature_length
,
min_masks
=
self
.
config
.
mask_feature_min_masks
,
min_masks
=
self
.
config
.
mask_feature_min_masks
,
)
)
mask_feature_indices
=
paddle
.
to_tensor
(
mask_feature_indices
=
paddle
.
to_tensor
(
mask_feature_indices
,
dtype
=
paddle
.
bool
)
mask_feature_indices
,
dtype
=
paddle
.
bool
)
mask_feature_indices
=
mask_feature_indices
[:,
None
].
expand
(
-
1
,
sequence_length
,
-
1
)
mask_feature_indices
=
mask_feature_indices
[:,
None
].
expand
(
-
1
,
sequence_length
,
-
1
)
hidden_states
[
mask_feature_indices
]
=
0
hidden_states
[
mask_feature_indices
]
=
0
return
hidden_states
return
hidden_states
def
forward
(
def
forward
(
self
,
self
,
input_values
:
Optional
[
paddle
.
Tensor
],
input_values
:
Optional
[
paddle
.
Tensor
],
attention_mask
:
Optional
[
paddle
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
paddle
.
Tensor
]
=
None
,
mask_time_indices
:
Optional
[
paddle
.
Tensor
]
=
None
,
mask_time_indices
:
Optional
[
paddle
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
Wav2Vec2BaseModelOutput
]:
)
->
Union
[
Tuple
,
Wav2Vec2BaseModelOutput
]:
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
=
(
output_hidden_states
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
if
output_hidden_states
is
not
None
else
)
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
extract_features
=
self
.
feature_extractor
(
input_values
)
extract_features
=
self
.
feature_extractor
(
input_values
)
extract_features
=
extract_features
.
transpose
([
0
,
2
,
1
])
extract_features
=
extract_features
.
transpose
([
0
,
2
,
1
])
...
@@ -1009,20 +1073,20 @@ class Wav2Vec2Model(nn.Layer):
...
@@ -1009,20 +1073,20 @@ class Wav2Vec2Model(nn.Layer):
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
# compute reduced attention_mask corresponding to feature vectors
# compute reduced attention_mask corresponding to feature vectors
attention_mask
=
self
.
_get_feature_vector_attention_mask
(
attention_mask
=
self
.
_get_feature_vector_attention_mask
(
extract_features
.
shape
[
1
],
attention_mask
,
add_adapter
=
False
extract_features
.
shape
[
1
],
attention_mask
,
add_adapter
=
False
)
)
hidden_states
,
extract_features
=
self
.
feature_projection
(
hidden_states
,
extract_features
=
self
.
feature_projection
(
extract_features
)
extract_features
)
hidden_states
=
self
.
_mask_hidden_states
(
hidden_states
=
self
.
_mask_hidden_states
(
hidden_states
,
mask_time_indices
=
mask_time_indices
,
attention_mask
=
attention_mask
hidden_states
,
)
mask_time_indices
=
mask_time_indices
,
attention_mask
=
attention_mask
)
encoder_outputs
=
self
.
encoder
(
encoder_outputs
=
self
.
encoder
(
hidden_states
,
hidden_states
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
)
)
hidden_states
=
encoder_outputs
[
0
]
hidden_states
=
encoder_outputs
[
0
]
...
@@ -1036,20 +1100,21 @@ class Wav2Vec2Model(nn.Layer):
...
@@ -1036,20 +1100,21 @@ class Wav2Vec2Model(nn.Layer):
last_hidden_state
=
hidden_states
,
last_hidden_state
=
hidden_states
,
extract_features
=
extract_features
,
extract_features
=
extract_features
,
hidden_states
=
encoder_outputs
.
hidden_states
,
hidden_states
=
encoder_outputs
.
hidden_states
,
attentions
=
encoder_outputs
.
attentions
,
attentions
=
encoder_outputs
.
attentions
,
)
)
def
post_init
(
self
):
def
post_init
(
self
):
"""
"""
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
modules properly initialized (such as weight initialization).
modules properly initialized (such as weight initialization).
"""
"""
# self.init_weights()
# self.init_weights()
# self._backward_compatibility_gradient_checkpointing()
# self._backward_compatibility_gradient_checkpointing()
pass
pass
class
Wav2Vec2ConfigPure
():
class
Wav2Vec2ConfigPure
():
model_type
=
"wav2vec2"
model_type
=
"wav2vec2"
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
self
.
output_attentions
=
False
self
.
output_attentions
=
False
self
.
output_hidden_states
=
False
self
.
output_hidden_states
=
False
...
@@ -1084,17 +1149,14 @@ class Wav2Vec2ConfigPure():
...
@@ -1084,17 +1149,14 @@ class Wav2Vec2ConfigPure():
self
.
do_stable_layer_norm
=
config
.
do_stable_layer_norm
self
.
do_stable_layer_norm
=
config
.
do_stable_layer_norm
self
.
use_weighted_layer_sum
=
config
.
use_weighted_layer_sum
self
.
use_weighted_layer_sum
=
config
.
use_weighted_layer_sum
if
(
if
((
len
(
self
.
conv_stride
)
!=
self
.
num_feat_extract_layers
)
or
(
len
(
self
.
conv_stride
)
!=
self
.
num_feat_extract_layers
)
(
len
(
self
.
conv_kernel
)
!=
self
.
num_feat_extract_layers
)
or
or
(
len
(
self
.
conv_kernel
)
!=
self
.
num_feat_extract_layers
)
(
len
(
self
.
conv_dim
)
!=
self
.
num_feat_extract_layers
)):
or
(
len
(
self
.
conv_dim
)
!=
self
.
num_feat_extract_layers
)
):
raise
ValueError
(
raise
ValueError
(
"Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
"Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
" `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
" `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
f
"
{
len
(
self
.
conv_dim
)
}
`, `len(config.conv_stride) =
{
len
(
self
.
conv_stride
)
}
`,"
f
"
{
len
(
self
.
conv_dim
)
}
`, `len(config.conv_stride) =
{
len
(
self
.
conv_stride
)
}
`,"
f
" `len(config.conv_kernel) =
{
len
(
self
.
conv_kernel
)
}
`."
f
" `len(config.conv_kernel) =
{
len
(
self
.
conv_kernel
)
}
`."
)
)
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
self
.
apply_spec_augment
=
config
.
apply_spec_augment
self
.
apply_spec_augment
=
config
.
apply_spec_augment
...
...
paddlespeech/s2t/models/wav2vec2/processing/signal_processing.py
浏览文件 @
19180d35
...
@@ -7,10 +7,8 @@ Authors
...
@@ -7,10 +7,8 @@ Authors
* Samuele Cornell 2020
* Samuele Cornell 2020
* Sarthak Yadav 2022
* Sarthak Yadav 2022
"""
"""
import
paddle
import
math
from
packaging
import
version
import
numpy
as
np
import
numpy
as
np
import
paddle
def
blackman_window
(
window_length
,
periodic
=
True
):
def
blackman_window
(
window_length
,
periodic
=
True
):
...
@@ -90,15 +88,14 @@ def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
...
@@ -90,15 +88,14 @@ def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
def
convolve1d
(
def
convolve1d
(
waveform
,
waveform
,
kernel
,
kernel
,
padding
=
0
,
padding
=
0
,
pad_type
=
"constant"
,
pad_type
=
"constant"
,
stride
=
1
,
stride
=
1
,
groups
=
1
,
groups
=
1
,
use_fft
=
False
,
use_fft
=
False
,
rotation_index
=
0
,
rotation_index
=
0
,
):
):
"""Use paddle.nn.functional to perform 1d padding and conv.
"""Use paddle.nn.functional to perform 1d padding and conv.
Arguments
Arguments
---------
---------
...
@@ -150,8 +147,7 @@ def convolve1d(
...
@@ -150,8 +147,7 @@ def convolve1d(
# Padding can be a tuple (left_pad, right_pad) or an int
# Padding can be a tuple (left_pad, right_pad) or an int
if
isinstance
(
padding
,
tuple
):
if
isinstance
(
padding
,
tuple
):
waveform
=
paddle
.
nn
.
functional
.
pad
(
waveform
=
paddle
.
nn
.
functional
.
pad
(
x
=
waveform
,
pad
=
padding
,
mode
=
pad_type
,
data_format
=
'NCL'
x
=
waveform
,
pad
=
padding
,
mode
=
pad_type
,
data_format
=
'NCL'
)
)
# This approach uses FFT, which is more efficient if the kernel is large
# This approach uses FFT, which is more efficient if the kernel is large
if
use_fft
:
if
use_fft
:
...
@@ -165,9 +161,7 @@ def convolve1d(
...
@@ -165,9 +161,7 @@ def convolve1d(
# Perform rotation to ensure alignment
# Perform rotation to ensure alignment
zeros
=
paddle
.
zeros
(
zeros
=
paddle
.
zeros
(
[
kernel
.
shape
[
0
],
kernel
.
shape
[
1
],
zero_length
],
[
kernel
.
shape
[
0
],
kernel
.
shape
[
1
],
zero_length
],
dtype
=
kernel
.
dtype
)
dtype
=
kernel
.
dtype
)
after_index
=
kernel
[...,
rotation_index
:]
after_index
=
kernel
[...,
rotation_index
:]
before_index
=
kernel
[...,
:
rotation_index
]
before_index
=
kernel
[...,
:
rotation_index
]
kernel
=
paddle
.
concat
((
after_index
,
zeros
,
before_index
),
axis
=-
1
)
kernel
=
paddle
.
concat
((
after_index
,
zeros
,
before_index
),
axis
=-
1
)
...
@@ -185,12 +179,12 @@ def convolve1d(
...
@@ -185,12 +179,12 @@ def convolve1d(
weight
=
kernel
,
weight
=
kernel
,
stride
=
stride
,
stride
=
stride
,
groups
=
groups
,
groups
=
groups
,
padding
=
padding
if
not
isinstance
(
padding
,
tuple
)
else
0
,
padding
=
padding
if
not
isinstance
(
padding
,
tuple
)
else
0
,
)
)
# Return time dimension to the second dimension.
# Return time dimension to the second dimension.
return
convolved
.
transpose
([
0
,
2
,
1
])
return
convolved
.
transpose
([
0
,
2
,
1
])
def
notch_filter
(
notch_freq
,
filter_width
=
101
,
notch_width
=
0.05
):
def
notch_filter
(
notch_freq
,
filter_width
=
101
,
notch_width
=
0.05
):
"""Returns a notch filter constructed from a high-pass and low-pass filter.
"""Returns a notch filter constructed from a high-pass and low-pass filter.
(from https://tomroelandts.com/articles/
(from https://tomroelandts.com/articles/
...
@@ -224,7 +218,8 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
...
@@ -224,7 +218,8 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
return
paddle
.
sin
(
x
)
/
x
return
paddle
.
sin
(
x
)
/
x
# The zero is at the middle index
# The zero is at the middle index
return
paddle
.
concat
([
_sinc
(
x
[:
pad
]),
paddle
.
ones
([
1
]),
_sinc
(
x
[
pad
+
1
:])])
return
paddle
.
concat
(
[
_sinc
(
x
[:
pad
]),
paddle
.
ones
([
1
]),
_sinc
(
x
[
pad
+
1
:])])
# Compute a low-pass filter with cutoff frequency notch_freq.
# Compute a low-pass filter with cutoff frequency notch_freq.
hlpf
=
sinc
(
3
*
(
notch_freq
-
notch_width
)
*
inputs
)
hlpf
=
sinc
(
3
*
(
notch_freq
-
notch_width
)
*
inputs
)
...
@@ -239,4 +234,3 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
...
@@ -239,4 +234,3 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
# Adding filters creates notch filter
# Adding filters creates notch filter
return
(
hlpf
+
hhpf
).
view
(
1
,
-
1
,
1
)
return
(
hlpf
+
hhpf
).
view
(
1
,
-
1
,
1
)
paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py
浏览文件 @
19180d35
import
math
import
math
import
paddle
import
paddle
import
paddle.nn
as
nn
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddlespeech.s2t.models.wav2vec2.processing.signal_processing
import
(
from
paddlespeech.s2t.models.wav2vec2.processing.signal_processing
import
compute_amplitude
compute_amplitude
,
from
paddlespeech.s2t.models.wav2vec2.processing.signal_processing
import
convolve1d
convolve1d
,
from
paddlespeech.s2t.models.wav2vec2.processing.signal_processing
import
notch_filter
notch_filter
)
class
SpeedPerturb
(
nn
.
Layer
):
class
SpeedPerturb
(
nn
.
Layer
):
"""Slightly speed up or slow down an audio signal.
"""Slightly speed up or slow down an audio signal.
...
@@ -36,8 +37,10 @@ class SpeedPerturb(nn.Layer):
...
@@ -36,8 +37,10 @@ class SpeedPerturb(nn.Layer):
"""
"""
def
__init__
(
def
__init__
(
self
,
orig_freq
,
speeds
=
[
90
,
100
,
110
],
perturb_prob
=
1.0
,
self
,
):
orig_freq
,
speeds
=
[
90
,
100
,
110
],
perturb_prob
=
1.0
,
):
super
().
__init__
()
super
().
__init__
()
self
.
orig_freq
=
orig_freq
self
.
orig_freq
=
orig_freq
self
.
speeds
=
speeds
self
.
speeds
=
speeds
...
@@ -70,14 +73,15 @@ class SpeedPerturb(nn.Layer):
...
@@ -70,14 +73,15 @@ class SpeedPerturb(nn.Layer):
# Don't perturb (return early) 1-`perturb_prob` portion of the batches
# Don't perturb (return early) 1-`perturb_prob` portion of the batches
if
paddle
.
rand
([
1
])
>
self
.
perturb_prob
:
if
paddle
.
rand
([
1
])
>
self
.
perturb_prob
:
return
waveform
.
clone
()
return
waveform
.
clone
()
# Perform a random perturbation
# Perform a random perturbation
self
.
samp_index
=
paddle
.
randint
(
len
(
self
.
speeds
),
shape
=
(
1
,))[
0
]
self
.
samp_index
=
paddle
.
randint
(
len
(
self
.
speeds
),
shape
=
(
1
,
))[
0
]
perturbed_waveform
=
self
.
resamplers
[
self
.
samp_index
](
waveform
)
perturbed_waveform
=
self
.
resamplers
[
self
.
samp_index
](
waveform
)
return
perturbed_waveform
return
perturbed_waveform
class
Resample
(
nn
.
Layer
):
class
Resample
(
nn
.
Layer
):
"""This class resamples an audio signal using sinc-based interpolation.
"""This class resamples an audio signal using sinc-based interpolation.
...
@@ -94,9 +98,12 @@ class Resample(nn.Layer):
...
@@ -94,9 +98,12 @@ class Resample(nn.Layer):
Controls the sharpness of the filter, larger numbers result in a
Controls the sharpness of the filter, larger numbers result in a
sharper filter, but they are less efficient. Values from 4 to 10 are allowed.
sharper filter, but they are less efficient. Values from 4 to 10 are allowed.
"""
"""
def
__init__
(
def
__init__
(
self
,
orig_freq
=
16000
,
new_freq
=
16000
,
lowpass_filter_width
=
6
,
self
,
):
orig_freq
=
16000
,
new_freq
=
16000
,
lowpass_filter_width
=
6
,
):
super
().
__init__
()
super
().
__init__
()
self
.
orig_freq
=
orig_freq
self
.
orig_freq
=
orig_freq
self
.
new_freq
=
new_freq
self
.
new_freq
=
new_freq
...
@@ -193,8 +200,7 @@ class Resample(nn.Layer):
...
@@ -193,8 +200,7 @@ class Resample(nn.Layer):
window_size
=
self
.
weights
.
shape
[
1
]
window_size
=
self
.
weights
.
shape
[
1
]
tot_output_samp
=
self
.
_output_samples
(
wave_len
)
tot_output_samp
=
self
.
_output_samples
(
wave_len
)
resampled_waveform
=
paddle
.
zeros
(
resampled_waveform
=
paddle
.
zeros
(
(
batch_size
,
num_channels
,
tot_output_samp
)
(
batch_size
,
num_channels
,
tot_output_samp
))
)
# self.weights = self.weights.to(waveforms.device)
# self.weights = self.weights.to(waveforms.device)
# Check weights are on correct device
# Check weights are on correct device
...
@@ -222,28 +228,25 @@ class Resample(nn.Layer):
...
@@ -222,28 +228,25 @@ class Resample(nn.Layer):
right_padding
=
max
(
0
,
end_index
+
1
-
current_wave_len
)
right_padding
=
max
(
0
,
end_index
+
1
-
current_wave_len
)
left_padding
=
max
(
0
,
-
first_index
)
left_padding
=
max
(
0
,
-
first_index
)
wave_to_conv
=
paddle
.
nn
.
functional
.
pad
(
wave_to_conv
=
paddle
.
nn
.
functional
.
pad
(
wave_to_conv
,
(
left_padding
,
right_padding
),
data_format
=
'NCL'
wave_to_conv
,
(
left_padding
,
right_padding
),
data_format
=
'NCL'
)
)
conv_wave
=
paddle
.
nn
.
functional
.
conv1d
(
conv_wave
=
paddle
.
nn
.
functional
.
conv1d
(
x
=
wave_to_conv
,
x
=
wave_to_conv
,
weight
=
self
.
weights
[
i
].
repeat
(
num_channels
,
1
,
1
),
weight
=
self
.
weights
[
i
].
repeat
(
num_channels
,
1
,
1
),
stride
=
self
.
conv_stride
,
stride
=
self
.
conv_stride
,
groups
=
num_channels
,
groups
=
num_channels
,
)
)
# we want conv_wave[:, i] to be at
# we want conv_wave[:, i] to be at
# output[:, i + n*conv_transpose_stride]
# output[:, i + n*conv_transpose_stride]
dilated_conv_wave
=
paddle
.
nn
.
functional
.
conv1d_transpose
(
dilated_conv_wave
=
paddle
.
nn
.
functional
.
conv1d_transpose
(
conv_wave
,
eye
,
stride
=
self
.
conv_transpose_stride
conv_wave
,
eye
,
stride
=
self
.
conv_transpose_stride
)
)
# pad dilated_conv_wave so it reaches the output length if needed.
# pad dilated_conv_wave so it reaches the output length if needed.
left_padding
=
i
left_padding
=
i
previous_padding
=
left_padding
+
dilated_conv_wave
.
shape
[
-
1
]
previous_padding
=
left_padding
+
dilated_conv_wave
.
shape
[
-
1
]
right_padding
=
max
(
0
,
tot_output_samp
-
previous_padding
)
right_padding
=
max
(
0
,
tot_output_samp
-
previous_padding
)
dilated_conv_wave
=
paddle
.
nn
.
functional
.
pad
(
dilated_conv_wave
=
paddle
.
nn
.
functional
.
pad
(
dilated_conv_wave
,
(
left_padding
,
right_padding
),
data_format
=
'NCL'
dilated_conv_wave
,
(
left_padding
,
right_padding
),
)
data_format
=
'NCL'
)
dilated_conv_wave
=
dilated_conv_wave
[...,
:
tot_output_samp
]
dilated_conv_wave
=
dilated_conv_wave
[...,
:
tot_output_samp
]
resampled_waveform
+=
dilated_conv_wave
resampled_waveform
+=
dilated_conv_wave
...
@@ -326,9 +329,7 @@ class Resample(nn.Layer):
...
@@ -326,9 +329,7 @@ class Resample(nn.Layer):
window_width
=
self
.
lowpass_filter_width
/
(
2.0
*
lowpass_cutoff
)
window_width
=
self
.
lowpass_filter_width
/
(
2.0
*
lowpass_cutoff
)
assert
lowpass_cutoff
<
min
(
self
.
orig_freq
,
self
.
new_freq
)
/
2
assert
lowpass_cutoff
<
min
(
self
.
orig_freq
,
self
.
new_freq
)
/
2
output_t
=
paddle
.
arange
(
output_t
=
paddle
.
arange
(
start
=
0.0
,
end
=
self
.
output_samples
)
start
=
0.0
,
end
=
self
.
output_samples
)
output_t
/=
self
.
new_freq
output_t
/=
self
.
new_freq
min_t
=
output_t
-
window_width
min_t
=
output_t
-
window_width
max_t
=
output_t
+
window_width
max_t
=
output_t
+
window_width
...
@@ -346,23 +347,16 @@ class Resample(nn.Layer):
...
@@ -346,23 +347,16 @@ class Resample(nn.Layer):
inside_window_indices
=
delta_t
.
abs
()
<
(
window_width
)
inside_window_indices
=
delta_t
.
abs
()
<
(
window_width
)
# raised-cosine (Hanning) window with width `window_width`
# raised-cosine (Hanning) window with width `window_width`
weights
[
inside_window_indices
]
=
0.5
*
(
weights
[
inside_window_indices
]
=
0.5
*
(
1
+
paddle
.
cos
(
1
2
*
math
.
pi
*
lowpass_cutoff
/
self
.
lowpass_filter_width
*
+
paddle
.
cos
(
delta_t
[
inside_window_indices
]))
2
*
math
.
pi
*
lowpass_cutoff
/
self
.
lowpass_filter_width
*
delta_t
[
inside_window_indices
]
)
)
t_eq_zero_indices
=
delta_t
==
0.0
t_eq_zero_indices
=
delta_t
==
0.0
t_not_eq_zero_indices
=
~
t_eq_zero_indices
t_not_eq_zero_indices
=
~
t_eq_zero_indices
# sinc filter function
# sinc filter function
weights
[
t_not_eq_zero_indices
]
*=
paddle
.
sin
(
weights
[
t_not_eq_zero_indices
]
*=
paddle
.
sin
(
2
*
math
.
pi
*
lowpass_cutoff
*
delta_t
[
t_not_eq_zero_indices
]
2
*
math
.
pi
*
lowpass_cutoff
*
delta_t
[
t_not_eq_zero_indices
]
)
/
(
)
/
(
math
.
pi
*
delta_t
[
t_not_eq_zero_indices
])
math
.
pi
*
delta_t
[
t_not_eq_zero_indices
])
# limit of the function at t = 0
# limit of the function at t = 0
weights
[
t_eq_zero_indices
]
*=
2
*
lowpass_cutoff
weights
[
t_eq_zero_indices
]
*=
2
*
lowpass_cutoff
...
@@ -405,14 +399,13 @@ class DropFreq(nn.Layer):
...
@@ -405,14 +399,13 @@ class DropFreq(nn.Layer):
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
drop_freq_low
=
1e-14
,
drop_freq_low
=
1e-14
,
drop_freq_high
=
1
,
drop_freq_high
=
1
,
drop_count_low
=
1
,
drop_count_low
=
1
,
drop_count_high
=
2
,
drop_count_high
=
2
,
drop_width
=
0.05
,
drop_width
=
0.05
,
drop_prob
=
1
,
drop_prob
=
1
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
drop_freq_low
=
drop_freq_low
self
.
drop_freq_low
=
drop_freq_low
self
.
drop_freq_high
=
drop_freq_high
self
.
drop_freq_high
=
drop_freq_high
...
@@ -443,14 +436,14 @@ class DropFreq(nn.Layer):
...
@@ -443,14 +436,14 @@ class DropFreq(nn.Layer):
# Pick number of frequencies to drop
# Pick number of frequencies to drop
drop_count
=
paddle
.
randint
(
drop_count
=
paddle
.
randint
(
low
=
self
.
drop_count_low
,
high
=
self
.
drop_count_high
+
1
,
shape
=
(
1
,),
low
=
self
.
drop_count_low
,
)
high
=
self
.
drop_count_high
+
1
,
shape
=
(
1
,
),
)
# Pick a frequency to drop
# Pick a frequency to drop
drop_range
=
self
.
drop_freq_high
-
self
.
drop_freq_low
drop_range
=
self
.
drop_freq_high
-
self
.
drop_freq_low
drop_frequency
=
(
drop_frequency
=
(
paddle
.
rand
(
drop_count
)
*
drop_range
+
self
.
drop_freq_low
paddle
.
rand
(
drop_count
)
*
drop_range
+
self
.
drop_freq_low
)
)
# Filter parameters
# Filter parameters
filter_length
=
101
filter_length
=
101
pad
=
filter_length
//
2
pad
=
filter_length
//
2
...
@@ -461,8 +454,9 @@ class DropFreq(nn.Layer):
...
@@ -461,8 +454,9 @@ class DropFreq(nn.Layer):
# Subtract each frequency
# Subtract each frequency
for
frequency
in
drop_frequency
:
for
frequency
in
drop_frequency
:
notch_kernel
=
notch_filter
(
notch_kernel
=
notch_filter
(
frequency
,
filter_length
,
self
.
drop_width
,
frequency
,
)
filter_length
,
self
.
drop_width
,
)
drop_filter
=
convolve1d
(
drop_filter
,
notch_kernel
,
pad
)
drop_filter
=
convolve1d
(
drop_filter
,
notch_kernel
,
pad
)
# Apply filter
# Apply filter
...
@@ -471,6 +465,7 @@ class DropFreq(nn.Layer):
...
@@ -471,6 +465,7 @@ class DropFreq(nn.Layer):
# Remove channels dimension if added
# Remove channels dimension if added
return
dropped_waveform
.
squeeze
(
-
1
)
return
dropped_waveform
.
squeeze
(
-
1
)
class
DropChunk
(
nn
.
Layer
):
class
DropChunk
(
nn
.
Layer
):
"""This class drops portions of the input signal.
"""This class drops portions of the input signal.
Using `DropChunk` as an augmentation strategy helps a models learn to rely
Using `DropChunk` as an augmentation strategy helps a models learn to rely
...
@@ -515,16 +510,15 @@ class DropChunk(nn.Layer):
...
@@ -515,16 +510,15 @@ class DropChunk(nn.Layer):
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
drop_length_low
=
100
,
drop_length_low
=
100
,
drop_length_high
=
1000
,
drop_length_high
=
1000
,
drop_count_low
=
1
,
drop_count_low
=
1
,
drop_count_high
=
10
,
drop_count_high
=
10
,
drop_start
=
0
,
drop_start
=
0
,
drop_end
=
None
,
drop_end
=
None
,
drop_prob
=
1
,
drop_prob
=
1
,
noise_factor
=
0.0
,
noise_factor
=
0.0
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
drop_length_low
=
drop_length_low
self
.
drop_length_low
=
drop_length_low
self
.
drop_length_high
=
drop_length_high
self
.
drop_length_high
=
drop_length_high
...
@@ -580,8 +574,7 @@ class DropChunk(nn.Layer):
...
@@ -580,8 +574,7 @@ class DropChunk(nn.Layer):
drop_times
=
paddle
.
randint
(
drop_times
=
paddle
.
randint
(
low
=
self
.
drop_count_low
,
low
=
self
.
drop_count_low
,
high
=
self
.
drop_count_high
+
1
,
high
=
self
.
drop_count_high
+
1
,
shape
=
(
batch_size
,),
shape
=
(
batch_size
,
),
)
)
# Iterate batch to set mask
# Iterate batch to set mask
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
...
@@ -592,8 +585,7 @@ class DropChunk(nn.Layer):
...
@@ -592,8 +585,7 @@ class DropChunk(nn.Layer):
length
=
paddle
.
randint
(
length
=
paddle
.
randint
(
low
=
self
.
drop_length_low
,
low
=
self
.
drop_length_low
,
high
=
self
.
drop_length_high
+
1
,
high
=
self
.
drop_length_high
+
1
,
shape
=
(
drop_times
[
i
],),
shape
=
(
drop_times
[
i
],
),
)
)
# Compute range of starting locations
# Compute range of starting locations
start_min
=
self
.
drop_start
start_min
=
self
.
drop_start
...
@@ -608,15 +600,16 @@ class DropChunk(nn.Layer):
...
@@ -608,15 +600,16 @@ class DropChunk(nn.Layer):
# Pick starting locations
# Pick starting locations
start
=
paddle
.
randint
(
start
=
paddle
.
randint
(
low
=
start_min
,
high
=
start_max
+
1
,
shape
=
(
drop_times
[
i
],),
low
=
start_min
,
)
high
=
start_max
+
1
,
shape
=
(
drop_times
[
i
],
),
)
end
=
start
+
length
end
=
start
+
length
# Update waveform
# Update waveform
if
not
self
.
noise_factor
:
if
not
self
.
noise_factor
:
for
j
in
range
(
drop_times
[
i
]):
for
j
in
range
(
drop_times
[
i
]):
dropped_waveform
[
i
,
start
[
j
]
:
end
[
j
]]
=
0.0
dropped_waveform
[
i
,
start
[
j
]
:
end
[
j
]]
=
0.0
else
:
else
:
# Uniform distribution of -2 to +2 * avg amplitude should
# Uniform distribution of -2 to +2 * avg amplitude should
# preserve the average for normalization
# preserve the average for normalization
...
@@ -625,7 +618,7 @@ class DropChunk(nn.Layer):
...
@@ -625,7 +618,7 @@ class DropChunk(nn.Layer):
# zero-center the noise distribution
# zero-center the noise distribution
noise_vec
=
paddle
.
rand
([
length
[
j
]])
noise_vec
=
paddle
.
rand
([
length
[
j
]])
noise_vec
=
2
*
noise_max
*
noise_vec
-
noise_max
noise_vec
=
2
*
noise_max
*
noise_vec
-
noise_max
dropped_waveform
[
i
,
start
[
j
]
:
end
[
j
]]
=
noise_vec
dropped_waveform
[
i
,
start
[
j
]
:
end
[
j
]]
=
noise_vec
return
dropped_waveform
return
dropped_waveform
...
@@ -679,37 +672,33 @@ class TimeDomainSpecAugment(nn.Layer):
...
@@ -679,37 +672,33 @@ class TimeDomainSpecAugment(nn.Layer):
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
perturb_prob
=
1.0
,
perturb_prob
=
1.0
,
drop_freq_prob
=
1.0
,
drop_freq_prob
=
1.0
,
drop_chunk_prob
=
1.0
,
drop_chunk_prob
=
1.0
,
speeds
=
[
95
,
100
,
105
],
speeds
=
[
95
,
100
,
105
],
sample_rate
=
16000
,
sample_rate
=
16000
,
drop_freq_count_low
=
0
,
drop_freq_count_low
=
0
,
drop_freq_count_high
=
3
,
drop_freq_count_high
=
3
,
drop_chunk_count_low
=
0
,
drop_chunk_count_low
=
0
,
drop_chunk_count_high
=
5
,
drop_chunk_count_high
=
5
,
drop_chunk_length_low
=
1000
,
drop_chunk_length_low
=
1000
,
drop_chunk_length_high
=
2000
,
drop_chunk_length_high
=
2000
,
drop_chunk_noise_factor
=
0
,
drop_chunk_noise_factor
=
0
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
speed_perturb
=
SpeedPerturb
(
self
.
speed_perturb
=
SpeedPerturb
(
perturb_prob
=
perturb_prob
,
orig_freq
=
sample_rate
,
speeds
=
speeds
perturb_prob
=
perturb_prob
,
orig_freq
=
sample_rate
,
speeds
=
speeds
)
)
self
.
drop_freq
=
DropFreq
(
self
.
drop_freq
=
DropFreq
(
drop_prob
=
drop_freq_prob
,
drop_prob
=
drop_freq_prob
,
drop_count_low
=
drop_freq_count_low
,
drop_count_low
=
drop_freq_count_low
,
drop_count_high
=
drop_freq_count_high
,
drop_count_high
=
drop_freq_count_high
,
)
)
self
.
drop_chunk
=
DropChunk
(
self
.
drop_chunk
=
DropChunk
(
drop_prob
=
drop_chunk_prob
,
drop_prob
=
drop_chunk_prob
,
drop_count_low
=
drop_chunk_count_low
,
drop_count_low
=
drop_chunk_count_low
,
drop_count_high
=
drop_chunk_count_high
,
drop_count_high
=
drop_chunk_count_high
,
drop_length_low
=
drop_chunk_length_low
,
drop_length_low
=
drop_chunk_length_low
,
drop_length_high
=
drop_chunk_length_high
,
drop_length_high
=
drop_chunk_length_high
,
noise_factor
=
drop_chunk_noise_factor
,
noise_factor
=
drop_chunk_noise_factor
,
)
)
def
forward
(
self
,
waveforms
,
lengths
):
def
forward
(
self
,
waveforms
,
lengths
):
"""Returns the distorted waveforms.
"""Returns the distorted waveforms.
...
@@ -724,4 +713,4 @@ class TimeDomainSpecAugment(nn.Layer):
...
@@ -724,4 +713,4 @@ class TimeDomainSpecAugment(nn.Layer):
waveforms
=
self
.
speed_perturb
(
waveforms
)
waveforms
=
self
.
speed_perturb
(
waveforms
)
waveforms
=
self
.
drop_freq
(
waveforms
)
waveforms
=
self
.
drop_freq
(
waveforms
)
waveforms
=
self
.
drop_chunk
(
waveforms
,
lengths
)
waveforms
=
self
.
drop_chunk
(
waveforms
,
lengths
)
return
waveforms
return
waveforms
\ No newline at end of file
paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py
浏览文件 @
19180d35
import
numpy
as
np
from
collections
import
defaultdict
import
os
from
typing
import
Dict
from
typing
import
Dict
from
typing
import
List
from
typing
import
List
from
typing
import
Optional
from
typing
import
Tuple
from
typing
import
Tuple
import
paddle
import
paddle
import
paddle.nn
as
nn
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
from
paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2
import
Wav2Vec2ConfigPure
from
paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2
import
Wav2Vec2ConfigPure
from
paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2
import
Wav2Vec2Model
from
paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2
import
Wav2Vec2Model
from
paddlespeech.s2t.modules.mask
import
make_pad_mask
from
paddlespeech.s2t.utils.utility
import
log_add
from
collections
import
defaultdict
from
paddlespeech.s2t.models.wav2vec2.modules.VanillaNN
import
VanillaNN
from
paddlespeech.s2t.models.wav2vec2.modules.VanillaNN
import
VanillaNN
from
paddlespeech.s2t.modules.ctc
import
CTCDecoderBase
as
CTC
from
paddlespeech.s2t.modules.ctc
import
CTCDecoderBase
as
CTC
from
paddlespeech.s2t.utils.ctc_utils
import
remove_duplicates_and_blank
from
paddlespeech.s2t.utils.ctc_utils
import
remove_duplicates_and_blank
from
yacs.config
import
CfgNode
from
paddlespeech.s2t.utils.utility
import
log_add
class
Wav2vec2ASR
(
nn
.
Layer
):
class
Wav2vec2ASR
(
nn
.
Layer
):
def
__init__
(
self
,
config
:
dict
):
def
__init__
(
self
,
config
:
dict
):
super
().
__init__
()
super
().
__init__
()
wav2vec2_config
=
Wav2Vec2ConfigPure
(
config
)
wav2vec2_config
=
Wav2Vec2ConfigPure
(
config
)
wav2vec2
=
Wav2Vec2Model
(
wav2vec2_config
)
wav2vec2
=
Wav2Vec2Model
(
wav2vec2_config
)
model_dict
=
paddle
.
load
(
config
.
wav2vec2_params_path
)
model_dict
=
paddle
.
load
(
config
.
wav2vec2_params_path
)
...
@@ -36,8 +30,16 @@ class Wav2vec2ASR(nn.Layer):
...
@@ -36,8 +30,16 @@ class Wav2vec2ASR(nn.Layer):
for
parm
in
wav2vec2
.
parameters
():
for
parm
in
wav2vec2
.
parameters
():
parm
.
trainable
=
False
parm
.
trainable
=
False
self
.
wav2vec2
=
wav2vec2
self
.
wav2vec2
=
wav2vec2
self
.
enc
=
VanillaNN
(
input_shape
=
[
None
,
None
,
wav2vec2_config
.
hidden_size
],
activation
=
nn
.
LeakyReLU
,
dnn_blocks
=
config
.
dnn_blocks
,
dnn_neurons
=
config
.
dnn_neurons
)
self
.
enc
=
VanillaNN
(
self
.
ctc
=
CTC
(
odim
=
config
.
output_dim
,
enc_n_units
=
config
.
dnn_neurons
,
blank_id
=
config
.
blank_id
,
dropout_rate
=
config
.
ctc_dropout_rate
,
reduction
=
True
)
input_shape
=
[
None
,
None
,
wav2vec2_config
.
hidden_size
],
activation
=
nn
.
LeakyReLU
,
dnn_blocks
=
config
.
dnn_blocks
,
dnn_neurons
=
config
.
dnn_neurons
)
self
.
ctc
=
CTC
(
odim
=
config
.
output_dim
,
enc_n_units
=
config
.
dnn_neurons
,
blank_id
=
config
.
blank_id
,
dropout_rate
=
config
.
ctc_dropout_rate
,
reduction
=
True
)
def
forward
(
self
,
wav
,
wavs_lens_rate
,
target
,
target_lens_rate
):
def
forward
(
self
,
wav
,
wavs_lens_rate
,
target
,
target_lens_rate
):
if
self
.
normalize_wav
:
if
self
.
normalize_wav
:
...
@@ -51,25 +53,27 @@ class Wav2vec2ASR(nn.Layer):
...
@@ -51,25 +53,27 @@ class Wav2vec2ASR(nn.Layer):
x
=
self
.
enc
(
feats
)
x
=
self
.
enc
(
feats
)
x_lens
=
(
wavs_lens_rate
*
x
.
shape
[
1
]).
round
().
astype
(
paddle
.
int64
)
x_lens
=
(
wavs_lens_rate
*
x
.
shape
[
1
]).
round
().
astype
(
paddle
.
int64
)
target_lens
=
(
target_lens_rate
*
target
.
shape
[
1
]).
round
().
astype
(
paddle
.
int64
)
target_lens
=
(
target_lens_rate
*
target
.
shape
[
1
]).
round
().
astype
(
paddle
.
int64
)
ctc_loss
=
self
.
ctc
(
x
,
x_lens
,
target
,
target_lens
)
ctc_loss
=
self
.
ctc
(
x
,
x_lens
,
target
,
target_lens
)
return
ctc_loss
return
ctc_loss
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
decode
(
self
,
def
decode
(
self
,
feats
:
paddle
.
Tensor
,
feats
:
paddle
.
Tensor
,
text_feature
:
Dict
[
str
,
int
],
text_feature
:
Dict
[
str
,
int
],
decoding_method
:
str
,
decoding_method
:
str
,
beam_size
:
int
):
beam_size
:
int
):
batch_size
=
feats
.
shape
[
0
]
batch_size
=
feats
.
shape
[
0
]
if
decoding_method
is
'ctc_prefix_beam_search'
and
batch_size
>
1
:
if
decoding_method
==
'ctc_prefix_beam_search'
and
batch_size
>
1
:
logger
.
error
(
logger
.
error
(
f
'decoding mode
{
decoding_method
}
must be running with batch_size == 1'
f
'decoding mode
{
decoding_method
}
must be running with batch_size == 1'
)
)
logger
.
error
(
f
"current batch_size is
{
batch_size
}
"
)
logger
.
error
(
f
"current batch_size is
{
batch_size
}
"
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
if
decoding_method
==
'ctc_greedy_search'
:
if
decoding_method
==
'ctc_greedy_search'
:
hyps
=
self
.
ctc_greedy_search
(
feats
)
hyps
=
self
.
ctc_greedy_search
(
feats
)
res
=
[
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
]
res
=
[
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
]
...
@@ -79,13 +83,12 @@ class Wav2vec2ASR(nn.Layer):
...
@@ -79,13 +83,12 @@ class Wav2vec2ASR(nn.Layer):
# with other batch decoding mode
# with other batch decoding mode
elif
decoding_method
==
'ctc_prefix_beam_search'
:
elif
decoding_method
==
'ctc_prefix_beam_search'
:
assert
feats
.
shape
[
0
]
==
1
assert
feats
.
shape
[
0
]
==
1
hyp
=
self
.
ctc_prefix_beam_search
(
hyp
=
self
.
ctc_prefix_beam_search
(
feats
,
beam_size
)
feats
,
beam_size
)
res
=
[
text_feature
.
defeaturize
(
hyp
)]
res
=
[
text_feature
.
defeaturize
(
hyp
)]
res_tokenids
=
[
hyp
]
res_tokenids
=
[
hyp
]
else
:
else
:
raise
ValueError
(
f
"wav2vec2 not support decoding method:
{
decoding_method
}
"
)
raise
ValueError
(
f
"wav2vec2 not support decoding method:
{
decoding_method
}
"
)
return
res
,
res_tokenids
return
res
,
res_tokenids
...
@@ -94,8 +97,7 @@ class Wav2vec2ASR(nn.Layer):
...
@@ -94,8 +97,7 @@ class Wav2vec2ASR(nn.Layer):
model
=
cls
(
config
)
model
=
cls
(
config
)
return
model
return
model
def
ctc_greedy_search
(
def
ctc_greedy_search
(
self
,
wav
)
->
List
[
List
[
int
]]:
self
,
wav
)
->
List
[
List
[
int
]]:
""" Apply CTC greedy search
""" Apply CTC greedy search
Args:
Args:
speech (paddle.Tensor): (batch, max_len)
speech (paddle.Tensor): (batch, max_len)
...
@@ -104,7 +106,7 @@ class Wav2vec2ASR(nn.Layer):
...
@@ -104,7 +106,7 @@ class Wav2vec2ASR(nn.Layer):
List[List[int]]: best path result
List[List[int]]: best path result
"""
"""
batch_size
=
wav
.
shape
[
0
]
batch_size
=
wav
.
shape
[
0
]
wav
=
wav
[:,
:,
0
]
wav
=
wav
[:,
:,
0
]
if
self
.
normalize_wav
:
if
self
.
normalize_wav
:
wav
=
F
.
layer_norm
(
wav
,
wav
.
shape
[
1
:])
wav
=
F
.
layer_norm
(
wav
,
wav
.
shape
[
1
:])
# Extract wav2vec output
# Extract wav2vec output
...
@@ -124,7 +126,10 @@ class Wav2vec2ASR(nn.Layer):
...
@@ -124,7 +126,10 @@ class Wav2vec2ASR(nn.Layer):
return
hyps
return
hyps
def
_ctc_prefix_beam_search
(
def
_ctc_prefix_beam_search
(
self
,
wav
,
beam_size
,
blank_id
:
int
=
0
,
)
->
Tuple
[
List
[
Tuple
[
int
,
float
]],
paddle
.
Tensor
]:
self
,
wav
,
beam_size
,
blank_id
:
int
=
0
,
)
->
Tuple
[
List
[
Tuple
[
int
,
float
]],
paddle
.
Tensor
]:
""" CTC prefix beam search inner implementation
""" CTC prefix beam search inner implementation
Args:
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
speech (paddle.Tensor): (batch, max_len, feat_dim)
...
@@ -142,7 +147,7 @@ class Wav2vec2ASR(nn.Layer):
...
@@ -142,7 +147,7 @@ class Wav2vec2ASR(nn.Layer):
paddle.Tensor: encoder output, (1, max_len, encoder_dim),
paddle.Tensor: encoder output, (1, max_len, encoder_dim),
it will be used for rescoring in attention rescoring mode
it will be used for rescoring in attention rescoring mode
"""
"""
wav
=
wav
[:,
:,
0
]
wav
=
wav
[:,
:,
0
]
if
self
.
normalize_wav
:
if
self
.
normalize_wav
:
wav
=
F
.
layer_norm
(
wav
,
wav
.
shape
[
1
:])
wav
=
F
.
layer_norm
(
wav
,
wav
.
shape
[
1
:])
...
@@ -219,29 +224,5 @@ class Wav2vec2ASR(nn.Layer):
...
@@ -219,29 +224,5 @@ class Wav2vec2ASR(nn.Layer):
Returns:
Returns:
List[int]: CTC prefix beam search nbest results
List[int]: CTC prefix beam search nbest results
"""
"""
hyps
=
self
.
_ctc_prefix_beam_search
(
hyps
=
self
.
_ctc_prefix_beam_search
(
wav
,
beam_size
)
wav
,
beam_size
)
return
hyps
[
0
][
0
]
return
hyps
[
0
][
0
]
# @jit.to_static
# def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
# """ Export interface for c++ call, apply linear transform and log
# softmax before ctc
# Args:
# xs (paddle.Tensor): encoder output, (B, T, D)
# Returns:
# paddle.Tensor: activation before ctc
# """
# return self.ctc.log_softmax(xs)
# def _get_data(self):
# data_dir = "data"
# wavs = np.load(os.path.join(data_dir, "wavs.npy"))
# wavs_lens = np.load(os.path.join(data_dir, "wavs_lens.npy"))
# tokens = np.load(os.path.join(data_dir, "tokens.npy"))
# tokens_lens = np.load(os.path.join(data_dir, "tokens_lens.npy"))
# batch = (paddle.to_tensor(wavs), paddle.to_tensor(wavs_lens, dtype='float32'),
# paddle.to_tensor(tokens, dtype='int32'), paddle.to_tensor(tokens_lens, dtype='float32'))
# return batch
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录