Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
03a50d7b
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
03a50d7b
编写于
8月 26, 2021
作者:
H
huangyuxin
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/DeepSpeech
into fix_bug
上级
8d506270
794294e9
变更
83
隐藏空白更改
内联
并排
Showing
83 changed file
with
1776 addition
and
438 deletion
+1776
-438
.flake8
.flake8
+4
-0
deepspeech/__init__.py
deepspeech/__init__.py
+1
-40
deepspeech/decoders/swig/setup.py
deepspeech/decoders/swig/setup.py
+5
-2
deepspeech/exps/deepspeech2/bin/export.py
deepspeech/exps/deepspeech2/bin/export.py
+3
-0
deepspeech/exps/deepspeech2/bin/test.py
deepspeech/exps/deepspeech2/bin/test.py
+3
-0
deepspeech/exps/u2/bin/alignment.py
deepspeech/exps/u2/bin/alignment.py
+3
-0
deepspeech/exps/u2/bin/export.py
deepspeech/exps/u2/bin/export.py
+3
-0
deepspeech/exps/u2/bin/test.py
deepspeech/exps/u2/bin/test.py
+3
-0
deepspeech/exps/u2_kaldi/bin/test.py
deepspeech/exps/u2_kaldi/bin/test.py
+10
-0
deepspeech/exps/u2_kaldi/model.py
deepspeech/exps/u2_kaldi/model.py
+26
-17
deepspeech/exps/u2_st/bin/export.py
deepspeech/exps/u2_st/bin/export.py
+3
-0
deepspeech/exps/u2_st/bin/test.py
deepspeech/exps/u2_st/bin/test.py
+3
-0
deepspeech/frontend/augmentor/impulse_response.py
deepspeech/frontend/augmentor/impulse_response.py
+1
-1
deepspeech/frontend/augmentor/noise_perturb.py
deepspeech/frontend/augmentor/noise_perturb.py
+1
-1
deepspeech/frontend/augmentor/online_bayesian_normalization.py
...peech/frontend/augmentor/online_bayesian_normalization.py
+1
-1
deepspeech/frontend/augmentor/resample.py
deepspeech/frontend/augmentor/resample.py
+1
-1
deepspeech/frontend/augmentor/shift_perturb.py
deepspeech/frontend/augmentor/shift_perturb.py
+1
-1
deepspeech/frontend/augmentor/spec_augment.py
deepspeech/frontend/augmentor/spec_augment.py
+98
-19
deepspeech/frontend/augmentor/speed_perturb.py
deepspeech/frontend/augmentor/speed_perturb.py
+1
-1
deepspeech/frontend/augmentor/volume_perturb.py
deepspeech/frontend/augmentor/volume_perturb.py
+1
-1
deepspeech/frontend/featurizer/__init__.py
deepspeech/frontend/featurizer/__init__.py
+3
-0
deepspeech/frontend/featurizer/audio_featurizer.py
deepspeech/frontend/featurizer/audio_featurizer.py
+48
-38
deepspeech/frontend/featurizer/speech_featurizer.py
deepspeech/frontend/featurizer/speech_featurizer.py
+1
-1
deepspeech/frontend/featurizer/text_featurizer.py
deepspeech/frontend/featurizer/text_featurizer.py
+28
-45
deepspeech/frontend/normalizer.py
deepspeech/frontend/normalizer.py
+7
-7
deepspeech/frontend/utility.py
deepspeech/frontend/utility.py
+40
-10
deepspeech/io/collator.py
deepspeech/io/collator.py
+1
-2
deepspeech/io/collator_st.py
deepspeech/io/collator_st.py
+35
-69
deepspeech/io/converter.py
deepspeech/io/converter.py
+1
-1
deepspeech/io/dataloader.py
deepspeech/io/dataloader.py
+18
-6
deepspeech/io/dataset.py
deepspeech/io/dataset.py
+6
-5
deepspeech/models/ds2_online/deepspeech2.py
deepspeech/models/ds2_online/deepspeech2.py
+16
-14
deepspeech/models/u2_st.py
deepspeech/models/u2_st.py
+7
-7
deepspeech/training/cli.py
deepspeech/training/cli.py
+16
-20
deepspeech/training/extensions/__init__.py
deepspeech/training/extensions/__init__.py
+41
-0
deepspeech/training/extensions/evaluator.py
deepspeech/training/extensions/evaluator.py
+71
-0
deepspeech/training/extensions/extension.py
deepspeech/training/extensions/extension.py
+52
-0
deepspeech/training/extensions/snapshot.py
deepspeech/training/extensions/snapshot.py
+114
-0
deepspeech/training/extensions/visualizer.py
deepspeech/training/extensions/visualizer.py
+37
-0
deepspeech/training/reporter.py
deepspeech/training/reporter.py
+144
-0
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+10
-3
deepspeech/training/triggers/__init__.py
deepspeech/training/triggers/__init__.py
+28
-0
deepspeech/training/triggers/interval_trigger.py
deepspeech/training/triggers/interval_trigger.py
+38
-0
deepspeech/training/triggers/limit_trigger.py
deepspeech/training/triggers/limit_trigger.py
+31
-0
deepspeech/training/triggers/time_trigger.py
deepspeech/training/triggers/time_trigger.py
+32
-0
deepspeech/training/updaters/__init__.py
deepspeech/training/updaters/__init__.py
+13
-0
deepspeech/training/updaters/standard_updater.py
deepspeech/training/updaters/standard_updater.py
+192
-0
deepspeech/training/updaters/trainer.py
deepspeech/training/updaters/trainer.py
+184
-0
deepspeech/training/updaters/updater.py
deepspeech/training/updaters/updater.py
+83
-0
deepspeech/utils/utility.py
deepspeech/utils/utility.py
+11
-1
examples/aishell/s0/README.md
examples/aishell/s0/README.md
+7
-1
examples/aishell/s0/conf/augmentation.json
examples/aishell/s0/conf/augmentation.json
+5
-3
examples/aishell/s0/local/train.sh
examples/aishell/s0/local/train.sh
+11
-1
examples/aishell/s1/conf/augmentation.json
examples/aishell/s1/conf/augmentation.json
+3
-1
examples/aishell/s1/local/train.sh
examples/aishell/s1/local/train.sh
+11
-1
examples/aug_conf/augmentation.json
examples/aug_conf/augmentation.json
+0
-10
examples/augmentation/augmentation.json
examples/augmentation/augmentation.json
+6
-4
examples/callcenter/s1/conf/augmentation.json
examples/callcenter/s1/conf/augmentation.json
+2
-1
examples/callcenter/s1/local/train.sh
examples/callcenter/s1/local/train.sh
+11
-1
examples/librispeech/s0/conf/augmentation.json
examples/librispeech/s0/conf/augmentation.json
+5
-3
examples/librispeech/s0/local/train.sh
examples/librispeech/s0/local/train.sh
+11
-1
examples/librispeech/s1/conf/augmentation.json
examples/librispeech/s1/conf/augmentation.json
+3
-1
examples/librispeech/s1/local/train.sh
examples/librispeech/s1/local/train.sh
+11
-1
examples/librispeech/s2/conf/augmentation.json
examples/librispeech/s2/conf/augmentation.json
+6
-4
examples/librispeech/s2/conf/transformer.yaml
examples/librispeech/s2/conf/transformer.yaml
+15
-24
examples/librispeech/s2/local/align.sh
examples/librispeech/s2/local/align.sh
+8
-5
examples/librispeech/s2/local/export.sh
examples/librispeech/s2/local/export.sh
+2
-1
examples/librispeech/s2/local/test.sh
examples/librispeech/s2/local/test.sh
+12
-7
examples/librispeech/s2/local/train.sh
examples/librispeech/s2/local/train.sh
+11
-1
examples/librispeech/s2/run.sh
examples/librispeech/s2/run.sh
+3
-2
examples/ted_en_zh/t0/local/train.sh
examples/ted_en_zh/t0/local/train.sh
+11
-1
examples/thchs30/a0/local/data.sh
examples/thchs30/a0/local/data.sh
+22
-16
examples/thchs30/a0/local/gen_word2phone.py
examples/thchs30/a0/local/gen_word2phone.py
+38
-18
examples/thchs30/a0/local/reorganize_thchs30.py
examples/thchs30/a0/local/reorganize_thchs30.py
+5
-4
examples/thchs30/a0/run.sh
examples/thchs30/a0/run.sh
+9
-6
examples/timit/s1/conf/augmentation.json
examples/timit/s1/conf/augmentation.json
+3
-1
examples/timit/s1/local/train.sh
examples/timit/s1/local/train.sh
+11
-1
examples/tiny/s0/conf/augmentation.json
examples/tiny/s0/conf/augmentation.json
+26
-0
examples/tiny/s0/local/train.sh
examples/tiny/s0/local/train.sh
+11
-1
examples/tiny/s1/conf/augmentation.json
examples/tiny/s1/conf/augmentation.json
+3
-1
examples/tiny/s1/local/train.sh
examples/tiny/s1/local/train.sh
+11
-1
requirements.txt
requirements.txt
+2
-0
tools/extras/install_mfa.sh
tools/extras/install_mfa.sh
+1
-1
未找到文件。
.flake8
浏览文件 @
03a50d7b
...
@@ -42,6 +42,10 @@ ignore =
...
@@ -42,6 +42,10 @@ ignore =
# these ignores are from flake8-comprehensions; please fix!
# these ignores are from flake8-comprehensions; please fix!
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
per-file-ignores =
*/__init__.py: F401
# Specify the list of error codes you wish Flake8 to report.
# Specify the list of error codes you wish Flake8 to report.
select =
select =
E,
E,
...
...
deepspeech/__init__.py
浏览文件 @
03a50d7b
...
@@ -352,45 +352,6 @@ if not hasattr(paddle.Tensor, 'tolist'):
...
@@ -352,45 +352,6 @@ if not hasattr(paddle.Tensor, 'tolist'):
"register user tolist to paddle.Tensor, remove this when fixed!"
)
"register user tolist to paddle.Tensor, remove this when fixed!"
)
setattr
(
paddle
.
Tensor
,
'tolist'
,
tolist
)
setattr
(
paddle
.
Tensor
,
'tolist'
,
tolist
)
########### hcak paddle.nn.functional #############
def
glu
(
x
:
paddle
.
Tensor
,
axis
=-
1
)
->
paddle
.
Tensor
:
"""The gated linear unit (GLU) activation."""
a
,
b
=
x
.
split
(
2
,
axis
=
axis
)
act_b
=
F
.
sigmoid
(
b
)
return
a
*
act_b
if
not
hasattr
(
paddle
.
nn
.
functional
,
'glu'
):
logger
.
warn
(
"register user glu to paddle.nn.functional, remove this when fixed!"
)
setattr
(
paddle
.
nn
.
functional
,
'glu'
,
glu
)
# def softplus(x):
# """Softplus function."""
# if hasattr(paddle.nn.functional, 'softplus'):
# #return paddle.nn.functional.softplus(x.float()).type_as(x)
# return paddle.nn.functional.softplus(x)
# else:
# raise NotImplementedError
# def gelu_accurate(x):
# """Gaussian Error Linear Units (GELU) activation."""
# # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py
# if not hasattr(gelu_accurate, "_a"):
# gelu_accurate._a = math.sqrt(2 / math.pi)
# return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a *
# (x + 0.044715 * paddle.pow(x, 3))))
# def gelu(x):
# """Gaussian Error Linear Units (GELU) activation."""
# if hasattr(nn.functional, 'gelu'):
# #return nn.functional.gelu(x.float()).type_as(x)
# return nn.functional.gelu(x)
# else:
# return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0)))
########### hcak paddle.nn #############
########### hcak paddle.nn #############
class
GLU
(
nn
.
Layer
):
class
GLU
(
nn
.
Layer
):
...
@@ -401,7 +362,7 @@ class GLU(nn.Layer):
...
@@ -401,7 +362,7 @@ class GLU(nn.Layer):
self
.
dim
=
dim
self
.
dim
=
dim
def
forward
(
self
,
xs
):
def
forward
(
self
,
xs
):
return
glu
(
xs
,
dim
=
self
.
dim
)
return
F
.
glu
(
xs
,
dim
=
self
.
dim
)
if
not
hasattr
(
paddle
.
nn
,
'GLU'
):
if
not
hasattr
(
paddle
.
nn
,
'GLU'
):
...
...
deepspeech/decoders/swig/setup.py
浏览文件 @
03a50d7b
...
@@ -83,10 +83,13 @@ FILES = glob.glob('kenlm/util/*.cc') \
...
@@ -83,10 +83,13 @@ FILES = glob.glob('kenlm/util/*.cc') \
FILES
+=
glob
.
glob
(
'openfst-1.6.3/src/lib/*.cc'
)
FILES
+=
glob
.
glob
(
'openfst-1.6.3/src/lib/*.cc'
)
# yapf: disable
FILES
=
[
FILES
=
[
fn
for
fn
in
FILES
if
not
(
fn
.
endswith
(
'main.cc'
)
or
fn
.
endswith
(
'test.cc'
)
fn
for
fn
in
FILES
or
fn
.
endswith
(
'unittest.cc'
))
if
not
(
fn
.
endswith
(
'main.cc'
)
or
fn
.
endswith
(
'test.cc'
)
or
fn
.
endswith
(
'unittest.cc'
))
]
]
# yapf: enable
LIBS
=
[
'stdc++'
]
LIBS
=
[
'stdc++'
]
if
platform
.
system
()
!=
'Darwin'
:
if
platform
.
system
()
!=
'Darwin'
:
...
...
deepspeech/exps/deepspeech2/bin/export.py
浏览文件 @
03a50d7b
...
@@ -30,6 +30,9 @@ def main(config, args):
...
@@ -30,6 +30,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
=
default_argument_parser
()
# save jit model to
parser
.
add_argument
(
"--export_path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
parser
.
add_argument
(
"--model_type"
)
parser
.
add_argument
(
"--model_type"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
model_type
is
None
:
if
args
.
model_type
is
None
:
...
...
deepspeech/exps/deepspeech2/bin/test.py
浏览文件 @
03a50d7b
...
@@ -31,6 +31,9 @@ def main(config, args):
...
@@ -31,6 +31,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
=
default_argument_parser
()
parser
.
add_argument
(
"--model_type"
)
parser
.
add_argument
(
"--model_type"
)
# save asr result to
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
print_arguments
(
args
,
globals
())
if
args
.
model_type
is
None
:
if
args
.
model_type
is
None
:
...
...
deepspeech/exps/u2/bin/alignment.py
浏览文件 @
03a50d7b
...
@@ -30,6 +30,9 @@ def main(config, args):
...
@@ -30,6 +30,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
=
default_argument_parser
()
# save asr result to
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2/bin/export.py
浏览文件 @
03a50d7b
...
@@ -30,6 +30,9 @@ def main(config, args):
...
@@ -30,6 +30,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
=
default_argument_parser
()
# save jit model to
parser
.
add_argument
(
"--export_path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2/bin/test.py
浏览文件 @
03a50d7b
...
@@ -34,6 +34,9 @@ def main(config, args):
...
@@ -34,6 +34,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
=
default_argument_parser
()
# save asr result to
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2_kaldi/bin/test.py
浏览文件 @
03a50d7b
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
"""Evaluation for U2 model."""
"""Evaluation for U2 model."""
import
cProfile
import
cProfile
from
yacs.config
import
CfgNode
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.utils.dynamic_import
import
dynamic_import
from
deepspeech.utils.dynamic_import
import
dynamic_import
from
deepspeech.utils.utility
import
print_arguments
from
deepspeech.utils.utility
import
print_arguments
...
@@ -53,6 +55,14 @@ if __name__ == "__main__":
...
@@ -53,6 +55,14 @@ if __name__ == "__main__":
type
=
str
,
type
=
str
,
default
=
'test'
,
default
=
'test'
,
help
=
'run mode, e.g. test, align, export'
)
help
=
'run mode, e.g. test, align, export'
)
parser
.
add_argument
(
'--dict-path'
,
type
=
str
,
default
=
None
,
help
=
'dict path.'
)
# save asr result to
parser
.
add_argument
(
"--result-file"
,
type
=
str
,
help
=
"path of save the asr result"
)
# save jit model to
parser
.
add_argument
(
"--export-path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2_kaldi/model.py
浏览文件 @
03a50d7b
...
@@ -25,6 +25,8 @@ import paddle
...
@@ -25,6 +25,8 @@ import paddle
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
from
deepspeech.frontend.featurizer
import
TextFeaturizer
from
deepspeech.frontend.utility
import
load_dict
from
deepspeech.io.dataloader
import
BatchDataLoader
from
deepspeech.io.dataloader
import
BatchDataLoader
from
deepspeech.models.u2
import
U2Model
from
deepspeech.models.u2
import
U2Model
from
deepspeech.training.optimizer
import
OptimizerFactory
from
deepspeech.training.optimizer
import
OptimizerFactory
...
@@ -80,8 +82,8 @@ class U2Trainer(Trainer):
...
@@ -80,8 +82,8 @@ class U2Trainer(Trainer):
def
train_batch
(
self
,
batch_index
,
batch_data
,
msg
):
def
train_batch
(
self
,
batch_index
,
batch_data
,
msg
):
train_conf
=
self
.
config
.
training
train_conf
=
self
.
config
.
training
start
=
time
.
time
()
start
=
time
.
time
()
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch_data
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch_data
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
audio
,
audio_len
,
text
,
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
audio
,
audio_len
,
text
,
text_len
)
text_len
)
# loss div by `batch_size * accum_grad`
# loss div by `batch_size * accum_grad`
...
@@ -124,6 +126,7 @@ class U2Trainer(Trainer):
...
@@ -124,6 +126,7 @@ class U2Trainer(Trainer):
valid_losses
=
defaultdict
(
list
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
num_seen_utts
=
1
total_loss
=
0.0
total_loss
=
0.0
for
i
,
batch
in
enumerate
(
self
.
valid_loader
):
for
i
,
batch
in
enumerate
(
self
.
valid_loader
):
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
audio
,
audio_len
,
text
,
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
audio
,
audio_len
,
text
,
...
@@ -168,10 +171,7 @@ class U2Trainer(Trainer):
...
@@ -168,10 +171,7 @@ class U2Trainer(Trainer):
if
from_scratch
:
if
from_scratch
:
# save init model, i.e. 0 epoch
# save init model, i.e. 0 epoch
self
.
save
(
tag
=
'init'
)
self
.
save
(
tag
=
'init'
)
self
.
lr_scheduler
.
step
(
self
.
iteration
)
self
.
lr_scheduler
.
step
(
self
.
iteration
)
if
self
.
parallel
:
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
...
@@ -225,7 +225,7 @@ class U2Trainer(Trainer):
...
@@ -225,7 +225,7 @@ class U2Trainer(Trainer):
maxlen_in
=
float
(
'inf'
),
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
minibatches
=
0
,
mini_batch_size
=
1
,
mini_batch_size
=
self
.
args
.
nprocs
,
batch_count
=
'auto'
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_in
=
0
,
...
@@ -244,7 +244,7 @@ class U2Trainer(Trainer):
...
@@ -244,7 +244,7 @@ class U2Trainer(Trainer):
maxlen_in
=
float
(
'inf'
),
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
minibatches
=
0
,
mini_batch_size
=
1
,
mini_batch_size
=
self
.
args
.
nprocs
,
batch_count
=
'auto'
,
batch_count
=
'auto'
,
batch_bins
=
0
,
batch_bins
=
0
,
batch_frames_in
=
0
,
batch_frames_in
=
0
,
...
@@ -260,7 +260,7 @@ class U2Trainer(Trainer):
...
@@ -260,7 +260,7 @@ class U2Trainer(Trainer):
json_file
=
config
.
data
.
test_manifest
,
json_file
=
config
.
data
.
test_manifest
,
train_mode
=
False
,
train_mode
=
False
,
sortagrad
=
False
,
sortagrad
=
False
,
batch_size
=
config
.
collator
.
batch_size
,
batch_size
=
config
.
decoding
.
batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
minibatches
=
0
,
...
@@ -279,7 +279,7 @@ class U2Trainer(Trainer):
...
@@ -279,7 +279,7 @@ class U2Trainer(Trainer):
json_file
=
config
.
data
.
test_manifest
,
json_file
=
config
.
data
.
test_manifest
,
train_mode
=
False
,
train_mode
=
False
,
sortagrad
=
False
,
sortagrad
=
False
,
batch_size
=
config
.
collator
.
batch_size
,
batch_size
=
config
.
decoding
.
batch_size
,
maxlen_in
=
float
(
'inf'
),
maxlen_in
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
maxlen_out
=
float
(
'inf'
),
minibatches
=
0
,
minibatches
=
0
,
...
@@ -305,10 +305,8 @@ class U2Trainer(Trainer):
...
@@ -305,10 +305,8 @@ class U2Trainer(Trainer):
model_conf
.
output_dim
=
self
.
train_loader
.
vocab_size
model_conf
.
output_dim
=
self
.
train_loader
.
vocab_size
model_conf
.
freeze
()
model_conf
.
freeze
()
model
=
U2Model
.
from_config
(
model_conf
)
model
=
U2Model
.
from_config
(
model_conf
)
if
self
.
parallel
:
if
self
.
parallel
:
model
=
paddle
.
DataParallel
(
model
)
model
=
paddle
.
DataParallel
(
model
)
logger
.
info
(
f
"
{
model
}
"
)
logger
.
info
(
f
"
{
model
}
"
)
layer_tools
.
print_params
(
model
,
logger
.
info
)
layer_tools
.
print_params
(
model
,
logger
.
info
)
...
@@ -379,13 +377,13 @@ class U2Tester(U2Trainer):
...
@@ -379,13 +377,13 @@ class U2Tester(U2Trainer):
def
__init__
(
self
,
config
,
args
):
def
__init__
(
self
,
config
,
args
):
super
().
__init__
(
config
,
args
)
super
().
__init__
(
config
,
args
)
def
ordid2token
(
self
,
texts
,
texts_len
):
def
id2token
(
self
,
texts
,
texts_len
,
text_feature
):
""" ord() id to chr() chr """
""" ord() id to chr() chr """
trans
=
[]
trans
=
[]
for
text
,
n
in
zip
(
texts
,
texts_len
):
for
text
,
n
in
zip
(
texts
,
texts_len
):
n
=
n
.
numpy
().
item
()
n
=
n
.
numpy
().
item
()
ids
=
text
[:
n
]
ids
=
text
[:
n
]
trans
.
append
(
''
.
join
([
chr
(
i
)
for
i
in
ids
]
))
trans
.
append
(
text_feature
.
defeaturize
(
ids
.
numpy
().
tolist
()
))
return
trans
return
trans
def
compute_metrics
(
self
,
def
compute_metrics
(
self
,
...
@@ -401,8 +399,11 @@ class U2Tester(U2Trainer):
...
@@ -401,8 +399,11 @@ class U2Tester(U2Trainer):
error_rate_func
=
error_rate
.
cer
if
cfg
.
error_rate_type
==
'cer'
else
error_rate
.
wer
error_rate_func
=
error_rate
.
cer
if
cfg
.
error_rate_type
==
'cer'
else
error_rate
.
wer
start_time
=
time
.
time
()
start_time
=
time
.
time
()
text_feature
=
self
.
test_loader
.
collate_fn
.
text_feature
text_feature
=
TextFeaturizer
(
target_transcripts
=
self
.
ordid2token
(
texts
,
texts_len
)
unit_type
=
self
.
config
.
collator
.
unit_type
,
vocab_filepath
=
self
.
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
self
.
config
.
collator
.
spm_model_prefix
)
target_transcripts
=
self
.
id2token
(
texts
,
texts_len
,
text_feature
)
result_transcripts
=
self
.
model
.
decode
(
result_transcripts
=
self
.
model
.
decode
(
audio
,
audio
,
audio_len
,
audio_len
,
...
@@ -450,7 +451,7 @@ class U2Tester(U2Trainer):
...
@@ -450,7 +451,7 @@ class U2Tester(U2Trainer):
self
.
model
.
eval
()
self
.
model
.
eval
()
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
stride_ms
=
self
.
test_loader
.
collate_fn
.
stride_ms
stride_ms
=
self
.
config
.
collator
.
stride_ms
error_rate_type
=
None
error_rate_type
=
None
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
num_frames
=
0.0
num_frames
=
0.0
...
@@ -525,8 +526,9 @@ class U2Tester(U2Trainer):
...
@@ -525,8 +526,9 @@ class U2Tester(U2Trainer):
self
.
model
.
eval
()
self
.
model
.
eval
()
logger
.
info
(
f
"Align Total Examples:
{
len
(
self
.
align_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Align Total Examples:
{
len
(
self
.
align_loader
.
dataset
)
}
"
)
stride_ms
=
self
.
config
.
collate
.
stride_ms
stride_ms
=
self
.
config
.
collater
.
stride_ms
token_dict
=
self
.
align_loader
.
collate_fn
.
vocab_list
token_dict
=
self
.
args
.
char_list
with
open
(
self
.
args
.
result_file
,
'w'
)
as
fout
:
with
open
(
self
.
args
.
result_file
,
'w'
)
as
fout
:
# one example in batch
# one example in batch
for
i
,
batch
in
enumerate
(
self
.
align_loader
):
for
i
,
batch
in
enumerate
(
self
.
align_loader
):
...
@@ -613,6 +615,11 @@ class U2Tester(U2Trainer):
...
@@ -613,6 +615,11 @@ class U2Tester(U2Trainer):
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
sys
.
exit
(
-
1
)
sys
.
exit
(
-
1
)
def
setup_dict
(
self
):
# load dictionary for debug log
self
.
args
.
char_list
=
load_dict
(
self
.
args
.
dict_path
,
"maskctc"
in
self
.
args
.
model_name
)
def
setup
(
self
):
def
setup
(
self
):
"""Setup the experiment.
"""Setup the experiment.
"""
"""
...
@@ -624,6 +631,8 @@ class U2Tester(U2Trainer):
...
@@ -624,6 +631,8 @@ class U2Tester(U2Trainer):
self
.
setup_dataloader
()
self
.
setup_dataloader
()
self
.
setup_model
()
self
.
setup_model
()
self
.
setup_dict
()
self
.
iteration
=
0
self
.
iteration
=
0
self
.
epoch
=
0
self
.
epoch
=
0
...
...
deepspeech/exps/u2_st/bin/export.py
浏览文件 @
03a50d7b
...
@@ -30,6 +30,9 @@ def main(config, args):
...
@@ -30,6 +30,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
=
default_argument_parser
()
# save jit model to
parser
.
add_argument
(
"--export_path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
print_arguments
(
args
,
globals
())
...
...
deepspeech/exps/u2_st/bin/test.py
浏览文件 @
03a50d7b
...
@@ -34,6 +34,9 @@ def main(config, args):
...
@@ -34,6 +34,9 @@ def main(config, args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
=
default_argument_parser
()
# save asr result to
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
print_arguments
(
args
,
globals
())
...
...
deepspeech/frontend/augmentor/impulse_response.py
浏览文件 @
03a50d7b
...
@@ -32,7 +32,7 @@ class ImpulseResponseAugmentor(AugmentorBase):
...
@@ -32,7 +32,7 @@ class ImpulseResponseAugmentor(AugmentorBase):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
if
not
train
:
return
return
x
self
.
transform_audio
(
x
)
self
.
transform_audio
(
x
)
return
x
return
x
...
...
deepspeech/frontend/augmentor/noise_perturb.py
浏览文件 @
03a50d7b
...
@@ -38,7 +38,7 @@ class NoisePerturbAugmentor(AugmentorBase):
...
@@ -38,7 +38,7 @@ class NoisePerturbAugmentor(AugmentorBase):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
if
not
train
:
return
return
x
self
.
transform_audio
(
x
)
self
.
transform_audio
(
x
)
return
x
return
x
...
...
deepspeech/frontend/augmentor/online_bayesian_normalization.py
浏览文件 @
03a50d7b
...
@@ -46,7 +46,7 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase):
...
@@ -46,7 +46,7 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
if
not
train
:
return
return
x
self
.
transform_audio
(
x
)
self
.
transform_audio
(
x
)
return
x
return
x
...
...
deepspeech/frontend/augmentor/resample.py
浏览文件 @
03a50d7b
...
@@ -33,7 +33,7 @@ class ResampleAugmentor(AugmentorBase):
...
@@ -33,7 +33,7 @@ class ResampleAugmentor(AugmentorBase):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
if
not
train
:
return
return
x
self
.
transform_audio
(
x
)
self
.
transform_audio
(
x
)
return
x
return
x
...
...
deepspeech/frontend/augmentor/shift_perturb.py
浏览文件 @
03a50d7b
...
@@ -33,7 +33,7 @@ class ShiftPerturbAugmentor(AugmentorBase):
...
@@ -33,7 +33,7 @@ class ShiftPerturbAugmentor(AugmentorBase):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
if
not
train
:
return
return
x
self
.
transform_audio
(
x
)
self
.
transform_audio
(
x
)
return
x
return
x
...
...
deepspeech/frontend/augmentor/spec_augment.py
浏览文件 @
03a50d7b
...
@@ -12,7 +12,11 @@
...
@@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Contains the volume perturb augmentation model."""
"""Contains the volume perturb augmentation model."""
import
random
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
from
PIL.Image
import
BICUBIC
from
deepspeech.frontend.augmentor.base
import
AugmentorBase
from
deepspeech.frontend.augmentor.base
import
AugmentorBase
from
deepspeech.utils.log
import
Log
from
deepspeech.utils.log
import
Log
...
@@ -41,7 +45,9 @@ class SpecAugmentor(AugmentorBase):
...
@@ -41,7 +45,9 @@ class SpecAugmentor(AugmentorBase):
W
=
40
,
W
=
40
,
adaptive_number_ratio
=
0
,
adaptive_number_ratio
=
0
,
adaptive_size_ratio
=
0
,
adaptive_size_ratio
=
0
,
max_n_time_masks
=
20
):
max_n_time_masks
=
20
,
replace_with_zero
=
True
,
warp_mode
=
'PIL'
):
"""SpecAugment class.
"""SpecAugment class.
Args:
Args:
rng (random.Random): random generator object.
rng (random.Random): random generator object.
...
@@ -54,10 +60,16 @@ class SpecAugmentor(AugmentorBase):
...
@@ -54,10 +60,16 @@ class SpecAugmentor(AugmentorBase):
adaptive_number_ratio (float): adaptive multiplicity ratio for time masking
adaptive_number_ratio (float): adaptive multiplicity ratio for time masking
adaptive_size_ratio (float): adaptive size ratio for time masking
adaptive_size_ratio (float): adaptive size ratio for time masking
max_n_time_masks (int): maximum number of time masking
max_n_time_masks (int): maximum number of time masking
replace_with_zero (bool): pad zero on mask if true else use mean
warp_mode (str): "PIL" (default, fast, not differentiable)
or "sparse_image_warp" (slow, differentiable)
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
_rng
=
rng
self
.
_rng
=
rng
self
.
inplace
=
True
self
.
replace_with_zero
=
replace_with_zero
self
.
mode
=
warp_mode
self
.
W
=
W
self
.
W
=
W
self
.
F
=
F
self
.
F
=
F
self
.
T
=
T
self
.
T
=
T
...
@@ -123,21 +135,83 @@ class SpecAugmentor(AugmentorBase):
...
@@ -123,21 +135,83 @@ class SpecAugmentor(AugmentorBase):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
"specaug: F-
{
F
}
, T-
{
T
}
, F-n-
{
n_freq_masks
}
, T-n-
{
n_time_masks
}
"
return
f
"specaug: F-
{
F
}
, T-
{
T
}
, F-n-
{
n_freq_masks
}
, T-n-
{
n_time_masks
}
"
def
time_warp
(
xs
,
W
=
40
):
def
time_warp
(
self
,
x
,
mode
=
'PIL'
):
raise
NotImplementedError
"""time warp for spec augment
move random center frame by the random width ~ uniform(-window, window)
Args:
x (np.ndarray): spectrogram (time, freq)
mode (str): PIL or sparse_image_warp
Raises:
NotImplementedError: [description]
NotImplementedError: [description]
Returns:
np.ndarray: time warped spectrogram (time, freq)
"""
window
=
max_time_warp
=
self
.
W
if
window
==
0
:
return
x
if
mode
==
"PIL"
:
t
=
x
.
shape
[
0
]
if
t
-
window
<=
window
:
return
x
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
center
=
random
.
randrange
(
window
,
t
-
window
)
warped
=
random
.
randrange
(
center
-
window
,
center
+
window
)
+
1
# 1 ... t - 1
left
=
Image
.
fromarray
(
x
[:
center
]).
resize
((
x
.
shape
[
1
],
warped
),
BICUBIC
)
right
=
Image
.
fromarray
(
x
[
center
:]).
resize
((
x
.
shape
[
1
],
t
-
warped
),
BICUBIC
)
if
self
.
inplace
:
x
[:
warped
]
=
left
x
[
warped
:]
=
right
return
x
return
np
.
concatenate
((
left
,
right
),
0
)
elif
mode
==
"sparse_image_warp"
:
raise
NotImplementedError
(
'sparse_image_warp'
)
else
:
raise
NotImplementedError
(
"unknown resize mode: "
+
mode
+
", choose one from (PIL, sparse_image_warp)."
)
def
mask_freq
(
self
,
x
,
replace_with_zero
=
False
):
"""freq mask
Args:
x (np.ndarray): spectrogram (time, freq)
replace_with_zero (bool, optional): Defaults to False.
def
mask_freq
(
self
,
xs
,
replace_with_zero
=
False
):
Returns:
n_bins
=
xs
.
shape
[
0
]
np.ndarray: freq mask spectrogram (time, freq)
"""
n_bins
=
x
.
shape
[
1
]
for
i
in
range
(
0
,
self
.
n_freq_masks
):
for
i
in
range
(
0
,
self
.
n_freq_masks
):
f
=
int
(
self
.
_rng
.
uniform
(
low
=
0
,
high
=
self
.
F
))
f
=
int
(
self
.
_rng
.
uniform
(
low
=
0
,
high
=
self
.
F
))
f_0
=
int
(
self
.
_rng
.
uniform
(
low
=
0
,
high
=
n_bins
-
f
))
f_0
=
int
(
self
.
_rng
.
uniform
(
low
=
0
,
high
=
n_bins
-
f
))
xs
[
f_0
:
f_0
+
f
,
:]
=
0
assert
f_0
<=
f_0
+
f
assert
f_0
<=
f_0
+
f
if
replace_with_zero
:
x
[:,
f_0
:
f_0
+
f
]
=
0
else
:
x
[:,
f_0
:
f_0
+
f
]
=
x
.
mean
()
self
.
_freq_mask
=
(
f_0
,
f_0
+
f
)
self
.
_freq_mask
=
(
f_0
,
f_0
+
f
)
return
x
s
return
x
def
mask_time
(
self
,
xs
,
replace_with_zero
=
False
):
def
mask_time
(
self
,
x
,
replace_with_zero
=
False
):
n_frames
=
xs
.
shape
[
1
]
"""time mask
Args:
x (np.ndarray): spectrogram (time, freq)
replace_with_zero (bool, optional): Defaults to False.
Returns:
np.ndarray: time mask spectrogram (time, freq)
"""
n_frames
=
x
.
shape
[
0
]
if
self
.
adaptive_number_ratio
>
0
:
if
self
.
adaptive_number_ratio
>
0
:
n_masks
=
int
(
n_frames
*
self
.
adaptive_number_ratio
)
n_masks
=
int
(
n_frames
*
self
.
adaptive_number_ratio
)
...
@@ -154,24 +228,29 @@ class SpecAugmentor(AugmentorBase):
...
@@ -154,24 +228,29 @@ class SpecAugmentor(AugmentorBase):
t
=
int
(
self
.
_rng
.
uniform
(
low
=
0
,
high
=
T
))
t
=
int
(
self
.
_rng
.
uniform
(
low
=
0
,
high
=
T
))
t
=
min
(
t
,
int
(
n_frames
*
self
.
p
))
t
=
min
(
t
,
int
(
n_frames
*
self
.
p
))
t_0
=
int
(
self
.
_rng
.
uniform
(
low
=
0
,
high
=
n_frames
-
t
))
t_0
=
int
(
self
.
_rng
.
uniform
(
low
=
0
,
high
=
n_frames
-
t
))
xs
[:,
t_0
:
t_0
+
t
]
=
0
assert
t_0
<=
t_0
+
t
assert
t_0
<=
t_0
+
t
if
replace_with_zero
:
x
[
t_0
:
t_0
+
t
,
:]
=
0
else
:
x
[
t_0
:
t_0
+
t
,
:]
=
x
.
mean
()
self
.
_time_mask
=
(
t_0
,
t_0
+
t
)
self
.
_time_mask
=
(
t_0
,
t_0
+
t
)
return
x
s
return
x
def
__call__
(
self
,
x
,
train
=
True
):
def
__call__
(
self
,
x
,
train
=
True
):
if
not
train
:
if
not
train
:
return
return
x
return
self
.
transform_feature
(
x
)
return
self
.
transform_feature
(
x
)
def
transform_feature
(
self
,
x
s
:
np
.
ndarray
):
def
transform_feature
(
self
,
x
:
np
.
ndarray
):
"""
"""
Args:
Args:
x
s (FloatTensor): `[F, T
]`
x
(np.ndarray): `[T, F
]`
Returns:
Returns:
x
s (FloatTensor): `[F, T
]`
x
(np.ndarray): `[T, F
]`
"""
"""
# xs = self.time_warp(xs)
assert
isinstance
(
x
,
np
.
ndarray
)
xs
=
self
.
mask_freq
(
xs
)
assert
x
.
ndim
==
2
xs
=
self
.
mask_time
(
xs
)
x
=
self
.
time_warp
(
x
,
self
.
mode
)
return
xs
x
=
self
.
mask_freq
(
x
,
self
.
replace_with_zero
)
x
=
self
.
mask_time
(
x
,
self
.
replace_with_zero
)
return
x
deepspeech/frontend/augmentor/speed_perturb.py
浏览文件 @
03a50d7b
...
@@ -81,7 +81,7 @@ class SpeedPerturbAugmentor(AugmentorBase):
...
@@ -81,7 +81,7 @@ class SpeedPerturbAugmentor(AugmentorBase):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
if
not
train
:
return
return
x
self
.
transform_audio
(
x
)
self
.
transform_audio
(
x
)
return
x
return
x
...
...
deepspeech/frontend/augmentor/volume_perturb.py
浏览文件 @
03a50d7b
...
@@ -39,7 +39,7 @@ class VolumePerturbAugmentor(AugmentorBase):
...
@@ -39,7 +39,7 @@ class VolumePerturbAugmentor(AugmentorBase):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
if
not
train
:
return
return
x
self
.
transform_audio
(
x
)
self
.
transform_audio
(
x
)
return
x
return
x
...
...
deepspeech/frontend/featurizer/__init__.py
浏览文件 @
03a50d7b
...
@@ -11,3 +11,6 @@
...
@@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.audio_featurizer
import
AudioFeaturizer
#noqa: F401
from
.speech_featurizer
import
SpeechFeaturizer
from
.text_featurizer
import
TextFeaturizer
deepspeech/frontend/featurizer/audio_featurizer.py
浏览文件 @
03a50d7b
...
@@ -18,7 +18,7 @@ from python_speech_features import logfbank
...
@@ -18,7 +18,7 @@ from python_speech_features import logfbank
from
python_speech_features
import
mfcc
from
python_speech_features
import
mfcc
class
AudioFeaturizer
(
object
):
class
AudioFeaturizer
():
"""Audio featurizer, for extracting features from audio contents of
"""Audio featurizer, for extracting features from audio contents of
AudioSegment or SpeechSegment.
AudioSegment or SpeechSegment.
...
@@ -167,32 +167,6 @@ class AudioFeaturizer(object):
...
@@ -167,32 +167,6 @@ class AudioFeaturizer(object):
raise
ValueError
(
"Unknown specgram_type %s. "
raise
ValueError
(
"Unknown specgram_type %s. "
"Supported values: linear."
%
self
.
_specgram_type
)
"Supported values: linear."
%
self
.
_specgram_type
)
def
_compute_linear_specgram
(
self
,
samples
,
sample_rate
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
max_freq
=
None
,
eps
=
1e-14
):
"""Compute the linear spectrogram from FFT energy."""
if
max_freq
is
None
:
max_freq
=
sample_rate
/
2
if
max_freq
>
sample_rate
/
2
:
raise
ValueError
(
"max_freq must not be greater than half of "
"sample rate."
)
if
stride_ms
>
window_ms
:
raise
ValueError
(
"Stride size must not be greater than "
"window size."
)
stride_size
=
int
(
0.001
*
sample_rate
*
stride_ms
)
window_size
=
int
(
0.001
*
sample_rate
*
window_ms
)
specgram
,
freqs
=
self
.
_specgram_real
(
samples
,
window_size
=
window_size
,
stride_size
=
stride_size
,
sample_rate
=
sample_rate
)
ind
=
np
.
where
(
freqs
<=
max_freq
)[
0
][
-
1
]
+
1
return
np
.
log
(
specgram
[:
ind
,
:]
+
eps
)
def
_specgram_real
(
self
,
samples
,
window_size
,
stride_size
,
sample_rate
):
def
_specgram_real
(
self
,
samples
,
window_size
,
stride_size
,
sample_rate
):
"""Compute the spectrogram for samples from a real signal."""
"""Compute the spectrogram for samples from a real signal."""
# extract strided windows
# extract strided windows
...
@@ -217,26 +191,65 @@ class AudioFeaturizer(object):
...
@@ -217,26 +191,65 @@ class AudioFeaturizer(object):
freqs
=
float
(
sample_rate
)
/
window_size
*
np
.
arange
(
fft
.
shape
[
0
])
freqs
=
float
(
sample_rate
)
/
window_size
*
np
.
arange
(
fft
.
shape
[
0
])
return
fft
,
freqs
return
fft
,
freqs
def
_compute_linear_specgram
(
self
,
samples
,
sample_rate
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
max_freq
=
None
,
eps
=
1e-14
):
"""Compute the linear spectrogram from FFT energy.
Args:
samples ([type]): [description]
sample_rate ([type]): [description]
stride_ms (float, optional): [description]. Defaults to 10.0.
window_ms (float, optional): [description]. Defaults to 20.0.
max_freq ([type], optional): [description]. Defaults to None.
eps ([type], optional): [description]. Defaults to 1e-14.
Raises:
ValueError: [description]
ValueError: [description]
Returns:
np.ndarray: log spectrogram, (time, freq)
"""
if
max_freq
is
None
:
max_freq
=
sample_rate
/
2
if
max_freq
>
sample_rate
/
2
:
raise
ValueError
(
"max_freq must not be greater than half of "
"sample rate."
)
if
stride_ms
>
window_ms
:
raise
ValueError
(
"Stride size must not be greater than "
"window size."
)
stride_size
=
int
(
0.001
*
sample_rate
*
stride_ms
)
window_size
=
int
(
0.001
*
sample_rate
*
window_ms
)
specgram
,
freqs
=
self
.
_specgram_real
(
samples
,
window_size
=
window_size
,
stride_size
=
stride_size
,
sample_rate
=
sample_rate
)
ind
=
np
.
where
(
freqs
<=
max_freq
)[
0
][
-
1
]
+
1
# (freq, time)
spec
=
np
.
log
(
specgram
[:
ind
,
:]
+
eps
)
return
np
.
transpose
(
spec
)
def
_concat_delta_delta
(
self
,
feat
):
def
_concat_delta_delta
(
self
,
feat
):
"""append delat, delta-delta feature.
"""append delat, delta-delta feature.
Args:
Args:
feat (np.ndarray): (
D, T
)
feat (np.ndarray): (
T, D
)
Returns:
Returns:
np.ndarray: feat with delta-delta, (
3*D, T
)
np.ndarray: feat with delta-delta, (
T, 3*D
)
"""
"""
feat
=
np
.
transpose
(
feat
)
# Deltas
# Deltas
d_feat
=
delta
(
feat
,
2
)
d_feat
=
delta
(
feat
,
2
)
# Deltas-Deltas
# Deltas-Deltas
dd_feat
=
delta
(
feat
,
2
)
dd_feat
=
delta
(
feat
,
2
)
# transpose
feat
=
np
.
transpose
(
feat
)
d_feat
=
np
.
transpose
(
d_feat
)
dd_feat
=
np
.
transpose
(
dd_feat
)
# concat above three features
# concat above three features
concat_feat
=
np
.
concatenate
((
feat
,
d_feat
,
dd_feat
))
concat_feat
=
np
.
concatenate
((
feat
,
d_feat
,
dd_feat
)
,
axis
=
1
)
return
concat_feat
return
concat_feat
def
_compute_mfcc
(
self
,
def
_compute_mfcc
(
self
,
...
@@ -292,7 +305,6 @@ class AudioFeaturizer(object):
...
@@ -292,7 +305,6 @@ class AudioFeaturizer(object):
ceplifter
=
22
,
ceplifter
=
22
,
useEnergy
=
True
,
useEnergy
=
True
,
winfunc
=
'povey'
)
winfunc
=
'povey'
)
mfcc_feat
=
np
.
transpose
(
mfcc_feat
)
if
delta_delta
:
if
delta_delta
:
mfcc_feat
=
self
.
_concat_delta_delta
(
mfcc_feat
)
mfcc_feat
=
self
.
_concat_delta_delta
(
mfcc_feat
)
return
mfcc_feat
return
mfcc_feat
...
@@ -346,8 +358,6 @@ class AudioFeaturizer(object):
...
@@ -346,8 +358,6 @@ class AudioFeaturizer(object):
remove_dc_offset
=
True
,
remove_dc_offset
=
True
,
preemph
=
0.97
,
preemph
=
0.97
,
wintype
=
'povey'
)
wintype
=
'povey'
)
fbank_feat
=
np
.
transpose
(
fbank_feat
)
if
delta_delta
:
if
delta_delta
:
fbank_feat
=
self
.
_concat_delta_delta
(
fbank_feat
)
fbank_feat
=
self
.
_concat_delta_delta
(
fbank_feat
)
return
fbank_feat
return
fbank_feat
deepspeech/frontend/featurizer/speech_featurizer.py
浏览文件 @
03a50d7b
...
@@ -16,7 +16,7 @@ from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer
...
@@ -16,7 +16,7 @@ from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer
from
deepspeech.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
deepspeech.frontend.featurizer.text_featurizer
import
TextFeaturizer
class
SpeechFeaturizer
(
object
):
class
SpeechFeaturizer
():
"""Speech featurizer, for extracting features from both audio and transcript
"""Speech featurizer, for extracting features from both audio and transcript
contents of SpeechSegment.
contents of SpeechSegment.
...
...
deepspeech/frontend/featurizer/text_featurizer.py
浏览文件 @
03a50d7b
...
@@ -14,12 +14,19 @@
...
@@ -14,12 +14,19 @@
"""Contains the text featurizer class."""
"""Contains the text featurizer class."""
import
sentencepiece
as
spm
import
sentencepiece
as
spm
from
deepspeech.frontend.utility
import
EOS
from
..utility
import
EOS
from
deepspeech.frontend.utility
import
UNK
from
..utility
import
load_dict
from
..utility
import
UNK
__all__
=
[
"TextFeaturizer"
]
class
TextFeaturizer
(
object
):
def
__init__
(
self
,
unit_type
,
vocab_filepath
,
spm_model_prefix
=
None
):
class
TextFeaturizer
():
def
__init__
(
self
,
unit_type
,
vocab_filepath
,
spm_model_prefix
=
None
,
maskctc
=
False
):
"""Text featurizer, for processing or extracting features from text.
"""Text featurizer, for processing or extracting features from text.
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
...
@@ -34,11 +41,12 @@ class TextFeaturizer(object):
...
@@ -34,11 +41,12 @@ class TextFeaturizer(object):
assert
unit_type
in
(
'char'
,
'spm'
,
'word'
)
assert
unit_type
in
(
'char'
,
'spm'
,
'word'
)
self
.
unit_type
=
unit_type
self
.
unit_type
=
unit_type
self
.
unk
=
UNK
self
.
unk
=
UNK
self
.
maskctc
=
maskctc
if
vocab_filepath
:
if
vocab_filepath
:
self
.
_vocab_dict
,
self
.
_id2token
,
self
.
_vocab_list
=
self
.
_load_vocabulary_from_file
(
self
.
vocab_dict
,
self
.
_id2token
,
self
.
vocab_list
,
self
.
unk_id
,
self
.
eos_id
=
self
.
_load_vocabulary_from_file
(
vocab_filepath
)
vocab_filepath
,
maskctc
)
self
.
unk_id
=
self
.
_vocab_list
.
index
(
self
.
unk
)
self
.
vocab_size
=
len
(
self
.
vocab_list
)
self
.
eos_id
=
self
.
_vocab_list
.
index
(
EOS
)
if
unit_type
==
'spm'
:
if
unit_type
==
'spm'
:
spm_model
=
spm_model_prefix
+
'.model'
spm_model
=
spm_model_prefix
+
'.model'
...
@@ -67,7 +75,7 @@ class TextFeaturizer(object):
...
@@ -67,7 +75,7 @@ class TextFeaturizer(object):
"""Convert text string to a list of token indices.
"""Convert text string to a list of token indices.
Args:
Args:
text (str): Text
to process
.
text (str): Text.
Returns:
Returns:
List[int]: List of token indices.
List[int]: List of token indices.
...
@@ -75,8 +83,8 @@ class TextFeaturizer(object):
...
@@ -75,8 +83,8 @@ class TextFeaturizer(object):
tokens
=
self
.
tokenize
(
text
)
tokens
=
self
.
tokenize
(
text
)
ids
=
[]
ids
=
[]
for
token
in
tokens
:
for
token
in
tokens
:
token
=
token
if
token
in
self
.
_
vocab_dict
else
self
.
unk
token
=
token
if
token
in
self
.
vocab_dict
else
self
.
unk
ids
.
append
(
self
.
_
vocab_dict
[
token
])
ids
.
append
(
self
.
vocab_dict
[
token
])
return
ids
return
ids
def
defeaturize
(
self
,
idxs
):
def
defeaturize
(
self
,
idxs
):
...
@@ -87,7 +95,7 @@ class TextFeaturizer(object):
...
@@ -87,7 +95,7 @@ class TextFeaturizer(object):
idxs (List[int]): List of token indices.
idxs (List[int]): List of token indices.
Returns:
Returns:
str: Text
to process
.
str: Text.
"""
"""
tokens
=
[]
tokens
=
[]
for
idx
in
idxs
:
for
idx
in
idxs
:
...
@@ -97,33 +105,6 @@ class TextFeaturizer(object):
...
@@ -97,33 +105,6 @@ class TextFeaturizer(object):
text
=
self
.
detokenize
(
tokens
)
text
=
self
.
detokenize
(
tokens
)
return
text
return
text
@
property
def
vocab_size
(
self
):
"""Return the vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return
len
(
self
.
_vocab_list
)
@
property
def
vocab_list
(
self
):
"""Return the vocabulary in list.
Returns:
List[str]: tokens.
"""
return
self
.
_vocab_list
@
property
def
vocab_dict
(
self
):
"""Return the vocabulary in dict.
Returns:
Dict[str, int]: token str -> int
"""
return
self
.
_vocab_dict
def
char_tokenize
(
self
,
text
):
def
char_tokenize
(
self
,
text
):
"""Character tokenizer.
"""Character tokenizer.
...
@@ -206,14 +187,16 @@ class TextFeaturizer(object):
...
@@ -206,14 +187,16 @@ class TextFeaturizer(object):
return
decode
(
tokens
)
return
decode
(
tokens
)
def
_load_vocabulary_from_file
(
self
,
vocab_filepath
):
def
_load_vocabulary_from_file
(
self
,
vocab_filepath
:
str
,
maskctc
:
bool
):
"""Load vocabulary from file."""
"""Load vocabulary from file."""
vocab_lines
=
[]
vocab_list
=
load_dict
(
vocab_filepath
,
maskctc
)
with
open
(
vocab_filepath
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
assert
vocab_list
is
not
None
vocab_lines
.
extend
(
file
.
readlines
())
vocab_list
=
[
line
[:
-
1
]
for
line
in
vocab_lines
]
id2token
=
dict
(
id2token
=
dict
(
[(
idx
,
token
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
[(
idx
,
token
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
token2id
=
dict
(
token2id
=
dict
(
[(
token
,
idx
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
[(
token
,
idx
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
return
token2id
,
id2token
,
vocab_list
unk_id
=
vocab_list
.
index
(
UNK
)
eos_id
=
vocab_list
.
index
(
EOS
)
return
token2id
,
id2token
,
vocab_list
,
unk_id
,
eos_id
deepspeech/frontend/normalizer.py
浏览文件 @
03a50d7b
...
@@ -40,21 +40,21 @@ class CollateFunc(object):
...
@@ -40,21 +40,21 @@ class CollateFunc(object):
number
=
0
number
=
0
for
item
in
batch
:
for
item
in
batch
:
audioseg
=
AudioSegment
.
from_file
(
item
[
'feat'
])
audioseg
=
AudioSegment
.
from_file
(
item
[
'feat'
])
feat
=
self
.
feature_func
(
audioseg
)
#(
D, T
)
feat
=
self
.
feature_func
(
audioseg
)
#(
T, D
)
sums
=
np
.
sum
(
feat
,
axis
=
1
)
sums
=
np
.
sum
(
feat
,
axis
=
0
)
if
mean_stat
is
None
:
if
mean_stat
is
None
:
mean_stat
=
sums
mean_stat
=
sums
else
:
else
:
mean_stat
+=
sums
mean_stat
+=
sums
square_sums
=
np
.
sum
(
np
.
square
(
feat
),
axis
=
1
)
square_sums
=
np
.
sum
(
np
.
square
(
feat
),
axis
=
0
)
if
var_stat
is
None
:
if
var_stat
is
None
:
var_stat
=
square_sums
var_stat
=
square_sums
else
:
else
:
var_stat
+=
square_sums
var_stat
+=
square_sums
number
+=
feat
.
shape
[
1
]
number
+=
feat
.
shape
[
0
]
return
number
,
mean_stat
,
var_stat
return
number
,
mean_stat
,
var_stat
...
@@ -120,7 +120,7 @@ class FeatureNormalizer(object):
...
@@ -120,7 +120,7 @@ class FeatureNormalizer(object):
"""Normalize features to be of zero mean and unit stddev.
"""Normalize features to be of zero mean and unit stddev.
:param features: Input features to be normalized.
:param features: Input features to be normalized.
:type features: ndarray, shape (
D, T
)
:type features: ndarray, shape (
T, D
)
:param eps: added to stddev to provide numerical stablibity.
:param eps: added to stddev to provide numerical stablibity.
:type eps: float
:type eps: float
:return: Normalized features.
:return: Normalized features.
...
@@ -131,8 +131,8 @@ class FeatureNormalizer(object):
...
@@ -131,8 +131,8 @@ class FeatureNormalizer(object):
def
_read_mean_std_from_file
(
self
,
filepath
,
eps
=
1e-20
):
def
_read_mean_std_from_file
(
self
,
filepath
,
eps
=
1e-20
):
"""Load mean and std from file."""
"""Load mean and std from file."""
mean
,
istd
=
load_cmvn
(
filepath
,
filetype
=
'json'
)
mean
,
istd
=
load_cmvn
(
filepath
,
filetype
=
'json'
)
self
.
_mean
=
np
.
expand_dims
(
mean
,
axis
=
-
1
)
self
.
_mean
=
np
.
expand_dims
(
mean
,
axis
=
0
)
self
.
_istd
=
np
.
expand_dims
(
istd
,
axis
=
-
1
)
self
.
_istd
=
np
.
expand_dims
(
istd
,
axis
=
0
)
def
write_to_file
(
self
,
filepath
):
def
write_to_file
(
self
,
filepath
):
"""Write the mean and stddev to the file.
"""Write the mean and stddev to the file.
...
...
deepspeech/frontend/utility.py
浏览文件 @
03a50d7b
...
@@ -15,6 +15,9 @@
...
@@ -15,6 +15,9 @@
import
codecs
import
codecs
import
json
import
json
import
math
import
math
from
typing
import
List
from
typing
import
Optional
from
typing
import
Text
import
numpy
as
np
import
numpy
as
np
...
@@ -23,16 +26,35 @@ from deepspeech.utils.log import Log
...
@@ -23,16 +26,35 @@ from deepspeech.utils.log import Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
__all__
=
[
"load_
cmvn"
,
"read_manifest"
,
"rms_to_db"
,
"rms_to_dbfs"
,
"max
_dbfs"
,
"load_
dict"
,
"load_cmvn"
,
"read_manifest"
,
"rms_to_db"
,
"rms_to
_dbfs"
,
"m
ean_dbfs"
,
"gain_db_to_ratio"
,
"normalize_audio"
,
"SOS"
,
"EOS"
,
"UNK
"
,
"m
ax_dbfs"
,
"mean_dbfs"
,
"gain_db_to_ratio"
,
"normalize_audio"
,
"SOS
"
,
"
BLANK
"
"
EOS"
,
"UNK"
,
"BLANK"
,
"MASKCTC
"
]
]
IGNORE_ID
=
-
1
IGNORE_ID
=
-
1
SOS
=
"<sos/eos>"
# `sos` and `eos` using same token
SOS
=
"<eos>"
EOS
=
SOS
EOS
=
SOS
UNK
=
"<unk>"
UNK
=
"<unk>"
BLANK
=
"<blank>"
BLANK
=
"<blank>"
MASKCTC
=
"<mask>"
def
load_dict
(
dict_path
:
Optional
[
Text
],
maskctc
=
False
)
->
Optional
[
List
[
Text
]]:
if
dict_path
is
None
:
return
None
with
open
(
dict_path
,
"r"
)
as
f
:
dictionary
=
f
.
readlines
()
char_list
=
[
entry
.
strip
().
split
(
" "
)[
0
]
for
entry
in
dictionary
]
if
BLANK
not
in
char_list
:
char_list
.
insert
(
0
,
BLANK
)
if
EOS
not
in
char_list
:
char_list
.
append
(
EOS
)
# for non-autoregressive maskctc model
if
maskctc
and
MASKCTC
not
in
char_list
:
char_list
.
append
(
MASKCTC
)
return
char_list
def
read_manifest
(
def
read_manifest
(
...
@@ -47,12 +69,20 @@ def read_manifest(
...
@@ -47,12 +69,20 @@ def read_manifest(
Args:
Args:
manifest_path ([type]): Manifest file to load and parse.
manifest_path ([type]): Manifest file to load and parse.
max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
max_input_len ([type], optional): maximum output seq length,
min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
in seconds for raw wav, in frame numbers for feature data.
max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0.
Defaults to float('inf').
min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0.
min_input_len (float, optional): minimum input seq length,
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0.
in seconds for raw wav, in frame numbers for feature data.
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05.
Defaults to 0.0.
max_output_len (float, optional): maximum input seq length,
in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length,
in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional):
maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional):
minimum output seq length/output seq length ratio. Defaults to 0.05.
Raises:
Raises:
IOError: If failed to parse the manifest.
IOError: If failed to parse the manifest.
...
...
deepspeech/io/collator.py
浏览文件 @
03a50d7b
...
@@ -242,7 +242,6 @@ class SpeechCollator():
...
@@ -242,7 +242,6 @@ class SpeechCollator():
# specgram augment
# specgram augment
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
specgram
.
transpose
([
1
,
0
])
return
specgram
,
transcript_part
return
specgram
,
transcript_part
def
__call__
(
self
,
batch
):
def
__call__
(
self
,
batch
):
...
@@ -250,7 +249,7 @@ class SpeechCollator():
...
@@ -250,7 +249,7 @@ class SpeechCollator():
Args:
Args:
batch ([List]): batch is (audio, text)
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (
D, T
)
audio (np.ndarray) shape (
T, D
)
text (List[int] or str): shape (U,)
text (List[int] or str): shape (U,)
Returns:
Returns:
...
...
deepspeech/io/collator_st.py
浏览文件 @
03a50d7b
...
@@ -217,6 +217,34 @@ class SpeechCollator():
...
@@ -217,6 +217,34 @@ class SpeechCollator():
return
self
.
_local_data
.
tar2object
[
tarpath
].
extractfile
(
return
self
.
_local_data
.
tar2object
[
tarpath
].
extractfile
(
self
.
_local_data
.
tar2info
[
tarpath
][
filename
])
self
.
_local_data
.
tar2info
[
tarpath
][
filename
])
@
property
def
manifest
(
self
):
return
self
.
_manifest
@
property
def
vocab_size
(
self
):
return
self
.
_speech_featurizer
.
vocab_size
@
property
def
vocab_list
(
self
):
return
self
.
_speech_featurizer
.
vocab_list
@
property
def
vocab_dict
(
self
):
return
self
.
_speech_featurizer
.
vocab_dict
@
property
def
text_feature
(
self
):
return
self
.
_speech_featurizer
.
text_feature
@
property
def
feature_size
(
self
):
return
self
.
_speech_featurizer
.
feature_size
@
property
def
stride_ms
(
self
):
return
self
.
_speech_featurizer
.
stride_ms
def
process_utterance
(
self
,
audio_file
,
translation
):
def
process_utterance
(
self
,
audio_file
,
translation
):
"""Load, augment, featurize and normalize for speech data.
"""Load, augment, featurize and normalize for speech data.
...
@@ -244,7 +272,6 @@ class SpeechCollator():
...
@@ -244,7 +272,6 @@ class SpeechCollator():
# specgram augment
# specgram augment
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
specgram
.
transpose
([
1
,
0
])
return
specgram
,
translation_part
return
specgram
,
translation_part
def
__call__
(
self
,
batch
):
def
__call__
(
self
,
batch
):
...
@@ -252,7 +279,7 @@ class SpeechCollator():
...
@@ -252,7 +279,7 @@ class SpeechCollator():
Args:
Args:
batch ([List]): batch is (audio, text)
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (
D, T
)
audio (np.ndarray) shape (
T, D
)
text (List[int] or str): shape (U,)
text (List[int] or str): shape (U,)
Returns:
Returns:
...
@@ -296,34 +323,6 @@ class SpeechCollator():
...
@@ -296,34 +323,6 @@ class SpeechCollator():
text_lens
=
np
.
array
(
text_lens
).
astype
(
np
.
int64
)
text_lens
=
np
.
array
(
text_lens
).
astype
(
np
.
int64
)
return
utts
,
padded_audios
,
audio_lens
,
padded_texts
,
text_lens
return
utts
,
padded_audios
,
audio_lens
,
padded_texts
,
text_lens
@
property
def
manifest
(
self
):
return
self
.
_manifest
@
property
def
vocab_size
(
self
):
return
self
.
_speech_featurizer
.
vocab_size
@
property
def
vocab_list
(
self
):
return
self
.
_speech_featurizer
.
vocab_list
@
property
def
vocab_dict
(
self
):
return
self
.
_speech_featurizer
.
vocab_dict
@
property
def
text_feature
(
self
):
return
self
.
_speech_featurizer
.
text_feature
@
property
def
feature_size
(
self
):
return
self
.
_speech_featurizer
.
feature_size
@
property
def
stride_ms
(
self
):
return
self
.
_speech_featurizer
.
stride_ms
class
TripletSpeechCollator
(
SpeechCollator
):
class
TripletSpeechCollator
(
SpeechCollator
):
def
process_utterance
(
self
,
audio_file
,
translation
,
transcript
):
def
process_utterance
(
self
,
audio_file
,
translation
,
transcript
):
...
@@ -355,7 +354,6 @@ class TripletSpeechCollator(SpeechCollator):
...
@@ -355,7 +354,6 @@ class TripletSpeechCollator(SpeechCollator):
# specgram augment
# specgram augment
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
specgram
.
transpose
([
1
,
0
])
return
specgram
,
translation_part
,
transcript_part
return
specgram
,
translation_part
,
transcript_part
def
__call__
(
self
,
batch
):
def
__call__
(
self
,
batch
):
...
@@ -363,7 +361,7 @@ class TripletSpeechCollator(SpeechCollator):
...
@@ -363,7 +361,7 @@ class TripletSpeechCollator(SpeechCollator):
Args:
Args:
batch ([List]): batch is (audio, text)
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (
D, T
)
audio (np.ndarray) shape (
T, D
)
text (List[int] or str): shape (U,)
text (List[int] or str): shape (U,)
Returns:
Returns:
...
@@ -524,49 +522,19 @@ class KaldiPrePorocessedCollator(SpeechCollator):
...
@@ -524,49 +522,19 @@ class KaldiPrePorocessedCollator(SpeechCollator):
:rtype: tuple of (2darray, list)
:rtype: tuple of (2darray, list)
"""
"""
specgram
=
kaldiio
.
load_mat
(
audio_file
)
specgram
=
kaldiio
.
load_mat
(
audio_file
)
specgram
=
specgram
.
transpose
([
1
,
0
])
assert
specgram
.
shape
[
assert
specgram
.
shape
[
0
]
==
self
.
_feat_dim
,
'expect feat dim {}, but got {}'
.
format
(
1
]
==
self
.
_feat_dim
,
'expect feat dim {}, but got {}'
.
format
(
self
.
_feat_dim
,
specgram
.
shape
[
0
])
self
.
_feat_dim
,
specgram
.
shape
[
1
])
# specgram augment
# specgram augment
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
specgram
.
transpose
([
1
,
0
])
if
self
.
_keep_transcription_text
:
if
self
.
_keep_transcription_text
:
return
specgram
,
translation
return
specgram
,
translation
else
:
else
:
text_ids
=
self
.
_text_featurizer
.
featurize
(
translation
)
text_ids
=
self
.
_text_featurizer
.
featurize
(
translation
)
return
specgram
,
text_ids
return
specgram
,
text_ids
@
property
def
manifest
(
self
):
return
self
.
_manifest
@
property
def
vocab_size
(
self
):
return
self
.
_text_featurizer
.
vocab_size
@
property
def
vocab_list
(
self
):
return
self
.
_text_featurizer
.
vocab_list
@
property
def
vocab_dict
(
self
):
return
self
.
_text_featurizer
.
vocab_dict
@
property
def
text_feature
(
self
):
return
self
.
_text_featurizer
@
property
def
feature_size
(
self
):
return
self
.
_feat_dim
@
property
def
stride_ms
(
self
):
return
self
.
_stride_ms
class
TripletKaldiPrePorocessedCollator
(
KaldiPrePorocessedCollator
):
class
TripletKaldiPrePorocessedCollator
(
KaldiPrePorocessedCollator
):
def
process_utterance
(
self
,
audio_file
,
translation
,
transcript
):
def
process_utterance
(
self
,
audio_file
,
translation
,
transcript
):
...
@@ -583,15 +551,13 @@ class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
...
@@ -583,15 +551,13 @@ class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
:rtype: tuple of (2darray, (list, list))
:rtype: tuple of (2darray, (list, list))
"""
"""
specgram
=
kaldiio
.
load_mat
(
audio_file
)
specgram
=
kaldiio
.
load_mat
(
audio_file
)
specgram
=
specgram
.
transpose
([
1
,
0
])
assert
specgram
.
shape
[
assert
specgram
.
shape
[
0
]
==
self
.
_feat_dim
,
'expect feat dim {}, but got {}'
.
format
(
1
]
==
self
.
_feat_dim
,
'expect feat dim {}, but got {}'
.
format
(
self
.
_feat_dim
,
specgram
.
shape
[
0
])
self
.
_feat_dim
,
specgram
.
shape
[
1
])
# specgram augment
# specgram augment
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
specgram
.
transpose
([
1
,
0
])
if
self
.
_keep_transcription_text
:
if
self
.
_keep_transcription_text
:
return
specgram
,
translation
,
transcript
return
specgram
,
translation
,
transcript
else
:
else
:
...
@@ -604,7 +570,7 @@ class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
...
@@ -604,7 +570,7 @@ class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
Args:
Args:
batch ([List]): batch is (audio, text)
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (
D, T
)
audio (np.ndarray) shape (
T, D
)
translation (List[int] or str): shape (U,)
translation (List[int] or str): shape (U,)
transcription (List[int] or str): shape (V,)
transcription (List[int] or str): shape (V,)
...
...
deepspeech/io/converter.py
浏览文件 @
03a50d7b
...
@@ -43,7 +43,7 @@ class CustomConverter():
...
@@ -43,7 +43,7 @@ class CustomConverter():
batch (list): The batch to transform.
batch (list): The batch to transform.
Returns:
Returns:
tuple(
paddle.Tensor, paddle.Tensor, paddle.Tensor
)
tuple(
np.ndarray, nn.ndarray, nn.ndarray
)
"""
"""
# batch should be located in list
# batch should be located in list
...
...
deepspeech/io/dataloader.py
浏览文件 @
03a50d7b
...
@@ -43,6 +43,18 @@ def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]],
...
@@ -43,6 +43,18 @@ def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]],
return
feat_dim
,
vocab_size
return
feat_dim
,
vocab_size
def
batch_collate
(
x
):
"""de-tuple.
Args:
x (List[Tuple]): [(utts, xs, ilens, ys, olens)]
Returns:
Tuple: (utts, xs, ilens, ys, olens)
"""
return
x
[
0
]
class
BatchDataLoader
():
class
BatchDataLoader
():
def
__init__
(
self
,
def
__init__
(
self
,
json_file
:
str
,
json_file
:
str
,
...
@@ -120,15 +132,15 @@ class BatchDataLoader():
...
@@ -120,15 +132,15 @@ class BatchDataLoader():
# actual bathsize is included in a list
# actual bathsize is included in a list
# default collate function converts numpy array to pytorch tensor
# default collate function converts numpy array to pytorch tensor
# we used an empty collate function instead which returns list
# we used an empty collate function instead which returns list
self
.
dataset
=
TransformDataset
(
self
.
dataset
=
TransformDataset
(
self
.
minibaches
,
self
.
converter
,
self
.
minibaches
,
self
.
reader
)
lambda
data
:
self
.
converter
([
self
.
reader
(
data
,
return_uttid
=
True
)]))
self
.
dataloader
=
DataLoader
(
self
.
dataloader
=
DataLoader
(
dataset
=
self
.
dataset
,
dataset
=
self
.
dataset
,
batch_size
=
1
,
batch_size
=
1
,
shuffle
=
not
self
.
use_sortagrad
if
train_mode
else
False
,
shuffle
=
not
self
.
use_sortagrad
if
self
.
train_mode
else
False
,
collate_fn
=
lambda
x
:
x
[
0
]
,
collate_fn
=
batch_collate
,
num_workers
=
n_iter_processes
,
)
num_workers
=
self
.
n_iter_processes
,
)
def
__repr__
(
self
):
def
__repr__
(
self
):
echo
=
f
"<
{
self
.
__class__
.
__module__
}
.
{
self
.
__class__
.
__name__
}
object at
{
hex
(
id
(
self
))
}
> "
echo
=
f
"<
{
self
.
__class__
.
__module__
}
.
{
self
.
__class__
.
__name__
}
object at
{
hex
(
id
(
self
))
}
> "
...
...
deepspeech/io/dataset.py
浏览文件 @
03a50d7b
...
@@ -129,15 +129,16 @@ class TransformDataset(Dataset):
...
@@ -129,15 +129,16 @@ class TransformDataset(Dataset):
Args:
Args:
data: list object from make_batchset
data: list object from make_batchset
transfrom: transform
function
converter: batch
function
reader: read data
"""
"""
def
__init__
(
self
,
data
,
transform
):
def
__init__
(
self
,
data
,
converter
,
reader
):
"""Init function."""
"""Init function."""
super
().
__init__
()
super
().
__init__
()
self
.
data
=
data
self
.
data
=
data
self
.
transform
=
transform
self
.
converter
=
converter
self
.
reader
=
reader
def
__len__
(
self
):
def
__len__
(
self
):
"""Len function."""
"""Len function."""
...
@@ -145,4 +146,4 @@ class TransformDataset(Dataset):
...
@@ -145,4 +146,4 @@ class TransformDataset(Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
"""[] operator."""
"""[] operator."""
return
self
.
transform
(
self
.
data
[
idx
])
return
self
.
converter
([
self
.
reader
(
self
.
data
[
idx
],
return_uttid
=
True
)
])
deepspeech/models/ds2_online/deepspeech2.py
浏览文件 @
03a50d7b
...
@@ -102,13 +102,13 @@ class CRNNEncoder(nn.Layer):
...
@@ -102,13 +102,13 @@ class CRNNEncoder(nn.Layer):
Args:
Args:
x (Tensor): [B, feature_size, D]
x (Tensor): [B, feature_size, D]
x_lens (Tensor): [B]
x_lens (Tensor): [B]
init_state_h_box(Tensor): init_states h for RNN layers
, num_rnn_layers * num_directions, batch_size, hidden_size
init_state_h_box(Tensor): init_states h for RNN layers
: [num_rnn_layers * num_directions, batch_size, hidden_size]
init_state_c_box(Tensor): init_states c for RNN layers
, num_rnn_layers * num_directions, batch_size, hidden_size
init_state_c_box(Tensor): init_states c for RNN layers
: [num_rnn_layers * num_directions, batch_size, hidden_size]
Return
s
:
Return:
x (Tensor): encoder outputs, [B, size, D]
x (Tensor): encoder outputs, [B, size, D]
x_lens (Tensor): encoder length, [B]
x_lens (Tensor): encoder length, [B]
final_state_h_box(Tensor): final_states h for RNN layers
, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_h_box(Tensor): final_states h for RNN layers
: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers
, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers
: [num_rnn_layers * num_directions, batch_size, hidden_size]
"""
"""
if
init_state_h_box
is
not
None
:
if
init_state_h_box
is
not
None
:
init_state_list
=
None
init_state_list
=
None
...
@@ -142,7 +142,7 @@ class CRNNEncoder(nn.Layer):
...
@@ -142,7 +142,7 @@ class CRNNEncoder(nn.Layer):
if
self
.
use_gru
is
True
:
if
self
.
use_gru
is
True
:
final_chunk_state_h_box
=
paddle
.
concat
(
final_chunk_state_h_box
=
paddle
.
concat
(
final_chunk_state_list
,
axis
=
0
)
final_chunk_state_list
,
axis
=
0
)
final_chunk_state_c_box
=
init_state_c_box
#paddle.zeros_like(final_chunk_state_h_box)
final_chunk_state_c_box
=
init_state_c_box
else
:
else
:
final_chunk_state_h_list
=
[
final_chunk_state_h_list
=
[
final_chunk_state_list
[
i
][
0
]
for
i
in
range
(
self
.
num_rnn_layers
)
final_chunk_state_list
[
i
][
0
]
for
i
in
range
(
self
.
num_rnn_layers
)
...
@@ -165,10 +165,10 @@ class CRNNEncoder(nn.Layer):
...
@@ -165,10 +165,10 @@ class CRNNEncoder(nn.Layer):
x_lens (Tensor): [B]
x_lens (Tensor): [B]
decoder_chunk_size: The chunk size of decoder
decoder_chunk_size: The chunk size of decoder
Returns:
Returns:
eouts_list (List of Tensor): The list of encoder outputs in chunk_size
,
[B, chunk_size, D] * num_chunks
eouts_list (List of Tensor): The list of encoder outputs in chunk_size
:
[B, chunk_size, D] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size
,
[B] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size
:
[B] * num_chunks
final_state_h_box(Tensor): final_states h for RNN layers
, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_h_box(Tensor): final_states h for RNN layers
: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers
, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers
: [num_rnn_layers * num_directions, batch_size, hidden_size]
"""
"""
subsampling_rate
=
self
.
conv
.
subsampling_rate
subsampling_rate
=
self
.
conv
.
subsampling_rate
receptive_field_length
=
self
.
conv
.
receptive_field_length
receptive_field_length
=
self
.
conv
.
receptive_field_length
...
@@ -215,12 +215,14 @@ class CRNNEncoder(nn.Layer):
...
@@ -215,12 +215,14 @@ class CRNNEncoder(nn.Layer):
class
DeepSpeech2ModelOnline
(
nn
.
Layer
):
class
DeepSpeech2ModelOnline
(
nn
.
Layer
):
"""The DeepSpeech2 network structure for online.
"""The DeepSpeech2 network structure for online.
:param audio
_data
: Audio spectrogram data layer.
:param audio: Audio spectrogram data layer.
:type audio
_data
: Variable
:type audio: Variable
:param text
_data
: Transcription text data layer.
:param text: Transcription text data layer.
:type text
_data
: Variable
:type text: Variable
:param audio_len: Valid sequence length data layer.
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:type audio_len: Variable
:param feat_size: feature size for audio.
:type feat_size: int
:param dict_size: Dictionary size for tokenized transcription.
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
:param num_conv_layers: Number of stacking convolution layers.
...
...
deepspeech/models/u2_st.py
浏览文件 @
03a50d7b
...
@@ -417,32 +417,32 @@ class U2STBaseModel(nn.Layer):
...
@@ -417,32 +417,32 @@ class U2STBaseModel(nn.Layer):
best_hyps
=
best_hyps
[:,
1
:]
best_hyps
=
best_hyps
[:,
1
:]
return
best_hyps
return
best_hyps
@
jit
.
export
@
jit
.
to_static
def
subsampling_rate
(
self
)
->
int
:
def
subsampling_rate
(
self
)
->
int
:
""" Export interface for c++ call, return subsampling_rate of the
""" Export interface for c++ call, return subsampling_rate of the
model
model
"""
"""
return
self
.
encoder
.
embed
.
subsampling_rate
return
self
.
encoder
.
embed
.
subsampling_rate
@
jit
.
export
@
jit
.
to_static
def
right_context
(
self
)
->
int
:
def
right_context
(
self
)
->
int
:
""" Export interface for c++ call, return right_context of the model
""" Export interface for c++ call, return right_context of the model
"""
"""
return
self
.
encoder
.
embed
.
right_context
return
self
.
encoder
.
embed
.
right_context
@
jit
.
export
@
jit
.
to_static
def
sos_symbol
(
self
)
->
int
:
def
sos_symbol
(
self
)
->
int
:
""" Export interface for c++ call, return sos symbol id of the model
""" Export interface for c++ call, return sos symbol id of the model
"""
"""
return
self
.
sos
return
self
.
sos
@
jit
.
export
@
jit
.
to_static
def
eos_symbol
(
self
)
->
int
:
def
eos_symbol
(
self
)
->
int
:
""" Export interface for c++ call, return eos symbol id of the model
""" Export interface for c++ call, return eos symbol id of the model
"""
"""
return
self
.
eos
return
self
.
eos
@
jit
.
export
@
jit
.
to_static
def
forward_encoder_chunk
(
def
forward_encoder_chunk
(
self
,
self
,
xs
:
paddle
.
Tensor
,
xs
:
paddle
.
Tensor
,
...
@@ -472,7 +472,7 @@ class U2STBaseModel(nn.Layer):
...
@@ -472,7 +472,7 @@ class U2STBaseModel(nn.Layer):
xs
,
offset
,
required_cache_size
,
subsampling_cache
,
xs
,
offset
,
required_cache_size
,
subsampling_cache
,
elayers_output_cache
,
conformer_cnn_cache
)
elayers_output_cache
,
conformer_cnn_cache
)
@
jit
.
export
@
jit
.
to_static
def
ctc_activation
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
def
ctc_activation
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
""" Export interface for c++ call, apply linear transform and log
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
softmax before ctc
...
@@ -483,7 +483,7 @@ class U2STBaseModel(nn.Layer):
...
@@ -483,7 +483,7 @@ class U2STBaseModel(nn.Layer):
"""
"""
return
self
.
ctc
.
log_softmax
(
xs
)
return
self
.
ctc
.
log_softmax
(
xs
)
@
jit
.
export
@
jit
.
to_static
def
forward_attention_decoder
(
def
forward_attention_decoder
(
self
,
self
,
hyps
:
paddle
.
Tensor
,
hyps
:
paddle
.
Tensor
,
...
...
deepspeech/training/cli.py
浏览文件 @
03a50d7b
...
@@ -16,23 +16,23 @@ import argparse
...
@@ -16,23 +16,23 @@ import argparse
def
default_argument_parser
():
def
default_argument_parser
():
r
"""A simple yet genral argument parser for experiments with parakeet.
r
"""A simple yet genral argument parser for experiments with parakeet.
This is used in examples with parakeet. And it is intended to be used by
This is used in examples with parakeet. And it is intended to be used by
other experiments with parakeet. It requires a minimal set of command line
other experiments with parakeet. It requires a minimal set of command line
arguments to start a training script.
arguments to start a training script.
The ``--config`` and ``--opts`` are used for overwrite the deault
The ``--config`` and ``--opts`` are used for overwrite the deault
configuration.
configuration.
The ``--data`` and ``--output`` specifies the data path and output path.
The ``--data`` and ``--output`` specifies the data path and output path.
Resuming training from existing progress at the output directory is the
Resuming training from existing progress at the output directory is the
intended default behavior.
intended default behavior.
The ``--checkpoint_path`` specifies the checkpoint to load from.
The ``--checkpoint_path`` specifies the checkpoint to load from.
The ``--device`` and ``--nprocs`` specifies how to run the training.
The ``--device`` and ``--nprocs`` specifies how to run the training.
See Also
See Also
--------
--------
parakeet.training.experiment
parakeet.training.experiment
...
@@ -47,28 +47,24 @@ def default_argument_parser():
...
@@ -47,28 +47,24 @@ def default_argument_parser():
# data and output
# data and output
parser
.
add_argument
(
"--config"
,
metavar
=
"FILE"
,
help
=
"path of the config file to overwrite to default config with."
)
parser
.
add_argument
(
"--config"
,
metavar
=
"FILE"
,
help
=
"path of the config file to overwrite to default config with."
)
parser
.
add_argument
(
"--dump-config"
,
metavar
=
"FILE"
,
help
=
"dump config to yaml file."
)
parser
.
add_argument
(
"--dump-config"
,
metavar
=
"FILE"
,
help
=
"dump config to yaml file."
)
# parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.")
parser
.
add_argument
(
"--output"
,
metavar
=
"OUTPUT_DIR"
,
help
=
"path to save checkpoint and logs."
)
parser
.
add_argument
(
"--output"
,
metavar
=
"OUTPUT_DIR"
,
help
=
"path to save checkpoint and logs."
)
# load from saved checkpoint
# load from saved checkpoint
parser
.
add_argument
(
"--checkpoint_path"
,
type
=
str
,
help
=
"path of the checkpoint to load"
)
parser
.
add_argument
(
"--checkpoint_path"
,
type
=
str
,
help
=
"path of the checkpoint to load"
)
# save jit model to
parser
.
add_argument
(
"--export_path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
# save asr result to
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
# running
# running
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
'gpu'
,
choices
=
[
"cpu"
,
"gpu"
],
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
'gpu'
,
choices
=
[
"cpu"
,
"gpu"
],
help
=
"device type to use, cpu and gpu are supported."
)
help
=
"device type to use, cpu and gpu are supported."
)
parser
.
add_argument
(
"--nprocs"
,
type
=
int
,
default
=
1
,
help
=
"number of parallel processes to use."
)
parser
.
add_argument
(
"--nprocs"
,
type
=
int
,
default
=
1
,
help
=
"number of parallel processes to use."
)
# overwrite extra config and default config
# overwrite extra config and default config
# parser.add_argument("--opts", nargs=argparse.REMAINDER,
# parser.add_argument("--opts", nargs=argparse.REMAINDER,
# help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
# help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
parser
.
add_argument
(
"--opts"
,
type
=
str
,
default
=
[],
nargs
=
'+'
,
parser
.
add_argument
(
"--opts"
,
type
=
str
,
default
=
[],
nargs
=
'+'
,
help
=
"options to overwrite --config file and the default config, passing in KEY VALUE pairs"
)
help
=
"options to overwrite --config file and the default config, passing in KEY VALUE pairs"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
None
,
help
=
"seed to use for paddle, np and random. None or 0 for random, else set seed."
)
# yapd: enable
# yapd: enable
return
parser
return
parser
deepspeech/training/extensions/__init__.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Callable
from
.extension
import
Extension
def
make_extension
(
trigger
:
Callable
=
None
,
default_name
:
str
=
None
,
priority
:
int
=
None
,
finalizer
:
Callable
=
None
,
initializer
:
Callable
=
None
,
on_error
:
Callable
=
None
):
"""Make an Extension-like object by injecting required attributes to it.
"""
if
trigger
is
None
:
trigger
=
Extension
.
trigger
if
priority
is
None
:
priority
=
Extension
.
priority
def
decorator
(
ext
):
ext
.
trigger
=
trigger
ext
.
default_name
=
default_name
or
ext
.
__name__
ext
.
priority
=
priority
ext
.
finalize
=
finalizer
ext
.
on_error
=
on_error
ext
.
initialize
=
initializer
return
ext
return
decorator
deepspeech/training/extensions/evaluator.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Dict
import
extension
import
paddle
from
paddle.io
import
DataLoader
from
paddle.nn
import
Layer
from
..reporter
import
DictSummary
from
..reporter
import
report
from
..reporter
import
scope
class
StandardEvaluator
(
extension
.
Extension
):
trigger
=
(
1
,
'epoch'
)
default_name
=
'validation'
priority
=
extension
.
PRIORITY_WRITER
name
=
None
def
__init__
(
self
,
model
:
Layer
,
dataloader
:
DataLoader
):
# it is designed to hold multiple models
models
=
{
"main"
:
model
}
self
.
models
:
Dict
[
str
,
Layer
]
=
models
self
.
model
=
model
# dataloaders
self
.
dataloader
=
dataloader
def
evaluate_core
(
self
,
batch
):
# compute
self
.
model
(
batch
)
# you may report here
def
evaluate
(
self
):
# switch to eval mode
for
model
in
self
.
models
.
values
():
model
.
eval
()
# to average evaluation metrics
summary
=
DictSummary
()
for
batch
in
self
.
dataloader
:
observation
=
{}
with
scope
(
observation
):
# main evaluation computation here.
with
paddle
.
no_grad
():
self
.
evaluate_core
(
batch
)
summary
.
add
(
observation
)
summary
=
summary
.
compute_mean
()
return
summary
def
__call__
(
self
,
trainer
=
None
):
# evaluate and report the averaged metric to current observation
# if it is used to extend a trainer, the metrics is reported to
# to observation of the trainer
# or otherwise, you can use your own observation
summary
=
self
.
evaluate
()
for
k
,
v
in
summary
.
items
():
report
(
k
,
v
)
deepspeech/training/extensions/extension.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
PRIORITY_WRITER
=
300
PRIORITY_EDITOR
=
200
PRIORITY_READER
=
100
class
Extension
():
"""Extension to customize the behavior of Trainer."""
trigger
=
(
1
,
'iteration'
)
priority
=
PRIORITY_READER
name
=
None
@
property
def
default_name
(
self
):
"""Default name of the extension, class name by default."""
return
type
(
self
).
__name__
def
__call__
(
self
,
trainer
):
"""Main action of the extention. After each update, it is executed
when the trigger fires."""
raise
NotImplementedError
(
'Extension implementation must override __call__.'
)
def
initialize
(
self
,
trainer
):
"""Action that is executed once to get the corect trainer state.
It is called before training normally, but if the trainer restores
states with an Snapshot extension, this method should also be called.
"""
pass
def
on_error
(
self
,
trainer
,
exc
,
tb
):
"""Handles the error raised during training before finalization.
"""
pass
def
finalize
(
self
,
trainer
):
"""Action that is executed when training is done.
For example, visualizers would need to be closed.
"""
pass
deepspeech/training/extensions/snapshot.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
datetime
import
datetime
from
pathlib
import
Path
from
typing
import
Any
from
typing
import
Dict
from
typing
import
List
import
jsonlines
from
deepspeech.training.extensions
import
extension
from
deepspeech.training.updaters.trainer
import
Trainer
from
deepspeech.utils.log
import
Log
from
deepspeech.utils.mp_tools
import
rank_zero_only
logger
=
Log
(
__name__
).
getlog
()
def
load_records
(
records_fp
):
"""Load record files (json lines.)"""
with
jsonlines
.
open
(
records_fp
,
'r'
)
as
reader
:
records
=
list
(
reader
)
return
records
class
Snapshot
(
extension
.
Extension
):
"""An extension to make snapshot of the updater object inside
the trainer. It is done by calling the updater's `save` method.
An Updater save its state_dict by default, which contains the
updater state, (i.e. epoch and iteration) and all the model
parameters and optimizer states. If the updater inside the trainer
subclasses StandardUpdater, everything is good to go.
Parameters
----------
checkpoint_dir : Union[str, Path]
The directory to save checkpoints into.
"""
trigger
=
(
1
,
'epoch'
)
priority
=
-
100
default_name
=
"snapshot"
def
__init__
(
self
,
max_size
:
int
=
5
,
snapshot_on_error
:
bool
=
False
):
self
.
records
:
List
[
Dict
[
str
,
Any
]]
=
[]
self
.
max_size
=
max_size
self
.
_snapshot_on_error
=
snapshot_on_error
self
.
_save_all
=
(
max_size
==
-
1
)
self
.
checkpoint_dir
=
None
def
initialize
(
self
,
trainer
:
Trainer
):
"""Setting up this extention."""
self
.
checkpoint_dir
=
trainer
.
out
/
"checkpoints"
# load existing records
record_path
:
Path
=
self
.
checkpoint_dir
/
"records.jsonl"
if
record_path
.
exists
():
logger
.
debug
(
"Loading from an existing checkpoint dir"
)
self
.
records
=
load_records
(
record_path
)
trainer
.
updater
.
load
(
self
.
records
[
-
1
][
'path'
])
def
on_error
(
self
,
trainer
,
exc
,
tb
):
if
self
.
_snapshot_on_error
:
self
.
save_checkpoint_and_update
(
trainer
)
def
__call__
(
self
,
trainer
:
Trainer
):
self
.
save_checkpoint_and_update
(
trainer
)
def
full
(
self
):
"""Whether the number of snapshots it keeps track of is greater
than the max_size."""
return
(
not
self
.
_save_all
)
and
len
(
self
.
records
)
>
self
.
max_size
@
rank_zero_only
def
save_checkpoint_and_update
(
self
,
trainer
:
Trainer
):
"""Saving new snapshot and remove the oldest snapshot if needed."""
iteration
=
trainer
.
updater
.
state
.
iteration
epoch
=
trainer
.
updater
.
state
.
epoch
num
=
epoch
if
self
.
trigger
[
1
]
==
'epoch'
else
iteration
path
=
self
.
checkpoint_dir
/
f
"
{
num
}
.pdz"
# add the new one
trainer
.
updater
.
save
(
path
)
record
=
{
"time"
:
str
(
datetime
.
now
()),
'path'
:
str
(
path
.
resolve
()),
# use absolute path
'iteration'
:
iteration
,
'epoch'
:
epoch
,
}
self
.
records
.
append
(
record
)
# remove the earist
if
self
.
full
():
eariest_record
=
self
.
records
[
0
]
os
.
remove
(
eariest_record
[
"path"
])
self
.
records
.
pop
(
0
)
# update the record file
record_path
=
self
.
checkpoint_dir
/
"records.jsonl"
with
jsonlines
.
open
(
record_path
,
'w'
)
as
writer
:
for
record
in
self
.
records
:
# jsonlines.open may return a Writer or a Reader
writer
.
write
(
record
)
# pylint: disable=no-member
deepspeech/training/extensions/visualizer.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
deepspeech.training.extensions
import
extension
from
deepspeech.training.updaters.trainer
import
Trainer
class
VisualDL
(
extension
.
Extension
):
"""A wrapper of visualdl log writer. It assumes that the metrics to be visualized
are all scalars which are recorded into the `.observation` dictionary of the
trainer object. The dictionary is created for each step, thus the visualdl log
writer uses the iteration from the updater's `iteration` as the global step to
add records.
"""
trigger
=
(
1
,
'iteration'
)
default_name
=
'visualdl'
priority
=
extension
.
PRIORITY_READER
def
__init__
(
self
,
writer
):
self
.
writer
=
writer
def
__call__
(
self
,
trainer
:
Trainer
):
for
k
,
v
in
trainer
.
observation
.
items
():
self
.
writer
.
add_scalar
(
k
,
v
,
step
=
trainer
.
updater
.
state
.
iteration
)
def
finalize
(
self
,
trainer
):
self
.
writer
.
close
()
deepspeech/training/reporter.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
contextlib
import
math
from
collections
import
defaultdict
OBSERVATIONS
=
None
@
contextlib
.
contextmanager
def
scope
(
observations
):
# make `observation` the target to report to.
# it is basically a dictionary that stores temporary observations
global
OBSERVATIONS
old
=
OBSERVATIONS
OBSERVATIONS
=
observations
try
:
yield
finally
:
OBSERVATIONS
=
old
def
get_observations
():
global
OBSERVATIONS
return
OBSERVATIONS
def
report
(
name
,
value
):
# a simple function to report named value
# you can use it everywhere, it will get the default target and writ to it
# you can think of it as std.out
observations
=
get_observations
()
if
observations
is
None
:
return
else
:
observations
[
name
]
=
value
class
Summary
():
"""Online summarization of a sequence of scalars.
Summary computes the statistics of given scalars online.
"""
def
__init__
(
self
):
self
.
_x
=
0.0
self
.
_x2
=
0.0
self
.
_n
=
0
def
add
(
self
,
value
,
weight
=
1
):
"""Adds a scalar value.
Args:
value: Scalar value to accumulate. It is either a NumPy scalar or
a zero-dimensional array (on CPU or GPU).
weight: An optional weight for the value. It is a NumPy scalar or
a zero-dimensional array (on CPU or GPU).
Default is 1 (integer).
"""
self
.
_x
+=
weight
*
value
self
.
_x2
+=
weight
*
value
*
value
self
.
_n
+=
weight
def
compute_mean
(
self
):
"""Computes the mean."""
x
,
n
=
self
.
_x
,
self
.
_n
return
x
/
n
def
make_statistics
(
self
):
"""Computes and returns the mean and standard deviation values.
Returns:
tuple: Mean and standard deviation values.
"""
x
,
n
=
self
.
_x
,
self
.
_n
mean
=
x
/
n
var
=
self
.
_x2
/
n
-
mean
*
mean
std
=
math
.
sqrt
(
var
)
return
mean
,
std
class
DictSummary
():
"""Online summarization of a sequence of dictionaries.
``DictSummary`` computes the statistics of a given set of scalars online.
It only computes the statistics for scalar values and variables of scalar
values in the dictionaries.
"""
def
__init__
(
self
):
self
.
_summaries
=
defaultdict
(
Summary
)
def
add
(
self
,
d
):
"""Adds a dictionary of scalars.
Args:
d (dict): Dictionary of scalars to accumulate. Only elements of
scalars, zero-dimensional arrays, and variables of
zero-dimensional arrays are accumulated. When the value
is a tuple, the second element is interpreted as a weight.
"""
summaries
=
self
.
_summaries
for
k
,
v
in
d
.
items
():
w
=
1
if
isinstance
(
v
,
tuple
):
v
=
v
[
0
]
w
=
v
[
1
]
summaries
[
k
].
add
(
v
,
weight
=
w
)
def
compute_mean
(
self
):
"""Creates a dictionary of mean values.
It returns a single dictionary that holds a mean value for each entry
added to the summary.
Returns:
dict: Dictionary of mean values.
"""
return
{
name
:
summary
.
compute_mean
()
for
name
,
summary
in
self
.
_summaries
.
items
()
}
def
make_statistics
(
self
):
"""Creates a dictionary of statistics.
It returns a single dictionary that holds mean and standard deviation
values for every entry added to the summary. For an entry of name
``'key'``, these values are added to the dictionary by names ``'key'``
and ``'key.std'``, respectively.
Returns:
dict: Dictionary of statistics of all entries.
"""
stats
=
{}
for
name
,
summary
in
self
.
_summaries
.
items
():
mean
,
std
=
summary
.
make_statistics
()
stats
[
name
]
=
mean
stats
[
name
+
'.std'
]
=
std
return
stats
deepspeech/training/trainer.py
浏览文件 @
03a50d7b
...
@@ -21,6 +21,7 @@ from tensorboardX import SummaryWriter
...
@@ -21,6 +21,7 @@ from tensorboardX import SummaryWriter
from
deepspeech.utils
import
mp_tools
from
deepspeech.utils
import
mp_tools
from
deepspeech.utils.checkpoint
import
Checkpoint
from
deepspeech.utils.checkpoint
import
Checkpoint
from
deepspeech.utils.log
import
Log
from
deepspeech.utils.log
import
Log
from
deepspeech.utils.utility
import
seed_all
__all__
=
[
"Trainer"
]
__all__
=
[
"Trainer"
]
...
@@ -94,6 +95,10 @@ class Trainer():
...
@@ -94,6 +95,10 @@ class Trainer():
self
.
iteration
=
0
self
.
iteration
=
0
self
.
epoch
=
0
self
.
epoch
=
0
if
args
.
seed
:
seed_all
(
args
.
seed
)
logger
.
info
(
f
"Set seed
{
args
.
seed
}
"
)
def
setup
(
self
):
def
setup
(
self
):
"""Setup the experiment.
"""Setup the experiment.
"""
"""
...
@@ -172,8 +177,10 @@ class Trainer():
...
@@ -172,8 +177,10 @@ class Trainer():
"""Reset the train loader seed and increment `epoch`.
"""Reset the train loader seed and increment `epoch`.
"""
"""
self
.
epoch
+=
1
self
.
epoch
+=
1
if
self
.
parallel
:
if
self
.
parallel
and
hasattr
(
self
.
train_loader
,
"batch_sampler"
):
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
batch_sampler
=
self
.
train_loader
.
batch_sampler
if
isinstance
(
batch_sampler
,
paddle
.
io
.
DistributedBatchSampler
):
batch_sampler
.
set_epoch
(
self
.
epoch
)
def
train
(
self
):
def
train
(
self
):
"""The training process control by epoch."""
"""The training process control by epoch."""
...
@@ -182,7 +189,7 @@ class Trainer():
...
@@ -182,7 +189,7 @@ class Trainer():
# save init model, i.e. 0 epoch
# save init model, i.e. 0 epoch
self
.
save
(
tag
=
'init'
,
infos
=
None
)
self
.
save
(
tag
=
'init'
,
infos
=
None
)
self
.
lr_scheduler
.
step
(
self
.
epoch
)
self
.
lr_scheduler
.
step
(
self
.
epoch
)
if
self
.
parallel
:
if
self
.
parallel
and
hasattr
(
self
.
train_loader
,
"batch_sampler"
)
:
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
...
...
deepspeech/training/triggers/__init__.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.interval_trigger
import
IntervalTrigger
def
never_fail_trigger
(
trainer
):
return
False
def
get_trigger
(
trigger
):
if
trigger
is
None
:
return
never_fail_trigger
if
callable
(
trigger
):
return
trigger
else
:
trigger
=
IntervalTrigger
(
*
trigger
)
return
trigger
deepspeech/training/triggers/interval_trigger.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
IntervalTrigger
():
"""A Predicate to do something every N cycle."""
def
__init__
(
self
,
period
:
int
,
unit
:
str
):
if
unit
not
in
(
"iteration"
,
"epoch"
):
raise
ValueError
(
"unit should be 'iteration' or 'epoch'"
)
if
period
<=
0
:
raise
ValueError
(
"period should be a positive integer."
)
self
.
period
=
period
self
.
unit
=
unit
self
.
last_index
=
None
def
__call__
(
self
,
trainer
):
if
self
.
last_index
is
None
:
last_index
=
getattr
(
trainer
.
updater
.
state
,
self
.
unit
)
self
.
last_index
=
last_index
last_index
=
self
.
last_index
index
=
getattr
(
trainer
.
updater
.
state
,
self
.
unit
)
fire
=
index
//
self
.
period
!=
last_index
//
self
.
period
self
.
last_index
=
index
return
fire
deepspeech/training/triggers/limit_trigger.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
LimitTrigger
():
"""A Predicate to decide whether to stop."""
def
__init__
(
self
,
limit
:
int
,
unit
:
str
):
if
unit
not
in
(
"iteration"
,
"epoch"
):
raise
ValueError
(
"unit should be 'iteration' or 'epoch'"
)
if
limit
<=
0
:
raise
ValueError
(
"limit should be a positive integer."
)
self
.
limit
=
limit
self
.
unit
=
unit
def
__call__
(
self
,
trainer
):
state
=
trainer
.
updater
.
state
index
=
getattr
(
state
,
self
.
unit
)
fire
=
index
>=
self
.
limit
return
fire
deepspeech/training/triggers/time_trigger.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
TimeTrigger
():
"""Trigger based on a fixed time interval.
This trigger accepts iterations with a given interval time.
Args:
period (float): Interval time. It is given in seconds.
"""
def
__init__
(
self
,
period
):
self
.
_period
=
period
self
.
_next_time
=
self
.
_period
def
__call__
(
self
,
trainer
):
if
self
.
_next_time
<
trainer
.
elapsed_time
:
self
.
_next_time
+=
self
.
_period
return
True
else
:
return
False
deepspeech/training/updaters/__init__.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
deepspeech/training/updaters/standard_updater.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Dict
from
typing
import
Optional
from
paddle
import
Tensor
from
paddle.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
from
paddle.nn
import
Layer
from
paddle.optimizer
import
Optimizer
from
timer
import
timer
from
deepspeech.training.reporter
import
report
from
deepspeech.training.updaters.updater
import
UpdaterBase
from
deepspeech.training.updaters.updater
import
UpdaterState
from
deepspeech.utils.log
import
Log
__all__
=
[
"StandardUpdater"
]
logger
=
Log
(
__name__
).
getlog
()
class
StandardUpdater
(
UpdaterBase
):
"""An example of over-simplification. Things may not be that simple, but
you can subclass it to fit your need.
"""
def
__init__
(
self
,
model
:
Layer
,
optimizer
:
Optimizer
,
dataloader
:
DataLoader
,
init_state
:
Optional
[
UpdaterState
]
=
None
):
# it is designed to hold multiple models
models
=
{
"main"
:
model
}
self
.
models
:
Dict
[
str
,
Layer
]
=
models
self
.
model
=
model
# it is designed to hold multiple optimizers
optimizers
=
{
"main"
:
optimizer
}
self
.
optimizer
=
optimizer
self
.
optimizers
:
Dict
[
str
,
Optimizer
]
=
optimizers
# dataloaders
self
.
dataloader
=
dataloader
# init state
if
init_state
is
None
:
self
.
state
=
UpdaterState
()
else
:
self
.
state
=
init_state
self
.
train_iterator
=
iter
(
dataloader
)
def
update
(
self
):
# We increase the iteration index after updating and before extension.
# Here are the reasons.
# 0. Snapshotting(as well as other extensions, like visualizer) is
# executed after a step of updating;
# 1. We decide to increase the iteration index after updating and
# before any all extension is executed.
# 3. We do not increase the iteration after extension because we
# prefer a consistent resume behavior, when load from a
# `snapshot_iter_100.pdz` then the next step to train is `101`,
# naturally. But if iteration is increased increased after
# extension(including snapshot), then, a `snapshot_iter_99` is
# loaded. You would need a extra increasing of the iteration idex
# before training to avoid another iteration `99`, which has been
# done before snapshotting.
# 4. Thus iteration index represrnts "currently how mant epochs has
# been done."
# NOTE: use report to capture the correctly value. If you want to
# report the learning rate used for a step, you must report it before
# the learning rate scheduler's step() has been called. In paddle's
# convention, we do not use an extension to change the learning rate.
# so if you want to report it, do it in the updater.
# Then here comes the next question. When is the proper time to
# increase the epoch index? Since all extensions are executed after
# updating, it is the time that after updating is the proper time to
# increase epoch index.
# 1. If we increase the epoch index before updating, then an extension
# based ot epoch would miss the correct timing. It could only be
# triggerd after an extra updating.
# 2. Theoretically, when an epoch is done, the epoch index should be
# increased. So it would be increase after updating.
# 3. Thus, eppoch index represents "currently how many epochs has been
# done." So it starts from 0.
# switch to training mode
for
model
in
self
.
models
.
values
():
model
.
train
()
# training for a step is implemented here
batch
=
self
.
read_batch
()
self
.
update_core
(
batch
)
self
.
state
.
iteration
+=
1
if
self
.
updates_per_epoch
is
not
None
:
if
self
.
state
.
iteration
%
self
.
updates_per_epoch
==
0
:
self
.
state
.
epoch
+=
1
def
update_core
(
self
,
batch
):
"""A simple case for a training step. Basic assumptions are:
Single model;
Single optimizer;
A batch from the dataloader is just the input of the model;
The model return a single loss, or a dict containing serval losses.
Parameters updates at every batch, no gradient accumulation.
"""
loss
=
self
.
model
(
*
batch
)
if
isinstance
(
loss
,
Tensor
):
loss_dict
=
{
"main"
:
loss
}
else
:
# Dict[str, Tensor]
loss_dict
=
loss
if
"main"
not
in
loss_dict
:
main_loss
=
0
for
loss_item
in
loss
.
values
():
main_loss
+=
loss_item
loss_dict
[
"main"
]
=
main_loss
for
name
,
loss_item
in
loss_dict
.
items
():
report
(
name
,
float
(
loss_item
))
self
.
optimizer
.
clear_gradient
()
loss_dict
[
"main"
].
backward
()
self
.
optimizer
.
update
()
@
property
def
updates_per_epoch
(
self
):
"""Number of updater per epoch, determined by the length of the
dataloader."""
length_of_dataloader
=
None
try
:
length_of_dataloader
=
len
(
self
.
dataloader
)
except
TypeError
:
logger
.
debug
(
"This dataloader has no __len__."
)
finally
:
return
length_of_dataloader
def
new_epoch
(
self
):
"""Start a new epoch."""
# NOTE: all batch sampler for distributed training should
# subclass DistributedBatchSampler and implement `set_epoch` method
if
hasattr
(
self
.
dataloader
,
"batch_sampler"
):
batch_sampler
=
self
.
dataloader
.
batch_sampler
if
isinstance
(
batch_sampler
,
DistributedBatchSampler
):
batch_sampler
.
set_epoch
(
self
.
state
.
epoch
)
self
.
train_iterator
=
iter
(
self
.
dataloader
)
def
read_batch
(
self
):
"""Read a batch from the data loader, auto renew when data is exhausted."""
with
timer
()
as
t
:
try
:
batch
=
next
(
self
.
train_iterator
)
except
StopIteration
:
self
.
new_epoch
()
batch
=
next
(
self
.
train_iterator
)
logger
.
debug
(
f
"Read a batch takes
{
t
.
elapse
}
s."
)
# replace it with logger
return
batch
def
state_dict
(
self
):
"""State dict of a Updater, model, optimizer and updater state are included."""
state_dict
=
super
().
state_dict
()
for
name
,
model
in
self
.
models
.
items
():
state_dict
[
f
"
{
name
}
_params"
]
=
model
.
state_dict
()
for
name
,
optim
in
self
.
optimizers
.
items
():
state_dict
[
f
"
{
name
}
_optimizer"
]
=
optim
.
state_dict
()
return
state_dict
def
set_state_dict
(
self
,
state_dict
):
"""Set state dict for a Updater. Parameters of models, states for
optimizers and UpdaterState are restored."""
for
name
,
model
in
self
.
models
.
items
():
model
.
set_state_dict
(
state_dict
[
f
"
{
name
}
_params"
])
for
name
,
optim
in
self
.
optimizers
.
items
():
optim
.
set_state_dict
(
state_dict
[
f
"
{
name
}
_optimizer"
])
super
().
set_state_dict
(
state_dict
)
deepspeech/training/updaters/trainer.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
traceback
from
collections
import
OrderedDict
from
pathlib
import
Path
from
typing
import
Callable
from
typing
import
List
from
typing
import
Union
import
six
import
tqdm
from
deepspeech.training.extensions.extension
import
Extension
from
deepspeech.training.extensions.extension
import
PRIORITY_READER
from
deepspeech.training.reporter
import
scope
from
deepspeech.training.triggers
import
get_trigger
from
deepspeech.training.triggers.limit_trigger
import
LimitTrigger
from
deepspeech.training.updaters.updater
import
UpdaterBase
class
_ExtensionEntry
():
def
__init__
(
self
,
extension
,
trigger
,
priority
):
self
.
extension
=
extension
self
.
trigger
=
trigger
self
.
priority
=
priority
class
Trainer
():
def
__init__
(
self
,
updater
:
UpdaterBase
,
stop_trigger
:
Callable
=
None
,
out
:
Union
[
str
,
Path
]
=
'result'
,
extensions
:
List
[
Extension
]
=
None
):
self
.
updater
=
updater
self
.
extensions
=
OrderedDict
()
self
.
stop_trigger
=
LimitTrigger
(
*
stop_trigger
)
self
.
out
=
Path
(
out
)
self
.
observation
=
None
self
.
_done
=
False
if
extensions
:
for
ext
in
extensions
:
self
.
extend
(
ext
)
@
property
def
is_before_training
(
self
):
return
self
.
updater
.
state
.
iteration
==
0
def
extend
(
self
,
extension
,
name
=
None
,
trigger
=
None
,
priority
=
None
):
# get name for the extension
# argument \
# -> extention's name \
# -> default_name (class name, when it is an object) \
# -> function name when it is a function \
# -> error
if
name
is
None
:
name
=
getattr
(
extension
,
'name'
,
None
)
if
name
is
None
:
name
=
getattr
(
extension
,
'default_name'
,
None
)
if
name
is
None
:
name
=
getattr
(
extension
,
'__name__'
,
None
)
if
name
is
None
:
raise
ValueError
(
"Name is not given for the extension."
)
if
name
==
'training'
:
raise
ValueError
(
"training is a reserved name."
)
if
trigger
is
None
:
trigger
=
getattr
(
extension
,
'trigger'
,
(
1
,
'iteration'
))
trigger
=
get_trigger
(
trigger
)
if
priority
is
None
:
priority
=
getattr
(
extension
,
'priority'
,
PRIORITY_READER
)
# add suffix to avoid nameing conflict
ordinal
=
0
modified_name
=
name
while
modified_name
in
self
.
extensions
:
ordinal
+=
1
modified_name
=
f
"
{
name
}
_
{
ordinal
}
"
extension
.
name
=
modified_name
self
.
extensions
[
modified_name
]
=
_ExtensionEntry
(
extension
,
trigger
,
priority
)
def
get_extension
(
self
,
name
):
"""get extension by name."""
extensions
=
self
.
extensions
if
name
in
extensions
:
return
extensions
[
name
].
extension
else
:
raise
ValueError
(
f
'extension
{
name
}
not found'
)
def
run
(
self
):
if
self
.
_done
:
raise
RuntimeError
(
"Training is already done!."
)
self
.
out
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# sort extensions by priorities once
extension_order
=
sorted
(
self
.
extensions
.
keys
(),
key
=
lambda
name
:
self
.
extensions
[
name
].
priority
,
reverse
=
True
)
extensions
=
[(
name
,
self
.
extensions
[
name
])
for
name
in
extension_order
]
# initializing all extensions
for
name
,
entry
in
extensions
:
if
hasattr
(
entry
.
extension
,
"initialize"
):
entry
.
extension
.
initialize
(
self
)
update
=
self
.
updater
.
update
# training step
stop_trigger
=
self
.
stop_trigger
# display only one progress bar
max_iteration
=
None
if
isinstance
(
stop_trigger
,
LimitTrigger
):
if
stop_trigger
.
unit
==
'epoch'
:
max_epoch
=
self
.
stop_trigger
.
limit
updates_per_epoch
=
getattr
(
self
.
updater
,
"updates_per_epoch"
,
None
)
max_iteration
=
max_epoch
*
updates_per_epoch
if
updates_per_epoch
else
None
else
:
max_iteration
=
self
.
stop_trigger
.
limit
p
=
tqdm
.
tqdm
(
initial
=
self
.
updater
.
state
.
iteration
,
total
=
max_iteration
)
try
:
while
not
stop_trigger
(
self
):
self
.
observation
=
{}
# set observation as the report target
# you can use report freely in Updater.update()
# updating parameters and state
with
scope
(
self
.
observation
):
update
()
p
.
update
()
# execute extension when necessary
for
name
,
entry
in
extensions
:
if
entry
.
trigger
(
self
):
entry
.
extension
(
self
)
# print("###", self.observation)
except
Exception
as
e
:
f
=
sys
.
stderr
f
.
write
(
f
"Exception in main training loop:
{
e
}
\n
"
)
f
.
write
(
"Traceback (most recent call last):
\n
"
)
traceback
.
print_tb
(
sys
.
exc_info
()[
2
])
f
.
write
(
"Trainer extensions will try to handle the extension. Then all extensions will finalize."
)
# capture the exception in the mian training loop
exc_info
=
sys
.
exc_info
()
# try to handle it
for
name
,
entry
in
extensions
:
if
hasattr
(
entry
.
extension
,
"on_error"
):
try
:
entry
.
extension
.
on_error
(
self
,
e
,
sys
.
exc_info
()[
2
])
except
Exception
as
ee
:
f
.
write
(
f
"Exception in error handler:
{
ee
}
\n
"
)
f
.
write
(
'Traceback (most recent call last):
\n
'
)
traceback
.
print_tb
(
sys
.
exc_info
()[
2
])
# raise exception in main training loop
six
.
reraise
(
*
exc_info
)
finally
:
for
name
,
entry
in
extensions
:
if
hasattr
(
entry
.
extension
,
"finalize"
):
entry
.
extension
.
finalize
(
self
)
deepspeech/training/updaters/updater.py
0 → 100644
浏览文件 @
03a50d7b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
import
paddle
from
deepspeech.utils.log
import
Log
__all__
=
[
"UpdaterBase"
,
"UpdaterState"
]
logger
=
Log
(
__name__
).
getlog
()
@
dataclass
class
UpdaterState
:
iteration
:
int
=
0
epoch
:
int
=
0
class
UpdaterBase
():
"""An updater is the abstraction of how a model is trained given the
dataloader and the optimizer.
The `update_core` method is a step in the training loop with only necessary
operations (get a batch, forward and backward, update the parameters).
Other stuffs are made extensions. Visualization, saving, loading and
periodical validation and evaluation are not considered here.
But even in such simplist case, things are not that simple. There is an
attempt to standardize this process and requires only the model and
dataset and do all the stuffs automatically. But this may hurt flexibility.
If we assume a batch yield from the dataloader is just the input to the
model, we will find that some model requires more arguments, or just some
keyword arguments. But this prevents us from over-simplifying it.
From another perspective, the batch may includes not just the input, but
also the target. But the model's forward method may just need the input.
We can pass a dict or a super-long tuple to the model and let it pick what
it really needs. But this is an abuse of lazy interface.
After all, we care about how a model is trained. But just how the model is
used for inference. We want to control how a model is trained. We just
don't want to be messed up with other auxiliary code.
So the best practice is to define a model and define a updater for it.
"""
def
__init__
(
self
,
init_state
=
None
):
if
init_state
is
None
:
self
.
state
=
UpdaterState
()
else
:
self
.
state
=
init_state
def
update
(
self
,
batch
):
raise
NotImplementedError
(
"Implement your own `update` method for training a step."
)
def
state_dict
(
self
):
state_dict
=
{
"epoch"
:
self
.
state
.
epoch
,
"iteration"
:
self
.
state
.
iteration
,
}
return
state_dict
def
set_state_dict
(
self
,
state_dict
):
self
.
state
.
epoch
=
state_dict
[
"epoch"
]
self
.
state
.
iteration
=
state_dict
[
"iteration"
]
def
save
(
self
,
path
):
logger
.
debug
(
f
"Saving to
{
path
}
."
)
archive
=
self
.
state_dict
()
paddle
.
save
(
archive
,
str
(
path
))
def
load
(
self
,
path
):
logger
.
debug
(
f
"Loading from
{
path
}
."
)
archive
=
paddle
.
load
(
str
(
path
))
self
.
set_state_dict
(
archive
)
deepspeech/utils/utility.py
浏览文件 @
03a50d7b
...
@@ -15,9 +15,19 @@
...
@@ -15,9 +15,19 @@
import
distutils.util
import
distutils.util
import
math
import
math
import
os
import
os
import
random
from
typing
import
List
from
typing
import
List
__all__
=
[
'print_arguments'
,
'add_arguments'
,
"log_add"
]
import
numpy
as
np
import
paddle
__all__
=
[
"seed_all"
,
'print_arguments'
,
'add_arguments'
,
"log_add"
]
def
seed_all
(
seed
:
int
=
210329
):
np
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
paddle
.
seed
(
seed
)
def
print_arguments
(
args
,
info
=
None
):
def
print_arguments
(
args
,
info
=
None
):
...
...
examples/aishell/s0/README.md
浏览文件 @
03a50d7b
# Aishell-1
# Aishell-1
## Data
| Data Subset | Duration in Seconds |
| data/manifest.train | 1.23 ~ 14.53125 |
| data/manifest.dev | 1.645 ~ 12.533 |
| data/manifest.test | 1.859125 ~ 14.6999375 |
## Deepspeech2
## Deepspeech2
| Model | Params | Release | Config | Test set | Loss | CER |
| Model | Params | Release | Config | Test set | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug + new datapipe | test | 6.396368026733398 | 0.068382
,0.073507
|
| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug + new datapipe | test | 6.396368026733398 | 0.068382 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 |
| DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 |
| DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 |
...
...
examples/aishell/s0/conf/augmentation.json
浏览文件 @
03a50d7b
...
@@ -19,15 +19,17 @@
...
@@ -19,15 +19,17 @@
{
{
"type"
:
"specaug"
,
"type"
:
"specaug"
,
"params"
:
{
"params"
:
{
"W"
:
0
,
"warp_mode"
:
"PIL"
,
"F"
:
10
,
"F"
:
10
,
"T"
:
50
,
"n_freq_masks"
:
2
,
"n_freq_masks"
:
2
,
"T"
:
50
,
"n_time_masks"
:
2
,
"n_time_masks"
:
2
,
"p"
:
1.0
,
"p"
:
1.0
,
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
},
},
"prob"
:
1.0
"prob"
:
1.0
}
}
...
...
examples/aishell/s0/local/train.sh
浏览文件 @
03a50d7b
...
@@ -19,12 +19,22 @@ fi
...
@@ -19,12 +19,22 @@ fi
mkdir
-p
exp
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
\
--output
exp/
${
ckpt_name
}
\
--model_type
${
model_type
}
--model_type
${
model_type
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
echo
"Failed in training!"
...
...
examples/aishell/s1/conf/augmentation.json
浏览文件 @
03a50d7b
...
@@ -27,7 +27,9 @@
...
@@ -27,7 +27,9 @@
"W"
:
80
,
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
,
"warp_mode"
:
"PIL"
},
},
"prob"
:
1.0
"prob"
:
1.0
}
}
...
...
examples/aishell/s1/local/train.sh
浏览文件 @
03a50d7b
...
@@ -19,11 +19,21 @@ echo "using ${device}..."
...
@@ -19,11 +19,21 @@ echo "using ${device}..."
mkdir
-p
exp
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
echo
"Failed in training!"
...
...
examples/aug_conf/augmentation.json
已删除
100644 → 0
浏览文件 @
8d506270
[
{
"type"
:
"shift"
,
"params"
:
{
"min_shift_ms"
:
-5
,
"max_shift_ms"
:
5
},
"prob"
:
1.0
}
]
examples/aug
_conf/augmentation.example
.json
→
examples/aug
mentation/augmentation
.json
浏览文件 @
03a50d7b
...
@@ -52,16 +52,18 @@
...
@@ -52,16 +52,18 @@
{
{
"type"
:
"specaug"
,
"type"
:
"specaug"
,
"params"
:
{
"params"
:
{
"W"
:
80
,
"warp_mode"
:
"PIL"
,
"F"
:
10
,
"F"
:
10
,
"T"
:
50
,
"n_freq_masks"
:
2
,
"n_freq_masks"
:
2
,
"T"
:
50
,
"n_time_masks"
:
2
,
"n_time_masks"
:
2
,
"p"
:
1.0
,
"p"
:
1.0
,
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
false
},
},
"prob"
:
0
.0
"prob"
:
1
.0
}
}
]
]
examples/callcenter/s1/conf/augmentation.json
浏览文件 @
03a50d7b
...
@@ -27,7 +27,8 @@
...
@@ -27,7 +27,8 @@
"W"
:
80
,
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
},
},
"prob"
:
1.0
"prob"
:
1.0
}
}
...
...
examples/callcenter/s1/local/train.sh
浏览文件 @
03a50d7b
...
@@ -19,11 +19,21 @@ echo "using ${device}..."
...
@@ -19,11 +19,21 @@ echo "using ${device}..."
mkdir
-p
exp
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
echo
"Failed in training!"
...
...
examples/librispeech/s0/conf/augmentation.json
浏览文件 @
03a50d7b
...
@@ -19,15 +19,17 @@
...
@@ -19,15 +19,17 @@
{
{
"type"
:
"specaug"
,
"type"
:
"specaug"
,
"params"
:
{
"params"
:
{
"W"
:
0
,
"warp_mode"
:
"PIL"
,
"F"
:
10
,
"F"
:
10
,
"T"
:
50
,
"n_freq_masks"
:
2
,
"n_freq_masks"
:
2
,
"T"
:
50
,
"n_time_masks"
:
2
,
"n_time_masks"
:
2
,
"p"
:
1.0
,
"p"
:
1.0
,
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
},
},
"prob"
:
1.0
"prob"
:
1.0
}
}
...
...
examples/librispeech/s0/local/train.sh
浏览文件 @
03a50d7b
...
@@ -20,12 +20,22 @@ echo "using ${device}..."
...
@@ -20,12 +20,22 @@ echo "using ${device}..."
mkdir
-p
exp
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
\
--output
exp/
${
ckpt_name
}
\
--model_type
${
model_type
}
--model_type
${
model_type
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
echo
"Failed in training!"
...
...
examples/librispeech/s1/conf/augmentation.json
浏览文件 @
03a50d7b
...
@@ -27,7 +27,9 @@
...
@@ -27,7 +27,9 @@
"W"
:
80
,
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
,
"warp_mode"
:
"PIL"
},
},
"prob"
:
1.0
"prob"
:
1.0
}
}
...
...
examples/librispeech/s1/local/train.sh
浏览文件 @
03a50d7b
...
@@ -19,11 +19,21 @@ echo "using ${device}..."
...
@@ -19,11 +19,21 @@ echo "using ${device}..."
mkdir
-p
exp
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
echo
"Failed in training!"
...
...
examples/librispeech/s2/conf/augmentation.json
浏览文件 @
03a50d7b
...
@@ -2,15 +2,17 @@
...
@@ -2,15 +2,17 @@
{
{
"type"
:
"specaug"
,
"type"
:
"specaug"
,
"params"
:
{
"params"
:
{
"F"
:
10
,
"W"
:
5
,
"T"
:
50
,
"warp_mode"
:
"PIL"
,
"F"
:
30
,
"n_freq_masks"
:
2
,
"n_freq_masks"
:
2
,
"T"
:
40
,
"n_time_masks"
:
2
,
"n_time_masks"
:
2
,
"p"
:
1.0
,
"p"
:
1.0
,
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
false
},
},
"prob"
:
1.0
"prob"
:
1.0
}
}
...
...
examples/librispeech/s2/conf/transformer.yaml
浏览文件 @
03a50d7b
...
@@ -3,37 +3,28 @@ data:
...
@@ -3,37 +3,28 @@ data:
train_manifest
:
data/manifest.train
train_manifest
:
data/manifest.train
dev_manifest
:
data/manifest.dev
dev_manifest
:
data/manifest.dev
test_manifest
:
data/manifest.test-clean
test_manifest
:
data/manifest.test-clean
min_input_len
:
0.5
# second
max_input_len
:
20.0
# second
min_output_len
:
0.0
# tokens
max_output_len
:
400.0
# tokens
min_output_input_ratio
:
0.05
max_output_input_ratio
:
10.0
collator
:
collator
:
vocab_filepath
:
data/
vocab
.txt
vocab_filepath
:
data/
train_960_unigram5000_units
.txt
unit_type
:
'
spm'
unit_type
:
'
spm'
spm_model_prefix
:
'
data/bpe_unigram_5000'
spm_model_prefix
:
'
data/train_960_unigram5000'
mean_std_filepath
:
"
"
augmentation_config
:
conf/augmentation.json
batch_size
:
64
raw_wav
:
True
# use raw_wav or kaldi feature
specgram_type
:
fbank
#linear, mfcc, fbank
feat_dim
:
83
feat_dim
:
83
delta_delta
:
False
dither
:
1.0
target_sample_rate
:
16000
max_freq
:
None
n_fft
:
None
stride_ms
:
10.0
stride_ms
:
10.0
window_ms
:
25.0
window_ms
:
25.0
use_dB_normalization
:
True
sortagrad
:
0
# Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
target_dB
:
-20
batch_size
:
32
random_seed
:
0
maxlen_in
:
512
# if input length > maxlen-in, batchsize is automatically reduced
keep_transcription_text
:
False
maxlen_out
:
150
# if output length > maxlen-out, batchsize is automatically reduced
sortagrad
:
True
minibatches
:
0
# for debug
shuffle_method
:
batch_shuffle
batch_count
:
auto
batch_bins
:
0
batch_frames_in
:
0
batch_frames_out
:
0
batch_frames_inout
:
0
augmentation_config
:
conf/augmentation.json
num_workers
:
2
num_workers
:
2
subsampling_factor
:
1
num_encs
:
1
# network architecture
# network architecture
...
...
examples/librispeech/s2/local/align.sh
浏览文件 @
03a50d7b
#!/bin/bash
#!/bin/bash
if
[
$#
!=
2
]
;
then
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
echo
"usage:
${
0
}
config_path
dict_path
ckpt_path_prefix"
exit
-1
exit
-1
fi
fi
...
@@ -13,7 +13,8 @@ if [ ${ngpu} == 0 ];then
...
@@ -13,7 +13,8 @@ if [ ${ngpu} == 0 ];then
device
=
cpu
device
=
cpu
fi
fi
config_path
=
$1
config_path
=
$1
ckpt_prefix
=
$2
dict_path
=
$2
ckpt_prefix
=
$3
batch_size
=
1
batch_size
=
1
output_dir
=
${
ckpt_prefix
}
output_dir
=
${
ckpt_prefix
}
...
@@ -22,11 +23,13 @@ mkdir -p ${output_dir}
...
@@ -22,11 +23,13 @@ mkdir -p ${output_dir}
# align dump in `result_file`
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3
-u
${
BIN_DIR
}
/test.py
\
python3
-u
${
BIN_DIR
}
/test.py
\
--run_mode
'align'
\
--model-name
'u2_kaldi'
\
--run-mode
'align'
\
--dict-path
${
dict_path
}
\
--device
${
device
}
\
--device
${
device
}
\
--nproc
1
\
--nproc
1
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--result
_
file
${
output_dir
}
/
${
type
}
.align
\
--result
-
file
${
output_dir
}
/
${
type
}
.align
\
--checkpoint_path
${
ckpt_prefix
}
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.batch_size
${
batch_size
}
--opts
decoding.batch_size
${
batch_size
}
...
...
examples/librispeech/s2/local/export.sh
浏览文件 @
03a50d7b
...
@@ -18,7 +18,8 @@ if [ ${ngpu} == 0 ];then
...
@@ -18,7 +18,8 @@ if [ ${ngpu} == 0 ];then
fi
fi
python3
-u
${
BIN_DIR
}
/test.py
\
python3
-u
${
BIN_DIR
}
/test.py
\
--run_mode
'export'
\
--model-name
'u2_kaldi'
\
--run-mode
'export'
\
--device
${
device
}
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
...
...
examples/librispeech/s2/local/test.sh
浏览文件 @
03a50d7b
#!/bin/bash
#!/bin/bash
if
[
$#
!=
2
]
;
then
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix"
echo
"usage:
${
0
}
config_path
dict_path
ckpt_path_prefix"
exit
-1
exit
-1
fi
fi
...
@@ -14,7 +14,8 @@ if [ ${ngpu} == 0 ];then
...
@@ -14,7 +14,8 @@ if [ ${ngpu} == 0 ];then
fi
fi
config_path
=
$1
config_path
=
$1
ckpt_prefix
=
$2
dict_path
=
$2
ckpt_prefix
=
$3
chunk_mode
=
false
chunk_mode
=
false
if
[[
${
config_path
}
=
~ ^.
*
chunk_.
*
yaml
$
]]
;
then
if
[[
${
config_path
}
=
~ ^.
*
chunk_.
*
yaml
$
]]
;
then
...
@@ -38,11 +39,13 @@ for type in attention ctc_greedy_search; do
...
@@ -38,11 +39,13 @@ for type in attention ctc_greedy_search; do
batch_size
=
64
batch_size
=
64
fi
fi
python3
-u
${
BIN_DIR
}
/test.py
\
python3
-u
${
BIN_DIR
}
/test.py
\
--run_mode
test
\
--model-name
u2_kaldi
\
--run-mode
test
\
--dict-path
${
dict_path
}
\
--device
${
device
}
\
--device
${
device
}
\
--nproc
1
\
--nproc
1
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--result
_
file
${
ckpt_prefix
}
.
${
type
}
.rsl
\
--result
-
file
${
ckpt_prefix
}
.
${
type
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.decoding_method
${
type
}
decoding.batch_size
${
batch_size
}
--opts
decoding.decoding_method
${
type
}
decoding.batch_size
${
batch_size
}
...
@@ -56,11 +59,13 @@ for type in ctc_prefix_beam_search attention_rescoring; do
...
@@ -56,11 +59,13 @@ for type in ctc_prefix_beam_search attention_rescoring; do
echo
"decoding
${
type
}
"
echo
"decoding
${
type
}
"
batch_size
=
1
batch_size
=
1
python3
-u
${
BIN_DIR
}
/test.py
\
python3
-u
${
BIN_DIR
}
/test.py
\
--run_mode
test
\
--model-name
u2_kaldi
\
--run-mode
test
\
--dict-path
${
dict_path
}
\
--device
${
device
}
\
--device
${
device
}
\
--nproc
1
\
--nproc
1
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--result
_
file
${
ckpt_prefix
}
.
${
type
}
.rsl
\
--result
-
file
${
ckpt_prefix
}
.
${
type
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.decoding_method
${
type
}
decoding.batch_size
${
batch_size
}
--opts
decoding.decoding_method
${
type
}
decoding.batch_size
${
batch_size
}
...
...
examples/librispeech/s2/local/train.sh
浏览文件 @
03a50d7b
...
@@ -19,12 +19,22 @@ echo "using ${device}..."
...
@@ -19,12 +19,22 @@ echo "using ${device}..."
mkdir
-p
exp
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
python3
-u
${
BIN_DIR
}
/train.py
\
--model-name
u2_kaldi
\
--model-name
u2_kaldi
\
--device
${
device
}
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
echo
"Failed in training!"
...
...
examples/librispeech/s2/run.sh
浏览文件 @
03a50d7b
...
@@ -5,6 +5,7 @@ source path.sh
...
@@ -5,6 +5,7 @@ source path.sh
stage
=
0
stage
=
0
stop_stage
=
100
stop_stage
=
100
conf_path
=
conf/transformer.yaml
conf_path
=
conf/transformer.yaml
dict_path
=
data/train_960_unigram5000_units.txt
avg_num
=
5
avg_num
=
5
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
source
${
MAIN_ROOT
}
/utils/parse_options.sh
||
exit
1
;
...
@@ -29,12 +30,12 @@ fi
...
@@ -29,12 +30,12 @@ fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# test ckpt avg_n
# test ckpt avg_n
CUDA_VISIBLE_DEVICES
=
0 ./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
CUDA_VISIBLE_DEVICES
=
0 ./local/test.sh
${
conf_path
}
${
dict_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
fi
fi
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
# ctc alignment of test data
# ctc alignment of test data
CUDA_VISIBLE_DEVICES
=
0 ./local/align.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
CUDA_VISIBLE_DEVICES
=
0 ./local/align.sh
${
conf_path
}
${
dict_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
fi
fi
if
[
${
stage
}
-le
5
]
&&
[
${
stop_stage
}
-ge
5
]
;
then
if
[
${
stage
}
-le
5
]
&&
[
${
stop_stage
}
-ge
5
]
;
then
...
...
examples/ted_en_zh/t0/local/train.sh
浏览文件 @
03a50d7b
...
@@ -19,11 +19,21 @@ echo "using ${device}..."
...
@@ -19,11 +19,21 @@ echo "using ${device}..."
mkdir
-p
exp
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
echo
"Failed in training!"
...
...
examples/thchs30/a0/local/data.sh
浏览文件 @
03a50d7b
...
@@ -20,27 +20,33 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
...
@@ -20,27 +20,33 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo
"Prepare THCHS-30 failed. Terminated."
echo
"Prepare THCHS-30 failed. Terminated."
exit
1
exit
1
fi
fi
fi
fi
# dump manifest to data/
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
python3
${
MAIN_ROOT
}
/utils/dump_manifest.py
--manifest-path
=
data/manifest.train
--output-dir
=
data
# dump manifest to data/
python3
${
MAIN_ROOT
}
/utils/dump_manifest.py
--manifest-path
=
data/manifest.train
--output-dir
=
data
# copy files to data/dict to gen word.lexicon
fi
cp
${
TARGET_DIR
}
/thchs30/data_thchs30/lm_word/lexicon.txt data/dict/lm_word_lexicon_1
cp
${
TARGET_DIR
}
/thchs30/resource/dict/lexicon.txt data/dict/lm_word_lexicon_2
# copy phone.lexicon to data/dict
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
cp
${
TARGET_DIR
}
/thchs30/data_thchs30/lm_phone/lexicon.txt data/dict/phone.lexicon
# copy files to data/dict to gen word.lexicon
cp
${
TARGET_DIR
}
/thchs30/data_thchs30/lm_word/lexicon.txt data/dict/lm_word_lexicon_1
cp
${
TARGET_DIR
}
/thchs30/resource/dict/lexicon.txt data/dict/lm_word_lexicon_2
# copy phone.lexicon to data/dict
cp
${
TARGET_DIR
}
/thchs30/data_thchs30/lm_phone/lexicon.txt data/dict/phone.lexicon
fi
# gen word.lexicon
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
python
local
/gen_word2phone.py
--root-dir
=
data/dict
--output-dir
=
data/dict
# gen word.lexicon
python
local
/gen_word2phone.py
--lexicon-files
=
"data/dict/lm_word_lexicon_1 data/dict/lm_word_lexicon_2"
--output-path
=
data/dict/word.lexicon
fi
# reorganize dataset for MFA
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
if
[
!
-d
$EXP_DIR
/thchs30_corpus
]
;
then
# reorganize dataset for MFA
echo
"reorganizing thchs30 corpus..."
if
[
!
-d
$EXP_DIR
/thchs30_corpus
]
;
then
python
local
/reorganize_thchs30.py
--root-dir
=
data
--output-dir
=
data/thchs30_corpus
--script-type
=
$LEXICON_NAME
echo
"reorganizing thchs30 corpus..."
echo
"reorganization done."
python
local
/reorganize_thchs30.py
--root-dir
=
data
--output-dir
=
data/thchs30_corpus
--script-type
=
$LEXICON_NAME
echo
"reorganization done."
fi
fi
fi
echo
"THCHS-30 data preparation done."
echo
"THCHS-30 data preparation done."
...
...
examples/thchs30/a0/local/gen_word2phone.py
浏览文件 @
03a50d7b
...
@@ -18,6 +18,7 @@ file2: THCHS-30/resource/dict/lexicon.txt
...
@@ -18,6 +18,7 @@ file2: THCHS-30/resource/dict/lexicon.txt
import
argparse
import
argparse
from
collections
import
defaultdict
from
collections
import
defaultdict
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
List
from
typing
import
Union
from
typing
import
Union
# key: (cn, ('ee', 'er4')),value: count
# key: (cn, ('ee', 'er4')),value: count
...
@@ -34,7 +35,7 @@ def is_Chinese(ch):
...
@@ -34,7 +35,7 @@ def is_Chinese(ch):
return
False
return
False
def
proc_line
(
line
):
def
proc_line
(
line
:
str
):
line
=
line
.
strip
()
line
=
line
.
strip
()
if
is_Chinese
(
line
[
0
]):
if
is_Chinese
(
line
[
0
]):
line_list
=
line
.
split
()
line_list
=
line
.
split
()
...
@@ -49,20 +50,25 @@ def proc_line(line):
...
@@ -49,20 +50,25 @@ def proc_line(line):
cn_phones_counter
[(
cn
,
phones
)]
+=
1
cn_phones_counter
[(
cn
,
phones
)]
+=
1
def
gen_lexicon
(
root_dir
:
Union
[
str
,
Path
],
output_dir
:
Union
[
str
,
Path
]):
"""
root_dir
=
Path
(
root_dir
).
expanduser
()
example lines of output
output_dir
=
Path
(
output_dir
).
expanduser
()
the first column is a Chinese character
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
the second is the probability of this pronunciation
file1
=
root_dir
/
"lm_word_lexicon_1"
and the rest are the phones of this pronunciation
file2
=
root_dir
/
"lm_word_lexicon_2"
一 0.22 ii i1↩
write_file
=
output_dir
/
"word.lexicon"
一 0.45 ii i4↩
一 0.32 ii i2↩
一 0.01 ii i5
"""
def
gen_lexicon
(
lexicon_files
:
List
[
Union
[
str
,
Path
]],
output_path
:
Union
[
str
,
Path
]):
for
file_path
in
lexicon_files
:
with
open
(
file_path
,
"r"
)
as
f1
:
for
line
in
f1
:
proc_line
(
line
)
with
open
(
file1
,
"r"
)
as
f1
:
for
line
in
f1
:
proc_line
(
line
)
with
open
(
file2
,
"r"
)
as
f2
:
for
line
in
f2
:
proc_line
(
line
)
for
key
in
cn_phones_counter
:
for
key
in
cn_phones_counter
:
cn
=
key
[
0
]
cn
=
key
[
0
]
cn_counter
[
cn
].
append
((
key
[
1
],
cn_phones_counter
[
key
]))
cn_counter
[
cn
].
append
((
key
[
1
],
cn_phones_counter
[
key
]))
...
@@ -75,7 +81,8 @@ def gen_lexicon(root_dir: Union[str, Path], output_dir: Union[str, Path]):
...
@@ -75,7 +81,8 @@ def gen_lexicon(root_dir: Union[str, Path], output_dir: Union[str, Path]):
p
=
round
(
p
,
2
)
p
=
round
(
p
,
2
)
if
p
>
0
:
if
p
>
0
:
cn_counter_p
[
key
].
append
((
item
[
0
],
p
))
cn_counter_p
[
key
].
append
((
item
[
0
],
p
))
with
open
(
write_file
,
"w"
)
as
wf
:
with
open
(
output_path
,
"w"
)
as
wf
:
for
key
in
cn_counter_p
:
for
key
in
cn_counter_p
:
phone_p_list
=
cn_counter_p
[
key
]
phone_p_list
=
cn_counter_p
[
key
]
for
item
in
phone_p_list
:
for
item
in
phone_p_list
:
...
@@ -87,8 +94,21 @@ if __name__ == "__main__":
...
@@ -87,8 +94,21 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
"Gen Chinese characters to phone lexicon for THCHS-30 dataset"
description
=
"Gen Chinese characters to phone lexicon for THCHS-30 dataset"
)
)
# A line of word_lexicon:
# 一丁点 ii i4 d ing1 d ian3
# the first is word, and the rest are the phones of the word, and the len of phones is twice of the word's len
parser
.
add_argument
(
"--lexicon-files"
,
type
=
str
,
default
=
"data/dict/lm_word_lexicon_1 data/dict/lm_word_lexicon_2"
,
help
=
"lm_word_lexicon files"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--root-dir"
,
type
=
str
,
help
=
"dir to thchs30 lm_word_lexicons"
)
"--output-path"
,
parser
.
add_argument
(
"--output-dir"
,
type
=
str
,
help
=
"path to save outputs"
)
type
=
str
,
default
=
"data/dict/word.lexicon"
,
help
=
"path to save output word2phone lexicon"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
gen_lexicon
(
args
.
root_dir
,
args
.
output_dir
)
lexicon_files
=
args
.
lexicon_files
.
split
(
" "
)
output_path
=
Path
(
args
.
output_path
).
expanduser
()
gen_lexicon
(
lexicon_files
,
output_path
)
examples/thchs30/a0/local/reorganize_thchs30.py
浏览文件 @
03a50d7b
...
@@ -58,8 +58,6 @@ def write_lab(root_dir: Union[str, Path],
...
@@ -58,8 +58,6 @@ def write_lab(root_dir: Union[str, Path],
def
reorganize_thchs30
(
root_dir
:
Union
[
str
,
Path
],
def
reorganize_thchs30
(
root_dir
:
Union
[
str
,
Path
],
output_dir
:
Union
[
str
,
Path
]
=
None
,
output_dir
:
Union
[
str
,
Path
]
=
None
,
script_type
=
'phone'
):
script_type
=
'phone'
):
root_dir
=
Path
(
root_dir
).
expanduser
()
output_dir
=
Path
(
output_dir
).
expanduser
()
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
link_wav
(
root_dir
,
output_dir
)
link_wav
(
root_dir
,
output_dir
)
write_lab
(
root_dir
,
output_dir
,
script_type
)
write_lab
(
root_dir
,
output_dir
,
script_type
)
...
@@ -72,12 +70,15 @@ if __name__ == "__main__":
...
@@ -72,12 +70,15 @@ if __name__ == "__main__":
parser
.
add_argument
(
parser
.
add_argument
(
"--output-dir"
,
"--output-dir"
,
type
=
str
,
type
=
str
,
help
=
"path to save outputs(audio and transcriptions)"
)
help
=
"path to save outputs
(audio and transcriptions)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--script-type"
,
"--script-type"
,
type
=
str
,
type
=
str
,
default
=
"phone"
,
default
=
"phone"
,
help
=
"type of lab ('word'/'syllable'/'phone')"
)
help
=
"type of lab ('word'/'syllable'/'phone')"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
reorganize_thchs30
(
args
.
root_dir
,
args
.
output_dir
,
args
.
script_type
)
root_dir
=
Path
(
args
.
root_dir
).
expanduser
()
output_dir
=
Path
(
args
.
output_dir
).
expanduser
()
reorganize_thchs30
(
root_dir
,
output_dir
,
args
.
script_type
)
examples/thchs30/a0/run.sh
浏览文件 @
03a50d7b
...
@@ -14,14 +14,17 @@ source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
...
@@ -14,14 +14,17 @@ source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
# gen lexicon relink gen dump
# gen lexicon relink gen dump
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
# prepare data
# prepare data
bash ./local/data.sh
$LEXICON_NAME
||
exit
-1
echo
"Start prepare thchs30 data for MFA ..."
bash ./local/data.sh
$LEXICON_NAME
||
exit
-1
fi
fi
# run MFA
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
if
[
!
-d
"
$EXP_DIR
/thchs30_alignment"
]
;
then
# run MFA
echo
"Start MFA training..."
if
[
!
-d
"
$EXP_DIR
/thchs30_alignment"
]
;
then
mfa_train_and_align data/thchs30_corpus data/dict/
$LEXICON_NAME
.lexicon
$EXP_DIR
/thchs30_alignment
-o
$EXP_DIR
/thchs30_model
--clean
--verbose
--temp_directory
exp/.mfa_train_and_align
--num_jobs
$NUM_JOBS
echo
"Start MFA training ..."
echo
"training done!
\n
results:
$EXP_DIR
/thchs30_alignment
\n
model:
$EXP_DIR
/thchs30_model
\n
"
mfa_train_and_align data/thchs30_corpus data/dict/
$LEXICON_NAME
.lexicon
$EXP_DIR
/thchs30_alignment
-o
$EXP_DIR
/thchs30_model
--clean
--verbose
--temp_directory
exp/.mfa_train_and_align
--num_jobs
$NUM_JOBS
echo
"MFA training done!
\n
results:
$EXP_DIR
/thchs30_alignment
\n
model:
$EXP_DIR
/thchs30_model
\n
"
fi
fi
fi
...
...
examples/timit/s1/conf/augmentation.json
浏览文件 @
03a50d7b
...
@@ -27,7 +27,9 @@
...
@@ -27,7 +27,9 @@
"W"
:
80
,
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
,
"warp_mode"
:
"PIL"
},
},
"prob"
:
1.0
"prob"
:
1.0
}
}
...
...
examples/timit/s1/local/train.sh
浏览文件 @
03a50d7b
...
@@ -19,11 +19,21 @@ echo "using ${device}..."
...
@@ -19,11 +19,21 @@ echo "using ${device}..."
mkdir
-p
exp
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
echo
"Failed in training!"
...
...
examples/tiny/s0/conf/augmentation.json
浏览文件 @
03a50d7b
[
[
{
"type"
:
"speed"
,
"params"
:
{
"min_speed_rate"
:
0.9
,
"max_speed_rate"
:
1.1
,
"num_rates"
:
3
},
"prob"
:
0.0
},
{
{
"type"
:
"shift"
,
"type"
:
"shift"
,
"params"
:
{
"params"
:
{
...
@@ -6,5 +15,22 @@
...
@@ -6,5 +15,22 @@
"max_shift_ms"
:
5
"max_shift_ms"
:
5
},
},
"prob"
:
1.0
"prob"
:
1.0
},
{
"type"
:
"specaug"
,
"params"
:
{
"W"
:
5
,
"warp_mode"
:
"PIL"
,
"F"
:
30
,
"n_freq_masks"
:
2
,
"T"
:
40
,
"n_time_masks"
:
2
,
"p"
:
1.0
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
},
"prob"
:
1.0
}
}
]
]
examples/tiny/s0/local/train.sh
浏览文件 @
03a50d7b
...
@@ -19,12 +19,22 @@ fi
...
@@ -19,12 +19,22 @@ fi
mkdir
-p
exp
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
\
--output
exp/
${
ckpt_name
}
\
--model_type
${
model_type
}
--model_type
${
model_type
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
echo
"Failed in training!"
...
...
examples/tiny/s1/conf/augmentation.json
浏览文件 @
03a50d7b
...
@@ -27,7 +27,9 @@
...
@@ -27,7 +27,9 @@
"W"
:
80
,
"W"
:
80
,
"adaptive_number_ratio"
:
0
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
,
"warp_mode"
:
"PIL"
},
},
"prob"
:
1.0
"prob"
:
1.0
}
}
...
...
examples/tiny/s1/local/train.sh
浏览文件 @
03a50d7b
...
@@ -18,11 +18,21 @@ fi
...
@@ -18,11 +18,21 @@ fi
mkdir
-p
exp
mkdir
-p
exp
seed
=
1024
if
[
${
seed
}
]
;
then
export
FLAGS_cudnn_deterministic
=
True
fi
python3
-u
${
BIN_DIR
}
/train.py
\
python3
-u
${
BIN_DIR
}
/train.py
\
--device
${
device
}
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--config
${
config_path
}
\
--output
exp/
${
ckpt_name
}
--output
exp/
${
ckpt_name
}
\
--seed
${
seed
}
if
[
${
seed
}
]
;
then
unset
FLAGS_cudnn_deterministic
fi
if
[
$?
-ne
0
]
;
then
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
echo
"Failed in training!"
...
...
requirements.txt
浏览文件 @
03a50d7b
coverage
coverage
gpustat
gpustat
jsonlines
kaldiio
kaldiio
Pillow
pre-commit
pre-commit
pybind11
pybind11
resampy
==0.2.2
resampy
==0.2.2
...
...
tools/extras/install_mfa.sh
浏览文件 @
03a50d7b
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
test
-d
Montreal-Forced-Aligner
||
git clone https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner.git
test
-d
Montreal-Forced-Aligner
||
git clone https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner.git
pushd
Montreal-Forced-Aligner
&&
git checkout v2.0.0a7
&&
python setup.py
install
pushd
Montreal-Forced-Aligner
&&
python setup.py
install
&&
popd
test
-d
kaldi
||
{
echo
"need install kaldi first"
;
exit
1
;
}
test
-d
kaldi
||
{
echo
"need install kaldi first"
;
exit
1
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录