Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
8f5e6109
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
11 个月 前同步成功
通知
204
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8f5e6109
编写于
6月 22, 2022
作者:
H
huangyuxin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
new feature: Add webdataset in audio
上级
e04cd188
变更
24
展开全部
显示空白变更内容
内联
并排
Showing
24 changed file
with
4349 addition
and
1 deletion
+4349
-1
paddlespeech/audio/stream_data/__init__.py
paddlespeech/audio/stream_data/__init__.py
+68
-0
paddlespeech/audio/stream_data/cache.py
paddlespeech/audio/stream_data/cache.py
+190
-0
paddlespeech/audio/stream_data/compat.py
paddlespeech/audio/stream_data/compat.py
+170
-0
paddlespeech/audio/stream_data/filters.py
paddlespeech/audio/stream_data/filters.py
+912
-0
paddlespeech/audio/stream_data/paddle_utils.py
paddlespeech/audio/stream_data/paddle_utils.py
+33
-0
paddlespeech/audio/stream_data/pipeline.py
paddlespeech/audio/stream_data/pipeline.py
+127
-0
paddlespeech/audio/stream_data/shardlists.py
paddlespeech/audio/stream_data/shardlists.py
+257
-0
paddlespeech/audio/stream_data/tariterators.py
paddlespeech/audio/stream_data/tariterators.py
+283
-0
paddlespeech/audio/stream_data/utils.py
paddlespeech/audio/stream_data/utils.py
+128
-0
paddlespeech/audio/transform/__init__.py
paddlespeech/audio/transform/__init__.py
+13
-0
paddlespeech/audio/transform/add_deltas.py
paddlespeech/audio/transform/add_deltas.py
+54
-0
paddlespeech/audio/transform/channel_selector.py
paddlespeech/audio/transform/channel_selector.py
+57
-0
paddlespeech/audio/transform/cmvn.py
paddlespeech/audio/transform/cmvn.py
+201
-0
paddlespeech/audio/transform/functional.py
paddlespeech/audio/transform/functional.py
+86
-0
paddlespeech/audio/transform/perturb.py
paddlespeech/audio/transform/perturb.py
+561
-0
paddlespeech/audio/transform/spec_augment.py
paddlespeech/audio/transform/spec_augment.py
+214
-0
paddlespeech/audio/transform/spectrogram.py
paddlespeech/audio/transform/spectrogram.py
+475
-0
paddlespeech/audio/transform/transform_interface.py
paddlespeech/audio/transform/transform_interface.py
+35
-0
paddlespeech/audio/transform/transformation.py
paddlespeech/audio/transform/transformation.py
+158
-0
paddlespeech/audio/transform/wpe.py
paddlespeech/audio/transform/wpe.py
+58
-0
paddlespeech/audio/utils/check_kwargs.py
paddlespeech/audio/utils/check_kwargs.py
+35
-0
paddlespeech/audio/utils/dynamic_import.py
paddlespeech/audio/utils/dynamic_import.py
+38
-0
paddlespeech/audio/utils/tensor_utils.py
paddlespeech/audio/utils/tensor_utils.py
+195
-0
setup.py
setup.py
+1
-1
未找到文件。
paddlespeech/audio/stream_data/__init__.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
# flake8: noqa
from
.cache
import
(
cached_tarfile_samples
,
cached_tarfile_to_samples
,
lru_cleanup
,
pipe_cleaner
,
)
from
.compat
import
WebDataset
,
WebLoader
,
FluidWrapper
from
webdataset.extradatasets
import
MockDataset
,
with_epoch
,
with_length
from
.filters
import
(
associate
,
batched
,
decode
,
detshuffle
,
extract_keys
,
getfirst
,
info
,
map
,
map_dict
,
map_tuple
,
pipelinefilter
,
rename
,
rename_keys
,
rsample
,
select
,
shuffle
,
slice
,
to_tuple
,
transform_with
,
unbatched
,
xdecode
,
data_filter
,
tokenize
,
resample
,
compute_fbank
,
spec_aug
,
sort
,
padding
,
cmvn
)
from
webdataset.handlers
import
(
ignore_and_continue
,
ignore_and_stop
,
reraise_exception
,
warn_and_continue
,
warn_and_stop
,
)
from
.pipeline
import
DataPipeline
from
.shardlists
import
(
MultiShardSample
,
ResampledShards
,
SimpleShardList
,
non_empty
,
resampled
,
shardspec
,
single_node_only
,
split_by_node
,
split_by_worker
,
)
from
.tariterators
import
tarfile_samples
,
tarfile_to_samples
from
.utils
import
PipelineStage
,
repeatedly
from
webdataset.writer
import
ShardWriter
,
TarWriter
,
numpy_dumps
from
webdataset.mix
import
RandomMix
,
RoundRobin
paddlespeech/audio/stream_data/cache.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
import
itertools
,
os
,
random
,
re
,
sys
from
urllib.parse
import
urlparse
from
.
import
filters
from
webdataset
import
gopen
from
webdataset.handlers
import
reraise_exception
from
.tariterators
import
tar_file_and_group_expander
default_cache_dir
=
os
.
environ
.
get
(
"WDS_CACHE"
,
"./_cache"
)
default_cache_size
=
float
(
os
.
environ
.
get
(
"WDS_CACHE_SIZE"
,
"1e18"
))
def
lru_cleanup
(
cache_dir
,
cache_size
,
keyfn
=
os
.
path
.
getctime
,
verbose
=
False
):
"""Performs cleanup of the file cache in cache_dir using an LRU strategy,
keeping the total size of all remaining files below cache_size."""
if
not
os
.
path
.
exists
(
cache_dir
):
return
total_size
=
0
for
dirpath
,
dirnames
,
filenames
in
os
.
walk
(
cache_dir
):
for
filename
in
filenames
:
total_size
+=
os
.
path
.
getsize
(
os
.
path
.
join
(
dirpath
,
filename
))
if
total_size
<=
cache_size
:
return
# sort files by last access time
files
=
[]
for
dirpath
,
dirnames
,
filenames
in
os
.
walk
(
cache_dir
):
for
filename
in
filenames
:
files
.
append
(
os
.
path
.
join
(
dirpath
,
filename
))
files
.
sort
(
key
=
keyfn
,
reverse
=
True
)
# delete files until we're under the cache size
while
len
(
files
)
>
0
and
total_size
>
cache_size
:
fname
=
files
.
pop
()
total_size
-=
os
.
path
.
getsize
(
fname
)
if
verbose
:
print
(
"# deleting %s"
%
fname
,
file
=
sys
.
stderr
)
os
.
remove
(
fname
)
def
download
(
url
,
dest
,
chunk_size
=
1024
**
2
,
verbose
=
False
):
"""Download a file from `url` to `dest`."""
temp
=
dest
+
f
".temp
{
os
.
getpid
()
}
"
with
gopen
.
gopen
(
url
)
as
stream
:
with
open
(
temp
,
"wb"
)
as
f
:
while
True
:
data
=
stream
.
read
(
chunk_size
)
if
not
data
:
break
f
.
write
(
data
)
os
.
rename
(
temp
,
dest
)
def
pipe_cleaner
(
spec
):
"""Guess the actual URL from a "pipe:" specification."""
if
spec
.
startswith
(
"pipe:"
):
spec
=
spec
[
5
:]
words
=
spec
.
split
(
" "
)
for
word
in
words
:
if
re
.
match
(
r
"^(https?|gs|ais|s3)"
,
word
):
return
word
return
spec
def
get_file_cached
(
spec
,
cache_size
=-
1
,
cache_dir
=
None
,
url_to_name
=
pipe_cleaner
,
verbose
=
False
,
):
if
cache_size
==
-
1
:
cache_size
=
default_cache_size
if
cache_dir
is
None
:
cache_dir
=
default_cache_dir
url
=
url_to_name
(
spec
)
parsed
=
urlparse
(
url
)
dirname
,
filename
=
os
.
path
.
split
(
parsed
.
path
)
dirname
=
dirname
.
lstrip
(
"/"
)
dirname
=
re
.
sub
(
r
"[:/|;]"
,
"_"
,
dirname
)
destdir
=
os
.
path
.
join
(
cache_dir
,
dirname
)
os
.
makedirs
(
destdir
,
exist_ok
=
True
)
dest
=
os
.
path
.
join
(
cache_dir
,
dirname
,
filename
)
if
not
os
.
path
.
exists
(
dest
):
if
verbose
:
print
(
"# downloading %s to %s"
%
(
url
,
dest
),
file
=
sys
.
stderr
)
lru_cleanup
(
cache_dir
,
cache_size
,
verbose
=
verbose
)
download
(
spec
,
dest
,
verbose
=
verbose
)
return
dest
def
get_filetype
(
fname
):
with
os
.
popen
(
"file '%s'"
%
fname
)
as
f
:
ftype
=
f
.
read
()
return
ftype
def
check_tar_format
(
fname
):
"""Check whether a file is a tar archive."""
ftype
=
get_filetype
(
fname
)
return
"tar archive"
in
ftype
or
"gzip compressed"
in
ftype
verbose_cache
=
int
(
os
.
environ
.
get
(
"WDS_VERBOSE_CACHE"
,
"0"
))
def
cached_url_opener
(
data
,
handler
=
reraise_exception
,
cache_size
=-
1
,
cache_dir
=
None
,
url_to_name
=
pipe_cleaner
,
validator
=
check_tar_format
,
verbose
=
False
,
always
=
False
,
):
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
verbose
=
verbose
or
verbose_cache
for
sample
in
data
:
assert
isinstance
(
sample
,
dict
),
sample
assert
"url"
in
sample
url
=
sample
[
"url"
]
attempts
=
5
try
:
if
not
always
and
os
.
path
.
exists
(
url
):
dest
=
url
else
:
dest
=
get_file_cached
(
url
,
cache_size
=
cache_size
,
cache_dir
=
cache_dir
,
url_to_name
=
url_to_name
,
verbose
=
verbose
,
)
if
verbose
:
print
(
"# opening %s"
%
dest
,
file
=
sys
.
stderr
)
assert
os
.
path
.
exists
(
dest
)
if
not
validator
(
dest
):
ftype
=
get_filetype
(
dest
)
with
open
(
dest
,
"rb"
)
as
f
:
data
=
f
.
read
(
200
)
os
.
remove
(
dest
)
raise
ValueError
(
"%s (%s) is not a tar archive, but a %s, contains %s"
%
(
dest
,
url
,
ftype
,
repr
(
data
))
)
try
:
stream
=
open
(
dest
,
"rb"
)
sample
.
update
(
stream
=
stream
)
yield
sample
except
FileNotFoundError
as
exn
:
# dealing with race conditions in lru_cleanup
attempts
-=
1
if
attempts
>
0
:
time
.
sleep
(
random
.
random
()
*
10
)
continue
raise
exn
except
Exception
as
exn
:
exn
.
args
=
exn
.
args
+
(
url
,)
if
handler
(
exn
):
continue
else
:
break
def
cached_tarfile_samples
(
src
,
handler
=
reraise_exception
,
cache_size
=-
1
,
cache_dir
=
None
,
verbose
=
False
,
url_to_name
=
pipe_cleaner
,
always
=
False
,
):
streams
=
cached_url_opener
(
src
,
handler
=
handler
,
cache_size
=
cache_size
,
cache_dir
=
cache_dir
,
verbose
=
verbose
,
url_to_name
=
url_to_name
,
always
=
always
,
)
samples
=
tar_file_and_group_expander
(
streams
,
handler
=
handler
)
return
samples
cached_tarfile_to_samples
=
filters
.
pipelinefilter
(
cached_tarfile_samples
)
paddlespeech/audio/stream_data/compat.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
from
dataclasses
import
dataclass
from
itertools
import
islice
from
typing
import
List
import
braceexpand
,
yaml
from
webdataset
import
autodecode
from
.
import
cache
,
filters
,
shardlists
,
tariterators
from
.filters
import
reraise_exception
from
.pipeline
import
DataPipeline
from
.paddle_utils
import
DataLoader
,
IterableDataset
class
FluidInterface
:
def
batched
(
self
,
batchsize
):
return
self
.
compose
(
filters
.
batched
(
batchsize
))
def
dynamic_batched
(
self
,
max_frames_in_batch
):
return
self
.
compose
(
filter
.
dynamic_batched
(
max_frames_in_batch
))
def
unbatched
(
self
):
return
self
.
compose
(
filters
.
unbatched
())
def
listed
(
self
,
batchsize
,
partial
=
True
):
return
self
.
compose
(
filters
.
batched
(),
batchsize
=
batchsize
,
collation_fn
=
None
)
def
unlisted
(
self
):
return
self
.
compose
(
filters
.
unlisted
())
def
log_keys
(
self
,
logfile
=
None
):
return
self
.
compose
(
filters
.
log_keys
(
logfile
))
def
shuffle
(
self
,
size
,
**
kw
):
if
size
<
1
:
return
self
else
:
return
self
.
compose
(
filters
.
shuffle
(
size
,
**
kw
))
def
map
(
self
,
f
,
handler
=
reraise_exception
):
return
self
.
compose
(
filters
.
map
(
f
,
handler
=
handler
))
def
decode
(
self
,
*
args
,
pre
=
None
,
post
=
None
,
only
=
None
,
partial
=
False
,
handler
=
reraise_exception
):
handlers
=
[
autodecode
.
ImageHandler
(
x
)
if
isinstance
(
x
,
str
)
else
x
for
x
in
args
]
decoder
=
autodecode
.
Decoder
(
handlers
,
pre
=
pre
,
post
=
post
,
only
=
only
,
partial
=
partial
)
return
self
.
map
(
decoder
,
handler
=
handler
)
def
map_dict
(
self
,
handler
=
reraise_exception
,
**
kw
):
return
self
.
compose
(
filters
.
map_dict
(
handler
=
handler
,
**
kw
))
def
select
(
self
,
predicate
,
**
kw
):
return
self
.
compose
(
filters
.
select
(
predicate
,
**
kw
))
def
to_tuple
(
self
,
*
args
,
handler
=
reraise_exception
):
return
self
.
compose
(
filters
.
to_tuple
(
*
args
,
handler
=
handler
))
def
map_tuple
(
self
,
*
args
,
handler
=
reraise_exception
):
return
self
.
compose
(
filters
.
map_tuple
(
*
args
,
handler
=
handler
))
def
slice
(
self
,
*
args
):
return
self
.
compose
(
filters
.
slice
(
*
args
))
def
rename
(
self
,
**
kw
):
return
self
.
compose
(
filters
.
rename
(
**
kw
))
def
rsample
(
self
,
p
=
0.5
):
return
self
.
compose
(
filters
.
rsample
(
p
))
def
rename_keys
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
rename_keys
(
*
args
,
**
kw
))
def
extract_keys
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
extract_keys
(
*
args
,
**
kw
))
def
xdecode
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
xdecode
(
*
args
,
**
kw
))
def
data_filter
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
data_filter
(
*
args
,
**
kw
))
def
tokenize
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
tokenize
(
*
args
,
**
kw
))
def
resample
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
resample
(
*
args
,
**
kw
))
def
compute_fbank
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
compute_fbank
(
*
args
,
**
kw
))
def
spec_aug
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
spec_aug
(
*
args
,
**
kw
))
def
sort
(
self
,
size
=
500
):
return
self
.
compose
(
filters
.
sort
(
size
))
def
padding
(
self
):
return
self
.
compose
(
filters
.
padding
())
def
cmvn
(
self
,
cmvn_file
):
return
self
.
compose
(
filters
.
cmvn
(
cmvn_file
))
class
WebDataset
(
DataPipeline
,
FluidInterface
):
"""Small fluid-interface wrapper for DataPipeline."""
def
__init__
(
self
,
urls
,
handler
=
reraise_exception
,
resampled
=
False
,
repeat
=
False
,
shardshuffle
=
None
,
cache_size
=
0
,
cache_dir
=
None
,
detshuffle
=
False
,
nodesplitter
=
shardlists
.
single_node_only
,
verbose
=
False
,
):
super
().
__init__
()
if
isinstance
(
urls
,
IterableDataset
):
assert
not
resampled
self
.
append
(
urls
)
elif
isinstance
(
urls
,
str
)
and
(
urls
.
endswith
(
".yaml"
)
or
urls
.
endswith
(
".yml"
)):
with
(
open
(
urls
))
as
stream
:
spec
=
yaml
.
safe_load
(
stream
)
assert
"datasets"
in
spec
self
.
append
(
shardlists
.
MultiShardSample
(
spec
))
elif
isinstance
(
urls
,
dict
):
assert
"datasets"
in
urls
self
.
append
(
shardlists
.
MultiShardSample
(
urls
))
elif
resampled
:
self
.
append
(
shardlists
.
ResampledShards
(
urls
))
else
:
self
.
append
(
shardlists
.
SimpleShardList
(
urls
))
self
.
append
(
nodesplitter
)
self
.
append
(
shardlists
.
split_by_worker
)
if
shardshuffle
is
True
:
shardshuffle
=
100
if
shardshuffle
is
not
None
:
if
detshuffle
:
self
.
append
(
filters
.
detshuffle
(
shardshuffle
))
else
:
self
.
append
(
filters
.
shuffle
(
shardshuffle
))
if
cache_size
==
0
:
self
.
append
(
tariterators
.
tarfile_to_samples
(
handler
=
handler
))
else
:
assert
cache_size
==
-
1
or
cache_size
>
0
self
.
append
(
cache
.
cached_tarfile_to_samples
(
handler
=
handler
,
verbose
=
verbose
,
cache_size
=
cache_size
,
cache_dir
=
cache_dir
,
)
)
class
FluidWrapper
(
DataPipeline
,
FluidInterface
):
"""Small fluid-interface wrapper for DataPipeline."""
def
__init__
(
self
,
initial
):
super
().
__init__
()
self
.
append
(
initial
)
class
WebLoader
(
DataPipeline
,
FluidInterface
):
def
__init__
(
self
,
*
args
,
**
kw
):
super
().
__init__
(
DataLoader
(
*
args
,
**
kw
))
paddlespeech/audio/stream_data/filters.py
0 → 100644
浏览文件 @
8f5e6109
此差异已折叠。
点击以展开。
paddlespeech/audio/stream_data/paddle_utils.py
0 → 100644
浏览文件 @
8f5e6109
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Mock implementations of paddle interfaces when paddle is not available."""
try
:
from
paddle.io
import
DataLoader
,
IterableDataset
except
ModuleNotFoundError
:
class
IterableDataset
:
"""Empty implementation of IterableDataset when paddle is not available."""
pass
class
DataLoader
:
"""Empty implementation of DataLoader when paddle is not available."""
pass
try
:
from
paddle
import
Tensor
as
PaddleTensor
except
ModuleNotFoundError
:
class
TorchTensor
:
"""Empty implementation of PaddleTensor when paddle is not available."""
pass
paddlespeech/audio/stream_data/pipeline.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#%%
import
copy
,
os
,
random
,
sys
,
time
from
dataclasses
import
dataclass
from
itertools
import
islice
from
typing
import
List
import
braceexpand
,
yaml
from
webdataset
import
autodecode
,
extradatasets
as
eds
,
filters
,
shardlists
,
tariterators
from
webdataset.handlers
import
reraise_exception
from
.paddle_utils
import
DataLoader
,
IterableDataset
from
.utils
import
PipelineStage
def
add_length_method
(
obj
):
def
length
(
self
):
return
self
.
size
Combined
=
type
(
obj
.
__class__
.
__name__
+
"_Length"
,
(
obj
.
__class__
,
IterableDataset
),
{
"__len__"
:
length
},
)
obj
.
__class__
=
Combined
return
obj
class
DataPipeline
(
IterableDataset
,
PipelineStage
):
"""A pipeline starting with an IterableDataset and a series of filters."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
()
self
.
pipeline
=
[]
self
.
length
=
-
1
self
.
repetitions
=
1
self
.
nsamples
=
-
1
for
arg
in
args
:
if
arg
is
None
:
continue
if
isinstance
(
arg
,
list
):
self
.
pipeline
.
extend
(
arg
)
else
:
self
.
pipeline
.
append
(
arg
)
def
invoke
(
self
,
f
,
*
args
,
**
kwargs
):
"""Apply a pipeline stage, possibly to the output of a previous stage."""
if
isinstance
(
f
,
PipelineStage
):
return
f
.
run
(
*
args
,
**
kwargs
)
if
isinstance
(
f
,
(
IterableDataset
,
DataLoader
))
and
len
(
args
)
==
0
:
return
iter
(
f
)
if
isinstance
(
f
,
list
):
return
iter
(
f
)
if
callable
(
f
):
result
=
f
(
*
args
,
**
kwargs
)
return
result
raise
ValueError
(
f
"
{
f
}
: not a valid pipeline stage"
)
def
iterator1
(
self
):
"""Create an iterator through one epoch in the pipeline."""
source
=
self
.
invoke
(
self
.
pipeline
[
0
])
for
step
in
self
.
pipeline
[
1
:]:
source
=
self
.
invoke
(
step
,
source
)
return
source
def
iterator
(
self
):
"""Create an iterator through the entire dataset, using the given number of repetitions."""
for
i
in
range
(
self
.
repetitions
):
for
sample
in
self
.
iterator1
():
yield
sample
def
__iter__
(
self
):
"""Create an iterator through the pipeline, repeating and slicing as requested."""
if
self
.
repetitions
!=
1
:
if
self
.
nsamples
>
0
:
return
islice
(
self
.
iterator
(),
self
.
nsamples
)
else
:
return
self
.
iterator
()
else
:
return
self
.
iterator
()
def
stage
(
self
,
i
):
"""Return pipeline stage i."""
return
self
.
pipeline
[
i
]
def
append
(
self
,
f
):
"""Append a pipeline stage (modifies the object)."""
self
.
pipeline
.
append
(
f
)
def
compose
(
self
,
*
args
):
"""Append a pipeline stage to a copy of the pipeline and returns the copy."""
result
=
copy
.
copy
(
self
)
for
arg
in
args
:
result
.
append
(
arg
)
return
result
def
with_length
(
self
,
n
):
"""Add a __len__ method returning the desired value.
This does not change the actual number of samples in an epoch.
PyTorch IterableDataset should not have a __len__ method.
This is provided only as a workaround for some broken training environments
that require a __len__ method.
"""
self
.
size
=
n
return
add_length_method
(
self
)
def
with_epoch
(
self
,
nsamples
=-
1
,
nbatches
=-
1
):
"""Change the epoch to return the given number of samples/batches.
The two arguments mean the same thing."""
self
.
repetitions
=
sys
.
maxsize
self
.
nsamples
=
max
(
nsamples
,
nbatches
)
return
self
def
repeat
(
self
,
nepochs
=-
1
,
nbatches
=-
1
):
"""Repeat iterating through the dataset for the given #epochs up to the given #samples."""
if
nepochs
>
0
:
self
.
repetitions
=
nepochs
self
.
nsamples
=
nbatches
else
:
self
.
repetitions
=
sys
.
maxsize
self
.
nsamples
=
nbatches
return
self
paddlespeech/audio/stream_data/shardlists.py
0 → 100644
浏览文件 @
8f5e6109
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
"""
import
os
,
random
,
sys
,
time
from
dataclasses
import
dataclass
,
field
from
itertools
import
islice
from
typing
import
List
import
braceexpand
,
yaml
from
.
import
utils
from
.filters
import
pipelinefilter
from
.paddle_utils
import
IterableDataset
def
expand_urls
(
urls
):
if
isinstance
(
urls
,
str
):
urllist
=
urls
.
split
(
"::"
)
result
=
[]
for
url
in
urllist
:
result
.
extend
(
braceexpand
.
braceexpand
(
url
))
return
result
else
:
return
list
(
urls
)
class
SimpleShardList
(
IterableDataset
):
"""An iterable dataset yielding a list of urls."""
def
__init__
(
self
,
urls
,
seed
=
None
):
"""Iterate through the list of shards.
:param urls: a list of URLs as a Python list or brace notation string
"""
super
().
__init__
()
urls
=
expand_urls
(
urls
)
self
.
urls
=
urls
assert
isinstance
(
self
.
urls
[
0
],
str
)
self
.
seed
=
seed
def
__len__
(
self
):
return
len
(
self
.
urls
)
def
__iter__
(
self
):
"""Return an iterator over the shards."""
urls
=
self
.
urls
.
copy
()
if
self
.
seed
is
not
None
:
random
.
Random
(
self
.
seed
).
shuffle
(
urls
)
for
url
in
urls
:
yield
dict
(
url
=
url
)
def
split_by_node
(
src
,
group
=
None
):
rank
,
world_size
,
worker
,
num_workers
=
utils
.
paddle_worker_info
(
group
=
group
)
if
world_size
>
1
:
for
s
in
islice
(
src
,
rank
,
None
,
world_size
):
yield
s
else
:
for
s
in
src
:
yield
s
def
single_node_only
(
src
,
group
=
None
):
rank
,
world_size
,
worker
,
num_workers
=
utils
.
paddle_worker_info
(
group
=
group
)
if
world_size
>
1
:
raise
ValueError
(
"input pipeline needs to be reconfigured for multinode training"
)
for
s
in
src
:
yield
s
def
split_by_worker
(
src
):
rank
,
world_size
,
worker
,
num_workers
=
utils
.
paddle_worker_info
()
if
num_workers
>
1
:
for
s
in
islice
(
src
,
worker
,
None
,
num_workers
):
yield
s
else
:
for
s
in
src
:
yield
s
def
resampled_
(
src
,
n
=
sys
.
maxsize
):
import
random
seed
=
time
.
time
()
try
:
seed
=
open
(
"/dev/random"
,
"rb"
).
read
(
20
)
except
Exception
as
exn
:
print
(
repr
(
exn
)[:
50
],
file
=
sys
.
stderr
)
rng
=
random
.
Random
(
seed
)
print
(
"# resampled loading"
,
file
=
sys
.
stderr
)
items
=
list
(
src
)
print
(
f
"# resampled got
{
len
(
items
)
}
samples, yielding
{
n
}
"
,
file
=
sys
.
stderr
)
for
i
in
range
(
n
):
yield
rng
.
choice
(
items
)
resampled
=
pipelinefilter
(
resampled_
)
def
non_empty
(
src
):
count
=
0
for
s
in
src
:
yield
s
count
+=
1
if
count
==
0
:
raise
ValueError
(
"pipeline stage received no data at all and this was declared as an error"
)
@
dataclass
class
MSSource
:
"""Class representing a data source."""
name
:
str
=
""
perepoch
:
int
=
-
1
resample
:
bool
=
False
urls
:
List
[
str
]
=
field
(
default_factory
=
list
)
default_rng
=
random
.
Random
()
def
expand
(
s
):
return
os
.
path
.
expanduser
(
os
.
path
.
expandvars
(
s
))
class
MultiShardSample
(
IterableDataset
):
def
__init__
(
self
,
fname
):
"""Construct a shardlist from multiple sources using a YAML spec."""
self
.
epoch
=
-
1
class
MultiShardSample
(
IterableDataset
):
def
__init__
(
self
,
fname
):
"""Construct a shardlist from multiple sources using a YAML spec."""
self
.
epoch
=
-
1
self
.
parse_spec
(
fname
)
def
parse_spec
(
self
,
fname
):
self
.
rng
=
default_rng
# capture default_rng if we fork
if
isinstance
(
fname
,
dict
):
spec
=
fname
fname
=
"{dict}"
else
:
with
open
(
fname
)
as
stream
:
spec
=
yaml
.
safe_load
(
stream
)
assert
set
(
spec
.
keys
()).
issubset
(
set
(
"prefix datasets buckets"
.
split
())),
list
(
spec
.
keys
())
prefix
=
expand
(
spec
.
get
(
"prefix"
,
""
))
self
.
sources
=
[]
for
ds
in
spec
[
"datasets"
]:
assert
set
(
ds
.
keys
()).
issubset
(
set
(
"buckets name shards resample choose"
.
split
())),
list
(
ds
.
keys
()
)
buckets
=
ds
.
get
(
"buckets"
,
spec
.
get
(
"buckets"
,
[]))
if
isinstance
(
buckets
,
str
):
buckets
=
[
buckets
]
buckets
=
[
expand
(
s
)
for
s
in
buckets
]
if
buckets
==
[]:
buckets
=
[
""
]
assert
len
(
buckets
)
==
1
,
f
"
{
buckets
}
: FIXME support for multiple buckets unimplemented"
bucket
=
buckets
[
0
]
name
=
ds
.
get
(
"name"
,
"@"
+
bucket
)
urls
=
ds
[
"shards"
]
if
isinstance
(
urls
,
str
):
urls
=
[
urls
]
# urls = [u for url in urls for u in braceexpand.braceexpand(url)]
urls
=
[
prefix
+
os
.
path
.
join
(
bucket
,
u
)
for
url
in
urls
for
u
in
braceexpand
.
braceexpand
(
expand
(
url
))
]
resample
=
ds
.
get
(
"resample"
,
-
1
)
nsample
=
ds
.
get
(
"choose"
,
-
1
)
if
nsample
>
len
(
urls
):
raise
ValueError
(
f
"perepoch
{
nsample
}
must be no greater than the number of shards"
)
if
(
nsample
>
0
)
and
(
resample
>
0
):
raise
ValueError
(
"specify only one of perepoch or choose"
)
entry
=
MSSource
(
name
=
name
,
urls
=
urls
,
perepoch
=
nsample
,
resample
=
resample
)
self
.
sources
.
append
(
entry
)
print
(
f
"#
{
name
}
{
len
(
urls
)
}
{
nsample
}
"
,
file
=
sys
.
stderr
)
def
set_epoch
(
self
,
seed
):
"""Set the current epoch (for consistent shard selection among nodes)."""
self
.
rng
=
random
.
Random
(
seed
)
def
get_shards_for_epoch
(
self
):
result
=
[]
for
source
in
self
.
sources
:
if
source
.
resample
>
0
:
# sample with replacement
l
=
self
.
rng
.
choices
(
source
.
urls
,
k
=
source
.
resample
)
elif
source
.
perepoch
>
0
:
# sample without replacement
l
=
list
(
source
.
urls
)
self
.
rng
.
shuffle
(
l
)
l
=
l
[:
source
.
perepoch
]
else
:
l
=
list
(
source
.
urls
)
result
+=
l
self
.
rng
.
shuffle
(
result
)
return
result
def
__iter__
(
self
):
shards
=
self
.
get_shards_for_epoch
()
for
shard
in
shards
:
yield
dict
(
url
=
shard
)
def
shardspec
(
spec
):
if
spec
.
endswith
(
".yaml"
):
return
MultiShardSample
(
spec
)
else
:
return
SimpleShardList
(
spec
)
class
ResampledShards
(
IterableDataset
):
"""An iterable dataset yielding a list of urls."""
def
__init__
(
self
,
urls
,
nshards
=
sys
.
maxsize
,
worker_seed
=
None
,
deterministic
=
False
,
):
"""Sample shards from the shard list with replacement.
:param urls: a list of URLs as a Python list or brace notation string
"""
super
().
__init__
()
urls
=
expand_urls
(
urls
)
self
.
urls
=
urls
assert
isinstance
(
self
.
urls
[
0
],
str
)
self
.
nshards
=
nshards
self
.
worker_seed
=
utils
.
paddle_worker_seed
if
worker_seed
is
None
else
worker_seed
self
.
deterministic
=
deterministic
self
.
epoch
=
-
1
def
__iter__
(
self
):
"""Return an iterator over the shards."""
self
.
epoch
+=
1
if
self
.
deterministic
:
seed
=
utils
.
make_seed
(
self
.
worker_seed
(),
self
.
epoch
)
else
:
seed
=
utils
.
make_seed
(
self
.
worker_seed
(),
self
.
epoch
,
os
.
getpid
(),
time
.
time_ns
(),
os
.
urandom
(
4
))
if
os
.
environ
.
get
(
"WDS_SHOW_SEED"
,
"0"
)
==
"1"
:
print
(
f
"# ResampledShards seed
{
seed
}
"
)
self
.
rng
=
random
.
Random
(
seed
)
for
_
in
range
(
self
.
nshards
):
index
=
self
.
rng
.
randint
(
0
,
len
(
self
.
urls
)
-
1
)
yield
dict
(
url
=
self
.
urls
[
index
])
paddlespeech/audio/stream_data/tariterators.py
0 → 100644
浏览文件 @
8f5e6109
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Low level iteration functions for tar archives."""
import
random
,
re
,
tarfile
import
braceexpand
from
.
import
filters
from
webdataset
import
gopen
from
webdataset.handlers
import
reraise_exception
trace
=
False
meta_prefix
=
"__"
meta_suffix
=
"__"
from
...
import
audio
as
paddleaudio
import
paddle
import
numpy
as
np
AUDIO_FORMAT_SETS
=
set
([
'flac'
,
'mp3'
,
'm4a'
,
'ogg'
,
'opus'
,
'wav'
,
'wma'
])
def
base_plus_ext
(
path
):
"""Split off all file extensions.
Returns base, allext.
:param path: path with extensions
:param returns: path with all extensions removed
"""
match
=
re
.
match
(
r
"^((?:.*/|)[^.]+)[.]([^/]*)$"
,
path
)
if
not
match
:
return
None
,
None
return
match
.
group
(
1
),
match
.
group
(
2
)
def
valid_sample
(
sample
):
"""Check whether a sample is valid.
:param sample: sample to be checked
"""
return
(
sample
is
not
None
and
isinstance
(
sample
,
dict
)
and
len
(
list
(
sample
.
keys
()))
>
0
and
not
sample
.
get
(
"__bad__"
,
False
)
)
# FIXME: UNUSED
def
shardlist
(
urls
,
*
,
shuffle
=
False
):
"""Given a list of URLs, yields that list, possibly shuffled."""
if
isinstance
(
urls
,
str
):
urls
=
braceexpand
.
braceexpand
(
urls
)
else
:
urls
=
list
(
urls
)
if
shuffle
:
random
.
shuffle
(
urls
)
for
url
in
urls
:
yield
dict
(
url
=
url
)
def
url_opener
(
data
,
handler
=
reraise_exception
,
**
kw
):
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
for
sample
in
data
:
assert
isinstance
(
sample
,
dict
),
sample
assert
"url"
in
sample
url
=
sample
[
"url"
]
try
:
stream
=
gopen
.
gopen
(
url
,
**
kw
)
sample
.
update
(
stream
=
stream
)
yield
sample
except
Exception
as
exn
:
exn
.
args
=
exn
.
args
+
(
url
,)
if
handler
(
exn
):
continue
else
:
break
def
tar_file_iterator
(
fileobj
,
skip_meta
=
r
"__[^/]*__($|/)"
,
handler
=
reraise_exception
):
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
:param fileobj: byte stream suitable for tarfile
:param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)")
"""
stream
=
tarfile
.
open
(
fileobj
=
fileobj
,
mode
=
"r:*"
)
for
tarinfo
in
stream
:
fname
=
tarinfo
.
name
try
:
if
not
tarinfo
.
isreg
():
continue
if
fname
is
None
:
continue
if
(
"/"
not
in
fname
and
fname
.
startswith
(
meta_prefix
)
and
fname
.
endswith
(
meta_suffix
)
):
# skipping metadata for now
continue
if
skip_meta
is
not
None
and
re
.
match
(
skip_meta
,
fname
):
continue
name
=
tarinfo
.
name
pos
=
name
.
rfind
(
'.'
)
assert
pos
>
0
prefix
,
postfix
=
name
[:
pos
],
name
[
pos
+
1
:]
if
postfix
==
'wav'
:
waveform
,
sample_rate
=
paddleaudio
.
load
(
stream
.
extractfile
(
tarinfo
),
normal
=
False
)
result
=
dict
(
fname
=
prefix
,
wav
=
waveform
,
sample_rate
=
sample_rate
)
else
:
txt
=
stream
.
extractfile
(
tarinfo
).
read
().
decode
(
'utf8'
).
strip
()
result
=
dict
(
fname
=
prefix
,
txt
=
txt
)
#result = dict(fname=fname, data=data)
yield
result
stream
.
members
=
[]
except
Exception
as
exn
:
if
hasattr
(
exn
,
"args"
)
and
len
(
exn
.
args
)
>
0
:
exn
.
args
=
(
exn
.
args
[
0
]
+
" @ "
+
str
(
fileobj
),)
+
exn
.
args
[
1
:]
if
handler
(
exn
):
continue
else
:
break
del
stream
def
tar_file_and_group_iterator
(
fileobj
,
skip_meta
=
r
"__[^/]*__($|/)"
,
handler
=
reraise_exception
):
""" Expand a stream of open tar files into a stream of tar file contents.
And groups the file with same prefix
Args:
data: Iterable[{src, stream}]
Returns:
Iterable[{key, wav, txt, sample_rate}]
"""
stream
=
tarfile
.
open
(
fileobj
=
fileobj
,
mode
=
"r:*"
)
prev_prefix
=
None
example
=
{}
valid
=
True
for
tarinfo
in
stream
:
name
=
tarinfo
.
name
pos
=
name
.
rfind
(
'.'
)
assert
pos
>
0
prefix
,
postfix
=
name
[:
pos
],
name
[
pos
+
1
:]
if
prev_prefix
is
not
None
and
prefix
!=
prev_prefix
:
example
[
'fname'
]
=
prev_prefix
if
valid
:
yield
example
example
=
{}
valid
=
True
with
stream
.
extractfile
(
tarinfo
)
as
file_obj
:
try
:
if
postfix
==
'txt'
:
example
[
'txt'
]
=
file_obj
.
read
().
decode
(
'utf8'
).
strip
()
elif
postfix
in
AUDIO_FORMAT_SETS
:
waveform
,
sample_rate
=
paddleaudio
.
load
(
file_obj
,
normal
=
False
)
waveform
=
paddle
.
to_tensor
(
np
.
expand_dims
(
np
.
array
(
waveform
),
0
),
dtype
=
paddle
.
float32
)
example
[
'wav'
]
=
waveform
example
[
'sample_rate'
]
=
sample_rate
else
:
example
[
postfix
]
=
file_obj
.
read
()
except
Exception
as
exn
:
if
hasattr
(
exn
,
"args"
)
and
len
(
exn
.
args
)
>
0
:
exn
.
args
=
(
exn
.
args
[
0
]
+
" @ "
+
str
(
fileobj
),)
+
exn
.
args
[
1
:]
if
handler
(
exn
):
continue
else
:
break
valid
=
False
# logging.warning('error to parse {}'.format(name))
prev_prefix
=
prefix
if
prev_prefix
is
not
None
:
example
[
'fname'
]
=
prev_prefix
yield
example
stream
.
close
()
def
tar_file_expander
(
data
,
handler
=
reraise_exception
):
"""Expand a stream of open tar files into a stream of tar file contents.
This returns an iterator over (filename, file_contents).
"""
for
source
in
data
:
url
=
source
[
"url"
]
try
:
assert
isinstance
(
source
,
dict
)
assert
"stream"
in
source
for
sample
in
tar_file_iterator
(
source
[
"stream"
]):
assert
(
isinstance
(
sample
,
dict
)
and
"data"
in
sample
and
"fname"
in
sample
)
sample
[
"__url__"
]
=
url
yield
sample
except
Exception
as
exn
:
exn
.
args
=
exn
.
args
+
(
source
.
get
(
"stream"
),
source
.
get
(
"url"
))
if
handler
(
exn
):
continue
else
:
break
def
tar_file_and_group_expander
(
data
,
handler
=
reraise_exception
):
"""Expand a stream of open tar files into a stream of tar file contents.
This returns an iterator over (filename, file_contents).
"""
for
source
in
data
:
url
=
source
[
"url"
]
try
:
assert
isinstance
(
source
,
dict
)
assert
"stream"
in
source
for
sample
in
tar_file_and_group_iterator
(
source
[
"stream"
]):
assert
(
isinstance
(
sample
,
dict
)
and
"wav"
in
sample
and
"txt"
in
sample
and
"fname"
in
sample
)
sample
[
"__url__"
]
=
url
yield
sample
except
Exception
as
exn
:
exn
.
args
=
exn
.
args
+
(
source
.
get
(
"stream"
),
source
.
get
(
"url"
))
if
handler
(
exn
):
continue
else
:
break
def
group_by_keys
(
data
,
keys
=
base_plus_ext
,
lcase
=
True
,
suffixes
=
None
,
handler
=
None
):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
:param lcase: convert suffixes to lower case (Default value = True)
"""
current_sample
=
None
for
filesample
in
data
:
assert
isinstance
(
filesample
,
dict
)
fname
,
value
=
filesample
[
"fname"
],
filesample
[
"data"
]
prefix
,
suffix
=
keys
(
fname
)
if
trace
:
print
(
prefix
,
suffix
,
current_sample
.
keys
()
if
isinstance
(
current_sample
,
dict
)
else
None
,
)
if
prefix
is
None
:
continue
if
lcase
:
suffix
=
suffix
.
lower
()
if
current_sample
is
None
or
prefix
!=
current_sample
[
"__key__"
]:
if
valid_sample
(
current_sample
):
yield
current_sample
current_sample
=
dict
(
__key__
=
prefix
,
__url__
=
filesample
[
"__url__"
])
if
suffix
in
current_sample
:
raise
ValueError
(
f
"
{
fname
}
: duplicate file name in tar file
{
suffix
}
{
current_sample
.
keys
()
}
"
)
if
suffixes
is
None
or
suffix
in
suffixes
:
current_sample
[
suffix
]
=
value
if
valid_sample
(
current_sample
):
yield
current_sample
def
tarfile_samples
(
src
,
handler
=
reraise_exception
):
streams
=
url_opener
(
src
,
handler
=
handler
)
samples
=
tar_file_and_group_expander
(
streams
,
handler
=
handler
)
return
samples
tarfile_to_samples
=
filters
.
pipelinefilter
(
tarfile_samples
)
paddlespeech/audio/stream_data/utils.py
0 → 100644
浏览文件 @
8f5e6109
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
"""Miscellaneous utility functions."""
import
importlib
import
itertools
as
itt
import
os
import
re
import
sys
from
typing
import
Any
,
Callable
,
Iterator
,
Optional
,
Union
def
make_seed
(
*
args
):
seed
=
0
for
arg
in
args
:
seed
=
(
seed
*
31
+
hash
(
arg
))
&
0x7FFFFFFF
return
seed
class
PipelineStage
:
def
invoke
(
self
,
*
args
,
**
kw
):
raise
NotImplementedError
def
identity
(
x
:
Any
)
->
Any
:
"""Return the argument as is."""
return
x
def
safe_eval
(
s
:
str
,
expr
:
str
=
"{}"
):
"""Evaluate the given expression more safely."""
if
re
.
sub
(
"[^A-Za-z0-9_]"
,
""
,
s
)
!=
s
:
raise
ValueError
(
f
"safe_eval: illegal characters in: '
{
s
}
'"
)
return
eval
(
expr
.
format
(
s
))
def
lookup_sym
(
sym
:
str
,
modules
:
list
):
"""Look up a symbol in a list of modules."""
for
mname
in
modules
:
module
=
importlib
.
import_module
(
mname
,
package
=
"webdataset"
)
result
=
getattr
(
module
,
sym
,
None
)
if
result
is
not
None
:
return
result
return
None
def
repeatedly0
(
loader
:
Iterator
,
nepochs
:
int
=
sys
.
maxsize
,
nbatches
:
int
=
sys
.
maxsize
):
"""Repeatedly returns batches from a DataLoader."""
for
epoch
in
range
(
nepochs
):
for
sample
in
itt
.
islice
(
loader
,
nbatches
):
yield
sample
def
guess_batchsize
(
batch
:
Union
[
tuple
,
list
]):
"""Guess the batch size by looking at the length of the first element in a tuple."""
return
len
(
batch
[
0
])
def
repeatedly
(
source
:
Iterator
,
nepochs
:
int
=
None
,
nbatches
:
int
=
None
,
nsamples
:
int
=
None
,
batchsize
:
Callable
[...,
int
]
=
guess_batchsize
,
):
"""Repeatedly yield samples from an iterator."""
epoch
=
0
batch
=
0
total
=
0
while
True
:
for
sample
in
source
:
yield
sample
batch
+=
1
if
nbatches
is
not
None
and
batch
>=
nbatches
:
return
if
nsamples
is
not
None
:
total
+=
guess_batchsize
(
sample
)
if
total
>=
nsamples
:
return
epoch
+=
1
if
nepochs
is
not
None
and
epoch
>=
nepochs
:
return
def
paddle_worker_info
(
group
=
None
):
"""Return node and worker info for PyTorch and some distributed environments."""
rank
=
0
world_size
=
1
worker
=
0
num_workers
=
1
if
"RANK"
in
os
.
environ
and
"WORLD_SIZE"
in
os
.
environ
:
rank
=
int
(
os
.
environ
[
"RANK"
])
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
else
:
try
:
import
paddle.distributed
group
=
group
or
paddle
.
distributed
.
get_group
()
rank
=
paddle
.
distributed
.
get_rank
()
world_size
=
paddle
.
distributed
.
get_world_size
()
except
ModuleNotFoundError
:
pass
if
"WORKER"
in
os
.
environ
and
"NUM_WORKERS"
in
os
.
environ
:
worker
=
int
(
os
.
environ
[
"WORKER"
])
num_workers
=
int
(
os
.
environ
[
"NUM_WORKERS"
])
else
:
try
:
import
paddle.io.get_worker_info
worker_info
=
paddle
.
io
.
get_worker_info
()
if
worker_info
is
not
None
:
worker
=
worker_info
.
id
num_workers
=
worker_info
.
num_workers
except
ModuleNotFoundError
:
pass
return
rank
,
world_size
,
worker
,
num_workers
def
paddle_worker_seed
(
group
=
None
):
"""Compute a distinct, deterministic RNG seed for each worker and node."""
rank
,
world_size
,
worker
,
num_workers
=
paddle_worker_info
(
group
=
group
)
return
rank
*
1000
+
worker
paddlespeech/audio/transform/__init__.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/audio/transform/add_deltas.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
numpy
as
np
def
delta
(
feat
,
window
):
assert
window
>
0
delta_feat
=
np
.
zeros_like
(
feat
)
for
i
in
range
(
1
,
window
+
1
):
delta_feat
[:
-
i
]
+=
i
*
feat
[
i
:]
delta_feat
[
i
:]
+=
-
i
*
feat
[:
-
i
]
delta_feat
[
-
i
:]
+=
i
*
feat
[
-
1
]
delta_feat
[:
i
]
+=
-
i
*
feat
[
0
]
delta_feat
/=
2
*
sum
(
i
**
2
for
i
in
range
(
1
,
window
+
1
))
return
delta_feat
def
add_deltas
(
x
,
window
=
2
,
order
=
2
):
"""
Args:
x (np.ndarray): speech feat, (T, D).
Return:
np.ndarray: (T, (1+order)*D)
"""
feats
=
[
x
]
for
_
in
range
(
order
):
feats
.
append
(
delta
(
feats
[
-
1
],
window
))
return
np
.
concatenate
(
feats
,
axis
=
1
)
class
AddDeltas
():
def
__init__
(
self
,
window
=
2
,
order
=
2
):
self
.
window
=
window
self
.
order
=
order
def
__repr__
(
self
):
return
"{name}(window={window}, order={order}"
.
format
(
name
=
self
.
__class__
.
__name__
,
window
=
self
.
window
,
order
=
self
.
order
)
def
__call__
(
self
,
x
):
return
add_deltas
(
x
,
window
=
self
.
window
,
order
=
self
.
order
)
paddlespeech/audio/transform/channel_selector.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
numpy
class
ChannelSelector
():
"""Select 1ch from multi-channel signal"""
def
__init__
(
self
,
train_channel
=
"random"
,
eval_channel
=
0
,
axis
=
1
):
self
.
train_channel
=
train_channel
self
.
eval_channel
=
eval_channel
self
.
axis
=
axis
def
__repr__
(
self
):
return
(
"{name}(train_channel={train_channel}, "
"eval_channel={eval_channel}, axis={axis})"
.
format
(
name
=
self
.
__class__
.
__name__
,
train_channel
=
self
.
train_channel
,
eval_channel
=
self
.
eval_channel
,
axis
=
self
.
axis
,
))
def
__call__
(
self
,
x
,
train
=
True
):
# Assuming x: [Time, Channel] by default
if
x
.
ndim
<=
self
.
axis
:
# If the dimension is insufficient, then unsqueeze
# (e.g [Time] -> [Time, 1])
ind
=
tuple
(
slice
(
None
)
if
i
<
x
.
ndim
else
None
for
i
in
range
(
self
.
axis
+
1
))
x
=
x
[
ind
]
if
train
:
channel
=
self
.
train_channel
else
:
channel
=
self
.
eval_channel
if
channel
==
"random"
:
ch
=
numpy
.
random
.
randint
(
0
,
x
.
shape
[
self
.
axis
])
else
:
ch
=
channel
ind
=
tuple
(
slice
(
None
)
if
i
!=
self
.
axis
else
ch
for
i
in
range
(
x
.
ndim
))
return
x
[
ind
]
paddlespeech/audio/transform/cmvn.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
io
import
json
import
h5py
import
kaldiio
import
numpy
as
np
class
CMVN
():
"Apply Global/Spk CMVN/iverserCMVN."
def
__init__
(
self
,
stats
,
norm_means
=
True
,
norm_vars
=
False
,
filetype
=
"mat"
,
utt2spk
=
None
,
spk2utt
=
None
,
reverse
=
False
,
std_floor
=
1.0e-20
,
):
self
.
stats_file
=
stats
self
.
norm_means
=
norm_means
self
.
norm_vars
=
norm_vars
self
.
reverse
=
reverse
if
isinstance
(
stats
,
dict
):
stats_dict
=
dict
(
stats
)
else
:
# Use for global CMVN
if
filetype
==
"mat"
:
stats_dict
=
{
None
:
kaldiio
.
load_mat
(
stats
)}
# Use for global CMVN
elif
filetype
==
"npy"
:
stats_dict
=
{
None
:
np
.
load
(
stats
)}
# Use for speaker CMVN
elif
filetype
==
"ark"
:
self
.
accept_uttid
=
True
stats_dict
=
dict
(
kaldiio
.
load_ark
(
stats
))
# Use for speaker CMVN
elif
filetype
==
"hdf5"
:
self
.
accept_uttid
=
True
stats_dict
=
h5py
.
File
(
stats
)
else
:
raise
ValueError
(
"Not supporting filetype={}"
.
format
(
filetype
))
if
utt2spk
is
not
None
:
self
.
utt2spk
=
{}
with
io
.
open
(
utt2spk
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
for
line
in
f
:
utt
,
spk
=
line
.
rstrip
().
split
(
None
,
1
)
self
.
utt2spk
[
utt
]
=
spk
elif
spk2utt
is
not
None
:
self
.
utt2spk
=
{}
with
io
.
open
(
spk2utt
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
for
line
in
f
:
spk
,
utts
=
line
.
rstrip
().
split
(
None
,
1
)
for
utt
in
utts
.
split
():
self
.
utt2spk
[
utt
]
=
spk
else
:
self
.
utt2spk
=
None
# Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1),
# and the first vector contains the sum of feats and the second is
# the sum of squares. The last value of the first, i.e. stats[0,-1],
# is the number of samples for this statistics.
self
.
bias
=
{}
self
.
scale
=
{}
for
spk
,
stats
in
stats_dict
.
items
():
assert
len
(
stats
)
==
2
,
stats
.
shape
count
=
stats
[
0
,
-
1
]
# If the feature has two or more dimensions
if
not
(
np
.
isscalar
(
count
)
or
isinstance
(
count
,
(
int
,
float
))):
# The first is only used
count
=
count
.
flatten
()[
0
]
mean
=
stats
[
0
,
:
-
1
]
/
count
# V(x) = E(x^2) - (E(x))^2
var
=
stats
[
1
,
:
-
1
]
/
count
-
mean
*
mean
std
=
np
.
maximum
(
np
.
sqrt
(
var
),
std_floor
)
self
.
bias
[
spk
]
=
-
mean
self
.
scale
[
spk
]
=
1
/
std
def
__repr__
(
self
):
return
(
"{name}(stats_file={stats_file}, "
"norm_means={norm_means}, norm_vars={norm_vars}, "
"reverse={reverse})"
.
format
(
name
=
self
.
__class__
.
__name__
,
stats_file
=
self
.
stats_file
,
norm_means
=
self
.
norm_means
,
norm_vars
=
self
.
norm_vars
,
reverse
=
self
.
reverse
,
))
def
__call__
(
self
,
x
,
uttid
=
None
):
if
self
.
utt2spk
is
not
None
:
spk
=
self
.
utt2spk
[
uttid
]
else
:
spk
=
uttid
if
not
self
.
reverse
:
# apply cmvn
if
self
.
norm_means
:
x
=
np
.
add
(
x
,
self
.
bias
[
spk
])
if
self
.
norm_vars
:
x
=
np
.
multiply
(
x
,
self
.
scale
[
spk
])
else
:
# apply reverse cmvn
if
self
.
norm_vars
:
x
=
np
.
divide
(
x
,
self
.
scale
[
spk
])
if
self
.
norm_means
:
x
=
np
.
subtract
(
x
,
self
.
bias
[
spk
])
return
x
class
UtteranceCMVN
():
"Apply Utterance CMVN"
def
__init__
(
self
,
norm_means
=
True
,
norm_vars
=
False
,
std_floor
=
1.0e-20
):
self
.
norm_means
=
norm_means
self
.
norm_vars
=
norm_vars
self
.
std_floor
=
std_floor
def
__repr__
(
self
):
return
"{name}(norm_means={norm_means}, norm_vars={norm_vars})"
.
format
(
name
=
self
.
__class__
.
__name__
,
norm_means
=
self
.
norm_means
,
norm_vars
=
self
.
norm_vars
,
)
def
__call__
(
self
,
x
,
uttid
=
None
):
# x: [Time, Dim]
square_sums
=
(
x
**
2
).
sum
(
axis
=
0
)
mean
=
x
.
mean
(
axis
=
0
)
if
self
.
norm_means
:
x
=
np
.
subtract
(
x
,
mean
)
if
self
.
norm_vars
:
var
=
square_sums
/
x
.
shape
[
0
]
-
mean
**
2
std
=
np
.
maximum
(
np
.
sqrt
(
var
),
self
.
std_floor
)
x
=
np
.
divide
(
x
,
std
)
return
x
class
GlobalCMVN
():
"Apply Global CMVN"
def
__init__
(
self
,
cmvn_path
,
norm_means
=
True
,
norm_vars
=
True
,
std_floor
=
1.0e-20
):
# cmvn_path: Option[str, dict]
cmvn
=
cmvn_path
self
.
cmvn
=
cmvn
self
.
norm_means
=
norm_means
self
.
norm_vars
=
norm_vars
self
.
std_floor
=
std_floor
if
isinstance
(
cmvn
,
dict
):
cmvn_stats
=
cmvn
else
:
with
open
(
cmvn
)
as
f
:
cmvn_stats
=
json
.
load
(
f
)
self
.
count
=
cmvn_stats
[
'frame_num'
]
self
.
mean
=
np
.
array
(
cmvn_stats
[
'mean_stat'
])
/
self
.
count
self
.
square_sums
=
np
.
array
(
cmvn_stats
[
'var_stat'
])
self
.
var
=
self
.
square_sums
/
self
.
count
-
self
.
mean
**
2
self
.
std
=
np
.
maximum
(
np
.
sqrt
(
self
.
var
),
self
.
std_floor
)
def
__repr__
(
self
):
return
f
"""
{
self
.
__class__
.
__name__
}
(
cmvn_path=
{
self
.
cmvn
}
,
norm_means=
{
self
.
norm_means
}
,
norm_vars=
{
self
.
norm_vars
}
,)"""
def
__call__
(
self
,
x
,
uttid
=
None
):
# x: [Time, Dim]
if
self
.
norm_means
:
x
=
np
.
subtract
(
x
,
self
.
mean
)
if
self
.
norm_vars
:
x
=
np
.
divide
(
x
,
self
.
std
)
return
x
paddlespeech/audio/transform/functional.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
inspect
from
paddlespeech.audio.transform.transform_interface
import
TransformInterface
from
paddlespeech.audio.utils.check_kwargs
import
check_kwargs
class
FuncTrans
(
TransformInterface
):
"""Functional Transformation
WARNING:
Builtin or C/C++ functions may not work properly
because this class heavily depends on the `inspect` module.
Usage:
>>> def foo_bar(x, a=1, b=2):
... '''Foo bar
... :param x: input
... :param int a: default 1
... :param int b: default 2
... '''
... return x + a - b
>>> class FooBar(FuncTrans):
... _func = foo_bar
... __doc__ = foo_bar.__doc__
"""
_func
=
None
def
__init__
(
self
,
**
kwargs
):
self
.
kwargs
=
kwargs
check_kwargs
(
self
.
func
,
kwargs
)
def
__call__
(
self
,
x
):
return
self
.
func
(
x
,
**
self
.
kwargs
)
@
classmethod
def
add_arguments
(
cls
,
parser
):
fname
=
cls
.
_func
.
__name__
.
replace
(
"_"
,
"-"
)
group
=
parser
.
add_argument_group
(
fname
+
" transformation setting"
)
for
k
,
v
in
cls
.
default_params
().
items
():
# TODO(karita): get help and choices from docstring?
attr
=
k
.
replace
(
"_"
,
"-"
)
group
.
add_argument
(
f
"--
{
fname
}
-
{
attr
}
"
,
default
=
v
,
type
=
type
(
v
))
return
parser
@
property
def
func
(
self
):
return
type
(
self
).
_func
@
classmethod
def
default_params
(
cls
):
try
:
d
=
dict
(
inspect
.
signature
(
cls
.
_func
).
parameters
)
except
ValueError
:
d
=
dict
()
return
{
k
:
v
.
default
for
k
,
v
in
d
.
items
()
if
v
.
default
!=
inspect
.
Parameter
.
empty
}
def
__repr__
(
self
):
params
=
self
.
default_params
()
params
.
update
(
**
self
.
kwargs
)
ret
=
self
.
__class__
.
__name__
+
"("
if
len
(
params
)
==
0
:
return
ret
+
")"
for
k
,
v
in
params
.
items
():
ret
+=
"{}={}, "
.
format
(
k
,
v
)
return
ret
[:
-
2
]
+
")"
paddlespeech/audio/transform/perturb.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
librosa
import
numpy
import
scipy
import
soundfile
import
io
import
os
import
h5py
import
numpy
as
np
class
SoundHDF5File
():
"""Collecting sound files to a HDF5 file
>>> f = SoundHDF5File('a.flac.h5', mode='a')
>>> array = np.random.randint(0, 100, 100, dtype=np.int16)
>>> f['id'] = (array, 16000)
>>> array, rate = f['id']
:param: str filepath:
:param: str mode:
:param: str format: The type used when saving wav. flac, nist, htk, etc.
:param: str dtype:
"""
def
__init__
(
self
,
filepath
,
mode
=
"r+"
,
format
=
None
,
dtype
=
"int16"
,
**
kwargs
):
self
.
filepath
=
filepath
self
.
mode
=
mode
self
.
dtype
=
dtype
self
.
file
=
h5py
.
File
(
filepath
,
mode
,
**
kwargs
)
if
format
is
None
:
# filepath = a.flac.h5 -> format = flac
second_ext
=
os
.
path
.
splitext
(
os
.
path
.
splitext
(
filepath
)[
0
])[
1
]
format
=
second_ext
[
1
:]
if
format
.
upper
()
not
in
soundfile
.
available_formats
():
# If not found, flac is selected
format
=
"flac"
# This format affects only saving
self
.
format
=
format
def
__repr__
(
self
):
return
'<SoundHDF5 file "{}" (mode {}, format {}, type {})>'
.
format
(
self
.
filepath
,
self
.
mode
,
self
.
format
,
self
.
dtype
)
def
create_dataset
(
self
,
name
,
shape
=
None
,
data
=
None
,
**
kwds
):
f
=
io
.
BytesIO
()
array
,
rate
=
data
soundfile
.
write
(
f
,
array
,
rate
,
format
=
self
.
format
)
self
.
file
.
create_dataset
(
name
,
shape
=
shape
,
data
=
np
.
void
(
f
.
getvalue
()),
**
kwds
)
def
__setitem__
(
self
,
name
,
data
):
self
.
create_dataset
(
name
,
data
=
data
)
def
__getitem__
(
self
,
key
):
data
=
self
.
file
[
key
][()]
f
=
io
.
BytesIO
(
data
.
tobytes
())
array
,
rate
=
soundfile
.
read
(
f
,
dtype
=
self
.
dtype
)
return
array
,
rate
def
keys
(
self
):
return
self
.
file
.
keys
()
def
values
(
self
):
for
k
in
self
.
file
:
yield
self
[
k
]
def
items
(
self
):
for
k
in
self
.
file
:
yield
k
,
self
[
k
]
def
__iter__
(
self
):
return
iter
(
self
.
file
)
def
__contains__
(
self
,
item
):
return
item
in
self
.
file
def
__len__
(
self
,
item
):
return
len
(
self
.
file
)
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
file
.
close
()
def
close
(
self
):
self
.
file
.
close
()
class
SpeedPerturbation
():
"""SpeedPerturbation
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
and sox-speed just to resample the input,
i.e pitch and tempo are changed both.
"Why use speed option instead of tempo -s in SoX for speed perturbation"
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
Warning:
This function is very slow because of resampling.
I recommmend to apply speed-perturb outside the training using sox.
"""
def
__init__
(
self
,
lower
=
0.9
,
upper
=
1.1
,
utt2ratio
=
None
,
keep_length
=
True
,
res_type
=
"kaiser_best"
,
seed
=
None
,
):
self
.
res_type
=
res_type
self
.
keep_length
=
keep_length
self
.
state
=
numpy
.
random
.
RandomState
(
seed
)
if
utt2ratio
is
not
None
:
self
.
utt2ratio
=
{}
# Use the scheduled ratio for each utterances
self
.
utt2ratio_file
=
utt2ratio
self
.
lower
=
None
self
.
upper
=
None
self
.
accept_uttid
=
True
with
open
(
utt2ratio
,
"r"
)
as
f
:
for
line
in
f
:
utt
,
ratio
=
line
.
rstrip
().
split
(
None
,
1
)
ratio
=
float
(
ratio
)
self
.
utt2ratio
[
utt
]
=
ratio
else
:
self
.
utt2ratio
=
None
# The ratio is given on runtime randomly
self
.
lower
=
lower
self
.
upper
=
upper
def
__repr__
(
self
):
if
self
.
utt2ratio
is
None
:
return
"{}(lower={}, upper={}, "
"keep_length={}, res_type={})"
.
format
(
self
.
__class__
.
__name__
,
self
.
lower
,
self
.
upper
,
self
.
keep_length
,
self
.
res_type
,
)
else
:
return
"{}({}, res_type={})"
.
format
(
self
.
__class__
.
__name__
,
self
.
utt2ratio_file
,
self
.
res_type
)
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
x
x
=
x
.
astype
(
numpy
.
float32
)
if
self
.
accept_uttid
:
ratio
=
self
.
utt2ratio
[
uttid
]
else
:
ratio
=
self
.
state
.
uniform
(
self
.
lower
,
self
.
upper
)
# Note1: resample requires the sampling-rate of input and output,
# but actually only the ratio is used.
y
=
librosa
.
resample
(
x
,
orig_sr
=
ratio
,
target_sr
=
1
,
res_type
=
self
.
res_type
)
if
self
.
keep_length
:
diff
=
abs
(
len
(
x
)
-
len
(
y
))
if
len
(
y
)
>
len
(
x
):
# Truncate noise
y
=
y
[
diff
//
2
:
-
((
diff
+
1
)
//
2
)]
elif
len
(
y
)
<
len
(
x
):
# Assume the time-axis is the first: (Time, Channel)
pad_width
=
[(
diff
//
2
,
(
diff
+
1
)
//
2
)]
+
[
(
0
,
0
)
for
_
in
range
(
y
.
ndim
-
1
)
]
y
=
numpy
.
pad
(
y
,
pad_width
=
pad_width
,
constant_values
=
0
,
mode
=
"constant"
)
return
y
class
SpeedPerturbationSox
():
"""SpeedPerturbationSox
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
and sox-speed just to resample the input,
i.e pitch and tempo are changed both.
To speed up or slow down the sound of a file,
use speed to modify the pitch and the duration of the file.
This raises the speed and reduces the time.
The default factor is 1.0 which makes no change to the audio.
2.0 doubles speed, thus time length is cut by a half and pitch is one interval higher.
"Why use speed option instead of tempo -s in SoX for speed perturbation"
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
tempo option:
sox -t wav input.wav -t wav output.tempo0.9.wav tempo -s 0.9
speed option:
sox -t wav input.wav -t wav output.speed0.9.wav speed 0.9
If we use speed option like above, the pitch of audio also will be changed,
but the tempo option does not change the pitch.
"""
def
__init__
(
self
,
lower
=
0.9
,
upper
=
1.1
,
utt2ratio
=
None
,
keep_length
=
True
,
sr
=
16000
,
seed
=
None
,
):
self
.
sr
=
sr
self
.
keep_length
=
keep_length
self
.
state
=
numpy
.
random
.
RandomState
(
seed
)
try
:
import
soxbindings
as
sox
except
ImportError
:
try
:
from
paddlespeech.s2t.utils
import
dynamic_pip_install
package
=
"sox"
dynamic_pip_install
.
install
(
package
)
package
=
"soxbindings"
if
sys
.
platform
!=
"win32"
:
dynamic_pip_install
.
install
(
package
)
import
soxbindings
as
sox
except
Exception
:
raise
RuntimeError
(
"Can not install soxbindings on your system."
)
self
.
sox
=
sox
if
utt2ratio
is
not
None
:
self
.
utt2ratio
=
{}
# Use the scheduled ratio for each utterances
self
.
utt2ratio_file
=
utt2ratio
self
.
lower
=
None
self
.
upper
=
None
self
.
accept_uttid
=
True
with
open
(
utt2ratio
,
"r"
)
as
f
:
for
line
in
f
:
utt
,
ratio
=
line
.
rstrip
().
split
(
None
,
1
)
ratio
=
float
(
ratio
)
self
.
utt2ratio
[
utt
]
=
ratio
else
:
self
.
utt2ratio
=
None
# The ratio is given on runtime randomly
self
.
lower
=
lower
self
.
upper
=
upper
def
__repr__
(
self
):
if
self
.
utt2ratio
is
None
:
return
f
"""
{
self
.
__class__
.
__name__
}
(
lower=
{
self
.
lower
}
,
upper=
{
self
.
upper
}
,
keep_length=
{
self
.
keep_length
}
,
sample_rate=
{
self
.
sr
}
)"""
else
:
return
f
"""
{
self
.
__class__
.
__name__
}
(
utt2ratio=
{
self
.
utt2ratio_file
}
,
sample_rate=
{
self
.
sr
}
)"""
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
x
x
=
x
.
astype
(
numpy
.
float32
)
if
self
.
accept_uttid
:
ratio
=
self
.
utt2ratio
[
uttid
]
else
:
ratio
=
self
.
state
.
uniform
(
self
.
lower
,
self
.
upper
)
tfm
=
self
.
sox
.
Transformer
()
tfm
.
set_globals
(
multithread
=
False
)
tfm
.
speed
(
ratio
)
y
=
tfm
.
build_array
(
input_array
=
x
,
sample_rate_in
=
self
.
sr
)
if
self
.
keep_length
:
diff
=
abs
(
len
(
x
)
-
len
(
y
))
if
len
(
y
)
>
len
(
x
):
# Truncate noise
y
=
y
[
diff
//
2
:
-
((
diff
+
1
)
//
2
)]
elif
len
(
y
)
<
len
(
x
):
# Assume the time-axis is the first: (Time, Channel)
pad_width
=
[(
diff
//
2
,
(
diff
+
1
)
//
2
)]
+
[
(
0
,
0
)
for
_
in
range
(
y
.
ndim
-
1
)
]
y
=
numpy
.
pad
(
y
,
pad_width
=
pad_width
,
constant_values
=
0
,
mode
=
"constant"
)
if
y
.
ndim
==
2
and
x
.
ndim
==
1
:
# (T, C) -> (T)
y
=
y
.
sequence
(
1
)
return
y
class
BandpassPerturbation
():
"""BandpassPerturbation
Randomly dropout along the frequency axis.
The original idea comes from the following:
"randomly-selected frequency band was cut off under the constraint of
leaving at least 1,000 Hz band within the range of less than 4,000Hz."
(The Hitachi/JHU CHiME-5 system: Advances in speech recognition for
everyday home environments using multiple microphone arrays;
http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf)
"""
def
__init__
(
self
,
lower
=
0.0
,
upper
=
0.75
,
seed
=
None
,
axes
=
(
-
1
,
)):
self
.
lower
=
lower
self
.
upper
=
upper
self
.
state
=
numpy
.
random
.
RandomState
(
seed
)
# x_stft: (Time, Channel, Freq)
self
.
axes
=
axes
def
__repr__
(
self
):
return
"{}(lower={}, upper={})"
.
format
(
self
.
__class__
.
__name__
,
self
.
lower
,
self
.
upper
)
def
__call__
(
self
,
x_stft
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
x_stft
if
x_stft
.
ndim
==
1
:
raise
RuntimeError
(
"Input in time-freq domain: "
"(Time, Channel, Freq) or (Time, Freq)"
)
ratio
=
self
.
state
.
uniform
(
self
.
lower
,
self
.
upper
)
axes
=
[
i
if
i
>=
0
else
x_stft
.
ndim
-
i
for
i
in
self
.
axes
]
shape
=
[
s
if
i
in
axes
else
1
for
i
,
s
in
enumerate
(
x_stft
.
shape
)]
mask
=
self
.
state
.
randn
(
*
shape
)
>
ratio
x_stft
*=
mask
return
x_stft
class
VolumePerturbation
():
def
__init__
(
self
,
lower
=-
1.6
,
upper
=
1.6
,
utt2ratio
=
None
,
dbunit
=
True
,
seed
=
None
):
self
.
dbunit
=
dbunit
self
.
utt2ratio_file
=
utt2ratio
self
.
lower
=
lower
self
.
upper
=
upper
self
.
state
=
numpy
.
random
.
RandomState
(
seed
)
if
utt2ratio
is
not
None
:
# Use the scheduled ratio for each utterances
self
.
utt2ratio
=
{}
self
.
lower
=
None
self
.
upper
=
None
self
.
accept_uttid
=
True
with
open
(
utt2ratio
,
"r"
)
as
f
:
for
line
in
f
:
utt
,
ratio
=
line
.
rstrip
().
split
(
None
,
1
)
ratio
=
float
(
ratio
)
self
.
utt2ratio
[
utt
]
=
ratio
else
:
# The ratio is given on runtime randomly
self
.
utt2ratio
=
None
def
__repr__
(
self
):
if
self
.
utt2ratio
is
None
:
return
"{}(lower={}, upper={}, dbunit={})"
.
format
(
self
.
__class__
.
__name__
,
self
.
lower
,
self
.
upper
,
self
.
dbunit
)
else
:
return
'{}("{}", dbunit={})'
.
format
(
self
.
__class__
.
__name__
,
self
.
utt2ratio_file
,
self
.
dbunit
)
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
x
x
=
x
.
astype
(
numpy
.
float32
)
if
self
.
accept_uttid
:
ratio
=
self
.
utt2ratio
[
uttid
]
else
:
ratio
=
self
.
state
.
uniform
(
self
.
lower
,
self
.
upper
)
if
self
.
dbunit
:
ratio
=
10
**
(
ratio
/
20
)
return
x
*
ratio
class
NoiseInjection
():
"""Add isotropic noise"""
def
__init__
(
self
,
utt2noise
=
None
,
lower
=-
20
,
upper
=-
5
,
utt2ratio
=
None
,
filetype
=
"list"
,
dbunit
=
True
,
seed
=
None
,
):
self
.
utt2noise_file
=
utt2noise
self
.
utt2ratio_file
=
utt2ratio
self
.
filetype
=
filetype
self
.
dbunit
=
dbunit
self
.
lower
=
lower
self
.
upper
=
upper
self
.
state
=
numpy
.
random
.
RandomState
(
seed
)
if
utt2ratio
is
not
None
:
# Use the scheduled ratio for each utterances
self
.
utt2ratio
=
{}
with
open
(
utt2noise
,
"r"
)
as
f
:
for
line
in
f
:
utt
,
snr
=
line
.
rstrip
().
split
(
None
,
1
)
snr
=
float
(
snr
)
self
.
utt2ratio
[
utt
]
=
snr
else
:
# The ratio is given on runtime randomly
self
.
utt2ratio
=
None
if
utt2noise
is
not
None
:
self
.
utt2noise
=
{}
if
filetype
==
"list"
:
with
open
(
utt2noise
,
"r"
)
as
f
:
for
line
in
f
:
utt
,
filename
=
line
.
rstrip
().
split
(
None
,
1
)
signal
,
rate
=
soundfile
.
read
(
filename
,
dtype
=
"int16"
)
# Load all files in memory
self
.
utt2noise
[
utt
]
=
(
signal
,
rate
)
elif
filetype
==
"sound.hdf5"
:
self
.
utt2noise
=
SoundHDF5File
(
utt2noise
,
"r"
)
else
:
raise
ValueError
(
filetype
)
else
:
self
.
utt2noise
=
None
if
utt2noise
is
not
None
and
utt2ratio
is
not
None
:
if
set
(
self
.
utt2ratio
)
!=
set
(
self
.
utt2noise
):
raise
RuntimeError
(
"The uttids mismatch between {} and {}"
.
format
(
utt2ratio
,
utt2noise
))
def
__repr__
(
self
):
if
self
.
utt2ratio
is
None
:
return
"{}(lower={}, upper={}, dbunit={})"
.
format
(
self
.
__class__
.
__name__
,
self
.
lower
,
self
.
upper
,
self
.
dbunit
)
else
:
return
'{}("{}", dbunit={})'
.
format
(
self
.
__class__
.
__name__
,
self
.
utt2ratio_file
,
self
.
dbunit
)
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
x
x
=
x
.
astype
(
numpy
.
float32
)
# 1. Get ratio of noise to signal in sound pressure level
if
uttid
is
not
None
and
self
.
utt2ratio
is
not
None
:
ratio
=
self
.
utt2ratio
[
uttid
]
else
:
ratio
=
self
.
state
.
uniform
(
self
.
lower
,
self
.
upper
)
if
self
.
dbunit
:
ratio
=
10
**
(
ratio
/
20
)
scale
=
ratio
*
numpy
.
sqrt
((
x
**
2
).
mean
())
# 2. Get noise
if
self
.
utt2noise
is
not
None
:
# Get noise from the external source
if
uttid
is
not
None
:
noise
,
rate
=
self
.
utt2noise
[
uttid
]
else
:
# Randomly select the noise source
noise
=
self
.
state
.
choice
(
list
(
self
.
utt2noise
.
values
()))
# Normalize the level
noise
/=
numpy
.
sqrt
((
noise
**
2
).
mean
())
# Adjust the noise length
diff
=
abs
(
len
(
x
)
-
len
(
noise
))
offset
=
self
.
state
.
randint
(
0
,
diff
)
if
len
(
noise
)
>
len
(
x
):
# Truncate noise
noise
=
noise
[
offset
:
-
(
diff
-
offset
)]
else
:
noise
=
numpy
.
pad
(
noise
,
pad_width
=
[
offset
,
diff
-
offset
],
mode
=
"wrap"
)
else
:
# Generate white noise
noise
=
self
.
state
.
normal
(
0
,
1
,
x
.
shape
)
# 3. Add noise to signal
return
x
+
noise
*
scale
class
RIRConvolve
():
def
__init__
(
self
,
utt2rir
,
filetype
=
"list"
):
self
.
utt2rir_file
=
utt2rir
self
.
filetype
=
filetype
self
.
utt2rir
=
{}
if
filetype
==
"list"
:
with
open
(
utt2rir
,
"r"
)
as
f
:
for
line
in
f
:
utt
,
filename
=
line
.
rstrip
().
split
(
None
,
1
)
signal
,
rate
=
soundfile
.
read
(
filename
,
dtype
=
"int16"
)
self
.
utt2rir
[
utt
]
=
(
signal
,
rate
)
elif
filetype
==
"sound.hdf5"
:
self
.
utt2rir
=
SoundHDF5File
(
utt2rir
,
"r"
)
else
:
raise
NotImplementedError
(
filetype
)
def
__repr__
(
self
):
return
'{}("{}")'
.
format
(
self
.
__class__
.
__name__
,
self
.
utt2rir_file
)
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
x
x
=
x
.
astype
(
numpy
.
float32
)
if
x
.
ndim
!=
1
:
# Must be single channel
raise
RuntimeError
(
"Input x must be one dimensional array, but got {}"
.
format
(
x
.
shape
))
rir
,
rate
=
self
.
utt2rir
[
uttid
]
if
rir
.
ndim
==
2
:
# FIXME(kamo): Use chainer.convolution_1d?
# return [Time, Channel]
return
numpy
.
stack
(
[
scipy
.
convolve
(
x
,
r
,
mode
=
"same"
)
for
r
in
rir
],
axis
=-
1
)
else
:
return
scipy
.
convolve
(
x
,
rir
,
mode
=
"same"
)
paddlespeech/audio/transform/spec_augment.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
"""Spec Augment module for preprocessing i.e., data augmentation"""
import
random
import
numpy
from
PIL
import
Image
from
PIL.Image
import
BICUBIC
from
.functional
import
FuncTrans
def
time_warp
(
x
,
max_time_warp
=
80
,
inplace
=
False
,
mode
=
"PIL"
):
"""time warp for spec augment
move random center frame by the random width ~ uniform(-window, window)
:param numpy.ndarray x: spectrogram (time, freq)
:param int max_time_warp: maximum time frames to warp
:param bool inplace: overwrite x with the result
:param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp"
(slow, differentiable)
:returns numpy.ndarray: time warped spectrogram (time, freq)
"""
window
=
max_time_warp
if
window
==
0
:
return
x
if
mode
==
"PIL"
:
t
=
x
.
shape
[
0
]
if
t
-
window
<=
window
:
return
x
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
center
=
random
.
randrange
(
window
,
t
-
window
)
warped
=
random
.
randrange
(
center
-
window
,
center
+
window
)
+
1
# 1 ... t - 1
left
=
Image
.
fromarray
(
x
[:
center
]).
resize
((
x
.
shape
[
1
],
warped
),
BICUBIC
)
right
=
Image
.
fromarray
(
x
[
center
:]).
resize
((
x
.
shape
[
1
],
t
-
warped
),
BICUBIC
)
if
inplace
:
x
[:
warped
]
=
left
x
[
warped
:]
=
right
return
x
return
numpy
.
concatenate
((
left
,
right
),
0
)
elif
mode
==
"sparse_image_warp"
:
import
paddle
from
espnet.utils
import
spec_augment
# TODO(karita): make this differentiable again
return
spec_augment
.
time_warp
(
paddle
.
to_tensor
(
x
),
window
).
numpy
()
else
:
raise
NotImplementedError
(
"unknown resize mode: "
+
mode
+
", choose one from (PIL, sparse_image_warp)."
)
class
TimeWarp
(
FuncTrans
):
_func
=
time_warp
__doc__
=
time_warp
.
__doc__
def
__call__
(
self
,
x
,
train
):
if
not
train
:
return
x
return
super
().
__call__
(
x
)
def
freq_mask
(
x
,
F
=
30
,
n_mask
=
2
,
replace_with_zero
=
True
,
inplace
=
False
):
"""freq mask for spec agument
:param numpy.ndarray x: (time, freq)
:param int n_mask: the number of masks
:param bool inplace: overwrite
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
if
inplace
:
cloned
=
x
else
:
cloned
=
x
.
copy
()
num_mel_channels
=
cloned
.
shape
[
1
]
fs
=
numpy
.
random
.
randint
(
0
,
F
,
size
=
(
n_mask
,
2
))
for
f
,
mask_end
in
fs
:
f_zero
=
random
.
randrange
(
0
,
num_mel_channels
-
f
)
mask_end
+=
f_zero
# avoids randrange error if values are equal and range is empty
if
f_zero
==
f_zero
+
f
:
continue
if
replace_with_zero
:
cloned
[:,
f_zero
:
mask_end
]
=
0
else
:
cloned
[:,
f_zero
:
mask_end
]
=
cloned
.
mean
()
return
cloned
class
FreqMask
(
FuncTrans
):
_func
=
freq_mask
__doc__
=
freq_mask
.
__doc__
def
__call__
(
self
,
x
,
train
):
if
not
train
:
return
x
return
super
().
__call__
(
x
)
def
time_mask
(
spec
,
T
=
40
,
n_mask
=
2
,
replace_with_zero
=
True
,
inplace
=
False
):
"""freq mask for spec agument
:param numpy.ndarray spec: (time, freq)
:param int n_mask: the number of masks
:param bool inplace: overwrite
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
if
inplace
:
cloned
=
spec
else
:
cloned
=
spec
.
copy
()
len_spectro
=
cloned
.
shape
[
0
]
ts
=
numpy
.
random
.
randint
(
0
,
T
,
size
=
(
n_mask
,
2
))
for
t
,
mask_end
in
ts
:
# avoid randint range error
if
len_spectro
-
t
<=
0
:
continue
t_zero
=
random
.
randrange
(
0
,
len_spectro
-
t
)
# avoids randrange error if values are equal and range is empty
if
t_zero
==
t_zero
+
t
:
continue
mask_end
+=
t_zero
if
replace_with_zero
:
cloned
[
t_zero
:
mask_end
]
=
0
else
:
cloned
[
t_zero
:
mask_end
]
=
cloned
.
mean
()
return
cloned
class
TimeMask
(
FuncTrans
):
_func
=
time_mask
__doc__
=
time_mask
.
__doc__
def
__call__
(
self
,
x
,
train
):
if
not
train
:
return
x
return
super
().
__call__
(
x
)
def
spec_augment
(
x
,
resize_mode
=
"PIL"
,
max_time_warp
=
80
,
max_freq_width
=
27
,
n_freq_mask
=
2
,
max_time_width
=
100
,
n_time_mask
=
2
,
inplace
=
True
,
replace_with_zero
=
True
,
):
"""spec agument
apply random time warping and time/freq masking
default setting is based on LD (Librispeech double) in Table 2
https://arxiv.org/pdf/1904.08779.pdf
:param numpy.ndarray x: (time, freq)
:param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp"
(slow, differentiable)
:param int max_time_warp: maximum frames to warp the center frame in spectrogram (W)
:param int freq_mask_width: maximum width of the random freq mask (F)
:param int n_freq_mask: the number of the random freq mask (m_F)
:param int time_mask_width: maximum width of the random time mask (T)
:param int n_time_mask: the number of the random time mask (m_T)
:param bool inplace: overwrite intermediate array
:param bool replace_with_zero: pad zero on mask if true else use mean
"""
assert
isinstance
(
x
,
numpy
.
ndarray
)
assert
x
.
ndim
==
2
x
=
time_warp
(
x
,
max_time_warp
,
inplace
=
inplace
,
mode
=
resize_mode
)
x
=
freq_mask
(
x
,
max_freq_width
,
n_freq_mask
,
inplace
=
inplace
,
replace_with_zero
=
replace_with_zero
,
)
x
=
time_mask
(
x
,
max_time_width
,
n_time_mask
,
inplace
=
inplace
,
replace_with_zero
=
replace_with_zero
,
)
return
x
class
SpecAugment
(
FuncTrans
):
_func
=
spec_augment
__doc__
=
spec_augment
.
__doc__
def
__call__
(
self
,
x
,
train
):
if
not
train
:
return
x
return
super
().
__call__
(
x
)
paddlespeech/audio/transform/spectrogram.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
librosa
import
numpy
as
np
import
paddle
from
python_speech_features
import
logfbank
from
..compliance
import
kaldi
def
stft
(
x
,
n_fft
,
n_shift
,
win_length
=
None
,
window
=
"hann"
,
center
=
True
,
pad_mode
=
"reflect"
):
# x: [Time, Channel]
if
x
.
ndim
==
1
:
single_channel
=
True
# x: [Time] -> [Time, Channel]
x
=
x
[:,
None
]
else
:
single_channel
=
False
x
=
x
.
astype
(
np
.
float32
)
# FIXME(kamo): librosa.stft can't use multi-channel?
# x: [Time, Channel, Freq]
x
=
np
.
stack
(
[
librosa
.
stft
(
y
=
x
[:,
ch
],
n_fft
=
n_fft
,
hop_length
=
n_shift
,
win_length
=
win_length
,
window
=
window
,
center
=
center
,
pad_mode
=
pad_mode
,
).
T
for
ch
in
range
(
x
.
shape
[
1
])
],
axis
=
1
,
)
if
single_channel
:
# x: [Time, Channel, Freq] -> [Time, Freq]
x
=
x
[:,
0
]
return
x
def
istft
(
x
,
n_shift
,
win_length
=
None
,
window
=
"hann"
,
center
=
True
):
# x: [Time, Channel, Freq]
if
x
.
ndim
==
2
:
single_channel
=
True
# x: [Time, Freq] -> [Time, Channel, Freq]
x
=
x
[:,
None
,
:]
else
:
single_channel
=
False
# x: [Time, Channel]
x
=
np
.
stack
(
[
librosa
.
istft
(
stft_matrix
=
x
[:,
ch
].
T
,
# [Time, Freq] -> [Freq, Time]
hop_length
=
n_shift
,
win_length
=
win_length
,
window
=
window
,
center
=
center
,
)
for
ch
in
range
(
x
.
shape
[
1
])
],
axis
=
1
,
)
if
single_channel
:
# x: [Time, Channel] -> [Time]
x
=
x
[:,
0
]
return
x
def
stft2logmelspectrogram
(
x_stft
,
fs
,
n_mels
,
n_fft
,
fmin
=
None
,
fmax
=
None
,
eps
=
1e-10
):
# x_stft: (Time, Channel, Freq) or (Time, Freq)
fmin
=
0
if
fmin
is
None
else
fmin
fmax
=
fs
/
2
if
fmax
is
None
else
fmax
# spc: (Time, Channel, Freq) or (Time, Freq)
spc
=
np
.
abs
(
x_stft
)
# mel_basis: (Mel_freq, Freq)
mel_basis
=
librosa
.
filters
.
mel
(
sr
=
fs
,
n_fft
=
n_fft
,
n_mels
=
n_mels
,
fmin
=
fmin
,
fmax
=
fmax
)
# lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq)
lmspc
=
np
.
log10
(
np
.
maximum
(
eps
,
np
.
dot
(
spc
,
mel_basis
.
T
)))
return
lmspc
def
spectrogram
(
x
,
n_fft
,
n_shift
,
win_length
=
None
,
window
=
"hann"
):
# x: (Time, Channel) -> spc: (Time, Channel, Freq)
spc
=
np
.
abs
(
stft
(
x
,
n_fft
,
n_shift
,
win_length
,
window
=
window
))
return
spc
def
logmelspectrogram
(
x
,
fs
,
n_mels
,
n_fft
,
n_shift
,
win_length
=
None
,
window
=
"hann"
,
fmin
=
None
,
fmax
=
None
,
eps
=
1e-10
,
pad_mode
=
"reflect"
,
):
# stft: (Time, Channel, Freq) or (Time, Freq)
x_stft
=
stft
(
x
,
n_fft
=
n_fft
,
n_shift
=
n_shift
,
win_length
=
win_length
,
window
=
window
,
pad_mode
=
pad_mode
,
)
return
stft2logmelspectrogram
(
x_stft
,
fs
=
fs
,
n_mels
=
n_mels
,
n_fft
=
n_fft
,
fmin
=
fmin
,
fmax
=
fmax
,
eps
=
eps
)
class
Spectrogram
():
def
__init__
(
self
,
n_fft
,
n_shift
,
win_length
=
None
,
window
=
"hann"
):
self
.
n_fft
=
n_fft
self
.
n_shift
=
n_shift
self
.
win_length
=
win_length
self
.
window
=
window
def
__repr__
(
self
):
return
(
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window})"
.
format
(
name
=
self
.
__class__
.
__name__
,
n_fft
=
self
.
n_fft
,
n_shift
=
self
.
n_shift
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
))
def
__call__
(
self
,
x
):
return
spectrogram
(
x
,
n_fft
=
self
.
n_fft
,
n_shift
=
self
.
n_shift
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
)
class
LogMelSpectrogram
():
def
__init__
(
self
,
fs
,
n_mels
,
n_fft
,
n_shift
,
win_length
=
None
,
window
=
"hann"
,
fmin
=
None
,
fmax
=
None
,
eps
=
1e-10
,
):
self
.
fs
=
fs
self
.
n_mels
=
n_mels
self
.
n_fft
=
n_fft
self
.
n_shift
=
n_shift
self
.
win_length
=
win_length
self
.
window
=
window
self
.
fmin
=
fmin
self
.
fmax
=
fmax
self
.
eps
=
eps
def
__repr__
(
self
):
return
(
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"n_shift={n_shift}, win_length={win_length}, window={window}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))"
.
format
(
name
=
self
.
__class__
.
__name__
,
fs
=
self
.
fs
,
n_mels
=
self
.
n_mels
,
n_fft
=
self
.
n_fft
,
n_shift
=
self
.
n_shift
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
fmin
=
self
.
fmin
,
fmax
=
self
.
fmax
,
eps
=
self
.
eps
,
))
def
__call__
(
self
,
x
):
return
logmelspectrogram
(
x
,
fs
=
self
.
fs
,
n_mels
=
self
.
n_mels
,
n_fft
=
self
.
n_fft
,
n_shift
=
self
.
n_shift
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
)
class
Stft2LogMelSpectrogram
():
def
__init__
(
self
,
fs
,
n_mels
,
n_fft
,
fmin
=
None
,
fmax
=
None
,
eps
=
1e-10
):
self
.
fs
=
fs
self
.
n_mels
=
n_mels
self
.
n_fft
=
n_fft
self
.
fmin
=
fmin
self
.
fmax
=
fmax
self
.
eps
=
eps
def
__repr__
(
self
):
return
(
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))"
.
format
(
name
=
self
.
__class__
.
__name__
,
fs
=
self
.
fs
,
n_mels
=
self
.
n_mels
,
n_fft
=
self
.
n_fft
,
fmin
=
self
.
fmin
,
fmax
=
self
.
fmax
,
eps
=
self
.
eps
,
))
def
__call__
(
self
,
x
):
return
stft2logmelspectrogram
(
x
,
fs
=
self
.
fs
,
n_mels
=
self
.
n_mels
,
n_fft
=
self
.
n_fft
,
fmin
=
self
.
fmin
,
fmax
=
self
.
fmax
,
)
class
Stft
():
def
__init__
(
self
,
n_fft
,
n_shift
,
win_length
=
None
,
window
=
"hann"
,
center
=
True
,
pad_mode
=
"reflect"
,
):
self
.
n_fft
=
n_fft
self
.
n_shift
=
n_shift
self
.
win_length
=
win_length
self
.
window
=
window
self
.
center
=
center
self
.
pad_mode
=
pad_mode
def
__repr__
(
self
):
return
(
"{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window},"
"center={center}, pad_mode={pad_mode})"
.
format
(
name
=
self
.
__class__
.
__name__
,
n_fft
=
self
.
n_fft
,
n_shift
=
self
.
n_shift
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
center
=
self
.
center
,
pad_mode
=
self
.
pad_mode
,
))
def
__call__
(
self
,
x
):
return
stft
(
x
,
self
.
n_fft
,
self
.
n_shift
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
center
=
self
.
center
,
pad_mode
=
self
.
pad_mode
,
)
class
IStft
():
def
__init__
(
self
,
n_shift
,
win_length
=
None
,
window
=
"hann"
,
center
=
True
):
self
.
n_shift
=
n_shift
self
.
win_length
=
win_length
self
.
window
=
window
self
.
center
=
center
def
__repr__
(
self
):
return
(
"{name}(n_shift={n_shift}, "
"win_length={win_length}, window={window},"
"center={center})"
.
format
(
name
=
self
.
__class__
.
__name__
,
n_shift
=
self
.
n_shift
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
center
=
self
.
center
,
))
def
__call__
(
self
,
x
):
return
istft
(
x
,
self
.
n_shift
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
center
=
self
.
center
,
)
class
LogMelSpectrogramKaldi
():
def
__init__
(
self
,
fs
=
16000
,
n_mels
=
80
,
n_shift
=
160
,
# unit:sample, 10ms
win_length
=
400
,
# unit:sample, 25ms
energy_floor
=
0.0
,
dither
=
0.1
):
"""
The Kaldi implementation of LogMelSpectrogram
Args:
fs (int): sample rate of the audio
n_mels (int): number of mel filter banks
n_shift (int): number of points in a frame shift
win_length (int): number of points in a frame windows
energy_floor (float): Floor on energy in Spectrogram computation (absolute)
dither (float): Dithering constant
Returns:
LogMelSpectrogramKaldi
"""
self
.
fs
=
fs
self
.
n_mels
=
n_mels
num_point_ms
=
fs
/
1000
self
.
n_frame_length
=
win_length
/
num_point_ms
self
.
n_frame_shift
=
n_shift
/
num_point_ms
self
.
energy_floor
=
energy_floor
self
.
dither
=
dither
def
__repr__
(
self
):
return
(
"{name}(fs={fs}, n_mels={n_mels}, "
"n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, "
"dither={dither}))"
.
format
(
name
=
self
.
__class__
.
__name__
,
fs
=
self
.
fs
,
n_mels
=
self
.
n_mels
,
n_frame_shift
=
self
.
n_frame_shift
,
n_frame_length
=
self
.
n_frame_length
,
dither
=
self
.
dither
,
))
def
__call__
(
self
,
x
,
train
):
"""
Args:
x (np.ndarray): shape (Ti,)
train (bool): True, train mode.
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
dither
=
self
.
dither
if
train
else
0.0
if
x
.
ndim
!=
1
:
raise
ValueError
(
"Not support x: [Time, Channel]"
)
waveform
=
paddle
.
to_tensor
(
np
.
expand_dims
(
x
,
0
),
dtype
=
paddle
.
float32
)
mat
=
kaldi
.
fbank
(
waveform
,
n_mels
=
self
.
n_mels
,
frame_length
=
self
.
n_frame_length
,
frame_shift
=
self
.
n_frame_shift
,
dither
=
dither
,
energy_floor
=
self
.
energy_floor
,
sr
=
self
.
fs
)
mat
=
np
.
squeeze
(
mat
.
numpy
())
return
mat
class
LogMelSpectrogramKaldi_decay
():
def
__init__
(
self
,
fs
=
16000
,
n_mels
=
80
,
n_fft
=
512
,
# fft point
n_shift
=
160
,
# unit:sample, 10ms
win_length
=
400
,
# unit:sample, 25ms
window
=
"povey"
,
fmin
=
20
,
fmax
=
None
,
eps
=
1e-10
,
dither
=
1.0
):
self
.
fs
=
fs
self
.
n_mels
=
n_mels
self
.
n_fft
=
n_fft
if
n_shift
>
win_length
:
raise
ValueError
(
"Stride size must not be greater than "
"window size."
)
self
.
n_shift
=
n_shift
/
fs
# unit: ms
self
.
win_length
=
win_length
/
fs
# unit: ms
self
.
window
=
window
self
.
fmin
=
fmin
if
fmax
is
None
:
fmax_
=
fmax
if
fmax
else
self
.
fs
/
2
elif
fmax
>
int
(
self
.
fs
/
2
):
raise
ValueError
(
"fmax must not be greater than half of "
"sample rate."
)
self
.
fmax
=
fmax_
self
.
eps
=
eps
self
.
remove_dc_offset
=
True
self
.
preemph
=
0.97
self
.
dither
=
dither
# only work in train mode
def
__repr__
(
self
):
return
(
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"n_shift={n_shift}, win_length={win_length}, preemph={preemph}, window={window}, "
"fmin={fmin}, fmax={fmax}, eps={eps}, dither={dither}))"
.
format
(
name
=
self
.
__class__
.
__name__
,
fs
=
self
.
fs
,
n_mels
=
self
.
n_mels
,
n_fft
=
self
.
n_fft
,
n_shift
=
self
.
n_shift
,
preemph
=
self
.
preemph
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
fmin
=
self
.
fmin
,
fmax
=
self
.
fmax
,
eps
=
self
.
eps
,
dither
=
self
.
dither
,
))
def
__call__
(
self
,
x
,
train
):
"""
Args:
x (np.ndarray): shape (Ti,)
train (bool): True, train mode.
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
dither
=
self
.
dither
if
train
else
0.0
if
x
.
ndim
!=
1
:
raise
ValueError
(
"Not support x: [Time, Channel]"
)
if
x
.
dtype
in
np
.
sctypes
[
'float'
]:
# PCM32 -> PCM16
bits
=
np
.
iinfo
(
np
.
int16
).
bits
x
=
x
*
2
**
(
bits
-
1
)
# logfbank need PCM16 input
y
=
logfbank
(
signal
=
x
,
samplerate
=
self
.
fs
,
winlen
=
self
.
win_length
,
# unit ms
winstep
=
self
.
n_shift
,
# unit ms
nfilt
=
self
.
n_mels
,
nfft
=
self
.
n_fft
,
lowfreq
=
self
.
fmin
,
highfreq
=
self
.
fmax
,
dither
=
dither
,
remove_dc_offset
=
self
.
remove_dc_offset
,
preemph
=
self
.
preemph
,
wintype
=
self
.
window
)
return
y
paddlespeech/audio/transform/transform_interface.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
class
TransformInterface
:
"""Transform Interface"""
def
__call__
(
self
,
x
):
raise
NotImplementedError
(
"__call__ method is not implemented"
)
@
classmethod
def
add_arguments
(
cls
,
parser
):
return
parser
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"()"
class
Identity
(
TransformInterface
):
"""Identity Function"""
def
__call__
(
self
,
x
):
return
x
paddlespeech/audio/transform/transformation.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
"""Transformation module."""
import
copy
import
io
import
logging
from
collections
import
OrderedDict
from
collections.abc
import
Sequence
from
inspect
import
signature
import
yaml
from
..utils.dynamic_import
import
dynamic_import
import_alias
=
dict
(
identity
=
"paddlespeech.audio.transform.transform_interface:Identity"
,
time_warp
=
"paddlespeech.audio.transform.spec_augment:TimeWarp"
,
time_mask
=
"paddlespeech.audio.transform.spec_augment:TimeMask"
,
freq_mask
=
"paddlespeech.audio.transform.spec_augment:FreqMask"
,
spec_augment
=
"paddlespeech.audio.transform.spec_augment:SpecAugment"
,
speed_perturbation
=
"paddlespeech.audio.transform.perturb:SpeedPerturbation"
,
speed_perturbation_sox
=
"paddlespeech.audio.transform.perturb:SpeedPerturbationSox"
,
volume_perturbation
=
"paddlespeech.audio.transform.perturb:VolumePerturbation"
,
noise_injection
=
"paddlespeech.audio.transform.perturb:NoiseInjection"
,
bandpass_perturbation
=
"paddlespeech.audio.transform.perturb:BandpassPerturbation"
,
rir_convolve
=
"paddlespeech.audio.transform.perturb:RIRConvolve"
,
delta
=
"paddlespeech.audio.transform.add_deltas:AddDeltas"
,
cmvn
=
"paddlespeech.audio.transform.cmvn:CMVN"
,
utterance_cmvn
=
"paddlespeech.audio.transform.cmvn:UtteranceCMVN"
,
fbank
=
"paddlespeech.audio.transform.spectrogram:LogMelSpectrogram"
,
spectrogram
=
"paddlespeech.audio.transform.spectrogram:Spectrogram"
,
stft
=
"paddlespeech.audio.transform.spectrogram:Stft"
,
istft
=
"paddlespeech.audio.transform.spectrogram:IStft"
,
stft2fbank
=
"paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram"
,
wpe
=
"paddlespeech.audio.transform.wpe:WPE"
,
channel_selector
=
"paddlespeech.audio.transform.channel_selector:ChannelSelector"
,
fbank_kaldi
=
"paddlespeech.audio.transform.spectrogram:LogMelSpectrogramKaldi"
,
cmvn_json
=
"paddlespeech.audio.transform.cmvn:GlobalCMVN"
)
class
Transformation
():
"""Apply some functions to the mini-batch
Examples:
>>> kwargs = {"process": [{"type": "fbank",
... "n_mels": 80,
... "fs": 16000},
... {"type": "cmvn",
... "stats": "data/train/cmvn.ark",
... "norm_vars": True},
... {"type": "delta", "window": 2, "order": 2}]}
>>> transform = Transformation(kwargs)
>>> bs = 10
>>> xs = [np.random.randn(100, 80).astype(np.float32)
... for _ in range(bs)]
>>> xs = transform(xs)
"""
def
__init__
(
self
,
conffile
=
None
):
if
conffile
is
not
None
:
if
isinstance
(
conffile
,
dict
):
self
.
conf
=
copy
.
deepcopy
(
conffile
)
else
:
with
io
.
open
(
conffile
,
encoding
=
"utf-8"
)
as
f
:
self
.
conf
=
yaml
.
safe_load
(
f
)
assert
isinstance
(
self
.
conf
,
dict
),
type
(
self
.
conf
)
else
:
self
.
conf
=
{
"mode"
:
"sequential"
,
"process"
:
[]}
self
.
functions
=
OrderedDict
()
if
self
.
conf
.
get
(
"mode"
,
"sequential"
)
==
"sequential"
:
for
idx
,
process
in
enumerate
(
self
.
conf
[
"process"
]):
assert
isinstance
(
process
,
dict
),
type
(
process
)
opts
=
dict
(
process
)
process_type
=
opts
.
pop
(
"type"
)
class_obj
=
dynamic_import
(
process_type
,
import_alias
)
# TODO(karita): assert issubclass(class_obj, TransformInterface)
try
:
self
.
functions
[
idx
]
=
class_obj
(
**
opts
)
except
TypeError
:
try
:
signa
=
signature
(
class_obj
)
except
ValueError
:
# Some function, e.g. built-in function, are failed
pass
else
:
logging
.
error
(
"Expected signature: {}({})"
.
format
(
class_obj
.
__name__
,
signa
))
raise
else
:
raise
NotImplementedError
(
"Not supporting mode={}"
.
format
(
self
.
conf
[
"mode"
]))
def
__repr__
(
self
):
rep
=
"
\n
"
+
"
\n
"
.
join
(
" {}: {}"
.
format
(
k
,
v
)
for
k
,
v
in
self
.
functions
.
items
())
return
"{}({})"
.
format
(
self
.
__class__
.
__name__
,
rep
)
def
__call__
(
self
,
xs
,
uttid_list
=
None
,
**
kwargs
):
"""Return new mini-batch
:param Union[Sequence[np.ndarray], np.ndarray] xs:
:param Union[Sequence[str], str] uttid_list:
:return: batch:
:rtype: List[np.ndarray]
"""
if
not
isinstance
(
xs
,
Sequence
):
is_batch
=
False
xs
=
[
xs
]
else
:
is_batch
=
True
if
isinstance
(
uttid_list
,
str
):
uttid_list
=
[
uttid_list
for
_
in
range
(
len
(
xs
))]
if
self
.
conf
.
get
(
"mode"
,
"sequential"
)
==
"sequential"
:
for
idx
in
range
(
len
(
self
.
conf
[
"process"
])):
func
=
self
.
functions
[
idx
]
# TODO(karita): use TrainingTrans and UttTrans to check __call__ args
# Derive only the args which the func has
try
:
param
=
signature
(
func
).
parameters
except
ValueError
:
# Some function, e.g. built-in function, are failed
param
=
{}
_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
in
param
}
try
:
if
uttid_list
is
not
None
and
"uttid"
in
param
:
xs
=
[
func
(
x
,
u
,
**
_kwargs
)
for
x
,
u
in
zip
(
xs
,
uttid_list
)
]
else
:
xs
=
[
func
(
x
,
**
_kwargs
)
for
x
in
xs
]
except
Exception
:
logging
.
fatal
(
"Catch a exception from {}th func: {}"
.
format
(
idx
,
func
))
raise
else
:
raise
NotImplementedError
(
"Not supporting mode={}"
.
format
(
self
.
conf
[
"mode"
]))
if
is_batch
:
return
xs
else
:
return
xs
[
0
]
paddlespeech/audio/transform/wpe.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
from
nara_wpe.wpe
import
wpe
class
WPE
(
object
):
def
__init__
(
self
,
taps
=
10
,
delay
=
3
,
iterations
=
3
,
psd_context
=
0
,
statistics_mode
=
"full"
):
self
.
taps
=
taps
self
.
delay
=
delay
self
.
iterations
=
iterations
self
.
psd_context
=
psd_context
self
.
statistics_mode
=
statistics_mode
def
__repr__
(
self
):
return
(
"{name}(taps={taps}, delay={delay}"
"iterations={iterations}, psd_context={psd_context}, "
"statistics_mode={statistics_mode})"
.
format
(
name
=
self
.
__class__
.
__name__
,
taps
=
self
.
taps
,
delay
=
self
.
delay
,
iterations
=
self
.
iterations
,
psd_context
=
self
.
psd_context
,
statistics_mode
=
self
.
statistics_mode
,
))
def
__call__
(
self
,
xs
):
"""Return enhanced
:param np.ndarray xs: (Time, Channel, Frequency)
:return: enhanced_xs
:rtype: np.ndarray
"""
# nara_wpe.wpe: (F, C, T)
xs
=
wpe
(
xs
.
transpose
((
2
,
1
,
0
)),
taps
=
self
.
taps
,
delay
=
self
.
delay
,
iterations
=
self
.
iterations
,
psd_context
=
self
.
psd_context
,
statistics_mode
=
self
.
statistics_mode
,
)
return
xs
.
transpose
(
2
,
1
,
0
)
paddlespeech/audio/utils/check_kwargs.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
inspect
def
check_kwargs
(
func
,
kwargs
,
name
=
None
):
"""check kwargs are valid for func
If kwargs are invalid, raise TypeError as same as python default
:param function func: function to be validated
:param dict kwargs: keyword arguments for func
:param str name: name used in TypeError (default is func name)
"""
try
:
params
=
inspect
.
signature
(
func
).
parameters
except
ValueError
:
return
if
name
is
None
:
name
=
func
.
__name__
for
k
in
kwargs
.
keys
():
if
k
not
in
params
:
raise
TypeError
(
f
"
{
name
}
() got an unexpected keyword argument '
{
k
}
'"
)
paddlespeech/audio/utils/dynamic_import.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
importlib
__all__
=
[
"dynamic_import"
]
def
dynamic_import
(
import_path
,
alias
=
dict
()):
"""dynamic import module and class
:param str import_path: syntax 'module_name:class_name'
e.g., 'paddlespeech.s2t.models.u2:U2Model'
:param dict alias: shortcut for registered class
:return: imported class
"""
if
import_path
not
in
alias
and
":"
not
in
import_path
:
raise
ValueError
(
"import_path should be one of {} or "
'include ":", e.g. "paddlespeech.s2t.models.u2:U2Model" : '
"{}"
.
format
(
set
(
alias
),
import_path
))
if
":"
not
in
import_path
:
import_path
=
alias
[
import_path
]
module_name
,
objname
=
import_path
.
split
(
":"
)
m
=
importlib
.
import_module
(
module_name
)
return
getattr
(
m
,
objname
)
paddlespeech/audio/utils/tensor_utils.py
0 → 100644
浏览文件 @
8f5e6109
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unility functions for Transformer."""
from
typing
import
List
from
typing
import
Tuple
import
paddle
from
.log
import
Logger
__all__
=
[
"pad_sequence"
,
"add_sos_eos"
,
"th_accuracy"
,
"has_tensor"
]
logger
=
Logger
(
__name__
)
def
has_tensor
(
val
):
if
isinstance
(
val
,
(
list
,
tuple
)):
for
item
in
val
:
if
has_tensor
(
item
):
return
True
elif
isinstance
(
val
,
dict
):
for
k
,
v
in
val
.
items
():
print
(
k
)
if
has_tensor
(
v
):
return
True
else
:
return
paddle
.
is_tensor
(
val
)
def
pad_sequence
(
sequences
:
List
[
paddle
.
Tensor
],
batch_first
:
bool
=
False
,
padding_value
:
float
=
0.0
)
->
paddle
.
Tensor
:
r
"""Pad a list of variable length Tensors with ``padding_value``
``pad_sequence`` stacks a list of Tensors along a new dimension,
and pads them to equal length. For example, if the input is list of
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
otherwise.
`B` is batch size. It is equal to the number of elements in ``sequences``.
`T` is length of the longest sequence.
`L` is length of the sequence.
`*` is any number of trailing dimensions, including none.
Example:
>>> from paddle.nn.utils.rnn import pad_sequence
>>> a = paddle.ones(25, 300)
>>> b = paddle.ones(22, 300)
>>> c = paddle.ones(15, 300)
>>> pad_sequence([a, b, c]).shape
paddle.Tensor([25, 3, 300])
Note:
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
where `T` is the length of the longest sequence. This function assumes
trailing dimensions and type of all the Tensors in sequences are same.
Args:
sequences (list[Tensor]): list of variable length sequences.
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
``T x B x *`` otherwise
padding_value (float, optional): value for padded elements. Default: 0.
Returns:
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
Tensor of size ``B x T x *`` otherwise
"""
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size
=
paddle
.
shape
(
sequences
[
0
])
# (TODO Hui Zhang): slice not supprot `end==start`
# trailing_dims = max_size[1:]
trailing_dims
=
tuple
(
max_size
[
1
:].
numpy
().
tolist
())
if
sequences
[
0
].
ndim
>=
2
else
()
max_len
=
max
([
s
.
shape
[
0
]
for
s
in
sequences
])
if
batch_first
:
out_dims
=
(
len
(
sequences
),
max_len
)
+
trailing_dims
else
:
out_dims
=
(
max_len
,
len
(
sequences
))
+
trailing_dims
out_tensor
=
paddle
.
full
(
out_dims
,
padding_value
,
sequences
[
0
].
dtype
)
for
i
,
tensor
in
enumerate
(
sequences
):
length
=
tensor
.
shape
[
0
]
# use index notation to prevent duplicate references to the tensor
logger
.
info
(
f
"length
{
length
}
, out_tensor
{
out_tensor
.
shape
}
, tensor
{
tensor
.
shape
}
"
)
if
batch_first
:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# TODO (Hui Zhang): set_value op not support int16
# TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...]
# out_tensor[i, :length, ...] = tensor
if
length
!=
0
:
out_tensor
[
i
,
:
length
]
=
tensor
else
:
out_tensor
[
i
,
length
]
=
tensor
else
:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# out_tensor[:length, i, ...] = tensor
if
length
!=
0
:
out_tensor
[:
length
,
i
]
=
tensor
else
:
out_tensor
[
length
,
i
]
=
tensor
return
out_tensor
def
add_sos_eos
(
ys_pad
:
paddle
.
Tensor
,
sos
:
int
,
eos
:
int
,
ignore_id
:
int
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Add <sos> and <eos> labels.
Args:
ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)
sos (int): index of <sos>
eos (int): index of <eeos>
ignore_id (int): index of padding
Returns:
ys_in (paddle.Tensor) : (B, Lmax + 1)
ys_out (paddle.Tensor) : (B, Lmax + 1)
Examples:
>>> sos_id = 10
>>> eos_id = 11
>>> ignore_id = -1
>>> ys_pad
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, -1, -1],
[ 7, 8, 9, -1, -1]], dtype=paddle.int32)
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
>>> ys_in
tensor([[10, 1, 2, 3, 4, 5],
[10, 4, 5, 6, 11, 11],
[10, 7, 8, 9, 11, 11]])
>>> ys_out
tensor([[ 1, 2, 3, 4, 5, 11],
[ 4, 5, 6, 11, -1, -1],
[ 7, 8, 9, 11, -1, -1]])
"""
# TODO(Hui Zhang): using comment code,
#_sos = paddle.to_tensor(
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#_eos = paddle.to_tensor(
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
B
=
ys_pad
.
shape
[
0
]
_sos
=
paddle
.
ones
([
B
,
1
],
dtype
=
ys_pad
.
dtype
)
*
sos
_eos
=
paddle
.
ones
([
B
,
1
],
dtype
=
ys_pad
.
dtype
)
*
eos
ys_in
=
paddle
.
cat
([
_sos
,
ys_pad
],
dim
=
1
)
mask_pad
=
(
ys_in
==
ignore_id
)
ys_in
=
ys_in
.
masked_fill
(
mask_pad
,
eos
)
ys_out
=
paddle
.
cat
([
ys_pad
,
_eos
],
dim
=
1
)
ys_out
=
ys_out
.
masked_fill
(
mask_pad
,
eos
)
mask_eos
=
(
ys_out
==
ignore_id
)
ys_out
=
ys_out
.
masked_fill
(
mask_eos
,
eos
)
ys_out
=
ys_out
.
masked_fill
(
mask_pad
,
ignore_id
)
return
ys_in
,
ys_out
def
th_accuracy
(
pad_outputs
:
paddle
.
Tensor
,
pad_targets
:
paddle
.
Tensor
,
ignore_label
:
int
)
->
float
:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred
=
pad_outputs
.
view
(
pad_targets
.
shape
[
0
],
pad_targets
.
shape
[
1
],
pad_outputs
.
shape
[
1
]).
argmax
(
2
)
mask
=
pad_targets
!=
ignore_label
#TODO(Hui Zhang): sum not support bool type
# numerator = paddle.sum(
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator
=
(
pad_pred
.
masked_select
(
mask
)
==
pad_targets
.
masked_select
(
mask
))
numerator
=
paddle
.
sum
(
numerator
.
type_as
(
pad_targets
))
#TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask)
denominator
=
paddle
.
sum
(
mask
.
type_as
(
pad_targets
))
return
float
(
numerator
)
/
float
(
denominator
)
setup.py
浏览文件 @
8f5e6109
...
@@ -38,7 +38,7 @@ base = [
...
@@ -38,7 +38,7 @@ base = [
"pypinyin"
,
"pypinyin-dict"
,
"python-dateutil"
,
"pyworld"
,
"resampy==0.2.2"
,
"pypinyin"
,
"pypinyin-dict"
,
"python-dateutil"
,
"pyworld"
,
"resampy==0.2.2"
,
"sacrebleu"
,
"scipy"
,
"sentencepiece~=0.1.96"
,
"soundfile~=0.10"
,
"sacrebleu"
,
"scipy"
,
"sentencepiece~=0.1.96"
,
"soundfile~=0.10"
,
"textgrid"
,
"timer"
,
"tqdm"
,
"typeguard"
,
"visualdl"
,
"webrtcvad"
,
"textgrid"
,
"timer"
,
"tqdm"
,
"typeguard"
,
"visualdl"
,
"webrtcvad"
,
"yacs~=0.1.8"
,
"prettytable"
,
"zhon"
,
'colorlog'
,
'pathos == 0.2.8'
"yacs~=0.1.8"
,
"prettytable"
,
"zhon"
,
'colorlog'
,
'pathos == 0.2.8'
,
'webdataset'
]
]
server
=
[
server
=
[
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录