Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
756be8fb
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
756be8fb
编写于
10月 08, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix dist batch sampler set_epcoh call
上级
12f540cd
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
55 addition
and
58 deletion
+55
-58
deepspeech/exps/u2/model.py
deepspeech/exps/u2/model.py
+6
-44
deepspeech/frontend/augmentor/spec_augment.py
deepspeech/frontend/augmentor/spec_augment.py
+5
-4
deepspeech/frontend/featurizer/text_featurizer.py
deepspeech/frontend/featurizer/text_featurizer.py
+3
-2
deepspeech/models/u2/u2.py
deepspeech/models/u2/u2.py
+4
-2
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+9
-4
examples/tiny/s0/conf/augmentation.json
examples/tiny/s0/conf/augmentation.json
+26
-0
examples/tiny/s0/local/data.sh
examples/tiny/s0/local/data.sh
+1
-1
examples/tiny/s1/path.sh
examples/tiny/s1/path.sh
+1
-1
examples/tiny/s1/test.profile
examples/tiny/s1/test.profile
+0
-0
未找到文件。
deepspeech/exps/u2/model.py
浏览文件 @
756be8fb
...
@@ -24,6 +24,7 @@ import numpy as np
...
@@ -24,6 +24,7 @@ import numpy as np
import
paddle
import
paddle
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
from
paddle.io
import
DataLoader
from
paddle.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
from
deepspeech.io.collator
import
SpeechCollator
from
deepspeech.io.collator
import
SpeechCollator
...
@@ -162,8 +163,10 @@ class U2Trainer(Trainer):
...
@@ -162,8 +163,10 @@ class U2Trainer(Trainer):
self
.
save
(
tag
=
'init'
)
self
.
save
(
tag
=
'init'
)
self
.
lr_scheduler
.
step
(
self
.
iteration
)
self
.
lr_scheduler
.
step
(
self
.
iteration
)
if
self
.
parallel
:
if
hasattr
(
self
.
train_loader
,
"batch_sampler"
):
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
batch_sampler
=
self
.
train_loader
.
batch_sampler
if
isinstance
(
batch_sampler
,
DistributedBatchSampler
):
batch_sampler
.
set_epoch
(
self
.
epoch
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
...
@@ -476,13 +479,6 @@ class U2Tester(U2Trainer):
...
@@ -476,13 +479,6 @@ class U2Tester(U2Trainer):
})
})
f
.
write
(
data
+
'
\n
'
)
f
.
write
(
data
+
'
\n
'
)
# def run_test(self):
# self.resume_or_scratch()
# try:
# self.test()
# except KeyboardInterrupt:
# sys.exit(-1)
def
load_inferspec
(
self
):
def
load_inferspec
(
self
):
"""infer model and input spec.
"""infer model and input spec.
...
@@ -491,7 +487,7 @@ class U2Tester(U2Trainer):
...
@@ -491,7 +487,7 @@ class U2Tester(U2Trainer):
List[paddle.static.InputSpec]: input spec.
List[paddle.static.InputSpec]: input spec.
"""
"""
from
deepspeech.models.u2
import
U2InferModel
from
deepspeech.models.u2
import
U2InferModel
infer_model
=
U2InferModel
.
from_pretrained
(
self
.
test_loader
.
dataset
,
infer_model
=
U2InferModel
.
from_pretrained
(
self
.
test_loader
,
self
.
config
.
model
.
clone
(),
self
.
config
.
model
.
clone
(),
self
.
args
.
checkpoint_path
)
self
.
args
.
checkpoint_path
)
feat_dim
=
self
.
test_loader
.
dataset
.
feature_size
feat_dim
=
self
.
test_loader
.
dataset
.
feature_size
...
@@ -511,37 +507,3 @@ class U2Tester(U2Trainer):
...
@@ -511,37 +507,3 @@ class U2Tester(U2Trainer):
static_model
=
paddle
.
jit
.
to_static
(
infer_model
,
input_spec
=
input_spec
)
static_model
=
paddle
.
jit
.
to_static
(
infer_model
,
input_spec
=
input_spec
)
logger
.
info
(
f
"Export code:
{
static_model
.
forward
.
code
}
"
)
logger
.
info
(
f
"Export code:
{
static_model
.
forward
.
code
}
"
)
paddle
.
jit
.
save
(
static_model
,
self
.
args
.
export_path
)
paddle
.
jit
.
save
(
static_model
,
self
.
args
.
export_path
)
# def run_export(self):
# try:
# self.export()
# except KeyboardInterrupt:
# sys.exit(-1)
# def setup(self):
# """Setup the experiment.
# """
# paddle.set_device(self.args.device)
# self.setup_output_dir()
# self.setup_checkpointer()
# self.setup_dataloader()
# self.setup_model()
# self.iteration = 0
# self.epoch = 0
# def setup_output_dir(self):
# """Create a directory used for output.
# """
# # output dir
# if self.args.output:
# output_dir = Path(self.args.output).expanduser()
# output_dir.mkdir(parents=True, exist_ok=True)
# else:
# output_dir = Path(
# self.args.checkpoint_path).expanduser().parent.parent
# output_dir.mkdir(parents=True, exist_ok=True)
# self.output_dir = output_dir
deepspeech/frontend/augmentor/spec_augment.py
浏览文件 @
756be8fb
...
@@ -41,7 +41,8 @@ class SpecAugmentor(AugmentorBase):
...
@@ -41,7 +41,8 @@ class SpecAugmentor(AugmentorBase):
W
=
40
,
W
=
40
,
adaptive_number_ratio
=
0
,
adaptive_number_ratio
=
0
,
adaptive_size_ratio
=
0
,
adaptive_size_ratio
=
0
,
max_n_time_masks
=
20
):
max_n_time_masks
=
20
,
**
kwargs
):
"""SpecAugment class.
"""SpecAugment class.
Args:
Args:
rng (random.Random): random generator object.
rng (random.Random): random generator object.
...
@@ -121,7 +122,7 @@ class SpecAugmentor(AugmentorBase):
...
@@ -121,7 +122,7 @@ class SpecAugmentor(AugmentorBase):
def
time_mask
(
self
):
def
time_mask
(
self
):
return
self
.
_time_mask
return
self
.
_time_mask
def
time_warp
(
xs
,
W
=
40
):
def
time_warp
(
self
,
xs
,
W
=
40
):
raise
NotImplementedError
raise
NotImplementedError
def
mask_freq
(
self
,
xs
,
replace_with_zero
=
False
):
def
mask_freq
(
self
,
xs
,
replace_with_zero
=
False
):
...
...
deepspeech/frontend/featurizer/text_featurizer.py
浏览文件 @
756be8fb
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
"""Contains the text featurizer class."""
"""Contains the text featurizer class."""
import
sentencepiece
as
spm
import
sentencepiece
as
spm
from
pprint
import
pformat
from
..utility
import
EOS
from
..utility
import
EOS
from
..utility
import
SPACE
from
..utility
import
SPACE
...
@@ -206,7 +207,7 @@ class TextFeaturizer():
...
@@ -206,7 +207,7 @@ class TextFeaturizer():
"""Load vocabulary from file."""
"""Load vocabulary from file."""
vocab_list
=
load_dict
(
vocab_filepath
,
maskctc
)
vocab_list
=
load_dict
(
vocab_filepath
,
maskctc
)
assert
vocab_list
is
not
None
assert
vocab_list
is
not
None
logger
.
info
(
f
"Vocab:
{
vocab_list
}
"
)
logger
.
info
(
f
"Vocab:
{
pformat
(
vocab_list
)
}
"
)
id2token
=
dict
(
id2token
=
dict
(
[(
idx
,
token
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
[(
idx
,
token
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
...
@@ -220,10 +221,10 @@ class TextFeaturizer():
...
@@ -220,10 +221,10 @@ class TextFeaturizer():
sos_id
=
vocab_list
.
index
(
SOS
)
if
SOS
in
vocab_list
else
-
1
sos_id
=
vocab_list
.
index
(
SOS
)
if
SOS
in
vocab_list
else
-
1
space_id
=
vocab_list
.
index
(
SPACE
)
if
SPACE
in
vocab_list
else
-
1
space_id
=
vocab_list
.
index
(
SPACE
)
if
SPACE
in
vocab_list
else
-
1
logger
.
info
(
f
"BLANK id:
{
blank_id
}
"
)
logger
.
info
(
f
"UNK id:
{
unk_id
}
"
)
logger
.
info
(
f
"UNK id:
{
unk_id
}
"
)
logger
.
info
(
f
"EOS id:
{
eos_id
}
"
)
logger
.
info
(
f
"EOS id:
{
eos_id
}
"
)
logger
.
info
(
f
"SOS id:
{
sos_id
}
"
)
logger
.
info
(
f
"SOS id:
{
sos_id
}
"
)
logger
.
info
(
f
"SPACE id:
{
space_id
}
"
)
logger
.
info
(
f
"SPACE id:
{
space_id
}
"
)
logger
.
info
(
f
"BLANK id:
{
blank_id
}
"
)
logger
.
info
(
f
"MASKCTC id:
{
maskctc_id
}
"
)
logger
.
info
(
f
"MASKCTC id:
{
maskctc_id
}
"
)
return
token2id
,
id2token
,
vocab_list
,
unk_id
,
eos_id
return
token2id
,
id2token
,
vocab_list
,
unk_id
,
eos_id
deepspeech/models/u2/u2.py
浏览文件 @
756be8fb
...
@@ -911,8 +911,10 @@ class U2Model(U2BaseModel):
...
@@ -911,8 +911,10 @@ class U2Model(U2BaseModel):
DeepSpeech2Model: The model built from pretrained result.
DeepSpeech2Model: The model built from pretrained result.
"""
"""
with
UpdateConfig
(
config
):
with
UpdateConfig
(
config
):
config
.
input_dim
=
dataloader
.
collate_fn
.
feature_size
#config.input_dim = dataloader.collate_fn.feature_size
config
.
output_dim
=
dataloader
.
collate_fn
.
vocab_size
#config.output_dim = dataloader.collate_fn.vocab_size
config
.
input_dim
=
dataloader
.
dataset
.
feature_size
config
.
output_dim
=
dataloader
.
dataset
.
vocab_size
model
=
cls
.
from_config
(
config
)
model
=
cls
.
from_config
(
config
)
...
...
deepspeech/training/trainer.py
浏览文件 @
756be8fb
...
@@ -17,6 +17,7 @@ from pathlib import Path
...
@@ -17,6 +17,7 @@ from pathlib import Path
import
paddle
import
paddle
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
from
paddle.io
import
DistributedBatchSampler
from
tensorboardX
import
SummaryWriter
from
tensorboardX
import
SummaryWriter
from
deepspeech.utils
import
mp_tools
from
deepspeech.utils
import
mp_tools
...
@@ -179,8 +180,10 @@ class Trainer():
...
@@ -179,8 +180,10 @@ class Trainer():
"""Reset the train loader seed and increment `epoch`.
"""Reset the train loader seed and increment `epoch`.
"""
"""
self
.
epoch
+=
1
self
.
epoch
+=
1
if
self
.
parallel
:
if
hasattr
(
self
.
train_loader
,
"batch_sampler"
):
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
batch_sampler
=
self
.
train_loader
.
batch_sampler
if
isinstance
(
batch_sampler
,
DistributedBatchSampler
):
batch_sampler
.
set_epoch
(
self
.
epoch
)
def
train
(
self
):
def
train
(
self
):
"""The training process control by epoch."""
"""The training process control by epoch."""
...
@@ -190,8 +193,10 @@ class Trainer():
...
@@ -190,8 +193,10 @@ class Trainer():
self
.
save
(
tag
=
'init'
)
self
.
save
(
tag
=
'init'
)
self
.
lr_scheduler
.
step
(
self
.
iteration
)
self
.
lr_scheduler
.
step
(
self
.
iteration
)
if
self
.
parallel
:
if
hasattr
(
self
.
train_loader
,
"batch_sampler"
):
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
batch_sampler
=
self
.
train_loader
.
batch_sampler
if
isinstance
(
batch_sampler
,
DistributedBatchSampler
):
batch_sampler
.
set_epoch
(
self
.
epoch
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
...
...
examples/tiny/s0/conf/augmentation.json
浏览文件 @
756be8fb
[
[
{
"type"
:
"speed"
,
"params"
:
{
"min_speed_rate"
:
0.9
,
"max_speed_rate"
:
1.1
,
"num_rates"
:
3
},
"prob"
:
0.0
},
{
{
"type"
:
"shift"
,
"type"
:
"shift"
,
"params"
:
{
"params"
:
{
...
@@ -6,5 +15,22 @@
...
@@ -6,5 +15,22 @@
"max_shift_ms"
:
5
"max_shift_ms"
:
5
},
},
"prob"
:
1.0
"prob"
:
1.0
},
{
"type"
:
"specaug"
,
"params"
:
{
"W"
:
5
,
"warp_mode"
:
"PIL"
,
"F"
:
30
,
"n_freq_masks"
:
2
,
"T"
:
40
,
"n_time_masks"
:
2
,
"p"
:
1.0
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
},
"prob"
:
1.0
}
}
]
]
examples/tiny/s0/local/data.sh
浏览文件 @
756be8fb
#!
/usr/bin/env
bash
#!
/bin/
bash
stage
=
-1
stage
=
-1
stop_stage
=
100
stop_stage
=
100
...
...
examples/tiny/s1/path.sh
浏览文件 @
756be8fb
export
MAIN_ROOT
=
${
PWD
}
/../../../
export
MAIN_ROOT
=
`
realpath
${
PWD
}
/../../../
`
export
PATH
=
${
MAIN_ROOT
}
:
${
MAIN_ROOT
}
/utils:
${
PATH
}
export
PATH
=
${
MAIN_ROOT
}
:
${
MAIN_ROOT
}
/utils:
${
PATH
}
export
LC_ALL
=
C
export
LC_ALL
=
C
...
...
examples/tiny/s1/test.profile
已删除
100644 → 0
浏览文件 @
12f540cd
文件已删除
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录