Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
c7a7b113
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c7a7b113
编写于
6月 24, 2022
作者:
H
huangyuxin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support multi-gpu training with webdataset
上级
8f5e6109
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
341 addition
and
1930 deletion
+341
-1930
examples/wenetspeech/asr1/conf/conformer.yaml
examples/wenetspeech/asr1/conf/conformer.yaml
+25
-10
paddlespeech/audio/stream_data/__init__.py
paddlespeech/audio/stream_data/__init__.py
+2
-1
paddlespeech/audio/stream_data/filters.py
paddlespeech/audio/stream_data/filters.py
+31
-8
paddlespeech/audio/stream_data/pipeline.py
paddlespeech/audio/stream_data/pipeline.py
+6
-0
paddlespeech/audio/stream_data/shardlists.py
paddlespeech/audio/stream_data/shardlists.py
+2
-0
paddlespeech/audio/utils/log.py
paddlespeech/audio/utils/log.py
+2
-1
paddlespeech/audio/utils/tensor_utils.py
paddlespeech/audio/utils/tensor_utils.py
+0
-3
paddlespeech/s2t/exps/u2/model.py
paddlespeech/s2t/exps/u2/model.py
+185
-84
paddlespeech/s2t/io/dataloader.py
paddlespeech/s2t/io/dataloader.py
+87
-0
paddlespeech/s2t/io/reader.py
paddlespeech/s2t/io/reader.py
+1
-1
paddlespeech/s2t/transform/__init__.py
paddlespeech/s2t/transform/__init__.py
+0
-13
paddlespeech/s2t/transform/add_deltas.py
paddlespeech/s2t/transform/add_deltas.py
+0
-54
paddlespeech/s2t/transform/channel_selector.py
paddlespeech/s2t/transform/channel_selector.py
+0
-57
paddlespeech/s2t/transform/cmvn.py
paddlespeech/s2t/transform/cmvn.py
+0
-201
paddlespeech/s2t/transform/functional.py
paddlespeech/s2t/transform/functional.py
+0
-86
paddlespeech/s2t/transform/perturb.py
paddlespeech/s2t/transform/perturb.py
+0
-471
paddlespeech/s2t/transform/spec_augment.py
paddlespeech/s2t/transform/spec_augment.py
+0
-214
paddlespeech/s2t/transform/spectrogram.py
paddlespeech/s2t/transform/spectrogram.py
+0
-475
paddlespeech/s2t/transform/transform_interface.py
paddlespeech/s2t/transform/transform_interface.py
+0
-35
paddlespeech/s2t/transform/transformation.py
paddlespeech/s2t/transform/transformation.py
+0
-158
paddlespeech/s2t/transform/wpe.py
paddlespeech/s2t/transform/wpe.py
+0
-58
未找到文件。
examples/wenetspeech/asr1/conf/conformer.yaml
浏览文件 @
c7a7b113
...
...
@@ -50,26 +50,41 @@ test_manifest: data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath
:
data/lang_char/vocab.txt
use_stream_data
:
True
unit_type
:
'
char'
vocab_filepath
:
data/lang_char/vocab.txt
cmvn_file
:
data/mean_std.json
preprocess_config
:
conf/preprocess.yaml
spm_model_prefix
:
'
'
feat_dim
:
80
stride_ms
:
10.0
window_ms
:
25.0
dither
:
0.1
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size
:
64
minlen_in
:
10
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
minlen_out
:
0
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
minibatches
:
0
# for debug
batch_count
:
auto
batch_bins
:
0
batch_frames_in
:
0
batch_frames_out
:
0
batch_frames_inout
:
0
num_workers
:
0
subsampling_factor
:
1
resample_rate
:
16000
shuffle_size
:
10000
sort_size
:
500
num_workers
:
4
prefetch_factor
:
100
dist_sampler
:
True
num_encs
:
1
augment_conf
:
max_w
:
80
w_inplace
:
True
w_mode
:
"
PIL"
max_f
:
30
num_f_mask
:
2
f_inplace
:
True
f_replace_with_zero
:
False
max_t
:
40
num_t_mask
:
2
t_inplace
:
True
t_replace_with_zero
:
False
###########################################
...
...
@@ -78,7 +93,7 @@ num_encs: 1
n_epoch
:
240
accum_grad
:
16
global_grad_clip
:
5.0
log_interval
:
1
00
log_interval
:
1
checkpoint
:
kbest_n
:
50
latest_n
:
5
...
...
paddlespeech/audio/stream_data/__init__.py
浏览文件 @
c7a7b113
...
...
@@ -41,7 +41,8 @@ from .filters import (
spec_aug
,
sort
,
padding
,
cmvn
cmvn
,
placeholder
,
)
from
webdataset.handlers
import
(
ignore_and_continue
,
...
...
paddlespeech/audio/stream_data/filters.py
浏览文件 @
c7a7b113
...
...
@@ -758,27 +758,44 @@ def _compute_fbank(source,
compute_fbank
=
pipelinefilter
(
_compute_fbank
)
def
_spec_aug
(
source
,
num_t_mask
=
2
,
num_f_mask
=
2
,
max_t
=
40
,
max_f
=
30
,
max_w
=
80
):
def
_spec_aug
(
source
,
max_w
=
5
,
w_inplace
=
True
,
w_mode
=
"PIL"
,
max_f
=
30
,
num_f_mask
=
2
,
f_inplace
=
True
,
f_replace_with_zero
=
False
,
max_t
=
40
,
num_t_mask
=
2
,
t_inplace
=
True
,
t_replace_with_zero
=
False
,):
""" Do spec augmentation
Inplace operation
Args:
source: Iterable[{fname, feat, label}]
num_t_mask: number of time mask to apply
max_w: max width of time warp
w_inplace: whether to inplace the original data while time warping
w_mode: time warp mode
max_f: max width of freq mask
num_f_mask: number of freq mask to apply
f_inplace: whether to inplace the original data while frequency masking
f_replace_with_zero: use zero to mask
max_t: max width of time mask
max_f: max width of freq mask
max_w: max width of time warp
num_t_mask: number of time mask to apply
t_inplace: whether to inplace the original data while time masking
t_replace_with_zero: use zero to mask
Returns
Iterable[{fname, feat, label}]
"""
for
sample
in
source
:
x
=
sample
[
'feat'
]
x
=
x
.
numpy
()
x
=
time_warp
(
x
,
max_time_warp
=
max_w
,
inplace
=
True
,
mode
=
"PIL"
)
x
=
freq_mask
(
x
,
F
=
max_f
,
n_mask
=
num_f_mask
,
inplace
=
True
,
replace_with_zero
=
False
)
x
=
time_mask
(
x
,
T
=
max_t
,
n_mask
=
num_t_mask
,
inplace
=
True
,
replace_with_zero
=
False
)
x
=
time_warp
(
x
,
max_time_warp
=
max_w
,
inplace
=
w_inplace
,
mode
=
w_mode
)
x
=
freq_mask
(
x
,
F
=
max_f
,
n_mask
=
num_f_mask
,
inplace
=
f_inplace
,
replace_with_zero
=
f_replace_with_zero
)
x
=
time_mask
(
x
,
T
=
max_t
,
n_mask
=
num_t_mask
,
inplace
=
t_inplace
,
replace_with_zero
=
t_replace_with_zero
)
sample
[
'feat'
]
=
paddle
.
to_tensor
(
x
,
dtype
=
paddle
.
float32
)
yield
sample
...
...
@@ -910,3 +927,9 @@ def _cmvn(source, cmvn_file):
label_lengths
)
cmvn
=
pipelinefilter
(
_cmvn
)
def
_placeholder
(
source
):
for
data
in
source
:
yield
data
placeholder
=
pipelinefilter
(
_placeholder
)
\ No newline at end of file
paddlespeech/audio/stream_data/pipeline.py
浏览文件 @
c7a7b113
...
...
@@ -89,6 +89,12 @@ class DataPipeline(IterableDataset, PipelineStage):
def
append
(
self
,
f
):
"""Append a pipeline stage (modifies the object)."""
self
.
pipeline
.
append
(
f
)
return
self
def
append_list
(
self
,
*
args
):
for
arg
in
args
:
self
.
pipeline
.
append
(
arg
)
return
self
def
compose
(
self
,
*
args
):
"""Append a pipeline stage to a copy of the pipeline and returns the copy."""
...
...
paddlespeech/audio/stream_data/shardlists.py
浏览文件 @
c7a7b113
...
...
@@ -24,6 +24,8 @@ from .filters import pipelinefilter
from
.paddle_utils
import
IterableDataset
from
..utils.log
import
Logger
logger
=
Logger
(
__name__
)
def
expand_urls
(
urls
):
if
isinstance
(
urls
,
str
):
urllist
=
urls
.
split
(
"::"
)
...
...
paddlespeech/audio/utils/log.py
浏览文件 @
c7a7b113
...
...
@@ -65,6 +65,7 @@ class Logger(object):
def
__init__
(
self
,
name
:
str
=
None
):
name
=
'PaddleAudio'
if
not
name
else
name
self
.
name
=
name
self
.
logger
=
logging
.
getLogger
(
name
)
for
key
,
conf
in
log_config
.
items
():
...
...
@@ -101,7 +102,7 @@ class Logger(object):
if
not
self
.
is_enable
:
return
self
.
logger
.
log
(
log_level
,
msg
)
self
.
logger
.
log
(
log_level
,
self
.
name
+
" | "
+
msg
)
@
contextlib
.
contextmanager
def
use_terminator
(
self
,
terminator
:
str
):
...
...
paddlespeech/audio/utils/tensor_utils.py
浏览文件 @
c7a7b113
...
...
@@ -93,9 +93,6 @@ def pad_sequence(sequences: List[paddle.Tensor],
for
i
,
tensor
in
enumerate
(
sequences
):
length
=
tensor
.
shape
[
0
]
# use index notation to prevent duplicate references to the tensor
logger
.
info
(
f
"length
{
length
}
, out_tensor
{
out_tensor
.
shape
}
, tensor
{
tensor
.
shape
}
"
)
if
batch_first
:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# TODO (Hui Zhang): set_value op not support int16
...
...
paddlespeech/s2t/exps/u2/model.py
浏览文件 @
c7a7b113
...
...
@@ -26,6 +26,7 @@ from paddle import distributed as dist
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.io.dataloader
import
StreamDataLoader
from
paddlespeech.s2t.models.u2
import
U2Model
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
from
paddlespeech.s2t.training.reporter
import
ObsScope
...
...
@@ -106,7 +107,8 @@ class U2Trainer(Trainer):
@
paddle
.
no_grad
()
def
valid
(
self
):
self
.
model
.
eval
()
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
total_loss
=
0.0
...
...
@@ -132,7 +134,7 @@ class U2Trainer(Trainer):
msg
=
f
"Valid: Rank:
{
dist
.
get_rank
()
}
, "
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
#
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_dump
.
items
())
logger
.
info
(
msg
)
...
...
@@ -152,7 +154,8 @@ class U2Trainer(Trainer):
self
.
before_train
()
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
n_epoch
:
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
...
...
@@ -170,7 +173,8 @@ class U2Trainer(Trainer):
self
.
train_batch
(
batch_index
,
batch
,
msg
)
self
.
after_train_batch
()
report
(
'iter'
,
batch_index
+
1
)
report
(
'total'
,
len
(
self
.
train_loader
))
if
not
self
.
use_streamdata
:
report
(
'total'
,
len
(
self
.
train_loader
))
report
(
'reader_cost'
,
dataload_time
)
observation
[
'batch_cost'
]
=
observation
[
'reader_cost'
]
+
observation
[
'step_cost'
]
...
...
@@ -218,92 +222,188 @@ class U2Trainer(Trainer):
def
setup_dataloader
(
self
):
config
=
self
.
config
.
clone
()
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
if
self
.
train
:
# train/valid dataset, return token ids
self
.
train_loader
=
BatchDataLoader
(
json_file
=
config
.
train_manifest
,
train_mode
=
True
,
sortagrad
=
config
.
sortagrad
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
config
.
maxlen_in
,
maxlen_out
=
config
.
maxlen_out
,
minibatches
=
config
.
minibatches
,
mini_batch_size
=
self
.
args
.
ngpu
,
batch_count
=
config
.
batch_count
,
batch_bins
=
config
.
batch_bins
,
batch_frames_in
=
config
.
batch_frames_in
,
batch_frames_out
=
config
.
batch_frames_out
,
batch_frames_inout
=
config
.
batch_frames_inout
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
,
dist_sampler
=
config
.
get
(
'dist_sampler'
,
False
),
shortest_first
=
False
)
self
.
valid_loader
=
BatchDataLoader
(
json_file
=
config
.
dev_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
self
.
args
.
ngpu
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
,
dist_sampler
=
config
.
get
(
'dist_sampler'
,
False
),
shortest_first
=
False
)
if
self
.
use_streamdata
:
self
.
train_loader
=
StreamDataLoader
(
manifest_file
=
config
.
train_manifest
,
train_mode
=
True
,
unit_type
=
config
.
unit_type
,
batch_size
=
config
.
batch_size
,
num_mel_bins
=
config
.
feat_dim
,
frame_length
=
config
.
window_ms
,
frame_shift
=
config
.
stride_ms
,
dither
=
config
.
dither
,
minlen_in
=
config
.
minlen_in
,
maxlen_in
=
config
.
maxlen_in
,
minlen_out
=
config
.
minlen_out
,
maxlen_out
=
config
.
maxlen_out
,
resample_rate
=
config
.
resample_rate
,
augment_conf
=
config
.
augment_conf
,
# dict
shuffle_size
=
config
.
shuffle_size
,
sort_size
=
config
.
sort_size
,
n_iter_processes
=
config
.
num_workers
,
prefetch_factor
=
config
.
prefetch_factor
,
dist_sampler
=
config
.
get
(
'dist_sampler'
,
False
),
cmvn_file
=
config
.
cmvn_file
,
vocab_filepath
=
config
.
vocab_filepath
,
)
self
.
valid_loader
=
StreamDataLoader
(
manifest_file
=
config
.
dev_manifest
,
train_mode
=
False
,
unit_type
=
config
.
unit_type
,
batch_size
=
config
.
batch_size
,
num_mel_bins
=
config
.
feat_dim
,
frame_length
=
config
.
window_ms
,
frame_shift
=
config
.
stride_ms
,
dither
=
config
.
dither
,
minlen_in
=
config
.
minlen_in
,
maxlen_in
=
config
.
maxlen_in
,
minlen_out
=
config
.
minlen_out
,
maxlen_out
=
config
.
maxlen_out
,
resample_rate
=
config
.
resample_rate
,
augment_conf
=
config
.
augment_conf
,
# dict
shuffle_size
=
config
.
shuffle_size
,
sort_size
=
config
.
sort_size
,
n_iter_processes
=
config
.
num_workers
,
prefetch_factor
=
config
.
prefetch_factor
,
dist_sampler
=
config
.
get
(
'dist_sampler'
,
False
),
cmvn_file
=
config
.
cmvn_file
,
vocab_filepath
=
config
.
vocab_filepath
,
)
else
:
self
.
train_loader
=
BatchDataLoader
(
json_file
=
config
.
train_manifest
,
train_mode
=
True
,
sortagrad
=
config
.
sortagrad
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
config
.
maxlen_in
,
maxlen_out
=
config
.
maxlen_out
,
minibatches
=
config
.
minibatches
,
mini_batch_size
=
self
.
args
.
ngpu
,
batch_count
=
config
.
batch_count
,
batch_bins
=
config
.
batch_bins
,
batch_frames_in
=
config
.
batch_frames_in
,
batch_frames_out
=
config
.
batch_frames_out
,
batch_frames_inout
=
config
.
batch_frames_inout
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
,
dist_sampler
=
config
.
get
(
'dist_sampler'
,
False
),
shortest_first
=
False
)
self
.
valid_loader
=
BatchDataLoader
(
json_file
=
config
.
dev_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
self
.
args
.
ngpu
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
,
dist_sampler
=
config
.
get
(
'dist_sampler'
,
False
),
shortest_first
=
False
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
decode_batch_size
=
config
.
get
(
'decode'
,
dict
()).
get
(
'decode_batch_size'
,
1
)
# test dataset, return raw text
self
.
test_loader
=
BatchDataLoader
(
json_file
=
config
.
test_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
decode_batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
self
.
align_loader
=
BatchDataLoader
(
json_file
=
config
.
test_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
decode_batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
if
self
.
use_streamdata
:
self
.
test_loader
=
StreamDataLoader
(
manifest_file
=
config
.
test_manifest
,
train_mode
=
False
,
unit_type
=
config
.
unit_type
,
batch_size
=
config
.
batch_size
,
num_mel_bins
=
config
.
feat_dim
,
frame_length
=
config
.
window_ms
,
frame_shift
=
config
.
stride_ms
,
dither
=
0.0
,
minlen_in
=
0.0
,
maxlen_in
=
float
(
'inf'
),
minlen_out
=
0
,
maxlen_out
=
float
(
'inf'
),
resample_rate
=
config
.
resample_rate
,
augment_conf
=
config
.
augment_conf
,
# dict
shuffle_size
=
config
.
shuffle_size
,
sort_size
=
config
.
sort_size
,
n_iter_processes
=
config
.
num_workers
,
prefetch_factor
=
config
.
prefetch_factor
,
dist_sampler
=
config
.
get
(
'dist_sampler'
,
False
),
cmvn_file
=
config
.
cmvn_file
,
vocab_filepath
=
config
.
vocab_filepath
,
)
self
.
align_loader
=
StreamDataLoader
(
manifest_file
=
config
.
test_manifest
,
train_mode
=
False
,
unit_type
=
config
.
unit_type
,
batch_size
=
config
.
batch_size
,
num_mel_bins
=
config
.
feat_dim
,
frame_length
=
config
.
window_ms
,
frame_shift
=
config
.
stride_ms
,
dither
=
0.0
,
minlen_in
=
0.0
,
maxlen_in
=
float
(
'inf'
),
minlen_out
=
0
,
maxlen_out
=
float
(
'inf'
),
resample_rate
=
config
.
resample_rate
,
augment_conf
=
config
.
augment_conf
,
# dict
shuffle_size
=
config
.
shuffle_size
,
sort_size
=
config
.
sort_size
,
n_iter_processes
=
config
.
num_workers
,
prefetch_factor
=
config
.
prefetch_factor
,
dist_sampler
=
config
.
get
(
'dist_sampler'
,
False
),
cmvn_file
=
config
.
cmvn_file
,
vocab_filepath
=
config
.
vocab_filepath
,
)
else
:
self
.
test_loader
=
BatchDataLoader
(
json_file
=
config
.
test_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
decode_batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
self
.
align_loader
=
BatchDataLoader
(
json_file
=
config
.
test_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
decode_batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
logger
.
info
(
"Setup test/align Dataloader!"
)
def
setup_model
(
self
):
...
...
@@ -452,7 +552,8 @@ class U2Tester(U2Trainer):
def
test
(
self
):
assert
self
.
args
.
result_file
self
.
model
.
eval
()
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
stride_ms
=
self
.
config
.
stride_ms
error_rate_type
=
None
...
...
paddlespeech/s2t/io/dataloader.py
浏览文件 @
c7a7b113
...
...
@@ -28,6 +28,9 @@ from paddlespeech.s2t.io.dataset import TransformDataset
from
paddlespeech.s2t.io.reader
import
LoadInputsAndTargets
from
paddlespeech.s2t.utils.log
import
Log
import
paddlespeech.audio.stream_data
as
stream_data
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
__all__
=
[
"BatchDataLoader"
]
logger
=
Log
(
__name__
).
getlog
()
...
...
@@ -56,6 +59,90 @@ def batch_collate(x):
"""
return
x
[
0
]
class
StreamDataLoader
():
def
__init__
(
self
,
manifest_file
:
str
,
train_mode
:
bool
,
unit_type
:
str
=
'char'
,
batch_size
:
int
=
0
,
num_mel_bins
=
80
,
frame_length
=
25
,
frame_shift
=
10
,
dither
=
0.0
,
minlen_in
:
float
=
0.0
,
maxlen_in
:
float
=
float
(
'inf'
),
minlen_out
:
float
=
0.0
,
maxlen_out
:
float
=
float
(
'inf'
),
resample_rate
:
int
=
16000
,
augment_conf
:
dict
=
None
,
shuffle_size
:
int
=
10000
,
sort_size
:
int
=
1000
,
n_iter_processes
:
int
=
1
,
prefetch_factor
:
int
=
2
,
dist_sampler
:
bool
=
False
,
cmvn_file
=
"data/mean_std.json"
,
vocab_filepath
=
'data/lang_char/vocab.txt'
):
self
.
manifest_file
=
manifest_file
self
.
train_model
=
train_mode
self
.
batch_size
=
batch_size
self
.
prefetch_factor
=
prefetch_factor
self
.
dist_sampler
=
dist_sampler
self
.
n_iter_processes
=
n_iter_processes
text_featurizer
=
TextFeaturizer
(
unit_type
,
vocab_filepath
)
symbol_table
=
text_featurizer
.
vocab_dict
self
.
feat_dim
=
num_mel_bins
self
.
vocab_size
=
text_featurizer
.
vocab_size
# The list of shard
shardlist
=
[]
with
open
(
manifest_file
,
"r"
)
as
f
:
for
line
in
f
.
readlines
():
shardlist
.
append
(
line
.
strip
())
if
self
.
dist_sampler
:
base_dataset
=
stream_data
.
DataPipeline
(
stream_data
.
SimpleShardList
(
shardlist
),
stream_data
.
split_by_node
,
stream_data
.
split_by_worker
,
stream_data
.
tarfile_to_samples
(
stream_data
.
reraise_exception
)
)
else
:
base_dataset
=
stream_data
.
DataPipeline
(
stream_data
.
SimpleShardList
(
shardlist
),
stream_data
.
split_by_worker
,
stream_data
.
tarfile_to_samples
(
stream_data
.
reraise_exception
)
)
self
.
dataset
=
base_dataset
.
append_list
(
stream_data
.
tokenize
(
symbol_table
),
stream_data
.
data_filter
(
frame_shift
=
frame_shift
,
max_length
=
maxlen_in
,
min_length
=
minlen_in
,
token_max_length
=
maxlen_out
,
token_min_length
=
minlen_in
),
stream_data
.
resample
(
resample_rate
=
resample_rate
),
stream_data
.
compute_fbank
(
num_mel_bins
=
num_mel_bins
,
frame_length
=
frame_length
,
frame_shift
=
frame_shift
,
dither
=
dither
),
stream_data
.
spec_aug
(
**
augment_conf
)
if
train_mode
else
stream_data
.
placeholder
(),
# num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
stream_data
.
shuffle
(
shuffle_size
),
stream_data
.
sort
(
sort_size
=
sort_size
),
stream_data
.
batched
(
batch_size
),
stream_data
.
padding
(),
stream_data
.
cmvn
(
cmvn_file
)
)
self
.
loader
=
stream_data
.
WebLoader
(
self
.
dataset
,
num_workers
=
self
.
n_iter_processes
,
prefetch_factor
=
self
.
prefetch_factor
,
batch_size
=
None
)
def
__iter__
(
self
):
return
self
.
loader
.
__iter__
()
def
__call__
(
self
):
return
self
.
__iter__
()
def
__len__
(
self
):
logger
.
info
(
"Stream dataloader does not support calculate the length of the dataset"
)
return
-
1
class
BatchDataLoader
():
def
__init__
(
self
,
...
...
paddlespeech/s2t/io/reader.py
浏览文件 @
c7a7b113
...
...
@@ -19,7 +19,7 @@ import numpy as np
import
soundfile
from
.utility
import
feat_type
from
paddlespeech.
s2t
.transform.transformation
import
Transformation
from
paddlespeech.
audio
.transform.transformation
import
Transformation
from
paddlespeech.s2t.utils.log
import
Log
# from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline as Transformation
...
...
paddlespeech/s2t/transform/__init__.py
已删除
100644 → 0
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/s2t/transform/add_deltas.py
已删除
100644 → 0
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
numpy
as
np
def
delta
(
feat
,
window
):
assert
window
>
0
delta_feat
=
np
.
zeros_like
(
feat
)
for
i
in
range
(
1
,
window
+
1
):
delta_feat
[:
-
i
]
+=
i
*
feat
[
i
:]
delta_feat
[
i
:]
+=
-
i
*
feat
[:
-
i
]
delta_feat
[
-
i
:]
+=
i
*
feat
[
-
1
]
delta_feat
[:
i
]
+=
-
i
*
feat
[
0
]
delta_feat
/=
2
*
sum
(
i
**
2
for
i
in
range
(
1
,
window
+
1
))
return
delta_feat
def
add_deltas
(
x
,
window
=
2
,
order
=
2
):
"""
Args:
x (np.ndarray): speech feat, (T, D).
Return:
np.ndarray: (T, (1+order)*D)
"""
feats
=
[
x
]
for
_
in
range
(
order
):
feats
.
append
(
delta
(
feats
[
-
1
],
window
))
return
np
.
concatenate
(
feats
,
axis
=
1
)
class
AddDeltas
():
def
__init__
(
self
,
window
=
2
,
order
=
2
):
self
.
window
=
window
self
.
order
=
order
def
__repr__
(
self
):
return
"{name}(window={window}, order={order}"
.
format
(
name
=
self
.
__class__
.
__name__
,
window
=
self
.
window
,
order
=
self
.
order
)
def
__call__
(
self
,
x
):
return
add_deltas
(
x
,
window
=
self
.
window
,
order
=
self
.
order
)
paddlespeech/s2t/transform/channel_selector.py
已删除
100644 → 0
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
numpy
class
ChannelSelector
():
"""Select 1ch from multi-channel signal"""
def
__init__
(
self
,
train_channel
=
"random"
,
eval_channel
=
0
,
axis
=
1
):
self
.
train_channel
=
train_channel
self
.
eval_channel
=
eval_channel
self
.
axis
=
axis
def
__repr__
(
self
):
return
(
"{name}(train_channel={train_channel}, "
"eval_channel={eval_channel}, axis={axis})"
.
format
(
name
=
self
.
__class__
.
__name__
,
train_channel
=
self
.
train_channel
,
eval_channel
=
self
.
eval_channel
,
axis
=
self
.
axis
,
))
def
__call__
(
self
,
x
,
train
=
True
):
# Assuming x: [Time, Channel] by default
if
x
.
ndim
<=
self
.
axis
:
# If the dimension is insufficient, then unsqueeze
# (e.g [Time] -> [Time, 1])
ind
=
tuple
(
slice
(
None
)
if
i
<
x
.
ndim
else
None
for
i
in
range
(
self
.
axis
+
1
))
x
=
x
[
ind
]
if
train
:
channel
=
self
.
train_channel
else
:
channel
=
self
.
eval_channel
if
channel
==
"random"
:
ch
=
numpy
.
random
.
randint
(
0
,
x
.
shape
[
self
.
axis
])
else
:
ch
=
channel
ind
=
tuple
(
slice
(
None
)
if
i
!=
self
.
axis
else
ch
for
i
in
range
(
x
.
ndim
))
return
x
[
ind
]
paddlespeech/s2t/transform/cmvn.py
已删除
100644 → 0
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
io
import
json
import
h5py
import
kaldiio
import
numpy
as
np
class
CMVN
():
"Apply Global/Spk CMVN/iverserCMVN."
def
__init__
(
self
,
stats
,
norm_means
=
True
,
norm_vars
=
False
,
filetype
=
"mat"
,
utt2spk
=
None
,
spk2utt
=
None
,
reverse
=
False
,
std_floor
=
1.0e-20
,
):
self
.
stats_file
=
stats
self
.
norm_means
=
norm_means
self
.
norm_vars
=
norm_vars
self
.
reverse
=
reverse
if
isinstance
(
stats
,
dict
):
stats_dict
=
dict
(
stats
)
else
:
# Use for global CMVN
if
filetype
==
"mat"
:
stats_dict
=
{
None
:
kaldiio
.
load_mat
(
stats
)}
# Use for global CMVN
elif
filetype
==
"npy"
:
stats_dict
=
{
None
:
np
.
load
(
stats
)}
# Use for speaker CMVN
elif
filetype
==
"ark"
:
self
.
accept_uttid
=
True
stats_dict
=
dict
(
kaldiio
.
load_ark
(
stats
))
# Use for speaker CMVN
elif
filetype
==
"hdf5"
:
self
.
accept_uttid
=
True
stats_dict
=
h5py
.
File
(
stats
)
else
:
raise
ValueError
(
"Not supporting filetype={}"
.
format
(
filetype
))
if
utt2spk
is
not
None
:
self
.
utt2spk
=
{}
with
io
.
open
(
utt2spk
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
for
line
in
f
:
utt
,
spk
=
line
.
rstrip
().
split
(
None
,
1
)
self
.
utt2spk
[
utt
]
=
spk
elif
spk2utt
is
not
None
:
self
.
utt2spk
=
{}
with
io
.
open
(
spk2utt
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
for
line
in
f
:
spk
,
utts
=
line
.
rstrip
().
split
(
None
,
1
)
for
utt
in
utts
.
split
():
self
.
utt2spk
[
utt
]
=
spk
else
:
self
.
utt2spk
=
None
# Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1),
# and the first vector contains the sum of feats and the second is
# the sum of squares. The last value of the first, i.e. stats[0,-1],
# is the number of samples for this statistics.
self
.
bias
=
{}
self
.
scale
=
{}
for
spk
,
stats
in
stats_dict
.
items
():
assert
len
(
stats
)
==
2
,
stats
.
shape
count
=
stats
[
0
,
-
1
]
# If the feature has two or more dimensions
if
not
(
np
.
isscalar
(
count
)
or
isinstance
(
count
,
(
int
,
float
))):
# The first is only used
count
=
count
.
flatten
()[
0
]
mean
=
stats
[
0
,
:
-
1
]
/
count
# V(x) = E(x^2) - (E(x))^2
var
=
stats
[
1
,
:
-
1
]
/
count
-
mean
*
mean
std
=
np
.
maximum
(
np
.
sqrt
(
var
),
std_floor
)
self
.
bias
[
spk
]
=
-
mean
self
.
scale
[
spk
]
=
1
/
std
def
__repr__
(
self
):
return
(
"{name}(stats_file={stats_file}, "
"norm_means={norm_means}, norm_vars={norm_vars}, "
"reverse={reverse})"
.
format
(
name
=
self
.
__class__
.
__name__
,
stats_file
=
self
.
stats_file
,
norm_means
=
self
.
norm_means
,
norm_vars
=
self
.
norm_vars
,
reverse
=
self
.
reverse
,
))
def
__call__
(
self
,
x
,
uttid
=
None
):
if
self
.
utt2spk
is
not
None
:
spk
=
self
.
utt2spk
[
uttid
]
else
:
spk
=
uttid
if
not
self
.
reverse
:
# apply cmvn
if
self
.
norm_means
:
x
=
np
.
add
(
x
,
self
.
bias
[
spk
])
if
self
.
norm_vars
:
x
=
np
.
multiply
(
x
,
self
.
scale
[
spk
])
else
:
# apply reverse cmvn
if
self
.
norm_vars
:
x
=
np
.
divide
(
x
,
self
.
scale
[
spk
])
if
self
.
norm_means
:
x
=
np
.
subtract
(
x
,
self
.
bias
[
spk
])
return
x
class
UtteranceCMVN
():
"Apply Utterance CMVN"
def
__init__
(
self
,
norm_means
=
True
,
norm_vars
=
False
,
std_floor
=
1.0e-20
):
self
.
norm_means
=
norm_means
self
.
norm_vars
=
norm_vars
self
.
std_floor
=
std_floor
def
__repr__
(
self
):
return
"{name}(norm_means={norm_means}, norm_vars={norm_vars})"
.
format
(
name
=
self
.
__class__
.
__name__
,
norm_means
=
self
.
norm_means
,
norm_vars
=
self
.
norm_vars
,
)
def
__call__
(
self
,
x
,
uttid
=
None
):
# x: [Time, Dim]
square_sums
=
(
x
**
2
).
sum
(
axis
=
0
)
mean
=
x
.
mean
(
axis
=
0
)
if
self
.
norm_means
:
x
=
np
.
subtract
(
x
,
mean
)
if
self
.
norm_vars
:
var
=
square_sums
/
x
.
shape
[
0
]
-
mean
**
2
std
=
np
.
maximum
(
np
.
sqrt
(
var
),
self
.
std_floor
)
x
=
np
.
divide
(
x
,
std
)
return
x
class
GlobalCMVN
():
"Apply Global CMVN"
def
__init__
(
self
,
cmvn_path
,
norm_means
=
True
,
norm_vars
=
True
,
std_floor
=
1.0e-20
):
# cmvn_path: Option[str, dict]
cmvn
=
cmvn_path
self
.
cmvn
=
cmvn
self
.
norm_means
=
norm_means
self
.
norm_vars
=
norm_vars
self
.
std_floor
=
std_floor
if
isinstance
(
cmvn
,
dict
):
cmvn_stats
=
cmvn
else
:
with
open
(
cmvn
)
as
f
:
cmvn_stats
=
json
.
load
(
f
)
self
.
count
=
cmvn_stats
[
'frame_num'
]
self
.
mean
=
np
.
array
(
cmvn_stats
[
'mean_stat'
])
/
self
.
count
self
.
square_sums
=
np
.
array
(
cmvn_stats
[
'var_stat'
])
self
.
var
=
self
.
square_sums
/
self
.
count
-
self
.
mean
**
2
self
.
std
=
np
.
maximum
(
np
.
sqrt
(
self
.
var
),
self
.
std_floor
)
def
__repr__
(
self
):
return
f
"""
{
self
.
__class__
.
__name__
}
(
cmvn_path=
{
self
.
cmvn
}
,
norm_means=
{
self
.
norm_means
}
,
norm_vars=
{
self
.
norm_vars
}
,)"""
def
__call__
(
self
,
x
,
uttid
=
None
):
# x: [Time, Dim]
if
self
.
norm_means
:
x
=
np
.
subtract
(
x
,
self
.
mean
)
if
self
.
norm_vars
:
x
=
np
.
divide
(
x
,
self
.
std
)
return
x
paddlespeech/s2t/transform/functional.py
已删除
100644 → 0
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
inspect
from
paddlespeech.s2t.transform.transform_interface
import
TransformInterface
from
paddlespeech.s2t.utils.check_kwargs
import
check_kwargs
class
FuncTrans
(
TransformInterface
):
"""Functional Transformation
WARNING:
Builtin or C/C++ functions may not work properly
because this class heavily depends on the `inspect` module.
Usage:
>>> def foo_bar(x, a=1, b=2):
... '''Foo bar
... :param x: input
... :param int a: default 1
... :param int b: default 2
... '''
... return x + a - b
>>> class FooBar(FuncTrans):
... _func = foo_bar
... __doc__ = foo_bar.__doc__
"""
_func
=
None
def
__init__
(
self
,
**
kwargs
):
self
.
kwargs
=
kwargs
check_kwargs
(
self
.
func
,
kwargs
)
def
__call__
(
self
,
x
):
return
self
.
func
(
x
,
**
self
.
kwargs
)
@
classmethod
def
add_arguments
(
cls
,
parser
):
fname
=
cls
.
_func
.
__name__
.
replace
(
"_"
,
"-"
)
group
=
parser
.
add_argument_group
(
fname
+
" transformation setting"
)
for
k
,
v
in
cls
.
default_params
().
items
():
# TODO(karita): get help and choices from docstring?
attr
=
k
.
replace
(
"_"
,
"-"
)
group
.
add_argument
(
f
"--
{
fname
}
-
{
attr
}
"
,
default
=
v
,
type
=
type
(
v
))
return
parser
@
property
def
func
(
self
):
return
type
(
self
).
_func
@
classmethod
def
default_params
(
cls
):
try
:
d
=
dict
(
inspect
.
signature
(
cls
.
_func
).
parameters
)
except
ValueError
:
d
=
dict
()
return
{
k
:
v
.
default
for
k
,
v
in
d
.
items
()
if
v
.
default
!=
inspect
.
Parameter
.
empty
}
def
__repr__
(
self
):
params
=
self
.
default_params
()
params
.
update
(
**
self
.
kwargs
)
ret
=
self
.
__class__
.
__name__
+
"("
if
len
(
params
)
==
0
:
return
ret
+
")"
for
k
,
v
in
params
.
items
():
ret
+=
"{}={}, "
.
format
(
k
,
v
)
return
ret
[:
-
2
]
+
")"
paddlespeech/s2t/transform/perturb.py
已删除
100644 → 0
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
librosa
import
numpy
import
scipy
import
soundfile
from
paddlespeech.s2t.io.reader
import
SoundHDF5File
class
SpeedPerturbation
():
"""SpeedPerturbation
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
and sox-speed just to resample the input,
i.e pitch and tempo are changed both.
"Why use speed option instead of tempo -s in SoX for speed perturbation"
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
Warning:
This function is very slow because of resampling.
I recommmend to apply speed-perturb outside the training using sox.
"""
def
__init__
(
self
,
lower
=
0.9
,
upper
=
1.1
,
utt2ratio
=
None
,
keep_length
=
True
,
res_type
=
"kaiser_best"
,
seed
=
None
,
):
self
.
res_type
=
res_type
self
.
keep_length
=
keep_length
self
.
state
=
numpy
.
random
.
RandomState
(
seed
)
if
utt2ratio
is
not
None
:
self
.
utt2ratio
=
{}
# Use the scheduled ratio for each utterances
self
.
utt2ratio_file
=
utt2ratio
self
.
lower
=
None
self
.
upper
=
None
self
.
accept_uttid
=
True
with
open
(
utt2ratio
,
"r"
)
as
f
:
for
line
in
f
:
utt
,
ratio
=
line
.
rstrip
().
split
(
None
,
1
)
ratio
=
float
(
ratio
)
self
.
utt2ratio
[
utt
]
=
ratio
else
:
self
.
utt2ratio
=
None
# The ratio is given on runtime randomly
self
.
lower
=
lower
self
.
upper
=
upper
def
__repr__
(
self
):
if
self
.
utt2ratio
is
None
:
return
"{}(lower={}, upper={}, "
"keep_length={}, res_type={})"
.
format
(
self
.
__class__
.
__name__
,
self
.
lower
,
self
.
upper
,
self
.
keep_length
,
self
.
res_type
,
)
else
:
return
"{}({}, res_type={})"
.
format
(
self
.
__class__
.
__name__
,
self
.
utt2ratio_file
,
self
.
res_type
)
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
x
x
=
x
.
astype
(
numpy
.
float32
)
if
self
.
accept_uttid
:
ratio
=
self
.
utt2ratio
[
uttid
]
else
:
ratio
=
self
.
state
.
uniform
(
self
.
lower
,
self
.
upper
)
# Note1: resample requires the sampling-rate of input and output,
# but actually only the ratio is used.
y
=
librosa
.
resample
(
x
,
orig_sr
=
ratio
,
target_sr
=
1
,
res_type
=
self
.
res_type
)
if
self
.
keep_length
:
diff
=
abs
(
len
(
x
)
-
len
(
y
))
if
len
(
y
)
>
len
(
x
):
# Truncate noise
y
=
y
[
diff
//
2
:
-
((
diff
+
1
)
//
2
)]
elif
len
(
y
)
<
len
(
x
):
# Assume the time-axis is the first: (Time, Channel)
pad_width
=
[(
diff
//
2
,
(
diff
+
1
)
//
2
)]
+
[
(
0
,
0
)
for
_
in
range
(
y
.
ndim
-
1
)
]
y
=
numpy
.
pad
(
y
,
pad_width
=
pad_width
,
constant_values
=
0
,
mode
=
"constant"
)
return
y
class
SpeedPerturbationSox
():
"""SpeedPerturbationSox
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
and sox-speed just to resample the input,
i.e pitch and tempo are changed both.
To speed up or slow down the sound of a file,
use speed to modify the pitch and the duration of the file.
This raises the speed and reduces the time.
The default factor is 1.0 which makes no change to the audio.
2.0 doubles speed, thus time length is cut by a half and pitch is one interval higher.
"Why use speed option instead of tempo -s in SoX for speed perturbation"
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
tempo option:
sox -t wav input.wav -t wav output.tempo0.9.wav tempo -s 0.9
speed option:
sox -t wav input.wav -t wav output.speed0.9.wav speed 0.9
If we use speed option like above, the pitch of audio also will be changed,
but the tempo option does not change the pitch.
"""
def
__init__
(
self
,
lower
=
0.9
,
upper
=
1.1
,
utt2ratio
=
None
,
keep_length
=
True
,
sr
=
16000
,
seed
=
None
,
):
self
.
sr
=
sr
self
.
keep_length
=
keep_length
self
.
state
=
numpy
.
random
.
RandomState
(
seed
)
try
:
import
soxbindings
as
sox
except
ImportError
:
try
:
from
paddlespeech.s2t.utils
import
dynamic_pip_install
package
=
"sox"
dynamic_pip_install
.
install
(
package
)
package
=
"soxbindings"
if
sys
.
platform
!=
"win32"
:
dynamic_pip_install
.
install
(
package
)
import
soxbindings
as
sox
except
Exception
:
raise
RuntimeError
(
"Can not install soxbindings on your system."
)
self
.
sox
=
sox
if
utt2ratio
is
not
None
:
self
.
utt2ratio
=
{}
# Use the scheduled ratio for each utterances
self
.
utt2ratio_file
=
utt2ratio
self
.
lower
=
None
self
.
upper
=
None
self
.
accept_uttid
=
True
with
open
(
utt2ratio
,
"r"
)
as
f
:
for
line
in
f
:
utt
,
ratio
=
line
.
rstrip
().
split
(
None
,
1
)
ratio
=
float
(
ratio
)
self
.
utt2ratio
[
utt
]
=
ratio
else
:
self
.
utt2ratio
=
None
# The ratio is given on runtime randomly
self
.
lower
=
lower
self
.
upper
=
upper
def
__repr__
(
self
):
if
self
.
utt2ratio
is
None
:
return
f
"""
{
self
.
__class__
.
__name__
}
(
lower=
{
self
.
lower
}
,
upper=
{
self
.
upper
}
,
keep_length=
{
self
.
keep_length
}
,
sample_rate=
{
self
.
sr
}
)"""
else
:
return
f
"""
{
self
.
__class__
.
__name__
}
(
utt2ratio=
{
self
.
utt2ratio_file
}
,
sample_rate=
{
self
.
sr
}
)"""
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
x
x
=
x
.
astype
(
numpy
.
float32
)
if
self
.
accept_uttid
:
ratio
=
self
.
utt2ratio
[
uttid
]
else
:
ratio
=
self
.
state
.
uniform
(
self
.
lower
,
self
.
upper
)
tfm
=
self
.
sox
.
Transformer
()
tfm
.
set_globals
(
multithread
=
False
)
tfm
.
speed
(
ratio
)
y
=
tfm
.
build_array
(
input_array
=
x
,
sample_rate_in
=
self
.
sr
)
if
self
.
keep_length
:
diff
=
abs
(
len
(
x
)
-
len
(
y
))
if
len
(
y
)
>
len
(
x
):
# Truncate noise
y
=
y
[
diff
//
2
:
-
((
diff
+
1
)
//
2
)]
elif
len
(
y
)
<
len
(
x
):
# Assume the time-axis is the first: (Time, Channel)
pad_width
=
[(
diff
//
2
,
(
diff
+
1
)
//
2
)]
+
[
(
0
,
0
)
for
_
in
range
(
y
.
ndim
-
1
)
]
y
=
numpy
.
pad
(
y
,
pad_width
=
pad_width
,
constant_values
=
0
,
mode
=
"constant"
)
if
y
.
ndim
==
2
and
x
.
ndim
==
1
:
# (T, C) -> (T)
y
=
y
.
sequence
(
1
)
return
y
class
BandpassPerturbation
():
"""BandpassPerturbation
Randomly dropout along the frequency axis.
The original idea comes from the following:
"randomly-selected frequency band was cut off under the constraint of
leaving at least 1,000 Hz band within the range of less than 4,000Hz."
(The Hitachi/JHU CHiME-5 system: Advances in speech recognition for
everyday home environments using multiple microphone arrays;
http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf)
"""
def
__init__
(
self
,
lower
=
0.0
,
upper
=
0.75
,
seed
=
None
,
axes
=
(
-
1
,
)):
self
.
lower
=
lower
self
.
upper
=
upper
self
.
state
=
numpy
.
random
.
RandomState
(
seed
)
# x_stft: (Time, Channel, Freq)
self
.
axes
=
axes
def
__repr__
(
self
):
return
"{}(lower={}, upper={})"
.
format
(
self
.
__class__
.
__name__
,
self
.
lower
,
self
.
upper
)
def
__call__
(
self
,
x_stft
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
x_stft
if
x_stft
.
ndim
==
1
:
raise
RuntimeError
(
"Input in time-freq domain: "
"(Time, Channel, Freq) or (Time, Freq)"
)
ratio
=
self
.
state
.
uniform
(
self
.
lower
,
self
.
upper
)
axes
=
[
i
if
i
>=
0
else
x_stft
.
ndim
-
i
for
i
in
self
.
axes
]
shape
=
[
s
if
i
in
axes
else
1
for
i
,
s
in
enumerate
(
x_stft
.
shape
)]
mask
=
self
.
state
.
randn
(
*
shape
)
>
ratio
x_stft
*=
mask
return
x_stft
class
VolumePerturbation
():
def
__init__
(
self
,
lower
=-
1.6
,
upper
=
1.6
,
utt2ratio
=
None
,
dbunit
=
True
,
seed
=
None
):
self
.
dbunit
=
dbunit
self
.
utt2ratio_file
=
utt2ratio
self
.
lower
=
lower
self
.
upper
=
upper
self
.
state
=
numpy
.
random
.
RandomState
(
seed
)
if
utt2ratio
is
not
None
:
# Use the scheduled ratio for each utterances
self
.
utt2ratio
=
{}
self
.
lower
=
None
self
.
upper
=
None
self
.
accept_uttid
=
True
with
open
(
utt2ratio
,
"r"
)
as
f
:
for
line
in
f
:
utt
,
ratio
=
line
.
rstrip
().
split
(
None
,
1
)
ratio
=
float
(
ratio
)
self
.
utt2ratio
[
utt
]
=
ratio
else
:
# The ratio is given on runtime randomly
self
.
utt2ratio
=
None
def
__repr__
(
self
):
if
self
.
utt2ratio
is
None
:
return
"{}(lower={}, upper={}, dbunit={})"
.
format
(
self
.
__class__
.
__name__
,
self
.
lower
,
self
.
upper
,
self
.
dbunit
)
else
:
return
'{}("{}", dbunit={})'
.
format
(
self
.
__class__
.
__name__
,
self
.
utt2ratio_file
,
self
.
dbunit
)
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
x
x
=
x
.
astype
(
numpy
.
float32
)
if
self
.
accept_uttid
:
ratio
=
self
.
utt2ratio
[
uttid
]
else
:
ratio
=
self
.
state
.
uniform
(
self
.
lower
,
self
.
upper
)
if
self
.
dbunit
:
ratio
=
10
**
(
ratio
/
20
)
return
x
*
ratio
class
NoiseInjection
():
"""Add isotropic noise"""
def
__init__
(
self
,
utt2noise
=
None
,
lower
=-
20
,
upper
=-
5
,
utt2ratio
=
None
,
filetype
=
"list"
,
dbunit
=
True
,
seed
=
None
,
):
self
.
utt2noise_file
=
utt2noise
self
.
utt2ratio_file
=
utt2ratio
self
.
filetype
=
filetype
self
.
dbunit
=
dbunit
self
.
lower
=
lower
self
.
upper
=
upper
self
.
state
=
numpy
.
random
.
RandomState
(
seed
)
if
utt2ratio
is
not
None
:
# Use the scheduled ratio for each utterances
self
.
utt2ratio
=
{}
with
open
(
utt2noise
,
"r"
)
as
f
:
for
line
in
f
:
utt
,
snr
=
line
.
rstrip
().
split
(
None
,
1
)
snr
=
float
(
snr
)
self
.
utt2ratio
[
utt
]
=
snr
else
:
# The ratio is given on runtime randomly
self
.
utt2ratio
=
None
if
utt2noise
is
not
None
:
self
.
utt2noise
=
{}
if
filetype
==
"list"
:
with
open
(
utt2noise
,
"r"
)
as
f
:
for
line
in
f
:
utt
,
filename
=
line
.
rstrip
().
split
(
None
,
1
)
signal
,
rate
=
soundfile
.
read
(
filename
,
dtype
=
"int16"
)
# Load all files in memory
self
.
utt2noise
[
utt
]
=
(
signal
,
rate
)
elif
filetype
==
"sound.hdf5"
:
self
.
utt2noise
=
SoundHDF5File
(
utt2noise
,
"r"
)
else
:
raise
ValueError
(
filetype
)
else
:
self
.
utt2noise
=
None
if
utt2noise
is
not
None
and
utt2ratio
is
not
None
:
if
set
(
self
.
utt2ratio
)
!=
set
(
self
.
utt2noise
):
raise
RuntimeError
(
"The uttids mismatch between {} and {}"
.
format
(
utt2ratio
,
utt2noise
))
def
__repr__
(
self
):
if
self
.
utt2ratio
is
None
:
return
"{}(lower={}, upper={}, dbunit={})"
.
format
(
self
.
__class__
.
__name__
,
self
.
lower
,
self
.
upper
,
self
.
dbunit
)
else
:
return
'{}("{}", dbunit={})'
.
format
(
self
.
__class__
.
__name__
,
self
.
utt2ratio_file
,
self
.
dbunit
)
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
x
x
=
x
.
astype
(
numpy
.
float32
)
# 1. Get ratio of noise to signal in sound pressure level
if
uttid
is
not
None
and
self
.
utt2ratio
is
not
None
:
ratio
=
self
.
utt2ratio
[
uttid
]
else
:
ratio
=
self
.
state
.
uniform
(
self
.
lower
,
self
.
upper
)
if
self
.
dbunit
:
ratio
=
10
**
(
ratio
/
20
)
scale
=
ratio
*
numpy
.
sqrt
((
x
**
2
).
mean
())
# 2. Get noise
if
self
.
utt2noise
is
not
None
:
# Get noise from the external source
if
uttid
is
not
None
:
noise
,
rate
=
self
.
utt2noise
[
uttid
]
else
:
# Randomly select the noise source
noise
=
self
.
state
.
choice
(
list
(
self
.
utt2noise
.
values
()))
# Normalize the level
noise
/=
numpy
.
sqrt
((
noise
**
2
).
mean
())
# Adjust the noise length
diff
=
abs
(
len
(
x
)
-
len
(
noise
))
offset
=
self
.
state
.
randint
(
0
,
diff
)
if
len
(
noise
)
>
len
(
x
):
# Truncate noise
noise
=
noise
[
offset
:
-
(
diff
-
offset
)]
else
:
noise
=
numpy
.
pad
(
noise
,
pad_width
=
[
offset
,
diff
-
offset
],
mode
=
"wrap"
)
else
:
# Generate white noise
noise
=
self
.
state
.
normal
(
0
,
1
,
x
.
shape
)
# 3. Add noise to signal
return
x
+
noise
*
scale
class
RIRConvolve
():
def
__init__
(
self
,
utt2rir
,
filetype
=
"list"
):
self
.
utt2rir_file
=
utt2rir
self
.
filetype
=
filetype
self
.
utt2rir
=
{}
if
filetype
==
"list"
:
with
open
(
utt2rir
,
"r"
)
as
f
:
for
line
in
f
:
utt
,
filename
=
line
.
rstrip
().
split
(
None
,
1
)
signal
,
rate
=
soundfile
.
read
(
filename
,
dtype
=
"int16"
)
self
.
utt2rir
[
utt
]
=
(
signal
,
rate
)
elif
filetype
==
"sound.hdf5"
:
self
.
utt2rir
=
SoundHDF5File
(
utt2rir
,
"r"
)
else
:
raise
NotImplementedError
(
filetype
)
def
__repr__
(
self
):
return
'{}("{}")'
.
format
(
self
.
__class__
.
__name__
,
self
.
utt2rir_file
)
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
x
x
=
x
.
astype
(
numpy
.
float32
)
if
x
.
ndim
!=
1
:
# Must be single channel
raise
RuntimeError
(
"Input x must be one dimensional array, but got {}"
.
format
(
x
.
shape
))
rir
,
rate
=
self
.
utt2rir
[
uttid
]
if
rir
.
ndim
==
2
:
# FIXME(kamo): Use chainer.convolution_1d?
# return [Time, Channel]
return
numpy
.
stack
(
[
scipy
.
convolve
(
x
,
r
,
mode
=
"same"
)
for
r
in
rir
],
axis
=-
1
)
else
:
return
scipy
.
convolve
(
x
,
rir
,
mode
=
"same"
)
paddlespeech/s2t/transform/spec_augment.py
已删除
100644 → 0
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
"""Spec Augment module for preprocessing i.e., data augmentation"""
import
random
import
numpy
from
PIL
import
Image
from
PIL.Image
import
BICUBIC
from
paddlespeech.s2t.transform.functional
import
FuncTrans
def
time_warp
(
x
,
max_time_warp
=
80
,
inplace
=
False
,
mode
=
"PIL"
):
"""time warp for spec augment
move random center frame by the random width ~ uniform(-window, window)
:param numpy.ndarray x: spectrogram (time, freq)
:param int max_time_warp: maximum time frames to warp
:param bool inplace: overwrite x with the result
:param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp"
(slow, differentiable)
:returns numpy.ndarray: time warped spectrogram (time, freq)
"""
window
=
max_time_warp
if
window
==
0
:
return
x
if
mode
==
"PIL"
:
t
=
x
.
shape
[
0
]
if
t
-
window
<=
window
:
return
x
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
center
=
random
.
randrange
(
window
,
t
-
window
)
warped
=
random
.
randrange
(
center
-
window
,
center
+
window
)
+
1
# 1 ... t - 1
left
=
Image
.
fromarray
(
x
[:
center
]).
resize
((
x
.
shape
[
1
],
warped
),
BICUBIC
)
right
=
Image
.
fromarray
(
x
[
center
:]).
resize
((
x
.
shape
[
1
],
t
-
warped
),
BICUBIC
)
if
inplace
:
x
[:
warped
]
=
left
x
[
warped
:]
=
right
return
x
return
numpy
.
concatenate
((
left
,
right
),
0
)
elif
mode
==
"sparse_image_warp"
:
import
paddle
from
espnet.utils
import
spec_augment
# TODO(karita): make this differentiable again
return
spec_augment
.
time_warp
(
paddle
.
to_tensor
(
x
),
window
).
numpy
()
else
:
raise
NotImplementedError
(
"unknown resize mode: "
+
mode
+
", choose one from (PIL, sparse_image_warp)."
)
class
TimeWarp
(
FuncTrans
):
_func
=
time_warp
__doc__
=
time_warp
.
__doc__
def
__call__
(
self
,
x
,
train
):
if
not
train
:
return
x
return
super
().
__call__
(
x
)
def
freq_mask
(
x
,
F
=
30
,
n_mask
=
2
,
replace_with_zero
=
True
,
inplace
=
False
):
"""freq mask for spec agument
:param numpy.ndarray x: (time, freq)
:param int n_mask: the number of masks
:param bool inplace: overwrite
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
if
inplace
:
cloned
=
x
else
:
cloned
=
x
.
copy
()
num_mel_channels
=
cloned
.
shape
[
1
]
fs
=
numpy
.
random
.
randint
(
0
,
F
,
size
=
(
n_mask
,
2
))
for
f
,
mask_end
in
fs
:
f_zero
=
random
.
randrange
(
0
,
num_mel_channels
-
f
)
mask_end
+=
f_zero
# avoids randrange error if values are equal and range is empty
if
f_zero
==
f_zero
+
f
:
continue
if
replace_with_zero
:
cloned
[:,
f_zero
:
mask_end
]
=
0
else
:
cloned
[:,
f_zero
:
mask_end
]
=
cloned
.
mean
()
return
cloned
class
FreqMask
(
FuncTrans
):
_func
=
freq_mask
__doc__
=
freq_mask
.
__doc__
def
__call__
(
self
,
x
,
train
):
if
not
train
:
return
x
return
super
().
__call__
(
x
)
def
time_mask
(
spec
,
T
=
40
,
n_mask
=
2
,
replace_with_zero
=
True
,
inplace
=
False
):
"""freq mask for spec agument
:param numpy.ndarray spec: (time, freq)
:param int n_mask: the number of masks
:param bool inplace: overwrite
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
if
inplace
:
cloned
=
spec
else
:
cloned
=
spec
.
copy
()
len_spectro
=
cloned
.
shape
[
0
]
ts
=
numpy
.
random
.
randint
(
0
,
T
,
size
=
(
n_mask
,
2
))
for
t
,
mask_end
in
ts
:
# avoid randint range error
if
len_spectro
-
t
<=
0
:
continue
t_zero
=
random
.
randrange
(
0
,
len_spectro
-
t
)
# avoids randrange error if values are equal and range is empty
if
t_zero
==
t_zero
+
t
:
continue
mask_end
+=
t_zero
if
replace_with_zero
:
cloned
[
t_zero
:
mask_end
]
=
0
else
:
cloned
[
t_zero
:
mask_end
]
=
cloned
.
mean
()
return
cloned
class
TimeMask
(
FuncTrans
):
_func
=
time_mask
__doc__
=
time_mask
.
__doc__
def
__call__
(
self
,
x
,
train
):
if
not
train
:
return
x
return
super
().
__call__
(
x
)
def
spec_augment
(
x
,
resize_mode
=
"PIL"
,
max_time_warp
=
80
,
max_freq_width
=
27
,
n_freq_mask
=
2
,
max_time_width
=
100
,
n_time_mask
=
2
,
inplace
=
True
,
replace_with_zero
=
True
,
):
"""spec agument
apply random time warping and time/freq masking
default setting is based on LD (Librispeech double) in Table 2
https://arxiv.org/pdf/1904.08779.pdf
:param numpy.ndarray x: (time, freq)
:param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp"
(slow, differentiable)
:param int max_time_warp: maximum frames to warp the center frame in spectrogram (W)
:param int freq_mask_width: maximum width of the random freq mask (F)
:param int n_freq_mask: the number of the random freq mask (m_F)
:param int time_mask_width: maximum width of the random time mask (T)
:param int n_time_mask: the number of the random time mask (m_T)
:param bool inplace: overwrite intermediate array
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
assert
isinstance
(
x
,
numpy
.
ndarray
)
assert
x
.
ndim
==
2
x
=
time_warp
(
x
,
max_time_warp
,
inplace
=
inplace
,
mode
=
resize_mode
)
x
=
freq_mask
(
x
,
max_freq_width
,
n_freq_mask
,
inplace
=
inplace
,
replace_with_zero
=
replace_with_zero
,
)
x
=
time_mask
(
x
,
max_time_width
,
n_time_mask
,
inplace
=
inplace
,
replace_with_zero
=
replace_with_zero
,
)
return
x
class
SpecAugment
(
FuncTrans
):
_func
=
spec_augment
__doc__
=
spec_augment
.
__doc__
def
__call__
(
self
,
x
,
train
):
if
not
train
:
return
x
return
super
().
__call__
(
x
)
paddlespeech/s2t/transform/spectrogram.py
已删除
100644 → 0
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
librosa
import
numpy
as
np
import
paddle
from
python_speech_features
import
logfbank
import
paddlespeech.audio.compliance.kaldi
as
kaldi
def
stft
(
x
,
n_fft
,
n_shift
,
win_length
=
None
,
window
=
"hann"
,
center
=
True
,
pad_mode
=
"reflect"
):
# x: [Time, Channel]
if
x
.
ndim
==
1
:
single_channel
=
True
# x: [Time] -> [Time, Channel]
x
=
x
[:,
None
]
else
:
single_channel
=
False
x
=
x
.
astype
(
np
.
float32
)
# FIXME(kamo): librosa.stft can't use multi-channel?
# x: [Time, Channel, Freq]
x
=
np
.
stack
(
[
librosa
.
stft
(
y
=
x
[:,
ch
],
n_fft
=
n_fft
,
hop_length
=
n_shift
,
win_length
=
win_length
,
window
=
window
,
center
=
center
,
pad_mode
=
pad_mode
,
).
T
for
ch
in
range
(
x
.
shape
[
1
])
],
axis
=
1
,
)
if
single_channel
:
# x: [Time, Channel, Freq] -> [Time, Freq]
x
=
x
[:,
0
]
return
x
def
istft
(
x
,
n_shift
,
win_length
=
None
,
window
=
"hann"
,
center
=
True
):
# x: [Time, Channel, Freq]
if
x
.
ndim
==
2
:
single_channel
=
True
# x: [Time, Freq] -> [Time, Channel, Freq]
x
=
x
[:,
None
,
:]
else
:
single_channel
=
False
# x: [Time, Channel]
x
=
np
.
stack
(
[
librosa
.
istft
(
stft_matrix
=
x
[:,
ch
].
T
,
# [Time, Freq] -> [Freq, Time]
hop_length
=
n_shift
,
win_length
=
win_length
,
window
=
window
,
center
=
center
,
)
for
ch
in
range
(
x
.
shape
[
1
])
],
axis
=
1
,
)
if
single_channel
:
# x: [Time, Channel] -> [Time]
x
=
x
[:,
0
]
return
x
def
stft2logmelspectrogram
(
x_stft
,
fs
,
n_mels
,
n_fft
,
fmin
=
None
,
fmax
=
None
,
eps
=
1e-10
):
# x_stft: (Time, Channel, Freq) or (Time, Freq)
fmin
=
0
if
fmin
is
None
else
fmin
fmax
=
fs
/
2
if
fmax
is
None
else
fmax
# spc: (Time, Channel, Freq) or (Time, Freq)
spc
=
np
.
abs
(
x_stft
)
# mel_basis: (Mel_freq, Freq)
mel_basis
=
librosa
.
filters
.
mel
(
sr
=
fs
,
n_fft
=
n_fft
,
n_mels
=
n_mels
,
fmin
=
fmin
,
fmax
=
fmax
)
# lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq)
lmspc
=
np
.
log10
(
np
.
maximum
(
eps
,
np
.
dot
(
spc
,
mel_basis
.
T
)))
return
lmspc
def
spectrogram
(
x
,
n_fft
,
n_shift
,
win_length
=
None
,
window
=
"hann"
):
# x: (Time, Channel) -> spc: (Time, Channel, Freq)
spc
=
np
.
abs
(
stft
(
x
,
n_fft
,
n_shift
,
win_length
,
window
=
window
))
return
spc
def
logmelspectrogram
(
x
,
fs
,
n_mels
,
n_fft
,
n_shift
,
win_length
=
None
,
window
=
"hann"
,
fmin
=
None
,
fmax
=
None
,
eps
=
1e-10
,
pad_mode
=
"reflect"
,
):
# stft: (Time, Channel, Freq) or (Time, Freq)
x_stft
=
stft
(
x
,
n_fft
=
n_fft
,
n_shift
=
n_shift
,
win_length
=
win_length
,
window
=
window
,
pad_mode
=
pad_mode
,
)
return
stft2logmelspectrogram
(
x_stft
,
fs
=
fs
,
n_mels
=
n_mels
,
n_fft
=
n_fft
,
fmin
=
fmin
,
fmax
=
fmax
,
eps
=
eps
)
class
Spectrogram
():
def
__init__
(
self
,
n_fft
,
n_shift
,
win_length
=
None
,
window
=
"hann"
):
self
.
n_fft
=
n_fft
self
.
n_shift
=
n_shift
self
.
win_length
=
win_length
self
.
window
=
window
def
__repr__
(
self
):
return
(
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window})"
.
format
(
name
=
self
.
__class__
.
__name__
,
n_fft
=
self
.
n_fft
,
n_shift
=
self
.
n_shift
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
))
def
__call__
(
self
,
x
):
return
spectrogram
(
x
,
n_fft
=
self
.
n_fft
,
n_shift
=
self
.
n_shift
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
)
class
LogMelSpectrogram
():
def
__init__
(
self
,
fs
,
n_mels
,
n_fft
,
n_shift
,
win_length
=
None
,
window
=
"hann"
,
fmin
=
None
,
fmax
=
None
,
eps
=
1e-10
,
):
self
.
fs
=
fs
self
.
n_mels
=
n_mels
self
.
n_fft
=
n_fft
self
.
n_shift
=
n_shift
self
.
win_length
=
win_length
self
.
window
=
window
self
.
fmin
=
fmin
self
.
fmax
=
fmax
self
.
eps
=
eps
def
__repr__
(
self
):
return
(
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"n_shift={n_shift}, win_length={win_length}, window={window}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))"
.
format
(
name
=
self
.
__class__
.
__name__
,
fs
=
self
.
fs
,
n_mels
=
self
.
n_mels
,
n_fft
=
self
.
n_fft
,
n_shift
=
self
.
n_shift
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
fmin
=
self
.
fmin
,
fmax
=
self
.
fmax
,
eps
=
self
.
eps
,
))
def
__call__
(
self
,
x
):
return
logmelspectrogram
(
x
,
fs
=
self
.
fs
,
n_mels
=
self
.
n_mels
,
n_fft
=
self
.
n_fft
,
n_shift
=
self
.
n_shift
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
)
class
Stft2LogMelSpectrogram
():
def
__init__
(
self
,
fs
,
n_mels
,
n_fft
,
fmin
=
None
,
fmax
=
None
,
eps
=
1e-10
):
self
.
fs
=
fs
self
.
n_mels
=
n_mels
self
.
n_fft
=
n_fft
self
.
fmin
=
fmin
self
.
fmax
=
fmax
self
.
eps
=
eps
def
__repr__
(
self
):
return
(
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))"
.
format
(
name
=
self
.
__class__
.
__name__
,
fs
=
self
.
fs
,
n_mels
=
self
.
n_mels
,
n_fft
=
self
.
n_fft
,
fmin
=
self
.
fmin
,
fmax
=
self
.
fmax
,
eps
=
self
.
eps
,
))
def
__call__
(
self
,
x
):
return
stft2logmelspectrogram
(
x
,
fs
=
self
.
fs
,
n_mels
=
self
.
n_mels
,
n_fft
=
self
.
n_fft
,
fmin
=
self
.
fmin
,
fmax
=
self
.
fmax
,
)
class
Stft
():
def
__init__
(
self
,
n_fft
,
n_shift
,
win_length
=
None
,
window
=
"hann"
,
center
=
True
,
pad_mode
=
"reflect"
,
):
self
.
n_fft
=
n_fft
self
.
n_shift
=
n_shift
self
.
win_length
=
win_length
self
.
window
=
window
self
.
center
=
center
self
.
pad_mode
=
pad_mode
def
__repr__
(
self
):
return
(
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window},"
"center={center}, pad_mode={pad_mode})"
.
format
(
name
=
self
.
__class__
.
__name__
,
n_fft
=
self
.
n_fft
,
n_shift
=
self
.
n_shift
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
center
=
self
.
center
,
pad_mode
=
self
.
pad_mode
,
))
def
__call__
(
self
,
x
):
return
stft
(
x
,
self
.
n_fft
,
self
.
n_shift
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
center
=
self
.
center
,
pad_mode
=
self
.
pad_mode
,
)
class
IStft
():
def
__init__
(
self
,
n_shift
,
win_length
=
None
,
window
=
"hann"
,
center
=
True
):
self
.
n_shift
=
n_shift
self
.
win_length
=
win_length
self
.
window
=
window
self
.
center
=
center
def
__repr__
(
self
):
return
(
"{name}(n_shift={n_shift}, "
"win_length={win_length}, window={window},"
"center={center})"
.
format
(
name
=
self
.
__class__
.
__name__
,
n_shift
=
self
.
n_shift
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
center
=
self
.
center
,
))
def
__call__
(
self
,
x
):
return
istft
(
x
,
self
.
n_shift
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
center
=
self
.
center
,
)
class
LogMelSpectrogramKaldi
():
def
__init__
(
self
,
fs
=
16000
,
n_mels
=
80
,
n_shift
=
160
,
# unit:sample, 10ms
win_length
=
400
,
# unit:sample, 25ms
energy_floor
=
0.0
,
dither
=
0.1
):
"""
The Kaldi implementation of LogMelSpectrogram
Args:
fs (int): sample rate of the audio
n_mels (int): number of mel filter banks
n_shift (int): number of points in a frame shift
win_length (int): number of points in a frame windows
energy_floor (float): Floor on energy in Spectrogram computation (absolute)
dither (float): Dithering constant
Returns:
LogMelSpectrogramKaldi
"""
self
.
fs
=
fs
self
.
n_mels
=
n_mels
num_point_ms
=
fs
/
1000
self
.
n_frame_length
=
win_length
/
num_point_ms
self
.
n_frame_shift
=
n_shift
/
num_point_ms
self
.
energy_floor
=
energy_floor
self
.
dither
=
dither
def
__repr__
(
self
):
return
(
"{name}(fs={fs}, n_mels={n_mels}, "
"n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, "
"dither={dither}))"
.
format
(
name
=
self
.
__class__
.
__name__
,
fs
=
self
.
fs
,
n_mels
=
self
.
n_mels
,
n_frame_shift
=
self
.
n_frame_shift
,
n_frame_length
=
self
.
n_frame_length
,
dither
=
self
.
dither
,
))
def
__call__
(
self
,
x
,
train
):
"""
Args:
x (np.ndarray): shape (Ti,)
train (bool): True, train mode.
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
dither
=
self
.
dither
if
train
else
0.0
if
x
.
ndim
!=
1
:
raise
ValueError
(
"Not support x: [Time, Channel]"
)
waveform
=
paddle
.
to_tensor
(
np
.
expand_dims
(
x
,
0
),
dtype
=
paddle
.
float32
)
mat
=
kaldi
.
fbank
(
waveform
,
n_mels
=
self
.
n_mels
,
frame_length
=
self
.
n_frame_length
,
frame_shift
=
self
.
n_frame_shift
,
dither
=
dither
,
energy_floor
=
self
.
energy_floor
,
sr
=
self
.
fs
)
mat
=
np
.
squeeze
(
mat
.
numpy
())
return
mat
class
LogMelSpectrogramKaldi_decay
():
def
__init__
(
self
,
fs
=
16000
,
n_mels
=
80
,
n_fft
=
512
,
# fft point
n_shift
=
160
,
# unit:sample, 10ms
win_length
=
400
,
# unit:sample, 25ms
window
=
"povey"
,
fmin
=
20
,
fmax
=
None
,
eps
=
1e-10
,
dither
=
1.0
):
self
.
fs
=
fs
self
.
n_mels
=
n_mels
self
.
n_fft
=
n_fft
if
n_shift
>
win_length
:
raise
ValueError
(
"Stride size must not be greater than "
"window size."
)
self
.
n_shift
=
n_shift
/
fs
# unit: ms
self
.
win_length
=
win_length
/
fs
# unit: ms
self
.
window
=
window
self
.
fmin
=
fmin
if
fmax
is
None
:
fmax_
=
fmax
if
fmax
else
self
.
fs
/
2
elif
fmax
>
int
(
self
.
fs
/
2
):
raise
ValueError
(
"fmax must not be greater than half of "
"sample rate."
)
self
.
fmax
=
fmax_
self
.
eps
=
eps
self
.
remove_dc_offset
=
True
self
.
preemph
=
0.97
self
.
dither
=
dither
# only work in train mode
def
__repr__
(
self
):
return
(
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"n_shift={n_shift}, win_length={win_length}, preemph={preemph}, window={window}, "
"fmin={fmin}, fmax={fmax}, eps={eps}, dither={dither}))"
.
format
(
name
=
self
.
__class__
.
__name__
,
fs
=
self
.
fs
,
n_mels
=
self
.
n_mels
,
n_fft
=
self
.
n_fft
,
n_shift
=
self
.
n_shift
,
preemph
=
self
.
preemph
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
fmin
=
self
.
fmin
,
fmax
=
self
.
fmax
,
eps
=
self
.
eps
,
dither
=
self
.
dither
,
))
def
__call__
(
self
,
x
,
train
):
"""
Args:
x (np.ndarray): shape (Ti,)
train (bool): True, train mode.
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
dither
=
self
.
dither
if
train
else
0.0
if
x
.
ndim
!=
1
:
raise
ValueError
(
"Not support x: [Time, Channel]"
)
if
x
.
dtype
in
np
.
sctypes
[
'float'
]:
# PCM32 -> PCM16
bits
=
np
.
iinfo
(
np
.
int16
).
bits
x
=
x
*
2
**
(
bits
-
1
)
# logfbank need PCM16 input
y
=
logfbank
(
signal
=
x
,
samplerate
=
self
.
fs
,
winlen
=
self
.
win_length
,
# unit ms
winstep
=
self
.
n_shift
,
# unit ms
nfilt
=
self
.
n_mels
,
nfft
=
self
.
n_fft
,
lowfreq
=
self
.
fmin
,
highfreq
=
self
.
fmax
,
dither
=
dither
,
remove_dc_offset
=
self
.
remove_dc_offset
,
preemph
=
self
.
preemph
,
wintype
=
self
.
window
)
return
y
paddlespeech/s2t/transform/transform_interface.py
已删除
100644 → 0
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
class
TransformInterface
:
"""Transform Interface"""
def
__call__
(
self
,
x
):
raise
NotImplementedError
(
"__call__ method is not implemented"
)
@
classmethod
def
add_arguments
(
cls
,
parser
):
return
parser
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"()"
class
Identity
(
TransformInterface
):
"""Identity Function"""
def
__call__
(
self
,
x
):
return
x
paddlespeech/s2t/transform/transformation.py
已删除
100644 → 0
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
"""Transformation module."""
import
copy
import
io
import
logging
from
collections
import
OrderedDict
from
collections.abc
import
Sequence
from
inspect
import
signature
import
yaml
from
paddlespeech.s2t.utils.dynamic_import
import
dynamic_import
import_alias
=
dict
(
identity
=
"paddlespeech.s2t.transform.transform_interface:Identity"
,
time_warp
=
"paddlespeech.s2t.transform.spec_augment:TimeWarp"
,
time_mask
=
"paddlespeech.s2t.transform.spec_augment:TimeMask"
,
freq_mask
=
"paddlespeech.s2t.transform.spec_augment:FreqMask"
,
spec_augment
=
"paddlespeech.s2t.transform.spec_augment:SpecAugment"
,
speed_perturbation
=
"paddlespeech.s2t.transform.perturb:SpeedPerturbation"
,
speed_perturbation_sox
=
"paddlespeech.s2t.transform.perturb:SpeedPerturbationSox"
,
volume_perturbation
=
"paddlespeech.s2t.transform.perturb:VolumePerturbation"
,
noise_injection
=
"paddlespeech.s2t.transform.perturb:NoiseInjection"
,
bandpass_perturbation
=
"paddlespeech.s2t.transform.perturb:BandpassPerturbation"
,
rir_convolve
=
"paddlespeech.s2t.transform.perturb:RIRConvolve"
,
delta
=
"paddlespeech.s2t.transform.add_deltas:AddDeltas"
,
cmvn
=
"paddlespeech.s2t.transform.cmvn:CMVN"
,
utterance_cmvn
=
"paddlespeech.s2t.transform.cmvn:UtteranceCMVN"
,
fbank
=
"paddlespeech.s2t.transform.spectrogram:LogMelSpectrogram"
,
spectrogram
=
"paddlespeech.s2t.transform.spectrogram:Spectrogram"
,
stft
=
"paddlespeech.s2t.transform.spectrogram:Stft"
,
istft
=
"paddlespeech.s2t.transform.spectrogram:IStft"
,
stft2fbank
=
"paddlespeech.s2t.transform.spectrogram:Stft2LogMelSpectrogram"
,
wpe
=
"paddlespeech.s2t.transform.wpe:WPE"
,
channel_selector
=
"paddlespeech.s2t.transform.channel_selector:ChannelSelector"
,
fbank_kaldi
=
"paddlespeech.s2t.transform.spectrogram:LogMelSpectrogramKaldi"
,
cmvn_json
=
"paddlespeech.s2t.transform.cmvn:GlobalCMVN"
)
class
Transformation
():
"""Apply some functions to the mini-batch
Examples:
>>> kwargs = {"process": [{"type": "fbank",
... "n_mels": 80,
... "fs": 16000},
... {"type": "cmvn",
... "stats": "data/train/cmvn.ark",
... "norm_vars": True},
... {"type": "delta", "window": 2, "order": 2}]}
>>> transform = Transformation(kwargs)
>>> bs = 10
>>> xs = [np.random.randn(100, 80).astype(np.float32)
... for _ in range(bs)]
>>> xs = transform(xs)
"""
def
__init__
(
self
,
conffile
=
None
):
if
conffile
is
not
None
:
if
isinstance
(
conffile
,
dict
):
self
.
conf
=
copy
.
deepcopy
(
conffile
)
else
:
with
io
.
open
(
conffile
,
encoding
=
"utf-8"
)
as
f
:
self
.
conf
=
yaml
.
safe_load
(
f
)
assert
isinstance
(
self
.
conf
,
dict
),
type
(
self
.
conf
)
else
:
self
.
conf
=
{
"mode"
:
"sequential"
,
"process"
:
[]}
self
.
functions
=
OrderedDict
()
if
self
.
conf
.
get
(
"mode"
,
"sequential"
)
==
"sequential"
:
for
idx
,
process
in
enumerate
(
self
.
conf
[
"process"
]):
assert
isinstance
(
process
,
dict
),
type
(
process
)
opts
=
dict
(
process
)
process_type
=
opts
.
pop
(
"type"
)
class_obj
=
dynamic_import
(
process_type
,
import_alias
)
# TODO(karita): assert issubclass(class_obj, TransformInterface)
try
:
self
.
functions
[
idx
]
=
class_obj
(
**
opts
)
except
TypeError
:
try
:
signa
=
signature
(
class_obj
)
except
ValueError
:
# Some function, e.g. built-in function, are failed
pass
else
:
logging
.
error
(
"Expected signature: {}({})"
.
format
(
class_obj
.
__name__
,
signa
))
raise
else
:
raise
NotImplementedError
(
"Not supporting mode={}"
.
format
(
self
.
conf
[
"mode"
]))
def
__repr__
(
self
):
rep
=
"
\n
"
+
"
\n
"
.
join
(
" {}: {}"
.
format
(
k
,
v
)
for
k
,
v
in
self
.
functions
.
items
())
return
"{}({})"
.
format
(
self
.
__class__
.
__name__
,
rep
)
def
__call__
(
self
,
xs
,
uttid_list
=
None
,
**
kwargs
):
"""Return new mini-batch
:param Union[Sequence[np.ndarray], np.ndarray] xs:
:param Union[Sequence[str], str] uttid_list:
:return: batch:
:rtype: List[np.ndarray]
"""
if
not
isinstance
(
xs
,
Sequence
):
is_batch
=
False
xs
=
[
xs
]
else
:
is_batch
=
True
if
isinstance
(
uttid_list
,
str
):
uttid_list
=
[
uttid_list
for
_
in
range
(
len
(
xs
))]
if
self
.
conf
.
get
(
"mode"
,
"sequential"
)
==
"sequential"
:
for
idx
in
range
(
len
(
self
.
conf
[
"process"
])):
func
=
self
.
functions
[
idx
]
# TODO(karita): use TrainingTrans and UttTrans to check __call__ args
# Derive only the args which the func has
try
:
param
=
signature
(
func
).
parameters
except
ValueError
:
# Some function, e.g. built-in function, are failed
param
=
{}
_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
in
param
}
try
:
if
uttid_list
is
not
None
and
"uttid"
in
param
:
xs
=
[
func
(
x
,
u
,
**
_kwargs
)
for
x
,
u
in
zip
(
xs
,
uttid_list
)
]
else
:
xs
=
[
func
(
x
,
**
_kwargs
)
for
x
in
xs
]
except
Exception
:
logging
.
fatal
(
"Catch a exception from {}th func: {}"
.
format
(
idx
,
func
))
raise
else
:
raise
NotImplementedError
(
"Not supporting mode={}"
.
format
(
self
.
conf
[
"mode"
]))
if
is_batch
:
return
xs
else
:
return
xs
[
0
]
paddlespeech/s2t/transform/wpe.py
已删除
100644 → 0
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
from
nara_wpe.wpe
import
wpe
class
WPE
(
object
):
def
__init__
(
self
,
taps
=
10
,
delay
=
3
,
iterations
=
3
,
psd_context
=
0
,
statistics_mode
=
"full"
):
self
.
taps
=
taps
self
.
delay
=
delay
self
.
iterations
=
iterations
self
.
psd_context
=
psd_context
self
.
statistics_mode
=
statistics_mode
def
__repr__
(
self
):
return
(
"{name}(taps={taps}, delay={delay}"
"iterations={iterations}, psd_context={psd_context}, "
"statistics_mode={statistics_mode})"
.
format
(
name
=
self
.
__class__
.
__name__
,
taps
=
self
.
taps
,
delay
=
self
.
delay
,
iterations
=
self
.
iterations
,
psd_context
=
self
.
psd_context
,
statistics_mode
=
self
.
statistics_mode
,
))
def
__call__
(
self
,
xs
):
"""Return enhanced
:param np.ndarray xs: (Time, Channel, Frequency)
:return: enhanced_xs
:rtype: np.ndarray
"""
# nara_wpe.wpe: (F, C, T)
xs
=
wpe
(
xs
.
transpose
((
2
,
1
,
0
)),
taps
=
self
.
taps
,
delay
=
self
.
delay
,
iterations
=
self
.
iterations
,
psd_context
=
self
.
psd_context
,
statistics_mode
=
self
.
statistics_mode
,
)
return
xs
.
transpose
(
2
,
1
,
0
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录