Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
b5339633
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
1 年多 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b5339633
编写于
4月 23, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix decoding
上级
8c5b8e35
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
171 addition
and
143 deletion
+171
-143
deepspeech/__init__.py
deepspeech/__init__.py
+1
-5
deepspeech/exps/u2/model.py
deepspeech/exps/u2/model.py
+10
-8
deepspeech/frontend/featurizer/audio_featurizer.py
deepspeech/frontend/featurizer/audio_featurizer.py
+1
-1
deepspeech/frontend/featurizer/speech_featurizer.py
deepspeech/frontend/featurizer/speech_featurizer.py
+1
-1
deepspeech/io/dataset.py
deepspeech/io/dataset.py
+3
-4
deepspeech/models/u2.py
deepspeech/models/u2.py
+11
-8
deepspeech/utils/tensor_utils.py
deepspeech/utils/tensor_utils.py
+3
-1
examples/aishell/s0/local/data.sh
examples/aishell/s0/local/data.sh
+65
-53
examples/aishell/s1/local/train.sh
examples/aishell/s1/local/train.sh
+1
-1
examples/tiny/s0/local/data.sh
examples/tiny/s0/local/data.sh
+70
-55
examples/tiny/s1/local/test.sh
examples/tiny/s1/local/test.sh
+4
-5
examples/tiny/s1/local/train.sh
examples/tiny/s1/local/train.sh
+1
-1
未找到文件。
deepspeech/__init__.py
浏览文件 @
b5339633
...
...
@@ -123,11 +123,7 @@ if not hasattr(paddle, 'cat'):
########### hcak paddle.Tensor #############
def
item
(
x
:
paddle
.
Tensor
):
if
x
.
dtype
==
paddle
.
fluid
.
core_avx
.
VarDesc
.
VarType
.
FP32
:
return
float
(
x
)
else
:
raise
ValueError
(
"not support"
)
return
x
.
numpy
().
item
()
if
not
hasattr
(
paddle
.
Tensor
,
'item'
):
logger
.
warn
(
...
...
deepspeech/exps/u2/model.py
浏览文件 @
b5339633
...
...
@@ -381,8 +381,8 @@ class U2Tester(U2Trainer):
decoding_chunk_size
=
cfg
.
decoding_chunk_size
,
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
,
simulate_streaming
=
cfg
.
simulate_streaming
)
decode_time
=
time
.
time
()
decode_time
=
time
.
time
()
-
start_time
for
target
,
result
in
zip
(
target_transcripts
,
result_transcripts
):
errors
,
len_ref
=
errors_func
(
target
,
result
)
errors_sum
+=
errors
...
...
@@ -392,13 +392,13 @@ class U2Tester(U2Trainer):
fout
.
write
(
result
+
"
\n
"
)
logger
.
info
(
"
\n
Target Transcription: %s
\n
Output Transcription: %s"
%
(
target
,
result
))
logger
.
info
(
"
Current
error rate [%s] = %f"
%
logger
.
info
(
"
One example
error rate [%s] = %f"
%
(
cfg
.
error_rate_type
,
error_rate_func
(
target
,
result
)))
return
dict
(
errors_sum
=
errors_sum
,
len_refs
=
len_refs
,
num_ins
=
num_ins
,
# num examples
num_ins
=
num_ins
,
# num examples
error_rate
=
errors_sum
/
len_refs
,
error_rate_type
=
cfg
.
error_rate_type
,
num_frames
=
audio_len
.
sum
().
numpy
().
item
(),
...
...
@@ -411,6 +411,7 @@ class U2Tester(U2Trainer):
self
.
model
.
eval
()
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
stride_ms
=
self
.
test_loader
.
dataset
.
stride_ms
error_rate_type
=
None
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
num_frames
=
0.0
...
...
@@ -424,11 +425,12 @@ class U2Tester(U2Trainer):
len_refs
+=
metrics
[
'len_refs'
]
num_ins
+=
metrics
[
'num_ins'
]
error_rate_type
=
metrics
[
'error_rate_type'
]
logger
.
info
(
"Error rate [%s] (%d/?) = %f"
%
(
error_rate_type
,
num_ins
,
errors_sum
/
len_refs
))
rtf
=
num_time
/
(
num_frames
*
stride_ms
)
logger
.
info
(
"RTF: %f, Error rate [%s] (%d/?) = %f"
%
(
rtf
,
error_rate_type
,
num_ins
,
errors_sum
/
len_refs
))
rtf
=
num_time
/
(
num_frames
*
self
.
test_loader
.
dataset
.
stride_ms
/
1000.0
)
# logging
rtf
=
num_time
/
(
num_frames
*
stride_ms
)
msg
=
"Test: "
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
...
...
deepspeech/frontend/featurizer/audio_featurizer.py
浏览文件 @
b5339633
...
...
@@ -108,7 +108,7 @@ class AudioFeaturizer(object):
@
property
def
stride_ms
(
self
):
return
self
.
_stride_ms
@
property
def
feature_size
(
self
):
"""audio feature size"""
...
...
deepspeech/frontend/featurizer/speech_featurizer.py
浏览文件 @
b5339633
...
...
@@ -148,7 +148,7 @@ class SpeechFeaturizer(object):
float: time(ms)/frame
"""
return
self
.
_audio_featurizer
.
stride_ms
@
property
def
text_feature
(
self
):
"""Return the text feature object.
...
...
deepspeech/io/dataset.py
浏览文件 @
b5339633
...
...
@@ -63,7 +63,7 @@ class ManifestDataset(Dataset):
specgram_type
=
'linear'
,
# 'linear', 'mfcc', 'fbank'
feat_dim
=
0
,
# 'mfcc', 'fbank'
delta_delta
=
False
,
# 'mfcc', 'fbank'
dither
=
1.0
,
# feature dither
dither
=
1.0
,
# feature dither
target_sample_rate
=
16000
,
# target sample rate
use_dB_normalization
=
True
,
target_dB
=-
20
,
...
...
@@ -188,8 +188,7 @@ class ManifestDataset(Dataset):
super
().
__init__
()
self
.
_stride_ms
=
stride_ms
self
.
_target_sample_rate
=
target_sample_rate
self
.
_normalizer
=
FeatureNormalizer
(
mean_std_filepath
)
if
mean_std_filepath
else
None
self
.
_augmentation_pipeline
=
AugmentationPipeline
(
...
...
@@ -251,7 +250,7 @@ class ManifestDataset(Dataset):
@
property
def
feature_size
(
self
):
return
self
.
_speech_featurizer
.
feature_size
@
property
def
stride_ms
(
self
):
return
self
.
_speech_featurizer
.
stride_ms
...
...
deepspeech/models/u2.py
浏览文件 @
b5339633
...
...
@@ -49,10 +49,10 @@ from deepspeech.utils.tensor_utils import pad_sequence
from
deepspeech.utils.tensor_utils
import
th_accuracy
from
deepspeech.utils.utility
import
log_add
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"U2Model"
,
"U2InferModel"
]
logger
=
Log
(
__name__
).
getlog
()
class
U2BaseModel
(
nn
.
Module
):
"""CTC-Attention hybrid Encoder-Decoder model"""
...
...
@@ -398,14 +398,17 @@ class U2BaseModel(nn.Module):
assert
decoding_chunk_size
!=
0
batch_size
=
speech
.
shape
[
0
]
# Let's assume B = batch_size
# encoder_out: (B, maxlen, encoder_dim)
# encoder_mask: (B, 1, Tmax)
encoder_out
,
encoder_mask
=
self
.
_forward_encoder
(
speech
,
speech_lengths
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
# (B, maxlen, encoder_dim)
num_decoding_left_chunks
,
simulate_streaming
)
maxlen
=
encoder_out
.
size
(
1
)
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
sum
(
1
)
# (TODO Hui Zhang): bool no support reduce_sum
# encoder_out_lens = encoder_mask.squeeze(1).sum(1)
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
astype
(
paddle
.
int
).
sum
(
1
)
ctc_probs
=
self
.
ctc
.
log_softmax
(
encoder_out
)
# (B, maxlen, vocab_size)
topk_prob
,
topk_index
=
ctc_probs
.
topk
(
1
,
dim
=
2
)
# (B, maxlen, 1)
topk_prob
,
topk_index
=
ctc_probs
.
topk
(
1
,
axis
=
2
)
# (B, maxlen, 1)
topk_index
=
topk_index
.
view
(
batch_size
,
maxlen
)
# (B, maxlen)
pad_mask
=
make_pad_mask
(
encoder_out_lens
)
# (B, maxlen)
topk_index
=
topk_index
.
masked_fill_
(
pad_mask
,
self
.
eos
)
# (B, maxlen)
...
...
@@ -573,11 +576,11 @@ class U2BaseModel(nn.Module):
hyps_lens
=
hyps_lens
+
1
# Add <sos> at begining
encoder_out
=
encoder_out
.
repeat
(
beam_size
,
1
,
1
)
encoder_mask
=
paddle
.
ones
(
beam_size
,
1
,
encoder_out
.
size
(
1
),
dtype
=
paddle
.
bool
)
(
beam_size
,
1
,
encoder_out
.
size
(
1
)
),
dtype
=
paddle
.
bool
)
decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
hyps_lens
)
# (beam_size, max_hyps_len, vocab_size)
decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
decoder_out
,
dim
=-
1
)
decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
decoder_out
,
axis
=-
1
)
decoder_out
=
decoder_out
.
numpy
()
# Only use decoder score for rescoring
best_score
=
-
float
(
'inf'
)
...
...
deepspeech/utils/tensor_utils.py
浏览文件 @
b5339633
...
...
@@ -66,7 +66,9 @@ def pad_sequence(sequences: List[paddle.Tensor],
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size
=
sequences
[
0
].
size
()
trailing_dims
=
max_size
[
1
:]
# (TODO Hui Zhang): slice not supprot `end==start`
# trailing_dims = max_size[1:]
trailing_dims
=
max_size
[
1
:]
if
max_size
.
ndim
>=
2
else
()
max_len
=
max
([
s
.
size
(
0
)
for
s
in
sequences
])
if
batch_first
:
out_dims
=
(
len
(
sequences
),
max_len
)
+
trailing_dims
...
...
examples/aishell/s0/local/data.sh
浏览文件 @
b5339633
#! /usr/bin/env bash
stage
=
-1
stop_stage
=
100
source
${
MAIN_ROOT
}
/utils/parse_options.sh
mkdir
-p
data
TARGET_DIR
=
${
MAIN_ROOT
}
/examples/dataset
mkdir
-p
${
TARGET_DIR
}
# download data, generate manifests
python3
${
TARGET_DIR
}
/aishell/aishell.py
\
--manifest_prefix
=
"data/manifest"
\
--target_dir
=
"
${
TARGET_DIR
}
/aishell"
if
[
${
stage
}
-le
-1
]
&&
[
${
stop_stage
}
-ge
-1
]
;
then
# download data, generate manifests
python3
${
TARGET_DIR
}
/aishell/aishell.py
\
--manifest_prefix
=
"data/manifest"
\
--target_dir
=
"
${
TARGET_DIR
}
/aishell"
if
[
$?
-ne
0
]
;
then
echo
"Prepare Aishell failed. Terminated."
exit
1
fi
if
[
$?
-ne
0
]
;
then
echo
"Prepare Aishell failed. Terminated."
exit
1
for
dataset
in
train dev
test
;
do
mv
data/manifest.
${
dataset
}
data/manifest.
${
dataset
}
.raw
done
fi
for
dataset
in
train dev
test
;
do
mv
data/manifest.
${
dataset
}
data/manifest.
${
dataset
}
.raw
done
# build vocabulary
python3
${
MAIN_ROOT
}
/utils/build_vocab.py
\
--unit_type
=
"char"
\
--count_threshold
=
0
\
--vocab_path
=
"data/vocab.txt"
\
--manifest_paths
"data/manifest.train.raw"
if
[
$?
-ne
0
]
;
then
echo
"Build vocabulary failed. Terminated."
exit
1
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
# download data, generate manifests
# build vocabulary
python3
${
MAIN_ROOT
}
/utils/build_vocab.py
\
--unit_type
=
"char"
\
--count_threshold
=
0
\
--vocab_path
=
"data/vocab.txt"
\
--manifest_paths
"data/manifest.train.raw"
if
[
$?
-ne
0
]
;
then
echo
"Build vocabulary failed. Terminated."
exit
1
fi
fi
# compute mean and stddev for normalizer
python3
${
MAIN_ROOT
}
/utils/compute_mean_std.py
\
--manifest_path
=
"data/manifest.train.raw"
\
--specgram_type
=
"fbank"
\
--feat_dim
=
80
\
--delta_delta
=
false
\
--stride_ms
=
10.0
\
--window_ms
=
25.0
\
--sample_rate
=
16000
\
--num_samples
=
2000
\
--num_workers
=
0
\
--output_path
=
"data/mean_std.json"
if
[
$?
-ne
0
]
;
then
echo
"Compute mean and stddev failed. Terminated."
exit
1
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# compute mean and stddev for normalizer
python3
${
MAIN_ROOT
}
/utils/compute_mean_std.py
\
--manifest_path
=
"data/manifest.train.raw"
\
--specgram_type
=
"fbank"
\
--feat_dim
=
80
\
--delta_delta
=
false
\
--stride_ms
=
10.0
\
--window_ms
=
25.0
\
--sample_rate
=
16000
\
--num_samples
=
-1
\
--num_workers
=
16
\
--output_path
=
"data/mean_std.json"
if
[
$?
-ne
0
]
;
then
echo
"Compute mean and stddev failed. Terminated."
exit
1
fi
fi
# format manifest with tokenids, vocab size
for
dataset
in
train dev
test
;
do
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
--feat_type
"raw"
\
--cmvn_path
"data/mean_std.npz"
\
--unit_type
"char"
\
--vocab_path
=
"data/vocab.txt"
\
--manifest_path
=
"data/manifest.
${
dataset
}
.raw"
\
--output_path
=
"data/manifest.
${
dataset
}
"
done
if
[
$?
-ne
0
]
;
then
echo
"Formt mnaifest failed. Terminated."
exit
1
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# format manifest with tokenids, vocab size
for
dataset
in
train dev
test
;
do
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
--feat_type
"raw"
\
--cmvn_path
"data/mean_std.json"
\
--unit_type
"char"
\
--vocab_path
=
"data/vocab.txt"
\
--manifest_path
=
"data/manifest.
${
dataset
}
.raw"
\
--output_path
=
"data/manifest.
${
dataset
}
"
done
if
[
$?
-ne
0
]
;
then
echo
"Formt mnaifest failed. Terminated."
exit
1
fi
fi
echo
"Aishell data preparation done."
...
...
examples/aishell/s1/local/train.sh
浏览文件 @
b5339633
#! /usr/bin/env bash
ngpu
=
$(
echo
$
{
CUDA_VISIBLE_DEVICES
}
| python
-c
'import sys; a = sys.stdin.read(); print(len(a.split(",")));
'
)
ngpu
=
$(
echo
$
CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}
'
)
echo
"using
$ngpu
gpus..."
python3
-u
${
BIN_DIR
}
/train.py
\
...
...
examples/tiny/s0/local/data.sh
浏览文件 @
b5339633
#! /usr/bin/env bash
mkdir
-p
data
TARGET_DIR
=
${
MAIN_ROOT
}
/examples/dataset
mkdir
-p
${
TARGET_DIR
}
# download data, generate manifests
python3
${
TARGET_DIR
}
/librispeech/librispeech.py
\
--manifest_prefix
=
"data/manifest"
\
--target_dir
=
"
${
TARGET_DIR
}
/librispeech"
\
--full_download
=
"False"
if
[
$?
-ne
0
]
;
then
echo
"Prepare LibriSpeech failed. Terminated."
exit
1
fi
head
-n
64 data/manifest.dev-clean
>
data/manifest.tiny.raw
stage
=
-1
stop_stage
=
100
# bpemode (unigram or bpe)
nbpe
=
200
bpemode
=
unigram
bpeprefix
=
"data/bpe_
${
bpemode
}
_
${
nbpe
}
"
# build vocabulary
python3
${
MAIN_ROOT
}
/utils/build_vocab.py
\
--unit_type
"spm"
\
--spm_vocab_size
=
${
nbpe
}
\
--spm_mode
${
bpemode
}
\
--spm_model_prefix
${
bpeprefix
}
\
--vocab_path
=
"data/vocab.txt"
\
--manifest_paths
=
"data/manifest.tiny.raw"
if
[
$?
-ne
0
]
;
then
echo
"Build vocabulary failed. Terminated."
exit
1
fi
source
${
MAIN_ROOT
}
/utils/parse_options.sh
# compute mean and stddev for normalizer
python3
${
MAIN_ROOT
}
/utils/compute_mean_std.py
\
--manifest_path
=
"data/manifest.tiny.raw"
\
--num_samples
=
64
\
--specgram_type
=
"fbank"
\
--feat_dim
=
80
\
--delta_delta
=
false
\
--sample_rate
=
16000
\
--stride_ms
=
10.0
\
--window_ms
=
25.0
\
--num_workers
=
0
\
--output_path
=
"data/mean_std.json"
mkdir
-p
data
TARGET_DIR
=
${
MAIN_ROOT
}
/examples/dataset
mkdir
-p
${
TARGET_DIR
}
if
[
$?
-ne
0
]
;
then
echo
"Compute mean and stddev failed. Terminated."
exit
1
if
[
${
stage
}
-le
-1
]
&&
[
${
stop_stage
}
-ge
-1
]
;
then
# download data, generate manifests
python3
${
TARGET_DIR
}
/librispeech/librispeech.py
\
--manifest_prefix
=
"data/manifest"
\
--target_dir
=
"
${
TARGET_DIR
}
/librispeech"
\
--full_download
=
"False"
if
[
$?
-ne
0
]
;
then
echo
"Prepare LibriSpeech failed. Terminated."
exit
1
fi
head
-n
64 data/manifest.dev-clean
>
data/manifest.tiny.raw
fi
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
# build vocabulary
python3
${
MAIN_ROOT
}
/utils/build_vocab.py
\
--unit_type
"spm"
\
--spm_vocab_size
=
${
nbpe
}
\
--spm_mode
${
bpemode
}
\
--spm_model_prefix
${
bpeprefix
}
\
--vocab_path
=
"data/vocab.txt"
\
--manifest_paths
=
"data/manifest.tiny.raw"
if
[
$?
-ne
0
]
;
then
echo
"Build vocabulary failed. Terminated."
exit
1
fi
fi
# format manifest with tokenids, vocab size
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
--feat_type
"raw"
\
--cmvn_path
"data/mean_std.npz"
\
--unit_type
"spm"
\
--spm_model_prefix
${
bpeprefix
}
\
--vocab_path
=
"data/vocab.txt"
\
--manifest_path
=
"data/manifest.tiny.raw"
\
--output_path
=
"data/manifest.tiny"
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
# compute mean and stddev for normalizer
python3
${
MAIN_ROOT
}
/utils/compute_mean_std.py
\
--manifest_path
=
"data/manifest.tiny.raw"
\
--num_samples
=
64
\
--specgram_type
=
"fbank"
\
--feat_dim
=
80
\
--delta_delta
=
false
\
--sample_rate
=
16000
\
--stride_ms
=
10.0
\
--window_ms
=
25.0
\
--num_workers
=
2
\
--output_path
=
"data/mean_std.json"
if
[
$?
-ne
0
]
;
then
echo
"Compute mean and stddev failed. Terminated."
exit
1
fi
fi
if
[
$?
-ne
0
]
;
then
echo
"Formt mnaifest failed. Terminated."
exit
1
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
# format manifest with tokenids, vocab size
python3
${
MAIN_ROOT
}
/utils/format_data.py
\
--feat_type
"raw"
\
--cmvn_path
"data/mean_std.json"
\
--unit_type
"spm"
\
--spm_model_prefix
${
bpeprefix
}
\
--vocab_path
=
"data/vocab.txt"
\
--manifest_path
=
"data/manifest.tiny.raw"
\
--output_path
=
"data/manifest.tiny"
if
[
$?
-ne
0
]
;
then
echo
"Formt mnaifest failed. Terminated."
exit
1
fi
fi
echo
"LibriSpeech Data preparation done."
...
...
examples/tiny/s1/local/test.sh
浏览文件 @
b5339633
#! /usr/bin/env bash
# download language model
bash
local
/download_lm_en.sh
if
[
$?
-ne
0
]
;
then
exit
1
fi
#
bash local/download_lm_en.sh
#
if [ $? -ne 0 ]; then
#
exit 1
#
fi
CUDA_VISIBLE_DEVICES
=
0
\
python3
-u
${
BIN_DIR
}
/test.py
\
--device
'gpu'
\
--nproc
1
\
...
...
examples/tiny/s1/local/train.sh
浏览文件 @
b5339633
#! /usr/bin/env bash
ngpu
=
$(
echo
$
{
CUDA_VISIBLE_DEVICES
}
| python
-c
'import sys; a = sys.stdin.read(); print(len(a.split(",")));
'
)
ngpu
=
$(
echo
$
CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}
'
)
echo
"using
$ngpu
gpus..."
python3
-u
${
BIN_DIR
}
/train.py
\
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录