Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
b31a1f46
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b31a1f46
编写于
3月 31, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils
上级
fcd91c62
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
410 addition
and
71 deletion
+410
-71
.notebook/jit_infer.ipynb
.notebook/jit_infer.ipynb
+2
-2
deepspeech/exps/deepspeech2/bin/tune.py
deepspeech/exps/deepspeech2/bin/tune.py
+1
-1
deepspeech/models/deepspeech2.py
deepspeech/models/deepspeech2.py
+5
-3
deepspeech/modules/ctc.py
deepspeech/modules/ctc.py
+24
-15
deepspeech/training/scheduler.py
deepspeech/training/scheduler.py
+56
-0
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+9
-3
deepspeech/utils/checkpoint.py
deepspeech/utils/checkpoint.py
+35
-24
deepspeech/utils/cmvn.py
deepspeech/utils/cmvn.py
+93
-0
deepspeech/utils/ctc_utils.py
deepspeech/utils/ctc_utils.py
+128
-0
deepspeech/utils/tensor_utils.py
deepspeech/utils/tensor_utils.py
+18
-21
deepspeech/utils/utility.py
deepspeech/utils/utility.py
+39
-2
未找到文件。
.notebook/jit_infer.ipynb
浏览文件 @
b31a1f46
...
@@ -509,7 +509,7 @@
...
@@ -509,7 +509,7 @@
" print(audio_len.shape)\n",
" print(audio_len.shape)\n",
" \n",
" \n",
" #eouts, eouts_len = model.encoder(audio, audio_len)\n",
" #eouts, eouts_len = model.encoder(audio, audio_len)\n",
" #probs = model.decoder.
probs
(eouts)\n",
" #probs = model.decoder.
softmax
(eouts)\n",
" probs = model.forward(audio, audio_len)\n",
" probs = model.forward(audio, audio_len)\n",
" print('paddle:', probs.numpy())\n",
" print('paddle:', probs.numpy())\n",
" \n",
" \n",
...
@@ -666,4 +666,4 @@
...
@@ -666,4 +666,4 @@
},
},
"nbformat": 4,
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 2
}
}
\ No newline at end of file
deepspeech/exps/deepspeech2/bin/tune.py
浏览文件 @
b31a1f46
...
@@ -109,7 +109,7 @@ def tune(config, args):
...
@@ -109,7 +109,7 @@ def tune(config, args):
# model infer
# model infer
eouts
,
eouts_len
=
model
.
encoder
(
audio
,
audio_len
)
eouts
,
eouts_len
=
model
.
encoder
(
audio
,
audio_len
)
probs
=
model
.
decoder
.
probs
(
eouts
)
probs
=
model
.
decoder
.
softmax
(
eouts
)
# grid search
# grid search
for
index
,
(
alpha
,
beta
)
in
enumerate
(
params_grid
):
for
index
,
(
alpha
,
beta
)
in
enumerate
(
params_grid
):
...
...
deepspeech/models/deepspeech2.py
浏览文件 @
b31a1f46
...
@@ -203,7 +203,7 @@ class DeepSpeech2Model(nn.Layer):
...
@@ -203,7 +203,7 @@ class DeepSpeech2Model(nn.Layer):
decoding_method
=
decoding_method
)
decoding_method
=
decoding_method
)
eouts
,
eouts_len
=
self
.
encoder
(
audio
,
audio_len
)
eouts
,
eouts_len
=
self
.
encoder
(
audio
,
audio_len
)
probs
=
self
.
decoder
.
probs
(
eouts
)
probs
=
self
.
decoder
.
softmax
(
eouts
)
return
self
.
decoder
.
decode_probs
(
return
self
.
decoder
.
decode_probs
(
probs
.
numpy
(),
eouts_len
,
vocab_list
,
decoding_method
,
probs
.
numpy
(),
eouts_len
,
vocab_list
,
decoding_method
,
lang_model_path
,
beam_alpha
,
beam_beta
,
beam_size
,
cutoff_prob
,
lang_model_path
,
beam_alpha
,
beam_beta
,
beam_size
,
cutoff_prob
,
...
@@ -234,7 +234,9 @@ class DeepSpeech2Model(nn.Layer):
...
@@ -234,7 +234,9 @@ class DeepSpeech2Model(nn.Layer):
rnn_size
=
config
.
model
.
rnn_layer_size
,
rnn_size
=
config
.
model
.
rnn_layer_size
,
use_gru
=
config
.
model
.
use_gru
,
use_gru
=
config
.
model
.
use_gru
,
share_rnn_weights
=
config
.
model
.
share_rnn_weights
)
share_rnn_weights
=
config
.
model
.
share_rnn_weights
)
checkpoint
.
load_parameters
(
model
,
checkpoint_path
=
checkpoint_path
)
infos
=
checkpoint
.
load_parameters
(
model
,
checkpoint_path
=
checkpoint_path
)
logger
.
info
(
f
"checkpoint info:
{
infos
}
"
)
layer_tools
.
summary
(
model
)
layer_tools
.
summary
(
model
)
return
model
return
model
...
@@ -268,5 +270,5 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
...
@@ -268,5 +270,5 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
probs: probs after softmax
probs: probs after softmax
"""
"""
eouts
,
eouts_len
=
self
.
encoder
(
audio
,
audio_len
)
eouts
,
eouts_len
=
self
.
encoder
(
audio
,
audio_len
)
probs
=
self
.
decoder
.
probs
(
eouts
)
probs
=
self
.
decoder
.
softmax
(
eouts
)
return
probs
return
probs
deepspeech/modules/ctc.py
浏览文件 @
b31a1f46
...
@@ -20,10 +20,12 @@ from paddle import nn
...
@@ -20,10 +20,12 @@ from paddle import nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
from
paddle.nn
import
initializer
as
I
from
deepspeech.modules.loss
import
CTCLoss
from
deepspeech.utils
import
ctc_utils
from
deepspeech.decoders.swig_wrapper
import
Scorer
from
deepspeech.decoders.swig_wrapper
import
Scorer
from
deepspeech.decoders.swig_wrapper
import
ctc_greedy_decoder
from
deepspeech.decoders.swig_wrapper
import
ctc_greedy_decoder
from
deepspeech.decoders.swig_wrapper
import
ctc_beam_search_decoder_batch
from
deepspeech.decoders.swig_wrapper
import
ctc_beam_search_decoder_batch
from
deepspeech.modules.loss
import
CTCLoss
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -67,38 +69,31 @@ class CTCDecoder(nn.Layer):
...
@@ -67,38 +69,31 @@ class CTCDecoder(nn.Layer):
ys_pad (Tenosr): batch of padded character id sequence tensor (B, Lmax)
ys_pad (Tenosr): batch of padded character id sequence tensor (B, Lmax)
ys_lens (Tensor): batch of lengths of character sequence (B)
ys_lens (Tensor): batch of lengths of character sequence (B)
Returns:
Returns:
loss (Tenosr): scalar.
loss (Tenosr):
ctc loss value,
scalar.
"""
"""
logits
=
self
.
ctc_lo
(
F
.
dropout
(
hs_pad
,
p
=
self
.
dropout_rate
))
logits
=
self
.
ctc_lo
(
F
.
dropout
(
hs_pad
,
p
=
self
.
dropout_rate
))
loss
=
self
.
criterion
(
logits
,
ys_pad
,
hlens
,
ys_lens
)
loss
=
self
.
criterion
(
logits
,
ys_pad
,
hlens
,
ys_lens
)
return
loss
return
loss
def
probs
(
self
,
eouts
:
paddle
.
Tensor
,
temperature
:
float
=
1.0
):
def
softmax
(
self
,
eouts
:
paddle
.
Tensor
,
temperature
:
float
=
1.0
):
"""Get CTC probabilities.
"""Get CTC probabilities.
Args:
Args:
eouts (FloatTensor): `[B, T, enc_units]`
eouts (FloatTensor): `[B, T, enc_units]`
Returns:
Returns:
probs (FloatTensor): `[B, T, odim]`
probs (FloatTensor): `[B, T, odim]`
"""
"""
return
F
.
softmax
(
self
.
ctc_lo
(
eouts
)
/
temperature
,
axis
=-
1
)
self
.
probs
=
F
.
softmax
(
self
.
ctc_lo
(
eouts
)
/
temperature
,
axis
=
2
)
return
self
.
probs
def
scores
(
self
,
eouts
:
paddle
.
Tensor
,
temperature
:
float
=
1.0
):
"""Get log-scale CTC probabilities.
Args:
eouts (FloatTensor): `[B, T, enc_units]`
Returns:
log_probs (FloatTensor): `[B, T, odim]`
"""
return
F
.
log_softmax
(
self
.
ctc_lo
(
eouts
)
/
temperature
,
axis
=-
1
)
def
log_softmax
(
self
,
hs_pad
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
def
log_softmax
(
self
,
hs_pad
:
paddle
.
Tensor
,
temperature
:
float
=
1.0
)
->
paddle
.
Tensor
:
"""log_softmax of frame activations
"""log_softmax of frame activations
Args:
Args:
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
Returns:
paddle.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
paddle.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
"""
"""
return
self
.
scores
(
hs_pad
)
return
F
.
log_softmax
(
self
.
ctc_lo
(
hs_pad
)
/
temperature
,
axis
=
2
)
def
argmax
(
self
,
hs_pad
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
def
argmax
(
self
,
hs_pad
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""argmax of frame activations
"""argmax of frame activations
...
@@ -109,6 +104,20 @@ class CTCDecoder(nn.Layer):
...
@@ -109,6 +104,20 @@ class CTCDecoder(nn.Layer):
"""
"""
return
paddle
.
argmax
(
self
.
ctc_lo
(
hs_pad
),
dim
=
2
)
return
paddle
.
argmax
(
self
.
ctc_lo
(
hs_pad
),
dim
=
2
)
def
forced_align
(
self
,
ctc_probs
:
paddle
.
Tensor
,
y
:
paddle
.
Tensor
,
blank_id
=
0
)
->
list
:
"""ctc forced alignment.
Args:
ctc_probs (paddle.Tensor): hidden state sequence, 2d tensor (T, D)
y (paddle.Tensor): label id sequence tensor, 1d tensor (L)
blank_id (int): blank symbol index
Returns:
paddle.Tensor: best alignment result, (T).
"""
return
ctc_utils
.
forced_align
(
ctc_probs
,
y
,
blank_id
)
def
_decode_batch_greedy
(
self
,
probs_split
,
vocab_list
):
def
_decode_batch_greedy
(
self
,
probs_split
,
vocab_list
):
"""Decode by best path for a batch of probs matrix input.
"""Decode by best path for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists
:param probs_split: List of 2-D probability matrix, and each consists
...
...
deepspeech/
utils/metric
.py
→
deepspeech/
training/scheduler
.py
浏览文件 @
b31a1f46
...
@@ -12,32 +12,45 @@
...
@@ -12,32 +12,45 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
math
import
logging
import
logging
from
typing
import
Tuple
,
List
import
paddle
import
paddle
from
paddle.optimizer.lr
import
LRScheduler
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"
th_accuracy
"
]
__all__
=
[
"
WarmupLR
"
]
def
th_accuracy
(
pad_outputs
:
paddle
.
Tensor
,
class
WarmupLR
(
LRScheduler
):
pad_targets
:
paddle
.
Tensor
,
"""The WarmupLR scheduler
ignore_label
:
int
)
->
float
:
This scheduler is almost same as NoamLR Scheduler except for following
"""Calculate accuracy.
difference:
Args:
NoamLR:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
lr = optimizer.lr * model_size ** -0.5
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
* min(step ** -0.5, step * warmup_step ** -1.5)
ignore_label (int): Ignore label id.
WarmupLR:
Returns:
lr = optimizer.lr * warmup_step ** 0.5
float: Accuracy value (0.0 - 1.0).
* min(step ** -0.5, step * warmup_step ** -1.5)
Note that the maximum lr equals to optimizer.lr in this scheduler.
"""
"""
pad_pred
=
pad_outputs
.
view
(
pad_targets
.
size
(
0
),
pad_targets
.
size
(
1
),
pad_outputs
.
size
(
1
)).
argmax
(
2
)
def
__init__
(
self
,
mask
=
pad_targets
!=
ignore_label
warmup_steps
:
Union
[
int
,
float
]
=
25000
,
numerator
=
paddle
.
sum
(
learning_rate
=
1.0
,
pad_pred
.
masked_select
(
mask
)
==
pad_targets
.
masked_select
(
mask
))
last_epoch
=-
1
,
denominator
=
paddle
.
sum
(
mask
)
verbose
=
False
):
return
float
(
numerator
)
/
float
(
denominator
)
assert
check_argument_types
()
self
.
warmup_steps
=
warmup_steps
super
().
__init__
(
learning_rate
,
last_epoch
,
verbose
)
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
(warmup_steps=
{
self
.
warmup_steps
}
)"
def
get_lr
(
self
):
step_num
=
self
.
last_epoch
+
1
return
self
.
base_lr
*
self
.
warmup_steps
**
0.5
*
min
(
step_num
**-
0.5
,
step_num
*
self
.
warmup_steps
**-
1.5
)
def
set_step
(
self
,
step
:
int
):
self
.
last_epoch
=
step
deepspeech/training/trainer.py
浏览文件 @
b31a1f46
...
@@ -131,8 +131,13 @@ class Trainer():
...
@@ -131,8 +131,13 @@ class Trainer():
def
save
(
self
):
def
save
(
self
):
"""Save checkpoint (model parameters and optimizer states).
"""Save checkpoint (model parameters and optimizer states).
"""
"""
infos
=
{
"step"
:
self
.
iteration
,
"epoch"
:
self
.
epoch
,
"lr"
:
self
.
optimizer
.
get_lr
(),
}
checkpoint
.
save_parameters
(
self
.
checkpoint_dir
,
self
.
iteration
,
checkpoint
.
save_parameters
(
self
.
checkpoint_dir
,
self
.
iteration
,
self
.
model
,
self
.
optimizer
)
self
.
model
,
self
.
optimizer
,
infos
)
def
resume_or_load
(
self
):
def
resume_or_load
(
self
):
"""Resume from latest checkpoint at checkpoints in the output
"""Resume from latest checkpoint at checkpoints in the output
...
@@ -141,12 +146,13 @@ class Trainer():
...
@@ -141,12 +146,13 @@ class Trainer():
If ``args.checkpoint_path`` is not None, load the checkpoint, else
If ``args.checkpoint_path`` is not None, load the checkpoint, else
resume training.
resume training.
"""
"""
i
teration
=
checkpoint
.
load_parameters
(
i
nfos
=
checkpoint
.
load_parameters
(
self
.
model
,
self
.
model
,
self
.
optimizer
,
self
.
optimizer
,
checkpoint_dir
=
self
.
checkpoint_dir
,
checkpoint_dir
=
self
.
checkpoint_dir
,
checkpoint_path
=
self
.
args
.
checkpoint_path
)
checkpoint_path
=
self
.
args
.
checkpoint_path
)
self
.
iteration
=
iteration
self
.
iteration
=
infos
[
"step"
]
self
.
epoch
=
infos
[
"epoch"
]
def
new_epoch
(
self
):
def
new_epoch
(
self
):
"""Reset the train loader and increment ``epoch``.
"""Reset the train loader and increment ``epoch``.
...
...
deepspeech/utils/checkpoint.py
浏览文件 @
b31a1f46
...
@@ -16,6 +16,8 @@ import os
...
@@ -16,6 +16,8 @@ import os
import
time
import
time
import
logging
import
logging
import
numpy
as
np
import
numpy
as
np
import
re
import
json
import
paddle
import
paddle
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
...
@@ -37,15 +39,13 @@ def _load_latest_checkpoint(checkpoint_dir: str) -> int:
...
@@ -37,15 +39,13 @@ def _load_latest_checkpoint(checkpoint_dir: str) -> int:
int: the latest iteration number.
int: the latest iteration number.
"""
"""
checkpoint_record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
checkpoint_record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
if
(
not
os
.
path
.
isfile
(
checkpoint_record
)
):
if
not
os
.
path
.
isfile
(
checkpoint_record
):
return
0
return
0
# Fetch the latest checkpoint index.
# Fetch the latest checkpoint index.
with
open
(
checkpoint_record
,
"rt"
)
as
handle
:
with
open
(
checkpoint_record
,
"rt"
)
as
handle
:
latest_checkpoint
=
handle
.
readlines
()[
-
1
].
strip
()
latest_checkpoint
=
handle
.
readlines
()[
-
1
].
strip
()
step
=
latest_checkpoint
.
split
(
":"
)[
-
1
]
iteration
=
int
(
latest_checkpoint
.
split
(
":"
)[
-
1
])
iteration
=
int
(
step
.
split
(
"-"
)[
-
1
])
return
iteration
return
iteration
...
@@ -60,7 +60,7 @@ def _save_checkpoint(checkpoint_dir: str, iteration: int):
...
@@ -60,7 +60,7 @@ def _save_checkpoint(checkpoint_dir: str, iteration: int):
checkpoint_record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
checkpoint_record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
# Update the latest checkpoint index.
# Update the latest checkpoint index.
with
open
(
checkpoint_record
,
"a+"
)
as
handle
:
with
open
(
checkpoint_record
,
"a+"
)
as
handle
:
handle
.
write
(
"model_checkpoint_path:
step-
{}
\n
"
.
format
(
iteration
))
handle
.
write
(
"model_checkpoint_path:{}
\n
"
.
format
(
iteration
))
def
load_parameters
(
model
,
def
load_parameters
(
model
,
...
@@ -74,20 +74,16 @@ def load_parameters(model,
...
@@ -74,20 +74,16 @@ def load_parameters(model,
Defaults to None.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path and the argument 'checkpoint_dir' will
stored in the checkpoint_path
(prefix)
and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
be ignored. Defaults to None.
Returns:
Returns:
iteration (int): number of iterations that the loaded checkpoint has
configs (dict): epoch or step, lr and other meta info should be saved.
been trained.
"""
"""
if
checkpoint_path
is
not
None
:
if
checkpoint_path
is
not
None
:
iteration
=
int
(
os
.
path
.
basename
(
checkpoint_path
).
split
(
"
-
"
)[
-
1
])
iteration
=
int
(
os
.
path
.
basename
(
checkpoint_path
).
split
(
"
:
"
)[
-
1
])
elif
checkpoint_dir
is
not
None
:
elif
checkpoint_dir
is
not
None
:
iteration
=
_load_latest_checkpoint
(
checkpoint_dir
)
iteration
=
_load_latest_checkpoint
(
checkpoint_dir
)
if
iteration
==
0
:
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"-{}"
.
format
(
iteration
))
return
iteration
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"step-{}"
.
format
(
iteration
))
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
...
@@ -98,43 +94,58 @@ def load_parameters(model,
...
@@ -98,43 +94,58 @@ def load_parameters(model,
params_path
=
checkpoint_path
+
".pdparams"
params_path
=
checkpoint_path
+
".pdparams"
model_dict
=
paddle
.
load
(
params_path
)
model_dict
=
paddle
.
load
(
params_path
)
model
.
set_state_dict
(
model_dict
)
model
.
set_state_dict
(
model_dict
)
logger
.
info
(
logger
.
info
(
"Rank {}: loaded model from {}"
.
format
(
rank
,
params_path
))
"[checkpoint] Rank {}: loaded model from {}"
.
format
(
rank
,
params_path
))
optimizer_path
=
checkpoint_path
+
".pdopt"
optimizer_path
=
checkpoint_path
+
".pdopt"
if
optimizer
and
os
.
path
.
isfile
(
optimizer_path
):
if
optimizer
and
os
.
path
.
isfile
(
optimizer_path
):
optimizer_dict
=
paddle
.
load
(
optimizer_path
)
optimizer_dict
=
paddle
.
load
(
optimizer_path
)
optimizer
.
set_state_dict
(
optimizer_dict
)
optimizer
.
set_state_dict
(
optimizer_dict
)
logger
.
info
(
"
[checkpoint] Rank {}: loaded optimizer state from {}"
.
logger
.
info
(
"
Rank {}: loaded optimizer state from {}"
.
format
(
format
(
rank
,
optimizer_path
))
rank
,
optimizer_path
))
return
iteration
info_path
=
re
.
sub
(
'.pdparams$'
,
'.json'
,
params_path
)
configs
=
{}
if
os
.
path
.
exists
(
info_path
):
with
open
(
info_path
,
'r'
)
as
fin
:
configs
=
json
.
load
(
fin
)
return
configs
@
mp_tools
.
rank_zero_only
@
mp_tools
.
rank_zero_only
def
save_parameters
(
checkpoint_dir
,
iteration
,
model
,
optimizer
=
None
):
def
save_parameters
(
checkpoint_dir
:
str
,
iteration
:
int
,
model
:
paddle
.
nn
.
Layer
,
optimizer
:
Optimizer
=
None
,
infos
:
dict
=
None
):
"""Checkpoint the latest trained model parameters.
"""Checkpoint the latest trained model parameters.
Args:
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
iteration (int): the latest iteration
(step or epoch)
number.
model (Layer): model to be checkpointed.
model (Layer): model to be checkpointed.
optimizer (Optimizer, optional): optimizer to be checkpointed.
optimizer (Optimizer, optional): optimizer to be checkpointed.
Defaults to None.
Defaults to None.
infos (dict or None): any info you want to save.
Returns:
Returns:
None
None
"""
"""
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"
step
-{}"
.
format
(
iteration
))
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"-{}"
.
format
(
iteration
))
model_dict
=
model
.
state_dict
()
model_dict
=
model
.
state_dict
()
params_path
=
checkpoint_path
+
".pdparams"
params_path
=
checkpoint_path
+
".pdparams"
paddle
.
save
(
model_dict
,
params_path
)
paddle
.
save
(
model_dict
,
params_path
)
logger
.
info
(
"
[checkpoint]
Saved model to {}"
.
format
(
params_path
))
logger
.
info
(
"Saved model to {}"
.
format
(
params_path
))
if
optimizer
:
if
optimizer
:
opt_dict
=
optimizer
.
state_dict
()
opt_dict
=
optimizer
.
state_dict
()
optimizer_path
=
checkpoint_path
+
".pdopt"
optimizer_path
=
checkpoint_path
+
".pdopt"
paddle
.
save
(
opt_dict
,
optimizer_path
)
paddle
.
save
(
opt_dict
,
optimizer_path
)
logger
.
info
(
logger
.
info
(
"Saved optimzier state to {}"
.
format
(
optimizer_path
))
"[checkpoint] Saved optimzier state to {}"
.
format
(
optimizer_path
))
info_path
=
re
.
sub
(
'.pdparams$'
,
'.json'
,
params_path
)
if
infos
is
None
:
infos
=
{}
with
open
(
info_path
,
'w'
)
as
fout
:
data
=
json
.
dumps
(
infos
)
fout
.
write
(
data
)
_save_checkpoint
(
checkpoint_dir
,
iteration
)
_save_checkpoint
(
checkpoint_dir
,
iteration
)
deepspeech/utils/cmvn.py
0 → 100644
浏览文件 @
b31a1f46
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
math
import
logging
import
numpy
as
np
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'load_cmvn'
]
def
_load_json_cmvn
(
json_cmvn_file
):
""" Load the json format cmvn stats file and calculate cmvn
Args:
json_cmvn_file: cmvn stats file in json format
Returns:
a numpy array of [means, vars]
"""
with
open
(
json_cmvn_file
)
as
f
:
cmvn_stats
=
json
.
load
(
f
)
means
=
cmvn_stats
[
'mean_stat'
]
variance
=
cmvn_stats
[
'var_stat'
]
count
=
cmvn_stats
[
'frame_num'
]
for
i
in
range
(
len
(
means
)):
means
[
i
]
/=
count
variance
[
i
]
=
variance
[
i
]
/
count
-
means
[
i
]
*
means
[
i
]
if
variance
[
i
]
<
1.0e-20
:
variance
[
i
]
=
1.0e-20
variance
[
i
]
=
1.0
/
math
.
sqrt
(
variance
[
i
])
cmvn
=
np
.
array
([
means
,
variance
])
return
cmvn
def
_load_kaldi_cmvn
(
kaldi_cmvn_file
):
""" Load the kaldi format cmvn stats file and calculate cmvn
Args:
kaldi_cmvn_file: kaldi text style global cmvn file, which
is generated by:
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
Returns:
a numpy array of [means, vars]
"""
means
=
[]
variance
=
[]
with
open
(
kaldi_cmvn_file
,
'r'
)
as
fid
:
# kaldi binary file start with '\0B'
if
fid
.
read
(
2
)
==
'
\0
B'
:
logger
.
error
(
'kaldi cmvn binary file is not supported, please '
'recompute it by: compute-cmvn-stats --binary=false '
' scp:feats.scp global_cmvn'
)
sys
.
exit
(
1
)
fid
.
seek
(
0
)
arr
=
fid
.
read
().
split
()
assert
(
arr
[
0
]
==
'['
)
assert
(
arr
[
-
2
]
==
'0'
)
assert
(
arr
[
-
1
]
==
']'
)
feat_dim
=
int
((
len
(
arr
)
-
2
-
2
)
/
2
)
for
i
in
range
(
1
,
feat_dim
+
1
):
means
.
append
(
float
(
arr
[
i
]))
count
=
float
(
arr
[
feat_dim
+
1
])
for
i
in
range
(
feat_dim
+
2
,
2
*
feat_dim
+
2
):
variance
.
append
(
float
(
arr
[
i
]))
for
i
in
range
(
len
(
means
)):
means
[
i
]
/=
count
variance
[
i
]
=
variance
[
i
]
/
count
-
means
[
i
]
*
means
[
i
]
if
variance
[
i
]
<
1.0e-20
:
variance
[
i
]
=
1.0e-20
variance
[
i
]
=
1.0
/
math
.
sqrt
(
variance
[
i
])
cmvn
=
np
.
array
([
means
,
variance
])
return
cmvn
def
load_cmvn
(
cmvn_file
,
is_json
):
if
is_json
:
cmvn
=
_load_json_cmvn
(
cmvn_file
)
else
:
cmvn
=
_load_kaldi_cmvn
(
cmvn_file
)
return
cmvn
[
0
],
cmvn
[
1
]
deepspeech/utils/ctc_utils.py
0 → 100644
浏览文件 @
b31a1f46
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
import
numpy
as
np
from
typing
import
List
import
paddle
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"forced_align"
,
"remove_duplicates_and_blank"
,
"insert_blank"
]
def
remove_duplicates_and_blank
(
hyp
:
List
[
int
],
blank_id
=
0
)
->
List
[
int
]:
"""ctc alignment to ctc label ids.
"abaa-acee-" -> "abaace"
Args:
hyp (List[int]): hypotheses ids, (L)
blank_id (int, optional): blank id. Defaults to 0.
Returns:
List[int]: remove dupicate ids, then remove blank id.
"""
new_hyp
:
List
[
int
]
=
[]
cur
=
0
while
cur
<
len
(
hyp
):
if
hyp
[
cur
]
!=
blank_id
:
new_hyp
.
append
(
hyp
[
cur
])
prev
=
cur
while
cur
<
len
(
hyp
)
and
hyp
[
cur
]
==
hyp
[
prev
]:
cur
+=
1
return
new_hyp
def
insert_blank
(
label
:
np
.
ndarray
,
blank_id
:
int
=
0
):
"""Insert blank token between every two label token.
"abcdefg" -> "-a-b-c-d-e-f-g-"
Args:
label ([np.ndarray]): label ids, (L).
blank_id (int, optional): blank id. Defaults to 0.
Returns:
[np.ndarray]: (2L+1).
"""
label
=
np
.
expand_dims
(
label
,
1
)
#[L, 1]
blanks
=
np
.
zeros
((
label
.
shape
[
0
],
1
),
dtype
=
np
.
int64
)
+
blank_id
label
=
np
.
concatenate
([
blanks
,
label
],
axis
=
1
)
#[L, 2]
label
=
label
.
reshape
(
-
1
)
#[2L]
label
=
np
.
append
(
label
,
label
[
0
])
#[2L + 1]
return
label
def
forced_align
(
ctc_probs
:
paddle
.
Tensor
,
y
:
paddle
.
Tensor
,
blank_id
=
0
)
->
list
:
"""ctc forced alignment.
https://distill.pub/2017/ctc/
Args:
ctc_probs (paddle.Tensor): hidden state sequence, 2d tensor (T, D)
y (paddle.Tensor): label id sequence tensor, 1d tensor (L)
blank_id (int): blank symbol index
Returns:
paddle.Tensor: best alignment result, (T).
"""
y_insert_blank
=
insert_blank
(
y
,
blank_id
)
log_alpha
=
paddle
.
zeros
(
(
ctc_probs
.
size
(
0
),
len
(
y_insert_blank
)))
#(T, 2L+1)
log_alpha
=
log_alpha
-
float
(
'inf'
)
# log of zero
state_path
=
(
paddle
.
zeros
(
(
ctc_probs
.
size
(
0
),
len
(
y_insert_blank
)),
dtype
=
paddle
.
int16
)
-
1
)
# state path
# init start state
log_alpha
[
0
,
0
]
=
ctc_probs
[
0
][
y_insert_blank
[
0
]]
# Sb
log_alpha
[
0
,
1
]
=
ctc_probs
[
0
][
y_insert_blank
[
1
]]
# Snb
for
t
in
range
(
1
,
ctc_probs
.
size
(
0
)):
for
s
in
range
(
len
(
y_insert_blank
)):
if
y_insert_blank
[
s
]
==
blank_id
or
s
<
2
or
y_insert_blank
[
s
]
==
y_insert_blank
[
s
-
2
]:
candidates
=
paddle
.
to_tensor
(
[
log_alpha
[
t
-
1
,
s
],
log_alpha
[
t
-
1
,
s
-
1
]])
prev_state
=
[
s
,
s
-
1
]
else
:
candidates
=
paddle
.
to_tensor
([
log_alpha
[
t
-
1
,
s
],
log_alpha
[
t
-
1
,
s
-
1
],
log_alpha
[
t
-
1
,
s
-
2
],
])
prev_state
=
[
s
,
s
-
1
,
s
-
2
]
log_alpha
[
t
,
s
]
=
paddle
.
max
(
candidates
)
+
ctc_probs
[
t
][
y_insert_blank
[
s
]]
state_path
[
t
,
s
]
=
prev_state
[
paddle
.
argmax
(
candidates
)]
state_seq
=
-
1
*
paddle
.
ones
((
ctc_probs
.
size
(
0
),
1
),
dtype
=
paddle
.
int16
)
candidates
=
paddle
.
to_tensor
([
log_alpha
[
-
1
,
len
(
y_insert_blank
)
-
1
],
# Sb
log_alpha
[
-
1
,
len
(
y_insert_blank
)
-
2
]
# Snb
])
prev_state
=
[
len
(
y_insert_blank
)
-
1
,
len
(
y_insert_blank
)
-
2
]
state_seq
[
-
1
]
=
prev_state
[
paddle
.
argmax
(
candidates
)]
for
t
in
range
(
ctc_probs
.
size
(
0
)
-
2
,
-
1
,
-
1
):
state_seq
[
t
]
=
state_path
[
t
+
1
,
state_seq
[
t
+
1
,
0
]]
output_alignment
=
[]
for
t
in
range
(
0
,
ctc_probs
.
size
(
0
)):
output_alignment
.
append
(
y_insert_blank
[
state_seq
[
t
,
0
]])
return
output_alignment
deepspeech/utils/
common
.py
→
deepspeech/utils/
tensor_utils
.py
浏览文件 @
b31a1f46
...
@@ -20,7 +20,7 @@ import paddle
...
@@ -20,7 +20,7 @@ import paddle
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"pad_list"
,
"add_sos_eos"
,
"
remove_duplicates_and_blank"
,
"log_add
"
]
__all__
=
[
"pad_list"
,
"add_sos_eos"
,
"
th_accuracy
"
]
IGNORE_ID
=
-
1
IGNORE_ID
=
-
1
...
@@ -90,24 +90,21 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
...
@@ -90,24 +90,21 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
return
pad_list
(
ys_in
,
eos
),
pad_list
(
ys_out
,
ignore_id
)
return
pad_list
(
ys_in
,
eos
),
pad_list
(
ys_out
,
ignore_id
)
def
remove_duplicates_and_blank
(
hyp
:
List
[
int
])
->
List
[
int
]:
def
th_accuracy
(
pad_outputs
:
paddle
.
Tensor
,
new_hyp
:
List
[
int
]
=
[]
pad_targets
:
paddle
.
Tensor
,
cur
=
0
ignore_label
:
int
)
->
float
:
while
cur
<
len
(
hyp
):
"""Calculate accuracy.
if
hyp
[
cur
]
!=
0
:
Args:
new_hyp
.
append
(
hyp
[
cur
])
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
prev
=
cur
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
while
cur
<
len
(
hyp
)
and
hyp
[
cur
]
==
hyp
[
prev
]:
ignore_label (int): Ignore label id.
cur
+=
1
Returns:
return
new_hyp
float: Accuracy value (0.0 - 1.0).
def
log_add
(
args
:
List
[
int
])
->
float
:
"""
Stable log add
"""
"""
if
all
(
a
==
-
float
(
'inf'
)
for
a
in
args
):
pad_pred
=
pad_outputs
.
view
(
return
-
float
(
'inf'
)
pad_targets
.
size
(
0
),
pad_targets
.
size
(
1
),
pad_outputs
.
size
(
1
)).
argmax
(
2
)
a_max
=
max
(
args
)
mask
=
pad_targets
!=
ignore_label
lsp
=
math
.
log
(
sum
(
math
.
exp
(
a
-
a_max
)
for
a
in
args
))
numerator
=
paddle
.
sum
(
return
a_max
+
lsp
pad_pred
.
masked_select
(
mask
)
==
pad_targets
.
masked_select
(
mask
))
\ No newline at end of file
denominator
=
paddle
.
sum
(
mask
)
return
float
(
numerator
)
/
float
(
denominator
)
deepspeech/utils/utility.py
浏览文件 @
b31a1f46
...
@@ -13,10 +13,13 @@
...
@@ -13,10 +13,13 @@
# limitations under the License.
# limitations under the License.
"""Contains common utility functions."""
"""Contains common utility functions."""
import
math
import
numpy
as
np
import
numpy
as
np
import
distutils.util
import
distutils.util
__all__
=
[
'print_arguments'
,
'add_arguments'
]
__all__
=
[
'print_arguments'
,
'add_arguments'
,
"log_add"
,
"remove_duplicates_and_blank"
]
def
print_arguments
(
args
):
def
print_arguments
(
args
):
...
@@ -57,4 +60,38 @@ def add_arguments(argname, type, default, help, argparser, **kwargs):
...
@@ -57,4 +60,38 @@ def add_arguments(argname, type, default, help, argparser, **kwargs):
default
=
default
,
default
=
default
,
type
=
type
,
type
=
type
,
help
=
help
+
' Default: %(default)s.'
,
help
=
help
+
' Default: %(default)s.'
,
**
kwargs
)
**
kwargs
)
\ No newline at end of file
def
log_add
(
args
:
List
[
int
])
->
float
:
"""
Stable log add
"""
if
all
(
a
==
-
float
(
'inf'
)
for
a
in
args
):
return
-
float
(
'inf'
)
a_max
=
max
(
args
)
lsp
=
math
.
log
(
sum
(
math
.
exp
(
a
-
a_max
)
for
a
in
args
))
return
a_max
+
lsp
def
remove_duplicates_and_blank
(
hyp
:
List
[
int
],
blank_id
=
0
)
->
List
[
int
]:
"""ctc alignment to ctc label ids.
"abaa-acee-" -> "abaace"
Args:
hyp (List[int]): hypotheses ids, (L)
blank_id (int, optional): blank id. Defaults to 0.
Returns:
List[int]: remove dupicate ids, then remove blank id.
"""
new_hyp
:
List
[
int
]
=
[]
cur
=
0
while
cur
<
len
(
hyp
):
if
hyp
[
cur
]
!=
blank_id
:
new_hyp
.
append
(
hyp
[
cur
])
prev
=
cur
while
cur
<
len
(
hyp
)
and
hyp
[
cur
]
==
hyp
[
prev
]:
cur
+=
1
return
new_hyp
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录