diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 2f84b686c69c01814551c50b854aa567caf87ae4..544d57d1bcc975f8cc2e71955f9cec258e38cbcc 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -127,7 +127,8 @@ class DeepSpeech2Trainer(Trainer): num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) + share_rnn_weights=config.model.share_rnn_weights, + apply_online=config.model.apply_online) if self.parallel: model = paddle.DataParallel(model) @@ -374,7 +375,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) + share_rnn_weights=config.model.share_rnn_weights, + apply_online=config.model.apply_online) self.model = model logger.info("Setup model!") diff --git a/deepspeech/models/ds2/deepspeech2.py b/deepspeech/models/ds2/deepspeech2.py index 0bd5fb95d3738b7c48a5884e707157ea10baae63..7f173ce292167c7fb5aeccf6481e01cc47abd8e3 100644 --- a/deepspeech/models/ds2/deepspeech2.py +++ b/deepspeech/models/ds2/deepspeech2.py @@ -25,6 +25,11 @@ from deepspeech.utils import layer_tools from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log +from paddle.nn import LSTM, GRU +from paddle.nn import LayerNorm +from paddle.nn import LayerList + + logger = Log(__name__).getlog() __all__ = ['DeepSpeech2Model', 'DeepSpeech2InferMode'] @@ -38,25 +43,50 @@ class CRNNEncoder(nn.Layer): num_rnn_layers=3, rnn_size=1024, use_gru=False, - share_rnn_weights=True): + share_rnn_weights=True, + apply_online=True): super().__init__() self.rnn_size = rnn_size self.feat_size = feat_size # 161 for linear self.dict_size = dict_size - + self.num_rnn_layers = num_rnn_layers + self.apply_online = apply_online self.conv = ConvStack(feat_size, num_conv_layers) i_size = self.conv.output_height # H after conv stack + + + self.rnn = LayerList() + self.layernorm_list = LayerList() + + if (apply_online == True): + rnn_direction = 'forward' + else: + rnn_direction = 'bidirect' + + if use_gru == True: + self.rnn.append(GRU(input_size=i_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction)) + self.layernorm_list.append(LayerNorm(rnn_size)) + for i in range(1, num_rnn_layers): + self.rnn.append(GRU(input_size=rnn_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction)) + self.layernorm_list.append(LayerNorm(rnn_size)) + else: + self.rnn.append(LSTM(input_size=i_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction)) + self.layernorm_list.append(LayerNorm(rnn_size)) + for i in range(1, num_rnn_layers): + self.rnn.append(LSTM(input_size=rnn_size, hidden_size=rnn_size, num_layers=1, direction = rnn_direction)) + self.layernorm_list.append(LayerNorm(rnn_size)) + """ self.rnn = RNNStack( i_size=i_size, h_size=rnn_size, num_stacks=num_rnn_layers, use_gru=use_gru, share_rnn_weights=share_rnn_weights) - + """ @property def output_size(self): - return self.rnn_size * 2 + return self.rnn_size def forward(self, audio, audio_len): """Compute Encoder outputs @@ -86,7 +116,15 @@ class CRNNEncoder(nn.Layer): x = x.reshape([0, 0, -1]) #[B, T, C*D] # remove padding part - x, x_lens = self.rnn(x, x_lens) #[B, T, D] + print ("x.shape:", x.shape) + x, output_state = self.rnn[0](x, None, x_lens) + x = self.layernorm_list[0](x) + for i in range(1, self.num_rnn_layers): + x, output_state = self.rnn[i](x, output_state, x_lens) #[B, T, D] + x = self.layernorm_list[i](x) + """ + x, x_lens = self.rnn(x, x_lens) + """ return x, x_lens @@ -141,7 +179,8 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=3, rnn_size=1024, use_gru=False, - share_rnn_weights=True): + share_rnn_weights=True, + apply_online = True): super().__init__() self.encoder = CRNNEncoder( feat_size=feat_size, @@ -150,8 +189,9 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=num_rnn_layers, rnn_size=rnn_size, use_gru=use_gru, - share_rnn_weights=share_rnn_weights) - assert (self.encoder.output_size == rnn_size * 2) + share_rnn_weights=share_rnn_weights, + apply_online=apply_online) + assert (self.encoder.output_size == rnn_size) self.decoder = CTCDecoder( odim=dict_size, # is in vocab @@ -221,7 +261,8 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) + share_rnn_weights=config.model.share_rnn_weights, + apply_online=config.model.apply_online) infos = Checkpoint().load_parameters( model, checkpoint_path=checkpoint_path) logger.info(f"checkpoint info: {infos}") @@ -237,7 +278,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model): num_rnn_layers=3, rnn_size=1024, use_gru=False, - share_rnn_weights=True): + share_rnn_weights=True, + apply_online = True): super().__init__( feat_size=feat_size, dict_size=dict_size, @@ -245,7 +287,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model): num_rnn_layers=num_rnn_layers, rnn_size=rnn_size, use_gru=use_gru, - share_rnn_weights=share_rnn_weights) + share_rnn_weights=share_rnn_weights, + apply_online=apply_online) def forward(self, audio, audio_len): """export model function diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 1c97fc6072166c50089e8217d15f6134778b5de6..7d0d1f8956ba9b6a516c60ba6d6fbb2ec7a802c1 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -36,10 +36,11 @@ collator: model: num_conv_layers: 2 - num_rnn_layers: 3 + num_rnn_layers: 4 rnn_layer_size: 1024 use_gru: True share_rnn_weights: False + apply_online: False training: n_epoch: 50 diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index acee94c3e71cda61ec2405f90c4b7940916a91e4..be1918d010f313072637e0a0255fd7d8ca808ef0 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -40,6 +40,7 @@ model: rnn_layer_size: 2048 use_gru: False share_rnn_weights: True + apply_online: False training: n_epoch: 50 diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index ea433f341577104e65f0b9fa274613c46e15cfe0..8c719e5cd02dd22e489922b7ef3560c08babd2e2 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -41,6 +41,7 @@ model: rnn_layer_size: 2048 use_gru: False share_rnn_weights: True + apply_online: True training: n_epoch: 10