Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
9e2773df
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9e2773df
编写于
10月 25, 2021
作者:
H
huangyuxin
浏览文件
操作
浏览文件
下载
差异文件
tiny change, not important
上级
418d85ef
e4ecfb22
变更
23
隐藏空白更改
内联
并排
Showing
23 changed file
with
671 addition
and
52 deletion
+671
-52
deepspeech/__init__.py
deepspeech/__init__.py
+0
-3
deepspeech/decoders/recog.py
deepspeech/decoders/recog.py
+0
-4
deepspeech/decoders/recog_bin.py
deepspeech/decoders/recog_bin.py
+0
-2
deepspeech/models/asr_interface.py
deepspeech/models/asr_interface.py
+3
-3
deepspeech/models/lm/__init__.py
deepspeech/models/lm/__init__.py
+13
-0
deepspeech/models/lm_interface.py
deepspeech/models/lm_interface.py
+18
-5
deepspeech/models/st_interface.py
deepspeech/models/st_interface.py
+75
-0
deepspeech/models/u2_st/__init__.py
deepspeech/models/u2_st/__init__.py
+15
-0
deepspeech/models/u2_st/u2_st.py
deepspeech/models/u2_st/u2_st.py
+0
-0
deepspeech/modules/ctc.py
deepspeech/modules/ctc.py
+2
-1
deepspeech/modules/embedding.py
deepspeech/modules/embedding.py
+4
-1
deepspeech/modules/encoder.py
deepspeech/modules/encoder.py
+1
-1
deepspeech/modules/subsampling.py
deepspeech/modules/subsampling.py
+4
-2
deepspeech/training/extensions/plot.py
deepspeech/training/extensions/plot.py
+418
-0
deepspeech/training/triggers/__init__.py
deepspeech/training/triggers/__init__.py
+0
-15
deepspeech/training/triggers/compare_value_trigger.py
deepspeech/training/triggers/compare_value_trigger.py
+61
-0
deepspeech/training/triggers/time_trigger.py
deepspeech/training/triggers/time_trigger.py
+9
-0
deepspeech/training/triggers/utils.py
deepspeech/training/triggers/utils.py
+28
-0
deepspeech/utils/asr_utils.py
deepspeech/utils/asr_utils.py
+2
-2
deepspeech/utils/bleu_score.py
deepspeech/utils/bleu_score.py
+6
-3
deepspeech/utils/error_rate.py
deepspeech/utils/error_rate.py
+11
-7
setup.py
setup.py
+1
-2
utils/json2trn.py
utils/json2trn.py
+0
-1
未找到文件。
deepspeech/__init__.py
浏览文件 @
9e2773df
...
@@ -355,7 +355,6 @@ if not hasattr(paddle.Tensor, 'tolist'):
...
@@ -355,7 +355,6 @@ if not hasattr(paddle.Tensor, 'tolist'):
"register user tolist to paddle.Tensor, remove this when fixed!"
)
"register user tolist to paddle.Tensor, remove this when fixed!"
)
setattr
(
paddle
.
Tensor
,
'tolist'
,
tolist
)
setattr
(
paddle
.
Tensor
,
'tolist'
,
tolist
)
########### hack paddle.nn #############
########### hack paddle.nn #############
from
paddle.nn
import
Layer
from
paddle.nn
import
Layer
from
typing
import
Optional
from
typing
import
Optional
...
@@ -506,5 +505,3 @@ if not hasattr(paddle.nn, 'LayerDict'):
...
@@ -506,5 +505,3 @@ if not hasattr(paddle.nn, 'LayerDict'):
logger
.
debug
(
logger
.
debug
(
"register user LayerDict to paddle.nn, remove this when fixed!"
)
"register user LayerDict to paddle.nn, remove this when fixed!"
)
setattr
(
paddle
.
nn
,
'LayerDict'
,
LayerDict
)
setattr
(
paddle
.
nn
,
'LayerDict'
,
LayerDict
)
deepspeech/decoders/recog.py
浏览文件 @
9e2773df
...
@@ -12,12 +12,8 @@
...
@@ -12,12 +12,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`."""
"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`."""
import
json
from
pathlib
import
Path
import
jsonlines
import
jsonlines
import
paddle
import
paddle
import
yaml
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
from
.beam_search
import
BatchBeamSearch
from
.beam_search
import
BatchBeamSearch
...
...
deepspeech/decoders/recog_bin.py
浏览文件 @
9e2773df
...
@@ -21,8 +21,6 @@ from distutils.util import strtobool
...
@@ -21,8 +21,6 @@ from distutils.util import strtobool
import
configargparse
import
configargparse
import
numpy
as
np
import
numpy
as
np
from
.recog
import
recog_v2
def
get_parser
():
def
get_parser
():
"""Get default arguments."""
"""Get default arguments."""
...
...
deepspeech/models/asr_interface.py
浏览文件 @
9e2773df
...
@@ -18,7 +18,7 @@ from deepspeech.utils.dynamic_import import dynamic_import
...
@@ -18,7 +18,7 @@ from deepspeech.utils.dynamic_import import dynamic_import
class
ASRInterface
:
class
ASRInterface
:
"""ASR Interface
for ESPnet
model implementation."""
"""ASR Interface model implementation."""
@
staticmethod
@
staticmethod
def
add_arguments
(
parser
):
def
add_arguments
(
parser
):
...
@@ -103,14 +103,14 @@ class ASRInterface:
...
@@ -103,14 +103,14 @@ class ASRInterface:
@
property
@
property
def
attention_plot_class
(
self
):
def
attention_plot_class
(
self
):
"""Get attention plot class."""
"""Get attention plot class."""
from
espnet.asr.asr_utils
import
PlotAttentionReport
from
deepspeech.training.extensions.plot
import
PlotAttentionReport
return
PlotAttentionReport
return
PlotAttentionReport
@
property
@
property
def
ctc_plot_class
(
self
):
def
ctc_plot_class
(
self
):
"""Get CTC plot class."""
"""Get CTC plot class."""
from
espnet.asr.asr_utils
import
PlotCTCReport
from
deepspeech.training.extensions.plot
import
PlotCTCReport
return
PlotCTCReport
return
PlotCTCReport
...
...
deepspeech/models/lm/__init__.py
浏览文件 @
9e2773df
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
deepspeech/models/lm_interface.py
浏览文件 @
9e2773df
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Language model interface."""
"""Language model interface."""
import
argparse
import
argparse
from
deepspeech.decoders.scorers.scorer_interface
import
ScorerInterface
from
deepspeech.decoders.scorers.scorer_interface
import
ScorerInterface
from
deepspeech.utils.dynamic_import
import
dynamic_import
from
deepspeech.utils.dynamic_import
import
dynamic_import
class
LMInterface
(
ScorerInterface
):
class
LMInterface
(
ScorerInterface
):
"""LM Interface
for ESPnet
model implementation."""
"""LM Interface model implementation."""
@
staticmethod
@
staticmethod
def
add_arguments
(
parser
):
def
add_arguments
(
parser
):
...
@@ -52,6 +65,7 @@ predefined_lms = {
...
@@ -52,6 +65,7 @@ predefined_lms = {
"transformer"
:
"deepspeech.models.lm.transformer:TransformerLM"
,
"transformer"
:
"deepspeech.models.lm.transformer:TransformerLM"
,
}
}
def
dynamic_import_lm
(
module
):
def
dynamic_import_lm
(
module
):
"""Import LM class dynamically.
"""Import LM class dynamically.
...
@@ -63,7 +77,6 @@ def dynamic_import_lm(module):
...
@@ -63,7 +77,6 @@ def dynamic_import_lm(module):
"""
"""
model_class
=
dynamic_import
(
module
,
predefined_lms
)
model_class
=
dynamic_import
(
module
,
predefined_lms
)
assert
issubclass
(
assert
issubclass
(
model_class
,
model_class
,
LMInterface
LMInterface
),
f
"
{
module
}
does not implement LMInterface"
),
f
"
{
module
}
does not implement LMInterface"
return
model_class
return
model_class
deepspeech/models/st_interface.py
0 → 100644
浏览文件 @
9e2773df
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ST Interface module."""
from
.asr_interface
import
ASRInterface
from
deepspeech.utils.dynamic_import
import
dynamic_import
class
STInterface
(
ASRInterface
):
"""ST Interface model implementation.
NOTE: This class is inherited from ASRInterface to enable joint translation
and recognition when performing multi-task learning with the ASR task.
"""
def
translate
(
self
,
x
,
trans_args
,
char_list
=
None
,
rnnlm
=
None
,
ensemble_models
=
[]):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace trans_args: argment namespace contraining options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise
NotImplementedError
(
"translate method is not implemented"
)
def
translate_batch
(
self
,
x
,
trans_args
,
char_list
=
None
,
rnnlm
=
None
):
"""Beam search implementation for batch.
:param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace trans_args: argument namespace containing options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise
NotImplementedError
(
"Batch decoding is not supported yet."
)
predefined_st
=
{
"transformer"
:
"deepspeech.models.u2_st:U2STModel"
,
}
def
dynamic_import_st
(
module
):
"""Import ST models dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_st`
Returns:
type: ST class
"""
model_class
=
dynamic_import
(
module
,
predefined_st
)
assert
issubclass
(
model_class
,
STInterface
),
f
"
{
module
}
does not implement STInterface"
return
model_class
deepspeech/models/u2_st/__init__.py
0 → 100644
浏览文件 @
9e2773df
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.u2_st
import
U2STInferModel
from
.u2_st
import
U2STModel
deepspeech/models/u2_st.py
→
deepspeech/models/u2_st
/u2_st
.py
浏览文件 @
9e2773df
文件已移动
deepspeech/modules/ctc.py
浏览文件 @
9e2773df
...
@@ -11,9 +11,10 @@
...
@@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
Union
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
typing
import
Union
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
functional
as
F
from
typeguard
import
check_argument_types
from
typeguard
import
check_argument_types
...
...
deepspeech/modules/embedding.py
浏览文件 @
9e2773df
...
@@ -22,7 +22,10 @@ from deepspeech.utils.log import Log
...
@@ -22,7 +22,10 @@ from deepspeech.utils.log import Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"NoPositionalEncoding"
,
"PositionalEncoding"
,
"RelPositionalEncoding"
]
__all__
=
[
"NoPositionalEncoding"
,
"PositionalEncoding"
,
"RelPositionalEncoding"
]
class
NoPositionalEncoding
(
nn
.
Layer
):
class
NoPositionalEncoding
(
nn
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
deepspeech/modules/encoder.py
浏览文件 @
9e2773df
...
@@ -103,7 +103,7 @@ class BaseEncoder(nn.Layer):
...
@@ -103,7 +103,7 @@ class BaseEncoder(nn.Layer):
pos_enc_class
=
PositionalEncoding
pos_enc_class
=
PositionalEncoding
elif
pos_enc_layer_type
==
"rel_pos"
:
elif
pos_enc_layer_type
==
"rel_pos"
:
pos_enc_class
=
RelPositionalEncoding
pos_enc_class
=
RelPositionalEncoding
elif
pos_enc_layer_type
is
"no_pos"
:
elif
pos_enc_layer_type
==
"no_pos"
:
pos_enc_class
=
NoPositionalEncoding
pos_enc_class
=
NoPositionalEncoding
else
:
else
:
raise
ValueError
(
"unknown pos_enc_layer: "
+
pos_enc_layer_type
)
raise
ValueError
(
"unknown pos_enc_layer: "
+
pos_enc_layer_type
)
...
...
deepspeech/modules/subsampling.py
浏览文件 @
9e2773df
...
@@ -60,8 +60,8 @@ class LinearNoSubsampling(BaseSubsampling):
...
@@ -60,8 +60,8 @@ class LinearNoSubsampling(BaseSubsampling):
self
.
out
=
nn
.
Sequential
(
self
.
out
=
nn
.
Sequential
(
nn
.
Linear
(
idim
,
odim
),
nn
.
Linear
(
idim
,
odim
),
nn
.
LayerNorm
(
odim
,
epsilon
=
1e-12
),
nn
.
LayerNorm
(
odim
,
epsilon
=
1e-12
),
nn
.
Dropout
(
dropout_rate
),
nn
.
Dropout
(
dropout_rate
),
nn
.
ReLU
(),)
nn
.
ReLU
(),
)
self
.
right_context
=
0
self
.
right_context
=
0
self
.
subsampling_rate
=
1
self
.
subsampling_rate
=
1
...
@@ -83,10 +83,12 @@ class LinearNoSubsampling(BaseSubsampling):
...
@@ -83,10 +83,12 @@ class LinearNoSubsampling(BaseSubsampling):
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
return
x
,
pos_emb
,
x_mask
return
x
,
pos_emb
,
x_mask
class
Conv2dSubsampling
(
BaseSubsampling
):
class
Conv2dSubsampling
(
BaseSubsampling
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
class
Conv2dSubsampling4
(
Conv2dSubsampling
):
class
Conv2dSubsampling4
(
Conv2dSubsampling
):
"""Convolutional 2D subsampling (to 1/4 length)."""
"""Convolutional 2D subsampling (to 1/4 length)."""
...
...
deepspeech/training/extensions/plot.py
0 → 100644
浏览文件 @
9e2773df
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
os
import
numpy
as
np
from
.
import
extension
class
PlotAttentionReport
(
extension
.
Extension
):
"""Plot attention reporter.
Args:
att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):
Function of attention visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device.
reverse (bool): If True, input and output length are reversed.
ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
"""
def
__init__
(
self
,
att_vis_fn
,
data
,
outdir
,
converter
,
transform
,
device
,
reverse
=
False
,
ikey
=
"input"
,
iaxis
=
0
,
okey
=
"output"
,
oaxis
=
0
,
subsampling_factor
=
1
,
):
self
.
att_vis_fn
=
att_vis_fn
self
.
data
=
copy
.
deepcopy
(
data
)
self
.
data_dict
=
{
k
:
v
for
k
,
v
in
copy
.
deepcopy
(
data
)}
# key is utterance ID
self
.
outdir
=
outdir
self
.
converter
=
converter
self
.
transform
=
transform
self
.
device
=
device
self
.
reverse
=
reverse
self
.
ikey
=
ikey
self
.
iaxis
=
iaxis
self
.
okey
=
okey
self
.
oaxis
=
oaxis
self
.
factor
=
subsampling_factor
if
not
os
.
path
.
exists
(
self
.
outdir
):
os
.
makedirs
(
self
.
outdir
)
def
__call__
(
self
,
trainer
):
"""Plot and save image file of att_ws matrix."""
att_ws
,
uttid_list
=
self
.
get_attention_weights
()
if
isinstance
(
att_ws
,
list
):
# multi-encoder case
num_encs
=
len
(
att_ws
)
-
1
# atts
for
i
in
range
(
num_encs
):
for
idx
,
att_w
in
enumerate
(
att_ws
[
i
]):
filename
=
"%s/%s.ep.{.updater.epoch}.att%d.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.att%d.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
np
.
save
(
np_filename
.
format
(
trainer
),
att_w
)
self
.
_plot_and_save_attention
(
att_w
,
filename
.
format
(
trainer
))
# han
for
idx
,
att_w
in
enumerate
(
att_ws
[
num_encs
]):
filename
=
"%s/%s.ep.{.updater.epoch}.han.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.han.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
np
.
save
(
np_filename
.
format
(
trainer
),
att_w
)
self
.
_plot_and_save_attention
(
att_w
,
filename
.
format
(
trainer
),
han_mode
=
True
)
else
:
for
idx
,
att_w
in
enumerate
(
att_ws
):
filename
=
"%s/%s.ep.{.updater.epoch}.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
np
.
save
(
np_filename
.
format
(
trainer
),
att_w
)
self
.
_plot_and_save_attention
(
att_w
,
filename
.
format
(
trainer
))
def
log_attentions
(
self
,
logger
,
step
):
"""Add image files of att_ws matrix to the tensorboard."""
att_ws
,
uttid_list
=
self
.
get_attention_weights
()
if
isinstance
(
att_ws
,
list
):
# multi-encoder case
num_encs
=
len
(
att_ws
)
-
1
# atts
for
i
in
range
(
num_encs
):
for
idx
,
att_w
in
enumerate
(
att_ws
[
i
]):
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
plot
=
self
.
draw_attention_plot
(
att_w
)
logger
.
add_figure
(
"%s_att%d"
%
(
uttid_list
[
idx
],
i
+
1
),
plot
.
gcf
(),
step
,
)
# han
for
idx
,
att_w
in
enumerate
(
att_ws
[
num_encs
]):
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
plot
=
self
.
draw_han_plot
(
att_w
)
logger
.
add_figure
(
"%s_han"
%
(
uttid_list
[
idx
]),
plot
.
gcf
(),
step
,
)
else
:
for
idx
,
att_w
in
enumerate
(
att_ws
):
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
plot
=
self
.
draw_attention_plot
(
att_w
)
logger
.
add_figure
(
"%s"
%
(
uttid_list
[
idx
]),
plot
.
gcf
(),
step
)
def
get_attention_weights
(
self
):
"""Return attention weights.
Returns:
numpy.ndarray: attention weights. float. Its shape would be
differ from backend.
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2)
other case => (B, Lmax, Tmax).
* chainer-> (B, Lmax, Tmax)
"""
return_batch
,
uttid_list
=
self
.
transform
(
self
.
data
,
return_uttid
=
True
)
batch
=
self
.
converter
([
return_batch
],
self
.
device
)
if
isinstance
(
batch
,
tuple
):
att_ws
=
self
.
att_vis_fn
(
*
batch
)
else
:
att_ws
=
self
.
att_vis_fn
(
**
batch
)
return
att_ws
,
uttid_list
def
trim_attention_weight
(
self
,
uttid
,
att_w
):
"""Transform attention matrix with regard to self.reverse."""
if
self
.
reverse
:
enc_key
,
enc_axis
=
self
.
okey
,
self
.
oaxis
dec_key
,
dec_axis
=
self
.
ikey
,
self
.
iaxis
else
:
enc_key
,
enc_axis
=
self
.
ikey
,
self
.
iaxis
dec_key
,
dec_axis
=
self
.
okey
,
self
.
oaxis
dec_len
=
int
(
self
.
data_dict
[
uttid
][
dec_key
][
dec_axis
][
"shape"
][
0
])
enc_len
=
int
(
self
.
data_dict
[
uttid
][
enc_key
][
enc_axis
][
"shape"
][
0
])
if
self
.
factor
>
1
:
enc_len
//=
self
.
factor
if
len
(
att_w
.
shape
)
==
3
:
att_w
=
att_w
[:,
:
dec_len
,
:
enc_len
]
else
:
att_w
=
att_w
[:
dec_len
,
:
enc_len
]
return
att_w
def
draw_attention_plot
(
self
,
att_w
):
"""Plot the att_w matrix.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import
matplotlib
matplotlib
.
use
(
"Agg"
)
import
matplotlib.pyplot
as
plt
plt
.
clf
()
att_w
=
att_w
.
astype
(
np
.
float32
)
if
len
(
att_w
.
shape
)
==
3
:
for
h
,
aw
in
enumerate
(
att_w
,
1
):
plt
.
subplot
(
1
,
len
(
att_w
),
h
)
plt
.
imshow
(
aw
,
aspect
=
"auto"
)
plt
.
xlabel
(
"Encoder Index"
)
plt
.
ylabel
(
"Decoder Index"
)
else
:
plt
.
imshow
(
att_w
,
aspect
=
"auto"
)
plt
.
xlabel
(
"Encoder Index"
)
plt
.
ylabel
(
"Decoder Index"
)
plt
.
tight_layout
()
return
plt
def
draw_han_plot
(
self
,
att_w
):
"""Plot the att_w matrix for hierarchical attention.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import
matplotlib
matplotlib
.
use
(
"Agg"
)
import
matplotlib.pyplot
as
plt
plt
.
clf
()
if
len
(
att_w
.
shape
)
==
3
:
for
h
,
aw
in
enumerate
(
att_w
,
1
):
legends
=
[]
plt
.
subplot
(
1
,
len
(
att_w
),
h
)
for
i
in
range
(
aw
.
shape
[
1
]):
plt
.
plot
(
aw
[:,
i
])
legends
.
append
(
"Att{}"
.
format
(
i
))
plt
.
ylim
([
0
,
1.0
])
plt
.
xlim
([
0
,
aw
.
shape
[
0
]])
plt
.
grid
(
True
)
plt
.
ylabel
(
"Attention Weight"
)
plt
.
xlabel
(
"Decoder Index"
)
plt
.
legend
(
legends
)
else
:
legends
=
[]
for
i
in
range
(
att_w
.
shape
[
1
]):
plt
.
plot
(
att_w
[:,
i
])
legends
.
append
(
"Att{}"
.
format
(
i
))
plt
.
ylim
([
0
,
1.0
])
plt
.
xlim
([
0
,
att_w
.
shape
[
0
]])
plt
.
grid
(
True
)
plt
.
ylabel
(
"Attention Weight"
)
plt
.
xlabel
(
"Decoder Index"
)
plt
.
legend
(
legends
)
plt
.
tight_layout
()
return
plt
def
_plot_and_save_attention
(
self
,
att_w
,
filename
,
han_mode
=
False
):
if
han_mode
:
plt
=
self
.
draw_han_plot
(
att_w
)
else
:
plt
=
self
.
draw_attention_plot
(
att_w
)
plt
.
savefig
(
filename
)
plt
.
close
()
class
PlotCTCReport
(
extension
.
Extension
):
"""Plot CTC reporter.
Args:
ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):
Function of CTC visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device.
reverse (bool): If True, input and output length are reversed.
ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
"""
def
__init__
(
self
,
ctc_vis_fn
,
data
,
outdir
,
converter
,
transform
,
device
,
reverse
=
False
,
ikey
=
"input"
,
iaxis
=
0
,
okey
=
"output"
,
oaxis
=
0
,
subsampling_factor
=
1
,
):
self
.
ctc_vis_fn
=
ctc_vis_fn
self
.
data
=
copy
.
deepcopy
(
data
)
self
.
data_dict
=
{
k
:
v
for
k
,
v
in
copy
.
deepcopy
(
data
)}
# key is utterance ID
self
.
outdir
=
outdir
self
.
converter
=
converter
self
.
transform
=
transform
self
.
device
=
device
self
.
reverse
=
reverse
self
.
ikey
=
ikey
self
.
iaxis
=
iaxis
self
.
okey
=
okey
self
.
oaxis
=
oaxis
self
.
factor
=
subsampling_factor
if
not
os
.
path
.
exists
(
self
.
outdir
):
os
.
makedirs
(
self
.
outdir
)
def
__call__
(
self
,
trainer
):
"""Plot and save image file of ctc prob."""
ctc_probs
,
uttid_list
=
self
.
get_ctc_probs
()
if
isinstance
(
ctc_probs
,
list
):
# multi-encoder case
num_encs
=
len
(
ctc_probs
)
-
1
for
i
in
range
(
num_encs
):
for
idx
,
ctc_prob
in
enumerate
(
ctc_probs
[
i
]):
filename
=
"%s/%s.ep.{.updater.epoch}.ctc%d.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
ctc_prob
=
self
.
trim_ctc_prob
(
uttid_list
[
idx
],
ctc_prob
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.ctc%d.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
np
.
save
(
np_filename
.
format
(
trainer
),
ctc_prob
)
self
.
_plot_and_save_ctc
(
ctc_prob
,
filename
.
format
(
trainer
))
else
:
for
idx
,
ctc_prob
in
enumerate
(
ctc_probs
):
filename
=
"%s/%s.ep.{.updater.epoch}.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
ctc_prob
=
self
.
trim_ctc_prob
(
uttid_list
[
idx
],
ctc_prob
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
np
.
save
(
np_filename
.
format
(
trainer
),
ctc_prob
)
self
.
_plot_and_save_ctc
(
ctc_prob
,
filename
.
format
(
trainer
))
def
log_ctc_probs
(
self
,
logger
,
step
):
"""Add image files of ctc probs to the tensorboard."""
ctc_probs
,
uttid_list
=
self
.
get_ctc_probs
()
if
isinstance
(
ctc_probs
,
list
):
# multi-encoder case
num_encs
=
len
(
ctc_probs
)
-
1
for
i
in
range
(
num_encs
):
for
idx
,
ctc_prob
in
enumerate
(
ctc_probs
[
i
]):
ctc_prob
=
self
.
trim_ctc_prob
(
uttid_list
[
idx
],
ctc_prob
)
plot
=
self
.
draw_ctc_plot
(
ctc_prob
)
logger
.
add_figure
(
"%s_ctc%d"
%
(
uttid_list
[
idx
],
i
+
1
),
plot
.
gcf
(),
step
,
)
else
:
for
idx
,
ctc_prob
in
enumerate
(
ctc_probs
):
ctc_prob
=
self
.
trim_ctc_prob
(
uttid_list
[
idx
],
ctc_prob
)
plot
=
self
.
draw_ctc_plot
(
ctc_prob
)
logger
.
add_figure
(
"%s"
%
(
uttid_list
[
idx
]),
plot
.
gcf
(),
step
)
def
get_ctc_probs
(
self
):
"""Return CTC probs.
Returns:
numpy.ndarray: CTC probs. float. Its shape would be
differ from backend. (B, Tmax, vocab).
"""
return_batch
,
uttid_list
=
self
.
transform
(
self
.
data
,
return_uttid
=
True
)
batch
=
self
.
converter
([
return_batch
],
self
.
device
)
if
isinstance
(
batch
,
tuple
):
probs
=
self
.
ctc_vis_fn
(
*
batch
)
else
:
probs
=
self
.
ctc_vis_fn
(
**
batch
)
return
probs
,
uttid_list
def
trim_ctc_prob
(
self
,
uttid
,
prob
):
"""Trim CTC posteriors accoding to input lengths."""
enc_len
=
int
(
self
.
data_dict
[
uttid
][
self
.
ikey
][
self
.
iaxis
][
"shape"
][
0
])
if
self
.
factor
>
1
:
enc_len
//=
self
.
factor
prob
=
prob
[:
enc_len
]
return
prob
def
draw_ctc_plot
(
self
,
ctc_prob
):
"""Plot the ctc_prob matrix.
Returns:
matplotlib.pyplot: pyplot object with CTC prob matrix image.
"""
import
matplotlib
matplotlib
.
use
(
"Agg"
)
import
matplotlib.pyplot
as
plt
ctc_prob
=
ctc_prob
.
astype
(
np
.
float32
)
plt
.
clf
()
topk_ids
=
np
.
argsort
(
ctc_prob
,
axis
=
1
)
n_frames
,
vocab
=
ctc_prob
.
shape
times_probs
=
np
.
arange
(
n_frames
)
plt
.
figure
(
figsize
=
(
20
,
8
))
# NOTE: index 0 is reserved for blank
for
idx
in
set
(
topk_ids
.
reshape
(
-
1
).
tolist
()):
if
idx
==
0
:
plt
.
plot
(
times_probs
,
ctc_prob
[:,
0
],
":"
,
label
=
"<blank>"
,
color
=
"grey"
)
else
:
plt
.
plot
(
times_probs
,
ctc_prob
[:,
idx
])
plt
.
xlabel
(
u
"Input [frame]"
,
fontsize
=
12
)
plt
.
ylabel
(
"Posteriors"
,
fontsize
=
12
)
plt
.
xticks
(
list
(
range
(
0
,
int
(
n_frames
)
+
1
,
10
)))
plt
.
yticks
(
list
(
range
(
0
,
2
,
1
)))
plt
.
tight_layout
()
return
plt
def
_plot_and_save_ctc
(
self
,
ctc_prob
,
filename
):
plt
=
self
.
draw_ctc_plot
(
ctc_prob
)
plt
.
savefig
(
filename
)
plt
.
close
()
deepspeech/training/triggers/__init__.py
浏览文件 @
9e2773df
...
@@ -11,18 +11,3 @@
...
@@ -11,18 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.interval_trigger
import
IntervalTrigger
def
never_fail_trigger
(
trainer
):
return
False
def
get_trigger
(
trigger
):
if
trigger
is
None
:
return
never_fail_trigger
if
callable
(
trigger
):
return
trigger
else
:
trigger
=
IntervalTrigger
(
*
trigger
)
return
trigger
deepspeech/training/triggers/compare_value_trigger.py
0 → 100644
浏览文件 @
9e2773df
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
..reporter
import
DictSummary
from
.utils
import
get_trigger
class
CompareValueTrigger
():
"""Trigger invoked when key value getting bigger or lower than before.
Args:
key (str) : Key of value.
compare_fn ((float, float) -> bool) : Function to compare the values.
trigger (tuple(int, str)) : Trigger that decide the comparison interval.
"""
def
__init__
(
self
,
key
,
compare_fn
,
trigger
=
(
1
,
"epoch"
)):
self
.
_key
=
key
self
.
_best_value
=
None
self
.
_interval_trigger
=
get_trigger
(
trigger
)
self
.
_init_summary
()
self
.
_compare_fn
=
compare_fn
def
__call__
(
self
,
trainer
):
"""Get value related to the key and compare with current value."""
observation
=
trainer
.
observation
summary
=
self
.
_summary
key
=
self
.
_key
if
key
in
observation
:
summary
.
add
({
key
:
observation
[
key
]})
if
not
self
.
_interval_trigger
(
trainer
):
return
False
stats
=
summary
.
compute_mean
()
value
=
float
(
stats
[
key
])
# copy to CPU
self
.
_init_summary
()
if
self
.
_best_value
is
None
:
# initialize best value
self
.
_best_value
=
value
return
False
elif
self
.
_compare_fn
(
self
.
_best_value
,
value
):
return
True
else
:
self
.
_best_value
=
value
return
False
def
_init_summary
(
self
):
self
.
_summary
=
DictSummary
()
deepspeech/training/triggers/time_trigger.py
浏览文件 @
9e2773df
...
@@ -30,3 +30,12 @@ class TimeTrigger():
...
@@ -30,3 +30,12 @@ class TimeTrigger():
return
True
return
True
else
:
else
:
return
False
return
False
def
state_dict
(
self
):
state_dict
=
{
"next_time"
:
self
.
_next_time
,
}
return
state_dict
def
set_state_dict
(
self
,
state_dict
):
self
.
_next_time
=
state_dict
[
'next_time'
]
deepspeech/training/triggers/utils.py
0 → 100644
浏览文件 @
9e2773df
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.interval_trigger
import
IntervalTrigger
def
never_fail_trigger
(
trainer
):
return
False
def
get_trigger
(
trigger
):
if
trigger
is
None
:
return
never_fail_trigger
if
callable
(
trigger
):
return
trigger
else
:
trigger
=
IntervalTrigger
(
*
trigger
)
return
trigger
deepspeech/utils/asr_utils.py
浏览文件 @
9e2773df
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
json
import
json
import
numpy
as
np
import
numpy
as
np
__all__
=
[
"label_smoothing_dist"
]
__all__
=
[
"label_smoothing_dist"
]
...
@@ -33,8 +34,7 @@ def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
...
@@ -33,8 +34,7 @@ def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
if
lsm_type
==
"unigram"
:
if
lsm_type
==
"unigram"
:
assert
transcript
is
not
None
,
(
assert
transcript
is
not
None
,
(
"transcript is required for %s label smoothing"
%
lsm_type
"transcript is required for %s label smoothing"
%
lsm_type
)
)
labelcount
=
np
.
zeros
(
odim
)
labelcount
=
np
.
zeros
(
odim
)
for
k
,
v
in
trans_json
.
items
():
for
k
,
v
in
trans_json
.
items
():
ids
=
np
.
array
([
int
(
n
)
for
n
in
v
[
"output"
][
0
][
"tokenid"
].
split
()])
ids
=
np
.
array
([
int
(
n
)
for
n
in
v
[
"output"
][
0
][
"tokenid"
].
split
()])
...
...
deepspeech/utils/bleu_score.py
浏览文件 @
9e2773df
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
"""This module provides functions to calculate bleu score in different level.
"""This module provides functions to calculate bleu score in different level.
e.g. wer for word-level, cer for char-level.
e.g. wer for word-level, cer for char-level.
"""
"""
import
sacrebleu
import
nltk
import
nltk
import
numpy
as
np
import
numpy
as
np
import
sacrebleu
__all__
=
[
'bleu'
,
'char_bleu'
,
"ErrorCalculator"
]
__all__
=
[
'bleu'
,
'char_bleu'
,
"ErrorCalculator"
]
...
@@ -106,11 +106,14 @@ class ErrorCalculator():
...
@@ -106,11 +106,14 @@ class ErrorCalculator():
# NOTE: padding index (-1) in y_true is used to pad y_hat
# NOTE: padding index (-1) in y_true is used to pad y_hat
# because y_hats is not padded with -1
# because y_hats is not padded with -1
seq_hat
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_hat
[:
ymax
]]
seq_hat
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_hat
[:
ymax
]]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_hat_text
=
""
.
join
(
seq_hat
).
replace
(
self
.
space
,
" "
)
seq_hat_text
=
""
.
join
(
seq_hat
).
replace
(
self
.
space
,
" "
)
seq_hat_text
=
seq_hat_text
.
replace
(
self
.
pad
,
""
)
seq_hat_text
=
seq_hat_text
.
replace
(
self
.
pad
,
""
)
seq_true_text
=
""
.
join
(
seq_true
).
replace
(
self
.
space
,
" "
)
seq_true_text
=
""
.
join
(
seq_true
).
replace
(
self
.
space
,
" "
)
seqs_hat
.
append
(
seq_hat_text
)
seqs_hat
.
append
(
seq_hat_text
)
seqs_true
.
append
(
seq_true_text
)
seqs_true
.
append
(
seq_true_text
)
bleu
=
nltk
.
bleu_score
.
corpus_bleu
([[
ref
]
for
ref
in
seqs_true
],
seqs_hat
)
bleu
=
nltk
.
bleu_score
.
corpus_bleu
([[
ref
]
for
ref
in
seqs_true
],
seqs_hat
)
return
bleu
*
100
return
bleu
*
100
deepspeech/utils/error_rate.py
浏览文件 @
9e2773df
...
@@ -14,11 +14,10 @@
...
@@ -14,11 +14,10 @@
"""This module provides functions to calculate error rate in different level.
"""This module provides functions to calculate error rate in different level.
e.g. wer for word-level, cer for char-level.
e.g. wer for word-level, cer for char-level.
"""
"""
from
itertools
import
groupby
import
editdistance
import
editdistance
import
numpy
as
np
import
numpy
as
np
import
logging
import
sys
from
itertools
import
groupby
__all__
=
[
'word_errors'
,
'char_errors'
,
'wer'
,
'cer'
,
"ErrorCalculator"
]
__all__
=
[
'word_errors'
,
'char_errors'
,
'wer'
,
'cer'
,
"ErrorCalculator"
]
...
@@ -225,9 +224,12 @@ class ErrorCalculator():
...
@@ -225,9 +224,12 @@ class ErrorCalculator():
:return:
:return:
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
char_list
,
sym_space
,
sym_blank
,
report_cer
=
False
,
report_wer
=
False
char_list
,
):
sym_space
,
sym_blank
,
report_cer
=
False
,
report_wer
=
False
):
"""Construct an ErrorCalculator object."""
"""Construct an ErrorCalculator object."""
super
().
__init__
()
super
().
__init__
()
...
@@ -317,7 +319,9 @@ class ErrorCalculator():
...
@@ -317,7 +319,9 @@ class ErrorCalculator():
ymax
=
eos_true
[
0
]
if
len
(
eos_true
)
>
0
else
len
(
y_true
)
ymax
=
eos_true
[
0
]
if
len
(
eos_true
)
>
0
else
len
(
y_true
)
# NOTE: padding index (-1) in y_true is used to pad y_hat
# NOTE: padding index (-1) in y_true is used to pad y_hat
seq_hat
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_hat
[:
ymax
]]
seq_hat
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_hat
[:
ymax
]]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_hat_text
=
""
.
join
(
seq_hat
).
replace
(
self
.
space
,
" "
)
seq_hat_text
=
""
.
join
(
seq_hat
).
replace
(
self
.
space
,
" "
)
seq_hat_text
=
seq_hat_text
.
replace
(
self
.
blank
,
""
)
seq_hat_text
=
seq_hat_text
.
replace
(
self
.
blank
,
""
)
seq_true_text
=
""
.
join
(
seq_true
).
replace
(
self
.
space
,
" "
)
seq_true_text
=
""
.
join
(
seq_true
).
replace
(
self
.
space
,
" "
)
...
...
setup.py
浏览文件 @
9e2773df
...
@@ -15,7 +15,6 @@ import contextlib
...
@@ -15,7 +15,6 @@ import contextlib
import
inspect
import
inspect
import
io
import
io
import
os
import
os
import
re
import
subprocess
as
sp
import
subprocess
as
sp
import
sys
import
sys
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -84,7 +83,7 @@ def _post_install(install_lib_dir):
...
@@ -84,7 +83,7 @@ def _post_install(install_lib_dir):
tools_extrs_dir
=
HERE
/
'tools/extras'
tools_extrs_dir
=
HERE
/
'tools/extras'
with
pushd
(
tools_extrs_dir
):
with
pushd
(
tools_extrs_dir
):
print
(
os
.
getcwd
())
print
(
os
.
getcwd
())
check_call
(
f
"./install_autolog.sh"
)
check_call
(
"./install_autolog.sh"
)
print
(
"autolog install."
)
print
(
"autolog install."
)
# ctcdecoder
# ctcdecoder
...
...
utils/json2trn.py
浏览文件 @
9e2773df
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
# 2018 Xuankai Chang (Shanghai Jiao Tong University)
# 2018 Xuankai Chang (Shanghai Jiao Tong University)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import
argparse
import
argparse
import
json
import
logging
import
logging
import
sys
import
sys
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录