Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
d1a25f6c
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d1a25f6c
编写于
7月 06, 2022
作者:
J
Jackwaterveg
提交者:
GitHub
7月 06, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2062 from Jackwaterveg/webdataset
[Audio] Add webdataset in paddlespeech.audio
上级
803fec21
05d41523
变更
51
隐藏空白更改
内联
并排
Showing
51 changed file
with
5199 addition
and
392 deletion
+5199
-392
examples/wenetspeech/asr1/conf/conformer.yaml
examples/wenetspeech/asr1/conf/conformer.yaml
+20
-19
examples/wenetspeech/asr1/local/data.sh
examples/wenetspeech/asr1/local/data.sh
+49
-76
examples/wenetspeech/asr1/local/train.sh
examples/wenetspeech/asr1/local/train.sh
+68
-0
examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh
examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh
+2
-2
examples/wenetspeech/asr1/run.sh
examples/wenetspeech/asr1/run.sh
+2
-1
paddlespeech/audio/streamdata/__init__.py
paddlespeech/audio/streamdata/__init__.py
+70
-0
paddlespeech/audio/streamdata/autodecode.py
paddlespeech/audio/streamdata/autodecode.py
+445
-0
paddlespeech/audio/streamdata/cache.py
paddlespeech/audio/streamdata/cache.py
+190
-0
paddlespeech/audio/streamdata/compat.py
paddlespeech/audio/streamdata/compat.py
+170
-0
paddlespeech/audio/streamdata/extradatasets.py
paddlespeech/audio/streamdata/extradatasets.py
+141
-0
paddlespeech/audio/streamdata/filters.py
paddlespeech/audio/streamdata/filters.py
+935
-0
paddlespeech/audio/streamdata/gopen.py
paddlespeech/audio/streamdata/gopen.py
+340
-0
paddlespeech/audio/streamdata/handlers.py
paddlespeech/audio/streamdata/handlers.py
+47
-0
paddlespeech/audio/streamdata/mix.py
paddlespeech/audio/streamdata/mix.py
+85
-0
paddlespeech/audio/streamdata/paddle_utils.py
paddlespeech/audio/streamdata/paddle_utils.py
+33
-0
paddlespeech/audio/streamdata/pipeline.py
paddlespeech/audio/streamdata/pipeline.py
+132
-0
paddlespeech/audio/streamdata/shardlists.py
paddlespeech/audio/streamdata/shardlists.py
+261
-0
paddlespeech/audio/streamdata/tariterators.py
paddlespeech/audio/streamdata/tariterators.py
+283
-0
paddlespeech/audio/streamdata/utils.py
paddlespeech/audio/streamdata/utils.py
+132
-0
paddlespeech/audio/streamdata/writer.py
paddlespeech/audio/streamdata/writer.py
+450
-0
paddlespeech/audio/text/text_featurizer.py
paddlespeech/audio/text/text_featurizer.py
+235
-0
paddlespeech/audio/text/utility.py
paddlespeech/audio/text/utility.py
+393
-0
paddlespeech/audio/transform/__init__.py
paddlespeech/audio/transform/__init__.py
+0
-0
paddlespeech/audio/transform/add_deltas.py
paddlespeech/audio/transform/add_deltas.py
+0
-0
paddlespeech/audio/transform/channel_selector.py
paddlespeech/audio/transform/channel_selector.py
+0
-0
paddlespeech/audio/transform/cmvn.py
paddlespeech/audio/transform/cmvn.py
+0
-0
paddlespeech/audio/transform/functional.py
paddlespeech/audio/transform/functional.py
+2
-2
paddlespeech/audio/transform/perturb.py
paddlespeech/audio/transform/perturb.py
+91
-1
paddlespeech/audio/transform/spec_augment.py
paddlespeech/audio/transform/spec_augment.py
+1
-1
paddlespeech/audio/transform/spectrogram.py
paddlespeech/audio/transform/spectrogram.py
+1
-1
paddlespeech/audio/transform/transform_interface.py
paddlespeech/audio/transform/transform_interface.py
+0
-0
paddlespeech/audio/transform/transformation.py
paddlespeech/audio/transform/transformation.py
+24
-24
paddlespeech/audio/transform/wpe.py
paddlespeech/audio/transform/wpe.py
+0
-0
paddlespeech/audio/utils/check_kwargs.py
paddlespeech/audio/utils/check_kwargs.py
+35
-0
paddlespeech/audio/utils/dynamic_import.py
paddlespeech/audio/utils/dynamic_import.py
+38
-0
paddlespeech/audio/utils/log.py
paddlespeech/audio/utils/log.py
+2
-1
paddlespeech/audio/utils/tensor_utils.py
paddlespeech/audio/utils/tensor_utils.py
+192
-0
paddlespeech/cli/asr/infer.py
paddlespeech/cli/asr/infer.py
+1
-1
paddlespeech/s2t/exps/deepspeech2/model.py
paddlespeech/s2t/exps/deepspeech2/model.py
+1
-1
paddlespeech/s2t/exps/u2/bin/test_wav.py
paddlespeech/s2t/exps/u2/bin/test_wav.py
+1
-1
paddlespeech/s2t/exps/u2/model.py
paddlespeech/s2t/exps/u2/model.py
+17
-87
paddlespeech/s2t/exps/u2_kaldi/model.py
paddlespeech/s2t/exps/u2_kaldi/model.py
+28
-88
paddlespeech/s2t/exps/u2_st/model.py
paddlespeech/s2t/exps/u2_st/model.py
+18
-75
paddlespeech/s2t/io/dataloader.py
paddlespeech/s2t/io/dataloader.py
+252
-1
paddlespeech/s2t/io/reader.py
paddlespeech/s2t/io/reader.py
+1
-1
paddlespeech/s2t/models/u2/u2.py
paddlespeech/s2t/models/u2/u2.py
+3
-3
paddlespeech/s2t/models/u2_st/u2_st.py
paddlespeech/s2t/models/u2_st/u2_st.py
+2
-2
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
+1
-1
paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py
...ch/server/engine/asr/online/paddleinference/asr_engine.py
+1
-1
paddlespeech/server/engine/asr/online/python/asr_engine.py
paddlespeech/server/engine/asr/online/python/asr_engine.py
+1
-1
setup.py
setup.py
+3
-1
未找到文件。
examples/wenetspeech/asr1/conf/conformer.yaml
浏览文件 @
d1a25f6c
############################################
############################################
# Network Architecture #
# Network Architecture #
############################################
############################################
cmvn_file
:
cmvn_file_type
:
"
json"
cmvn_file_type
:
"
json"
# encoder related
# encoder related
encoder
:
conformer
encoder
:
conformer
...
@@ -43,40 +42,42 @@ model_conf:
...
@@ -43,40 +42,42 @@ model_conf:
###########################################
###########################################
# Data #
# Data #
###########################################
###########################################
train_manifest
:
data/
manifest.train
train_manifest
:
data/
train_l/data.list
dev_manifest
:
data/
manifest.dev
dev_manifest
:
data/
dev/data.list
test_manifest
:
data/
manifest.te
st
test_manifest
:
data/
test_meeting/data.li
st
###########################################
###########################################
# Dataloader #
# Dataloader #
###########################################
###########################################
vocab_filepath
:
data/lang_char/vocab.txt
use_stream_data
:
True
unit_type
:
'
char'
unit_type
:
'
char'
vocab_filepath
:
data/lang_char/vocab.txt
preprocess_config
:
conf/preprocess.yaml
preprocess_config
:
conf/preprocess.yaml
cmvn_file
:
data/mean_std.json
spm_model_prefix
:
'
'
spm_model_prefix
:
'
'
feat_dim
:
80
feat_dim
:
80
stride_ms
:
10.0
stride_ms
:
10.0
window_ms
:
25.0
window_ms
:
25.0
dither
:
0.1
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size
:
64
batch_size
:
32
m
axlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
m
inlen_in
:
10
maxlen_
out
:
150
# if output length > maxlen-out, batchsize is automatically reduc
ed
maxlen_
in
:
1200
# if input length(number of frames) > maxlen-in, data is automatically remov
ed
min
ibatches
:
0
# for debug
min
len_out
:
0
batch_count
:
auto
maxlen_out
:
150
# if output length(number of tokens) > maxlen-out, data is automatically removed
batch_bins
:
0
resample_rate
:
16000
batch_frames_in
:
0
shuffle_size
:
1500
# read number of 'shuffle_size' data as a chunk, shuffle the data in the chunk
batch_frames_out
:
0
sort_size
:
1000
# read number of 'sort_size' data as a chunk, sort the data in the chunk
batch_frames_inout
:
0
num_workers
:
8
num_workers
:
0
prefetch_factor
:
1
0
subsampling_factor
:
1
dist_sampler
:
True
num_encs
:
1
num_encs
:
1
###########################################
###########################################
# Training #
# Training #
###########################################
###########################################
n_epoch
:
240
n_epoch
:
32
accum_grad
:
16
accum_grad
:
32
global_grad_clip
:
5.0
global_grad_clip
:
5.0
log_interval
:
100
log_interval
:
100
checkpoint
:
checkpoint
:
...
...
examples/wenetspeech/asr1/local/data.sh
浏览文件 @
d1a25f6c
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
# Copyright 2021 Mobvoi Inc(Author: Di Wu, Binbin Zhang)
# Copyright 2021 Mobvoi Inc(Author: Di Wu, Binbin Zhang)
# NPU, ASLP Group (Author: Qijie Shao)
# NPU, ASLP Group (Author: Qijie Shao)
#
# Modified from wenet(https://github.com/wenet-e2e/wenet)
stage
=
-1
stage
=
-1
stop_stage
=
100
stop_stage
=
100
...
@@ -30,7 +32,7 @@ mkdir -p data
...
@@ -30,7 +32,7 @@ mkdir -p data
TARGET_DIR
=
${
MAIN_ROOT
}
/dataset
TARGET_DIR
=
${
MAIN_ROOT
}
/dataset
mkdir
-p
${
TARGET_DIR
}
mkdir
-p
${
TARGET_DIR
}
if
[
${
stage
}
-le
-
2
]
&&
[
${
stop_stage
}
-ge
-2
]
;
then
if
[
${
stage
}
-le
-
1
]
&&
[
${
stop_stage
}
-ge
-1
]
;
then
# download data
# download data
echo
"Please follow https://github.com/wenet-e2e/WenetSpeech to download the data."
echo
"Please follow https://github.com/wenet-e2e/WenetSpeech to download the data."
exit
0
;
exit
0
;
...
@@ -44,86 +46,57 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
...
@@ -44,86 +46,57 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
data
||
exit
1
;
data
||
exit
1
;
fi
fi
if
[
${
stage
}
-le
-1
]
&&
[
${
stop_stage
}
-ge
-1
]
;
then
dict
=
data/lang_char/vocab.txt
# generate manifests
python3
${
TARGET_DIR
}
/aishell/aishell.py
\
--manifest_prefix
=
"data/manifest"
\
--target_dir
=
"
${
TARGET_DIR
}
/aishell"
if
[
$?
-ne
0
]
;
then
echo
"Prepare Aishell failed. Terminated."
exit
1
fi
for
dataset
in
train dev
test
;
do
mv
data/manifest.
${
dataset
}
data/manifest.
${
dataset
}
.raw
done
fi
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
# compute mean and stddev for normalizer
if
$cmvn
;
then
full_size
=
`
cat
data/
${
train_set
}
/wav.scp |
wc
-l
`
sampling_size
=
$((
full_size
/
cmvn_sampling_divisor
))
shuf
-n
$sampling_size
data/
$train_set
/wav.scp
\
>
data/
$train_set
/wav.scp.sampled
num_workers
=
$(
nproc
)
python3
${
MAIN_ROOT
}
/utils/compute_mean_std.py
\
--manifest_path
=
"data/manifest.train.raw"
\
--spectrum_type
=
"fbank"
\
--feat_dim
=
80
\
--delta_delta
=
false
\
--stride_ms
=
10
\
--window_ms
=
25
\
--sample_rate
=
16000
\
--use_dB_normalization
=
False
\
--num_samples
=
-1
\
--num_workers
=
${
num_workers
}
\
--output_path
=
"data/mean_std.json"
if
[
$?
-ne
0
]
;
then
echo
"Compute mean and stddev failed. Terminated."
exit
1
fi
fi
fi
dict
=
data/dict/lang_char.txt
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# download data, generate manifests
echo
"Make a dictionary"
# build vocabulary
echo
"dictionary:
${
dict
}
"
python3
${
MAIN_ROOT
}
/utils/build_vocab.py
\
mkdir
-p
$(
dirname
$dict
)
--unit_type
=
"char"
\
echo
"<blank>"
>
${
dict
}
# 0 will be used for "blank" in CTC
--count_threshold
=
0
\
echo
"<unk>"
>>
${
dict
}
# <unk> must be 1
--vocab_path
=
"data/lang_char/vocab.txt"
\
echo
"▁"
>>
${
dict
}
# ▁ is for space
--manifest_paths
"data/manifest.train.raw"
utils/text2token.py
-s
1
-n
1
--space
"▁"
data/
${
train_set
}
/text
\
|
cut
-f
2-
-d
" "
|
tr
" "
"
\n
"
\
if
[
$?
-ne
0
]
;
then
|
sort
|
uniq
|
grep
-a
-v
-e
'^\s*$'
\
echo
"Build vocabulary failed. Terminated."
|
grep
-v
"▁"
\
exit
1
|
awk
'{print $0}'
>>
${
dict
}
\
fi
||
exit
1
;
echo
"<eos>"
>>
$dict
fi
fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# format manifest with tokenids, vocab size
echo
"Compute cmvn"
for
dataset
in
train dev
test
;
do
# Here we use all the training data, you can sample some some data to save time
{
# BUG!!! We should use the segmented data for CMVN
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
if
$cmvn
;
then
--cmvn_path
"data/mean_std.json"
\
full_size
=
`
cat
data/
${
train_set
}
/wav.scp |
wc
-l
`
--unit_type
"char"
\
sampling_size
=
$((
full_size
/
cmvn_sampling_divisor
))
--vocab_path
=
"data/vocab.txt"
\
shuf
-n
$sampling_size
data/
$train_set
/wav.scp
\
--manifest_path
=
"data/manifest.
${
dataset
}
.raw"
\
>
data/
$train_set
/wav.scp.sampled
--output_path
=
"data/manifest.
${
dataset
}
"
python3 utils/compute_cmvn_stats.py
\
--num_workers
16
\
--train_config
$train_config
\
--in_scp
data/
$train_set
/wav.scp.sampled
\
--out_cmvn
data/
$train_set
/mean_std.json
\
||
exit
1
;
fi
fi
if
[
$?
-ne
0
]
;
then
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
echo
"Formt mnaifest failed. Terminated."
echo
"Making shards, please wait..."
exit
1
RED
=
'\033[0;31m'
fi
NOCOLOR
=
'\033[0m'
}
&
echo
-e
"It requires
${
RED
}
1.2T
${
NOCOLOR
}
space for
$shards_dir
, please make sure you have enough space"
done
echo
-e
"It takes about
${
RED
}
12
${
NOCOLOR
}
hours with 32 threads"
wait
for
x
in
$dev_set
$test_sets
${
train_set
}
;
do
dst
=
$shards_dir
/
$x
mkdir
-p
$dst
utils/make_filted_shard_list.py
--num_node
1
--num_gpus_per_node
8
--num_utts_per_shard
1000
\
--do_filter
--resample
16000
\
--num_threads
32
--segments
data/
$x
/segments
\
data/
$x
/wav.scp data/
$x
/text
\
$(
realpath
$dst
)
data/
$x
/data.list
done
fi
fi
echo
"
Aishell
data preparation done."
echo
"
Wenetspeech
data preparation done."
exit
0
exit
0
examples/wenetspeech/asr1/local/train.sh
0 → 100755
浏览文件 @
d1a25f6c
#!/bin/bash
profiler_options
=
benchmark_batch_size
=
0
benchmark_max_step
=
0
# seed may break model convergence
seed
=
0
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
echo
"using
$ngpu
gpus..."
if
[
${
seed
}
!=
0
]
;
then
export
FLAGS_cudnn_deterministic
=
True
echo
"using seed
$seed
& FLAGS_cudnn_deterministic=True ..."
fi
if
[
$#
-lt
2
]
&&
[
$#
-gt
3
]
;
then
echo
"usage: CUDA_VISIBLE_DEVICES=0
${
0
}
config_path ckpt_name ips(optional)"
exit
-1
fi
config_path
=
$1
ckpt_name
=
$2
ips
=
$3
if
[
!
$ips
]
;
then
ips_config
=
else
ips_config
=
"--ips="
${
ips
}
fi
echo
${
ips_config
}
mkdir
-p
exp
if
[
${
ngpu
}
==
0
]
;
then
python3
-u
${
BIN_DIR
}
/train.py
\
--ngpu
${
ngpu
}
\
--seed
${
seed
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
\
--profiler-options
"
${
profiler_options
}
"
\
--benchmark-batch-size
${
benchmark_batch_size
}
\
--benchmark-max-step
${
benchmark_max_step
}
else
NCCL_SOCKET_IFNAME
=
eth0 python3
-m
paddle.distributed.launch
--gpus
=
${
CUDA_VISIBLE_DEVICES
}
${
ips_config
}
${
BIN_DIR
}
/train.py
\
--ngpu
${
ngpu
}
\
--seed
${
seed
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
\
--profiler-options
"
${
profiler_options
}
"
\
--benchmark-batch-size
${
benchmark_batch_size
}
\
--benchmark-max-step
${
benchmark_max_step
}
fi
if
[
${
seed
}
!=
0
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
exit
1
fi
exit
0
examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh
浏览文件 @
d1a25f6c
...
@@ -24,7 +24,7 @@ stage=1
...
@@ -24,7 +24,7 @@ stage=1
prefix
=
prefix
=
train_subset
=
L
train_subset
=
L
.
./
too
ls/parse_options.sh
||
exit
1
;
.
./
uti
ls/parse_options.sh
||
exit
1
;
filter_by_id
()
{
filter_by_id
()
{
idlist
=
$1
idlist
=
$1
...
@@ -132,4 +132,4 @@ if [ $stage -le 2 ]; then
...
@@ -132,4 +132,4 @@ if [ $stage -le 2 ]; then
done
done
fi
fi
echo
"
$0
: Done"
echo
"
$0
: Done"
\ No newline at end of file
examples/wenetspeech/asr1/run.sh
浏览文件 @
d1a25f6c
...
@@ -7,6 +7,7 @@ gpus=0,1,2,3,4,5,6,7
...
@@ -7,6 +7,7 @@ gpus=0,1,2,3,4,5,6,7
stage
=
0
stage
=
0
stop_stage
=
100
stop_stage
=
100
conf_path
=
conf/conformer.yaml
conf_path
=
conf/conformer.yaml
ips
=
#xxx.xxx.xxx.xxx,xxx.xxx.xxx.xxx
decode_conf_path
=
conf/tuning/decode.yaml
decode_conf_path
=
conf/tuning/decode.yaml
average_checkpoint
=
true
average_checkpoint
=
true
avg_num
=
10
avg_num
=
10
...
@@ -26,7 +27,7 @@ fi
...
@@ -26,7 +27,7 @@ fi
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# train model, all `ckpt` under `exp` dir
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/train.sh
${
conf_path
}
${
ckpt
}
CUDA_VISIBLE_DEVICES
=
${
gpus
}
./local/train.sh
${
conf_path
}
${
ckpt
}
${
ips
}
fi
fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
...
...
paddlespeech/audio/streamdata/__init__.py
0 → 100644
浏览文件 @
d1a25f6c
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
# flake8: noqa
from
.cache
import
(
cached_tarfile_samples
,
cached_tarfile_to_samples
,
lru_cleanup
,
pipe_cleaner
,
)
from
.compat
import
WebDataset
,
WebLoader
,
FluidWrapper
from
.extradatasets
import
MockDataset
,
with_epoch
,
with_length
from
.filters
import
(
associate
,
batched
,
decode
,
detshuffle
,
extract_keys
,
getfirst
,
info
,
map
,
map_dict
,
map_tuple
,
pipelinefilter
,
rename
,
rename_keys
,
audio_resample
,
select
,
shuffle
,
slice
,
to_tuple
,
transform_with
,
unbatched
,
xdecode
,
audio_data_filter
,
audio_tokenize
,
audio_resample
,
audio_compute_fbank
,
audio_spec_aug
,
sort
,
audio_padding
,
audio_cmvn
,
placeholder
,
)
from
.handlers
import
(
ignore_and_continue
,
ignore_and_stop
,
reraise_exception
,
warn_and_continue
,
warn_and_stop
,
)
from
.pipeline
import
DataPipeline
from
.shardlists
import
(
MultiShardSample
,
ResampledShards
,
SimpleShardList
,
non_empty
,
resampled
,
shardspec
,
single_node_only
,
split_by_node
,
split_by_worker
,
)
from
.tariterators
import
tarfile_samples
,
tarfile_to_samples
from
.utils
import
PipelineStage
,
repeatedly
from
.writer
import
ShardWriter
,
TarWriter
,
numpy_dumps
from
.mix
import
RandomMix
,
RoundRobin
paddlespeech/audio/streamdata/autodecode.py
0 → 100644
浏览文件 @
d1a25f6c
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Automatically decode webdataset samples."""
import
io
,
json
,
os
,
pickle
,
re
,
tempfile
from
functools
import
partial
import
numpy
as
np
"""Extensions passed on to the image decoder."""
image_extensions
=
"jpg jpeg png ppm pgm pbm pnm"
.
split
()
################################################################
# handle basic datatypes
################################################################
def
paddle_loads
(
data
):
"""Load data using paddle.loads, importing paddle only if needed.
:param data: data to be decoded
"""
import
io
import
paddle
stream
=
io
.
BytesIO
(
data
)
return
paddle
.
load
(
stream
)
def
tenbin_loads
(
data
):
from
.
import
tenbin
return
tenbin
.
decode_buffer
(
data
)
def
msgpack_loads
(
data
):
import
msgpack
return
msgpack
.
unpackb
(
data
)
def
npy_loads
(
data
):
import
numpy.lib.format
stream
=
io
.
BytesIO
(
data
)
return
numpy
.
lib
.
format
.
read_array
(
stream
)
def
cbor_loads
(
data
):
import
cbor
return
cbor
.
loads
(
data
)
decoders
=
{
"txt"
:
lambda
data
:
data
.
decode
(
"utf-8"
),
"text"
:
lambda
data
:
data
.
decode
(
"utf-8"
),
"transcript"
:
lambda
data
:
data
.
decode
(
"utf-8"
),
"cls"
:
lambda
data
:
int
(
data
),
"cls2"
:
lambda
data
:
int
(
data
),
"index"
:
lambda
data
:
int
(
data
),
"inx"
:
lambda
data
:
int
(
data
),
"id"
:
lambda
data
:
int
(
data
),
"json"
:
lambda
data
:
json
.
loads
(
data
),
"jsn"
:
lambda
data
:
json
.
loads
(
data
),
"pyd"
:
lambda
data
:
pickle
.
loads
(
data
),
"pickle"
:
lambda
data
:
pickle
.
loads
(
data
),
"pdparams"
:
lambda
data
:
paddle_loads
(
data
),
"ten"
:
tenbin_loads
,
"tb"
:
tenbin_loads
,
"mp"
:
msgpack_loads
,
"msg"
:
msgpack_loads
,
"npy"
:
npy_loads
,
"npz"
:
lambda
data
:
np
.
load
(
io
.
BytesIO
(
data
)),
"cbor"
:
cbor_loads
,
}
def
basichandlers
(
key
,
data
):
"""Handle basic file decoding.
This function is usually part of the post= decoders.
This handles the following forms of decoding:
- txt -> unicode string
- cls cls2 class count index inx id -> int
- json jsn -> JSON decoding
- pyd pickle -> pickle decoding
- pdparams -> paddle.loads
- ten tenbin -> fast tensor loading
- mp messagepack msg -> messagepack decoding
- npy -> Python NPY decoding
:param key: file name extension
:param data: binary data to be decoded
"""
extension
=
re
.
sub
(
r
".*[.]"
,
""
,
key
)
if
extension
in
decoders
:
return
decoders
[
extension
](
data
)
return
None
################################################################
# Generic extension handler.
################################################################
def
call_extension_handler
(
key
,
data
,
f
,
extensions
):
"""Call the function f with the given data if the key matches the extensions.
:param key: actual key found in the sample
:param data: binary data
:param f: decoder function
:param extensions: list of matching extensions
"""
extension
=
key
.
lower
().
split
(
"."
)
for
target
in
extensions
:
target
=
target
.
split
(
"."
)
if
len
(
target
)
>
len
(
extension
):
continue
if
extension
[
-
len
(
target
)
:]
==
target
:
return
f
(
data
)
return
None
def
handle_extension
(
extensions
,
f
):
"""Return a decoder function for the list of extensions.
Extensions can be a space separated list of extensions.
Extensions can contain dots, in which case the corresponding number
of extension components must be present in the key given to f.
Comparisons are case insensitive.
Examples:
handle_extension("jpg jpeg", my_decode_jpg) # invoked for any file.jpg
handle_extension("seg.jpg", special_case_jpg) # invoked only for file.seg.jpg
"""
extensions
=
extensions
.
lower
().
split
()
return
partial
(
call_extension_handler
,
f
=
f
,
extensions
=
extensions
)
################################################################
# handle images
################################################################
imagespecs
=
{
"l8"
:
(
"numpy"
,
"uint8"
,
"l"
),
"rgb8"
:
(
"numpy"
,
"uint8"
,
"rgb"
),
"rgba8"
:
(
"numpy"
,
"uint8"
,
"rgba"
),
"l"
:
(
"numpy"
,
"float"
,
"l"
),
"rgb"
:
(
"numpy"
,
"float"
,
"rgb"
),
"rgba"
:
(
"numpy"
,
"float"
,
"rgba"
),
"paddlel8"
:
(
"paddle"
,
"uint8"
,
"l"
),
"paddlergb8"
:
(
"paddle"
,
"uint8"
,
"rgb"
),
"paddlergba8"
:
(
"paddle"
,
"uint8"
,
"rgba"
),
"paddlel"
:
(
"paddle"
,
"float"
,
"l"
),
"paddlergb"
:
(
"paddle"
,
"float"
,
"rgb"
),
"paddle"
:
(
"paddle"
,
"float"
,
"rgb"
),
"paddlergba"
:
(
"paddle"
,
"float"
,
"rgba"
),
"pill"
:
(
"pil"
,
None
,
"l"
),
"pil"
:
(
"pil"
,
None
,
"rgb"
),
"pilrgb"
:
(
"pil"
,
None
,
"rgb"
),
"pilrgba"
:
(
"pil"
,
None
,
"rgba"
),
}
class
ImageHandler
:
"""Decode image data using the given `imagespec`.
The `imagespec` specifies whether the image is decoded
to numpy/paddle/pi, decoded to uint8/float, and decoded
to l/rgb/rgba:
- l8: numpy uint8 l
- rgb8: numpy uint8 rgb
- rgba8: numpy uint8 rgba
- l: numpy float l
- rgb: numpy float rgb
- rgba: numpy float rgba
- paddlel8: paddle uint8 l
- paddlergb8: paddle uint8 rgb
- paddlergba8: paddle uint8 rgba
- paddlel: paddle float l
- paddlergb: paddle float rgb
- paddle: paddle float rgb
- paddlergba: paddle float rgba
- pill: pil None l
- pil: pil None rgb
- pilrgb: pil None rgb
- pilrgba: pil None rgba
"""
def
__init__
(
self
,
imagespec
,
extensions
=
image_extensions
):
"""Create an image handler.
:param imagespec: short string indicating the type of decoding
:param extensions: list of extensions the image handler is invoked for
"""
if
imagespec
not
in
list
(
imagespecs
.
keys
()):
raise
ValueError
(
"Unknown imagespec: %s"
%
imagespec
)
self
.
imagespec
=
imagespec
.
lower
()
self
.
extensions
=
extensions
def
__call__
(
self
,
key
,
data
):
"""Perform image decoding.
:param key: file name extension
:param data: binary data
"""
import
PIL.Image
extension
=
re
.
sub
(
r
".*[.]"
,
""
,
key
)
if
extension
.
lower
()
not
in
self
.
extensions
:
return
None
imagespec
=
self
.
imagespec
atype
,
etype
,
mode
=
imagespecs
[
imagespec
]
with
io
.
BytesIO
(
data
)
as
stream
:
img
=
PIL
.
Image
.
open
(
stream
)
img
.
load
()
img
=
img
.
convert
(
mode
.
upper
())
if
atype
==
"pil"
:
return
img
elif
atype
==
"numpy"
:
result
=
np
.
asarray
(
img
)
if
result
.
dtype
!=
np
.
uint8
:
raise
ValueError
(
"ImageHandler: numpy image must be uint8"
)
if
etype
==
"uint8"
:
return
result
else
:
return
result
.
astype
(
"f"
)
/
255.0
elif
atype
==
"paddle"
:
import
paddle
result
=
np
.
asarray
(
img
)
if
result
.
dtype
!=
np
.
uint8
:
raise
ValueError
(
"ImageHandler: paddle image must be uint8"
)
if
etype
==
"uint8"
:
result
=
np
.
array
(
result
.
transpose
(
2
,
0
,
1
))
return
paddle
.
tensor
(
result
)
else
:
result
=
np
.
array
(
result
.
transpose
(
2
,
0
,
1
))
return
paddle
.
tensor
(
result
)
/
255.0
return
None
def
imagehandler
(
imagespec
,
extensions
=
image_extensions
):
"""Create an image handler.
This is just a lower case alias for ImageHander.
:param imagespec: textual image spec
:param extensions: list of extensions the handler should be applied for
"""
return
ImageHandler
(
imagespec
,
extensions
)
################################################################
# torch video
################################################################
'''
def torch_video(key, data):
"""Decode video using the torchvideo library.
:param key: file name extension
:param data: data to be decoded
"""
extension = re.sub(r".*[.]", "", key)
if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split():
return None
import torchvision.io
with tempfile.TemporaryDirectory() as dirname:
fname = os.path.join(dirname, f"file.{extension}")
with open(fname, "wb") as stream:
stream.write(data)
return torchvision.io.read_video(fname, pts_unit="sec")
'''
################################################################
# paddlespeech.audio
################################################################
def
paddle_audio
(
key
,
data
):
"""Decode audio using the paddlespeech.audio library.
:param key: file name extension
:param data: data to be decoded
"""
extension
=
re
.
sub
(
r
".*[.]"
,
""
,
key
)
if
extension
not
in
[
"flac"
,
"mp3"
,
"sox"
,
"wav"
,
"m4a"
,
"ogg"
,
"wma"
]:
return
None
import
paddlespeech.audio
with
tempfile
.
TemporaryDirectory
()
as
dirname
:
fname
=
os
.
path
.
join
(
dirname
,
f
"file.
{
extension
}
"
)
with
open
(
fname
,
"wb"
)
as
stream
:
stream
.
write
(
data
)
return
paddlespeech
.
audio
.
load
(
fname
)
################################################################
# special class for continuing decoding
################################################################
class
Continue
:
"""Special class for continuing decoding.
This is mostly used for decompression, as in:
def decompressor(key, data):
if key.endswith(".gz"):
return Continue(key[:-3], decompress(data))
return None
"""
def
__init__
(
self
,
key
,
data
):
"""__init__.
:param key:
:param data:
"""
self
.
key
,
self
.
data
=
key
,
data
def
gzfilter
(
key
,
data
):
"""Decode .gz files.
This decodes compressed files and the continues decoding.
:param key: file name extension
:param data: binary data
"""
import
gzip
if
not
key
.
endswith
(
".gz"
):
return
None
decompressed
=
gzip
.
open
(
io
.
BytesIO
(
data
)).
read
()
return
Continue
(
key
[:
-
3
],
decompressed
)
################################################################
# decode entire training amples
################################################################
default_pre_handlers
=
[
gzfilter
]
default_post_handlers
=
[
basichandlers
]
class
Decoder
:
"""Decode samples using a list of handlers.
For each key/data item, this iterates through the list of
handlers until some handler returns something other than None.
"""
def
__init__
(
self
,
handlers
,
pre
=
None
,
post
=
None
,
only
=
None
,
partial
=
False
):
"""Create a Decoder.
:param handlers: main list of handlers
:param pre: handlers called before the main list (.gz handler by default)
:param post: handlers called after the main list (default handlers by default)
:param only: a list of extensions; when give, only ignores files with those extensions
:param partial: allow partial decoding (i.e., don't decode fields that aren't of type bytes)
"""
if
isinstance
(
only
,
str
):
only
=
only
.
split
()
self
.
only
=
only
if
only
is
None
else
set
(
only
)
if
pre
is
None
:
pre
=
default_pre_handlers
if
post
is
None
:
post
=
default_post_handlers
assert
all
(
callable
(
h
)
for
h
in
handlers
),
f
"one of
{
handlers
}
not callable"
assert
all
(
callable
(
h
)
for
h
in
pre
),
f
"one of
{
pre
}
not callable"
assert
all
(
callable
(
h
)
for
h
in
post
),
f
"one of
{
post
}
not callable"
self
.
handlers
=
pre
+
handlers
+
post
self
.
partial
=
partial
def
decode1
(
self
,
key
,
data
):
"""Decode a single field of a sample.
:param key: file name extension
:param data: binary data
"""
key
=
"."
+
key
for
f
in
self
.
handlers
:
result
=
f
(
key
,
data
)
if
isinstance
(
result
,
Continue
):
key
,
data
=
result
.
key
,
result
.
data
continue
if
result
is
not
None
:
return
result
return
data
def
decode
(
self
,
sample
):
"""Decode an entire sample.
:param sample: the sample, a dictionary of key value pairs
"""
result
=
{}
assert
isinstance
(
sample
,
dict
),
sample
for
k
,
v
in
list
(
sample
.
items
()):
if
k
[
0
]
==
"_"
:
if
isinstance
(
v
,
bytes
):
v
=
v
.
decode
(
"utf-8"
)
result
[
k
]
=
v
continue
if
self
.
only
is
not
None
and
k
not
in
self
.
only
:
result
[
k
]
=
v
continue
assert
v
is
not
None
if
self
.
partial
:
if
isinstance
(
v
,
bytes
):
result
[
k
]
=
self
.
decode1
(
k
,
v
)
else
:
result
[
k
]
=
v
else
:
assert
isinstance
(
v
,
bytes
)
result
[
k
]
=
self
.
decode1
(
k
,
v
)
return
result
def
__call__
(
self
,
sample
):
"""Decode an entire sample.
:param sample: the sample
"""
assert
isinstance
(
sample
,
dict
),
(
len
(
sample
),
sample
)
return
self
.
decode
(
sample
)
paddlespeech/audio/streamdata/cache.py
0 → 100644
浏览文件 @
d1a25f6c
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
import
itertools
,
os
,
random
,
re
,
sys
from
urllib.parse
import
urlparse
from
.
import
filters
from
.
import
gopen
from
.handlers
import
reraise_exception
from
.tariterators
import
tar_file_and_group_expander
default_cache_dir
=
os
.
environ
.
get
(
"WDS_CACHE"
,
"./_cache"
)
default_cache_size
=
float
(
os
.
environ
.
get
(
"WDS_CACHE_SIZE"
,
"1e18"
))
def
lru_cleanup
(
cache_dir
,
cache_size
,
keyfn
=
os
.
path
.
getctime
,
verbose
=
False
):
"""Performs cleanup of the file cache in cache_dir using an LRU strategy,
keeping the total size of all remaining files below cache_size."""
if
not
os
.
path
.
exists
(
cache_dir
):
return
total_size
=
0
for
dirpath
,
dirnames
,
filenames
in
os
.
walk
(
cache_dir
):
for
filename
in
filenames
:
total_size
+=
os
.
path
.
getsize
(
os
.
path
.
join
(
dirpath
,
filename
))
if
total_size
<=
cache_size
:
return
# sort files by last access time
files
=
[]
for
dirpath
,
dirnames
,
filenames
in
os
.
walk
(
cache_dir
):
for
filename
in
filenames
:
files
.
append
(
os
.
path
.
join
(
dirpath
,
filename
))
files
.
sort
(
key
=
keyfn
,
reverse
=
True
)
# delete files until we're under the cache size
while
len
(
files
)
>
0
and
total_size
>
cache_size
:
fname
=
files
.
pop
()
total_size
-=
os
.
path
.
getsize
(
fname
)
if
verbose
:
print
(
"# deleting %s"
%
fname
,
file
=
sys
.
stderr
)
os
.
remove
(
fname
)
def
download
(
url
,
dest
,
chunk_size
=
1024
**
2
,
verbose
=
False
):
"""Download a file from `url` to `dest`."""
temp
=
dest
+
f
".temp
{
os
.
getpid
()
}
"
with
gopen
.
gopen
(
url
)
as
stream
:
with
open
(
temp
,
"wb"
)
as
f
:
while
True
:
data
=
stream
.
read
(
chunk_size
)
if
not
data
:
break
f
.
write
(
data
)
os
.
rename
(
temp
,
dest
)
def
pipe_cleaner
(
spec
):
"""Guess the actual URL from a "pipe:" specification."""
if
spec
.
startswith
(
"pipe:"
):
spec
=
spec
[
5
:]
words
=
spec
.
split
(
" "
)
for
word
in
words
:
if
re
.
match
(
r
"^(https?|gs|ais|s3)"
,
word
):
return
word
return
spec
def
get_file_cached
(
spec
,
cache_size
=-
1
,
cache_dir
=
None
,
url_to_name
=
pipe_cleaner
,
verbose
=
False
,
):
if
cache_size
==
-
1
:
cache_size
=
default_cache_size
if
cache_dir
is
None
:
cache_dir
=
default_cache_dir
url
=
url_to_name
(
spec
)
parsed
=
urlparse
(
url
)
dirname
,
filename
=
os
.
path
.
split
(
parsed
.
path
)
dirname
=
dirname
.
lstrip
(
"/"
)
dirname
=
re
.
sub
(
r
"[:/|;]"
,
"_"
,
dirname
)
destdir
=
os
.
path
.
join
(
cache_dir
,
dirname
)
os
.
makedirs
(
destdir
,
exist_ok
=
True
)
dest
=
os
.
path
.
join
(
cache_dir
,
dirname
,
filename
)
if
not
os
.
path
.
exists
(
dest
):
if
verbose
:
print
(
"# downloading %s to %s"
%
(
url
,
dest
),
file
=
sys
.
stderr
)
lru_cleanup
(
cache_dir
,
cache_size
,
verbose
=
verbose
)
download
(
spec
,
dest
,
verbose
=
verbose
)
return
dest
def
get_filetype
(
fname
):
with
os
.
popen
(
"file '%s'"
%
fname
)
as
f
:
ftype
=
f
.
read
()
return
ftype
def
check_tar_format
(
fname
):
"""Check whether a file is a tar archive."""
ftype
=
get_filetype
(
fname
)
return
"tar archive"
in
ftype
or
"gzip compressed"
in
ftype
verbose_cache
=
int
(
os
.
environ
.
get
(
"WDS_VERBOSE_CACHE"
,
"0"
))
def
cached_url_opener
(
data
,
handler
=
reraise_exception
,
cache_size
=-
1
,
cache_dir
=
None
,
url_to_name
=
pipe_cleaner
,
validator
=
check_tar_format
,
verbose
=
False
,
always
=
False
,
):
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
verbose
=
verbose
or
verbose_cache
for
sample
in
data
:
assert
isinstance
(
sample
,
dict
),
sample
assert
"url"
in
sample
url
=
sample
[
"url"
]
attempts
=
5
try
:
if
not
always
and
os
.
path
.
exists
(
url
):
dest
=
url
else
:
dest
=
get_file_cached
(
url
,
cache_size
=
cache_size
,
cache_dir
=
cache_dir
,
url_to_name
=
url_to_name
,
verbose
=
verbose
,
)
if
verbose
:
print
(
"# opening %s"
%
dest
,
file
=
sys
.
stderr
)
assert
os
.
path
.
exists
(
dest
)
if
not
validator
(
dest
):
ftype
=
get_filetype
(
dest
)
with
open
(
dest
,
"rb"
)
as
f
:
data
=
f
.
read
(
200
)
os
.
remove
(
dest
)
raise
ValueError
(
"%s (%s) is not a tar archive, but a %s, contains %s"
%
(
dest
,
url
,
ftype
,
repr
(
data
))
)
try
:
stream
=
open
(
dest
,
"rb"
)
sample
.
update
(
stream
=
stream
)
yield
sample
except
FileNotFoundError
as
exn
:
# dealing with race conditions in lru_cleanup
attempts
-=
1
if
attempts
>
0
:
time
.
sleep
(
random
.
random
()
*
10
)
continue
raise
exn
except
Exception
as
exn
:
exn
.
args
=
exn
.
args
+
(
url
,)
if
handler
(
exn
):
continue
else
:
break
def
cached_tarfile_samples
(
src
,
handler
=
reraise_exception
,
cache_size
=-
1
,
cache_dir
=
None
,
verbose
=
False
,
url_to_name
=
pipe_cleaner
,
always
=
False
,
):
streams
=
cached_url_opener
(
src
,
handler
=
handler
,
cache_size
=
cache_size
,
cache_dir
=
cache_dir
,
verbose
=
verbose
,
url_to_name
=
url_to_name
,
always
=
always
,
)
samples
=
tar_file_and_group_expander
(
streams
,
handler
=
handler
)
return
samples
cached_tarfile_to_samples
=
filters
.
pipelinefilter
(
cached_tarfile_samples
)
paddlespeech/audio/streamdata/compat.py
0 → 100644
浏览文件 @
d1a25f6c
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
from
dataclasses
import
dataclass
from
itertools
import
islice
from
typing
import
List
import
braceexpand
,
yaml
from
.
import
autodecode
from
.
import
cache
,
filters
,
shardlists
,
tariterators
from
.filters
import
reraise_exception
from
.pipeline
import
DataPipeline
from
.paddle_utils
import
DataLoader
,
IterableDataset
class
FluidInterface
:
def
batched
(
self
,
batchsize
):
return
self
.
compose
(
filters
.
batched
(
batchsize
))
def
dynamic_batched
(
self
,
max_frames_in_batch
):
return
self
.
compose
(
filter
.
dynamic_batched
(
max_frames_in_batch
))
def
unbatched
(
self
):
return
self
.
compose
(
filters
.
unbatched
())
def
listed
(
self
,
batchsize
,
partial
=
True
):
return
self
.
compose
(
filters
.
batched
(),
batchsize
=
batchsize
,
collation_fn
=
None
)
def
unlisted
(
self
):
return
self
.
compose
(
filters
.
unlisted
())
def
log_keys
(
self
,
logfile
=
None
):
return
self
.
compose
(
filters
.
log_keys
(
logfile
))
def
shuffle
(
self
,
size
,
**
kw
):
if
size
<
1
:
return
self
else
:
return
self
.
compose
(
filters
.
shuffle
(
size
,
**
kw
))
def
map
(
self
,
f
,
handler
=
reraise_exception
):
return
self
.
compose
(
filters
.
map
(
f
,
handler
=
handler
))
def
decode
(
self
,
*
args
,
pre
=
None
,
post
=
None
,
only
=
None
,
partial
=
False
,
handler
=
reraise_exception
):
handlers
=
[
autodecode
.
ImageHandler
(
x
)
if
isinstance
(
x
,
str
)
else
x
for
x
in
args
]
decoder
=
autodecode
.
Decoder
(
handlers
,
pre
=
pre
,
post
=
post
,
only
=
only
,
partial
=
partial
)
return
self
.
map
(
decoder
,
handler
=
handler
)
def
map_dict
(
self
,
handler
=
reraise_exception
,
**
kw
):
return
self
.
compose
(
filters
.
map_dict
(
handler
=
handler
,
**
kw
))
def
select
(
self
,
predicate
,
**
kw
):
return
self
.
compose
(
filters
.
select
(
predicate
,
**
kw
))
def
to_tuple
(
self
,
*
args
,
handler
=
reraise_exception
):
return
self
.
compose
(
filters
.
to_tuple
(
*
args
,
handler
=
handler
))
def
map_tuple
(
self
,
*
args
,
handler
=
reraise_exception
):
return
self
.
compose
(
filters
.
map_tuple
(
*
args
,
handler
=
handler
))
def
slice
(
self
,
*
args
):
return
self
.
compose
(
filters
.
slice
(
*
args
))
def
rename
(
self
,
**
kw
):
return
self
.
compose
(
filters
.
rename
(
**
kw
))
def
rsample
(
self
,
p
=
0.5
):
return
self
.
compose
(
filters
.
rsample
(
p
))
def
rename_keys
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
rename_keys
(
*
args
,
**
kw
))
def
extract_keys
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
extract_keys
(
*
args
,
**
kw
))
def
xdecode
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
xdecode
(
*
args
,
**
kw
))
def
audio_data_filter
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
audio_data_filter
(
*
args
,
**
kw
))
def
audio_tokenize
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
audio_tokenize
(
*
args
,
**
kw
))
def
resample
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
resample
(
*
args
,
**
kw
))
def
audio_compute_fbank
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
audio_compute_fbank
(
*
args
,
**
kw
))
def
audio_spec_aug
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
audio_spec_aug
(
*
args
,
**
kw
))
def
sort
(
self
,
size
=
500
):
return
self
.
compose
(
filters
.
sort
(
size
))
def
audio_padding
(
self
):
return
self
.
compose
(
filters
.
audio_padding
())
def
audio_cmvn
(
self
,
cmvn_file
):
return
self
.
compose
(
filters
.
audio_cmvn
(
cmvn_file
))
class
WebDataset
(
DataPipeline
,
FluidInterface
):
"""Small fluid-interface wrapper for DataPipeline."""
def
__init__
(
self
,
urls
,
handler
=
reraise_exception
,
resampled
=
False
,
repeat
=
False
,
shardshuffle
=
None
,
cache_size
=
0
,
cache_dir
=
None
,
detshuffle
=
False
,
nodesplitter
=
shardlists
.
single_node_only
,
verbose
=
False
,
):
super
().
__init__
()
if
isinstance
(
urls
,
IterableDataset
):
assert
not
resampled
self
.
append
(
urls
)
elif
isinstance
(
urls
,
str
)
and
(
urls
.
endswith
(
".yaml"
)
or
urls
.
endswith
(
".yml"
)):
with
(
open
(
urls
))
as
stream
:
spec
=
yaml
.
safe_load
(
stream
)
assert
"datasets"
in
spec
self
.
append
(
shardlists
.
MultiShardSample
(
spec
))
elif
isinstance
(
urls
,
dict
):
assert
"datasets"
in
urls
self
.
append
(
shardlists
.
MultiShardSample
(
urls
))
elif
resampled
:
self
.
append
(
shardlists
.
ResampledShards
(
urls
))
else
:
self
.
append
(
shardlists
.
SimpleShardList
(
urls
))
self
.
append
(
nodesplitter
)
self
.
append
(
shardlists
.
split_by_worker
)
if
shardshuffle
is
True
:
shardshuffle
=
100
if
shardshuffle
is
not
None
:
if
detshuffle
:
self
.
append
(
filters
.
detshuffle
(
shardshuffle
))
else
:
self
.
append
(
filters
.
shuffle
(
shardshuffle
))
if
cache_size
==
0
:
self
.
append
(
tariterators
.
tarfile_to_samples
(
handler
=
handler
))
else
:
assert
cache_size
==
-
1
or
cache_size
>
0
self
.
append
(
cache
.
cached_tarfile_to_samples
(
handler
=
handler
,
verbose
=
verbose
,
cache_size
=
cache_size
,
cache_dir
=
cache_dir
,
)
)
class
FluidWrapper
(
DataPipeline
,
FluidInterface
):
"""Small fluid-interface wrapper for DataPipeline."""
def
__init__
(
self
,
initial
):
super
().
__init__
()
self
.
append
(
initial
)
class
WebLoader
(
DataPipeline
,
FluidInterface
):
def
__init__
(
self
,
*
args
,
**
kw
):
super
().
__init__
(
DataLoader
(
*
args
,
**
kw
))
paddlespeech/audio/streamdata/extradatasets.py
0 → 100644
浏览文件 @
d1a25f6c
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
"""
import
itertools
as
itt
import
os
import
random
import
sys
import
braceexpand
from
.
import
utils
from
.paddle_utils
import
IterableDataset
from
.utils
import
PipelineStage
class
MockDataset
(
IterableDataset
):
"""MockDataset.
A mock dataset for performance testing and unit testing.
"""
def
__init__
(
self
,
sample
,
length
):
"""Create a mock dataset instance.
:param sample: the sample to be returned repeatedly
:param length: the length of the mock dataset
"""
self
.
sample
=
sample
self
.
length
=
length
def
__iter__
(
self
):
"""Return an iterator over this mock dataset."""
for
i
in
range
(
self
.
length
):
yield
self
.
sample
class
repeatedly
(
IterableDataset
,
PipelineStage
):
"""Repeatedly yield samples from a dataset."""
def
__init__
(
self
,
source
,
nepochs
=
None
,
nbatches
=
None
,
length
=
None
):
"""Create an instance of Repeatedly.
:param nepochs: repeat for a maximum of nepochs
:param nbatches: repeat for a maximum of nbatches
"""
self
.
source
=
source
self
.
length
=
length
self
.
nbatches
=
nbatches
def
invoke
(
self
,
source
):
"""Return an iterator that iterates repeatedly over a source."""
return
utils
.
repeatedly
(
source
,
nepochs
=
self
.
nepochs
,
nbatches
=
self
.
nbatches
,
)
class
with_epoch
(
IterableDataset
):
"""Change the actual and nominal length of an IterableDataset.
This will continuously iterate through the original dataset, but
impose new epoch boundaries at the given length/nominal.
This exists mainly as a workaround for the odd logic in DataLoader.
It is also useful for choosing smaller nominal epoch sizes with
very large datasets.
"""
def
__init__
(
self
,
dataset
,
length
):
"""Chop the dataset to the given length.
:param dataset: IterableDataset
:param length: declared length of the dataset
:param nominal: nominal length of dataset (if different from declared)
"""
super
().
__init__
()
self
.
length
=
length
self
.
source
=
None
def
__getstate__
(
self
):
"""Return the pickled state of the dataset.
This resets the dataset iterator, since that can't be pickled.
"""
result
=
dict
(
self
.
__dict__
)
result
[
"source"
]
=
None
return
result
def
invoke
(
self
,
dataset
):
"""Return an iterator over the dataset.
This iterator returns as many samples as given by the `length`
parameter.
"""
if
self
.
source
is
None
:
self
.
source
=
iter
(
dataset
)
for
i
in
range
(
self
.
length
):
try
:
sample
=
next
(
self
.
source
)
except
StopIteration
:
self
.
source
=
iter
(
dataset
)
try
:
sample
=
next
(
self
.
source
)
except
StopIteration
:
return
yield
sample
self
.
source
=
None
class
with_length
(
IterableDataset
,
PipelineStage
):
"""Repeatedly yield samples from a dataset."""
def
__init__
(
self
,
dataset
,
length
):
"""Create an instance of Repeatedly.
:param dataset: source dataset
:param length: stated length
"""
super
().
__init__
()
self
.
dataset
=
dataset
self
.
length
=
length
def
invoke
(
self
,
dataset
):
"""Return an iterator that iterates repeatedly over a source."""
return
iter
(
dataset
)
def
__len__
(
self
):
"""Return the user specified length."""
return
self
.
length
paddlespeech/audio/streamdata/filters.py
0 → 100644
浏览文件 @
d1a25f6c
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""A collection of iterators for data transformations.
These functions are plain iterator functions. You can find curried versions
in webdataset.filters, and you can find IterableDataset wrappers in
webdataset.processing.
"""
import
io
from
fnmatch
import
fnmatch
import
re
import
itertools
,
os
,
random
,
sys
,
time
from
functools
import
reduce
,
wraps
import
numpy
as
np
from
.
import
autodecode
from
.
import
utils
from
.paddle_utils
import
PaddleTensor
from
.utils
import
PipelineStage
from
..
import
backends
from
..compliance
import
kaldi
import
paddle
from
..transform.cmvn
import
GlobalCMVN
from
..utils.tensor_utils
import
pad_sequence
from
..transform.spec_augment
import
time_warp
from
..transform.spec_augment
import
time_mask
from
..transform.spec_augment
import
freq_mask
class
FilterFunction
(
object
):
"""Helper class for currying pipeline stages.
We use this roundabout construct becauce it can be pickled.
"""
def
__init__
(
self
,
f
,
*
args
,
**
kw
):
"""Create a curried function."""
self
.
f
=
f
self
.
args
=
args
self
.
kw
=
kw
def
__call__
(
self
,
data
):
"""Call the curried function with the given argument."""
return
self
.
f
(
data
,
*
self
.
args
,
**
self
.
kw
)
def
__str__
(
self
):
"""Compute a string representation."""
return
f
"<
{
self
.
f
.
__name__
}
{
self
.
args
}
{
self
.
kw
}
>"
def
__repr__
(
self
):
"""Compute a string representation."""
return
f
"<
{
self
.
f
.
__name__
}
{
self
.
args
}
{
self
.
kw
}
>"
class
RestCurried
(
object
):
"""Helper class for currying pipeline stages.
We use this roundabout construct because it can be pickled.
"""
def
__init__
(
self
,
f
):
"""Store the function for future currying."""
self
.
f
=
f
def
__call__
(
self
,
*
args
,
**
kw
):
"""Curry with the given arguments."""
return
FilterFunction
(
self
.
f
,
*
args
,
**
kw
)
def
pipelinefilter
(
f
):
"""Turn the decorated function into one that is partially applied for
all arguments other than the first."""
result
=
RestCurried
(
f
)
return
result
def
reraise_exception
(
exn
):
"""Reraises the given exception; used as a handler.
:param exn: exception
"""
raise
exn
def
identity
(
x
):
"""Return the argument."""
return
x
def
compose2
(
f
,
g
):
"""Compose two functions, g(f(x))."""
return
lambda
x
:
g
(
f
(
x
))
def
compose
(
*
args
):
"""Compose a sequence of functions (left-to-right)."""
return
reduce
(
compose2
,
args
)
def
pipeline
(
source
,
*
args
):
"""Write an input pipeline; first argument is source, rest are filters."""
if
len
(
args
)
==
0
:
return
source
return
compose
(
*
args
)(
source
)
def
getfirst
(
a
,
keys
,
default
=
None
,
missing_is_error
=
True
):
"""Get the first matching key from a dictionary.
Keys can be specified as a list, or as a string of keys separated by ';'.
"""
if
isinstance
(
keys
,
str
):
assert
" "
not
in
keys
keys
=
keys
.
split
(
";"
)
for
k
in
keys
:
if
k
in
a
:
return
a
[
k
]
if
missing_is_error
:
raise
ValueError
(
f
"didn't find
{
keys
}
in
{
list
(
a
.
keys
())
}
"
)
return
default
def
parse_field_spec
(
fields
):
"""Parse a specification for a list of fields to be extracted.
Keys are separated by spaces in the spec. Each key can itself
be composed of key alternatives separated by ';'.
"""
if
isinstance
(
fields
,
str
):
fields
=
fields
.
split
()
return
[
field
.
split
(
";"
)
for
field
in
fields
]
def
transform_with
(
sample
,
transformers
):
"""Transform a list of values using a list of functions.
sample: list of values
transformers: list of functions
If there are fewer transformers than inputs, or if a transformer
function is None, then the identity function is used for the
corresponding sample fields.
"""
if
transformers
is
None
or
len
(
transformers
)
==
0
:
return
sample
result
=
list
(
sample
)
assert
len
(
transformers
)
<=
len
(
sample
)
for
i
in
range
(
len
(
transformers
)):
# skipcq: PYL-C0200
f
=
transformers
[
i
]
if
f
is
not
None
:
result
[
i
]
=
f
(
sample
[
i
])
return
result
###
# Iterators
###
def
_info
(
data
,
fmt
=
None
,
n
=
3
,
every
=-
1
,
width
=
50
,
stream
=
sys
.
stderr
,
name
=
""
):
"""Print information about the samples that are passing through.
:param data: source iterator
:param fmt: format statement (using sample dict as keyword)
:param n: when to stop
:param every: how often to print
:param width: maximum width
:param stream: output stream
:param name: identifier printed before any output
"""
for
i
,
sample
in
enumerate
(
data
):
if
i
<
n
or
(
every
>
0
and
(
i
+
1
)
%
every
==
0
):
if
fmt
is
None
:
print
(
"---"
,
name
,
file
=
stream
)
for
k
,
v
in
sample
.
items
():
print
(
k
,
repr
(
v
)[:
width
],
file
=
stream
)
else
:
print
(
fmt
.
format
(
**
sample
),
file
=
stream
)
yield
sample
info
=
pipelinefilter
(
_info
)
def
pick
(
buf
,
rng
):
k
=
rng
.
randint
(
0
,
len
(
buf
)
-
1
)
sample
=
buf
[
k
]
buf
[
k
]
=
buf
[
-
1
]
buf
.
pop
()
return
sample
def
_shuffle
(
data
,
bufsize
=
1000
,
initial
=
100
,
rng
=
None
,
handler
=
None
):
"""Shuffle the data in the stream.
This uses a buffer of size `bufsize`. Shuffling at
startup is less random; this is traded off against
yielding samples quickly.
data: iterator
bufsize: buffer size for shuffling
returns: iterator
rng: either random module or random.Random instance
"""
if
rng
is
None
:
rng
=
random
.
Random
(
int
((
os
.
getpid
()
+
time
.
time
())
*
1e9
))
initial
=
min
(
initial
,
bufsize
)
buf
=
[]
for
sample
in
data
:
buf
.
append
(
sample
)
if
len
(
buf
)
<
bufsize
:
try
:
buf
.
append
(
next
(
data
))
# skipcq: PYL-R1708
except
StopIteration
:
pass
if
len
(
buf
)
>=
initial
:
yield
pick
(
buf
,
rng
)
while
len
(
buf
)
>
0
:
yield
pick
(
buf
,
rng
)
shuffle
=
pipelinefilter
(
_shuffle
)
class
detshuffle
(
PipelineStage
):
def
__init__
(
self
,
bufsize
=
1000
,
initial
=
100
,
seed
=
0
,
epoch
=-
1
):
self
.
bufsize
=
bufsize
self
.
initial
=
initial
self
.
seed
=
seed
self
.
epoch
=
epoch
def
run
(
self
,
src
):
self
.
epoch
+=
1
rng
=
random
.
Random
()
rng
.
seed
((
self
.
seed
,
self
.
epoch
))
return
_shuffle
(
src
,
self
.
bufsize
,
self
.
initial
,
rng
)
def
_select
(
data
,
predicate
):
"""Select samples based on a predicate.
:param data: source iterator
:param predicate: predicate (function)
"""
for
sample
in
data
:
if
predicate
(
sample
):
yield
sample
select
=
pipelinefilter
(
_select
)
def
_log_keys
(
data
,
logfile
=
None
):
import
fcntl
if
logfile
is
None
or
logfile
==
""
:
for
sample
in
data
:
yield
sample
else
:
with
open
(
logfile
,
"a"
)
as
stream
:
for
i
,
sample
in
enumerate
(
data
):
buf
=
f
"
{
i
}
\t
{
sample
.
get
(
'__worker__'
)
}
\t
{
sample
.
get
(
'__rank__'
)
}
\t
{
sample
.
get
(
'__key__'
)
}
\n
"
try
:
fcntl
.
flock
(
stream
.
fileno
(),
fcntl
.
LOCK_EX
)
stream
.
write
(
buf
)
finally
:
fcntl
.
flock
(
stream
.
fileno
(),
fcntl
.
LOCK_UN
)
yield
sample
log_keys
=
pipelinefilter
(
_log_keys
)
def
_decode
(
data
,
*
args
,
handler
=
reraise_exception
,
**
kw
):
"""Decode data based on the decoding functions given as arguments."""
decoder
=
lambda
x
:
autodecode
.
imagehandler
(
x
)
if
isinstance
(
x
,
str
)
else
x
handlers
=
[
decoder
(
x
)
for
x
in
args
]
f
=
autodecode
.
Decoder
(
handlers
,
**
kw
)
for
sample
in
data
:
assert
isinstance
(
sample
,
dict
),
sample
try
:
decoded
=
f
(
sample
)
except
Exception
as
exn
:
# skipcq: PYL-W0703
if
handler
(
exn
):
continue
else
:
break
yield
decoded
decode
=
pipelinefilter
(
_decode
)
def
_map
(
data
,
f
,
handler
=
reraise_exception
):
"""Map samples."""
for
sample
in
data
:
try
:
result
=
f
(
sample
)
except
Exception
as
exn
:
if
handler
(
exn
):
continue
else
:
break
if
result
is
None
:
continue
if
isinstance
(
sample
,
dict
)
and
isinstance
(
result
,
dict
):
result
[
"__key__"
]
=
sample
.
get
(
"__key__"
)
yield
result
map
=
pipelinefilter
(
_map
)
def
_rename
(
data
,
handler
=
reraise_exception
,
keep
=
True
,
**
kw
):
"""Rename samples based on keyword arguments."""
for
sample
in
data
:
try
:
if
not
keep
:
yield
{
k
:
getfirst
(
sample
,
v
,
missing_is_error
=
True
)
for
k
,
v
in
kw
.
items
()}
else
:
def
listify
(
v
):
return
v
.
split
(
";"
)
if
isinstance
(
v
,
str
)
else
v
to_be_replaced
=
{
x
for
v
in
kw
.
values
()
for
x
in
listify
(
v
)}
result
=
{
k
:
v
for
k
,
v
in
sample
.
items
()
if
k
not
in
to_be_replaced
}
result
.
update
({
k
:
getfirst
(
sample
,
v
,
missing_is_error
=
True
)
for
k
,
v
in
kw
.
items
()})
yield
result
except
Exception
as
exn
:
if
handler
(
exn
):
continue
else
:
break
rename
=
pipelinefilter
(
_rename
)
def
_associate
(
data
,
associator
,
**
kw
):
"""Associate additional data with samples."""
for
sample
in
data
:
if
callable
(
associator
):
extra
=
associator
(
sample
[
"__key__"
])
else
:
extra
=
associator
.
get
(
sample
[
"__key__"
],
{})
sample
.
update
(
extra
)
# destructive
yield
sample
associate
=
pipelinefilter
(
_associate
)
def
_map_dict
(
data
,
handler
=
reraise_exception
,
**
kw
):
"""Map the entries in a dict sample with individual functions."""
assert
len
(
list
(
kw
.
keys
()))
>
0
for
key
,
f
in
kw
.
items
():
assert
callable
(
f
),
(
key
,
f
)
for
sample
in
data
:
assert
isinstance
(
sample
,
dict
)
try
:
for
k
,
f
in
kw
.
items
():
sample
[
k
]
=
f
(
sample
[
k
])
except
Exception
as
exn
:
if
handler
(
exn
):
continue
else
:
break
yield
sample
map_dict
=
pipelinefilter
(
_map_dict
)
def
_to_tuple
(
data
,
*
args
,
handler
=
reraise_exception
,
missing_is_error
=
True
,
none_is_error
=
None
):
"""Convert dict samples to tuples."""
if
none_is_error
is
None
:
none_is_error
=
missing_is_error
if
len
(
args
)
==
1
and
isinstance
(
args
[
0
],
str
)
and
" "
in
args
[
0
]:
args
=
args
[
0
].
split
()
for
sample
in
data
:
try
:
result
=
tuple
([
getfirst
(
sample
,
f
,
missing_is_error
=
missing_is_error
)
for
f
in
args
])
if
none_is_error
and
any
(
x
is
None
for
x
in
result
):
raise
ValueError
(
f
"to_tuple
{
args
}
got
{
sample
.
keys
()
}
"
)
yield
result
except
Exception
as
exn
:
if
handler
(
exn
):
continue
else
:
break
to_tuple
=
pipelinefilter
(
_to_tuple
)
def
_map_tuple
(
data
,
*
args
,
handler
=
reraise_exception
):
"""Map the entries of a tuple with individual functions."""
args
=
[
f
if
f
is
not
None
else
utils
.
identity
for
f
in
args
]
for
f
in
args
:
assert
callable
(
f
),
f
for
sample
in
data
:
assert
isinstance
(
sample
,
(
list
,
tuple
))
sample
=
list
(
sample
)
n
=
min
(
len
(
args
),
len
(
sample
))
try
:
for
i
in
range
(
n
):
sample
[
i
]
=
args
[
i
](
sample
[
i
])
except
Exception
as
exn
:
if
handler
(
exn
):
continue
else
:
break
yield
tuple
(
sample
)
map_tuple
=
pipelinefilter
(
_map_tuple
)
def
_unlisted
(
data
):
"""Turn batched data back into unbatched data."""
for
batch
in
data
:
assert
isinstance
(
batch
,
list
),
sample
for
sample
in
batch
:
yield
sample
unlisted
=
pipelinefilter
(
_unlisted
)
def
_unbatched
(
data
):
"""Turn batched data back into unbatched data."""
for
sample
in
data
:
assert
isinstance
(
sample
,
(
tuple
,
list
)),
sample
assert
len
(
sample
)
>
0
for
i
in
range
(
len
(
sample
[
0
])):
yield
tuple
(
x
[
i
]
for
x
in
sample
)
unbatched
=
pipelinefilter
(
_unbatched
)
def
_rsample
(
data
,
p
=
0.5
):
"""Randomly subsample a stream of data."""
assert
p
>=
0.0
and
p
<=
1.0
for
sample
in
data
:
if
random
.
uniform
(
0.0
,
1.0
)
<
p
:
yield
sample
rsample
=
pipelinefilter
(
_rsample
)
slice
=
pipelinefilter
(
itertools
.
islice
)
def
_extract_keys
(
source
,
*
patterns
,
duplicate_is_error
=
True
,
ignore_missing
=
False
):
for
sample
in
source
:
result
=
[]
for
pattern
in
patterns
:
pattern
=
pattern
.
split
(
";"
)
if
isinstance
(
pattern
,
str
)
else
pattern
matches
=
[
x
for
x
in
sample
.
keys
()
if
any
(
fnmatch
(
"."
+
x
,
p
)
for
p
in
pattern
)]
if
len
(
matches
)
==
0
:
if
ignore_missing
:
continue
else
:
raise
ValueError
(
f
"Cannot find
{
pattern
}
in sample keys
{
sample
.
keys
()
}
."
)
if
len
(
matches
)
>
1
and
duplicate_is_error
:
raise
ValueError
(
f
"Multiple sample keys
{
sample
.
keys
()
}
match
{
pattern
}
."
)
value
=
sample
[
matches
[
0
]]
result
.
append
(
value
)
yield
tuple
(
result
)
extract_keys
=
pipelinefilter
(
_extract_keys
)
def
_rename_keys
(
source
,
*
args
,
keep_unselected
=
False
,
must_match
=
True
,
duplicate_is_error
=
True
,
**
kw
):
renamings
=
[(
pattern
,
output
)
for
output
,
pattern
in
args
]
renamings
+=
[(
pattern
,
output
)
for
output
,
pattern
in
kw
.
items
()]
for
sample
in
source
:
new_sample
=
{}
matched
=
{
k
:
False
for
k
,
_
in
renamings
}
for
path
,
value
in
sample
.
items
():
fname
=
re
.
sub
(
r
".*/"
,
""
,
path
)
new_name
=
None
for
pattern
,
name
in
renamings
[::
-
1
]:
if
fnmatch
(
fname
.
lower
(),
pattern
):
matched
[
pattern
]
=
True
new_name
=
name
break
if
new_name
is
None
:
if
keep_unselected
:
new_sample
[
path
]
=
value
continue
if
new_name
in
new_sample
:
if
duplicate_is_error
:
raise
ValueError
(
f
"Duplicate value in sample
{
sample
.
keys
()
}
after rename."
)
continue
new_sample
[
new_name
]
=
value
if
must_match
and
not
all
(
matched
.
values
()):
raise
ValueError
(
f
"Not all patterns (
{
matched
}
) matched sample keys (
{
sample
.
keys
()
}
)."
)
yield
new_sample
rename_keys
=
pipelinefilter
(
_rename_keys
)
def
decode_bin
(
stream
):
return
stream
.
read
()
def
decode_text
(
stream
):
binary
=
stream
.
read
()
return
binary
.
decode
(
"utf-8"
)
def
decode_pickle
(
stream
):
return
pickle
.
load
(
stream
)
default_decoders
=
[
(
"*.bin"
,
decode_bin
),
(
"*.txt"
,
decode_text
),
(
"*.pyd"
,
decode_pickle
),
]
def
find_decoder
(
decoders
,
path
):
fname
=
re
.
sub
(
r
".*/"
,
""
,
path
)
if
fname
.
startswith
(
"__"
):
return
lambda
x
:
x
for
pattern
,
fun
in
decoders
[::
-
1
]:
if
fnmatch
(
fname
.
lower
(),
pattern
)
or
fnmatch
(
"."
+
fname
.
lower
(),
pattern
):
return
fun
return
None
def
_xdecode
(
source
,
*
args
,
must_decode
=
True
,
defaults
=
default_decoders
,
**
kw
,
):
decoders
=
list
(
defaults
)
+
list
(
args
)
decoders
+=
[(
"*."
+
k
,
v
)
for
k
,
v
in
kw
.
items
()]
for
sample
in
source
:
new_sample
=
{}
for
path
,
data
in
sample
.
items
():
if
path
.
startswith
(
"__"
):
new_sample
[
path
]
=
data
continue
decoder
=
find_decoder
(
decoders
,
path
)
if
decoder
is
False
:
value
=
data
elif
decoder
is
None
:
if
must_decode
:
raise
ValueError
(
f
"No decoder found for
{
path
}
."
)
value
=
data
else
:
if
isinstance
(
data
,
bytes
):
data
=
io
.
BytesIO
(
data
)
value
=
decoder
(
data
)
new_sample
[
path
]
=
value
yield
new_sample
xdecode
=
pipelinefilter
(
_xdecode
)
def
_audio_data_filter
(
source
,
frame_shift
=
10
,
max_length
=
10240
,
min_length
=
10
,
token_max_length
=
200
,
token_min_length
=
1
,
min_output_input_ratio
=
0.0005
,
max_output_input_ratio
=
1
):
""" Filter sample according to feature and label length
Inplace operation.
Args::
source: Iterable[{fname, wav, label, sample_rate}]
frame_shift: length of frame shift (ms)
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{fname, wav, label, sample_rate}]
"""
for
sample
in
source
:
assert
'sample_rate'
in
sample
assert
'wav'
in
sample
assert
'label'
in
sample
# sample['wav'] is paddle.Tensor, we have 100 frames every second (default)
num_frames
=
sample
[
'wav'
].
shape
[
1
]
/
sample
[
'sample_rate'
]
*
(
1000
/
frame_shift
)
if
num_frames
<
min_length
:
continue
if
num_frames
>
max_length
:
continue
if
len
(
sample
[
'label'
])
<
token_min_length
:
continue
if
len
(
sample
[
'label'
])
>
token_max_length
:
continue
if
num_frames
!=
0
:
if
len
(
sample
[
'label'
])
/
num_frames
<
min_output_input_ratio
:
continue
if
len
(
sample
[
'label'
])
/
num_frames
>
max_output_input_ratio
:
continue
yield
sample
audio_data_filter
=
pipelinefilter
(
_audio_data_filter
)
def
_audio_tokenize
(
source
,
symbol_table
,
bpe_model
=
None
,
non_lang_syms
=
None
,
split_with_space
=
False
):
""" Decode text to chars or BPE
Inplace operation
Args:
source: Iterable[{fname, wav, txt, sample_rate}]
Returns:
Iterable[{fname, wav, txt, tokens, label, sample_rate}]
"""
if
non_lang_syms
is
not
None
:
non_lang_syms_pattern
=
re
.
compile
(
r
"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})"
)
else
:
non_lang_syms
=
{}
non_lang_syms_pattern
=
None
if
bpe_model
is
not
None
:
import
sentencepiece
as
spm
sp
=
spm
.
SentencePieceProcessor
()
sp
.
load
(
bpe_model
)
else
:
sp
=
None
for
sample
in
source
:
assert
'txt'
in
sample
txt
=
sample
[
'txt'
].
strip
()
if
non_lang_syms_pattern
is
not
None
:
parts
=
non_lang_syms_pattern
.
split
(
txt
.
upper
())
parts
=
[
w
for
w
in
parts
if
len
(
w
.
strip
())
>
0
]
else
:
parts
=
[
txt
]
label
=
[]
tokens
=
[]
for
part
in
parts
:
if
part
in
non_lang_syms
:
tokens
.
append
(
part
)
else
:
if
bpe_model
is
not
None
:
tokens
.
extend
(
__tokenize_by_bpe_model
(
sp
,
part
))
else
:
if
split_with_space
:
part
=
part
.
split
(
" "
)
for
ch
in
part
:
if
ch
==
' '
:
ch
=
"<space>"
tokens
.
append
(
ch
)
for
ch
in
tokens
:
if
ch
in
symbol_table
:
label
.
append
(
symbol_table
[
ch
])
elif
'<unk>'
in
symbol_table
:
label
.
append
(
symbol_table
[
'<unk>'
])
sample
[
'tokens'
]
=
tokens
sample
[
'label'
]
=
label
yield
sample
audio_tokenize
=
pipelinefilter
(
_audio_tokenize
)
def
_audio_resample
(
source
,
resample_rate
=
16000
):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{fname, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{fname, wav, label, sample_rate}]
"""
for
sample
in
source
:
assert
'sample_rate'
in
sample
assert
'wav'
in
sample
sample_rate
=
sample
[
'sample_rate'
]
waveform
=
sample
[
'wav'
]
if
sample_rate
!=
resample_rate
:
sample
[
'sample_rate'
]
=
resample_rate
sample
[
'wav'
]
=
paddle
.
to_tensor
(
backends
.
soundfile_backend
.
resample
(
waveform
.
numpy
(),
src_sr
=
sample_rate
,
target_sr
=
resample_rate
))
yield
sample
audio_resample
=
pipelinefilter
(
_audio_resample
)
def
_audio_compute_fbank
(
source
,
num_mel_bins
=
80
,
frame_length
=
25
,
frame_shift
=
10
,
dither
=
0.0
):
""" Extract fbank
Args:
source: Iterable[{fname, wav, label, sample_rate}]
num_mel_bins: number of mel filter bank
frame_length: length of one frame (ms)
frame_shift: length of frame shift (ms)
dither: value of dither
Returns:
Iterable[{fname, feat, label}]
"""
for
sample
in
source
:
assert
'sample_rate'
in
sample
assert
'wav'
in
sample
assert
'fname'
in
sample
assert
'label'
in
sample
sample_rate
=
sample
[
'sample_rate'
]
waveform
=
sample
[
'wav'
]
waveform
=
waveform
*
(
1
<<
15
)
# Only keep fname, feat, label
mat
=
kaldi
.
fbank
(
waveform
,
n_mels
=
num_mel_bins
,
frame_length
=
frame_length
,
frame_shift
=
frame_shift
,
dither
=
dither
,
energy_floor
=
0.0
,
sr
=
sample_rate
)
yield
dict
(
fname
=
sample
[
'fname'
],
label
=
sample
[
'label'
],
feat
=
mat
)
audio_compute_fbank
=
pipelinefilter
(
_audio_compute_fbank
)
def
_audio_spec_aug
(
source
,
max_w
=
5
,
w_inplace
=
True
,
w_mode
=
"PIL"
,
max_f
=
30
,
num_f_mask
=
2
,
f_inplace
=
True
,
f_replace_with_zero
=
False
,
max_t
=
40
,
num_t_mask
=
2
,
t_inplace
=
True
,
t_replace_with_zero
=
False
,):
""" Do spec augmentation
Inplace operation
Args:
source: Iterable[{fname, feat, label}]
max_w: max width of time warp
w_inplace: whether to inplace the original data while time warping
w_mode: time warp mode
max_f: max width of freq mask
num_f_mask: number of freq mask to apply
f_inplace: whether to inplace the original data while frequency masking
f_replace_with_zero: use zero to mask
max_t: max width of time mask
num_t_mask: number of time mask to apply
t_inplace: whether to inplace the original data while time masking
t_replace_with_zero: use zero to mask
Returns
Iterable[{fname, feat, label}]
"""
for
sample
in
source
:
x
=
sample
[
'feat'
]
x
=
x
.
numpy
()
x
=
time_warp
(
x
,
max_time_warp
=
max_w
,
inplace
=
w_inplace
,
mode
=
w_mode
)
x
=
freq_mask
(
x
,
F
=
max_f
,
n_mask
=
num_f_mask
,
inplace
=
f_inplace
,
replace_with_zero
=
f_replace_with_zero
)
x
=
time_mask
(
x
,
T
=
max_t
,
n_mask
=
num_t_mask
,
inplace
=
t_inplace
,
replace_with_zero
=
t_replace_with_zero
)
sample
[
'feat'
]
=
paddle
.
to_tensor
(
x
,
dtype
=
paddle
.
float32
)
yield
sample
audio_spec_aug
=
pipelinefilter
(
_audio_spec_aug
)
def
_sort
(
source
,
sort_size
=
500
):
""" Sort the data by feature length.
Sort is used after shuffle and before batch, so we can group
utts with similar lengths into a batch, and `sort_size` should
be less than `shuffle_size`
Args:
source: Iterable[{fname, feat, label}]
sort_size: buffer size for sort
Returns:
Iterable[{fname, feat, label}]
"""
buf
=
[]
for
sample
in
source
:
buf
.
append
(
sample
)
if
len
(
buf
)
>=
sort_size
:
buf
.
sort
(
key
=
lambda
x
:
x
[
'feat'
].
shape
[
0
])
for
x
in
buf
:
yield
x
buf
=
[]
# The sample left over
buf
.
sort
(
key
=
lambda
x
:
x
[
'feat'
].
shape
[
0
])
for
x
in
buf
:
yield
x
sort
=
pipelinefilter
(
_sort
)
def
_batched
(
source
,
batch_size
=
16
):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{fname, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{fname, feat, label}]]
"""
buf
=
[]
for
sample
in
source
:
buf
.
append
(
sample
)
if
len
(
buf
)
>=
batch_size
:
yield
buf
buf
=
[]
if
len
(
buf
)
>
0
:
yield
buf
batched
=
pipelinefilter
(
_batched
)
def
dynamic_batched
(
source
,
max_frames_in_batch
=
12000
):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
Args:
source: Iterable[{fname, feat, label}]
max_frames_in_batch: max_frames in one batch
Returns:
Iterable[List[{fname, feat, label}]]
"""
buf
=
[]
longest_frames
=
0
for
sample
in
source
:
assert
'feat'
in
sample
assert
isinstance
(
sample
[
'feat'
],
paddle
.
Tensor
)
new_sample_frames
=
sample
[
'feat'
].
size
(
0
)
longest_frames
=
max
(
longest_frames
,
new_sample_frames
)
frames_after_padding
=
longest_frames
*
(
len
(
buf
)
+
1
)
if
frames_after_padding
>
max_frames_in_batch
:
yield
buf
buf
=
[
sample
]
longest_frames
=
new_sample_frames
else
:
buf
.
append
(
sample
)
if
len
(
buf
)
>
0
:
yield
buf
def
_audio_padding
(
source
):
""" Padding the data into training data
Args:
source: Iterable[List[{fname, feat, label}]]
Returns:
Iterable[Tuple(fname, feats, labels, feats lengths, label lengths)]
"""
for
sample
in
source
:
assert
isinstance
(
sample
,
list
)
feats_length
=
paddle
.
to_tensor
([
x
[
'feat'
].
shape
[
0
]
for
x
in
sample
],
dtype
=
"int64"
)
order
=
paddle
.
argsort
(
feats_length
,
descending
=
True
)
feats_lengths
=
paddle
.
to_tensor
(
[
sample
[
i
][
'feat'
].
shape
[
0
]
for
i
in
order
],
dtype
=
"int64"
)
sorted_feats
=
[
sample
[
i
][
'feat'
]
for
i
in
order
]
sorted_keys
=
[
sample
[
i
][
'fname'
]
for
i
in
order
]
sorted_labels
=
[
paddle
.
to_tensor
(
sample
[
i
][
'label'
],
dtype
=
"int32"
)
for
i
in
order
]
label_lengths
=
paddle
.
to_tensor
([
x
.
shape
[
0
]
for
x
in
sorted_labels
],
dtype
=
"int64"
)
padded_feats
=
pad_sequence
(
sorted_feats
,
batch_first
=
True
,
padding_value
=
0
)
padding_labels
=
pad_sequence
(
sorted_labels
,
batch_first
=
True
,
padding_value
=-
1
)
yield
(
sorted_keys
,
padded_feats
,
feats_lengths
,
padding_labels
,
label_lengths
)
audio_padding
=
pipelinefilter
(
_audio_padding
)
def
_audio_cmvn
(
source
,
cmvn_file
):
global_cmvn
=
GlobalCMVN
(
cmvn_file
)
for
batch
in
source
:
sorted_keys
,
padded_feats
,
feats_lengths
,
padding_labels
,
label_lengths
=
batch
padded_feats
=
padded_feats
.
numpy
()
padded_feats
=
global_cmvn
(
padded_feats
)
padded_feats
=
paddle
.
to_tensor
(
padded_feats
,
dtype
=
paddle
.
float32
)
yield
(
sorted_keys
,
padded_feats
,
feats_lengths
,
padding_labels
,
label_lengths
)
audio_cmvn
=
pipelinefilter
(
_audio_cmvn
)
def
_placeholder
(
source
):
for
data
in
source
:
yield
data
placeholder
=
pipelinefilter
(
_placeholder
)
paddlespeech/audio/streamdata/gopen.py
0 → 100644
浏览文件 @
d1a25f6c
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
"""Open URLs by calling subcommands."""
import
os
,
sys
,
re
from
subprocess
import
PIPE
,
Popen
from
urllib.parse
import
urlparse
# global used for printing additional node information during verbose output
info
=
{}
class
Pipe
:
"""Wrapper class for subprocess.Pipe.
This class looks like a stream from the outside, but it checks
subprocess status and handles timeouts with exceptions.
This way, clients of the class do not need to know that they are
dealing with subprocesses.
:param *args: passed to `subprocess.Pipe`
:param **kw: passed to `subprocess.Pipe`
:param timeout: timeout for closing/waiting
:param ignore_errors: don't raise exceptions on subprocess errors
:param ignore_status: list of status codes to ignore
"""
def
__init__
(
self
,
*
args
,
mode
=
None
,
timeout
=
7200.0
,
ignore_errors
=
False
,
ignore_status
=
[],
**
kw
,
):
"""Create an IO Pipe."""
self
.
ignore_errors
=
ignore_errors
self
.
ignore_status
=
[
0
]
+
ignore_status
self
.
timeout
=
timeout
self
.
args
=
(
args
,
kw
)
if
mode
[
0
]
==
"r"
:
self
.
proc
=
Popen
(
*
args
,
stdout
=
PIPE
,
**
kw
)
self
.
stream
=
self
.
proc
.
stdout
if
self
.
stream
is
None
:
raise
ValueError
(
f
"
{
args
}
: couldn't open"
)
elif
mode
[
0
]
==
"w"
:
self
.
proc
=
Popen
(
*
args
,
stdin
=
PIPE
,
**
kw
)
self
.
stream
=
self
.
proc
.
stdin
if
self
.
stream
is
None
:
raise
ValueError
(
f
"
{
args
}
: couldn't open"
)
self
.
status
=
None
def
__str__
(
self
):
return
f
"<Pipe
{
self
.
args
}
>"
def
check_status
(
self
):
"""Poll the process and handle any errors."""
status
=
self
.
proc
.
poll
()
if
status
is
not
None
:
self
.
wait_for_child
()
def
wait_for_child
(
self
):
"""Check the status variable and raise an exception if necessary."""
verbose
=
int
(
os
.
environ
.
get
(
"GOPEN_VERBOSE"
,
0
))
if
self
.
status
is
not
None
and
verbose
:
# print(f"(waiting again [{self.status} {os.getpid()}:{self.proc.pid}])", file=sys.stderr)
return
self
.
status
=
self
.
proc
.
wait
()
if
verbose
:
print
(
f
"pipe exit [
{
self
.
status
}
{
os
.
getpid
()
}
:
{
self
.
proc
.
pid
}
]
{
self
.
args
}
{
info
}
"
,
file
=
sys
.
stderr
,
)
if
self
.
status
not
in
self
.
ignore_status
and
not
self
.
ignore_errors
:
raise
Exception
(
f
"
{
self
.
args
}
: exit
{
self
.
status
}
(read)
{
info
}
"
)
def
read
(
self
,
*
args
,
**
kw
):
"""Wrap stream.read and checks status."""
result
=
self
.
stream
.
read
(
*
args
,
**
kw
)
self
.
check_status
()
return
result
def
write
(
self
,
*
args
,
**
kw
):
"""Wrap stream.write and checks status."""
result
=
self
.
stream
.
write
(
*
args
,
**
kw
)
self
.
check_status
()
return
result
def
readLine
(
self
,
*
args
,
**
kw
):
"""Wrap stream.readLine and checks status."""
result
=
self
.
stream
.
readLine
(
*
args
,
**
kw
)
self
.
status
=
self
.
proc
.
poll
()
self
.
check_status
()
return
result
def
close
(
self
):
"""Wrap stream.close, wait for the subprocess, and handle errors."""
self
.
stream
.
close
()
self
.
status
=
self
.
proc
.
wait
(
self
.
timeout
)
self
.
wait_for_child
()
def
__enter__
(
self
):
"""Context handler."""
return
self
def
__exit__
(
self
,
etype
,
value
,
traceback
):
"""Context handler."""
self
.
close
()
def
set_options
(
obj
,
timeout
=
None
,
ignore_errors
=
None
,
ignore_status
=
None
,
handler
=
None
):
"""Set options for Pipes.
This function can be called on any stream. It will set pipe options only
when its argument is a pipe.
:param obj: any kind of stream
:param timeout: desired timeout
:param ignore_errors: desired ignore_errors setting
:param ignore_status: desired ignore_status setting
:param handler: desired error handler
"""
if
not
isinstance
(
obj
,
Pipe
):
return
False
if
timeout
is
not
None
:
obj
.
timeout
=
timeout
if
ignore_errors
is
not
None
:
obj
.
ignore_errors
=
ignore_errors
if
ignore_status
is
not
None
:
obj
.
ignore_status
=
ignore_status
if
handler
is
not
None
:
obj
.
handler
=
handler
return
True
def
gopen_file
(
url
,
mode
=
"rb"
,
bufsize
=
8192
):
"""Open a file.
This works for local files, files over HTTP, and pipe: files.
:param url: URL to be opened
:param mode: mode to open it with
:param bufsize: requested buffer size
"""
return
open
(
url
,
mode
)
def
gopen_pipe
(
url
,
mode
=
"rb"
,
bufsize
=
8192
):
"""Use gopen to open a pipe.
:param url: a pipe: URL
:param mode: desired mode
:param bufsize: desired buffer size
"""
assert
url
.
startswith
(
"pipe:"
)
cmd
=
url
[
5
:]
if
mode
[
0
]
==
"r"
:
return
Pipe
(
cmd
,
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
],
)
# skipcq: BAN-B604
elif
mode
[
0
]
==
"w"
:
return
Pipe
(
cmd
,
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
],
)
# skipcq: BAN-B604
else
:
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
def
gopen_curl
(
url
,
mode
=
"rb"
,
bufsize
=
8192
):
"""Open a URL with `curl`.
:param url: url (usually, http:// etc.)
:param mode: file mode
:param bufsize: buffer size
"""
if
mode
[
0
]
==
"r"
:
cmd
=
f
"curl -s -L '
{
url
}
'"
return
Pipe
(
cmd
,
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
23
],
)
# skipcq: BAN-B604
elif
mode
[
0
]
==
"w"
:
cmd
=
f
"curl -s -L -T - '
{
url
}
'"
return
Pipe
(
cmd
,
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
26
],
)
# skipcq: BAN-B604
else
:
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
def
gopen_htgs
(
url
,
mode
=
"rb"
,
bufsize
=
8192
):
"""Open a URL with `curl`.
:param url: url (usually, http:// etc.)
:param mode: file mode
:param bufsize: buffer size
"""
if
mode
[
0
]
==
"r"
:
url
=
re
.
sub
(
r
"(?i)^htgs://"
,
"gs://"
,
url
)
cmd
=
f
"curl -s -L '
{
url
}
'"
return
Pipe
(
cmd
,
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
23
],
)
# skipcq: BAN-B604
elif
mode
[
0
]
==
"w"
:
raise
ValueError
(
f
"
{
mode
}
: cannot write"
)
else
:
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
def
gopen_gsutil
(
url
,
mode
=
"rb"
,
bufsize
=
8192
):
"""Open a URL with `curl`.
:param url: url (usually, http:// etc.)
:param mode: file mode
:param bufsize: buffer size
"""
if
mode
[
0
]
==
"r"
:
cmd
=
f
"gsutil cat '
{
url
}
'"
return
Pipe
(
cmd
,
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
23
],
)
# skipcq: BAN-B604
elif
mode
[
0
]
==
"w"
:
cmd
=
f
"gsutil cp - '
{
url
}
'"
return
Pipe
(
cmd
,
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
26
],
)
# skipcq: BAN-B604
else
:
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
def
gopen_error
(
url
,
*
args
,
**
kw
):
"""Raise a value error.
:param url: url
:param args: other arguments
:param kw: other keywords
"""
raise
ValueError
(
f
"
{
url
}
: no gopen handler defined"
)
"""A dispatch table mapping URL schemes to handlers."""
gopen_schemes
=
dict
(
__default__
=
gopen_error
,
pipe
=
gopen_pipe
,
http
=
gopen_curl
,
https
=
gopen_curl
,
sftp
=
gopen_curl
,
ftps
=
gopen_curl
,
scp
=
gopen_curl
,
gs
=
gopen_gsutil
,
htgs
=
gopen_htgs
,
)
def
gopen
(
url
,
mode
=
"rb"
,
bufsize
=
8192
,
**
kw
):
"""Open the URL.
This uses the `gopen_schemes` dispatch table to dispatch based
on scheme.
Support for the following schemes is built-in: pipe, file,
http, https, sftp, ftps, scp.
When no scheme is given the url is treated as a file.
You can use the OPEN_VERBOSE argument to get info about
files being opened.
:param url: the source URL
:param mode: the mode ("rb", "r")
:param bufsize: the buffer size
"""
global
fallback_gopen
verbose
=
int
(
os
.
environ
.
get
(
"GOPEN_VERBOSE"
,
0
))
if
verbose
:
print
(
"GOPEN"
,
url
,
info
,
file
=
sys
.
stderr
)
assert
mode
in
[
"rb"
,
"wb"
],
mode
if
url
==
"-"
:
if
mode
==
"rb"
:
return
sys
.
stdin
.
buffer
elif
mode
==
"wb"
:
return
sys
.
stdout
.
buffer
else
:
raise
ValueError
(
f
"unknown mode
{
mode
}
"
)
pr
=
urlparse
(
url
)
if
pr
.
scheme
==
""
:
bufsize
=
int
(
os
.
environ
.
get
(
"GOPEN_BUFFER"
,
-
1
))
return
open
(
url
,
mode
,
buffering
=
bufsize
)
if
pr
.
scheme
==
"file"
:
bufsize
=
int
(
os
.
environ
.
get
(
"GOPEN_BUFFER"
,
-
1
))
return
open
(
pr
.
path
,
mode
,
buffering
=
bufsize
)
handler
=
gopen_schemes
[
"__default__"
]
handler
=
gopen_schemes
.
get
(
pr
.
scheme
,
handler
)
return
handler
(
url
,
mode
,
bufsize
,
**
kw
)
def
reader
(
url
,
**
kw
):
"""Open url with gopen and mode "rb".
:param url: source URL
:param kw: other keywords forwarded to gopen
"""
return
gopen
(
url
,
"rb"
,
**
kw
)
paddlespeech/audio/streamdata/handlers.py
0 → 100644
浏览文件 @
d1a25f6c
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
"""Pluggable exception handlers.
These are functions that take an exception as an argument and then return...
- the exception (in order to re-raise it)
- True (in order to continue and ignore the exception)
- False (in order to ignore the exception and stop processing)
They are used as handler= arguments in much of the library.
"""
import
time
,
warnings
def
reraise_exception
(
exn
):
"""Call in an exception handler to re-raise the exception."""
raise
exn
def
ignore_and_continue
(
exn
):
"""Call in an exception handler to ignore any exception and continue."""
return
True
def
warn_and_continue
(
exn
):
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
warnings
.
warn
(
repr
(
exn
))
time
.
sleep
(
0.5
)
return
True
def
ignore_and_stop
(
exn
):
"""Call in an exception handler to ignore any exception and stop further processing."""
return
False
def
warn_and_stop
(
exn
):
"""Call in an exception handler to ignore any exception and stop further processing."""
warnings
.
warn
(
repr
(
exn
))
time
.
sleep
(
0.5
)
return
False
paddlespeech/audio/streamdata/mix.py
0 → 100644
浏览文件 @
d1a25f6c
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Classes for mixing samples from multiple sources."""
import
itertools
,
os
,
random
,
time
,
sys
from
functools
import
reduce
,
wraps
import
numpy
as
np
from
.
import
autodecode
,
utils
from
.paddle_utils
import
PaddleTensor
,
IterableDataset
from
.utils
import
PipelineStage
def
round_robin_shortest
(
*
sources
):
i
=
0
while
True
:
try
:
sample
=
next
(
sources
[
i
%
len
(
sources
)])
yield
sample
except
StopIteration
:
break
i
+=
1
def
round_robin_longest
(
*
sources
):
i
=
0
while
len
(
sources
)
>
0
:
try
:
sample
=
next
(
sources
[
i
])
i
+=
1
yield
sample
except
StopIteration
:
del
sources
[
i
]
class
RoundRobin
(
IterableDataset
):
def
__init__
(
self
,
datasets
,
longest
=
False
):
self
.
datasets
=
datasets
self
.
longest
=
longest
def
__iter__
(
self
):
"""Return an iterator over the sources."""
sources
=
[
iter
(
d
)
for
d
in
self
.
datasets
]
if
self
.
longest
:
return
round_robin_longest
(
*
sources
)
else
:
return
round_robin_shortest
(
*
sources
)
def
random_samples
(
sources
,
probs
=
None
,
longest
=
False
):
if
probs
is
None
:
probs
=
[
1
]
*
len
(
sources
)
else
:
probs
=
list
(
probs
)
while
len
(
sources
)
>
0
:
cum
=
(
np
.
array
(
probs
)
/
np
.
sum
(
probs
)).
cumsum
()
r
=
random
.
random
()
i
=
np
.
searchsorted
(
cum
,
r
)
try
:
yield
next
(
sources
[
i
])
except
StopIteration
:
if
longest
:
del
sources
[
i
]
del
probs
[
i
]
else
:
break
class
RandomMix
(
IterableDataset
):
def
__init__
(
self
,
datasets
,
probs
=
None
,
longest
=
False
):
self
.
datasets
=
datasets
self
.
probs
=
probs
self
.
longest
=
longest
def
__iter__
(
self
):
"""Return an iterator over the sources."""
sources
=
[
iter
(
d
)
for
d
in
self
.
datasets
]
return
random_samples
(
sources
,
self
.
probs
,
longest
=
self
.
longest
)
paddlespeech/audio/streamdata/paddle_utils.py
0 → 100644
浏览文件 @
d1a25f6c
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Mock implementations of paddle interfaces when paddle is not available."""
try
:
from
paddle.io
import
DataLoader
,
IterableDataset
except
ModuleNotFoundError
:
class
IterableDataset
:
"""Empty implementation of IterableDataset when paddle is not available."""
pass
class
DataLoader
:
"""Empty implementation of DataLoader when paddle is not available."""
pass
try
:
from
paddle
import
Tensor
as
PaddleTensor
except
ModuleNotFoundError
:
class
TorchTensor
:
"""Empty implementation of PaddleTensor when paddle is not available."""
pass
paddlespeech/audio/streamdata/pipeline.py
0 → 100644
浏览文件 @
d1a25f6c
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#%%
import
copy
,
os
,
random
,
sys
,
time
from
dataclasses
import
dataclass
from
itertools
import
islice
from
typing
import
List
import
braceexpand
,
yaml
from
.handlers
import
reraise_exception
from
.paddle_utils
import
DataLoader
,
IterableDataset
from
.utils
import
PipelineStage
def
add_length_method
(
obj
):
def
length
(
self
):
return
self
.
size
Combined
=
type
(
obj
.
__class__
.
__name__
+
"_Length"
,
(
obj
.
__class__
,
IterableDataset
),
{
"__len__"
:
length
},
)
obj
.
__class__
=
Combined
return
obj
class
DataPipeline
(
IterableDataset
,
PipelineStage
):
"""A pipeline starting with an IterableDataset and a series of filters."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
()
self
.
pipeline
=
[]
self
.
length
=
-
1
self
.
repetitions
=
1
self
.
nsamples
=
-
1
for
arg
in
args
:
if
arg
is
None
:
continue
if
isinstance
(
arg
,
list
):
self
.
pipeline
.
extend
(
arg
)
else
:
self
.
pipeline
.
append
(
arg
)
def
invoke
(
self
,
f
,
*
args
,
**
kwargs
):
"""Apply a pipeline stage, possibly to the output of a previous stage."""
if
isinstance
(
f
,
PipelineStage
):
return
f
.
run
(
*
args
,
**
kwargs
)
if
isinstance
(
f
,
(
IterableDataset
,
DataLoader
))
and
len
(
args
)
==
0
:
return
iter
(
f
)
if
isinstance
(
f
,
list
):
return
iter
(
f
)
if
callable
(
f
):
result
=
f
(
*
args
,
**
kwargs
)
return
result
raise
ValueError
(
f
"
{
f
}
: not a valid pipeline stage"
)
def
iterator1
(
self
):
"""Create an iterator through one epoch in the pipeline."""
source
=
self
.
invoke
(
self
.
pipeline
[
0
])
for
step
in
self
.
pipeline
[
1
:]:
source
=
self
.
invoke
(
step
,
source
)
return
source
def
iterator
(
self
):
"""Create an iterator through the entire dataset, using the given number of repetitions."""
for
i
in
range
(
self
.
repetitions
):
for
sample
in
self
.
iterator1
():
yield
sample
def
__iter__
(
self
):
"""Create an iterator through the pipeline, repeating and slicing as requested."""
if
self
.
repetitions
!=
1
:
if
self
.
nsamples
>
0
:
return
islice
(
self
.
iterator
(),
self
.
nsamples
)
else
:
return
self
.
iterator
()
else
:
return
self
.
iterator
()
def
stage
(
self
,
i
):
"""Return pipeline stage i."""
return
self
.
pipeline
[
i
]
def
append
(
self
,
f
):
"""Append a pipeline stage (modifies the object)."""
self
.
pipeline
.
append
(
f
)
return
self
def
append_list
(
self
,
*
args
):
for
arg
in
args
:
self
.
pipeline
.
append
(
arg
)
return
self
def
compose
(
self
,
*
args
):
"""Append a pipeline stage to a copy of the pipeline and returns the copy."""
result
=
copy
.
copy
(
self
)
for
arg
in
args
:
result
.
append
(
arg
)
return
result
def
with_length
(
self
,
n
):
"""Add a __len__ method returning the desired value.
This does not change the actual number of samples in an epoch.
PyTorch IterableDataset should not have a __len__ method.
This is provided only as a workaround for some broken training environments
that require a __len__ method.
"""
self
.
size
=
n
return
add_length_method
(
self
)
def
with_epoch
(
self
,
nsamples
=-
1
,
nbatches
=-
1
):
"""Change the epoch to return the given number of samples/batches.
The two arguments mean the same thing."""
self
.
repetitions
=
sys
.
maxsize
self
.
nsamples
=
max
(
nsamples
,
nbatches
)
return
self
def
repeat
(
self
,
nepochs
=-
1
,
nbatches
=-
1
):
"""Repeat iterating through the dataset for the given #epochs up to the given #samples."""
if
nepochs
>
0
:
self
.
repetitions
=
nepochs
self
.
nsamples
=
nbatches
else
:
self
.
repetitions
=
sys
.
maxsize
self
.
nsamples
=
nbatches
return
self
paddlespeech/audio/streamdata/shardlists.py
0 → 100644
浏览文件 @
d1a25f6c
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
"""
import
os
,
random
,
sys
,
time
from
dataclasses
import
dataclass
,
field
from
itertools
import
islice
from
typing
import
List
import
braceexpand
,
yaml
from
.
import
utils
from
.filters
import
pipelinefilter
from
.paddle_utils
import
IterableDataset
from
..utils.log
import
Logger
logger
=
Logger
(
__name__
)
def
expand_urls
(
urls
):
if
isinstance
(
urls
,
str
):
urllist
=
urls
.
split
(
"::"
)
result
=
[]
for
url
in
urllist
:
result
.
extend
(
braceexpand
.
braceexpand
(
url
))
return
result
else
:
return
list
(
urls
)
class
SimpleShardList
(
IterableDataset
):
"""An iterable dataset yielding a list of urls."""
def
__init__
(
self
,
urls
,
seed
=
None
):
"""Iterate through the list of shards.
:param urls: a list of URLs as a Python list or brace notation string
"""
super
().
__init__
()
urls
=
expand_urls
(
urls
)
self
.
urls
=
urls
assert
isinstance
(
self
.
urls
[
0
],
str
)
self
.
seed
=
seed
def
__len__
(
self
):
return
len
(
self
.
urls
)
def
__iter__
(
self
):
"""Return an iterator over the shards."""
urls
=
self
.
urls
.
copy
()
if
self
.
seed
is
not
None
:
random
.
Random
(
self
.
seed
).
shuffle
(
urls
)
for
url
in
urls
:
yield
dict
(
url
=
url
)
def
split_by_node
(
src
,
group
=
None
):
rank
,
world_size
,
worker
,
num_workers
=
utils
.
paddle_worker_info
(
group
=
group
)
logger
.
info
(
f
"world_size:
{
world_size
}
, rank:
{
rank
}
"
)
if
world_size
>
1
:
for
s
in
islice
(
src
,
rank
,
None
,
world_size
):
yield
s
else
:
for
s
in
src
:
yield
s
def
single_node_only
(
src
,
group
=
None
):
rank
,
world_size
,
worker
,
num_workers
=
utils
.
paddle_worker_info
(
group
=
group
)
if
world_size
>
1
:
raise
ValueError
(
"input pipeline needs to be reconfigured for multinode training"
)
for
s
in
src
:
yield
s
def
split_by_worker
(
src
):
rank
,
world_size
,
worker
,
num_workers
=
utils
.
paddle_worker_info
()
logger
.
info
(
f
"num_workers:
{
num_workers
}
, worker:
{
worker
}
"
)
if
num_workers
>
1
:
for
s
in
islice
(
src
,
worker
,
None
,
num_workers
):
yield
s
else
:
for
s
in
src
:
yield
s
def
resampled_
(
src
,
n
=
sys
.
maxsize
):
import
random
seed
=
time
.
time
()
try
:
seed
=
open
(
"/dev/random"
,
"rb"
).
read
(
20
)
except
Exception
as
exn
:
print
(
repr
(
exn
)[:
50
],
file
=
sys
.
stderr
)
rng
=
random
.
Random
(
seed
)
print
(
"# resampled loading"
,
file
=
sys
.
stderr
)
items
=
list
(
src
)
print
(
f
"# resampled got
{
len
(
items
)
}
samples, yielding
{
n
}
"
,
file
=
sys
.
stderr
)
for
i
in
range
(
n
):
yield
rng
.
choice
(
items
)
resampled
=
pipelinefilter
(
resampled_
)
def
non_empty
(
src
):
count
=
0
for
s
in
src
:
yield
s
count
+=
1
if
count
==
0
:
raise
ValueError
(
"pipeline stage received no data at all and this was declared as an error"
)
@
dataclass
class
MSSource
:
"""Class representing a data source."""
name
:
str
=
""
perepoch
:
int
=
-
1
resample
:
bool
=
False
urls
:
List
[
str
]
=
field
(
default_factory
=
list
)
default_rng
=
random
.
Random
()
def
expand
(
s
):
return
os
.
path
.
expanduser
(
os
.
path
.
expandvars
(
s
))
class
MultiShardSample
(
IterableDataset
):
def
__init__
(
self
,
fname
):
"""Construct a shardlist from multiple sources using a YAML spec."""
self
.
epoch
=
-
1
class
MultiShardSample
(
IterableDataset
):
def
__init__
(
self
,
fname
):
"""Construct a shardlist from multiple sources using a YAML spec."""
self
.
epoch
=
-
1
self
.
parse_spec
(
fname
)
def
parse_spec
(
self
,
fname
):
self
.
rng
=
default_rng
# capture default_rng if we fork
if
isinstance
(
fname
,
dict
):
spec
=
fname
fname
=
"{dict}"
else
:
with
open
(
fname
)
as
stream
:
spec
=
yaml
.
safe_load
(
stream
)
assert
set
(
spec
.
keys
()).
issubset
(
set
(
"prefix datasets buckets"
.
split
())),
list
(
spec
.
keys
())
prefix
=
expand
(
spec
.
get
(
"prefix"
,
""
))
self
.
sources
=
[]
for
ds
in
spec
[
"datasets"
]:
assert
set
(
ds
.
keys
()).
issubset
(
set
(
"buckets name shards resample choose"
.
split
())),
list
(
ds
.
keys
()
)
buckets
=
ds
.
get
(
"buckets"
,
spec
.
get
(
"buckets"
,
[]))
if
isinstance
(
buckets
,
str
):
buckets
=
[
buckets
]
buckets
=
[
expand
(
s
)
for
s
in
buckets
]
if
buckets
==
[]:
buckets
=
[
""
]
assert
len
(
buckets
)
==
1
,
f
"
{
buckets
}
: FIXME support for multiple buckets unimplemented"
bucket
=
buckets
[
0
]
name
=
ds
.
get
(
"name"
,
"@"
+
bucket
)
urls
=
ds
[
"shards"
]
if
isinstance
(
urls
,
str
):
urls
=
[
urls
]
# urls = [u for url in urls for u in braceexpand.braceexpand(url)]
urls
=
[
prefix
+
os
.
path
.
join
(
bucket
,
u
)
for
url
in
urls
for
u
in
braceexpand
.
braceexpand
(
expand
(
url
))
]
resample
=
ds
.
get
(
"resample"
,
-
1
)
nsample
=
ds
.
get
(
"choose"
,
-
1
)
if
nsample
>
len
(
urls
):
raise
ValueError
(
f
"perepoch
{
nsample
}
must be no greater than the number of shards"
)
if
(
nsample
>
0
)
and
(
resample
>
0
):
raise
ValueError
(
"specify only one of perepoch or choose"
)
entry
=
MSSource
(
name
=
name
,
urls
=
urls
,
perepoch
=
nsample
,
resample
=
resample
)
self
.
sources
.
append
(
entry
)
print
(
f
"#
{
name
}
{
len
(
urls
)
}
{
nsample
}
"
,
file
=
sys
.
stderr
)
def
set_epoch
(
self
,
seed
):
"""Set the current epoch (for consistent shard selection among nodes)."""
self
.
rng
=
random
.
Random
(
seed
)
def
get_shards_for_epoch
(
self
):
result
=
[]
for
source
in
self
.
sources
:
if
source
.
resample
>
0
:
# sample with replacement
l
=
self
.
rng
.
choices
(
source
.
urls
,
k
=
source
.
resample
)
elif
source
.
perepoch
>
0
:
# sample without replacement
l
=
list
(
source
.
urls
)
self
.
rng
.
shuffle
(
l
)
l
=
l
[:
source
.
perepoch
]
else
:
l
=
list
(
source
.
urls
)
result
+=
l
self
.
rng
.
shuffle
(
result
)
return
result
def
__iter__
(
self
):
shards
=
self
.
get_shards_for_epoch
()
for
shard
in
shards
:
yield
dict
(
url
=
shard
)
def
shardspec
(
spec
):
if
spec
.
endswith
(
".yaml"
):
return
MultiShardSample
(
spec
)
else
:
return
SimpleShardList
(
spec
)
class
ResampledShards
(
IterableDataset
):
"""An iterable dataset yielding a list of urls."""
def
__init__
(
self
,
urls
,
nshards
=
sys
.
maxsize
,
worker_seed
=
None
,
deterministic
=
False
,
):
"""Sample shards from the shard list with replacement.
:param urls: a list of URLs as a Python list or brace notation string
"""
super
().
__init__
()
urls
=
expand_urls
(
urls
)
self
.
urls
=
urls
assert
isinstance
(
self
.
urls
[
0
],
str
)
self
.
nshards
=
nshards
self
.
worker_seed
=
utils
.
paddle_worker_seed
if
worker_seed
is
None
else
worker_seed
self
.
deterministic
=
deterministic
self
.
epoch
=
-
1
def
__iter__
(
self
):
"""Return an iterator over the shards."""
self
.
epoch
+=
1
if
self
.
deterministic
:
seed
=
utils
.
make_seed
(
self
.
worker_seed
(),
self
.
epoch
)
else
:
seed
=
utils
.
make_seed
(
self
.
worker_seed
(),
self
.
epoch
,
os
.
getpid
(),
time
.
time_ns
(),
os
.
urandom
(
4
))
if
os
.
environ
.
get
(
"WDS_SHOW_SEED"
,
"0"
)
==
"1"
:
print
(
f
"# ResampledShards seed
{
seed
}
"
)
self
.
rng
=
random
.
Random
(
seed
)
for
_
in
range
(
self
.
nshards
):
index
=
self
.
rng
.
randint
(
0
,
len
(
self
.
urls
)
-
1
)
yield
dict
(
url
=
self
.
urls
[
index
])
paddlespeech/audio/streamdata/tariterators.py
0 → 100644
浏览文件 @
d1a25f6c
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Low level iteration functions for tar archives."""
import
random
,
re
,
tarfile
import
braceexpand
from
.
import
filters
from
.
import
gopen
from
.handlers
import
reraise_exception
trace
=
False
meta_prefix
=
"__"
meta_suffix
=
"__"
import
paddlespeech
import
paddle
import
numpy
as
np
AUDIO_FORMAT_SETS
=
set
([
'flac'
,
'mp3'
,
'm4a'
,
'ogg'
,
'opus'
,
'wav'
,
'wma'
])
def
base_plus_ext
(
path
):
"""Split off all file extensions.
Returns base, allext.
:param path: path with extensions
:param returns: path with all extensions removed
"""
match
=
re
.
match
(
r
"^((?:.*/|)[^.]+)[.]([^/]*)$"
,
path
)
if
not
match
:
return
None
,
None
return
match
.
group
(
1
),
match
.
group
(
2
)
def
valid_sample
(
sample
):
"""Check whether a sample is valid.
:param sample: sample to be checked
"""
return
(
sample
is
not
None
and
isinstance
(
sample
,
dict
)
and
len
(
list
(
sample
.
keys
()))
>
0
and
not
sample
.
get
(
"__bad__"
,
False
)
)
# FIXME: UNUSED
def
shardlist
(
urls
,
*
,
shuffle
=
False
):
"""Given a list of URLs, yields that list, possibly shuffled."""
if
isinstance
(
urls
,
str
):
urls
=
braceexpand
.
braceexpand
(
urls
)
else
:
urls
=
list
(
urls
)
if
shuffle
:
random
.
shuffle
(
urls
)
for
url
in
urls
:
yield
dict
(
url
=
url
)
def
url_opener
(
data
,
handler
=
reraise_exception
,
**
kw
):
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
for
sample
in
data
:
assert
isinstance
(
sample
,
dict
),
sample
assert
"url"
in
sample
url
=
sample
[
"url"
]
try
:
stream
=
gopen
.
gopen
(
url
,
**
kw
)
sample
.
update
(
stream
=
stream
)
yield
sample
except
Exception
as
exn
:
exn
.
args
=
exn
.
args
+
(
url
,)
if
handler
(
exn
):
continue
else
:
break
def
tar_file_iterator
(
fileobj
,
skip_meta
=
r
"__[^/]*__($|/)"
,
handler
=
reraise_exception
):
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
:param fileobj: byte stream suitable for tarfile
:param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)")
"""
stream
=
tarfile
.
open
(
fileobj
=
fileobj
,
mode
=
"r:*"
)
for
tarinfo
in
stream
:
fname
=
tarinfo
.
name
try
:
if
not
tarinfo
.
isreg
():
continue
if
fname
is
None
:
continue
if
(
"/"
not
in
fname
and
fname
.
startswith
(
meta_prefix
)
and
fname
.
endswith
(
meta_suffix
)
):
# skipping metadata for now
continue
if
skip_meta
is
not
None
and
re
.
match
(
skip_meta
,
fname
):
continue
name
=
tarinfo
.
name
pos
=
name
.
rfind
(
'.'
)
assert
pos
>
0
prefix
,
postfix
=
name
[:
pos
],
name
[
pos
+
1
:]
if
postfix
==
'wav'
:
waveform
,
sample_rate
=
paddlespeech
.
audio
.
load
(
stream
.
extractfile
(
tarinfo
),
normal
=
False
)
result
=
dict
(
fname
=
prefix
,
wav
=
waveform
,
sample_rate
=
sample_rate
)
else
:
txt
=
stream
.
extractfile
(
tarinfo
).
read
().
decode
(
'utf8'
).
strip
()
result
=
dict
(
fname
=
prefix
,
txt
=
txt
)
#result = dict(fname=fname, data=data)
yield
result
stream
.
members
=
[]
except
Exception
as
exn
:
if
hasattr
(
exn
,
"args"
)
and
len
(
exn
.
args
)
>
0
:
exn
.
args
=
(
exn
.
args
[
0
]
+
" @ "
+
str
(
fileobj
),)
+
exn
.
args
[
1
:]
if
handler
(
exn
):
continue
else
:
break
del
stream
def
tar_file_and_group_iterator
(
fileobj
,
skip_meta
=
r
"__[^/]*__($|/)"
,
handler
=
reraise_exception
):
""" Expand a stream of open tar files into a stream of tar file contents.
And groups the file with same prefix
Args:
data: Iterable[{src, stream}]
Returns:
Iterable[{key, wav, txt, sample_rate}]
"""
stream
=
tarfile
.
open
(
fileobj
=
fileobj
,
mode
=
"r:*"
)
prev_prefix
=
None
example
=
{}
valid
=
True
for
tarinfo
in
stream
:
name
=
tarinfo
.
name
pos
=
name
.
rfind
(
'.'
)
assert
pos
>
0
prefix
,
postfix
=
name
[:
pos
],
name
[
pos
+
1
:]
if
prev_prefix
is
not
None
and
prefix
!=
prev_prefix
:
example
[
'fname'
]
=
prev_prefix
if
valid
:
yield
example
example
=
{}
valid
=
True
with
stream
.
extractfile
(
tarinfo
)
as
file_obj
:
try
:
if
postfix
==
'txt'
:
example
[
'txt'
]
=
file_obj
.
read
().
decode
(
'utf8'
).
strip
()
elif
postfix
in
AUDIO_FORMAT_SETS
:
waveform
,
sample_rate
=
paddlespeech
.
audio
.
load
(
file_obj
,
normal
=
False
)
waveform
=
paddle
.
to_tensor
(
np
.
expand_dims
(
np
.
array
(
waveform
),
0
),
dtype
=
paddle
.
float32
)
example
[
'wav'
]
=
waveform
example
[
'sample_rate'
]
=
sample_rate
else
:
example
[
postfix
]
=
file_obj
.
read
()
except
Exception
as
exn
:
if
hasattr
(
exn
,
"args"
)
and
len
(
exn
.
args
)
>
0
:
exn
.
args
=
(
exn
.
args
[
0
]
+
" @ "
+
str
(
fileobj
),)
+
exn
.
args
[
1
:]
if
handler
(
exn
):
continue
else
:
break
valid
=
False
# logging.warning('error to parse {}'.format(name))
prev_prefix
=
prefix
if
prev_prefix
is
not
None
:
example
[
'fname'
]
=
prev_prefix
yield
example
stream
.
close
()
def
tar_file_expander
(
data
,
handler
=
reraise_exception
):
"""Expand a stream of open tar files into a stream of tar file contents.
This returns an iterator over (filename, file_contents).
"""
for
source
in
data
:
url
=
source
[
"url"
]
try
:
assert
isinstance
(
source
,
dict
)
assert
"stream"
in
source
for
sample
in
tar_file_iterator
(
source
[
"stream"
]):
assert
(
isinstance
(
sample
,
dict
)
and
"data"
in
sample
and
"fname"
in
sample
)
sample
[
"__url__"
]
=
url
yield
sample
except
Exception
as
exn
:
exn
.
args
=
exn
.
args
+
(
source
.
get
(
"stream"
),
source
.
get
(
"url"
))
if
handler
(
exn
):
continue
else
:
break
def
tar_file_and_group_expander
(
data
,
handler
=
reraise_exception
):
"""Expand a stream of open tar files into a stream of tar file contents.
This returns an iterator over (filename, file_contents).
"""
for
source
in
data
:
url
=
source
[
"url"
]
try
:
assert
isinstance
(
source
,
dict
)
assert
"stream"
in
source
for
sample
in
tar_file_and_group_iterator
(
source
[
"stream"
]):
assert
(
isinstance
(
sample
,
dict
)
and
"wav"
in
sample
and
"txt"
in
sample
and
"fname"
in
sample
)
sample
[
"__url__"
]
=
url
yield
sample
except
Exception
as
exn
:
exn
.
args
=
exn
.
args
+
(
source
.
get
(
"stream"
),
source
.
get
(
"url"
))
if
handler
(
exn
):
continue
else
:
break
def
group_by_keys
(
data
,
keys
=
base_plus_ext
,
lcase
=
True
,
suffixes
=
None
,
handler
=
None
):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
:param lcase: convert suffixes to lower case (Default value = True)
"""
current_sample
=
None
for
filesample
in
data
:
assert
isinstance
(
filesample
,
dict
)
fname
,
value
=
filesample
[
"fname"
],
filesample
[
"data"
]
prefix
,
suffix
=
keys
(
fname
)
if
trace
:
print
(
prefix
,
suffix
,
current_sample
.
keys
()
if
isinstance
(
current_sample
,
dict
)
else
None
,
)
if
prefix
is
None
:
continue
if
lcase
:
suffix
=
suffix
.
lower
()
if
current_sample
is
None
or
prefix
!=
current_sample
[
"__key__"
]:
if
valid_sample
(
current_sample
):
yield
current_sample
current_sample
=
dict
(
__key__
=
prefix
,
__url__
=
filesample
[
"__url__"
])
if
suffix
in
current_sample
:
raise
ValueError
(
f
"
{
fname
}
: duplicate file name in tar file
{
suffix
}
{
current_sample
.
keys
()
}
"
)
if
suffixes
is
None
or
suffix
in
suffixes
:
current_sample
[
suffix
]
=
value
if
valid_sample
(
current_sample
):
yield
current_sample
def
tarfile_samples
(
src
,
handler
=
reraise_exception
):
streams
=
url_opener
(
src
,
handler
=
handler
)
samples
=
tar_file_and_group_expander
(
streams
,
handler
=
handler
)
return
samples
tarfile_to_samples
=
filters
.
pipelinefilter
(
tarfile_samples
)
paddlespeech/audio/streamdata/utils.py
0 → 100644
浏览文件 @
d1a25f6c
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
"""Miscellaneous utility functions."""
import
importlib
import
itertools
as
itt
import
os
import
re
import
sys
from
typing
import
Any
,
Callable
,
Iterator
,
Optional
,
Union
from
..utils.log
import
Logger
logger
=
Logger
(
__name__
)
def
make_seed
(
*
args
):
seed
=
0
for
arg
in
args
:
seed
=
(
seed
*
31
+
hash
(
arg
))
&
0x7FFFFFFF
return
seed
class
PipelineStage
:
def
invoke
(
self
,
*
args
,
**
kw
):
raise
NotImplementedError
def
identity
(
x
:
Any
)
->
Any
:
"""Return the argument as is."""
return
x
def
safe_eval
(
s
:
str
,
expr
:
str
=
"{}"
):
"""Evaluate the given expression more safely."""
if
re
.
sub
(
"[^A-Za-z0-9_]"
,
""
,
s
)
!=
s
:
raise
ValueError
(
f
"safe_eval: illegal characters in: '
{
s
}
'"
)
return
eval
(
expr
.
format
(
s
))
def
lookup_sym
(
sym
:
str
,
modules
:
list
):
"""Look up a symbol in a list of modules."""
for
mname
in
modules
:
module
=
importlib
.
import_module
(
mname
,
package
=
"webdataset"
)
result
=
getattr
(
module
,
sym
,
None
)
if
result
is
not
None
:
return
result
return
None
def
repeatedly0
(
loader
:
Iterator
,
nepochs
:
int
=
sys
.
maxsize
,
nbatches
:
int
=
sys
.
maxsize
):
"""Repeatedly returns batches from a DataLoader."""
for
epoch
in
range
(
nepochs
):
for
sample
in
itt
.
islice
(
loader
,
nbatches
):
yield
sample
def
guess_batchsize
(
batch
:
Union
[
tuple
,
list
]):
"""Guess the batch size by looking at the length of the first element in a tuple."""
return
len
(
batch
[
0
])
def
repeatedly
(
source
:
Iterator
,
nepochs
:
int
=
None
,
nbatches
:
int
=
None
,
nsamples
:
int
=
None
,
batchsize
:
Callable
[...,
int
]
=
guess_batchsize
,
):
"""Repeatedly yield samples from an iterator."""
epoch
=
0
batch
=
0
total
=
0
while
True
:
for
sample
in
source
:
yield
sample
batch
+=
1
if
nbatches
is
not
None
and
batch
>=
nbatches
:
return
if
nsamples
is
not
None
:
total
+=
guess_batchsize
(
sample
)
if
total
>=
nsamples
:
return
epoch
+=
1
if
nepochs
is
not
None
and
epoch
>=
nepochs
:
return
def
paddle_worker_info
(
group
=
None
):
"""Return node and worker info for PyTorch and some distributed environments."""
rank
=
0
world_size
=
1
worker
=
0
num_workers
=
1
if
"RANK"
in
os
.
environ
and
"WORLD_SIZE"
in
os
.
environ
:
rank
=
int
(
os
.
environ
[
"RANK"
])
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
else
:
try
:
import
paddle.distributed
group
=
group
or
paddle
.
distributed
.
get_group
()
rank
=
paddle
.
distributed
.
get_rank
()
world_size
=
paddle
.
distributed
.
get_world_size
()
except
ModuleNotFoundError
:
pass
if
"WORKER"
in
os
.
environ
and
"NUM_WORKERS"
in
os
.
environ
:
worker
=
int
(
os
.
environ
[
"WORKER"
])
num_workers
=
int
(
os
.
environ
[
"NUM_WORKERS"
])
else
:
try
:
from
paddle.io
import
get_worker_info
worker_info
=
paddle
.
io
.
get_worker_info
()
if
worker_info
is
not
None
:
worker
=
worker_info
.
id
num_workers
=
worker_info
.
num_workers
except
ModuleNotFoundError
as
E
:
logger
.
info
(
f
"not found
{
E
}
"
)
exit
(
-
1
)
return
rank
,
world_size
,
worker
,
num_workers
def
paddle_worker_seed
(
group
=
None
):
"""Compute a distinct, deterministic RNG seed for each worker and node."""
rank
,
world_size
,
worker
,
num_workers
=
paddle_worker_info
(
group
=
group
)
return
rank
*
1000
+
worker
paddlespeech/audio/streamdata/writer.py
0 → 100644
浏览文件 @
d1a25f6c
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Classes and functions for writing tar files and WebDataset files."""
import
io
,
json
,
pickle
,
re
,
tarfile
,
time
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
numpy
as
np
from
.
import
gopen
def
imageencoder
(
image
:
Any
,
format
:
str
=
"PNG"
):
# skipcq: PYL-W0622
"""Compress an image using PIL and return it as a string.
Can handle float or uint8 images.
:param image: ndarray representing an image
:param format: compression format (PNG, JPEG, PPM)
"""
import
PIL
assert
isinstance
(
image
,
(
PIL
.
Image
.
Image
,
np
.
ndarray
)),
type
(
image
)
if
isinstance
(
image
,
np
.
ndarray
):
if
image
.
dtype
in
[
np
.
dtype
(
"f"
),
np
.
dtype
(
"d"
)]:
if
not
(
np
.
amin
(
image
)
>
-
0.001
and
np
.
amax
(
image
)
<
1.001
):
raise
ValueError
(
f
"image values out of range
{
np
.
amin
(
image
)
}
{
np
.
amax
(
image
)
}
"
)
image
=
np
.
clip
(
image
,
0.0
,
1.0
)
image
=
np
.
array
(
image
*
255.0
,
"uint8"
)
assert
image
.
ndim
in
[
2
,
3
]
if
image
.
ndim
==
3
:
assert
image
.
shape
[
2
]
in
[
1
,
3
]
image
=
PIL
.
Image
.
fromarray
(
image
)
if
format
.
upper
()
==
"JPG"
:
format
=
"JPEG"
elif
format
.
upper
()
in
[
"IMG"
,
"IMAGE"
]:
format
=
"PPM"
if
format
==
"JPEG"
:
opts
=
dict
(
quality
=
100
)
else
:
opts
=
{}
with
io
.
BytesIO
()
as
result
:
image
.
save
(
result
,
format
=
format
,
**
opts
)
return
result
.
getvalue
()
def
bytestr
(
data
:
Any
):
"""Convert data into a bytestring.
Uses str and ASCII encoding for data that isn't already in string format.
:param data: data
"""
if
isinstance
(
data
,
bytes
):
return
data
if
isinstance
(
data
,
str
):
return
data
.
encode
(
"ascii"
)
return
str
(
data
).
encode
(
"ascii"
)
def
paddle_dumps
(
data
:
Any
):
"""Dump data into a bytestring using paddle.dumps.
This delays importing paddle until needed.
:param data: data to be dumped
"""
import
io
import
paddle
stream
=
io
.
BytesIO
()
paddle
.
save
(
data
,
stream
)
return
stream
.
getvalue
()
def
numpy_dumps
(
data
:
np
.
ndarray
):
"""Dump data into a bytestring using numpy npy format.
:param data: data to be dumped
"""
import
io
import
numpy.lib.format
stream
=
io
.
BytesIO
()
numpy
.
lib
.
format
.
write_array
(
stream
,
data
)
return
stream
.
getvalue
()
def
numpy_npz_dumps
(
data
:
np
.
ndarray
):
"""Dump data into a bytestring using numpy npz format.
:param data: data to be dumped
"""
import
io
stream
=
io
.
BytesIO
()
np
.
savez_compressed
(
stream
,
**
data
)
return
stream
.
getvalue
()
def
tenbin_dumps
(
x
):
from
.
import
tenbin
if
isinstance
(
x
,
list
):
return
memoryview
(
tenbin
.
encode_buffer
(
x
))
else
:
return
memoryview
(
tenbin
.
encode_buffer
([
x
]))
def
cbor_dumps
(
x
):
import
cbor
return
cbor
.
dumps
(
x
)
def
mp_dumps
(
x
):
import
msgpack
return
msgpack
.
packb
(
x
)
def
add_handlers
(
d
,
keys
,
value
):
if
isinstance
(
keys
,
str
):
keys
=
keys
.
split
()
for
k
in
keys
:
d
[
k
]
=
value
def
make_handlers
():
"""Create a list of handlers for encoding data."""
handlers
=
{}
add_handlers
(
handlers
,
"cls cls2 class count index inx id"
,
lambda
x
:
str
(
x
).
encode
(
"ascii"
)
)
add_handlers
(
handlers
,
"txt text transcript"
,
lambda
x
:
x
.
encode
(
"utf-8"
))
add_handlers
(
handlers
,
"html htm"
,
lambda
x
:
x
.
encode
(
"utf-8"
))
add_handlers
(
handlers
,
"pyd pickle"
,
pickle
.
dumps
)
add_handlers
(
handlers
,
"pdparams"
,
paddle_dumps
)
add_handlers
(
handlers
,
"npy"
,
numpy_dumps
)
add_handlers
(
handlers
,
"npz"
,
numpy_npz_dumps
)
add_handlers
(
handlers
,
"ten tenbin tb"
,
tenbin_dumps
)
add_handlers
(
handlers
,
"json jsn"
,
lambda
x
:
json
.
dumps
(
x
).
encode
(
"utf-8"
))
add_handlers
(
handlers
,
"mp msgpack msg"
,
mp_dumps
)
add_handlers
(
handlers
,
"cbor"
,
cbor_dumps
)
add_handlers
(
handlers
,
"jpg jpeg img image"
,
lambda
data
:
imageencoder
(
data
,
"jpg"
))
add_handlers
(
handlers
,
"png"
,
lambda
data
:
imageencoder
(
data
,
"png"
))
add_handlers
(
handlers
,
"pbm"
,
lambda
data
:
imageencoder
(
data
,
"pbm"
))
add_handlers
(
handlers
,
"pgm"
,
lambda
data
:
imageencoder
(
data
,
"pgm"
))
add_handlers
(
handlers
,
"ppm"
,
lambda
data
:
imageencoder
(
data
,
"ppm"
))
return
handlers
default_handlers
=
make_handlers
()
def
encode_based_on_extension1
(
data
:
Any
,
tname
:
str
,
handlers
:
dict
):
"""Encode data based on its extension and a dict of handlers.
:param data: data
:param tname: file extension
:param handlers: handlers
"""
if
tname
[
0
]
==
"_"
:
if
not
isinstance
(
data
,
str
):
raise
ValueError
(
"the values of metadata must be of string type"
)
return
data
extension
=
re
.
sub
(
r
".*\."
,
""
,
tname
).
lower
()
if
isinstance
(
data
,
bytes
):
return
data
if
isinstance
(
data
,
str
):
return
data
.
encode
(
"utf-8"
)
handler
=
handlers
.
get
(
extension
)
if
handler
is
None
:
raise
ValueError
(
f
"no handler found for
{
extension
}
"
)
return
handler
(
data
)
def
encode_based_on_extension
(
sample
:
dict
,
handlers
:
dict
):
"""Encode an entire sample with a collection of handlers.
:param sample: data sample (a dict)
:param handlers: handlers for encoding
"""
return
{
k
:
encode_based_on_extension1
(
v
,
k
,
handlers
)
for
k
,
v
in
list
(
sample
.
items
())
}
def
make_encoder
(
spec
:
Union
[
bool
,
str
,
dict
,
Callable
]):
"""Make an encoder function from a specification.
:param spec: specification
"""
if
spec
is
False
or
spec
is
None
:
def
encoder
(
x
):
"""Do not encode at all."""
return
x
elif
callable
(
spec
):
encoder
=
spec
elif
isinstance
(
spec
,
dict
):
def
f
(
sample
):
"""Encode based on extension."""
return
encode_based_on_extension
(
sample
,
spec
)
encoder
=
f
elif
spec
is
True
:
handlers
=
default_handlers
def
g
(
sample
):
"""Encode based on extension."""
return
encode_based_on_extension
(
sample
,
handlers
)
encoder
=
g
else
:
raise
ValueError
(
f
"
{
spec
}
: unknown decoder spec"
)
if
not
callable
(
encoder
):
raise
ValueError
(
f
"
{
spec
}
did not yield a callable encoder"
)
return
encoder
class
TarWriter
:
"""A class for writing dictionaries to tar files.
:param fileobj: fileobj: file name for tar file (.tgz/.tar) or open file descriptor
:param encoder: sample encoding (Default value = True)
:param compress: (Default value = None)
`True` will use an encoder that behaves similar to the automatic
decoder for `Dataset`. `False` disables encoding and expects byte strings
(except for metadata, which must be strings). The `encoder` argument can
also be a `callable`, or a dictionary mapping extensions to encoders.
The following code will add two file to the tar archive: `a/b.png` and
`a/b.output.png`.
```Python
tarwriter = TarWriter(stream)
image = imread("b.jpg")
image2 = imread("b.out.jpg")
sample = {"__key__": "a/b", "png": image, "output.png": image2}
tarwriter.write(sample)
```
"""
def
__init__
(
self
,
fileobj
,
user
:
str
=
"bigdata"
,
group
:
str
=
"bigdata"
,
mode
:
int
=
0o0444
,
compress
:
Optional
[
bool
]
=
None
,
encoder
:
Union
[
None
,
bool
,
Callable
]
=
True
,
keep_meta
:
bool
=
False
,
):
"""Create a tar writer.
:param fileobj: stream to write data to
:param user: user for tar files
:param group: group for tar files
:param mode: mode for tar files
:param compress: desired compression
:param encoder: encoder function
:param keep_meta: keep metadata (entries starting with "_")
"""
if
isinstance
(
fileobj
,
str
):
if
compress
is
False
:
tarmode
=
"w|"
elif
compress
is
True
:
tarmode
=
"w|gz"
else
:
tarmode
=
"w|gz"
if
fileobj
.
endswith
(
"gz"
)
else
"w|"
fileobj
=
gopen
.
gopen
(
fileobj
,
"wb"
)
self
.
own_fileobj
=
fileobj
else
:
tarmode
=
"w|gz"
if
compress
is
True
else
"w|"
self
.
own_fileobj
=
None
self
.
encoder
=
make_encoder
(
encoder
)
self
.
keep_meta
=
keep_meta
self
.
stream
=
fileobj
self
.
tarstream
=
tarfile
.
open
(
fileobj
=
fileobj
,
mode
=
tarmode
)
self
.
user
=
user
self
.
group
=
group
self
.
mode
=
mode
self
.
compress
=
compress
def
__enter__
(
self
):
"""Enter context."""
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
"""Exit context."""
self
.
close
()
def
close
(
self
):
"""Close the tar file."""
self
.
tarstream
.
close
()
if
self
.
own_fileobj
is
not
None
:
self
.
own_fileobj
.
close
()
self
.
own_fileobj
=
None
def
write
(
self
,
obj
):
"""Write a dictionary to the tar file.
:param obj: dictionary of objects to be stored
:returns: size of the entry
"""
total
=
0
obj
=
self
.
encoder
(
obj
)
if
"__key__"
not
in
obj
:
raise
ValueError
(
"object must contain a __key__"
)
for
k
,
v
in
list
(
obj
.
items
()):
if
k
[
0
]
==
"_"
:
continue
if
not
isinstance
(
v
,
(
bytes
,
bytearray
,
memoryview
)):
raise
ValueError
(
f
"
{
k
}
doesn't map to a bytes after encoding (
{
type
(
v
)
}
)"
)
key
=
obj
[
"__key__"
]
for
k
in
sorted
(
obj
.
keys
()):
if
k
==
"__key__"
:
continue
if
not
self
.
keep_meta
and
k
[
0
]
==
"_"
:
continue
v
=
obj
[
k
]
if
isinstance
(
v
,
str
):
v
=
v
.
encode
(
"utf-8"
)
now
=
time
.
time
()
ti
=
tarfile
.
TarInfo
(
key
+
"."
+
k
)
ti
.
size
=
len
(
v
)
ti
.
mtime
=
now
ti
.
mode
=
self
.
mode
ti
.
uname
=
self
.
user
ti
.
gname
=
self
.
group
if
not
isinstance
(
v
,
(
bytes
,
bytearray
,
memoryview
)):
raise
ValueError
(
f
"converter didn't yield bytes:
{
k
}
,
{
type
(
v
)
}
"
)
stream
=
io
.
BytesIO
(
v
)
self
.
tarstream
.
addfile
(
ti
,
stream
)
total
+=
ti
.
size
return
total
class
ShardWriter
:
"""Like TarWriter but splits into multiple shards."""
def
__init__
(
self
,
pattern
:
str
,
maxcount
:
int
=
100000
,
maxsize
:
float
=
3e9
,
post
:
Optional
[
Callable
]
=
None
,
start_shard
:
int
=
0
,
**
kw
,
):
"""Create a ShardWriter.
:param pattern: output file pattern
:param maxcount: maximum number of records per shard (Default value = 100000)
:param maxsize: maximum size of each shard (Default value = 3e9)
:param kw: other options passed to TarWriter
"""
self
.
verbose
=
1
self
.
kw
=
kw
self
.
maxcount
=
maxcount
self
.
maxsize
=
maxsize
self
.
post
=
post
self
.
tarstream
=
None
self
.
shard
=
start_shard
self
.
pattern
=
pattern
self
.
total
=
0
self
.
count
=
0
self
.
size
=
0
self
.
fname
=
None
self
.
next_stream
()
def
next_stream
(
self
):
"""Close the current stream and move to the next."""
self
.
finish
()
self
.
fname
=
self
.
pattern
%
self
.
shard
if
self
.
verbose
:
print
(
"# writing"
,
self
.
fname
,
self
.
count
,
"%.1f GB"
%
(
self
.
size
/
1e9
),
self
.
total
,
)
self
.
shard
+=
1
stream
=
open
(
self
.
fname
,
"wb"
)
self
.
tarstream
=
TarWriter
(
stream
,
**
self
.
kw
)
self
.
count
=
0
self
.
size
=
0
def
write
(
self
,
obj
):
"""Write a sample.
:param obj: sample to be written
"""
if
(
self
.
tarstream
is
None
or
self
.
count
>=
self
.
maxcount
or
self
.
size
>=
self
.
maxsize
):
self
.
next_stream
()
size
=
self
.
tarstream
.
write
(
obj
)
self
.
count
+=
1
self
.
total
+=
1
self
.
size
+=
size
def
finish
(
self
):
"""Finish all writing (use close instead)."""
if
self
.
tarstream
is
not
None
:
self
.
tarstream
.
close
()
assert
self
.
fname
is
not
None
if
callable
(
self
.
post
):
self
.
post
(
self
.
fname
)
self
.
tarstream
=
None
def
close
(
self
):
"""Close the stream."""
self
.
finish
()
del
self
.
tarstream
del
self
.
shard
del
self
.
count
del
self
.
size
def
__enter__
(
self
):
"""Enter context."""
return
self
def
__exit__
(
self
,
*
args
,
**
kw
):
"""Exit context."""
self
.
close
()
paddlespeech/audio/text/text_featurizer.py
0 → 100644
浏览文件 @
d1a25f6c
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the text featurizer class."""
from
pprint
import
pformat
from
typing
import
Union
import
sentencepiece
as
spm
from
.utility
import
BLANK
from
.utility
import
EOS
from
.utility
import
load_dict
from
.utility
import
MASKCTC
from
.utility
import
SOS
from
.utility
import
SPACE
from
.utility
import
UNK
from
..utils.log
import
Logger
logger
=
Logger
(
__name__
)
__all__
=
[
"TextFeaturizer"
]
class
TextFeaturizer
():
def
__init__
(
self
,
unit_type
,
vocab
,
spm_model_prefix
=
None
,
maskctc
=
False
):
"""Text featurizer, for processing or extracting features from text.
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
a list of token indices. Note that the token indexing order follows the
given vocabulary file.
Args:
unit_type (str): unit type, e.g. char, word, spm
vocab Option[str, list]: Filepath to load vocabulary for token indices conversion, or vocab list.
spm_model_prefix (str, optional): spm model prefix. Defaults to None.
"""
assert
unit_type
in
(
'char'
,
'spm'
,
'word'
)
self
.
unit_type
=
unit_type
self
.
unk
=
UNK
self
.
maskctc
=
maskctc
if
vocab
:
self
.
vocab_dict
,
self
.
_id2token
,
self
.
vocab_list
,
self
.
unk_id
,
self
.
eos_id
,
self
.
blank_id
=
self
.
_load_vocabulary_from_file
(
vocab
,
maskctc
)
self
.
vocab_size
=
len
(
self
.
vocab_list
)
else
:
logger
.
warning
(
"TextFeaturizer: not have vocab file or vocab list."
)
if
unit_type
==
'spm'
:
spm_model
=
spm_model_prefix
+
'.model'
self
.
sp
=
spm
.
SentencePieceProcessor
()
self
.
sp
.
Load
(
spm_model
)
def
tokenize
(
self
,
text
,
replace_space
=
True
):
if
self
.
unit_type
==
'char'
:
tokens
=
self
.
char_tokenize
(
text
,
replace_space
)
elif
self
.
unit_type
==
'word'
:
tokens
=
self
.
word_tokenize
(
text
)
else
:
# spm
tokens
=
self
.
spm_tokenize
(
text
)
return
tokens
def
detokenize
(
self
,
tokens
):
if
self
.
unit_type
==
'char'
:
text
=
self
.
char_detokenize
(
tokens
)
elif
self
.
unit_type
==
'word'
:
text
=
self
.
word_detokenize
(
tokens
)
else
:
# spm
text
=
self
.
spm_detokenize
(
tokens
)
return
text
def
featurize
(
self
,
text
):
"""Convert text string to a list of token indices.
Args:
text (str): Text to process.
Returns:
List[int]: List of token indices.
"""
tokens
=
self
.
tokenize
(
text
)
ids
=
[]
for
token
in
tokens
:
if
token
not
in
self
.
vocab_dict
:
logger
.
debug
(
f
"Text Token:
{
token
}
->
{
self
.
unk
}
"
)
token
=
self
.
unk
ids
.
append
(
self
.
vocab_dict
[
token
])
return
ids
def
defeaturize
(
self
,
idxs
):
"""Convert a list of token indices to text string,
ignore index after eos_id.
Args:
idxs (List[int]): List of token indices.
Returns:
str: Text.
"""
tokens
=
[]
for
idx
in
idxs
:
if
idx
==
self
.
eos_id
:
break
tokens
.
append
(
self
.
_id2token
[
idx
])
text
=
self
.
detokenize
(
tokens
)
return
text
def
char_tokenize
(
self
,
text
,
replace_space
=
True
):
"""Character tokenizer.
Args:
text (str): text string.
replace_space (bool): False only used by build_vocab.py.
Returns:
List[str]: tokens.
"""
text
=
text
.
strip
()
if
replace_space
:
text_list
=
[
SPACE
if
item
==
" "
else
item
for
item
in
list
(
text
)]
else
:
text_list
=
list
(
text
)
return
text_list
def
char_detokenize
(
self
,
tokens
):
"""Character detokenizer.
Args:
tokens (List[str]): tokens.
Returns:
str: text string.
"""
tokens
=
[
t
.
replace
(
SPACE
,
" "
)
for
t
in
tokens
]
return
""
.
join
(
tokens
)
def
word_tokenize
(
self
,
text
):
"""Word tokenizer, separate by <space>."""
return
text
.
strip
().
split
()
def
word_detokenize
(
self
,
tokens
):
"""Word detokenizer, separate by <space>."""
return
" "
.
join
(
tokens
)
def
spm_tokenize
(
self
,
text
):
"""spm tokenize.
Args:
text (str): text string.
Returns:
List[str]: sentence pieces str code
"""
stats
=
{
"num_empty"
:
0
,
"num_filtered"
:
0
}
def
valid
(
line
):
return
True
def
encode
(
l
):
return
self
.
sp
.
EncodeAsPieces
(
l
)
def
encode_line
(
line
):
line
=
line
.
strip
()
if
len
(
line
)
>
0
:
line
=
encode
(
line
)
if
valid
(
line
):
return
line
else
:
stats
[
"num_filtered"
]
+=
1
else
:
stats
[
"num_empty"
]
+=
1
return
None
enc_line
=
encode_line
(
text
)
return
enc_line
def
spm_detokenize
(
self
,
tokens
,
input_format
=
'piece'
):
"""spm detokenize.
Args:
ids (List[str]): tokens.
Returns:
str: text
"""
if
input_format
==
"piece"
:
def
decode
(
l
):
return
""
.
join
(
self
.
sp
.
DecodePieces
(
l
))
elif
input_format
==
"id"
:
def
decode
(
l
):
return
""
.
join
(
self
.
sp
.
DecodeIds
(
l
))
return
decode
(
tokens
)
def
_load_vocabulary_from_file
(
self
,
vocab
:
Union
[
str
,
list
],
maskctc
:
bool
):
"""Load vocabulary from file."""
if
isinstance
(
vocab
,
list
):
vocab_list
=
vocab
else
:
vocab_list
=
load_dict
(
vocab
,
maskctc
)
assert
vocab_list
is
not
None
logger
.
debug
(
f
"Vocab:
{
pformat
(
vocab_list
)
}
"
)
id2token
=
dict
(
[(
idx
,
token
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
token2id
=
dict
(
[(
token
,
idx
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
blank_id
=
vocab_list
.
index
(
BLANK
)
if
BLANK
in
vocab_list
else
-
1
maskctc_id
=
vocab_list
.
index
(
MASKCTC
)
if
MASKCTC
in
vocab_list
else
-
1
unk_id
=
vocab_list
.
index
(
UNK
)
if
UNK
in
vocab_list
else
-
1
eos_id
=
vocab_list
.
index
(
EOS
)
if
EOS
in
vocab_list
else
-
1
sos_id
=
vocab_list
.
index
(
SOS
)
if
SOS
in
vocab_list
else
-
1
space_id
=
vocab_list
.
index
(
SPACE
)
if
SPACE
in
vocab_list
else
-
1
logger
.
info
(
f
"BLANK id:
{
blank_id
}
"
)
logger
.
info
(
f
"UNK id:
{
unk_id
}
"
)
logger
.
info
(
f
"EOS id:
{
eos_id
}
"
)
logger
.
info
(
f
"SOS id:
{
sos_id
}
"
)
logger
.
info
(
f
"SPACE id:
{
space_id
}
"
)
logger
.
info
(
f
"MASKCTC id:
{
maskctc_id
}
"
)
return
token2id
,
id2token
,
vocab_list
,
unk_id
,
eos_id
,
blank_id
paddlespeech/audio/text/utility.py
0 → 100644
浏览文件 @
d1a25f6c
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains data helper functions."""
import
json
import
math
import
tarfile
from
collections
import
namedtuple
from
typing
import
List
from
typing
import
Optional
from
typing
import
Text
import
jsonlines
import
numpy
as
np
from
paddlespeech.s2t.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"load_dict"
,
"load_cmvn"
,
"read_manifest"
,
"rms_to_db"
,
"rms_to_dbfs"
,
"max_dbfs"
,
"mean_dbfs"
,
"gain_db_to_ratio"
,
"normalize_audio"
,
"SOS"
,
"EOS"
,
"UNK"
,
"BLANK"
,
"MASKCTC"
,
"SPACE"
,
"convert_samples_to_float32"
,
"convert_samples_from_float32"
]
IGNORE_ID
=
-
1
# `sos` and `eos` using same token
SOS
=
"<eos>"
EOS
=
SOS
UNK
=
"<unk>"
BLANK
=
"<blank>"
MASKCTC
=
"<mask>"
SPACE
=
"<space>"
def
load_dict
(
dict_path
:
Optional
[
Text
],
maskctc
=
False
)
->
Optional
[
List
[
Text
]]:
if
dict_path
is
None
:
return
None
with
open
(
dict_path
,
"r"
)
as
f
:
dictionary
=
f
.
readlines
()
# first token is `<blank>`
# multi line: `<blank> 0\n`
# one line: `<blank>`
# space is relpace with <space>
char_list
=
[
entry
[:
-
1
].
split
(
" "
)[
0
]
for
entry
in
dictionary
]
if
BLANK
not
in
char_list
:
char_list
.
insert
(
0
,
BLANK
)
if
EOS
not
in
char_list
:
char_list
.
append
(
EOS
)
# for non-autoregressive maskctc model
if
maskctc
and
MASKCTC
not
in
char_list
:
char_list
.
append
(
MASKCTC
)
return
char_list
def
read_manifest
(
manifest_path
,
max_input_len
=
float
(
'inf'
),
min_input_len
=
0.0
,
max_output_len
=
float
(
'inf'
),
min_output_len
=
0.0
,
max_output_input_ratio
=
float
(
'inf'
),
min_output_input_ratio
=
0.0
,
):
"""Load and parse manifest file.
Args:
manifest_path ([type]): Manifest file to load and parse.
max_input_len ([type], optional): maximum output seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to float('inf').
min_input_len (float, optional): minimum input seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to 0.0.
max_output_len (float, optional): maximum input seq length,
in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length,
in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional):
maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional):
minimum output seq length/output seq length ratio. Defaults to 0.05.
Raises:
IOError: If failed to parse the manifest.
Returns:
List[dict]: Manifest parsing results.
"""
manifest
=
[]
with
jsonlines
.
open
(
manifest_path
,
'r'
)
as
reader
:
for
json_data
in
reader
:
feat_len
=
json_data
[
"input"
][
0
][
"shape"
][
0
]
if
"input"
in
json_data
and
"shape"
in
json_data
[
"input"
][
0
]
else
1.0
token_len
=
json_data
[
"output"
][
0
][
"shape"
][
0
]
if
"output"
in
json_data
and
"shape"
in
json_data
[
"output"
][
0
]
else
1.0
conditions
=
[
feat_len
>=
min_input_len
,
feat_len
<=
max_input_len
,
token_len
>=
min_output_len
,
token_len
<=
max_output_len
,
token_len
/
feat_len
>=
min_output_input_ratio
,
token_len
/
feat_len
<=
max_output_input_ratio
,
]
if
all
(
conditions
):
manifest
.
append
(
json_data
)
return
manifest
# Tar File read
TarLocalData
=
namedtuple
(
'TarLocalData'
,
[
'tar2info'
,
'tar2object'
])
def
parse_tar
(
file
):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result
=
{}
f
=
tarfile
.
open
(
file
)
for
tarinfo
in
f
.
getmembers
():
result
[
tarinfo
.
name
]
=
tarinfo
return
f
,
result
def
subfile_from_tar
(
file
,
local_data
=
None
):
"""Get subfile object from tar.
tar:tarpath#filename
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath
,
filename
=
file
.
split
(
':'
,
1
)[
1
].
split
(
'#'
,
1
)
if
local_data
is
None
:
local_data
=
TarLocalData
(
tar2info
=
{},
tar2object
=
{})
assert
isinstance
(
local_data
,
TarLocalData
)
if
'tar2info'
not
in
local_data
.
__dict__
:
local_data
.
tar2info
=
{}
if
'tar2object'
not
in
local_data
.
__dict__
:
local_data
.
tar2object
=
{}
if
tarpath
not
in
local_data
.
tar2info
:
fobj
,
infos
=
parse_tar
(
tarpath
)
local_data
.
tar2info
[
tarpath
]
=
infos
local_data
.
tar2object
[
tarpath
]
=
fobj
else
:
fobj
=
local_data
.
tar2object
[
tarpath
]
infos
=
local_data
.
tar2info
[
tarpath
]
return
fobj
.
extractfile
(
infos
[
filename
])
def
rms_to_db
(
rms
:
float
):
"""Root Mean Square to dB.
Args:
rms ([float]): root mean square
Returns:
float: dB
"""
return
20.0
*
math
.
log10
(
max
(
1e-16
,
rms
))
def
rms_to_dbfs
(
rms
:
float
):
"""Root Mean Square to dBFS.
https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/
Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB.
dB = dBFS + 3.0103
dBFS = db - 3.0103
e.g. 0 dB = -3.0103 dBFS
Args:
rms ([float]): root mean square
Returns:
float: dBFS
"""
return
rms_to_db
(
rms
)
-
3.0103
def
max_dbfs
(
sample_data
:
np
.
ndarray
):
"""Peak dBFS based on the maximum energy sample.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
"""
# Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization.
return
rms_to_dbfs
(
max
(
abs
(
np
.
min
(
sample_data
)),
abs
(
np
.
max
(
sample_data
))))
def
mean_dbfs
(
sample_data
):
"""Peak dBFS based on the RMS energy.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
"""
return
rms_to_dbfs
(
math
.
sqrt
(
np
.
mean
(
np
.
square
(
sample_data
,
dtype
=
np
.
float64
))))
def
gain_db_to_ratio
(
gain_db
:
float
):
"""dB to ratio
Args:
gain_db (float): gain in dB
Returns:
float: scale in amp
"""
return
math
.
pow
(
10.0
,
gain_db
/
20.0
)
def
normalize_audio
(
sample_data
:
np
.
ndarray
,
dbfs
:
float
=-
3.0103
):
"""Nomalize audio to dBFS.
Args:
sample_data (np.ndarray): input wave samples, [-1, 1].
dbfs (float, optional): target dBFS. Defaults to -3.0103.
Returns:
np.ndarray: normalized wave
"""
return
np
.
maximum
(
np
.
minimum
(
sample_data
*
gain_db_to_ratio
(
dbfs
-
max_dbfs
(
sample_data
)),
1.0
),
-
1.0
)
def
_load_json_cmvn
(
json_cmvn_file
):
""" Load the json format cmvn stats file and calculate cmvn
Args:
json_cmvn_file: cmvn stats file in json format
Returns:
a numpy array of [means, vars]
"""
with
open
(
json_cmvn_file
)
as
f
:
cmvn_stats
=
json
.
load
(
f
)
means
=
cmvn_stats
[
'mean_stat'
]
variance
=
cmvn_stats
[
'var_stat'
]
count
=
cmvn_stats
[
'frame_num'
]
for
i
in
range
(
len
(
means
)):
means
[
i
]
/=
count
variance
[
i
]
=
variance
[
i
]
/
count
-
means
[
i
]
*
means
[
i
]
if
variance
[
i
]
<
1.0e-20
:
variance
[
i
]
=
1.0e-20
variance
[
i
]
=
1.0
/
math
.
sqrt
(
variance
[
i
])
cmvn
=
np
.
array
([
means
,
variance
])
return
cmvn
def
_load_kaldi_cmvn
(
kaldi_cmvn_file
):
""" Load the kaldi format cmvn stats file and calculate cmvn
Args:
kaldi_cmvn_file: kaldi text style global cmvn file, which
is generated by:
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
Returns:
a numpy array of [means, vars]
"""
means
=
[]
variance
=
[]
with
open
(
kaldi_cmvn_file
,
'r'
)
as
fid
:
# kaldi binary file start with '\0B'
if
fid
.
read
(
2
)
==
'
\0
B'
:
logger
.
error
(
'kaldi cmvn binary file is not supported, please '
'recompute it by: compute-cmvn-stats --binary=false '
' scp:feats.scp global_cmvn'
)
sys
.
exit
(
1
)
fid
.
seek
(
0
)
arr
=
fid
.
read
().
split
()
assert
(
arr
[
0
]
==
'['
)
assert
(
arr
[
-
2
]
==
'0'
)
assert
(
arr
[
-
1
]
==
']'
)
feat_dim
=
int
((
len
(
arr
)
-
2
-
2
)
/
2
)
for
i
in
range
(
1
,
feat_dim
+
1
):
means
.
append
(
float
(
arr
[
i
]))
count
=
float
(
arr
[
feat_dim
+
1
])
for
i
in
range
(
feat_dim
+
2
,
2
*
feat_dim
+
2
):
variance
.
append
(
float
(
arr
[
i
]))
for
i
in
range
(
len
(
means
)):
means
[
i
]
/=
count
variance
[
i
]
=
variance
[
i
]
/
count
-
means
[
i
]
*
means
[
i
]
if
variance
[
i
]
<
1.0e-20
:
variance
[
i
]
=
1.0e-20
variance
[
i
]
=
1.0
/
math
.
sqrt
(
variance
[
i
])
cmvn
=
np
.
array
([
means
,
variance
])
return
cmvn
def
load_cmvn
(
cmvn_file
:
str
,
filetype
:
str
):
"""load cmvn from file.
Args:
cmvn_file (str): cmvn path.
filetype (str): file type, optional[npz, json, kaldi].
Raises:
ValueError: file type not support.
Returns:
Tuple[np.ndarray, np.ndarray]: mean, istd
"""
assert
filetype
in
[
'npz'
,
'json'
,
'kaldi'
],
filetype
filetype
=
filetype
.
lower
()
if
filetype
==
"json"
:
cmvn
=
_load_json_cmvn
(
cmvn_file
)
elif
filetype
==
"kaldi"
:
cmvn
=
_load_kaldi_cmvn
(
cmvn_file
)
elif
filetype
==
"npz"
:
eps
=
1e-14
npzfile
=
np
.
load
(
cmvn_file
)
mean
=
np
.
squeeze
(
npzfile
[
"mean"
])
std
=
np
.
squeeze
(
npzfile
[
"std"
])
istd
=
1
/
(
std
+
eps
)
cmvn
=
[
mean
,
istd
]
else
:
raise
ValueError
(
f
"cmvn file type no support:
{
filetype
}
"
)
return
cmvn
[
0
],
cmvn
[
1
]
def
convert_samples_to_float32
(
samples
):
"""Convert sample type to float32.
Audio sample type is usually integer or float-point.
Integers will be scaled to [-1, 1] in float32.
PCM16 -> PCM32
"""
float32_samples
=
samples
.
astype
(
'float32'
)
if
samples
.
dtype
in
np
.
sctypes
[
'int'
]:
bits
=
np
.
iinfo
(
samples
.
dtype
).
bits
float32_samples
*=
(
1.
/
2
**
(
bits
-
1
))
elif
samples
.
dtype
in
np
.
sctypes
[
'float'
]:
pass
else
:
raise
TypeError
(
"Unsupported sample type: %s."
%
samples
.
dtype
)
return
float32_samples
def
convert_samples_from_float32
(
samples
,
dtype
):
"""Convert sample type from float32 to dtype.
Audio sample type is usually integer or float-point. For integer
type, float32 will be rescaled from [-1, 1] to the maximum range
supported by the integer type.
PCM32 -> PCM16
"""
dtype
=
np
.
dtype
(
dtype
)
output_samples
=
samples
.
copy
()
if
dtype
in
np
.
sctypes
[
'int'
]:
bits
=
np
.
iinfo
(
dtype
).
bits
output_samples
*=
(
2
**
(
bits
-
1
)
/
1.
)
min_val
=
np
.
iinfo
(
dtype
).
min
max_val
=
np
.
iinfo
(
dtype
).
max
output_samples
[
output_samples
>
max_val
]
=
max_val
output_samples
[
output_samples
<
min_val
]
=
min_val
elif
samples
.
dtype
in
np
.
sctypes
[
'float'
]:
min_val
=
np
.
finfo
(
dtype
).
min
max_val
=
np
.
finfo
(
dtype
).
max
output_samples
[
output_samples
>
max_val
]
=
max_val
output_samples
[
output_samples
<
min_val
]
=
min_val
else
:
raise
TypeError
(
"Unsupported sample type: %s."
%
samples
.
dtype
)
return
output_samples
.
astype
(
dtype
)
paddlespeech/
s2t
/transform/__init__.py
→
paddlespeech/
audio
/transform/__init__.py
浏览文件 @
d1a25f6c
文件已移动
paddlespeech/
s2t
/transform/add_deltas.py
→
paddlespeech/
audio
/transform/add_deltas.py
浏览文件 @
d1a25f6c
文件已移动
paddlespeech/
s2t
/transform/channel_selector.py
→
paddlespeech/
audio
/transform/channel_selector.py
浏览文件 @
d1a25f6c
文件已移动
paddlespeech/
s2t
/transform/cmvn.py
→
paddlespeech/
audio
/transform/cmvn.py
浏览文件 @
d1a25f6c
文件已移动
paddlespeech/
s2t
/transform/functional.py
→
paddlespeech/
audio
/transform/functional.py
浏览文件 @
d1a25f6c
...
@@ -14,8 +14,8 @@
...
@@ -14,8 +14,8 @@
# Modified from espnet(https://github.com/espnet/espnet)
# Modified from espnet(https://github.com/espnet/espnet)
import
inspect
import
inspect
from
paddlespeech.
s2t
.transform.transform_interface
import
TransformInterface
from
paddlespeech.
audio
.transform.transform_interface
import
TransformInterface
from
paddlespeech.
s2t
.utils.check_kwargs
import
check_kwargs
from
paddlespeech.
audio
.utils.check_kwargs
import
check_kwargs
class
FuncTrans
(
TransformInterface
):
class
FuncTrans
(
TransformInterface
):
...
...
paddlespeech/
s2t
/transform/perturb.py
→
paddlespeech/
audio
/transform/perturb.py
浏览文件 @
d1a25f6c
...
@@ -17,8 +17,97 @@ import numpy
...
@@ -17,8 +17,97 @@ import numpy
import
scipy
import
scipy
import
soundfile
import
soundfile
from
paddlespeech.s2t.io.reader
import
SoundHDF5File
import
io
import
os
import
h5py
import
numpy
as
np
class
SoundHDF5File
():
"""Collecting sound files to a HDF5 file
>>> f = SoundHDF5File('a.flac.h5', mode='a')
>>> array = np.random.randint(0, 100, 100, dtype=np.int16)
>>> f['id'] = (array, 16000)
>>> array, rate = f['id']
:param: str filepath:
:param: str mode:
:param: str format: The type used when saving wav. flac, nist, htk, etc.
:param: str dtype:
"""
def
__init__
(
self
,
filepath
,
mode
=
"r+"
,
format
=
None
,
dtype
=
"int16"
,
**
kwargs
):
self
.
filepath
=
filepath
self
.
mode
=
mode
self
.
dtype
=
dtype
self
.
file
=
h5py
.
File
(
filepath
,
mode
,
**
kwargs
)
if
format
is
None
:
# filepath = a.flac.h5 -> format = flac
second_ext
=
os
.
path
.
splitext
(
os
.
path
.
splitext
(
filepath
)[
0
])[
1
]
format
=
second_ext
[
1
:]
if
format
.
upper
()
not
in
soundfile
.
available_formats
():
# If not found, flac is selected
format
=
"flac"
# This format affects only saving
self
.
format
=
format
def
__repr__
(
self
):
return
'<SoundHDF5 file "{}" (mode {}, format {}, type {})>'
.
format
(
self
.
filepath
,
self
.
mode
,
self
.
format
,
self
.
dtype
)
def
create_dataset
(
self
,
name
,
shape
=
None
,
data
=
None
,
**
kwds
):
f
=
io
.
BytesIO
()
array
,
rate
=
data
soundfile
.
write
(
f
,
array
,
rate
,
format
=
self
.
format
)
self
.
file
.
create_dataset
(
name
,
shape
=
shape
,
data
=
np
.
void
(
f
.
getvalue
()),
**
kwds
)
def
__setitem__
(
self
,
name
,
data
):
self
.
create_dataset
(
name
,
data
=
data
)
def
__getitem__
(
self
,
key
):
data
=
self
.
file
[
key
][()]
f
=
io
.
BytesIO
(
data
.
tobytes
())
array
,
rate
=
soundfile
.
read
(
f
,
dtype
=
self
.
dtype
)
return
array
,
rate
def
keys
(
self
):
return
self
.
file
.
keys
()
def
values
(
self
):
for
k
in
self
.
file
:
yield
self
[
k
]
def
items
(
self
):
for
k
in
self
.
file
:
yield
k
,
self
[
k
]
def
__iter__
(
self
):
return
iter
(
self
.
file
)
def
__contains__
(
self
,
item
):
return
item
in
self
.
file
def
__len__
(
self
,
item
):
return
len
(
self
.
file
)
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
file
.
close
()
def
close
(
self
):
self
.
file
.
close
()
class
SpeedPerturbation
():
class
SpeedPerturbation
():
"""SpeedPerturbation
"""SpeedPerturbation
...
@@ -469,3 +558,4 @@ class RIRConvolve():
...
@@ -469,3 +558,4 @@ class RIRConvolve():
[
scipy
.
convolve
(
x
,
r
,
mode
=
"same"
)
for
r
in
rir
],
axis
=-
1
)
[
scipy
.
convolve
(
x
,
r
,
mode
=
"same"
)
for
r
in
rir
],
axis
=-
1
)
else
:
else
:
return
scipy
.
convolve
(
x
,
rir
,
mode
=
"same"
)
return
scipy
.
convolve
(
x
,
rir
,
mode
=
"same"
)
paddlespeech/
s2t
/transform/spec_augment.py
→
paddlespeech/
audio
/transform/spec_augment.py
浏览文件 @
d1a25f6c
...
@@ -17,7 +17,7 @@ import random
...
@@ -17,7 +17,7 @@ import random
import
numpy
import
numpy
from
PIL
import
Image
from
PIL
import
Image
from
paddlespeech.s2t.transform
.functional
import
FuncTrans
from
.functional
import
FuncTrans
def
time_warp
(
x
,
max_time_warp
=
80
,
inplace
=
False
,
mode
=
"PIL"
):
def
time_warp
(
x
,
max_time_warp
=
80
,
inplace
=
False
,
mode
=
"PIL"
):
...
...
paddlespeech/
s2t
/transform/spectrogram.py
→
paddlespeech/
audio
/transform/spectrogram.py
浏览文件 @
d1a25f6c
...
@@ -17,7 +17,7 @@ import numpy as np
...
@@ -17,7 +17,7 @@ import numpy as np
import
paddle
import
paddle
from
python_speech_features
import
logfbank
from
python_speech_features
import
logfbank
import
paddlespeech.audio.compliance.kaldi
as
kaldi
from
..compliance
import
kaldi
def
stft
(
x
,
def
stft
(
x
,
...
...
paddlespeech/
s2t
/transform/transform_interface.py
→
paddlespeech/
audio
/transform/transform_interface.py
浏览文件 @
d1a25f6c
文件已移动
paddlespeech/
s2t
/transform/transformation.py
→
paddlespeech/
audio
/transform/transformation.py
浏览文件 @
d1a25f6c
...
@@ -22,32 +22,32 @@ from inspect import signature
...
@@ -22,32 +22,32 @@ from inspect import signature
import
yaml
import
yaml
from
paddlespeech.s2t
.utils.dynamic_import
import
dynamic_import
from
.
.utils.dynamic_import
import
dynamic_import
import_alias
=
dict
(
import_alias
=
dict
(
identity
=
"paddlespeech.
s2t
.transform.transform_interface:Identity"
,
identity
=
"paddlespeech.
audio
.transform.transform_interface:Identity"
,
time_warp
=
"paddlespeech.
s2t
.transform.spec_augment:TimeWarp"
,
time_warp
=
"paddlespeech.
audio
.transform.spec_augment:TimeWarp"
,
time_mask
=
"paddlespeech.
s2t
.transform.spec_augment:TimeMask"
,
time_mask
=
"paddlespeech.
audio
.transform.spec_augment:TimeMask"
,
freq_mask
=
"paddlespeech.
s2t
.transform.spec_augment:FreqMask"
,
freq_mask
=
"paddlespeech.
audio
.transform.spec_augment:FreqMask"
,
spec_augment
=
"paddlespeech.
s2t
.transform.spec_augment:SpecAugment"
,
spec_augment
=
"paddlespeech.
audio
.transform.spec_augment:SpecAugment"
,
speed_perturbation
=
"paddlespeech.
s2t
.transform.perturb:SpeedPerturbation"
,
speed_perturbation
=
"paddlespeech.
audio
.transform.perturb:SpeedPerturbation"
,
speed_perturbation_sox
=
"paddlespeech.
s2t
.transform.perturb:SpeedPerturbationSox"
,
speed_perturbation_sox
=
"paddlespeech.
audio
.transform.perturb:SpeedPerturbationSox"
,
volume_perturbation
=
"paddlespeech.
s2t
.transform.perturb:VolumePerturbation"
,
volume_perturbation
=
"paddlespeech.
audio
.transform.perturb:VolumePerturbation"
,
noise_injection
=
"paddlespeech.
s2t
.transform.perturb:NoiseInjection"
,
noise_injection
=
"paddlespeech.
audio
.transform.perturb:NoiseInjection"
,
bandpass_perturbation
=
"paddlespeech.
s2t
.transform.perturb:BandpassPerturbation"
,
bandpass_perturbation
=
"paddlespeech.
audio
.transform.perturb:BandpassPerturbation"
,
rir_convolve
=
"paddlespeech.
s2t
.transform.perturb:RIRConvolve"
,
rir_convolve
=
"paddlespeech.
audio
.transform.perturb:RIRConvolve"
,
delta
=
"paddlespeech.
s2t
.transform.add_deltas:AddDeltas"
,
delta
=
"paddlespeech.
audio
.transform.add_deltas:AddDeltas"
,
cmvn
=
"paddlespeech.
s2t
.transform.cmvn:CMVN"
,
cmvn
=
"paddlespeech.
audio
.transform.cmvn:CMVN"
,
utterance_cmvn
=
"paddlespeech.
s2t
.transform.cmvn:UtteranceCMVN"
,
utterance_cmvn
=
"paddlespeech.
audio
.transform.cmvn:UtteranceCMVN"
,
fbank
=
"paddlespeech.
s2t
.transform.spectrogram:LogMelSpectrogram"
,
fbank
=
"paddlespeech.
audio
.transform.spectrogram:LogMelSpectrogram"
,
spectrogram
=
"paddlespeech.
s2t
.transform.spectrogram:Spectrogram"
,
spectrogram
=
"paddlespeech.
audio
.transform.spectrogram:Spectrogram"
,
stft
=
"paddlespeech.
s2t
.transform.spectrogram:Stft"
,
stft
=
"paddlespeech.
audio
.transform.spectrogram:Stft"
,
istft
=
"paddlespeech.
s2t
.transform.spectrogram:IStft"
,
istft
=
"paddlespeech.
audio
.transform.spectrogram:IStft"
,
stft2fbank
=
"paddlespeech.
s2t
.transform.spectrogram:Stft2LogMelSpectrogram"
,
stft2fbank
=
"paddlespeech.
audio
.transform.spectrogram:Stft2LogMelSpectrogram"
,
wpe
=
"paddlespeech.
s2t
.transform.wpe:WPE"
,
wpe
=
"paddlespeech.
audio
.transform.wpe:WPE"
,
channel_selector
=
"paddlespeech.
s2t
.transform.channel_selector:ChannelSelector"
,
channel_selector
=
"paddlespeech.
audio
.transform.channel_selector:ChannelSelector"
,
fbank_kaldi
=
"paddlespeech.
s2t
.transform.spectrogram:LogMelSpectrogramKaldi"
,
fbank_kaldi
=
"paddlespeech.
audio
.transform.spectrogram:LogMelSpectrogramKaldi"
,
cmvn_json
=
"paddlespeech.
s2t
.transform.cmvn:GlobalCMVN"
)
cmvn_json
=
"paddlespeech.
audio
.transform.cmvn:GlobalCMVN"
)
class
Transformation
():
class
Transformation
():
...
...
paddlespeech/
s2t
/transform/wpe.py
→
paddlespeech/
audio
/transform/wpe.py
浏览文件 @
d1a25f6c
文件已移动
paddlespeech/audio/utils/check_kwargs.py
0 → 100644
浏览文件 @
d1a25f6c
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
inspect
def
check_kwargs
(
func
,
kwargs
,
name
=
None
):
"""check kwargs are valid for func
If kwargs are invalid, raise TypeError as same as python default
:param function func: function to be validated
:param dict kwargs: keyword arguments for func
:param str name: name used in TypeError (default is func name)
"""
try
:
params
=
inspect
.
signature
(
func
).
parameters
except
ValueError
:
return
if
name
is
None
:
name
=
func
.
__name__
for
k
in
kwargs
.
keys
():
if
k
not
in
params
:
raise
TypeError
(
f
"
{
name
}
() got an unexpected keyword argument '
{
k
}
'"
)
paddlespeech/audio/utils/dynamic_import.py
0 → 100644
浏览文件 @
d1a25f6c
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
importlib
__all__
=
[
"dynamic_import"
]
def
dynamic_import
(
import_path
,
alias
=
dict
()):
"""dynamic import module and class
:param str import_path: syntax 'module_name:class_name'
e.g., 'paddlespeech.s2t.models.u2:U2Model'
:param dict alias: shortcut for registered class
:return: imported class
"""
if
import_path
not
in
alias
and
":"
not
in
import_path
:
raise
ValueError
(
"import_path should be one of {} or "
'include ":", e.g. "paddlespeech.s2t.models.u2:U2Model" : '
"{}"
.
format
(
set
(
alias
),
import_path
))
if
":"
not
in
import_path
:
import_path
=
alias
[
import_path
]
module_name
,
objname
=
import_path
.
split
(
":"
)
m
=
importlib
.
import_module
(
module_name
)
return
getattr
(
m
,
objname
)
paddlespeech/audio/utils/log.py
浏览文件 @
d1a25f6c
...
@@ -65,6 +65,7 @@ class Logger(object):
...
@@ -65,6 +65,7 @@ class Logger(object):
def
__init__
(
self
,
name
:
str
=
None
):
def
__init__
(
self
,
name
:
str
=
None
):
name
=
'PaddleAudio'
if
not
name
else
name
name
=
'PaddleAudio'
if
not
name
else
name
self
.
name
=
name
self
.
logger
=
logging
.
getLogger
(
name
)
self
.
logger
=
logging
.
getLogger
(
name
)
for
key
,
conf
in
log_config
.
items
():
for
key
,
conf
in
log_config
.
items
():
...
@@ -101,7 +102,7 @@ class Logger(object):
...
@@ -101,7 +102,7 @@ class Logger(object):
if
not
self
.
is_enable
:
if
not
self
.
is_enable
:
return
return
self
.
logger
.
log
(
log_level
,
msg
)
self
.
logger
.
log
(
log_level
,
self
.
name
+
" | "
+
msg
)
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
use_terminator
(
self
,
terminator
:
str
):
def
use_terminator
(
self
,
terminator
:
str
):
...
...
paddlespeech/audio/utils/tensor_utils.py
0 → 100644
浏览文件 @
d1a25f6c
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unility functions for Transformer."""
from
typing
import
List
from
typing
import
Tuple
import
paddle
from
.log
import
Logger
__all__
=
[
"pad_sequence"
,
"add_sos_eos"
,
"th_accuracy"
,
"has_tensor"
]
logger
=
Logger
(
__name__
)
def
has_tensor
(
val
):
if
isinstance
(
val
,
(
list
,
tuple
)):
for
item
in
val
:
if
has_tensor
(
item
):
return
True
elif
isinstance
(
val
,
dict
):
for
k
,
v
in
val
.
items
():
print
(
k
)
if
has_tensor
(
v
):
return
True
else
:
return
paddle
.
is_tensor
(
val
)
def
pad_sequence
(
sequences
:
List
[
paddle
.
Tensor
],
batch_first
:
bool
=
False
,
padding_value
:
float
=
0.0
)
->
paddle
.
Tensor
:
r
"""Pad a list of variable length Tensors with ``padding_value``
``pad_sequence`` stacks a list of Tensors along a new dimension,
and pads them to equal length. For example, if the input is list of
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
otherwise.
`B` is batch size. It is equal to the number of elements in ``sequences``.
`T` is length of the longest sequence.
`L` is length of the sequence.
`*` is any number of trailing dimensions, including none.
Example:
>>> from paddle.nn.utils.rnn import pad_sequence
>>> a = paddle.ones(25, 300)
>>> b = paddle.ones(22, 300)
>>> c = paddle.ones(15, 300)
>>> pad_sequence([a, b, c]).shape
paddle.Tensor([25, 3, 300])
Note:
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
where `T` is the length of the longest sequence. This function assumes
trailing dimensions and type of all the Tensors in sequences are same.
Args:
sequences (list[Tensor]): list of variable length sequences.
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
``T x B x *`` otherwise
padding_value (float, optional): value for padded elements. Default: 0.
Returns:
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
Tensor of size ``B x T x *`` otherwise
"""
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size
=
paddle
.
shape
(
sequences
[
0
])
# (TODO Hui Zhang): slice not supprot `end==start`
# trailing_dims = max_size[1:]
trailing_dims
=
tuple
(
max_size
[
1
:].
numpy
().
tolist
())
if
sequences
[
0
].
ndim
>=
2
else
()
max_len
=
max
([
s
.
shape
[
0
]
for
s
in
sequences
])
if
batch_first
:
out_dims
=
(
len
(
sequences
),
max_len
)
+
trailing_dims
else
:
out_dims
=
(
max_len
,
len
(
sequences
))
+
trailing_dims
out_tensor
=
paddle
.
full
(
out_dims
,
padding_value
,
sequences
[
0
].
dtype
)
for
i
,
tensor
in
enumerate
(
sequences
):
length
=
tensor
.
shape
[
0
]
# use index notation to prevent duplicate references to the tensor
if
batch_first
:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# TODO (Hui Zhang): set_value op not support int16
# TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...]
# out_tensor[i, :length, ...] = tensor
if
length
!=
0
:
out_tensor
[
i
,
:
length
]
=
tensor
else
:
out_tensor
[
i
,
length
]
=
tensor
else
:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# out_tensor[:length, i, ...] = tensor
if
length
!=
0
:
out_tensor
[:
length
,
i
]
=
tensor
else
:
out_tensor
[
length
,
i
]
=
tensor
return
out_tensor
def
add_sos_eos
(
ys_pad
:
paddle
.
Tensor
,
sos
:
int
,
eos
:
int
,
ignore_id
:
int
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Add <sos> and <eos> labels.
Args:
ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)
sos (int): index of <sos>
eos (int): index of <eeos>
ignore_id (int): index of padding
Returns:
ys_in (paddle.Tensor) : (B, Lmax + 1)
ys_out (paddle.Tensor) : (B, Lmax + 1)
Examples:
>>> sos_id = 10
>>> eos_id = 11
>>> ignore_id = -1
>>> ys_pad
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, -1, -1],
[ 7, 8, 9, -1, -1]], dtype=paddle.int32)
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
>>> ys_in
tensor([[10, 1, 2, 3, 4, 5],
[10, 4, 5, 6, 11, 11],
[10, 7, 8, 9, 11, 11]])
>>> ys_out
tensor([[ 1, 2, 3, 4, 5, 11],
[ 4, 5, 6, 11, -1, -1],
[ 7, 8, 9, 11, -1, -1]])
"""
# TODO(Hui Zhang): using comment code,
#_sos = paddle.to_tensor(
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#_eos = paddle.to_tensor(
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
B
=
ys_pad
.
shape
[
0
]
_sos
=
paddle
.
ones
([
B
,
1
],
dtype
=
ys_pad
.
dtype
)
*
sos
_eos
=
paddle
.
ones
([
B
,
1
],
dtype
=
ys_pad
.
dtype
)
*
eos
ys_in
=
paddle
.
cat
([
_sos
,
ys_pad
],
dim
=
1
)
mask_pad
=
(
ys_in
==
ignore_id
)
ys_in
=
ys_in
.
masked_fill
(
mask_pad
,
eos
)
ys_out
=
paddle
.
cat
([
ys_pad
,
_eos
],
dim
=
1
)
ys_out
=
ys_out
.
masked_fill
(
mask_pad
,
eos
)
mask_eos
=
(
ys_out
==
ignore_id
)
ys_out
=
ys_out
.
masked_fill
(
mask_eos
,
eos
)
ys_out
=
ys_out
.
masked_fill
(
mask_pad
,
ignore_id
)
return
ys_in
,
ys_out
def
th_accuracy
(
pad_outputs
:
paddle
.
Tensor
,
pad_targets
:
paddle
.
Tensor
,
ignore_label
:
int
)
->
float
:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred
=
pad_outputs
.
view
(
pad_targets
.
shape
[
0
],
pad_targets
.
shape
[
1
],
pad_outputs
.
shape
[
1
]).
argmax
(
2
)
mask
=
pad_targets
!=
ignore_label
#TODO(Hui Zhang): sum not support bool type
# numerator = paddle.sum(
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator
=
(
pad_pred
.
masked_select
(
mask
)
==
pad_targets
.
masked_select
(
mask
))
numerator
=
paddle
.
sum
(
numerator
.
type_as
(
pad_targets
))
#TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask)
denominator
=
paddle
.
sum
(
mask
.
type_as
(
pad_targets
))
return
float
(
numerator
)
/
float
(
denominator
)
paddlespeech/cli/asr/infer.py
浏览文件 @
d1a25f6c
...
@@ -33,8 +33,8 @@ from ..log import logger
...
@@ -33,8 +33,8 @@ from ..log import logger
from
..utils
import
CLI_TIMER
from
..utils
import
CLI_TIMER
from
..utils
import
stats_wrapper
from
..utils
import
stats_wrapper
from
..utils
import
timer_register
from
..utils
import
timer_register
from
paddlespeech.audio.transform.transformation
import
Transformation
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.transform.transformation
import
Transformation
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
__all__
=
[
'ASRExecutor'
]
__all__
=
[
'ASRExecutor'
]
...
...
paddlespeech/s2t/exps/deepspeech2/model.py
浏览文件 @
d1a25f6c
...
@@ -23,7 +23,7 @@ import paddle
...
@@ -23,7 +23,7 @@ import paddle
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
from
paddle
import
inference
from
paddle
import
inference
from
paddlespeech.
s2t.frontend.featurizer
.text_featurizer
import
TextFeaturizer
from
paddlespeech.
audio.text
.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.models.ds2
import
DeepSpeech2InferModel
from
paddlespeech.s2t.models.ds2
import
DeepSpeech2InferModel
from
paddlespeech.s2t.models.ds2
import
DeepSpeech2Model
from
paddlespeech.s2t.models.ds2
import
DeepSpeech2Model
...
...
paddlespeech/s2t/exps/u2/bin/test_wav.py
浏览文件 @
d1a25f6c
...
@@ -20,10 +20,10 @@ import paddle
...
@@ -20,10 +20,10 @@ import paddle
import
soundfile
import
soundfile
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
from
paddlespeech.audio.transform.transformation
import
Transformation
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.models.u2
import
U2Model
from
paddlespeech.s2t.models.u2
import
U2Model
from
paddlespeech.s2t.training.cli
import
default_argument_parser
from
paddlespeech.s2t.training.cli
import
default_argument_parser
from
paddlespeech.s2t.transform.transformation
import
Transformation
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
...
paddlespeech/s2t/exps/u2/model.py
浏览文件 @
d1a25f6c
...
@@ -26,6 +26,8 @@ from paddle import distributed as dist
...
@@ -26,6 +26,8 @@ from paddle import distributed as dist
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.io.dataloader
import
StreamDataLoader
from
paddlespeech.s2t.io.dataloader
import
DataLoaderFactory
from
paddlespeech.s2t.models.u2
import
U2Model
from
paddlespeech.s2t.models.u2
import
U2Model
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
from
paddlespeech.s2t.training.reporter
import
ObsScope
from
paddlespeech.s2t.training.reporter
import
ObsScope
...
@@ -106,7 +108,8 @@ class U2Trainer(Trainer):
...
@@ -106,7 +108,8 @@ class U2Trainer(Trainer):
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
valid
(
self
):
def
valid
(
self
):
self
.
model
.
eval
()
self
.
model
.
eval
()
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
num_seen_utts
=
1
total_loss
=
0.0
total_loss
=
0.0
...
@@ -132,7 +135,8 @@ class U2Trainer(Trainer):
...
@@ -132,7 +135,8 @@ class U2Trainer(Trainer):
msg
=
f
"Valid: Rank:
{
dist
.
get_rank
()
}
, "
msg
=
f
"Valid: Rank:
{
dist
.
get_rank
()
}
, "
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
if
not
self
.
use_streamdata
:
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_dump
.
items
())
for
k
,
v
in
valid_dump
.
items
())
logger
.
info
(
msg
)
logger
.
info
(
msg
)
...
@@ -152,7 +156,8 @@ class U2Trainer(Trainer):
...
@@ -152,7 +156,8 @@ class U2Trainer(Trainer):
self
.
before_train
()
self
.
before_train
()
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
n_epoch
:
with
Timer
(
"Epoch-Train Time Cost: {}"
):
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
self
.
model
.
train
()
...
@@ -170,7 +175,8 @@ class U2Trainer(Trainer):
...
@@ -170,7 +175,8 @@ class U2Trainer(Trainer):
self
.
train_batch
(
batch_index
,
batch
,
msg
)
self
.
train_batch
(
batch_index
,
batch
,
msg
)
self
.
after_train_batch
()
self
.
after_train_batch
()
report
(
'iter'
,
batch_index
+
1
)
report
(
'iter'
,
batch_index
+
1
)
report
(
'total'
,
len
(
self
.
train_loader
))
if
not
self
.
use_streamdata
:
report
(
'total'
,
len
(
self
.
train_loader
))
report
(
'reader_cost'
,
dataload_time
)
report
(
'reader_cost'
,
dataload_time
)
observation
[
'batch_cost'
]
=
observation
[
observation
[
'batch_cost'
]
=
observation
[
'reader_cost'
]
+
observation
[
'step_cost'
]
'reader_cost'
]
+
observation
[
'step_cost'
]
...
@@ -191,7 +197,6 @@ class U2Trainer(Trainer):
...
@@ -191,7 +197,6 @@ class U2Trainer(Trainer):
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
e
)
logger
.
error
(
e
)
raise
e
raise
e
with
Timer
(
"Eval Time Cost: {}"
):
with
Timer
(
"Eval Time Cost: {}"
):
total_loss
,
num_seen_utts
=
self
.
valid
()
total_loss
,
num_seen_utts
=
self
.
valid
()
if
dist
.
get_world_size
()
>
1
:
if
dist
.
get_world_size
()
>
1
:
...
@@ -218,92 +223,16 @@ class U2Trainer(Trainer):
...
@@ -218,92 +223,16 @@ class U2Trainer(Trainer):
def
setup_dataloader
(
self
):
def
setup_dataloader
(
self
):
config
=
self
.
config
.
clone
()
config
=
self
.
config
.
clone
()
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
if
self
.
train
:
if
self
.
train
:
# train/valid dataset, return token ids
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
self
.
train_loader
=
BatchDataLoader
(
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
json_file
=
config
.
train_manifest
,
train_mode
=
True
,
sortagrad
=
config
.
sortagrad
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
config
.
maxlen_in
,
maxlen_out
=
config
.
maxlen_out
,
minibatches
=
config
.
minibatches
,
mini_batch_size
=
self
.
args
.
ngpu
,
batch_count
=
config
.
batch_count
,
batch_bins
=
config
.
batch_bins
,
batch_frames_in
=
config
.
batch_frames_in
,
batch_frames_out
=
config
.
batch_frames_out
,
batch_frames_inout
=
config
.
batch_frames_inout
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
,
dist_sampler
=
config
.
get
(
'dist_sampler'
,
False
),
shortest_first
=
False
)
self
.
valid_loader
=
BatchDataLoader
(
json_file
=
config
.
dev_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
self
.
args
.
ngpu
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
,
dist_sampler
=
config
.
get
(
'dist_sampler'
,
False
),
shortest_first
=
False
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
else
:
decode_batch_size
=
config
.
get
(
'decode'
,
dict
()).
get
(
decode_batch_size
=
config
.
get
(
'decode'
,
dict
()).
get
(
'decode_batch_size'
,
1
)
'decode_batch_size'
,
1
)
# test dataset, return raw text
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
self
.
test_loader
=
BatchDataLoader
(
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
json_file
=
config
.
test_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
decode_batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
self
.
align_loader
=
BatchDataLoader
(
json_file
=
config
.
test_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
decode_batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
logger
.
info
(
"Setup test/align Dataloader!"
)
logger
.
info
(
"Setup test/align Dataloader!"
)
def
setup_model
(
self
):
def
setup_model
(
self
):
...
@@ -452,7 +381,8 @@ class U2Tester(U2Trainer):
...
@@ -452,7 +381,8 @@ class U2Tester(U2Trainer):
def
test
(
self
):
def
test
(
self
):
assert
self
.
args
.
result_file
assert
self
.
args
.
result_file
self
.
model
.
eval
()
self
.
model
.
eval
()
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
stride_ms
=
self
.
config
.
stride_ms
stride_ms
=
self
.
config
.
stride_ms
error_rate_type
=
None
error_rate_type
=
None
...
...
paddlespeech/s2t/exps/u2_kaldi/model.py
浏览文件 @
d1a25f6c
...
@@ -25,7 +25,7 @@ from paddle import distributed as dist
...
@@ -25,7 +25,7 @@ from paddle import distributed as dist
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.utility
import
load_dict
from
paddlespeech.s2t.frontend.utility
import
load_dict
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.io.dataloader
import
DataLoaderFactory
from
paddlespeech.s2t.models.u2
import
U2Model
from
paddlespeech.s2t.models.u2
import
U2Model
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
from
paddlespeech.s2t.training.scheduler
import
LRSchedulerFactory
from
paddlespeech.s2t.training.scheduler
import
LRSchedulerFactory
...
@@ -104,7 +104,8 @@ class U2Trainer(Trainer):
...
@@ -104,7 +104,8 @@ class U2Trainer(Trainer):
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
valid
(
self
):
def
valid
(
self
):
self
.
model
.
eval
()
self
.
model
.
eval
()
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
num_seen_utts
=
1
total_loss
=
0.0
total_loss
=
0.0
...
@@ -131,7 +132,8 @@ class U2Trainer(Trainer):
...
@@ -131,7 +132,8 @@ class U2Trainer(Trainer):
msg
=
f
"Valid: Rank:
{
dist
.
get_rank
()
}
, "
msg
=
f
"Valid: Rank:
{
dist
.
get_rank
()
}
, "
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
if
not
self
.
use_streamdata
:
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_dump
.
items
())
for
k
,
v
in
valid_dump
.
items
())
logger
.
info
(
msg
)
logger
.
info
(
msg
)
...
@@ -150,8 +152,8 @@ class U2Trainer(Trainer):
...
@@ -150,8 +152,8 @@ class U2Trainer(Trainer):
# paddle.jit.save(script_model, script_model_path)
# paddle.jit.save(script_model, script_model_path)
self
.
before_train
()
self
.
before_train
()
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
n_epoch
:
with
Timer
(
"Epoch-Train Time Cost: {}"
):
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
self
.
model
.
train
()
...
@@ -162,7 +164,8 @@ class U2Trainer(Trainer):
...
@@ -162,7 +164,8 @@ class U2Trainer(Trainer):
msg
=
"Train: Rank: {}, "
.
format
(
dist
.
get_rank
())
msg
=
"Train: Rank: {}, "
.
format
(
dist
.
get_rank
())
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"batch : {}/{}, "
.
format
(
batch_index
+
1
,
if
not
self
.
use_streamdata
:
msg
+=
"batch : {}/{}, "
.
format
(
batch_index
+
1
,
len
(
self
.
train_loader
))
len
(
self
.
train_loader
))
msg
+=
"lr: {:>.8f}, "
.
format
(
self
.
lr_scheduler
())
msg
+=
"lr: {:>.8f}, "
.
format
(
self
.
lr_scheduler
())
msg
+=
"data time: {:>.3f}s, "
.
format
(
dataload_time
)
msg
+=
"data time: {:>.3f}s, "
.
format
(
dataload_time
)
...
@@ -198,87 +201,23 @@ class U2Trainer(Trainer):
...
@@ -198,87 +201,23 @@ class U2Trainer(Trainer):
self
.
new_epoch
()
self
.
new_epoch
()
def
setup_dataloader
(
self
):
def
setup_dataloader
(
self
):
config
=
self
.
config
.
clone
()
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
# train/valid dataset, return token ids
if
self
.
train
:
self
.
train_loader
=
BatchDataLoader
(
config
=
self
.
config
.
clone
()
json_file
=
config
.
train_manifest
,
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
train_mode
=
True
,
config
=
self
.
config
.
clone
()
sortagrad
=
False
,
config
[
'preprocess_config'
]
=
None
batch_size
=
config
.
batch_size
,
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
maxlen_in
=
float
(
'inf'
),
logger
.
info
(
"Setup train/valid Dataloader!"
)
maxlen_out
=
float
(
'inf'
),
else
:
minibatches
=
0
,
config
=
self
.
config
.
clone
()
mini_batch_size
=
self
.
args
.
ngpu
,
config
[
'preprocess_config'
]
=
None
batch_count
=
'auto'
,
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
batch_bins
=
0
,
config
=
self
.
config
.
clone
()
batch_frames_in
=
0
,
config
[
'preprocess_config'
]
=
None
batch_frames_out
=
0
,
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
batch_frames_inout
=
0
,
logger
.
info
(
"Setup test/align Dataloader!"
)
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
)
self
.
valid_loader
=
BatchDataLoader
(
json_file
=
config
.
dev_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
self
.
args
.
ngpu
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
None
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
)
decode_batch_size
=
config
.
get
(
'decode'
,
dict
()).
get
(
'decode_batch_size'
,
1
)
# test dataset, return raw text
self
.
test_loader
=
BatchDataLoader
(
json_file
=
config
.
test_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
decode_batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
None
,
n_iter_processes
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
self
.
align_loader
=
BatchDataLoader
(
json_file
=
config
.
test_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
decode_batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
None
,
n_iter_processes
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
logger
.
info
(
"Setup train/valid/test/align Dataloader!"
)
def
setup_model
(
self
):
def
setup_model
(
self
):
config
=
self
.
config
config
=
self
.
config
...
@@ -406,7 +345,8 @@ class U2Tester(U2Trainer):
...
@@ -406,7 +345,8 @@ class U2Tester(U2Trainer):
def
test
(
self
):
def
test
(
self
):
assert
self
.
args
.
result_file
assert
self
.
args
.
result_file
self
.
model
.
eval
()
self
.
model
.
eval
()
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
stride_ms
=
self
.
config
.
stride_ms
stride_ms
=
self
.
config
.
stride_ms
error_rate_type
=
None
error_rate_type
=
None
...
...
paddlespeech/s2t/exps/u2_st/model.py
浏览文件 @
d1a25f6c
...
@@ -25,7 +25,7 @@ import paddle
...
@@ -25,7 +25,7 @@ import paddle
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.io.dataloader
import
DataLoaderFactory
from
paddlespeech.s2t.models.u2_st
import
U2STModel
from
paddlespeech.s2t.models.u2_st
import
U2STModel
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
from
paddlespeech.s2t.training.reporter
import
ObsScope
from
paddlespeech.s2t.training.reporter
import
ObsScope
...
@@ -120,7 +120,8 @@ class U2STTrainer(Trainer):
...
@@ -120,7 +120,8 @@ class U2STTrainer(Trainer):
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
valid
(
self
):
def
valid
(
self
):
self
.
model
.
eval
()
self
.
model
.
eval
()
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
num_seen_utts
=
1
total_loss
=
0.0
total_loss
=
0.0
...
@@ -153,7 +154,8 @@ class U2STTrainer(Trainer):
...
@@ -153,7 +154,8 @@ class U2STTrainer(Trainer):
msg
=
f
"Valid: Rank:
{
dist
.
get_rank
()
}
, "
msg
=
f
"Valid: Rank:
{
dist
.
get_rank
()
}
, "
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
if
not
self
.
use_streamdata
:
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_dump
.
items
())
for
k
,
v
in
valid_dump
.
items
())
logger
.
info
(
msg
)
logger
.
info
(
msg
)
...
@@ -172,8 +174,8 @@ class U2STTrainer(Trainer):
...
@@ -172,8 +174,8 @@ class U2STTrainer(Trainer):
# paddle.jit.save(script_model, script_model_path)
# paddle.jit.save(script_model, script_model_path)
self
.
before_train
()
self
.
before_train
()
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
n_epoch
:
with
Timer
(
"Epoch-Train Time Cost: {}"
):
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
self
.
model
.
train
()
...
@@ -191,7 +193,8 @@ class U2STTrainer(Trainer):
...
@@ -191,7 +193,8 @@ class U2STTrainer(Trainer):
self
.
train_batch
(
batch_index
,
batch
,
msg
)
self
.
train_batch
(
batch_index
,
batch
,
msg
)
self
.
after_train_batch
()
self
.
after_train_batch
()
report
(
'iter'
,
batch_index
+
1
)
report
(
'iter'
,
batch_index
+
1
)
report
(
'total'
,
len
(
self
.
train_loader
))
if
not
self
.
use_streamdata
:
report
(
'total'
,
len
(
self
.
train_loader
))
report
(
'reader_cost'
,
dataload_time
)
report
(
'reader_cost'
,
dataload_time
)
observation
[
'batch_cost'
]
=
observation
[
observation
[
'batch_cost'
]
=
observation
[
'reader_cost'
]
+
observation
[
'step_cost'
]
'reader_cost'
]
+
observation
[
'step_cost'
]
...
@@ -241,79 +244,18 @@ class U2STTrainer(Trainer):
...
@@ -241,79 +244,18 @@ class U2STTrainer(Trainer):
load_transcript
=
True
if
config
.
model_conf
.
asr_weight
>
0
else
False
load_transcript
=
True
if
config
.
model_conf
.
asr_weight
>
0
else
False
config
=
self
.
config
.
clone
()
config
[
'load_transcript'
]
=
load_transcript
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
if
self
.
train
:
if
self
.
train
:
# train/valid dataset, return token ids
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
self
.
train_loader
=
BatchDataLoader
(
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
json_file
=
config
.
train_manifest
,
train_mode
=
True
,
sortagrad
=
False
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
config
.
maxlen_in
,
maxlen_out
=
config
.
maxlen_out
,
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
# aug will be off when train_mode=False
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
load_aux_output
=
load_transcript
,
num_encs
=
1
,
dist_sampler
=
True
)
self
.
valid_loader
=
BatchDataLoader
(
json_file
=
config
.
dev_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
# aug will be off when train_mode=False
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
load_aux_output
=
load_transcript
,
num_encs
=
1
,
dist_sampler
=
False
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
else
:
# test dataset, return raw text
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
decode_batch_size
=
config
.
get
(
'decode'
,
dict
()).
get
(
'decode_batch_size'
,
1
)
self
.
test_loader
=
BatchDataLoader
(
json_file
=
config
.
test_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
decode_batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
# aug will be off when train_mode=False
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
,
dist_sampler
=
False
)
logger
.
info
(
"Setup test Dataloader!"
)
logger
.
info
(
"Setup test Dataloader!"
)
def
setup_model
(
self
):
def
setup_model
(
self
):
config
=
self
.
config
config
=
self
.
config
model_conf
=
config
model_conf
=
config
...
@@ -468,7 +410,8 @@ class U2STTester(U2STTrainer):
...
@@ -468,7 +410,8 @@ class U2STTester(U2STTrainer):
def
test
(
self
):
def
test
(
self
):
assert
self
.
args
.
result_file
assert
self
.
args
.
result_file
self
.
model
.
eval
()
self
.
model
.
eval
()
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
decode_cfg
=
self
.
config
.
decode
decode_cfg
=
self
.
config
.
decode
bleu_func
=
bleu_score
.
char_bleu
if
decode_cfg
.
error_rate_type
==
'char-bleu'
else
bleu_score
.
bleu
bleu_func
=
bleu_score
.
char_bleu
if
decode_cfg
.
error_rate_type
==
'char-bleu'
else
bleu_score
.
bleu
...
...
paddlespeech/s2t/io/dataloader.py
浏览文件 @
d1a25f6c
...
@@ -18,6 +18,7 @@ from typing import Text
...
@@ -18,6 +18,7 @@ from typing import Text
import
jsonlines
import
jsonlines
import
numpy
as
np
import
numpy
as
np
import
paddle
from
paddle.io
import
BatchSampler
from
paddle.io
import
BatchSampler
from
paddle.io
import
DataLoader
from
paddle.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
from
paddle.io
import
DistributedBatchSampler
...
@@ -28,7 +29,11 @@ from paddlespeech.s2t.io.dataset import TransformDataset
...
@@ -28,7 +29,11 @@ from paddlespeech.s2t.io.dataset import TransformDataset
from
paddlespeech.s2t.io.reader
import
LoadInputsAndTargets
from
paddlespeech.s2t.io.reader
import
LoadInputsAndTargets
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
__all__
=
[
"BatchDataLoader"
]
import
paddlespeech.audio.streamdata
as
streamdata
from
paddlespeech.audio.text.text_featurizer
import
TextFeaturizer
from
yacs.config
import
CfgNode
__all__
=
[
"BatchDataLoader"
,
"StreamDataLoader"
]
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
@@ -56,6 +61,136 @@ def batch_collate(x):
...
@@ -56,6 +61,136 @@ def batch_collate(x):
"""
"""
return
x
[
0
]
return
x
[
0
]
def
read_preprocess_cfg
(
preprocess_conf_file
):
augment_conf
=
dict
()
preprocess_cfg
=
CfgNode
(
new_allowed
=
True
)
preprocess_cfg
.
merge_from_file
(
preprocess_conf_file
)
for
idx
,
process
in
enumerate
(
preprocess_cfg
[
"process"
]):
opts
=
dict
(
process
)
process_type
=
opts
.
pop
(
"type"
)
if
process_type
==
'time_warp'
:
augment_conf
[
'max_w'
]
=
process
[
'max_time_warp'
]
augment_conf
[
'w_inplace'
]
=
process
[
'inplace'
]
augment_conf
[
'w_mode'
]
=
process
[
'mode'
]
if
process_type
==
'freq_mask'
:
augment_conf
[
'max_f'
]
=
process
[
'F'
]
augment_conf
[
'num_f_mask'
]
=
process
[
'n_mask'
]
augment_conf
[
'f_inplace'
]
=
process
[
'inplace'
]
augment_conf
[
'f_replace_with_zero'
]
=
process
[
'replace_with_zero'
]
if
process_type
==
'time_mask'
:
augment_conf
[
'max_t'
]
=
process
[
'T'
]
augment_conf
[
'num_t_mask'
]
=
process
[
'n_mask'
]
augment_conf
[
't_inplace'
]
=
process
[
'inplace'
]
augment_conf
[
't_replace_with_zero'
]
=
process
[
'replace_with_zero'
]
return
augment_conf
class
StreamDataLoader
():
def
__init__
(
self
,
manifest_file
:
str
,
train_mode
:
bool
,
unit_type
:
str
=
'char'
,
batch_size
:
int
=
0
,
preprocess_conf
=
None
,
num_mel_bins
=
80
,
frame_length
=
25
,
frame_shift
=
10
,
dither
=
0.0
,
minlen_in
:
float
=
0.0
,
maxlen_in
:
float
=
float
(
'inf'
),
minlen_out
:
float
=
0.0
,
maxlen_out
:
float
=
float
(
'inf'
),
resample_rate
:
int
=
16000
,
shuffle_size
:
int
=
10000
,
sort_size
:
int
=
1000
,
n_iter_processes
:
int
=
1
,
prefetch_factor
:
int
=
2
,
dist_sampler
:
bool
=
False
,
cmvn_file
=
"data/mean_std.json"
,
vocab_filepath
=
'data/lang_char/vocab.txt'
):
self
.
manifest_file
=
manifest_file
self
.
train_model
=
train_mode
self
.
batch_size
=
batch_size
self
.
prefetch_factor
=
prefetch_factor
self
.
dist_sampler
=
dist_sampler
self
.
n_iter_processes
=
n_iter_processes
text_featurizer
=
TextFeaturizer
(
unit_type
,
vocab_filepath
)
symbol_table
=
text_featurizer
.
vocab_dict
self
.
feat_dim
=
num_mel_bins
self
.
vocab_size
=
text_featurizer
.
vocab_size
augment_conf
=
read_preprocess_cfg
(
preprocess_conf
)
# The list of shard
shardlist
=
[]
with
open
(
manifest_file
,
"r"
)
as
f
:
for
line
in
f
.
readlines
():
shardlist
.
append
(
line
.
strip
())
world_size
=
1
try
:
world_size
=
paddle
.
distributed
.
get_world_size
()
except
Exception
as
e
:
logger
.
warninig
(
e
)
logger
.
warninig
(
"can not get world_size using paddle.distributed.get_world_size(), use world_size=1"
)
assert
(
len
(
shardlist
)
>=
world_size
,
"the length of shard list should >= number of gpus/xpus/..."
)
update_n_iter_processes
=
int
(
max
(
min
(
len
(
shardlist
)
/
world_size
-
1
,
self
.
n_iter_processes
),
0
))
logger
.
info
(
f
"update_n_iter_processes
{
update_n_iter_processes
}
"
)
if
update_n_iter_processes
!=
self
.
n_iter_processes
:
self
.
n_iter_processes
=
update_n_iter_processes
logger
.
info
(
f
"change nun_workers to
{
self
.
n_iter_processes
}
"
)
if
self
.
dist_sampler
:
base_dataset
=
streamdata
.
DataPipeline
(
streamdata
.
SimpleShardList
(
shardlist
),
streamdata
.
split_by_node
if
train_mode
else
streamdata
.
placeholder
(),
streamdata
.
split_by_worker
,
streamdata
.
tarfile_to_samples
(
streamdata
.
reraise_exception
)
)
else
:
base_dataset
=
streamdata
.
DataPipeline
(
streamdata
.
SimpleShardList
(
shardlist
),
streamdata
.
split_by_worker
,
streamdata
.
tarfile_to_samples
(
streamdata
.
reraise_exception
)
)
self
.
dataset
=
base_dataset
.
append_list
(
streamdata
.
audio_tokenize
(
symbol_table
),
streamdata
.
audio_data_filter
(
frame_shift
=
frame_shift
,
max_length
=
maxlen_in
,
min_length
=
minlen_in
,
token_max_length
=
maxlen_out
,
token_min_length
=
minlen_out
),
streamdata
.
audio_resample
(
resample_rate
=
resample_rate
),
streamdata
.
audio_compute_fbank
(
num_mel_bins
=
num_mel_bins
,
frame_length
=
frame_length
,
frame_shift
=
frame_shift
,
dither
=
dither
),
streamdata
.
audio_spec_aug
(
**
augment_conf
)
if
train_mode
else
streamdata
.
placeholder
(),
# num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
streamdata
.
shuffle
(
shuffle_size
),
streamdata
.
sort
(
sort_size
=
sort_size
),
streamdata
.
batched
(
batch_size
),
streamdata
.
audio_padding
(),
streamdata
.
audio_cmvn
(
cmvn_file
)
)
if
paddle
.
__version__
>=
'2.3.2'
:
self
.
loader
=
streamdata
.
WebLoader
(
self
.
dataset
,
num_workers
=
self
.
n_iter_processes
,
prefetch_factor
=
self
.
prefetch_factor
,
batch_size
=
None
)
else
:
self
.
loader
=
streamdata
.
WebLoader
(
self
.
dataset
,
num_workers
=
self
.
n_iter_processes
,
batch_size
=
None
)
def
__iter__
(
self
):
return
self
.
loader
.
__iter__
()
def
__call__
(
self
):
return
self
.
__iter__
()
def
__len__
(
self
):
logger
.
info
(
"Stream dataloader does not support calculate the length of the dataset"
)
return
-
1
class
BatchDataLoader
():
class
BatchDataLoader
():
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -199,3 +334,119 @@ class BatchDataLoader():
...
@@ -199,3 +334,119 @@ class BatchDataLoader():
echo
+=
f
"shortest_first:
{
self
.
shortest_first
}
, "
echo
+=
f
"shortest_first:
{
self
.
shortest_first
}
, "
echo
+=
f
"file:
{
self
.
json_file
}
"
echo
+=
f
"file:
{
self
.
json_file
}
"
return
echo
return
echo
class
DataLoaderFactory
():
@
staticmethod
def
get_dataloader
(
mode
:
str
,
config
,
args
):
config
=
config
.
clone
()
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
if
use_streamdata
:
if
mode
==
'train'
:
config
[
'manifest'
]
=
config
.
train_manifest
config
[
'train_mode'
]
=
True
elif
mode
==
'valid'
:
config
[
'manifest'
]
=
config
.
dev_manifest
config
[
'train_mode'
]
=
False
elif
model
==
'test'
or
mode
==
'align'
:
config
[
'manifest'
]
=
config
.
test_manifest
config
[
'train_mode'
]
=
False
config
[
'dither'
]
=
0.0
config
[
'minlen_in'
]
=
0.0
config
[
'maxlen_in'
]
=
float
(
'inf'
)
config
[
'minlen_out'
]
=
0
config
[
'maxlen_out'
]
=
float
(
'inf'
)
config
[
'dist_sampler'
]
=
False
else
:
raise
KeyError
(
"not valid mode type!!, please input one of 'train, valid, test, align'"
)
return
StreamDataLoader
(
manifest_file
=
config
.
manifest
,
train_mode
=
config
.
train_mode
,
unit_type
=
config
.
unit_type
,
preprocess_conf
=
config
.
preprocess_config
,
batch_size
=
config
.
batch_size
,
num_mel_bins
=
config
.
feat_dim
,
frame_length
=
config
.
window_ms
,
frame_shift
=
config
.
stride_ms
,
dither
=
config
.
dither
,
minlen_in
=
config
.
minlen_in
,
maxlen_in
=
config
.
maxlen_in
,
minlen_out
=
config
.
minlen_out
,
maxlen_out
=
config
.
maxlen_out
,
resample_rate
=
config
.
resample_rate
,
shuffle_size
=
config
.
shuffle_size
,
sort_size
=
config
.
sort_size
,
n_iter_processes
=
config
.
num_workers
,
prefetch_factor
=
config
.
prefetch_factor
,
dist_sampler
=
config
.
dist_sampler
,
cmvn_file
=
config
.
cmvn_file
,
vocab_filepath
=
config
.
vocab_filepath
,
)
else
:
if
mode
==
'train'
:
config
[
'manifest'
]
=
config
.
train_manifest
config
[
'train_mode'
]
=
True
config
[
'mini_batch_size'
]
=
args
.
ngpu
config
[
'subsampling_factor'
]
=
1
config
[
'num_encs'
]
=
1
elif
mode
==
'valid'
:
config
[
'manifest'
]
=
config
.
dev_manifest
config
[
'train_mode'
]
=
False
config
[
'sortagrad'
]
=
False
config
[
'maxlen_in'
]
=
float
(
'inf'
)
config
[
'maxlen_out'
]
=
float
(
'inf'
)
config
[
'minibatches'
]
=
0
config
[
'mini_batch_size'
]
=
args
.
ngpu
config
[
'batch_count'
]
=
'auto'
config
[
'batch_bins'
]
=
0
config
[
'batch_frames_in'
]
=
0
config
[
'batch_frames_out'
]
=
0
config
[
'batch_frames_inout'
]
=
0
config
[
'subsampling_factor'
]
=
1
config
[
'num_encs'
]
=
1
config
[
'shortest_first'
]
=
False
elif
mode
==
'test'
or
mode
==
'align'
:
config
[
'manifest'
]
=
config
.
test_manifest
config
[
'train_mode'
]
=
False
config
[
'sortagrad'
]
=
False
config
[
'batch_size'
]
=
config
.
get
(
'decode'
,
dict
()).
get
(
'decode_batch_size'
,
1
)
config
[
'maxlen_in'
]
=
float
(
'inf'
)
config
[
'maxlen_out'
]
=
float
(
'inf'
)
config
[
'minibatches'
]
=
0
config
[
'mini_batch_size'
]
=
1
config
[
'batch_count'
]
=
'auto'
config
[
'batch_bins'
]
=
0
config
[
'batch_frames_in'
]
=
0
config
[
'batch_frames_out'
]
=
0
config
[
'batch_frames_inout'
]
=
0
config
[
'num_workers'
]
=
1
config
[
'subsampling_factor'
]
=
1
config
[
'num_encs'
]
=
1
config
[
'dist_sampler'
]
=
False
config
[
'shortest_first'
]
=
False
else
:
raise
KeyError
(
"not valid mode type!!, please input one of 'train, valid, test, align'"
)
return
BatchDataLoader
(
json_file
=
config
.
manifest
,
train_mode
=
config
.
train_mode
,
sortagrad
=
config
.
sortagrad
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
config
.
maxlen_in
,
maxlen_out
=
config
.
maxlen_out
,
minibatches
=
config
.
minibatches
,
mini_batch_size
=
config
.
mini_batch_size
,
batch_count
=
config
.
batch_count
,
batch_bins
=
config
.
batch_bins
,
batch_frames_in
=
config
.
batch_frames_in
,
batch_frames_out
=
config
.
batch_frames_out
,
batch_frames_inout
=
config
.
batch_frames_inout
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
config
.
subsampling_factor
,
load_aux_output
=
config
.
get
(
'load_transcript'
,
None
),
num_encs
=
config
.
num_encs
,
dist_sampler
=
config
.
dist_sampler
,
shortest_first
=
config
.
shortest_first
)
paddlespeech/s2t/io/reader.py
浏览文件 @
d1a25f6c
...
@@ -19,7 +19,7 @@ import numpy as np
...
@@ -19,7 +19,7 @@ import numpy as np
import
soundfile
import
soundfile
from
.utility
import
feat_type
from
.utility
import
feat_type
from
paddlespeech.
s2t
.transform.transformation
import
Transformation
from
paddlespeech.
audio
.transform.transformation
import
Transformation
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
# from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline as Transformation
# from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline as Transformation
...
...
paddlespeech/s2t/models/u2/u2.py
浏览文件 @
d1a25f6c
...
@@ -48,9 +48,9 @@ from paddlespeech.s2t.utils import checkpoint
...
@@ -48,9 +48,9 @@ from paddlespeech.s2t.utils import checkpoint
from
paddlespeech.s2t.utils
import
layer_tools
from
paddlespeech.s2t.utils
import
layer_tools
from
paddlespeech.s2t.utils.ctc_utils
import
remove_duplicates_and_blank
from
paddlespeech.s2t.utils.ctc_utils
import
remove_duplicates_and_blank
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.
s2t
.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.
audio
.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.
s2t
.utils.tensor_utils
import
pad_sequence
from
paddlespeech.
audio
.utils.tensor_utils
import
pad_sequence
from
paddlespeech.
s2t
.utils.tensor_utils
import
th_accuracy
from
paddlespeech.
audio
.utils.tensor_utils
import
th_accuracy
from
paddlespeech.s2t.utils.utility
import
log_add
from
paddlespeech.s2t.utils.utility
import
log_add
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
...
...
paddlespeech/s2t/models/u2_st/u2_st.py
浏览文件 @
d1a25f6c
...
@@ -38,8 +38,8 @@ from paddlespeech.s2t.modules.mask import subsequent_mask
...
@@ -38,8 +38,8 @@ from paddlespeech.s2t.modules.mask import subsequent_mask
from
paddlespeech.s2t.utils
import
checkpoint
from
paddlespeech.s2t.utils
import
checkpoint
from
paddlespeech.s2t.utils
import
layer_tools
from
paddlespeech.s2t.utils
import
layer_tools
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.
s2t
.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.
audio
.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.
s2t
.utils.tensor_utils
import
th_accuracy
from
paddlespeech.
audio
.utils.tensor_utils
import
th_accuracy
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
__all__
=
[
"U2STModel"
,
"U2STInferModel"
]
__all__
=
[
"U2STModel"
,
"U2STInferModel"
]
...
...
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
浏览文件 @
d1a25f6c
...
@@ -26,7 +26,7 @@ from paddlespeech.cli.log import logger
...
@@ -26,7 +26,7 @@ from paddlespeech.cli.log import logger
from
paddlespeech.resource
import
CommonTaskResource
from
paddlespeech.resource
import
CommonTaskResource
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.
s2t
.transform.transformation
import
Transformation
from
paddlespeech.
audio
.transform.transformation
import
Transformation
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.utils
import
onnx_infer
from
paddlespeech.server.utils
import
onnx_infer
...
...
paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py
浏览文件 @
d1a25f6c
...
@@ -24,9 +24,9 @@ from yacs.config import CfgNode
...
@@ -24,9 +24,9 @@ from yacs.config import CfgNode
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.log
import
logger
from
paddlespeech.resource
import
CommonTaskResource
from
paddlespeech.resource
import
CommonTaskResource
from
paddlespeech.audio.transform.transformation
import
Transformation
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.transform.transformation
import
Transformation
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.utils.paddle_predictor
import
init_predictor
from
paddlespeech.server.utils.paddle_predictor
import
init_predictor
...
...
paddlespeech/server/engine/asr/online/python/asr_engine.py
浏览文件 @
d1a25f6c
...
@@ -24,9 +24,9 @@ from yacs.config import CfgNode
...
@@ -24,9 +24,9 @@ from yacs.config import CfgNode
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.log
import
logger
from
paddlespeech.resource
import
CommonTaskResource
from
paddlespeech.resource
import
CommonTaskResource
from
paddlespeech.audio.transform.transformation
import
Transformation
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.transform.transformation
import
Transformation
from
paddlespeech.s2t.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.s2t.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.s2t.utils.tensor_utils
import
pad_sequence
from
paddlespeech.s2t.utils.tensor_utils
import
pad_sequence
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
...
...
setup.py
浏览文件 @
d1a25f6c
...
@@ -69,7 +69,9 @@ base = [
...
@@ -69,7 +69,9 @@ base = [
"prettytable"
,
"prettytable"
,
"zhon"
,
"zhon"
,
"colorlog"
,
"colorlog"
,
"pathos == 0.2.8"
"pathos == 0.2.8"
,
"braceexpand"
,
"pyyaml"
]
]
server
=
[
server
=
[
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录