Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
a6091008
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
接近 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a6091008
编写于
10月 22, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
can run recog
上级
190f4cc4
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
646 addition
and
81 deletion
+646
-81
deepspeech/__init__.py
deepspeech/__init__.py
+10
-10
deepspeech/decoders/beam_search.py
deepspeech/decoders/beam_search.py
+13
-9
deepspeech/decoders/recog.py
deepspeech/decoders/recog.py
+75
-36
deepspeech/decoders/scorers/ctc.py
deepspeech/decoders/scorers/ctc.py
+2
-2
deepspeech/decoders/scorers/ctc_prefix_score.py
deepspeech/decoders/scorers/ctc_prefix_score.py
+2
-2
deepspeech/decoders/scorers/scorer_interface.py
deepspeech/decoders/scorers/scorer_interface.py
+0
-0
deepspeech/decoders/utils.py
deepspeech/decoders/utils.py
+7
-3
deepspeech/exps/u2_kaldi/bin/recog.py
deepspeech/exps/u2_kaldi/bin/recog.py
+379
-0
deepspeech/exps/u2_kaldi/model.py
deepspeech/exps/u2_kaldi/model.py
+0
-2
deepspeech/frontend/featurizer/text_featurizer.py
deepspeech/frontend/featurizer/text_featurizer.py
+1
-1
deepspeech/models/u2/u2.py
deepspeech/models/u2/u2.py
+1
-1
deepspeech/modules/decoder.py
deepspeech/modules/decoder.py
+5
-6
deepspeech/modules/mask.py
deepspeech/modules/mask.py
+4
-3
deepspeech/training/cli.py
deepspeech/training/cli.py
+4
-2
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+17
-3
examples/librispeech/s2/conf/decode/decode.yaml
examples/librispeech/s2/conf/decode/decode.yaml
+7
-0
examples/librispeech/s2/conf/decode/decode_all.yaml
examples/librispeech/s2/conf/decode/decode_all.yaml
+7
-0
examples/librispeech/s2/conf/decode/decode_wo_lm.yaml
examples/librispeech/s2/conf/decode/decode_wo_lm.yaml
+7
-0
examples/librispeech/s2/local/recog.sh
examples/librispeech/s2/local/recog.sh
+103
-0
examples/librispeech/s2/local/test.sh
examples/librispeech/s2/local/test.sh
+1
-1
requirements.txt
requirements.txt
+1
-0
未找到文件。
deepspeech/__init__.py
浏览文件 @
a6091008
...
@@ -233,7 +233,7 @@ def is_broadcastable(shp1, shp2):
...
@@ -233,7 +233,7 @@ def is_broadcastable(shp1, shp2):
def
masked_fill
(
xs
:
paddle
.
Tensor
,
def
masked_fill
(
xs
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
value
:
Union
[
float
,
int
]):
value
:
Union
[
float
,
int
]):
assert
is_broadcastable
(
xs
.
shape
,
mask
.
shape
)
is
True
assert
is_broadcastable
(
xs
.
shape
,
mask
.
shape
)
is
True
,
(
xs
.
shape
,
mask
.
shape
)
bshape
=
paddle
.
broadcast_shape
(
xs
.
shape
,
mask
.
shape
)
bshape
=
paddle
.
broadcast_shape
(
xs
.
shape
,
mask
.
shape
)
mask
=
mask
.
broadcast_to
(
bshape
)
mask
=
mask
.
broadcast_to
(
bshape
)
trues
=
paddle
.
ones_like
(
xs
)
*
value
trues
=
paddle
.
ones_like
(
xs
)
*
value
...
@@ -312,18 +312,18 @@ if not hasattr(paddle.Tensor, 'type_as'):
...
@@ -312,18 +312,18 @@ if not hasattr(paddle.Tensor, 'type_as'):
def
to
(
x
:
paddle
.
Tensor
,
*
args
,
**
kwargs
)
->
paddle
.
Tensor
:
def
to
(
x
:
paddle
.
Tensor
,
*
args
,
**
kwargs
)
->
paddle
.
Tensor
:
assert
len
(
args
)
==
1
assert
len
(
args
)
==
1
if
isinstance
(
args
[
0
],
str
):
# dtype
if
isinstance
(
args
[
0
],
str
):
# dtype
return
x
.
astype
(
args
[
0
])
return
x
.
astype
(
args
[
0
])
elif
isinstance
(
args
[
0
],
paddle
.
Tensor
):
#
Tensor
elif
isinstance
(
args
[
0
],
paddle
.
Tensor
):
#
Tensor
return
x
.
astype
(
args
[
0
].
dtype
)
return
x
.
astype
(
args
[
0
].
dtype
)
else
:
# Device
else
:
# Device
return
x
return
x
if
not
hasattr
(
paddle
.
Tensor
,
'to'
):
if
not
hasattr
(
paddle
.
Tensor
,
'to'
):
logger
.
debug
(
"register user to to paddle.Tensor, remove this when fixed!"
)
logger
.
debug
(
"register user to to paddle.Tensor, remove this when fixed!"
)
setattr
(
paddle
.
Tensor
,
'to'
,
to
)
setattr
(
paddle
.
Tensor
,
'to'
,
to
)
def
func_float
(
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
def
func_float
(
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
...
...
deepspeech/decoders/beam_search.py
浏览文件 @
a6091008
"""Beam search module."""
"""Beam search module."""
from
itertools
import
chain
from
itertools
import
chain
import
logger
from
typing
import
Any
from
typing
import
Any
from
typing
import
Dict
from
typing
import
Dict
from
typing
import
List
from
typing
import
List
...
@@ -141,7 +140,7 @@ class BeamSearch(paddle.nn.Layer):
...
@@ -141,7 +140,7 @@ class BeamSearch(paddle.nn.Layer):
]
]
@
staticmethod
@
staticmethod
def
append_token
(
xs
:
paddle
.
Tensor
,
x
:
int
)
->
paddle
.
Tensor
:
def
append_token
(
xs
:
paddle
.
Tensor
,
x
:
Union
[
int
,
paddle
.
Tensor
]
)
->
paddle
.
Tensor
:
"""Append new token to prefix tokens.
"""Append new token to prefix tokens.
Args:
Args:
...
@@ -152,8 +151,8 @@ class BeamSearch(paddle.nn.Layer):
...
@@ -152,8 +151,8 @@ class BeamSearch(paddle.nn.Layer):
paddle.Tensor: (T+1,), New tensor contains: xs + [x] with xs.dtype and xs.device
paddle.Tensor: (T+1,), New tensor contains: xs + [x] with xs.dtype and xs.device
"""
"""
x
=
paddle
.
to_tensor
([
x
],
dtype
=
xs
.
dtype
,
place
=
xs
.
place
)
x
=
paddle
.
to_tensor
([
x
],
dtype
=
xs
.
dtype
)
if
isinstance
(
x
,
int
)
else
x
return
paddle
.
cat
((
xs
,
x
))
return
paddle
.
c
onc
at
((
xs
,
x
))
def
score_full
(
def
score_full
(
self
,
hyp
:
Hypothesis
,
x
:
paddle
.
Tensor
self
,
hyp
:
Hypothesis
,
x
:
paddle
.
Tensor
...
@@ -306,7 +305,7 @@ class BeamSearch(paddle.nn.Layer):
...
@@ -306,7 +305,7 @@ class BeamSearch(paddle.nn.Layer):
part_ids
=
paddle
.
arange
(
self
.
n_vocab
)
# no pre-beam
part_ids
=
paddle
.
arange
(
self
.
n_vocab
)
# no pre-beam
for
hyp
in
running_hyps
:
for
hyp
in
running_hyps
:
# scoring
# scoring
weighted_scores
=
paddle
.
zeros
(
self
.
n_vocab
,
dtype
=
x
.
dtype
)
weighted_scores
=
paddle
.
zeros
(
[
self
.
n_vocab
]
,
dtype
=
x
.
dtype
)
scores
,
states
=
self
.
score_full
(
hyp
,
x
)
scores
,
states
=
self
.
score_full
(
hyp
,
x
)
for
k
in
self
.
full_scorers
:
for
k
in
self
.
full_scorers
:
weighted_scores
+=
self
.
weights
[
k
]
*
scores
[
k
]
weighted_scores
+=
self
.
weights
[
k
]
*
scores
[
k
]
...
@@ -410,15 +409,20 @@ class BeamSearch(paddle.nn.Layer):
...
@@ -410,15 +409,20 @@ class BeamSearch(paddle.nn.Layer):
best
=
nbest_hyps
[
0
]
best
=
nbest_hyps
[
0
]
for
k
,
v
in
best
.
scores
.
items
():
for
k
,
v
in
best
.
scores
.
items
():
logger
.
info
(
logger
.
info
(
f
"
{
v
:
6.2
f
}
*
{
self
.
weights
[
k
]:
3
}
=
{
v
*
self
.
weights
[
k
]:
6.2
f
}
for
{
k
}
"
f
"
{
float
(
v
):
6.2
f
}
*
{
self
.
weights
[
k
]:
3
}
=
{
float
(
v
)
*
self
.
weights
[
k
]:
6.2
f
}
for
{
k
}
"
)
)
logger
.
info
(
f
"total log probability:
{
best
.
score
:.
2
f
}
"
)
logger
.
info
(
f
"total log probability:
{
float
(
best
.
score
)
:.
2
f
}
"
)
logger
.
info
(
f
"normalized log probability:
{
best
.
score
/
len
(
best
.
yseq
):.
2
f
}
"
)
logger
.
info
(
f
"normalized log probability:
{
float
(
best
.
score
)
/
len
(
best
.
yseq
):.
2
f
}
"
)
logger
.
info
(
f
"total number of ended hypotheses:
{
len
(
nbest_hyps
)
}
"
)
logger
.
info
(
f
"total number of ended hypotheses:
{
len
(
nbest_hyps
)
}
"
)
if
self
.
token_list
is
not
None
:
if
self
.
token_list
is
not
None
:
# logger.info(
# "best hypo: "
# + "".join([self.token_list[x] for x in best.yseq[1:-1]])
# + "\n"
# )
logger
.
info
(
logger
.
info
(
"best hypo: "
"best hypo: "
+
""
.
join
([
self
.
token_list
[
x
]
for
x
in
best
.
yseq
[
1
:
-
1
]])
+
""
.
join
([
self
.
token_list
[
x
]
for
x
in
best
.
yseq
[
1
:]])
+
"
\n
"
+
"
\n
"
)
)
return
nbest_hyps
return
nbest_hyps
...
...
deepspeech/decoders/recog.py
浏览文件 @
a6091008
...
@@ -2,18 +2,22 @@
...
@@ -2,18 +2,22 @@
import
json
import
json
import
paddle
import
paddle
import
yaml
from
yacs.config
import
CfgNode
from
pathlib
import
Path
import
jsonlines
# from espnet.asr.asr_utils import get_model_conf
# from espnet.asr.asr_utils import get_model_conf
# from espnet.asr.asr_utils import torch_load
# from espnet.asr.asr_utils import torch_load
# from espnet.asr.pytorch_backend.asr import load_trained_model
# from espnet.asr.pytorch_backend.asr import load_trained_model
# from espnet.nets.lm_interface import dynamic_import_lm
# from espnet.nets.lm_interface import dynamic_import_lm
# from espnet.net
s.asr_interface import ASRInterface
from
deepspeech.model
s.asr_interface
import
ASRInterface
from
.utils
import
add_results_to_json
from
.utils
import
add_results_to_json
# from .batch_beam_search import BatchBeamSearch
# from .batch_beam_search import BatchBeamSearch
from
.beam_search
import
BeamSearch
from
.beam_search
import
BeamSearch
from
.scorer_interface
import
BatchScorerInterface
from
.scorer
s.scorer
_interface
import
BatchScorerInterface
from
.scorers.length_bonus
import
LengthBonus
from
.scorers.length_bonus
import
LengthBonus
from
deepspeech.io.reader
import
LoadInputsAndTargets
from
deepspeech.io.reader
import
LoadInputsAndTargets
...
@@ -21,6 +25,14 @@ from deepspeech.utils.log import Log
...
@@ -21,6 +25,14 @@ from deepspeech.utils.log import Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
from
deepspeech.utils.dynamic_import
import
dynamic_import
from
deepspeech.utils.utility
import
print_arguments
model_test_alias
=
{
"u2"
:
"deepspeech.exps.u2.model:U2Tester"
,
"u2_kaldi"
:
"deepspeech.exps.u2_kaldi.model:U2Tester"
,
}
def
recog_v2
(
args
):
def
recog_v2
(
args
):
"""Decode with custom models that implements ScorerInterface.
"""Decode with custom models that implements ScorerInterface.
...
@@ -36,16 +48,31 @@ def recog_v2(args):
...
@@ -36,16 +48,31 @@ def recog_v2(args):
raise
NotImplementedError
(
"streaming mode is not implemented"
)
raise
NotImplementedError
(
"streaming mode is not implemented"
)
if
args
.
word_rnnlm
:
if
args
.
word_rnnlm
:
raise
NotImplementedError
(
"word LM is not implemented"
)
raise
NotImplementedError
(
"word LM is not implemented"
)
args
.
nprocs
=
args
.
ngpu
# set_deterministic(args)
# set_deterministic(args)
model
,
train_args
=
load_trained_model
(
args
.
model
)
# assert isinstance(model, ASRInterface)
#model, train_args = load_trained_model(args.model)
model
.
eval
()
model_path
=
Path
(
args
.
model
)
ckpt_dir
=
model_path
.
parent
.
parent
confs
=
CfgNode
()
confs
.
set_new_allowed
(
True
)
confs
.
merge_from_file
(
args
.
model_conf
)
class_obj
=
dynamic_import
(
args
.
model_name
,
model_test_alias
)
exp
=
class_obj
(
confs
,
args
)
with
exp
.
eval
():
exp
.
setup
()
exp
.
restore
()
char_list
=
exp
.
args
.
char_list
model
=
exp
.
model
assert
isinstance
(
model
,
ASRInterface
)
load_inputs_and_targets
=
LoadInputsAndTargets
(
load_inputs_and_targets
=
LoadInputsAndTargets
(
mode
=
"asr"
,
mode
=
"asr"
,
load_output
=
False
,
load_output
=
False
,
sort_in_input_length
=
False
,
sort_in_input_length
=
False
,
preprocess_conf
=
train_args
.
preprocess_conf
preprocess_conf
=
confs
.
collator
.
augmentation_config
if
args
.
preprocess_conf
is
None
if
args
.
preprocess_conf
is
None
else
args
.
preprocess_conf
,
else
args
.
preprocess_conf
,
preprocess_args
=
{
"train"
:
False
},
preprocess_args
=
{
"train"
:
False
},
...
@@ -56,7 +83,7 @@ def recog_v2(args):
...
@@ -56,7 +83,7 @@ def recog_v2(args):
# NOTE: for a compatibility with less than 0.5.0 version models
# NOTE: for a compatibility with less than 0.5.0 version models
lm_model_module
=
getattr
(
lm_args
,
"model_module"
,
"default"
)
lm_model_module
=
getattr
(
lm_args
,
"model_module"
,
"default"
)
lm_class
=
dynamic_import_lm
(
lm_model_module
,
lm_args
.
backend
)
lm_class
=
dynamic_import_lm
(
lm_model_module
,
lm_args
.
backend
)
lm
=
lm_class
(
len
(
train_args
.
char_list
),
lm_args
)
lm
=
lm_class
(
len
(
char_list
),
lm_args
)
torch_load
(
args
.
rnnlm
,
lm
)
torch_load
(
args
.
rnnlm
,
lm
)
lm
.
eval
()
lm
.
eval
()
else
:
else
:
...
@@ -67,16 +94,16 @@ def recog_v2(args):
...
@@ -67,16 +94,16 @@ def recog_v2(args):
from
.scorers.ngram
import
NgramPartScorer
from
.scorers.ngram
import
NgramPartScorer
if
args
.
ngram_scorer
==
"full"
:
if
args
.
ngram_scorer
==
"full"
:
ngram
=
NgramFullScorer
(
args
.
ngram_model
,
train_args
.
char_list
)
ngram
=
NgramFullScorer
(
args
.
ngram_model
,
char_list
)
else
:
else
:
ngram
=
NgramPartScorer
(
args
.
ngram_model
,
train_args
.
char_list
)
ngram
=
NgramPartScorer
(
args
.
ngram_model
,
char_list
)
else
:
else
:
ngram
=
None
ngram
=
None
scorers
=
model
.
scorers
()
scorers
=
model
.
scorers
()
scorers
[
"lm"
]
=
lm
scorers
[
"lm"
]
=
lm
scorers
[
"ngram"
]
=
ngram
scorers
[
"ngram"
]
=
ngram
scorers
[
"length_bonus"
]
=
LengthBonus
(
len
(
train_args
.
char_list
))
scorers
[
"length_bonus"
]
=
LengthBonus
(
len
(
char_list
))
weights
=
dict
(
weights
=
dict
(
decoder
=
1.0
-
args
.
ctc_weight
,
decoder
=
1.0
-
args
.
ctc_weight
,
ctc
=
args
.
ctc_weight
,
ctc
=
args
.
ctc_weight
,
...
@@ -86,14 +113,15 @@ def recog_v2(args):
...
@@ -86,14 +113,15 @@ def recog_v2(args):
)
)
beam_search
=
BeamSearch
(
beam_search
=
BeamSearch
(
beam_size
=
args
.
beam_size
,
beam_size
=
args
.
beam_size
,
vocab_size
=
len
(
train_args
.
char_list
),
vocab_size
=
len
(
char_list
),
weights
=
weights
,
weights
=
weights
,
scorers
=
scorers
,
scorers
=
scorers
,
sos
=
model
.
sos
,
sos
=
model
.
sos
,
eos
=
model
.
eos
,
eos
=
model
.
eos
,
token_list
=
train_args
.
char_list
,
token_list
=
char_list
,
pre_beam_score_key
=
None
if
args
.
ctc_weight
==
1.0
else
"full"
,
pre_beam_score_key
=
None
if
args
.
ctc_weight
==
1.0
else
"full"
,
)
)
# TODO(karita): make all scorers batchfied
# TODO(karita): make all scorers batchfied
if
args
.
batchsize
==
1
:
if
args
.
batchsize
==
1
:
non_batch
=
[
non_batch
=
[
...
@@ -116,6 +144,7 @@ def recog_v2(args):
...
@@ -116,6 +144,7 @@ def recog_v2(args):
device
=
"gpu:0"
device
=
"gpu:0"
else
:
else
:
device
=
"cpu"
device
=
"cpu"
paddle
.
set_device
(
device
)
dtype
=
getattr
(
paddle
,
args
.
dtype
)
dtype
=
getattr
(
paddle
,
args
.
dtype
)
logger
.
info
(
f
"Decoding device=
{
device
}
, dtype=
{
dtype
}
"
)
logger
.
info
(
f
"Decoding device=
{
device
}
, dtype=
{
dtype
}
"
)
model
.
to
(
device
=
device
,
dtype
=
dtype
)
model
.
to
(
device
=
device
,
dtype
=
dtype
)
...
@@ -124,31 +153,41 @@ def recog_v2(args):
...
@@ -124,31 +153,41 @@ def recog_v2(args):
beam_search
.
eval
()
beam_search
.
eval
()
# read json data
# read json data
with
open
(
args
.
recog_json
,
"rb"
)
as
f
:
js
=
[]
js
=
json
.
load
(
f
)
with
jsonlines
.
open
(
args
.
recog_json
,
"r"
)
as
reader
:
for
item
in
reader
:
js
.
append
(
item
)
# josnlines to dict, key by 'utt'
# josnlines to dict, key by 'utt'
js
=
{
item
[
'utt'
]:
item
for
item
in
js
}
js
=
{
item
[
'utt'
]:
item
for
item
in
js
}
new_js
=
{}
new_js
=
{}
with
paddle
.
no_grad
():
with
paddle
.
no_grad
():
for
idx
,
name
in
enumerate
(
js
.
keys
(),
1
):
with
jsonlines
.
open
(
args
.
result_label
,
"w"
)
as
f
:
logger
.
info
(
"(%d/%d) decoding "
+
name
,
idx
,
len
(
js
.
keys
()))
for
idx
,
name
in
enumerate
(
js
.
keys
(),
1
):
batch
=
[(
name
,
js
[
name
])]
logger
.
info
(
f
"(
{
idx
}
/
{
len
(
js
.
keys
())
}
) decoding "
+
name
)
feat
=
load_inputs_and_targets
(
batch
)[
0
][
0
]
batch
=
[(
name
,
js
[
name
])]
enc
=
model
.
encode
(
paddle
.
to_tensor
(
feat
).
to
(
device
=
device
,
dtype
=
dtype
))
feat
=
load_inputs_and_targets
(
batch
)[
0
][
0
]
nbest_hyps
=
beam_search
(
logger
.
info
(
f
'feat:
{
feat
.
shape
}
'
)
x
=
enc
,
maxlenratio
=
args
.
maxlenratio
,
minlenratio
=
args
.
minlenratio
enc
=
model
.
encode
(
paddle
.
to_tensor
(
feat
).
to
(
dtype
))
)
logger
.
info
(
f
'eouts:
{
enc
.
shape
}
'
)
nbest_hyps
=
[
nbest_hyps
=
beam_search
(
h
.
asdict
()
for
h
in
nbest_hyps
[:
min
(
len
(
nbest_hyps
),
args
.
nbest
)]
x
=
enc
,
maxlenratio
=
args
.
maxlenratio
,
minlenratio
=
args
.
minlenratio
]
)
new_js
[
name
]
=
add_results_to_json
(
nbest_hyps
=
[
js
[
name
],
nbest_hyps
,
train_args
.
char_list
h
.
asdict
()
for
h
in
nbest_hyps
[:
min
(
len
(
nbest_hyps
),
args
.
nbest
)]
)
]
new_js
[
name
]
=
add_results_to_json
(
with
open
(
args
.
result_label
,
"wb"
)
as
f
:
js
[
name
],
nbest_hyps
,
char_list
f
.
write
(
)
json
.
dumps
(
{
"utts"
:
new_js
},
indent
=
4
,
ensure_ascii
=
False
,
sort_keys
=
True
item
=
new_js
[
name
][
'output'
][
0
]
# 1-best
).
encode
(
"utf_8"
)
utt
=
name
)
ref
=
item
[
'text'
]
rec_text
=
item
[
'rec_text'
].
replace
(
'▁'
,
' '
).
strip
()
rec_tokenid
=
item
[
'rec_tokenid'
].
split
()
f
.
write
({
"utt"
:
utt
,
"refs"
:
[
ref
],
"hyps"
:
[
rec_text
],
"hyps_tokenid"
:
[
rec_tokenid
],
})
\ No newline at end of file
deepspeech/decoders/scorers/ctc.py
浏览文件 @
a6091008
...
@@ -15,8 +15,8 @@
...
@@ -15,8 +15,8 @@
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
from
.ctc_prefix_score
import
CTCPrefixScore
r
from
.ctc_prefix_score
import
CTCPrefixScore
from
.ctc_prefix_score
import
CTCPrefixScore
r
PD
from
.ctc_prefix_score
import
CTCPrefixScorePD
from
.scorer_interface
import
BatchPartialScorerInterface
from
.scorer_interface
import
BatchPartialScorerInterface
...
...
deepspeech/decoders/scorers/ctc_prefix_score.py
浏览文件 @
a6091008
...
@@ -6,7 +6,7 @@ import paddle
...
@@ -6,7 +6,7 @@ import paddle
import
six
import
six
class
CTCPrefixScore
r
PD
():
class
CTCPrefixScorePD
():
"""Batch processing of CTCPrefixScore
"""Batch processing of CTCPrefixScore
which is based on Algorithm 2 in WATANABE et al.
which is based on Algorithm 2 in WATANABE et al.
...
@@ -267,7 +267,7 @@ class CTCPrefixScorerPD():
...
@@ -267,7 +267,7 @@ class CTCPrefixScorerPD():
return
(
r_prev_new
,
s_prev
,
f_min_prev
,
f_max_prev
)
return
(
r_prev_new
,
s_prev
,
f_min_prev
,
f_max_prev
)
class
CTCPrefixScore
r
():
class
CTCPrefixScore
():
"""Compute CTC label sequence scores
"""Compute CTC label sequence scores
which is based on Algorithm 2 in WATANABE et al.
which is based on Algorithm 2 in WATANABE et al.
...
...
deepspeech/decoders/scorers/score_interface.py
→
deepspeech/decoders/scorers/score
r
_interface.py
浏览文件 @
a6091008
文件已移动
deepspeech/decoders/utils.py
浏览文件 @
a6091008
...
@@ -12,7 +12,11 @@
...
@@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
__all__
=
[
"end_detect"
]
import
numpy
as
np
from
deepspeech.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"end_detect"
,
"parse_hypothesis"
,
"add_results_to_json"
]
def
end_detect
(
ended_hyps
,
i
,
M
=
3
,
D_end
=
np
.
log
(
1
*
np
.
exp
(
-
10
))):
def
end_detect
(
ended_hyps
,
i
,
M
=
3
,
D_end
=
np
.
log
(
1
*
np
.
exp
(
-
10
))):
...
@@ -118,7 +122,7 @@ def add_results_to_json(js, nbest_hyps, char_list):
...
@@ -118,7 +122,7 @@ def add_results_to_json(js, nbest_hyps, char_list):
# show 1-best result
# show 1-best result
if
n
==
1
:
if
n
==
1
:
if
"text"
in
out_dic
.
keys
():
if
"text"
in
out_dic
.
keys
():
logg
ing
.
info
(
"groundtruth: %s"
%
out_dic
[
"text"
])
logg
er
.
info
(
"groundtruth: %s"
%
out_dic
[
"text"
])
logg
ing
.
info
(
"prediction : %s"
%
out_dic
[
"rec_text"
])
logg
er
.
info
(
"prediction : %s"
%
out_dic
[
"rec_text"
])
return
new_js
return
new_js
\ No newline at end of file
deepspeech/exps/u2_kaldi/bin/recog.py
0 → 100644
浏览文件 @
a6091008
"""End-to-end speech recognition model decoding script."""
import
configargparse
import
logging
import
os
import
random
import
sys
import
numpy
as
np
from
distutils.util
import
strtobool
from
deepspeech.training.cli
import
default_argument_parser
# NOTE: you need this func to generate our sphinx doc
def
get_parser
():
"""Get default arguments."""
parser
=
configargparse
.
ArgumentParser
(
description
=
"Transcribe text from speech using "
"a speech recognition model on one CPU or GPU"
,
config_file_parser_class
=
configargparse
.
YAMLConfigFileParser
,
formatter_class
=
configargparse
.
ArgumentDefaultsHelpFormatter
,
)
parser
.
add
(
'--model-name'
,
type
=
str
,
default
=
'u2_kaldi'
,
help
=
'model name, e.g: deepspeech2, u2, u2_kaldi, u2_st'
)
# general configuration
parser
.
add
(
"--config"
,
is_config_file
=
True
,
help
=
"Config file path"
)
parser
.
add
(
"--config2"
,
is_config_file
=
True
,
help
=
"Second config file path that overwrites the settings in `--config`"
,
)
parser
.
add
(
"--config3"
,
is_config_file
=
True
,
help
=
"Third config file path that overwrites the settings "
"in `--config` and `--config2`"
,
)
parser
.
add_argument
(
"--ngpu"
,
type
=
int
,
default
=
0
,
help
=
"Number of GPUs"
)
parser
.
add_argument
(
"--dtype"
,
choices
=
(
"float16"
,
"float32"
,
"float64"
),
default
=
"float32"
,
help
=
"Float precision (only available in --api v2)"
,
)
parser
.
add_argument
(
"--debugmode"
,
type
=
int
,
default
=
1
,
help
=
"Debugmode"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
1
,
help
=
"Random seed"
)
parser
.
add_argument
(
"--verbose"
,
"-V"
,
type
=
int
,
default
=
2
,
help
=
"Verbose option"
)
parser
.
add_argument
(
"--batchsize"
,
type
=
int
,
default
=
1
,
help
=
"Batch size for beam search (0: means no batch processing)"
,
)
parser
.
add_argument
(
"--preprocess-conf"
,
type
=
str
,
default
=
None
,
help
=
"The configuration file for the pre-processing"
,
)
parser
.
add_argument
(
"--api"
,
default
=
"v2"
,
choices
=
[
"v2"
],
help
=
"Beam search APIs "
"v2: Experimental API. It supports any models that implements ScorerInterface."
,
)
# task related
parser
.
add_argument
(
"--recog-json"
,
type
=
str
,
help
=
"Filename of recognition data (json)"
)
parser
.
add_argument
(
"--result-label"
,
type
=
str
,
required
=
True
,
help
=
"Filename of result label data (json)"
,
)
# model (parameter) related
parser
.
add_argument
(
"--model"
,
type
=
str
,
required
=
True
,
help
=
"Model file parameters to read"
)
parser
.
add_argument
(
"--model-conf"
,
type
=
str
,
default
=
None
,
help
=
"Model config file"
)
parser
.
add_argument
(
"--num-spkrs"
,
type
=
int
,
default
=
1
,
choices
=
[
1
,
2
],
help
=
"Number of speakers in the speech"
,
)
parser
.
add_argument
(
"--num-encs"
,
default
=
1
,
type
=
int
,
help
=
"Number of encoders in the model."
)
# search related
parser
.
add_argument
(
"--nbest"
,
type
=
int
,
default
=
1
,
help
=
"Output N-best hypotheses"
)
parser
.
add_argument
(
"--beam-size"
,
type
=
int
,
default
=
1
,
help
=
"Beam size"
)
parser
.
add_argument
(
"--penalty"
,
type
=
float
,
default
=
0.0
,
help
=
"Incertion penalty"
)
parser
.
add_argument
(
"--maxlenratio"
,
type
=
float
,
default
=
0.0
,
help
=
"""Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths.
If maxlenratio<0.0, its absolute value is interpreted
as a constant max output length"""
,
)
parser
.
add_argument
(
"--minlenratio"
,
type
=
float
,
default
=
0.0
,
help
=
"Input length ratio to obtain min output length"
,
)
parser
.
add_argument
(
"--ctc-weight"
,
type
=
float
,
default
=
0.0
,
help
=
"CTC weight in joint decoding"
)
parser
.
add_argument
(
"--weights-ctc-dec"
,
type
=
float
,
action
=
"append"
,
help
=
"ctc weight assigned to each encoder during decoding."
"[in multi-encoder mode only]"
,
)
parser
.
add_argument
(
"--ctc-window-margin"
,
type
=
int
,
default
=
0
,
help
=
"""Use CTC window with margin parameter to accelerate
CTC/attention decoding especially on GPU. Smaller magin
makes decoding faster, but may increase search errors.
If margin=0 (default), this function is disabled"""
,
)
# transducer related
parser
.
add_argument
(
"--search-type"
,
type
=
str
,
default
=
"default"
,
choices
=
[
"default"
,
"nsc"
,
"tsd"
,
"alsd"
,
"maes"
],
help
=
"""Type of beam search implementation to use during inference.
Can be either: default beam search ("default"),
N-Step Constrained beam search ("nsc"), Time-Synchronous Decoding ("tsd"),
Alignment-Length Synchronous Decoding ("alsd") or
modified Adaptive Expansion Search ("maes")."""
,
)
parser
.
add_argument
(
"--nstep"
,
type
=
int
,
default
=
1
,
help
=
"""Number of expansion steps allowed in NSC beam search or mAES
(nstep > 0 for NSC and nstep > 1 for mAES)."""
,
)
parser
.
add_argument
(
"--prefix-alpha"
,
type
=
int
,
default
=
2
,
help
=
"Length prefix difference allowed in NSC beam search or mAES."
,
)
parser
.
add_argument
(
"--max-sym-exp"
,
type
=
int
,
default
=
2
,
help
=
"Number of symbol expansions allowed in TSD."
,
)
parser
.
add_argument
(
"--u-max"
,
type
=
int
,
default
=
400
,
help
=
"Length prefix difference allowed in ALSD."
,
)
parser
.
add_argument
(
"--expansion-gamma"
,
type
=
float
,
default
=
2.3
,
help
=
"Allowed logp difference for prune-by-value method in mAES."
,
)
parser
.
add_argument
(
"--expansion-beta"
,
type
=
int
,
default
=
2
,
help
=
"""Number of additional candidates for expanded hypotheses
selection in mAES."""
,
)
parser
.
add_argument
(
"--score-norm"
,
type
=
strtobool
,
nargs
=
"?"
,
default
=
True
,
help
=
"Normalize final hypotheses' score by length"
,
)
parser
.
add_argument
(
"--softmax-temperature"
,
type
=
float
,
default
=
1.0
,
help
=
"Penalization term for softmax function."
,
)
# rnnlm related
parser
.
add_argument
(
"--rnnlm"
,
type
=
str
,
default
=
None
,
help
=
"RNNLM model file to read"
)
parser
.
add_argument
(
"--rnnlm-conf"
,
type
=
str
,
default
=
None
,
help
=
"RNNLM model config file to read"
)
parser
.
add_argument
(
"--word-rnnlm"
,
type
=
str
,
default
=
None
,
help
=
"Word RNNLM model file to read"
)
parser
.
add_argument
(
"--word-rnnlm-conf"
,
type
=
str
,
default
=
None
,
help
=
"Word RNNLM model config file to read"
,
)
parser
.
add_argument
(
"--word-dict"
,
type
=
str
,
default
=
None
,
help
=
"Word list to read"
)
parser
.
add_argument
(
"--lm-weight"
,
type
=
float
,
default
=
0.1
,
help
=
"RNNLM weight"
)
# ngram related
parser
.
add_argument
(
"--ngram-model"
,
type
=
str
,
default
=
None
,
help
=
"ngram model file to read"
)
parser
.
add_argument
(
"--ngram-weight"
,
type
=
float
,
default
=
0.1
,
help
=
"ngram weight"
)
parser
.
add_argument
(
"--ngram-scorer"
,
type
=
str
,
default
=
"part"
,
choices
=
(
"full"
,
"part"
),
help
=
"""if the ngram is set as a part scorer, similar with CTC scorer,
ngram scorer only scores topK hypethesis.
if the ngram is set as full scorer, ngram scorer scores all hypthesis
the decoding speed of part scorer is musch faster than full one"""
,
)
# streaming related
parser
.
add_argument
(
"--streaming-mode"
,
type
=
str
,
default
=
None
,
choices
=
[
"window"
,
"segment"
],
help
=
"""Use streaming recognizer for inference.
`--batchsize` must be set to 0 to enable this mode"""
,
)
parser
.
add_argument
(
"--streaming-window"
,
type
=
int
,
default
=
10
,
help
=
"Window size"
)
parser
.
add_argument
(
"--streaming-min-blank-dur"
,
type
=
int
,
default
=
10
,
help
=
"Minimum blank duration threshold"
,
)
parser
.
add_argument
(
"--streaming-onset-margin"
,
type
=
int
,
default
=
1
,
help
=
"Onset margin"
)
parser
.
add_argument
(
"--streaming-offset-margin"
,
type
=
int
,
default
=
1
,
help
=
"Offset margin"
)
# non-autoregressive related
# Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail.
parser
.
add_argument
(
"--maskctc-n-iterations"
,
type
=
int
,
default
=
10
,
help
=
"Number of decoding iterations."
"For Mask CTC, set 0 to predict 1 mask/iter."
,
)
parser
.
add_argument
(
"--maskctc-probability-threshold"
,
type
=
float
,
default
=
0.999
,
help
=
"Threshold probability for CTC output"
,
)
# quantize model related
parser
.
add_argument
(
"--quantize-config"
,
nargs
=
"*"
,
help
=
"Quantize config list. E.g.: --quantize-config=[Linear,LSTM,GRU]"
,
)
parser
.
add_argument
(
"--quantize-dtype"
,
type
=
str
,
default
=
"qint8"
,
help
=
"Dtype dynamic quantize"
)
parser
.
add_argument
(
"--quantize-asr-model"
,
type
=
bool
,
default
=
False
,
help
=
"Quantize asr model"
,
)
parser
.
add_argument
(
"--quantize-lm-model"
,
type
=
bool
,
default
=
False
,
help
=
"Quantize lm model"
,
)
return
parser
def
main
(
args
):
"""Run the main decoding function."""
parser
=
get_parser
()
parser
.
add_argument
(
"--output"
,
metavar
=
"CKPT_DIR"
,
help
=
"path to save checkpoint."
)
parser
.
add_argument
(
"--checkpoint_path"
,
type
=
str
,
help
=
"path to load checkpoint"
)
parser
.
add_argument
(
"--dict-path"
,
type
=
str
,
help
=
"path to load checkpoint"
)
# parser = default_argument_parser(parser)
args
=
parser
.
parse_args
(
args
)
if
args
.
ngpu
==
0
and
args
.
dtype
==
"float16"
:
raise
ValueError
(
f
"--dtype
{
args
.
dtype
}
does not support the CPU backend."
)
# logging info
if
args
.
verbose
==
1
:
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
"%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
,
)
elif
args
.
verbose
==
2
:
logging
.
basicConfig
(
level
=
logging
.
DEBUG
,
format
=
"%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
,
)
else
:
logging
.
basicConfig
(
level
=
logging
.
WARN
,
format
=
"%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
,
)
logging
.
warning
(
"Skip DEBUG/INFO messages"
)
logging
.
info
(
args
)
# check CUDA_VISIBLE_DEVICES
if
args
.
ngpu
>
0
:
cvd
=
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
)
if
cvd
is
None
:
logging
.
warning
(
"CUDA_VISIBLE_DEVICES is not set."
)
elif
args
.
ngpu
!=
len
(
cvd
.
split
(
","
)):
logging
.
error
(
"#gpus is not matched with CUDA_VISIBLE_DEVICES."
)
sys
.
exit
(
1
)
# TODO(mn5k): support of multiple GPUs
if
args
.
ngpu
>
1
:
logging
.
error
(
"The program only supports ngpu=1."
)
sys
.
exit
(
1
)
# display PYTHONPATH
logging
.
info
(
"python path = "
+
os
.
environ
.
get
(
"PYTHONPATH"
,
"(None)"
))
# seed setting
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
logging
.
info
(
"set random seed = %d"
%
args
.
seed
)
# validate rnn options
if
args
.
rnnlm
is
not
None
and
args
.
word_rnnlm
is
not
None
:
logging
.
error
(
"It seems that both --rnnlm and --word-rnnlm are specified. "
"Please use either option."
)
sys
.
exit
(
1
)
# recog
if
args
.
num_spkrs
==
1
:
if
args
.
num_encs
==
1
:
# Experimental API that supports custom LMs
if
args
.
api
==
"v2"
:
from
deepspeech.decoders.recog
import
recog_v2
recog_v2
(
args
)
else
:
raise
ValueError
(
"Only support --api v2"
)
else
:
if
args
.
api
==
"v2"
:
raise
NotImplementedError
(
f
"--num-encs
{
args
.
num_encs
}
> 1 is not supported in --api v2"
)
elif
args
.
num_spkrs
==
2
:
raise
ValueError
(
"asr_mix not supported."
)
if
__name__
==
"__main__"
:
main
(
sys
.
argv
[
1
:])
deepspeech/exps/u2_kaldi/model.py
浏览文件 @
a6091008
...
@@ -317,11 +317,9 @@ class U2Trainer(Trainer):
...
@@ -317,11 +317,9 @@ class U2Trainer(Trainer):
with
UpdateConfig
(
model_conf
):
with
UpdateConfig
(
model_conf
):
model_conf
.
input_dim
=
self
.
train_loader
.
feat_dim
model_conf
.
input_dim
=
self
.
train_loader
.
feat_dim
model_conf
.
output_dim
=
self
.
train_loader
.
vocab_size
model_conf
.
output_dim
=
self
.
train_loader
.
vocab_size
model
=
U2Model
.
from_config
(
model_conf
)
model
=
U2Model
.
from_config
(
model_conf
)
if
self
.
parallel
:
if
self
.
parallel
:
model
=
paddle
.
DataParallel
(
model
)
model
=
paddle
.
DataParallel
(
model
)
logger
.
info
(
f
"
{
model
}
"
)
layer_tools
.
print_params
(
model
,
logger
.
info
)
layer_tools
.
print_params
(
model
,
logger
.
info
)
# lr
# lr
...
...
deepspeech/frontend/featurizer/text_featurizer.py
浏览文件 @
a6091008
...
@@ -207,7 +207,7 @@ class TextFeaturizer():
...
@@ -207,7 +207,7 @@ class TextFeaturizer():
"""Load vocabulary from file."""
"""Load vocabulary from file."""
vocab_list
=
load_dict
(
vocab_filepath
,
maskctc
)
vocab_list
=
load_dict
(
vocab_filepath
,
maskctc
)
assert
vocab_list
is
not
None
assert
vocab_list
is
not
None
logger
.
info
(
f
"Vocab:
{
pformat
(
vocab_list
)
}
"
)
logger
.
debug
(
f
"Vocab:
{
pformat
(
vocab_list
)
}
"
)
id2token
=
dict
(
id2token
=
dict
(
[(
idx
,
token
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
[(
idx
,
token
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
...
...
deepspeech/models/u2/u2.py
浏览文件 @
a6091008
...
@@ -50,7 +50,7 @@ from deepspeech.utils.tensor_utils import th_accuracy
...
@@ -50,7 +50,7 @@ from deepspeech.utils.tensor_utils import th_accuracy
from
deepspeech.utils.utility
import
log_add
from
deepspeech.utils.utility
import
log_add
from
deepspeech.utils.utility
import
UpdateConfig
from
deepspeech.utils.utility
import
UpdateConfig
from
deepspeech.models.asr_interface
import
ASRInterface
from
deepspeech.models.asr_interface
import
ASRInterface
from
deepspeech.decoders.scorers.ctc
_prefix_score
import
CTCPrefixScorer
from
deepspeech.decoders.scorers.ctc
import
CTCPrefixScorer
__all__
=
[
"U2Model"
,
"U2InferModel"
]
__all__
=
[
"U2Model"
,
"U2InferModel"
]
...
...
deepspeech/modules/decoder.py
浏览文件 @
a6091008
...
@@ -28,7 +28,7 @@ from deepspeech.modules.mask import make_non_pad_mask
...
@@ -28,7 +28,7 @@ from deepspeech.modules.mask import make_non_pad_mask
from
deepspeech.modules.mask
import
subsequent_mask
from
deepspeech.modules.mask
import
subsequent_mask
from
deepspeech.modules.mask
import
make_xs_mask
from
deepspeech.modules.mask
import
make_xs_mask
from
deepspeech.modules.positionwise_feed_forward
import
PositionwiseFeedForward
from
deepspeech.modules.positionwise_feed_forward
import
PositionwiseFeedForward
from
deepspeech.decoders.scorers.score_interface
import
BatchScorerInterface
from
deepspeech.decoders.scorers.score
r
_interface
import
BatchScorerInterface
from
deepspeech.utils.log
import
Log
from
deepspeech.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
@@ -191,8 +191,8 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
...
@@ -191,8 +191,8 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
ys: (ylen,)
ys: (ylen,)
x: (xlen, n_feat)
x: (xlen, n_feat)
"""
"""
ys_mask
=
subsequent_mask
(
len
(
ys
)).
unsqueeze
(
0
)
ys_mask
=
subsequent_mask
(
len
(
ys
)).
unsqueeze
(
0
)
# (B,L,L)
x_mask
=
make_xs_mask
(
x
.
unsqueeze
(
0
))
x_mask
=
make_xs_mask
(
x
.
unsqueeze
(
0
))
.
unsqueeze
(
1
)
# (B,1,T)
if
self
.
selfattention_layer_type
!=
"selfattn"
:
if
self
.
selfattention_layer_type
!=
"selfattn"
:
# TODO(karita): implement cache
# TODO(karita): implement cache
logging
.
warning
(
logging
.
warning
(
...
@@ -237,9 +237,8 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
...
@@ -237,9 +237,8 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
]
]
# batch decoding
# batch decoding
ys_mask
=
subsequent_mask
(
ys
.
size
(
-
1
)).
unsqueeze
(
0
)
ys_mask
=
subsequent_mask
(
ys
.
size
(
-
1
)).
unsqueeze
(
0
)
# (B,L,L)
xs_mask
=
make_xs_mask
(
xs
).
unsqueeze
(
1
)
# (B,1,T)
xs_mask
=
make_xs_mask
(
xs
)
logp
,
states
=
self
.
forward_one_step
(
xs
,
xs_mask
,
ys
,
ys_mask
,
cache
=
batch_state
)
logp
,
states
=
self
.
forward_one_step
(
xs
,
xs_mask
,
ys
,
ys_mask
,
cache
=
batch_state
)
# transpose state of [layer, batch] into [batch, layer]
# transpose state of [layer, batch] into [batch, layer]
...
...
deepspeech/modules/mask.py
浏览文件 @
a6091008
...
@@ -24,15 +24,16 @@ __all__ = [
...
@@ -24,15 +24,16 @@ __all__ = [
]
]
def
make_xs_mask
(
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
def
make_xs_mask
(
xs
:
paddle
.
Tensor
,
pad_value
=
0.0
)
->
paddle
.
Tensor
:
"""Maks mask tensor containing indices of non-padded part.
"""Maks mask tensor containing indices of non-padded part.
Args:
Args:
xs (paddle.Tensor): (B, T, D), zeros for pad.
xs (paddle.Tensor): (B, T, D), zeros for pad.
Returns:
Returns:
paddle.Tensor: Mask Tensor indices of non-padded part. (B, T
, D
)
paddle.Tensor: Mask Tensor indices of non-padded part. (B, T)
"""
"""
pad_frame
=
paddle
.
zeros
([
1
,
1
,
xs
.
shape
[
-
1
]]
,
dtype
=
xs
.
dtype
)
pad_frame
=
paddle
.
full
([
1
,
1
,
xs
.
shape
[
-
1
]],
pad_value
,
dtype
=
xs
.
dtype
)
mask
=
xs
!=
pad_frame
mask
=
xs
!=
pad_frame
mask
=
mask
.
all
(
axis
=-
1
)
return
mask
return
mask
...
...
deepspeech/training/cli.py
浏览文件 @
a6091008
...
@@ -35,7 +35,7 @@ class LoadFromFile(argparse.Action):
...
@@ -35,7 +35,7 @@ class LoadFromFile(argparse.Action):
parser
.
parse_args
(
f
.
read
().
split
(),
namespace
)
parser
.
parse_args
(
f
.
read
().
split
(),
namespace
)
def
default_argument_parser
():
def
default_argument_parser
(
parser
=
None
):
r
"""A simple yet genral argument parser for experiments with parakeet.
r
"""A simple yet genral argument parser for experiments with parakeet.
This is used in examples with parakeet. And it is intended to be used by
This is used in examples with parakeet. And it is intended to be used by
...
@@ -62,7 +62,9 @@ def default_argument_parser():
...
@@ -62,7 +62,9 @@ def default_argument_parser():
argparse.ArgumentParser
argparse.ArgumentParser
the parser
the parser
"""
"""
parser
=
argparse
.
ArgumentParser
()
if
parser
is
None
:
parser
=
argparse
.
ArgumentParser
()
parser
.
register
(
'action'
,
'extend'
,
ExtendAction
)
parser
.
register
(
'action'
,
'extend'
,
ExtendAction
)
parser
.
add_argument
(
parser
.
add_argument
(
'--conf'
,
type
=
open
,
action
=
LoadFromFile
,
help
=
"config file."
)
'--conf'
,
type
=
open
,
action
=
LoadFromFile
,
help
=
"config file."
)
...
...
deepspeech/training/trainer.py
浏览文件 @
a6091008
...
@@ -126,7 +126,7 @@ class Trainer():
...
@@ -126,7 +126,7 @@ class Trainer():
logger
.
info
(
f
"Set seed
{
args
.
seed
}
"
)
logger
.
info
(
f
"Set seed
{
args
.
seed
}
"
)
# profiler and benchmark options
# profiler and benchmark options
if
self
.
args
.
benchmark_batch_size
:
if
hasattr
(
self
.
args
,
"benchmark_batch_size"
)
and
self
.
args
.
benchmark_batch_size
:
with
UpdateConfig
(
self
.
config
):
with
UpdateConfig
(
self
.
config
):
self
.
config
.
collator
.
batch_size
=
self
.
args
.
benchmark_batch_size
self
.
config
.
collator
.
batch_size
=
self
.
args
.
benchmark_batch_size
self
.
config
.
training
.
log_interval
=
1
self
.
config
.
training
.
log_interval
=
1
...
@@ -326,12 +326,25 @@ class Trainer():
...
@@ -326,12 +326,25 @@ class Trainer():
finally
:
finally
:
self
.
destory
()
self
.
destory
()
def
restore
(
self
):
"""Resume from latest checkpoint at checkpoints in the output
directory or load a specified checkpoint.
If ``args.checkpoint_path`` is not None, load the checkpoint, else
resume training.
"""
assert
self
.
args
.
checkpoint_path
infos
=
self
.
checkpoint
.
load_latest_parameters
(
self
.
model
,
checkpoint_path
=
self
.
args
.
checkpoint_path
)
return
infos
def
run_test
(
self
):
def
run_test
(
self
):
"""Do Test/Decode"""
"""Do Test/Decode"""
try
:
try
:
with
Timer
(
"Test/Decode Done: {}"
):
with
Timer
(
"Test/Decode Done: {}"
):
with
self
.
eval
():
with
self
.
eval
():
self
.
res
ume_or_scratch
()
self
.
res
tore
()
self
.
test
()
self
.
test
()
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
exit
(
-
1
)
exit
(
-
1
)
...
@@ -341,6 +354,7 @@ class Trainer():
...
@@ -341,6 +354,7 @@ class Trainer():
try
:
try
:
with
Timer
(
"Export Done: {}"
):
with
Timer
(
"Export Done: {}"
):
with
self
.
eval
():
with
self
.
eval
():
self
.
restore
()
self
.
export
()
self
.
export
()
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
exit
(
-
1
)
exit
(
-
1
)
...
@@ -350,7 +364,7 @@ class Trainer():
...
@@ -350,7 +364,7 @@ class Trainer():
try
:
try
:
with
Timer
(
"Align Done: {}"
):
with
Timer
(
"Align Done: {}"
):
with
self
.
eval
():
with
self
.
eval
():
self
.
res
ume_or_scratch
()
self
.
res
tore
()
self
.
align
()
self
.
align
()
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
sys
.
exit
(
-
1
)
sys
.
exit
(
-
1
)
...
...
examples/librispeech/s2/conf/decode/decode.yaml
0 → 100644
浏览文件 @
a6091008
batchsize
:
0
beam-size
:
60
ctc-weight
:
0.4
lm-weight
:
0.0
maxlenratio
:
0.0
minlenratio
:
0.0
penalty
:
0.0
examples/librispeech/s2/conf/decode/decode_all.yaml
0 → 100644
浏览文件 @
a6091008
batchsize
:
0
beam-size
:
60
ctc-weight
:
0.4
lm-weight
:
0.6
maxlenratio
:
0.0
minlenratio
:
0.0
penalty
:
0.0
\ No newline at end of file
examples/librispeech/s2/conf/decode/decode_wo_lm.yaml
0 → 100644
浏览文件 @
a6091008
batchsize
:
0
beam-size
:
60
ctc-weight
:
0.4
lm-weight
:
0.0
maxlenratio
:
0.0
minlenratio
:
0.0
penalty
:
0.0
\ No newline at end of file
examples/librispeech/s2/local/recog.sh
0 → 100755
浏览文件 @
a6091008
#!/bin/bash
set
-e
expdir
=
exp
datadir
=
data
nj
=
32
decode_config
=
conf/decode/decode.yaml
lang_model
=
rnnlm.model.best
lmexpdir
=
exp/train_rnnlm_pytorch_lm_transformer_cosine_batchsize32_lr1e-4_layer16_unigram5000_ngpu4/
lmtag
=
'nolm'
recog_set
=
"test-clean test-other dev-clean dev-other"
recog_set
=
"test-clean"
# bpemode (unigram or bpe)
nbpe
=
5000
bpemode
=
unigram
bpeprefix
=
"data/bpe_
${
bpemode
}
_
${
nbpe
}
"
bpemodel
=
${
bpeprefix
}
.model
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path dict_path ckpt_path_prefix"
exit
-1
fi
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
echo
"using
$ngpu
gpus..."
config_path
=
$1
dict
=
$2
ckpt_prefix
=
$3
ckpt_dir
=
$(
dirname
`
dirname
${
ckpt_prefix
}
`
)
echo
"ckpt dir:
${
ckpt_dir
}
"
ckpt_tag
=
$(
basename
${
ckpt_prefix
}
)
echo
"ckpt tag:
${
ckpt_tag
}
"
chunk_mode
=
false
if
[[
${
config_path
}
=
~ ^.
*
chunk_.
*
yaml
$
]]
;
then
chunk_mode
=
true
fi
echo
"chunk mode:
${
chunk_mode
}
"
echo
"decode conf:
${
decode_config
}
"
# download language model
#bash local/download_lm_en.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
pids
=()
# initialize pids
for
dmethd
in
join_ctc
;
do
(
echo
"
${
dmethd
}
decoding"
for
rtask
in
${
recog_set
}
;
do
(
echo
"
${
rtask
}
dataset"
decode_dir
=
${
ckpt_dir
}
/decode/decode_
${
rtask
/-/_
}
_
${
dmethd
}
_
$(
basename
${
config_path
%.*
}
)
_
${
lmtag
}
_
${
ckpt_tag
}
feat_recog_dir
=
${
datadir
}
mkdir
-p
${
decode_dir
}
mkdir
-p
${
feat_recog_dir
}
# split data
split_json.sh manifest.
${
rtask
}
${
nj
}
#### use CPU for decoding
ngpu
=
0
# set batchsize 0 to disable batch decoding
${
decode_cmd
}
JOB
=
1:
${
nj
}
${
decode_dir
}
/log/decode.JOB.log
\
python3
-u
${
BIN_DIR
}
/recog.py
\
--api
v2
\
--config
${
decode_config
}
\
--ngpu
${
ngpu
}
\
--batchsize
0
\
--checkpoint_path
${
ckpt_prefix
}
\
--dict-path
${
dict
}
\
--recog-json
${
feat_recog_dir
}
/split
${
nj
}
/JOB/manifest.
${
rtask
}
\
--result-label
${
decode_dir
}
/data.JOB.json
\
--model-conf
${
config_path
}
\
--model
${
ckpt_prefix
}
.pdparams
#--rnnlm ${lmexpdir}/${lang_model} \
score_sclite.sh
--bpe
${
nbpe
}
--bpemodel
${
bpemodel
}
--wer
false
${
decode_dir
}
${
dict
}
)
&
pids+
=(
$!
)
# store background pids
i
=
0
;
for
pid
in
"
${
pids
[@]
}
"
;
do
wait
${
pid
}
||
((
++i
))
;
done
[
${
i
}
-gt
0
]
&&
echo
"
$0
:
${
i
}
background jobs are failed."
||
true
done
)
done
echo
"Finished"
exit
0
examples/librispeech/s2/local/test.sh
浏览文件 @
a6091008
...
@@ -83,7 +83,7 @@ for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_resco
...
@@ -83,7 +83,7 @@ for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_resco
--opts
decoding.batch_size
${
batch_size
}
\
--opts
decoding.batch_size
${
batch_size
}
\
--opts
data.test_manifest
${
feat_recog_dir
}
/split
${
nj
}
/JOB/manifest.
${
rtask
}
--opts
data.test_manifest
${
feat_recog_dir
}
/split
${
nj
}
/JOB/manifest.
${
rtask
}
score_sclite.sh
--bpe
${
nbpe
}
--bpemodel
${
bpemodel
}
--wer
false
${
expdir
}
/
${
decode_dir
}
${
dict
}
score_sclite.sh
--bpe
${
nbpe
}
--bpemodel
${
bpemodel
}
--wer
false
${
decode_dir
}
${
dict
}
)
&
)
&
pids+
=(
$!
)
# store background pids
pids+
=(
$!
)
# store background pids
...
...
requirements.txt
浏览文件 @
a6091008
...
@@ -40,3 +40,4 @@ pyworld
...
@@ -40,3 +40,4 @@ pyworld
jieba
jieba
phkit
phkit
yq
yq
ConfigArgParse
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录