Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
ab23eb57
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ab23eb57
编写于
8月 18, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix for kaldi
上级
cd34e733
变更
18
展开全部
隐藏空白更改
内联
并排
Showing
18 changed file
with
1009 addition
and
89 deletion
+1009
-89
deepspeech/__init__.py
deepspeech/__init__.py
+0
-39
deepspeech/exps/u2_kaldi/__init__.py
deepspeech/exps/u2_kaldi/__init__.py
+13
-0
deepspeech/exps/u2_kaldi/bin/alignment.py
deepspeech/exps/u2_kaldi/bin/alignment.py
+54
-0
deepspeech/exps/u2_kaldi/bin/export.py
deepspeech/exps/u2_kaldi/bin/export.py
+48
-0
deepspeech/exps/u2_kaldi/bin/test.py
deepspeech/exps/u2_kaldi/bin/test.py
+55
-0
deepspeech/exps/u2_kaldi/bin/train.py
deepspeech/exps/u2_kaldi/bin/train.py
+69
-0
deepspeech/exps/u2_kaldi/model.py
deepspeech/exps/u2_kaldi/model.py
+642
-0
deepspeech/io/dataloader.py
deepspeech/io/dataloader.py
+21
-1
deepspeech/io/reader.py
deepspeech/io/reader.py
+4
-3
deepspeech/io/sampler.py
deepspeech/io/sampler.py
+1
-1
deepspeech/models/u2.py
deepspeech/models/u2.py
+7
-7
deepspeech/modules/activation.py
deepspeech/modules/activation.py
+1
-1
deepspeech/training/optimizer.py
deepspeech/training/optimizer.py
+46
-8
deepspeech/training/scheduler.py
deepspeech/training/scheduler.py
+35
-16
examples/librispeech/s2/conf/transformer.yaml
examples/librispeech/s2/conf/transformer.yaml
+11
-11
examples/librispeech/s2/local/train.sh
examples/librispeech/s2/local/train.sh
+1
-0
examples/librispeech/s2/path.sh
examples/librispeech/s2/path.sh
+1
-1
speechnn/core/transformers/README.md
speechnn/core/transformers/README.md
+0
-1
未找到文件。
deepspeech/__init__.py
浏览文件 @
ab23eb57
...
...
@@ -407,42 +407,3 @@ class GLU(nn.Layer):
if
not
hasattr
(
paddle
.
nn
,
'GLU'
):
logger
.
warn
(
"register user GLU to paddle.nn, remove this when fixed!"
)
setattr
(
paddle
.
nn
,
'GLU'
,
GLU
)
# TODO(Hui Zhang): remove this Layer
class
ConstantPad2d
(
nn
.
Layer
):
"""Pads the input tensor boundaries with a constant value.
For N-dimensional padding, use paddle.nn.functional.pad().
"""
def
__init__
(
self
,
padding
:
Union
[
tuple
,
list
,
int
],
value
:
float
):
"""
Args:
paddle ([tuple]): the size of the padding.
If is int, uses the same padding in all boundaries.
If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom)
value ([flaot]): pad value
"""
self
.
padding
=
padding
if
isinstance
(
padding
,
[
tuple
,
list
])
else
[
padding
]
*
4
self
.
value
=
value
def
forward
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
return
nn
.
functional
.
pad
(
xs
,
self
.
padding
,
mode
=
'constant'
,
value
=
self
.
value
,
data_format
=
'NCHW'
)
if
not
hasattr
(
paddle
.
nn
,
'ConstantPad2d'
):
logger
.
warn
(
"register user ConstantPad2d to paddle.nn, remove this when fixed!"
)
setattr
(
paddle
.
nn
,
'ConstantPad2d'
,
ConstantPad2d
)
########### hcak paddle.jit #############
if
not
hasattr
(
paddle
.
jit
,
'export'
):
logger
.
warn
(
"register user export to paddle.jit, remove this when fixed!"
)
setattr
(
paddle
.
jit
,
'export'
,
paddle
.
jit
.
to_static
)
deepspeech/exps/u2_kaldi/__init__.py
0 → 100644
浏览文件 @
ab23eb57
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
deepspeech/exps/u2_kaldi/bin/alignment.py
0 → 100644
浏览文件 @
ab23eb57
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Alignment for U2 model."""
from
deepspeech.exps.u2.model
import
get_cfg_defaults
from
deepspeech.exps.u2.model
import
U2Tester
as
Tester
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.utils.dynamic_import
import
dynamic_import
from
deepspeech.utils.utility
import
print_arguments
def
main_sp
(
config
,
args
):
exp
=
Tester
(
config
,
args
)
exp
.
setup
()
exp
.
run_align
()
def
main
(
config
,
args
):
main_sp
(
config
,
args
)
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
.
add_arguments
(
'--model-name'
,
type
=
str
,
default
=
'u2'
,
help
=
'model name, e.g: deepspeech2, u2, u2_kaldi, u2_st'
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
# https://yaml.org/type/float.html
config
=
get_cfg_defaults
()
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
if
args
.
opts
:
config
.
merge_from_list
(
args
.
opts
)
config
.
freeze
()
print
(
config
)
if
args
.
dump_config
:
with
open
(
args
.
dump_config
,
'w'
)
as
f
:
print
(
config
,
file
=
f
)
main
(
config
,
args
)
deepspeech/exps/u2_kaldi/bin/export.py
0 → 100644
浏览文件 @
ab23eb57
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Export for U2 model."""
from
deepspeech.exps.u2.model
import
get_cfg_defaults
from
deepspeech.exps.u2.model
import
U2Tester
as
Tester
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.utils.utility
import
print_arguments
def
main_sp
(
config
,
args
):
exp
=
Tester
(
config
,
args
)
exp
.
setup
()
exp
.
run_export
()
def
main
(
config
,
args
):
main_sp
(
config
,
args
)
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
# https://yaml.org/type/float.html
config
=
get_cfg_defaults
()
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
if
args
.
opts
:
config
.
merge_from_list
(
args
.
opts
)
config
.
freeze
()
print
(
config
)
if
args
.
dump_config
:
with
open
(
args
.
dump_config
,
'w'
)
as
f
:
print
(
config
,
file
=
f
)
main
(
config
,
args
)
deepspeech/exps/u2_kaldi/bin/test.py
0 → 100644
浏览文件 @
ab23eb57
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for U2 model."""
import
cProfile
from
deepspeech.exps.u2.model
import
get_cfg_defaults
from
deepspeech.exps.u2.model
import
U2Tester
as
Tester
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.utils.utility
import
print_arguments
# TODO(hui zhang): dynamic load
def
main_sp
(
config
,
args
):
exp
=
Tester
(
config
,
args
)
exp
.
setup
()
exp
.
run_test
()
def
main
(
config
,
args
):
main_sp
(
config
,
args
)
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
# https://yaml.org/type/float.html
config
=
get_cfg_defaults
()
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
if
args
.
opts
:
config
.
merge_from_list
(
args
.
opts
)
config
.
freeze
()
print
(
config
)
if
args
.
dump_config
:
with
open
(
args
.
dump_config
,
'w'
)
as
f
:
print
(
config
,
file
=
f
)
# Setting for profiling
pr
=
cProfile
.
Profile
()
pr
.
runcall
(
main
,
config
,
args
)
pr
.
dump_stats
(
'test.profile'
)
deepspeech/exps/u2_kaldi/bin/train.py
0 → 100644
浏览文件 @
ab23eb57
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer for U2 model."""
import
cProfile
import
os
from
paddle
import
distributed
as
dist
from
yacs.config
import
CfgNode
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.utils.dynamic_import
import
dynamic_import
from
deepspeech.utils.utility
import
print_arguments
model_alias
=
{
"u2"
:
"deepspeech.exps.u2.model:U2Trainer"
,
"u2_kaldi"
:
"deepspeech.exps.u2_kaldi.model:U2Trainer"
,
}
def
main_sp
(
config
,
args
):
trainer_cls
=
dynamic_import
(
args
.
model_name
,
model_alias
)
exp
=
trainer_cls
(
config
,
args
)
exp
.
setup
()
exp
.
run
()
def
main
(
config
,
args
):
if
args
.
device
==
"gpu"
and
args
.
nprocs
>
1
:
dist
.
spawn
(
main_sp
,
args
=
(
config
,
args
),
nprocs
=
args
.
nprocs
)
else
:
main_sp
(
config
,
args
)
if
__name__
==
"__main__"
:
parser
=
default_argument_parser
()
parser
.
add_argument
(
'--model-name'
,
type
=
str
,
default
=
'u2_kaldi'
,
help
=
'model name, e.g: deepspeech2, u2, u2_kaldi, u2_st'
)
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
config
=
CfgNode
()
config
.
set_new_allowed
(
True
)
config
.
merge_from_file
(
args
.
config
)
if
args
.
opts
:
config
.
merge_from_list
(
args
.
opts
)
config
.
freeze
()
print
(
config
)
if
args
.
dump_config
:
with
open
(
args
.
dump_config
,
'w'
)
as
f
:
print
(
config
,
file
=
f
)
# Setting for profiling
pr
=
cProfile
.
Profile
()
pr
.
runcall
(
main
,
config
,
args
)
pr
.
dump_stats
(
os
.
path
.
join
(
args
.
output
,
'train.profile'
))
deepspeech/exps/u2_kaldi/model.py
0 → 100644
浏览文件 @
ab23eb57
此差异已折叠。
点击以展开。
deepspeech/io/dataloader.py
浏览文件 @
ab23eb57
...
...
@@ -11,6 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Any
from
typing
import
Dict
from
typing
import
List
from
typing
import
Text
import
numpy
as
np
from
paddle.io
import
DataLoader
from
deepspeech.frontend.utility
import
read_manifest
...
...
@@ -25,6 +31,18 @@ __all__ = ["BatchDataLoader"]
logger
=
Log
(
__name__
).
getlog
()
def
feat_dim_and_vocab_size
(
data_json
:
List
[
Dict
[
Text
,
Any
]],
mode
:
Text
=
"asr"
,
iaxis
=
0
,
oaxis
=
0
):
if
mode
==
'asr'
:
feat_dim
=
data_json
[
0
][
'input'
][
oaxis
][
'shape'
][
1
]
vocab_size
=
data_json
[
0
][
'output'
][
oaxis
][
'shape'
][
1
]
else
:
raise
ValueError
(
f
"
{
mode
}
mode not support!"
)
return
feat_dim
,
vocab_size
class
BatchDataLoader
():
def
__init__
(
self
,
json_file
:
str
,
...
...
@@ -62,6 +80,8 @@ class BatchDataLoader():
# read json data
self
.
data_json
=
read_manifest
(
json_file
)
self
.
feat_dim
,
self
.
vocab_size
=
feat_dim_and_vocab_size
(
self
.
data_json
,
mode
=
'asr'
)
# make minibatch list (variable length)
self
.
minibaches
=
make_batchset
(
...
...
@@ -106,7 +126,7 @@ class BatchDataLoader():
self
.
dataloader
=
DataLoader
(
dataset
=
self
.
dataset
,
batch_size
=
1
,
shuffle
=
not
use_sortagrad
if
train_mode
else
False
,
shuffle
=
not
self
.
use_sortagrad
if
train_mode
else
False
,
collate_fn
=
lambda
x
:
x
[
0
],
num_workers
=
n_iter_processes
,
)
...
...
deepspeech/io/reader.py
浏览文件 @
ab23eb57
...
...
@@ -66,8 +66,9 @@ class LoadInputsAndTargets():
raise
ValueError
(
"Only asr are allowed: mode={}"
.
format
(
mode
))
if
preprocess_conf
is
not
None
:
self
.
preprocessing
=
AugmentationPipeline
(
preprocess_conf
)
logging
.
warning
(
with
open
(
preprocess_conf
,
'r'
)
as
fin
:
self
.
preprocessing
=
AugmentationPipeline
(
fin
.
read
())
logger
.
warning
(
"[Experimental feature] Some preprocessing will be done "
"for the mini-batch creation using {}"
.
format
(
self
.
preprocessing
))
...
...
@@ -197,7 +198,7 @@ class LoadInputsAndTargets():
nonzero_sorted_idx
=
nonzero_idx
if
len
(
nonzero_sorted_idx
)
!=
len
(
xs
[
0
]):
logg
ing
.
warning
(
logg
er
.
warning
(
"Target sequences include empty tokenid (batch {} -> {})."
.
format
(
len
(
xs
[
0
]),
len
(
nonzero_sorted_idx
)))
...
...
deepspeech/io/sampler.py
浏览文件 @
ab23eb57
...
...
@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng
=
np
.
random
.
RandomState
(
epoch
)
shift_len
=
rng
.
randint
(
0
,
batch_size
-
1
)
batch_indices
=
list
(
zip
(
*
[
iter
(
indices
[
shift_len
:])]
*
batch_size
))
batch_indices
=
list
(
zip
(
*
[
iter
(
indices
[
shift_len
:])]
*
batch_size
))
rng
.
shuffle
(
batch_indices
)
batch_indices
=
[
item
for
batch
in
batch_indices
for
item
in
batch
]
assert
clipped
is
False
...
...
deepspeech/models/u2.py
浏览文件 @
ab23eb57
...
...
@@ -612,32 +612,32 @@ class U2BaseModel(nn.Layer):
best_index
=
i
return
hyps
[
best_index
][
0
]
#@jit.
export
#@jit.
to_static
def
subsampling_rate
(
self
)
->
int
:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return
self
.
encoder
.
embed
.
subsampling_rate
#@jit.
export
#@jit.
to_static
def
right_context
(
self
)
->
int
:
""" Export interface for c++ call, return right_context of the model
"""
return
self
.
encoder
.
embed
.
right_context
#@jit.
export
#@jit.
to_static
def
sos_symbol
(
self
)
->
int
:
""" Export interface for c++ call, return sos symbol id of the model
"""
return
self
.
sos
#@jit.
export
#@jit.
to_static
def
eos_symbol
(
self
)
->
int
:
""" Export interface for c++ call, return eos symbol id of the model
"""
return
self
.
eos
@
jit
.
export
@
jit
.
to_static
def
forward_encoder_chunk
(
self
,
xs
:
paddle
.
Tensor
,
...
...
@@ -667,7 +667,7 @@ class U2BaseModel(nn.Layer):
xs
,
offset
,
required_cache_size
,
subsampling_cache
,
elayers_output_cache
,
conformer_cnn_cache
)
# @jit.
export
([
# @jit.
to_static
([
# paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D]
# ])
def
ctc_activation
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
...
...
@@ -680,7 +680,7 @@ class U2BaseModel(nn.Layer):
"""
return
self
.
ctc
.
log_softmax
(
xs
)
@
jit
.
export
@
jit
.
to_static
def
forward_attention_decoder
(
self
,
hyps
:
paddle
.
Tensor
,
...
...
deepspeech/modules/activation.py
浏览文件 @
ab23eb57
...
...
@@ -69,7 +69,7 @@ class ConvGLUBlock(nn.Layer):
dim
=
0
)
self
.
dropout_residual
=
nn
.
Dropout
(
p
=
dropout
)
self
.
pad_left
=
Constant
Pad2d
((
0
,
0
,
kernel_size
-
1
,
0
),
0
)
self
.
pad_left
=
nn
.
Pad2d
((
0
,
0
,
kernel_size
-
1
,
0
),
0
)
layers
=
OrderedDict
()
if
bottlececk_dim
==
0
:
...
...
deepspeech/training/optimizer.py
浏览文件 @
ab23eb57
...
...
@@ -15,6 +15,7 @@ from typing import Any
from
typing
import
Dict
from
typing
import
Text
import
paddle
from
paddle.optimizer
import
Optimizer
from
paddle.regularizer
import
L2Decay
...
...
@@ -43,6 +44,40 @@ def register_optimizer(cls):
return
cls
@
register_optimizer
class
Noam
(
paddle
.
optimizer
.
Adam
):
"""Seem to: espnet/nets/pytorch_backend/transformer/optimizer.py """
def
__init__
(
self
,
learning_rate
=
0
,
beta1
=
0.9
,
beta2
=
0.98
,
epsilon
=
1e-9
,
parameters
=
None
,
weight_decay
=
None
,
grad_clip
=
None
,
lazy_mode
=
False
,
multi_precision
=
False
,
name
=
None
):
super
().
__init__
(
learning_rate
=
learning_rate
,
beta1
=
beta1
,
beta2
=
beta2
,
epsilon
=
epsilon
,
parameters
=
parameters
,
weight_decay
=
weight_decay
,
grad_clip
=
grad_clip
,
lazy_mode
=
lazy_mode
,
multi_precision
=
multi_precision
,
name
=
name
)
def
__repr__
(
self
):
echo
=
f
"<
{
self
.
__class__
.
__module__
}
.
{
self
.
__class__
.
__name__
}
object at
{
hex
(
id
(
self
))
}
> "
echo
+=
f
"learning_rate:
{
self
.
_learning_rate
}
, "
echo
+=
f
"(beta1:
{
self
.
_beta1
}
beta2:
{
self
.
_beta2
}
), "
echo
+=
f
"epsilon:
{
self
.
_epsilon
}
"
def
dynamic_import_optimizer
(
module
):
"""Import Optimizer class dynamically.
...
...
@@ -69,15 +104,18 @@ class OptimizerFactory():
args
[
'grad_clip'
])
if
"grad_clip"
in
args
else
None
weight_decay
=
L2Decay
(
args
[
'weight_decay'
])
if
"weight_decay"
in
args
else
None
module_class
=
dynamic_import_optimizer
(
name
.
lower
())
if
weight_decay
:
logger
.
info
(
f
'
WeightDecay:
{
weight_decay
}
'
)
logger
.
info
(
f
'
<WeightDecay -
{
weight_decay
}
>
'
)
if
grad_clip
:
logger
.
info
(
f
'GradClip:
{
grad_clip
}
'
)
logger
.
info
(
f
"Optimizer:
{
module_class
.
__name__
}
{
args
[
'learning_rate'
]
}
"
)
logger
.
info
(
f
'<GradClip -
{
grad_clip
}
>'
)
module_class
=
dynamic_import_optimizer
(
name
.
lower
())
args
.
update
({
"grad_clip"
:
grad_clip
,
"weight_decay"
:
weight_decay
})
return
instance_class
(
module_class
,
args
)
opt
=
instance_class
(
module_class
,
args
)
if
"__repr__"
in
vars
(
opt
):
logger
.
info
(
f
"
{
opt
}
"
)
else
:
logger
.
info
(
f
"<Optimizer
{
module_class
.
__module__
}
.
{
module_class
.
__name__
}
> LR:
{
args
[
'learning_rate'
]
}
"
)
return
opt
deepspeech/training/scheduler.py
浏览文件 @
ab23eb57
...
...
@@ -41,22 +41,6 @@ def register_scheduler(cls):
return
cls
def
dynamic_import_scheduler
(
module
):
"""Import Scheduler class dynamically.
Args:
module (str): module_name:class_name or alias in `SCHEDULER_DICT`
Returns:
type: Scheduler class
"""
module_class
=
dynamic_import
(
module
,
SCHEDULER_DICT
)
assert
issubclass
(
module_class
,
LRScheduler
),
f
"
{
module
}
does not implement LRScheduler"
return
module_class
@
register_scheduler
class
WarmupLR
(
LRScheduler
):
"""The WarmupLR scheduler
...
...
@@ -102,6 +86,41 @@ class WarmupLR(LRScheduler):
self
.
step
(
epoch
=
step
)
@
register_scheduler
class
ConstantLR
(
LRScheduler
):
"""
Args:
learning_rate (float): The initial learning rate. It is a python float number.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``ConstantLR`` instance to schedule learning rate.
"""
def
__init__
(
self
,
learning_rate
,
last_epoch
=-
1
,
verbose
=
False
):
super
().
__init__
(
learning_rate
,
last_epoch
,
verbose
)
def
get_lr
(
self
):
return
self
.
base_lr
def
dynamic_import_scheduler
(
module
):
"""Import Scheduler class dynamically.
Args:
module (str): module_name:class_name or alias in `SCHEDULER_DICT`
Returns:
type: Scheduler class
"""
module_class
=
dynamic_import
(
module
,
SCHEDULER_DICT
)
assert
issubclass
(
module_class
,
LRScheduler
),
f
"
{
module
}
does not implement LRScheduler"
return
module_class
class
LRSchedulerFactory
():
@
classmethod
def
from_args
(
cls
,
name
:
str
,
args
:
Dict
[
Text
,
Any
]):
...
...
examples/librispeech/s2/conf/transformer.yaml
浏览文件 @
ab23eb57
...
...
@@ -19,7 +19,7 @@ collator:
batch_size
:
64
raw_wav
:
True
# use raw_wav or kaldi feature
specgram_type
:
fbank
#linear, mfcc, fbank
feat_dim
:
8
0
feat_dim
:
8
3
delta_delta
:
False
dither
:
1.0
target_sample_rate
:
16000
...
...
@@ -38,7 +38,7 @@ collator:
# network architecture
model
:
cmvn_file
:
"
data/mean_std.json"
cmvn_file
:
cmvn_file_type
:
"
json"
# encoder related
encoder
:
transformer
...
...
@@ -74,20 +74,20 @@ model:
training
:
n_epoch
:
120
accum_grad
:
2
global_grad_clip
:
5.0
optim
:
adam
optim_conf
:
lr
:
0.004
weight_decay
:
1e-06
scheduler
:
warmuplr
# pytorch v1.1.0+ required
scheduler_conf
:
warmup_steps
:
25000
lr_decay
:
1.0
log_interval
:
100
checkpoint
:
kbest_n
:
50
latest_n
:
5
optim
:
adam
optim_conf
:
global_grad_clip
:
5.0
weight_decay
:
1.0e-06
scheduler
:
warmuplr
# pytorch v1.1.0+ required
scheduler_conf
:
lr
:
0.004
warmup_steps
:
25000
lr_decay
:
1.0
decoding
:
batch_size
:
64
...
...
examples/librispeech/s2/local/train.sh
浏览文件 @
ab23eb57
...
...
@@ -20,6 +20,7 @@ echo "using ${device}..."
mkdir
-p
exp
python3
-u
${
BIN_DIR
}
/train.py
\
--model-name
u2_kaldi
\
--device
${
device
}
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
...
...
examples/librispeech/s2/path.sh
浏览文件 @
ab23eb57
...
...
@@ -10,5 +10,5 @@ export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export
LD_LIBRARY_PATH
=
${
LD_LIBRARY_PATH
}
:/usr/local/lib/
MODEL
=
u2
MODEL
=
u2
_kaldi
export
BIN_DIR
=
${
MAIN_ROOT
}
/deepspeech/exps/
${
MODEL
}
/bin
speechnn/core/transformers/README.md
浏览文件 @
ab23eb57
...
...
@@ -7,4 +7,3 @@
*
https://github.com/NVIDIA/FasterTransformer.git
*
https://github.com/idiap/fast-transformers
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录