Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
9e2773df
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9e2773df
编写于
10月 25, 2021
作者:
H
huangyuxin
浏览文件
操作
浏览文件
下载
差异文件
tiny change, not important
上级
418d85ef
e4ecfb22
变更
23
显示空白变更内容
内联
并排
Showing
23 changed file
with
671 addition
and
52 deletion
+671
-52
deepspeech/__init__.py
deepspeech/__init__.py
+0
-3
deepspeech/decoders/recog.py
deepspeech/decoders/recog.py
+0
-4
deepspeech/decoders/recog_bin.py
deepspeech/decoders/recog_bin.py
+0
-2
deepspeech/models/asr_interface.py
deepspeech/models/asr_interface.py
+3
-3
deepspeech/models/lm/__init__.py
deepspeech/models/lm/__init__.py
+13
-0
deepspeech/models/lm_interface.py
deepspeech/models/lm_interface.py
+18
-5
deepspeech/models/st_interface.py
deepspeech/models/st_interface.py
+75
-0
deepspeech/models/u2_st/__init__.py
deepspeech/models/u2_st/__init__.py
+15
-0
deepspeech/models/u2_st/u2_st.py
deepspeech/models/u2_st/u2_st.py
+0
-0
deepspeech/modules/ctc.py
deepspeech/modules/ctc.py
+2
-1
deepspeech/modules/embedding.py
deepspeech/modules/embedding.py
+4
-1
deepspeech/modules/encoder.py
deepspeech/modules/encoder.py
+1
-1
deepspeech/modules/subsampling.py
deepspeech/modules/subsampling.py
+4
-2
deepspeech/training/extensions/plot.py
deepspeech/training/extensions/plot.py
+418
-0
deepspeech/training/triggers/__init__.py
deepspeech/training/triggers/__init__.py
+0
-15
deepspeech/training/triggers/compare_value_trigger.py
deepspeech/training/triggers/compare_value_trigger.py
+61
-0
deepspeech/training/triggers/time_trigger.py
deepspeech/training/triggers/time_trigger.py
+9
-0
deepspeech/training/triggers/utils.py
deepspeech/training/triggers/utils.py
+28
-0
deepspeech/utils/asr_utils.py
deepspeech/utils/asr_utils.py
+2
-2
deepspeech/utils/bleu_score.py
deepspeech/utils/bleu_score.py
+6
-3
deepspeech/utils/error_rate.py
deepspeech/utils/error_rate.py
+11
-7
setup.py
setup.py
+1
-2
utils/json2trn.py
utils/json2trn.py
+0
-1
未找到文件。
deepspeech/__init__.py
浏览文件 @
9e2773df
...
...
@@ -355,7 +355,6 @@ if not hasattr(paddle.Tensor, 'tolist'):
"register user tolist to paddle.Tensor, remove this when fixed!"
)
setattr
(
paddle
.
Tensor
,
'tolist'
,
tolist
)
########### hack paddle.nn #############
from
paddle.nn
import
Layer
from
typing
import
Optional
...
...
@@ -506,5 +505,3 @@ if not hasattr(paddle.nn, 'LayerDict'):
logger
.
debug
(
"register user LayerDict to paddle.nn, remove this when fixed!"
)
setattr
(
paddle
.
nn
,
'LayerDict'
,
LayerDict
)
deepspeech/decoders/recog.py
浏览文件 @
9e2773df
...
...
@@ -12,12 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`."""
import
json
from
pathlib
import
Path
import
jsonlines
import
paddle
import
yaml
from
yacs.config
import
CfgNode
from
.beam_search
import
BatchBeamSearch
...
...
deepspeech/decoders/recog_bin.py
浏览文件 @
9e2773df
...
...
@@ -21,8 +21,6 @@ from distutils.util import strtobool
import
configargparse
import
numpy
as
np
from
.recog
import
recog_v2
def
get_parser
():
"""Get default arguments."""
...
...
deepspeech/models/asr_interface.py
浏览文件 @
9e2773df
...
...
@@ -18,7 +18,7 @@ from deepspeech.utils.dynamic_import import dynamic_import
class
ASRInterface
:
"""ASR Interface
for ESPnet
model implementation."""
"""ASR Interface model implementation."""
@
staticmethod
def
add_arguments
(
parser
):
...
...
@@ -103,14 +103,14 @@ class ASRInterface:
@
property
def
attention_plot_class
(
self
):
"""Get attention plot class."""
from
espnet.asr.asr_utils
import
PlotAttentionReport
from
deepspeech.training.extensions.plot
import
PlotAttentionReport
return
PlotAttentionReport
@
property
def
ctc_plot_class
(
self
):
"""Get CTC plot class."""
from
espnet.asr.asr_utils
import
PlotCTCReport
from
deepspeech.training.extensions.plot
import
PlotCTCReport
return
PlotCTCReport
...
...
deepspeech/models/lm/__init__.py
浏览文件 @
9e2773df
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
deepspeech/models/lm_interface.py
浏览文件 @
9e2773df
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Language model interface."""
import
argparse
from
deepspeech.decoders.scorers.scorer_interface
import
ScorerInterface
from
deepspeech.utils.dynamic_import
import
dynamic_import
class
LMInterface
(
ScorerInterface
):
"""LM Interface
for ESPnet
model implementation."""
"""LM Interface model implementation."""
@
staticmethod
def
add_arguments
(
parser
):
...
...
@@ -52,6 +65,7 @@ predefined_lms = {
"transformer"
:
"deepspeech.models.lm.transformer:TransformerLM"
,
}
def
dynamic_import_lm
(
module
):
"""Import LM class dynamically.
...
...
@@ -63,7 +77,6 @@ def dynamic_import_lm(module):
"""
model_class
=
dynamic_import
(
module
,
predefined_lms
)
assert
issubclass
(
model_class
,
LMInterface
),
f
"
{
module
}
does not implement LMInterface"
assert
issubclass
(
model_class
,
LMInterface
),
f
"
{
module
}
does not implement LMInterface"
return
model_class
deepspeech/models/st_interface.py
0 → 100644
浏览文件 @
9e2773df
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ST Interface module."""
from
.asr_interface
import
ASRInterface
from
deepspeech.utils.dynamic_import
import
dynamic_import
class
STInterface
(
ASRInterface
):
"""ST Interface model implementation.
NOTE: This class is inherited from ASRInterface to enable joint translation
and recognition when performing multi-task learning with the ASR task.
"""
def
translate
(
self
,
x
,
trans_args
,
char_list
=
None
,
rnnlm
=
None
,
ensemble_models
=
[]):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace trans_args: argment namespace contraining options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise
NotImplementedError
(
"translate method is not implemented"
)
def
translate_batch
(
self
,
x
,
trans_args
,
char_list
=
None
,
rnnlm
=
None
):
"""Beam search implementation for batch.
:param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace trans_args: argument namespace containing options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise
NotImplementedError
(
"Batch decoding is not supported yet."
)
predefined_st
=
{
"transformer"
:
"deepspeech.models.u2_st:U2STModel"
,
}
def
dynamic_import_st
(
module
):
"""Import ST models dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_st`
Returns:
type: ST class
"""
model_class
=
dynamic_import
(
module
,
predefined_st
)
assert
issubclass
(
model_class
,
STInterface
),
f
"
{
module
}
does not implement STInterface"
return
model_class
deepspeech/models/u2_st/__init__.py
0 → 100644
浏览文件 @
9e2773df
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.u2_st
import
U2STInferModel
from
.u2_st
import
U2STModel
deepspeech/models/u2_st.py
→
deepspeech/models/u2_st
/u2_st
.py
浏览文件 @
9e2773df
文件已移动
deepspeech/modules/ctc.py
浏览文件 @
9e2773df
...
...
@@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Union
import
paddle
from
paddle
import
nn
from
typing
import
Union
from
paddle.nn
import
functional
as
F
from
typeguard
import
check_argument_types
...
...
deepspeech/modules/embedding.py
浏览文件 @
9e2773df
...
...
@@ -22,7 +22,10 @@ from deepspeech.utils.log import Log
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"NoPositionalEncoding"
,
"PositionalEncoding"
,
"RelPositionalEncoding"
]
__all__
=
[
"NoPositionalEncoding"
,
"PositionalEncoding"
,
"RelPositionalEncoding"
]
class
NoPositionalEncoding
(
nn
.
Layer
):
def
__init__
(
self
,
...
...
deepspeech/modules/encoder.py
浏览文件 @
9e2773df
...
...
@@ -103,7 +103,7 @@ class BaseEncoder(nn.Layer):
pos_enc_class
=
PositionalEncoding
elif
pos_enc_layer_type
==
"rel_pos"
:
pos_enc_class
=
RelPositionalEncoding
elif
pos_enc_layer_type
is
"no_pos"
:
elif
pos_enc_layer_type
==
"no_pos"
:
pos_enc_class
=
NoPositionalEncoding
else
:
raise
ValueError
(
"unknown pos_enc_layer: "
+
pos_enc_layer_type
)
...
...
deepspeech/modules/subsampling.py
浏览文件 @
9e2773df
...
...
@@ -61,7 +61,7 @@ class LinearNoSubsampling(BaseSubsampling):
nn
.
Linear
(
idim
,
odim
),
nn
.
LayerNorm
(
odim
,
epsilon
=
1e-12
),
nn
.
Dropout
(
dropout_rate
),
nn
.
ReLU
(),)
nn
.
ReLU
(),
)
self
.
right_context
=
0
self
.
subsampling_rate
=
1
...
...
@@ -83,10 +83,12 @@ class LinearNoSubsampling(BaseSubsampling):
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
return
x
,
pos_emb
,
x_mask
class
Conv2dSubsampling
(
BaseSubsampling
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
class
Conv2dSubsampling4
(
Conv2dSubsampling
):
"""Convolutional 2D subsampling (to 1/4 length)."""
...
...
deepspeech/training/extensions/plot.py
0 → 100644
浏览文件 @
9e2773df
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
os
import
numpy
as
np
from
.
import
extension
class
PlotAttentionReport
(
extension
.
Extension
):
"""Plot attention reporter.
Args:
att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):
Function of attention visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device.
reverse (bool): If True, input and output length are reversed.
ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
"""
def
__init__
(
self
,
att_vis_fn
,
data
,
outdir
,
converter
,
transform
,
device
,
reverse
=
False
,
ikey
=
"input"
,
iaxis
=
0
,
okey
=
"output"
,
oaxis
=
0
,
subsampling_factor
=
1
,
):
self
.
att_vis_fn
=
att_vis_fn
self
.
data
=
copy
.
deepcopy
(
data
)
self
.
data_dict
=
{
k
:
v
for
k
,
v
in
copy
.
deepcopy
(
data
)}
# key is utterance ID
self
.
outdir
=
outdir
self
.
converter
=
converter
self
.
transform
=
transform
self
.
device
=
device
self
.
reverse
=
reverse
self
.
ikey
=
ikey
self
.
iaxis
=
iaxis
self
.
okey
=
okey
self
.
oaxis
=
oaxis
self
.
factor
=
subsampling_factor
if
not
os
.
path
.
exists
(
self
.
outdir
):
os
.
makedirs
(
self
.
outdir
)
def
__call__
(
self
,
trainer
):
"""Plot and save image file of att_ws matrix."""
att_ws
,
uttid_list
=
self
.
get_attention_weights
()
if
isinstance
(
att_ws
,
list
):
# multi-encoder case
num_encs
=
len
(
att_ws
)
-
1
# atts
for
i
in
range
(
num_encs
):
for
idx
,
att_w
in
enumerate
(
att_ws
[
i
]):
filename
=
"%s/%s.ep.{.updater.epoch}.att%d.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.att%d.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
np
.
save
(
np_filename
.
format
(
trainer
),
att_w
)
self
.
_plot_and_save_attention
(
att_w
,
filename
.
format
(
trainer
))
# han
for
idx
,
att_w
in
enumerate
(
att_ws
[
num_encs
]):
filename
=
"%s/%s.ep.{.updater.epoch}.han.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.han.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
np
.
save
(
np_filename
.
format
(
trainer
),
att_w
)
self
.
_plot_and_save_attention
(
att_w
,
filename
.
format
(
trainer
),
han_mode
=
True
)
else
:
for
idx
,
att_w
in
enumerate
(
att_ws
):
filename
=
"%s/%s.ep.{.updater.epoch}.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
np
.
save
(
np_filename
.
format
(
trainer
),
att_w
)
self
.
_plot_and_save_attention
(
att_w
,
filename
.
format
(
trainer
))
def
log_attentions
(
self
,
logger
,
step
):
"""Add image files of att_ws matrix to the tensorboard."""
att_ws
,
uttid_list
=
self
.
get_attention_weights
()
if
isinstance
(
att_ws
,
list
):
# multi-encoder case
num_encs
=
len
(
att_ws
)
-
1
# atts
for
i
in
range
(
num_encs
):
for
idx
,
att_w
in
enumerate
(
att_ws
[
i
]):
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
plot
=
self
.
draw_attention_plot
(
att_w
)
logger
.
add_figure
(
"%s_att%d"
%
(
uttid_list
[
idx
],
i
+
1
),
plot
.
gcf
(),
step
,
)
# han
for
idx
,
att_w
in
enumerate
(
att_ws
[
num_encs
]):
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
plot
=
self
.
draw_han_plot
(
att_w
)
logger
.
add_figure
(
"%s_han"
%
(
uttid_list
[
idx
]),
plot
.
gcf
(),
step
,
)
else
:
for
idx
,
att_w
in
enumerate
(
att_ws
):
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
plot
=
self
.
draw_attention_plot
(
att_w
)
logger
.
add_figure
(
"%s"
%
(
uttid_list
[
idx
]),
plot
.
gcf
(),
step
)
def
get_attention_weights
(
self
):
"""Return attention weights.
Returns:
numpy.ndarray: attention weights. float. Its shape would be
differ from backend.
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2)
other case => (B, Lmax, Tmax).
* chainer-> (B, Lmax, Tmax)
"""
return_batch
,
uttid_list
=
self
.
transform
(
self
.
data
,
return_uttid
=
True
)
batch
=
self
.
converter
([
return_batch
],
self
.
device
)
if
isinstance
(
batch
,
tuple
):
att_ws
=
self
.
att_vis_fn
(
*
batch
)
else
:
att_ws
=
self
.
att_vis_fn
(
**
batch
)
return
att_ws
,
uttid_list
def
trim_attention_weight
(
self
,
uttid
,
att_w
):
"""Transform attention matrix with regard to self.reverse."""
if
self
.
reverse
:
enc_key
,
enc_axis
=
self
.
okey
,
self
.
oaxis
dec_key
,
dec_axis
=
self
.
ikey
,
self
.
iaxis
else
:
enc_key
,
enc_axis
=
self
.
ikey
,
self
.
iaxis
dec_key
,
dec_axis
=
self
.
okey
,
self
.
oaxis
dec_len
=
int
(
self
.
data_dict
[
uttid
][
dec_key
][
dec_axis
][
"shape"
][
0
])
enc_len
=
int
(
self
.
data_dict
[
uttid
][
enc_key
][
enc_axis
][
"shape"
][
0
])
if
self
.
factor
>
1
:
enc_len
//=
self
.
factor
if
len
(
att_w
.
shape
)
==
3
:
att_w
=
att_w
[:,
:
dec_len
,
:
enc_len
]
else
:
att_w
=
att_w
[:
dec_len
,
:
enc_len
]
return
att_w
def
draw_attention_plot
(
self
,
att_w
):
"""Plot the att_w matrix.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import
matplotlib
matplotlib
.
use
(
"Agg"
)
import
matplotlib.pyplot
as
plt
plt
.
clf
()
att_w
=
att_w
.
astype
(
np
.
float32
)
if
len
(
att_w
.
shape
)
==
3
:
for
h
,
aw
in
enumerate
(
att_w
,
1
):
plt
.
subplot
(
1
,
len
(
att_w
),
h
)
plt
.
imshow
(
aw
,
aspect
=
"auto"
)
plt
.
xlabel
(
"Encoder Index"
)
plt
.
ylabel
(
"Decoder Index"
)
else
:
plt
.
imshow
(
att_w
,
aspect
=
"auto"
)
plt
.
xlabel
(
"Encoder Index"
)
plt
.
ylabel
(
"Decoder Index"
)
plt
.
tight_layout
()
return
plt
def
draw_han_plot
(
self
,
att_w
):
"""Plot the att_w matrix for hierarchical attention.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import
matplotlib
matplotlib
.
use
(
"Agg"
)
import
matplotlib.pyplot
as
plt
plt
.
clf
()
if
len
(
att_w
.
shape
)
==
3
:
for
h
,
aw
in
enumerate
(
att_w
,
1
):
legends
=
[]
plt
.
subplot
(
1
,
len
(
att_w
),
h
)
for
i
in
range
(
aw
.
shape
[
1
]):
plt
.
plot
(
aw
[:,
i
])
legends
.
append
(
"Att{}"
.
format
(
i
))
plt
.
ylim
([
0
,
1.0
])
plt
.
xlim
([
0
,
aw
.
shape
[
0
]])
plt
.
grid
(
True
)
plt
.
ylabel
(
"Attention Weight"
)
plt
.
xlabel
(
"Decoder Index"
)
plt
.
legend
(
legends
)
else
:
legends
=
[]
for
i
in
range
(
att_w
.
shape
[
1
]):
plt
.
plot
(
att_w
[:,
i
])
legends
.
append
(
"Att{}"
.
format
(
i
))
plt
.
ylim
([
0
,
1.0
])
plt
.
xlim
([
0
,
att_w
.
shape
[
0
]])
plt
.
grid
(
True
)
plt
.
ylabel
(
"Attention Weight"
)
plt
.
xlabel
(
"Decoder Index"
)
plt
.
legend
(
legends
)
plt
.
tight_layout
()
return
plt
def
_plot_and_save_attention
(
self
,
att_w
,
filename
,
han_mode
=
False
):
if
han_mode
:
plt
=
self
.
draw_han_plot
(
att_w
)
else
:
plt
=
self
.
draw_attention_plot
(
att_w
)
plt
.
savefig
(
filename
)
plt
.
close
()
class
PlotCTCReport
(
extension
.
Extension
):
"""Plot CTC reporter.
Args:
ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):
Function of CTC visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device.
reverse (bool): If True, input and output length are reversed.
ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
"""
def
__init__
(
self
,
ctc_vis_fn
,
data
,
outdir
,
converter
,
transform
,
device
,
reverse
=
False
,
ikey
=
"input"
,
iaxis
=
0
,
okey
=
"output"
,
oaxis
=
0
,
subsampling_factor
=
1
,
):
self
.
ctc_vis_fn
=
ctc_vis_fn
self
.
data
=
copy
.
deepcopy
(
data
)
self
.
data_dict
=
{
k
:
v
for
k
,
v
in
copy
.
deepcopy
(
data
)}
# key is utterance ID
self
.
outdir
=
outdir
self
.
converter
=
converter
self
.
transform
=
transform
self
.
device
=
device
self
.
reverse
=
reverse
self
.
ikey
=
ikey
self
.
iaxis
=
iaxis
self
.
okey
=
okey
self
.
oaxis
=
oaxis
self
.
factor
=
subsampling_factor
if
not
os
.
path
.
exists
(
self
.
outdir
):
os
.
makedirs
(
self
.
outdir
)
def
__call__
(
self
,
trainer
):
"""Plot and save image file of ctc prob."""
ctc_probs
,
uttid_list
=
self
.
get_ctc_probs
()
if
isinstance
(
ctc_probs
,
list
):
# multi-encoder case
num_encs
=
len
(
ctc_probs
)
-
1
for
i
in
range
(
num_encs
):
for
idx
,
ctc_prob
in
enumerate
(
ctc_probs
[
i
]):
filename
=
"%s/%s.ep.{.updater.epoch}.ctc%d.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
ctc_prob
=
self
.
trim_ctc_prob
(
uttid_list
[
idx
],
ctc_prob
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.ctc%d.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
np
.
save
(
np_filename
.
format
(
trainer
),
ctc_prob
)
self
.
_plot_and_save_ctc
(
ctc_prob
,
filename
.
format
(
trainer
))
else
:
for
idx
,
ctc_prob
in
enumerate
(
ctc_probs
):
filename
=
"%s/%s.ep.{.updater.epoch}.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
ctc_prob
=
self
.
trim_ctc_prob
(
uttid_list
[
idx
],
ctc_prob
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
np
.
save
(
np_filename
.
format
(
trainer
),
ctc_prob
)
self
.
_plot_and_save_ctc
(
ctc_prob
,
filename
.
format
(
trainer
))
def
log_ctc_probs
(
self
,
logger
,
step
):
"""Add image files of ctc probs to the tensorboard."""
ctc_probs
,
uttid_list
=
self
.
get_ctc_probs
()
if
isinstance
(
ctc_probs
,
list
):
# multi-encoder case
num_encs
=
len
(
ctc_probs
)
-
1
for
i
in
range
(
num_encs
):
for
idx
,
ctc_prob
in
enumerate
(
ctc_probs
[
i
]):
ctc_prob
=
self
.
trim_ctc_prob
(
uttid_list
[
idx
],
ctc_prob
)
plot
=
self
.
draw_ctc_plot
(
ctc_prob
)
logger
.
add_figure
(
"%s_ctc%d"
%
(
uttid_list
[
idx
],
i
+
1
),
plot
.
gcf
(),
step
,
)
else
:
for
idx
,
ctc_prob
in
enumerate
(
ctc_probs
):
ctc_prob
=
self
.
trim_ctc_prob
(
uttid_list
[
idx
],
ctc_prob
)
plot
=
self
.
draw_ctc_plot
(
ctc_prob
)
logger
.
add_figure
(
"%s"
%
(
uttid_list
[
idx
]),
plot
.
gcf
(),
step
)
def
get_ctc_probs
(
self
):
"""Return CTC probs.
Returns:
numpy.ndarray: CTC probs. float. Its shape would be
differ from backend. (B, Tmax, vocab).
"""
return_batch
,
uttid_list
=
self
.
transform
(
self
.
data
,
return_uttid
=
True
)
batch
=
self
.
converter
([
return_batch
],
self
.
device
)
if
isinstance
(
batch
,
tuple
):
probs
=
self
.
ctc_vis_fn
(
*
batch
)
else
:
probs
=
self
.
ctc_vis_fn
(
**
batch
)
return
probs
,
uttid_list
def
trim_ctc_prob
(
self
,
uttid
,
prob
):
"""Trim CTC posteriors accoding to input lengths."""
enc_len
=
int
(
self
.
data_dict
[
uttid
][
self
.
ikey
][
self
.
iaxis
][
"shape"
][
0
])
if
self
.
factor
>
1
:
enc_len
//=
self
.
factor
prob
=
prob
[:
enc_len
]
return
prob
def
draw_ctc_plot
(
self
,
ctc_prob
):
"""Plot the ctc_prob matrix.
Returns:
matplotlib.pyplot: pyplot object with CTC prob matrix image.
"""
import
matplotlib
matplotlib
.
use
(
"Agg"
)
import
matplotlib.pyplot
as
plt
ctc_prob
=
ctc_prob
.
astype
(
np
.
float32
)
plt
.
clf
()
topk_ids
=
np
.
argsort
(
ctc_prob
,
axis
=
1
)
n_frames
,
vocab
=
ctc_prob
.
shape
times_probs
=
np
.
arange
(
n_frames
)
plt
.
figure
(
figsize
=
(
20
,
8
))
# NOTE: index 0 is reserved for blank
for
idx
in
set
(
topk_ids
.
reshape
(
-
1
).
tolist
()):
if
idx
==
0
:
plt
.
plot
(
times_probs
,
ctc_prob
[:,
0
],
":"
,
label
=
"<blank>"
,
color
=
"grey"
)
else
:
plt
.
plot
(
times_probs
,
ctc_prob
[:,
idx
])
plt
.
xlabel
(
u
"Input [frame]"
,
fontsize
=
12
)
plt
.
ylabel
(
"Posteriors"
,
fontsize
=
12
)
plt
.
xticks
(
list
(
range
(
0
,
int
(
n_frames
)
+
1
,
10
)))
plt
.
yticks
(
list
(
range
(
0
,
2
,
1
)))
plt
.
tight_layout
()
return
plt
def
_plot_and_save_ctc
(
self
,
ctc_prob
,
filename
):
plt
=
self
.
draw_ctc_plot
(
ctc_prob
)
plt
.
savefig
(
filename
)
plt
.
close
()
deepspeech/training/triggers/__init__.py
浏览文件 @
9e2773df
...
...
@@ -11,18 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.interval_trigger
import
IntervalTrigger
def
never_fail_trigger
(
trainer
):
return
False
def
get_trigger
(
trigger
):
if
trigger
is
None
:
return
never_fail_trigger
if
callable
(
trigger
):
return
trigger
else
:
trigger
=
IntervalTrigger
(
*
trigger
)
return
trigger
deepspeech/training/triggers/compare_value_trigger.py
0 → 100644
浏览文件 @
9e2773df
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
..reporter
import
DictSummary
from
.utils
import
get_trigger
class
CompareValueTrigger
():
"""Trigger invoked when key value getting bigger or lower than before.
Args:
key (str) : Key of value.
compare_fn ((float, float) -> bool) : Function to compare the values.
trigger (tuple(int, str)) : Trigger that decide the comparison interval.
"""
def
__init__
(
self
,
key
,
compare_fn
,
trigger
=
(
1
,
"epoch"
)):
self
.
_key
=
key
self
.
_best_value
=
None
self
.
_interval_trigger
=
get_trigger
(
trigger
)
self
.
_init_summary
()
self
.
_compare_fn
=
compare_fn
def
__call__
(
self
,
trainer
):
"""Get value related to the key and compare with current value."""
observation
=
trainer
.
observation
summary
=
self
.
_summary
key
=
self
.
_key
if
key
in
observation
:
summary
.
add
({
key
:
observation
[
key
]})
if
not
self
.
_interval_trigger
(
trainer
):
return
False
stats
=
summary
.
compute_mean
()
value
=
float
(
stats
[
key
])
# copy to CPU
self
.
_init_summary
()
if
self
.
_best_value
is
None
:
# initialize best value
self
.
_best_value
=
value
return
False
elif
self
.
_compare_fn
(
self
.
_best_value
,
value
):
return
True
else
:
self
.
_best_value
=
value
return
False
def
_init_summary
(
self
):
self
.
_summary
=
DictSummary
()
deepspeech/training/triggers/time_trigger.py
浏览文件 @
9e2773df
...
...
@@ -30,3 +30,12 @@ class TimeTrigger():
return
True
else
:
return
False
def
state_dict
(
self
):
state_dict
=
{
"next_time"
:
self
.
_next_time
,
}
return
state_dict
def
set_state_dict
(
self
,
state_dict
):
self
.
_next_time
=
state_dict
[
'next_time'
]
deepspeech/training/triggers/utils.py
0 → 100644
浏览文件 @
9e2773df
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.interval_trigger
import
IntervalTrigger
def
never_fail_trigger
(
trainer
):
return
False
def
get_trigger
(
trigger
):
if
trigger
is
None
:
return
never_fail_trigger
if
callable
(
trigger
):
return
trigger
else
:
trigger
=
IntervalTrigger
(
*
trigger
)
return
trigger
deepspeech/utils/asr_utils.py
浏览文件 @
9e2773df
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
numpy
as
np
__all__
=
[
"label_smoothing_dist"
]
...
...
@@ -33,8 +34,7 @@ def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
if
lsm_type
==
"unigram"
:
assert
transcript
is
not
None
,
(
"transcript is required for %s label smoothing"
%
lsm_type
)
"transcript is required for %s label smoothing"
%
lsm_type
)
labelcount
=
np
.
zeros
(
odim
)
for
k
,
v
in
trans_json
.
items
():
ids
=
np
.
array
([
int
(
n
)
for
n
in
v
[
"output"
][
0
][
"tokenid"
].
split
()])
...
...
deepspeech/utils/bleu_score.py
浏览文件 @
9e2773df
...
...
@@ -14,9 +14,9 @@
"""This module provides functions to calculate bleu score in different level.
e.g. wer for word-level, cer for char-level.
"""
import
sacrebleu
import
nltk
import
numpy
as
np
import
sacrebleu
__all__
=
[
'bleu'
,
'char_bleu'
,
"ErrorCalculator"
]
...
...
@@ -106,11 +106,14 @@ class ErrorCalculator():
# NOTE: padding index (-1) in y_true is used to pad y_hat
# because y_hats is not padded with -1
seq_hat
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_hat
[:
ymax
]]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_hat_text
=
""
.
join
(
seq_hat
).
replace
(
self
.
space
,
" "
)
seq_hat_text
=
seq_hat_text
.
replace
(
self
.
pad
,
""
)
seq_true_text
=
""
.
join
(
seq_true
).
replace
(
self
.
space
,
" "
)
seqs_hat
.
append
(
seq_hat_text
)
seqs_true
.
append
(
seq_true_text
)
bleu
=
nltk
.
bleu_score
.
corpus_bleu
([[
ref
]
for
ref
in
seqs_true
],
seqs_hat
)
bleu
=
nltk
.
bleu_score
.
corpus_bleu
([[
ref
]
for
ref
in
seqs_true
],
seqs_hat
)
return
bleu
*
100
deepspeech/utils/error_rate.py
浏览文件 @
9e2773df
...
...
@@ -14,11 +14,10 @@
"""This module provides functions to calculate error rate in different level.
e.g. wer for word-level, cer for char-level.
"""
from
itertools
import
groupby
import
editdistance
import
numpy
as
np
import
logging
import
sys
from
itertools
import
groupby
__all__
=
[
'word_errors'
,
'char_errors'
,
'wer'
,
'cer'
,
"ErrorCalculator"
]
...
...
@@ -225,9 +224,12 @@ class ErrorCalculator():
:return:
"""
def
__init__
(
self
,
char_list
,
sym_space
,
sym_blank
,
report_cer
=
False
,
report_wer
=
False
):
def
__init__
(
self
,
char_list
,
sym_space
,
sym_blank
,
report_cer
=
False
,
report_wer
=
False
):
"""Construct an ErrorCalculator object."""
super
().
__init__
()
...
...
@@ -317,7 +319,9 @@ class ErrorCalculator():
ymax
=
eos_true
[
0
]
if
len
(
eos_true
)
>
0
else
len
(
y_true
)
# NOTE: padding index (-1) in y_true is used to pad y_hat
seq_hat
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_hat
[:
ymax
]]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_hat_text
=
""
.
join
(
seq_hat
).
replace
(
self
.
space
,
" "
)
seq_hat_text
=
seq_hat_text
.
replace
(
self
.
blank
,
""
)
seq_true_text
=
""
.
join
(
seq_true
).
replace
(
self
.
space
,
" "
)
...
...
setup.py
浏览文件 @
9e2773df
...
...
@@ -15,7 +15,6 @@ import contextlib
import
inspect
import
io
import
os
import
re
import
subprocess
as
sp
import
sys
from
pathlib
import
Path
...
...
@@ -84,7 +83,7 @@ def _post_install(install_lib_dir):
tools_extrs_dir
=
HERE
/
'tools/extras'
with
pushd
(
tools_extrs_dir
):
print
(
os
.
getcwd
())
check_call
(
f
"./install_autolog.sh"
)
check_call
(
"./install_autolog.sh"
)
print
(
"autolog install."
)
# ctcdecoder
...
...
utils/json2trn.py
浏览文件 @
9e2773df
...
...
@@ -4,7 +4,6 @@
# 2018 Xuankai Chang (Shanghai Jiao Tong University)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import
argparse
import
json
import
logging
import
sys
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录