Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
92d1d08b
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
92d1d08b
编写于
7月 05, 2022
作者:
H
huangyuxin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix scripts
上级
6ec69212
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
901 addition
and
417 deletion
+901
-417
examples/wenetspeech/asr1/conf/conformer.yaml
examples/wenetspeech/asr1/conf/conformer.yaml
+4
-16
examples/wenetspeech/asr1/local/data.sh
examples/wenetspeech/asr1/local/data.sh
+2
-2
paddlespeech/audio/streamdata/__init__.py
paddlespeech/audio/streamdata/__init__.py
+9
-8
paddlespeech/audio/streamdata/autodecode.py
paddlespeech/audio/streamdata/autodecode.py
+4
-4
paddlespeech/audio/streamdata/compat.py
paddlespeech/audio/streamdata/compat.py
+12
-12
paddlespeech/audio/streamdata/filters.py
paddlespeech/audio/streamdata/filters.py
+14
-14
paddlespeech/audio/streamdata/tariterators.py
paddlespeech/audio/streamdata/tariterators.py
+3
-3
paddlespeech/audio/text/text_featurizer.py
paddlespeech/audio/text/text_featurizer.py
+235
-0
paddlespeech/audio/text/utility.py
paddlespeech/audio/text/utility.py
+393
-0
paddlespeech/s2t/exps/deepspeech2/model.py
paddlespeech/s2t/exps/deepspeech2/model.py
+1
-1
paddlespeech/s2t/exps/u2/model.py
paddlespeech/s2t/exps/u2/model.py
+7
-178
paddlespeech/s2t/exps/u2_kaldi/model.py
paddlespeech/s2t/exps/u2_kaldi/model.py
+28
-88
paddlespeech/s2t/exps/u2_st/model.py
paddlespeech/s2t/exps/u2_st/model.py
+18
-75
paddlespeech/s2t/io/dataloader.py
paddlespeech/s2t/io/dataloader.py
+166
-11
paddlespeech/s2t/models/u2/u2.py
paddlespeech/s2t/models/u2/u2.py
+3
-3
paddlespeech/s2t/models/u2_st/u2_st.py
paddlespeech/s2t/models/u2_st/u2_st.py
+2
-2
未找到文件。
examples/wenetspeech/asr1/conf/conformer.yaml
浏览文件 @
92d1d08b
...
...
@@ -52,6 +52,7 @@ test_manifest: data/test_meeting/data.list
use_stream_data
:
True
unit_type
:
'
char'
vocab_filepath
:
data/lang_char/vocab.txt
preprocess_config
:
conf/preprocess.yaml
cmvn_file
:
data/mean_std.json
spm_model_prefix
:
'
'
feat_dim
:
80
...
...
@@ -65,30 +66,17 @@ maxlen_in: 1200 # if input length(number of frames) > maxlen-in, data is automa
minlen_out
:
0
maxlen_out
:
150
# if output length(number of tokens) > maxlen-out, data is automatically removed
resample_rate
:
16000
shuffle_size
:
1500
sort_size
:
1000
shuffle_size
:
1500
# read number of 'shuffle_size' data as a chunk, shuffle the data in the chunk
sort_size
:
1000
# read number of 'sort_size' data as a chunk, sort the data in the chunk
num_workers
:
8
prefetch_factor
:
10
dist_sampler
:
True
num_encs
:
1
augment_conf
:
max_w
:
80
w_inplace
:
True
w_mode
:
"
PIL"
max_f
:
30
num_f_mask
:
2
f_inplace
:
True
f_replace_with_zero
:
False
max_t
:
40
num_t_mask
:
2
t_inplace
:
True
t_replace_with_zero
:
False
###########################################
# Training #
###########################################
n_epoch
:
3
0
n_epoch
:
3
2
accum_grad
:
32
global_grad_clip
:
5.0
log_interval
:
100
...
...
examples/wenetspeech/asr1/local/data.sh
浏览文件 @
92d1d08b
...
...
@@ -90,8 +90,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
for
x
in
$dev_set
$test_sets
${
train_set
}
;
do
dst
=
$shards_dir
/
$x
mkdir
-p
$dst
utils/make_filted_shard_list.py
--
resample
16000
--num_utts_per_shard
1000
\
--do_filter
--
num_node
1
--num_gpus_per_node
8
\
utils/make_filted_shard_list.py
--
num_node
1
--num_gpus_per_node
8
--num_utts_per_shard
1000
\
--do_filter
--
resample
16000
\
--num_threads
32
--segments
data/
$x
/segments
\
data/
$x
/wav.scp data/
$x
/text
\
$(
realpath
$dst
)
data/
$x
/data.list
...
...
paddlespeech/audio/streamdata/__init__.py
浏览文件 @
92d1d08b
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
...
...
@@ -26,7 +27,7 @@ from .filters import (
pipelinefilter
,
rename
,
rename_keys
,
r
sample
,
audio_re
sample
,
select
,
shuffle
,
slice
,
...
...
@@ -34,14 +35,14 @@ from .filters import (
transform_with
,
unbatched
,
xdecode
,
data_filter
,
tokenize
,
resample
,
compute_fbank
,
spec_aug
,
audio_
data_filter
,
audio_
tokenize
,
audio_
resample
,
audio_
compute_fbank
,
audio_
spec_aug
,
sort
,
padding
,
cmvn
,
audio_
padding
,
audio_
cmvn
,
placeholder
,
)
from
.handlers
import
(
...
...
paddlespeech/audio/streamdata/autodecode.py
浏览文件 @
92d1d08b
...
...
@@ -291,12 +291,12 @@ def torch_video(key, data):
################################################################
# paddleaudio
# paddle
speech.
audio
################################################################
def
paddle_audio
(
key
,
data
):
"""Decode audio using the paddleaudio library.
"""Decode audio using the paddle
speech.
audio library.
:param key: file name extension
:param data: data to be decoded
...
...
@@ -305,13 +305,13 @@ def paddle_audio(key, data):
if
extension
not
in
[
"flac"
,
"mp3"
,
"sox"
,
"wav"
,
"m4a"
,
"ogg"
,
"wma"
]:
return
None
import
paddleaudio
import
paddle
speech.
audio
with
tempfile
.
TemporaryDirectory
()
as
dirname
:
fname
=
os
.
path
.
join
(
dirname
,
f
"file.
{
extension
}
"
)
with
open
(
fname
,
"wb"
)
as
stream
:
stream
.
write
(
data
)
return
paddleaudio
.
load
(
fname
)
return
paddle
speech
.
audio
.
load
(
fname
)
################################################################
...
...
paddlespeech/audio/streamdata/compat.py
浏览文件 @
92d1d08b
...
...
@@ -78,29 +78,29 @@ class FluidInterface:
def
xdecode
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
xdecode
(
*
args
,
**
kw
))
def
data_filter
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
data_filter
(
*
args
,
**
kw
))
def
audio_
data_filter
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
audio_
data_filter
(
*
args
,
**
kw
))
def
tokenize
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
tokenize
(
*
args
,
**
kw
))
def
audio_
tokenize
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
audio_
tokenize
(
*
args
,
**
kw
))
def
resample
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
resample
(
*
args
,
**
kw
))
def
compute_fbank
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
compute_fbank
(
*
args
,
**
kw
))
def
audio_
compute_fbank
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
audio_
compute_fbank
(
*
args
,
**
kw
))
def
spec_aug
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
spec_aug
(
*
args
,
**
kw
))
def
audio_
spec_aug
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
audio_
spec_aug
(
*
args
,
**
kw
))
def
sort
(
self
,
size
=
500
):
return
self
.
compose
(
filters
.
sort
(
size
))
def
padding
(
self
):
return
self
.
compose
(
filters
.
padding
())
def
audio_
padding
(
self
):
return
self
.
compose
(
filters
.
audio_
padding
())
def
cmvn
(
self
,
cmvn_file
):
return
self
.
compose
(
filters
.
cmvn
(
cmvn_file
))
def
audio_
cmvn
(
self
,
cmvn_file
):
return
self
.
compose
(
filters
.
audio_
cmvn
(
cmvn_file
))
class
WebDataset
(
DataPipeline
,
FluidInterface
):
"""Small fluid-interface wrapper for DataPipeline."""
...
...
paddlespeech/audio/streamdata/filters.py
浏览文件 @
92d1d08b
...
...
@@ -579,7 +579,7 @@ xdecode = pipelinefilter(_xdecode)
def
_data_filter
(
source
,
def
_
audio_
data_filter
(
source
,
frame_shift
=
10
,
max_length
=
10240
,
min_length
=
10
,
...
...
@@ -629,9 +629,9 @@ def _data_filter(source,
continue
yield
sample
data_filter
=
pipelinefilter
(
_data_filter
)
audio_data_filter
=
pipelinefilter
(
_audio
_data_filter
)
def
_tokenize
(
source
,
def
_
audio_
tokenize
(
source
,
symbol_table
,
bpe_model
=
None
,
non_lang_syms
=
None
,
...
...
@@ -693,9 +693,9 @@ def _tokenize(source,
sample
[
'label'
]
=
label
yield
sample
tokenize
=
pipelinefilter
(
_tokenize
)
audio_tokenize
=
pipelinefilter
(
_audio
_tokenize
)
def
_resample
(
source
,
resample_rate
=
16000
):
def
_
audio_
resample
(
source
,
resample_rate
=
16000
):
""" Resample data.
Inplace operation.
...
...
@@ -718,9 +718,9 @@ def _resample(source, resample_rate=16000):
))
yield
sample
resample
=
pipelinefilter
(
_resample
)
audio_resample
=
pipelinefilter
(
_audio
_resample
)
def
_compute_fbank
(
source
,
def
_
audio_
compute_fbank
(
source
,
num_mel_bins
=
80
,
frame_length
=
25
,
frame_shift
=
10
,
...
...
@@ -756,9 +756,9 @@ def _compute_fbank(source,
yield
dict
(
fname
=
sample
[
'fname'
],
label
=
sample
[
'label'
],
feat
=
mat
)
compute_fbank
=
pipelinefilter
(
_compute_fbank
)
audio_compute_fbank
=
pipelinefilter
(
_audio
_compute_fbank
)
def
_spec_aug
(
source
,
def
_
audio_
spec_aug
(
source
,
max_w
=
5
,
w_inplace
=
True
,
w_mode
=
"PIL"
,
...
...
@@ -799,7 +799,7 @@ def _spec_aug(source,
sample
[
'feat'
]
=
paddle
.
to_tensor
(
x
,
dtype
=
paddle
.
float32
)
yield
sample
spec_aug
=
pipelinefilter
(
_spec_aug
)
audio_spec_aug
=
pipelinefilter
(
_audio
_spec_aug
)
def
_sort
(
source
,
sort_size
=
500
):
...
...
@@ -881,7 +881,7 @@ def dynamic_batched(source, max_frames_in_batch=12000):
yield
buf
def
_padding
(
source
):
def
_
audio_
padding
(
source
):
""" Padding the data into training data
Args:
...
...
@@ -914,9 +914,9 @@ def _padding(source):
yield
(
sorted_keys
,
padded_feats
,
feats_lengths
,
padding_labels
,
label_lengths
)
padding
=
pipelinefilter
(
_padding
)
audio_padding
=
pipelinefilter
(
_audio
_padding
)
def
_cmvn
(
source
,
cmvn_file
):
def
_
audio_
cmvn
(
source
,
cmvn_file
):
global_cmvn
=
GlobalCMVN
(
cmvn_file
)
for
batch
in
source
:
sorted_keys
,
padded_feats
,
feats_lengths
,
padding_labels
,
label_lengths
=
batch
...
...
@@ -926,7 +926,7 @@ def _cmvn(source, cmvn_file):
yield
(
sorted_keys
,
padded_feats
,
feats_lengths
,
padding_labels
,
label_lengths
)
cmvn
=
pipelinefilter
(
_cmvn
)
audio_cmvn
=
pipelinefilter
(
_audio
_cmvn
)
def
_placeholder
(
source
):
for
data
in
source
:
...
...
paddlespeech/audio/streamdata/tariterators.py
浏览文件 @
92d1d08b
...
...
@@ -21,7 +21,7 @@ trace = False
meta_prefix
=
"__"
meta_suffix
=
"__"
from
...
import
audio
as
paddleaudio
import
paddlespeech
import
paddle
import
numpy
as
np
...
...
@@ -118,7 +118,7 @@ def tar_file_iterator(
assert
pos
>
0
prefix
,
postfix
=
name
[:
pos
],
name
[
pos
+
1
:]
if
postfix
==
'wav'
:
waveform
,
sample_rate
=
paddleaudio
.
load
(
stream
.
extractfile
(
tarinfo
),
normal
=
False
)
waveform
,
sample_rate
=
paddle
speech
.
audio
.
load
(
stream
.
extractfile
(
tarinfo
),
normal
=
False
)
result
=
dict
(
fname
=
prefix
,
wav
=
waveform
,
sample_rate
=
sample_rate
)
else
:
txt
=
stream
.
extractfile
(
tarinfo
).
read
().
decode
(
'utf8'
).
strip
()
...
...
@@ -167,7 +167,7 @@ def tar_file_and_group_iterator(
if
postfix
==
'txt'
:
example
[
'txt'
]
=
file_obj
.
read
().
decode
(
'utf8'
).
strip
()
elif
postfix
in
AUDIO_FORMAT_SETS
:
waveform
,
sample_rate
=
paddleaudio
.
load
(
file_obj
,
normal
=
False
)
waveform
,
sample_rate
=
paddle
speech
.
audio
.
load
(
file_obj
,
normal
=
False
)
waveform
=
paddle
.
to_tensor
(
np
.
expand_dims
(
np
.
array
(
waveform
),
0
),
dtype
=
paddle
.
float32
)
example
[
'wav'
]
=
waveform
...
...
paddlespeech/audio/text/text_featurizer.py
0 → 100644
浏览文件 @
92d1d08b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the text featurizer class."""
from
pprint
import
pformat
from
typing
import
Union
import
sentencepiece
as
spm
from
.utility
import
BLANK
from
.utility
import
EOS
from
.utility
import
load_dict
from
.utility
import
MASKCTC
from
.utility
import
SOS
from
.utility
import
SPACE
from
.utility
import
UNK
from
..utils.log
import
Logger
logger
=
Logger
(
__name__
)
__all__
=
[
"TextFeaturizer"
]
class
TextFeaturizer
():
def
__init__
(
self
,
unit_type
,
vocab
,
spm_model_prefix
=
None
,
maskctc
=
False
):
"""Text featurizer, for processing or extracting features from text.
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
a list of token indices. Note that the token indexing order follows the
given vocabulary file.
Args:
unit_type (str): unit type, e.g. char, word, spm
vocab Option[str, list]: Filepath to load vocabulary for token indices conversion, or vocab list.
spm_model_prefix (str, optional): spm model prefix. Defaults to None.
"""
assert
unit_type
in
(
'char'
,
'spm'
,
'word'
)
self
.
unit_type
=
unit_type
self
.
unk
=
UNK
self
.
maskctc
=
maskctc
if
vocab
:
self
.
vocab_dict
,
self
.
_id2token
,
self
.
vocab_list
,
self
.
unk_id
,
self
.
eos_id
,
self
.
blank_id
=
self
.
_load_vocabulary_from_file
(
vocab
,
maskctc
)
self
.
vocab_size
=
len
(
self
.
vocab_list
)
else
:
logger
.
warning
(
"TextFeaturizer: not have vocab file or vocab list."
)
if
unit_type
==
'spm'
:
spm_model
=
spm_model_prefix
+
'.model'
self
.
sp
=
spm
.
SentencePieceProcessor
()
self
.
sp
.
Load
(
spm_model
)
def
tokenize
(
self
,
text
,
replace_space
=
True
):
if
self
.
unit_type
==
'char'
:
tokens
=
self
.
char_tokenize
(
text
,
replace_space
)
elif
self
.
unit_type
==
'word'
:
tokens
=
self
.
word_tokenize
(
text
)
else
:
# spm
tokens
=
self
.
spm_tokenize
(
text
)
return
tokens
def
detokenize
(
self
,
tokens
):
if
self
.
unit_type
==
'char'
:
text
=
self
.
char_detokenize
(
tokens
)
elif
self
.
unit_type
==
'word'
:
text
=
self
.
word_detokenize
(
tokens
)
else
:
# spm
text
=
self
.
spm_detokenize
(
tokens
)
return
text
def
featurize
(
self
,
text
):
"""Convert text string to a list of token indices.
Args:
text (str): Text to process.
Returns:
List[int]: List of token indices.
"""
tokens
=
self
.
tokenize
(
text
)
ids
=
[]
for
token
in
tokens
:
if
token
not
in
self
.
vocab_dict
:
logger
.
debug
(
f
"Text Token:
{
token
}
->
{
self
.
unk
}
"
)
token
=
self
.
unk
ids
.
append
(
self
.
vocab_dict
[
token
])
return
ids
def
defeaturize
(
self
,
idxs
):
"""Convert a list of token indices to text string,
ignore index after eos_id.
Args:
idxs (List[int]): List of token indices.
Returns:
str: Text.
"""
tokens
=
[]
for
idx
in
idxs
:
if
idx
==
self
.
eos_id
:
break
tokens
.
append
(
self
.
_id2token
[
idx
])
text
=
self
.
detokenize
(
tokens
)
return
text
def
char_tokenize
(
self
,
text
,
replace_space
=
True
):
"""Character tokenizer.
Args:
text (str): text string.
replace_space (bool): False only used by build_vocab.py.
Returns:
List[str]: tokens.
"""
text
=
text
.
strip
()
if
replace_space
:
text_list
=
[
SPACE
if
item
==
" "
else
item
for
item
in
list
(
text
)]
else
:
text_list
=
list
(
text
)
return
text_list
def
char_detokenize
(
self
,
tokens
):
"""Character detokenizer.
Args:
tokens (List[str]): tokens.
Returns:
str: text string.
"""
tokens
=
[
t
.
replace
(
SPACE
,
" "
)
for
t
in
tokens
]
return
""
.
join
(
tokens
)
def
word_tokenize
(
self
,
text
):
"""Word tokenizer, separate by <space>."""
return
text
.
strip
().
split
()
def
word_detokenize
(
self
,
tokens
):
"""Word detokenizer, separate by <space>."""
return
" "
.
join
(
tokens
)
def
spm_tokenize
(
self
,
text
):
"""spm tokenize.
Args:
text (str): text string.
Returns:
List[str]: sentence pieces str code
"""
stats
=
{
"num_empty"
:
0
,
"num_filtered"
:
0
}
def
valid
(
line
):
return
True
def
encode
(
l
):
return
self
.
sp
.
EncodeAsPieces
(
l
)
def
encode_line
(
line
):
line
=
line
.
strip
()
if
len
(
line
)
>
0
:
line
=
encode
(
line
)
if
valid
(
line
):
return
line
else
:
stats
[
"num_filtered"
]
+=
1
else
:
stats
[
"num_empty"
]
+=
1
return
None
enc_line
=
encode_line
(
text
)
return
enc_line
def
spm_detokenize
(
self
,
tokens
,
input_format
=
'piece'
):
"""spm detokenize.
Args:
ids (List[str]): tokens.
Returns:
str: text
"""
if
input_format
==
"piece"
:
def
decode
(
l
):
return
""
.
join
(
self
.
sp
.
DecodePieces
(
l
))
elif
input_format
==
"id"
:
def
decode
(
l
):
return
""
.
join
(
self
.
sp
.
DecodeIds
(
l
))
return
decode
(
tokens
)
def
_load_vocabulary_from_file
(
self
,
vocab
:
Union
[
str
,
list
],
maskctc
:
bool
):
"""Load vocabulary from file."""
if
isinstance
(
vocab
,
list
):
vocab_list
=
vocab
else
:
vocab_list
=
load_dict
(
vocab
,
maskctc
)
assert
vocab_list
is
not
None
logger
.
debug
(
f
"Vocab:
{
pformat
(
vocab_list
)
}
"
)
id2token
=
dict
(
[(
idx
,
token
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
token2id
=
dict
(
[(
token
,
idx
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
blank_id
=
vocab_list
.
index
(
BLANK
)
if
BLANK
in
vocab_list
else
-
1
maskctc_id
=
vocab_list
.
index
(
MASKCTC
)
if
MASKCTC
in
vocab_list
else
-
1
unk_id
=
vocab_list
.
index
(
UNK
)
if
UNK
in
vocab_list
else
-
1
eos_id
=
vocab_list
.
index
(
EOS
)
if
EOS
in
vocab_list
else
-
1
sos_id
=
vocab_list
.
index
(
SOS
)
if
SOS
in
vocab_list
else
-
1
space_id
=
vocab_list
.
index
(
SPACE
)
if
SPACE
in
vocab_list
else
-
1
logger
.
info
(
f
"BLANK id:
{
blank_id
}
"
)
logger
.
info
(
f
"UNK id:
{
unk_id
}
"
)
logger
.
info
(
f
"EOS id:
{
eos_id
}
"
)
logger
.
info
(
f
"SOS id:
{
sos_id
}
"
)
logger
.
info
(
f
"SPACE id:
{
space_id
}
"
)
logger
.
info
(
f
"MASKCTC id:
{
maskctc_id
}
"
)
return
token2id
,
id2token
,
vocab_list
,
unk_id
,
eos_id
,
blank_id
paddlespeech/audio/text/utility.py
0 → 100644
浏览文件 @
92d1d08b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains data helper functions."""
import
json
import
math
import
tarfile
from
collections
import
namedtuple
from
typing
import
List
from
typing
import
Optional
from
typing
import
Text
import
jsonlines
import
numpy
as
np
from
paddlespeech.s2t.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"load_dict"
,
"load_cmvn"
,
"read_manifest"
,
"rms_to_db"
,
"rms_to_dbfs"
,
"max_dbfs"
,
"mean_dbfs"
,
"gain_db_to_ratio"
,
"normalize_audio"
,
"SOS"
,
"EOS"
,
"UNK"
,
"BLANK"
,
"MASKCTC"
,
"SPACE"
,
"convert_samples_to_float32"
,
"convert_samples_from_float32"
]
IGNORE_ID
=
-
1
# `sos` and `eos` using same token
SOS
=
"<eos>"
EOS
=
SOS
UNK
=
"<unk>"
BLANK
=
"<blank>"
MASKCTC
=
"<mask>"
SPACE
=
"<space>"
def
load_dict
(
dict_path
:
Optional
[
Text
],
maskctc
=
False
)
->
Optional
[
List
[
Text
]]:
if
dict_path
is
None
:
return
None
with
open
(
dict_path
,
"r"
)
as
f
:
dictionary
=
f
.
readlines
()
# first token is `<blank>`
# multi line: `<blank> 0\n`
# one line: `<blank>`
# space is relpace with <space>
char_list
=
[
entry
[:
-
1
].
split
(
" "
)[
0
]
for
entry
in
dictionary
]
if
BLANK
not
in
char_list
:
char_list
.
insert
(
0
,
BLANK
)
if
EOS
not
in
char_list
:
char_list
.
append
(
EOS
)
# for non-autoregressive maskctc model
if
maskctc
and
MASKCTC
not
in
char_list
:
char_list
.
append
(
MASKCTC
)
return
char_list
def
read_manifest
(
manifest_path
,
max_input_len
=
float
(
'inf'
),
min_input_len
=
0.0
,
max_output_len
=
float
(
'inf'
),
min_output_len
=
0.0
,
max_output_input_ratio
=
float
(
'inf'
),
min_output_input_ratio
=
0.0
,
):
"""Load and parse manifest file.
Args:
manifest_path ([type]): Manifest file to load and parse.
max_input_len ([type], optional): maximum output seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to float('inf').
min_input_len (float, optional): minimum input seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to 0.0.
max_output_len (float, optional): maximum input seq length,
in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length,
in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional):
maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional):
minimum output seq length/output seq length ratio. Defaults to 0.05.
Raises:
IOError: If failed to parse the manifest.
Returns:
List[dict]: Manifest parsing results.
"""
manifest
=
[]
with
jsonlines
.
open
(
manifest_path
,
'r'
)
as
reader
:
for
json_data
in
reader
:
feat_len
=
json_data
[
"input"
][
0
][
"shape"
][
0
]
if
"input"
in
json_data
and
"shape"
in
json_data
[
"input"
][
0
]
else
1.0
token_len
=
json_data
[
"output"
][
0
][
"shape"
][
0
]
if
"output"
in
json_data
and
"shape"
in
json_data
[
"output"
][
0
]
else
1.0
conditions
=
[
feat_len
>=
min_input_len
,
feat_len
<=
max_input_len
,
token_len
>=
min_output_len
,
token_len
<=
max_output_len
,
token_len
/
feat_len
>=
min_output_input_ratio
,
token_len
/
feat_len
<=
max_output_input_ratio
,
]
if
all
(
conditions
):
manifest
.
append
(
json_data
)
return
manifest
# Tar File read
TarLocalData
=
namedtuple
(
'TarLocalData'
,
[
'tar2info'
,
'tar2object'
])
def
parse_tar
(
file
):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result
=
{}
f
=
tarfile
.
open
(
file
)
for
tarinfo
in
f
.
getmembers
():
result
[
tarinfo
.
name
]
=
tarinfo
return
f
,
result
def
subfile_from_tar
(
file
,
local_data
=
None
):
"""Get subfile object from tar.
tar:tarpath#filename
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath
,
filename
=
file
.
split
(
':'
,
1
)[
1
].
split
(
'#'
,
1
)
if
local_data
is
None
:
local_data
=
TarLocalData
(
tar2info
=
{},
tar2object
=
{})
assert
isinstance
(
local_data
,
TarLocalData
)
if
'tar2info'
not
in
local_data
.
__dict__
:
local_data
.
tar2info
=
{}
if
'tar2object'
not
in
local_data
.
__dict__
:
local_data
.
tar2object
=
{}
if
tarpath
not
in
local_data
.
tar2info
:
fobj
,
infos
=
parse_tar
(
tarpath
)
local_data
.
tar2info
[
tarpath
]
=
infos
local_data
.
tar2object
[
tarpath
]
=
fobj
else
:
fobj
=
local_data
.
tar2object
[
tarpath
]
infos
=
local_data
.
tar2info
[
tarpath
]
return
fobj
.
extractfile
(
infos
[
filename
])
def
rms_to_db
(
rms
:
float
):
"""Root Mean Square to dB.
Args:
rms ([float]): root mean square
Returns:
float: dB
"""
return
20.0
*
math
.
log10
(
max
(
1e-16
,
rms
))
def
rms_to_dbfs
(
rms
:
float
):
"""Root Mean Square to dBFS.
https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/
Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB.
dB = dBFS + 3.0103
dBFS = db - 3.0103
e.g. 0 dB = -3.0103 dBFS
Args:
rms ([float]): root mean square
Returns:
float: dBFS
"""
return
rms_to_db
(
rms
)
-
3.0103
def
max_dbfs
(
sample_data
:
np
.
ndarray
):
"""Peak dBFS based on the maximum energy sample.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
"""
# Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization.
return
rms_to_dbfs
(
max
(
abs
(
np
.
min
(
sample_data
)),
abs
(
np
.
max
(
sample_data
))))
def
mean_dbfs
(
sample_data
):
"""Peak dBFS based on the RMS energy.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
"""
return
rms_to_dbfs
(
math
.
sqrt
(
np
.
mean
(
np
.
square
(
sample_data
,
dtype
=
np
.
float64
))))
def
gain_db_to_ratio
(
gain_db
:
float
):
"""dB to ratio
Args:
gain_db (float): gain in dB
Returns:
float: scale in amp
"""
return
math
.
pow
(
10.0
,
gain_db
/
20.0
)
def
normalize_audio
(
sample_data
:
np
.
ndarray
,
dbfs
:
float
=-
3.0103
):
"""Nomalize audio to dBFS.
Args:
sample_data (np.ndarray): input wave samples, [-1, 1].
dbfs (float, optional): target dBFS. Defaults to -3.0103.
Returns:
np.ndarray: normalized wave
"""
return
np
.
maximum
(
np
.
minimum
(
sample_data
*
gain_db_to_ratio
(
dbfs
-
max_dbfs
(
sample_data
)),
1.0
),
-
1.0
)
def
_load_json_cmvn
(
json_cmvn_file
):
""" Load the json format cmvn stats file and calculate cmvn
Args:
json_cmvn_file: cmvn stats file in json format
Returns:
a numpy array of [means, vars]
"""
with
open
(
json_cmvn_file
)
as
f
:
cmvn_stats
=
json
.
load
(
f
)
means
=
cmvn_stats
[
'mean_stat'
]
variance
=
cmvn_stats
[
'var_stat'
]
count
=
cmvn_stats
[
'frame_num'
]
for
i
in
range
(
len
(
means
)):
means
[
i
]
/=
count
variance
[
i
]
=
variance
[
i
]
/
count
-
means
[
i
]
*
means
[
i
]
if
variance
[
i
]
<
1.0e-20
:
variance
[
i
]
=
1.0e-20
variance
[
i
]
=
1.0
/
math
.
sqrt
(
variance
[
i
])
cmvn
=
np
.
array
([
means
,
variance
])
return
cmvn
def
_load_kaldi_cmvn
(
kaldi_cmvn_file
):
""" Load the kaldi format cmvn stats file and calculate cmvn
Args:
kaldi_cmvn_file: kaldi text style global cmvn file, which
is generated by:
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
Returns:
a numpy array of [means, vars]
"""
means
=
[]
variance
=
[]
with
open
(
kaldi_cmvn_file
,
'r'
)
as
fid
:
# kaldi binary file start with '\0B'
if
fid
.
read
(
2
)
==
'
\0
B'
:
logger
.
error
(
'kaldi cmvn binary file is not supported, please '
'recompute it by: compute-cmvn-stats --binary=false '
' scp:feats.scp global_cmvn'
)
sys
.
exit
(
1
)
fid
.
seek
(
0
)
arr
=
fid
.
read
().
split
()
assert
(
arr
[
0
]
==
'['
)
assert
(
arr
[
-
2
]
==
'0'
)
assert
(
arr
[
-
1
]
==
']'
)
feat_dim
=
int
((
len
(
arr
)
-
2
-
2
)
/
2
)
for
i
in
range
(
1
,
feat_dim
+
1
):
means
.
append
(
float
(
arr
[
i
]))
count
=
float
(
arr
[
feat_dim
+
1
])
for
i
in
range
(
feat_dim
+
2
,
2
*
feat_dim
+
2
):
variance
.
append
(
float
(
arr
[
i
]))
for
i
in
range
(
len
(
means
)):
means
[
i
]
/=
count
variance
[
i
]
=
variance
[
i
]
/
count
-
means
[
i
]
*
means
[
i
]
if
variance
[
i
]
<
1.0e-20
:
variance
[
i
]
=
1.0e-20
variance
[
i
]
=
1.0
/
math
.
sqrt
(
variance
[
i
])
cmvn
=
np
.
array
([
means
,
variance
])
return
cmvn
def
load_cmvn
(
cmvn_file
:
str
,
filetype
:
str
):
"""load cmvn from file.
Args:
cmvn_file (str): cmvn path.
filetype (str): file type, optional[npz, json, kaldi].
Raises:
ValueError: file type not support.
Returns:
Tuple[np.ndarray, np.ndarray]: mean, istd
"""
assert
filetype
in
[
'npz'
,
'json'
,
'kaldi'
],
filetype
filetype
=
filetype
.
lower
()
if
filetype
==
"json"
:
cmvn
=
_load_json_cmvn
(
cmvn_file
)
elif
filetype
==
"kaldi"
:
cmvn
=
_load_kaldi_cmvn
(
cmvn_file
)
elif
filetype
==
"npz"
:
eps
=
1e-14
npzfile
=
np
.
load
(
cmvn_file
)
mean
=
np
.
squeeze
(
npzfile
[
"mean"
])
std
=
np
.
squeeze
(
npzfile
[
"std"
])
istd
=
1
/
(
std
+
eps
)
cmvn
=
[
mean
,
istd
]
else
:
raise
ValueError
(
f
"cmvn file type no support:
{
filetype
}
"
)
return
cmvn
[
0
],
cmvn
[
1
]
def
convert_samples_to_float32
(
samples
):
"""Convert sample type to float32.
Audio sample type is usually integer or float-point.
Integers will be scaled to [-1, 1] in float32.
PCM16 -> PCM32
"""
float32_samples
=
samples
.
astype
(
'float32'
)
if
samples
.
dtype
in
np
.
sctypes
[
'int'
]:
bits
=
np
.
iinfo
(
samples
.
dtype
).
bits
float32_samples
*=
(
1.
/
2
**
(
bits
-
1
))
elif
samples
.
dtype
in
np
.
sctypes
[
'float'
]:
pass
else
:
raise
TypeError
(
"Unsupported sample type: %s."
%
samples
.
dtype
)
return
float32_samples
def
convert_samples_from_float32
(
samples
,
dtype
):
"""Convert sample type from float32 to dtype.
Audio sample type is usually integer or float-point. For integer
type, float32 will be rescaled from [-1, 1] to the maximum range
supported by the integer type.
PCM32 -> PCM16
"""
dtype
=
np
.
dtype
(
dtype
)
output_samples
=
samples
.
copy
()
if
dtype
in
np
.
sctypes
[
'int'
]:
bits
=
np
.
iinfo
(
dtype
).
bits
output_samples
*=
(
2
**
(
bits
-
1
)
/
1.
)
min_val
=
np
.
iinfo
(
dtype
).
min
max_val
=
np
.
iinfo
(
dtype
).
max
output_samples
[
output_samples
>
max_val
]
=
max_val
output_samples
[
output_samples
<
min_val
]
=
min_val
elif
samples
.
dtype
in
np
.
sctypes
[
'float'
]:
min_val
=
np
.
finfo
(
dtype
).
min
max_val
=
np
.
finfo
(
dtype
).
max
output_samples
[
output_samples
>
max_val
]
=
max_val
output_samples
[
output_samples
<
min_val
]
=
min_val
else
:
raise
TypeError
(
"Unsupported sample type: %s."
%
samples
.
dtype
)
return
output_samples
.
astype
(
dtype
)
paddlespeech/s2t/exps/deepspeech2/model.py
浏览文件 @
92d1d08b
...
...
@@ -23,7 +23,7 @@ import paddle
from
paddle
import
distributed
as
dist
from
paddle
import
inference
from
paddlespeech.
s2t.frontend.featurizer
.text_featurizer
import
TextFeaturizer
from
paddlespeech.
audio.text
.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.models.ds2
import
DeepSpeech2InferModel
from
paddlespeech.s2t.models.ds2
import
DeepSpeech2Model
...
...
paddlespeech/s2t/exps/u2/model.py
浏览文件 @
92d1d08b
...
...
@@ -27,6 +27,7 @@ from paddle import distributed as dist
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.io.dataloader
import
StreamDataLoader
from
paddlespeech.s2t.io.dataloader
import
DataLoaderFactory
from
paddlespeech.s2t.models.u2
import
U2Model
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
from
paddlespeech.s2t.training.reporter
import
ObsScope
...
...
@@ -134,7 +135,8 @@ class U2Trainer(Trainer):
msg
=
f
"Valid: Rank:
{
dist
.
get_rank
()
}
, "
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
#msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
if
not
self
.
use_streamdata
:
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_dump
.
items
())
logger
.
info
(
msg
)
...
...
@@ -195,7 +197,6 @@ class U2Trainer(Trainer):
except
Exception
as
e
:
logger
.
error
(
e
)
raise
e
with
Timer
(
"Eval Time Cost: {}"
):
total_loss
,
num_seen_utts
=
self
.
valid
()
if
dist
.
get_world_size
()
>
1
:
...
...
@@ -224,186 +225,14 @@ class U2Trainer(Trainer):
config
=
self
.
config
.
clone
()
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
if
self
.
train
:
# train/valid dataset, return token ids
if
self
.
use_streamdata
:
self
.
train_loader
=
StreamDataLoader
(
manifest_file
=
config
.
train_manifest
,
train_mode
=
True
,
unit_type
=
config
.
unit_type
,
batch_size
=
config
.
batch_size
,
num_mel_bins
=
config
.
feat_dim
,
frame_length
=
config
.
window_ms
,
frame_shift
=
config
.
stride_ms
,
dither
=
config
.
dither
,
minlen_in
=
config
.
minlen_in
,
maxlen_in
=
config
.
maxlen_in
,
minlen_out
=
config
.
minlen_out
,
maxlen_out
=
config
.
maxlen_out
,
resample_rate
=
config
.
resample_rate
,
augment_conf
=
config
.
augment_conf
,
# dict
shuffle_size
=
config
.
shuffle_size
,
sort_size
=
config
.
sort_size
,
n_iter_processes
=
config
.
num_workers
,
prefetch_factor
=
config
.
prefetch_factor
,
dist_sampler
=
config
.
get
(
'dist_sampler'
,
False
),
cmvn_file
=
config
.
cmvn_file
,
vocab_filepath
=
config
.
vocab_filepath
,
)
self
.
valid_loader
=
StreamDataLoader
(
manifest_file
=
config
.
dev_manifest
,
train_mode
=
False
,
unit_type
=
config
.
unit_type
,
batch_size
=
config
.
batch_size
,
num_mel_bins
=
config
.
feat_dim
,
frame_length
=
config
.
window_ms
,
frame_shift
=
config
.
stride_ms
,
dither
=
config
.
dither
,
minlen_in
=
config
.
minlen_in
,
maxlen_in
=
config
.
maxlen_in
,
minlen_out
=
config
.
minlen_out
,
maxlen_out
=
config
.
maxlen_out
,
resample_rate
=
config
.
resample_rate
,
augment_conf
=
config
.
augment_conf
,
# dict
shuffle_size
=
config
.
shuffle_size
,
sort_size
=
config
.
sort_size
,
n_iter_processes
=
config
.
num_workers
,
prefetch_factor
=
config
.
prefetch_factor
,
dist_sampler
=
config
.
get
(
'dist_sampler'
,
False
),
cmvn_file
=
config
.
cmvn_file
,
vocab_filepath
=
config
.
vocab_filepath
,
)
else
:
self
.
train_loader
=
BatchDataLoader
(
json_file
=
config
.
train_manifest
,
train_mode
=
True
,
sortagrad
=
config
.
sortagrad
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
config
.
maxlen_in
,
maxlen_out
=
config
.
maxlen_out
,
minibatches
=
config
.
minibatches
,
mini_batch_size
=
self
.
args
.
ngpu
,
batch_count
=
config
.
batch_count
,
batch_bins
=
config
.
batch_bins
,
batch_frames_in
=
config
.
batch_frames_in
,
batch_frames_out
=
config
.
batch_frames_out
,
batch_frames_inout
=
config
.
batch_frames_inout
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
,
dist_sampler
=
config
.
get
(
'dist_sampler'
,
False
),
shortest_first
=
False
)
self
.
valid_loader
=
BatchDataLoader
(
json_file
=
config
.
dev_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
self
.
args
.
ngpu
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
,
dist_sampler
=
config
.
get
(
'dist_sampler'
,
False
),
shortest_first
=
False
)
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
decode_batch_size
=
config
.
get
(
'decode'
,
dict
()).
get
(
'decode_batch_size'
,
1
)
# test dataset, return raw text
if
self
.
use_streamdata
:
self
.
test_loader
=
StreamDataLoader
(
manifest_file
=
config
.
test_manifest
,
train_mode
=
False
,
unit_type
=
config
.
unit_type
,
batch_size
=
config
.
batch_size
,
num_mel_bins
=
config
.
feat_dim
,
frame_length
=
config
.
window_ms
,
frame_shift
=
config
.
stride_ms
,
dither
=
0.0
,
minlen_in
=
0.0
,
maxlen_in
=
float
(
'inf'
),
minlen_out
=
0
,
maxlen_out
=
float
(
'inf'
),
resample_rate
=
config
.
resample_rate
,
augment_conf
=
config
.
augment_conf
,
# dict
shuffle_size
=
config
.
shuffle_size
,
sort_size
=
config
.
sort_size
,
n_iter_processes
=
config
.
num_workers
,
prefetch_factor
=
config
.
prefetch_factor
,
dist_sampler
=
config
.
get
(
'dist_sampler'
,
False
),
cmvn_file
=
config
.
cmvn_file
,
vocab_filepath
=
config
.
vocab_filepath
,
)
self
.
align_loader
=
StreamDataLoader
(
manifest_file
=
config
.
test_manifest
,
train_mode
=
False
,
unit_type
=
config
.
unit_type
,
batch_size
=
config
.
batch_size
,
num_mel_bins
=
config
.
feat_dim
,
frame_length
=
config
.
window_ms
,
frame_shift
=
config
.
stride_ms
,
dither
=
0.0
,
minlen_in
=
0.0
,
maxlen_in
=
float
(
'inf'
),
minlen_out
=
0
,
maxlen_out
=
float
(
'inf'
),
resample_rate
=
config
.
resample_rate
,
augment_conf
=
config
.
augment_conf
,
# dict
shuffle_size
=
config
.
shuffle_size
,
sort_size
=
config
.
sort_size
,
n_iter_processes
=
config
.
num_workers
,
prefetch_factor
=
config
.
prefetch_factor
,
dist_sampler
=
config
.
get
(
'dist_sampler'
,
False
),
cmvn_file
=
config
.
cmvn_file
,
vocab_filepath
=
config
.
vocab_filepath
,
)
else
:
self
.
test_loader
=
BatchDataLoader
(
json_file
=
config
.
test_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
decode_batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
self
.
align_loader
=
BatchDataLoader
(
json_file
=
config
.
test_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
decode_batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
logger
.
info
(
"Setup test/align Dataloader!"
)
def
setup_model
(
self
):
...
...
paddlespeech/s2t/exps/u2_kaldi/model.py
浏览文件 @
92d1d08b
...
...
@@ -25,7 +25,7 @@ from paddle import distributed as dist
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.utility
import
load_dict
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.io.dataloader
import
DataLoaderFactory
from
paddlespeech.s2t.models.u2
import
U2Model
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
from
paddlespeech.s2t.training.scheduler
import
LRSchedulerFactory
...
...
@@ -104,7 +104,8 @@ class U2Trainer(Trainer):
@
paddle
.
no_grad
()
def
valid
(
self
):
self
.
model
.
eval
()
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
total_loss
=
0.0
...
...
@@ -131,7 +132,8 @@ class U2Trainer(Trainer):
msg
=
f
"Valid: Rank:
{
dist
.
get_rank
()
}
, "
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
if
not
self
.
use_streamdata
:
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_dump
.
items
())
logger
.
info
(
msg
)
...
...
@@ -150,8 +152,8 @@ class U2Trainer(Trainer):
# paddle.jit.save(script_model, script_model_path)
self
.
before_train
()
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
n_epoch
:
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
...
...
@@ -162,7 +164,8 @@ class U2Trainer(Trainer):
msg
=
"Train: Rank: {}, "
.
format
(
dist
.
get_rank
())
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"batch : {}/{}, "
.
format
(
batch_index
+
1
,
if
not
self
.
use_streamdata
:
msg
+=
"batch : {}/{}, "
.
format
(
batch_index
+
1
,
len
(
self
.
train_loader
))
msg
+=
"lr: {:>.8f}, "
.
format
(
self
.
lr_scheduler
())
msg
+=
"data time: {:>.3f}s, "
.
format
(
dataload_time
)
...
...
@@ -198,87 +201,23 @@ class U2Trainer(Trainer):
self
.
new_epoch
()
def
setup_dataloader
(
self
):
config
=
self
.
config
.
clone
()
# train/valid dataset, return token ids
self
.
train_loader
=
BatchDataLoader
(
json_file
=
config
.
train_manifest
,
train_mode
=
True
,
sortagrad
=
False
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
self
.
args
.
ngpu
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
)
self
.
valid_loader
=
BatchDataLoader
(
json_file
=
config
.
dev_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
self
.
args
.
ngpu
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
None
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
)
decode_batch_size
=
config
.
get
(
'decode'
,
dict
()).
get
(
'decode_batch_size'
,
1
)
# test dataset, return raw text
self
.
test_loader
=
BatchDataLoader
(
json_file
=
config
.
test_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
decode_batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
None
,
n_iter_processes
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
self
.
align_loader
=
BatchDataLoader
(
json_file
=
config
.
test_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
decode_batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
None
,
n_iter_processes
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
logger
.
info
(
"Setup train/valid/test/align Dataloader!"
)
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
if
self
.
train
:
config
=
self
.
config
.
clone
()
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
config
=
self
.
config
.
clone
()
config
[
'preprocess_config'
]
=
None
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
config
=
self
.
config
.
clone
()
config
[
'preprocess_config'
]
=
None
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
config
=
self
.
config
.
clone
()
config
[
'preprocess_config'
]
=
None
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
logger
.
info
(
"Setup test/align Dataloader!"
)
def
setup_model
(
self
):
config
=
self
.
config
...
...
@@ -406,7 +345,8 @@ class U2Tester(U2Trainer):
def
test
(
self
):
assert
self
.
args
.
result_file
self
.
model
.
eval
()
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
stride_ms
=
self
.
config
.
stride_ms
error_rate_type
=
None
...
...
paddlespeech/s2t/exps/u2_st/model.py
浏览文件 @
92d1d08b
...
...
@@ -25,7 +25,7 @@ import paddle
from
paddle
import
distributed
as
dist
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.io.dataloader
import
DataLoaderFactory
from
paddlespeech.s2t.models.u2_st
import
U2STModel
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
from
paddlespeech.s2t.training.reporter
import
ObsScope
...
...
@@ -120,7 +120,8 @@ class U2STTrainer(Trainer):
@
paddle
.
no_grad
()
def
valid
(
self
):
self
.
model
.
eval
()
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
total_loss
=
0.0
...
...
@@ -153,7 +154,8 @@ class U2STTrainer(Trainer):
msg
=
f
"Valid: Rank:
{
dist
.
get_rank
()
}
, "
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
if
not
self
.
use_streamdata
:
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_dump
.
items
())
logger
.
info
(
msg
)
...
...
@@ -172,8 +174,8 @@ class U2STTrainer(Trainer):
# paddle.jit.save(script_model, script_model_path)
self
.
before_train
()
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
n_epoch
:
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
...
...
@@ -191,7 +193,8 @@ class U2STTrainer(Trainer):
self
.
train_batch
(
batch_index
,
batch
,
msg
)
self
.
after_train_batch
()
report
(
'iter'
,
batch_index
+
1
)
report
(
'total'
,
len
(
self
.
train_loader
))
if
not
self
.
use_streamdata
:
report
(
'total'
,
len
(
self
.
train_loader
))
report
(
'reader_cost'
,
dataload_time
)
observation
[
'batch_cost'
]
=
observation
[
'reader_cost'
]
+
observation
[
'step_cost'
]
...
...
@@ -241,79 +244,18 @@ class U2STTrainer(Trainer):
load_transcript
=
True
if
config
.
model_conf
.
asr_weight
>
0
else
False
config
=
self
.
config
.
clone
()
config
[
'load_transcript'
]
=
load_transcript
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
if
self
.
train
:
# train/valid dataset, return token ids
self
.
train_loader
=
BatchDataLoader
(
json_file
=
config
.
train_manifest
,
train_mode
=
True
,
sortagrad
=
False
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
config
.
maxlen_in
,
maxlen_out
=
config
.
maxlen_out
,
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
# aug will be off when train_mode=False
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
load_aux_output
=
load_transcript
,
num_encs
=
1
,
dist_sampler
=
True
)
self
.
valid_loader
=
BatchDataLoader
(
json_file
=
config
.
dev_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
# aug will be off when train_mode=False
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
load_aux_output
=
load_transcript
,
num_encs
=
1
,
dist_sampler
=
False
)
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
# test dataset, return raw text
decode_batch_size
=
config
.
get
(
'decode'
,
dict
()).
get
(
'decode_batch_size'
,
1
)
self
.
test_loader
=
BatchDataLoader
(
json_file
=
config
.
test_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
decode_batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
preprocess_config
,
# aug will be off when train_mode=False
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
,
dist_sampler
=
False
)
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
logger
.
info
(
"Setup test Dataloader!"
)
def
setup_model
(
self
):
config
=
self
.
config
model_conf
=
config
...
...
@@ -468,7 +410,8 @@ class U2STTester(U2STTrainer):
def
test
(
self
):
assert
self
.
args
.
result_file
self
.
model
.
eval
()
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
decode_cfg
=
self
.
config
.
decode
bleu_func
=
bleu_score
.
char_bleu
if
decode_cfg
.
error_rate_type
==
'char-bleu'
else
bleu_score
.
bleu
...
...
paddlespeech/s2t/io/dataloader.py
浏览文件 @
92d1d08b
...
...
@@ -30,9 +30,10 @@ from paddlespeech.s2t.io.reader import LoadInputsAndTargets
from
paddlespeech.s2t.utils.log
import
Log
import
paddlespeech.audio.streamdata
as
streamdata
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.audio.text.text_featurizer
import
TextFeaturizer
from
yacs.config
import
CfgNode
__all__
=
[
"BatchDataLoader"
]
__all__
=
[
"BatchDataLoader"
,
"StreamDataLoader"
]
logger
=
Log
(
__name__
).
getlog
()
...
...
@@ -60,12 +61,36 @@ def batch_collate(x):
"""
return
x
[
0
]
def
read_preprocess_cfg
(
preprocess_conf_file
):
augment_conf
=
dict
()
preprocess_cfg
=
CfgNode
(
new_allowed
=
True
)
preprocess_cfg
.
merge_from_file
(
preprocess_conf_file
)
for
idx
,
process
in
enumerate
(
preprocess_cfg
[
"process"
]):
opts
=
dict
(
process
)
process_type
=
opts
.
pop
(
"type"
)
if
process_type
==
'time_warp'
:
augment_conf
[
'max_w'
]
=
process
[
'max_time_warp'
]
augment_conf
[
'w_inplace'
]
=
process
[
'inplace'
]
augment_conf
[
'w_mode'
]
=
process
[
'mode'
]
if
process_type
==
'freq_mask'
:
augment_conf
[
'max_f'
]
=
process
[
'F'
]
augment_conf
[
'num_f_mask'
]
=
process
[
'n_mask'
]
augment_conf
[
'f_inplace'
]
=
process
[
'inplace'
]
augment_conf
[
'f_replace_with_zero'
]
=
process
[
'replace_with_zero'
]
if
process_type
==
'time_mask'
:
augment_conf
[
'max_t'
]
=
process
[
'T'
]
augment_conf
[
'num_t_mask'
]
=
process
[
'n_mask'
]
augment_conf
[
't_inplace'
]
=
process
[
'inplace'
]
augment_conf
[
't_replace_with_zero'
]
=
process
[
'replace_with_zero'
]
return
augment_conf
class
StreamDataLoader
():
def
__init__
(
self
,
manifest_file
:
str
,
train_mode
:
bool
,
unit_type
:
str
=
'char'
,
batch_size
:
int
=
0
,
preprocess_conf
=
None
,
num_mel_bins
=
80
,
frame_length
=
25
,
frame_shift
=
10
,
...
...
@@ -75,7 +100,6 @@ class StreamDataLoader():
minlen_out
:
float
=
0.0
,
maxlen_out
:
float
=
float
(
'inf'
),
resample_rate
:
int
=
16000
,
augment_conf
:
dict
=
None
,
shuffle_size
:
int
=
10000
,
sort_size
:
int
=
1000
,
n_iter_processes
:
int
=
1
,
...
...
@@ -95,12 +119,27 @@ class StreamDataLoader():
self
.
feat_dim
=
num_mel_bins
self
.
vocab_size
=
text_featurizer
.
vocab_size
augment_conf
=
read_preprocess_cfg
(
preprocess_conf
)
# The list of shard
shardlist
=
[]
with
open
(
manifest_file
,
"r"
)
as
f
:
for
line
in
f
.
readlines
():
shardlist
.
append
(
line
.
strip
())
world_size
=
1
try
:
world_size
=
paddle
.
distributed
.
get_world_size
()
except
Exception
as
e
:
logger
.
warninig
(
e
)
logger
.
warninig
(
"can not get world_size using paddle.distributed.get_world_size(), use world_size=1"
)
assert
(
len
(
shardlist
)
>=
world_size
,
"the length of shard list should >= number of gpus/xpus/..."
)
update_n_iter_processes
=
int
(
max
(
min
(
len
(
shardlist
)
/
world_size
-
1
,
self
.
n_iter_processes
),
0
))
logger
.
info
(
f
"update_n_iter_processes
{
update_n_iter_processes
}
"
)
if
update_n_iter_processes
!=
self
.
n_iter_processes
:
self
.
n_iter_processes
=
update_n_iter_processes
logger
.
info
(
f
"change nun_workers to
{
self
.
n_iter_processes
}
"
)
if
self
.
dist_sampler
:
base_dataset
=
streamdata
.
DataPipeline
(
streamdata
.
SimpleShardList
(
shardlist
),
...
...
@@ -116,16 +155,16 @@ class StreamDataLoader():
)
self
.
dataset
=
base_dataset
.
append_list
(
streamdata
.
tokenize
(
symbol_table
),
streamdata
.
data_filter
(
frame_shift
=
frame_shift
,
max_length
=
maxlen_in
,
min_length
=
minlen_in
,
token_max_length
=
maxlen_out
,
token_min_length
=
minlen_in
),
streamdata
.
resample
(
resample_rate
=
resample_rate
),
streamdata
.
compute_fbank
(
num_mel_bins
=
num_mel_bins
,
frame_length
=
frame_length
,
frame_shift
=
frame_shift
,
dither
=
dither
),
streamdata
.
spec_aug
(
**
augment_conf
)
if
train_mode
else
streamdata
.
placeholder
(),
# num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
streamdata
.
audio_
tokenize
(
symbol_table
),
streamdata
.
audio_data_filter
(
frame_shift
=
frame_shift
,
max_length
=
maxlen_in
,
min_length
=
minlen_in
,
token_max_length
=
maxlen_out
,
token_min_length
=
minlen_out
),
streamdata
.
audio_
resample
(
resample_rate
=
resample_rate
),
streamdata
.
audio_
compute_fbank
(
num_mel_bins
=
num_mel_bins
,
frame_length
=
frame_length
,
frame_shift
=
frame_shift
,
dither
=
dither
),
streamdata
.
audio_
spec_aug
(
**
augment_conf
)
if
train_mode
else
streamdata
.
placeholder
(),
# num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
streamdata
.
shuffle
(
shuffle_size
),
streamdata
.
sort
(
sort_size
=
sort_size
),
streamdata
.
batched
(
batch_size
),
streamdata
.
padding
(),
streamdata
.
cmvn
(
cmvn_file
)
streamdata
.
audio_
padding
(),
streamdata
.
audio_
cmvn
(
cmvn_file
)
)
if
paddle
.
__version__
>=
'2.3.2'
:
...
...
@@ -295,3 +334,119 @@ class BatchDataLoader():
echo
+=
f
"shortest_first:
{
self
.
shortest_first
}
, "
echo
+=
f
"file:
{
self
.
json_file
}
"
return
echo
class
DataLoaderFactory
():
@
staticmethod
def
get_dataloader
(
mode
:
str
,
config
,
args
):
config
=
config
.
clone
()
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
if
use_streamdata
:
if
mode
==
'train'
:
config
[
'manifest'
]
=
config
.
train_manifest
config
[
'train_mode'
]
=
True
elif
mode
==
'valid'
:
config
[
'manifest'
]
=
config
.
dev_manifest
config
[
'train_mode'
]
=
False
elif
model
==
'test'
or
mode
==
'align'
:
config
[
'manifest'
]
=
config
.
test_manifest
config
[
'train_mode'
]
=
False
config
[
'dither'
]
=
0.0
config
[
'minlen_in'
]
=
0.0
config
[
'maxlen_in'
]
=
float
(
'inf'
)
config
[
'minlen_out'
]
=
0
config
[
'maxlen_out'
]
=
float
(
'inf'
)
config
[
'dist_sampler'
]
=
False
else
:
raise
KeyError
(
"not valid mode type!!, please input one of 'train, valid, test, align'"
)
return
StreamDataLoader
(
manifest_file
=
config
.
manifest
,
train_mode
=
config
.
train_mode
,
unit_type
=
config
.
unit_type
,
preprocess_conf
=
config
.
preprocess_config
,
batch_size
=
config
.
batch_size
,
num_mel_bins
=
config
.
feat_dim
,
frame_length
=
config
.
window_ms
,
frame_shift
=
config
.
stride_ms
,
dither
=
config
.
dither
,
minlen_in
=
config
.
minlen_in
,
maxlen_in
=
config
.
maxlen_in
,
minlen_out
=
config
.
minlen_out
,
maxlen_out
=
config
.
maxlen_out
,
resample_rate
=
config
.
resample_rate
,
shuffle_size
=
config
.
shuffle_size
,
sort_size
=
config
.
sort_size
,
n_iter_processes
=
config
.
num_workers
,
prefetch_factor
=
config
.
prefetch_factor
,
dist_sampler
=
config
.
dist_sampler
,
cmvn_file
=
config
.
cmvn_file
,
vocab_filepath
=
config
.
vocab_filepath
,
)
else
:
if
mode
==
'train'
:
config
[
'manifest'
]
=
config
.
train_manifest
config
[
'train_mode'
]
=
True
config
[
'mini_batch_size'
]
=
args
.
ngpu
config
[
'subsampling_factor'
]
=
1
config
[
'num_encs'
]
=
1
elif
mode
==
'valid'
:
config
[
'manifest'
]
=
config
.
dev_manifest
config
[
'train_mode'
]
=
False
config
[
'sortagrad'
]
=
False
config
[
'maxlen_in'
]
=
float
(
'inf'
)
config
[
'maxlen_out'
]
=
float
(
'inf'
)
config
[
'minibatches'
]
=
0
config
[
'mini_batch_size'
]
=
args
.
ngpu
config
[
'batch_count'
]
=
'auto'
config
[
'batch_bins'
]
=
0
config
[
'batch_frames_in'
]
=
0
config
[
'batch_frames_out'
]
=
0
config
[
'batch_frames_inout'
]
=
0
config
[
'subsampling_factor'
]
=
1
config
[
'num_encs'
]
=
1
config
[
'shortest_first'
]
=
False
elif
mode
==
'test'
or
mode
==
'align'
:
config
[
'manifest'
]
=
config
.
test_manifest
config
[
'train_mode'
]
=
False
config
[
'sortagrad'
]
=
False
config
[
'batch_size'
]
=
config
.
get
(
'decode'
,
dict
()).
get
(
'decode_batch_size'
,
1
)
config
[
'maxlen_in'
]
=
float
(
'inf'
)
config
[
'maxlen_out'
]
=
float
(
'inf'
)
config
[
'minibatches'
]
=
0
config
[
'mini_batch_size'
]
=
1
config
[
'batch_count'
]
=
'auto'
config
[
'batch_bins'
]
=
0
config
[
'batch_frames_in'
]
=
0
config
[
'batch_frames_out'
]
=
0
config
[
'batch_frames_inout'
]
=
0
config
[
'num_workers'
]
=
1
config
[
'subsampling_factor'
]
=
1
config
[
'num_encs'
]
=
1
config
[
'dist_sampler'
]
=
False
config
[
'shortest_first'
]
=
False
else
:
raise
KeyError
(
"not valid mode type!!, please input one of 'train, valid, test, align'"
)
return
BatchDataLoader
(
json_file
=
config
.
manifest
,
train_mode
=
config
.
train_mode
,
sortagrad
=
config
.
sortagrad
,
batch_size
=
config
.
batch_size
,
maxlen_in
=
config
.
maxlen_in
,
maxlen_out
=
config
.
maxlen_out
,
minibatches
=
config
.
minibatches
,
mini_batch_size
=
config
.
mini_batch_size
,
batch_count
=
config
.
batch_count
,
batch_bins
=
config
.
batch_bins
,
batch_frames_in
=
config
.
batch_frames_in
,
batch_frames_out
=
config
.
batch_frames_out
,
batch_frames_inout
=
config
.
batch_frames_inout
,
preprocess_conf
=
config
.
preprocess_config
,
n_iter_processes
=
config
.
num_workers
,
subsampling_factor
=
config
.
subsampling_factor
,
load_aux_output
=
config
.
get
(
'load_transcript'
,
None
),
num_encs
=
config
.
num_encs
,
dist_sampler
=
config
.
dist_sampler
,
shortest_first
=
config
.
shortest_first
)
paddlespeech/s2t/models/u2/u2.py
浏览文件 @
92d1d08b
...
...
@@ -48,9 +48,9 @@ from paddlespeech.s2t.utils import checkpoint
from
paddlespeech.s2t.utils
import
layer_tools
from
paddlespeech.s2t.utils.ctc_utils
import
remove_duplicates_and_blank
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.
s2t
.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.
s2t
.utils.tensor_utils
import
pad_sequence
from
paddlespeech.
s2t
.utils.tensor_utils
import
th_accuracy
from
paddlespeech.
audio
.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.
audio
.utils.tensor_utils
import
pad_sequence
from
paddlespeech.
audio
.utils.tensor_utils
import
th_accuracy
from
paddlespeech.s2t.utils.utility
import
log_add
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
...
...
paddlespeech/s2t/models/u2_st/u2_st.py
浏览文件 @
92d1d08b
...
...
@@ -38,8 +38,8 @@ from paddlespeech.s2t.modules.mask import subsequent_mask
from
paddlespeech.s2t.utils
import
checkpoint
from
paddlespeech.s2t.utils
import
layer_tools
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.
s2t
.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.
s2t
.utils.tensor_utils
import
th_accuracy
from
paddlespeech.
audio
.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.
audio
.utils.tensor_utils
import
th_accuracy
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
__all__
=
[
"U2STModel"
,
"U2STInferModel"
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录