Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
a58b1cb3
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
1 年多 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a58b1cb3
编写于
6月 08, 2021
作者:
H
Haoxin Ma
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add result output
上级
f3c9f32c
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
43 addition
and
29 deletion
+43
-29
deepspeech/exps/deepspeech2/model.py
deepspeech/exps/deepspeech2/model.py
+14
-12
deepspeech/exps/u2/model.py
deepspeech/exps/u2/model.py
+11
-6
deepspeech/io/collator.py
deepspeech/io/collator.py
+4
-1
deepspeech/models/u2.py
deepspeech/models/u2.py
+1
-0
deepspeech/modules/conv.py
deepspeech/modules/conv.py
+2
-1
examples/tiny/s0/run.sh
examples/tiny/s0/run.sh
+1
-1
examples/tiny/s1/conf/transformer.yaml
examples/tiny/s1/conf/transformer.yaml
+4
-4
examples/tiny/s1/run.sh
examples/tiny/s1/run.sh
+6
-4
未找到文件。
deepspeech/exps/deepspeech2/model.py
浏览文件 @
a58b1cb3
...
...
@@ -193,7 +193,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
trans
.
append
(
''
.
join
([
chr
(
i
)
for
i
in
ids
]))
return
trans
def
compute_metrics
(
self
,
audio
,
audio_len
,
texts
,
texts_len
):
def
compute_metrics
(
self
,
utts
,
audio
,
audio_len
,
texts
,
texts_len
,
fout
=
None
):
cfg
=
self
.
config
.
decoding
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
errors_func
=
error_rate
.
char_errors
if
cfg
.
error_rate_type
==
'cer'
else
error_rate
.
word_errors
...
...
@@ -215,11 +215,13 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cutoff_top_n
=
cfg
.
cutoff_top_n
,
num_processes
=
cfg
.
num_proc_bsearch
)
for
target
,
result
in
zip
(
target_transcripts
,
result_transcripts
):
for
utt
,
target
,
result
in
zip
(
utts
,
target_transcripts
,
result_transcripts
):
errors
,
len_ref
=
errors_func
(
target
,
result
)
errors_sum
+=
errors
len_refs
+=
len_ref
num_ins
+=
1
if
fout
:
fout
.
write
(
utt
+
" "
+
result
+
"
\n
"
)
logger
.
info
(
"
\n
Target Transcription: %s
\n
Output Transcription: %s"
%
(
target
,
result
))
logger
.
info
(
"Current error rate [%s] = %f"
%
...
...
@@ -240,16 +242,16 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cfg
=
self
.
config
error_rate_type
=
None
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
for
i
,
batch
in
enumerate
(
self
.
test_loader
):
utt
,
audio
,
audio_len
,
texts
,
texts_len
=
batch
metrics
=
self
.
compute_metrics
(
audio
,
audio_len
,
texts
,
texts_len
)
errors_sum
+=
metrics
[
'errors_sum'
]
len_refs
+=
metrics
[
'len_refs'
]
num_ins
+=
metrics
[
'num_ins'
]
error_rate_type
=
metrics
[
'error_rate_type'
]
logger
.
info
(
"Error rate [%s] (%d/?) = %f"
%
(
error_rate_type
,
num_ins
,
errors_sum
/
len_refs
))
with
open
(
self
.
args
.
result_file
,
'w'
)
as
fout
:
for
i
,
batch
in
enumerate
(
self
.
test_loader
):
utts
,
audio
,
audio_len
,
texts
,
texts_len
=
batch
metrics
=
self
.
compute_metrics
(
utts
,
audio
,
audio_len
,
texts
,
texts_len
,
fout
)
errors_sum
+=
metrics
[
'errors_sum'
]
len_refs
+=
metrics
[
'len_refs'
]
num_ins
+=
metrics
[
'num_ins'
]
error_rate_type
=
metrics
[
'error_rate_type'
]
logger
.
info
(
"Error rate [%s] (%d/?) = %f"
%
(
error_rate_type
,
num_ins
,
errors_sum
/
len_refs
))
# logging
msg
=
"Test: "
...
...
deepspeech/exps/u2/model.py
浏览文件 @
a58b1cb3
...
...
@@ -76,8 +76,9 @@ class U2Trainer(Trainer):
def
train_batch
(
self
,
batch_index
,
batch_data
,
msg
):
train_conf
=
self
.
config
.
training
start
=
time
.
time
()
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch_data
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
*
batch_data
)
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
audio
,
audio_len
,
text
,
text_len
)
# loss div by `batch_size * accum_grad`
loss
/=
train_conf
.
accum_grad
loss
.
backward
()
...
...
@@ -119,9 +120,10 @@ class U2Trainer(Trainer):
num_seen_utts
=
1
total_loss
=
0.0
for
i
,
batch
in
enumerate
(
self
.
valid_loader
):
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
*
batch
)
utt
,
audio
,
audio_len
,
text
,
text_len
=
batch
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
audio
,
audio_len
,
text
,
text_len
)
if
paddle
.
isfinite
(
loss
):
num_utts
=
batch
[
0
].
shape
[
0
]
num_utts
=
batch
[
1
].
shape
[
0
]
num_seen_utts
+=
num_utts
total_loss
+=
float
(
loss
)
*
num_utts
valid_losses
[
'val_loss'
].
append
(
float
(
loss
))
...
...
@@ -366,7 +368,7 @@ class U2Tester(U2Trainer):
trans
.
append
(
''
.
join
([
chr
(
i
)
for
i
in
ids
]))
return
trans
def
compute_metrics
(
self
,
audio
,
audio_len
,
texts
,
texts_len
,
fout
=
None
):
def
compute_metrics
(
self
,
utts
,
audio
,
audio_len
,
texts
,
texts_len
,
fout
=
None
,
fref
=
None
):
cfg
=
self
.
config
.
decoding
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
errors_func
=
error_rate
.
char_errors
if
cfg
.
error_rate_type
==
'cer'
else
error_rate
.
word_errors
...
...
@@ -393,13 +395,15 @@ class U2Tester(U2Trainer):
simulate_streaming
=
cfg
.
simulate_streaming
)
decode_time
=
time
.
time
()
-
start_time
for
target
,
result
in
zip
(
target_transcripts
,
result_transcripts
):
for
utt
,
target
,
result
in
zip
(
utts
,
target_transcripts
,
result_transcripts
):
errors
,
len_ref
=
errors_func
(
target
,
result
)
errors_sum
+=
errors
len_refs
+=
len_ref
num_ins
+=
1
if
fout
:
fout
.
write
(
result
+
"
\n
"
)
fout
.
write
(
utt
+
" "
+
result
+
"
\n
"
)
if
fref
:
fref
.
write
(
utt
+
" "
+
target
+
"
\n
"
)
logger
.
info
(
"
\n
Target Transcription: %s
\n
Output Transcription: %s"
%
(
target
,
result
))
logger
.
info
(
"One example error rate [%s] = %f"
%
...
...
@@ -428,6 +432,7 @@ class U2Tester(U2Trainer):
num_time
=
0.0
with
open
(
self
.
args
.
result_file
,
'w'
)
as
fout
:
for
i
,
batch
in
enumerate
(
self
.
test_loader
):
# utt, audio, audio_len, text, text_len = batch
metrics
=
self
.
compute_metrics
(
*
batch
,
fout
=
fout
)
num_frames
+=
metrics
[
'num_frames'
]
num_time
+=
metrics
[
"decode_time"
]
...
...
deepspeech/io/collator.py
浏览文件 @
a58b1cb3
...
...
@@ -51,7 +51,10 @@ class SpeechCollator():
audio_lens
=
[]
texts
=
[]
text_lens
=
[]
utts
=
[]
for
utt
,
audio
,
text
in
batch
:
#utt
utts
.
append
(
utt
)
# audio
audios
.
append
(
audio
.
T
)
# [T, D]
audio_lens
.
append
(
audio
.
shape
[
1
])
...
...
@@ -75,4 +78,4 @@ class SpeechCollator():
padded_texts
=
pad_sequence
(
texts
,
padding_value
=
IGNORE_ID
).
astype
(
np
.
int64
)
text_lens
=
np
.
array
(
text_lens
).
astype
(
np
.
int64
)
return
utt
,
padded_audios
,
audio_lens
,
padded_texts
,
text_lens
return
utt
s
,
padded_audios
,
audio_lens
,
padded_texts
,
text_lens
deepspeech/models/u2.py
浏览文件 @
a58b1cb3
...
...
@@ -905,6 +905,7 @@ class U2InferModel(U2Model):
def
__init__
(
self
,
configs
:
dict
):
super
().
__init__
(
configs
)
def
forward
(
self
,
feats
,
feats_lengths
,
...
...
deepspeech/modules/conv.py
浏览文件 @
a58b1cb3
...
...
@@ -114,7 +114,8 @@ class ConvBn(nn.Layer):
masks
=
make_non_pad_mask
(
x_len
)
#[B, T]
masks
=
masks
.
unsqueeze
(
1
).
unsqueeze
(
1
)
# [B, 1, 1, T]
# TODO(Hui Zhang): not support bool multiply
masks
=
masks
.
type_as
(
x
)
# masks = masks.type_as(x)
masks
=
masks
.
astype
(
x
)
x
=
x
.
multiply
(
masks
)
return
x
,
x_len
...
...
examples/tiny/s0/run.sh
浏览文件 @
a58b1cb3
...
...
@@ -26,7 +26,7 @@ fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# avg n best model
./local/
avg.sh exp/
${
ckpt
}
/checkpoints
${
avg_num
}
avg.sh exp/
${
ckpt
}
/checkpoints
${
avg_num
}
fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
...
...
examples/tiny/s1/conf/transformer.yaml
浏览文件 @
a58b1cb3
...
...
@@ -8,7 +8,7 @@ data:
spm_model_prefix
:
'
data/bpe_unigram_200'
mean_std_filepath
:
"
"
augmentation_config
:
conf/augmentation.json
batch_size
:
4
batch_size
:
2
#
4
min_input_len
:
0.5
# second
max_input_len
:
20.0
# second
min_output_len
:
0.0
# tokens
...
...
@@ -31,7 +31,7 @@ data:
keep_transcription_text
:
False
sortagrad
:
True
shuffle_method
:
batch_shuffle
num_workers
:
2
num_workers
:
0
#
2
# network architecture
...
...
@@ -70,7 +70,7 @@ model:
training
:
n_epoch
:
2
0
n_epoch
:
2
accum_grad
:
1
global_grad_clip
:
5.0
optim
:
adam
...
...
@@ -85,7 +85,7 @@ training:
decoding
:
batch_size
:
64
batch_size
:
8
#
64
error_rate_type
:
wer
decoding_method
:
attention
# 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path
:
data/lm/common_crawl_00.prune01111.trie.klm
...
...
examples/tiny/s1/run.sh
浏览文件 @
a58b1cb3
...
...
@@ -20,20 +20,22 @@ fi
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES
=
4,5,6,7
./local/train.sh
${
conf_path
}
${
ckpt
}
./local/train.sh
${
conf_path
}
${
ckpt
}
fi
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# avg n best model
./local/
avg.sh exp/
${
ckpt
}
/checkpoints
${
avg_num
}
avg.sh exp/
${
ckpt
}
/checkpoints
${
avg_num
}
fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES
=
7 ./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
# CUDA_VISIBLE_DEVICES=7
./local/test.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
||
exit
-1
fi
if
[
${
stage
}
-le
4
]
&&
[
${
stop_stage
}
-ge
4
]
;
then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES
=
./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
# CUDA_VISIBLE_DEVICES=
./local/export.sh
${
conf_path
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
exp/
${
ckpt
}
/checkpoints/
${
avg_ckpt
}
.jit
fi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录