Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
756be8fb
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
756be8fb
编写于
10月 08, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix dist batch sampler set_epcoh call
上级
12f540cd
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
55 addition
and
58 deletion
+55
-58
deepspeech/exps/u2/model.py
deepspeech/exps/u2/model.py
+6
-44
deepspeech/frontend/augmentor/spec_augment.py
deepspeech/frontend/augmentor/spec_augment.py
+5
-4
deepspeech/frontend/featurizer/text_featurizer.py
deepspeech/frontend/featurizer/text_featurizer.py
+3
-2
deepspeech/models/u2/u2.py
deepspeech/models/u2/u2.py
+4
-2
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+9
-4
examples/tiny/s0/conf/augmentation.json
examples/tiny/s0/conf/augmentation.json
+26
-0
examples/tiny/s0/local/data.sh
examples/tiny/s0/local/data.sh
+1
-1
examples/tiny/s1/path.sh
examples/tiny/s1/path.sh
+1
-1
examples/tiny/s1/test.profile
examples/tiny/s1/test.profile
+0
-0
未找到文件。
deepspeech/exps/u2/model.py
浏览文件 @
756be8fb
...
...
@@ -24,6 +24,7 @@ import numpy as np
import
paddle
from
paddle
import
distributed
as
dist
from
paddle.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
from
yacs.config
import
CfgNode
from
deepspeech.io.collator
import
SpeechCollator
...
...
@@ -162,8 +163,10 @@ class U2Trainer(Trainer):
self
.
save
(
tag
=
'init'
)
self
.
lr_scheduler
.
step
(
self
.
iteration
)
if
self
.
parallel
:
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
if
hasattr
(
self
.
train_loader
,
"batch_sampler"
):
batch_sampler
=
self
.
train_loader
.
batch_sampler
if
isinstance
(
batch_sampler
,
DistributedBatchSampler
):
batch_sampler
.
set_epoch
(
self
.
epoch
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
...
...
@@ -476,13 +479,6 @@ class U2Tester(U2Trainer):
})
f
.
write
(
data
+
'
\n
'
)
# def run_test(self):
# self.resume_or_scratch()
# try:
# self.test()
# except KeyboardInterrupt:
# sys.exit(-1)
def
load_inferspec
(
self
):
"""infer model and input spec.
...
...
@@ -491,7 +487,7 @@ class U2Tester(U2Trainer):
List[paddle.static.InputSpec]: input spec.
"""
from
deepspeech.models.u2
import
U2InferModel
infer_model
=
U2InferModel
.
from_pretrained
(
self
.
test_loader
.
dataset
,
infer_model
=
U2InferModel
.
from_pretrained
(
self
.
test_loader
,
self
.
config
.
model
.
clone
(),
self
.
args
.
checkpoint_path
)
feat_dim
=
self
.
test_loader
.
dataset
.
feature_size
...
...
@@ -511,37 +507,3 @@ class U2Tester(U2Trainer):
static_model
=
paddle
.
jit
.
to_static
(
infer_model
,
input_spec
=
input_spec
)
logger
.
info
(
f
"Export code:
{
static_model
.
forward
.
code
}
"
)
paddle
.
jit
.
save
(
static_model
,
self
.
args
.
export_path
)
# def run_export(self):
# try:
# self.export()
# except KeyboardInterrupt:
# sys.exit(-1)
# def setup(self):
# """Setup the experiment.
# """
# paddle.set_device(self.args.device)
# self.setup_output_dir()
# self.setup_checkpointer()
# self.setup_dataloader()
# self.setup_model()
# self.iteration = 0
# self.epoch = 0
# def setup_output_dir(self):
# """Create a directory used for output.
# """
# # output dir
# if self.args.output:
# output_dir = Path(self.args.output).expanduser()
# output_dir.mkdir(parents=True, exist_ok=True)
# else:
# output_dir = Path(
# self.args.checkpoint_path).expanduser().parent.parent
# output_dir.mkdir(parents=True, exist_ok=True)
# self.output_dir = output_dir
deepspeech/frontend/augmentor/spec_augment.py
浏览文件 @
756be8fb
...
...
@@ -25,10 +25,10 @@ class SpecAugmentor(AugmentorBase):
SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
https://arxiv.org/abs/1904.08779
SpecAugment on Large Scale Datasets
https://arxiv.org/abs/1912.05533
"""
def
__init__
(
self
,
...
...
@@ -41,7 +41,8 @@ class SpecAugmentor(AugmentorBase):
W
=
40
,
adaptive_number_ratio
=
0
,
adaptive_size_ratio
=
0
,
max_n_time_masks
=
20
):
max_n_time_masks
=
20
,
**
kwargs
):
"""SpecAugment class.
Args:
rng (random.Random): random generator object.
...
...
@@ -121,7 +122,7 @@ class SpecAugmentor(AugmentorBase):
def
time_mask
(
self
):
return
self
.
_time_mask
def
time_warp
(
xs
,
W
=
40
):
def
time_warp
(
self
,
xs
,
W
=
40
):
raise
NotImplementedError
def
mask_freq
(
self
,
xs
,
replace_with_zero
=
False
):
...
...
deepspeech/frontend/featurizer/text_featurizer.py
浏览文件 @
756be8fb
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
"""Contains the text featurizer class."""
import
sentencepiece
as
spm
from
pprint
import
pformat
from
..utility
import
EOS
from
..utility
import
SPACE
...
...
@@ -206,7 +207,7 @@ class TextFeaturizer():
"""Load vocabulary from file."""
vocab_list
=
load_dict
(
vocab_filepath
,
maskctc
)
assert
vocab_list
is
not
None
logger
.
info
(
f
"Vocab:
{
vocab_list
}
"
)
logger
.
info
(
f
"Vocab:
{
pformat
(
vocab_list
)
}
"
)
id2token
=
dict
(
[(
idx
,
token
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
...
...
@@ -220,10 +221,10 @@ class TextFeaturizer():
sos_id
=
vocab_list
.
index
(
SOS
)
if
SOS
in
vocab_list
else
-
1
space_id
=
vocab_list
.
index
(
SPACE
)
if
SPACE
in
vocab_list
else
-
1
logger
.
info
(
f
"BLANK id:
{
blank_id
}
"
)
logger
.
info
(
f
"UNK id:
{
unk_id
}
"
)
logger
.
info
(
f
"EOS id:
{
eos_id
}
"
)
logger
.
info
(
f
"SOS id:
{
sos_id
}
"
)
logger
.
info
(
f
"SPACE id:
{
space_id
}
"
)
logger
.
info
(
f
"BLANK id:
{
blank_id
}
"
)
logger
.
info
(
f
"MASKCTC id:
{
maskctc_id
}
"
)
return
token2id
,
id2token
,
vocab_list
,
unk_id
,
eos_id
deepspeech/models/u2/u2.py
浏览文件 @
756be8fb
...
...
@@ -911,8 +911,10 @@ class U2Model(U2BaseModel):
DeepSpeech2Model: The model built from pretrained result.
"""
with
UpdateConfig
(
config
):
config
.
input_dim
=
dataloader
.
collate_fn
.
feature_size
config
.
output_dim
=
dataloader
.
collate_fn
.
vocab_size
#config.input_dim = dataloader.collate_fn.feature_size
#config.output_dim = dataloader.collate_fn.vocab_size
config
.
input_dim
=
dataloader
.
dataset
.
feature_size
config
.
output_dim
=
dataloader
.
dataset
.
vocab_size
model
=
cls
.
from_config
(
config
)
...
...
deepspeech/training/trainer.py
浏览文件 @
756be8fb
...
...
@@ -17,6 +17,7 @@ from pathlib import Path
import
paddle
from
paddle
import
distributed
as
dist
from
paddle.io
import
DistributedBatchSampler
from
tensorboardX
import
SummaryWriter
from
deepspeech.utils
import
mp_tools
...
...
@@ -179,8 +180,10 @@ class Trainer():
"""Reset the train loader seed and increment `epoch`.
"""
self
.
epoch
+=
1
if
self
.
parallel
:
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
if
hasattr
(
self
.
train_loader
,
"batch_sampler"
):
batch_sampler
=
self
.
train_loader
.
batch_sampler
if
isinstance
(
batch_sampler
,
DistributedBatchSampler
):
batch_sampler
.
set_epoch
(
self
.
epoch
)
def
train
(
self
):
"""The training process control by epoch."""
...
...
@@ -190,8 +193,10 @@ class Trainer():
self
.
save
(
tag
=
'init'
)
self
.
lr_scheduler
.
step
(
self
.
iteration
)
if
self
.
parallel
:
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
if
hasattr
(
self
.
train_loader
,
"batch_sampler"
):
batch_sampler
=
self
.
train_loader
.
batch_sampler
if
isinstance
(
batch_sampler
,
DistributedBatchSampler
):
batch_sampler
.
set_epoch
(
self
.
epoch
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
...
...
examples/tiny/s0/conf/augmentation.json
浏览文件 @
756be8fb
[
{
"type"
:
"speed"
,
"params"
:
{
"min_speed_rate"
:
0.9
,
"max_speed_rate"
:
1.1
,
"num_rates"
:
3
},
"prob"
:
0.0
},
{
"type"
:
"shift"
,
"params"
:
{
...
...
@@ -6,5 +15,22 @@
"max_shift_ms"
:
5
},
"prob"
:
1.0
},
{
"type"
:
"specaug"
,
"params"
:
{
"W"
:
5
,
"warp_mode"
:
"PIL"
,
"F"
:
30
,
"n_freq_masks"
:
2
,
"T"
:
40
,
"n_time_masks"
:
2
,
"p"
:
1.0
,
"adaptive_number_ratio"
:
0
,
"adaptive_size_ratio"
:
0
,
"max_n_time_masks"
:
20
,
"replace_with_zero"
:
true
},
"prob"
:
1.0
}
]
examples/tiny/s0/local/data.sh
浏览文件 @
756be8fb
#!
/usr/bin/env
bash
#!
/bin/
bash
stage
=
-1
stop_stage
=
100
...
...
examples/tiny/s1/path.sh
浏览文件 @
756be8fb
export
MAIN_ROOT
=
${
PWD
}
/../../../
export
MAIN_ROOT
=
`
realpath
${
PWD
}
/../../../
`
export
PATH
=
${
MAIN_ROOT
}
:
${
MAIN_ROOT
}
/utils:
${
PATH
}
export
LC_ALL
=
C
...
...
examples/tiny/s1/test.profile
已删除
100644 → 0
浏览文件 @
12f540cd
文件已删除
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录