Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
e5641ca4
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e5641ca4
编写于
4月 01, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs, refactor collator, add pad_sequence, fix ckpt bugs
上级
944457d6
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
880 addition
and
55 deletion
+880
-55
deepspeech/__init__.py
deepspeech/__init__.py
+93
-0
deepspeech/io/collator.py
deepspeech/io/collator.py
+39
-30
deepspeech/io/utility.py
deepspeech/io/utility.py
+82
-0
deepspeech/models/deepspeech2.py
deepspeech/models/deepspeech2.py
+5
-9
deepspeech/models/u2.py
deepspeech/models/u2.py
+638
-0
deepspeech/modules/conv.py
deepspeech/modules/conv.py
+3
-2
deepspeech/modules/rnn.py
deepspeech/modules/rnn.py
+1
-1
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+10
-8
deepspeech/utils/checkpoint.py
deepspeech/utils/checkpoint.py
+8
-5
deepspeech/utils/utility.py
deepspeech/utils/utility.py
+1
-0
未找到文件。
deepspeech/__init__.py
浏览文件 @
e5641ca4
...
...
@@ -13,6 +13,9 @@
# limitations under the License.
import
logging
from
typing
import
Union
from
typing
import
Optional
from
typing
import
List
from
typing
import
Tuple
from
typing
import
Any
import
paddle
...
...
@@ -83,6 +86,20 @@ if not hasattr(paddle.Tensor, 'numel'):
paddle
.
Tensor
.
numel
=
paddle
.
numel
def
new_full
(
x
:
paddle
.
Tensor
,
size
:
Union
[
List
[
int
],
Tuple
[
int
],
paddle
.
Tensor
],
fill_value
:
Union
[
float
,
int
,
bool
,
paddle
.
Tensor
],
dtype
=
None
):
return
paddle
.
full
(
size
,
fill_value
,
dtype
=
x
.
dtype
)
if
not
hasattr
(
paddle
.
Tensor
,
'new_full'
):
logger
.
warn
(
"override new_full of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle
.
Tensor
.
new_full
=
new_full
def
eq
(
xs
:
paddle
.
Tensor
,
ys
:
Union
[
paddle
.
Tensor
,
float
])
->
paddle
.
Tensor
:
return
xs
.
equal
(
paddle
.
to_tensor
(
ys
,
dtype
=
xs
.
dtype
,
place
=
xs
.
place
))
...
...
@@ -279,6 +296,7 @@ if not hasattr(paddle.nn, 'Module'):
logger
.
warn
(
"register user Module to paddle.nn, remove this when fixed!"
)
setattr
(
paddle
.
nn
,
'Module'
,
paddle
.
nn
.
Layer
)
# maybe cause assert isinstance(sublayer, core.Layer)
if
not
hasattr
(
paddle
.
nn
,
'ModuleList'
):
logger
.
warn
(
"register user ModuleList to paddle.nn, remove this when fixed!"
)
...
...
@@ -332,3 +350,78 @@ if not hasattr(paddle.nn, 'ConstantPad2d'):
logger
.
warn
(
"register user ConstantPad2d to paddle.nn, remove this when fixed!"
)
setattr
(
paddle
.
nn
,
'ConstantPad2d'
,
ConstantPad2d
)
########### hcak paddle.jit #############
if
not
hasattr
(
paddle
.
jit
,
'export'
):
logger
.
warn
(
"register user export to paddle.jit, remove this when fixed!"
)
setattr
(
paddle
.
jit
,
'export'
,
paddle
.
jit
.
to_static
)
########### hcak paddle.nn.utils #############
def
pad_sequence
(
sequences
:
List
[
paddle
.
Tensor
],
batch_first
:
bool
=
False
,
padding_value
:
float
=
0.0
)
->
paddle
.
Tensor
:
r
"""Pad a list of variable length Tensors with ``padding_value``
``pad_sequence`` stacks a list of Tensors along a new dimension,
and pads them to equal length. For example, if the input is list of
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
otherwise.
`B` is batch size. It is equal to the number of elements in ``sequences``.
`T` is length of the longest sequence.
`L` is length of the sequence.
`*` is any number of trailing dimensions, including none.
Example:
>>> from paddle.nn.utils.rnn import pad_sequence
>>> a = paddle.ones(25, 300)
>>> b = paddle.ones(22, 300)
>>> c = paddle.ones(15, 300)
>>> pad_sequence([a, b, c]).size()
paddle.Tensor([25, 3, 300])
Note:
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
where `T` is the length of the longest sequence. This function assumes
trailing dimensions and type of all the Tensors in sequences are same.
Args:
sequences (list[Tensor]): list of variable length sequences.
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
``T x B x *`` otherwise
padding_value (float, optional): value for padded elements. Default: 0.
Returns:
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
Tensor of size ``B x T x *`` otherwise
"""
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size
=
sequences
[
0
].
size
()
trailing_dims
=
max_size
[
1
:]
max_len
=
max
([
s
.
size
(
0
)
for
s
in
sequences
])
if
batch_first
:
out_dims
=
(
len
(
sequences
),
max_len
)
+
trailing_dims
else
:
out_dims
=
(
max_len
,
len
(
sequences
))
+
trailing_dims
out_tensor
=
sequences
[
0
].
new_full
(
out_dims
,
padding_value
)
for
i
,
tensor
in
enumerate
(
sequences
):
length
=
tensor
.
size
(
0
)
# use index notation to prevent duplicate references to the tensor
if
batch_first
:
out_tensor
[
i
,
:
length
,
...]
=
tensor
else
:
out_tensor
[:
length
,
i
,
...]
=
tensor
return
out_tensor
if
not
hasattr
(
paddle
.
nn
.
utils
,
'rnn.pad_sequence'
):
logger
.
warn
(
"register user rnn.pad_sequence to paddle.nn.utils, remove this when fixed!"
)
setattr
(
paddle
.
nn
.
utils
,
'rnn.pad_sequence'
,
pad_sequence
)
deepspeech/io/collator.py
浏览文件 @
e5641ca4
...
...
@@ -16,15 +16,15 @@ import logging
import
numpy
as
np
from
collections
import
namedtuple
from
deepspeech.io.utility
import
pad_sequence
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"SpeechCollator"
,
]
__all__
=
[
"SpeechCollator"
]
class
SpeechCollator
():
def
__init__
(
self
,
padding_to
=-
1
,
is_training
=
True
):
def
__init__
(
self
,
is_training
=
True
):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
...
...
@@ -32,42 +32,51 @@ class SpeechCollator():
If ``padding_to`` is -1, the maximun shape in the batch will be used
as the target shape for padding. Otherwise, `padding_to` will be the
target shape (only refers to the second axis).
if ``is_training`` is True, text is token ids else is raw string.
"""
self
.
_padding_to
=
padding_to
self
.
_is_training
=
is_training
def
__call__
(
self
,
batch
):
new_batch
=
[]
# get target shape
max_length
=
max
([
audio
.
shape
[
1
]
for
audio
,
_
in
batch
])
if
self
.
_padding_to
!=
-
1
:
if
self
.
_padding_to
<
max_length
:
raise
ValueError
(
"If padding_to is not -1, it should be larger "
"than any instance's shape in the batch"
)
max_length
=
self
.
_padding_to
max_text_length
=
max
([
len
(
text
)
for
_
,
text
in
batch
])
# padding
padded_audios
=
[]
"""batch examples
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T)
text (List[int] or str): shape (U,)
Returns:
tuple(audio, text, audio_lens, text_lens): batched data.
audio : (B, Tmax, D)
text : (B, Umax)
audio_lens: (B)
text_lens: (B)
"""
audios
=
[]
audio_lens
=
[]
texts
,
text_lens
=
[],
[]
texts
=
[]
text_lens
=
[]
for
audio
,
text
in
batch
:
# audio
padded_audio
=
np
.
zeros
([
audio
.
shape
[
0
],
max_length
])
padded_audio
[:,
:
audio
.
shape
[
1
]]
=
audio
padded_audios
.
append
(
padded_audio
)
audios
.
append
(
audio
.
T
)
# [T, D]
audio_lens
.
append
(
audio
.
shape
[
1
])
# text
padded_text
=
np
.
zeros
([
max_text_length
])
# for training, text is token ids
# else text is string, convert to unicode ord
tokens
=
[]
if
self
.
_is_training
:
padded_text
[:
len
(
text
)]
=
text
# token ids
tokens
=
text
# token ids
else
:
padded_text
[:
len
(
text
)]
=
[
ord
(
t
)
for
t
in
text
]
# string, unicode ord
texts
.
append
(
padded_text
)
assert
isinstance
(
text
,
str
)
tokens
=
[
ord
(
t
)
for
t
in
text
]
tokens
=
tokens
if
isinstance
(
tokens
,
np
.
ndarray
)
else
np
.
array
(
tokens
,
dtype
=
np
.
int64
)
texts
.
append
(
tokens
)
text_lens
.
append
(
len
(
text
))
padded_audios
=
np
.
array
(
padded_audios
).
astype
(
'float32'
)
audio_lens
=
np
.
array
(
audio_lens
).
astype
(
'int64'
)
texts
=
np
.
array
(
texts
).
astype
(
'int32'
)
text_lens
=
np
.
array
(
text_lens
).
astype
(
'int64'
)
return
padded_audios
,
texts
,
audio_lens
,
text_lens
padded_audios
=
pad_sequence
(
audios
,
padding_value
=
0.0
).
astype
(
np
.
float32
)
#[B, T, D]
padded_texts
=
pad_sequence
(
texts
,
padding_value
=-
1
).
astype
(
np
.
int32
)
audio_lens
=
np
.
array
(
audio_lens
).
astype
(
np
.
int64
)
text_lens
=
np
.
array
(
text_lens
).
astype
(
np
.
int64
)
return
padded_audios
,
padded_texts
,
audio_lens
,
text_lens
deepspeech/io/utility.py
0 → 100644
浏览文件 @
e5641ca4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
import
numpy
as
np
from
collections
import
namedtuple
from
typing
import
List
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"pad_sequence"
]
def
pad_sequence
(
sequences
:
List
[
np
.
ndarray
],
batch_first
:
bool
=
True
,
padding_value
:
float
=
0.0
)
->
np
.
ndarray
:
r
"""Pad a list of variable length Tensors with ``padding_value``
``pad_sequence`` stacks a list of Tensors along a new dimension,
and pads them to equal length. For example, if the input is list of
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
otherwise.
`B` is batch size. It is equal to the number of elements in ``sequences``.
`T` is length of the longest sequence.
`L` is length of the sequence.
`*` is any number of trailing dimensions, including none.
Example:
>>> a = np.ones([25, 300])
>>> b = np.ones([22, 300])
>>> c = np.ones([15, 300])
>>> pad_sequence([a, b, c]).shape
[25, 3, 300]
Note:
This function returns a np.ndarray of size ``T x B x *`` or ``B x T x *``
where `T` is the length of the longest sequence. This function assumes
trailing dimensions and type of all the Tensors in sequences are same.
Args:
sequences (list[np.ndarray]): list of variable length sequences.
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
``T x B x *`` otherwise
padding_value (float, optional): value for padded elements. Default: 0.
Returns:
np.ndarray of size ``T x B x *`` if :attr:`batch_first` is ``False``.
np.ndarray of size ``B x T x *`` otherwise
"""
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size
=
sequences
[
0
].
shape
trailing_dims
=
max_size
[
1
:]
max_len
=
max
([
s
.
shape
[
0
]
for
s
in
sequences
])
if
batch_first
:
out_dims
=
(
len
(
sequences
),
max_len
)
+
trailing_dims
else
:
out_dims
=
(
max_len
,
len
(
sequences
))
+
trailing_dims
out_tensor
=
np
.
full
(
out_dims
,
padding_value
,
dtype
=
sequences
[
0
].
dtype
)
for
i
,
tensor
in
enumerate
(
sequences
):
length
=
tensor
.
shape
[
0
]
# use index notation to prevent duplicate references to the tensor
if
batch_first
:
out_tensor
[
i
,
:
length
,
...]
=
tensor
else
:
out_tensor
[:
length
,
i
,
...]
=
tensor
return
out_tensor
deepspeech/models/deepspeech2.py
浏览文件 @
e5641ca4
...
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Model"""
import
math
import
collections
import
numpy
as
np
...
...
@@ -67,23 +67,19 @@ class CRNNEncoder(nn.Layer):
return
self
.
rnn_size
*
2
def
forward
(
self
,
audio
,
audio_len
):
"""
audio: shape [B, D, T]
text: shape [B, T]
audio_len: shape [B]
text_len: shape [B]
"""
"""Compute Encoder outputs
Args:
audio (Tensor): [B,
D, T
]
text (Tensor): [B,
T
]
audio (Tensor): [B,
Tmax, D
]
text (Tensor): [B,
Umax
]
audio_len (Tensor): [B]
text_len (Tensor): [B]
Returns:
x (Tensor): encoder outputs, [B, T, D]
x_lens (Tensor): encoder length, [B]
"""
# [B, T, D] -> [B, D, T]
audio
=
audio
.
transpose
([
0
,
2
,
1
])
# [B, D, T] -> [B, C=1, D, T]
x
=
audio
.
unsqueeze
(
1
)
x_lens
=
audio_len
...
...
deepspeech/models/u2.py
0 → 100644
浏览文件 @
e5641ca4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""U2 ASR Model
Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
(https://arxiv.org/pdf/2012.05481.pdf)
"""
import
math
import
collections
from
collections
import
defaultdict
import
numpy
as
np
import
logging
from
yacs.config
import
CfgNode
from
typing
import
List
,
Optional
,
Tuple
import
paddle
from
paddle
import
jit
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
from
paddle.nn.utils.rnn
import
pad_sequence
from
deepspeech.modules.cmvn
import
GlobalCMVN
from
deepspeech.modules.encoder
import
ConformerEncoder
from
deepspeech.modules.encoder
import
TransformerEncoder
from
deepspeech.modules.ctc
import
CTCDecoder
from
deepspeech.modules.decoder
import
TransformerDecoder
from
deepspeech.modules.label_smoothing_loss
import
LabelSmoothingLoss
from
deepspeech.modules.mask
import
make_pad_mask
from
deepspeech.modules.mask
import
mask_finished_preds
from
deepspeech.modules.mask
import
mask_finished_scores
from
deepspeech.modules.mask
import
subsequent_mask
from
deepspeech.utils
import
checkpoint
from
deepspeech.utils
import
layer_tools
from
deepspeech.utils.cmvn
import
load_cmvn
from
deepspeech.utils.utility
import
log_add
from
deepspeech.utils.tensor_utils
import
IGNORE_ID
from
deepspeech.utils.tensor_utils
import
add_sos_eos
from
deepspeech.utils.tensor_utils
import
th_accuracy
from
deepspeech.utils.ctc_utils
import
remove_duplicates_and_blank
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'U2Model'
]
class
U2Model
(
nn
.
Module
):
"""CTC-Attention hybrid Encoder-Decoder model"""
def
__init__
(
self
,
vocab_size
:
int
,
encoder
:
TransformerEncoder
,
decoder
:
TransformerDecoder
,
ctc
:
CTCDecoder
,
ctc_weight
:
float
=
0.5
,
ignore_id
:
int
=
IGNORE_ID
,
lsm_weight
:
float
=
0.0
,
length_normalized_loss
:
bool
=
False
,
):
assert
0.0
<=
ctc_weight
<=
1.0
,
ctc_weight
super
().
__init__
()
# note that eos is the same as sos (equivalent ID)
self
.
sos
=
vocab_size
-
1
self
.
eos
=
vocab_size
-
1
self
.
vocab_size
=
vocab_size
self
.
ignore_id
=
ignore_id
self
.
ctc_weight
=
ctc_weight
self
.
encoder
=
encoder
self
.
decoder
=
decoder
self
.
ctc
=
ctc
self
.
criterion_att
=
LabelSmoothingLoss
(
size
=
vocab_size
,
padding_idx
=
ignore_id
,
smoothing
=
lsm_weight
,
normalize_length
=
length_normalized_loss
,
)
def
forward
(
self
,
speech
:
paddle
.
Tensor
,
speech_lengths
:
paddle
.
Tensor
,
text
:
paddle
.
Tensor
,
text_lengths
:
paddle
.
Tensor
,
)
->
Tuple
[
Optional
[
paddle
.
Tensor
],
Optional
[
paddle
.
Tensor
],
Optional
[
paddle
.
Tensor
]]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert
text_lengths
.
dim
()
==
1
,
text_lengths
.
shape
# Check that batch_size is unified
assert
(
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
==
text
.
shape
[
0
]
==
text_lengths
.
shape
[
0
]),
(
speech
.
shape
,
speech_lengths
.
shape
,
text
.
shape
,
text_lengths
.
shape
)
# 1. Encoder
encoder_out
,
encoder_mask
=
self
.
encoder
(
speech
,
speech_lengths
)
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
sum
(
1
)
# 2a. Attention-decoder branch
if
self
.
ctc_weight
!=
1.0
:
loss_att
,
acc_att
=
self
.
_calc_att_loss
(
encoder_out
,
encoder_mask
,
text
,
text_lengths
)
else
:
loss_att
=
None
# 2b. CTC branch
if
self
.
ctc_weight
!=
0.0
:
loss_ctc
=
self
.
ctc
(
encoder_out
,
encoder_out_lens
,
text
,
text_lengths
)
else
:
loss_ctc
=
None
if
loss_ctc
is
None
:
loss
=
loss_att
elif
loss_att
is
None
:
loss
=
loss_ctc
else
:
loss
=
self
.
ctc_weight
*
loss_ctc
+
(
1
-
self
.
ctc_weight
)
*
loss_att
return
loss
,
loss_att
,
loss_ctc
def
_calc_att_loss
(
self
,
encoder_out
:
paddle
.
Tensor
,
encoder_mask
:
paddle
.
Tensor
,
ys_pad
:
paddle
.
Tensor
,
ys_pad_lens
:
paddle
.
Tensor
,
)
->
Tuple
[
paddle
.
Tensor
,
float
]:
ys_in_pad
,
ys_out_pad
=
add_sos_eos
(
ys_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
ys_in_lens
=
ys_pad_lens
+
1
# 1. Forward decoder
decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
ys_in_pad
,
ys_in_lens
)
# 2. Compute attention loss
loss_att
=
self
.
criterion_att
(
decoder_out
,
ys_out_pad
)
acc_att
=
th_accuracy
(
decoder_out
.
view
(
-
1
,
self
.
vocab_size
),
ys_out_pad
,
ignore_label
=
self
.
ignore_id
,
)
return
loss_att
,
acc_att
def
_forward_encoder
(
self
,
speech
:
paddle
.
Tensor
,
speech_lengths
:
paddle
.
Tensor
,
decoding_chunk_size
:
int
=-
1
,
num_decoding_left_chunks
:
int
=-
1
,
simulate_streaming
:
bool
=
False
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
# Let's assume B = batch_size
# 1. Encoder
if
simulate_streaming
and
decoding_chunk_size
>
0
:
encoder_out
,
encoder_mask
=
self
.
encoder
.
forward_chunk_by_chunk
(
speech
,
decoding_chunk_size
=
decoding_chunk_size
,
num_decoding_left_chunks
=
num_decoding_left_chunks
)
# (B, maxlen, encoder_dim)
else
:
encoder_out
,
encoder_mask
=
self
.
encoder
(
speech
,
speech_lengths
,
decoding_chunk_size
=
decoding_chunk_size
,
num_decoding_left_chunks
=
num_decoding_left_chunks
)
# (B, maxlen, encoder_dim)
return
encoder_out
,
encoder_mask
def
recognize
(
self
,
speech
:
paddle
.
Tensor
,
speech_lengths
:
paddle
.
Tensor
,
beam_size
:
int
=
10
,
decoding_chunk_size
:
int
=-
1
,
num_decoding_left_chunks
:
int
=-
1
,
simulate_streaming
:
bool
=
False
,
)
->
paddle
.
Tensor
:
""" Apply beam search on attention decoder
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
speech_length (paddle.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
paddle.Tensor: decoding result, (batch, max_result_len)
"""
assert
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
assert
decoding_chunk_size
!=
0
device
=
speech
.
device
batch_size
=
speech
.
shape
[
0
]
# Let's assume B = batch_size and N = beam_size
# 1. Encoder
encoder_out
,
encoder_mask
=
self
.
_forward_encoder
(
speech
,
speech_lengths
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
# (B, maxlen, encoder_dim)
maxlen
=
encoder_out
.
size
(
1
)
encoder_dim
=
encoder_out
.
size
(
2
)
running_size
=
batch_size
*
beam_size
encoder_out
=
encoder_out
.
unsqueeze
(
1
).
repeat
(
1
,
beam_size
,
1
,
1
).
view
(
running_size
,
maxlen
,
encoder_dim
)
# (B*N, maxlen, encoder_dim)
encoder_mask
=
encoder_mask
.
unsqueeze
(
1
).
repeat
(
1
,
beam_size
,
1
,
1
).
view
(
running_size
,
1
,
maxlen
)
# (B*N, 1, max_len)
hyps
=
torch
.
ones
(
[
running_size
,
1
],
dtype
=
torch
.
long
,
device
=
device
).
fill_
(
self
.
sos
)
# (B*N, 1)
scores
=
paddle
.
tensor
(
[
0.0
]
+
[
-
float
(
'inf'
)]
*
(
beam_size
-
1
),
dtype
=
torch
.
float
)
scores
=
scores
.
to
(
device
).
repeat
([
batch_size
]).
unsqueeze
(
1
).
to
(
device
)
# (B*N, 1)
end_flag
=
torch
.
zeros_like
(
scores
,
dtype
=
torch
.
bool
,
device
=
device
)
cache
:
Optional
[
List
[
paddle
.
Tensor
]]
=
None
# 2. Decoder forward step by step
for
i
in
range
(
1
,
maxlen
+
1
):
# Stop if all batch and all beam produce eos
if
end_flag
.
sum
()
==
running_size
:
break
# 2.1 Forward decoder step
hyps_mask
=
subsequent_mask
(
i
).
unsqueeze
(
0
).
repeat
(
running_size
,
1
,
1
).
to
(
device
)
# (B*N, i, i)
# logp: (B*N, vocab)
logp
,
cache
=
self
.
decoder
.
forward_one_step
(
encoder_out
,
encoder_mask
,
hyps
,
hyps_mask
,
cache
)
# 2.2 First beam prune: select topk best prob at current time
top_k_logp
,
top_k_index
=
logp
.
topk
(
beam_size
)
# (B*N, N)
top_k_logp
=
mask_finished_scores
(
top_k_logp
,
end_flag
)
top_k_index
=
mask_finished_preds
(
top_k_index
,
end_flag
,
self
.
eos
)
# 2.3 Seconde beam prune: select topk score with history
scores
=
scores
+
top_k_logp
# (B*N, N), broadcast add
scores
=
scores
.
view
(
batch_size
,
beam_size
*
beam_size
)
# (B, N*N)
scores
,
offset_k_index
=
scores
.
topk
(
k
=
beam_size
)
# (B, N)
scores
=
scores
.
view
(
-
1
,
1
)
# (B*N, 1)
# 2.4. Compute base index in top_k_index,
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
# then find offset_k_index in top_k_index
base_k_index
=
torch
.
arange
(
batch_size
,
device
=
device
).
view
(
-
1
,
1
).
repeat
([
1
,
beam_size
])
# (B, N)
base_k_index
=
base_k_index
*
beam_size
*
beam_size
best_k_index
=
base_k_index
.
view
(
-
1
)
+
offset_k_index
.
view
(
-
1
)
# (B*N)
# 2.5 Update best hyps
best_k_pred
=
torch
.
index_select
(
top_k_index
.
view
(
-
1
),
dim
=-
1
,
index
=
best_k_index
)
# (B*N)
best_hyps_index
=
best_k_index
//
beam_size
last_best_k_hyps
=
torch
.
index_select
(
hyps
,
dim
=
0
,
index
=
best_hyps_index
)
# (B*N, i)
hyps
=
torch
.
cat
(
(
last_best_k_hyps
,
best_k_pred
.
view
(
-
1
,
1
)),
dim
=
1
)
# (B*N, i+1)
# 2.6 Update end flag
end_flag
=
torch
.
eq
(
hyps
[:,
-
1
],
self
.
eos
).
view
(
-
1
,
1
)
# 3. Select best of best
scores
=
scores
.
view
(
batch_size
,
beam_size
)
# TODO: length normalization
best_index
=
torch
.
argmax
(
scores
,
dim
=-
1
).
long
()
best_hyps_index
=
best_index
+
torch
.
arange
(
batch_size
,
dtype
=
torch
.
long
,
device
=
device
)
*
beam_size
best_hyps
=
torch
.
index_select
(
hyps
,
dim
=
0
,
index
=
best_hyps_index
)
best_hyps
=
best_hyps
[:,
1
:]
return
best_hyps
def
ctc_greedy_search
(
self
,
speech
:
paddle
.
Tensor
,
speech_lengths
:
paddle
.
Tensor
,
decoding_chunk_size
:
int
=-
1
,
num_decoding_left_chunks
:
int
=-
1
,
simulate_streaming
:
bool
=
False
,
)
->
List
[
List
[
int
]]:
""" Apply CTC greedy search
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
speech_length (paddle.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[List[int]]: best path result
"""
assert
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
assert
decoding_chunk_size
!=
0
batch_size
=
speech
.
shape
[
0
]
# Let's assume B = batch_size
encoder_out
,
encoder_mask
=
self
.
_forward_encoder
(
speech
,
speech_lengths
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
# (B, maxlen, encoder_dim)
maxlen
=
encoder_out
.
size
(
1
)
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
sum
(
1
)
ctc_probs
=
self
.
ctc
.
log_softmax
(
encoder_out
)
# (B, maxlen, vocab_size)
topk_prob
,
topk_index
=
ctc_probs
.
topk
(
1
,
dim
=
2
)
# (B, maxlen, 1)
topk_index
=
topk_index
.
view
(
batch_size
,
maxlen
)
# (B, maxlen)
mask
=
make_pad_mask
(
encoder_out_lens
)
# (B, maxlen)
topk_index
=
topk_index
.
masked_fill_
(
mask
,
self
.
eos
)
# (B, maxlen)
hyps
=
[
hyp
.
tolist
()
for
hyp
in
topk_index
]
hyps
=
[
remove_duplicates_and_blank
(
hyp
)
for
hyp
in
hyps
]
return
hyps
def
_ctc_prefix_beam_search
(
self
,
speech
:
paddle
.
Tensor
,
speech_lengths
:
paddle
.
Tensor
,
beam_size
:
int
,
decoding_chunk_size
:
int
=-
1
,
num_decoding_left_chunks
:
int
=-
1
,
simulate_streaming
:
bool
=
False
,
)
->
Tuple
[
List
[
List
[
int
]],
paddle
.
Tensor
]:
""" CTC prefix beam search inner implementation
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
speech_length (paddle.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[List[int]]: nbest results
paddle.Tensor: encoder output, (1, max_len, encoder_dim),
it will be used for rescoring in attention rescoring mode
"""
assert
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
assert
decoding_chunk_size
!=
0
batch_size
=
speech
.
shape
[
0
]
# For CTC prefix beam search, we only support batch_size=1
assert
batch_size
==
1
# Let's assume B = batch_size and N = beam_size
# 1. Encoder forward and get CTC score
encoder_out
,
encoder_mask
=
self
.
_forward_encoder
(
speech
,
speech_lengths
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
# (B, maxlen, encoder_dim)
maxlen
=
encoder_out
.
size
(
1
)
ctc_probs
=
self
.
ctc
.
log_softmax
(
encoder_out
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
cur_hyps
=
[(
tuple
(),
(
0.0
,
-
float
(
'inf'
)))]
# 2. CTC beam search step by step
for
t
in
range
(
0
,
maxlen
):
logp
=
ctc_probs
[
t
]
# (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps
=
defaultdict
(
lambda
:
(
-
float
(
'inf'
),
-
float
(
'inf'
)))
# 2.1 First beam prune: select topk best
top_k_logp
,
top_k_index
=
logp
.
topk
(
beam_size
)
# (beam_size,)
for
s
in
top_k_index
:
s
=
s
.
item
()
ps
=
logp
[
s
].
item
()
for
prefix
,
(
pb
,
pnb
)
in
cur_hyps
:
last
=
prefix
[
-
1
]
if
len
(
prefix
)
>
0
else
None
if
s
==
0
:
# blank
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
n_pb
=
log_add
([
n_pb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
elif
s
==
last
:
# Update *ss -> *s;
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
n_pnb
=
log_add
([
n_pnb
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
# Update *s-s -> *ss, - is for blank
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
else
:
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
# 2.2 Second beam prune
next_hyps
=
sorted
(
next_hyps
.
items
(),
key
=
lambda
x
:
log_add
(
list
(
x
[
1
])),
reverse
=
True
)
cur_hyps
=
next_hyps
[:
beam_size
]
hyps
=
[(
y
[
0
],
log_add
([
y
[
1
][
0
],
y
[
1
][
1
]]))
for
y
in
cur_hyps
]
return
hyps
,
encoder_out
def
ctc_prefix_beam_search
(
self
,
speech
:
paddle
.
Tensor
,
speech_lengths
:
paddle
.
Tensor
,
beam_size
:
int
,
decoding_chunk_size
:
int
=-
1
,
num_decoding_left_chunks
:
int
=-
1
,
simulate_streaming
:
bool
=
False
,
)
->
List
[
int
]:
""" Apply CTC prefix beam search
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
speech_length (paddle.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[int]: CTC prefix beam search nbest results
"""
hyps
,
_
=
self
.
_ctc_prefix_beam_search
(
speech
,
speech_lengths
,
beam_size
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
return
hyps
[
0
][
0
]
def
attention_rescoring
(
self
,
speech
:
paddle
.
Tensor
,
speech_lengths
:
paddle
.
Tensor
,
beam_size
:
int
,
decoding_chunk_size
:
int
=-
1
,
num_decoding_left_chunks
:
int
=-
1
,
ctc_weight
:
float
=
0.0
,
simulate_streaming
:
bool
=
False
,
)
->
List
[
int
]:
""" Apply attention rescoring decoding, CTC prefix beam search
is applied first to get nbest, then we resoring the nbest on
attention decoder with corresponding encoder out
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
speech_length (paddle.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[int]: Attention rescoring result
"""
assert
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
assert
decoding_chunk_size
!=
0
device
=
speech
.
device
batch_size
=
speech
.
shape
[
0
]
# For attention rescoring we only support batch_size=1
assert
batch_size
==
1
# encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size
hyps
,
encoder_out
=
self
.
_ctc_prefix_beam_search
(
speech
,
speech_lengths
,
beam_size
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
assert
len
(
hyps
)
==
beam_size
hyps_pad
=
pad_sequence
([
paddle
.
tensor
(
hyp
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
for
hyp
in
hyps
],
True
,
self
.
ignore_id
)
# (beam_size, max_hyps_len)
hyps_lens
=
paddle
.
tensor
(
[
len
(
hyp
[
0
])
for
hyp
in
hyps
],
device
=
device
,
dtype
=
torch
.
long
)
# (beam_size,)
hyps_pad
,
_
=
add_sos_eos
(
hyps_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
hyps_lens
=
hyps_lens
+
1
# Add <sos> at begining
encoder_out
=
encoder_out
.
repeat
(
beam_size
,
1
,
1
)
encoder_mask
=
torch
.
ones
(
beam_size
,
1
,
encoder_out
.
size
(
1
),
dtype
=
torch
.
bool
,
device
=
device
)
decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
hyps_lens
)
# (beam_size, max_hyps_len, vocab_size)
decoder_out
=
torch
.
nn
.
functional
.
log_softmax
(
decoder_out
,
dim
=-
1
)
decoder_out
=
decoder_out
.
cpu
().
numpy
()
# Only use decoder score for rescoring
best_score
=
-
float
(
'inf'
)
best_index
=
0
for
i
,
hyp
in
enumerate
(
hyps
):
score
=
0.0
for
j
,
w
in
enumerate
(
hyp
[
0
]):
score
+=
decoder_out
[
i
][
j
][
w
]
score
+=
decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
eos
]
# add ctc score
score
+=
hyp
[
1
]
*
ctc_weight
if
score
>
best_score
:
best_score
=
score
best_index
=
i
return
hyps
[
best_index
][
0
]
@
jit
.
export
def
subsampling_rate
(
self
)
->
int
:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return
self
.
encoder
.
embed
.
subsampling_rate
@
jit
.
export
def
right_context
(
self
)
->
int
:
""" Export interface for c++ call, return right_context of the model
"""
return
self
.
encoder
.
embed
.
right_context
@
jit
.
export
def
sos_symbol
(
self
)
->
int
:
""" Export interface for c++ call, return sos symbol id of the model
"""
return
self
.
sos
@
jit
.
export
def
eos_symbol
(
self
)
->
int
:
""" Export interface for c++ call, return eos symbol id of the model
"""
return
self
.
eos
@
jit
.
export
def
forward_encoder_chunk
(
self
,
xs
:
paddle
.
Tensor
,
offset
:
int
,
required_cache_size
:
int
,
subsampling_cache
:
Optional
[
paddle
.
Tensor
]
=
None
,
elayers_output_cache
:
Optional
[
List
[
paddle
.
Tensor
]]
=
None
,
conformer_cnn_cache
:
Optional
[
List
[
paddle
.
Tensor
]]
=
None
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
List
[
paddle
.
Tensor
],
List
[
paddle
.
Tensor
]]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
Args:
xs (paddle.Tensor): chunk input
subsampling_cache (Optional[paddle.Tensor]): subsampling cache
elayers_output_cache (Optional[List[paddle.Tensor]]):
transformer/conformer encoder layers output cache
conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer
cnn cache
Returns:
paddle.Tensor: output, it ranges from time 0 to current chunk.
paddle.Tensor: subsampling cache
List[paddle.Tensor]: attention cache
List[paddle.Tensor]: conformer cnn cache
"""
return
self
.
encoder
.
forward_chunk
(
xs
,
offset
,
required_cache_size
,
subsampling_cache
,
elayers_output_cache
,
conformer_cnn_cache
)
@
jit
.
export
def
ctc_activation
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
Args:
xs (paddle.Tensor): encoder output
Returns:
paddle.Tensor: activation before ctc
"""
return
self
.
ctc
.
log_softmax
(
xs
)
@
jit
.
export
def
forward_attention_decoder
(
self
,
hyps
:
paddle
.
Tensor
,
hyps_lens
:
paddle
.
Tensor
,
encoder_out
:
paddle
.
Tensor
,
)
->
paddle
.
Tensor
:
""" Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output
Args:
hyps (paddle.Tensor): hyps from ctc prefix beam search, already
pad sos at the begining
hyps_lens (paddle.Tensor): length of each hyp in hyps
encoder_out (paddle.Tensor): corresponding encoder output
Returns:
paddle.Tensor: decoder output
"""
assert
encoder_out
.
size
(
0
)
==
1
num_hyps
=
hyps
.
size
(
0
)
assert
hyps_lens
.
size
(
0
)
==
num_hyps
encoder_out
=
encoder_out
.
repeat
(
num_hyps
,
1
,
1
)
encoder_mask
=
torch
.
ones
(
num_hyps
,
1
,
encoder_out
.
size
(
1
),
dtype
=
torch
.
bool
,
device
=
encoder_out
.
device
)
decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
hyps
,
hyps_lens
)
# (num_hyps, max_hyps_len, vocab_size)
decoder_out
=
torch
.
nn
.
functional
.
log_softmax
(
decoder_out
,
dim
=-
1
)
return
decoder_out
def
init_asr_model
(
configs
):
if
configs
[
'cmvn_file'
]
is
not
None
:
mean
,
istd
=
load_cmvn
(
configs
[
'cmvn_file'
],
configs
[
'is_json_cmvn'
])
global_cmvn
=
GlobalCMVN
(
torch
.
from_numpy
(
mean
).
float
(),
torch
.
from_numpy
(
istd
).
float
())
else
:
global_cmvn
=
None
input_dim
=
configs
[
'input_dim'
]
vocab_size
=
configs
[
'output_dim'
]
encoder_type
=
configs
.
get
(
'encoder'
,
'conformer'
)
if
encoder_type
==
'conformer'
:
encoder
=
ConformerEncoder
(
input_dim
,
global_cmvn
=
global_cmvn
,
**
configs
[
'encoder_conf'
])
else
:
encoder
=
TransformerEncoder
(
input_dim
,
global_cmvn
=
global_cmvn
,
**
configs
[
'encoder_conf'
])
decoder
=
TransformerDecoder
(
vocab_size
,
encoder
.
output_size
(),
**
configs
[
'decoder_conf'
])
ctc
=
CTCDecoder
(
vocab_size
,
encoder
.
output_size
())
model
=
U2Model
(
vocab_size
=
vocab_size
,
encoder
=
encoder
,
decoder
=
decoder
,
ctc
=
ctc
,
**
configs
[
'model_conf'
],
)
return
model
deepspeech/modules/conv.py
浏览文件 @
e5641ca4
...
...
@@ -145,7 +145,7 @@ class ConvStack(nn.Layer):
act
=
'brelu'
)
out_channel
=
32
self
.
conv_stack
=
nn
.
Sequential
(
[
convs
=
[
ConvBn
(
num_channels_in
=
32
,
num_channels_out
=
out_channel
,
...
...
@@ -153,7 +153,8 @@ class ConvStack(nn.Layer):
stride
=
(
2
,
1
),
padding
=
(
10
,
5
),
act
=
'brelu'
)
for
i
in
range
(
num_stacks
-
1
)
])
]
self
.
conv_stack
=
nn
.
LayerList
(
convs
)
# conv output feat_dim
output_height
=
(
feat_size
-
1
)
//
2
+
1
...
...
deepspeech/modules/rnn.py
浏览文件 @
e5641ca4
...
...
@@ -298,7 +298,7 @@ class RNNStack(nn.Layer):
share_weights
=
share_rnn_weights
))
i_size
=
h_size
*
2
self
.
rnn_stacks
=
nn
.
Sequential
(
rnn_stacks
)
self
.
rnn_stacks
=
nn
.
ModuleList
(
rnn_stacks
)
def
forward
(
self
,
x
:
paddle
.
Tensor
,
x_len
:
paddle
.
Tensor
):
"""
...
...
deepspeech/training/trainer.py
浏览文件 @
e5641ca4
...
...
@@ -128,9 +128,10 @@ class Trainer():
dist
.
init_parallel_env
()
@
mp_tools
.
rank_zero_only
def
save
(
self
):
def
save
(
self
,
infos
=
None
):
"""Save checkpoint (model parameters and optimizer states).
"""
if
infos
is
None
:
infos
=
{
"step"
:
self
.
iteration
,
"epoch"
:
self
.
epoch
,
...
...
@@ -151,6 +152,7 @@ class Trainer():
self
.
optimizer
,
checkpoint_dir
=
self
.
checkpoint_dir
,
checkpoint_path
=
self
.
args
.
checkpoint_path
)
if
infos
:
self
.
iteration
=
infos
[
"step"
]
self
.
epoch
=
infos
[
"epoch"
]
...
...
deepspeech/utils/checkpoint.py
浏览文件 @
e5641ca4
...
...
@@ -36,11 +36,11 @@ def _load_latest_checkpoint(checkpoint_dir: str) -> int:
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
Returns:
int: the latest iteration number.
int: the latest iteration number.
-1 for no checkpoint to load.
"""
checkpoint_record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
if
not
os
.
path
.
isfile
(
checkpoint_record
):
return
0
return
-
1
# Fetch the latest checkpoint index.
with
open
(
checkpoint_record
,
"rt"
)
as
handle
:
...
...
@@ -79,11 +79,15 @@ def load_parameters(model,
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
configs
=
{}
if
checkpoint_path
is
not
None
:
iteration
=
int
(
os
.
path
.
basename
(
checkpoint_path
).
split
(
":"
)[
-
1
])
elif
checkpoint_dir
is
not
None
:
iteration
=
_load_latest_checkpoint
(
checkpoint_dir
)
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"-{}"
.
format
(
iteration
))
if
iteration
==
-
1
:
return
configs
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
iteration
))
else
:
raise
ValueError
(
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
...
...
@@ -104,7 +108,6 @@ def load_parameters(model,
rank
,
optimizer_path
))
info_path
=
re
.
sub
(
'.pdparams$'
,
'.json'
,
params_path
)
configs
=
{}
if
os
.
path
.
exists
(
info_path
):
with
open
(
info_path
,
'r'
)
as
fin
:
configs
=
json
.
load
(
fin
)
...
...
@@ -128,7 +131,7 @@ def save_parameters(checkpoint_dir: str,
Returns:
None
"""
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"
-
{}"
.
format
(
iteration
))
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"{}"
.
format
(
iteration
))
model_dict
=
model
.
state_dict
()
params_path
=
checkpoint_path
+
".pdparams"
...
...
deepspeech/utils/utility.py
浏览文件 @
e5641ca4
...
...
@@ -16,6 +16,7 @@
import
math
import
numpy
as
np
import
distutils.util
from
typing
import
List
__all__
=
[
'print_arguments'
,
'add_arguments'
,
"log_add"
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录