Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
718ae52e
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
718ae52e
编写于
8月 10, 2021
作者:
H
huangyuxin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add from_config function to ds2_oneline and ds2
上级
7a3d1641
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
141 addition
and
202 deletion
+141
-202
deepspeech/exps/deepspeech2/model.py
deepspeech/exps/deepspeech2/model.py
+28
-107
deepspeech/models/ds2/deepspeech2.py
deepspeech/models/ds2/deepspeech2.py
+33
-0
deepspeech/models/ds2_online/deepspeech2.py
deepspeech/models/ds2_online/deepspeech2.py
+34
-10
examples/aishell/s0/conf/deepspeech2_online.yaml
examples/aishell/s0/conf/deepspeech2_online.yaml
+6
-6
tests/deepspeech2_online_model_test.py
tests/deepspeech2_online_model_test.py
+40
-79
未找到文件。
deepspeech/exps/deepspeech2/model.py
浏览文件 @
718ae52e
...
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains DeepSpeech2 model."""
"""Contains DeepSpeech2
and DeepSpeech2Online
model."""
import
time
from
collections
import
defaultdict
from
pathlib
import
Path
...
...
@@ -38,8 +38,6 @@ from deepspeech.utils import layer_tools
from
deepspeech.utils
import
mp_tools
from
deepspeech.utils.log
import
Autolog
from
deepspeech.utils.log
import
Log
#from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
#from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
logger
=
Log
(
__name__
).
getlog
()
...
...
@@ -123,40 +121,20 @@ class DeepSpeech2Trainer(Trainer):
return
total_loss
,
num_seen_utts
def
setup_model
(
self
):
config
=
self
.
config
if
hasattr
(
self
,
"train_loader"
):
config
.
defrost
()
config
.
model
.
feat_size
=
self
.
train_loader
.
collate_fn
.
feature_size
config
.
model
.
dict_size
=
self
.
train_loader
.
collate_fn
.
vocab_size
config
.
freeze
()
elif
hasattr
(
self
,
"test_loader"
):
config
.
defrost
()
config
.
model
.
feat_size
=
self
.
test_loader
.
collate_fn
.
feature_size
config
.
model
.
dict_size
=
self
.
test_loader
.
collate_fn
.
vocab_size
config
.
freeze
()
else
:
raise
Exception
(
"Please setup the dataloader first"
)
config
=
self
.
config
.
clone
()
config
.
defrost
()
assert
(
self
.
train_loader
.
collate_fn
.
feature_size
==
self
.
test_loader
.
collate_fn
.
feature_size
)
assert
(
self
.
train_loader
.
collate_fn
.
vocab_size
==
self
.
test_loader
.
collate_fn
.
vocab_size
)
config
.
model
.
feat_size
=
self
.
train_loader
.
collate_fn
.
feature_size
config
.
model
.
dict_size
=
self
.
train_loader
.
collate_fn
.
vocab_size
config
.
freeze
()
if
self
.
args
.
model_type
==
'offline'
:
model
=
DeepSpeech2Model
(
feat_size
=
config
.
model
.
feat_size
,
dict_size
=
config
.
model
.
dict_size
,
num_conv_layers
=
config
.
model
.
num_conv_layers
,
num_rnn_layers
=
config
.
model
.
num_rnn_layers
,
rnn_size
=
config
.
model
.
rnn_layer_size
,
use_gru
=
config
.
model
.
use_gru
,
share_rnn_weights
=
config
.
model
.
share_rnn_weights
)
model
=
DeepSpeech2Model
.
from_config
(
config
.
model
)
elif
self
.
args
.
model_type
==
'online'
:
model
=
DeepSpeech2ModelOnline
(
feat_size
=
config
.
model
.
feat_size
,
dict_size
=
config
.
model
.
dict_size
,
num_conv_layers
=
config
.
model
.
num_conv_layers
,
num_rnn_layers
=
config
.
model
.
num_rnn_layers
,
rnn_size
=
config
.
model
.
rnn_layer_size
,
rnn_direction
=
config
.
model
.
rnn_direction
,
num_fc_layers
=
config
.
model
.
num_fc_layers
,
fc_layers_size_list
=
config
.
model
.
fc_layers_size_list
,
use_gru
=
config
.
model
.
use_gru
)
model
=
DeepSpeech2ModelOnline
.
from_config
(
config
.
model
)
else
:
raise
Exception
(
"wrong model type"
)
if
self
.
parallel
:
...
...
@@ -194,6 +172,9 @@ class DeepSpeech2Trainer(Trainer):
config
.
data
.
manifest
=
config
.
data
.
dev_manifest
dev_dataset
=
ManifestDataset
.
from_config
(
config
)
config
.
data
.
manifest
=
config
.
data
.
test_manifest
test_dataset
=
ManifestDataset
.
from_config
(
config
)
if
self
.
parallel
:
batch_sampler
=
SortagradDistributedBatchSampler
(
train_dataset
,
...
...
@@ -217,19 +198,29 @@ class DeepSpeech2Trainer(Trainer):
config
.
collator
.
augmentation_config
=
""
collate_fn_dev
=
SpeechCollator
.
from_config
(
config
)
config
.
collator
.
keep_transcription_text
=
True
config
.
collator
.
augmentation_config
=
""
collate_fn_test
=
SpeechCollator
.
from_config
(
config
)
self
.
train_loader
=
DataLoader
(
train_dataset
,
batch_sampler
=
batch_sampler
,
collate_fn
=
collate_fn_train
,
num_workers
=
config
.
collator
.
num_workers
)
print
(
"feature_size"
,
self
.
train_loader
.
collate_fn
.
feature_size
)
self
.
valid_loader
=
DataLoader
(
dev_dataset
,
batch_size
=
config
.
collator
.
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
collate_fn
=
collate_fn_dev
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
self
.
test_loader
=
DataLoader
(
test_dataset
,
batch_size
=
config
.
decoding
.
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
collate_fn
=
collate_fn_test
)
logger
.
info
(
"Setup train/valid/test Dataloader!"
)
class
DeepSpeech2Tester
(
DeepSpeech2Trainer
):
...
...
@@ -371,20 +362,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
infer_model
.
eval
()
feat_dim
=
self
.
test_loader
.
collate_fn
.
feature_size
if
self
.
args
.
model_type
==
'offline'
:
static_model
=
paddle
.
jit
.
to_static
(
infer_model
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
,
feat_dim
],
dtype
=
'float32'
),
# audio, [B,T,D]
paddle
.
static
.
InputSpec
(
shape
=
[
None
],
dtype
=
'int64'
),
# audio_length, [B]
])
elif
self
.
args
.
model_type
==
'online'
:
static_model
=
infer_model
.
export
()
else
:
raise
Exception
(
"wrong model type"
)
static_model
=
infer_model
.
export
()
logger
.
info
(
f
"Export code:
{
static_model
.
forward
.
code
}
"
)
paddle
.
jit
.
save
(
static_model
,
self
.
args
.
export_path
)
...
...
@@ -408,63 +386,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self
.
iteration
=
0
self
.
epoch
=
0
'''
def setup_model(self):
config = self.config
if self.args.model_type == 'offline':
model = DeepSpeech2Model(
feat_size=self.test_loader.collate_fn.feature_size,
dict_size=self.test_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
elif self.args.model_type == 'online':
model = DeepSpeech2ModelOnline(
feat_size=self.test_loader.collate_fn.feature_size,
dict_size=self.test_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
rnn_direction=config.model.rnn_direction,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
use_gru=config.model.use_gru)
else:
raise Exception("Wrong model type")
self.model = model
logger.info("Setup model!")
'''
def
setup_dataloader
(
self
):
config
=
self
.
config
.
clone
()
config
.
defrost
()
# return raw text
config
.
data
.
manifest
=
config
.
data
.
test_manifest
# filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now.
# config.data.min_input_len = 0.0 # second
# config.data.max_input_len = float('inf') # second
# config.data.min_output_len = 0.0 # tokens
# config.data.max_output_len = float('inf') # tokens
# config.data.min_output_input_ratio = 0.00
# config.data.max_output_input_ratio = float('inf')
test_dataset
=
ManifestDataset
.
from_config
(
config
)
config
.
collator
.
keep_transcription_text
=
True
config
.
collator
.
augmentation_config
=
""
# return text ord id
self
.
test_loader
=
DataLoader
(
test_dataset
,
batch_size
=
config
.
decoding
.
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
collate_fn
=
SpeechCollator
.
from_config
(
config
))
logger
.
info
(
"Setup test Dataloader!"
)
def
setup_output_dir
(
self
):
"""Create a directory used for output.
"""
...
...
deepspeech/models/ds2/deepspeech2.py
浏览文件 @
718ae52e
...
...
@@ -228,6 +228,27 @@ class DeepSpeech2Model(nn.Layer):
layer_tools
.
summary
(
model
)
return
model
@
classmethod
def
from_config
(
cls
,
config
):
"""Build a DeepSpeec2Model from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2Model
The model built from config.
"""
model
=
cls
(
feat_size
=
config
.
feat_size
,
dict_size
=
config
.
dict_size
,
num_conv_layers
=
config
.
num_conv_layers
,
num_rnn_layers
=
config
.
num_rnn_layers
,
rnn_size
=
config
.
rnn_layer_size
,
use_gru
=
config
.
use_gru
,
share_rnn_weights
=
config
.
share_rnn_weights
)
return
model
class
DeepSpeech2InferModel
(
DeepSpeech2Model
):
def
__init__
(
self
,
...
...
@@ -260,3 +281,15 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
eouts
,
eouts_len
=
self
.
encoder
(
audio
,
audio_len
)
probs
=
self
.
decoder
.
softmax
(
eouts
)
return
probs
def
export
(
self
):
static_model
=
paddle
.
jit
.
to_static
(
self
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
,
self
.
encoder
.
feat_size
],
dtype
=
'float32'
),
# audio, [B,T,D]
paddle
.
static
.
InputSpec
(
shape
=
[
None
],
dtype
=
'int64'
),
# audio_length, [B]
])
return
static_model
deepspeech/models/ds2_online/deepspeech2.py
浏览文件 @
718ae52e
...
...
@@ -51,8 +51,9 @@ class CRNNEncoder(nn.Layer):
self
.
use_gru
=
use_gru
self
.
conv
=
Conv2dSubsampling4Online
(
feat_size
,
32
,
dropout_rate
=
0.0
)
i_size
=
self
.
conv
.
output_dim
self
.
output_dim
=
self
.
conv
.
output_dim
i_size
=
self
.
conv
.
output_dim
self
.
rnn
=
nn
.
LayerList
()
self
.
layernorm_list
=
nn
.
LayerList
()
self
.
fc_layers_list
=
nn
.
LayerList
()
...
...
@@ -82,16 +83,18 @@ class CRNNEncoder(nn.Layer):
num_layers
=
1
,
direction
=
rnn_direction
))
self
.
layernorm_list
.
append
(
nn
.
LayerNorm
(
layernorm_size
))
self
.
output_dim
=
layernorm_size
fc_input_size
=
layernorm_size
for
i
in
range
(
self
.
num_fc_layers
):
self
.
fc_layers_list
.
append
(
nn
.
Linear
(
fc_input_size
,
fc_layers_size_list
[
i
]))
fc_input_size
=
fc_layers_size_list
[
i
]
self
.
output_dim
=
fc_layers_size_list
[
i
]
@
property
def
output_size
(
self
):
return
self
.
fc_layers_size_list
[
-
1
]
return
self
.
output_dim
def
forward
(
self
,
x
,
x_lens
,
init_state_h_box
=
None
,
init_state_c_box
=
None
):
"""Compute Encoder outputs
...
...
@@ -190,9 +193,6 @@ class CRNNEncoder(nn.Layer):
for
i
in
range
(
0
,
num_chunk
):
start
=
i
*
chunk_stride
end
=
start
+
chunk_size
# end = min(start + chunk_size, max_len)
# if (end - start < receptive_field_length):
# break
x_chunk
=
padded_x
[:,
start
:
end
,
:]
x_len_left
=
paddle
.
where
(
x_lens
-
i
*
chunk_stride
<
0
,
...
...
@@ -221,8 +221,6 @@ class DeepSpeech2ModelOnline(nn.Layer):
:type text_data: Variable
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:param masks: Masks data layer to reset padding.
:type masks: Variable
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
...
...
@@ -231,6 +229,10 @@ class DeepSpeech2ModelOnline(nn.Layer):
:type num_rnn_layers: int
:param rnn_size: RNN layer size (dimension of RNN cells).
:type rnn_size: int
:param num_fc_layers: Number of stacking FC layers.
:type num_fc_layers: int
:param fc_layers_size_list: The list of FC layer sizes.
:type fc_layers_size_list: [int,]
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:return: A tuple of an output unnormalized log probability layer (
...
...
@@ -274,7 +276,6 @@ class DeepSpeech2ModelOnline(nn.Layer):
fc_layers_size_list
=
fc_layers_size_list
,
rnn_size
=
rnn_size
,
use_gru
=
use_gru
)
assert
(
self
.
encoder
.
output_size
==
fc_layers_size_list
[
-
1
])
self
.
decoder
=
CTCDecoder
(
odim
=
dict_size
,
# <blank> is in vocab
...
...
@@ -337,7 +338,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
Returns
-------
DeepSpeech2Model
DeepSpeech2Model
Online
The model built from pretrained result.
"""
model
=
cls
(
feat_size
=
dataloader
.
collate_fn
.
feature_size
,
...
...
@@ -355,6 +356,29 @@ class DeepSpeech2ModelOnline(nn.Layer):
layer_tools
.
summary
(
model
)
return
model
@
classmethod
def
from_config
(
cls
,
config
):
"""Build a DeepSpeec2ModelOnline from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2ModelOnline
The model built from config.
"""
model
=
cls
(
feat_size
=
config
.
feat_size
,
dict_size
=
config
.
dict_size
,
num_conv_layers
=
config
.
num_conv_layers
,
num_rnn_layers
=
config
.
num_rnn_layers
,
rnn_size
=
config
.
rnn_layer_size
,
rnn_direction
=
config
.
rnn_direction
,
num_fc_layers
=
config
.
num_fc_layers
,
fc_layers_size_list
=
config
.
fc_layers_size_list
,
use_gru
=
config
.
use_gru
)
return
model
class
DeepSpeech2InferModelOnline
(
DeepSpeech2ModelOnline
):
def
__init__
(
self
,
...
...
@@ -392,7 +416,7 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
None
,
self
.
encoder
.
feat_size
],
#[B, chunk_size, feat_dim]
dtype
=
'float32'
),
# audio, [B,T,D]
dtype
=
'float32'
),
paddle
.
static
.
InputSpec
(
shape
=
[
None
],
dtype
=
'int64'
),
# audio_length, [B]
paddle
.
static
.
InputSpec
(
...
...
examples/aishell/s0/conf/deepspeech2_online.yaml
浏览文件 @
718ae52e
...
...
@@ -36,17 +36,17 @@ collator:
model
:
num_conv_layers
:
2
num_rnn_layers
:
4
num_rnn_layers
:
3
rnn_layer_size
:
1024
rnn_direction
:
bidirect
num_fc_layers
:
2
fc_layers_size_list
:
512,
256
rnn_direction
:
forward
# [forward, bidirect]
num_fc_layers
:
1
fc_layers_size_list
:
512,
use_gru
:
True
training
:
n_epoch
:
50
lr
:
2e-3
lr_decay
:
0.83
lr_decay
:
0.83
# 0.83
weight_decay
:
1e-06
global_grad_clip
:
3.0
log_interval
:
100
...
...
@@ -55,7 +55,7 @@ training:
latest_n
:
5
decoding
:
batch_size
:
64
batch_size
:
32
error_rate_type
:
cer
decoding_method
:
ctc_beam_search
lang_model_path
:
data/lm/zh_giga.no_cna_cmn.prune01244.klm
...
...
tests/deepspeech2_online_model_test.py
浏览文件 @
718ae52e
...
...
@@ -106,18 +106,34 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
self
.
assertEqual
(
loss
.
numel
(),
1
)
def
test_ds2_6
(
self
):
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
3
,
rnn_size
=
1024
,
rnn_direction
=
'bidirect'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
False
)
loss
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
self
.
assertEqual
(
loss
.
numel
(),
1
)
def
test_ds2_7
(
self
):
use_gru
=
False
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
1
,
rnn_size
=
1024
,
rnn_direction
=
'forward'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru
=
True
)
use_gru
=
use_gru
)
model
.
eval
()
paddle
.
device
.
set_device
(
"cpu"
)
de_ch_size
=
9
de_ch_size
=
8
eouts
,
eouts_lens
,
final_state_h_box
,
final_state_c_box
=
model
.
encoder
(
self
.
audio
,
self
.
audio_len
)
...
...
@@ -126,99 +142,44 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
eouts_by_chk
=
paddle
.
concat
(
eouts_by_chk_list
,
axis
=
1
)
eouts_lens_by_chk
=
paddle
.
add_n
(
eouts_lens_by_chk_list
)
decode_max_len
=
eouts
.
shape
[
1
]
print
(
"dml"
,
decode_max_len
)
eouts_by_chk
=
eouts_by_chk
[:,
:
decode_max_len
,
:]
self
.
assertEqual
(
paddle
.
sum
(
paddle
.
abs
(
paddle
.
subtract
(
eouts_lens
,
eouts_lens_by_chk
))),
0
)
self
.
assertEqual
(
paddle
.
sum
(
paddle
.
abs
(
paddle
.
subtract
(
eouts
,
eouts_by_chk
))),
0
)
self
.
assertEqual
(
paddle
.
allclose
(
eouts_by_chk
,
eouts
),
True
)
self
.
assertEqual
(
paddle
.
allclose
(
final_state_h_box
,
final_state_h_box_chk
),
True
)
self
.
assertEqual
(
paddle
.
allclose
(
final_state_c_box
,
final_state_c_box_chk
),
True
)
"""
print ("conv_x", conv_x)
print ("conv_x_by_chk", conv_x_by_chk)
print ("final_state_list", final_state_list)
#print ("final_state_list_by_chk", final_state_list_by_chk)
print (paddle.sum(paddle.abs(paddle.subtract(eouts[:,:de_ch_size,:], eouts_by_chk[:,:de_ch_size,:]))))
print (paddle.allclose(eouts[:,:de_ch_size,:], eouts_by_chk[:,:de_ch_size,:]))
print (paddle.sum(paddle.abs(paddle.subtract(eouts[:,de_ch_size:de_ch_size*2,:], eouts_by_chk[:,de_ch_size:de_ch_size*2,:]))))
print (paddle.allclose(eouts[:,de_ch_size:de_ch_size*2,:], eouts_by_chk[:,de_ch_size:de_ch_size*2,:]))
print (paddle.sum(paddle.abs(paddle.subtract(eouts[:,de_ch_size*2:de_ch_size*3,:], eouts_by_chk[:,de_ch_size*2:de_ch_size*3,:]))))
print (paddle.allclose(eouts[:,de_ch_size*2:de_ch_size*3,:], eouts_by_chk[:,de_ch_size*2:de_ch_size*3,:]))
print (paddle.sum(paddle.abs(paddle.subtract(eouts, eouts_by_chk))))
print (paddle.sum(paddle.abs(paddle.subtract(eouts, eouts_by_chk))))
print (paddle.allclose(eouts[:,:,:], eouts_by_chk[:,:,:]))
"""
"""
def split_into_chunk(self, x, x_lens, decoder_chunk_size, subsampling_rate,
receptive_field_length):
chunk_size = (decoder_chunk_size - 1
) * subsampling_rate + receptive_field_length
chunk_stride = subsampling_rate * decoder_chunk_size
max_len = x.shape[1]
assert (chunk_size <= max_len)
x_chunk_list = []
x_chunk_lens_list = []
padding_len = chunk_stride - (max_len - chunk_size) % chunk_stride
padding = paddle.zeros((x.shape[0], padding_len, x.shape[2]))
padded_x = paddle.concat([x, padding], axis=1)
num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1
num_chunk = int(num_chunk)
for i in range(0, num_chunk):
start = i * chunk_stride
end = start + chunk_size
x_chunk = padded_x[:, start:end, :]
x_len_left = paddle.where(x_lens - i * chunk_stride < 0,
paddle.zeros_like(x_lens),
x_lens - i * chunk_stride)
x_chunk_len_tmp = paddle.ones_like(x_lens) * chunk_size
x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp,
x_len_left, x_chunk_len_tmp)
x_chunk_list.append(x_chunk)
x_chunk_lens_list.append(x_chunk_lens)
return x_chunk_list, x_chunk_lens_list
if
use_gru
==
False
:
self
.
assertEqual
(
paddle
.
allclose
(
final_state_c_box
,
final_state_c_box_chk
),
True
)
def test_ds2_7(self):
def
test_ds2_8
(
self
):
use_gru
=
True
model
=
DeepSpeech2ModelOnline
(
feat_size
=
self
.
feat_dim
,
dict_size
=
10
,
num_conv_layers
=
2
,
num_rnn_layers
=
1
,
rnn_size
=
1024
,
rnn_direction
=
'forward'
,
num_fc_layers
=
2
,
fc_layers_size_list
=
[
512
,
256
],
use_gru=
True
)
use_gru
=
use_gru
)
model
.
eval
()
paddle
.
device
.
set_device
(
"cpu"
)
de_ch_size = 9
audio_chunk_list, audio_chunk_lens_list = self.split_into_chunk(
self.audio, self.audio_len, de_ch_size,
model.encoder.conv.subsampling_rate,
model.encoder.conv.receptive_field_length)
eouts_prefix = None
eouts_lens_prefix = None
chunk_state_list = [None] * model.encoder.num_rnn_layers
for i, audio_chunk in enumerate(audio_chunk_list):
audio_chunk_lens = audio_chunk_lens_list[i]
eouts_prefix, eouts_lens_prefix, chunk_state_list = model.decode_prob_by_chunk(
audio_chunk, audio_chunk_lens, eouts_prefix, eouts_lens_prefix,
chunk_state_list)
# print (i, probs_pre_chunks.shape)
probs, eouts, eouts_lens, final_state_list = model.decode_prob(
self.audio, self.audio_len)
de_ch_size
=
8
decode_max_len = probs.shape[1]
probs_pre_chunks = probs_pre_chunks[:, :decode_max_len, :]
self.assertEqual(paddle.allclose(probs, probs_pre_chunks), True)
"""
eouts
,
eouts_lens
,
final_state_h_box
,
final_state_c_box
=
model
.
encoder
(
self
.
audio
,
self
.
audio_len
)
eouts_by_chk_list
,
eouts_lens_by_chk_list
,
final_state_h_box_chk
,
final_state_c_box_chk
=
model
.
encoder
.
forward_chunk_by_chunk
(
self
.
audio
,
self
.
audio_len
,
de_ch_size
)
eouts_by_chk
=
paddle
.
concat
(
eouts_by_chk_list
,
axis
=
1
)
eouts_lens_by_chk
=
paddle
.
add_n
(
eouts_lens_by_chk_list
)
decode_max_len
=
eouts
.
shape
[
1
]
eouts_by_chk
=
eouts_by_chk
[:,
:
decode_max_len
,
:]
self
.
assertEqual
(
paddle
.
allclose
(
eouts_by_chk
,
eouts
),
True
)
self
.
assertEqual
(
paddle
.
allclose
(
final_state_h_box
,
final_state_h_box_chk
),
True
)
if
use_gru
==
False
:
self
.
assertEqual
(
paddle
.
allclose
(
final_state_c_box
,
final_state_c_box_chk
),
True
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录