未验证 提交 9ac6d65a 编写于 作者: J Jackwaterveg 提交者: GitHub

Merge pull request #780 from Jackwaterveg/ds2_online

修改pre_commit, 注释以及增加ds2的seed
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Trainer for DeepSpeech2 model.""" """Trainer for DeepSpeech2 model."""
import os
from paddle import distributed as dist from paddle import distributed as dist
from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.exps.deepspeech2.config import get_cfg_defaults
...@@ -53,5 +55,7 @@ if __name__ == "__main__": ...@@ -53,5 +55,7 @@ if __name__ == "__main__":
if args.dump_config: if args.dump_config:
with open(args.dump_config, 'w') as f: with open(args.dump_config, 'w') as f:
print(config, file=f) print(config, file=f)
if config.training.seed is not None:
os.environ.setdefault('FLAGS_cudnn_deterministic', 'True')
main(config, args) main(config, args)
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains DeepSpeech2 and DeepSpeech2Online model.""" """Contains DeepSpeech2 and DeepSpeech2Online model."""
import random
import time import time
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
...@@ -53,6 +54,7 @@ class DeepSpeech2Trainer(Trainer): ...@@ -53,6 +54,7 @@ class DeepSpeech2Trainer(Trainer):
weight_decay=1e-6, # the coeff of weight decay weight_decay=1e-6, # the coeff of weight decay
global_grad_clip=5.0, # the global norm clip global_grad_clip=5.0, # the global norm clip
n_epoch=50, # train epochs n_epoch=50, # train epochs
seed=1024, #train seed
)) ))
if config is not None: if config is not None:
...@@ -61,6 +63,13 @@ class DeepSpeech2Trainer(Trainer): ...@@ -61,6 +63,13 @@ class DeepSpeech2Trainer(Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
if config.training.seed is not None:
self.set_seed(config.training.seed)
def set_seed(self, seed):
np.random.seed(seed)
random.seed(seed)
paddle.seed(seed)
def train_batch(self, batch_index, batch_data, msg): def train_batch(self, batch_index, batch_data, msg):
start = time.time() start = time.time()
......
...@@ -12,9 +12,7 @@ ...@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle import paddle
from paddle import nn
from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.subsampling import Conv2dSubsampling4 from deepspeech.modules.subsampling import Conv2dSubsampling4
......
...@@ -26,7 +26,7 @@ from deepspeech.utils.checkpoint import Checkpoint ...@@ -26,7 +26,7 @@ from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModeOnline'] __all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline']
class CRNNEncoder(nn.Layer): class CRNNEncoder(nn.Layer):
...@@ -68,7 +68,7 @@ class CRNNEncoder(nn.Layer): ...@@ -68,7 +68,7 @@ class CRNNEncoder(nn.Layer):
rnn_input_size = i_size rnn_input_size = i_size
else: else:
rnn_input_size = layernorm_size rnn_input_size = layernorm_size
if use_gru == True: if use_gru is True:
self.rnn.append( self.rnn.append(
nn.GRU( nn.GRU(
input_size=rnn_input_size, input_size=rnn_input_size,
...@@ -102,18 +102,18 @@ class CRNNEncoder(nn.Layer): ...@@ -102,18 +102,18 @@ class CRNNEncoder(nn.Layer):
Args: Args:
x (Tensor): [B, feature_size, D] x (Tensor): [B, feature_size, D]
x_lens (Tensor): [B] x_lens (Tensor): [B]
init_state_h_box(Tensor): init_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size init_state_h_box(Tensor): init_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
init_state_c_box(Tensor): init_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size init_state_c_box(Tensor): init_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
Returns: Return:
x (Tensor): encoder outputs, [B, size, D] x (Tensor): encoder outputs, [B, size, D]
x_lens (Tensor): encoder length, [B] x_lens (Tensor): encoder length, [B]
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
""" """
if init_state_h_box is not None: if init_state_h_box is not None:
init_state_list = None init_state_list = None
if self.use_gru == True: if self.use_gru is True:
init_state_h_list = paddle.split( init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0) init_state_h_box, self.num_rnn_layers, axis=0)
init_state_list = init_state_h_list init_state_list = init_state_h_list
...@@ -139,10 +139,10 @@ class CRNNEncoder(nn.Layer): ...@@ -139,10 +139,10 @@ class CRNNEncoder(nn.Layer):
x = self.fc_layers_list[i](x) x = self.fc_layers_list[i](x)
x = F.relu(x) x = F.relu(x)
if self.use_gru == True: if self.use_gru is True:
final_chunk_state_h_box = paddle.concat( final_chunk_state_h_box = paddle.concat(
final_chunk_state_list, axis=0) final_chunk_state_list, axis=0)
final_chunk_state_c_box = init_state_c_box #paddle.zeros_like(final_chunk_state_h_box) final_chunk_state_c_box = init_state_c_box
else: else:
final_chunk_state_h_list = [ final_chunk_state_h_list = [
final_chunk_state_list[i][0] for i in range(self.num_rnn_layers) final_chunk_state_list[i][0] for i in range(self.num_rnn_layers)
...@@ -165,10 +165,10 @@ class CRNNEncoder(nn.Layer): ...@@ -165,10 +165,10 @@ class CRNNEncoder(nn.Layer):
x_lens (Tensor): [B] x_lens (Tensor): [B]
decoder_chunk_size: The chunk size of decoder decoder_chunk_size: The chunk size of decoder
Returns: Returns:
eouts_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks eouts_list (List of Tensor): The list of encoder outputs in chunk_size: [B, chunk_size, D] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks eouts_lens_list (List of Tensor): The list of encoder length in chunk_size: [B] * num_chunks
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
""" """
subsampling_rate = self.conv.subsampling_rate subsampling_rate = self.conv.subsampling_rate
receptive_field_length = self.conv.receptive_field_length receptive_field_length = self.conv.receptive_field_length
...@@ -215,12 +215,14 @@ class CRNNEncoder(nn.Layer): ...@@ -215,12 +215,14 @@ class CRNNEncoder(nn.Layer):
class DeepSpeech2ModelOnline(nn.Layer): class DeepSpeech2ModelOnline(nn.Layer):
"""The DeepSpeech2 network structure for online. """The DeepSpeech2 network structure for online.
:param audio_data: Audio spectrogram data layer. :param audio: Audio spectrogram data layer.
:type audio_data: Variable :type audio: Variable
:param text_data: Transcription text data layer. :param text: Transcription text data layer.
:type text_data: Variable :type text: Variable
:param audio_len: Valid sequence length data layer. :param audio_len: Valid sequence length data layer.
:type audio_len: Variable :type audio_len: Variable
:param feat_size: feature size for audio.
:type feat_size: int
:param dict_size: Dictionary size for tokenized transcription. :param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int :type dict_size: int
:param num_conv_layers: Number of stacking convolution layers. :param num_conv_layers: Number of stacking convolution layers.
......
...@@ -146,7 +146,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase): ...@@ -146,7 +146,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual( self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True) paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False: if use_gru is False:
self.assertEqual( self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True) paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
...@@ -177,7 +177,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase): ...@@ -177,7 +177,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual( self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True) paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False: if use_gru is False:
self.assertEqual( self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True) paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册