Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
19180d35
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
206
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
19180d35
编写于
10月 10, 2022
作者:
T
tianhao zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format wav2vec2 demo
上级
6e429f05
变更
15
展开全部
隐藏空白更改
内联
并排
Showing
15 changed file
with
558 addition
and
485 deletion
+558
-485
.flake8
.flake8
+1
-1
examples/librispeech/README.md
examples/librispeech/README.md
+1
-1
paddlespeech/audio/transform/spectrogram.py
paddlespeech/audio/transform/spectrogram.py
+30
-0
paddlespeech/audio/transform/transformation.py
paddlespeech/audio/transform/transformation.py
+1
-0
paddlespeech/s2t/exps/wav2vec2/bin/test_wav.py
paddlespeech/s2t/exps/wav2vec2/bin/test_wav.py
+6
-6
paddlespeech/s2t/exps/wav2vec2/model.py
paddlespeech/s2t/exps/wav2vec2/model.py
+35
-28
paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py
paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py
+7
-8
paddlespeech/s2t/models/wav2vec2/modules/activations.py
paddlespeech/s2t/models/wav2vec2/modules/activations.py
+14
-9
paddlespeech/s2t/models/wav2vec2/modules/containers.py
paddlespeech/s2t/models/wav2vec2/modules/containers.py
+5
-7
paddlespeech/s2t/models/wav2vec2/modules/linear.py
paddlespeech/s2t/models/wav2vec2/modules/linear.py
+8
-9
paddlespeech/s2t/models/wav2vec2/modules/modeling_outputs.py
paddlespeech/s2t/models/wav2vec2/modules/modeling_outputs.py
+23
-15
paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py
...lespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py
+301
-239
paddlespeech/s2t/models/wav2vec2/processing/signal_processing.py
...peech/s2t/models/wav2vec2/processing/signal_processing.py
+15
-21
paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py
...ech/s2t/models/wav2vec2/processing/speech_augmentation.py
+78
-89
paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py
paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py
+33
-52
未找到文件。
.flake8
浏览文件 @
19180d35
...
...
@@ -33,7 +33,7 @@ filename =
# Specify a list of codes to ignore.
ignore =
W503
E252,E262,E127,E265,E126,E266,E241,E261,E128,E125
E252,E262,E127,E265,E126,E266,E241,E261,E128,E125
,E129
W291,W293,W605
E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
...
...
examples/librispeech/README.md
浏览文件 @
19180d35
...
...
@@ -3,7 +3,7 @@
*
asr0 - deepspeech2 Streaming/Non-Streaming
*
asr1 - transformer/conformer Streaming/Non-Streaming
*
asr2 - transformer/conformer Streaming/Non-Streaming with Kaldi feature
*
asr3 - wav2vecASR, ASR model with pre-trained wav2vec2 and CTC
## Data
| Data Subset | Duration in Seconds |
...
...
paddlespeech/audio/transform/spectrogram.py
浏览文件 @
19180d35
...
...
@@ -382,6 +382,36 @@ class LogMelSpectrogramKaldi():
return
mat
class
WavProcess
():
def
__init__
(
self
,
dither
=
0.1
):
"""
Args:
dither (float): Dithering constant
Returns:
"""
self
.
dither
=
dither
def
__call__
(
self
,
x
,
train
):
"""
Args:
x (np.ndarray): shape (Ti,)
train (bool): True, train mode.
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
dither
=
self
.
dither
if
train
else
0.0
if
x
.
ndim
!=
1
:
raise
ValueError
(
"Not support x: [Time, Channel]"
)
waveform
=
np
.
expand_dims
(
x
,
-
1
)
return
waveform
class
LogMelSpectrogramKaldi_decay
():
def
__init__
(
self
,
...
...
paddlespeech/audio/transform/transformation.py
浏览文件 @
19180d35
...
...
@@ -41,6 +41,7 @@ import_alias = dict(
utterance_cmvn
=
"paddlespeech.audio.transform.cmvn:UtteranceCMVN"
,
fbank
=
"paddlespeech.audio.transform.spectrogram:LogMelSpectrogram"
,
spectrogram
=
"paddlespeech.audio.transform.spectrogram:Spectrogram"
,
wav_process
=
"paddlespeech.audio.transform.spectrogram:WavProcess"
,
stft
=
"paddlespeech.audio.transform.spectrogram:Stft"
,
istft
=
"paddlespeech.audio.transform.spectrogram:IStft"
,
stft2fbank
=
"paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram"
,
...
...
paddlespeech/s2t/exps/wav2vec2/bin/test_wav.py
浏览文件 @
19180d35
...
...
@@ -27,6 +27,7 @@ from paddlespeech.s2t.utils.log import Log
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
logger
=
Log
(
__name__
).
getlog
()
class
Wav2vec2Infer
():
def
__init__
(
self
,
config
,
args
):
self
.
args
=
args
...
...
@@ -34,8 +35,7 @@ class Wav2vec2Infer():
self
.
audio_file
=
args
.
audio_file
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
config
.
unit_type
,
vocab
=
config
.
vocab_filepath
)
unit_type
=
config
.
unit_type
,
vocab
=
config
.
vocab_filepath
)
paddle
.
set_device
(
'gpu'
if
self
.
args
.
ngpu
>
0
else
'cpu'
)
# model
...
...
@@ -63,10 +63,10 @@ class Wav2vec2Infer():
xs
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
).
unsqueeze
(
axis
=
0
)
decode_config
=
self
.
config
.
decode
result_transcripts
,
result_tokenids
=
self
.
model
.
decode
(
xs
,
text_feature
=
self
.
text_feature
,
decoding_method
=
decode_config
.
decoding_method
,
beam_size
=
decode_config
.
beam_size
)
xs
,
text_feature
=
self
.
text_feature
,
decoding_method
=
decode_config
.
decoding_method
,
beam_size
=
decode_config
.
beam_size
)
rsl
=
result_transcripts
[
0
]
utt
=
Path
(
self
.
audio_file
).
name
logger
.
info
(
f
"hyp:
{
utt
}
{
rsl
}
"
)
...
...
paddlespeech/s2t/exps/wav2vec2/model.py
浏览文件 @
19180d35
...
...
@@ -18,53 +18,53 @@ import time
from
collections
import
defaultdict
from
collections
import
OrderedDict
from
contextlib
import
nullcontext
from
paddlespeech.s2t.utils
import
mp_tools
import
jsonlines
import
numpy
as
np
import
paddle
from
paddle
import
distributed
as
dist
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.io.dataloader
import
StreamDataLoader
from
paddlespeech.s2t.io.dataloader
import
DataLoaderFactory
from
paddlespeech.s2t.
models.wav2vec2.wav2vec2_ASR
import
Wav2vec2ASR
from
paddlespeech.s2t.
io.dataloader
import
StreamDataLoader
from
paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation
import
TimeDomainSpecAugment
from
paddlespeech.s2t.utils
import
error_rate
from
paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR
import
Wav2vec2ASR
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
from
paddlespeech.s2t.training.reporter
import
ObsScope
from
paddlespeech.s2t.training.reporter
import
report
from
paddlespeech.s2t.training.scheduler
import
LRSchedulerFactory
from
paddlespeech.s2t.training.timer
import
Timer
from
paddlespeech.s2t.training.trainer
import
Trainer
from
paddlespeech.s2t.utils
.utility
import
UpdateConfig
from
paddlespeech.s2t.utils
import
error_rate
from
paddlespeech.s2t.utils
import
layer_tools
from
paddlespeech.s2t.utils
import
mp_tools
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
logger
=
Log
(
__name__
).
getlog
()
class
Wav2Vec2ASRTrainer
(
Trainer
):
def
__init__
(
self
,
config
,
args
):
super
().
__init__
(
config
,
args
)
self
.
avg_train_loss
=
0
def
train_batch
(
self
,
batch_index
,
batch
,
msg
):
train_conf
=
self
.
config
start
=
time
.
time
()
# forward
utt
,
wav
,
wavs_lens
,
target
,
target_lens
=
batch
wavs_lens_rate
=
wavs_lens
/
wav
.
shape
[
1
]
wavs_lens_rate
=
wavs_lens
/
wav
.
shape
[
1
]
target_lens_rate
=
target_lens
/
target
.
shape
[
1
]
wav
=
wav
[:,
:,
0
]
wav
=
wav
[:,
:,
0
]
wav
=
self
.
speech_augmentation
(
wav
,
wavs_lens_rate
)
loss
=
self
.
model
(
wav
,
wavs_lens_rate
,
target
,
target_lens_rate
)
# pring(wav, wavs_lens_rate, target, target_lens_rate)
# loss div by `batch_size * accum_grad`
loss
/=
train_conf
.
accum_grad
losses_np
=
{
'loss'
:
float
(
loss
)
*
train_conf
.
accum_grad
}
# loss backward
...
...
@@ -108,15 +108,16 @@ class Wav2Vec2ASRTrainer(Trainer):
def
valid
(
self
):
self
.
model
.
eval
()
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
total_loss
=
0.0
for
i
,
batch
in
enumerate
(
self
.
valid_loader
):
utt
,
wav
,
wavs_lens
,
target
,
target_lens
=
batch
wavs_lens_rate
=
wavs_lens
/
wav
.
shape
[
1
]
wavs_lens_rate
=
wavs_lens
/
wav
.
shape
[
1
]
target_lens_rate
=
target_lens
/
target
.
shape
[
1
]
wav
=
wav
[:,
:,
0
]
wav
=
wav
[:,
:,
0
]
loss
=
self
.
model
(
wav
,
wavs_lens_rate
,
target
,
target_lens_rate
)
if
paddle
.
isfinite
(
loss
):
...
...
@@ -134,7 +135,8 @@ class Wav2Vec2ASRTrainer(Trainer):
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
if
not
self
.
use_streamdata
:
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_dump
.
items
())
logger
.
info
(
msg
)
...
...
@@ -155,7 +157,8 @@ class Wav2Vec2ASRTrainer(Trainer):
self
.
before_train
()
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
n_epoch
:
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
...
...
@@ -223,14 +226,18 @@ class Wav2Vec2ASRTrainer(Trainer):
config
=
self
.
config
.
clone
()
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
if
self
.
train
:
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
decode_batch_size
=
config
.
get
(
'decode'
,
dict
()).
get
(
'decode_batch_size'
,
1
)
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
logger
.
info
(
"Setup test/align Dataloader!"
)
def
setup_model
(
self
):
...
...
@@ -248,7 +255,7 @@ class Wav2Vec2ASRTrainer(Trainer):
model
=
Wav2vec2ASR
.
from_config
(
model_conf
)
if
self
.
parallel
:
model
=
paddle
.
DataParallel
(
model
,
find_unused_parameters
=
True
)
model
=
paddle
.
DataParallel
(
model
,
find_unused_parameters
=
True
)
logger
.
info
(
f
"
{
model
}
"
)
layer_tools
.
print_params
(
model
,
logger
.
info
)
...
...
@@ -312,14 +319,14 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer):
self
.
text_featurizer
=
TextFeaturizer
(
unit_type
=
config
.
unit_type
,
vocab
=
config
.
vocab_filepath
)
self
.
vocab_list
=
self
.
text_featurizer
.
vocab_list
def
id2token
(
self
,
texts
,
texts_len
):
""" ord() id to chr() chr """
trans
=
[]
for
text
,
n
in
zip
(
texts
,
texts_len
):
n
=
n
.
numpy
().
item
()
ids
=
text
[:
n
]
trans
.
append
(
self
.
text_featurizer
.
defeaturize
(
ids
.
numpy
().
tolist
()))
trans
.
append
(
self
.
text_featurizer
.
defeaturize
(
ids
.
numpy
().
tolist
()))
return
trans
def
compute_metrics
(
self
,
...
...
@@ -337,10 +344,10 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer):
start_time
=
time
.
time
()
target_transcripts
=
self
.
id2token
(
texts
,
texts_len
)
result_transcripts
,
result_tokenids
=
self
.
model
.
decode
(
audio
,
text_feature
=
self
.
text_featurizer
,
decoding_method
=
decode_cfg
.
decoding_method
,
beam_size
=
decode_cfg
.
beam_size
)
audio
,
text_feature
=
self
.
text_featurizer
,
decoding_method
=
decode_cfg
.
decoding_method
,
beam_size
=
decode_cfg
.
beam_size
)
decode_time
=
time
.
time
()
-
start_time
for
utt
,
target
,
result
,
rec_tids
in
zip
(
...
...
@@ -432,4 +439,4 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer):
"decode_method"
:
self
.
config
.
decode
.
decoding_method
,
})
f
.
write
(
data
+
'
\n
'
)
\ No newline at end of file
f
.
write
(
data
+
'
\n
'
)
paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py
浏览文件 @
19180d35
...
...
@@ -3,6 +3,7 @@ Authors
* Elena Rastorgueva 2020
"""
import
paddle
from
paddlespeech.s2t.models.wav2vec2.modules
import
containers
from
paddlespeech.s2t.models.wav2vec2.modules
import
linear
...
...
@@ -27,12 +28,11 @@ class VanillaNN(containers.Sequential):
"""
def
__init__
(
self
,
input_shape
,
activation
=
paddle
.
nn
.
LeakyReLU
,
dnn_blocks
=
2
,
dnn_neurons
=
512
,
):
self
,
input_shape
,
activation
=
paddle
.
nn
.
LeakyReLU
,
dnn_blocks
=
2
,
dnn_neurons
=
512
,
):
super
().
__init__
(
input_shape
=
input_shape
)
for
block_index
in
range
(
dnn_blocks
):
...
...
@@ -40,6 +40,5 @@ class VanillaNN(containers.Sequential):
linear
.
Linear
,
n_neurons
=
dnn_neurons
,
bias
=
True
,
layer_name
=
"linear"
,
)
layer_name
=
"linear"
,
)
self
.
append
(
activation
(),
layer_name
=
"act"
)
paddlespeech/s2t/models/wav2vec2/modules/activations.py
浏览文件 @
19180d35
...
...
@@ -11,12 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
from
packaging
import
version
from
paddle
import
Tensor
,
nn
from
paddle
import
nn
from
paddle
import
Tensor
from
paddlespeech.s2t.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
...
...
@@ -29,7 +27,9 @@ class NewGELUActivation(nn.Layer):
"""
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
return
0.5
*
input
*
(
1.0
+
paddle
.
tanh
(
math
.
sqrt
(
2.0
/
math
.
pi
)
*
(
input
+
0.044715
*
paddle
.
pow
(
input
,
3.0
))))
return
0.5
*
input
*
(
1.0
+
paddle
.
tanh
(
math
.
sqrt
(
2.0
/
math
.
pi
)
*
(
input
+
0.044715
*
paddle
.
pow
(
input
,
3.0
))))
class
GELUActivation
(
nn
.
Layer
):
...
...
@@ -40,7 +40,7 @@ class GELUActivation(nn.Layer):
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def
__init__
(
self
,
use_gelu_python
:
bool
=
False
):
def
__init__
(
self
,
use_gelu_python
:
bool
=
False
):
super
().
__init__
()
self
.
act
=
nn
.
functional
.
gelu
...
...
@@ -57,7 +57,9 @@ class FastGELUActivation(nn.Layer):
"""
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
return
0.5
*
input
*
(
1.0
+
paddle
.
tanh
(
input
*
0.7978845608
*
(
1.0
+
0.044715
*
input
*
input
)))
return
0.5
*
input
*
(
1.0
+
paddle
.
tanh
(
input
*
0.7978845608
*
(
1.0
+
0.044715
*
input
*
input
)))
class
QuickGELUActivation
(
nn
.
Layer
):
...
...
@@ -84,7 +86,8 @@ class ClippedGELUActivation(nn.Layer):
def
__init__
(
self
,
min
:
float
,
max
:
float
):
if
min
>
max
:
raise
ValueError
(
f
"min should be < max (got min:
{
min
}
, max:
{
max
}
)"
)
raise
ValueError
(
f
"min should be < max (got min:
{
min
}
, max:
{
max
}
)"
)
super
().
__init__
()
self
.
min
=
min
...
...
@@ -161,7 +164,9 @@ def get_activation(activation_string):
if
activation_string
in
ACT2FN
:
return
ACT2FN
[
activation_string
]
else
:
raise
KeyError
(
f
"function
{
activation_string
}
not found in ACT2FN mapping
{
list
(
ACT2FN
.
keys
())
}
"
)
raise
KeyError
(
f
"function
{
activation_string
}
not found in ACT2FN mapping
{
list
(
ACT2FN
.
keys
())
}
"
)
# For backwards compatibility with: from activations import gelu_python
...
...
paddlespeech/s2t/models/wav2vec2/modules/containers.py
浏览文件 @
19180d35
import
paddle
import
inspect
import
logging
import
operator
import
functools
import
paddle
class
Sequential
(
paddle
.
nn
.
LayerDict
):
"""A sequence of modules with potentially inferring shape on construction.
...
...
@@ -98,13 +97,12 @@ class Sequential(paddle.nn.LayerDict):
# Finally, append the layer.
try
:
self
[
layer_name
]
=
layer
# self.add_module(layer_name, layer)
# self.add_module(layer_name, layer)
except
TypeError
:
raise
ValueError
(
"Must pass `input_shape` at initialization and use "
"modules that take `input_shape` to infer shape when "
"using `append()`."
)
"using `append()`."
)
def
get_output_shape
(
self
):
"""Returns expected shape of the output.
...
...
paddlespeech/s2t/models/wav2vec2/modules/linear.py
浏览文件 @
19180d35
...
...
@@ -3,10 +3,10 @@ Authors
* Mirco Ravanelli 2020
* Davide Borra 2021
"""
import
logging
import
paddle
import
paddle.nn
as
nn
from
paddlespeech.s2t.modules
import
align
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -37,13 +37,12 @@ class Linear(paddle.nn.Layer):
"""
def
__init__
(
self
,
n_neurons
,
input_shape
=
None
,
input_size
=
None
,
bias
=
True
,
combine_dims
=
False
,
):
self
,
n_neurons
,
input_shape
=
None
,
input_size
=
None
,
bias
=
True
,
combine_dims
=
False
,
):
super
().
__init__
()
self
.
combine_dims
=
combine_dims
...
...
paddlespeech/s2t/models/wav2vec2/modules/modeling_outputs.py
浏览文件 @
19180d35
...
...
@@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
from
collections
import
OrderedDict
from
dataclasses
import
dataclass
from
dataclasses
import
fields
from
typing
import
Optional
from
typing
import
Tuple
import
paddle
...
...
@@ -41,10 +41,13 @@ class ModelOutput(OrderedDict):
if
not
len
(
class_fields
):
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
has no fields."
)
if
not
all
(
field
.
default
is
None
for
field
in
class_fields
[
1
:]):
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
should not have more than one required field."
)
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
should not have more than one required field."
)
first_field
=
getattr
(
self
,
class_fields
[
0
].
name
)
other_fields_are_none
=
all
(
getattr
(
self
,
field
.
name
)
is
None
for
field
in
class_fields
[
1
:])
other_fields_are_none
=
all
(
getattr
(
self
,
field
.
name
)
is
None
for
field
in
class_fields
[
1
:])
if
other_fields_are_none
and
not
paddle
.
is_tensor
(
first_field
):
if
isinstance
(
first_field
,
dict
):
...
...
@@ -61,11 +64,9 @@ class ModelOutput(OrderedDict):
# set the associated fields
if
first_field_iterator
:
for
element
in
iterator
:
if
(
not
isinstance
(
element
,
(
list
,
tuple
))
or
not
len
(
element
)
==
2
or
not
isinstance
(
element
[
0
],
str
)
):
if
(
not
isinstance
(
element
,
(
list
,
tuple
))
or
not
len
(
element
)
==
2
or
not
isinstance
(
element
[
0
],
str
)):
break
setattr
(
self
,
element
[
0
],
element
[
1
])
if
element
[
1
]
is
not
None
:
...
...
@@ -79,16 +80,23 @@ class ModelOutput(OrderedDict):
self
[
field
.
name
]
=
v
def
__delitem__
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``__delitem__`` on a
{
self
.
__class__
.
__name__
}
instance."
)
raise
Exception
(
f
"You cannot use ``__delitem__`` on a
{
self
.
__class__
.
__name__
}
instance."
)
def
setdefault
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``setdefault`` on a
{
self
.
__class__
.
__name__
}
instance."
)
raise
Exception
(
f
"You cannot use ``setdefault`` on a
{
self
.
__class__
.
__name__
}
instance."
)
def
pop
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``pop`` on a
{
self
.
__class__
.
__name__
}
instance."
)
raise
Exception
(
f
"You cannot use ``pop`` on a
{
self
.
__class__
.
__name__
}
instance."
)
def
update
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``update`` on a
{
self
.
__class__
.
__name__
}
instance."
)
raise
Exception
(
f
"You cannot use ``update`` on a
{
self
.
__class__
.
__name__
}
instance."
)
def
__getitem__
(
self
,
k
):
if
isinstance
(
k
,
str
):
...
...
paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py
浏览文件 @
19180d35
此差异已折叠。
点击以展开。
paddlespeech/s2t/models/wav2vec2/processing/signal_processing.py
浏览文件 @
19180d35
...
...
@@ -7,10 +7,8 @@ Authors
* Samuele Cornell 2020
* Sarthak Yadav 2022
"""
import
paddle
import
math
from
packaging
import
version
import
numpy
as
np
import
paddle
def
blackman_window
(
window_length
,
periodic
=
True
):
...
...
@@ -90,15 +88,14 @@ def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
def
convolve1d
(
waveform
,
kernel
,
padding
=
0
,
pad_type
=
"constant"
,
stride
=
1
,
groups
=
1
,
use_fft
=
False
,
rotation_index
=
0
,
):
waveform
,
kernel
,
padding
=
0
,
pad_type
=
"constant"
,
stride
=
1
,
groups
=
1
,
use_fft
=
False
,
rotation_index
=
0
,
):
"""Use paddle.nn.functional to perform 1d padding and conv.
Arguments
---------
...
...
@@ -150,8 +147,7 @@ def convolve1d(
# Padding can be a tuple (left_pad, right_pad) or an int
if
isinstance
(
padding
,
tuple
):
waveform
=
paddle
.
nn
.
functional
.
pad
(
x
=
waveform
,
pad
=
padding
,
mode
=
pad_type
,
data_format
=
'NCL'
)
x
=
waveform
,
pad
=
padding
,
mode
=
pad_type
,
data_format
=
'NCL'
)
# This approach uses FFT, which is more efficient if the kernel is large
if
use_fft
:
...
...
@@ -165,9 +161,7 @@ def convolve1d(
# Perform rotation to ensure alignment
zeros
=
paddle
.
zeros
(
[
kernel
.
shape
[
0
],
kernel
.
shape
[
1
],
zero_length
],
dtype
=
kernel
.
dtype
)
[
kernel
.
shape
[
0
],
kernel
.
shape
[
1
],
zero_length
],
dtype
=
kernel
.
dtype
)
after_index
=
kernel
[...,
rotation_index
:]
before_index
=
kernel
[...,
:
rotation_index
]
kernel
=
paddle
.
concat
((
after_index
,
zeros
,
before_index
),
axis
=-
1
)
...
...
@@ -185,12 +179,12 @@ def convolve1d(
weight
=
kernel
,
stride
=
stride
,
groups
=
groups
,
padding
=
padding
if
not
isinstance
(
padding
,
tuple
)
else
0
,
)
padding
=
padding
if
not
isinstance
(
padding
,
tuple
)
else
0
,
)
# Return time dimension to the second dimension.
return
convolved
.
transpose
([
0
,
2
,
1
])
def
notch_filter
(
notch_freq
,
filter_width
=
101
,
notch_width
=
0.05
):
"""Returns a notch filter constructed from a high-pass and low-pass filter.
(from https://tomroelandts.com/articles/
...
...
@@ -224,7 +218,8 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
return
paddle
.
sin
(
x
)
/
x
# The zero is at the middle index
return
paddle
.
concat
([
_sinc
(
x
[:
pad
]),
paddle
.
ones
([
1
]),
_sinc
(
x
[
pad
+
1
:])])
return
paddle
.
concat
(
[
_sinc
(
x
[:
pad
]),
paddle
.
ones
([
1
]),
_sinc
(
x
[
pad
+
1
:])])
# Compute a low-pass filter with cutoff frequency notch_freq.
hlpf
=
sinc
(
3
*
(
notch_freq
-
notch_width
)
*
inputs
)
...
...
@@ -239,4 +234,3 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
# Adding filters creates notch filter
return
(
hlpf
+
hhpf
).
view
(
1
,
-
1
,
1
)
paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py
浏览文件 @
19180d35
import
math
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddlespeech.s2t.models.wav2vec2.processing.signal_processing
import
(
compute_amplitude
,
convolve1d
,
notch_filter
)
from
paddlespeech.s2t.models.wav2vec2.processing.signal_processing
import
compute_amplitude
from
paddlespeech.s2t.models.wav2vec2.processing.signal_processing
import
convolve1d
from
paddlespeech.s2t.models.wav2vec2.processing.signal_processing
import
notch_filter
class
SpeedPerturb
(
nn
.
Layer
):
"""Slightly speed up or slow down an audio signal.
...
...
@@ -36,8 +37,10 @@ class SpeedPerturb(nn.Layer):
"""
def
__init__
(
self
,
orig_freq
,
speeds
=
[
90
,
100
,
110
],
perturb_prob
=
1.0
,
):
self
,
orig_freq
,
speeds
=
[
90
,
100
,
110
],
perturb_prob
=
1.0
,
):
super
().
__init__
()
self
.
orig_freq
=
orig_freq
self
.
speeds
=
speeds
...
...
@@ -70,14 +73,15 @@ class SpeedPerturb(nn.Layer):
# Don't perturb (return early) 1-`perturb_prob` portion of the batches
if
paddle
.
rand
([
1
])
>
self
.
perturb_prob
:
return
waveform
.
clone
()
# Perform a random perturbation
self
.
samp_index
=
paddle
.
randint
(
len
(
self
.
speeds
),
shape
=
(
1
,))[
0
]
self
.
samp_index
=
paddle
.
randint
(
len
(
self
.
speeds
),
shape
=
(
1
,
))[
0
]
perturbed_waveform
=
self
.
resamplers
[
self
.
samp_index
](
waveform
)
return
perturbed_waveform
class
Resample
(
nn
.
Layer
):
"""This class resamples an audio signal using sinc-based interpolation.
...
...
@@ -94,9 +98,12 @@ class Resample(nn.Layer):
Controls the sharpness of the filter, larger numbers result in a
sharper filter, but they are less efficient. Values from 4 to 10 are allowed.
"""
def
__init__
(
self
,
orig_freq
=
16000
,
new_freq
=
16000
,
lowpass_filter_width
=
6
,
):
self
,
orig_freq
=
16000
,
new_freq
=
16000
,
lowpass_filter_width
=
6
,
):
super
().
__init__
()
self
.
orig_freq
=
orig_freq
self
.
new_freq
=
new_freq
...
...
@@ -193,8 +200,7 @@ class Resample(nn.Layer):
window_size
=
self
.
weights
.
shape
[
1
]
tot_output_samp
=
self
.
_output_samples
(
wave_len
)
resampled_waveform
=
paddle
.
zeros
(
(
batch_size
,
num_channels
,
tot_output_samp
)
)
(
batch_size
,
num_channels
,
tot_output_samp
))
# self.weights = self.weights.to(waveforms.device)
# Check weights are on correct device
...
...
@@ -222,28 +228,25 @@ class Resample(nn.Layer):
right_padding
=
max
(
0
,
end_index
+
1
-
current_wave_len
)
left_padding
=
max
(
0
,
-
first_index
)
wave_to_conv
=
paddle
.
nn
.
functional
.
pad
(
wave_to_conv
,
(
left_padding
,
right_padding
),
data_format
=
'NCL'
)
wave_to_conv
,
(
left_padding
,
right_padding
),
data_format
=
'NCL'
)
conv_wave
=
paddle
.
nn
.
functional
.
conv1d
(
x
=
wave_to_conv
,
weight
=
self
.
weights
[
i
].
repeat
(
num_channels
,
1
,
1
),
stride
=
self
.
conv_stride
,
groups
=
num_channels
,
)
groups
=
num_channels
,
)
# we want conv_wave[:, i] to be at
# output[:, i + n*conv_transpose_stride]
dilated_conv_wave
=
paddle
.
nn
.
functional
.
conv1d_transpose
(
conv_wave
,
eye
,
stride
=
self
.
conv_transpose_stride
)
conv_wave
,
eye
,
stride
=
self
.
conv_transpose_stride
)
# pad dilated_conv_wave so it reaches the output length if needed.
left_padding
=
i
previous_padding
=
left_padding
+
dilated_conv_wave
.
shape
[
-
1
]
right_padding
=
max
(
0
,
tot_output_samp
-
previous_padding
)
dilated_conv_wave
=
paddle
.
nn
.
functional
.
pad
(
dilated_conv_wave
,
(
left_padding
,
right_padding
),
data_format
=
'NCL'
)
dilated_conv_wave
,
(
left_padding
,
right_padding
),
data_format
=
'NCL'
)
dilated_conv_wave
=
dilated_conv_wave
[...,
:
tot_output_samp
]
resampled_waveform
+=
dilated_conv_wave
...
...
@@ -326,9 +329,7 @@ class Resample(nn.Layer):
window_width
=
self
.
lowpass_filter_width
/
(
2.0
*
lowpass_cutoff
)
assert
lowpass_cutoff
<
min
(
self
.
orig_freq
,
self
.
new_freq
)
/
2
output_t
=
paddle
.
arange
(
start
=
0.0
,
end
=
self
.
output_samples
)
output_t
=
paddle
.
arange
(
start
=
0.0
,
end
=
self
.
output_samples
)
output_t
/=
self
.
new_freq
min_t
=
output_t
-
window_width
max_t
=
output_t
+
window_width
...
...
@@ -346,23 +347,16 @@ class Resample(nn.Layer):
inside_window_indices
=
delta_t
.
abs
()
<
(
window_width
)
# raised-cosine (Hanning) window with width `window_width`
weights
[
inside_window_indices
]
=
0.5
*
(
1
+
paddle
.
cos
(
2
*
math
.
pi
*
lowpass_cutoff
/
self
.
lowpass_filter_width
*
delta_t
[
inside_window_indices
]
)
)
weights
[
inside_window_indices
]
=
0.5
*
(
1
+
paddle
.
cos
(
2
*
math
.
pi
*
lowpass_cutoff
/
self
.
lowpass_filter_width
*
delta_t
[
inside_window_indices
]))
t_eq_zero_indices
=
delta_t
==
0.0
t_not_eq_zero_indices
=
~
t_eq_zero_indices
# sinc filter function
weights
[
t_not_eq_zero_indices
]
*=
paddle
.
sin
(
2
*
math
.
pi
*
lowpass_cutoff
*
delta_t
[
t_not_eq_zero_indices
]
)
/
(
math
.
pi
*
delta_t
[
t_not_eq_zero_indices
])
2
*
math
.
pi
*
lowpass_cutoff
*
delta_t
[
t_not_eq_zero_indices
]
)
/
(
math
.
pi
*
delta_t
[
t_not_eq_zero_indices
])
# limit of the function at t = 0
weights
[
t_eq_zero_indices
]
*=
2
*
lowpass_cutoff
...
...
@@ -405,14 +399,13 @@ class DropFreq(nn.Layer):
"""
def
__init__
(
self
,
drop_freq_low
=
1e-14
,
drop_freq_high
=
1
,
drop_count_low
=
1
,
drop_count_high
=
2
,
drop_width
=
0.05
,
drop_prob
=
1
,
):
self
,
drop_freq_low
=
1e-14
,
drop_freq_high
=
1
,
drop_count_low
=
1
,
drop_count_high
=
2
,
drop_width
=
0.05
,
drop_prob
=
1
,
):
super
().
__init__
()
self
.
drop_freq_low
=
drop_freq_low
self
.
drop_freq_high
=
drop_freq_high
...
...
@@ -443,14 +436,14 @@ class DropFreq(nn.Layer):
# Pick number of frequencies to drop
drop_count
=
paddle
.
randint
(
low
=
self
.
drop_count_low
,
high
=
self
.
drop_count_high
+
1
,
shape
=
(
1
,),
)
low
=
self
.
drop_count_low
,
high
=
self
.
drop_count_high
+
1
,
shape
=
(
1
,
),
)
# Pick a frequency to drop
drop_range
=
self
.
drop_freq_high
-
self
.
drop_freq_low
drop_frequency
=
(
paddle
.
rand
(
drop_count
)
*
drop_range
+
self
.
drop_freq_low
)
paddle
.
rand
(
drop_count
)
*
drop_range
+
self
.
drop_freq_low
)
# Filter parameters
filter_length
=
101
pad
=
filter_length
//
2
...
...
@@ -461,8 +454,9 @@ class DropFreq(nn.Layer):
# Subtract each frequency
for
frequency
in
drop_frequency
:
notch_kernel
=
notch_filter
(
frequency
,
filter_length
,
self
.
drop_width
,
)
frequency
,
filter_length
,
self
.
drop_width
,
)
drop_filter
=
convolve1d
(
drop_filter
,
notch_kernel
,
pad
)
# Apply filter
...
...
@@ -471,6 +465,7 @@ class DropFreq(nn.Layer):
# Remove channels dimension if added
return
dropped_waveform
.
squeeze
(
-
1
)
class
DropChunk
(
nn
.
Layer
):
"""This class drops portions of the input signal.
Using `DropChunk` as an augmentation strategy helps a models learn to rely
...
...
@@ -515,16 +510,15 @@ class DropChunk(nn.Layer):
"""
def
__init__
(
self
,
drop_length_low
=
100
,
drop_length_high
=
1000
,
drop_count_low
=
1
,
drop_count_high
=
10
,
drop_start
=
0
,
drop_end
=
None
,
drop_prob
=
1
,
noise_factor
=
0.0
,
):
self
,
drop_length_low
=
100
,
drop_length_high
=
1000
,
drop_count_low
=
1
,
drop_count_high
=
10
,
drop_start
=
0
,
drop_end
=
None
,
drop_prob
=
1
,
noise_factor
=
0.0
,
):
super
().
__init__
()
self
.
drop_length_low
=
drop_length_low
self
.
drop_length_high
=
drop_length_high
...
...
@@ -580,8 +574,7 @@ class DropChunk(nn.Layer):
drop_times
=
paddle
.
randint
(
low
=
self
.
drop_count_low
,
high
=
self
.
drop_count_high
+
1
,
shape
=
(
batch_size
,),
)
shape
=
(
batch_size
,
),
)
# Iterate batch to set mask
for
i
in
range
(
batch_size
):
...
...
@@ -592,8 +585,7 @@ class DropChunk(nn.Layer):
length
=
paddle
.
randint
(
low
=
self
.
drop_length_low
,
high
=
self
.
drop_length_high
+
1
,
shape
=
(
drop_times
[
i
],),
)
shape
=
(
drop_times
[
i
],
),
)
# Compute range of starting locations
start_min
=
self
.
drop_start
...
...
@@ -608,15 +600,16 @@ class DropChunk(nn.Layer):
# Pick starting locations
start
=
paddle
.
randint
(
low
=
start_min
,
high
=
start_max
+
1
,
shape
=
(
drop_times
[
i
],),
)
low
=
start_min
,
high
=
start_max
+
1
,
shape
=
(
drop_times
[
i
],
),
)
end
=
start
+
length
# Update waveform
if
not
self
.
noise_factor
:
for
j
in
range
(
drop_times
[
i
]):
dropped_waveform
[
i
,
start
[
j
]
:
end
[
j
]]
=
0.0
dropped_waveform
[
i
,
start
[
j
]
:
end
[
j
]]
=
0.0
else
:
# Uniform distribution of -2 to +2 * avg amplitude should
# preserve the average for normalization
...
...
@@ -625,7 +618,7 @@ class DropChunk(nn.Layer):
# zero-center the noise distribution
noise_vec
=
paddle
.
rand
([
length
[
j
]])
noise_vec
=
2
*
noise_max
*
noise_vec
-
noise_max
dropped_waveform
[
i
,
start
[
j
]
:
end
[
j
]]
=
noise_vec
dropped_waveform
[
i
,
start
[
j
]
:
end
[
j
]]
=
noise_vec
return
dropped_waveform
...
...
@@ -679,37 +672,33 @@ class TimeDomainSpecAugment(nn.Layer):
"""
def
__init__
(
self
,
perturb_prob
=
1.0
,
drop_freq_prob
=
1.0
,
drop_chunk_prob
=
1.0
,
speeds
=
[
95
,
100
,
105
],
sample_rate
=
16000
,
drop_freq_count_low
=
0
,
drop_freq_count_high
=
3
,
drop_chunk_count_low
=
0
,
drop_chunk_count_high
=
5
,
drop_chunk_length_low
=
1000
,
drop_chunk_length_high
=
2000
,
drop_chunk_noise_factor
=
0
,
):
self
,
perturb_prob
=
1.0
,
drop_freq_prob
=
1.0
,
drop_chunk_prob
=
1.0
,
speeds
=
[
95
,
100
,
105
],
sample_rate
=
16000
,
drop_freq_count_low
=
0
,
drop_freq_count_high
=
3
,
drop_chunk_count_low
=
0
,
drop_chunk_count_high
=
5
,
drop_chunk_length_low
=
1000
,
drop_chunk_length_high
=
2000
,
drop_chunk_noise_factor
=
0
,
):
super
().
__init__
()
self
.
speed_perturb
=
SpeedPerturb
(
perturb_prob
=
perturb_prob
,
orig_freq
=
sample_rate
,
speeds
=
speeds
)
perturb_prob
=
perturb_prob
,
orig_freq
=
sample_rate
,
speeds
=
speeds
)
self
.
drop_freq
=
DropFreq
(
drop_prob
=
drop_freq_prob
,
drop_count_low
=
drop_freq_count_low
,
drop_count_high
=
drop_freq_count_high
,
)
drop_count_high
=
drop_freq_count_high
,
)
self
.
drop_chunk
=
DropChunk
(
drop_prob
=
drop_chunk_prob
,
drop_count_low
=
drop_chunk_count_low
,
drop_count_high
=
drop_chunk_count_high
,
drop_length_low
=
drop_chunk_length_low
,
drop_length_high
=
drop_chunk_length_high
,
noise_factor
=
drop_chunk_noise_factor
,
)
noise_factor
=
drop_chunk_noise_factor
,
)
def
forward
(
self
,
waveforms
,
lengths
):
"""Returns the distorted waveforms.
...
...
@@ -724,4 +713,4 @@ class TimeDomainSpecAugment(nn.Layer):
waveforms
=
self
.
speed_perturb
(
waveforms
)
waveforms
=
self
.
drop_freq
(
waveforms
)
waveforms
=
self
.
drop_chunk
(
waveforms
,
lengths
)
return
waveforms
\ No newline at end of file
return
waveforms
paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py
浏览文件 @
19180d35
import
numpy
as
np
import
os
from
collections
import
defaultdict
from
typing
import
Dict
from
typing
import
List
from
typing
import
Optional
from
typing
import
Tuple
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2
import
Wav2Vec2ConfigPure
from
paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2
import
Wav2Vec2Model
from
paddlespeech.s2t.modules.mask
import
make_pad_mask
from
paddlespeech.s2t.utils.utility
import
log_add
from
collections
import
defaultdict
from
paddlespeech.s2t.models.wav2vec2.modules.VanillaNN
import
VanillaNN
from
paddlespeech.s2t.modules.ctc
import
CTCDecoderBase
as
CTC
from
paddlespeech.s2t.utils.ctc_utils
import
remove_duplicates_and_blank
from
yacs.config
import
CfgNode
from
paddlespeech.s2t.utils.utility
import
log_add
class
Wav2vec2ASR
(
nn
.
Layer
):
def
__init__
(
self
,
config
:
dict
):
super
().
__init__
()
wav2vec2_config
=
Wav2Vec2ConfigPure
(
config
)
wav2vec2
=
Wav2Vec2Model
(
wav2vec2_config
)
model_dict
=
paddle
.
load
(
config
.
wav2vec2_params_path
)
...
...
@@ -36,8 +30,16 @@ class Wav2vec2ASR(nn.Layer):
for
parm
in
wav2vec2
.
parameters
():
parm
.
trainable
=
False
self
.
wav2vec2
=
wav2vec2
self
.
enc
=
VanillaNN
(
input_shape
=
[
None
,
None
,
wav2vec2_config
.
hidden_size
],
activation
=
nn
.
LeakyReLU
,
dnn_blocks
=
config
.
dnn_blocks
,
dnn_neurons
=
config
.
dnn_neurons
)
self
.
ctc
=
CTC
(
odim
=
config
.
output_dim
,
enc_n_units
=
config
.
dnn_neurons
,
blank_id
=
config
.
blank_id
,
dropout_rate
=
config
.
ctc_dropout_rate
,
reduction
=
True
)
self
.
enc
=
VanillaNN
(
input_shape
=
[
None
,
None
,
wav2vec2_config
.
hidden_size
],
activation
=
nn
.
LeakyReLU
,
dnn_blocks
=
config
.
dnn_blocks
,
dnn_neurons
=
config
.
dnn_neurons
)
self
.
ctc
=
CTC
(
odim
=
config
.
output_dim
,
enc_n_units
=
config
.
dnn_neurons
,
blank_id
=
config
.
blank_id
,
dropout_rate
=
config
.
ctc_dropout_rate
,
reduction
=
True
)
def
forward
(
self
,
wav
,
wavs_lens_rate
,
target
,
target_lens_rate
):
if
self
.
normalize_wav
:
...
...
@@ -51,25 +53,27 @@ class Wav2vec2ASR(nn.Layer):
x
=
self
.
enc
(
feats
)
x_lens
=
(
wavs_lens_rate
*
x
.
shape
[
1
]).
round
().
astype
(
paddle
.
int64
)
target_lens
=
(
target_lens_rate
*
target
.
shape
[
1
]).
round
().
astype
(
paddle
.
int64
)
target_lens
=
(
target_lens_rate
*
target
.
shape
[
1
]).
round
().
astype
(
paddle
.
int64
)
ctc_loss
=
self
.
ctc
(
x
,
x_lens
,
target
,
target_lens
)
return
ctc_loss
@
paddle
.
no_grad
()
def
decode
(
self
,
def
decode
(
self
,
feats
:
paddle
.
Tensor
,
text_feature
:
Dict
[
str
,
int
],
decoding_method
:
str
,
beam_size
:
int
):
batch_size
=
feats
.
shape
[
0
]
if
decoding_method
is
'ctc_prefix_beam_search'
and
batch_size
>
1
:
if
decoding_method
==
'ctc_prefix_beam_search'
and
batch_size
>
1
:
logger
.
error
(
f
'decoding mode
{
decoding_method
}
must be running with batch_size == 1'
)
logger
.
error
(
f
"current batch_size is
{
batch_size
}
"
)
sys
.
exit
(
1
)
if
decoding_method
==
'ctc_greedy_search'
:
hyps
=
self
.
ctc_greedy_search
(
feats
)
res
=
[
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
]
...
...
@@ -79,13 +83,12 @@ class Wav2vec2ASR(nn.Layer):
# with other batch decoding mode
elif
decoding_method
==
'ctc_prefix_beam_search'
:
assert
feats
.
shape
[
0
]
==
1
hyp
=
self
.
ctc_prefix_beam_search
(
feats
,
beam_size
)
hyp
=
self
.
ctc_prefix_beam_search
(
feats
,
beam_size
)
res
=
[
text_feature
.
defeaturize
(
hyp
)]
res_tokenids
=
[
hyp
]
else
:
raise
ValueError
(
f
"wav2vec2 not support decoding method:
{
decoding_method
}
"
)
raise
ValueError
(
f
"wav2vec2 not support decoding method:
{
decoding_method
}
"
)
return
res
,
res_tokenids
...
...
@@ -94,8 +97,7 @@ class Wav2vec2ASR(nn.Layer):
model
=
cls
(
config
)
return
model
def
ctc_greedy_search
(
self
,
wav
)
->
List
[
List
[
int
]]:
def
ctc_greedy_search
(
self
,
wav
)
->
List
[
List
[
int
]]:
""" Apply CTC greedy search
Args:
speech (paddle.Tensor): (batch, max_len)
...
...
@@ -104,7 +106,7 @@ class Wav2vec2ASR(nn.Layer):
List[List[int]]: best path result
"""
batch_size
=
wav
.
shape
[
0
]
wav
=
wav
[:,
:,
0
]
wav
=
wav
[:,
:,
0
]
if
self
.
normalize_wav
:
wav
=
F
.
layer_norm
(
wav
,
wav
.
shape
[
1
:])
# Extract wav2vec output
...
...
@@ -124,7 +126,10 @@ class Wav2vec2ASR(nn.Layer):
return
hyps
def
_ctc_prefix_beam_search
(
self
,
wav
,
beam_size
,
blank_id
:
int
=
0
,
)
->
Tuple
[
List
[
Tuple
[
int
,
float
]],
paddle
.
Tensor
]:
self
,
wav
,
beam_size
,
blank_id
:
int
=
0
,
)
->
Tuple
[
List
[
Tuple
[
int
,
float
]],
paddle
.
Tensor
]:
""" CTC prefix beam search inner implementation
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
...
...
@@ -142,7 +147,7 @@ class Wav2vec2ASR(nn.Layer):
paddle.Tensor: encoder output, (1, max_len, encoder_dim),
it will be used for rescoring in attention rescoring mode
"""
wav
=
wav
[:,
:,
0
]
wav
=
wav
[:,
:,
0
]
if
self
.
normalize_wav
:
wav
=
F
.
layer_norm
(
wav
,
wav
.
shape
[
1
:])
...
...
@@ -219,29 +224,5 @@ class Wav2vec2ASR(nn.Layer):
Returns:
List[int]: CTC prefix beam search nbest results
"""
hyps
=
self
.
_ctc_prefix_beam_search
(
wav
,
beam_size
)
hyps
=
self
.
_ctc_prefix_beam_search
(
wav
,
beam_size
)
return
hyps
[
0
][
0
]
# @jit.to_static
# def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
# """ Export interface for c++ call, apply linear transform and log
# softmax before ctc
# Args:
# xs (paddle.Tensor): encoder output, (B, T, D)
# Returns:
# paddle.Tensor: activation before ctc
# """
# return self.ctc.log_softmax(xs)
# def _get_data(self):
# data_dir = "data"
# wavs = np.load(os.path.join(data_dir, "wavs.npy"))
# wavs_lens = np.load(os.path.join(data_dir, "wavs_lens.npy"))
# tokens = np.load(os.path.join(data_dir, "tokens.npy"))
# tokens_lens = np.load(os.path.join(data_dir, "tokens_lens.npy"))
# batch = (paddle.to_tensor(wavs), paddle.to_tensor(wavs_lens, dtype='float32'),
# paddle.to_tensor(tokens, dtype='int32'), paddle.to_tensor(tokens_lens, dtype='float32'))
# return batch
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录