Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
3fa2e44e
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3fa2e44e
编写于
10月 25, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
more interface, trigger, extension
上级
c0295aa1
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
588 addition
and
21 deletion
+588
-21
deepspeech/models/asr_interface.py
deepspeech/models/asr_interface.py
+3
-3
deepspeech/models/lm_interface.py
deepspeech/models/lm_interface.py
+1
-1
deepspeech/models/st_interface.py
deepspeech/models/st_interface.py
+59
-0
deepspeech/models/u2_st/__init__.py
deepspeech/models/u2_st/__init__.py
+2
-0
deepspeech/models/u2_st/u2_st.py
deepspeech/models/u2_st/u2_st.py
+0
-0
deepspeech/training/extensions/plot.py
deepspeech/training/extensions/plot.py
+437
-0
deepspeech/training/extensions/visualizer.py
deepspeech/training/extensions/visualizer.py
+1
-1
deepspeech/training/triggers/__init__.py
deepspeech/training/triggers/__init__.py
+0
-15
deepspeech/training/triggers/compare_value_trigger.py
deepspeech/training/triggers/compare_value_trigger.py
+60
-0
deepspeech/training/triggers/limit_trigger.py
deepspeech/training/triggers/limit_trigger.py
+1
-1
deepspeech/training/triggers/time_trigger.py
deepspeech/training/triggers/time_trigger.py
+9
-0
deepspeech/training/triggers/utils.py
deepspeech/training/triggers/utils.py
+15
-0
未找到文件。
deepspeech/models/asr_interface.py
浏览文件 @
3fa2e44e
...
@@ -18,7 +18,7 @@ from deepspeech.utils.dynamic_import import dynamic_import
...
@@ -18,7 +18,7 @@ from deepspeech.utils.dynamic_import import dynamic_import
class
ASRInterface
:
class
ASRInterface
:
"""ASR Interface
for ESPnet
model implementation."""
"""ASR Interface model implementation."""
@
staticmethod
@
staticmethod
def
add_arguments
(
parser
):
def
add_arguments
(
parser
):
...
@@ -103,14 +103,14 @@ class ASRInterface:
...
@@ -103,14 +103,14 @@ class ASRInterface:
@
property
@
property
def
attention_plot_class
(
self
):
def
attention_plot_class
(
self
):
"""Get attention plot class."""
"""Get attention plot class."""
from
espnet.asr.asr_utils
import
PlotAttentionReport
from
deepspeech.training.extensions.plot
import
PlotAttentionReport
return
PlotAttentionReport
return
PlotAttentionReport
@
property
@
property
def
ctc_plot_class
(
self
):
def
ctc_plot_class
(
self
):
"""Get CTC plot class."""
"""Get CTC plot class."""
from
espnet.asr.asr_utils
import
PlotCTCReport
from
deepspeech.training.extensions.plot
import
PlotCTCReport
return
PlotCTCReport
return
PlotCTCReport
...
...
deepspeech/models/lm_interface.py
浏览文件 @
3fa2e44e
...
@@ -6,7 +6,7 @@ from deepspeech.decoders.scorers.scorer_interface import ScorerInterface
...
@@ -6,7 +6,7 @@ from deepspeech.decoders.scorers.scorer_interface import ScorerInterface
from
deepspeech.utils.dynamic_import
import
dynamic_import
from
deepspeech.utils.dynamic_import
import
dynamic_import
class
LMInterface
(
ScorerInterface
):
class
LMInterface
(
ScorerInterface
):
"""LM Interface
for ESPnet
model implementation."""
"""LM Interface model implementation."""
@
staticmethod
@
staticmethod
def
add_arguments
(
parser
):
def
add_arguments
(
parser
):
...
...
deepspeech/models/st_interface.py
0 → 100644
浏览文件 @
3fa2e44e
"""ST Interface module."""
import
argparse
from
deepspeech.utils.dynamic_import
import
dynamic_import
from
.asr_interface
import
ASRInterface
class
STInterface
(
ASRInterface
):
"""ST Interface model implementation.
NOTE: This class is inherited from ASRInterface to enable joint translation
and recognition when performing multi-task learning with the ASR task.
"""
def
translate
(
self
,
x
,
trans_args
,
char_list
=
None
,
rnnlm
=
None
,
ensemble_models
=
[]):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace trans_args: argment namespace contraining options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise
NotImplementedError
(
"translate method is not implemented"
)
def
translate_batch
(
self
,
x
,
trans_args
,
char_list
=
None
,
rnnlm
=
None
):
"""Beam search implementation for batch.
:param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace trans_args: argument namespace containing options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise
NotImplementedError
(
"Batch decoding is not supported yet."
)
predefined_st
=
{
"transformer"
:
"deepspeech.models.u2_st:U2STModel"
,
}
def
dynamic_import_st
(
module
):
"""Import ST models dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_st`
Returns:
type: ST class
"""
model_class
=
dynamic_import
(
module
,
predefined_st
)
assert
issubclass
(
model_class
,
STInterface
),
f
"
{
module
}
does not implement STInterface"
return
model_class
deepspeech/models/u2_st/__init__.py
0 → 100644
浏览文件 @
3fa2e44e
from
.u2_st
import
U2STModel
from
.u2_st
import
U2STInferModel
\ No newline at end of file
deepspeech/models/u2_st.py
→
deepspeech/models/u2_st
/u2_st
.py
浏览文件 @
3fa2e44e
文件已移动
deepspeech/training/extensions/plot.py
0 → 100644
浏览文件 @
3fa2e44e
import
argparse
import
copy
import
json
import
os
import
shutil
import
tempfile
import
numpy
as
np
from
.
import
extension
from
..updaters.trainer
import
Trainer
class
PlotAttentionReport
(
extension
.
Extension
):
"""Plot attention reporter.
Args:
att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):
Function of attention visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device.
reverse (bool): If True, input and output length are reversed.
ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
"""
def
__init__
(
self
,
att_vis_fn
,
data
,
outdir
,
converter
,
transform
,
device
,
reverse
=
False
,
ikey
=
"input"
,
iaxis
=
0
,
okey
=
"output"
,
oaxis
=
0
,
subsampling_factor
=
1
,
):
self
.
att_vis_fn
=
att_vis_fn
self
.
data
=
copy
.
deepcopy
(
data
)
self
.
data_dict
=
{
k
:
v
for
k
,
v
in
copy
.
deepcopy
(
data
)}
# key is utterance ID
self
.
outdir
=
outdir
self
.
converter
=
converter
self
.
transform
=
transform
self
.
device
=
device
self
.
reverse
=
reverse
self
.
ikey
=
ikey
self
.
iaxis
=
iaxis
self
.
okey
=
okey
self
.
oaxis
=
oaxis
self
.
factor
=
subsampling_factor
if
not
os
.
path
.
exists
(
self
.
outdir
):
os
.
makedirs
(
self
.
outdir
)
def
__call__
(
self
,
trainer
):
"""Plot and save image file of att_ws matrix."""
att_ws
,
uttid_list
=
self
.
get_attention_weights
()
if
isinstance
(
att_ws
,
list
):
# multi-encoder case
num_encs
=
len
(
att_ws
)
-
1
# atts
for
i
in
range
(
num_encs
):
for
idx
,
att_w
in
enumerate
(
att_ws
[
i
]):
filename
=
"%s/%s.ep.{.updater.epoch}.att%d.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.att%d.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
np
.
save
(
np_filename
.
format
(
trainer
),
att_w
)
self
.
_plot_and_save_attention
(
att_w
,
filename
.
format
(
trainer
))
# han
for
idx
,
att_w
in
enumerate
(
att_ws
[
num_encs
]):
filename
=
"%s/%s.ep.{.updater.epoch}.han.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.han.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
np
.
save
(
np_filename
.
format
(
trainer
),
att_w
)
self
.
_plot_and_save_attention
(
att_w
,
filename
.
format
(
trainer
),
han_mode
=
True
)
else
:
for
idx
,
att_w
in
enumerate
(
att_ws
):
filename
=
"%s/%s.ep.{.updater.epoch}.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
np
.
save
(
np_filename
.
format
(
trainer
),
att_w
)
self
.
_plot_and_save_attention
(
att_w
,
filename
.
format
(
trainer
))
def
log_attentions
(
self
,
logger
,
step
):
"""Add image files of att_ws matrix to the tensorboard."""
att_ws
,
uttid_list
=
self
.
get_attention_weights
()
if
isinstance
(
att_ws
,
list
):
# multi-encoder case
num_encs
=
len
(
att_ws
)
-
1
# atts
for
i
in
range
(
num_encs
):
for
idx
,
att_w
in
enumerate
(
att_ws
[
i
]):
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
plot
=
self
.
draw_attention_plot
(
att_w
)
logger
.
add_figure
(
"%s_att%d"
%
(
uttid_list
[
idx
],
i
+
1
),
plot
.
gcf
(),
step
,
)
# han
for
idx
,
att_w
in
enumerate
(
att_ws
[
num_encs
]):
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
plot
=
self
.
draw_han_plot
(
att_w
)
logger
.
add_figure
(
"%s_han"
%
(
uttid_list
[
idx
]),
plot
.
gcf
(),
step
,
)
else
:
for
idx
,
att_w
in
enumerate
(
att_ws
):
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
plot
=
self
.
draw_attention_plot
(
att_w
)
logger
.
add_figure
(
"%s"
%
(
uttid_list
[
idx
]),
plot
.
gcf
(),
step
)
def
get_attention_weights
(
self
):
"""Return attention weights.
Returns:
numpy.ndarray: attention weights. float. Its shape would be
differ from backend.
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2)
other case => (B, Lmax, Tmax).
* chainer-> (B, Lmax, Tmax)
"""
return_batch
,
uttid_list
=
self
.
transform
(
self
.
data
,
return_uttid
=
True
)
batch
=
self
.
converter
([
return_batch
],
self
.
device
)
if
isinstance
(
batch
,
tuple
):
att_ws
=
self
.
att_vis_fn
(
*
batch
)
else
:
att_ws
=
self
.
att_vis_fn
(
**
batch
)
return
att_ws
,
uttid_list
def
trim_attention_weight
(
self
,
uttid
,
att_w
):
"""Transform attention matrix with regard to self.reverse."""
if
self
.
reverse
:
enc_key
,
enc_axis
=
self
.
okey
,
self
.
oaxis
dec_key
,
dec_axis
=
self
.
ikey
,
self
.
iaxis
else
:
enc_key
,
enc_axis
=
self
.
ikey
,
self
.
iaxis
dec_key
,
dec_axis
=
self
.
okey
,
self
.
oaxis
dec_len
=
int
(
self
.
data_dict
[
uttid
][
dec_key
][
dec_axis
][
"shape"
][
0
])
enc_len
=
int
(
self
.
data_dict
[
uttid
][
enc_key
][
enc_axis
][
"shape"
][
0
])
if
self
.
factor
>
1
:
enc_len
//=
self
.
factor
if
len
(
att_w
.
shape
)
==
3
:
att_w
=
att_w
[:,
:
dec_len
,
:
enc_len
]
else
:
att_w
=
att_w
[:
dec_len
,
:
enc_len
]
return
att_w
def
draw_attention_plot
(
self
,
att_w
):
"""Plot the att_w matrix.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import
matplotlib
matplotlib
.
use
(
"Agg"
)
import
matplotlib.pyplot
as
plt
plt
.
clf
()
att_w
=
att_w
.
astype
(
np
.
float32
)
if
len
(
att_w
.
shape
)
==
3
:
for
h
,
aw
in
enumerate
(
att_w
,
1
):
plt
.
subplot
(
1
,
len
(
att_w
),
h
)
plt
.
imshow
(
aw
,
aspect
=
"auto"
)
plt
.
xlabel
(
"Encoder Index"
)
plt
.
ylabel
(
"Decoder Index"
)
else
:
plt
.
imshow
(
att_w
,
aspect
=
"auto"
)
plt
.
xlabel
(
"Encoder Index"
)
plt
.
ylabel
(
"Decoder Index"
)
plt
.
tight_layout
()
return
plt
def
draw_han_plot
(
self
,
att_w
):
"""Plot the att_w matrix for hierarchical attention.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import
matplotlib
matplotlib
.
use
(
"Agg"
)
import
matplotlib.pyplot
as
plt
plt
.
clf
()
if
len
(
att_w
.
shape
)
==
3
:
for
h
,
aw
in
enumerate
(
att_w
,
1
):
legends
=
[]
plt
.
subplot
(
1
,
len
(
att_w
),
h
)
for
i
in
range
(
aw
.
shape
[
1
]):
plt
.
plot
(
aw
[:,
i
])
legends
.
append
(
"Att{}"
.
format
(
i
))
plt
.
ylim
([
0
,
1.0
])
plt
.
xlim
([
0
,
aw
.
shape
[
0
]])
plt
.
grid
(
True
)
plt
.
ylabel
(
"Attention Weight"
)
plt
.
xlabel
(
"Decoder Index"
)
plt
.
legend
(
legends
)
else
:
legends
=
[]
for
i
in
range
(
att_w
.
shape
[
1
]):
plt
.
plot
(
att_w
[:,
i
])
legends
.
append
(
"Att{}"
.
format
(
i
))
plt
.
ylim
([
0
,
1.0
])
plt
.
xlim
([
0
,
att_w
.
shape
[
0
]])
plt
.
grid
(
True
)
plt
.
ylabel
(
"Attention Weight"
)
plt
.
xlabel
(
"Decoder Index"
)
plt
.
legend
(
legends
)
plt
.
tight_layout
()
return
plt
def
_plot_and_save_attention
(
self
,
att_w
,
filename
,
han_mode
=
False
):
if
han_mode
:
plt
=
self
.
draw_han_plot
(
att_w
)
else
:
plt
=
self
.
draw_attention_plot
(
att_w
)
plt
.
savefig
(
filename
)
plt
.
close
()
class
PlotCTCReport
(
extension
.
Extension
):
"""Plot CTC reporter.
Args:
ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):
Function of CTC visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device.
reverse (bool): If True, input and output length are reversed.
ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
"""
def
__init__
(
self
,
ctc_vis_fn
,
data
,
outdir
,
converter
,
transform
,
device
,
reverse
=
False
,
ikey
=
"input"
,
iaxis
=
0
,
okey
=
"output"
,
oaxis
=
0
,
subsampling_factor
=
1
,
):
self
.
ctc_vis_fn
=
ctc_vis_fn
self
.
data
=
copy
.
deepcopy
(
data
)
self
.
data_dict
=
{
k
:
v
for
k
,
v
in
copy
.
deepcopy
(
data
)}
# key is utterance ID
self
.
outdir
=
outdir
self
.
converter
=
converter
self
.
transform
=
transform
self
.
device
=
device
self
.
reverse
=
reverse
self
.
ikey
=
ikey
self
.
iaxis
=
iaxis
self
.
okey
=
okey
self
.
oaxis
=
oaxis
self
.
factor
=
subsampling_factor
if
not
os
.
path
.
exists
(
self
.
outdir
):
os
.
makedirs
(
self
.
outdir
)
def
__call__
(
self
,
trainer
):
"""Plot and save image file of ctc prob."""
ctc_probs
,
uttid_list
=
self
.
get_ctc_probs
()
if
isinstance
(
ctc_probs
,
list
):
# multi-encoder case
num_encs
=
len
(
ctc_probs
)
-
1
for
i
in
range
(
num_encs
):
for
idx
,
ctc_prob
in
enumerate
(
ctc_probs
[
i
]):
filename
=
"%s/%s.ep.{.updater.epoch}.ctc%d.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
ctc_prob
=
self
.
trim_ctc_prob
(
uttid_list
[
idx
],
ctc_prob
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.ctc%d.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
np
.
save
(
np_filename
.
format
(
trainer
),
ctc_prob
)
self
.
_plot_and_save_ctc
(
ctc_prob
,
filename
.
format
(
trainer
))
else
:
for
idx
,
ctc_prob
in
enumerate
(
ctc_probs
):
filename
=
"%s/%s.ep.{.updater.epoch}.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
ctc_prob
=
self
.
trim_ctc_prob
(
uttid_list
[
idx
],
ctc_prob
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
np
.
save
(
np_filename
.
format
(
trainer
),
ctc_prob
)
self
.
_plot_and_save_ctc
(
ctc_prob
,
filename
.
format
(
trainer
))
def
log_ctc_probs
(
self
,
logger
,
step
):
"""Add image files of ctc probs to the tensorboard."""
ctc_probs
,
uttid_list
=
self
.
get_ctc_probs
()
if
isinstance
(
ctc_probs
,
list
):
# multi-encoder case
num_encs
=
len
(
ctc_probs
)
-
1
for
i
in
range
(
num_encs
):
for
idx
,
ctc_prob
in
enumerate
(
ctc_probs
[
i
]):
ctc_prob
=
self
.
trim_ctc_prob
(
uttid_list
[
idx
],
ctc_prob
)
plot
=
self
.
draw_ctc_plot
(
ctc_prob
)
logger
.
add_figure
(
"%s_ctc%d"
%
(
uttid_list
[
idx
],
i
+
1
),
plot
.
gcf
(),
step
,
)
else
:
for
idx
,
ctc_prob
in
enumerate
(
ctc_probs
):
ctc_prob
=
self
.
trim_ctc_prob
(
uttid_list
[
idx
],
ctc_prob
)
plot
=
self
.
draw_ctc_plot
(
ctc_prob
)
logger
.
add_figure
(
"%s"
%
(
uttid_list
[
idx
]),
plot
.
gcf
(),
step
)
def
get_ctc_probs
(
self
):
"""Return CTC probs.
Returns:
numpy.ndarray: CTC probs. float. Its shape would be
differ from backend. (B, Tmax, vocab).
"""
return_batch
,
uttid_list
=
self
.
transform
(
self
.
data
,
return_uttid
=
True
)
batch
=
self
.
converter
([
return_batch
],
self
.
device
)
if
isinstance
(
batch
,
tuple
):
probs
=
self
.
ctc_vis_fn
(
*
batch
)
else
:
probs
=
self
.
ctc_vis_fn
(
**
batch
)
return
probs
,
uttid_list
def
trim_ctc_prob
(
self
,
uttid
,
prob
):
"""Trim CTC posteriors accoding to input lengths."""
enc_len
=
int
(
self
.
data_dict
[
uttid
][
self
.
ikey
][
self
.
iaxis
][
"shape"
][
0
])
if
self
.
factor
>
1
:
enc_len
//=
self
.
factor
prob
=
prob
[:
enc_len
]
return
prob
def
draw_ctc_plot
(
self
,
ctc_prob
):
"""Plot the ctc_prob matrix.
Returns:
matplotlib.pyplot: pyplot object with CTC prob matrix image.
"""
import
matplotlib
matplotlib
.
use
(
"Agg"
)
import
matplotlib.pyplot
as
plt
ctc_prob
=
ctc_prob
.
astype
(
np
.
float32
)
plt
.
clf
()
topk_ids
=
np
.
argsort
(
ctc_prob
,
axis
=
1
)
n_frames
,
vocab
=
ctc_prob
.
shape
times_probs
=
np
.
arange
(
n_frames
)
plt
.
figure
(
figsize
=
(
20
,
8
))
# NOTE: index 0 is reserved for blank
for
idx
in
set
(
topk_ids
.
reshape
(
-
1
).
tolist
()):
if
idx
==
0
:
plt
.
plot
(
times_probs
,
ctc_prob
[:,
0
],
":"
,
label
=
"<blank>"
,
color
=
"grey"
)
else
:
plt
.
plot
(
times_probs
,
ctc_prob
[:,
idx
])
plt
.
xlabel
(
u
"Input [frame]"
,
fontsize
=
12
)
plt
.
ylabel
(
"Posteriors"
,
fontsize
=
12
)
plt
.
xticks
(
list
(
range
(
0
,
int
(
n_frames
)
+
1
,
10
)))
plt
.
yticks
(
list
(
range
(
0
,
2
,
1
)))
plt
.
tight_layout
()
return
plt
def
_plot_and_save_ctc
(
self
,
ctc_prob
,
filename
):
plt
=
self
.
draw_ctc_plot
(
ctc_prob
)
plt
.
savefig
(
filename
)
plt
.
close
()
\ No newline at end of file
deepspeech/training/extensions/visualizer.py
浏览文件 @
3fa2e44e
...
@@ -36,4 +36,4 @@ class VisualDL(extension.Extension):
...
@@ -36,4 +36,4 @@ class VisualDL(extension.Extension):
self
.
writer
.
add_scalar
(
k
,
v
,
step
=
trainer
.
updater
.
state
.
iteration
)
self
.
writer
.
add_scalar
(
k
,
v
,
step
=
trainer
.
updater
.
state
.
iteration
)
def
finalize
(
self
,
trainer
):
def
finalize
(
self
,
trainer
):
self
.
writer
.
close
()
self
.
writer
.
close
()
\ No newline at end of file
deepspeech/training/triggers/__init__.py
浏览文件 @
3fa2e44e
...
@@ -11,18 +11,3 @@
...
@@ -11,18 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.interval_trigger
import
IntervalTrigger
def
never_fail_trigger
(
trainer
):
return
False
def
get_trigger
(
trigger
):
if
trigger
is
None
:
return
never_fail_trigger
if
callable
(
trigger
):
return
trigger
else
:
trigger
=
IntervalTrigger
(
*
trigger
)
return
trigger
deepspeech/training/triggers/compare_value_trigger.py
0 → 100644
浏览文件 @
3fa2e44e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.utils
import
get_trigger
from
..reporter
import
DictSummary
class
CompareValueTrigger
():
"""Trigger invoked when key value getting bigger or lower than before.
Args:
key (str) : Key of value.
compare_fn ((float, float) -> bool) : Function to compare the values.
trigger (tuple(int, str)) : Trigger that decide the comparison interval.
"""
def
__init__
(
self
,
key
,
compare_fn
,
trigger
=
(
1
,
"epoch"
)):
self
.
_key
=
key
self
.
_best_value
=
None
self
.
_interval_trigger
=
get_trigger
(
trigger
)
self
.
_init_summary
()
self
.
_compare_fn
=
compare_fn
def
__call__
(
self
,
trainer
):
"""Get value related to the key and compare with current value."""
observation
=
trainer
.
observation
summary
=
self
.
_summary
key
=
self
.
_key
if
key
in
observation
:
summary
.
add
({
key
:
observation
[
key
]})
if
not
self
.
_interval_trigger
(
trainer
):
return
False
stats
=
summary
.
compute_mean
()
value
=
float
(
stats
[
key
])
# copy to CPU
self
.
_init_summary
()
if
self
.
_best_value
is
None
:
# initialize best value
self
.
_best_value
=
value
return
False
elif
self
.
_compare_fn
(
self
.
_best_value
,
value
):
return
True
else
:
self
.
_best_value
=
value
return
False
def
_init_summary
(
self
):
self
.
_summary
=
DictSummary
()
\ No newline at end of file
deepspeech/training/triggers/limit_trigger.py
浏览文件 @
3fa2e44e
...
@@ -28,4 +28,4 @@ class LimitTrigger():
...
@@ -28,4 +28,4 @@ class LimitTrigger():
state
=
trainer
.
updater
.
state
state
=
trainer
.
updater
.
state
index
=
getattr
(
state
,
self
.
unit
)
index
=
getattr
(
state
,
self
.
unit
)
fire
=
index
>=
self
.
limit
fire
=
index
>=
self
.
limit
return
fire
return
fire
\ No newline at end of file
deepspeech/training/triggers/time_trigger.py
浏览文件 @
3fa2e44e
...
@@ -30,3 +30,12 @@ class TimeTrigger():
...
@@ -30,3 +30,12 @@ class TimeTrigger():
return
True
return
True
else
:
else
:
return
False
return
False
def
state_dict
(
self
):
state_dict
=
{
"next_time"
:
self
.
_next_time
,
}
return
state_dict
def
set_state_dict
(
self
,
state_dict
):
self
.
_next_time
=
state_dict
[
'next_time'
]
\ No newline at end of file
deepspeech/training/triggers/utils.py
0 → 100644
浏览文件 @
3fa2e44e
from
.interval_trigger
import
IntervalTrigger
def
never_fail_trigger
(
trainer
):
return
False
def
get_trigger
(
trigger
):
if
trigger
is
None
:
return
never_fail_trigger
if
callable
(
trigger
):
return
trigger
else
:
trigger
=
IntervalTrigger
(
*
trigger
)
return
trigger
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录