提交 3fb9f688 编写于 作者: H huangyuxin

complete model export for ds2_online

上级 e8a39134
......@@ -32,7 +32,8 @@ if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--model_type")
args = parser.parse_args()
if args.model_type is None:
args.model_type = 'offline'
print_arguments(args)
# https://yaml.org/type/float.html
......
......@@ -33,6 +33,8 @@ if __name__ == "__main__":
parser.add_argument("--model_type")
args = parser.parse_args()
print_arguments(args, globals())
if args.model_type is None:
args.model_type = 'offline'
# https://yaml.org/type/float.html
config = get_cfg_defaults(args.model_type)
......
......@@ -37,6 +37,8 @@ if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--model_type")
args = parser.parse_args()
if args.model_type is None:
args.model_type = 'offline'
print_arguments(args, globals())
# https://yaml.org/type/float.html
......
......@@ -21,7 +21,7 @@ from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
def get_cfg_defaults(model_type):
def get_cfg_defaults(model_type='offline'):
_C = CfgNode()
if (model_type == 'offline'):
_C.data = ManifestDataset.params()
......
......@@ -134,6 +134,7 @@ class DeepSpeech2Trainer(Trainer):
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
elif self.args.model_type == 'online':
print("fc_layers_size_list", config.model.fc_layers_size_list)
model = DeepSpeech2ModelOnline(
feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.collate_fn.vocab_size,
......@@ -352,10 +353,11 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
infer_model = DeepSpeech2InferModelOnline.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
else:
raise Exception("wrong model tyep")
raise Exception("wrong model type")
infer_model.eval()
feat_dim = self.test_loader.collate_fn.feature_size
if self.args.model_type == 'offline':
static_model = paddle.jit.to_static(
infer_model,
input_spec=[
......@@ -365,6 +367,29 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
])
elif self.args.model_type == 'online':
static_model = paddle.jit.to_static(
infer_model,
input_spec=[
paddle.static.InputSpec(
shape=[None, None,
feat_dim], #[B, chunk_size, feat_dim]
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
[
(
paddle.static.InputSpec(
shape=[None, None, None], dtype='float32'
), #num_rnn_layers * num_dirctions, rnn_size
paddle.static.InputSpec(
shape=[None, None, None], dtype='float32'
) #num_rnn_layers * num_dirctions, rnn_size
) for i in range(self.config.model.num_rnn_layers)
]
])
else:
raise Exception("wrong model type")
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
......
......@@ -29,7 +29,7 @@ class Conv2dSubsampling4Online(Conv2dSubsampling4):
x_len: paddle.Tensor) -> [paddle.Tensor, paddle.Tensor]:
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
b, c, t, f = paddle.shape(x)
x = x.transpose([0, 2, 1, 3]).reshape([b, t, c * f])
#b, c, t, f = paddle.shape(x) #not work under jit
x = x.transpose([0, 2, 1, 3]).reshape([0, 0, -1])
x_len = ((x_len - 1) // 2 - 1) // 2
return x, x_len
......@@ -61,7 +61,7 @@ class CRNNEncoder(nn.Layer):
rnn_input_size = i_size
else:
rnn_input_size = rnn_size
if (use_gru == True):
if use_gru == True:
self.rnn.append(
nn.GRU(
input_size=rnn_input_size,
......@@ -146,6 +146,17 @@ class CRNNEncoder(nn.Layer):
return x, x_lens, chunk_final_state_list
def forward_chunk_by_chunk(self, x, x_lens, decoder_chunk_size=8):
"""Compute Encoder outputs
Args:
x (Tensor): [B, T, D]
x_lens (Tensor): [B]
decoder_chunk_size: The chunk size of decoder
Returns:
eouts_chunk_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks
eouts_chunk_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks
final_chunk_state_list: list of final_states for RNN layers, [num_directions, batch_size, hidden_size] * num_rnn_layers
"""
subsampling_rate = self.conv.subsampling_rate
receptive_field_length = self.conv.receptive_field_length
chunk_size = (decoder_chunk_size - 1
......@@ -183,8 +194,8 @@ class CRNNEncoder(nn.Layer):
eouts_chunk_list.append(eouts_chunk)
eouts_chunk_lens_list.append(eouts_chunk_lens)
return eouts_chunk_list, eouts_chunk_lens_list, chunk_state_list
final_chunk_state_list = chunk_state_list
return eouts_chunk_list, eouts_chunk_lens_list, final_chunk_state_list
class DeepSpeech2ModelOnline(nn.Layer):
......@@ -208,7 +219,6 @@ class DeepSpeech2ModelOnline(nn.Layer):
:type rnn_size: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:type share_weights: bool
:return: A tuple of an output unnormalized log probability layer (
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
......@@ -295,97 +305,6 @@ class DeepSpeech2ModelOnline(nn.Layer):
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
"""
@paddle.no_grad()
def decode_by_chunk(self, eouts_prefix, eouts_len_prefix, chunk_state_list,
audio_chunk, audio_len_chunk, vocab_list,
decoding_method, lang_model_path, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
eouts_chunk, eouts_chunk_len, final_state_list = self.encoder.forward_chunk(
audio_chunk, audio_len_chunk, chunk_state_list)
if eouts_prefix is not None:
eouts = paddle.concat([eouts_prefix, eouts_chunk], axis=1)
eouts_len = paddle.add_n([eouts_len_prefix, eouts_chunk_len])
else:
eouts = eouts_chunk
eouts_len = eouts_chunk_len
probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes), eouts, eouts_len, final_state_list
@paddle.no_grad()
def decode_chunk_by_chunk(self, audio, audio_len, vocab_list,
decoding_method, lang_model_path, beam_alpha,
beam_beta, beam_size, cutoff_prob, cutoff_top_n,
num_processes):
# init once
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
eouts_chunk_list, eouts_chunk_len_list, final_state_list = self.encoder.forward_chunk_by_chunk(
audio, audio_len)
eouts = paddle.concat(eouts_chunk_list, axis=1)
eouts_len = paddle.add_n(eouts_chunk_len_list)
probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
"""
"""
decocd_prob,
decode_prob_chunk_by_chunk
decode_prob_by_chunk
is only used for test
"""
"""
@paddle.no_grad()
def decode_prob(self, audio, audio_len):
eouts, eouts_len, final_state_list = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
return probs, eouts, eouts_len, final_state_list
@paddle.no_grad()
def decode_prob_chunk_by_chunk(self, audio, audio_len, decoder_chunk_size):
eouts_chunk_list, eouts_chunk_len_list, final_state_list = self.encoder.forward_chunk_by_chunk(
audio, audio_len, decoder_chunk_size)
eouts = paddle.concat(eouts_chunk_list, axis=1)
eouts_len = paddle.add_n(eouts_chunk_len_list)
probs = self.decoder.softmax(eouts)
return probs, eouts, eouts_len, final_state_list
@paddle.no_grad()
def decode_prob_by_chunk(self, audio, audio_len, eouts_prefix,
eouts_lens_prefix, chunk_state_list):
eouts_chunk, eouts_chunk_lens, final_state_list = self.encoder.forward_chunk(
audio, audio_len, chunk_state_list)
if eouts_prefix is not None:
eouts = paddle.concat([eouts_prefix, eouts_chunk], axis=1)
eouts_lens = paddle.add_n([eouts_lens_prefix, eouts_chunk_lens])
else:
eouts = eouts_chunk
eouts_lens = eouts_chunk_lens
probs = self.decoder.softmax(eouts)
return probs, eouts, eouts_lens, final_state_list
"""
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
......@@ -443,42 +362,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
fc_layers_size_list=fc_layers_size_list,
use_gru=use_gru)
def forward(self, audio, audio_len):
"""export model function
Args:
audio (Tensor): [B, T, D]
audio_len (Tensor): [B]
Returns:
probs: probs after softmax
"""
eouts, eouts_len, final_state_list = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
return probs
def forward_chunk(self, audio_chunk, audio_chunk_lens):
eouts_chunkt, eouts_chunk_lens, final_state_list = self.encoder.forward_chunk(
audio_chunk, audio_chunk_lens)
probs = self.decoder.softmax(eouts)
return probs
def forward(self, eouts_chunk_prefix, eouts_chunk_lens_prefix, audio_chunk,
audio_chunk_lens, chunk_state_list):
"""export model function
Args:
audio_chunk (Tensor): [B, T, D]
audio_chunk_len (Tensor): [B]
Returns:
probs: probs after softmax
"""
def forward(self, audio_chunk, audio_chunk_lens, chunk_state_list):
eouts_chunk, eouts_chunk_lens, final_state_list = self.encoder.forward_chunk(
audio_chunk, audio_chunk_lens, chunk_state_list)
eouts_chunk_new_prefix = paddle.concat(
[eouts_chunk_prefix, eouts_chunk], axis=1)
eouts_chunk_lens_new_prefix = paddle.add(eouts_chunk_lens_prefix,
eouts_chunk_lens)
probs_chunk = self.decoder.softmax(eouts_chunk_new_prefix)
return probs_chunk, eouts_chunk_new_prefix, eouts_chunk_lens_new_prefix, final_state_list
probs_chunk = self.decoder.softmax(eouts_chunk)
return probs_chunk, final_state_list
......@@ -7,7 +7,7 @@ stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
avg_num=1
model_type=online
model_type=offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
......
......@@ -4,10 +4,10 @@ source path.sh
gpus=7
stage=1
stop_stage=100
conf_path=conf/deepspeech2.yaml
stop_stage=1
conf_path=conf/deepspeech2_online.yaml
avg_num=1
model_type=online
model_type=online #online | offline
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
......
......@@ -19,7 +19,6 @@ import paddle
from deepspeech.models.ds2 import DeepSpeech2Model
class TestDeepSpeech2Model(unittest.TestCase):
def setUp(self):
paddle.set_device('cpu')
......
......@@ -119,14 +119,14 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
paddle.device.set_device("cpu")
de_ch_size = 9
eouts, eouts_lens, final_state_list = model.encoder(
self.audio, self.audio_len)
eouts, eouts_lens, final_state_list = model.encoder(self.audio,
self.audio_len)
eouts_by_chk_list, eouts_lens_by_chk_list, final_state_list_by_chk = model.encoder.forward_chunk_by_chunk(
self.audio, self.audio_len, de_ch_size)
eouts_by_chk = paddle.concat(eouts_by_chk_list, axis = 1)
eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1)
eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list)
decode_max_len = eouts.shape[1]
print ("dml", decode_max_len)
print("dml", decode_max_len)
eouts_by_chk = eouts_by_chk[:, :decode_max_len, :]
self.assertEqual(
paddle.sum(
......@@ -149,6 +149,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
print (paddle.sum(paddle.abs(paddle.subtract(eouts, eouts_by_chk))))
print (paddle.allclose(eouts[:,:,:], eouts_by_chk[:,:,:]))
"""
"""
def split_into_chunk(self, x, x_lens, decoder_chunk_size, subsampling_rate,
receptive_field_length):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册