Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
090e7947
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
090e7947
编写于
4月 12, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
model init from config
上级
a7244593
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
97 addition
and
69 deletion
+97
-69
.notebook/dataloader_with_tokens_tokenids.ipynb
.notebook/dataloader_with_tokens_tokenids.ipynb
+7
-28
deepspeech/models/u2.py
deepspeech/models/u2.py
+52
-34
examples/tiny/s1/conf/conformer.yaml
examples/tiny/s1/conf/conformer.yaml
+35
-3
tests/u2_model_test.py
tests/u2_model_test.py
+3
-4
未找到文件。
.notebook/dataloader_with_tokens_tokenids.ipynb
浏览文件 @
090e7947
...
...
@@ -3,7 +3,7 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "
downtown-invalid
",
"id": "
medieval-monday
",
"metadata": {},
"outputs": [
{
...
...
@@ -213,27 +213,6 @@
}
],
"source": [
"# batch_reader = create_dataloader(\n",
"# manifest_path=args.infer_manifest,\n",
"# vocab_filepath=args.vocab_path,\n",
"# mean_std_filepath=args.mean_std_path,\n",
"# augmentation_config='{}',\n",
"# #max_duration=float('inf'),\n",
"# max_duration=27.0,\n",
"# min_duration=0.0,\n",
"# stride_ms=10.0,\n",
"# window_ms=20.0,\n",
"# max_freq=None,\n",
"# specgram_type=args.specgram_type,\n",
"# use_dB_normalization=True,\n",
"# random_seed=0,\n",
"# keep_transcription_text=True,\n",
"# is_training=False,\n",
"# batch_size=args.num_samples,\n",
"# sortagrad=True,\n",
"# shuffle_method=None,\n",
"# dist=False)\n",
"\n",
"from deepspeech.frontend.utility import read_manifest\n",
"from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n",
"from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer\n",
...
...
@@ -375,7 +354,7 @@
},
{
"cell_type": "code",
"execution_count":
9
,
"execution_count":
6
,
"id": "minus-modern",
"metadata": {},
"outputs": [
...
...
@@ -391,8 +370,6 @@
" [97, 37, 26, 79, 26, 1, 38, 82, 1, 58, 102, 1, 17, 79, 64, 87, 37, 26, 79, 1, 61, 64, 97]])\n",
"test raw: W%\u001a\u0001Wa\u001a=W&\u001aR\n",
"test raw: a%\u001aO\u001a\u0001&R\u0001:f\u0001\u0011O@W%\u001aO\u0001=@a\n",
"audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 173, 184, 190, 203])\n",
"test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
" [12, 13, 11, 22, 23])\n",
"audio: Tensor(shape=[5, 203, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n",
...
...
@@ -434,7 +411,9 @@
" ...,\n",
" [-4.81728077 , -10.65084648, 0.00000000 , ..., 3.19982862 , 8.42359638 , 7.95100546 ],\n",
" [-7.54755068 , -12.56441689, 0.00000000 , ..., 4.12789631 , 6.98472023 , 7.79936218 ],\n",
" [-8.79256725 , -11.23776722, 0.00000000 , ..., 1.31829071 , 1.30352044 , 6.80789280 ]]])\n"
" [-8.79256725 , -11.23776722, 0.00000000 , ..., 1.31829071 , 1.30352044 , 6.80789280 ]]])\n",
"audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n",
" [163, 173, 184, 190, 203])\n"
]
}
],
...
...
@@ -472,16 +451,16 @@
" print('test:', text)\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
" print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
" print('audio len:', audio_len)\n",
" print('test len:', text_len)\n",
" print('audio:', audio)\n",
" print('audio len:', audio_len)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c
hronic-diagram
",
"id": "c
ompetitive-mounting
",
"metadata": {},
"outputs": [],
"source": []
...
...
deepspeech/models/u2.py
浏览文件 @
090e7947
...
...
@@ -57,7 +57,7 @@ logger = logging.getLogger(__name__)
__all__
=
[
'U2TransformerModel'
,
"U2ConformerModel"
]
class
U2Model
(
nn
.
Module
):
class
U2
Base
Model
(
nn
.
Module
):
"""CTC-Attention hybrid Encoder-Decoder model"""
def
__init__
(
self
,
...
...
@@ -635,28 +635,9 @@ class U2Model(nn.Module):
return
decoder_out
class
U2
TransformerModel
(
U2
Model
):
class
U2
Model
(
U2Base
Model
):
def
__init__
(
self
,
configs
:
dict
):
if
configs
[
'cmvn_file'
]
is
not
None
:
mean
,
istd
=
load_cmvn
(
configs
[
'cmvn_file'
],
configs
[
'cmvn_file_type'
])
global_cmvn
=
GlobalCMVN
(
paddle
.
to_tensor
(
mean
).
float
(),
paddle
.
to_tensor
(
istd
).
float
())
else
:
global_cmvn
=
None
input_dim
=
configs
[
'input_dim'
]
vocab_size
=
configs
[
'output_dim'
]
encoder_type
=
configs
.
get
(
'encoder'
,
'transformer'
)
assert
encoder_type
==
'transformer'
encoder
=
TransformerEncoder
(
input_dim
,
global_cmvn
=
global_cmvn
,
**
configs
[
'encoder_conf'
])
decoder
=
TransformerDecoder
(
vocab_size
,
encoder
.
output_size
(),
**
configs
[
'decoder_conf'
])
ctc
=
CTCDecoder
(
vocab_size
,
encoder
.
output_size
())
vocab_size
,
encoder
,
decoder
,
ctc
=
U2Model
.
_init_from_config
(
configs
)
super
().
__init__
(
vocab_size
=
vocab_size
,
...
...
@@ -665,9 +646,19 @@ class U2TransformerModel(U2Model):
ctc
=
ctc
,
**
configs
[
'model_conf'
])
@
classmethod
def
_init_from_config
(
cls
,
configs
:
dict
):
"""init sub module for model.
class
U2ConformerModel
(
U2Model
):
def
__init__
(
self
,
configs
:
dict
):
Args:
configs (dict): config dict.
Raises:
ValueError: raise when using not support encoder type.
Returns:
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
"""
if
configs
[
'cmvn_file'
]
is
not
None
:
mean
,
istd
=
load_cmvn
(
configs
[
'cmvn_file'
],
configs
[
'cmvn_file_type'
])
...
...
@@ -679,19 +670,46 @@ class U2ConformerModel(U2Model):
input_dim
=
configs
[
'input_dim'
]
vocab_size
=
configs
[
'output_dim'
]
encoder_type
=
configs
.
get
(
'encoder'
,
'conformer'
)
assert
encoder_type
==
'conformer'
encoder
=
ConformerEncoder
(
input_dim
,
global_cmvn
=
global_cmvn
,
**
configs
[
'encoder_conf'
])
encoder_type
=
configs
.
get
(
'encoder'
,
'transformer'
)
logger
.
info
(
f
"U2 Encoder type:
{
encoder_type
}
"
)
if
encoder_type
==
'transformer'
:
encoder
=
TransformerEncoder
(
input_dim
,
global_cmvn
=
global_cmvn
,
**
configs
[
'encoder_conf'
])
elif
encoder_type
==
'conformer'
:
encoder
=
ConformerEncoder
(
input_dim
,
global_cmvn
=
global_cmvn
,
**
configs
[
'encoder_conf'
])
else
:
raise
ValueError
(
"not support encoder type:{encoder_type}"
)
decoder
=
TransformerDecoder
(
vocab_size
,
encoder
.
output_size
(),
**
configs
[
'decoder_conf'
])
ctc
=
CTCDecoder
(
vocab_size
,
encoder
.
output_size
())
return
vocab_size
,
encoder
,
decoder
,
ctc
super
().
__init__
(
vocab_size
=
vocab_size
,
encoder
=
encoder
,
decoder
=
decoder
,
ctc
=
ctc
,
**
configs
[
'model_conf'
])
@
classmethod
def
from_pretrained
(
cls
,
dataset
,
config
,
checkpoint_path
):
"""Build a DeepSpeech2Model model from a pretrained model.
Args:
dataset (paddle.io.Dataset): [description]
config (yacs.config.CfgNode): model configs
checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name
Returns:
DeepSpeech2Model: The model built from pretrained result.
"""
vocab_size
,
encoder
,
decoder
,
ctc
=
U2Model
.
_init_from_config
(
configs
)
model
=
cls
(
vocab_size
=
vocab_size
,
encoder
=
encoder
,
decoder
=
decoder
,
ctc
=
ctc
,
**
configs
[
'model_conf'
])
infos
=
checkpoint
.
load_parameters
(
model
,
checkpoint_path
=
checkpoint_path
)
logger
.
info
(
f
"checkpoint info:
{
infos
}
"
)
layer_tools
.
summary
(
model
)
return
model
examples/tiny/s1/conf/conformer.yaml
浏览文件 @
090e7947
# https://yaml.org/type/float.html
data
:
train_manifest
:
data/manifest.tiny
dev_manifest
:
data/manifest.tiny
test_manifest
:
data/manifest.tiny
vocab_filepath
:
data/vocab.txt
unit_type
:
'
spm'
spm_model_prefix
:
'
bpe_unigram_200'
mean_std_filepath
:
data/mean_std.npz
augmentation_config
:
conf/augmentation.config
batch_size
:
4
max_input_len
:
27.0
min_input_len
:
0.0
max_output_len
:
.INF
min_output_len
:
0.0
max_output_input_ratio
:
.INF
min_output_input_ratio
:
0.0
raw_wav
:
True
# use raw_wav or kaldi feature
specgram_type
:
fbank
#linear, mfcc, fbank
feat_dim
:
80
delta_delta
:
False
target_sample_rate
:
16000
max_freq
:
None
n_fft
:
None
stride_ms
:
10.0
window_ms
:
20.0
use_dB_normalization
:
True
target_dB
:
-20
random_seed
:
0
keep_transcription_text
:
False
sortagrad
:
True
shuffle_method
:
batch_shuffle
num_workers
:
0
# network architecture
# encoder related
encoder
:
conformer
...
...
@@ -34,9 +69,6 @@ model_conf:
lsm_weight
:
0.1
# label smoothing option
length_normalized_loss
:
false
# use raw_wav or kaldi feature
raw_wav
:
true
# feature extraction
collate_conf
:
# waveform level config
...
...
tests/u2_model_test.py
浏览文件 @
090e7947
...
...
@@ -18,8 +18,7 @@ import unittest
import
numpy
as
np
from
yacs.config
import
CfgNode
as
CN
from
deepspeech.models.u2
import
U2TransformerModel
from
deepspeech.models.u2
import
U2ConformerModel
from
deepspeech.models.u2
import
U2Model
from
deepspeech.utils.layer_tools
import
summary
...
...
@@ -84,7 +83,7 @@ class TestU2Model(unittest.TestCase):
cfg
.
cmvn_file
=
None
cfg
.
cmvn_file_type
=
'npz'
cfg
.
freeze
()
model
=
U2
Transformer
Model
(
cfg
)
model
=
U2Model
(
cfg
)
summary
(
model
,
None
)
total_loss
,
attention_loss
,
ctc_loss
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
...
...
@@ -136,7 +135,7 @@ class TestU2Model(unittest.TestCase):
cfg
.
cmvn_file
=
None
cfg
.
cmvn_file_type
=
'npz'
cfg
.
freeze
()
model
=
U2
Conformer
Model
(
cfg
)
model
=
U2Model
(
cfg
)
summary
(
model
,
None
)
total_loss
,
attention_loss
,
ctc_loss
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录