Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
69055698
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
12 个月 前同步成功
通知
205
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
69055698
编写于
11月 05, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
transformer using batch data loader
上级
3f611c75
变更
27
显示空白变更内容
内联
并排
Showing
27 changed file
with
328 addition
and
172 deletion
+328
-172
examples/aishell/s0/local/data.sh
examples/aishell/s0/local/data.sh
+0
-1
examples/aishell/s1/local/data.sh
examples/aishell/s1/local/data.sh
+0
-1
examples/callcenter/s1/local/data.sh
examples/callcenter/s1/local/data.sh
+0
-1
examples/dataset/librispeech/librispeech.py
examples/dataset/librispeech/librispeech.py
+11
-8
examples/librispeech/s0/local/data.sh
examples/librispeech/s0/local/data.sh
+0
-1
examples/librispeech/s1/local/data.sh
examples/librispeech/s1/local/data.sh
+0
-1
examples/other/1xt2x/aishell/local/data.sh
examples/other/1xt2x/aishell/local/data.sh
+0
-1
examples/other/1xt2x/baidu_en8k/local/data.sh
examples/other/1xt2x/baidu_en8k/local/data.sh
+0
-1
examples/other/1xt2x/librispeech/local/data.sh
examples/other/1xt2x/librispeech/local/data.sh
+0
-1
examples/ted_en_zh/t0/local/data.sh
examples/ted_en_zh/t0/local/data.sh
+0
-1
examples/timit/s1/local/data.sh
examples/timit/s1/local/data.sh
+0
-1
examples/tiny/s0/local/data.sh
examples/tiny/s0/local/data.sh
+0
-1
examples/tiny/s1/conf/chunk_confermer.yaml
examples/tiny/s1/conf/chunk_confermer.yaml
+1
-1
examples/tiny/s1/conf/chunk_transformer.yaml
examples/tiny/s1/conf/chunk_transformer.yaml
+1
-1
examples/tiny/s1/conf/conformer.yaml
examples/tiny/s1/conf/conformer.yaml
+1
-1
examples/tiny/s1/conf/preprocess.yaml
examples/tiny/s1/conf/preprocess.yaml
+27
-0
examples/tiny/s1/conf/transformer.yaml
examples/tiny/s1/conf/transformer.yaml
+1
-1
examples/tiny/s1/local/data.sh
examples/tiny/s1/local/data.sh
+0
-1
paddlespeech/s2t/exps/u2/model.py
paddlespeech/s2t/exps/u2/model.py
+108
-87
paddlespeech/s2t/exps/u2_kaldi/model.py
paddlespeech/s2t/exps/u2_kaldi/model.py
+3
-2
paddlespeech/s2t/io/dataset.py
paddlespeech/s2t/io/dataset.py
+11
-27
paddlespeech/s2t/io/reader.py
paddlespeech/s2t/io/reader.py
+4
-15
paddlespeech/s2t/io/utility.py
paddlespeech/s2t/io/utility.py
+18
-1
paddlespeech/s2t/transform/spectrogram.py
paddlespeech/s2t/transform/spectrogram.py
+83
-0
paddlespeech/s2t/transform/transformation.py
paddlespeech/s2t/transform/transformation.py
+1
-0
utils/format_data.py
utils/format_data.py
+54
-14
utils/format_triplet_data.py
utils/format_triplet_data.py
+4
-2
未找到文件。
examples/aishell/s0/local/data.sh
浏览文件 @
69055698
...
@@ -66,7 +66,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
...
@@ -66,7 +66,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for
dataset
in
train dev
test
;
do
for
dataset
in
train dev
test
;
do
{
{
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
--feat_type
"raw"
\
--cmvn_path
"data/mean_std.json"
\
--cmvn_path
"data/mean_std.json"
\
--unit_type
"char"
\
--unit_type
"char"
\
--vocab_path
=
"data/vocab.txt"
\
--vocab_path
=
"data/vocab.txt"
\
...
...
examples/aishell/s1/local/data.sh
浏览文件 @
69055698
...
@@ -67,7 +67,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
...
@@ -67,7 +67,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for
dataset
in
train dev
test
;
do
for
dataset
in
train dev
test
;
do
{
{
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
--feat_type
"raw"
\
--cmvn_path
"data/mean_std.json"
\
--cmvn_path
"data/mean_std.json"
\
--unit_type
"char"
\
--unit_type
"char"
\
--vocab_path
=
"data/vocab.txt"
\
--vocab_path
=
"data/vocab.txt"
\
...
...
examples/callcenter/s1/local/data.sh
浏览文件 @
69055698
...
@@ -55,7 +55,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
...
@@ -55,7 +55,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for
dataset
in
train dev
test
;
do
for
dataset
in
train dev
test
;
do
{
{
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
--feat_type
"raw"
\
--cmvn_path
"data/mean_std.json"
\
--cmvn_path
"data/mean_std.json"
\
--unit_type
"char"
\
--unit_type
"char"
\
--vocab_path
=
"data/vocab.txt"
\
--vocab_path
=
"data/vocab.txt"
\
...
...
examples/dataset/librispeech/librispeech.py
浏览文件 @
69055698
...
@@ -89,25 +89,28 @@ def create_manifest(data_dir, manifest_path):
...
@@ -89,25 +89,28 @@ def create_manifest(data_dir, manifest_path):
text_filepath
=
os
.
path
.
join
(
subfolder
,
text_filelist
[
0
])
text_filepath
=
os
.
path
.
join
(
subfolder
,
text_filelist
[
0
])
for
line
in
io
.
open
(
text_filepath
,
encoding
=
"utf8"
):
for
line
in
io
.
open
(
text_filepath
,
encoding
=
"utf8"
):
segments
=
line
.
strip
().
split
()
segments
=
line
.
strip
().
split
()
n_token
=
len
(
segments
[
1
:])
text
=
' '
.
join
(
segments
[
1
:]).
lower
()
text
=
' '
.
join
(
segments
[
1
:]).
lower
()
audio_filepath
=
os
.
path
.
abspath
(
audio_filepath
=
os
.
path
.
abspath
(
os
.
path
.
join
(
subfolder
,
segments
[
0
]
+
'.flac'
))
os
.
path
.
join
(
subfolder
,
segments
[
0
]
+
'.flac'
))
audio_data
,
samplerate
=
soundfile
.
read
(
audio_filepath
)
audio_data
,
samplerate
=
soundfile
.
read
(
audio_filepath
)
duration
=
float
(
len
(
audio_data
))
/
samplerate
duration
=
float
(
len
(
audio_data
))
/
samplerate
utt
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
audio_filepath
))[
0
]
utt2spk
=
'-'
.
join
(
utt
.
split
(
'-'
)[:
2
])
json_lines
.
append
(
json_lines
.
append
(
json
.
dumps
({
json
.
dumps
({
'utt'
:
'utt'
:
utt
,
os
.
path
.
splitext
(
os
.
path
.
basename
(
audio_filepath
))[
0
],
'utt2spk'
:
utt2spk
,
'feat'
:
'feat'
:
audio_filepath
,
audio_filepath
,
'feat_shape'
:
(
duration
,
),
# second
'feat_shape'
:
(
duration
,
),
#second
'text'
:
text
,
'text'
:
text
}))
}))
total_sec
+=
duration
total_sec
+=
duration
total_text
+=
len
(
text
)
total_text
+=
n_token
total_num
+=
1
total_num
+=
1
with
codecs
.
open
(
manifest_path
,
'w'
,
'utf-8'
)
as
out_file
:
with
codecs
.
open
(
manifest_path
,
'w'
,
'utf-8'
)
as
out_file
:
...
...
examples/librispeech/s0/local/data.sh
浏览文件 @
69055698
...
@@ -81,7 +81,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
...
@@ -81,7 +81,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for
set
in
train dev
test
dev-clean dev-other test-clean test-other
;
do
for
set
in
train dev
test
dev-clean dev-other test-clean test-other
;
do
{
{
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
--feat_type
"raw"
\
--cmvn_path
"data/mean_std.json"
\
--cmvn_path
"data/mean_std.json"
\
--unit_type
${
unit_type
}
\
--unit_type
${
unit_type
}
\
--vocab_path
=
"data/vocab.txt"
\
--vocab_path
=
"data/vocab.txt"
\
...
...
examples/librispeech/s1/local/data.sh
浏览文件 @
69055698
...
@@ -88,7 +88,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
...
@@ -88,7 +88,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for
set
in
train dev
test
dev-clean dev-other test-clean test-other
;
do
for
set
in
train dev
test
dev-clean dev-other test-clean test-other
;
do
{
{
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
--feat_type
"raw"
\
--cmvn_path
"data/mean_std.json"
\
--cmvn_path
"data/mean_std.json"
\
--unit_type
"spm"
\
--unit_type
"spm"
\
--spm_model_prefix
${
bpeprefix
}
\
--spm_model_prefix
${
bpeprefix
}
\
...
...
examples/other/1xt2x/aishell/local/data.sh
浏览文件 @
69055698
...
@@ -50,7 +50,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
...
@@ -50,7 +50,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for
dataset
in
train dev
test
;
do
for
dataset
in
train dev
test
;
do
{
{
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
--feat_type
"raw"
\
--cmvn_path
"data/mean_std.npz"
\
--cmvn_path
"data/mean_std.npz"
\
--unit_type
"char"
\
--unit_type
"char"
\
--vocab_path
=
"data/vocab.txt"
\
--vocab_path
=
"data/vocab.txt"
\
...
...
examples/other/1xt2x/baidu_en8k/local/data.sh
浏览文件 @
69055698
...
@@ -65,7 +65,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
...
@@ -65,7 +65,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for
set
in
train dev
test
dev-clean dev-other test-clean test-other
;
do
for
set
in
train dev
test
dev-clean dev-other test-clean test-other
;
do
{
{
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
--feat_type
"raw"
\
--cmvn_path
"data/mean_std.npz"
\
--cmvn_path
"data/mean_std.npz"
\
--unit_type
${
unit_type
}
\
--unit_type
${
unit_type
}
\
--vocab_path
=
"data/vocab.txt"
\
--vocab_path
=
"data/vocab.txt"
\
...
...
examples/other/1xt2x/librispeech/local/data.sh
浏览文件 @
69055698
...
@@ -63,7 +63,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
...
@@ -63,7 +63,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for
set
in
train dev
test
dev-clean dev-other test-clean test-other
;
do
for
set
in
train dev
test
dev-clean dev-other test-clean test-other
;
do
{
{
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
--feat_type
"raw"
\
--cmvn_path
"data/mean_std.npz"
\
--cmvn_path
"data/mean_std.npz"
\
--unit_type
${
unit_type
}
\
--unit_type
${
unit_type
}
\
--vocab_path
=
"data/vocab.txt"
\
--vocab_path
=
"data/vocab.txt"
\
...
...
examples/ted_en_zh/t0/local/data.sh
浏览文件 @
69055698
...
@@ -89,7 +89,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
...
@@ -89,7 +89,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for
set
in
train dev
test
;
do
for
set
in
train dev
test
;
do
{
{
python3
${
MAIN_ROOT
}
/utils/format_triplet_data.py
\
python3
${
MAIN_ROOT
}
/utils/format_triplet_data.py
\
--feat_type
"raw"
\
--cmvn_path
"data/mean_std.json"
\
--cmvn_path
"data/mean_std.json"
\
--unit_type
"spm"
\
--unit_type
"spm"
\
--spm_model_prefix
${
bpeprefix
}
\
--spm_model_prefix
${
bpeprefix
}
\
...
...
examples/timit/s1/local/data.sh
浏览文件 @
69055698
...
@@ -66,7 +66,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
...
@@ -66,7 +66,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for
set
in
train dev
test
;
do
for
set
in
train dev
test
;
do
{
{
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
--feat_type
"raw"
\
--cmvn_path
"data/mean_std.json"
\
--cmvn_path
"data/mean_std.json"
\
--unit_type
${
unit_type
}
\
--unit_type
${
unit_type
}
\
--vocab_path
=
"data/vocab.txt"
\
--vocab_path
=
"data/vocab.txt"
\
...
...
examples/tiny/s0/local/data.sh
浏览文件 @
69055698
...
@@ -63,7 +63,6 @@ fi
...
@@ -63,7 +63,6 @@ fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# format manifest with tokenids, vocab size
# format manifest with tokenids, vocab size
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
--feat_type
"raw"
\
--cmvn_path
"data/mean_std.json"
\
--cmvn_path
"data/mean_std.json"
\
--unit_type
${
unit_type
}
\
--unit_type
${
unit_type
}
\
--vocab_path
=
"data/vocab.txt"
\
--vocab_path
=
"data/vocab.txt"
\
...
...
examples/tiny/s1/conf/chunk_confermer.yaml
浏览文件 @
69055698
...
@@ -15,7 +15,7 @@ collator:
...
@@ -15,7 +15,7 @@ collator:
vocab_filepath
:
data/vocab.txt
vocab_filepath
:
data/vocab.txt
unit_type
:
'
spm'
unit_type
:
'
spm'
spm_model_prefix
:
'
data/bpe_unigram_200'
spm_model_prefix
:
'
data/bpe_unigram_200'
augmentation_config
:
conf/
augmentation.json
augmentation_config
:
conf/
preprocess.yaml
batch_size
:
4
batch_size
:
4
raw_wav
:
True
# use raw_wav or kaldi feature
raw_wav
:
True
# use raw_wav or kaldi feature
spectrum_type
:
fbank
#linear, mfcc, fbank
spectrum_type
:
fbank
#linear, mfcc, fbank
...
...
examples/tiny/s1/conf/chunk_transformer.yaml
浏览文件 @
69055698
...
@@ -15,7 +15,7 @@ collator:
...
@@ -15,7 +15,7 @@ collator:
vocab_filepath
:
data/vocab.txt
vocab_filepath
:
data/vocab.txt
unit_type
:
'
spm'
unit_type
:
'
spm'
spm_model_prefix
:
'
data/bpe_unigram_200'
spm_model_prefix
:
'
data/bpe_unigram_200'
augmentation_config
:
conf/
augmentation.json
augmentation_config
:
conf/
preprocess.yaml
batch_size
:
4
batch_size
:
4
raw_wav
:
True
# use raw_wav or kaldi feature
raw_wav
:
True
# use raw_wav or kaldi feature
spectrum_type
:
fbank
#linear, mfcc, fbank
spectrum_type
:
fbank
#linear, mfcc, fbank
...
...
examples/tiny/s1/conf/conformer.yaml
浏览文件 @
69055698
...
@@ -15,7 +15,7 @@ collator:
...
@@ -15,7 +15,7 @@ collator:
vocab_filepath
:
data/vocab.txt
vocab_filepath
:
data/vocab.txt
unit_type
:
'
spm'
unit_type
:
'
spm'
spm_model_prefix
:
'
data/bpe_unigram_200'
spm_model_prefix
:
'
data/bpe_unigram_200'
augmentation_config
:
conf/
augmentation.json
augmentation_config
:
conf/
preprocess.yaml
batch_size
:
4
batch_size
:
4
raw_wav
:
True
# use raw_wav or kaldi feature
raw_wav
:
True
# use raw_wav or kaldi feature
spectrum_type
:
fbank
#linear, mfcc, fbank
spectrum_type
:
fbank
#linear, mfcc, fbank
...
...
examples/tiny/s1/conf/preprocess.yaml
0 → 100644
浏览文件 @
69055698
process
:
# extract kaldi fbank from PCM
-
type
:
"
fbank_kaldi"
fs
:
16000
n_mels
:
80
n_shift
:
160
win_length
:
400
dither
:
true
# these three processes are a.k.a. SpecAugument
-
type
:
"
time_warp"
max_time_warp
:
5
inplace
:
true
mode
:
"
PIL"
-
type
:
"
freq_mask"
F
:
30
n_mask
:
2
inplace
:
true
replace_with_zero
:
false
-
type
:
"
time_mask"
T
:
40
n_mask
:
2
inplace
:
true
replace_with_zero
:
false
examples/tiny/s1/conf/transformer.yaml
浏览文件 @
69055698
...
@@ -15,7 +15,7 @@ collator:
...
@@ -15,7 +15,7 @@ collator:
vocab_filepath
:
data/vocab.txt
vocab_filepath
:
data/vocab.txt
unit_type
:
'
spm'
unit_type
:
'
spm'
spm_model_prefix
:
'
data/bpe_unigram_200'
spm_model_prefix
:
'
data/bpe_unigram_200'
augmentation_config
:
conf/
augmentation.json
augmentation_config
:
conf/
preprocess.yaml
batch_size
:
4
batch_size
:
4
raw_wav
:
True
# use raw_wav or kaldi feature
raw_wav
:
True
# use raw_wav or kaldi feature
spectrum_type
:
fbank
#linear, mfcc, fbank
spectrum_type
:
fbank
#linear, mfcc, fbank
...
...
examples/tiny/s1/local/data.sh
浏览文件 @
69055698
...
@@ -69,7 +69,6 @@ fi
...
@@ -69,7 +69,6 @@ fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# format manifest with tokenids, vocab size
# format manifest with tokenids, vocab size
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
--feat_type
"raw"
\
--cmvn_path
"data/mean_std.json"
\
--cmvn_path
"data/mean_std.json"
\
--unit_type
"spm"
\
--unit_type
"spm"
\
--spm_model_prefix
${
bpeprefix
}
\
--spm_model_prefix
${
bpeprefix
}
\
...
...
paddlespeech/s2t/exps/u2/model.py
浏览文件 @
69055698
...
@@ -27,7 +27,9 @@ from paddle import distributed as dist
...
@@ -27,7 +27,9 @@ from paddle import distributed as dist
from
paddle.io
import
DataLoader
from
paddle.io
import
DataLoader
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.collator
import
SpeechCollator
from
paddlespeech.s2t.io.collator
import
SpeechCollator
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.io.dataset
import
ManifestDataset
from
paddlespeech.s2t.io.dataset
import
ManifestDataset
from
paddlespeech.s2t.io.sampler
import
SortagradBatchSampler
from
paddlespeech.s2t.io.sampler
import
SortagradBatchSampler
from
paddlespeech.s2t.io.sampler
import
SortagradDistributedBatchSampler
from
paddlespeech.s2t.io.sampler
import
SortagradDistributedBatchSampler
...
@@ -247,92 +249,103 @@ class U2Trainer(Trainer):
...
@@ -247,92 +249,103 @@ class U2Trainer(Trainer):
def
setup_dataloader
(
self
):
def
setup_dataloader
(
self
):
config
=
self
.
config
.
clone
()
config
=
self
.
config
.
clone
()
config
.
defrost
()
config
.
collator
.
keep_transcription_text
=
False
if
self
.
train
:
# train/valid dataset, return token ids
# train/valid dataset, return token ids
config
.
data
.
manifest
=
config
.
data
.
train_manifest
self
.
train_loader
=
BatchDataLoader
(
train_dataset
=
ManifestDataset
.
from_config
(
config
)
json_file
=
config
.
data
.
train_manifest
,
train_mode
=
True
,
config
.
data
.
manifest
=
config
.
data
.
dev_manifest
sortagrad
=
False
,
dev_dataset
=
ManifestDataset
.
from_config
(
config
)
collate_fn_train
=
SpeechCollator
.
from_config
(
config
)
config
.
collator
.
augmentation_config
=
""
collate_fn_dev
=
SpeechCollator
.
from_config
(
config
)
if
self
.
parallel
:
batch_sampler
=
SortagradDistributedBatchSampler
(
train_dataset
,
batch_size
=
config
.
collator
.
batch_size
,
num_replicas
=
None
,
rank
=
None
,
shuffle
=
True
,
drop_last
=
True
,
sortagrad
=
config
.
collator
.
sortagrad
,
shuffle_method
=
config
.
collator
.
shuffle_method
)
else
:
batch_sampler
=
SortagradBatchSampler
(
train_dataset
,
shuffle
=
True
,
batch_size
=
config
.
collator
.
batch_size
,
batch_size
=
config
.
collator
.
batch_size
,
drop_last
=
True
,
maxlen_in
=
float
(
'inf'
),
sortagrad
=
config
.
collator
.
sortagrad
,
maxlen_out
=
float
(
'inf'
),
shuffle_method
=
config
.
collator
.
shuffle_method
)
minibatches
=
0
,
self
.
train_loader
=
DataLoader
(
mini_batch_size
=
self
.
args
.
nprocs
,
train_dataset
,
batch_count
=
'auto'
,
batch_sampler
=
batch_sampler
,
batch_bins
=
0
,
collate_fn
=
collate_fn_train
,
batch_frames_in
=
0
,
num_workers
=
config
.
collator
.
num_workers
,
)
batch_frames_out
=
0
,
self
.
valid_loader
=
DataLoader
(
batch_frames_inout
=
0
,
dev_dataset
,
preprocess_conf
=
config
.
collator
.
augmentation_config
,
# aug will be off when train_mode=False
n_iter_processes
=
config
.
collator
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
)
self
.
valid_loader
=
BatchDataLoader
(
json_file
=
config
.
data
.
dev_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
config
.
collator
.
batch_size
,
batch_size
=
config
.
collator
.
batch_size
,
shuffle
=
False
,
maxlen_in
=
float
(
'inf'
),
drop_last
=
False
,
maxlen_out
=
float
(
'inf'
),
collate_fn
=
collate_fn_dev
,
minibatches
=
0
,
num_workers
=
config
.
collator
.
num_workers
,
)
mini_batch_size
=
self
.
args
.
nprocs
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
collator
.
augmentation_config
,
# aug will be off when train_mode=False
n_iter_processes
=
config
.
collator
.
num_workers
,
subsampling_factor
=
1
,
num_encs
=
1
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
# test dataset, return raw text
# test dataset, return raw text
config
.
data
.
manifest
=
config
.
data
.
test_manifest
self
.
test_loader
=
BatchDataLoader
(
# filter test examples, will cause less examples, but no mismatch with training
json_file
=
config
.
data
.
test_manifest
,
# and can use large batch size , save training time, so filter test egs now.
train_mode
=
False
,
config
.
data
.
min_input_len
=
0.0
# second
sortagrad
=
False
,
config
.
data
.
max_input_len
=
float
(
'inf'
)
# second
config
.
data
.
min_output_len
=
0.0
# tokens
config
.
data
.
max_output_len
=
float
(
'inf'
)
# tokens
config
.
data
.
min_output_input_ratio
=
0.00
config
.
data
.
max_output_input_ratio
=
float
(
'inf'
)
test_dataset
=
ManifestDataset
.
from_config
(
config
)
# return text ord id
config
.
collator
.
keep_transcription_text
=
True
config
.
collator
.
augmentation_config
=
""
self
.
test_loader
=
DataLoader
(
test_dataset
,
batch_size
=
config
.
decoding
.
batch_size
,
batch_size
=
config
.
decoding
.
batch_size
,
shuffle
=
False
,
maxlen_in
=
float
(
'inf'
),
drop_last
=
False
,
maxlen_out
=
float
(
'inf'
),
collate_fn
=
SpeechCollator
.
from_config
(
config
),
minibatches
=
0
,
num_workers
=
config
.
collator
.
num_workers
,
)
mini_batch_size
=
1
,
# return text token id
batch_count
=
'auto'
,
config
.
collator
.
keep_transcription_text
=
False
batch_bins
=
0
,
self
.
align_loader
=
DataLoader
(
batch_frames_in
=
0
,
test_dataset
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
collator
.
augmentation_config
,
# aug will be off when train_mode=False
n_iter_processes
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
self
.
align_loader
=
BatchDataLoader
(
json_file
=
config
.
data
.
test_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
config
.
decoding
.
batch_size
,
batch_size
=
config
.
decoding
.
batch_size
,
shuffle
=
False
,
maxlen_in
=
float
(
'inf'
),
drop_last
=
False
,
maxlen_out
=
float
(
'inf'
),
collate_fn
=
SpeechCollator
.
from_config
(
config
),
minibatches
=
0
,
num_workers
=
config
.
collator
.
num_workers
,
)
mini_batch_size
=
1
,
logger
.
info
(
"Setup train/valid/test/align Dataloader!"
)
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_out
=
0
,
batch_frames_inout
=
0
,
preprocess_conf
=
config
.
collator
.
augmentation_config
,
# aug will be off when train_mode=False
n_iter_processes
=
1
,
subsampling_factor
=
1
,
num_encs
=
1
)
logger
.
info
(
"Setup test/align Dataloader!"
)
def
setup_model
(
self
):
def
setup_model
(
self
):
config
=
self
.
config
config
=
self
.
config
model_conf
=
config
.
model
model_conf
=
config
.
model
with
UpdateConfig
(
model_conf
):
with
UpdateConfig
(
model_conf
):
model_conf
.
input_dim
=
self
.
train_loader
.
collate_fn
.
feature_size
if
self
.
train
:
model_conf
.
output_dim
=
self
.
train_loader
.
collate_fn
.
vocab_size
model_conf
.
input_dim
=
self
.
train_loader
.
feat_dim
model_conf
.
output_dim
=
self
.
train_loader
.
vocab_size
else
:
model_conf
.
input_dim
=
self
.
test_loader
.
feat_dim
model_conf
.
output_dim
=
self
.
test_loader
.
vocab_size
model
=
U2Model
.
from_config
(
model_conf
)
model
=
U2Model
.
from_config
(
model_conf
)
...
@@ -341,6 +354,11 @@ class U2Trainer(Trainer):
...
@@ -341,6 +354,11 @@ class U2Trainer(Trainer):
logger
.
info
(
f
"
{
model
}
"
)
logger
.
info
(
f
"
{
model
}
"
)
layer_tools
.
print_params
(
model
,
logger
.
info
)
layer_tools
.
print_params
(
model
,
logger
.
info
)
self
.
model
=
model
logger
.
info
(
"Setup model!"
)
if
not
self
.
train
:
return
train_config
=
config
.
training
train_config
=
config
.
training
optim_type
=
train_config
.
optim
optim_type
=
train_config
.
optim
...
@@ -381,10 +399,9 @@ class U2Trainer(Trainer):
...
@@ -381,10 +399,9 @@ class U2Trainer(Trainer):
optimzer_args
=
optimizer_args
(
config
,
model
.
parameters
(),
lr_scheduler
)
optimzer_args
=
optimizer_args
(
config
,
model
.
parameters
(),
lr_scheduler
)
optimizer
=
OptimizerFactory
.
from_args
(
optim_type
,
optimzer_args
)
optimizer
=
OptimizerFactory
.
from_args
(
optim_type
,
optimzer_args
)
self
.
model
=
model
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
self
.
lr_scheduler
=
lr_scheduler
self
.
lr_scheduler
=
lr_scheduler
logger
.
info
(
"Setup
model/
optimizer/lr_scheduler!"
)
logger
.
info
(
"Setup optimizer/lr_scheduler!"
)
class
U2Tester
(
U2Trainer
):
class
U2Tester
(
U2Trainer
):
...
@@ -419,14 +436,19 @@ class U2Tester(U2Trainer):
...
@@ -419,14 +436,19 @@ class U2Tester(U2Trainer):
def
__init__
(
self
,
config
,
args
):
def
__init__
(
self
,
config
,
args
):
super
().
__init__
(
config
,
args
)
super
().
__init__
(
config
,
args
)
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
collator
.
unit_type
,
vocab_filepath
=
self
.
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
self
.
vocab_list
=
self
.
text_feature
.
vocab_list
def
ordid2token
(
self
,
texts
,
texts_len
):
def
id2token
(
self
,
texts
,
texts_len
,
text_feature
):
""" ord() id to chr() chr """
""" ord() id to chr() chr """
trans
=
[]
trans
=
[]
for
text
,
n
in
zip
(
texts
,
texts_len
):
for
text
,
n
in
zip
(
texts
,
texts_len
):
n
=
n
.
numpy
().
item
()
n
=
n
.
numpy
().
item
()
ids
=
text
[:
n
]
ids
=
text
[:
n
]
trans
.
append
(
''
.
join
([
chr
(
i
)
for
i
in
ids
]
))
trans
.
append
(
text_feature
.
defeaturize
(
ids
.
numpy
().
tolist
()
))
return
trans
return
trans
def
compute_metrics
(
self
,
def
compute_metrics
(
self
,
...
@@ -442,12 +464,11 @@ class U2Tester(U2Trainer):
...
@@ -442,12 +464,11 @@ class U2Tester(U2Trainer):
error_rate_func
=
error_rate
.
cer
if
cfg
.
error_rate_type
==
'cer'
else
error_rate
.
wer
error_rate_func
=
error_rate
.
cer
if
cfg
.
error_rate_type
==
'cer'
else
error_rate
.
wer
start_time
=
time
.
time
()
start_time
=
time
.
time
()
text_feature
=
self
.
test_loader
.
collate_fn
.
text_feature
target_transcripts
=
self
.
id2token
(
texts
,
texts_len
,
self
.
text_feature
)
target_transcripts
=
self
.
ordid2token
(
texts
,
texts_len
)
result_transcripts
,
result_tokenids
=
self
.
model
.
decode
(
result_transcripts
,
result_tokenids
=
self
.
model
.
decode
(
audio
,
audio
,
audio_len
,
audio_len
,
text_feature
=
text_feature
,
text_feature
=
self
.
text_feature
,
decoding_method
=
cfg
.
decoding_method
,
decoding_method
=
cfg
.
decoding_method
,
lang_model_path
=
cfg
.
lang_model_path
,
lang_model_path
=
cfg
.
lang_model_path
,
beam_alpha
=
cfg
.
alpha
,
beam_alpha
=
cfg
.
alpha
,
...
@@ -497,7 +518,7 @@ class U2Tester(U2Trainer):
...
@@ -497,7 +518,7 @@ class U2Tester(U2Trainer):
self
.
model
.
eval
()
self
.
model
.
eval
()
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
stride_ms
=
self
.
test_loader
.
collate_fn
.
stride_ms
stride_ms
=
self
.
config
.
collator
.
stride_ms
error_rate_type
=
None
error_rate_type
=
None
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
num_frames
=
0.0
num_frames
=
0.0
...
@@ -556,8 +577,8 @@ class U2Tester(U2Trainer):
...
@@ -556,8 +577,8 @@ class U2Tester(U2Trainer):
def
align
(
self
):
def
align
(
self
):
ctc_utils
.
ctc_align
(
ctc_utils
.
ctc_align
(
self
.
model
,
self
.
align_loader
,
self
.
config
.
decoding
.
batch_size
,
self
.
model
,
self
.
align_loader
,
self
.
config
.
decoding
.
batch_size
,
self
.
align_loader
.
collate_fn
.
stride_ms
,
self
.
config
.
collator
.
stride_ms
,
self
.
align_loader
.
collate_fn
.
vocab_list
,
self
.
args
.
result_file
)
self
.
vocab_list
,
self
.
args
.
result_file
)
def
load_inferspec
(
self
):
def
load_inferspec
(
self
):
"""infer model and input spec.
"""infer model and input spec.
...
...
paddlespeech/s2t/exps/u2_kaldi/model.py
浏览文件 @
69055698
...
@@ -392,6 +392,7 @@ class U2Tester(U2Trainer):
...
@@ -392,6 +392,7 @@ class U2Tester(U2Trainer):
unit_type
=
self
.
config
.
collator
.
unit_type
,
unit_type
=
self
.
config
.
collator
.
unit_type
,
vocab_filepath
=
self
.
config
.
collator
.
vocab_filepath
,
vocab_filepath
=
self
.
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
self
.
vocab_list
=
self
.
text_feature
.
vocab_list
def
id2token
(
self
,
texts
,
texts_len
,
text_feature
):
def
id2token
(
self
,
texts
,
texts_len
,
text_feature
):
""" ord() id to chr() chr """
""" ord() id to chr() chr """
...
@@ -529,8 +530,8 @@ class U2Tester(U2Trainer):
...
@@ -529,8 +530,8 @@ class U2Tester(U2Trainer):
def
align
(
self
):
def
align
(
self
):
ctc_utils
.
ctc_align
(
ctc_utils
.
ctc_align
(
self
.
model
,
self
.
align_loader
,
self
.
config
.
decoding
.
batch_size
,
self
.
model
,
self
.
align_loader
,
self
.
config
.
decoding
.
batch_size
,
self
.
align_loader
.
collate_fn
.
stride_ms
,
self
.
config
.
collator
.
stride_ms
,
self
.
align_loader
.
collate_fn
.
vocab_list
,
self
.
args
.
result_file
)
self
.
vocab_list
,
self
.
args
.
result_file
)
def
load_inferspec
(
self
):
def
load_inferspec
(
self
):
"""infer model and input spec.
"""infer model and input spec.
...
...
paddlespeech/s2t/io/dataset.py
浏览文件 @
69055698
...
@@ -207,34 +207,16 @@ class AudioDataset(Dataset):
...
@@ -207,34 +207,16 @@ class AudioDataset(Dataset):
if
sort
:
if
sort
:
data
=
sorted
(
data
,
key
=
lambda
x
:
x
[
"feat_shape"
][
0
])
data
=
sorted
(
data
,
key
=
lambda
x
:
x
[
"feat_shape"
][
0
])
if
raw_wav
:
if
raw_wav
:
assert
data
[
0
][
'feat'
].
split
(
':'
)[
0
].
splitext
()[
-
1
]
not
in
(
'.ark'
,
path_suffix
=
data
[
0
][
'feat'
].
split
(
':'
)[
0
].
splitext
()[
-
1
]
'.scp'
)
assert
path_suffix
not
in
(
'.ark'
,
'.scp'
)
data
=
map
(
lambda
x
:
(
float
(
x
[
'feat_shape'
][
0
])
*
1000
/
stride_ms
))
# m second to n frame
data
=
list
(
map
(
lambda
x
:
(
float
(
x
[
'feat_shape'
][
0
])
*
1000
/
stride_ms
),
data
))
self
.
input_dim
=
data
[
0
][
'feat_shape'
][
1
]
self
.
input_dim
=
data
[
0
][
'feat_shape'
][
1
]
self
.
output_dim
=
data
[
0
][
'token_shape'
][
1
]
self
.
output_dim
=
data
[
0
][
'token_shape'
][
1
]
# with open(data_file, 'r') as f:
# for line in f:
# arr = line.strip().split('\t')
# if len(arr) != 7:
# continue
# key = arr[0].split(':')[1]
# tokenid = arr[5].split(':')[1]
# output_dim = int(arr[6].split(':')[1].split(',')[1])
# if raw_wav:
# wav_path = ':'.join(arr[1].split(':')[1:])
# duration = int(float(arr[2].split(':')[1]) * 1000 / 10)
# data.append((key, wav_path, duration, tokenid))
# else:
# feat_ark = ':'.join(arr[1].split(':')[1:])
# feat_info = arr[2].split(':')[1].split(',')
# feat_dim = int(feat_info[1].strip())
# num_frames = int(feat_info[0].strip())
# data.append((key, feat_ark, num_frames, tokenid))
# self.input_dim = feat_dim
# self.output_dim = output_dim
valid_data
=
[]
valid_data
=
[]
for
i
in
range
(
len
(
data
)):
for
i
in
range
(
len
(
data
)):
length
=
data
[
i
][
'feat_shape'
][
0
]
length
=
data
[
i
][
'feat_shape'
][
0
]
...
@@ -242,17 +224,17 @@ class AudioDataset(Dataset):
...
@@ -242,17 +224,17 @@ class AudioDataset(Dataset):
# remove too lang or too short utt for both input and output
# remove too lang or too short utt for both input and output
# to prevent from out of memory
# to prevent from out of memory
if
length
>
max_length
or
length
<
min_length
:
if
length
>
max_length
or
length
<
min_length
:
# logging.warn('ignore utterance {} feature {}'.format(
# data[i][0], length))
pass
pass
elif
token_length
>
token_max_length
or
token_length
<
token_min_length
:
elif
token_length
>
token_max_length
or
token_length
<
token_min_length
:
pass
pass
else
:
else
:
valid_data
.
append
(
data
[
i
])
valid_data
.
append
(
data
[
i
])
logger
.
info
(
f
"raw dataset len:
{
len
(
data
)
}
"
)
data
=
valid_data
data
=
valid_data
num_data
=
len
(
data
)
logger
.
info
(
f
"dataset len after filter:
{
num_data
}
"
)
self
.
minibatch
=
[]
self
.
minibatch
=
[]
num_data
=
len
(
data
)
# Dynamic batch size
# Dynamic batch size
if
batch_type
==
'dynamic'
:
if
batch_type
==
'dynamic'
:
assert
(
max_frames_in_batch
>
0
)
assert
(
max_frames_in_batch
>
0
)
...
@@ -277,7 +259,9 @@ class AudioDataset(Dataset):
...
@@ -277,7 +259,9 @@ class AudioDataset(Dataset):
cur
=
end
cur
=
end
def
__len__
(
self
):
def
__len__
(
self
):
"""number of example(batch)"""
return
len
(
self
.
minibatch
)
return
len
(
self
.
minibatch
)
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
"""batch example of idx"""
return
self
.
minibatch
[
idx
]
return
self
.
minibatch
[
idx
]
paddlespeech/s2t/io/reader.py
浏览文件 @
69055698
...
@@ -18,8 +18,10 @@ import kaldiio
...
@@ -18,8 +18,10 @@ import kaldiio
import
numpy
as
np
import
numpy
as
np
import
soundfile
import
soundfile
from
paddlespeech.s2t.frontend.augmentor.augmentation
import
AugmentationPipeline
as
Transformation
from
.utility
import
feat_type
from
paddlespeech.s2t.transform.transformation
import
Transformation
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
# from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline as Transformation
__all__
=
[
"LoadInputsAndTargets"
]
__all__
=
[
"LoadInputsAndTargets"
]
...
@@ -322,20 +324,7 @@ class LoadInputsAndTargets():
...
@@ -322,20 +324,7 @@ class LoadInputsAndTargets():
"Not supported: loader_type={}"
.
format
(
filetype
))
"Not supported: loader_type={}"
.
format
(
filetype
))
def
file_type
(
self
,
filepath
):
def
file_type
(
self
,
filepath
):
suffix
=
filepath
.
split
(
":"
)[
0
].
split
(
'.'
)[
-
1
].
lower
()
return
feat_type
(
filepath
)
if
suffix
==
'ark'
:
return
'mat'
elif
suffix
==
'scp'
:
return
'scp'
elif
suffix
==
'npy'
:
return
'npy'
elif
suffix
==
'npz'
:
return
'npz'
elif
suffix
in
[
'wav'
,
'flac'
]:
# PCM16
return
'sound'
else
:
raise
ValueError
(
f
"Not support filetype:
{
suffix
}
"
)
class
SoundHDF5File
():
class
SoundHDF5File
():
...
...
paddlespeech/s2t/io/utility.py
浏览文件 @
69055698
...
@@ -17,7 +17,7 @@ import numpy as np
...
@@ -17,7 +17,7 @@ import numpy as np
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
__all__
=
[
"pad_list"
,
"pad_sequence"
]
__all__
=
[
"pad_list"
,
"pad_sequence"
,
"feat_type"
]
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
@@ -85,3 +85,20 @@ def pad_sequence(sequences: List[np.ndarray],
...
@@ -85,3 +85,20 @@ def pad_sequence(sequences: List[np.ndarray],
out_tensor
[:
length
,
i
,
...]
=
tensor
out_tensor
[:
length
,
i
,
...]
=
tensor
return
out_tensor
return
out_tensor
def
feat_type
(
filepath
):
suffix
=
filepath
.
split
(
":"
)[
0
].
split
(
'.'
)[
-
1
].
lower
()
if
suffix
==
'ark'
:
return
'mat'
elif
suffix
==
'scp'
:
return
'scp'
elif
suffix
==
'npy'
:
return
'npy'
elif
suffix
==
'npz'
:
return
'npz'
elif
suffix
in
[
'wav'
,
'flac'
]:
# PCM16
return
'sound'
else
:
raise
ValueError
(
f
"Not support filetype:
{
suffix
}
"
)
paddlespeech/s2t/transform/spectrogram.py
浏览文件 @
69055698
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# Modified from espnet(https://github.com/espnet/espnet)
# Modified from espnet(https://github.com/espnet/espnet)
import
librosa
import
librosa
import
numpy
as
np
import
numpy
as
np
from
python_speech_features
import
logfbank
def
stft
(
x
,
def
stft
(
x
,
...
@@ -304,3 +305,85 @@ class IStft():
...
@@ -304,3 +305,85 @@ class IStft():
win_length
=
self
.
win_length
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
window
=
self
.
window
,
center
=
self
.
center
,
)
center
=
self
.
center
,
)
class
LogMelSpectrogramKaldi
():
def
__init__
(
self
,
fs
=
16000
,
n_mels
=
80
,
n_fft
=
512
,
# fft point
n_shift
=
160
,
# unit:sample, 10ms
win_length
=
400
,
# unit:sample, 25ms
window
=
"povey"
,
fmin
=
20
,
fmax
=
None
,
eps
=
1e-10
,
dither
=
False
):
self
.
fs
=
fs
self
.
n_mels
=
n_mels
self
.
n_fft
=
n_fft
if
n_shift
>
win_length
:
raise
ValueError
(
"Stride size must not be greater than "
"window size."
)
self
.
n_shift
=
n_shift
/
fs
# unit: ms
self
.
win_length
=
win_length
/
fs
# unit: ms
self
.
window
=
window
self
.
fmin
=
fmin
if
fmax
is
None
:
fmax_
=
fmax
if
fmax
else
self
.
fs
/
2
elif
fmax
>
int
(
self
.
fs
/
2
):
raise
ValueError
(
"fmax must not be greater than half of "
"sample rate."
)
self
.
fmax
=
fmax_
self
.
eps
=
eps
self
.
remove_dc_offset
=
True
self
.
preemph
=
0.97
self
.
dither
=
dither
def
__repr__
(
self
):
return
(
"{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"n_shift={n_shift}, win_length={win_length}, window={window}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))"
.
format
(
name
=
self
.
__class__
.
__name__
,
fs
=
self
.
fs
,
n_mels
=
self
.
n_mels
,
n_fft
=
self
.
n_fft
,
n_shift
=
self
.
n_shift
,
win_length
=
self
.
win_length
,
window
=
self
.
window
,
fmin
=
self
.
fmin
,
fmax
=
self
.
fmax
,
eps
=
self
.
eps
,
))
def
__call__
(
self
,
x
):
"""
Args:
x (np.ndarray): shape (Ti,)
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
if
x
.
ndim
!=
1
:
raise
ValueError
(
"Not support x: [Time, Channel]"
)
if
x
.
dtype
==
np
.
int16
:
x
=
x
/
2
**
(
16
-
1
)
return
logfbank
(
signal
=
x
,
samplerate
=
self
.
fs
,
winlen
=
self
.
win_length
,
# unit ms
winstep
=
self
.
n_shift
,
# unit ms
nfilt
=
self
.
n_mels
,
nfft
=
self
.
n_fft
,
lowfreq
=
self
.
fmin
,
highfreq
=
self
.
fmax
,
dither
=
self
.
dither
,
remove_dc_offset
=
self
.
remove_dc_offset
,
preemph
=
self
.
preemph
,
wintype
=
self
.
window
)
paddlespeech/s2t/transform/transformation.py
浏览文件 @
69055698
...
@@ -45,6 +45,7 @@ import_alias = dict(
...
@@ -45,6 +45,7 @@ import_alias = dict(
stft2fbank
=
"paddlespeech.s2t.transform.spectrogram:Stft2LogMelSpectrogram"
,
stft2fbank
=
"paddlespeech.s2t.transform.spectrogram:Stft2LogMelSpectrogram"
,
wpe
=
"paddlespeech.s2t.transform.wpe:WPE"
,
wpe
=
"paddlespeech.s2t.transform.wpe:WPE"
,
channel_selector
=
"paddlespeech.s2t.transform.channel_selector:ChannelSelector"
,
channel_selector
=
"paddlespeech.s2t.transform.channel_selector:ChannelSelector"
,
fbank_kaldi
=
"paddlespeech.s2t.transform.spectrogram:LogMelSpectrogramKaldi"
,
)
)
...
...
utils/format_data.py
浏览文件 @
69055698
...
@@ -20,13 +20,13 @@ import json
...
@@ -20,13 +20,13 @@ import json
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.utility
import
load_cmvn
from
paddlespeech.s2t.frontend.utility
import
load_cmvn
from
paddlespeech.s2t.frontend.utility
import
read_manifest
from
paddlespeech.s2t.frontend.utility
import
read_manifest
from
paddlespeech.s2t.io.utility
import
feat_type
from
paddlespeech.s2t.utils.utility
import
add_arguments
from
paddlespeech.s2t.utils.utility
import
add_arguments
from
paddlespeech.s2t.utils.utility
import
print_arguments
from
paddlespeech.s2t.utils.utility
import
print_arguments
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
add_arg
=
functools
.
partial
(
add_arguments
,
argparser
=
parser
)
add_arg
=
functools
.
partial
(
add_arguments
,
argparser
=
parser
)
# yapf: disable
# yapf: disable
add_arg
(
'feat_type'
,
str
,
"raw"
,
"speech feature type, e.g. raw(wav, flac), mat(ark), scp"
)
add_arg
(
'cmvn_path'
,
str
,
add_arg
(
'cmvn_path'
,
str
,
'examples/librispeech/data/mean_std.json'
,
'examples/librispeech/data/mean_std.json'
,
"Filepath of cmvn."
)
"Filepath of cmvn."
)
...
@@ -62,24 +62,64 @@ def main():
...
@@ -62,24 +62,64 @@ def main():
vocab_size
=
text_feature
.
vocab_size
vocab_size
=
text_feature
.
vocab_size
print
(
f
"Vocab size:
{
vocab_size
}
"
)
print
(
f
"Vocab size:
{
vocab_size
}
"
)
# josnline like this
# {
# "input": [{"name": "input1", "shape": (100, 83), "feat": "xxx.ark:123"}],
# "output": [{"name":"target1", "shape": (40, 5002), "text": "a b c de"}],
# "utt2spk": "111-2222",
# "utt": "111-2222-333"
# }
count
=
0
count
=
0
for
manifest_path
in
args
.
manifest_paths
:
for
manifest_path
in
args
.
manifest_paths
:
manifest_jsons
=
read_manifest
(
manifest_path
)
manifest_jsons
=
read_manifest
(
manifest_path
)
for
line_json
in
manifest_jsons
:
for
line_json
in
manifest_jsons
:
output_json
=
{
"input"
:
[],
"output"
:
[],
'utt'
:
line_json
[
'utt'
],
'utt2spk'
:
line_json
.
get
(
'utt2spk'
,
'global'
),
}
# output
line
=
line_json
[
'text'
]
line
=
line_json
[
'text'
]
if
isinstance
(
line
,
str
):
# only one target
tokens
=
text_feature
.
tokenize
(
line
)
tokens
=
text_feature
.
tokenize
(
line
)
tokenids
=
text_feature
.
featurize
(
line
)
tokenids
=
text_feature
.
featurize
(
line
)
line_json
[
'token'
]
=
tokens
output_json
[
'output'
].
append
({
line_json
[
'token_id'
]
=
tokenids
'name'
:
'traget1'
,
line_json
[
'token_shape'
]
=
(
len
(
tokenids
),
vocab_size
)
'shape'
:
(
len
(
tokenids
),
vocab_size
),
'text'
:
line
,
'token'
:
' '
.
join
(
tokens
),
'tokenid'
:
' '
.
join
(
map
(
str
,
tokenids
)),
})
else
:
# isinstance(line, list), multi target
raise
NotImplementedError
(
"not support multi output now!"
)
# input
line
=
line_json
[
'feat'
]
if
isinstance
(
line
,
str
):
# only one input
feat_shape
=
line_json
[
'feat_shape'
]
feat_shape
=
line_json
[
'feat_shape'
]
assert
isinstance
(
feat_shape
,
(
list
,
tuple
)),
type
(
feat_shape
)
assert
isinstance
(
feat_shape
,
(
list
,
tuple
)),
type
(
feat_shape
)
if
args
.
feat_type
==
'raw'
:
filetype
=
feat_type
(
line
)
if
filetype
==
'sound'
:
feat_shape
.
append
(
feat_dim
)
feat_shape
.
append
(
feat_dim
)
line_json
[
'filetype'
]
=
'sound'
else
:
# kaldi
else
:
# kaldi
raise
NotImplementedError
(
'no support kaldi feat now!'
)
raise
NotImplementedError
(
'no support kaldi feat now!'
)
fout
.
write
(
json
.
dumps
(
line_json
)
+
'
\n
'
)
output_json
[
'input'
].
append
({
"name"
:
"input1"
,
"shape"
:
feat_shape
,
"feat"
:
line
,
"filetype"
:
filetype
,
})
else
:
# isinstance(line, list), multi input
raise
NotImplementedError
(
"not support multi input now!"
)
fout
.
write
(
json
.
dumps
(
output_json
)
+
'
\n
'
)
count
+=
1
count
+=
1
print
(
f
"Examples number:
{
count
}
"
)
print
(
f
"Examples number:
{
count
}
"
)
...
...
utils/format_triplet_data.py
浏览文件 @
69055698
...
@@ -20,13 +20,13 @@ import json
...
@@ -20,13 +20,13 @@ import json
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.utility
import
load_cmvn
from
paddlespeech.s2t.frontend.utility
import
load_cmvn
from
paddlespeech.s2t.frontend.utility
import
read_manifest
from
paddlespeech.s2t.frontend.utility
import
read_manifest
from
paddlespeech.s2t.io.utility
import
feat_type
from
paddlespeech.s2t.utils.utility
import
add_arguments
from
paddlespeech.s2t.utils.utility
import
add_arguments
from
paddlespeech.s2t.utils.utility
import
print_arguments
from
paddlespeech.s2t.utils.utility
import
print_arguments
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
add_arg
=
functools
.
partial
(
add_arguments
,
argparser
=
parser
)
add_arg
=
functools
.
partial
(
add_arguments
,
argparser
=
parser
)
# yapf: disable
# yapf: disable
add_arg
(
'feat_type'
,
str
,
"raw"
,
"speech feature type, e.g. raw(wav, flac), kaldi"
)
add_arg
(
'cmvn_path'
,
str
,
add_arg
(
'cmvn_path'
,
str
,
'examples/librispeech/data/mean_std.json'
,
'examples/librispeech/data/mean_std.json'
,
"Filepath of cmvn."
)
"Filepath of cmvn."
)
...
@@ -79,9 +79,11 @@ def main():
...
@@ -79,9 +79,11 @@ def main():
line_json
[
'token1'
]
=
tokens
line_json
[
'token1'
]
=
tokens
line_json
[
'token_id1'
]
=
tokenids
line_json
[
'token_id1'
]
=
tokenids
line_json
[
'token_shape1'
]
=
(
len
(
tokenids
),
vocab_size
)
line_json
[
'token_shape1'
]
=
(
len
(
tokenids
),
vocab_size
)
feat_shape
=
line_json
[
'feat_shape'
]
feat_shape
=
line_json
[
'feat_shape'
]
assert
isinstance
(
feat_shape
,
(
list
,
tuple
)),
type
(
feat_shape
)
assert
isinstance
(
feat_shape
,
(
list
,
tuple
)),
type
(
feat_shape
)
if
args
.
feat_type
==
'raw'
:
filetype
=
feat_type
(
line_json
[
'feat'
])
if
filetype
==
'sound'
:
feat_shape
.
append
(
feat_dim
)
feat_shape
.
append
(
feat_dim
)
else
:
# kaldi
else
:
# kaldi
raise
NotImplementedError
(
'no support kaldi feat now!'
)
raise
NotImplementedError
(
'no support kaldi feat now!'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录