提交 2cacbaf4 编写于 作者: H huangyuxin

修改了deepspeech2.py部分LSTM和GRU的代码,增加了LayerNorm

上级 ce1e8ab5
...@@ -127,7 +127,8 @@ class DeepSpeech2Trainer(Trainer): ...@@ -127,7 +127,8 @@ class DeepSpeech2Trainer(Trainer):
num_rnn_layers=config.model.num_rnn_layers, num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru, use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights,
apply_online=config.model.apply_online)
if self.parallel: if self.parallel:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
...@@ -374,7 +375,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): ...@@ -374,7 +375,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
num_rnn_layers=config.model.num_rnn_layers, num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru, use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights,
apply_online=config.model.apply_online)
self.model = model self.model = model
logger.info("Setup model!") logger.info("Setup model!")
......
...@@ -25,6 +25,11 @@ from deepspeech.utils import layer_tools ...@@ -25,6 +25,11 @@ from deepspeech.utils import layer_tools
from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from paddle.nn import LSTM, GRU
from paddle.nn import LayerNorm
from paddle.nn import LayerList
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferMode'] __all__ = ['DeepSpeech2Model', 'DeepSpeech2InferMode']
...@@ -38,25 +43,50 @@ class CRNNEncoder(nn.Layer): ...@@ -38,25 +43,50 @@ class CRNNEncoder(nn.Layer):
num_rnn_layers=3, num_rnn_layers=3,
rnn_size=1024, rnn_size=1024,
use_gru=False, use_gru=False,
share_rnn_weights=True): share_rnn_weights=True,
apply_online=True):
super().__init__() super().__init__()
self.rnn_size = rnn_size self.rnn_size = rnn_size
self.feat_size = feat_size # 161 for linear self.feat_size = feat_size # 161 for linear
self.dict_size = dict_size self.dict_size = dict_size
self.num_rnn_layers = num_rnn_layers
self.apply_online = apply_online
self.conv = ConvStack(feat_size, num_conv_layers) self.conv = ConvStack(feat_size, num_conv_layers)
i_size = self.conv.output_height # H after conv stack i_size = self.conv.output_height # H after conv stack
self.rnn = LayerList()
self.layernorm_list = LayerList()
if (apply_online == True):
rnn_direction = 'forward'
else:
rnn_direction = 'bidirect'
if use_gru == True:
self.rnn.append(GRU(input_size=i_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction))
self.layernorm_list.append(LayerNorm(rnn_size))
for i in range(1, num_rnn_layers):
self.rnn.append(GRU(input_size=rnn_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction))
self.layernorm_list.append(LayerNorm(rnn_size))
else:
self.rnn.append(LSTM(input_size=i_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction))
self.layernorm_list.append(LayerNorm(rnn_size))
for i in range(1, num_rnn_layers):
self.rnn.append(LSTM(input_size=rnn_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction))
self.layernorm_list.append(LayerNorm(rnn_size))
"""
self.rnn = RNNStack( self.rnn = RNNStack(
i_size=i_size, i_size=i_size,
h_size=rnn_size, h_size=rnn_size,
num_stacks=num_rnn_layers, num_stacks=num_rnn_layers,
use_gru=use_gru, use_gru=use_gru,
share_rnn_weights=share_rnn_weights) share_rnn_weights=share_rnn_weights)
"""
@property @property
def output_size(self): def output_size(self):
return self.rnn_size * 2 return self.rnn_size
def forward(self, audio, audio_len): def forward(self, audio, audio_len):
"""Compute Encoder outputs """Compute Encoder outputs
...@@ -86,7 +116,15 @@ class CRNNEncoder(nn.Layer): ...@@ -86,7 +116,15 @@ class CRNNEncoder(nn.Layer):
x = x.reshape([0, 0, -1]) #[B, T, C*D] x = x.reshape([0, 0, -1]) #[B, T, C*D]
# remove padding part # remove padding part
x, x_lens = self.rnn(x, x_lens) #[B, T, D] print ("x.shape:", x.shape)
x, output_state = self.rnn[0](x, None, x_lens)
x = self.layernorm_list[0](x)
for i in range(1, self.num_rnn_layers):
x, output_state = self.rnn[i](x, output_state, x_lens) #[B, T, D]
x = self.layernorm_list[i](x)
"""
x, x_lens = self.rnn(x, x_lens)
"""
return x, x_lens return x, x_lens
...@@ -141,7 +179,8 @@ class DeepSpeech2Model(nn.Layer): ...@@ -141,7 +179,8 @@ class DeepSpeech2Model(nn.Layer):
num_rnn_layers=3, num_rnn_layers=3,
rnn_size=1024, rnn_size=1024,
use_gru=False, use_gru=False,
share_rnn_weights=True): share_rnn_weights=True,
apply_online = True):
super().__init__() super().__init__()
self.encoder = CRNNEncoder( self.encoder = CRNNEncoder(
feat_size=feat_size, feat_size=feat_size,
...@@ -150,8 +189,9 @@ class DeepSpeech2Model(nn.Layer): ...@@ -150,8 +189,9 @@ class DeepSpeech2Model(nn.Layer):
num_rnn_layers=num_rnn_layers, num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size, rnn_size=rnn_size,
use_gru=use_gru, use_gru=use_gru,
share_rnn_weights=share_rnn_weights) share_rnn_weights=share_rnn_weights,
assert (self.encoder.output_size == rnn_size * 2) apply_online=apply_online)
assert (self.encoder.output_size == rnn_size)
self.decoder = CTCDecoder( self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab odim=dict_size, # <blank> is in vocab
...@@ -221,7 +261,8 @@ class DeepSpeech2Model(nn.Layer): ...@@ -221,7 +261,8 @@ class DeepSpeech2Model(nn.Layer):
num_rnn_layers=config.model.num_rnn_layers, num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru, use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) share_rnn_weights=config.model.share_rnn_weights,
apply_online=config.model.apply_online)
infos = Checkpoint().load_parameters( infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path) model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}") logger.info(f"checkpoint info: {infos}")
...@@ -237,7 +278,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model): ...@@ -237,7 +278,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
num_rnn_layers=3, num_rnn_layers=3,
rnn_size=1024, rnn_size=1024,
use_gru=False, use_gru=False,
share_rnn_weights=True): share_rnn_weights=True,
apply_online = True):
super().__init__( super().__init__(
feat_size=feat_size, feat_size=feat_size,
dict_size=dict_size, dict_size=dict_size,
...@@ -245,7 +287,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model): ...@@ -245,7 +287,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
num_rnn_layers=num_rnn_layers, num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size, rnn_size=rnn_size,
use_gru=use_gru, use_gru=use_gru,
share_rnn_weights=share_rnn_weights) share_rnn_weights=share_rnn_weights,
apply_online=apply_online)
def forward(self, audio, audio_len): def forward(self, audio, audio_len):
"""export model function """export model function
......
...@@ -36,10 +36,11 @@ collator: ...@@ -36,10 +36,11 @@ collator:
model: model:
num_conv_layers: 2 num_conv_layers: 2
num_rnn_layers: 3 num_rnn_layers: 4
rnn_layer_size: 1024 rnn_layer_size: 1024
use_gru: True use_gru: True
share_rnn_weights: False share_rnn_weights: False
apply_online: False
training: training:
n_epoch: 50 n_epoch: 50
......
...@@ -40,6 +40,7 @@ model: ...@@ -40,6 +40,7 @@ model:
rnn_layer_size: 2048 rnn_layer_size: 2048
use_gru: False use_gru: False
share_rnn_weights: True share_rnn_weights: True
apply_online: False
training: training:
n_epoch: 50 n_epoch: 50
......
...@@ -41,6 +41,7 @@ model: ...@@ -41,6 +41,7 @@ model:
rnn_layer_size: 2048 rnn_layer_size: 2048
use_gru: False use_gru: False
share_rnn_weights: True share_rnn_weights: True
apply_online: True
training: training:
n_epoch: 10 n_epoch: 10
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册