Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
feaf71d4
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
feaf71d4
编写于
10月 19, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
u2 kaldi mutli process test with batchsize one
上级
aaa87698
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
200 addition
and
84 deletion
+200
-84
deepspeech/exps/u2/model.py
deepspeech/exps/u2/model.py
+9
-4
deepspeech/exps/u2_kaldi/model.py
deepspeech/exps/u2_kaldi/model.py
+15
-10
deepspeech/io/collator.py
deepspeech/io/collator.py
+5
-2
deepspeech/io/dataloader.py
deepspeech/io/dataloader.py
+9
-9
deepspeech/models/u2/u2.py
deepspeech/models/u2/u2.py
+2
-1
examples/librispeech/s1/conf/augmentation.json
examples/librispeech/s1/conf/augmentation.json
+8
-8
examples/librispeech/s2/README.md
examples/librispeech/s2/README.md
+5
-37
examples/librispeech/s2/conf/transformer.yaml
examples/librispeech/s2/conf/transformer.yaml
+3
-3
examples/librispeech/s2/local/test.sh
examples/librispeech/s2/local/test.sh
+8
-7
examples/librispeech/s2/run.sh
examples/librispeech/s2/run.sh
+1
-1
tools/Makefile
tools/Makefile
+1
-1
utils/json2trn.py
utils/json2trn.py
+96
-0
utils/score_sclite.sh
utils/score_sclite.sh
+2
-0
utils/utility.py
utils/utility.py
+36
-1
未找到文件。
deepspeech/exps/u2/model.py
浏览文件 @
feaf71d4
...
...
@@ -444,7 +444,7 @@ class U2Tester(U2Trainer):
start_time
=
time
.
time
()
text_feature
=
self
.
test_loader
.
collate_fn
.
text_feature
target_transcripts
=
self
.
ordid2token
(
texts
,
texts_len
)
result_transcripts
=
self
.
model
.
decode
(
result_transcripts
,
result_tokenids
=
self
.
model
.
decode
(
audio
,
audio_len
,
text_feature
=
text_feature
,
...
...
@@ -462,14 +462,19 @@ class U2Tester(U2Trainer):
simulate_streaming
=
cfg
.
simulate_streaming
)
decode_time
=
time
.
time
()
-
start_time
for
utt
,
target
,
result
in
zip
(
utts
,
target_transcripts
,
result_transcript
s
):
for
utt
,
target
,
result
,
rec_tids
in
zip
(
utts
,
target_transcripts
,
result_transcripts
,
result_tokenid
s
):
errors
,
len_ref
=
errors_func
(
target
,
result
)
errors_sum
+=
errors
len_refs
+=
len_ref
num_ins
+=
1
if
fout
:
fout
.
write
({
"utt"
:
utt
,
"ref"
:
target
,
"hyp"
:
result
})
fout
.
write
({
"utt"
:
utt
,
"refs"
:
[
target
],
"hyps"
:
[
result
],
"hyps_tokenid"
:
[
rec_tids
],
})
logger
.
info
(
f
"Utt:
{
utt
}
"
)
logger
.
info
(
f
"Ref:
{
target
}
"
)
logger
.
info
(
f
"Hyp:
{
result
}
"
)
...
...
deepspeech/exps/u2_kaldi/model.py
浏览文件 @
feaf71d4
...
...
@@ -390,6 +390,10 @@ class U2Tester(U2Trainer):
def
__init__
(
self
,
config
,
args
):
super
().
__init__
(
config
,
args
)
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
collator
.
unit_type
,
vocab_filepath
=
self
.
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
def
id2token
(
self
,
texts
,
texts_len
,
text_feature
):
""" ord() id to chr() chr """
...
...
@@ -413,15 +417,11 @@ class U2Tester(U2Trainer):
error_rate_func
=
error_rate
.
cer
if
cfg
.
error_rate_type
==
'cer'
else
error_rate
.
wer
start_time
=
time
.
time
()
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
collator
.
unit_type
,
vocab_filepath
=
self
.
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
target_transcripts
=
self
.
id2token
(
texts
,
texts_len
,
text_feature
)
result_transcripts
=
self
.
model
.
decode
(
target_transcripts
=
self
.
id2token
(
texts
,
texts_len
,
self
.
text_feature
)
result_transcripts
,
result_tokenids
=
self
.
model
.
decode
(
audio
,
audio_len
,
text_feature
=
text_feature
,
text_feature
=
self
.
text_feature
,
decoding_method
=
cfg
.
decoding_method
,
lang_model_path
=
cfg
.
lang_model_path
,
beam_alpha
=
cfg
.
alpha
,
...
...
@@ -436,14 +436,19 @@ class U2Tester(U2Trainer):
simulate_streaming
=
cfg
.
simulate_streaming
)
decode_time
=
time
.
time
()
-
start_time
for
utt
,
target
,
result
in
zip
(
utts
,
target_transcripts
,
result_transcripts
):
for
i
,
(
utt
,
target
,
result
,
rec_tids
)
in
enumerate
(
zip
(
utts
,
target_transcripts
,
result_transcripts
,
result_tokenids
)
):
errors
,
len_ref
=
errors_func
(
target
,
result
)
errors_sum
+=
errors
len_refs
+=
len_ref
num_ins
+=
1
if
fout
:
fout
.
write
({
"utt"
:
utt
,
"ref"
:
target
,
"hyp"
:
result
})
fout
.
write
({
"utt"
:
utt
,
"refs"
:
[
target
],
"hyps"
:
[
result
],
"hyps_tokenid"
:
[
rec_tids
],
})
logger
.
info
(
f
"Utt:
{
utt
}
"
)
logger
.
info
(
f
"Ref:
{
target
}
"
)
logger
.
info
(
f
"Hyp:
{
result
}
"
)
...
...
deepspeech/io/collator.py
浏览文件 @
feaf71d4
...
...
@@ -32,7 +32,7 @@ __all__ = ["SpeechCollator", "TripletSpeechCollator"]
logger
=
Log
(
__name__
).
getlog
()
def
tokenids
(
text
,
keep_transcription_text
):
def
_
tokenids
(
text
,
keep_transcription_text
):
# for training text is token ids
tokens
=
text
# token ids
...
...
@@ -93,6 +93,8 @@ class SpeechCollatorBase():
a user-defined shape) within one batch.
"""
self
.
keep_transcription_text
=
keep_transcription_text
self
.
train_mode
=
not
keep_transcription_text
self
.
stride_ms
=
stride_ms
self
.
window_ms
=
window_ms
self
.
feat_dim
=
feat_dim
...
...
@@ -192,6 +194,7 @@ class SpeechCollatorBase():
texts
=
[]
text_lens
=
[]
utts
=
[]
tids
=
[]
# tokenids
for
idx
,
item
in
enumerate
(
batch
):
utts
.
append
(
item
[
'utt'
])
...
...
@@ -203,7 +206,7 @@ class SpeechCollatorBase():
audios
.
append
(
audio
)
# [T, D]
audio_lens
.
append
(
audio
.
shape
[
0
])
tokens
=
tokenids
(
text
,
self
.
keep_transcription_text
)
tokens
=
_
tokenids
(
text
,
self
.
keep_transcription_text
)
texts
.
append
(
tokens
)
text_lens
.
append
(
tokens
.
shape
[
0
])
...
...
deepspeech/io/dataloader.py
浏览文件 @
feaf71d4
...
...
@@ -142,6 +142,15 @@ class BatchDataLoader():
collate_fn
=
batch_collate
,
num_workers
=
self
.
n_iter_processes
,
)
def
__len__
(
self
):
return
len
(
self
.
dataloader
)
def
__iter__
(
self
):
return
self
.
dataloader
.
__iter__
()
def
__call__
(
self
):
return
self
.
__iter__
()
def
__repr__
(
self
):
echo
=
f
"<
{
self
.
__class__
.
__module__
}
.
{
self
.
__class__
.
__name__
}
object at
{
hex
(
id
(
self
))
}
> "
echo
+=
f
"train_mode:
{
self
.
train_mode
}
, "
...
...
@@ -159,12 +168,3 @@ class BatchDataLoader():
echo
+=
f
"num_workers:
{
self
.
n_iter_processes
}
, "
echo
+=
f
"file:
{
self
.
json_file
}
"
return
echo
def
__len__
(
self
):
return
len
(
self
.
dataloader
)
def
__iter__
(
self
):
return
self
.
dataloader
.
__iter__
()
def
__call__
(
self
):
return
self
.
__iter__
()
deepspeech/models/u2/u2.py
浏览文件 @
feaf71d4
...
...
@@ -809,7 +809,8 @@ class U2BaseModel(nn.Layer):
raise
ValueError
(
f
"Not support decoding method:
{
decoding_method
}
"
)
res
=
[
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
]
return
res
res_tokenids
=
[
hyp
for
hyp
in
hyps
]
return
res
,
res_tokenids
class
U2Model
(
U2BaseModel
):
...
...
examples/librispeech/s1/conf/augmentation.json
浏览文件 @
feaf71d4
[
{
"type"
:
"shift"
,
"params"
:
{
"min_shift_ms"
:
-5
,
"max_shift_ms"
:
5
},
"prob"
:
1.0
},
{
"type"
:
"speed"
,
"params"
:
{
...
...
@@ -16,6 +8,14 @@
},
"prob"
:
0.0
},
{
"type"
:
"shift"
,
"params"
:
{
"min_shift_ms"
:
-5
,
"max_shift_ms"
:
5
},
"prob"
:
1.0
},
{
"type"
:
"specaug"
,
"params"
:
{
...
...
examples/librispeech/s2/README.md
浏览文件 @
feaf71d4
# LibriSpeech
## Data
| Data Subset | Duration in Seconds |
| data/manifest.train | 0.83s ~ 29.735s |
| data/manifest.dev | 1.065 ~ 35.155s |
| data/manifest.test-clean | 1.285s ~ 34.955s |
## Conformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention | - | - |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | | |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | | |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | | |
### Test w/o length filter
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean-all | attention | | |
## Chunk Conformer
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | attention | 16, -1 | | |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | 16, -1 | | |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | 16, -1 | | - |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | 16, -1 | | - |
## Transformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean | attention | | |
### Test w/o length filter
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| Model | Params | Config | Augmentation| Test Set | Decode Method | Loss | WER % |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean-all | attention | | |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention | 6.395054340362549 | 4.2 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_greedy_search | 6.395054340362549 | 5.0 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_prefix_beam_search | 6.395054340362549 | |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention_rescore | 6.395054340362549 | |
examples/librispeech/s2/conf/transformer.yaml
浏览文件 @
feaf71d4
...
...
@@ -5,9 +5,9 @@ data:
test_manifest
:
data/manifest.test-clean
collator
:
vocab_filepath
:
data/
train_960_unigram
5000_units.txt
unit_type
:
'
spm'
spm_model_prefix
:
'
data/train_960_unigram5000'
vocab_filepath
:
data/
bpe_unigram_
5000_units.txt
unit_type
:
spm
spm_model_prefix
:
data/bpe_unigram_5000
feat_dim
:
83
stride_ms
:
10.0
window_ms
:
25.0
...
...
examples/librispeech/s2/local/test.sh
浏览文件 @
feaf71d4
...
...
@@ -46,15 +46,17 @@ pids=() # initialize pids
for
dmethd
in
attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring
;
do
(
echo
"
${
dmethd
}
decoding"
for
rtask
in
${
recog_set
}
;
do
(
decode_dir
=
decode_
${
rtask
}
_
${
dmethd
}
_
$(
basename
${
config_path
%.*
}
)
_
${
lmtag
}
echo
"
${
rtask
}
dataset"
decode_dir
=
decode_
${
rtask
/-/_
}
_
${
dmethd
}
_
$(
basename
${
config_path
%.*
}
)
_
${
lmtag
}
feat_recog_dir
=
${
datadir
}
mkdir
-p
${
expdir
}
/
${
decode_dir
}
mkdir
-p
${
feat_recog_dir
}
# split data
split_json.sh
${
feat_recog_dir
}
/
manifest.
${
rtask
}
${
nj
}
split_json.sh manifest.
${
rtask
}
${
nj
}
#### use CPU for decoding
ngpu
=
0
...
...
@@ -74,17 +76,16 @@ for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_resco
--opts
decoding.batch_size
${
batch_size
}
\
--opts
data.test_manifest
${
feat_recog_dir
}
/split
${
nj
}
/JOB/manifest.
${
rtask
}
score_sclite.sh
--bpe
${
nbpe
}
--bpemodel
${
bpemodel
}
.model
--wer
tru
e
${
expdir
}
/
${
decode_dir
}
${
dict
}
score_sclite.sh
--bpe
${
nbpe
}
--bpemodel
${
bpemodel
}
--wer
fals
e
${
expdir
}
/
${
decode_dir
}
${
dict
}
)
&
pids+
=(
$!
)
# store background pids
i
=
0
;
for
pid
in
"
${
pids
[@]
}
"
;
do
wait
${
pid
}
||
((
++i
))
;
done
[
${
i
}
-gt
0
]
&&
echo
"
$0
:
${
i
}
background jobs are failed."
&&
false
done
)
&
pids+
=(
$!
)
# store background pids
)
done
i
=
0
;
for
pid
in
"
${
pids
[@]
}
"
;
do
wait
${
pid
}
||
((
++i
))
;
done
[
${
i
}
-gt
0
]
&&
echo
"
$0
:
${
i
}
background jobs are failed."
&&
false
echo
"Finished"
exit
0
examples/librispeech/s2/run.sh
浏览文件 @
feaf71d4
...
...
@@ -32,7 +32,7 @@ fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES
=
0
./local/test.sh
${
conf_path
}
${
dict_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
./local/test.sh
${
conf_path
}
${
dict_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
fi
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
...
...
tools/Makefile
浏览文件 @
feaf71d4
...
...
@@ -6,7 +6,7 @@ CC ?= gcc # used for sph2pipe
# CXX = clang++ # Uncomment these lines...
# CC = clang # ...to build with Clang.
WGET
?=
wget
WGET
?=
wget
--no-check-certificate
.PHONY
:
all clean
...
...
utils/json2trn.py
0 → 100755
浏览文件 @
feaf71d4
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# 2018 Xuankai Chang (Shanghai Jiao Tong University)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import
argparse
import
json
import
logging
import
sys
import
jsonlines
from
utility
import
get_commandline_args
def
get_parser
():
parser
=
argparse
.
ArgumentParser
(
description
=
"convert a json to a transcription file with a token dictionary"
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
)
parser
.
add_argument
(
"json"
,
type
=
str
,
help
=
"jsonlines files"
)
parser
.
add_argument
(
"dict"
,
type
=
str
,
help
=
"dict, not used."
)
parser
.
add_argument
(
"--num-spkrs"
,
type
=
int
,
default
=
1
,
help
=
"number of speakers"
)
parser
.
add_argument
(
"--refs"
,
type
=
str
,
nargs
=
"+"
,
help
=
"ref for all speakers"
)
parser
.
add_argument
(
"--hyps"
,
type
=
str
,
nargs
=
"+"
,
help
=
"hyp for all outputs"
)
return
parser
def
main
(
args
):
args
=
get_parser
().
parse_args
(
args
)
convert
(
args
.
json
,
args
.
dict
,
args
.
refs
,
args
.
hyps
,
args
.
num_spkrs
)
def
convert
(
jsonf
,
dic
,
refs
,
hyps
,
num_spkrs
=
1
):
n_ref
=
len
(
refs
)
n_hyp
=
len
(
hyps
)
assert
n_ref
==
n_hyp
assert
n_ref
==
num_spkrs
# logging info
logfmt
=
"%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
logfmt
)
logging
.
info
(
get_commandline_args
())
logging
.
info
(
"reading %s"
,
jsonf
)
with
jsonlines
.
open
(
jsonf
,
"r"
)
as
f
:
j
=
[
item
for
item
in
f
]
logging
.
info
(
"reading %s"
,
dic
)
with
open
(
dic
,
"r"
)
as
f
:
dictionary
=
f
.
readlines
()
char_list
=
[
entry
.
split
(
" "
)[
0
]
for
entry
in
dictionary
]
char_list
.
insert
(
0
,
"<blank>"
)
char_list
.
append
(
"<eos>"
)
for
ns
in
range
(
num_spkrs
):
hyp_file
=
open
(
hyps
[
ns
],
"w"
)
ref_file
=
open
(
refs
[
ns
],
"w"
)
for
x
in
j
:
# recognition hypothesis
if
num_spkrs
==
1
:
#seq = [char_list[int(i)] for i in x['hyps_tokenid'][0]]
seq
=
x
[
'hyps'
][
0
]
else
:
seq
=
[
char_list
[
int
(
i
)]
for
i
in
x
[
'hyps_tokenid'
][
ns
]]
# In the recognition hypothesis,
# the <eos> symbol is usually attached in the last part of the sentence
# and it is removed below.
#hyp_file.write(" ".join(seq).replace("<eos>", ""))
hyp_file
.
write
(
seq
.
replace
(
"<eos>"
,
""
))
# spk-uttid
hyp_file
.
write
(
" ("
+
x
[
"utt"
]
+
")
\n
"
)
# reference
if
num_spkrs
==
1
:
seq
=
x
[
"refs"
][
0
]
else
:
seq
=
x
[
'refs'
][
ns
]
# Unlike the recognition hypothesis,
# the reference is directly generated from a token without dictionary
# to avoid to include <unk> symbols in the reference to make scoring normal.
# The detailed discussion can be found at
# https://github.com/espnet/espnet/issues/993
# ref_file.write(
# seq + " (" + j["utts"][x]["utt2spk"].replace("-", "_") + "-" + x + ")\n"
# )
ref_file
.
write
(
seq
+
" ("
+
x
[
'utt'
]
+
")
\n
"
)
hyp_file
.
close
()
ref_file
.
close
()
if
__name__
==
"__main__"
:
main
(
sys
.
argv
[
1
:])
utils/score_sclite.sh
浏览文件 @
feaf71d4
#!/usr/bin/env bash
set
-e
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
...
...
utils/utility.py
浏览文件 @
feaf71d4
...
...
@@ -14,6 +14,7 @@
import
hashlib
import
json
import
os
import
sys
import
tarfile
import
zipfile
from
typing
import
Text
...
...
@@ -21,7 +22,7 @@ from typing import Text
__all__
=
[
"check_md5sum"
,
"getfile_insensitive"
,
"download_multi"
,
"download"
,
"unpack"
,
"unzip"
,
"md5file"
,
"print_arguments"
,
"add_arguments"
,
"read_manifest"
"read_manifest"
,
"get_commandline_args"
]
...
...
@@ -46,6 +47,40 @@ def read_manifest(manifest_path):
return
manifest
def
get_commandline_args
():
extra_chars
=
[
" "
,
";"
,
"&"
,
"("
,
")"
,
"|"
,
"^"
,
"<"
,
">"
,
"?"
,
"*"
,
"["
,
"]"
,
"$"
,
"`"
,
'"'
,
"
\\
"
,
"!"
,
"{"
,
"}"
,
]
# Escape the extra characters for shell
argv
=
[
arg
.
replace
(
"'"
,
"'
\\
''"
)
if
all
(
char
not
in
arg
for
char
in
extra_chars
)
else
"'"
+
arg
.
replace
(
"'"
,
"'
\\
''"
)
+
"'"
for
arg
in
sys
.
argv
]
return
sys
.
executable
+
" "
+
" "
.
join
(
argv
)
def
print_arguments
(
args
,
info
=
None
):
"""Print argparse's arguments.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录