Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
b355b67f
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b355b67f
编写于
4月 15, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
log valid loss, time dataset process
上级
926b1876
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
39 addition
and
16 deletion
+39
-16
deepspeech/exps/deepspeech2/model.py
deepspeech/exps/deepspeech2/model.py
+2
-1
deepspeech/exps/u2/bin/test.py
deepspeech/exps/u2/bin/test.py
+6
-1
deepspeech/exps/u2/model.py
deepspeech/exps/u2/model.py
+4
-4
deepspeech/io/dataset.py
deepspeech/io/dataset.py
+15
-0
deepspeech/models/u2.py
deepspeech/models/u2.py
+2
-1
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+10
-9
未找到文件。
deepspeech/exps/deepspeech2/model.py
浏览文件 @
b355b67f
...
@@ -46,7 +46,6 @@ class DeepSpeech2Trainer(Trainer):
...
@@ -46,7 +46,6 @@ class DeepSpeech2Trainer(Trainer):
super
().
__init__
(
config
,
args
)
super
().
__init__
(
config
,
args
)
def
train_batch
(
self
,
batch_index
,
batch_data
,
msg
):
def
train_batch
(
self
,
batch_index
,
batch_data
,
msg
):
self
.
model
.
train
()
start
=
time
.
time
()
start
=
time
.
time
()
loss
=
self
.
model
(
*
batch_data
)
loss
=
self
.
model
(
*
batch_data
)
...
@@ -100,6 +99,8 @@ class DeepSpeech2Trainer(Trainer):
...
@@ -100,6 +99,8 @@ class DeepSpeech2Trainer(Trainer):
self
.
visualizer
.
add_scalar
(
"valid/{}"
.
format
(
k
),
v
,
self
.
visualizer
.
add_scalar
(
"valid/{}"
.
format
(
k
),
v
,
self
.
iteration
)
self
.
iteration
)
return
valid_losses
def
setup_model
(
self
):
def
setup_model
(
self
):
config
=
self
.
config
config
=
self
.
config
model
=
DeepSpeech2Model
(
model
=
DeepSpeech2Model
(
...
...
deepspeech/exps/u2/bin/test.py
浏览文件 @
b355b67f
...
@@ -12,6 +12,8 @@
...
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Evaluation for U2 model."""
"""Evaluation for U2 model."""
import
os
import
cProfile
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.utils.utility
import
print_arguments
from
deepspeech.utils.utility
import
print_arguments
...
@@ -48,4 +50,7 @@ if __name__ == "__main__":
...
@@ -48,4 +50,7 @@ if __name__ == "__main__":
with
open
(
args
.
dump_config
,
'w'
)
as
f
:
with
open
(
args
.
dump_config
,
'w'
)
as
f
:
print
(
config
,
file
=
f
)
print
(
config
,
file
=
f
)
main
(
config
,
args
)
# Setting for profiling
pr
=
cProfile
.
Profile
()
pr
.
runcall
(
main
,
config
,
args
)
pr
.
dump_stats
(
os
.
path
.
join
(
'.'
,
'test.profile'
))
deepspeech/exps/u2/model.py
浏览文件 @
b355b67f
...
@@ -77,8 +77,6 @@ class U2Trainer(Trainer):
...
@@ -77,8 +77,6 @@ class U2Trainer(Trainer):
def
train_batch
(
self
,
batch_index
,
batch_data
,
msg
):
def
train_batch
(
self
,
batch_index
,
batch_data
,
msg
):
train_conf
=
self
.
config
.
training
train_conf
=
self
.
config
.
training
self
.
model
.
train
()
start
=
time
.
time
()
start
=
time
.
time
()
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
*
batch_data
)
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
*
batch_data
)
...
@@ -134,6 +132,7 @@ class U2Trainer(Trainer):
...
@@ -134,6 +132,7 @@ class U2Trainer(Trainer):
self
.
logger
.
info
(
self
.
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
self
.
model
.
train
()
try
:
try
:
data_start_time
=
time
.
time
()
data_start_time
=
time
.
time
()
for
batch_index
,
batch
in
enumerate
(
self
.
train_loader
):
for
batch_index
,
batch
in
enumerate
(
self
.
train_loader
):
...
@@ -149,8 +148,8 @@ class U2Trainer(Trainer):
...
@@ -149,8 +148,8 @@ class U2Trainer(Trainer):
self
.
logger
.
error
(
e
)
self
.
logger
.
error
(
e
)
raise
e
raise
e
self
.
valid
()
valid_losses
=
self
.
valid
()
self
.
save
()
self
.
save
(
infos
=
valid_losses
)
self
.
new_epoch
()
self
.
new_epoch
()
@
mp_tools
.
rank_zero_only
@
mp_tools
.
rank_zero_only
...
@@ -182,6 +181,7 @@ class U2Trainer(Trainer):
...
@@ -182,6 +181,7 @@ class U2Trainer(Trainer):
for
k
,
v
in
valid_losses
.
items
():
for
k
,
v
in
valid_losses
.
items
():
self
.
visualizer
.
add_scalar
(
"valid/{}"
.
format
(
k
),
v
,
self
.
visualizer
.
add_scalar
(
"valid/{}"
.
format
(
k
),
v
,
self
.
iteration
)
self
.
iteration
)
return
valid_losses
def
setup_dataloader
(
self
):
def
setup_dataloader
(
self
):
config
=
self
.
config
.
clone
()
config
=
self
.
config
.
clone
()
...
...
deepspeech/io/dataset.py
浏览文件 @
b355b67f
...
@@ -290,19 +290,34 @@ class ManifestDataset(Dataset):
...
@@ -290,19 +290,34 @@ class ManifestDataset(Dataset):
where transcription part could be token ids or text.
where transcription part could be token ids or text.
:rtype: tuple of (2darray, list)
:rtype: tuple of (2darray, list)
"""
"""
start_time
=
time
.
time
()
if
isinstance
(
audio_file
,
str
)
and
audio_file
.
startswith
(
'tar:'
):
if
isinstance
(
audio_file
,
str
)
and
audio_file
.
startswith
(
'tar:'
):
speech_segment
=
SpeechSegment
.
from_file
(
speech_segment
=
SpeechSegment
.
from_file
(
self
.
_subfile_from_tar
(
audio_file
),
transcript
)
self
.
_subfile_from_tar
(
audio_file
),
transcript
)
else
:
else
:
speech_segment
=
SpeechSegment
.
from_file
(
audio_file
,
transcript
)
speech_segment
=
SpeechSegment
.
from_file
(
audio_file
,
transcript
)
load_wav_time
=
time
.
time
()
-
start_time
logger
.
debug
(
f
"load wav time:
{
load_wav_time
}
"
)
# audio augment
# audio augment
start_time
=
time
.
time
()
self
.
_augmentation_pipeline
.
transform_audio
(
speech_segment
)
self
.
_augmentation_pipeline
.
transform_audio
(
speech_segment
)
audio_aug_time
=
time
.
time
()
-
start_time
logger
.
debug
(
f
"audio augmentation time:
{
audio_aug_time
}
"
)
start_time
=
time
.
time
()
specgram
,
transcript_part
=
self
.
_speech_featurizer
.
featurize
(
specgram
,
transcript_part
=
self
.
_speech_featurizer
.
featurize
(
speech_segment
,
self
.
_keep_transcription_text
)
speech_segment
,
self
.
_keep_transcription_text
)
if
self
.
_normalizer
:
if
self
.
_normalizer
:
specgram
=
self
.
_normalizer
.
apply
(
specgram
)
specgram
=
self
.
_normalizer
.
apply
(
specgram
)
feature_time
=
time
.
time
()
-
start_time
logger
.
debug
(
f
"audio & test feature time:
{
feature_time
}
"
)
# specgram augment
# specgram augment
start_time
=
time
.
time
()
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
feature_aug_time
=
time
.
time
()
-
start_time
logger
.
debug
(
f
"audio feature augmentation time:
{
feature_aug_time
}
"
)
return
specgram
,
transcript_part
return
specgram
,
transcript_part
def
_instance_reader_creator
(
self
,
manifest
):
def
_instance_reader_creator
(
self
,
manifest
):
...
...
deepspeech/models/u2.py
浏览文件 @
b355b67f
...
@@ -821,7 +821,8 @@ class U2Model(U2BaseModel):
...
@@ -821,7 +821,8 @@ class U2Model(U2BaseModel):
mean
,
istd
=
load_cmvn
(
configs
[
'cmvn_file'
],
mean
,
istd
=
load_cmvn
(
configs
[
'cmvn_file'
],
configs
[
'cmvn_file_type'
])
configs
[
'cmvn_file_type'
])
global_cmvn
=
GlobalCMVN
(
global_cmvn
=
GlobalCMVN
(
paddle
.
to_tensor
(
mean
).
float
(),
paddle
.
to_tensor
(
istd
).
float
())
paddle
.
to_tensor
(
mean
,
dtype
=
paddle
.
float
),
paddle
.
to_tensor
(
istd
,
dtype
=
paddle
.
float
))
else
:
else
:
global_cmvn
=
None
global_cmvn
=
None
...
...
deepspeech/training/trainer.py
浏览文件 @
b355b67f
...
@@ -128,15 +128,15 @@ class Trainer():
...
@@ -128,15 +128,15 @@ class Trainer():
dist
.
init_parallel_env
()
dist
.
init_parallel_env
()
@
mp_tools
.
rank_zero_only
@
mp_tools
.
rank_zero_only
def
save
(
self
,
tag
=
None
,
infos
=
None
):
def
save
(
self
,
tag
=
None
,
infos
:
dict
=
None
):
"""Save checkpoint (model parameters and optimizer states).
"""Save checkpoint (model parameters and optimizer states).
"""
"""
i
f
infos
is
None
:
i
nfos
=
infos
if
infos
else
dict
()
infos
=
{
infos
.
update
(
{
"step"
:
self
.
iteration
,
"step"
:
self
.
iteration
,
"epoch"
:
self
.
epoch
,
"epoch"
:
self
.
epoch
,
"lr"
:
self
.
optimizer
.
get_lr
(),
"lr"
:
self
.
optimizer
.
get_lr
()
}
})
checkpoint
.
save_parameters
(
self
.
checkpoint_dir
,
self
.
iteration
checkpoint
.
save_parameters
(
self
.
checkpoint_dir
,
self
.
iteration
if
tag
is
None
else
tag
,
self
.
model
,
if
tag
is
None
else
tag
,
self
.
model
,
self
.
optimizer
,
infos
)
self
.
optimizer
,
infos
)
...
@@ -185,6 +185,7 @@ class Trainer():
...
@@ -185,6 +185,7 @@ class Trainer():
self
.
logger
.
info
(
self
.
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
self
.
model
.
train
()
try
:
try
:
data_start_time
=
time
.
time
()
data_start_time
=
time
.
time
()
for
batch_index
,
batch
in
enumerate
(
self
.
train_loader
):
for
batch_index
,
batch
in
enumerate
(
self
.
train_loader
):
...
@@ -200,8 +201,8 @@ class Trainer():
...
@@ -200,8 +201,8 @@ class Trainer():
self
.
logger
.
error
(
e
)
self
.
logger
.
error
(
e
)
raise
e
raise
e
self
.
valid
()
valid_losses
=
self
.
valid
()
self
.
save
()
self
.
save
(
infos
=
valid_losses
)
self
.
lr_scheduler
.
step
()
self
.
lr_scheduler
.
step
()
self
.
new_epoch
()
self
.
new_epoch
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录