Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
e4ecfb22
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
e4ecfb22
编写于
10月 25, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format code
上级
3fa2e44e
变更
24
隐藏空白更改
内联
并排
Showing
24 changed file
with
229 addition
and
180 deletion
+229
-180
deepspeech/__init__.py
deepspeech/__init__.py
+0
-3
deepspeech/decoders/recog.py
deepspeech/decoders/recog.py
+9
-16
deepspeech/decoders/recog_bin.py
deepspeech/decoders/recog_bin.py
+0
-2
deepspeech/models/asr_interface.py
deepspeech/models/asr_interface.py
+1
-1
deepspeech/models/lm/__init__.py
deepspeech/models/lm/__init__.py
+13
-0
deepspeech/models/lm/transformer.py
deepspeech/models/lm/transformer.py
+32
-32
deepspeech/models/lm_interface.py
deepspeech/models/lm_interface.py
+17
-4
deepspeech/models/st_interface.py
deepspeech/models/st_interface.py
+24
-8
deepspeech/models/u2_st/__init__.py
deepspeech/models/u2_st/__init__.py
+14
-1
deepspeech/modules/ctc.py
deepspeech/modules/ctc.py
+2
-1
deepspeech/modules/embedding.py
deepspeech/modules/embedding.py
+4
-1
deepspeech/modules/encoder.py
deepspeech/modules/encoder.py
+5
-5
deepspeech/modules/subsampling.py
deepspeech/modules/subsampling.py
+4
-2
deepspeech/training/extensions/plot.py
deepspeech/training/extensions/plot.py
+64
-83
deepspeech/training/extensions/visualizer.py
deepspeech/training/extensions/visualizer.py
+1
-1
deepspeech/training/triggers/compare_value_trigger.py
deepspeech/training/triggers/compare_value_trigger.py
+4
-3
deepspeech/training/triggers/limit_trigger.py
deepspeech/training/triggers/limit_trigger.py
+1
-1
deepspeech/training/triggers/time_trigger.py
deepspeech/training/triggers/time_trigger.py
+1
-1
deepspeech/training/triggers/utils.py
deepspeech/training/triggers/utils.py
+13
-0
deepspeech/utils/asr_utils.py
deepspeech/utils/asr_utils.py
+2
-2
deepspeech/utils/bleu_score.py
deepspeech/utils/bleu_score.py
+6
-3
deepspeech/utils/error_rate.py
deepspeech/utils/error_rate.py
+11
-7
setup.py
setup.py
+1
-2
utils/json2trn.py
utils/json2trn.py
+0
-1
未找到文件。
deepspeech/__init__.py
浏览文件 @
e4ecfb22
...
...
@@ -355,7 +355,6 @@ if not hasattr(paddle.Tensor, 'tolist'):
"register user tolist to paddle.Tensor, remove this when fixed!"
)
setattr
(
paddle
.
Tensor
,
'tolist'
,
tolist
)
########### hack paddle.nn #############
from
paddle.nn
import
Layer
from
typing
import
Optional
...
...
@@ -506,5 +505,3 @@ if not hasattr(paddle.nn, 'LayerDict'):
logger
.
debug
(
"register user LayerDict to paddle.nn, remove this when fixed!"
)
setattr
(
paddle
.
nn
,
'LayerDict'
,
LayerDict
)
deepspeech/decoders/recog.py
浏览文件 @
e4ecfb22
...
...
@@ -12,12 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`."""
import
json
from
pathlib
import
Path
import
jsonlines
import
paddle
import
yaml
from
yacs.config
import
CfgNode
from
.beam_search
import
BatchBeamSearch
...
...
@@ -79,8 +75,7 @@ def recog_v2(args):
sort_in_input_length
=
False
,
preprocess_conf
=
confs
.
collator
.
augmentation_config
if
args
.
preprocess_conf
is
None
else
args
.
preprocess_conf
,
preprocess_args
=
{
"train"
:
False
},
)
preprocess_args
=
{
"train"
:
False
},
)
if
args
.
rnnlm
:
lm_args
=
get_model_conf
(
args
.
rnnlm
,
args
.
rnnlm_conf
)
...
...
@@ -113,8 +108,7 @@ def recog_v2(args):
ctc
=
args
.
ctc_weight
,
lm
=
args
.
lm_weight
,
ngram
=
args
.
ngram_weight
,
length_bonus
=
args
.
penalty
,
)
length_bonus
=
args
.
penalty
,
)
beam_search
=
BeamSearch
(
beam_size
=
args
.
beam_size
,
vocab_size
=
len
(
char_list
),
...
...
@@ -123,8 +117,7 @@ def recog_v2(args):
sos
=
model
.
sos
,
eos
=
model
.
eos
,
token_list
=
char_list
,
pre_beam_score_key
=
None
if
args
.
ctc_weight
==
1.0
else
"full"
,
)
pre_beam_score_key
=
None
if
args
.
ctc_weight
==
1.0
else
"full"
,
)
# TODO(karita): make all scorers batchfied
if
args
.
batchsize
==
1
:
...
...
@@ -171,9 +164,10 @@ def recog_v2(args):
logger
.
info
(
f
'feat:
{
feat
.
shape
}
'
)
enc
=
model
.
encode
(
paddle
.
to_tensor
(
feat
).
to
(
dtype
))
logger
.
info
(
f
'eout:
{
enc
.
shape
}
'
)
nbest_hyps
=
beam_search
(
x
=
enc
,
maxlenratio
=
args
.
maxlenratio
,
minlenratio
=
args
.
minlenratio
)
nbest_hyps
=
beam_search
(
x
=
enc
,
maxlenratio
=
args
.
maxlenratio
,
minlenratio
=
args
.
minlenratio
)
nbest_hyps
=
[
h
.
asdict
()
for
h
in
nbest_hyps
[:
min
(
len
(
nbest_hyps
),
args
.
nbest
)]
...
...
@@ -183,9 +177,8 @@ def recog_v2(args):
item
=
new_js
[
name
][
'output'
][
0
]
# 1-best
ref
=
item
[
'text'
]
rec_text
=
item
[
'rec_text'
].
replace
(
'▁'
,
' '
).
replace
(
'<eos>'
,
''
).
strip
()
rec_text
=
item
[
'rec_text'
].
replace
(
'▁'
,
' '
).
replace
(
'<eos>'
,
''
).
strip
()
rec_tokenid
=
list
(
map
(
int
,
item
[
'rec_tokenid'
].
split
()))
f
.
write
({
"utt"
:
name
,
...
...
deepspeech/decoders/recog_bin.py
浏览文件 @
e4ecfb22
...
...
@@ -21,8 +21,6 @@ from distutils.util import strtobool
import
configargparse
import
numpy
as
np
from
.recog
import
recog_v2
def
get_parser
():
"""Get default arguments."""
...
...
deepspeech/models/asr_interface.py
浏览文件 @
e4ecfb22
...
...
@@ -110,7 +110,7 @@ class ASRInterface:
@
property
def
ctc_plot_class
(
self
):
"""Get CTC plot class."""
from
deepspeech.training.extensions.plot
import
PlotCTCReport
from
deepspeech.training.extensions.plot
import
PlotCTCReport
return
PlotCTCReport
...
...
deepspeech/models/lm/__init__.py
浏览文件 @
e4ecfb22
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
deepspeech/models/lm/transformer.py
浏览文件 @
e4ecfb22
...
...
@@ -20,11 +20,11 @@ import paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
deepspeech.modules.mask
import
subsequent_mask
from
deepspeech.modules.encoder
import
TransformerEncoder
from
deepspeech.decoders.scorers.scorer_interface
import
BatchScorerInterface
from
deepspeech.models.lm_interface
import
#LMInterface
from
deepspeech.models.lm_interface
import
LMInterface
from
deepspeech.modules.encoder
import
TransformerEncoder
from
deepspeech.modules.mask
import
subsequent_mask
class
TransformerLM
(
nn
.
Layer
,
LMInterface
,
BatchScorerInterface
):
def
__init__
(
...
...
@@ -36,10 +36,10 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
head
:
int
=
2
,
unit
:
int
=
1024
,
layer
:
int
=
4
,
dropout_rate
:
float
=
0.5
,
emb_dropout_rate
:
float
=
0.0
,
att_dropout_rate
:
float
=
0.0
,
tie_weights
:
bool
=
False
,
):
dropout_rate
:
float
=
0.5
,
emb_dropout_rate
:
float
=
0.0
,
att_dropout_rate
:
float
=
0.0
,
tie_weights
:
bool
=
False
,
):
nn
.
Layer
.
__init__
(
self
)
if
pos_enc
==
"sinusoidal"
:
...
...
@@ -89,9 +89,8 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
m
=
subsequent_mask
(
ys_mask
.
size
(
-
1
)).
unsqueeze
(
0
)
return
ys_mask
.
unsqueeze
(
-
2
)
&
m
def
forward
(
self
,
x
:
paddle
.
Tensor
,
t
:
paddle
.
Tensor
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
def
forward
(
self
,
x
:
paddle
.
Tensor
,
t
:
paddle
.
Tensor
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute LM loss value from buffer sequences.
Args:
...
...
@@ -117,7 +116,8 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
emb
=
self
.
embed
(
x
)
h
,
_
=
self
.
encoder
(
emb
,
xlen
)
y
=
self
.
decoder
(
h
)
loss
=
F
.
cross_entropy
(
y
.
view
(
-
1
,
y
.
shape
[
-
1
]),
t
.
view
(
-
1
),
reduction
=
"none"
)
loss
=
F
.
cross_entropy
(
y
.
view
(
-
1
,
y
.
shape
[
-
1
]),
t
.
view
(
-
1
),
reduction
=
"none"
)
mask
=
xm
.
to
(
dtype
=
loss
.
dtype
)
logp
=
loss
*
mask
.
view
(
-
1
)
logp
=
logp
.
sum
()
...
...
@@ -148,16 +148,16 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
emb
=
self
.
embed
(
y
)
h
,
_
,
cache
=
self
.
encoder
.
forward_one_step
(
emb
,
self
.
_target_mask
(
y
),
cache
=
state
)
emb
,
self
.
_target_mask
(
y
),
cache
=
state
)
h
=
self
.
decoder
(
h
[:,
-
1
])
logp
=
h
.
log_softmax
(
axis
=-
1
).
squeeze
(
0
)
return
logp
,
cache
# batch beam search API (see BatchScorerInterface)
def
batch_score
(
self
,
ys
:
paddle
.
Tensor
,
states
:
List
[
Any
],
xs
:
paddle
.
Tensor
)
->
Tuple
[
paddle
.
Tensor
,
List
[
Any
]]:
def
batch_score
(
self
,
ys
:
paddle
.
Tensor
,
states
:
List
[
Any
],
xs
:
paddle
.
Tensor
)
->
Tuple
[
paddle
.
Tensor
,
List
[
Any
]]:
"""Score new token batch (required).
Args:
...
...
@@ -191,13 +191,13 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
# batch decoding
h
,
_
,
states
=
self
.
encoder
.
forward_one_step
(
emb
,
self
.
_target_mask
(
ys
),
cache
=
batch_state
)
emb
,
self
.
_target_mask
(
ys
),
cache
=
batch_state
)
h
=
self
.
decoder
(
h
[:,
-
1
])
logp
=
h
.
log_softmax
(
axi
=-
1
)
# transpose state of [layer, batch] into [batch, layer]
state_list
=
[[
states
[
i
][
b
]
for
i
in
range
(
n_layers
)]
for
b
in
range
(
n_batch
)]
state_list
=
[[
states
[
i
][
b
]
for
i
in
range
(
n_layers
)]
for
b
in
range
(
n_batch
)]
return
logp
,
state_list
...
...
@@ -212,17 +212,17 @@ if __name__ == "__main__":
layer
=
16
,
dropout_rate
=
0.5
,
)
# n_vocab: int,
# pos_enc: str=None,
# embed_unit: int=128,
# att_unit: int=256,
# head: int=2,
# unit: int=1024,
# layer: int=4,
# dropout_rate: float=0.5,
# emb_dropout_rate: float = 0.0,
# att_dropout_rate: float = 0.0,
# tie_weights: bool = False,):
# n_vocab: int,
# pos_enc: str=None,
# embed_unit: int=128,
# att_unit: int=256,
# head: int=2,
# unit: int=1024,
# layer: int=4,
# dropout_rate: float=0.5,
# emb_dropout_rate: float = 0.0,
# att_dropout_rate: float = 0.0,
# tie_weights: bool = False,):
paddle
.
set_device
(
"cpu"
)
model_dict
=
paddle
.
load
(
"transformerLM.pdparams"
)
tlm
.
set_state_dict
(
model_dict
)
...
...
@@ -256,4 +256,4 @@ if __name__ == "__main__":
print("output", output)
#print("cache", cache)
#np.save("output_pd.npy", output)
"""
\ No newline at end of file
"""
deepspeech/models/lm_interface.py
浏览文件 @
e4ecfb22
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Language model interface."""
import
argparse
from
deepspeech.decoders.scorers.scorer_interface
import
ScorerInterface
from
deepspeech.utils.dynamic_import
import
dynamic_import
class
LMInterface
(
ScorerInterface
):
"""LM Interface model implementation."""
...
...
@@ -52,6 +65,7 @@ predefined_lms = {
"transformer"
:
"deepspeech.models.lm.transformer:TransformerLM"
,
}
def
dynamic_import_lm
(
module
):
"""Import LM class dynamically.
...
...
@@ -63,7 +77,6 @@ def dynamic_import_lm(module):
"""
model_class
=
dynamic_import
(
module
,
predefined_lms
)
assert
issubclass
(
model_class
,
LMInterface
),
f
"
{
module
}
does not implement LMInterface"
assert
issubclass
(
model_class
,
LMInterface
),
f
"
{
module
}
does not implement LMInterface"
return
model_class
deepspeech/models/st_interface.py
浏览文件 @
e4ecfb22
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ST Interface module."""
import
argparse
from
deepspeech.utils.dynamic_import
import
dynamic_import
from
.asr_interface
import
ASRInterface
from
deepspeech.utils.dynamic_import
import
dynamic_import
class
STInterface
(
ASRInterface
):
"""ST Interface model implementation.
...
...
@@ -13,7 +24,12 @@ class STInterface(ASRInterface):
"""
def
translate
(
self
,
x
,
trans_args
,
char_list
=
None
,
rnnlm
=
None
,
ensemble_models
=
[]):
def
translate
(
self
,
x
,
trans_args
,
char_list
=
None
,
rnnlm
=
None
,
ensemble_models
=
[]):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
...
...
@@ -42,6 +58,7 @@ predefined_st = {
"transformer"
:
"deepspeech.models.u2_st:U2STModel"
,
}
def
dynamic_import_st
(
module
):
"""Import ST models dynamically.
...
...
@@ -53,7 +70,6 @@ def dynamic_import_st(module):
"""
model_class
=
dynamic_import
(
module
,
predefined_st
)
assert
issubclass
(
model_class
,
STInterface
),
f
"
{
module
}
does not implement STInterface"
assert
issubclass
(
model_class
,
STInterface
),
f
"
{
module
}
does not implement STInterface"
return
model_class
deepspeech/models/u2_st/__init__.py
浏览文件 @
e4ecfb22
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.u2_st
import
U2STInferModel
from
.u2_st
import
U2STModel
from
.u2_st
import
U2STInferModel
\ No newline at end of file
deepspeech/modules/ctc.py
浏览文件 @
e4ecfb22
...
...
@@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Union
import
paddle
from
paddle
import
nn
from
typing
import
Union
from
paddle.nn
import
functional
as
F
from
typeguard
import
check_argument_types
...
...
deepspeech/modules/embedding.py
浏览文件 @
e4ecfb22
...
...
@@ -22,7 +22,10 @@ from deepspeech.utils.log import Log
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"NoPositionalEncoding"
,
"PositionalEncoding"
,
"RelPositionalEncoding"
]
__all__
=
[
"NoPositionalEncoding"
,
"PositionalEncoding"
,
"RelPositionalEncoding"
]
class
NoPositionalEncoding
(
nn
.
Layer
):
def
__init__
(
self
,
...
...
deepspeech/modules/encoder.py
浏览文件 @
e4ecfb22
...
...
@@ -24,9 +24,9 @@ from deepspeech.modules.activation import get_activation
from
deepspeech.modules.attention
import
MultiHeadedAttention
from
deepspeech.modules.attention
import
RelPositionMultiHeadedAttention
from
deepspeech.modules.conformer_convolution
import
ConvolutionModule
from
deepspeech.modules.embedding
import
NoPositionalEncoding
from
deepspeech.modules.embedding
import
PositionalEncoding
from
deepspeech.modules.embedding
import
RelPositionalEncoding
from
deepspeech.modules.embedding
import
NoPositionalEncoding
from
deepspeech.modules.encoder_layer
import
ConformerEncoderLayer
from
deepspeech.modules.encoder_layer
import
TransformerEncoderLayer
from
deepspeech.modules.mask
import
add_optional_chunk_mask
...
...
@@ -103,7 +103,7 @@ class BaseEncoder(nn.Layer):
pos_enc_class
=
PositionalEncoding
elif
pos_enc_layer_type
==
"rel_pos"
:
pos_enc_class
=
RelPositionalEncoding
elif
pos_enc_layer_type
is
"no_pos"
:
elif
pos_enc_layer_type
==
"no_pos"
:
pos_enc_class
=
NoPositionalEncoding
else
:
raise
ValueError
(
"unknown pos_enc_layer: "
+
pos_enc_layer_type
)
...
...
@@ -378,8 +378,7 @@ class TransformerEncoder(BaseEncoder):
self
,
xs
:
paddle
.
Tensor
,
masks
:
paddle
.
Tensor
,
cache
=
None
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
cache
=
None
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Encode input frame.
Args:
...
...
@@ -397,7 +396,8 @@ class TransformerEncoder(BaseEncoder):
if
isinstance
(
self
.
embed
,
Conv2dSubsampling
):
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs
,
pos_emb
,
masks
=
self
.
embed
(
xs
,
masks
.
astype
(
xs
.
dtype
),
offset
=
0
)
xs
,
pos_emb
,
masks
=
self
.
embed
(
xs
,
masks
.
astype
(
xs
.
dtype
),
offset
=
0
)
else
:
xs
=
self
.
embed
(
xs
)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
...
...
deepspeech/modules/subsampling.py
浏览文件 @
e4ecfb22
...
...
@@ -60,8 +60,8 @@ class LinearNoSubsampling(BaseSubsampling):
self
.
out
=
nn
.
Sequential
(
nn
.
Linear
(
idim
,
odim
),
nn
.
LayerNorm
(
odim
,
epsilon
=
1e-12
),
nn
.
Dropout
(
dropout_rate
),
nn
.
ReLU
(),)
nn
.
Dropout
(
dropout_rate
),
nn
.
ReLU
(),
)
self
.
right_context
=
0
self
.
subsampling_rate
=
1
...
...
@@ -83,10 +83,12 @@ class LinearNoSubsampling(BaseSubsampling):
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
return
x
,
pos_emb
,
x_mask
class
Conv2dSubsampling
(
BaseSubsampling
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
class
Conv2dSubsampling4
(
Conv2dSubsampling
):
"""Convolutional 2D subsampling (to 1/4 length)."""
...
...
deepspeech/training/extensions/plot.py
浏览文件 @
e4ecfb22
import
argparse
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
json
import
os
import
shutil
import
tempfile
import
numpy
as
np
import
numpy
as
np
from
.
import
extension
from
..updaters.trainer
import
Trainer
class
PlotAttentionReport
(
extension
.
Extension
):
...
...
@@ -37,20 +44,19 @@ class PlotAttentionReport(extension.Extension):
"""
def
__init__
(
self
,
att_vis_fn
,
data
,
outdir
,
converter
,
transform
,
device
,
reverse
=
False
,
ikey
=
"input"
,
iaxis
=
0
,
okey
=
"output"
,
oaxis
=
0
,
subsampling_factor
=
1
,
):
self
,
att_vis_fn
,
data
,
outdir
,
converter
,
transform
,
device
,
reverse
=
False
,
ikey
=
"input"
,
iaxis
=
0
,
okey
=
"output"
,
oaxis
=
0
,
subsampling_factor
=
1
,
):
self
.
att_vis_fn
=
att_vis_fn
self
.
data
=
copy
.
deepcopy
(
data
)
self
.
data_dict
=
{
k
:
v
for
k
,
v
in
copy
.
deepcopy
(
data
)}
...
...
@@ -77,44 +83,30 @@ class PlotAttentionReport(extension.Extension):
for
i
in
range
(
num_encs
):
for
idx
,
att_w
in
enumerate
(
att_ws
[
i
]):
filename
=
"%s/%s.ep.{.updater.epoch}.att%d.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.att%d.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
np
.
save
(
np_filename
.
format
(
trainer
),
att_w
)
self
.
_plot_and_save_attention
(
att_w
,
filename
.
format
(
trainer
))
self
.
_plot_and_save_attention
(
att_w
,
filename
.
format
(
trainer
))
# han
for
idx
,
att_w
in
enumerate
(
att_ws
[
num_encs
]):
filename
=
"%s/%s.ep.{.updater.epoch}.han.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
self
.
outdir
,
uttid_list
[
idx
],
)
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.han.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
self
.
outdir
,
uttid_list
[
idx
],
)
np
.
save
(
np_filename
.
format
(
trainer
),
att_w
)
self
.
_plot_and_save_attention
(
att_w
,
filename
.
format
(
trainer
),
han_mode
=
True
)
att_w
,
filename
.
format
(
trainer
),
han_mode
=
True
)
else
:
for
idx
,
att_w
in
enumerate
(
att_ws
):
filename
=
"%s/%s.ep.{.updater.epoch}.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
filename
=
"%s/%s.ep.{.updater.epoch}.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
self
.
outdir
,
uttid_list
[
idx
],
)
np
.
save
(
np_filename
.
format
(
trainer
),
att_w
)
self
.
_plot_and_save_attention
(
att_w
,
filename
.
format
(
trainer
))
...
...
@@ -131,8 +123,7 @@ class PlotAttentionReport(extension.Extension):
logger
.
add_figure
(
"%s_att%d"
%
(
uttid_list
[
idx
],
i
+
1
),
plot
.
gcf
(),
step
,
)
step
,
)
# han
for
idx
,
att_w
in
enumerate
(
att_ws
[
num_encs
]):
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
...
...
@@ -140,8 +131,7 @@ class PlotAttentionReport(extension.Extension):
logger
.
add_figure
(
"%s_han"
%
(
uttid_list
[
idx
]),
plot
.
gcf
(),
step
,
)
step
,
)
else
:
for
idx
,
att_w
in
enumerate
(
att_ws
):
att_w
=
self
.
trim_attention_weight
(
uttid_list
[
idx
],
att_w
)
...
...
@@ -286,20 +276,19 @@ class PlotCTCReport(extension.Extension):
"""
def
__init__
(
self
,
ctc_vis_fn
,
data
,
outdir
,
converter
,
transform
,
device
,
reverse
=
False
,
ikey
=
"input"
,
iaxis
=
0
,
okey
=
"output"
,
oaxis
=
0
,
subsampling_factor
=
1
,
):
self
,
ctc_vis_fn
,
data
,
outdir
,
converter
,
transform
,
device
,
reverse
=
False
,
ikey
=
"input"
,
iaxis
=
0
,
okey
=
"output"
,
oaxis
=
0
,
subsampling_factor
=
1
,
):
self
.
ctc_vis_fn
=
ctc_vis_fn
self
.
data
=
copy
.
deepcopy
(
data
)
self
.
data_dict
=
{
k
:
v
for
k
,
v
in
copy
.
deepcopy
(
data
)}
...
...
@@ -325,29 +314,19 @@ class PlotCTCReport(extension.Extension):
for
i
in
range
(
num_encs
):
for
idx
,
ctc_prob
in
enumerate
(
ctc_probs
[
i
]):
filename
=
"%s/%s.ep.{.updater.epoch}.ctc%d.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
ctc_prob
=
self
.
trim_ctc_prob
(
uttid_list
[
idx
],
ctc_prob
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.ctc%d.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
self
.
outdir
,
uttid_list
[
idx
],
i
+
1
,
)
np
.
save
(
np_filename
.
format
(
trainer
),
ctc_prob
)
self
.
_plot_and_save_ctc
(
ctc_prob
,
filename
.
format
(
trainer
))
else
:
for
idx
,
ctc_prob
in
enumerate
(
ctc_probs
):
filename
=
"%s/%s.ep.{.updater.epoch}.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
filename
=
"%s/%s.ep.{.updater.epoch}.png"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
ctc_prob
=
self
.
trim_ctc_prob
(
uttid_list
[
idx
],
ctc_prob
)
np_filename
=
"%s/%s.ep.{.updater.epoch}.npy"
%
(
self
.
outdir
,
uttid_list
[
idx
],
)
self
.
outdir
,
uttid_list
[
idx
],
)
np
.
save
(
np_filename
.
format
(
trainer
),
ctc_prob
)
self
.
_plot_and_save_ctc
(
ctc_prob
,
filename
.
format
(
trainer
))
...
...
@@ -363,8 +342,7 @@ class PlotCTCReport(extension.Extension):
logger
.
add_figure
(
"%s_ctc%d"
%
(
uttid_list
[
idx
],
i
+
1
),
plot
.
gcf
(),
step
,
)
step
,
)
else
:
for
idx
,
ctc_prob
in
enumerate
(
ctc_probs
):
ctc_prob
=
self
.
trim_ctc_prob
(
uttid_list
[
idx
],
ctc_prob
)
...
...
@@ -420,8 +398,11 @@ class PlotCTCReport(extension.Extension):
for
idx
in
set
(
topk_ids
.
reshape
(
-
1
).
tolist
()):
if
idx
==
0
:
plt
.
plot
(
times_probs
,
ctc_prob
[:,
0
],
":"
,
label
=
"<blank>"
,
color
=
"grey"
)
times_probs
,
ctc_prob
[:,
0
],
":"
,
label
=
"<blank>"
,
color
=
"grey"
)
else
:
plt
.
plot
(
times_probs
,
ctc_prob
[:,
idx
])
plt
.
xlabel
(
u
"Input [frame]"
,
fontsize
=
12
)
...
...
@@ -434,4 +415,4 @@ class PlotCTCReport(extension.Extension):
def
_plot_and_save_ctc
(
self
,
ctc_prob
,
filename
):
plt
=
self
.
draw_ctc_plot
(
ctc_prob
)
plt
.
savefig
(
filename
)
plt
.
close
()
\ No newline at end of file
plt
.
close
()
deepspeech/training/extensions/visualizer.py
浏览文件 @
e4ecfb22
...
...
@@ -36,4 +36,4 @@ class VisualDL(extension.Extension):
self
.
writer
.
add_scalar
(
k
,
v
,
step
=
trainer
.
updater
.
state
.
iteration
)
def
finalize
(
self
,
trainer
):
self
.
writer
.
close
()
\ No newline at end of file
self
.
writer
.
close
()
deepspeech/training/triggers/compare_value_trigger.py
浏览文件 @
e4ecfb22
...
...
@@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.utils
import
get_trigger
from
..reporter
import
DictSummary
from
.utils
import
get_trigger
class
CompareValueTrigger
():
"""Trigger invoked when key value getting bigger or lower than before.
...
...
@@ -24,6 +24,7 @@ class CompareValueTrigger():
trigger (tuple(int, str)) : Trigger that decide the comparison interval.
"""
def
__init__
(
self
,
key
,
compare_fn
,
trigger
=
(
1
,
"epoch"
)):
self
.
_key
=
key
self
.
_best_value
=
None
...
...
@@ -57,4 +58,4 @@ class CompareValueTrigger():
return
False
def
_init_summary
(
self
):
self
.
_summary
=
DictSummary
()
\ No newline at end of file
self
.
_summary
=
DictSummary
()
deepspeech/training/triggers/limit_trigger.py
浏览文件 @
e4ecfb22
...
...
@@ -28,4 +28,4 @@ class LimitTrigger():
state
=
trainer
.
updater
.
state
index
=
getattr
(
state
,
self
.
unit
)
fire
=
index
>=
self
.
limit
return
fire
\ No newline at end of file
return
fire
deepspeech/training/triggers/time_trigger.py
浏览文件 @
e4ecfb22
...
...
@@ -38,4 +38,4 @@ class TimeTrigger():
return
state_dict
def
set_state_dict
(
self
,
state_dict
):
self
.
_next_time
=
state_dict
[
'next_time'
]
\ No newline at end of file
self
.
_next_time
=
state_dict
[
'next_time'
]
deepspeech/training/triggers/utils.py
浏览文件 @
e4ecfb22
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.interval_trigger
import
IntervalTrigger
...
...
deepspeech/utils/asr_utils.py
浏览文件 @
e4ecfb22
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
numpy
as
np
__all__
=
[
"label_smoothing_dist"
]
...
...
@@ -33,8 +34,7 @@ def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
if
lsm_type
==
"unigram"
:
assert
transcript
is
not
None
,
(
"transcript is required for %s label smoothing"
%
lsm_type
)
"transcript is required for %s label smoothing"
%
lsm_type
)
labelcount
=
np
.
zeros
(
odim
)
for
k
,
v
in
trans_json
.
items
():
ids
=
np
.
array
([
int
(
n
)
for
n
in
v
[
"output"
][
0
][
"tokenid"
].
split
()])
...
...
deepspeech/utils/bleu_score.py
浏览文件 @
e4ecfb22
...
...
@@ -14,9 +14,9 @@
"""This module provides functions to calculate bleu score in different level.
e.g. wer for word-level, cer for char-level.
"""
import
sacrebleu
import
nltk
import
numpy
as
np
import
sacrebleu
__all__
=
[
'bleu'
,
'char_bleu'
,
"ErrorCalculator"
]
...
...
@@ -106,11 +106,14 @@ class ErrorCalculator():
# NOTE: padding index (-1) in y_true is used to pad y_hat
# because y_hats is not padded with -1
seq_hat
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_hat
[:
ymax
]]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_hat_text
=
""
.
join
(
seq_hat
).
replace
(
self
.
space
,
" "
)
seq_hat_text
=
seq_hat_text
.
replace
(
self
.
pad
,
""
)
seq_true_text
=
""
.
join
(
seq_true
).
replace
(
self
.
space
,
" "
)
seqs_hat
.
append
(
seq_hat_text
)
seqs_true
.
append
(
seq_true_text
)
bleu
=
nltk
.
bleu_score
.
corpus_bleu
([[
ref
]
for
ref
in
seqs_true
],
seqs_hat
)
bleu
=
nltk
.
bleu_score
.
corpus_bleu
([[
ref
]
for
ref
in
seqs_true
],
seqs_hat
)
return
bleu
*
100
deepspeech/utils/error_rate.py
浏览文件 @
e4ecfb22
...
...
@@ -14,11 +14,10 @@
"""This module provides functions to calculate error rate in different level.
e.g. wer for word-level, cer for char-level.
"""
from
itertools
import
groupby
import
editdistance
import
numpy
as
np
import
logging
import
sys
from
itertools
import
groupby
__all__
=
[
'word_errors'
,
'char_errors'
,
'wer'
,
'cer'
,
"ErrorCalculator"
]
...
...
@@ -225,9 +224,12 @@ class ErrorCalculator():
:return:
"""
def
__init__
(
self
,
char_list
,
sym_space
,
sym_blank
,
report_cer
=
False
,
report_wer
=
False
):
def
__init__
(
self
,
char_list
,
sym_space
,
sym_blank
,
report_cer
=
False
,
report_wer
=
False
):
"""Construct an ErrorCalculator object."""
super
().
__init__
()
...
...
@@ -317,7 +319,9 @@ class ErrorCalculator():
ymax
=
eos_true
[
0
]
if
len
(
eos_true
)
>
0
else
len
(
y_true
)
# NOTE: padding index (-1) in y_true is used to pad y_hat
seq_hat
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_hat
[:
ymax
]]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_hat_text
=
""
.
join
(
seq_hat
).
replace
(
self
.
space
,
" "
)
seq_hat_text
=
seq_hat_text
.
replace
(
self
.
blank
,
""
)
seq_true_text
=
""
.
join
(
seq_true
).
replace
(
self
.
space
,
" "
)
...
...
setup.py
浏览文件 @
e4ecfb22
...
...
@@ -15,7 +15,6 @@ import contextlib
import
inspect
import
io
import
os
import
re
import
subprocess
as
sp
import
sys
from
pathlib
import
Path
...
...
@@ -84,7 +83,7 @@ def _post_install(install_lib_dir):
tools_extrs_dir
=
HERE
/
'tools/extras'
with
pushd
(
tools_extrs_dir
):
print
(
os
.
getcwd
())
check_call
(
f
"./install_autolog.sh"
)
check_call
(
"./install_autolog.sh"
)
print
(
"autolog install."
)
# ctcdecoder
...
...
utils/json2trn.py
浏览文件 @
e4ecfb22
...
...
@@ -4,7 +4,6 @@
# 2018 Xuankai Chang (Shanghai Jiao Tong University)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import
argparse
import
json
import
logging
import
sys
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录