Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
385bdf5c
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
385bdf5c
编写于
10月 19, 2021
作者:
J
Jackwaterveg
提交者:
GitHub
10月 19, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #907 from PaddlePaddle/u2
u2 kaldi mutli process test with batchsize one
上级
aaa87698
50b2114b
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
200 addition
and
84 deletion
+200
-84
deepspeech/exps/u2/model.py
deepspeech/exps/u2/model.py
+9
-4
deepspeech/exps/u2_kaldi/model.py
deepspeech/exps/u2_kaldi/model.py
+15
-10
deepspeech/io/collator.py
deepspeech/io/collator.py
+5
-2
deepspeech/io/dataloader.py
deepspeech/io/dataloader.py
+9
-9
deepspeech/models/u2/u2.py
deepspeech/models/u2/u2.py
+2
-1
examples/librispeech/s1/conf/augmentation.json
examples/librispeech/s1/conf/augmentation.json
+8
-8
examples/librispeech/s2/README.md
examples/librispeech/s2/README.md
+5
-37
examples/librispeech/s2/conf/transformer.yaml
examples/librispeech/s2/conf/transformer.yaml
+3
-3
examples/librispeech/s2/local/test.sh
examples/librispeech/s2/local/test.sh
+8
-7
examples/librispeech/s2/run.sh
examples/librispeech/s2/run.sh
+1
-1
tools/Makefile
tools/Makefile
+1
-1
utils/json2trn.py
utils/json2trn.py
+96
-0
utils/score_sclite.sh
utils/score_sclite.sh
+2
-0
utils/utility.py
utils/utility.py
+36
-1
未找到文件。
deepspeech/exps/u2/model.py
浏览文件 @
385bdf5c
...
...
@@ -444,7 +444,7 @@ class U2Tester(U2Trainer):
start_time
=
time
.
time
()
text_feature
=
self
.
test_loader
.
collate_fn
.
text_feature
target_transcripts
=
self
.
ordid2token
(
texts
,
texts_len
)
result_transcripts
=
self
.
model
.
decode
(
result_transcripts
,
result_tokenids
=
self
.
model
.
decode
(
audio
,
audio_len
,
text_feature
=
text_feature
,
...
...
@@ -462,14 +462,19 @@ class U2Tester(U2Trainer):
simulate_streaming
=
cfg
.
simulate_streaming
)
decode_time
=
time
.
time
()
-
start_time
for
utt
,
target
,
result
in
zip
(
utts
,
target_transcripts
,
result_transcript
s
):
for
utt
,
target
,
result
,
rec_tids
in
zip
(
utts
,
target_transcripts
,
result_transcripts
,
result_tokenid
s
):
errors
,
len_ref
=
errors_func
(
target
,
result
)
errors_sum
+=
errors
len_refs
+=
len_ref
num_ins
+=
1
if
fout
:
fout
.
write
({
"utt"
:
utt
,
"ref"
:
target
,
"hyp"
:
result
})
fout
.
write
({
"utt"
:
utt
,
"refs"
:
[
target
],
"hyps"
:
[
result
],
"hyps_tokenid"
:
[
rec_tids
],
})
logger
.
info
(
f
"Utt:
{
utt
}
"
)
logger
.
info
(
f
"Ref:
{
target
}
"
)
logger
.
info
(
f
"Hyp:
{
result
}
"
)
...
...
deepspeech/exps/u2_kaldi/model.py
浏览文件 @
385bdf5c
...
...
@@ -390,6 +390,10 @@ class U2Tester(U2Trainer):
def
__init__
(
self
,
config
,
args
):
super
().
__init__
(
config
,
args
)
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
collator
.
unit_type
,
vocab_filepath
=
self
.
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
def
id2token
(
self
,
texts
,
texts_len
,
text_feature
):
""" ord() id to chr() chr """
...
...
@@ -413,15 +417,11 @@ class U2Tester(U2Trainer):
error_rate_func
=
error_rate
.
cer
if
cfg
.
error_rate_type
==
'cer'
else
error_rate
.
wer
start_time
=
time
.
time
()
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
collator
.
unit_type
,
vocab_filepath
=
self
.
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
target_transcripts
=
self
.
id2token
(
texts
,
texts_len
,
text_feature
)
result_transcripts
=
self
.
model
.
decode
(
target_transcripts
=
self
.
id2token
(
texts
,
texts_len
,
self
.
text_feature
)
result_transcripts
,
result_tokenids
=
self
.
model
.
decode
(
audio
,
audio_len
,
text_feature
=
text_feature
,
text_feature
=
self
.
text_feature
,
decoding_method
=
cfg
.
decoding_method
,
lang_model_path
=
cfg
.
lang_model_path
,
beam_alpha
=
cfg
.
alpha
,
...
...
@@ -436,14 +436,19 @@ class U2Tester(U2Trainer):
simulate_streaming
=
cfg
.
simulate_streaming
)
decode_time
=
time
.
time
()
-
start_time
for
utt
,
target
,
result
in
zip
(
utts
,
target_transcripts
,
result_transcripts
):
for
i
,
(
utt
,
target
,
result
,
rec_tids
)
in
enumerate
(
zip
(
utts
,
target_transcripts
,
result_transcripts
,
result_tokenids
)
):
errors
,
len_ref
=
errors_func
(
target
,
result
)
errors_sum
+=
errors
len_refs
+=
len_ref
num_ins
+=
1
if
fout
:
fout
.
write
({
"utt"
:
utt
,
"ref"
:
target
,
"hyp"
:
result
})
fout
.
write
({
"utt"
:
utt
,
"refs"
:
[
target
],
"hyps"
:
[
result
],
"hyps_tokenid"
:
[
rec_tids
],
})
logger
.
info
(
f
"Utt:
{
utt
}
"
)
logger
.
info
(
f
"Ref:
{
target
}
"
)
logger
.
info
(
f
"Hyp:
{
result
}
"
)
...
...
deepspeech/io/collator.py
浏览文件 @
385bdf5c
...
...
@@ -32,7 +32,7 @@ __all__ = ["SpeechCollator", "TripletSpeechCollator"]
logger
=
Log
(
__name__
).
getlog
()
def
tokenids
(
text
,
keep_transcription_text
):
def
_
tokenids
(
text
,
keep_transcription_text
):
# for training text is token ids
tokens
=
text
# token ids
...
...
@@ -93,6 +93,8 @@ class SpeechCollatorBase():
a user-defined shape) within one batch.
"""
self
.
keep_transcription_text
=
keep_transcription_text
self
.
train_mode
=
not
keep_transcription_text
self
.
stride_ms
=
stride_ms
self
.
window_ms
=
window_ms
self
.
feat_dim
=
feat_dim
...
...
@@ -192,6 +194,7 @@ class SpeechCollatorBase():
texts
=
[]
text_lens
=
[]
utts
=
[]
tids
=
[]
# tokenids
for
idx
,
item
in
enumerate
(
batch
):
utts
.
append
(
item
[
'utt'
])
...
...
@@ -203,7 +206,7 @@ class SpeechCollatorBase():
audios
.
append
(
audio
)
# [T, D]
audio_lens
.
append
(
audio
.
shape
[
0
])
tokens
=
tokenids
(
text
,
self
.
keep_transcription_text
)
tokens
=
_
tokenids
(
text
,
self
.
keep_transcription_text
)
texts
.
append
(
tokens
)
text_lens
.
append
(
tokens
.
shape
[
0
])
...
...
deepspeech/io/dataloader.py
浏览文件 @
385bdf5c
...
...
@@ -142,6 +142,15 @@ class BatchDataLoader():
collate_fn
=
batch_collate
,
num_workers
=
self
.
n_iter_processes
,
)
def
__len__
(
self
):
return
len
(
self
.
dataloader
)
def
__iter__
(
self
):
return
self
.
dataloader
.
__iter__
()
def
__call__
(
self
):
return
self
.
__iter__
()
def
__repr__
(
self
):
echo
=
f
"<
{
self
.
__class__
.
__module__
}
.
{
self
.
__class__
.
__name__
}
object at
{
hex
(
id
(
self
))
}
> "
echo
+=
f
"train_mode:
{
self
.
train_mode
}
, "
...
...
@@ -159,12 +168,3 @@ class BatchDataLoader():
echo
+=
f
"num_workers:
{
self
.
n_iter_processes
}
, "
echo
+=
f
"file:
{
self
.
json_file
}
"
return
echo
def
__len__
(
self
):
return
len
(
self
.
dataloader
)
def
__iter__
(
self
):
return
self
.
dataloader
.
__iter__
()
def
__call__
(
self
):
return
self
.
__iter__
()
deepspeech/models/u2/u2.py
浏览文件 @
385bdf5c
...
...
@@ -809,7 +809,8 @@ class U2BaseModel(nn.Layer):
raise
ValueError
(
f
"Not support decoding method:
{
decoding_method
}
"
)
res
=
[
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
]
return
res
res_tokenids
=
[
hyp
for
hyp
in
hyps
]
return
res
,
res_tokenids
class
U2Model
(
U2BaseModel
):
...
...
examples/librispeech/s1/conf/augmentation.json
浏览文件 @
385bdf5c
[
{
"type"
:
"shift"
,
"params"
:
{
"min_shift_ms"
:
-5
,
"max_shift_ms"
:
5
},
"prob"
:
1.0
},
{
"type"
:
"speed"
,
"params"
:
{
...
...
@@ -16,6 +8,14 @@
},
"prob"
:
0.0
},
{
"type"
:
"shift"
,
"params"
:
{
"min_shift_ms"
:
-5
,
"max_shift_ms"
:
5
},
"prob"
:
1.0
},
{
"type"
:
"specaug"
,
"params"
:
{
...
...
examples/librispeech/s2/README.md
浏览文件 @
385bdf5c
# LibriSpeech
## Data
| Data Subset | Duration in Seconds |
| data/manifest.train | 0.83s ~ 29.735s |
| data/manifest.dev | 1.065 ~ 35.155s |
| data/manifest.test-clean | 1.285s ~ 34.955s |
## Conformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention | - | - |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | | |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | | |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | | |
### Test w/o length filter
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean-all | attention | | |
## Chunk Conformer
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | attention | 16, -1 | | |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | 16, -1 | | |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | 16, -1 | | - |
| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | 16, -1 | | - |
## Transformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean | attention | | |
### Test w/o length filter
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| Model | Params | Config | Augmentation| Test Set | Decode Method | Loss | WER % |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean-all | attention | | |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention | 6.395054340362549 | 4.2 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_greedy_search | 6.395054340362549 | 5.0 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_prefix_beam_search | 6.395054340362549 | |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention_rescore | 6.395054340362549 | |
examples/librispeech/s2/conf/transformer.yaml
浏览文件 @
385bdf5c
...
...
@@ -5,9 +5,9 @@ data:
test_manifest
:
data/manifest.test-clean
collator
:
vocab_filepath
:
data/
train_960_unigram
5000_units.txt
unit_type
:
'
spm'
spm_model_prefix
:
'
data/train_960_unigram5000'
vocab_filepath
:
data/
bpe_unigram_
5000_units.txt
unit_type
:
spm
spm_model_prefix
:
data/bpe_unigram_5000
feat_dim
:
83
stride_ms
:
10.0
window_ms
:
25.0
...
...
examples/librispeech/s2/local/test.sh
浏览文件 @
385bdf5c
...
...
@@ -46,15 +46,17 @@ pids=() # initialize pids
for
dmethd
in
attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring
;
do
(
echo
"
${
dmethd
}
decoding"
for
rtask
in
${
recog_set
}
;
do
(
decode_dir
=
decode_
${
rtask
}
_
${
dmethd
}
_
$(
basename
${
config_path
%.*
}
)
_
${
lmtag
}
echo
"
${
rtask
}
dataset"
decode_dir
=
decode_
${
rtask
/-/_
}
_
${
dmethd
}
_
$(
basename
${
config_path
%.*
}
)
_
${
lmtag
}
feat_recog_dir
=
${
datadir
}
mkdir
-p
${
expdir
}
/
${
decode_dir
}
mkdir
-p
${
feat_recog_dir
}
# split data
split_json.sh
${
feat_recog_dir
}
/
manifest.
${
rtask
}
${
nj
}
split_json.sh manifest.
${
rtask
}
${
nj
}
#### use CPU for decoding
ngpu
=
0
...
...
@@ -74,17 +76,16 @@ for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_resco
--opts
decoding.batch_size
${
batch_size
}
\
--opts
data.test_manifest
${
feat_recog_dir
}
/split
${
nj
}
/JOB/manifest.
${
rtask
}
score_sclite.sh
--bpe
${
nbpe
}
--bpemodel
${
bpemodel
}
.model
--wer
tru
e
${
expdir
}
/
${
decode_dir
}
${
dict
}
score_sclite.sh
--bpe
${
nbpe
}
--bpemodel
${
bpemodel
}
--wer
fals
e
${
expdir
}
/
${
decode_dir
}
${
dict
}
)
&
pids+
=(
$!
)
# store background pids
i
=
0
;
for
pid
in
"
${
pids
[@]
}
"
;
do
wait
${
pid
}
||
((
++i
))
;
done
[
${
i
}
-gt
0
]
&&
echo
"
$0
:
${
i
}
background jobs are failed."
||
true
done
)
&
pids+
=(
$!
)
# store background pids
)
done
i
=
0
;
for
pid
in
"
${
pids
[@]
}
"
;
do
wait
${
pid
}
||
((
++i
))
;
done
[
${
i
}
-gt
0
]
&&
echo
"
$0
:
${
i
}
background jobs are failed."
&&
false
echo
"Finished"
exit
0
examples/librispeech/s2/run.sh
浏览文件 @
385bdf5c
...
...
@@ -32,7 +32,7 @@ fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES
=
0
./local/test.sh
${
conf_path
}
${
dict_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
./local/test.sh
${
conf_path
}
${
dict_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
fi
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
...
...
tools/Makefile
浏览文件 @
385bdf5c
...
...
@@ -6,7 +6,7 @@ CC ?= gcc # used for sph2pipe
# CXX = clang++ # Uncomment these lines...
# CC = clang # ...to build with Clang.
WGET
?=
wget
WGET
?=
wget
--no-check-certificate
.PHONY
:
all clean
...
...
utils/json2trn.py
0 → 100755
浏览文件 @
385bdf5c
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# 2018 Xuankai Chang (Shanghai Jiao Tong University)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import
argparse
import
json
import
logging
import
sys
import
jsonlines
from
utility
import
get_commandline_args
def
get_parser
():
parser
=
argparse
.
ArgumentParser
(
description
=
"convert a json to a transcription file with a token dictionary"
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
)
parser
.
add_argument
(
"json"
,
type
=
str
,
help
=
"jsonlines files"
)
parser
.
add_argument
(
"dict"
,
type
=
str
,
help
=
"dict, not used."
)
parser
.
add_argument
(
"--num-spkrs"
,
type
=
int
,
default
=
1
,
help
=
"number of speakers"
)
parser
.
add_argument
(
"--refs"
,
type
=
str
,
nargs
=
"+"
,
help
=
"ref for all speakers"
)
parser
.
add_argument
(
"--hyps"
,
type
=
str
,
nargs
=
"+"
,
help
=
"hyp for all outputs"
)
return
parser
def
main
(
args
):
args
=
get_parser
().
parse_args
(
args
)
convert
(
args
.
json
,
args
.
dict
,
args
.
refs
,
args
.
hyps
,
args
.
num_spkrs
)
def
convert
(
jsonf
,
dic
,
refs
,
hyps
,
num_spkrs
=
1
):
n_ref
=
len
(
refs
)
n_hyp
=
len
(
hyps
)
assert
n_ref
==
n_hyp
assert
n_ref
==
num_spkrs
# logging info
logfmt
=
"%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
logfmt
)
logging
.
info
(
get_commandline_args
())
logging
.
info
(
"reading %s"
,
jsonf
)
with
jsonlines
.
open
(
jsonf
,
"r"
)
as
f
:
j
=
[
item
for
item
in
f
]
logging
.
info
(
"reading %s"
,
dic
)
with
open
(
dic
,
"r"
)
as
f
:
dictionary
=
f
.
readlines
()
char_list
=
[
entry
.
split
(
" "
)[
0
]
for
entry
in
dictionary
]
char_list
.
insert
(
0
,
"<blank>"
)
char_list
.
append
(
"<eos>"
)
for
ns
in
range
(
num_spkrs
):
hyp_file
=
open
(
hyps
[
ns
],
"w"
)
ref_file
=
open
(
refs
[
ns
],
"w"
)
for
x
in
j
:
# recognition hypothesis
if
num_spkrs
==
1
:
#seq = [char_list[int(i)] for i in x['hyps_tokenid'][0]]
seq
=
x
[
'hyps'
][
0
]
else
:
seq
=
[
char_list
[
int
(
i
)]
for
i
in
x
[
'hyps_tokenid'
][
ns
]]
# In the recognition hypothesis,
# the <eos> symbol is usually attached in the last part of the sentence
# and it is removed below.
#hyp_file.write(" ".join(seq).replace("<eos>", ""))
hyp_file
.
write
(
seq
.
replace
(
"<eos>"
,
""
))
# spk-uttid
hyp_file
.
write
(
" ("
+
x
[
"utt"
]
+
")
\n
"
)
# reference
if
num_spkrs
==
1
:
seq
=
x
[
"refs"
][
0
]
else
:
seq
=
x
[
'refs'
][
ns
]
# Unlike the recognition hypothesis,
# the reference is directly generated from a token without dictionary
# to avoid to include <unk> symbols in the reference to make scoring normal.
# The detailed discussion can be found at
# https://github.com/espnet/espnet/issues/993
# ref_file.write(
# seq + " (" + j["utts"][x]["utt2spk"].replace("-", "_") + "-" + x + ")\n"
# )
ref_file
.
write
(
seq
+
" ("
+
x
[
'utt'
]
+
")
\n
"
)
hyp_file
.
close
()
ref_file
.
close
()
if
__name__
==
"__main__"
:
main
(
sys
.
argv
[
1
:])
utils/score_sclite.sh
浏览文件 @
385bdf5c
#!/usr/bin/env bash
set
-e
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
...
...
utils/utility.py
浏览文件 @
385bdf5c
...
...
@@ -14,6 +14,7 @@
import
hashlib
import
json
import
os
import
sys
import
tarfile
import
zipfile
from
typing
import
Text
...
...
@@ -21,7 +22,7 @@ from typing import Text
__all__
=
[
"check_md5sum"
,
"getfile_insensitive"
,
"download_multi"
,
"download"
,
"unpack"
,
"unzip"
,
"md5file"
,
"print_arguments"
,
"add_arguments"
,
"read_manifest"
"read_manifest"
,
"get_commandline_args"
]
...
...
@@ -46,6 +47,40 @@ def read_manifest(manifest_path):
return
manifest
def
get_commandline_args
():
extra_chars
=
[
" "
,
";"
,
"&"
,
"("
,
")"
,
"|"
,
"^"
,
"<"
,
">"
,
"?"
,
"*"
,
"["
,
"]"
,
"$"
,
"`"
,
'"'
,
"
\\
"
,
"!"
,
"{"
,
"}"
,
]
# Escape the extra characters for shell
argv
=
[
arg
.
replace
(
"'"
,
"'
\\
''"
)
if
all
(
char
not
in
arg
for
char
in
extra_chars
)
else
"'"
+
arg
.
replace
(
"'"
,
"'
\\
''"
)
+
"'"
for
arg
in
sys
.
argv
]
return
sys
.
executable
+
" "
+
" "
.
join
(
argv
)
def
print_arguments
(
args
,
info
=
None
):
"""Print argparse's arguments.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录