未验证 提交 38d95784 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #735 from Jackwaterveg/ds2_online

Ds2 online
......@@ -30,11 +30,15 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--model_type")
args = parser.parse_args()
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
print_arguments(args)
# https://yaml.org/type/float.html
config = get_cfg_defaults()
config = get_cfg_defaults(args.model_type)
if args.config:
config.merge_from_file(args.config)
if args.opts:
......
......@@ -30,11 +30,15 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--model_type")
args = parser.parse_args()
print_arguments(args, globals())
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
# https://yaml.org/type/float.html
config = get_cfg_defaults()
config = get_cfg_defaults(args.model_type)
if args.config:
config.merge_from_file(args.config)
if args.opts:
......
......@@ -35,11 +35,15 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--model_type")
args = parser.parse_args()
if args.model_type is None:
args.model_type = 'offline'
print("model_type:{}".format(args.model_type))
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = get_cfg_defaults()
config = get_cfg_defaults(args.model_type)
if args.config:
config.merge_from_file(args.config)
if args.opts:
......
......@@ -18,21 +18,19 @@ from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.ds2 import DeepSpeech2Model
_C = CfgNode()
_C.data = ManifestDataset.params()
_C.collator = SpeechCollator.params()
_C.model = DeepSpeech2Model.params()
_C.training = DeepSpeech2Trainer.params()
_C.decoding = DeepSpeech2Tester.params()
def get_cfg_defaults():
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
def get_cfg_defaults(model_type='offline'):
_C = CfgNode()
_C.data = ManifestDataset.params()
_C.collator = SpeechCollator.params()
_C.training = DeepSpeech2Trainer.params()
_C.decoding = DeepSpeech2Tester.params()
if model_type == 'offline':
_C.model = DeepSpeech2Model.params()
else:
_C.model = DeepSpeech2ModelOnline.params()
"""Get a yacs CfgNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
......
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains DeepSpeech2 model."""
"""Contains DeepSpeech2 and DeepSpeech2Online model."""
import time
from collections import defaultdict
from pathlib import Path
......@@ -29,6 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.ds2 import DeepSpeech2InferModel
from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.trainer import Trainer
from deepspeech.utils import error_rate
......@@ -119,16 +121,22 @@ class DeepSpeech2Trainer(Trainer):
return total_loss, num_seen_utts
def setup_model(self):
config = self.config
model = DeepSpeech2Model(
feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
config = self.config.clone()
config.defrost()
assert (self.train_loader.collate_fn.feature_size ==
self.test_loader.collate_fn.feature_size)
assert (self.train_loader.collate_fn.vocab_size ==
self.test_loader.collate_fn.vocab_size)
config.model.feat_size = self.train_loader.collate_fn.feature_size
config.model.dict_size = self.train_loader.collate_fn.vocab_size
config.freeze()
if self.args.model_type == 'offline':
model = DeepSpeech2Model.from_config(config.model)
elif self.args.model_type == 'online':
model = DeepSpeech2ModelOnline.from_config(config.model)
else:
raise Exception("wrong model type")
if self.parallel:
model = paddle.DataParallel(model)
......@@ -164,6 +172,9 @@ class DeepSpeech2Trainer(Trainer):
config.data.manifest = config.data.dev_manifest
dev_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.test_manifest
test_dataset = ManifestDataset.from_config(config)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
......@@ -187,6 +198,11 @@ class DeepSpeech2Trainer(Trainer):
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
collate_fn_test = SpeechCollator.from_config(config)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
......@@ -198,7 +214,13 @@ class DeepSpeech2Trainer(Trainer):
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev)
logger.info("Setup train/valid Dataloader!")
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_test)
logger.info("Setup train/valid/test Dataloader!")
class DeepSpeech2Tester(DeepSpeech2Trainer):
......@@ -329,19 +351,18 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
exit(-1)
def export(self):
infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
if self.args.model_type == 'offline':
infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
elif self.args.model_type == 'online':
infer_model = DeepSpeech2InferModelOnline.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
else:
raise Exception("wrong model type")
infer_model.eval()
feat_dim = self.test_loader.collate_fn.feature_size
static_model = paddle.jit.to_static(
infer_model,
input_spec=[
paddle.static.InputSpec(
shape=[None, None, feat_dim],
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
])
static_model = infer_model.export()
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
......@@ -365,46 +386,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self.iteration = 0
self.epoch = 0
def setup_model(self):
config = self.config
model = DeepSpeech2Model(
feat_size=self.test_loader.collate_fn.feature_size,
dict_size=self.test_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
self.model = model
logger.info("Setup model!")
def setup_dataloader(self):
config = self.config.clone()
config.defrost()
# return raw text
config.data.manifest = config.data.test_manifest
# filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now.
# config.data.min_input_len = 0.0 # second
# config.data.max_input_len = float('inf') # second
# config.data.min_output_len = 0.0 # tokens
# config.data.max_output_len = float('inf') # tokens
# config.data.min_output_input_ratio = 0.00
# config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config)
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
# return text ord id
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
logger.info("Setup test Dataloader!")
def setup_output_dir(self):
"""Create a directory used for output.
"""
......
......@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False
......
......@@ -228,6 +228,27 @@ class DeepSpeech2Model(nn.Layer):
layer_tools.summary(model)
return model
@classmethod
def from_config(cls, config):
"""Build a DeepSpeec2Model from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2Model
The model built from config.
"""
model = cls(feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
use_gru=config.use_gru,
share_rnn_weights=config.share_rnn_weights)
return model
class DeepSpeech2InferModel(DeepSpeech2Model):
def __init__(self,
......@@ -260,3 +281,15 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
return probs
def export(self):
static_model = paddle.jit.to_static(
self,
input_spec=[
paddle.static.InputSpec(
shape=[None, None, self.encoder.feat_size],
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
])
return static_model
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .deepspeech2 import DeepSpeech2InferModelOnline
from .deepspeech2 import DeepSpeech2ModelOnline
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline']
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.subsampling import Conv2dSubsampling4
class Conv2dSubsampling4Online(Conv2dSubsampling4):
def __init__(self, idim: int, odim: int, dropout_rate: float):
super().__init__(idim, odim, dropout_rate, None)
self.output_dim = ((idim - 1) // 2 - 1) // 2 * odim
self.receptive_field_length = 2 * (
3 - 1) + 3 # stride_1 * (kernel_size_2 - 1) + kerel_size_1
def forward(self, x: paddle.Tensor,
x_len: paddle.Tensor) -> [paddle.Tensor, paddle.Tensor]:
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
#b, c, t, f = paddle.shape(x) #not work under jit
x = x.transpose([0, 2, 1, 3]).reshape([0, 0, -1])
x_len = ((x_len - 1) // 2 - 1) // 2
return x, x_len
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Online Model"""
from typing import Optional
import paddle
import paddle.nn.functional as F
from paddle import nn
from yacs.config import CfgNode
from deepspeech.models.ds2_online.conv import Conv2dSubsampling4Online
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.utils import layer_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModeOnline']
class CRNNEncoder(nn.Layer):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False):
super().__init__()
self.rnn_size = rnn_size
self.feat_size = feat_size # 161 for linear
self.dict_size = dict_size
self.num_rnn_layers = num_rnn_layers
self.num_fc_layers = num_fc_layers
self.rnn_direction = rnn_direction
self.fc_layers_size_list = fc_layers_size_list
self.use_gru = use_gru
self.conv = Conv2dSubsampling4Online(feat_size, 32, dropout_rate=0.0)
self.output_dim = self.conv.output_dim
i_size = self.conv.output_dim
self.rnn = nn.LayerList()
self.layernorm_list = nn.LayerList()
self.fc_layers_list = nn.LayerList()
if rnn_direction == 'bidirect' or rnn_direction == 'bidirectional':
layernorm_size = 2 * rnn_size
elif rnn_direction == 'forward':
layernorm_size = rnn_size
else:
raise Exception("Wrong rnn direction")
for i in range(0, num_rnn_layers):
if i == 0:
rnn_input_size = i_size
else:
rnn_input_size = layernorm_size
if use_gru == True:
self.rnn.append(
nn.GRU(
input_size=rnn_input_size,
hidden_size=rnn_size,
num_layers=1,
direction=rnn_direction))
else:
self.rnn.append(
nn.LSTM(
input_size=rnn_input_size,
hidden_size=rnn_size,
num_layers=1,
direction=rnn_direction))
self.layernorm_list.append(nn.LayerNorm(layernorm_size))
self.output_dim = layernorm_size
fc_input_size = layernorm_size
for i in range(self.num_fc_layers):
self.fc_layers_list.append(
nn.Linear(fc_input_size, fc_layers_size_list[i]))
fc_input_size = fc_layers_size_list[i]
self.output_dim = fc_layers_size_list[i]
@property
def output_size(self):
return self.output_dim
def forward(self, x, x_lens, init_state_h_box=None, init_state_c_box=None):
"""Compute Encoder outputs
Args:
x (Tensor): [B, feature_size, D]
x_lens (Tensor): [B]
init_state_h_box(Tensor): init_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
init_state_c_box(Tensor): init_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
Returns:
x (Tensor): encoder outputs, [B, size, D]
x_lens (Tensor): encoder length, [B]
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
"""
if init_state_h_box is not None:
init_state_list = None
if self.use_gru == True:
init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0)
init_state_list = init_state_h_list
else:
init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0)
init_state_c_list = paddle.split(
init_state_c_box, self.num_rnn_layers, axis=0)
init_state_list = [(init_state_h_list[i], init_state_c_list[i])
for i in range(self.num_rnn_layers)]
else:
init_state_list = [None] * self.num_rnn_layers
x, x_lens = self.conv(x, x_lens)
final_chunk_state_list = []
for i in range(0, self.num_rnn_layers):
x, final_state = self.rnn[i](x, init_state_list[i],
x_lens) #[B, T, D]
final_chunk_state_list.append(final_state)
x = self.layernorm_list[i](x)
for i in range(self.num_fc_layers):
x = self.fc_layers_list[i](x)
x = F.relu(x)
if self.use_gru == True:
final_chunk_state_h_box = paddle.concat(
final_chunk_state_list, axis=0)
final_chunk_state_c_box = init_state_c_box #paddle.zeros_like(final_chunk_state_h_box)
else:
final_chunk_state_h_list = [
final_chunk_state_list[i][0] for i in range(self.num_rnn_layers)
]
final_chunk_state_c_list = [
final_chunk_state_list[i][1] for i in range(self.num_rnn_layers)
]
final_chunk_state_h_box = paddle.concat(
final_chunk_state_h_list, axis=0)
final_chunk_state_c_box = paddle.concat(
final_chunk_state_c_list, axis=0)
return x, x_lens, final_chunk_state_h_box, final_chunk_state_c_box
def forward_chunk_by_chunk(self, x, x_lens, decoder_chunk_size=8):
"""Compute Encoder outputs
Args:
x (Tensor): [B, T, D]
x_lens (Tensor): [B]
decoder_chunk_size: The chunk size of decoder
Returns:
eouts_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
"""
subsampling_rate = self.conv.subsampling_rate
receptive_field_length = self.conv.receptive_field_length
chunk_size = (decoder_chunk_size - 1
) * subsampling_rate + receptive_field_length
chunk_stride = subsampling_rate * decoder_chunk_size
max_len = x.shape[1]
assert (chunk_size <= max_len)
eouts_chunk_list = []
eouts_chunk_lens_list = []
padding_len = chunk_stride - (max_len - chunk_size) % chunk_stride
padding = paddle.zeros((x.shape[0], padding_len, x.shape[2]))
padded_x = paddle.concat([x, padding], axis=1)
num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1
num_chunk = int(num_chunk)
chunk_state_h_box = None
chunk_state_c_box = None
final_state_h_box = None
final_state_c_box = None
for i in range(0, num_chunk):
start = i * chunk_stride
end = start + chunk_size
x_chunk = padded_x[:, start:end, :]
x_len_left = paddle.where(x_lens - i * chunk_stride < 0,
paddle.zeros_like(x_lens),
x_lens - i * chunk_stride)
x_chunk_len_tmp = paddle.ones_like(x_lens) * chunk_size
x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp,
x_len_left, x_chunk_len_tmp)
eouts_chunk, eouts_chunk_lens, chunk_state_h_box, chunk_state_c_box = self.forward(
x_chunk, x_chunk_lens, chunk_state_h_box, chunk_state_c_box)
eouts_chunk_list.append(eouts_chunk)
eouts_chunk_lens_list.append(eouts_chunk_lens)
final_state_h_box = chunk_state_h_box
final_state_c_box = chunk_state_c_box
return eouts_chunk_list, eouts_chunk_lens_list, final_state_h_box, final_state_c_box
class DeepSpeech2ModelOnline(nn.Layer):
"""The DeepSpeech2 network structure for online.
:param audio_data: Audio spectrogram data layer.
:type audio_data: Variable
:param text_data: Transcription text data layer.
:type text_data: Variable
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
:type num_conv_layers: int
:param num_rnn_layers: Number of stacking RNN layers.
:type num_rnn_layers: int
:param rnn_size: RNN layer size (dimension of RNN cells).
:type rnn_size: int
:param num_fc_layers: Number of stacking FC layers.
:type num_fc_layers: int
:param fc_layers_size_list: The list of FC layer sizes.
:type fc_layers_size_list: [int,]
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:return: A tuple of an output unnormalized log probability layer (
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
"""
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=4, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=True, #Use gru if set True. Use simple rnn if set False.
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_direction=rnn_direction,
num_fc_layers=num_fc_layers,
fc_layers_size_list=fc_layers_size_list,
rnn_size=rnn_size,
use_gru=use_gru)
self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab
enc_n_units=self.encoder.output_size,
blank_id=0, # first token is <blank>
dropout_rate=0.0,
reduction=True, # sum
batch_average=True) # sum / batch_size
def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss
Args:
audio (Tenosr): [B, T, D]
audio_len (Tensor): [B]
text (Tensor): [B, U]
text_len (Tensor): [B]
Returns:
loss (Tenosr): [1]
"""
eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder(
audio, audio_len, None, None)
loss = self.decoder(eouts, eouts_len, text, text_len)
return loss
@paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder(
audio, audio_len, None, None)
probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Parameters
----------
dataloader: paddle.io.DataLoader
config: yacs.config.CfgNode
model configs
checkpoint_path: Path or str
the path of pretrained model checkpoint, without extension name
Returns
-------
DeepSpeech2ModelOnline
The model built from pretrained result.
"""
model = cls(feat_size=dataloader.collate_fn.feature_size,
dict_size=dataloader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
rnn_direction=config.model.rnn_direction,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
use_gru=config.model.use_gru)
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
return model
@classmethod
def from_config(cls, config):
"""Build a DeepSpeec2ModelOnline from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2ModelOnline
The model built from config.
"""
model = cls(feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
rnn_direction=config.rnn_direction,
num_fc_layers=config.num_fc_layers,
fc_layers_size_list=config.fc_layers_size_list,
use_gru=config.use_gru)
return model
class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False):
super().__init__(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
rnn_direction=rnn_direction,
num_fc_layers=num_fc_layers,
fc_layers_size_list=fc_layers_size_list,
use_gru=use_gru)
def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box,
chunk_state_c_box):
eouts_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box = self.encoder(
audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box)
probs_chunk = self.decoder.softmax(eouts_chunk)
return probs_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box
def export(self):
static_model = paddle.jit.to_static(
self,
input_spec=[
paddle.static.InputSpec(
shape=[None, None,
self.encoder.feat_size], #[B, chunk_size, feat_dim]
dtype='float32'),
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
paddle.static.InputSpec(
shape=[None, None, None], dtype='float32'),
paddle.static.InputSpec(
shape=[None, None, None], dtype='float32')
])
return static_model
......@@ -92,7 +92,7 @@ class Conv2dSubsampling4(BaseSubsampling):
dropout_rate: float,
pos_enc_class: nn.Layer=PositionalEncoding):
"""Construct an Conv2dSubsampling4 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
......@@ -143,7 +143,7 @@ class Conv2dSubsampling6(BaseSubsampling):
dropout_rate: float,
pos_enc_class: nn.Layer=PositionalEncoding):
"""Construct an Conv2dSubsampling6 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
......@@ -196,7 +196,7 @@ class Conv2dSubsampling8(BaseSubsampling):
dropout_rate: float,
pos_enc_class: nn.Layer=PositionalEncoding):
"""Construct an Conv2dSubsampling8 object.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
......
......@@ -4,7 +4,7 @@ export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.0
max_input_len: 27.0 # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
collator:
batch_size: 32 # one gpu
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear #linear, mfcc, fbank
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
model:
num_conv_layers: 2
num_rnn_layers: 3
rnn_layer_size: 1024
rnn_direction: forward # [forward, bidirect]
num_fc_layers: 1
fc_layers_size_list: 512,
use_gru: True
training:
n_epoch: 50
lr: 2e-3
lr_decay: 0.83 # 0.83
weight_decay: 1e-06
global_grad_clip: 3.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 32
error_rate_type: cer
decoding_method: ctc_beam_search
lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm
alpha: 1.9
beta: 5.0
beam_size: 300
cutoff_prob: 0.99
cutoff_top_n: 40
num_proc_bsearch: 10
#!/bin/bash
if [ $# != 3 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path"
if [ $# != 4 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path model_type"
exit -1
fi
......@@ -11,6 +11,7 @@ echo "using $ngpu gpus..."
config_path=$1
ckpt_path_prefix=$2
jit_model_export_path=$3
model_type=$4
device=gpu
if [ ${ngpu} == 0 ];then
......@@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
--nproc ${ngpu} \
--config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \
--export_path ${jit_model_export_path}
--export_path ${jit_model_export_path} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in export!"
......
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1
fi
......@@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then
fi
config_path=$1
ckpt_prefix=$2
model_type=$3
# download language model
bash local/download_lm_ch.sh
......@@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix}
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
......
#!/bin/bash
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
if [ $# != 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type"
exit -1
fi
......@@ -10,6 +10,7 @@ echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
model_type=$3
device=gpu
if [ ${ngpu} == 0 ];then
......@@ -22,7 +23,8 @@ python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
--output exp/${ckpt_name} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in training!"
......
......@@ -7,6 +7,7 @@ stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=1
model_type=offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
......@@ -21,7 +22,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt}
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
......@@ -31,10 +32,10 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type}|| exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}
fi
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev-clean
test_manifest: data/manifest.test-clean
min_input_len: 0.0
max_input_len: 27.0 # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
collator:
batch_size: 20
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 20.0
delta_delta: False
dither: 1.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
model:
num_conv_layers: 2
num_rnn_layers: 3
rnn_layer_size: 2048
rnn_direction: forward
num_fc_layers: 2
fc_layers_size_list: 512, 256
use_gru: False
training:
n_epoch: 50
lr: 1e-3
lr_decay: 0.83
weight_decay: 1e-06
global_grad_clip: 5.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 128
error_rate_type: wer
decoding_method: ctc_beam_search
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 1.9
beta: 0.3
beam_size: 500
cutoff_prob: 1.0
cutoff_top_n: 40
num_proc_bsearch: 8
#!/bin/bash
if [ $# != 3 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path"
if [ $# != 4 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path model_type"
exit -1
fi
......@@ -11,6 +11,7 @@ echo "using $ngpu gpus..."
config_path=$1
ckpt_path_prefix=$2
jit_model_export_path=$3
model_type=$4
device=gpu
if [ ${ngpu} == 0 ];then
......@@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
--nproc ${ngpu} \
--config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \
--export_path ${jit_model_export_path}
--export_path ${jit_model_export_path} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in export!"
......
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1
fi
......@@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then
fi
config_path=$1
ckpt_prefix=$2
model_type=$3
# download language model
bash local/download_lm_en.sh
......@@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix}
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
......
#!/bin/bash
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
if [ $# != 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type"
exit -1
fi
......@@ -10,6 +10,7 @@ echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
model_type=$3
device=gpu
if [ ${ngpu} == 0 ];then
......@@ -23,7 +24,8 @@ python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
--output exp/${ckpt_name} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in training!"
......
......@@ -6,6 +6,7 @@ stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=30
model_type=offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num}
......@@ -19,7 +20,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt}
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt} ${model_type}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
......@@ -29,10 +30,10 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}
fi
# https://yaml.org/type/float.html
data:
train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny
min_input_len: 0.0
max_input_len: 27.0
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
batch_size: 4
model:
num_conv_layers: 2
num_rnn_layers: 4
rnn_layer_size: 2048
rnn_direction: forward
num_fc_layers: 2
fc_layers_size_list: 512, 256
use_gru: True
training:
n_epoch: 10
lr: 1e-5
lr_decay: 1.0
weight_decay: 1e-06
global_grad_clip: 5.0
log_interval: 1
checkpoint:
kbest_n: 3
latest_n: 2
decoding:
batch_size: 128
error_rate_type: wer
decoding_method: ctc_beam_search
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 500
cutoff_prob: 1.0
cutoff_top_n: 40
num_proc_bsearch: 8
#!/bin/bash
if [ $# != 3 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path"
if [ $# != 4 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path model_type"
exit -1
fi
......@@ -11,6 +11,7 @@ echo "using $ngpu gpus..."
config_path=$1
ckpt_path_prefix=$2
jit_model_export_path=$3
model_type=$4
device=gpu
if [ ${ngpu} == 0 ];then
......@@ -22,8 +23,8 @@ python3 -u ${BIN_DIR}/export.py \
--nproc ${ngpu} \
--config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \
--export_path ${jit_model_export_path}
--export_path ${jit_model_export_path} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in export!"
......
#!/bin/bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
if [ $# != 3 ];then
echo "usage: ${0} config_path ckpt_path_prefix model_type"
exit -1
fi
......@@ -14,6 +14,7 @@ if [ ${ngpu} == 0 ];then
fi
config_path=$1
ckpt_prefix=$2
model_type=$3
# download language model
bash local/download_lm_en.sh
......@@ -26,7 +27,8 @@ python3 -u ${BIN_DIR}/test.py \
--nproc 1 \
--config ${config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix}
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
......
#!/bin/bash
if [ $# != 2 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name"
if [ $# != 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type"
exit -1
fi
......@@ -10,6 +10,7 @@ echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
model_type=$3
device=gpu
if [ ${ngpu} == 0 ];then
......@@ -22,7 +23,8 @@ python3 -u ${BIN_DIR}/train.py \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name}
--output exp/${ckpt_name} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in training!"
......
......@@ -7,6 +7,7 @@ stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=1
model_type=offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
......@@ -21,7 +22,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt}
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
......@@ -31,10 +32,10 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES=${gpus} ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
CUDA_VISIBLE_DEVICES=${gpus} ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}
fi
......@@ -16,7 +16,7 @@ import unittest
import numpy as np
import paddle
from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.models.ds2 import DeepSpeech2Model
class TestDeepSpeech2Model(unittest.TestCase):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
class TestDeepSpeech2ModelOnline(unittest.TestCase):
def setUp(self):
paddle.set_device('cpu')
self.batch_size = 2
self.feat_dim = 161
max_len = 210
# (B, T, D)
audio = np.random.randn(self.batch_size, max_len, self.feat_dim)
audio_len = np.random.randint(max_len, size=self.batch_size)
audio_len[-1] = max_len
# (B, U)
text = np.array([[1, 2], [1, 2]])
text_len = np.array([2] * self.batch_size)
self.audio = paddle.to_tensor(audio, dtype='float32')
self.audio_len = paddle.to_tensor(audio_len, dtype='int64')
self.text = paddle.to_tensor(text, dtype='int32')
self.text_len = paddle.to_tensor(text_len, dtype='int64')
def test_ds2_1(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_2(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=True)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_3(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_4(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=True)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_5(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_6(self):
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
rnn_direction='bidirect',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False)
loss = model(self.audio, self.audio_len, self.text, self.text_len)
self.assertEqual(loss.numel(), 1)
def test_ds2_7(self):
use_gru = False
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=1,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=use_gru)
model.eval()
paddle.device.set_device("cpu")
de_ch_size = 8
eouts, eouts_lens, final_state_h_box, final_state_c_box = model.encoder(
self.audio, self.audio_len)
eouts_by_chk_list, eouts_lens_by_chk_list, final_state_h_box_chk, final_state_c_box_chk = model.encoder.forward_chunk_by_chunk(
self.audio, self.audio_len, de_ch_size)
eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1)
eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list)
decode_max_len = eouts.shape[1]
eouts_by_chk = eouts_by_chk[:, :decode_max_len, :]
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False:
self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
def test_ds2_8(self):
use_gru = True
model = DeepSpeech2ModelOnline(
feat_size=self.feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=1,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=use_gru)
model.eval()
paddle.device.set_device("cpu")
de_ch_size = 8
eouts, eouts_lens, final_state_h_box, final_state_c_box = model.encoder(
self.audio, self.audio_len)
eouts_by_chk_list, eouts_lens_by_chk_list, final_state_h_box_chk, final_state_c_box_chk = model.encoder.forward_chunk_by_chunk(
self.audio, self.audio_len, de_ch_size)
eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1)
eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list)
decode_max_len = eouts.shape[1]
eouts_by_chk = eouts_by_chk[:, :decode_max_len, :]
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False:
self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册