Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
03a50d7b
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
03a50d7b
编写于
8月 26, 2021
作者:
H
huangyuxin
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/DeepSpeech
into fix_bug
上级
8d506270
794294e9
变更
83
隐藏空白更改
内联
并排
Showing
83 changed file
with
1776 addition
and
438 deletion
+1776
-438
.flake8
.flake8
+4
-0
deepspeech/__init__.py
deepspeech/__init__.py
+1
-40
deepspeech/decoders/swig/setup.py
deepspeech/decoders/swig/setup.py
+5
-2
deepspeech/exps/deepspeech2/bin/export.py
deepspeech/exps/deepspeech2/bin/export.py
+3
-0
deepspeech/exps/deepspeech2/bin/test.py
deepspeech/exps/deepspeech2/bin/test.py
+3
-0
deepspeech/exps/u2/bin/alignment.py
deepspeech/exps/u2/bin/alignment.py
+3
-0
deepspeech/exps/u2/bin/export.py
deepspeech/exps/u2/bin/export.py
+3
-0
deepspeech/exps/u2/bin/test.py
deepspeech/exps/u2/bin/test.py
+3
-0
deepspeech/exps/u2_kaldi/bin/test.py
deepspeech/exps/u2_kaldi/bin/test.py
+10
-0
deepspeech/exps/u2_kaldi/model.py
deepspeech/exps/u2_kaldi/model.py
+26
-17
deepspeech/exps/u2_st/bin/export.py
deepspeech/exps/u2_st/bin/export.py
+3
-0
deepspeech/exps/u2_st/bin/test.py
deepspeech/exps/u2_st/bin/test.py
+3
-0
deepspeech/frontend/augmentor/impulse_response.py
deepspeech/frontend/augmentor/impulse_response.py
+1
-1
deepspeech/frontend/augmentor/noise_perturb.py
deepspeech/frontend/augmentor/noise_perturb.py
+1
-1
deepspeech/frontend/augmentor/online_bayesian_normalization.py
...peech/frontend/augmentor/online_bayesian_normalization.py
+1
-1
deepspeech/frontend/augmentor/resample.py
deepspeech/frontend/augmentor/resample.py
+1
-1
deepspeech/frontend/augmentor/shift_perturb.py
deepspeech/frontend/augmentor/shift_perturb.py
+1
-1
deepspeech/frontend/augmentor/spec_augment.py
deepspeech/frontend/augmentor/spec_augment.py
+98
-19
deepspeech/frontend/augmentor/speed_perturb.py
deepspeech/frontend/augmentor/speed_perturb.py
+1
-1
deepspeech/frontend/augmentor/volume_perturb.py
deepspeech/frontend/augmentor/volume_perturb.py
+1
-1
deepspeech/frontend/featurizer/__init__.py
deepspeech/frontend/featurizer/__init__.py
+3
-0
deepspeech/frontend/featurizer/audio_featurizer.py
deepspeech/frontend/featurizer/audio_featurizer.py
+48
-38
deepspeech/frontend/featurizer/speech_featurizer.py
deepspeech/frontend/featurizer/speech_featurizer.py
+1
-1
deepspeech/frontend/featurizer/text_featurizer.py
deepspeech/frontend/featurizer/text_featurizer.py
+28
-45
deepspeech/frontend/normalizer.py
deepspeech/frontend/normalizer.py
+7
-7
deepspeech/frontend/utility.py
deepspeech/frontend/utility.py
+40
-10
deepspeech/io/collator.py
deepspeech/io/collator.py
+1
-2
deepspeech/io/collator_st.py
deepspeech/io/collator_st.py
+35
-69
deepspeech/io/converter.py
deepspeech/io/converter.py
+1
-1
deepspeech/io/dataloader.py
deepspeech/io/dataloader.py
+18
-6
deepspeech/io/dataset.py
deepspeech/io/dataset.py
+6
-5
deepspeech/models/ds2_online/deepspeech2.py
deepspeech/models/ds2_online/deepspeech2.py
+16
-14
deepspeech/models/u2_st.py
deepspeech/models/u2_st.py
+7
-7
deepspeech/training/cli.py
deepspeech/training/cli.py
+16
-20
deepspeech/training/extensions/__init__.py
deepspeech/training/extensions/__init__.py
+41
-0
deepspeech/training/extensions/evaluator.py
deepspeech/training/extensions/evaluator.py
+71
-0
deepspeech/training/extensions/extension.py
deepspeech/training/extensions/extension.py
+52
-0
deepspeech/training/extensions/snapshot.py
deepspeech/training/extensions/snapshot.py
+114
-0
deepspeech/training/extensions/visualizer.py
deepspeech/training/extensions/visualizer.py
+37
-0
deepspeech/training/reporter.py
deepspeech/training/reporter.py
+144
-0
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+10
-3
deepspeech/training/triggers/__init__.py
deepspeech/training/triggers/__init__.py
+28
-0
deepspeech/training/triggers/interval_trigger.py
deepspeech/training/triggers/interval_trigger.py
+38
-0
deepspeech/training/triggers/limit_trigger.py
deepspeech/training/triggers/limit_trigger.py
+31
-0
deepspeech/training/triggers/time_trigger.py
deepspeech/training/triggers/time_trigger.py
+32
-0
deepspeech/training/updaters/__init__.py
deepspeech/training/updaters/__init__.py
+13
-0
deepspeech/training/updaters/standard_updater.py
deepspeech/training/updaters/standard_updater.py
+192
-0
deepspeech/training/updaters/trainer.py
deepspeech/training/updaters/trainer.py
+184
-0
deepspeech/training/updaters/updater.py
deepspeech/training/updaters/updater.py
+83
-0
deepspeech/utils/utility.py
deepspeech/utils/utility.py
+11
-1
examples/aishell/s0/README.md
examples/aishell/s0/README.md
+7
-1
examples/aishell/s0/conf/augmentation.json
examples/aishell/s0/conf/augmentation.json
+5
-3
examples/aishell/s0/local/train.sh
examples/aishell/s0/local/train.sh
+11
-1
examples/aishell/s1/conf/augmentation.json
examples/aishell/s1/conf/augmentation.json
+3
-1
examples/aishell/s1/local/train.sh
examples/aishell/s1/local/train.sh
+11
-1
examples/aug_conf/augmentation.json
examples/aug_conf/augmentation.json
+0
-10
examples/augmentation/augmentation.json
examples/augmentation/augmentation.json
+6
-4
examples/callcenter/s1/conf/augmentation.json
examples/callcenter/s1/conf/augmentation.json
+2
-1
examples/callcenter/s1/local/train.sh
examples/callcenter/s1/local/train.sh
+11
-1
examples/librispeech/s0/conf/augmentation.json
examples/librispeech/s0/conf/augmentation.json
+5
-3
examples/librispeech/s0/local/train.sh
examples/librispeech/s0/local/train.sh
+11
-1
examples/librispeech/s1/conf/augmentation.json
examples/librispeech/s1/conf/augmentation.json
+3
-1
examples/librispeech/s1/local/train.sh
examples/librispeech/s1/local/train.sh
+11
-1
examples/librispeech/s2/conf/augmentation.json
examples/librispeech/s2/conf/augmentation.json
+6
-4
examples/librispeech/s2/conf/transformer.yaml
examples/librispeech/s2/conf/transformer.yaml
+15
-24
examples/librispeech/s2/local/align.sh
examples/librispeech/s2/local/align.sh
+8
-5
examples/librispeech/s2/local/export.sh
examples/librispeech/s2/local/export.sh
+2
-1
examples/librispeech/s2/local/test.sh
examples/librispeech/s2/local/test.sh
+12
-7
examples/librispeech/s2/local/train.sh
examples/librispeech/s2/local/train.sh
+11
-1
examples/librispeech/s2/run.sh
examples/librispeech/s2/run.sh
+3
-2
examples/ted_en_zh/t0/local/train.sh
examples/ted_en_zh/t0/local/train.sh
+11
-1
examples/thchs30/a0/local/data.sh
examples/thchs30/a0/local/data.sh
+22
-16
examples/thchs30/a0/local/gen_word2phone.py
examples/thchs30/a0/local/gen_word2phone.py
+38
-18
examples/thchs30/a0/local/reorganize_thchs30.py
examples/thchs30/a0/local/reorganize_thchs30.py
+5
-4
examples/thchs30/a0/run.sh
examples/thchs30/a0/run.sh
+9
-6
examples/timit/s1/conf/augmentation.json
examples/timit/s1/conf/augmentation.json
+3
-1
examples/timit/s1/local/train.sh
examples/timit/s1/local/train.sh
+11
-1
examples/tiny/s0/conf/augmentation.json
examples/tiny/s0/conf/augmentation.json
+26
-0
examples/tiny/s0/local/train.sh
examples/tiny/s0/local/train.sh
+11
-1
examples/tiny/s1/conf/augmentation.json
examples/tiny/s1/conf/augmentation.json
+3
-1
examples/tiny/s1/local/train.sh
examples/tiny/s1/local/train.sh
+11
-1
requirements.txt
requirements.txt
+2
-0
tools/extras/install_mfa.sh
tools/extras/install_mfa.sh
+1
-1
未找到文件。
.flake8
浏览文件 @
03a50d7b
...
...
@@ -42,6 +42,10 @@ ignore =
# these ignores are from flake8-comprehensions; please fix!
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
per-file-ignores =
*/__init__.py: F401
# Specify the list of error codes you wish Flake8 to report.
select =
E,
...
...
deepspeech/__init__.py
浏览文件 @
03a50d7b
...
...
@@ -352,45 +352,6 @@ if not hasattr(paddle.Tensor, 'tolist'):
"register user tolist to paddle.Tensor, remove this when fixed!"
)
setattr
(
paddle
.
Tensor
,
'tolist'
,
tolist
)
########### hcak paddle.nn.functional #############
def
glu
(
x
:
paddle
.
Tensor
,
axis
=-
1
)
->
paddle
.
Tensor
:
"""The gated linear unit (GLU) activation."""
a
,
b
=
x
.
split
(
2
,
axis
=
axis
)
act_b
=
F
.
sigmoid
(
b
)
return
a
*
act_b
if
not
hasattr
(
paddle
.
nn
.
functional
,
'glu'
):
logger
.
warn
(
"register user glu to paddle.nn.functional, remove this when fixed!"
)
setattr
(
paddle
.
nn
.
functional
,
'glu'
,
glu
)
# def softplus(x):
# """Softplus function."""
# if hasattr(paddle.nn.functional, 'softplus'):
# #return paddle.nn.functional.softplus(x.float()).type_as(x)
# return paddle.nn.functional.softplus(x)
# else:
# raise NotImplementedError
# def gelu_accurate(x):
# """Gaussian Error Linear Units (GELU) activation."""
# # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py
# if not hasattr(gelu_accurate, "_a"):
# gelu_accurate._a = math.sqrt(2 / math.pi)
# return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a *
# (x + 0.044715 * paddle.pow(x, 3))))
# def gelu(x):
# """Gaussian Error Linear Units (GELU) activation."""
# if hasattr(nn.functional, 'gelu'):
# #return nn.functional.gelu(x.float()).type_as(x)
# return nn.functional.gelu(x)
# else:
# return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0)))
########### hcak paddle.nn #############
class
GLU
(
nn
.
Layer
):
...
...
@@ -401,7 +362,7 @@ class GLU(nn.Layer):
self
.
dim
=
dim
def
forward
(
self
,
xs
):
return
glu
(
xs
,
dim
=
self
.
dim
)
return
F
.
glu
(
xs
,
dim
=
self
.
dim
)
if
not
hasattr
(
paddle
.
nn
,
'GLU'
):
...
...
deepspeech/decoders/swig/setup.py
浏览文件 @
03a50d7b
...
...
@@ -83,10 +83,13 @@ FILES = glob.glob('kenlm/util/*.cc') \
FILES
+=
glob
.
glob
(
'openfst-1.6.3/src/lib/*.cc'
)
# yapf: disable
FILES
=
[
fn
for
fn
in
FILES
if
not
(
fn
.
endswith
(
'main.cc'
)
or
fn
.
endswith
(
'test.cc'
)
or
fn
.
endswith
(
'unittest.cc'
))
fn
for
fn
in
FILES
if
not
(
fn
.
endswith
(
'main.cc'
)
or
fn
.
endswith
(
'test.cc'
)
or
fn
.
endswith
(
'unittest.cc'
))
]
# yapf: enable
LIBS
=
[
'stdc++'
]
if
platform
.
system
()
!=
'Darwin'
:
...
...
deepspeech/exps/deepspeech2/bin/export.py
浏览文件 @
03a50d7b
...
...
@@ -30,6 +30,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
# save jit model to
parser
.
add_argument
(
"--export_path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
parser
.
add_argument
(
"--model_type"
)
args
=
parser
.
parse_args
()
if
args
.
model_type
is
None
:
...
...
deepspeech/exps/deepspeech2/bin/test.py
浏览文件 @
03a50d7b
...
...
@@ -31,6 +31,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
.
add_argument
(
"--model_type"
)
# save asr result to
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
if
args
.
model_type
is
None
:
...
...
deepspeech/exps/u2/bin/alignment.py
浏览文件 @
03a50d7b
...
...
@@ -30,6 +30,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
# save asr result to
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2/bin/export.py
浏览文件 @
03a50d7b
...
...
@@ -30,6 +30,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
# save jit model to
parser
.
add_argument
(
"--export_path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2/bin/test.py
浏览文件 @
03a50d7b
...
...
@@ -34,6 +34,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
# save asr result to
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2_kaldi/bin/test.py
浏览文件 @
03a50d7b
...
...
@@ -14,6 +14,8 @@
"""Evaluation for U2 model."""
import
cProfile
from
yacs.config
import
CfgNode
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.utils.dynamic_import
import
dynamic_import
from
deepspeech.utils.utility
import
print_arguments
...
...
@@ -53,6 +55,14 @@ if __name__ == "__main__":
type
=
str
,
default
=
'test'
,
help
=
'run mode, e.g. test, align, export'
)
parser
.
add_argument
(
'--dict-path'
,
type
=
str
,
default
=
None
,
help
=
'dict path.'
)
# save asr result to
parser
.
add_argument
(
"--result-file"
,
type
=
str
,
help
=
"path of save the asr result"
)
# save jit model to
parser
.
add_argument
(
"--export-path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2_kaldi/model.py
浏览文件 @
03a50d7b
...
...
@@ -25,6 +25,8 @@ import paddle
from
paddle
import
distributed
as
dist
from
yacs.config
import
CfgNode
from
deepspeech.frontend.featurizer
import
TextFeaturizer
from
deepspeech.frontend.utility
import
load_dict
from
deepspeech.io.dataloader
import
BatchDataLoader
from
deepspeech.models.u2
import
U2Model
from
deepspeech.training.optimizer
import
OptimizerFactory
...
...
@@ -80,8 +82,8 @@ class U2Trainer(Trainer):
def
train_batch
(
self
,
batch_index
,
batch_data
,
msg
):
train_conf
=
self
.
config
.
training
start
=
time
.
time
()
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch_data
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch_data
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
audio
,
audio_len
,
text
,
text_len
)
# loss div by `batch_size * accum_grad`
...
...
@@ -124,6 +126,7 @@ class U2Trainer(Trainer):
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
total_loss
=
0.0
for
i
,
batch
in
enumerate
(
self
.
valid_loader
):
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
audio
,
audio_len
,
text
,
...
...
@@ -168,10 +171,7 @@ class U2Trainer(Trainer):
if
from_scratch
:
# save init model, i.e. 0 epoch
self
.
save
(
tag
=
'init'
)
self
.
lr_scheduler
.
step
(
self
.
iteration
)
if
self
.
parallel
:
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
...
...
@@ -225,7 +225,7 @@ class U2Trainer(Trainer):
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
mini_batch_size
=
self
.
args
.
nprocs
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
...
...
@@ -244,7 +244,7 @@ class U2Trainer(Trainer):
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
mini_batch_size
=
1
,
mini_batch_size
=
self
.
args
.
nprocs
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_frames_in
=
0
,
...
...
@@ -260,7 +260,7 @@ class U2Trainer(Trainer):
json_file
=
config
.
data
.
test_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
config
.
collator
.
batch_size
,
batch_size
=
config
.
decoding
.
batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
...
...
@@ -279,7 +279,7 @@ class U2Trainer(Trainer):
json_file
=
config
.
data
.
test_manifest
,
train_mode
=
False
,
sortagrad
=
False
,
batch_size
=
config
.
collator
.
batch_size
,
batch_size
=
config
.
decoding
.
batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
...
...
@@ -305,10 +305,8 @@ class U2Trainer(Trainer):
model_conf
.
output_dim
=
self
.
train_loader
.
vocab_size
model_conf
.
freeze
()
model
=
U2Model
.
from_config
(
model_conf
)
if
self
.
parallel
:
model
=
paddle
.
DataParallel
(
model
)
logger
.
info
(
f
"
{
model
}
"
)
layer_tools
.
print_params
(
model
,
logger
.
info
)
...
...
@@ -379,13 +377,13 @@ class U2Tester(U2Trainer):
def
__init__
(
self
,
config
,
args
):
super
().
__init__
(
config
,
args
)
def
ordid2token
(
self
,
texts
,
texts_len
):
def
id2token
(
self
,
texts
,
texts_len
,
text_feature
):
""" ord() id to chr() chr """
trans
=
[]
for
text
,
n
in
zip
(
texts
,
texts_len
):
n
=
n
.
numpy
().
item
()
ids
=
text
[:
n
]
trans
.
append
(
''
.
join
([
chr
(
i
)
for
i
in
ids
]
))
trans
.
append
(
text_feature
.
defeaturize
(
ids
.
numpy
().
tolist
()
))
return
trans
def
compute_metrics
(
self
,
...
...
@@ -401,8 +399,11 @@ class U2Tester(U2Trainer):
error_rate_func
=
error_rate
.
cer
if
cfg
.
error_rate_type
==
'cer'
else
error_rate
.
wer
start_time
=
time
.
time
()
text_feature
=
self
.
test_loader
.
collate_fn
.
text_feature
target_transcripts
=
self
.
ordid2token
(
texts
,
texts_len
)
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
collator
.
unit_type
,
vocab_filepath
=
self
.
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
target_transcripts
=
self
.
id2token
(
texts
,
texts_len
,
text_feature
)
result_transcripts
=
self
.
model
.
decode
(
audio
,
audio_len
,
...
...
@@ -450,7 +451,7 @@ class U2Tester(U2Trainer):
self
.
model
.
eval
()
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
stride_ms
=
self
.
test_loader
.
collate_fn
.
stride_ms
stride_ms
=
self
.
config
.
collator
.
stride_ms
error_rate_type
=
None
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
num_frames
=
0.0
...
...
@@ -525,8 +526,9 @@ class U2Tester(U2Trainer):
self
.
model
.
eval
()
logger
.
info
(
f
"Align Total Examples:
{
len
(
self
.
align_loader
.
dataset
)
}
"
)
stride_ms
=
self
.
config
.
collate
.
stride_ms
token_dict
=
self
.
align_loader
.
collate_fn
.
vocab_list
stride_ms
=
self
.
config
.
collater
.
stride_ms
token_dict
=
self
.
args
.
char_list
with
open
(
self
.
args
.
result_file
,
'w'
)
as
fout
:
# one example in batch
for
i
,
batch
in
enumerate
(
self
.
align_loader
):
...
...
@@ -613,6 +615,11 @@ class U2Tester(U2Trainer):
except
KeyboardInterrupt
:
sys
.
exit
(
-
1
)
def
setup_dict
(
self
):
# load dictionary for debug log
self
.
args
.
char_list
=
load_dict
(
self
.
args
.
dict_path
,
"maskctc"
in
self
.
args
.
model_name
)
def
setup
(
self
):
"""Setup the experiment.
"""
...
...
@@ -624,6 +631,8 @@ class U2Tester(U2Trainer):
self
.
setup_dataloader
()
self
.
setup_model
()
self
.
setup_dict
()
self
.
iteration
=
0
self
.
epoch
=
0
...
...
deepspeech/exps/u2_st/bin/export.py
浏览文件 @
03a50d7b
...
...
@@ -30,6 +30,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
# save jit model to
parser
.
add_argument
(
"--export_path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2_st/bin/test.py
浏览文件 @
03a50d7b
...
...
@@ -34,6 +34,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
# save asr result to
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
...
...
deepspeech/frontend/augmentor/impulse_response.py
浏览文件 @
03a50d7b
...
...
@@ -32,7 +32,7 @@ class ImpulseResponseAugmentor(AugmentorBase):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
return
x
self
.
transform_audio
(
x
)
return
x
...
...
deepspeech/frontend/augmentor/noise_perturb.py
浏览文件 @
03a50d7b
...
...
@@ -38,7 +38,7 @@ class NoisePerturbAugmentor(AugmentorBase):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
return
x
self
.
transform_audio
(
x
)
return
x
...
...
deepspeech/frontend/augmentor/online_bayesian_normalization.py
浏览文件 @
03a50d7b
...
...
@@ -46,7 +46,7 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
return
x
self
.
transform_audio
(
x
)
return
x
...
...
deepspeech/frontend/augmentor/resample.py
浏览文件 @
03a50d7b
...
...
@@ -33,7 +33,7 @@ class ResampleAugmentor(AugmentorBase):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
return
x
self
.
transform_audio
(
x
)
return
x
...
...
deepspeech/frontend/augmentor/shift_perturb.py
浏览文件 @
03a50d7b
...
...
@@ -33,7 +33,7 @@ class ShiftPerturbAugmentor(AugmentorBase):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
return
x
self
.
transform_audio
(
x
)
return
x
...
...
deepspeech/frontend/augmentor/spec_augment.py
浏览文件 @
03a50d7b
...
...
@@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the volume perturb augmentation model."""
import
random
import
numpy
as
np
from
PIL
import
Image
from
PIL.Image
import
BICUBIC
from
deepspeech.frontend.augmentor.base
import
AugmentorBase
from
deepspeech.utils.log
import
Log
...
...
@@ -41,7 +45,9 @@ class SpecAugmentor(AugmentorBase):
W
=
40
,
adaptive_number_ratio
=
0
,
adaptive_size_ratio
=
0
,
max_n_time_masks
=
20
):
max_n_time_masks
=
20
,
replace_with_zero
=
True
,
warp_mode
=
'PIL'
):
"""SpecAugment class.
Args:
rng (random.Random): random generator object.
...
...
@@ -54,10 +60,16 @@ class SpecAugmentor(AugmentorBase):
adaptive_number_ratio (float): adaptive multiplicity ratio for time masking
adaptive_size_ratio (float): adaptive size ratio for time masking
max_n_time_masks (int): maximum number of time masking
replace_with_zero (bool): pad zero on mask if true else use mean
warp_mode (str): "PIL" (default, fast, not differentiable)
or "sparse_image_warp" (slow, differentiable)
"""
super
().
__init__
()
self
.
_rng
=
rng
self
.
inplace
=
True
self
.
replace_with_zero
=
replace_with_zero
self
.
mode
=
warp_mode
self
.
W
=
W
self
.
F
=
F
self
.
T
=
T
...
...
@@ -123,21 +135,83 @@ class SpecAugmentor(AugmentorBase):
def
__repr__
(
self
):
return
f
"specaug: F-
{
F
}
, T-
{
T
}
, F-n-
{
n_freq_masks
}
, T-n-
{
n_time_masks
}
"
def
time_warp
(
xs
,
W
=
40
):
raise
NotImplementedError
def
time_warp
(
self
,
x
,
mode
=
'PIL'
):
"""time warp for spec augment
move random center frame by the random width ~ uniform(-window, window)
Args:
x (np.ndarray): spectrogram (time, freq)
mode (str): PIL or sparse_image_warp
Raises:
NotImplementedError: [description]
NotImplementedError: [description]
Returns:
np.ndarray: time warped spectrogram (time, freq)
"""
window
=
max_time_warp
=
self
.
W
if
window
==
0
:
return
x
if
mode
==
"PIL"
:
t
=
x
.
shape
[
0
]
if
t
-
window
<=
window
:
return
x
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
center
=
random
.
randrange
(
window
,
t
-
window
)
warped
=
random
.
randrange
(
center
-
window
,
center
+
window
)
+
1
# 1 ... t - 1
left
=
Image
.
fromarray
(
x
[:
center
]).
resize
((
x
.
shape
[
1
],
warped
),
BICUBIC
)
right
=
Image
.
fromarray
(
x
[
center
:]).
resize
((
x
.
shape
[
1
],
t
-
warped
),
BICUBIC
)
if
self
.
inplace
:
x
[:
warped
]
=
left
x
[
warped
:]
=
right
return
x
return
np
.
concatenate
((
left
,
right
),
0
)
elif
mode
==
"sparse_image_warp"
:
raise
NotImplementedError
(
'sparse_image_warp'
)
else
:
raise
NotImplementedError
(
"unknown resize mode: "
+
mode
+
", choose one from (PIL, sparse_image_warp)."
)
def
mask_freq
(
self
,
x
,
replace_with_zero
=
False
):
"""freq mask
Args:
x (np.ndarray): spectrogram (time, freq)
replace_with_zero (bool, optional): Defaults to False.
def
mask_freq
(
self
,
xs
,
replace_with_zero
=
False
):
n_bins
=
xs
.
shape
[
0
]
Returns:
np.ndarray: freq mask spectrogram (time, freq)
"""
n_bins
=
x
.
shape
[
1
]
for
i
in
range
(
0
,
self
.
n_freq_masks
):
f
=
int
(
self
.
_rng
.
uniform
(
low
=
0
,
high
=
self
.
F
))
f_0
=
int
(
self
.
_rng
.
uniform
(
low
=
0
,
high
=
n_bins
-
f
))
xs
[
f_0
:
f_0
+
f
,
:]
=
0
assert
f_0
<=
f_0
+
f
if
replace_with_zero
:
x
[:,
f_0
:
f_0
+
f
]
=
0
else
:
x
[:,
f_0
:
f_0
+
f
]
=
x
.
mean
()
self
.
_freq_mask
=
(
f_0
,
f_0
+
f
)
return
x
s
return
x
def
mask_time
(
self
,
xs
,
replace_with_zero
=
False
):
n_frames
=
xs
.
shape
[
1
]
def
mask_time
(
self
,
x
,
replace_with_zero
=
False
):
"""time mask
Args:
x (np.ndarray): spectrogram (time, freq)
replace_with_zero (bool, optional): Defaults to False.
Returns:
np.ndarray: time mask spectrogram (time, freq)
"""
n_frames
=
x
.
shape
[
0
]
if
self
.
adaptive_number_ratio
>
0
:
n_masks
=
int
(
n_frames
*
self
.
adaptive_number_ratio
)
...
...
@@ -154,24 +228,29 @@ class SpecAugmentor(AugmentorBase):
t
=
int
(
self
.
_rng
.
uniform
(
low
=
0
,
high
=
T
))
t
=
min
(
t
,
int
(
n_frames
*
self
.
p
))
t_0
=
int
(
self
.
_rng
.
uniform
(
low
=
0
,
high
=
n_frames
-
t
))
xs
[:,
t_0
:
t_0
+
t
]
=
0
assert
t_0
<=
t_0
+
t
if
replace_with_zero
:
x
[
t_0
:
t_0
+
t
,
:]
=
0
else
:
x
[
t_0
:
t_0
+
t
,
:]
=
x
.
mean
()
self
.
_time_mask
=
(
t_0
,
t_0
+
t
)
return
x
s
return
x
def
__call__
(
self
,
x
,
train
=
True
):
if
not
train
:
return
return
x
return
self
.
transform_feature
(
x
)
def
transform_feature
(
self
,
x
s
:
np
.
ndarray
):
def
transform_feature
(
self
,
x
:
np
.
ndarray
):
"""
Args:
x
s (FloatTensor): `[F, T
]`
x
(np.ndarray): `[T, F
]`
Returns:
x
s (FloatTensor): `[F, T
]`
x
(np.ndarray): `[T, F
]`
"""
# xs = self.time_warp(xs)
xs
=
self
.
mask_freq
(
xs
)
xs
=
self
.
mask_time
(
xs
)
return
xs
assert
isinstance
(
x
,
np
.
ndarray
)
assert
x
.
ndim
==
2
x
=
self
.
time_warp
(
x
,
self
.
mode
)
x
=
self
.
mask_freq
(
x
,
self
.
replace_with_zero
)
x
=
self
.
mask_time
(
x
,
self
.
replace_with_zero
)
return
x
deepspeech/frontend/augmentor/speed_perturb.py
浏览文件 @
03a50d7b
...
...
@@ -81,7 +81,7 @@ class SpeedPerturbAugmentor(AugmentorBase):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
return
x
self
.
transform_audio
(
x
)
return
x
...
...
deepspeech/frontend/augmentor/volume_perturb.py
浏览文件 @
03a50d7b
...
...
@@ -39,7 +39,7 @@ class VolumePerturbAugmentor(AugmentorBase):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
return
x
self
.
transform_audio
(
x
)
return
x
...
...
deepspeech/frontend/featurizer/__init__.py
浏览文件 @
03a50d7b
...
...
@@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.audio_featurizer
import
AudioFeaturizer
#noqa: F401
from
.speech_featurizer
import
SpeechFeaturizer
from
.text_featurizer
import
TextFeaturizer
deepspeech/frontend/featurizer/audio_featurizer.py
浏览文件 @
03a50d7b
...
...
@@ -18,7 +18,7 @@ from python_speech_features import logfbank
from
python_speech_features
import
mfcc
class
AudioFeaturizer
(
object
):
class
AudioFeaturizer
():
"""Audio featurizer, for extracting features from audio contents of
AudioSegment or SpeechSegment.
...
...
@@ -167,32 +167,6 @@ class AudioFeaturizer(object):
raise
ValueError
(
"Unknown specgram_type %s. "
"Supported values: linear."
%
self
.
_specgram_type
)
def
_compute_linear_specgram
(
self
,
samples
,
sample_rate
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
max_freq
=
None
,
eps
=
1e-14
):
"""Compute the linear spectrogram from FFT energy."""
if
max_freq
is
None
:
max_freq
=
sample_rate
/
2
if
max_freq
>
sample_rate
/
2
:
raise
ValueError
(
"max_freq must not be greater than half of "
"sample rate."
)
if
stride_ms
>
window_ms
:
raise
ValueError
(
"Stride size must not be greater than "
"window size."
)
stride_size
=
int
(
0.001
*
sample_rate
*
stride_ms
)
window_size
=
int
(
0.001
*
sample_rate
*
window_ms
)
specgram
,
freqs
=
self
.
_specgram_real
(
samples
,
window_size
=
window_size
,
stride_size
=
stride_size
,
sample_rate
=
sample_rate
)
ind
=
np
.
where
(
freqs
<=
max_freq
)[
0
][
-
1
]
+
1
return
np
.
log
(
specgram
[:
ind
,
:]
+
eps
)
def
_specgram_real
(
self
,
samples
,
window_size
,
stride_size
,
sample_rate
):
"""Compute the spectrogram for samples from a real signal."""
# extract strided windows
...
...
@@ -217,26 +191,65 @@ class AudioFeaturizer(object):
freqs
=
float
(
sample_rate
)
/
window_size
*
np
.
arange
(
fft
.
shape
[
0
])
return
fft
,
freqs
def
_compute_linear_specgram
(
self
,
samples
,
sample_rate
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
max_freq
=
None
,
eps
=
1e-14
):
"""Compute the linear spectrogram from FFT energy.
Args:
samples ([type]): [description]
sample_rate ([type]): [description]
stride_ms (float, optional): [description]. Defaults to 10.0.
window_ms (float, optional): [description]. Defaults to 20.0.
max_freq ([type], optional): [description]. Defaults to None.
eps ([type], optional): [description]. Defaults to 1e-14.
Raises:
ValueError: [description]
ValueError: [description]
Returns:
np.ndarray: log spectrogram, (time, freq)
"""
if
max_freq
is
None
:
max_freq
=
sample_rate
/
2
if
max_freq
>
sample_rate
/
2
:
raise
ValueError
(
"max_freq must not be greater than half of "
"sample rate."
)
if
stride_ms
>
window_ms
:
raise
ValueError
(
"Stride size must not be greater than "
"window size."
)
stride_size
=
int
(
0.001
*
sample_rate
*
stride_ms
)
window_size
=
int
(
0.001
*
sample_rate
*
window_ms
)
specgram
,
freqs
=
self
.
_specgram_real
(
samples
,
window_size
=
window_size
,
stride_size
=
stride_size
,
sample_rate
=
sample_rate
)
ind
=
np
.
where
(
freqs
<=
max_freq
)[
0
][
-
1
]
+
1
# (freq, time)
spec
=
np
.
log
(
specgram
[:
ind
,
:]
+
eps
)
return
np
.
transpose
(
spec
)
def
_concat_delta_delta
(
self
,
feat
):
"""append delat, delta-delta feature.
Args:
feat (np.ndarray): (
D, T
)
feat (np.ndarray): (
T, D
)
Returns:
np.ndarray: feat with delta-delta, (
3*D, T
)
np.ndarray: feat with delta-delta, (
T, 3*D
)
"""
feat
=
np
.
transpose
(
feat
)
# Deltas
d_feat
=
delta
(
feat
,
2
)
# Deltas-Deltas
dd_feat
=
delta
(
feat
,
2
)
# transpose
feat
=
np
.
transpose
(
feat
)
d_feat
=
np
.
transpose
(
d_feat
)
dd_feat
=
np
.
transpose
(
dd_feat
)
# concat above three features
concat_feat
=
np
.
concatenate
((
feat
,
d_feat
,
dd_feat
))
concat_feat
=
np
.
concatenate
((
feat
,
d_feat
,
dd_feat
)
,
axis
=
1
)
return
concat_feat
def
_compute_mfcc
(
self
,
...
...
@@ -292,7 +305,6 @@ class AudioFeaturizer(object):
ceplifter
=
22
,
useEnergy
=
True
,
winfunc
=
'povey'
)
mfcc_feat
=
np
.
transpose
(
mfcc_feat
)
if
delta_delta
:
mfcc_feat
=
self
.
_concat_delta_delta
(
mfcc_feat
)
return
mfcc_feat
...
...
@@ -346,8 +358,6 @@ class AudioFeaturizer(object):
remove_dc_offset
=
True
,
preemph
=
0.97
,
wintype
=
'povey'
)
fbank_feat
=
np
.
transpose
(
fbank_feat
)
if
delta_delta
:
fbank_feat
=
self
.
_concat_delta_delta
(
fbank_feat
)
return
fbank_feat
deepspeech/frontend/featurizer/speech_featurizer.py
浏览文件 @
03a50d7b
...
...
@@ -16,7 +16,7 @@ from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer
from
deepspeech.frontend.featurizer.text_featurizer
import
TextFeaturizer
class
SpeechFeaturizer
(
object
):
class
SpeechFeaturizer
():
"""Speech featurizer, for extracting features from both audio and transcript
contents of SpeechSegment.
...
...
deepspeech/frontend/featurizer/text_featurizer.py
浏览文件 @
03a50d7b
...
...
@@ -14,12 +14,19 @@
"""Contains the text featurizer class."""
import
sentencepiece
as
spm
from
deepspeech.frontend.utility
import
EOS
from
deepspeech.frontend.utility
import
UNK
from
..utility
import
EOS
from
..utility
import
load_dict
from
..utility
import
UNK
__all__
=
[
"TextFeaturizer"
]
class
TextFeaturizer
(
object
):
def
__init__
(
self
,
unit_type
,
vocab_filepath
,
spm_model_prefix
=
None
):
class
TextFeaturizer
():
def
__init__
(
self
,
unit_type
,
vocab_filepath
,
spm_model_prefix
=
None
,
maskctc
=
False
):
"""Text featurizer, for processing or extracting features from text.
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
...
...
@@ -34,11 +41,12 @@ class TextFeaturizer(object):
assert
unit_type
in
(
'char'
,
'spm'
,
'word'
)
self
.
unit_type
=
unit_type
self
.
unk
=
UNK
self
.
maskctc
=
maskctc
if
vocab_filepath
:
self
.
_vocab_dict
,
self
.
_id2token
,
self
.
_vocab_list
=
self
.
_load_vocabulary_from_file
(
vocab_filepath
)
self
.
unk_id
=
self
.
_vocab_list
.
index
(
self
.
unk
)
self
.
eos_id
=
self
.
_vocab_list
.
index
(
EOS
)
self
.
vocab_dict
,
self
.
_id2token
,
self
.
vocab_list
,
self
.
unk_id
,
self
.
eos_id
=
self
.
_load_vocabulary_from_file
(
vocab_filepath
,
maskctc
)
self
.
vocab_size
=
len
(
self
.
vocab_list
)
if
unit_type
==
'spm'
:
spm_model
=
spm_model_prefix
+
'.model'
...
...
@@ -67,7 +75,7 @@ class TextFeaturizer(object):
"""Convert text string to a list of token indices.
Args:
text (str): Text
to process
.
text (str): Text.
Returns:
List[int]: List of token indices.
...
...
@@ -75,8 +83,8 @@ class TextFeaturizer(object):
tokens
=
self
.
tokenize
(
text
)
ids
=
[]
for
token
in
tokens
:
token
=
token
if
token
in
self
.
_
vocab_dict
else
self
.
unk
ids
.
append
(
self
.
_
vocab_dict
[
token
])
token
=
token
if
token
in
self
.
vocab_dict
else
self
.
unk
ids
.
append
(
self
.
vocab_dict
[
token
])
return
ids
def
defeaturize
(
self
,
idxs
):
...
...
@@ -87,7 +95,7 @@ class TextFeaturizer(object):
idxs (List[int]): List of token indices.
Returns:
str: Text
to process
.
str: Text.
"""
tokens
=
[]
for
idx
in
idxs
:
...
...
@@ -97,33 +105,6 @@ class TextFeaturizer(object):
text
=
self
.
detokenize
(
tokens
)
return
text
@
property
def
vocab_size
(
self
):
"""Return the vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return
len
(
self
.
_vocab_list
)
@
property
def
vocab_list
(
self
):
"""Return the vocabulary in list.
Returns:
List[str]: tokens.
"""
return
self
.
_vocab_list
@
property
def
vocab_dict
(
self
):
"""Return the vocabulary in dict.
Returns:
Dict[str, int]: token str -> int
"""
return
self
.
_vocab_dict
def
char_tokenize
(
self
,
text
):
"""Character tokenizer.
...
...
@@ -206,14 +187,16 @@ class TextFeaturizer(object):
return
decode
(
tokens
)
def
_load_vocabulary_from_file
(
self
,
vocab_filepath
):
def
_load_vocabulary_from_file
(
self
,
vocab_filepath
:
str
,
maskctc
:
bool
):
"""Load vocabulary from file."""
vocab_lines
=
[]
with
open
(
vocab_filepath
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
vocab_lines
.
extend
(
file
.
readlines
())
vocab_list
=
[
line
[:
-
1
]
for
line
in
vocab_lines
]
vocab_list
=
load_dict
(
vocab_filepath
,
maskctc
)
assert
vocab_list
is
not
None
id2token
=
dict
(
[(
idx
,
token
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
token2id
=
dict
(
[(
token
,
idx
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
return
token2id
,
id2token
,
vocab_list
unk_id
=
vocab_list
.
index
(
UNK
)
eos_id
=
vocab_list
.
index
(
EOS
)
return
token2id
,
id2token
,
vocab_list
,
unk_id
,
eos_id
deepspeech/frontend/normalizer.py
浏览文件 @
03a50d7b
...
...
@@ -40,21 +40,21 @@ class CollateFunc(object):
number
=
0
for
item
in
batch
:
audioseg
=
AudioSegment
.
from_file
(
item
[
'feat'
])
feat
=
self
.
feature_func
(
audioseg
)
#(
D, T
)
feat
=
self
.
feature_func
(
audioseg
)
#(
T, D
)
sums
=
np
.
sum
(
feat
,
axis
=
1
)
sums
=
np
.
sum
(
feat
,
axis
=
0
)
if
mean_stat
is
None
:
mean_stat
=
sums
else
:
mean_stat
+=
sums
square_sums
=
np
.
sum
(
np
.
square
(
feat
),
axis
=
1
)
square_sums
=
np
.
sum
(
np
.
square
(
feat
),
axis
=
0
)
if
var_stat
is
None
:
var_stat
=
square_sums
else
:
var_stat
+=
square_sums
number
+=
feat
.
shape
[
1
]
number
+=
feat
.
shape
[
0
]
return
number
,
mean_stat
,
var_stat
...
...
@@ -120,7 +120,7 @@ class FeatureNormalizer(object):
"""Normalize features to be of zero mean and unit stddev.
:param features: Input features to be normalized.
:type features: ndarray, shape (
D, T
)
:type features: ndarray, shape (
T, D
)
:param eps: added to stddev to provide numerical stablibity.
:type eps: float
:return: Normalized features.
...
...
@@ -131,8 +131,8 @@ class FeatureNormalizer(object):
def
_read_mean_std_from_file
(
self
,
filepath
,
eps
=
1e-20
):
"""Load mean and std from file."""
mean
,
istd
=
load_cmvn
(
filepath
,
filetype
=
'json'
)
self
.
_mean
=
np
.
expand_dims
(
mean
,
axis
=
-
1
)
self
.
_istd
=
np
.
expand_dims
(
istd
,
axis
=
-
1
)
self
.
_mean
=
np
.
expand_dims
(
mean
,
axis
=
0
)
self
.
_istd
=
np
.
expand_dims
(
istd
,
axis
=
0
)
def
write_to_file
(
self
,
filepath
):
"""Write the mean and stddev to the file.
...
...
deepspeech/frontend/utility.py
浏览文件 @
03a50d7b
...
...
@@ -15,6 +15,9 @@
import
codecs
import
json
import
math
from
typing
import
List
from
typing
import
Optional
from
typing
import
Text
import
numpy
as
np
...
...
@@ -23,16 +26,35 @@ from deepspeech.utils.log import Log
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"load_
cmvn"
,
"read_manifest"
,
"rms_to_db"
,
"rms_to_dbfs"
,
"max
_dbfs"
,
"m
ean_dbfs"
,
"gain_db_to_ratio"
,
"normalize_audio"
,
"SOS"
,
"EOS"
,
"UNK
"
,
"
BLANK
"
"load_
dict"
,
"load_cmvn"
,
"read_manifest"
,
"rms_to_db"
,
"rms_to
_dbfs"
,
"m
ax_dbfs"
,
"mean_dbfs"
,
"gain_db_to_ratio"
,
"normalize_audio"
,
"SOS
"
,
"
EOS"
,
"UNK"
,
"BLANK"
,
"MASKCTC
"
]
IGNORE_ID
=
-
1
SOS
=
"<sos/eos>"
# `sos` and `eos` using same token
SOS
=
"<eos>"
EOS
=
SOS
UNK
=
"<unk>"
BLANK
=
"<blank>"
MASKCTC
=
"<mask>"
def
load_dict
(
dict_path
:
Optional
[
Text
],
maskctc
=
False
)
->
Optional
[
List
[
Text
]]:
if
dict_path
is
None
:
return
None
with
open
(
dict_path
,
"r"
)
as
f
:
dictionary
=
f
.
readlines
()
char_list
=
[
entry
.
strip
().
split
(
" "
)[
0
]
for
entry
in
dictionary
]
if
BLANK
not
in
char_list
:
char_list
.
insert
(
0
,
BLANK
)
if
EOS
not
in
char_list
:
char_list
.
append
(
EOS
)
# for non-autoregressive maskctc model
if
maskctc
and
MASKCTC
not
in
char_list
:
char_list
.
append
(
MASKCTC
)
return
char_list
def
read_manifest
(
...
...
@@ -47,12 +69,20 @@ def read_manifest(
Args:
manifest_path ([type]): Manifest file to load and parse.
max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05.
max_input_len ([type], optional): maximum output seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to float('inf').
min_input_len (float, optional): minimum input seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to 0.0.
max_output_len (float, optional): maximum input seq length,
in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length,
in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional):
maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional):
minimum output seq length/output seq length ratio. Defaults to 0.05.
Raises:
IOError: If failed to parse the manifest.
...
...
deepspeech/io/collator.py
浏览文件 @
03a50d7b
...
...
@@ -242,7 +242,6 @@ class SpeechCollator():
# specgram augment
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
specgram
.
transpose
([
1
,
0
])
return
specgram
,
transcript_part
def
__call__
(
self
,
batch
):
...
...
@@ -250,7 +249,7 @@ class SpeechCollator():
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (
D, T
)
audio (np.ndarray) shape (
T, D
)
text (List[int] or str): shape (U,)
Returns:
...
...
deepspeech/io/collator_st.py
浏览文件 @
03a50d7b
...
...
@@ -217,6 +217,34 @@ class SpeechCollator():
return
self
.
_local_data
.
tar2object
[
tarpath
].
extractfile
(
self
.
_local_data
.
tar2info
[
tarpath
][
filename
])
@
property
def
manifest
(
self
):
return
self
.
_manifest
@
property
def
vocab_size
(
self
):
return
self
.
_speech_featurizer
.
vocab_size
@
property
def
vocab_list
(
self
):
return
self
.
_speech_featurizer
.
vocab_list
@
property
def
vocab_dict
(
self
):
return
self
.
_speech_featurizer
.
vocab_dict
@
property
def
text_feature
(
self
):
return
self
.
_speech_featurizer
.
text_feature
@
property
def
feature_size
(
self
):
return
self
.
_speech_featurizer
.
feature_size
@
property
def
stride_ms
(
self
):
return
self
.
_speech_featurizer
.
stride_ms
def
process_utterance
(
self
,
audio_file
,
translation
):
"""Load, augment, featurize and normalize for speech data.
...
...
@@ -244,7 +272,6 @@ class SpeechCollator():
# specgram augment
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
specgram
.
transpose
([
1
,
0
])
return
specgram
,
translation_part
def
__call__
(
self
,
batch
):
...
...
@@ -252,7 +279,7 @@ class SpeechCollator():
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (
D, T
)
audio (np.ndarray) shape (
T, D
)
text (List[int] or str): shape (U,)
Returns:
...
...
@@ -296,34 +323,6 @@ class SpeechCollator():
text_lens
=
np
.
array
(
text_lens
).
astype
(
np
.
int64
)
return
utts
,
padded_audios
,
audio_lens
,
padded_texts
,
text_lens
@
property
def
manifest
(
self
):
return
self
.
_manifest
@
property
def
vocab_size
(
self
):
return
self
.
_speech_featurizer
.
vocab_size
@
property
def
vocab_list
(
self
):
return
self
.
_speech_featurizer
.
vocab_list
@
property
def
vocab_dict
(
self
):
return
self
.
_speech_featurizer
.
vocab_dict
@
property
def
text_feature
(
self
):
return
self
.
_speech_featurizer
.
text_feature
@
property
def
feature_size
(
self
):
return
self
.
_speech_featurizer
.
feature_size
@
property
def
stride_ms
(
self
):
return
self
.
_speech_featurizer
.
stride_ms
class
TripletSpeechCollator
(
SpeechCollator
):
def
process_utterance
(
self
,
audio_file
,
translation
,
transcript
):
...
...
@@ -355,7 +354,6 @@ class TripletSpeechCollator(SpeechCollator):
# specgram augment
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
specgram
.
transpose
([
1
,
0
])
return
specgram
,
translation_part
,
transcript_part
def
__call__
(
self
,
batch
):
...
...
@@ -363,7 +361,7 @@ class TripletSpeechCollator(SpeechCollator):
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (
D, T
)
audio (np.ndarray) shape (
T, D
)
text (List[int] or str): shape (U,)
Returns:
...
...
@@ -524,49 +522,19 @@ class KaldiPrePorocessedCollator(SpeechCollator):
:rtype: tuple of (2darray, list)
"""
specgram
=
kaldiio
.
load_mat
(
audio_file
)
specgram
=
specgram
.
transpose
([
1
,
0
])
assert
specgram
.
shape
[
0
]
==
self
.
_feat_dim
,
'expect feat dim {}, but got {}'
.
format
(
self
.
_feat_dim
,
specgram
.
shape
[
0
])
1
]
==
self
.
_feat_dim
,
'expect feat dim {}, but got {}'
.
format
(
self
.
_feat_dim
,
specgram
.
shape
[
1
])
# specgram augment
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
specgram
.
transpose
([
1
,
0
])
if
self
.
_keep_transcription_text
:
return
specgram
,
translation
else
:
text_ids
=
self
.
_text_featurizer
.
featurize
(
translation
)
return
specgram
,
text_ids
@
property
def
manifest
(
self
):
return
self
.
_manifest
@
property
def
vocab_size
(
self
):
return
self
.
_text_featurizer
.
vocab_size
@
property
def
vocab_list
(
self
):
return
self
.
_text_featurizer
.
vocab_list
@
property
def
vocab_dict
(
self
):
return
self
.
_text_featurizer
.
vocab_dict
@
property
def
text_feature
(
self
):
return
self
.
_text_featurizer
@
property
def
feature_size
(
self
):
return
self
.
_feat_dim
@
property
def
stride_ms
(
self
):
return
self
.
_stride_ms
class
TripletKaldiPrePorocessedCollator
(
KaldiPrePorocessedCollator
):
def
process_utterance
(
self
,
audio_file
,
translation
,
transcript
):
...
...
@@ -583,15 +551,13 @@ class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
:rtype: tuple of (2darray, (list, list))
"""
specgram
=
kaldiio
.
load_mat
(
audio_file
)
specgram
=
specgram
.
transpose
([
1
,
0
])
assert
specgram
.
shape
[
0
]
==
self
.
_feat_dim
,
'expect feat dim {}, but got {}'
.
format
(
self
.
_feat_dim
,
specgram
.
shape
[
0
])
1
]
==
self
.
_feat_dim
,
'expect feat dim {}, but got {}'
.
format
(
self
.
_feat_dim
,
specgram
.
shape
[
1
])
# specgram augment
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
specgram
.
transpose
([
1
,
0
])
if
self
.
_keep_transcription_text
:
return
specgram
,
translation
,
transcript
else
:
...
...
@@ -604,7 +570,7 @@ class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (
D, T
)
audio (np.ndarray) shape (
T, D
)
translation (List[int] or str): shape (U,)
transcription (List[int] or str): shape (V,)
...
...
deepspeech/io/converter.py
浏览文件 @
03a50d7b
...
...
@@ -43,7 +43,7 @@ class CustomConverter():
batch (list): The batch to transform.
Returns:
tuple(
paddle.Tensor, paddle.Tensor, paddle.Tensor
)
tuple(
np.ndarray, nn.ndarray, nn.ndarray
)
"""
# batch should be located in list
...
...
deepspeech/io/dataloader.py
浏览文件 @
03a50d7b
...
...
@@ -43,6 +43,18 @@ def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]],
return
feat_dim
,
vocab_size
def
batch_collate
(
x
):
"""de-tuple.
Args:
x (List[Tuple]): [(utts, xs, ilens, ys, olens)]
Returns:
Tuple: (utts, xs, ilens, ys, olens)
"""
return
x
[
0
]
class
BatchDataLoader
():
def
__init__
(
self
,
json_file
:
str
,
...
...
@@ -120,15 +132,15 @@ class BatchDataLoader():
# actual bathsize is included in a list
# default collate function converts numpy array to pytorch tensor
# we used an empty collate function instead which returns list
self
.
dataset
=
TransformDataset
(
self
.
minibaches
,
lambda
data
:
self
.
converter
([
self
.
reader
(
data
,
return_uttid
=
True
)]))
self
.
dataset
=
TransformDataset
(
self
.
minibaches
,
self
.
converter
,
self
.
reader
)
self
.
dataloader
=
DataLoader
(
dataset
=
self
.
dataset
,
batch_size
=
1
,
shuffle
=
not
self
.
use_sortagrad
if
train_mode
else
False
,
collate_fn
=
lambda
x
:
x
[
0
]
,
num_workers
=
n_iter_processes
,
)
shuffle
=
not
self
.
use_sortagrad
if
self
.
train_mode
else
False
,
collate_fn
=
batch_collate
,
num_workers
=
self
.
n_iter_processes
,
)
def
__repr__
(
self
):
echo
=
f
"<
{
self
.
__class__
.
__module__
}
.
{
self
.
__class__
.
__name__
}
object at
{
hex
(
id
(
self
))
}
> "
...
...
deepspeech/io/dataset.py
浏览文件 @
03a50d7b
...
...
@@ -129,15 +129,16 @@ class TransformDataset(Dataset):
Args:
data: list object from make_batchset
transfrom: transform
function
converter: batch
function
reader: read data
"""
def
__init__
(
self
,
data
,
transform
):
def
__init__
(
self
,
data
,
converter
,
reader
):
"""Init function."""
super
().
__init__
()
self
.
data
=
data
self
.
transform
=
transform
self
.
converter
=
converter
self
.
reader
=
reader
def
__len__
(
self
):
"""Len function."""
...
...
@@ -145,4 +146,4 @@ class TransformDataset(Dataset):
def
__getitem__
(
self
,
idx
):
"""[] operator."""
return
self
.
transform
(
self
.
data
[
idx
])
return
self
.
converter
([
self
.
reader
(
self
.
data
[
idx
],
return_uttid
=
True
)
])
deepspeech/models/ds2_online/deepspeech2.py
浏览文件 @
03a50d7b
...
...
@@ -102,13 +102,13 @@ class CRNNEncoder(nn.Layer):
Args:
x (Tensor): [B, feature_size, D]
x_lens (Tensor): [B]
init_state_h_box(Tensor): init_states h for RNN layers
, num_rnn_layers * num_directions, batch_size, hidden_size
init_state_c_box(Tensor): init_states c for RNN layers
, num_rnn_layers * num_directions, batch_size, hidden_size
Return
s
:
init_state_h_box(Tensor): init_states h for RNN layers
: [num_rnn_layers * num_directions, batch_size, hidden_size]
init_state_c_box(Tensor): init_states c for RNN layers
: [num_rnn_layers * num_directions, batch_size, hidden_size]
Return:
x (Tensor): encoder outputs, [B, size, D]
x_lens (Tensor): encoder length, [B]
final_state_h_box(Tensor): final_states h for RNN layers
, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers
, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_h_box(Tensor): final_states h for RNN layers
: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers
: [num_rnn_layers * num_directions, batch_size, hidden_size]
"""
if
init_state_h_box
is
not
None
:
init_state_list
=
None
...
...
@@ -142,7 +142,7 @@ class CRNNEncoder(nn.Layer):
if
self
.
use_gru
is
True
:
final_chunk_state_h_box
=
paddle
.
concat
(
final_chunk_state_list
,
axis
=
0
)
final_chunk_state_c_box
=
init_state_c_box
#paddle.zeros_like(final_chunk_state_h_box)
final_chunk_state_c_box
=
init_state_c_box
else
:
final_chunk_state_h_list
=
[
final_chunk_state_list
[
i
][
0
]
for
i
in
range
(
self
.
num_rnn_layers
)
...
...
@@ -165,10 +165,10 @@ class CRNNEncoder(nn.Layer):
x_lens (Tensor): [B]
decoder_chunk_size: The chunk size of decoder
Returns:
eouts_list (List of Tensor): The list of encoder outputs in chunk_size
,
[B, chunk_size, D] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size
,
[B] * num_chunks
final_state_h_box(Tensor): final_states h for RNN layers
, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers
, num_rnn_layers * num_directions, batch_size, hidden_size
eouts_list (List of Tensor): The list of encoder outputs in chunk_size
:
[B, chunk_size, D] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size
:
[B] * num_chunks
final_state_h_box(Tensor): final_states h for RNN layers
: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers
: [num_rnn_layers * num_directions, batch_size, hidden_size]
"""
subsampling_rate
=
self
.
conv
.
subsampling_rate
receptive_field_length
=
self
.
conv
.
receptive_field_length
...
...
@@ -215,12 +215,14 @@ class CRNNEncoder(nn.Layer):
class
DeepSpeech2ModelOnline
(
nn
.
Layer
):
"""The DeepSpeech2 network structure for online.
:param audio
_data
: Audio spectrogram data layer.
:type audio
_data
: Variable
:param text
_data
: Transcription text data layer.
:type text
_data
: Variable
:param audio: Audio spectrogram data layer.
:type audio: Variable
:param text: Transcription text data layer.
:type text: Variable
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:param feat_size: feature size for audio.
:type feat_size: int
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
...
...
deepspeech/models/u2_st.py
浏览文件 @
03a50d7b
...
...
@@ -417,32 +417,32 @@ class U2STBaseModel(nn.Layer):
best_hyps
=
best_hyps
[:,
1
:]
return
best_hyps
@
jit
.
export
@
jit
.
to_static
def
subsampling_rate
(
self
)
->
int
:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return
self
.
encoder
.
embed
.
subsampling_rate
@
jit
.
export
@
jit
.
to_static
def
right_context
(
self
)
->
int
:
""" Export interface for c++ call, return right_context of the model
"""
return
self
.
encoder
.
embed
.
right_context
@
jit
.
export
@
jit
.
to_static
def
sos_symbol
(
self
)
->
int
:
""" Export interface for c++ call, return sos symbol id of the model
"""
return
self
.
sos
@
jit
.
export
@
jit
.
to_static
def
eos_symbol
(
self
)
->
int
:
""" Export interface for c++ call, return eos symbol id of the model
"""
return
self
.
eos
@
jit
.
export
@
jit
.
to_static
def
forward_encoder_chunk
(
self
,
xs
:
paddle
.
Tensor
,
...
...
@@ -472,7 +472,7 @@ class U2STBaseModel(nn.Layer):
xs
,
offset
,
required_cache_size
,
subsampling_cache
,
elayers_output_cache
,
conformer_cnn_cache
)
@
jit
.
export
@
jit
.
to_static
def
ctc_activation
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
...
...
@@ -483,7 +483,7 @@ class U2STBaseModel(nn.Layer):
"""
return
self
.
ctc
.
log_softmax
(
xs
)
@
jit
.
export
@
jit
.
to_static
def
forward_attention_decoder
(
self
,
hyps
:
paddle
.
Tensor
,
...
...
deepspeech/training/cli.py
浏览文件 @
03a50d7b
...
...
@@ -16,23 +16,23 @@ import argparse
def
default_argument_parser
():
r
"""A simple yet genral argument parser for experiments with parakeet.
This is used in examples with parakeet. And it is intended to be used by
other experiments with parakeet. It requires a minimal set of command line
This is used in examples with parakeet. And it is intended to be used by
other experiments with parakeet. It requires a minimal set of command line
arguments to start a training script.
The ``--config`` and ``--opts`` are used for overwrite the deault
The ``--config`` and ``--opts`` are used for overwrite the deault
configuration.
The ``--data`` and ``--output`` specifies the data path and output path.
Resuming training from existing progress at the output directory is the
The ``--data`` and ``--output`` specifies the data path and output path.
Resuming training from existing progress at the output directory is the
intended default behavior.
The ``--checkpoint_path`` specifies the checkpoint to load from.
The ``--device`` and ``--nprocs`` specifies how to run the training.
See Also
--------
parakeet.training.experiment
...
...
@@ -47,28 +47,24 @@ def default_argument_parser():
# data and output
parser
.
add_argument
(
"--config"
,
metavar
=
"FILE"
,
help
=
"path of the config file to overwrite to default config with."
)
parser
.
add_argument
(
"--dump-config"
,
metavar
=
"FILE"
,
help
=
"dump config to yaml file."
)
# parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.")
parser
.
add_argument
(
"--output"
,
metavar
=
"OUTPUT_DIR"
,
help
=
"path to save checkpoint and logs."
)
# load from saved checkpoint
parser
.
add_argument
(
"--checkpoint_path"
,
type
=
str
,
help
=
"path of the checkpoint to load"
)
# save jit model to
parser
.
add_argument
(
"--export_path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
# save asr result to
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
# running
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
'gpu'
,
choices
=
[
"cpu"
,
"gpu"
],
help
=
"device type to use, cpu and gpu are supported."
)
parser
.
add_argument
(
"--nprocs"
,
type
=
int
,
default
=
1
,
help
=
"number of parallel processes to use."
)
# overwrite extra config and default config
# parser.add_argument("--opts", nargs=argparse.REMAINDER,
# parser.add_argument("--opts", nargs=argparse.REMAINDER,
# help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
parser
.
add_argument
(
"--opts"
,
type
=
str
,
default
=
[],
nargs
=
'+'
,
help
=
"options to overwrite --config file and the default config, passing in KEY VALUE pairs"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
None
,
help
=
"seed to use for paddle, np and random. None or 0 for random, else set seed."
)
# yapd: enable
return
parser
deepspeech/training/extensions/__init__.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Callable
from
.extension
import
Extension
def
make_extension
(
trigger
:
Callable
=
None
,
default_name
:
str
=
None
,
priority
:
int
=
None
,
finalizer
:
Callable
=
None
,
initializer
:
Callable
=
None
,
on_error
:
Callable
=
None
):
"""Make an Extension-like object by injecting required attributes to it.
"""
if
trigger
is
None
:
trigger
=
Extension
.
trigger
if
priority
is
None
:
priority
=
Extension
.
priority
def
decorator
(
ext
):
ext
.
trigger
=
trigger
ext
.
default_name
=
default_name
or
ext
.
__name__
ext
.
priority
=
priority
ext
.
finalize
=
finalizer
ext
.
on_error
=
on_error
ext
.
initialize
=
initializer
return
ext
return
decorator
deepspeech/training/extensions/evaluator.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Dict
import
extension
import
paddle
from
paddle.io
import
DataLoader
from
paddle.nn
import
Layer
from
..reporter
import
DictSummary
from
..reporter
import
report
from
..reporter
import
scope
class
StandardEvaluator
(
extension
.
Extension
):
trigger
=
(
1
,
'epoch'
)
default_name
=
'validation'
priority
=
extension
.
PRIORITY_WRITER
name
=
None
def
__init__
(
self
,
model
:
Layer
,
dataloader
:
DataLoader
):
# it is designed to hold multiple models
models
=
{
"main"
:
model
}
self
.
models
:
Dict
[
str
,
Layer
]
=
models
self
.
model
=
model
# dataloaders
self
.
dataloader
=
dataloader
def
evaluate_core
(
self
,
batch
):
# compute
self
.
model
(
batch
)
# you may report here
def
evaluate
(
self
):
# switch to eval mode
for
model
in
self
.
models
.
values
():
model
.
eval
()
# to average evaluation metrics
summary
=
DictSummary
()
for
batch
in
self
.
dataloader
:
observation
=
{}
with
scope
(
observation
):
# main evaluation computation here.
with
paddle
.
no_grad
():
self
.
evaluate_core
(
batch
)
summary
.
add
(
observation
)
summary
=
summary
.
compute_mean
()
return
summary
def
__call__
(
self
,
trainer
=
None
):
# evaluate and report the averaged metric to current observation
# if it is used to extend a trainer, the metrics is reported to
# to observation of the trainer
# or otherwise, you can use your own observation
summary
=
self
.
evaluate
()
for
k
,
v
in
summary
.
items
():
report
(
k
,
v
)
deepspeech/training/extensions/extension.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
PRIORITY_WRITER
=
300
PRIORITY_EDITOR
=
200
PRIORITY_READER
=
100
class
Extension
():
"""Extension to customize the behavior of Trainer."""
trigger
=
(
1
,
'iteration'
)
priority
=
PRIORITY_READER
name
=
None
@
property
def
default_name
(
self
):
"""Default name of the extension, class name by default."""
return
type
(
self
).
__name__
def
__call__
(
self
,
trainer
):
"""Main action of the extention. After each update, it is executed
when the trigger fires."""
raise
NotImplementedError
(
'Extension implementation must override __call__.'
)
def
initialize
(
self
,
trainer
):
"""Action that is executed once to get the corect trainer state.
It is called before training normally, but if the trainer restores
states with an Snapshot extension, this method should also be called.
"""
pass
def
on_error
(
self
,
trainer
,
exc
,
tb
):
"""Handles the error raised during training before finalization.
"""
pass
def
finalize
(
self
,
trainer
):
"""Action that is executed when training is done.
For example, visualizers would need to be closed.
"""
pass
deepspeech/training/extensions/snapshot.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
datetime
import
datetime
from
pathlib
import
Path
from
typing
import
Any
from
typing
import
Dict
from
typing
import
List
import
jsonlines
from
deepspeech.training.extensions
import
extension
from
deepspeech.training.updaters.trainer
import
Trainer
from
deepspeech.utils.log
import
Log
from
deepspeech.utils.mp_tools
import
rank_zero_only
logger
=
Log
(
__name__
).
getlog
()
def
load_records
(
records_fp
):
"""Load record files (json lines.)"""
with
jsonlines
.
open
(
records_fp
,
'r'
)
as
reader
:
records
=
list
(
reader
)
return
records
class
Snapshot
(
extension
.
Extension
):
"""An extension to make snapshot of the updater object inside
the trainer. It is done by calling the updater's `save` method.
An Updater save its state_dict by default, which contains the
updater state, (i.e. epoch and iteration) and all the model
parameters and optimizer states. If the updater inside the trainer
subclasses StandardUpdater, everything is good to go.
Parameters
----------
checkpoint_dir : Union[str, Path]
The directory to save checkpoints into.
"""
trigger
=
(
1
,
'epoch'
)
priority
=
-
100
default_name
=
"snapshot"
def
__init__
(
self
,
max_size
:
int
=
5
,
snapshot_on_error
:
bool
=
False
):
self
.
records
:
List
[
Dict
[
str
,
Any
]]
=
[]
self
.
max_size
=
max_size
self
.
_snapshot_on_error
=
snapshot_on_error
self
.
_save_all
=
(
max_size
==
-
1
)
self
.
checkpoint_dir
=
None
def
initialize
(
self
,
trainer
:
Trainer
):
"""Setting up this extention."""
self
.
checkpoint_dir
=
trainer
.
out
/
"checkpoints"
# load existing records
record_path
:
Path
=
self
.
checkpoint_dir
/
"records.jsonl"
if
record_path
.
exists
():
logger
.
debug
(
"Loading from an existing checkpoint dir"
)
self
.
records
=
load_records
(
record_path
)
trainer
.
updater
.
load
(
self
.
records
[
-
1
][
'path'
])
def
on_error
(
self
,
trainer
,
exc
,
tb
):
if
self
.
_snapshot_on_error
:
self
.
save_checkpoint_and_update
(
trainer
)
def
__call__
(
self
,
trainer
:
Trainer
):
self
.
save_checkpoint_and_update
(
trainer
)
def
full
(
self
):
"""Whether the number of snapshots it keeps track of is greater
than the max_size."""
return
(
not
self
.
_save_all
)
and
len
(
self
.
records
)
>
self
.
max_size
@
rank_zero_only
def
save_checkpoint_and_update
(
self
,
trainer
:
Trainer
):
"""Saving new snapshot and remove the oldest snapshot if needed."""
iteration
=
trainer
.
updater
.
state
.
iteration
epoch
=
trainer
.
updater
.
state
.
epoch
num
=
epoch
if
self
.
trigger
[
1
]
==
'epoch'
else
iteration
path
=
self
.
checkpoint_dir
/
f
"
{
num
}
.pdz"
# add the new one
trainer
.
updater
.
save
(
path
)
record
=
{
"time"
:
str
(
datetime
.
now
()),
'path'
:
str
(
path
.
resolve
()),
# use absolute path
'iteration'
:
iteration
,
'epoch'
:
epoch
,
}
self
.
records
.
append
(
record
)
# remove the earist
if
self
.
full
():
eariest_record
=
self
.
records
[
0
]
os
.
remove
(
eariest_record
[
"path"
])
self
.
records
.
pop
(
0
)
# update the record file
record_path
=
self
.
checkpoint_dir
/
"records.jsonl"
with
jsonlines
.
open
(
record_path
,
'w'
)
as
writer
:
for
record
in
self
.
records
:
# jsonlines.open may return a Writer or a Reader
writer
.
write
(
record
)
# pylint: disable=no-member
deepspeech/training/extensions/visualizer.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
deepspeech.training.extensions
import
extension
from
deepspeech.training.updaters.trainer
import
Trainer
class
VisualDL
(
extension
.
Extension
):
"""A wrapper of visualdl log writer. It assumes that the metrics to be visualized
are all scalars which are recorded into the `.observation` dictionary of the
trainer object. The dictionary is created for each step, thus the visualdl log
writer uses the iteration from the updater's `iteration` as the global step to
add records.
"""
trigger
=
(
1
,
'iteration'
)
default_name
=
'visualdl'
priority
=
extension
.
PRIORITY_READER
def
__init__
(
self
,
writer
):
self
.
writer
=
writer
def
__call__
(
self
,
trainer
:
Trainer
):
for
k
,
v
in
trainer
.
observation
.
items
():
self
.
writer
.
add_scalar
(
k
,
v
,
step
=
trainer
.
updater
.
state
.
iteration
)
def
finalize
(
self
,
trainer
):
self
.
writer
.
close
()
deepspeech/training/reporter.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
contextlib
import
math
from
collections
import
defaultdict
OBSERVATIONS
=
None
@
contextlib
.
contextmanager
def
scope
(
observations
):
# make `observation` the target to report to.
# it is basically a dictionary that stores temporary observations
global
OBSERVATIONS
old
=
OBSERVATIONS
OBSERVATIONS
=
observations
try
:
yield
finally
:
OBSERVATIONS
=
old
def
get_observations
():
global
OBSERVATIONS
return
OBSERVATIONS
def
report
(
name
,
value
):
# a simple function to report named value
# you can use it everywhere, it will get the default target and writ to it
# you can think of it as std.out
observations
=
get_observations
()
if
observations
is
None
:
return
else
:
observations
[
name
]
=
value
class
Summary
():
"""Online summarization of a sequence of scalars.
Summary computes the statistics of given scalars online.
"""
def
__init__
(
self
):
self
.
_x
=
0.0
self
.
_x2
=
0.0
self
.
_n
=
0
def
add
(
self
,
value
,
weight
=
1
):
"""Adds a scalar value.
Args:
value: Scalar value to accumulate. It is either a NumPy scalar or
a zero-dimensional array (on CPU or GPU).
weight: An optional weight for the value. It is a NumPy scalar or
a zero-dimensional array (on CPU or GPU).
Default is 1 (integer).
"""
self
.
_x
+=
weight
*
value
self
.
_x2
+=
weight
*
value
*
value
self
.
_n
+=
weight
def
compute_mean
(
self
):
"""Computes the mean."""
x
,
n
=
self
.
_x
,
self
.
_n
return
x
/
n
def
make_statistics
(
self
):
"""Computes and returns the mean and standard deviation values.
Returns:
tuple: Mean and standard deviation values.
"""
x
,
n
=
self
.
_x
,
self
.
_n
mean
=
x
/
n
var
=
self
.
_x2
/
n
-
mean
*
mean
std
=
math
.
sqrt
(
var
)
return
mean
,
std
class
DictSummary
():
"""Online summarization of a sequence of dictionaries.
``DictSummary`` computes the statistics of a given set of scalars online.
It only computes the statistics for scalar values and variables of scalar
values in the dictionaries.
"""
def
__init__
(
self
):
self
.
_summaries
=
defaultdict
(
Summary
)
def
add
(
self
,
d
):
"""Adds a dictionary of scalars.
Args:
d (dict): Dictionary of scalars to accumulate. Only elements of
scalars, zero-dimensional arrays, and variables of
zero-dimensional arrays are accumulated. When the value
is a tuple, the second element is interpreted as a weight.
"""
summaries
=
self
.
_summaries
for
k
,
v
in
d
.
items
():
w
=
1
if
isinstance
(
v
,
tuple
):
v
=
v
[
0
]
w
=
v
[
1
]
summaries
[
k
].
add
(
v
,
weight
=
w
)
def
compute_mean
(
self
):
"""Creates a dictionary of mean values.
It returns a single dictionary that holds a mean value for each entry
added to the summary.
Returns:
dict: Dictionary of mean values.
"""
return
{
name
:
summary
.
compute_mean
()
for
name
,
summary
in
self
.
_summaries
.
items
()
}
def
make_statistics
(
self
):
"""Creates a dictionary of statistics.
It returns a single dictionary that holds mean and standard deviation
values for every entry added to the summary. For an entry of name
``'key'``, these values are added to the dictionary by names ``'key'``
and ``'key.std'``, respectively.
Returns:
dict: Dictionary of statistics of all entries.
"""
stats
=
{}
for
name
,
summary
in
self
.
_summaries
.
items
():
mean
,
std
=
summary
.
make_statistics
()
stats
[
name
]
=
mean
stats
[
name
+
'.std'
]
=
std
return
stats
deepspeech/training/trainer.py
浏览文件 @
03a50d7b
...
...
@@ -21,6 +21,7 @@ from tensorboardX import SummaryWriter
from
deepspeech.utils
import
mp_tools
from
deepspeech.utils.checkpoint
import
Checkpoint
from
deepspeech.utils.log
import
Log
from
deepspeech.utils.utility
import
seed_all
__all__
=
[
"Trainer"
]
...
...
@@ -94,6 +95,10 @@ class Trainer():
self
.
iteration
=
0
self
.
epoch
=
0
if
args
.
seed
:
seed_all
(
args
.
seed
)
logger
.
info
(
f
"Set seed
{
args
.
seed
}
"
)
def
setup
(
self
):
"""Setup the experiment.
"""
...
...
@@ -172,8 +177,10 @@ class Trainer():
"""Reset the train loader seed and increment `epoch`.
"""
self
.
epoch
+=
1
if
self
.
parallel
:
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
if
self
.
parallel
and
hasattr
(
self
.
train_loader
,
"batch_sampler"
):
batch_sampler
=
self
.
train_loader
.
batch_sampler
if
isinstance
(
batch_sampler
,
paddle
.
io
.
DistributedBatchSampler
):
batch_sampler
.
set_epoch
(
self
.
epoch
)
def
train
(
self
):
"""The training process control by epoch."""
...
...
@@ -182,7 +189,7 @@ class Trainer():
# save init model, i.e. 0 epoch
self
.
save
(
tag
=
'init'
,
infos
=
None
)
self
.
lr_scheduler
.
step
(
self
.
epoch
)
if
self
.
parallel
:
if
self
.
parallel
and
hasattr
(
self
.
train_loader
,
"batch_sampler"
)
:
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
...
...
deepspeech/training/triggers/__init__.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.interval_trigger
import
IntervalTrigger
def
never_fail_trigger
(
trainer
):
return
False
def
get_trigger
(
trigger
):
if
trigger
is
None
:
return
never_fail_trigger
if
callable
(
trigger
):
return
trigger
else
:
trigger
=
IntervalTrigger
(
*
trigger
)
return
trigger
deepspeech/training/triggers/interval_trigger.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
IntervalTrigger
():
"""A Predicate to do something every N cycle."""
def
__init__
(
self
,
period
:
int
,
unit
:
str
):
if
unit
not
in
(
"iteration"
,
"epoch"
):
raise
ValueError
(
"unit should be 'iteration' or 'epoch'"
)
if
period
<=
0
:
raise
ValueError
(
"period should be a positive integer."
)
self
.
period
=
period
self
.
unit
=
unit
self
.
last_index
=
None
def
__call__
(
self
,
trainer
):
if
self
.
last_index
is
None
:
last_index
=
getattr
(
trainer
.
updater
.
state
,
self
.
unit
)
self
.
last_index
=
last_index
last_index
=
self
.
last_index
index
=
getattr
(
trainer
.
updater
.
state
,
self
.
unit
)
fire
=
index
//
self
.
period
!=
last_index
//
self
.
period
self
.
last_index
=
index
return
fire
deepspeech/training/triggers/limit_trigger.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
LimitTrigger
():
"""A Predicate to decide whether to stop."""
def
__init__
(
self
,
limit
:
int
,
unit
:
str
):
if
unit
not
in
(
"iteration"
,
"epoch"
):
raise
ValueError
(
"unit should be 'iteration' or 'epoch'"
)
if
limit
<=
0
:
raise
ValueError
(
"limit should be a positive integer."
)
self
.
limit
=
limit
self
.
unit
=
unit
def
__call__
(
self
,
trainer
):
state
=
trainer
.
updater
.
state
index
=
getattr
(
state
,
self
.
unit
)
fire
=
index
>=
self
.
limit
return
fire
deepspeech/training/triggers/time_trigger.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
TimeTrigger
():
"""Trigger based on a fixed time interval.
This trigger accepts iterations with a given interval time.
Args:
period (float): Interval time. It is given in seconds.
"""
def
__init__
(
self
,
period
):
self
.
_period
=
period
self
.
_next_time
=
self
.
_period
def
__call__
(
self
,
trainer
):
if
self
.
_next_time
<
trainer
.
elapsed_time
:
self
.
_next_time
+=
self
.
_period
return
True
else
:
return
False
deepspeech/training/updaters/__init__.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
deepspeech/training/updaters/standard_updater.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Dict
from
typing
import
Optional
from
paddle
import
Tensor
from
paddle.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
from
paddle.nn
import
Layer
from
paddle.optimizer
import
Optimizer
from
timer
import
timer
from
deepspeech.training.reporter
import
report
from
deepspeech.training.updaters.updater
import
UpdaterBase
from
deepspeech.training.updaters.updater
import
UpdaterState
from
deepspeech.utils.log
import
Log
__all__
=
[
"StandardUpdater"
]
logger
=
Log
(
__name__
).
getlog
()
class
StandardUpdater
(
UpdaterBase
):
"""An example of over-simplification. Things may not be that simple, but
you can subclass it to fit your need.
"""
def
__init__
(
self
,
model
:
Layer
,
optimizer
:
Optimizer
,
dataloader
:
DataLoader
,
init_state
:
Optional
[
UpdaterState
]
=
None
):
# it is designed to hold multiple models
models
=
{
"main"
:
model
}
self
.
models
:
Dict
[
str
,
Layer
]
=
models
self
.
model
=
model
# it is designed to hold multiple optimizers
optimizers
=
{
"main"
:
optimizer
}
self
.
optimizer
=
optimizer
self
.
optimizers
:
Dict
[
str
,
Optimizer
]
=
optimizers
# dataloaders
self
.
dataloader
=
dataloader
# init state
if
init_state
is
None
:
self
.
state
=
UpdaterState
()
else
:
self
.
state
=
init_state
self
.
train_iterator
=
iter
(
dataloader
)
def
update
(
self
):
# We increase the iteration index after updating and before extension.
# Here are the reasons.
# 0. Snapshotting(as well as other extensions, like visualizer) is
# executed after a step of updating;
# 1. We decide to increase the iteration index after updating and
# before any all extension is executed.
# 3. We do not increase the iteration after extension because we
# prefer a consistent resume behavior, when load from a
# `snapshot_iter_100.pdz` then the next step to train is `101`,
# naturally. But if iteration is increased increased after
# extension(including snapshot), then, a `snapshot_iter_99` is
# loaded. You would need a extra increasing of the iteration idex
# before training to avoid another iteration `99`, which has been
# done before snapshotting.
# 4. Thus iteration index represrnts "currently how mant epochs has
# been done."
# NOTE: use report to capture the correctly value. If you want to
# report the learning rate used for a step, you must report it before
# the learning rate scheduler's step() has been called. In paddle's
# convention, we do not use an extension to change the learning rate.
# so if you want to report it, do it in the updater.
# Then here comes the next question. When is the proper time to
# increase the epoch index? Since all extensions are executed after
# updating, it is the time that after updating is the proper time to
# increase epoch index.
# 1. If we increase the epoch index before updating, then an extension
# based ot epoch would miss the correct timing. It could only be
# triggerd after an extra updating.
# 2. Theoretically, when an epoch is done, the epoch index should be
# increased. So it would be increase after updating.
# 3. Thus, eppoch index represents "currently how many epochs has been
# done." So it starts from 0.
# switch to training mode
for
model
in
self
.
models
.
values
():
model
.
train
()
# training for a step is implemented here
batch
=
self
.
read_batch
()
self
.
update_core
(
batch
)
self
.
state
.
iteration
+=
1
if
self
.
updates_per_epoch
is
not
None
:
if
self
.
state
.
iteration
%
self
.
updates_per_epoch
==
0
:
self
.
state
.
epoch
+=
1
def
update_core
(
self
,
batch
):
"""A simple case for a training step. Basic assumptions are:
Single model;
Single optimizer;
A batch from the dataloader is just the input of the model;
The model return a single loss, or a dict containing serval losses.
Parameters updates at every batch, no gradient accumulation.
"""
loss
=
self
.
model
(
*
batch
)
if
isinstance
(
loss
,
Tensor
):
loss_dict
=
{
"main"
:
loss
}
else
:
# Dict[str, Tensor]
loss_dict
=
loss
if
"main"
not
in
loss_dict
:
main_loss
=
0
for
loss_item
in
loss
.
values
():
main_loss
+=
loss_item
loss_dict
[
"main"
]
=
main_loss
for
name
,
loss_item
in
loss_dict
.
items
():
report
(
name
,
float
(
loss_item
))
self
.
optimizer
.
clear_gradient
()
loss_dict
[
"main"
].
backward
()
self
.
optimizer
.
update
()
@
property
def
updates_per_epoch
(
self
):
"""Number of updater per epoch, determined by the length of the
dataloader."""
length_of_dataloader
=
None
try
:
length_of_dataloader
=
len
(
self
.
dataloader
)
except
TypeError
:
logger
.
debug
(
"This dataloader has no __len__."
)
finally
:
return
length_of_dataloader
def
new_epoch
(
self
):
"""Start a new epoch."""
# NOTE: all batch sampler for distributed training should
# subclass DistributedBatchSampler and implement `set_epoch` method
if
hasattr
(
self
.
dataloader
,
"batch_sampler"
):
batch_sampler
=
self
.
dataloader
.
batch_sampler
if
isinstance
(
batch_sampler
,
DistributedBatchSampler
):
batch_sampler
.
set_epoch
(
self
.
state
.
epoch
)
self
.
train_iterator
=
iter
(
self
.
dataloader
)
def
read_batch
(
self
):
"""Read a batch from the data loader, auto renew when data is exhausted."""
with
timer
()
as
t
:
try
:
batch
=
next
(
self
.
train_iterator
)
except
StopIteration
:
self
.
new_epoch
()
batch
=
next
(
self
.
train_iterator
)
logger
.
debug
(
f
"Read a batch takes
{
t
.
elapse
}
s."
)
# replace it with logger
return
batch
def
state_dict
(
self
):
"""State dict of a Updater, model, optimizer and updater state are included."""
state_dict
=
super
().
state_dict
()
for
name
,
model
in
self
.
models
.
items
():
state_dict
[
f
"
{
name
}
_params"
]
=
model
.
state_dict
()
for
name
,
optim
in
self
.
optimizers
.
items
():
state_dict
[
f
"
{
name
}
_optimizer"
]
=
optim
.
state_dict
()
return
state_dict
def
set_state_dict
(
self
,
state_dict
):
"""Set state dict for a Updater. Parameters of models, states for
optimizers and UpdaterState are restored."""
for
name
,
model
in
self
.
models
.
items
():
model
.
set_state_dict
(
state_dict
[
f
"
{
name
}
_params"
])
for
name
,
optim
in
self
.
optimizers
.
items
():
optim
.
set_state_dict
(
state_dict
[
f
"
{
name
}
_optimizer"
])
super
().
set_state_dict
(
state_dict
)
deepspeech/training/updaters/trainer.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
traceback
from
collections
import
OrderedDict
from
pathlib
import
Path
from
typing
import
Callable
from
typing
import
List
from
typing
import
Union
import
six
import
tqdm
from
deepspeech.training.extensions.extension
import
Extension
from
deepspeech.training.extensions.extension
import
PRIORITY_READER
from
deepspeech.training.reporter
import
scope
from
deepspeech.training.triggers
import
get_trigger
from
deepspeech.training.triggers.limit_trigger
import
LimitTrigger
from
deepspeech.training.updaters.updater
import
UpdaterBase
class
_ExtensionEntry
():
def
__init__
(
self
,
extension
,
trigger
,
priority
):
self
.
extension
=
extension
self
.
trigger
=
trigger
self
.
priority
=
priority
class
Trainer
():
def
__init__
(
self
,
updater
:
UpdaterBase
,
stop_trigger
:
Callable
=
None
,
out
:
Union
[
str
,
Path
]
=
'result'
,
extensions
:
List
[
Extension
]
=
None
):
self
.
updater
=
updater
self
.
extensions
=
OrderedDict
()
self
.
stop_trigger
=
LimitTrigger
(
*
stop_trigger
)
self
.
out
=
Path
(
out
)
self
.
observation
=
None
self
.
_done
=
False
if
extensions
:
for
ext
in
extensions
:
self
.
extend
(
ext
)
@
property
def
is_before_training
(
self
):
return
self
.
updater
.
state
.
iteration
==
0
def
extend
(
self
,
extension
,
name
=
None
,
trigger
=
None
,
priority
=
None
):
# get name for the extension
# argument \
# -> extention's name \
# -> default_name (class name, when it is an object) \
# -> function name when it is a function \
# -> error
if
name
is
None
:
name
=
getattr
(
extension
,
'name'
,
None
)
if
name
is
None
:
name
=
getattr
(
extension
,
'default_name'
,
None
)
if
name
is
None
:
name
=
getattr
(
extension
,
'__name__'
,
None
)
if
name
is
None
:
raise
ValueError
(
"Name is not given for the extension."
)
if
name
==
'training'
:
raise
ValueError
(
"training is a reserved name."
)
if
trigger
is
None
:
trigger
=
getattr
(
extension
,
'trigger'
,
(
1
,
'iteration'
))
trigger
=
get_trigger
(
trigger
)
if
priority
is
None
:
priority
=
getattr
(
extension
,
'priority'
,
PRIORITY_READER
)
# add suffix to avoid nameing conflict
ordinal
=
0
modified_name
=
name
while
modified_name
in
self
.
extensions
:
ordinal
+=
1
modified_name
=
f
"
{
name
}
_
{
ordinal
}
"
extension
.
name
=
modified_name
self
.
extensions
[
modified_name
]
=
_ExtensionEntry
(
extension
,
trigger
,
priority
)
def
get_extension
(
self
,
name
):
"""get extension by name."""
extensions
=
self
.
extensions
if
name
in
extensions
:
return
extensions
[
name
].
extension
else
:
raise
ValueError
(
f
'extension
{
name
}
not found'
)
def
run
(
self
):
if
self
.
_done
:
raise
RuntimeError
(
"Training is already done!."
)
self
.
out
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# sort extensions by priorities once
extension_order
=
sorted
(
self
.
extensions
.
keys
(),
key
=
lambda
name
:
self
.
extensions
[
name
].
priority
,
reverse
=
True
)
extensions
=
[(
name
,
self
.
extensions
[
name
])
for
name
in
extension_order
]
# initializing all extensions
for
name
,
entry
in
extensions
:
if
hasattr
(
entry
.
extension
,
"initialize"
):
entry
.
extension
.
initialize
(
self
)
update
=
self
.
updater
.
update
# training step
stop_trigger
=
self
.
stop_trigger
# display only one progress bar
max_iteration
=
None
if
isinstance
(
stop_trigger
,
LimitTrigger
):
if
stop_trigger
.
unit
==
'epoch'
:
max_epoch
=
self
.
stop_trigger
.
limit
updates_per_epoch
=
getattr
(
self
.
updater
,
"updates_per_epoch"
,
None
)
max_iteration
=
max_epoch
*
updates_per_epoch
if
updates_per_epoch
else
None
else
:
max_iteration
=
self
.
stop_trigger
.
limit
p
=
tqdm
.
tqdm
(
initial
=
self
.
updater
.
state
.
iteration
,
total
=
max_iteration
)
try
:
while
not
stop_trigger
(
self
):
self
.
observation
=
{}
# set observation as the report target
# you can use report freely in Updater.update()
# updating parameters and state
with
scope
(
self
.
observation
):
update
()
p
.
update
()
# execute extension when necessary
for
name
,
entry
in
extensions
:
if
entry
.
trigger
(
self
):
entry
.
extension
(
self
)
# print("###", self.observation)
except
Exception
as
e
:
f
=
sys
.
stderr
f
.
write
(
f
"Exception in main training loop:
{
e
}
\n
"
)
f
.
write
(
"Traceback (most recent call last):
\n
"
)
traceback
.
print_tb
(
sys
.
exc_info
()[
2
])
f
.
write
(
"Trainer extensions will try to handle the extension. Then all extensions will finalize."
)
# capture the exception in the mian training loop
exc_info
=
sys
.
exc_info
()
# try to handle it
for
name
,
entry
in
extensions
:
if
hasattr
(
entry
.
extension
,
"on_error"
):
try
:
entry
.
extension
.
on_error
(
self
,
e
,
sys
.
exc_info
()[
2
])
except
Exception
as
ee
:
f
.
write
(
f
"Exception in error handler:
{
ee
}
\n
"
)
f
.
write
(
'Traceback (most recent call last):
\n
'
)
traceback
.
print_tb
(
sys
.
exc_info
()[
2
])
# raise exception in main training loop
six
.
reraise
(
*
exc_info
)
finally
:
for
name
,
entry
in
extensions
:
if
hasattr
(
entry
.
extension
,
"finalize"
):
entry
.
extension
.
finalize
(
self
)
deepspeech/training/updaters/updater.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
import
paddle
from
deepspeech.utils.log
import
Log
__all__
=
[
"UpdaterBase"
,
"UpdaterState"
]
logger
=
Log
(
__name__
).
getlog
()
@
dataclass
class
UpdaterState
:
iteration
:
int
=
0
epoch
:
int
=
0
class
UpdaterBase
():
"""An updater is the abstraction of how a model is trained given the
dataloader and the optimizer.
The `update_core` method is a step in the training loop with only necessary
operations (get a batch, forward and backward, update the parameters).
Other stuffs are made extensions. Visualization, saving, loading and
periodical validation and evaluation are not considered here.
But even in such simplist case, things are not that simple. There is an
attempt to standardize this process and requires only the model and
dataset and do all the stuffs automatically. But this may hurt flexibility.
If we assume a batch yield from the dataloader is just the input to the
model, we will find that some model requires more arguments, or just some
keyword arguments. But this prevents us from over-simplifying it.
From another perspective, the batch may includes not just the input, but
also the target. But the model's forward method may just need the input.
We can pass a dict or a super-long tuple to the model and let it pick what
it really needs. But this is an abuse of lazy interface.
After all, we care about how a model is trained. But just how the model is
used for inference. We want to control how a model is trained. We just
don't want to be messed up with other auxiliary code.
So the best practice is to define a model and define a updater for it.
"""
def
__init__
(
self
,
init_state
=
None
):
if
init_state
is
None
:
self
.
state
=
UpdaterState
()
else
:
self
.
state
=
init_state
def
update
(
self
,
batch
):
raise
NotImplementedError
(
"Implement your own `update` method for training a step."
)
def
state_dict
(
self
):
state_dict
=
{
"epoch"
:
self
.
state
.
epoch
,
"iteration"
:
self
.
state
.
iteration
,
}
return
state_dict
def
set_state_dict
(
self
,
state_dict
):
self
.
state
.
epoch
=
state_dict
[
"epoch"
]
self
.
state
.
iteration
=
state_dict
[
"iteration"
]
def
save
(
self
,
path
):
logger
.
debug
(
f
"Saving to
{
path
}
."
)
archive
=
self
.
state_dict
()
paddle
.
save
(
archive
,
str
(
path
))
def
load
(
self
,
path
):
logger
.
debug
(
f
"Loading from
{
path
}
."
)
archive
=
paddle
.
load
(
str
(
path
))
self
.
set_state_dict
(
archive
)
deepspeech/utils/utility.py
浏览文件 @
03a50d7b
...
...
@@ -15,9 +15,19 @@
import
distutils.util
import
math
import
os
import
random
from
typing
import
List
__all__
=
[
'print_arguments'
,
'add_arguments'
,
"log_add"
]
import
numpy
as
np
import
paddle
__all__
=
[
"seed_all"
,
'print_arguments'
,
'add_arguments'
,
"log_add"
]
def
seed_all
(
seed
:
int
=
210329
):
np
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
paddle
.
seed
(
seed
)
def
print_arguments
(
args
,
info
=
None
):
...
...
examples/aishell/s0/README.md
浏览文件 @
03a50d7b
# Aishell-1
## Data
| Data Subset | Duration in Seconds |
| data/manifest.train | 1.23 ~ 14.53125 |
| data/manifest.dev | 1.645 ~ 12.533 |
| data/manifest.test | 1.859125 ~ 14.6999375 |
## Deepspeech2
| Model | Params | Release | Config | Test set | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug + new datapipe | test | 6.396368026733398 | 0.068382
,0.073507
|
| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug + new datapipe | test | 6.396368026733398 | 0.068382 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 |
| DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 |
...
...
examples/aishell/s0/conf/augmentation.json
浏览文件 @
03a50d7b
...
...
@@ -19,15 +19,17 @@
{
"type"
:
"specaug"
,
"params"
:
{
"W"
:
0
,
"warp_mode"
:
"PIL"
,
"F"
:
10
,
"T"
:
50
,
"n_freq_masks"
:
2
,
"T"
:
50
,
"n_time_masks"
:
2
,
"p"
:
1.0
,
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
},
"prob"
:
1.0
}
...
...
examples/aishell/s0/local/train.sh
浏览文件 @
03a50d7b
...
...
@@ -19,12 +19,22 @@ fi
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
\
--model_type
${
model_type
}
--model_type
${
model_type
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
examples/aishell/s1/conf/augmentation.json
浏览文件 @
03a50d7b
...
...
@@ -27,7 +27,9 @@
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
,
"warp_mode"
:
"PIL"
},
"prob"
:
1.0
}
...
...
examples/aishell/s1/local/train.sh
浏览文件 @
03a50d7b
...
...
@@ -19,11 +19,21 @@ echo "using ${device}..."
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
examples/aug_conf/augmentation.json
已删除
100644 → 0
浏览文件 @
8d506270
[
{
"type"
:
"shift"
,
"params"
:
{
"min_shift_ms"
:
-5
,
"max_shift_ms"
:
5
},
"prob"
:
1.0
}
]
examples/aug
_conf/augmentation.example
.json
→
examples/aug
mentation/augmentation
.json
浏览文件 @
03a50d7b
...
...
@@ -52,16 +52,18 @@
{
"type"
:
"specaug"
,
"params"
:
{
"W"
:
80
,
"warp_mode"
:
"PIL"
,
"F"
:
10
,
"T"
:
50
,
"n_freq_masks"
:
2
,
"T"
:
50
,
"n_time_masks"
:
2
,
"p"
:
1.0
,
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
false
},
"prob"
:
0
.0
"prob"
:
1
.0
}
]
examples/callcenter/s1/conf/augmentation.json
浏览文件 @
03a50d7b
...
...
@@ -27,7 +27,8 @@
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
},
"prob"
:
1.0
}
...
...
examples/callcenter/s1/local/train.sh
浏览文件 @
03a50d7b
...
...
@@ -19,11 +19,21 @@ echo "using ${device}..."
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
examples/librispeech/s0/conf/augmentation.json
浏览文件 @
03a50d7b
...
...
@@ -19,15 +19,17 @@
{
"type"
:
"specaug"
,
"params"
:
{
"W"
:
0
,
"warp_mode"
:
"PIL"
,
"F"
:
10
,
"T"
:
50
,
"n_freq_masks"
:
2
,
"T"
:
50
,
"n_time_masks"
:
2
,
"p"
:
1.0
,
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
},
"prob"
:
1.0
}
...
...
examples/librispeech/s0/local/train.sh
浏览文件 @
03a50d7b
...
...
@@ -20,12 +20,22 @@ echo "using ${device}..."
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
\
--model_type
${
model_type
}
--model_type
${
model_type
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
examples/librispeech/s1/conf/augmentation.json
浏览文件 @
03a50d7b
...
...
@@ -27,7 +27,9 @@
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
,
"warp_mode"
:
"PIL"
},
"prob"
:
1.0
}
...
...
examples/librispeech/s1/local/train.sh
浏览文件 @
03a50d7b
...
...
@@ -19,11 +19,21 @@ echo "using ${device}..."
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
examples/librispeech/s2/conf/augmentation.json
浏览文件 @
03a50d7b
...
...
@@ -2,15 +2,17 @@
{
"type"
:
"specaug"
,
"params"
:
{
"F"
:
10
,
"T"
:
50
,
"W"
:
5
,
"warp_mode"
:
"PIL"
,
"F"
:
30
,
"n_freq_masks"
:
2
,
"T"
:
40
,
"n_time_masks"
:
2
,
"p"
:
1.0
,
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
false
},
"prob"
:
1.0
}
...
...
examples/librispeech/s2/conf/transformer.yaml
浏览文件 @
03a50d7b
...
...
@@ -3,37 +3,28 @@ data:
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test-clean
min_input_len
:
0.5
# second
max_input_len
:
20.0
# second
min_output_len
:
0.0
# tokens
max_output_len
:
400.0
# tokens
min_output_input_ratio
:
0.05
max_output_input_ratio
:
10.0
collator
:
vocab_filepath
:
data/
vocab
.txt
vocab_filepath
:
data/
train_960_unigram5000_units
.txt
unit_type
:
'
spm'
spm_model_prefix
:
'
data/bpe_unigram_5000'
mean_std_filepath
:
"
"
augmentation_config
:
conf/augmentation.json
batch_size
:
64
raw_wav
:
True
# use raw_wav or kaldi feature
specgram_type
:
fbank
#linear, mfcc, fbank
spm_model_prefix
:
'
data/train_960_unigram5000'
feat_dim
:
83
delta_delta
:
False
dither
:
1.0
target_sample_rate
:
16000
max_freq
:
None
n_fft
:
None
stride_ms
:
10.0
window_ms
:
25.0
use_dB_normalization
:
True
target_dB
:
-20
random_seed
:
0
keep_transcription_text
:
False
sortagrad
:
True
shuffle_method
:
batch_shuffle
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size
:
32
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
minibatches
:
0
# for debug
batch_count
:
auto
batch_bins
:
0
batch_frames_in
:
0
batch_frames_out
:
0
batch_frames_inout
:
0
augmentation_config
:
conf/augmentation.json
num_workers
:
2
subsampling_factor
:
1
num_encs
:
1
# network architecture
...
...
examples/librispeech/s2/local/align.sh
浏览文件 @
03a50d7b
#!/bin/bash
if
[
$#
!=
2
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path
dict_path
ckpt_path_prefix"
exit
-1
fi
...
...
@@ -13,7 +13,8 @@ if [ ${ngpu} == 0 ];then
device
=
cpu
fi
config_path
=
$1
ckpt_prefix
=
$2
dict_path
=
$2
ckpt_prefix
=
$3
batch_size
=
1
output_dir
=
${
ckpt_prefix
}
...
...
@@ -22,11 +23,13 @@ mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3
-u
${
BIN_DIR
}
/test.py
\
--run_mode
'align'
\
--model-name
'u2_kaldi'
\
--run-mode
'align'
\
--dict-path
${
dict_path
}
\
--device
${
device
}
\
--nproc
1
\
--config
${
config_path
}
\
--result
_
file
${
output_dir
}
/
${
type
}
.align
\
--result
-
file
${
output_dir
}
/
${
type
}
.align
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.batch_size
${
batch_size
}
...
...
examples/librispeech/s2/local/export.sh
浏览文件 @
03a50d7b
...
...
@@ -18,7 +18,8 @@ if [ ${ngpu} == 0 ];then
fi
python3
-u
${
BIN_DIR
}
/test.py
\
--run_mode
'export'
\
--model-name
'u2_kaldi'
\
--run-mode
'export'
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
...
...
examples/librispeech/s2/local/test.sh
浏览文件 @
03a50d7b
#!/bin/bash
if
[
$#
!=
2
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path
dict_path
ckpt_path_prefix"
exit
-1
fi
...
...
@@ -14,7 +14,8 @@ if [ ${ngpu} == 0 ];then
fi
config_path
=
$1
ckpt_prefix
=
$2
dict_path
=
$2
ckpt_prefix
=
$3
chunk_mode
=
false
if
[[
${
config_path
}
=
~ ^.
*
chunk_.
*
yaml
$
]]
;
then
...
...
@@ -38,11 +39,13 @@ for type in attention ctc_greedy_search; do
batch_size
=
64
fi
python3
-u
${
BIN_DIR
}
/test.py
\
--run_mode
test
\
--model-name
u2_kaldi
\
--run-mode
test
\
--dict-path
${
dict_path
}
\
--device
${
device
}
\
--nproc
1
\
--config
${
config_path
}
\
--result
_
file
${
ckpt_prefix
}
.
${
type
}
.rsl
\
--result
-
file
${
ckpt_prefix
}
.
${
type
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.decoding_method
${
type
}
decoding.batch_size
${
batch_size
}
...
...
@@ -56,11 +59,13 @@ for type in ctc_prefix_beam_search attention_rescoring; do
echo
"decoding
${
type
}
"
batch_size
=
1
python3
-u
${
BIN_DIR
}
/test.py
\
--run_mode
test
\
--model-name
u2_kaldi
\
--run-mode
test
\
--dict-path
${
dict_path
}
\
--device
${
device
}
\
--nproc
1
\
--config
${
config_path
}
\
--result
_
file
${
ckpt_prefix
}
.
${
type
}
.rsl
\
--result
-
file
${
ckpt_prefix
}
.
${
type
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.decoding_method
${
type
}
decoding.batch_size
${
batch_size
}
...
...
examples/librispeech/s2/local/train.sh
浏览文件 @
03a50d7b
...
...
@@ -19,12 +19,22 @@ echo "using ${device}..."
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
--model-name
u2_kaldi
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
examples/librispeech/s2/run.sh
浏览文件 @
03a50d7b
...
...
@@ -5,6 +5,7 @@ source path.sh
stage
=
0
stop_stage
=
100
conf_path
=
conf/transformer.yaml
dict_path
=
data/train_960_unigram5000_units.txt
avg_num
=
5
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
...
...
@@ -29,12 +30,12 @@ fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES
=
0 ./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
CUDA_VISIBLE_DEVICES
=
0 ./local/test.sh
${
conf_path
}
${
dict_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
fi
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES
=
0 ./local/align.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
CUDA_VISIBLE_DEVICES
=
0 ./local/align.sh
${
conf_path
}
${
dict_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
fi
if
[
${
stage
}
-le
5
]
&&
[
${
stop_stage
}
-ge
5
]
;
then
...
...
examples/ted_en_zh/t0/local/train.sh
浏览文件 @
03a50d7b
...
...
@@ -19,11 +19,21 @@ echo "using ${device}..."
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
examples/thchs30/a0/local/data.sh
浏览文件 @
03a50d7b
...
...
@@ -20,27 +20,33 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo
"Prepare THCHS-30 failed. Terminated."
exit
1
fi
fi
# dump manifest to data/
python3
${
MAIN_ROOT
}
/utils/dump_manifest.py
--manifest-path
=
data/manifest.train
--output-dir
=
data
# copy files to data/dict to gen word.lexicon
cp
${
TARGET_DIR
}
/thchs30/data_thchs30/lm_word/lexicon.txt data/dict/lm_word_lexicon_1
cp
${
TARGET_DIR
}
/thchs30/resource/dict/lexicon.txt data/dict/lm_word_lexicon_2
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
# dump manifest to data/
python3
${
MAIN_ROOT
}
/utils/dump_manifest.py
--manifest-path
=
data/manifest.train
--output-dir
=
data
fi
# copy phone.lexicon to data/dict
cp
${
TARGET_DIR
}
/thchs30/data_thchs30/lm_phone/lexicon.txt data/dict/phone.lexicon
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# copy files to data/dict to gen word.lexicon
cp
${
TARGET_DIR
}
/thchs30/data_thchs30/lm_word/lexicon.txt data/dict/lm_word_lexicon_1
cp
${
TARGET_DIR
}
/thchs30/resource/dict/lexicon.txt data/dict/lm_word_lexicon_2
# copy phone.lexicon to data/dict
cp
${
TARGET_DIR
}
/thchs30/data_thchs30/lm_phone/lexicon.txt data/dict/phone.lexicon
fi
# gen word.lexicon
python
local
/gen_word2phone.py
--root-dir
=
data/dict
--output-dir
=
data/dict
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# gen word.lexicon
python
local
/gen_word2phone.py
--lexicon-files
=
"data/dict/lm_word_lexicon_1 data/dict/lm_word_lexicon_2"
--output-path
=
data/dict/word.lexicon
fi
# reorganize dataset for MFA
if
[
!
-d
$EXP_DIR
/thchs30_corpus
]
;
then
echo
"reorganizing thchs30 corpus..."
python
local
/reorganize_thchs30.py
--root-dir
=
data
--output-dir
=
data/thchs30_corpus
--script-type
=
$LEXICON_NAME
echo
"reorganization done."
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# reorganize dataset for MFA
if
[
!
-d
$EXP_DIR
/thchs30_corpus
]
;
then
echo
"reorganizing thchs30 corpus..."
python
local
/reorganize_thchs30.py
--root-dir
=
data
--output-dir
=
data/thchs30_corpus
--script-type
=
$LEXICON_NAME
echo
"reorganization done."
fi
fi
echo
"THCHS-30 data preparation done."
...
...
examples/thchs30/a0/local/gen_word2phone.py
浏览文件 @
03a50d7b
...
...
@@ -18,6 +18,7 @@ file2: THCHS-30/resource/dict/lexicon.txt
import
argparse
from
collections
import
defaultdict
from
pathlib
import
Path
from
typing
import
List
from
typing
import
Union
# key: (cn, ('ee', 'er4')),value: count
...
...
@@ -34,7 +35,7 @@ def is_Chinese(ch):
return
False
def
proc_line
(
line
):
def
proc_line
(
line
:
str
):
line
=
line
.
strip
()
if
is_Chinese
(
line
[
0
]):
line_list
=
line
.
split
()
...
...
@@ -49,20 +50,25 @@ def proc_line(line):
cn_phones_counter
[(
cn
,
phones
)]
+=
1
def
gen_lexicon
(
root_dir
:
Union
[
str
,
Path
],
output_dir
:
Union
[
str
,
Path
]):
root_dir
=
Path
(
root_dir
).
expanduser
()
output_dir
=
Path
(
output_dir
).
expanduser
()
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
file1
=
root_dir
/
"lm_word_lexicon_1"
file2
=
root_dir
/
"lm_word_lexicon_2"
write_file
=
output_dir
/
"word.lexicon"
"""
example lines of output
the first column is a Chinese character
the second is the probability of this pronunciation
and the rest are the phones of this pronunciation
一 0.22 ii i1↩
一 0.45 ii i4↩
一 0.32 ii i2↩
一 0.01 ii i5
"""
def
gen_lexicon
(
lexicon_files
:
List
[
Union
[
str
,
Path
]],
output_path
:
Union
[
str
,
Path
]):
for
file_path
in
lexicon_files
:
with
open
(
file_path
,
"r"
)
as
f1
:
for
line
in
f1
:
proc_line
(
line
)
with
open
(
file1
,
"r"
)
as
f1
:
for
line
in
f1
:
proc_line
(
line
)
with
open
(
file2
,
"r"
)
as
f2
:
for
line
in
f2
:
proc_line
(
line
)
for
key
in
cn_phones_counter
:
cn
=
key
[
0
]
cn_counter
[
cn
].
append
((
key
[
1
],
cn_phones_counter
[
key
]))
...
...
@@ -75,7 +81,8 @@ def gen_lexicon(root_dir: Union[str, Path], output_dir: Union[str, Path]):
p
=
round
(
p
,
2
)
if
p
>
0
:
cn_counter_p
[
key
].
append
((
item
[
0
],
p
))
with
open
(
write_file
,
"w"
)
as
wf
:
with
open
(
output_path
,
"w"
)
as
wf
:
for
key
in
cn_counter_p
:
phone_p_list
=
cn_counter_p
[
key
]
for
item
in
phone_p_list
:
...
...
@@ -87,8 +94,21 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
(
description
=
"Gen Chinese characters to phone lexicon for THCHS-30 dataset"
)
# A line of word_lexicon:
# 一丁点 ii i4 d ing1 d ian3
# the first is word, and the rest are the phones of the word, and the len of phones is twice of the word's len
parser
.
add_argument
(
"--lexicon-files"
,
type
=
str
,
default
=
"data/dict/lm_word_lexicon_1 data/dict/lm_word_lexicon_2"
,
help
=
"lm_word_lexicon files"
)
parser
.
add_argument
(
"--root-dir"
,
type
=
str
,
help
=
"dir to thchs30 lm_word_lexicons"
)
parser
.
add_argument
(
"--output-dir"
,
type
=
str
,
help
=
"path to save outputs"
)
"--output-path"
,
type
=
str
,
default
=
"data/dict/word.lexicon"
,
help
=
"path to save output word2phone lexicon"
)
args
=
parser
.
parse_args
()
gen_lexicon
(
args
.
root_dir
,
args
.
output_dir
)
lexicon_files
=
args
.
lexicon_files
.
split
(
" "
)
output_path
=
Path
(
args
.
output_path
).
expanduser
()
gen_lexicon
(
lexicon_files
,
output_path
)
examples/thchs30/a0/local/reorganize_thchs30.py
浏览文件 @
03a50d7b
...
...
@@ -58,8 +58,6 @@ def write_lab(root_dir: Union[str, Path],
def
reorganize_thchs30
(
root_dir
:
Union
[
str
,
Path
],
output_dir
:
Union
[
str
,
Path
]
=
None
,
script_type
=
'phone'
):
root_dir
=
Path
(
root_dir
).
expanduser
()
output_dir
=
Path
(
output_dir
).
expanduser
()
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
link_wav
(
root_dir
,
output_dir
)
write_lab
(
root_dir
,
output_dir
,
script_type
)
...
...
@@ -72,12 +70,15 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--output-dir"
,
type
=
str
,
help
=
"path to save outputs(audio and transcriptions)"
)
help
=
"path to save outputs
(audio and transcriptions)"
)
parser
.
add_argument
(
"--script-type"
,
type
=
str
,
default
=
"phone"
,
help
=
"type of lab ('word'/'syllable'/'phone')"
)
args
=
parser
.
parse_args
()
reorganize_thchs30
(
args
.
root_dir
,
args
.
output_dir
,
args
.
script_type
)
root_dir
=
Path
(
args
.
root_dir
).
expanduser
()
output_dir
=
Path
(
args
.
output_dir
).
expanduser
()
reorganize_thchs30
(
root_dir
,
output_dir
,
args
.
script_type
)
examples/thchs30/a0/run.sh
浏览文件 @
03a50d7b
...
...
@@ -14,14 +14,17 @@ source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
# gen lexicon relink gen dump
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
# prepare data
bash ./local/data.sh
$LEXICON_NAME
||
exit
-1
echo
"Start prepare thchs30 data for MFA ..."
bash ./local/data.sh
$LEXICON_NAME
||
exit
-1
fi
# run MFA
if
[
!
-d
"
$EXP_DIR
/thchs30_alignment"
]
;
then
echo
"Start MFA training..."
mfa_train_and_align data/thchs30_corpus data/dict/
$LEXICON_NAME
.lexicon
$EXP_DIR
/thchs30_alignment
-o
$EXP_DIR
/thchs30_model
--clean
--verbose
--temp_directory
exp/.mfa_train_and_align
--num_jobs
$NUM_JOBS
echo
"training done!
\n
results:
$EXP_DIR
/thchs30_alignment
\n
model:
$EXP_DIR
/thchs30_model
\n
"
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# run MFA
if
[
!
-d
"
$EXP_DIR
/thchs30_alignment"
]
;
then
echo
"Start MFA training ..."
mfa_train_and_align data/thchs30_corpus data/dict/
$LEXICON_NAME
.lexicon
$EXP_DIR
/thchs30_alignment
-o
$EXP_DIR
/thchs30_model
--clean
--verbose
--temp_directory
exp/.mfa_train_and_align
--num_jobs
$NUM_JOBS
echo
"MFA training done!
\n
results:
$EXP_DIR
/thchs30_alignment
\n
model:
$EXP_DIR
/thchs30_model
\n
"
fi
fi
...
...
examples/timit/s1/conf/augmentation.json
浏览文件 @
03a50d7b
...
...
@@ -27,7 +27,9 @@
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
,
"warp_mode"
:
"PIL"
},
"prob"
:
1.0
}
...
...
examples/timit/s1/local/train.sh
浏览文件 @
03a50d7b
...
...
@@ -19,11 +19,21 @@ echo "using ${device}..."
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
examples/tiny/s0/conf/augmentation.json
浏览文件 @
03a50d7b
[
{
"type"
:
"speed"
,
"params"
:
{
"min_speed_rate"
:
0.9
,
"max_speed_rate"
:
1.1
,
"num_rates"
:
3
},
"prob"
:
0.0
},
{
"type"
:
"shift"
,
"params"
:
{
...
...
@@ -6,5 +15,22 @@
"max_shift_ms"
:
5
},
"prob"
:
1.0
},
{
"type"
:
"specaug"
,
"params"
:
{
"W"
:
5
,
"warp_mode"
:
"PIL"
,
"F"
:
30
,
"n_freq_masks"
:
2
,
"T"
:
40
,
"n_time_masks"
:
2
,
"p"
:
1.0
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
},
"prob"
:
1.0
}
]
examples/tiny/s0/local/train.sh
浏览文件 @
03a50d7b
...
...
@@ -19,12 +19,22 @@ fi
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
\
--model_type
${
model_type
}
--model_type
${
model_type
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
examples/tiny/s1/conf/augmentation.json
浏览文件 @
03a50d7b
...
...
@@ -27,7 +27,9 @@
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
,
"warp_mode"
:
"PIL"
},
"prob"
:
1.0
}
...
...
examples/tiny/s1/local/train.sh
浏览文件 @
03a50d7b
...
...
@@ -18,11 +18,21 @@ fi
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
requirements.txt
浏览文件 @
03a50d7b
coverage
gpustat
jsonlines
kaldiio
Pillow
pre-commit
pybind11
resampy
==0.2.2
...
...
tools/extras/install_mfa.sh
浏览文件 @
03a50d7b
...
...
@@ -4,7 +4,7 @@
test
-d
Montreal-Forced-Aligner
||
git clone https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner.git
pushd
Montreal-Forced-Aligner
&&
git checkout v2.0.0a7
&&
python setup.py
install
pushd
Montreal-Forced-Aligner
&&
python setup.py
install
&&
popd
test
-d
kaldi
||
{
echo
"need install kaldi first"
;
exit
1
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录