Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
dee672a7
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
接近 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dee672a7
编写于
4月 12, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
flake8
上级
64f177cc
变更
68
显示空白变更内容
内联
并排
Showing
68 changed file
with
127 addition
and
247 deletion
+127
-247
.flake8
.flake8
+49
-0
deepspeech/__init__.py
deepspeech/__init__.py
+2
-4
deepspeech/decoders/decoders_deprecated.py
deepspeech/decoders/decoders_deprecated.py
+5
-5
deepspeech/decoders/scorer_deprecated.py
deepspeech/decoders/scorer_deprecated.py
+1
-1
deepspeech/decoders/swig/setup.py
deepspeech/decoders/swig/setup.py
+7
-6
deepspeech/exps/deepspeech2/bin/deploy/client.py
deepspeech/exps/deepspeech2/bin/deploy/client.py
+1
-3
deepspeech/exps/deepspeech2/bin/deploy/runtime.py
deepspeech/exps/deepspeech2/bin/deploy/runtime.py
+2
-8
deepspeech/exps/deepspeech2/bin/deploy/send.py
deepspeech/exps/deepspeech2/bin/deploy/send.py
+0
-2
deepspeech/exps/deepspeech2/bin/deploy/server.py
deepspeech/exps/deepspeech2/bin/deploy/server.py
+1
-5
deepspeech/exps/deepspeech2/bin/export.py
deepspeech/exps/deepspeech2/bin/export.py
+0
-9
deepspeech/exps/deepspeech2/bin/infer.py
deepspeech/exps/deepspeech2/bin/infer.py
+0
-9
deepspeech/exps/deepspeech2/bin/test.py
deepspeech/exps/deepspeech2/bin/test.py
+0
-9
deepspeech/exps/deepspeech2/bin/train.py
deepspeech/exps/deepspeech2/bin/train.py
+0
-6
deepspeech/exps/deepspeech2/bin/tune.py
deepspeech/exps/deepspeech2/bin/tune.py
+1
-5
deepspeech/exps/deepspeech2/model.py
deepspeech/exps/deepspeech2/model.py
+0
-4
deepspeech/exps/u2/bin/export.py
deepspeech/exps/u2/bin/export.py
+0
-8
deepspeech/exps/u2/bin/test.py
deepspeech/exps/u2/bin/test.py
+0
-8
deepspeech/exps/u2/bin/train.py
deepspeech/exps/u2/bin/train.py
+0
-5
deepspeech/exps/u2/model.py
deepspeech/exps/u2/model.py
+0
-6
deepspeech/frontend/audio.py
deepspeech/frontend/audio.py
+0
-1
deepspeech/frontend/augmentor/augmentation.py
deepspeech/frontend/augmentor/augmentation.py
+1
-1
deepspeech/frontend/augmentor/base.py
deepspeech/frontend/augmentor/base.py
+1
-1
deepspeech/frontend/featurizer/audio_featurizer.py
deepspeech/frontend/featurizer/audio_featurizer.py
+1
-3
deepspeech/frontend/featurizer/text_featurizer.py
deepspeech/frontend/featurizer/text_featurizer.py
+0
-1
deepspeech/frontend/utility.py
deepspeech/frontend/utility.py
+0
-8
deepspeech/io/__init__.py
deepspeech/io/__init__.py
+2
-3
deepspeech/io/collator.py
deepspeech/io/collator.py
+0
-1
deepspeech/io/dataset.py
deepspeech/io/dataset.py
+0
-3
deepspeech/io/sampler.py
deepspeech/io/sampler.py
+4
-9
deepspeech/io/utility.py
deepspeech/io/utility.py
+0
-1
deepspeech/models/deepspeech2.py
deepspeech/models/deepspeech2.py
+0
-8
deepspeech/models/u2.py
deepspeech/models/u2.py
+2
-6
deepspeech/modules/__init__.py
deepspeech/modules/__init__.py
+1
-1
deepspeech/modules/activation.py
deepspeech/modules/activation.py
+0
-5
deepspeech/modules/attention.py
deepspeech/modules/attention.py
+0
-1
deepspeech/modules/cmvn.py
deepspeech/modules/cmvn.py
+0
-2
deepspeech/modules/conformer_convolution.py
deepspeech/modules/conformer_convolution.py
+0
-2
deepspeech/modules/conv.py
deepspeech/modules/conv.py
+0
-2
deepspeech/modules/ctc.py
deepspeech/modules/ctc.py
+2
-3
deepspeech/modules/decoder.py
deepspeech/modules/decoder.py
+3
-5
deepspeech/modules/decoder_layer.py
deepspeech/modules/decoder_layer.py
+0
-2
deepspeech/modules/embedding.py
deepspeech/modules/embedding.py
+0
-3
deepspeech/modules/encoder.py
deepspeech/modules/encoder.py
+0
-2
deepspeech/modules/encoder_layer.py
deepspeech/modules/encoder_layer.py
+0
-2
deepspeech/modules/loss.py
deepspeech/modules/loss.py
+0
-1
deepspeech/modules/mask.py
deepspeech/modules/mask.py
+0
-3
deepspeech/modules/positionwise_feed_forward.py
deepspeech/modules/positionwise_feed_forward.py
+0
-2
deepspeech/modules/subsampling.py
deepspeech/modules/subsampling.py
+0
-2
deepspeech/training/__init__.py
deepspeech/training/__init__.py
+0
-2
deepspeech/training/cli.py
deepspeech/training/cli.py
+6
-3
deepspeech/training/scheduler.py
deepspeech/training/scheduler.py
+0
-1
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+0
-3
deepspeech/utils/checkpoint.py
deepspeech/utils/checkpoint.py
+0
-3
deepspeech/utils/error_rate.py
deepspeech/utils/error_rate.py
+3
-3
deepspeech/utils/layer_tools.py
deepspeech/utils/layer_tools.py
+1
-1
deepspeech/utils/mp_tools.py
deepspeech/utils/mp_tools.py
+0
-1
deepspeech/utils/tensor_utils.py
deepspeech/utils/tensor_utils.py
+0
-1
deepspeech/utils/utility.py
deepspeech/utils/utility.py
+0
-1
examples/dataset/aishell/aishell.py
examples/dataset/aishell/aishell.py
+3
-2
examples/dataset/chime3_background/chime3_background.py
examples/dataset/chime3_background/chime3_background.py
+5
-5
examples/dataset/librispeech/librispeech.py
examples/dataset/librispeech/librispeech.py
+0
-1
examples/dataset/mini_librispeech/mini_librispeech.py
examples/dataset/mini_librispeech/mini_librispeech.py
+0
-2
examples/dataset/rir_noise/rir_noise.py
examples/dataset/rir_noise/rir_noise.py
+1
-1
examples/dataset/voxforge/voxforge.py
examples/dataset/voxforge/voxforge.py
+1
-1
tests/deepspeech2_model_test.py
tests/deepspeech2_model_test.py
+2
-2
utils/build_vocab.py
utils/build_vocab.py
+12
-15
utils/format_data.py
utils/format_data.py
+1
-8
utils/utility.py
utils/utility.py
+6
-4
未找到文件。
.flake8
0 → 100644
浏览文件 @
dee672a7
[flake8]
########## OPTIONS ##########
# Set the maximum length that any line (with some exceptions) may be.
max-line-length = 120
################### FILE PATTERNS ##########################
# Provide a comma-separated list of glob patterns to exclude from checks.
exclude =
# git folder
.git,
# python cache
__pycache__,
# Provide a comma-separate list of glob patterns to include for checks.
filename =
*.py
########## RULES ##########
# ERROR CODES
#
# E/W - PEP8 errors/warnings (pycodestyle)
# F - linting errors (pyflakes)
# C - McCabe complexity error (mccabe)
#
# W503 - line break before binary operator
# Specify a list of codes to ignore.
ignore =
W503
E252,E262,E127,E265,E126,E266,E241,E261,E128,E125
W291,W293,W605
E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
EXE001,
# these ignores are from flake8-bugbear; please fix!
B007,B008,
# these ignores are from flake8-comprehensions; please fix!
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
# Specify the list of error codes you wish Flake8 to report.
select =
E,
W,
F,
C
\ No newline at end of file
deepspeech/__init__.py
浏览文件 @
dee672a7
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
import
logging
import
logging
from
typing
import
Union
from
typing
import
Union
from
typing
import
Optional
from
typing
import
List
from
typing
import
List
from
typing
import
Tuple
from
typing
import
Tuple
from
typing
import
Any
from
typing
import
Any
...
@@ -21,7 +20,6 @@ from typing import Any
...
@@ -21,7 +20,6 @@ from typing import Any
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
#TODO(Hui Zhang): remove fluid import
#TODO(Hui Zhang): remove fluid import
from
paddle.fluid
import
core
from
paddle.fluid
import
core
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -242,7 +240,7 @@ def is_broadcastable(shp1, shp2):
...
@@ -242,7 +240,7 @@ def is_broadcastable(shp1, shp2):
def
masked_fill
(
xs
:
paddle
.
Tensor
,
def
masked_fill
(
xs
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
value
:
Union
[
float
,
int
]):
value
:
Union
[
float
,
int
]):
assert
is_broadcastable
(
xs
.
shape
,
mask
.
shape
)
==
True
assert
is_broadcastable
(
xs
.
shape
,
mask
.
shape
)
is
True
bshape
=
paddle
.
broadcast_shape
(
xs
.
shape
,
mask
.
shape
)
bshape
=
paddle
.
broadcast_shape
(
xs
.
shape
,
mask
.
shape
)
mask
=
mask
.
broadcast_to
(
bshape
)
mask
=
mask
.
broadcast_to
(
bshape
)
trues
=
paddle
.
ones_like
(
xs
)
*
value
trues
=
paddle
.
ones_like
(
xs
)
*
value
...
@@ -259,7 +257,7 @@ if not hasattr(paddle.Tensor, 'masked_fill'):
...
@@ -259,7 +257,7 @@ if not hasattr(paddle.Tensor, 'masked_fill'):
def
masked_fill_
(
xs
:
paddle
.
Tensor
,
def
masked_fill_
(
xs
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
mask
:
paddle
.
Tensor
,
value
:
Union
[
float
,
int
]):
value
:
Union
[
float
,
int
]):
assert
is_broadcastable
(
xs
.
shape
,
mask
.
shape
)
==
True
assert
is_broadcastable
(
xs
.
shape
,
mask
.
shape
)
is
True
bshape
=
paddle
.
broadcast_shape
(
xs
.
shape
,
mask
.
shape
)
bshape
=
paddle
.
broadcast_shape
(
xs
.
shape
,
mask
.
shape
)
mask
=
mask
.
broadcast_to
(
bshape
)
mask
=
mask
.
broadcast_to
(
bshape
)
trues
=
paddle
.
ones_like
(
xs
)
*
value
trues
=
paddle
.
ones_like
(
xs
)
*
value
...
...
deepspeech/decoders/decoders_deprecated.py
浏览文件 @
dee672a7
...
@@ -104,14 +104,14 @@ def ctc_beam_search_decoder(probs_seq,
...
@@ -104,14 +104,14 @@ def ctc_beam_search_decoder(probs_seq,
global
ext_nproc_scorer
global
ext_nproc_scorer
ext_scoring_func
=
ext_nproc_scorer
ext_scoring_func
=
ext_nproc_scorer
#
#
initialize
# initialize
# prefix_set_prev: the set containing selected prefixes
# prefix_set_prev: the set containing selected prefixes
# probs_b_prev: prefixes' probability ending with blank in previous step
# probs_b_prev: prefixes' probability ending with blank in previous step
# probs_nb_prev: prefixes' probability ending with non-blank in previous step
# probs_nb_prev: prefixes' probability ending with non-blank in previous step
prefix_set_prev
=
{
'
\t
'
:
1.0
}
prefix_set_prev
=
{
'
\t
'
:
1.0
}
probs_b_prev
,
probs_nb_prev
=
{
'
\t
'
:
1.0
},
{
'
\t
'
:
0.0
}
probs_b_prev
,
probs_nb_prev
=
{
'
\t
'
:
1.0
},
{
'
\t
'
:
0.0
}
#
#
extend prefix in loop
# extend prefix in loop
for
time_step
in
range
(
len
(
probs_seq
)):
for
time_step
in
range
(
len
(
probs_seq
)):
# prefix_set_next: the set containing candidate prefixes
# prefix_set_next: the set containing candidate prefixes
# probs_b_cur: prefixes' probability ending with blank in current step
# probs_b_cur: prefixes' probability ending with blank in current step
...
@@ -120,7 +120,7 @@ def ctc_beam_search_decoder(probs_seq,
...
@@ -120,7 +120,7 @@ def ctc_beam_search_decoder(probs_seq,
prob_idx
=
list
(
enumerate
(
probs_seq
[
time_step
]))
prob_idx
=
list
(
enumerate
(
probs_seq
[
time_step
]))
cutoff_len
=
len
(
prob_idx
)
cutoff_len
=
len
(
prob_idx
)
#If pruning is enabled
#
If pruning is enabled
if
cutoff_prob
<
1.0
or
cutoff_top_n
<
cutoff_len
:
if
cutoff_prob
<
1.0
or
cutoff_top_n
<
cutoff_len
:
prob_idx
=
sorted
(
prob_idx
,
key
=
lambda
asd
:
asd
[
1
],
reverse
=
True
)
prob_idx
=
sorted
(
prob_idx
,
key
=
lambda
asd
:
asd
[
1
],
reverse
=
True
)
cutoff_len
,
cum_prob
=
0
,
0.0
cutoff_len
,
cum_prob
=
0
,
0.0
...
@@ -172,7 +172,7 @@ def ctc_beam_search_decoder(probs_seq,
...
@@ -172,7 +172,7 @@ def ctc_beam_search_decoder(probs_seq,
# update probs
# update probs
probs_b_prev
,
probs_nb_prev
=
probs_b_cur
,
probs_nb_cur
probs_b_prev
,
probs_nb_prev
=
probs_b_cur
,
probs_nb_cur
#
#
store top beam_size prefixes
# store top beam_size prefixes
prefix_set_prev
=
sorted
(
prefix_set_prev
=
sorted
(
prefix_set_next
.
items
(),
key
=
lambda
asd
:
asd
[
1
],
reverse
=
True
)
prefix_set_next
.
items
(),
key
=
lambda
asd
:
asd
[
1
],
reverse
=
True
)
if
beam_size
<
len
(
prefix_set_prev
):
if
beam_size
<
len
(
prefix_set_prev
):
...
@@ -191,7 +191,7 @@ def ctc_beam_search_decoder(probs_seq,
...
@@ -191,7 +191,7 @@ def ctc_beam_search_decoder(probs_seq,
else
:
else
:
beam_result
.
append
((
float
(
'-inf'
),
''
))
beam_result
.
append
((
float
(
'-inf'
),
''
))
#
#
output top beam_size decoding results
# output top beam_size decoding results
beam_result
=
sorted
(
beam_result
,
key
=
lambda
asd
:
asd
[
0
],
reverse
=
True
)
beam_result
=
sorted
(
beam_result
,
key
=
lambda
asd
:
asd
[
0
],
reverse
=
True
)
return
beam_result
return
beam_result
...
...
deepspeech/decoders/scorer_deprecated.py
浏览文件 @
dee672a7
...
@@ -71,7 +71,7 @@ class Scorer(object):
...
@@ -71,7 +71,7 @@ class Scorer(object):
"""
"""
lm
=
self
.
_language_model_score
(
sentence
)
lm
=
self
.
_language_model_score
(
sentence
)
word_cnt
=
self
.
_word_count
(
sentence
)
word_cnt
=
self
.
_word_count
(
sentence
)
if
log
==
False
:
if
log
is
False
:
score
=
np
.
power
(
lm
,
self
.
_alpha
)
*
np
.
power
(
word_cnt
,
self
.
_beta
)
score
=
np
.
power
(
lm
,
self
.
_alpha
)
*
np
.
power
(
word_cnt
,
self
.
_beta
)
else
:
else
:
score
=
self
.
_alpha
*
np
.
log
(
lm
)
+
self
.
_beta
*
np
.
log
(
word_cnt
)
score
=
self
.
_alpha
*
np
.
log
(
lm
)
+
self
.
_beta
*
np
.
log
(
word_cnt
)
...
...
deepspeech/decoders/swig/setup.py
浏览文件 @
dee672a7
...
@@ -16,7 +16,8 @@
...
@@ -16,7 +16,8 @@
from
setuptools
import
setup
,
Extension
,
distutils
from
setuptools
import
setup
,
Extension
,
distutils
import
glob
import
glob
import
platform
import
platform
import
os
,
sys
import
os
import
sys
import
multiprocessing.pool
import
multiprocessing.pool
import
argparse
import
argparse
...
...
deepspeech/exps/deepspeech2/bin/deploy/client.py
浏览文件 @
dee672a7
...
@@ -13,8 +13,6 @@
...
@@ -13,8 +13,6 @@
# limitations under the License.
# limitations under the License.
"""Client-end for the ASR demo."""
"""Client-end for the ASR demo."""
import
keyboard
import
keyboard
import
struct
import
socket
import
sys
import
sys
import
argparse
import
argparse
import
pyaudio
import
pyaudio
...
@@ -49,7 +47,7 @@ def on_press_release(x):
...
@@ -49,7 +47,7 @@ def on_press_release(x):
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
is_recording
=
True
is_recording
=
True
if
x
.
event_type
==
'up'
and
x
.
name
==
release
.
name
:
if
x
.
event_type
==
'up'
and
x
.
name
==
release
.
name
:
if
is_recording
==
True
:
if
is_recording
:
is_recording
=
False
is_recording
=
False
...
...
deepspeech/exps/deepspeech2/bin/deploy/runtime.py
浏览文件 @
dee672a7
...
@@ -12,9 +12,6 @@
...
@@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Server-end for the ASR demo."""
"""Server-end for the ASR demo."""
import
os
import
time
import
argparse
import
functools
import
functools
import
paddle
import
paddle
import
numpy
as
np
import
numpy
as
np
...
@@ -26,7 +23,6 @@ from deepspeech.utils.socket_server import AsrRequestHandler
...
@@ -26,7 +23,6 @@ from deepspeech.utils.socket_server import AsrRequestHandler
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.exps.deepspeech2.config
import
get_cfg_defaults
from
deepspeech.exps.deepspeech2.config
import
get_cfg_defaults
from
deepspeech.frontend.utility
import
read_manifest
from
deepspeech.utils.utility
import
add_arguments
,
print_arguments
from
deepspeech.utils.utility
import
add_arguments
,
print_arguments
from
deepspeech.models.deepspeech2
import
DeepSpeech2Model
from
deepspeech.models.deepspeech2
import
DeepSpeech2Model
...
@@ -159,15 +155,13 @@ if __name__ == "__main__":
...
@@ -159,15 +155,13 @@ if __name__ == "__main__":
"--params_file"
,
"--params_file"
,
type
=
str
,
type
=
str
,
default
=
""
,
default
=
""
,
help
=
help
=
"Parameter filename, Specify this when your model is a combined model."
"Parameter filename, Specify this when your model is a combined model."
)
)
add_arg
(
add_arg
(
"--model_dir"
,
"--model_dir"
,
type
=
str
,
type
=
str
,
default
=
None
,
default
=
None
,
help
=
help
=
"Model dir, If you load a non-combined model, specify the directory of the model."
"Model dir, If you load a non-combined model, specify the directory of the model."
)
)
add_arg
(
"--use_gpu"
,
add_arg
(
"--use_gpu"
,
type
=
bool
,
type
=
bool
,
...
...
deepspeech/exps/deepspeech2/bin/deploy/send.py
浏览文件 @
dee672a7
...
@@ -12,8 +12,6 @@
...
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Socket client to send wav to ASR server."""
"""Socket client to send wav to ASR server."""
import
struct
import
socket
import
argparse
import
argparse
import
wave
import
wave
...
...
deepspeech/exps/deepspeech2/bin/deploy/server.py
浏览文件 @
dee672a7
...
@@ -12,9 +12,6 @@
...
@@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Server-end for the ASR demo."""
"""Server-end for the ASR demo."""
import
os
import
time
import
argparse
import
functools
import
functools
import
paddle
import
paddle
import
numpy
as
np
import
numpy
as
np
...
@@ -26,7 +23,6 @@ from deepspeech.utils.socket_server import AsrRequestHandler
...
@@ -26,7 +23,6 @@ from deepspeech.utils.socket_server import AsrRequestHandler
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.exps.deepspeech2.config
import
get_cfg_defaults
from
deepspeech.exps.deepspeech2.config
import
get_cfg_defaults
from
deepspeech.frontend.utility
import
read_manifest
from
deepspeech.utils.utility
import
add_arguments
,
print_arguments
from
deepspeech.utils.utility
import
add_arguments
,
print_arguments
from
deepspeech.models.deepspeech2
import
DeepSpeech2Model
from
deepspeech.models.deepspeech2
import
DeepSpeech2Model
...
...
deepspeech/exps/deepspeech2/bin/export.py
浏览文件 @
dee672a7
...
@@ -12,17 +12,8 @@
...
@@ -12,17 +12,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Export for DeepSpeech2 model."""
"""Export for DeepSpeech2 model."""
import
io
import
logging
import
argparse
import
functools
from
paddle
import
distributed
as
dist
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.utils.utility
import
print_arguments
from
deepspeech.utils.utility
import
print_arguments
from
deepspeech.utils.error_rate
import
char_errors
,
word_errors
from
deepspeech.exps.deepspeech2.config
import
get_cfg_defaults
from
deepspeech.exps.deepspeech2.config
import
get_cfg_defaults
from
deepspeech.exps.deepspeech2.model
import
DeepSpeech2Tester
as
Tester
from
deepspeech.exps.deepspeech2.model
import
DeepSpeech2Tester
as
Tester
...
...
deepspeech/exps/deepspeech2/bin/infer.py
浏览文件 @
dee672a7
...
@@ -12,17 +12,8 @@
...
@@ -12,17 +12,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inferer for DeepSpeech2 model."""
"""Inferer for DeepSpeech2 model."""
import
io
import
logging
import
argparse
import
functools
from
paddle
import
distributed
as
dist
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.utils.utility
import
print_arguments
from
deepspeech.utils.utility
import
print_arguments
from
deepspeech.utils.error_rate
import
char_errors
,
word_errors
# TODO(hui zhang): dynamic load
# TODO(hui zhang): dynamic load
from
deepspeech.exps.deepspeech2.config
import
get_cfg_defaults
from
deepspeech.exps.deepspeech2.config
import
get_cfg_defaults
...
...
deepspeech/exps/deepspeech2/bin/test.py
浏览文件 @
dee672a7
...
@@ -12,17 +12,8 @@
...
@@ -12,17 +12,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Evaluation for DeepSpeech2 model."""
"""Evaluation for DeepSpeech2 model."""
import
io
import
logging
import
argparse
import
functools
from
paddle
import
distributed
as
dist
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.utils.utility
import
print_arguments
from
deepspeech.utils.utility
import
print_arguments
from
deepspeech.utils.error_rate
import
char_errors
,
word_errors
from
deepspeech.exps.deepspeech2.config
import
get_cfg_defaults
from
deepspeech.exps.deepspeech2.config
import
get_cfg_defaults
from
deepspeech.exps.deepspeech2.model
import
DeepSpeech2Tester
as
Tester
from
deepspeech.exps.deepspeech2.model
import
DeepSpeech2Tester
as
Tester
...
...
deepspeech/exps/deepspeech2/bin/train.py
浏览文件 @
dee672a7
...
@@ -12,12 +12,6 @@
...
@@ -12,12 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Trainer for DeepSpeech2 model."""
"""Trainer for DeepSpeech2 model."""
import
io
import
logging
import
argparse
import
functools
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
from
deepspeech.utils.utility
import
print_arguments
from
deepspeech.utils.utility
import
print_arguments
...
...
deepspeech/exps/deepspeech2/bin/tune.py
浏览文件 @
dee672a7
...
@@ -14,12 +14,8 @@
...
@@ -14,12 +14,8 @@
"""Beam search parameters tuning for DeepSpeech2 model."""
"""Beam search parameters tuning for DeepSpeech2 model."""
import
sys
import
sys
import
os
import
numpy
as
np
import
numpy
as
np
import
argparse
import
functools
import
functools
import
gzip
import
logging
from
paddle.io
import
DataLoader
from
paddle.io
import
DataLoader
...
@@ -122,7 +118,7 @@ def tune(config, args):
...
@@ -122,7 +118,7 @@ def tune(config, args):
if
index
%
2
==
0
:
if
index
%
2
==
0
:
sys
.
stdout
.
write
(
'.'
)
sys
.
stdout
.
write
(
'.'
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
print
(
f
"tuneing: one grid done!"
)
print
(
"tuneing: one grid done!"
)
# output on-line tuning result at the end of current batch
# output on-line tuning result at the end of current batch
err_ave_min
=
min
(
err_ave
)
err_ave_min
=
min
(
err_ave
)
...
...
deepspeech/exps/deepspeech2/model.py
浏览文件 @
dee672a7
...
@@ -14,13 +14,10 @@
...
@@ -14,13 +14,10 @@
"""Contains DeepSpeech2 model."""
"""Contains DeepSpeech2 model."""
import
io
import
io
import
sys
import
os
import
time
import
time
import
logging
import
logging
import
numpy
as
np
import
numpy
as
np
from
collections
import
defaultdict
from
collections
import
defaultdict
from
functools
import
partial
from
pathlib
import
Path
from
pathlib
import
Path
import
paddle
import
paddle
...
@@ -39,7 +36,6 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
...
@@ -39,7 +36,6 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from
deepspeech.io.sampler
import
SortagradBatchSampler
from
deepspeech.io.sampler
import
SortagradBatchSampler
from
deepspeech.io.dataset
import
ManifestDataset
from
deepspeech.io.dataset
import
ManifestDataset
from
deepspeech.modules.loss
import
CTCLoss
from
deepspeech.models.deepspeech2
import
DeepSpeech2Model
from
deepspeech.models.deepspeech2
import
DeepSpeech2Model
from
deepspeech.models.deepspeech2
import
DeepSpeech2InferModel
from
deepspeech.models.deepspeech2
import
DeepSpeech2InferModel
...
...
deepspeech/exps/u2/bin/export.py
浏览文件 @
dee672a7
...
@@ -13,16 +13,8 @@
...
@@ -13,16 +13,8 @@
# limitations under the License.
# limitations under the License.
"""Export for U2 model."""
"""Export for U2 model."""
import
io
import
logging
import
argparse
import
functools
from
paddle
import
distributed
as
dist
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.utils.utility
import
print_arguments
from
deepspeech.utils.utility
import
print_arguments
from
deepspeech.utils.error_rate
import
char_errors
,
word_errors
from
deepspeech.exps.u2.config
import
get_cfg_defaults
from
deepspeech.exps.u2.config
import
get_cfg_defaults
from
deepspeech.exps.u2.model
import
U2Tester
as
Tester
from
deepspeech.exps.u2.model
import
U2Tester
as
Tester
...
...
deepspeech/exps/u2/bin/test.py
浏览文件 @
dee672a7
...
@@ -13,16 +13,8 @@
...
@@ -13,16 +13,8 @@
# limitations under the License.
# limitations under the License.
"""Evaluation for U2 model."""
"""Evaluation for U2 model."""
import
io
import
logging
import
argparse
import
functools
from
paddle
import
distributed
as
dist
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.training.cli
import
default_argument_parser
from
deepspeech.utils.utility
import
print_arguments
from
deepspeech.utils.utility
import
print_arguments
from
deepspeech.utils.error_rate
import
char_errors
,
word_errors
# TODO(hui zhang): dynamic load
# TODO(hui zhang): dynamic load
from
deepspeech.exps.u2.config
import
get_cfg_defaults
from
deepspeech.exps.u2.config
import
get_cfg_defaults
...
...
deepspeech/exps/u2/bin/train.py
浏览文件 @
dee672a7
...
@@ -13,11 +13,6 @@
...
@@ -13,11 +13,6 @@
# limitations under the License.
# limitations under the License.
"""Trainer for U2 model."""
"""Trainer for U2 model."""
import
io
import
logging
import
argparse
import
functools
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
from
deepspeech.utils.utility
import
print_arguments
from
deepspeech.utils.utility
import
print_arguments
...
...
deepspeech/exps/u2/model.py
浏览文件 @
dee672a7
...
@@ -13,14 +13,10 @@
...
@@ -13,14 +13,10 @@
# limitations under the License.
# limitations under the License.
"""Contains U2 model."""
"""Contains U2 model."""
import
io
import
sys
import
os
import
time
import
time
import
logging
import
logging
import
numpy
as
np
import
numpy
as
np
from
collections
import
defaultdict
from
collections
import
defaultdict
from
functools
import
partial
from
pathlib
import
Path
from
pathlib
import
Path
import
paddle
import
paddle
...
@@ -40,8 +36,6 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
...
@@ -40,8 +36,6 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from
deepspeech.io.sampler
import
SortagradBatchSampler
from
deepspeech.io.sampler
import
SortagradBatchSampler
from
deepspeech.io.dataset
import
ManifestDataset
from
deepspeech.io.dataset
import
ManifestDataset
from
deepspeech.modules.loss
import
CTCLoss
from
deepspeech.models.u2
import
U2Model
from
deepspeech.models.u2
import
U2Model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
deepspeech/frontend/audio.py
浏览文件 @
dee672a7
...
@@ -22,7 +22,6 @@ import resampy
...
@@ -22,7 +22,6 @@ import resampy
from
scipy
import
signal
from
scipy
import
signal
import
random
import
random
import
copy
import
copy
import
io
class
AudioSegment
(
object
):
class
AudioSegment
(
object
):
...
...
deepspeech/frontend/augmentor/augmentation.py
浏览文件 @
dee672a7
deepspeech/frontend/augmentor/base.py
浏览文件 @
dee672a7
deepspeech/frontend/featurizer/audio_featurizer.py
浏览文件 @
dee672a7
...
@@ -14,8 +14,6 @@
...
@@ -14,8 +14,6 @@
"""Contains the audio featurizer class."""
"""Contains the audio featurizer class."""
import
numpy
as
np
import
numpy
as
np
from
deepspeech.frontend.utility
import
read_manifest
from
deepspeech.frontend.audio
import
AudioSegment
from
python_speech_features
import
mfcc
from
python_speech_features
import
mfcc
from
python_speech_features
import
logfbank
from
python_speech_features
import
logfbank
from
python_speech_features
import
delta
from
python_speech_features
import
delta
...
@@ -320,7 +318,7 @@ class AudioFeaturizer(object):
...
@@ -320,7 +318,7 @@ class AudioFeaturizer(object):
if
stride_ms
>
window_ms
:
if
stride_ms
>
window_ms
:
raise
ValueError
(
"Stride size must not be greater than "
raise
ValueError
(
"Stride size must not be greater than "
"window size."
)
"window size."
)
#(T, D)
#
(T, D)
fbank_feat
=
logfbank
(
fbank_feat
=
logfbank
(
signal
=
samples
,
signal
=
samples
,
samplerate
=
sample_rate
,
samplerate
=
sample_rate
,
...
...
deepspeech/frontend/featurizer/text_featurizer.py
浏览文件 @
dee672a7
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
"""Contains the text featurizer class."""
"""Contains the text featurizer class."""
import
os
import
sentencepiece
as
spm
import
sentencepiece
as
spm
from
deepspeech.frontend.utility
import
UNK
from
deepspeech.frontend.utility
import
UNK
...
...
deepspeech/frontend/utility.py
浏览文件 @
dee672a7
...
@@ -16,15 +16,7 @@ import numpy as np
...
@@ -16,15 +16,7 @@ import numpy as np
import
math
import
math
import
json
import
json
import
codecs
import
codecs
import
os
import
tarfile
import
time
import
logging
import
logging
from
typing
import
List
from
threading
import
Thread
from
multiprocessing
import
Process
,
Manager
,
Value
from
paddle.dataset.common
import
md5file
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
deepspeech/io/__init__.py
浏览文件 @
dee672a7
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
functools
import
numpy
as
np
import
numpy
as
np
from
paddle.io
import
DataLoader
from
paddle.io
import
DataLoader
...
@@ -131,7 +130,7 @@ def create_dataloader(manifest_path,
...
@@ -131,7 +130,7 @@ def create_dataloader(manifest_path,
if
keep_transcription_text
:
if
keep_transcription_text
:
padded_text
[:
len
(
text
)]
=
[
ord
(
t
)
for
t
in
text
]
# string
padded_text
[:
len
(
text
)]
=
[
ord
(
t
)
for
t
in
text
]
# string
else
:
else
:
padded_text
[:
len
(
text
)]
=
text
#ids
padded_text
[:
len
(
text
)]
=
text
#
ids
texts
.
append
(
padded_text
)
texts
.
append
(
padded_text
)
text_lens
.
append
(
len
(
text
))
text_lens
.
append
(
len
(
text
))
...
@@ -141,7 +140,7 @@ def create_dataloader(manifest_path,
...
@@ -141,7 +140,7 @@ def create_dataloader(manifest_path,
text_lens
=
np
.
array
(
text_lens
).
astype
(
'int64'
)
text_lens
=
np
.
array
(
text_lens
).
astype
(
'int64'
)
return
padded_audios
,
audio_lens
,
texts
,
text_lens
return
padded_audios
,
audio_lens
,
texts
,
text_lens
#collate_fn=functools.partial(padding_batch, keep_transcription_text=keep_transcription_text),
#
collate_fn=functools.partial(padding_batch, keep_transcription_text=keep_transcription_text),
collate_fn
=
SpeechCollator
(
keep_transcription_text
=
keep_transcription_text
)
collate_fn
=
SpeechCollator
(
keep_transcription_text
=
keep_transcription_text
)
loader
=
DataLoader
(
loader
=
DataLoader
(
dataset
,
dataset
,
...
...
deepspeech/io/collator.py
浏览文件 @
dee672a7
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
import
logging
import
logging
import
numpy
as
np
import
numpy
as
np
from
collections
import
namedtuple
from
deepspeech.io.utility
import
pad_sequence
from
deepspeech.io.utility
import
pad_sequence
from
deepspeech.frontend.utility
import
IGNORE_ID
from
deepspeech.frontend.utility
import
IGNORE_ID
...
...
deepspeech/io/dataset.py
浏览文件 @
dee672a7
...
@@ -13,13 +13,10 @@
...
@@ -13,13 +13,10 @@
# limitations under the License.
# limitations under the License.
import
io
import
io
import
math
import
random
import
random
import
tarfile
import
tarfile
import
logging
import
logging
import
numpy
as
np
from
collections
import
namedtuple
from
collections
import
namedtuple
from
functools
import
partial
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
from
paddle.io
import
Dataset
from
paddle.io
import
Dataset
...
...
deepspeech/io/sampler.py
浏览文件 @
dee672a7
...
@@ -13,14 +13,9 @@
...
@@ -13,14 +13,9 @@
# limitations under the License.
# limitations under the License.
import
math
import
math
import
random
import
tarfile
import
logging
import
logging
import
numpy
as
np
import
numpy
as
np
from
collections
import
namedtuple
from
functools
import
partial
import
paddle
from
paddle.io
import
BatchSampler
from
paddle.io
import
BatchSampler
from
paddle.io
import
DistributedBatchSampler
from
paddle.io
import
DistributedBatchSampler
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
...
@@ -59,7 +54,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
...
@@ -59,7 +54,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
batch_indices
=
list
(
zip
(
*
[
iter
(
indices
[
shift_len
:])]
*
batch_size
))
batch_indices
=
list
(
zip
(
*
[
iter
(
indices
[
shift_len
:])]
*
batch_size
))
rng
.
shuffle
(
batch_indices
)
rng
.
shuffle
(
batch_indices
)
batch_indices
=
[
item
for
batch
in
batch_indices
for
item
in
batch
]
batch_indices
=
[
item
for
batch
in
batch_indices
for
item
in
batch
]
assert
(
clipped
==
False
)
assert
clipped
is
False
if
not
clipped
:
if
not
clipped
:
res_len
=
len
(
indices
)
-
shift_len
-
len
(
batch_indices
)
res_len
=
len
(
indices
)
-
shift_len
-
len
(
batch_indices
)
# when res_len is 0, will return whole list, len(List[-0:]) = len(List[:])
# when res_len is 0, will return whole list, len(List[-0:]) = len(List[:])
...
...
deepspeech/io/utility.py
浏览文件 @
dee672a7
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
import
logging
import
logging
import
numpy
as
np
import
numpy
as
np
from
collections
import
namedtuple
from
typing
import
List
from
typing
import
List
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
deepspeech/models/deepspeech2.py
浏览文件 @
dee672a7
...
@@ -12,20 +12,12 @@
...
@@ -12,20 +12,12 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Deepspeech2 ASR Model"""
"""Deepspeech2 ASR Model"""
import
math
import
collections
import
numpy
as
np
import
logging
import
logging
from
typing
import
Optional
from
typing
import
Optional
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
from
deepspeech.modules.mask
import
sequence_mask
from
deepspeech.modules.activation
import
brelu
from
deepspeech.modules.conv
import
ConvStack
from
deepspeech.modules.conv
import
ConvStack
from
deepspeech.modules.rnn
import
RNNStack
from
deepspeech.modules.rnn
import
RNNStack
from
deepspeech.modules.ctc
import
CTCDecoder
from
deepspeech.modules.ctc
import
CTCDecoder
...
...
deepspeech/models/u2.py
浏览文件 @
dee672a7
...
@@ -15,10 +15,8 @@
...
@@ -15,10 +15,8 @@
Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
(https://arxiv.org/pdf/2012.05481.pdf)
(https://arxiv.org/pdf/2012.05481.pdf)
"""
"""
import
math
import
collections
from
collections
import
defaultdict
from
collections
import
defaultdict
import
numpy
as
np
import
logging
import
logging
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
...
@@ -26,8 +24,6 @@ from typing import List, Optional, Tuple
...
@@ -26,8 +24,6 @@ from typing import List, Optional, Tuple
import
paddle
import
paddle
from
paddle
import
jit
from
paddle
import
jit
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
from
deepspeech.modules.mask
import
make_pad_mask
from
deepspeech.modules.mask
import
make_pad_mask
from
deepspeech.modules.mask
import
mask_finished_preds
from
deepspeech.modules.mask
import
mask_finished_preds
...
@@ -54,7 +50,7 @@ from deepspeech.utils.ctc_utils import remove_duplicates_and_blank
...
@@ -54,7 +50,7 @@ from deepspeech.utils.ctc_utils import remove_duplicates_and_blank
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'U2TransformerModel'
,
"U2Conform
erModel"
]
__all__
=
[
"U2Model"
,
"U2Inf
erModel"
]
class
U2BaseModel
(
nn
.
Module
):
class
U2BaseModel
(
nn
.
Module
):
...
...
deepspeech/modules/__init__.py
浏览文件 @
dee672a7
deepspeech/modules/activation.py
浏览文件 @
dee672a7
...
@@ -12,16 +12,11 @@
...
@@ -12,16 +12,11 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
Union
import
logging
import
logging
import
numpy
as
np
import
math
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
deepspeech/modules/attention.py
浏览文件 @
dee672a7
...
@@ -18,7 +18,6 @@ from typing import Optional, Tuple
...
@@ -18,7 +18,6 @@ from typing import Optional, Tuple
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
from
paddle.nn
import
initializer
as
I
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
deepspeech/modules/cmvn.py
浏览文件 @
dee672a7
...
@@ -16,8 +16,6 @@ import logging
...
@@ -16,8 +16,6 @@ import logging
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
deepspeech/modules/conformer_convolution.py
浏览文件 @
dee672a7
...
@@ -19,8 +19,6 @@ import logging
...
@@ -19,8 +19,6 @@ import logging
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
deepspeech/modules/conv.py
浏览文件 @
dee672a7
...
@@ -14,10 +14,8 @@
...
@@ -14,10 +14,8 @@
import
logging
import
logging
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
from
deepspeech.modules.mask
import
sequence_mask
from
deepspeech.modules.mask
import
sequence_mask
from
deepspeech.modules.activation
import
brelu
from
deepspeech.modules.activation
import
brelu
...
...
deepspeech/modules/ctc.py
浏览文件 @
dee672a7
...
@@ -18,7 +18,6 @@ from typeguard import check_argument_types
...
@@ -18,7 +18,6 @@ from typeguard import check_argument_types
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
from
deepspeech.modules.loss
import
CTCLoss
from
deepspeech.modules.loss
import
CTCLoss
from
deepspeech.utils
import
ctc_utils
from
deepspeech.utils
import
ctc_utils
...
@@ -151,7 +150,7 @@ class CTCDecoder(nn.Layer):
...
@@ -151,7 +150,7 @@ class CTCDecoder(nn.Layer):
:type vocab_list: list
:type vocab_list: list
"""
"""
# init once
# init once
if
self
.
_ext_scorer
!=
None
:
if
self
.
_ext_scorer
is
not
None
:
return
return
if
language_model_path
!=
''
:
if
language_model_path
!=
''
:
...
@@ -199,7 +198,7 @@ class CTCDecoder(nn.Layer):
...
@@ -199,7 +198,7 @@ class CTCDecoder(nn.Layer):
:return: List of transcription texts.
:return: List of transcription texts.
:rtype: List of str
:rtype: List of str
"""
"""
if
self
.
_ext_scorer
!=
None
:
if
self
.
_ext_scorer
is
not
None
:
self
.
_ext_scorer
.
reset_params
(
beam_alpha
,
beam_beta
)
self
.
_ext_scorer
.
reset_params
(
beam_alpha
,
beam_beta
)
# beam search decode
# beam search decode
...
...
deepspeech/modules/decoder.py
浏览文件 @
dee672a7
...
@@ -18,8 +18,6 @@ import logging
...
@@ -18,8 +18,6 @@ import logging
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
from
deepspeech.modules.attention
import
MultiHeadedAttention
from
deepspeech.modules.attention
import
MultiHeadedAttention
from
deepspeech.modules.decoder_layer
import
DecoderLayer
from
deepspeech.modules.decoder_layer
import
DecoderLayer
...
@@ -125,7 +123,7 @@ class TransformerDecoder(nn.Module):
...
@@ -125,7 +123,7 @@ class TransformerDecoder(nn.Module):
m
=
subsequent_mask
(
tgt_mask
.
size
(
-
1
)).
unsqueeze
(
0
)
m
=
subsequent_mask
(
tgt_mask
.
size
(
-
1
)).
unsqueeze
(
0
)
# tgt_mask: (B, L, L)
# tgt_mask: (B, L, L)
# TODO(Hui Zhang): not support & for tensor
# TODO(Hui Zhang): not support & for tensor
#tgt_mask = tgt_mask & m
#
tgt_mask = tgt_mask & m
tgt_mask
=
tgt_mask
.
logical_and
(
m
)
tgt_mask
=
tgt_mask
.
logical_and
(
m
)
x
,
_
=
self
.
embed
(
tgt
)
x
,
_
=
self
.
embed
(
tgt
)
...
@@ -137,8 +135,8 @@ class TransformerDecoder(nn.Module):
...
@@ -137,8 +135,8 @@ class TransformerDecoder(nn.Module):
if
self
.
use_output_layer
:
if
self
.
use_output_layer
:
x
=
self
.
output_layer
(
x
)
x
=
self
.
output_layer
(
x
)
#TODO(Hui Zhang): reduce_sum not support bool type
#
TODO(Hui Zhang): reduce_sum not support bool type
#olens = tgt_mask.sum(1)
#
olens = tgt_mask.sum(1)
olens
=
tgt_mask
.
astype
(
paddle
.
int
).
sum
(
1
)
olens
=
tgt_mask
.
astype
(
paddle
.
int
).
sum
(
1
)
return
x
,
olens
return
x
,
olens
...
...
deepspeech/modules/decoder_layer.py
浏览文件 @
dee672a7
...
@@ -17,8 +17,6 @@ import logging
...
@@ -17,8 +17,6 @@ import logging
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
deepspeech/modules/embedding.py
浏览文件 @
dee672a7
...
@@ -15,13 +15,10 @@
...
@@ -15,13 +15,10 @@
import
math
import
math
import
logging
import
logging
import
numpy
as
np
from
typing
import
Tuple
from
typing
import
Tuple
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
deepspeech/modules/encoder.py
浏览文件 @
dee672a7
...
@@ -18,8 +18,6 @@ from typeguard import check_argument_types
...
@@ -18,8 +18,6 @@ from typeguard import check_argument_types
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
from
deepspeech.modules.attention
import
MultiHeadedAttention
from
deepspeech.modules.attention
import
MultiHeadedAttention
from
deepspeech.modules.attention
import
RelPositionMultiHeadedAttention
from
deepspeech.modules.attention
import
RelPositionMultiHeadedAttention
...
...
deepspeech/modules/encoder_layer.py
浏览文件 @
dee672a7
...
@@ -17,8 +17,6 @@ import logging
...
@@ -17,8 +17,6 @@ import logging
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
deepspeech/modules/loss.py
浏览文件 @
dee672a7
...
@@ -17,7 +17,6 @@ import logging
...
@@ -17,7 +17,6 @@ import logging
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
deepspeech/modules/mask.py
浏览文件 @
dee672a7
...
@@ -15,9 +15,6 @@
...
@@ -15,9 +15,6 @@
import
logging
import
logging
import
paddle
import
paddle
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
deepspeech/modules/positionwise_feed_forward.py
浏览文件 @
dee672a7
...
@@ -16,8 +16,6 @@ import logging
...
@@ -16,8 +16,6 @@ import logging
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
deepspeech/modules/subsampling.py
浏览文件 @
dee672a7
...
@@ -18,8 +18,6 @@ import logging
...
@@ -18,8 +18,6 @@ import logging
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
from
deepspeech.modules.embedding
import
PositionalEncoding
from
deepspeech.modules.embedding
import
PositionalEncoding
...
...
deepspeech/training/__init__.py
浏览文件 @
dee672a7
...
@@ -11,5 +11,3 @@
...
@@ -11,5 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
deepspeech.training.trainer
import
*
deepspeech/training/cli.py
浏览文件 @
dee672a7
...
@@ -58,12 +58,15 @@ def default_argument_parser():
...
@@ -58,12 +58,15 @@ def default_argument_parser():
parser
.
add_argument
(
"--export_path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
parser
.
add_argument
(
"--export_path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
# running
# running
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
'gpu'
,
choices
=
[
"cpu"
,
"gpu"
],
help
=
"device type to use, cpu and gpu are supported."
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
'gpu'
,
choices
=
[
"cpu"
,
"gpu"
],
help
=
"device type to use, cpu and gpu are supported."
)
parser
.
add_argument
(
"--nprocs"
,
type
=
int
,
default
=
1
,
help
=
"number of parallel processes to use."
)
parser
.
add_argument
(
"--nprocs"
,
type
=
int
,
default
=
1
,
help
=
"number of parallel processes to use."
)
# overwrite extra config and default config
# overwrite extra config and default config
#parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
# parser.add_argument("--opts", nargs=argparse.REMAINDER,
parser
.
add_argument
(
"--opts"
,
type
=
str
,
default
=
[],
nargs
=
'+'
,
help
=
"options to overwrite --config file and the default config, passing in KEY VALUE pairs"
)
# help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
parser
.
add_argument
(
"--opts"
,
type
=
str
,
default
=
[],
nargs
=
'+'
,
help
=
"options to overwrite --config file and the default config, passing in KEY VALUE pairs"
)
# yapd: enable
# yapd: enable
return
parser
return
parser
deepspeech/training/scheduler.py
浏览文件 @
dee672a7
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
import
logging
import
logging
import
paddle
from
paddle.optimizer.lr
import
LRScheduler
from
paddle.optimizer.lr
import
LRScheduler
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
deepspeech/training/trainer.py
浏览文件 @
dee672a7
...
@@ -16,12 +16,9 @@ import time
...
@@ -16,12 +16,9 @@ import time
import
logging
import
logging
import
logging.handlers
import
logging.handlers
from
pathlib
import
Path
from
pathlib
import
Path
import
numpy
as
np
from
collections
import
defaultdict
import
paddle
import
paddle
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
from
paddle.distributed.utils
import
get_gpus
from
tensorboardX
import
SummaryWriter
from
tensorboardX
import
SummaryWriter
from
deepspeech.utils
import
checkpoint
from
deepspeech.utils
import
checkpoint
...
...
deepspeech/utils/checkpoint.py
浏览文件 @
dee672a7
...
@@ -13,15 +13,12 @@
...
@@ -13,15 +13,12 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
time
import
logging
import
logging
import
numpy
as
np
import
re
import
re
import
json
import
json
import
paddle
import
paddle
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
from
paddle.nn
import
Layer
from
paddle.optimizer
import
Optimizer
from
paddle.optimizer
import
Optimizer
from
deepspeech.utils
import
mp_tools
from
deepspeech.utils
import
mp_tools
...
...
deepspeech/utils/error_rate.py
浏览文件 @
dee672a7
...
@@ -81,7 +81,7 @@ def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '):
...
@@ -81,7 +81,7 @@ def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '):
:return: Levenshtein distance and word number of reference sentence.
:return: Levenshtein distance and word number of reference sentence.
:rtype: list
:rtype: list
"""
"""
if
ignore_case
==
True
:
if
ignore_case
:
reference
=
reference
.
lower
()
reference
=
reference
.
lower
()
hypothesis
=
hypothesis
.
lower
()
hypothesis
=
hypothesis
.
lower
()
...
@@ -107,12 +107,12 @@ def char_errors(reference, hypothesis, ignore_case=False, remove_space=False):
...
@@ -107,12 +107,12 @@ def char_errors(reference, hypothesis, ignore_case=False, remove_space=False):
:return: Levenshtein distance and length of reference sentence.
:return: Levenshtein distance and length of reference sentence.
:rtype: list
:rtype: list
"""
"""
if
ignore_case
==
True
:
if
ignore_case
:
reference
=
reference
.
lower
()
reference
=
reference
.
lower
()
hypothesis
=
hypothesis
.
lower
()
hypothesis
=
hypothesis
.
lower
()
join_char
=
' '
join_char
=
' '
if
remove_space
==
True
:
if
remove_space
:
join_char
=
''
join_char
=
''
reference
=
join_char
.
join
(
list
(
filter
(
None
,
reference
.
split
(
' '
))))
reference
=
join_char
.
join
(
list
(
filter
(
None
,
reference
.
split
(
' '
))))
...
...
deepspeech/utils/layer_tools.py
浏览文件 @
dee672a7
...
@@ -51,7 +51,7 @@ def recursively_remove_weight_norm(layer: nn.Layer):
...
@@ -51,7 +51,7 @@ def recursively_remove_weight_norm(layer: nn.Layer):
for
layer
in
layer
.
sublayers
():
for
layer
in
layer
.
sublayers
():
try
:
try
:
nn
.
utils
.
remove_weight_norm
(
layer
)
nn
.
utils
.
remove_weight_norm
(
layer
)
except
:
except
ValueError
as
e
:
# ther is not weight norm hoom in this layer
# ther is not weight norm hoom in this layer
pass
pass
...
...
deepspeech/utils/mp_tools.py
浏览文件 @
dee672a7
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
paddle
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
from
functools
import
wraps
from
functools
import
wraps
...
...
deepspeech/utils/tensor_utils.py
浏览文件 @
dee672a7
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Unility functions for Transformer."""
"""Unility functions for Transformer."""
import
math
import
logging
import
logging
from
typing
import
Tuple
,
List
from
typing
import
Tuple
,
List
...
...
deepspeech/utils/utility.py
浏览文件 @
dee672a7
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
"""Contains common utility functions."""
"""Contains common utility functions."""
import
math
import
math
import
numpy
as
np
import
distutils.util
import
distutils.util
from
typing
import
List
from
typing
import
List
...
...
examples/dataset/aishell/aishell.py
浏览文件 @
dee672a7
...
@@ -55,7 +55,8 @@ def create_manifest(data_dir, manifest_path_prefix):
...
@@ -55,7 +55,8 @@ def create_manifest(data_dir, manifest_path_prefix):
transcript_dict
=
{}
transcript_dict
=
{}
for
line
in
codecs
.
open
(
transcript_path
,
'r'
,
'utf-8'
):
for
line
in
codecs
.
open
(
transcript_path
,
'r'
,
'utf-8'
):
line
=
line
.
strip
()
line
=
line
.
strip
()
if
line
==
''
:
continue
if
line
==
''
:
continue
audio_id
,
text
=
line
.
split
(
' '
,
1
)
audio_id
,
text
=
line
.
split
(
' '
,
1
)
# remove withespace
# remove withespace
text
=
''
.
join
(
text
.
split
())
text
=
''
.
join
(
text
.
split
())
...
@@ -82,7 +83,7 @@ def create_manifest(data_dir, manifest_path_prefix):
...
@@ -82,7 +83,7 @@ def create_manifest(data_dir, manifest_path_prefix):
os
.
path
.
splitext
(
os
.
path
.
basename
(
audio_path
))[
0
],
os
.
path
.
splitext
(
os
.
path
.
basename
(
audio_path
))[
0
],
'feat'
:
'feat'
:
audio_path
,
audio_path
,
'feat_shape'
:
(
duration
,
),
#second
'feat_shape'
:
(
duration
,
),
#
second
'text'
:
'text'
:
text
text
},
},
...
...
examples/dataset/chime3_background/chime3_background.py
浏览文件 @
dee672a7
...
@@ -19,7 +19,6 @@ meta data (i.e. audio filepath, transcript and audio duration)
...
@@ -19,7 +19,6 @@ meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
of each audio file in the data set.
"""
"""
import
distutils.util
import
os
import
os
import
wget
import
wget
import
zipfile
import
zipfile
...
@@ -29,7 +28,7 @@ import json
...
@@ -29,7 +28,7 @@ import json
import
io
import
io
from
paddle.v2.dataset.common
import
md5file
from
paddle.v2.dataset.common
import
md5file
#DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
#
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
DATA_HOME
=
os
.
path
.
expanduser
(
'.'
)
DATA_HOME
=
os
.
path
.
expanduser
(
'.'
)
URL
=
"https://d4s.myairbridge.com/packagev2/AG0Y3DNBE5IWRRTV/?dlid=W19XG7T0NNHB027139H0EQ"
URL
=
"https://d4s.myairbridge.com/packagev2/AG0Y3DNBE5IWRRTV/?dlid=W19XG7T0NNHB027139H0EQ"
...
@@ -51,9 +50,10 @@ args = parser.parse_args()
...
@@ -51,9 +50,10 @@ args = parser.parse_args()
def
download
(
url
,
md5sum
,
target_dir
,
filename
=
None
):
def
download
(
url
,
md5sum
,
target_dir
,
filename
=
None
):
"""Download file from url to target_dir, and check md5sum."""
"""Download file from url to target_dir, and check md5sum."""
if
filename
==
None
:
if
filename
is
None
:
filename
=
url
.
split
(
"/"
)[
-
1
]
filename
=
url
.
split
(
"/"
)[
-
1
]
if
not
os
.
path
.
exists
(
target_dir
):
os
.
makedirs
(
target_dir
)
if
not
os
.
path
.
exists
(
target_dir
):
os
.
makedirs
(
target_dir
)
filepath
=
os
.
path
.
join
(
target_dir
,
filename
)
filepath
=
os
.
path
.
join
(
target_dir
,
filename
)
if
not
(
os
.
path
.
exists
(
filepath
)
and
md5file
(
filepath
)
==
md5sum
):
if
not
(
os
.
path
.
exists
(
filepath
)
and
md5file
(
filepath
)
==
md5sum
):
print
(
"Downloading %s ..."
%
url
)
print
(
"Downloading %s ..."
%
url
)
...
@@ -100,7 +100,7 @@ def create_manifest(data_dir, manifest_path):
...
@@ -100,7 +100,7 @@ def create_manifest(data_dir, manifest_path):
'utt'
:
os
.
path
.
splitext
(
os
.
path
.
basename
(
filepath
))[
'utt'
:
os
.
path
.
splitext
(
os
.
path
.
basename
(
filepath
))[
0
],
0
],
'feat'
:
filepath
,
'feat'
:
filepath
,
'feat_shape'
:
(
duration
,
),
#second
'feat_shape'
:
(
duration
,
),
#
second
'type'
:
'background'
'type'
:
'background'
}))
}))
with
io
.
open
(
manifest_path
,
mode
=
'w'
,
encoding
=
'utf8'
)
as
out_file
:
with
io
.
open
(
manifest_path
,
mode
=
'w'
,
encoding
=
'utf8'
)
as
out_file
:
...
...
examples/dataset/librispeech/librispeech.py
浏览文件 @
dee672a7
...
@@ -21,7 +21,6 @@ of each audio file in the data set.
...
@@ -21,7 +21,6 @@ of each audio file in the data set.
import
distutils.util
import
distutils.util
import
os
import
os
import
sys
import
argparse
import
argparse
import
soundfile
import
soundfile
import
json
import
json
...
...
examples/dataset/mini_librispeech/mini_librispeech.py
浏览文件 @
dee672a7
...
@@ -19,9 +19,7 @@ meta data (i.e. audio filepath, transcript and audio duration)
...
@@ -19,9 +19,7 @@ meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
of each audio file in the data set.
"""
"""
import
distutils.util
import
os
import
os
import
sys
import
argparse
import
argparse
import
soundfile
import
soundfile
import
json
import
json
...
...
examples/dataset/rir_noise/rir_noise.py
浏览文件 @
dee672a7
...
@@ -27,7 +27,7 @@ import codecs
...
@@ -27,7 +27,7 @@ import codecs
import
soundfile
import
soundfile
import
json
import
json
import
argparse
import
argparse
from
utils.utility
import
download
,
un
pack
,
un
zip
from
utils.utility
import
download
,
unzip
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset/speech'
)
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset/speech'
)
...
...
examples/dataset/voxforge/voxforge.py
浏览文件 @
dee672a7
tests/deepspeech2_model_test.py
浏览文件 @
dee672a7
...
@@ -26,11 +26,11 @@ class TestDeepSpeech2Model(unittest.TestCase):
...
@@ -26,11 +26,11 @@ class TestDeepSpeech2Model(unittest.TestCase):
self
.
feat_dim
=
161
self
.
feat_dim
=
161
max_len
=
64
max_len
=
64
#(B, T, D)
#
(B, T, D)
audio
=
np
.
random
.
randn
(
self
.
batch_size
,
max_len
,
self
.
feat_dim
)
audio
=
np
.
random
.
randn
(
self
.
batch_size
,
max_len
,
self
.
feat_dim
)
audio_len
=
np
.
random
.
randint
(
max_len
,
size
=
self
.
batch_size
)
audio_len
=
np
.
random
.
randint
(
max_len
,
size
=
self
.
batch_size
)
audio_len
[
-
1
]
=
max_len
audio_len
[
-
1
]
=
max_len
#(B, U)
#
(B, U)
text
=
np
.
array
([[
1
,
2
],
[
1
,
2
]])
text
=
np
.
array
([[
1
,
2
],
[
1
,
2
]])
text_len
=
np
.
array
([
2
]
*
self
.
batch_size
)
text_len
=
np
.
array
([
2
]
*
self
.
batch_size
)
...
...
utils/build_vocab.py
浏览文件 @
dee672a7
...
@@ -17,10 +17,8 @@ Each item in vocabulary file is a character.
...
@@ -17,10 +17,8 @@ Each item in vocabulary file is a character.
import
argparse
import
argparse
import
functools
import
functools
import
json
from
collections
import
Counter
from
collections
import
Counter
import
os
import
os
import
copy
import
tempfile
import
tempfile
from
deepspeech.frontend.utility
import
read_manifest
from
deepspeech.frontend.utility
import
read_manifest
...
@@ -48,10 +46,8 @@ add_arg('manifest_paths', str,
...
@@ -48,10 +46,8 @@ add_arg('manifest_paths', str,
required
=
True
)
required
=
True
)
# bpe
# bpe
add_arg
(
'vocab_size'
,
int
,
0
,
"Vocab size for spm."
)
add_arg
(
'vocab_size'
,
int
,
0
,
"Vocab size for spm."
)
add_arg
(
'spm_mode'
,
str
,
'unigram'
,
add_arg
(
'spm_mode'
,
str
,
'unigram'
,
"spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm"
)
"spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm"
)
add_arg
(
'spm_model_prefix'
,
str
,
"spm_model_%(spm_mode)_%(count_threshold)"
,
"spm model prefix, only need when `unit_type` is spm"
)
add_arg
(
'spm_model_prefix'
,
str
,
"spm_model_%(spm_mode)_%(count_threshold)"
,
"spm model prefix, only need when `unit_type` is spm"
)
# yapf: disable
# yapf: disable
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -104,7 +100,8 @@ def main():
...
@@ -104,7 +100,8 @@ def main():
count_sorted
=
sorted
(
counter
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
count_sorted
=
sorted
(
counter
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
tokens
=
[]
tokens
=
[]
for
token
,
count
in
count_sorted
:
for
token
,
count
in
count_sorted
:
if
count
<
args
.
count_threshold
:
break
if
count
<
args
.
count_threshold
:
break
tokens
.
append
(
token
)
tokens
.
append
(
token
)
tokens
=
sorted
(
tokens
)
tokens
=
sorted
(
tokens
)
...
...
utils/format_data.py
浏览文件 @
dee672a7
...
@@ -15,15 +15,8 @@
...
@@ -15,15 +15,8 @@
import
argparse
import
argparse
import
functools
import
functools
import
json
import
json
from
collections
import
Counter
import
os
import
copy
import
tempfile
from
deepspeech.frontend.utility
import
read_manifest
from
deepspeech.frontend.utility
import
read_manifest
from
deepspeech.frontend.utility
import
UNK
from
deepspeech.frontend.utility
import
BLANK
from
deepspeech.frontend.utility
import
SOS
from
deepspeech.frontend.utility
import
load_cmvn
from
deepspeech.frontend.utility
import
load_cmvn
from
deepspeech.utils.utility
import
add_arguments
from
deepspeech.utils.utility
import
add_arguments
from
deepspeech.utils.utility
import
print_arguments
from
deepspeech.utils.utility
import
print_arguments
...
@@ -82,7 +75,7 @@ def main():
...
@@ -82,7 +75,7 @@ def main():
if
args
.
feat_type
==
'raw'
:
if
args
.
feat_type
==
'raw'
:
feat_shape
.
append
(
feat_dim
)
feat_shape
.
append
(
feat_dim
)
else
:
# kaldi
else
:
# kaldi
raise
NotImplemented
(
'no support kaldi feat now!'
)
raise
NotImplemented
Error
(
'no support kaldi feat now!'
)
fout
.
write
(
json
.
dumps
(
line_json
)
+
'
\n
'
)
fout
.
write
(
json
.
dumps
(
line_json
)
+
'
\n
'
)
count
+=
1
count
+=
1
...
...
utils/utility.py
浏览文件 @
dee672a7
...
@@ -30,7 +30,8 @@ def getfile_insensitive(path):
...
@@ -30,7 +30,8 @@ def getfile_insensitive(path):
def
download_multi
(
url
,
target_dir
,
extra_args
):
def
download_multi
(
url
,
target_dir
,
extra_args
):
"""Download multiple files from url to target_dir."""
"""Download multiple files from url to target_dir."""
if
not
os
.
path
.
exists
(
target_dir
):
os
.
makedirs
(
target_dir
)
if
not
os
.
path
.
exists
(
target_dir
):
os
.
makedirs
(
target_dir
)
print
(
"Downloading %s ..."
%
url
)
print
(
"Downloading %s ..."
%
url
)
ret_code
=
os
.
system
(
"wget -c "
+
url
+
' '
+
extra_args
+
" -P "
+
ret_code
=
os
.
system
(
"wget -c "
+
url
+
' '
+
extra_args
+
" -P "
+
target_dir
)
target_dir
)
...
@@ -39,7 +40,8 @@ def download_multi(url, target_dir, extra_args):
...
@@ -39,7 +40,8 @@ def download_multi(url, target_dir, extra_args):
def
download
(
url
,
md5sum
,
target_dir
):
def
download
(
url
,
md5sum
,
target_dir
):
"""Download file from url to target_dir, and check md5sum."""
"""Download file from url to target_dir, and check md5sum."""
if
not
os
.
path
.
exists
(
target_dir
):
os
.
makedirs
(
target_dir
)
if
not
os
.
path
.
exists
(
target_dir
):
os
.
makedirs
(
target_dir
)
filepath
=
os
.
path
.
join
(
target_dir
,
url
.
split
(
"/"
)[
-
1
])
filepath
=
os
.
path
.
join
(
target_dir
,
url
.
split
(
"/"
)[
-
1
])
if
not
(
os
.
path
.
exists
(
filepath
)
and
md5file
(
filepath
)
==
md5sum
):
if
not
(
os
.
path
.
exists
(
filepath
)
and
md5file
(
filepath
)
==
md5sum
):
print
(
"Downloading %s ..."
%
url
)
print
(
"Downloading %s ..."
%
url
)
...
@@ -58,7 +60,7 @@ def unpack(filepath, target_dir, rm_tar=False):
...
@@ -58,7 +60,7 @@ def unpack(filepath, target_dir, rm_tar=False):
tar
=
tarfile
.
open
(
filepath
)
tar
=
tarfile
.
open
(
filepath
)
tar
.
extractall
(
target_dir
)
tar
.
extractall
(
target_dir
)
tar
.
close
()
tar
.
close
()
if
rm_tar
==
True
:
if
rm_tar
:
os
.
remove
(
filepath
)
os
.
remove
(
filepath
)
...
@@ -68,5 +70,5 @@ def unzip(filepath, target_dir, rm_tar=False):
...
@@ -68,5 +70,5 @@ def unzip(filepath, target_dir, rm_tar=False):
tar
=
zipfile
.
ZipFile
(
filepath
,
'r'
)
tar
=
zipfile
.
ZipFile
(
filepath
,
'r'
)
tar
.
extractall
(
target_dir
)
tar
.
extractall
(
target_dir
)
tar
.
close
()
tar
.
close
()
if
rm_tar
==
True
:
if
rm_tar
:
os
.
remove
(
filepath
)
os
.
remove
(
filepath
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录