Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
0c7abc1f
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0c7abc1f
编写于
6月 27, 2022
作者:
H
huangyuxin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add training scripts
上级
c7a7b113
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
1620 addition
and
132 deletion
+1620
-132
examples/wenetspeech/asr1/conf/conformer.yaml
examples/wenetspeech/asr1/conf/conformer.yaml
+13
-15
examples/wenetspeech/asr1/local/data.sh
examples/wenetspeech/asr1/local/data.sh
+49
-76
examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh
examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh
+2
-2
paddlespeech/audio/streamdata/__init__.py
paddlespeech/audio/streamdata/__init__.py
+3
-3
paddlespeech/audio/streamdata/autodecode.py
paddlespeech/audio/streamdata/autodecode.py
+445
-0
paddlespeech/audio/streamdata/cache.py
paddlespeech/audio/streamdata/cache.py
+2
-2
paddlespeech/audio/streamdata/compat.py
paddlespeech/audio/streamdata/compat.py
+1
-1
paddlespeech/audio/streamdata/extradatasets.py
paddlespeech/audio/streamdata/extradatasets.py
+141
-0
paddlespeech/audio/streamdata/filters.py
paddlespeech/audio/streamdata/filters.py
+2
-2
paddlespeech/audio/streamdata/gopen.py
paddlespeech/audio/streamdata/gopen.py
+340
-0
paddlespeech/audio/streamdata/handlers.py
paddlespeech/audio/streamdata/handlers.py
+47
-0
paddlespeech/audio/streamdata/mix.py
paddlespeech/audio/streamdata/mix.py
+85
-0
paddlespeech/audio/streamdata/paddle_utils.py
paddlespeech/audio/streamdata/paddle_utils.py
+0
-0
paddlespeech/audio/streamdata/pipeline.py
paddlespeech/audio/streamdata/pipeline.py
+1
-2
paddlespeech/audio/streamdata/shardlists.py
paddlespeech/audio/streamdata/shardlists.py
+0
-0
paddlespeech/audio/streamdata/tariterators.py
paddlespeech/audio/streamdata/tariterators.py
+2
-2
paddlespeech/audio/streamdata/utils.py
paddlespeech/audio/streamdata/utils.py
+0
-0
paddlespeech/audio/streamdata/writer.py
paddlespeech/audio/streamdata/writer.py
+450
-0
paddlespeech/s2t/io/dataloader.py
paddlespeech/s2t/io/dataloader.py
+35
-26
setup.py
setup.py
+2
-1
未找到文件。
examples/wenetspeech/asr1/conf/conformer.yaml
浏览文件 @
0c7abc1f
############################################
# Network Architecture #
############################################
cmvn_file
:
cmvn_file_type
:
"
json"
# encoder related
encoder
:
conformer
...
...
@@ -43,9 +42,9 @@ model_conf:
###########################################
# Data #
###########################################
train_manifest
:
data/
manifest.train
dev_manifest
:
data/
manifest.dev
test_manifest
:
data/
manifest.te
st
train_manifest
:
data/
train_l/data.list
dev_manifest
:
data/
dev/data.list
test_manifest
:
data/
test_meeting/data.li
st
###########################################
# Dataloader #
...
...
@@ -54,23 +53,22 @@ use_stream_data: True
unit_type
:
'
char'
vocab_filepath
:
data/lang_char/vocab.txt
cmvn_file
:
data/mean_std.json
preprocess_config
:
conf/preprocess.yaml
spm_model_prefix
:
'
'
feat_dim
:
80
stride_ms
:
10.0
window_ms
:
25.0
dither
:
0.1
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size
:
64
batch_size
:
32
minlen_in
:
10
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduc
ed
maxlen_in
:
1200
# if input length(number of frames) > maxlen-in, data is automatically remov
ed
minlen_out
:
0
maxlen_out
:
150
# if output length
> maxlen-out, batchsize is automatically reduc
ed
maxlen_out
:
150
# if output length
(number of tokens) > maxlen-out, data is automatically remov
ed
resample_rate
:
16000
shuffle_size
:
1
00
00
sort_size
:
5
00
num_workers
:
4
prefetch_factor
:
10
0
shuffle_size
:
1
5
00
sort_size
:
10
00
num_workers
:
0
prefetch_factor
:
10
dist_sampler
:
True
num_encs
:
1
augment_conf
:
...
...
@@ -90,10 +88,10 @@ augment_conf:
###########################################
# Training #
###########################################
n_epoch
:
24
0
accum_grad
:
16
n_epoch
:
3
0
accum_grad
:
32
global_grad_clip
:
5.0
log_interval
:
1
log_interval
:
1
00
checkpoint
:
kbest_n
:
50
latest_n
:
5
...
...
examples/wenetspeech/asr1/local/data.sh
浏览文件 @
0c7abc1f
...
...
@@ -2,6 +2,8 @@
# Copyright 2021 Mobvoi Inc(Author: Di Wu, Binbin Zhang)
# NPU, ASLP Group (Author: Qijie Shao)
#
# Modified from wenet(https://github.com/wenet-e2e/wenet)
stage
=
-1
stop_stage
=
100
...
...
@@ -30,7 +32,7 @@ mkdir -p data
TARGET_DIR
=
${
MAIN_ROOT
}
/dataset
mkdir
-p
${
TARGET_DIR
}
if
[
${
stage
}
-le
-
2
]
&&
[
${
stop_stage
}
-ge
-2
]
;
then
if
[
${
stage
}
-le
-
1
]
&&
[
${
stop_stage
}
-ge
-1
]
;
then
# download data
echo
"Please follow https://github.com/wenet-e2e/WenetSpeech to download the data."
exit
0
;
...
...
@@ -44,86 +46,57 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
data
||
exit
1
;
fi
if
[
${
stage
}
-le
-1
]
&&
[
${
stop_stage
}
-ge
-1
]
;
then
# generate manifests
python3
${
TARGET_DIR
}
/aishell/aishell.py
\
--manifest_prefix
=
"data/manifest"
\
--target_dir
=
"
${
TARGET_DIR
}
/aishell"
if
[
$?
-ne
0
]
;
then
echo
"Prepare Aishell failed. Terminated."
exit
1
fi
for
dataset
in
train dev
test
;
do
mv
data/manifest.
${
dataset
}
data/manifest.
${
dataset
}
.raw
done
fi
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
# compute mean and stddev for normalizer
if
$cmvn
;
then
full_size
=
`
cat
data/
${
train_set
}
/wav.scp |
wc
-l
`
sampling_size
=
$((
full_size
/
cmvn_sampling_divisor
))
shuf
-n
$sampling_size
data/
$train_set
/wav.scp
\
>
data/
$train_set
/wav.scp.sampled
num_workers
=
$(
nproc
)
python3
${
MAIN_ROOT
}
/utils/compute_mean_std.py
\
--manifest_path
=
"data/manifest.train.raw"
\
--spectrum_type
=
"fbank"
\
--feat_dim
=
80
\
--delta_delta
=
false
\
--stride_ms
=
10
\
--window_ms
=
25
\
--sample_rate
=
16000
\
--use_dB_normalization
=
False
\
--num_samples
=
-1
\
--num_workers
=
${
num_workers
}
\
--output_path
=
"data/mean_std.json"
if
[
$?
-ne
0
]
;
then
echo
"Compute mean and stddev failed. Terminated."
exit
1
fi
fi
fi
dict
=
data/dict/lang_char.txt
dict
=
data/lang_char/vocab.txt
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# download data, generate manifests
# build vocabulary
python3
${
MAIN_ROOT
}
/utils/build_vocab.py
\
--unit_type
=
"char"
\
--count_threshold
=
0
\
--vocab_path
=
"data/lang_char/vocab.txt"
\
--manifest_paths
"data/manifest.train.raw"
if
[
$?
-ne
0
]
;
then
echo
"Build vocabulary failed. Terminated."
exit
1
fi
echo
"Make a dictionary"
echo
"dictionary:
${
dict
}
"
mkdir
-p
$(
dirname
$dict
)
echo
"<blank>"
>
${
dict
}
# 0 will be used for "blank" in CTC
echo
"<unk>"
>>
${
dict
}
# <unk> must be 1
echo
"▁"
>>
${
dict
}
# ▁ is for space
utils/text2token.py
-s
1
-n
1
--space
"▁"
data/
${
train_set
}
/text
\
|
cut
-f
2-
-d
" "
|
tr
" "
"
\n
"
\
|
sort
|
uniq
|
grep
-a
-v
-e
'^\s*$'
\
|
grep
-v
"▁"
\
|
awk
'{print $0}'
>>
${
dict
}
\
||
exit
1
;
echo
"<eos>"
>>
$dict
fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# format manifest with tokenids, vocab size
for
dataset
in
train dev
test
;
do
{
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
--cmvn_path
"data/mean_std.json"
\
--unit_type
"char"
\
--vocab_path
=
"data/vocab.txt"
\
--manifest_path
=
"data/manifest.
${
dataset
}
.raw"
\
--output_path
=
"data/manifest.
${
dataset
}
"
echo
"Compute cmvn"
# Here we use all the training data, you can sample some some data to save time
# BUG!!! We should use the segmented data for CMVN
if
$cmvn
;
then
full_size
=
`
cat
data/
${
train_set
}
/wav.scp |
wc
-l
`
sampling_size
=
$((
full_size
/
cmvn_sampling_divisor
))
shuf
-n
$sampling_size
data/
$train_set
/wav.scp
\
>
data/
$train_set
/wav.scp.sampled
python3 utils/compute_cmvn_stats.py
\
--num_workers
16
\
--train_config
$train_config
\
--in_scp
data/
$train_set
/wav.scp.sampled
\
--out_cmvn
data/
$train_set
/mean_std.json
\
||
exit
1
;
fi
fi
if
[
$?
-ne
0
]
;
then
echo
"Formt mnaifest failed. Terminated."
exit
1
fi
}
&
done
wait
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
echo
"Making shards, please wait..."
RED
=
'\033[0;31m'
NOCOLOR
=
'\033[0m'
echo
-e
"It requires
${
RED
}
1.2T
${
NOCOLOR
}
space for
$shards_dir
, please make sure you have enough space"
echo
-e
"It takes about
${
RED
}
12
${
NOCOLOR
}
hours with 32 threads"
for
x
in
$dev_set
$test_sets
${
train_set
}
;
do
dst
=
$shards_dir
/
$x
mkdir
-p
$dst
utils/make_filted_shard_list.py
--resample
16000
--num_utts_per_shard
1000
\
--do_filter
--num_node
1
--num_gpus_per_node
8
\
--num_threads
32
--segments
data/
$x
/segments
\
data/
$x
/wav.scp data/
$x
/text
\
$(
realpath
$dst
)
data/
$x
/data.list
done
fi
echo
"
Aishell
data preparation done."
echo
"
Wenetspeech
data preparation done."
exit
0
examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh
浏览文件 @
0c7abc1f
...
...
@@ -24,7 +24,7 @@ stage=1
prefix
=
train_subset
=
L
.
./
too
ls/parse_options.sh
||
exit
1
;
.
./
uti
ls/parse_options.sh
||
exit
1
;
filter_by_id
()
{
idlist
=
$1
...
...
@@ -132,4 +132,4 @@ if [ $stage -le 2 ]; then
done
fi
echo
"
$0
: Done"
\ No newline at end of file
echo
"
$0
: Done"
paddlespeech/audio/stream
_
data/__init__.py
→
paddlespeech/audio/streamdata/__init__.py
浏览文件 @
0c7abc1f
...
...
@@ -11,7 +11,7 @@ from .cache import (
pipe_cleaner
,
)
from
.compat
import
WebDataset
,
WebLoader
,
FluidWrapper
from
webdataset
.extradatasets
import
MockDataset
,
with_epoch
,
with_length
from
.extradatasets
import
MockDataset
,
with_epoch
,
with_length
from
.filters
import
(
associate
,
batched
,
...
...
@@ -65,5 +65,5 @@ from .shardlists import (
)
from
.tariterators
import
tarfile_samples
,
tarfile_to_samples
from
.utils
import
PipelineStage
,
repeatedly
from
webdataset
.writer
import
ShardWriter
,
TarWriter
,
numpy_dumps
from
webdataset
.mix
import
RandomMix
,
RoundRobin
from
.writer
import
ShardWriter
,
TarWriter
,
numpy_dumps
from
.mix
import
RandomMix
,
RoundRobin
paddlespeech/audio/streamdata/autodecode.py
0 → 100644
浏览文件 @
0c7abc1f
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Automatically decode webdataset samples."""
import
io
,
json
,
os
,
pickle
,
re
,
tempfile
from
functools
import
partial
import
numpy
as
np
"""Extensions passed on to the image decoder."""
image_extensions
=
"jpg jpeg png ppm pgm pbm pnm"
.
split
()
################################################################
# handle basic datatypes
################################################################
def
paddle_loads
(
data
):
"""Load data using paddle.loads, importing paddle only if needed.
:param data: data to be decoded
"""
import
io
import
paddle
stream
=
io
.
BytesIO
(
data
)
return
paddle
.
load
(
stream
)
def
tenbin_loads
(
data
):
from
.
import
tenbin
return
tenbin
.
decode_buffer
(
data
)
def
msgpack_loads
(
data
):
import
msgpack
return
msgpack
.
unpackb
(
data
)
def
npy_loads
(
data
):
import
numpy.lib.format
stream
=
io
.
BytesIO
(
data
)
return
numpy
.
lib
.
format
.
read_array
(
stream
)
def
cbor_loads
(
data
):
import
cbor
return
cbor
.
loads
(
data
)
decoders
=
{
"txt"
:
lambda
data
:
data
.
decode
(
"utf-8"
),
"text"
:
lambda
data
:
data
.
decode
(
"utf-8"
),
"transcript"
:
lambda
data
:
data
.
decode
(
"utf-8"
),
"cls"
:
lambda
data
:
int
(
data
),
"cls2"
:
lambda
data
:
int
(
data
),
"index"
:
lambda
data
:
int
(
data
),
"inx"
:
lambda
data
:
int
(
data
),
"id"
:
lambda
data
:
int
(
data
),
"json"
:
lambda
data
:
json
.
loads
(
data
),
"jsn"
:
lambda
data
:
json
.
loads
(
data
),
"pyd"
:
lambda
data
:
pickle
.
loads
(
data
),
"pickle"
:
lambda
data
:
pickle
.
loads
(
data
),
"pdparams"
:
lambda
data
:
paddle_loads
(
data
),
"ten"
:
tenbin_loads
,
"tb"
:
tenbin_loads
,
"mp"
:
msgpack_loads
,
"msg"
:
msgpack_loads
,
"npy"
:
npy_loads
,
"npz"
:
lambda
data
:
np
.
load
(
io
.
BytesIO
(
data
)),
"cbor"
:
cbor_loads
,
}
def
basichandlers
(
key
,
data
):
"""Handle basic file decoding.
This function is usually part of the post= decoders.
This handles the following forms of decoding:
- txt -> unicode string
- cls cls2 class count index inx id -> int
- json jsn -> JSON decoding
- pyd pickle -> pickle decoding
- pdparams -> paddle.loads
- ten tenbin -> fast tensor loading
- mp messagepack msg -> messagepack decoding
- npy -> Python NPY decoding
:param key: file name extension
:param data: binary data to be decoded
"""
extension
=
re
.
sub
(
r
".*[.]"
,
""
,
key
)
if
extension
in
decoders
:
return
decoders
[
extension
](
data
)
return
None
################################################################
# Generic extension handler.
################################################################
def
call_extension_handler
(
key
,
data
,
f
,
extensions
):
"""Call the function f with the given data if the key matches the extensions.
:param key: actual key found in the sample
:param data: binary data
:param f: decoder function
:param extensions: list of matching extensions
"""
extension
=
key
.
lower
().
split
(
"."
)
for
target
in
extensions
:
target
=
target
.
split
(
"."
)
if
len
(
target
)
>
len
(
extension
):
continue
if
extension
[
-
len
(
target
)
:]
==
target
:
return
f
(
data
)
return
None
def
handle_extension
(
extensions
,
f
):
"""Return a decoder function for the list of extensions.
Extensions can be a space separated list of extensions.
Extensions can contain dots, in which case the corresponding number
of extension components must be present in the key given to f.
Comparisons are case insensitive.
Examples:
handle_extension("jpg jpeg", my_decode_jpg) # invoked for any file.jpg
handle_extension("seg.jpg", special_case_jpg) # invoked only for file.seg.jpg
"""
extensions
=
extensions
.
lower
().
split
()
return
partial
(
call_extension_handler
,
f
=
f
,
extensions
=
extensions
)
################################################################
# handle images
################################################################
imagespecs
=
{
"l8"
:
(
"numpy"
,
"uint8"
,
"l"
),
"rgb8"
:
(
"numpy"
,
"uint8"
,
"rgb"
),
"rgba8"
:
(
"numpy"
,
"uint8"
,
"rgba"
),
"l"
:
(
"numpy"
,
"float"
,
"l"
),
"rgb"
:
(
"numpy"
,
"float"
,
"rgb"
),
"rgba"
:
(
"numpy"
,
"float"
,
"rgba"
),
"paddlel8"
:
(
"paddle"
,
"uint8"
,
"l"
),
"paddlergb8"
:
(
"paddle"
,
"uint8"
,
"rgb"
),
"paddlergba8"
:
(
"paddle"
,
"uint8"
,
"rgba"
),
"paddlel"
:
(
"paddle"
,
"float"
,
"l"
),
"paddlergb"
:
(
"paddle"
,
"float"
,
"rgb"
),
"paddle"
:
(
"paddle"
,
"float"
,
"rgb"
),
"paddlergba"
:
(
"paddle"
,
"float"
,
"rgba"
),
"pill"
:
(
"pil"
,
None
,
"l"
),
"pil"
:
(
"pil"
,
None
,
"rgb"
),
"pilrgb"
:
(
"pil"
,
None
,
"rgb"
),
"pilrgba"
:
(
"pil"
,
None
,
"rgba"
),
}
class
ImageHandler
:
"""Decode image data using the given `imagespec`.
The `imagespec` specifies whether the image is decoded
to numpy/paddle/pi, decoded to uint8/float, and decoded
to l/rgb/rgba:
- l8: numpy uint8 l
- rgb8: numpy uint8 rgb
- rgba8: numpy uint8 rgba
- l: numpy float l
- rgb: numpy float rgb
- rgba: numpy float rgba
- paddlel8: paddle uint8 l
- paddlergb8: paddle uint8 rgb
- paddlergba8: paddle uint8 rgba
- paddlel: paddle float l
- paddlergb: paddle float rgb
- paddle: paddle float rgb
- paddlergba: paddle float rgba
- pill: pil None l
- pil: pil None rgb
- pilrgb: pil None rgb
- pilrgba: pil None rgba
"""
def
__init__
(
self
,
imagespec
,
extensions
=
image_extensions
):
"""Create an image handler.
:param imagespec: short string indicating the type of decoding
:param extensions: list of extensions the image handler is invoked for
"""
if
imagespec
not
in
list
(
imagespecs
.
keys
()):
raise
ValueError
(
"Unknown imagespec: %s"
%
imagespec
)
self
.
imagespec
=
imagespec
.
lower
()
self
.
extensions
=
extensions
def
__call__
(
self
,
key
,
data
):
"""Perform image decoding.
:param key: file name extension
:param data: binary data
"""
import
PIL.Image
extension
=
re
.
sub
(
r
".*[.]"
,
""
,
key
)
if
extension
.
lower
()
not
in
self
.
extensions
:
return
None
imagespec
=
self
.
imagespec
atype
,
etype
,
mode
=
imagespecs
[
imagespec
]
with
io
.
BytesIO
(
data
)
as
stream
:
img
=
PIL
.
Image
.
open
(
stream
)
img
.
load
()
img
=
img
.
convert
(
mode
.
upper
())
if
atype
==
"pil"
:
return
img
elif
atype
==
"numpy"
:
result
=
np
.
asarray
(
img
)
if
result
.
dtype
!=
np
.
uint8
:
raise
ValueError
(
"ImageHandler: numpy image must be uint8"
)
if
etype
==
"uint8"
:
return
result
else
:
return
result
.
astype
(
"f"
)
/
255.0
elif
atype
==
"paddle"
:
import
paddle
result
=
np
.
asarray
(
img
)
if
result
.
dtype
!=
np
.
uint8
:
raise
ValueError
(
"ImageHandler: paddle image must be uint8"
)
if
etype
==
"uint8"
:
result
=
np
.
array
(
result
.
transpose
(
2
,
0
,
1
))
return
paddle
.
tensor
(
result
)
else
:
result
=
np
.
array
(
result
.
transpose
(
2
,
0
,
1
))
return
paddle
.
tensor
(
result
)
/
255.0
return
None
def
imagehandler
(
imagespec
,
extensions
=
image_extensions
):
"""Create an image handler.
This is just a lower case alias for ImageHander.
:param imagespec: textual image spec
:param extensions: list of extensions the handler should be applied for
"""
return
ImageHandler
(
imagespec
,
extensions
)
################################################################
# torch video
################################################################
'''
def torch_video(key, data):
"""Decode video using the torchvideo library.
:param key: file name extension
:param data: data to be decoded
"""
extension = re.sub(r".*[.]", "", key)
if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split():
return None
import torchvision.io
with tempfile.TemporaryDirectory() as dirname:
fname = os.path.join(dirname, f"file.{extension}")
with open(fname, "wb") as stream:
stream.write(data)
return torchvision.io.read_video(fname, pts_unit="sec")
'''
################################################################
# paddleaudio
################################################################
def
paddle_audio
(
key
,
data
):
"""Decode audio using the paddleaudio library.
:param key: file name extension
:param data: data to be decoded
"""
extension
=
re
.
sub
(
r
".*[.]"
,
""
,
key
)
if
extension
not
in
[
"flac"
,
"mp3"
,
"sox"
,
"wav"
,
"m4a"
,
"ogg"
,
"wma"
]:
return
None
import
paddleaudio
with
tempfile
.
TemporaryDirectory
()
as
dirname
:
fname
=
os
.
path
.
join
(
dirname
,
f
"file.
{
extension
}
"
)
with
open
(
fname
,
"wb"
)
as
stream
:
stream
.
write
(
data
)
return
paddleaudio
.
load
(
fname
)
################################################################
# special class for continuing decoding
################################################################
class
Continue
:
"""Special class for continuing decoding.
This is mostly used for decompression, as in:
def decompressor(key, data):
if key.endswith(".gz"):
return Continue(key[:-3], decompress(data))
return None
"""
def
__init__
(
self
,
key
,
data
):
"""__init__.
:param key:
:param data:
"""
self
.
key
,
self
.
data
=
key
,
data
def
gzfilter
(
key
,
data
):
"""Decode .gz files.
This decodes compressed files and the continues decoding.
:param key: file name extension
:param data: binary data
"""
import
gzip
if
not
key
.
endswith
(
".gz"
):
return
None
decompressed
=
gzip
.
open
(
io
.
BytesIO
(
data
)).
read
()
return
Continue
(
key
[:
-
3
],
decompressed
)
################################################################
# decode entire training amples
################################################################
default_pre_handlers
=
[
gzfilter
]
default_post_handlers
=
[
basichandlers
]
class
Decoder
:
"""Decode samples using a list of handlers.
For each key/data item, this iterates through the list of
handlers until some handler returns something other than None.
"""
def
__init__
(
self
,
handlers
,
pre
=
None
,
post
=
None
,
only
=
None
,
partial
=
False
):
"""Create a Decoder.
:param handlers: main list of handlers
:param pre: handlers called before the main list (.gz handler by default)
:param post: handlers called after the main list (default handlers by default)
:param only: a list of extensions; when give, only ignores files with those extensions
:param partial: allow partial decoding (i.e., don't decode fields that aren't of type bytes)
"""
if
isinstance
(
only
,
str
):
only
=
only
.
split
()
self
.
only
=
only
if
only
is
None
else
set
(
only
)
if
pre
is
None
:
pre
=
default_pre_handlers
if
post
is
None
:
post
=
default_post_handlers
assert
all
(
callable
(
h
)
for
h
in
handlers
),
f
"one of
{
handlers
}
not callable"
assert
all
(
callable
(
h
)
for
h
in
pre
),
f
"one of
{
pre
}
not callable"
assert
all
(
callable
(
h
)
for
h
in
post
),
f
"one of
{
post
}
not callable"
self
.
handlers
=
pre
+
handlers
+
post
self
.
partial
=
partial
def
decode1
(
self
,
key
,
data
):
"""Decode a single field of a sample.
:param key: file name extension
:param data: binary data
"""
key
=
"."
+
key
for
f
in
self
.
handlers
:
result
=
f
(
key
,
data
)
if
isinstance
(
result
,
Continue
):
key
,
data
=
result
.
key
,
result
.
data
continue
if
result
is
not
None
:
return
result
return
data
def
decode
(
self
,
sample
):
"""Decode an entire sample.
:param sample: the sample, a dictionary of key value pairs
"""
result
=
{}
assert
isinstance
(
sample
,
dict
),
sample
for
k
,
v
in
list
(
sample
.
items
()):
if
k
[
0
]
==
"_"
:
if
isinstance
(
v
,
bytes
):
v
=
v
.
decode
(
"utf-8"
)
result
[
k
]
=
v
continue
if
self
.
only
is
not
None
and
k
not
in
self
.
only
:
result
[
k
]
=
v
continue
assert
v
is
not
None
if
self
.
partial
:
if
isinstance
(
v
,
bytes
):
result
[
k
]
=
self
.
decode1
(
k
,
v
)
else
:
result
[
k
]
=
v
else
:
assert
isinstance
(
v
,
bytes
)
result
[
k
]
=
self
.
decode1
(
k
,
v
)
return
result
def
__call__
(
self
,
sample
):
"""Decode an entire sample.
:param sample: the sample
"""
assert
isinstance
(
sample
,
dict
),
(
len
(
sample
),
sample
)
return
self
.
decode
(
sample
)
paddlespeech/audio/stream
_
data/cache.py
→
paddlespeech/audio/streamdata/cache.py
浏览文件 @
0c7abc1f
...
...
@@ -6,8 +6,8 @@ import itertools, os, random, re, sys
from
urllib.parse
import
urlparse
from
.
import
filters
from
webdataset
import
gopen
from
webdataset
.handlers
import
reraise_exception
from
.
import
gopen
from
.handlers
import
reraise_exception
from
.tariterators
import
tar_file_and_group_expander
default_cache_dir
=
os
.
environ
.
get
(
"WDS_CACHE"
,
"./_cache"
)
...
...
paddlespeech/audio/stream
_
data/compat.py
→
paddlespeech/audio/streamdata/compat.py
浏览文件 @
0c7abc1f
...
...
@@ -8,7 +8,7 @@ from typing import List
import
braceexpand
,
yaml
from
webdataset
import
autodecode
from
.
import
autodecode
from
.
import
cache
,
filters
,
shardlists
,
tariterators
from
.filters
import
reraise_exception
from
.pipeline
import
DataPipeline
...
...
paddlespeech/audio/streamdata/extradatasets.py
0 → 100644
浏览文件 @
0c7abc1f
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
"""
import
itertools
as
itt
import
os
import
random
import
sys
import
braceexpand
from
.
import
utils
from
.paddle_utils
import
IterableDataset
from
.utils
import
PipelineStage
class
MockDataset
(
IterableDataset
):
"""MockDataset.
A mock dataset for performance testing and unit testing.
"""
def
__init__
(
self
,
sample
,
length
):
"""Create a mock dataset instance.
:param sample: the sample to be returned repeatedly
:param length: the length of the mock dataset
"""
self
.
sample
=
sample
self
.
length
=
length
def
__iter__
(
self
):
"""Return an iterator over this mock dataset."""
for
i
in
range
(
self
.
length
):
yield
self
.
sample
class
repeatedly
(
IterableDataset
,
PipelineStage
):
"""Repeatedly yield samples from a dataset."""
def
__init__
(
self
,
source
,
nepochs
=
None
,
nbatches
=
None
,
length
=
None
):
"""Create an instance of Repeatedly.
:param nepochs: repeat for a maximum of nepochs
:param nbatches: repeat for a maximum of nbatches
"""
self
.
source
=
source
self
.
length
=
length
self
.
nbatches
=
nbatches
def
invoke
(
self
,
source
):
"""Return an iterator that iterates repeatedly over a source."""
return
utils
.
repeatedly
(
source
,
nepochs
=
self
.
nepochs
,
nbatches
=
self
.
nbatches
,
)
class
with_epoch
(
IterableDataset
):
"""Change the actual and nominal length of an IterableDataset.
This will continuously iterate through the original dataset, but
impose new epoch boundaries at the given length/nominal.
This exists mainly as a workaround for the odd logic in DataLoader.
It is also useful for choosing smaller nominal epoch sizes with
very large datasets.
"""
def
__init__
(
self
,
dataset
,
length
):
"""Chop the dataset to the given length.
:param dataset: IterableDataset
:param length: declared length of the dataset
:param nominal: nominal length of dataset (if different from declared)
"""
super
().
__init__
()
self
.
length
=
length
self
.
source
=
None
def
__getstate__
(
self
):
"""Return the pickled state of the dataset.
This resets the dataset iterator, since that can't be pickled.
"""
result
=
dict
(
self
.
__dict__
)
result
[
"source"
]
=
None
return
result
def
invoke
(
self
,
dataset
):
"""Return an iterator over the dataset.
This iterator returns as many samples as given by the `length`
parameter.
"""
if
self
.
source
is
None
:
self
.
source
=
iter
(
dataset
)
for
i
in
range
(
self
.
length
):
try
:
sample
=
next
(
self
.
source
)
except
StopIteration
:
self
.
source
=
iter
(
dataset
)
try
:
sample
=
next
(
self
.
source
)
except
StopIteration
:
return
yield
sample
self
.
source
=
None
class
with_length
(
IterableDataset
,
PipelineStage
):
"""Repeatedly yield samples from a dataset."""
def
__init__
(
self
,
dataset
,
length
):
"""Create an instance of Repeatedly.
:param dataset: source dataset
:param length: stated length
"""
super
().
__init__
()
self
.
dataset
=
dataset
self
.
length
=
length
def
invoke
(
self
,
dataset
):
"""Return an iterator that iterates repeatedly over a source."""
return
iter
(
dataset
)
def
__len__
(
self
):
"""Return the user specified length."""
return
self
.
length
paddlespeech/audio/stream
_
data/filters.py
→
paddlespeech/audio/streamdata/filters.py
浏览文件 @
0c7abc1f
...
...
@@ -21,7 +21,7 @@ from functools import reduce, wraps
import
numpy
as
np
from
webdataset
import
autodecode
from
.
import
autodecode
from
.
import
utils
from
.paddle_utils
import
PaddleTensor
from
.utils
import
PipelineStage
...
...
@@ -932,4 +932,4 @@ def _placeholder(source):
for
data
in
source
:
yield
data
placeholder
=
pipelinefilter
(
_placeholder
)
\ No newline at end of file
placeholder
=
pipelinefilter
(
_placeholder
)
paddlespeech/audio/streamdata/gopen.py
0 → 100644
浏览文件 @
0c7abc1f
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
"""Open URLs by calling subcommands."""
import
os
,
sys
,
re
from
subprocess
import
PIPE
,
Popen
from
urllib.parse
import
urlparse
# global used for printing additional node information during verbose output
info
=
{}
class
Pipe
:
"""Wrapper class for subprocess.Pipe.
This class looks like a stream from the outside, but it checks
subprocess status and handles timeouts with exceptions.
This way, clients of the class do not need to know that they are
dealing with subprocesses.
:param *args: passed to `subprocess.Pipe`
:param **kw: passed to `subprocess.Pipe`
:param timeout: timeout for closing/waiting
:param ignore_errors: don't raise exceptions on subprocess errors
:param ignore_status: list of status codes to ignore
"""
def
__init__
(
self
,
*
args
,
mode
=
None
,
timeout
=
7200.0
,
ignore_errors
=
False
,
ignore_status
=
[],
**
kw
,
):
"""Create an IO Pipe."""
self
.
ignore_errors
=
ignore_errors
self
.
ignore_status
=
[
0
]
+
ignore_status
self
.
timeout
=
timeout
self
.
args
=
(
args
,
kw
)
if
mode
[
0
]
==
"r"
:
self
.
proc
=
Popen
(
*
args
,
stdout
=
PIPE
,
**
kw
)
self
.
stream
=
self
.
proc
.
stdout
if
self
.
stream
is
None
:
raise
ValueError
(
f
"
{
args
}
: couldn't open"
)
elif
mode
[
0
]
==
"w"
:
self
.
proc
=
Popen
(
*
args
,
stdin
=
PIPE
,
**
kw
)
self
.
stream
=
self
.
proc
.
stdin
if
self
.
stream
is
None
:
raise
ValueError
(
f
"
{
args
}
: couldn't open"
)
self
.
status
=
None
def
__str__
(
self
):
return
f
"<Pipe
{
self
.
args
}
>"
def
check_status
(
self
):
"""Poll the process and handle any errors."""
status
=
self
.
proc
.
poll
()
if
status
is
not
None
:
self
.
wait_for_child
()
def
wait_for_child
(
self
):
"""Check the status variable and raise an exception if necessary."""
verbose
=
int
(
os
.
environ
.
get
(
"GOPEN_VERBOSE"
,
0
))
if
self
.
status
is
not
None
and
verbose
:
# print(f"(waiting again [{self.status} {os.getpid()}:{self.proc.pid}])", file=sys.stderr)
return
self
.
status
=
self
.
proc
.
wait
()
if
verbose
:
print
(
f
"pipe exit [
{
self
.
status
}
{
os
.
getpid
()
}
:
{
self
.
proc
.
pid
}
]
{
self
.
args
}
{
info
}
"
,
file
=
sys
.
stderr
,
)
if
self
.
status
not
in
self
.
ignore_status
and
not
self
.
ignore_errors
:
raise
Exception
(
f
"
{
self
.
args
}
: exit
{
self
.
status
}
(read)
{
info
}
"
)
def
read
(
self
,
*
args
,
**
kw
):
"""Wrap stream.read and checks status."""
result
=
self
.
stream
.
read
(
*
args
,
**
kw
)
self
.
check_status
()
return
result
def
write
(
self
,
*
args
,
**
kw
):
"""Wrap stream.write and checks status."""
result
=
self
.
stream
.
write
(
*
args
,
**
kw
)
self
.
check_status
()
return
result
def
readLine
(
self
,
*
args
,
**
kw
):
"""Wrap stream.readLine and checks status."""
result
=
self
.
stream
.
readLine
(
*
args
,
**
kw
)
self
.
status
=
self
.
proc
.
poll
()
self
.
check_status
()
return
result
def
close
(
self
):
"""Wrap stream.close, wait for the subprocess, and handle errors."""
self
.
stream
.
close
()
self
.
status
=
self
.
proc
.
wait
(
self
.
timeout
)
self
.
wait_for_child
()
def
__enter__
(
self
):
"""Context handler."""
return
self
def
__exit__
(
self
,
etype
,
value
,
traceback
):
"""Context handler."""
self
.
close
()
def
set_options
(
obj
,
timeout
=
None
,
ignore_errors
=
None
,
ignore_status
=
None
,
handler
=
None
):
"""Set options for Pipes.
This function can be called on any stream. It will set pipe options only
when its argument is a pipe.
:param obj: any kind of stream
:param timeout: desired timeout
:param ignore_errors: desired ignore_errors setting
:param ignore_status: desired ignore_status setting
:param handler: desired error handler
"""
if
not
isinstance
(
obj
,
Pipe
):
return
False
if
timeout
is
not
None
:
obj
.
timeout
=
timeout
if
ignore_errors
is
not
None
:
obj
.
ignore_errors
=
ignore_errors
if
ignore_status
is
not
None
:
obj
.
ignore_status
=
ignore_status
if
handler
is
not
None
:
obj
.
handler
=
handler
return
True
def
gopen_file
(
url
,
mode
=
"rb"
,
bufsize
=
8192
):
"""Open a file.
This works for local files, files over HTTP, and pipe: files.
:param url: URL to be opened
:param mode: mode to open it with
:param bufsize: requested buffer size
"""
return
open
(
url
,
mode
)
def
gopen_pipe
(
url
,
mode
=
"rb"
,
bufsize
=
8192
):
"""Use gopen to open a pipe.
:param url: a pipe: URL
:param mode: desired mode
:param bufsize: desired buffer size
"""
assert
url
.
startswith
(
"pipe:"
)
cmd
=
url
[
5
:]
if
mode
[
0
]
==
"r"
:
return
Pipe
(
cmd
,
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
],
)
# skipcq: BAN-B604
elif
mode
[
0
]
==
"w"
:
return
Pipe
(
cmd
,
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
],
)
# skipcq: BAN-B604
else
:
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
def
gopen_curl
(
url
,
mode
=
"rb"
,
bufsize
=
8192
):
"""Open a URL with `curl`.
:param url: url (usually, http:// etc.)
:param mode: file mode
:param bufsize: buffer size
"""
if
mode
[
0
]
==
"r"
:
cmd
=
f
"curl -s -L '
{
url
}
'"
return
Pipe
(
cmd
,
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
23
],
)
# skipcq: BAN-B604
elif
mode
[
0
]
==
"w"
:
cmd
=
f
"curl -s -L -T - '
{
url
}
'"
return
Pipe
(
cmd
,
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
26
],
)
# skipcq: BAN-B604
else
:
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
def
gopen_htgs
(
url
,
mode
=
"rb"
,
bufsize
=
8192
):
"""Open a URL with `curl`.
:param url: url (usually, http:// etc.)
:param mode: file mode
:param bufsize: buffer size
"""
if
mode
[
0
]
==
"r"
:
url
=
re
.
sub
(
r
"(?i)^htgs://"
,
"gs://"
,
url
)
cmd
=
f
"curl -s -L '
{
url
}
'"
return
Pipe
(
cmd
,
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
23
],
)
# skipcq: BAN-B604
elif
mode
[
0
]
==
"w"
:
raise
ValueError
(
f
"
{
mode
}
: cannot write"
)
else
:
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
def
gopen_gsutil
(
url
,
mode
=
"rb"
,
bufsize
=
8192
):
"""Open a URL with `curl`.
:param url: url (usually, http:// etc.)
:param mode: file mode
:param bufsize: buffer size
"""
if
mode
[
0
]
==
"r"
:
cmd
=
f
"gsutil cat '
{
url
}
'"
return
Pipe
(
cmd
,
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
23
],
)
# skipcq: BAN-B604
elif
mode
[
0
]
==
"w"
:
cmd
=
f
"gsutil cp - '
{
url
}
'"
return
Pipe
(
cmd
,
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
26
],
)
# skipcq: BAN-B604
else
:
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
def
gopen_error
(
url
,
*
args
,
**
kw
):
"""Raise a value error.
:param url: url
:param args: other arguments
:param kw: other keywords
"""
raise
ValueError
(
f
"
{
url
}
: no gopen handler defined"
)
"""A dispatch table mapping URL schemes to handlers."""
gopen_schemes
=
dict
(
__default__
=
gopen_error
,
pipe
=
gopen_pipe
,
http
=
gopen_curl
,
https
=
gopen_curl
,
sftp
=
gopen_curl
,
ftps
=
gopen_curl
,
scp
=
gopen_curl
,
gs
=
gopen_gsutil
,
htgs
=
gopen_htgs
,
)
def
gopen
(
url
,
mode
=
"rb"
,
bufsize
=
8192
,
**
kw
):
"""Open the URL.
This uses the `gopen_schemes` dispatch table to dispatch based
on scheme.
Support for the following schemes is built-in: pipe, file,
http, https, sftp, ftps, scp.
When no scheme is given the url is treated as a file.
You can use the OPEN_VERBOSE argument to get info about
files being opened.
:param url: the source URL
:param mode: the mode ("rb", "r")
:param bufsize: the buffer size
"""
global
fallback_gopen
verbose
=
int
(
os
.
environ
.
get
(
"GOPEN_VERBOSE"
,
0
))
if
verbose
:
print
(
"GOPEN"
,
url
,
info
,
file
=
sys
.
stderr
)
assert
mode
in
[
"rb"
,
"wb"
],
mode
if
url
==
"-"
:
if
mode
==
"rb"
:
return
sys
.
stdin
.
buffer
elif
mode
==
"wb"
:
return
sys
.
stdout
.
buffer
else
:
raise
ValueError
(
f
"unknown mode
{
mode
}
"
)
pr
=
urlparse
(
url
)
if
pr
.
scheme
==
""
:
bufsize
=
int
(
os
.
environ
.
get
(
"GOPEN_BUFFER"
,
-
1
))
return
open
(
url
,
mode
,
buffering
=
bufsize
)
if
pr
.
scheme
==
"file"
:
bufsize
=
int
(
os
.
environ
.
get
(
"GOPEN_BUFFER"
,
-
1
))
return
open
(
pr
.
path
,
mode
,
buffering
=
bufsize
)
handler
=
gopen_schemes
[
"__default__"
]
handler
=
gopen_schemes
.
get
(
pr
.
scheme
,
handler
)
return
handler
(
url
,
mode
,
bufsize
,
**
kw
)
def
reader
(
url
,
**
kw
):
"""Open url with gopen and mode "rb".
:param url: source URL
:param kw: other keywords forwarded to gopen
"""
return
gopen
(
url
,
"rb"
,
**
kw
)
paddlespeech/audio/streamdata/handlers.py
0 → 100644
浏览文件 @
0c7abc1f
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
"""Pluggable exception handlers.
These are functions that take an exception as an argument and then return...
- the exception (in order to re-raise it)
- True (in order to continue and ignore the exception)
- False (in order to ignore the exception and stop processing)
They are used as handler= arguments in much of the library.
"""
import
time
,
warnings
def
reraise_exception
(
exn
):
"""Call in an exception handler to re-raise the exception."""
raise
exn
def
ignore_and_continue
(
exn
):
"""Call in an exception handler to ignore any exception and continue."""
return
True
def
warn_and_continue
(
exn
):
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
warnings
.
warn
(
repr
(
exn
))
time
.
sleep
(
0.5
)
return
True
def
ignore_and_stop
(
exn
):
"""Call in an exception handler to ignore any exception and stop further processing."""
return
False
def
warn_and_stop
(
exn
):
"""Call in an exception handler to ignore any exception and stop further processing."""
warnings
.
warn
(
repr
(
exn
))
time
.
sleep
(
0.5
)
return
False
paddlespeech/audio/streamdata/mix.py
0 → 100644
浏览文件 @
0c7abc1f
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Classes for mixing samples from multiple sources."""
import
itertools
,
os
,
random
,
time
,
sys
from
functools
import
reduce
,
wraps
import
numpy
as
np
from
.
import
autodecode
,
utils
from
.paddle_utils
import
PaddleTensor
,
IterableDataset
from
.utils
import
PipelineStage
def
round_robin_shortest
(
*
sources
):
i
=
0
while
True
:
try
:
sample
=
next
(
sources
[
i
%
len
(
sources
)])
yield
sample
except
StopIteration
:
break
i
+=
1
def
round_robin_longest
(
*
sources
):
i
=
0
while
len
(
sources
)
>
0
:
try
:
sample
=
next
(
sources
[
i
])
i
+=
1
yield
sample
except
StopIteration
:
del
sources
[
i
]
class
RoundRobin
(
IterableDataset
):
def
__init__
(
self
,
datasets
,
longest
=
False
):
self
.
datasets
=
datasets
self
.
longest
=
longest
def
__iter__
(
self
):
"""Return an iterator over the sources."""
sources
=
[
iter
(
d
)
for
d
in
self
.
datasets
]
if
self
.
longest
:
return
round_robin_longest
(
*
sources
)
else
:
return
round_robin_shortest
(
*
sources
)
def
random_samples
(
sources
,
probs
=
None
,
longest
=
False
):
if
probs
is
None
:
probs
=
[
1
]
*
len
(
sources
)
else
:
probs
=
list
(
probs
)
while
len
(
sources
)
>
0
:
cum
=
(
np
.
array
(
probs
)
/
np
.
sum
(
probs
)).
cumsum
()
r
=
random
.
random
()
i
=
np
.
searchsorted
(
cum
,
r
)
try
:
yield
next
(
sources
[
i
])
except
StopIteration
:
if
longest
:
del
sources
[
i
]
del
probs
[
i
]
else
:
break
class
RandomMix
(
IterableDataset
):
def
__init__
(
self
,
datasets
,
probs
=
None
,
longest
=
False
):
self
.
datasets
=
datasets
self
.
probs
=
probs
self
.
longest
=
longest
def
__iter__
(
self
):
"""Return an iterator over the sources."""
sources
=
[
iter
(
d
)
for
d
in
self
.
datasets
]
return
random_samples
(
sources
,
self
.
probs
,
longest
=
self
.
longest
)
paddlespeech/audio/stream
_
data/paddle_utils.py
→
paddlespeech/audio/streamdata/paddle_utils.py
浏览文件 @
0c7abc1f
文件已移动
paddlespeech/audio/stream
_
data/pipeline.py
→
paddlespeech/audio/streamdata/pipeline.py
浏览文件 @
0c7abc1f
...
...
@@ -10,8 +10,7 @@ from typing import List
import
braceexpand
,
yaml
from
webdataset
import
autodecode
,
extradatasets
as
eds
,
filters
,
shardlists
,
tariterators
from
webdataset.handlers
import
reraise_exception
from
.handlers
import
reraise_exception
from
.paddle_utils
import
DataLoader
,
IterableDataset
from
.utils
import
PipelineStage
...
...
paddlespeech/audio/stream
_
data/shardlists.py
→
paddlespeech/audio/streamdata/shardlists.py
浏览文件 @
0c7abc1f
文件已移动
paddlespeech/audio/stream
_
data/tariterators.py
→
paddlespeech/audio/streamdata/tariterators.py
浏览文件 @
0c7abc1f
...
...
@@ -14,8 +14,8 @@ import random, re, tarfile
import
braceexpand
from
.
import
filters
from
webdataset
import
gopen
from
webdataset
.handlers
import
reraise_exception
from
.
import
gopen
from
.handlers
import
reraise_exception
trace
=
False
meta_prefix
=
"__"
...
...
paddlespeech/audio/stream
_
data/utils.py
→
paddlespeech/audio/streamdata/utils.py
浏览文件 @
0c7abc1f
文件已移动
paddlespeech/audio/streamdata/writer.py
0 → 100644
浏览文件 @
0c7abc1f
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Classes and functions for writing tar files and WebDataset files."""
import
io
,
json
,
pickle
,
re
,
tarfile
,
time
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
numpy
as
np
from
.
import
gopen
def
imageencoder
(
image
:
Any
,
format
:
str
=
"PNG"
):
# skipcq: PYL-W0622
"""Compress an image using PIL and return it as a string.
Can handle float or uint8 images.
:param image: ndarray representing an image
:param format: compression format (PNG, JPEG, PPM)
"""
import
PIL
assert
isinstance
(
image
,
(
PIL
.
Image
.
Image
,
np
.
ndarray
)),
type
(
image
)
if
isinstance
(
image
,
np
.
ndarray
):
if
image
.
dtype
in
[
np
.
dtype
(
"f"
),
np
.
dtype
(
"d"
)]:
if
not
(
np
.
amin
(
image
)
>
-
0.001
and
np
.
amax
(
image
)
<
1.001
):
raise
ValueError
(
f
"image values out of range
{
np
.
amin
(
image
)
}
{
np
.
amax
(
image
)
}
"
)
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
image
=
np
.
array
(
image
*
255.0
,
"uint8"
)
assert
image
.
ndim
in
[
2
,
3
]
if
image
.
ndim
==
3
:
assert
image
.
shape
[
2
]
in
[
1
,
3
]
image
=
PIL
.
Image
.
fromarray
(
image
)
if
format
.
upper
()
==
"JPG"
:
format
=
"JPEG"
elif
format
.
upper
()
in
[
"IMG"
,
"IMAGE"
]:
format
=
"PPM"
if
format
==
"JPEG"
:
opts
=
dict
(
quality
=
100
)
else
:
opts
=
{}
with
io
.
BytesIO
()
as
result
:
image
.
save
(
result
,
format
=
format
,
**
opts
)
return
result
.
getvalue
()
def
bytestr
(
data
:
Any
):
"""Convert data into a bytestring.
Uses str and ASCII encoding for data that isn't already in string format.
:param data: data
"""
if
isinstance
(
data
,
bytes
):
return
data
if
isinstance
(
data
,
str
):
return
data
.
encode
(
"ascii"
)
return
str
(
data
).
encode
(
"ascii"
)
def
paddle_dumps
(
data
:
Any
):
"""Dump data into a bytestring using paddle.dumps.
This delays importing paddle until needed.
:param data: data to be dumped
"""
import
io
import
paddle
stream
=
io
.
BytesIO
()
paddle
.
save
(
data
,
stream
)
return
stream
.
getvalue
()
def
numpy_dumps
(
data
:
np
.
ndarray
):
"""Dump data into a bytestring using numpy npy format.
:param data: data to be dumped
"""
import
io
import
numpy.lib.format
stream
=
io
.
BytesIO
()
numpy
.
lib
.
format
.
write_array
(
stream
,
data
)
return
stream
.
getvalue
()
def
numpy_npz_dumps
(
data
:
np
.
ndarray
):
"""Dump data into a bytestring using numpy npz format.
:param data: data to be dumped
"""
import
io
stream
=
io
.
BytesIO
()
np
.
savez_compressed
(
stream
,
**
data
)
return
stream
.
getvalue
()
def
tenbin_dumps
(
x
):
from
.
import
tenbin
if
isinstance
(
x
,
list
):
return
memoryview
(
tenbin
.
encode_buffer
(
x
))
else
:
return
memoryview
(
tenbin
.
encode_buffer
([
x
]))
def
cbor_dumps
(
x
):
import
cbor
return
cbor
.
dumps
(
x
)
def
mp_dumps
(
x
):
import
msgpack
return
msgpack
.
packb
(
x
)
def
add_handlers
(
d
,
keys
,
value
):
if
isinstance
(
keys
,
str
):
keys
=
keys
.
split
()
for
k
in
keys
:
d
[
k
]
=
value
def
make_handlers
():
"""Create a list of handlers for encoding data."""
handlers
=
{}
add_handlers
(
handlers
,
"cls cls2 class count index inx id"
,
lambda
x
:
str
(
x
).
encode
(
"ascii"
)
)
add_handlers
(
handlers
,
"txt text transcript"
,
lambda
x
:
x
.
encode
(
"utf-8"
))
add_handlers
(
handlers
,
"html htm"
,
lambda
x
:
x
.
encode
(
"utf-8"
))
add_handlers
(
handlers
,
"pyd pickle"
,
pickle
.
dumps
)
add_handlers
(
handlers
,
"pdparams"
,
paddle_dumps
)
add_handlers
(
handlers
,
"npy"
,
numpy_dumps
)
add_handlers
(
handlers
,
"npz"
,
numpy_npz_dumps
)
add_handlers
(
handlers
,
"ten tenbin tb"
,
tenbin_dumps
)
add_handlers
(
handlers
,
"json jsn"
,
lambda
x
:
json
.
dumps
(
x
).
encode
(
"utf-8"
))
add_handlers
(
handlers
,
"mp msgpack msg"
,
mp_dumps
)
add_handlers
(
handlers
,
"cbor"
,
cbor_dumps
)
add_handlers
(
handlers
,
"jpg jpeg img image"
,
lambda
data
:
imageencoder
(
data
,
"jpg"
))
add_handlers
(
handlers
,
"png"
,
lambda
data
:
imageencoder
(
data
,
"png"
))
add_handlers
(
handlers
,
"pbm"
,
lambda
data
:
imageencoder
(
data
,
"pbm"
))
add_handlers
(
handlers
,
"pgm"
,
lambda
data
:
imageencoder
(
data
,
"pgm"
))
add_handlers
(
handlers
,
"ppm"
,
lambda
data
:
imageencoder
(
data
,
"ppm"
))
return
handlers
default_handlers
=
make_handlers
()
def
encode_based_on_extension1
(
data
:
Any
,
tname
:
str
,
handlers
:
dict
):
"""Encode data based on its extension and a dict of handlers.
:param data: data
:param tname: file extension
:param handlers: handlers
"""
if
tname
[
0
]
==
"_"
:
if
not
isinstance
(
data
,
str
):
raise
ValueError
(
"the values of metadata must be of string type"
)
return
data
extension
=
re
.
sub
(
r
".*\."
,
""
,
tname
).
lower
()
if
isinstance
(
data
,
bytes
):
return
data
if
isinstance
(
data
,
str
):
return
data
.
encode
(
"utf-8"
)
handler
=
handlers
.
get
(
extension
)
if
handler
is
None
:
raise
ValueError
(
f
"no handler found for
{
extension
}
"
)
return
handler
(
data
)
def
encode_based_on_extension
(
sample
:
dict
,
handlers
:
dict
):
"""Encode an entire sample with a collection of handlers.
:param sample: data sample (a dict)
:param handlers: handlers for encoding
"""
return
{
k
:
encode_based_on_extension1
(
v
,
k
,
handlers
)
for
k
,
v
in
list
(
sample
.
items
())
}
def
make_encoder
(
spec
:
Union
[
bool
,
str
,
dict
,
Callable
]):
"""Make an encoder function from a specification.
:param spec: specification
"""
if
spec
is
False
or
spec
is
None
:
def
encoder
(
x
):
"""Do not encode at all."""
return
x
elif
callable
(
spec
):
encoder
=
spec
elif
isinstance
(
spec
,
dict
):
def
f
(
sample
):
"""Encode based on extension."""
return
encode_based_on_extension
(
sample
,
spec
)
encoder
=
f
elif
spec
is
True
:
handlers
=
default_handlers
def
g
(
sample
):
"""Encode based on extension."""
return
encode_based_on_extension
(
sample
,
handlers
)
encoder
=
g
else
:
raise
ValueError
(
f
"
{
spec
}
: unknown decoder spec"
)
if
not
callable
(
encoder
):
raise
ValueError
(
f
"
{
spec
}
did not yield a callable encoder"
)
return
encoder
class
TarWriter
:
"""A class for writing dictionaries to tar files.
:param fileobj: fileobj: file name for tar file (.tgz/.tar) or open file descriptor
:param encoder: sample encoding (Default value = True)
:param compress: (Default value = None)
`True` will use an encoder that behaves similar to the automatic
decoder for `Dataset`. `False` disables encoding and expects byte strings
(except for metadata, which must be strings). The `encoder` argument can
also be a `callable`, or a dictionary mapping extensions to encoders.
The following code will add two file to the tar archive: `a/b.png` and
`a/b.output.png`.
```Python
tarwriter = TarWriter(stream)
image = imread("b.jpg")
image2 = imread("b.out.jpg")
sample = {"__key__": "a/b", "png": image, "output.png": image2}
tarwriter.write(sample)
```
"""
def
__init__
(
self
,
fileobj
,
user
:
str
=
"bigdata"
,
group
:
str
=
"bigdata"
,
mode
:
int
=
0o0444
,
compress
:
Optional
[
bool
]
=
None
,
encoder
:
Union
[
None
,
bool
,
Callable
]
=
True
,
keep_meta
:
bool
=
False
,
):
"""Create a tar writer.
:param fileobj: stream to write data to
:param user: user for tar files
:param group: group for tar files
:param mode: mode for tar files
:param compress: desired compression
:param encoder: encoder function
:param keep_meta: keep metadata (entries starting with "_")
"""
if
isinstance
(
fileobj
,
str
):
if
compress
is
False
:
tarmode
=
"w|"
elif
compress
is
True
:
tarmode
=
"w|gz"
else
:
tarmode
=
"w|gz"
if
fileobj
.
endswith
(
"gz"
)
else
"w|"
fileobj
=
gopen
.
gopen
(
fileobj
,
"wb"
)
self
.
own_fileobj
=
fileobj
else
:
tarmode
=
"w|gz"
if
compress
is
True
else
"w|"
self
.
own_fileobj
=
None
self
.
encoder
=
make_encoder
(
encoder
)
self
.
keep_meta
=
keep_meta
self
.
stream
=
fileobj
self
.
tarstream
=
tarfile
.
open
(
fileobj
=
fileobj
,
mode
=
tarmode
)
self
.
user
=
user
self
.
group
=
group
self
.
mode
=
mode
self
.
compress
=
compress
def
__enter__
(
self
):
"""Enter context."""
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
"""Exit context."""
self
.
close
()
def
close
(
self
):
"""Close the tar file."""
self
.
tarstream
.
close
()
if
self
.
own_fileobj
is
not
None
:
self
.
own_fileobj
.
close
()
self
.
own_fileobj
=
None
def
write
(
self
,
obj
):
"""Write a dictionary to the tar file.
:param obj: dictionary of objects to be stored
:returns: size of the entry
"""
total
=
0
obj
=
self
.
encoder
(
obj
)
if
"__key__"
not
in
obj
:
raise
ValueError
(
"object must contain a __key__"
)
for
k
,
v
in
list
(
obj
.
items
()):
if
k
[
0
]
==
"_"
:
continue
if
not
isinstance
(
v
,
(
bytes
,
bytearray
,
memoryview
)):
raise
ValueError
(
f
"
{
k
}
doesn't map to a bytes after encoding (
{
type
(
v
)
}
)"
)
key
=
obj
[
"__key__"
]
for
k
in
sorted
(
obj
.
keys
()):
if
k
==
"__key__"
:
continue
if
not
self
.
keep_meta
and
k
[
0
]
==
"_"
:
continue
v
=
obj
[
k
]
if
isinstance
(
v
,
str
):
v
=
v
.
encode
(
"utf-8"
)
now
=
time
.
time
()
ti
=
tarfile
.
TarInfo
(
key
+
"."
+
k
)
ti
.
size
=
len
(
v
)
ti
.
mtime
=
now
ti
.
mode
=
self
.
mode
ti
.
uname
=
self
.
user
ti
.
gname
=
self
.
group
if
not
isinstance
(
v
,
(
bytes
,
bytearray
,
memoryview
)):
raise
ValueError
(
f
"converter didn't yield bytes:
{
k
}
,
{
type
(
v
)
}
"
)
stream
=
io
.
BytesIO
(
v
)
self
.
tarstream
.
addfile
(
ti
,
stream
)
total
+=
ti
.
size
return
total
class
ShardWriter
:
"""Like TarWriter but splits into multiple shards."""
def
__init__
(
self
,
pattern
:
str
,
maxcount
:
int
=
100000
,
maxsize
:
float
=
3e9
,
post
:
Optional
[
Callable
]
=
None
,
start_shard
:
int
=
0
,
**
kw
,
):
"""Create a ShardWriter.
:param pattern: output file pattern
:param maxcount: maximum number of records per shard (Default value = 100000)
:param maxsize: maximum size of each shard (Default value = 3e9)
:param kw: other options passed to TarWriter
"""
self
.
verbose
=
1
self
.
kw
=
kw
self
.
maxcount
=
maxcount
self
.
maxsize
=
maxsize
self
.
post
=
post
self
.
tarstream
=
None
self
.
shard
=
start_shard
self
.
pattern
=
pattern
self
.
total
=
0
self
.
count
=
0
self
.
size
=
0
self
.
fname
=
None
self
.
next_stream
()
def
next_stream
(
self
):
"""Close the current stream and move to the next."""
self
.
finish
()
self
.
fname
=
self
.
pattern
%
self
.
shard
if
self
.
verbose
:
print
(
"# writing"
,
self
.
fname
,
self
.
count
,
"%.1f GB"
%
(
self
.
size
/
1e9
),
self
.
total
,
)
self
.
shard
+=
1
stream
=
open
(
self
.
fname
,
"wb"
)
self
.
tarstream
=
TarWriter
(
stream
,
**
self
.
kw
)
self
.
count
=
0
self
.
size
=
0
def
write
(
self
,
obj
):
"""Write a sample.
:param obj: sample to be written
"""
if
(
self
.
tarstream
is
None
or
self
.
count
>=
self
.
maxcount
or
self
.
size
>=
self
.
maxsize
):
self
.
next_stream
()
size
=
self
.
tarstream
.
write
(
obj
)
self
.
count
+=
1
self
.
total
+=
1
self
.
size
+=
size
def
finish
(
self
):
"""Finish all writing (use close instead)."""
if
self
.
tarstream
is
not
None
:
self
.
tarstream
.
close
()
assert
self
.
fname
is
not
None
if
callable
(
self
.
post
):
self
.
post
(
self
.
fname
)
self
.
tarstream
=
None
def
close
(
self
):
"""Close the stream."""
self
.
finish
()
del
self
.
tarstream
del
self
.
shard
del
self
.
count
del
self
.
size
def
__enter__
(
self
):
"""Enter context."""
return
self
def
__exit__
(
self
,
*
args
,
**
kw
):
"""Exit context."""
self
.
close
()
paddlespeech/s2t/io/dataloader.py
浏览文件 @
0c7abc1f
...
...
@@ -18,6 +18,7 @@ from typing import Text
import
jsonlines
import
numpy
as
np
import
paddle
from
paddle.io
import
BatchSampler
from
paddle.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
...
...
@@ -28,7 +29,7 @@ from paddlespeech.s2t.io.dataset import TransformDataset
from
paddlespeech.s2t.io.reader
import
LoadInputsAndTargets
from
paddlespeech.s2t.utils.log
import
Log
import
paddlespeech.audio.stream
_data
as
stream_
data
import
paddlespeech.audio.stream
data
as
stream
data
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
__all__
=
[
"BatchDataLoader"
]
...
...
@@ -101,38 +102,46 @@ class StreamDataLoader():
shardlist
.
append
(
line
.
strip
())
if
self
.
dist_sampler
:
base_dataset
=
stream
_
data
.
DataPipeline
(
stream
_
data
.
SimpleShardList
(
shardlist
),
stream
_
data
.
split_by_node
,
stream
_
data
.
split_by_worker
,
stream
_data
.
tarfile_to_samples
(
stream_
data
.
reraise_exception
)
base_dataset
=
streamdata
.
DataPipeline
(
streamdata
.
SimpleShardList
(
shardlist
),
streamdata
.
split_by_node
,
streamdata
.
split_by_worker
,
stream
data
.
tarfile_to_samples
(
stream
data
.
reraise_exception
)
)
else
:
base_dataset
=
stream
_
data
.
DataPipeline
(
stream
_
data
.
SimpleShardList
(
shardlist
),
stream
_
data
.
split_by_worker
,
stream
_data
.
tarfile_to_samples
(
stream_
data
.
reraise_exception
)
base_dataset
=
streamdata
.
DataPipeline
(
streamdata
.
SimpleShardList
(
shardlist
),
streamdata
.
split_by_worker
,
stream
data
.
tarfile_to_samples
(
stream
data
.
reraise_exception
)
)
self
.
dataset
=
base_dataset
.
append_list
(
stream_data
.
tokenize
(
symbol_table
),
stream_data
.
data_filter
(
frame_shift
=
frame_shift
,
max_length
=
maxlen_in
,
min_length
=
minlen_in
,
token_max_length
=
maxlen_out
,
token_min_length
=
minlen_in
),
stream_data
.
resample
(
resample_rate
=
resample_rate
),
stream_data
.
compute_fbank
(
num_mel_bins
=
num_mel_bins
,
frame_length
=
frame_length
,
frame_shift
=
frame_shift
,
dither
=
dither
),
stream_data
.
spec_aug
(
**
augment_conf
)
if
train_mode
else
stream_data
.
placeholder
(),
# num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
stream_data
.
shuffle
(
shuffle_size
),
stream_data
.
sort
(
sort_size
=
sort_size
),
stream_data
.
batched
(
batch_size
),
stream_data
.
padding
(),
stream_data
.
cmvn
(
cmvn_file
)
)
self
.
loader
=
stream_data
.
WebLoader
(
self
.
dataset
,
num_workers
=
self
.
n_iter_processes
,
prefetch_factor
=
self
.
prefetch_factor
,
batch_size
=
None
streamdata
.
tokenize
(
symbol_table
),
streamdata
.
data_filter
(
frame_shift
=
frame_shift
,
max_length
=
maxlen_in
,
min_length
=
minlen_in
,
token_max_length
=
maxlen_out
,
token_min_length
=
minlen_in
),
streamdata
.
resample
(
resample_rate
=
resample_rate
),
streamdata
.
compute_fbank
(
num_mel_bins
=
num_mel_bins
,
frame_length
=
frame_length
,
frame_shift
=
frame_shift
,
dither
=
dither
),
streamdata
.
spec_aug
(
**
augment_conf
)
if
train_mode
else
streamdata
.
placeholder
(),
# num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
streamdata
.
shuffle
(
shuffle_size
),
streamdata
.
sort
(
sort_size
=
sort_size
),
streamdata
.
batched
(
batch_size
),
streamdata
.
padding
(),
streamdata
.
cmvn
(
cmvn_file
)
)
if
paddle
.
__version__
>=
'2.3.2'
:
self
.
loader
=
streamdata
.
WebLoader
(
self
.
dataset
,
num_workers
=
self
.
n_iter_processes
,
prefetch_factor
=
self
.
prefetch_factor
,
batch_size
=
None
)
else
:
self
.
loader
=
streamdata
.
WebLoader
(
self
.
dataset
,
num_workers
=
self
.
n_iter_processes
,
batch_size
=
None
)
def
__iter__
(
self
):
return
self
.
loader
.
__iter__
()
...
...
setup.py
浏览文件 @
0c7abc1f
...
...
@@ -38,7 +38,8 @@ base = [
"pypinyin"
,
"pypinyin-dict"
,
"python-dateutil"
,
"pyworld"
,
"resampy==0.2.2"
,
"sacrebleu"
,
"scipy"
,
"sentencepiece~=0.1.96"
,
"soundfile~=0.10"
,
"textgrid"
,
"timer"
,
"tqdm"
,
"typeguard"
,
"visualdl"
,
"webrtcvad"
,
"yacs~=0.1.8"
,
"prettytable"
,
"zhon"
,
'colorlog'
,
'pathos == 0.2.8'
,
'webdataset'
"yacs~=0.1.8"
,
"prettytable"
,
"zhon"
,
"colorlog"
,
"pathos == 0.2.8"
,
"braceexpand"
,
"pyyaml"
]
server
=
[
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录