Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
eb4b3892
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
eb4b3892
编写于
10月 09, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
more log; refactor ctc decoders; rm useless code
上级
26910132
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
110 addition
and
103 deletion
+110
-103
deepspeech/decoders/swig/ctc_beam_search_decoder.cpp
deepspeech/decoders/swig/ctc_beam_search_decoder.cpp
+0
-1
deepspeech/decoders/swig/decoder_utils.h
deepspeech/decoders/swig/decoder_utils.h
+2
-0
deepspeech/decoders/swig/scorer.cpp
deepspeech/decoders/swig/scorer.cpp
+0
-1
deepspeech/exps/deepspeech2/bin/export.py
deepspeech/exps/deepspeech2/bin/export.py
+1
-1
deepspeech/exps/deepspeech2/bin/test_hub.py
deepspeech/exps/deepspeech2/bin/test_hub.py
+1
-1
deepspeech/frontend/augmentor/augmentation.py
deepspeech/frontend/augmentor/augmentation.py
+8
-3
deepspeech/frontend/augmentor/spec_augment.py
deepspeech/frontend/augmentor/spec_augment.py
+4
-4
deepspeech/frontend/featurizer/speech_featurizer.py
deepspeech/frontend/featurizer/speech_featurizer.py
+2
-0
deepspeech/frontend/featurizer/text_featurizer.py
deepspeech/frontend/featurizer/text_featurizer.py
+21
-2
deepspeech/frontend/utility.py
deepspeech/frontend/utility.py
+5
-1
deepspeech/models/ds2/deepspeech2.py
deepspeech/models/ds2/deepspeech2.py
+24
-34
deepspeech/models/ds2_online/deepspeech2.py
deepspeech/models/ds2_online/deepspeech2.py
+41
-55
deepspeech/utils/log.py
deepspeech/utils/log.py
+1
-0
未找到文件。
deepspeech/decoders/swig/ctc_beam_search_decoder.cpp
浏览文件 @
eb4b3892
...
@@ -28,7 +28,6 @@
...
@@ -28,7 +28,6 @@
#include "path_trie.h"
#include "path_trie.h"
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
const
std
::
string
kSPACE
=
"<space>"
;
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
ctc_beam_search_decoder
(
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
ctc_beam_search_decoder
(
const
std
::
vector
<
std
::
vector
<
double
>>
&
probs_seq
,
const
std
::
vector
<
std
::
vector
<
double
>>
&
probs_seq
,
...
...
deepspeech/decoders/swig/decoder_utils.h
浏览文件 @
eb4b3892
...
@@ -15,10 +15,12 @@
...
@@ -15,10 +15,12 @@
#ifndef DECODER_UTILS_H_
#ifndef DECODER_UTILS_H_
#define DECODER_UTILS_H_
#define DECODER_UTILS_H_
#include <string>
#include <utility>
#include <utility>
#include "fst/log.h"
#include "fst/log.h"
#include "path_trie.h"
#include "path_trie.h"
const
std
::
string
kSPACE
=
"<space>"
;
const
float
NUM_FLT_INF
=
std
::
numeric_limits
<
float
>::
max
();
const
float
NUM_FLT_INF
=
std
::
numeric_limits
<
float
>::
max
();
const
float
NUM_FLT_MIN
=
std
::
numeric_limits
<
float
>::
min
();
const
float
NUM_FLT_MIN
=
std
::
numeric_limits
<
float
>::
min
();
...
...
deepspeech/decoders/swig/scorer.cpp
浏览文件 @
eb4b3892
...
@@ -26,7 +26,6 @@
...
@@ -26,7 +26,6 @@
#include "decoder_utils.h"
#include "decoder_utils.h"
using
namespace
lm
::
ngram
;
using
namespace
lm
::
ngram
;
const
std
::
string
kSPACE
=
"<space>"
;
Scorer
::
Scorer
(
double
alpha
,
Scorer
::
Scorer
(
double
alpha
,
double
beta
,
double
beta
,
...
...
deepspeech/exps/deepspeech2/bin/export.py
浏览文件 @
eb4b3892
...
@@ -34,7 +34,7 @@ if __name__ == "__main__":
...
@@ -34,7 +34,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
parser
.
add_argument
(
"--export_path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
"--export_path"
,
type
=
str
,
help
=
"path of the jit model to save"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--model_type"
,
type
=
str
,
default
=
'offline'
,
help
=
'offline/online'
)
"--model_type"
,
type
=
str
,
default
=
'offline'
,
help
=
"offline/online"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
"model_type:{}"
.
format
(
args
.
model_type
))
print
(
"model_type:{}"
.
format
(
args
.
model_type
))
print_arguments
(
args
)
print_arguments
(
args
)
...
...
deepspeech/exps/deepspeech2/bin/test_hub.py
浏览文件 @
eb4b3892
...
@@ -179,7 +179,7 @@ if __name__ == "__main__":
...
@@ -179,7 +179,7 @@ if __name__ == "__main__":
parser
=
default_argument_parser
()
parser
=
default_argument_parser
()
parser
.
add_argument
(
parser
.
add_argument
(
"--model_type"
,
type
=
str
,
default
=
'offline'
,
help
=
'offline/online'
)
"--model_type"
,
type
=
str
,
default
=
'offline'
,
help
=
'offline/online'
)
parser
.
add_argument
(
"--audio_file"
,
type
=
str
,
help
=
'audio file path
.
'
)
parser
.
add_argument
(
"--audio_file"
,
type
=
str
,
help
=
'audio file path'
)
# save asr result to
# save asr result to
parser
.
add_argument
(
parser
.
add_argument
(
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
"--result_file"
,
type
=
str
,
help
=
"path of save the asr result"
)
...
...
deepspeech/frontend/augmentor/augmentation.py
浏览文件 @
eb4b3892
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
import
json
import
json
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
inspect
import
signature
from
inspect
import
signature
from
pprint
import
pformat
import
numpy
as
np
import
numpy
as
np
...
@@ -22,10 +23,10 @@ from deepspeech.frontend.augmentor.base import AugmentorBase
...
@@ -22,10 +23,10 @@ from deepspeech.frontend.augmentor.base import AugmentorBase
from
deepspeech.utils.dynamic_import
import
dynamic_import
from
deepspeech.utils.dynamic_import
import
dynamic_import
from
deepspeech.utils.log
import
Log
from
deepspeech.utils.log
import
Log
__all__
=
[
"AugmentationPipeline"
]
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"AugmentationPipeline"
]
import_alias
=
dict
(
import_alias
=
dict
(
volume
=
"deepspeech.frontend.augmentor.impulse_response:VolumePerturbAugmentor"
,
volume
=
"deepspeech.frontend.augmentor.impulse_response:VolumePerturbAugmentor"
,
shift
=
"deepspeech.frontend.augmentor.shift_perturb:ShiftPerturbAugmentor"
,
shift
=
"deepspeech.frontend.augmentor.shift_perturb:ShiftPerturbAugmentor"
,
...
@@ -111,6 +112,8 @@ class AugmentationPipeline():
...
@@ -111,6 +112,8 @@ class AugmentationPipeline():
'audio'
)
'audio'
)
self
.
_spec_augmentors
,
self
.
_spec_rates
=
self
.
_parse_pipeline_from
(
self
.
_spec_augmentors
,
self
.
_spec_rates
=
self
.
_parse_pipeline_from
(
'feature'
)
'feature'
)
logger
.
info
(
f
"Augmentation:
{
pformat
(
list
(
zip
(
self
.
_augmentors
,
self
.
_rates
)))
}
"
)
def
__call__
(
self
,
xs
,
uttid_list
=
None
,
**
kwargs
):
def
__call__
(
self
,
xs
,
uttid_list
=
None
,
**
kwargs
):
if
not
isinstance
(
xs
,
Sequence
):
if
not
isinstance
(
xs
,
Sequence
):
...
@@ -197,8 +200,10 @@ class AugmentationPipeline():
...
@@ -197,8 +200,10 @@ class AugmentationPipeline():
aug_confs
=
audio_confs
aug_confs
=
audio_confs
elif
aug_type
==
'feature'
:
elif
aug_type
==
'feature'
:
aug_confs
=
feature_confs
aug_confs
=
feature_confs
el
se
:
el
if
aug_type
==
'all'
:
aug_confs
=
all_confs
aug_confs
=
all_confs
else
:
raise
ValueError
(
f
"Not support:
{
aug_type
}
"
)
augmentors
=
[
augmentors
=
[
self
.
_get_augmentor
(
config
[
"type"
],
config
[
"params"
])
self
.
_get_augmentor
(
config
[
"type"
],
config
[
"params"
])
...
...
deepspeech/frontend/augmentor/spec_augment.py
浏览文件 @
eb4b3892
...
@@ -133,7 +133,7 @@ class SpecAugmentor(AugmentorBase):
...
@@ -133,7 +133,7 @@ class SpecAugmentor(AugmentorBase):
return
self
.
_time_mask
return
self
.
_time_mask
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
"specaug: F-
{
F
}
, T-
{
T
}
, F-n-
{
n_freq_masks
}
, T-n-
{
n_time_masks
}
"
return
f
"specaug: F-
{
self
.
F
}
, T-
{
self
.
T
}
, F-n-
{
self
.
n_freq_masks
}
, T-n-
{
self
.
n_time_masks
}
"
def
time_warp
(
self
,
x
,
mode
=
'PIL'
):
def
time_warp
(
self
,
x
,
mode
=
'PIL'
):
"""time warp for spec augment
"""time warp for spec augment
...
...
deepspeech/frontend/featurizer/speech_featurizer.py
浏览文件 @
eb4b3892
...
@@ -51,12 +51,14 @@ class SpeechFeaturizer():
...
@@ -51,12 +51,14 @@ class SpeechFeaturizer():
use_dB_normalization
=
use_dB_normalization
,
use_dB_normalization
=
use_dB_normalization
,
target_dB
=
target_dB
,
target_dB
=
target_dB
,
dither
=
dither
)
dither
=
dither
)
self
.
feature_size
=
self
.
audio_feature
.
feature_size
self
.
text_feature
=
TextFeaturizer
(
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
unit_type
,
unit_type
=
unit_type
,
vocab_filepath
=
vocab_filepath
,
vocab_filepath
=
vocab_filepath
,
spm_model_prefix
=
spm_model_prefix
,
spm_model_prefix
=
spm_model_prefix
,
maskctc
=
maskctc
)
maskctc
=
maskctc
)
self
.
vocab_size
=
self
.
text_feature
.
vocab_size
def
featurize
(
self
,
speech_segment
,
keep_transcription_text
):
def
featurize
(
self
,
speech_segment
,
keep_transcription_text
):
"""Extract features for speech segment.
"""Extract features for speech segment.
...
...
deepspeech/frontend/featurizer/text_featurizer.py
浏览文件 @
eb4b3892
...
@@ -12,12 +12,20 @@
...
@@ -12,12 +12,20 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Contains the text featurizer class."""
"""Contains the text featurizer class."""
from
pprint
import
pformat
import
sentencepiece
as
spm
import
sentencepiece
as
spm
from
..utility
import
BLANK
from
..utility
import
EOS
from
..utility
import
EOS
from
..utility
import
load_dict
from
..utility
import
load_dict
from
..utility
import
MASKCTC
from
..utility
import
SOS
from
..utility
import
SPACE
from
..utility
import
SPACE
from
..utility
import
UNK
from
..utility
import
UNK
from
deepspeech.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"TextFeaturizer"
]
__all__
=
[
"TextFeaturizer"
]
...
@@ -76,7 +84,7 @@ class TextFeaturizer():
...
@@ -76,7 +84,7 @@ class TextFeaturizer():
"""Convert text string to a list of token indices.
"""Convert text string to a list of token indices.
Args:
Args:
text (str): Text.
text (str): Text
to process
.
Returns:
Returns:
List[int]: List of token indices.
List[int]: List of token indices.
...
@@ -199,13 +207,24 @@ class TextFeaturizer():
...
@@ -199,13 +207,24 @@ class TextFeaturizer():
"""Load vocabulary from file."""
"""Load vocabulary from file."""
vocab_list
=
load_dict
(
vocab_filepath
,
maskctc
)
vocab_list
=
load_dict
(
vocab_filepath
,
maskctc
)
assert
vocab_list
is
not
None
assert
vocab_list
is
not
None
logger
.
info
(
f
"Vocab:
{
pformat
(
vocab_list
)
}
"
)
id2token
=
dict
(
id2token
=
dict
(
[(
idx
,
token
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
[(
idx
,
token
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
token2id
=
dict
(
token2id
=
dict
(
[(
token
,
idx
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
[(
token
,
idx
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
blank_id
=
vocab_list
.
index
(
BLANK
)
if
BLANK
in
vocab_list
else
-
1
maskctc_id
=
vocab_list
.
index
(
MASKCTC
)
if
MASKCTC
in
vocab_list
else
-
1
unk_id
=
vocab_list
.
index
(
UNK
)
if
UNK
in
vocab_list
else
-
1
unk_id
=
vocab_list
.
index
(
UNK
)
if
UNK
in
vocab_list
else
-
1
eos_id
=
vocab_list
.
index
(
EOS
)
if
EOS
in
vocab_list
else
-
1
eos_id
=
vocab_list
.
index
(
EOS
)
if
EOS
in
vocab_list
else
-
1
sos_id
=
vocab_list
.
index
(
SOS
)
if
SOS
in
vocab_list
else
-
1
space_id
=
vocab_list
.
index
(
SPACE
)
if
SPACE
in
vocab_list
else
-
1
logger
.
info
(
f
"BLANK id:
{
blank_id
}
"
)
logger
.
info
(
f
"UNK id:
{
unk_id
}
"
)
logger
.
info
(
f
"EOS id:
{
eos_id
}
"
)
logger
.
info
(
f
"SOS id:
{
sos_id
}
"
)
logger
.
info
(
f
"SPACE id:
{
space_id
}
"
)
logger
.
info
(
f
"MASKCTC id:
{
maskctc_id
}
"
)
return
token2id
,
id2token
,
vocab_list
,
unk_id
,
eos_id
return
token2id
,
id2token
,
vocab_list
,
unk_id
,
eos_id
deepspeech/frontend/utility.py
浏览文件 @
eb4b3892
...
@@ -49,7 +49,11 @@ def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
...
@@ -49,7 +49,11 @@ def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
with
open
(
dict_path
,
"r"
)
as
f
:
with
open
(
dict_path
,
"r"
)
as
f
:
dictionary
=
f
.
readlines
()
dictionary
=
f
.
readlines
()
char_list
=
[
entry
.
strip
().
split
(
" "
)[
0
]
for
entry
in
dictionary
]
# first token is `<blank>`
# multi line: `<blank> 0\n`
# one line: `<blank>`
# space is relpace with <space>
char_list
=
[
entry
[:
-
1
].
split
(
" "
)[
0
]
for
entry
in
dictionary
]
if
BLANK
not
in
char_list
:
if
BLANK
not
in
char_list
:
char_list
.
insert
(
0
,
BLANK
)
char_list
.
insert
(
0
,
BLANK
)
if
EOS
not
in
char_list
:
if
EOS
not
in
char_list
:
...
...
deepspeech/models/ds2/deepspeech2.py
浏览文件 @
eb4b3892
...
@@ -218,14 +218,18 @@ class DeepSpeech2Model(nn.Layer):
...
@@ -218,14 +218,18 @@ class DeepSpeech2Model(nn.Layer):
DeepSpeech2Model
DeepSpeech2Model
The model built from pretrained result.
The model built from pretrained result.
"""
"""
model
=
cls
(
feat_size
=
dataloader
.
collate_fn
.
feature_size
,
model
=
cls
(
dict_size
=
dataloader
.
collate_fn
.
vocab_size
,
#feat_size=dataloader.collate_fn.feature_size,
feat_size
=
dataloader
.
dataset
.
feature_size
,
#dict_size=dataloader.collate_fn.vocab_size,
dict_size
=
dataloader
.
dataset
.
vocab_size
,
num_conv_layers
=
config
.
model
.
num_conv_layers
,
num_conv_layers
=
config
.
model
.
num_conv_layers
,
num_rnn_layers
=
config
.
model
.
num_rnn_layers
,
num_rnn_layers
=
config
.
model
.
num_rnn_layers
,
rnn_size
=
config
.
model
.
rnn_layer_size
,
rnn_size
=
config
.
model
.
rnn_layer_size
,
use_gru
=
config
.
model
.
use_gru
,
use_gru
=
config
.
model
.
use_gru
,
share_rnn_weights
=
config
.
model
.
share_rnn_weights
,
share_rnn_weights
=
config
.
model
.
share_rnn_weights
,
blank_id
=
config
.
model
.
blank_id
)
blank_id
=
config
.
model
.
blank_id
,
ctc_grad_norm_type
=
config
.
model
.
ctc_grad_norm_type
,
)
infos
=
Checkpoint
().
load_parameters
(
infos
=
Checkpoint
().
load_parameters
(
model
,
checkpoint_path
=
checkpoint_path
)
model
,
checkpoint_path
=
checkpoint_path
)
logger
.
info
(
f
"checkpoint info:
{
infos
}
"
)
logger
.
info
(
f
"checkpoint info:
{
infos
}
"
)
...
@@ -244,36 +248,22 @@ class DeepSpeech2Model(nn.Layer):
...
@@ -244,36 +248,22 @@ class DeepSpeech2Model(nn.Layer):
DeepSpeech2Model
DeepSpeech2Model
The model built from config.
The model built from config.
"""
"""
model
=
cls
(
feat_size
=
config
.
feat_size
,
model
=
cls
(
feat_size
=
config
.
feat_size
,
dict_size
=
config
.
dict_size
,
dict_size
=
config
.
dict_size
,
num_conv_layers
=
config
.
num_conv_layers
,
num_conv_layers
=
config
.
num_conv_layers
,
num_rnn_layers
=
config
.
num_rnn_layers
,
num_rnn_layers
=
config
.
num_rnn_layers
,
rnn_size
=
config
.
rnn_layer_size
,
rnn_size
=
config
.
rnn_layer_size
,
use_gru
=
config
.
use_gru
,
use_gru
=
config
.
use_gru
,
share_rnn_weights
=
config
.
share_rnn_weights
,
share_rnn_weights
=
config
.
share_rnn_weights
,
blank_id
=
config
.
blank_id
)
blank_id
=
config
.
blank_id
,
ctc_grad_norm_type
=
config
.
ctc_grad_norm_type
,
)
return
model
return
model
class
DeepSpeech2InferModel
(
DeepSpeech2Model
):
class
DeepSpeech2InferModel
(
DeepSpeech2Model
):
def
__init__
(
self
,
def
__init__
(
self
,
*
args
,
**
kwargs
):
feat_size
,
super
().
__init__
(
*
args
,
**
kwargs
)
dict_size
,
num_conv_layers
=
2
,
num_rnn_layers
=
3
,
rnn_size
=
1024
,
use_gru
=
False
,
share_rnn_weights
=
True
,
blank_id
=
0
):
super
().
__init__
(
feat_size
=
feat_size
,
dict_size
=
dict_size
,
num_conv_layers
=
num_conv_layers
,
num_rnn_layers
=
num_rnn_layers
,
rnn_size
=
rnn_size
,
use_gru
=
use_gru
,
share_rnn_weights
=
share_rnn_weights
,
blank_id
=
blank_id
)
def
forward
(
self
,
audio
,
audio_len
):
def
forward
(
self
,
audio
,
audio_len
):
"""export model function
"""export model function
...
...
deepspeech/models/ds2_online/deepspeech2.py
浏览文件 @
eb4b3892
...
@@ -255,12 +255,13 @@ class DeepSpeech2ModelOnline(nn.Layer):
...
@@ -255,12 +255,13 @@ class DeepSpeech2ModelOnline(nn.Layer):
fc_layers_size_list
=
[
512
,
256
],
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
True
,
#Use gru if set True. Use simple rnn if set False.
use_gru
=
True
,
#Use gru if set True. Use simple rnn if set False.
blank_id
=
0
,
# index of blank in vocob.txt
blank_id
=
0
,
# index of blank in vocob.txt
))
ctc_grad_norm_type
=
'instance'
,
))
if
config
is
not
None
:
if
config
is
not
None
:
config
.
merge_from_other_cfg
(
default
)
config
.
merge_from_other_cfg
(
default
)
return
default
return
default
def
__init__
(
self
,
def
__init__
(
self
,
feat_size
,
feat_size
,
dict_size
,
dict_size
,
num_conv_layers
=
2
,
num_conv_layers
=
2
,
...
@@ -270,7 +271,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
...
@@ -270,7 +271,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
num_fc_layers
=
2
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
,
use_gru
=
False
,
blank_id
=
0
):
blank_id
=
0
,
ctc_grad_norm_type
=
'instance'
,
):
super
().
__init__
()
super
().
__init__
()
self
.
encoder
=
CRNNEncoder
(
self
.
encoder
=
CRNNEncoder
(
feat_size
=
feat_size
,
feat_size
=
feat_size
,
...
@@ -290,7 +292,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
...
@@ -290,7 +292,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
dropout_rate
=
0.0
,
dropout_rate
=
0.0
,
reduction
=
True
,
# sum
reduction
=
True
,
# sum
batch_average
=
True
,
# sum / batch_size
batch_average
=
True
,
# sum / batch_size
grad_norm_type
=
'instance'
)
grad_norm_type
=
ctc_grad_norm_type
)
def
forward
(
self
,
audio
,
audio_len
,
text
,
text_len
):
def
forward
(
self
,
audio
,
audio_len
,
text
,
text_len
):
"""Compute Model loss
"""Compute Model loss
...
@@ -348,7 +350,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
...
@@ -348,7 +350,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
DeepSpeech2ModelOnline
DeepSpeech2ModelOnline
The model built from pretrained result.
The model built from pretrained result.
"""
"""
model
=
cls
(
feat_size
=
dataloader
.
collate_fn
.
feature_size
,
model
=
cls
(
feat_size
=
dataloader
.
collate_fn
.
feature_size
,
dict_size
=
dataloader
.
collate_fn
.
vocab_size
,
dict_size
=
dataloader
.
collate_fn
.
vocab_size
,
num_conv_layers
=
config
.
model
.
num_conv_layers
,
num_conv_layers
=
config
.
model
.
num_conv_layers
,
num_rnn_layers
=
config
.
model
.
num_rnn_layers
,
num_rnn_layers
=
config
.
model
.
num_rnn_layers
,
...
@@ -357,7 +360,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
...
@@ -357,7 +360,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
num_fc_layers
=
config
.
model
.
num_fc_layers
,
num_fc_layers
=
config
.
model
.
num_fc_layers
,
fc_layers_size_list
=
config
.
model
.
fc_layers_size_list
,
fc_layers_size_list
=
config
.
model
.
fc_layers_size_list
,
use_gru
=
config
.
model
.
use_gru
,
use_gru
=
config
.
model
.
use_gru
,
blank_id
=
config
.
model
.
blank_id
)
blank_id
=
config
.
model
.
blank_id
,
ctc_grad_norm_type
=
config
.
model
.
ctc_grad_norm_type
,
)
infos
=
Checkpoint
().
load_parameters
(
infos
=
Checkpoint
().
load_parameters
(
model
,
checkpoint_path
=
checkpoint_path
)
model
,
checkpoint_path
=
checkpoint_path
)
logger
.
info
(
f
"checkpoint info:
{
infos
}
"
)
logger
.
info
(
f
"checkpoint info:
{
infos
}
"
)
...
@@ -376,7 +380,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
...
@@ -376,7 +380,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
DeepSpeech2ModelOnline
DeepSpeech2ModelOnline
The model built from config.
The model built from config.
"""
"""
model
=
cls
(
feat_size
=
config
.
feat_size
,
model
=
cls
(
feat_size
=
config
.
feat_size
,
dict_size
=
config
.
dict_size
,
dict_size
=
config
.
dict_size
,
num_conv_layers
=
config
.
num_conv_layers
,
num_conv_layers
=
config
.
num_conv_layers
,
num_rnn_layers
=
config
.
num_rnn_layers
,
num_rnn_layers
=
config
.
num_rnn_layers
,
...
@@ -385,33 +390,14 @@ class DeepSpeech2ModelOnline(nn.Layer):
...
@@ -385,33 +390,14 @@ class DeepSpeech2ModelOnline(nn.Layer):
num_fc_layers
=
config
.
num_fc_layers
,
num_fc_layers
=
config
.
num_fc_layers
,
fc_layers_size_list
=
config
.
fc_layers_size_list
,
fc_layers_size_list
=
config
.
fc_layers_size_list
,
use_gru
=
config
.
use_gru
,
use_gru
=
config
.
use_gru
,
blank_id
=
config
.
blank_id
)
blank_id
=
config
.
blank_id
,
ctc_grad_norm_type
=
config
.
ctc_grad_norm_type
,
)
return
model
return
model
class
DeepSpeech2InferModelOnline
(
DeepSpeech2ModelOnline
):
class
DeepSpeech2InferModelOnline
(
DeepSpeech2ModelOnline
):
def
__init__
(
self
,
def
__init__
(
self
,
*
args
,
**
kwargs
):
feat_size
,
super
().
__init__
(
*
args
,
**
kwargs
)
dict_size
,
num_conv_layers
=
2
,
num_rnn_layers
=
4
,
rnn_size
=
1024
,
rnn_direction
=
'forward'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
,
blank_id
=
0
):
super
().
__init__
(
feat_size
=
feat_size
,
dict_size
=
dict_size
,
num_conv_layers
=
num_conv_layers
,
num_rnn_layers
=
num_rnn_layers
,
rnn_size
=
rnn_size
,
rnn_direction
=
rnn_direction
,
num_fc_layers
=
num_fc_layers
,
fc_layers_size_list
=
fc_layers_size_list
,
use_gru
=
use_gru
,
blank_id
=
blank_id
)
def
forward
(
self
,
audio_chunk
,
audio_chunk_lens
,
chunk_state_h_box
,
def
forward
(
self
,
audio_chunk
,
audio_chunk_lens
,
chunk_state_h_box
,
chunk_state_c_box
):
chunk_state_c_box
):
...
...
deepspeech/utils/log.py
浏览文件 @
eb4b3892
...
@@ -127,6 +127,7 @@ class Autolog:
...
@@ -127,6 +127,7 @@ class Autolog:
else
:
else
:
gpu_id
=
None
gpu_id
=
None
infer_config
=
inference
.
Config
()
infer_config
=
inference
.
Config
()
self
.
autolog
=
auto_log
.
AutoLogger
(
self
.
autolog
=
auto_log
.
AutoLogger
(
model_name
=
model_name
,
model_name
=
model_name
,
model_precision
=
model_precision
,
model_precision
=
model_precision
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录