未验证 提交 9ac6d65a 编写于 作者: J Jackwaterveg 提交者: GitHub

Merge pull request #780 from Jackwaterveg/ds2_online

修改pre_commit, 注释以及增加ds2的seed
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer for DeepSpeech2 model."""
import os
from paddle import distributed as dist
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
......@@ -53,5 +55,7 @@ if __name__ == "__main__":
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
if config.training.seed is not None:
os.environ.setdefault('FLAGS_cudnn_deterministic', 'True')
main(config, args)
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains DeepSpeech2 and DeepSpeech2Online model."""
import random
import time
from collections import defaultdict
from pathlib import Path
......@@ -53,6 +54,7 @@ class DeepSpeech2Trainer(Trainer):
weight_decay=1e-6, # the coeff of weight decay
global_grad_clip=5.0, # the global norm clip
n_epoch=50, # train epochs
seed=1024, #train seed
))
if config is not None:
......@@ -61,6 +63,13 @@ class DeepSpeech2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
if config.training.seed is not None:
self.set_seed(config.training.seed)
def set_seed(self, seed):
np.random.seed(seed)
random.seed(seed)
paddle.seed(seed)
def train_batch(self, batch_index, batch_data, msg):
start = time.time()
......
......@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.subsampling import Conv2dSubsampling4
......
......@@ -26,7 +26,7 @@ from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModeOnline']
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline']
class CRNNEncoder(nn.Layer):
......@@ -68,7 +68,7 @@ class CRNNEncoder(nn.Layer):
rnn_input_size = i_size
else:
rnn_input_size = layernorm_size
if use_gru == True:
if use_gru is True:
self.rnn.append(
nn.GRU(
input_size=rnn_input_size,
......@@ -102,18 +102,18 @@ class CRNNEncoder(nn.Layer):
Args:
x (Tensor): [B, feature_size, D]
x_lens (Tensor): [B]
init_state_h_box(Tensor): init_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
init_state_c_box(Tensor): init_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
Returns:
init_state_h_box(Tensor): init_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
init_state_c_box(Tensor): init_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
Return:
x (Tensor): encoder outputs, [B, size, D]
x_lens (Tensor): encoder length, [B]
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
"""
if init_state_h_box is not None:
init_state_list = None
if self.use_gru == True:
if self.use_gru is True:
init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0)
init_state_list = init_state_h_list
......@@ -139,10 +139,10 @@ class CRNNEncoder(nn.Layer):
x = self.fc_layers_list[i](x)
x = F.relu(x)
if self.use_gru == True:
if self.use_gru is True:
final_chunk_state_h_box = paddle.concat(
final_chunk_state_list, axis=0)
final_chunk_state_c_box = init_state_c_box #paddle.zeros_like(final_chunk_state_h_box)
final_chunk_state_c_box = init_state_c_box
else:
final_chunk_state_h_list = [
final_chunk_state_list[i][0] for i in range(self.num_rnn_layers)
......@@ -165,10 +165,10 @@ class CRNNEncoder(nn.Layer):
x_lens (Tensor): [B]
decoder_chunk_size: The chunk size of decoder
Returns:
eouts_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
eouts_list (List of Tensor): The list of encoder outputs in chunk_size: [B, chunk_size, D] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size: [B] * num_chunks
final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
"""
subsampling_rate = self.conv.subsampling_rate
receptive_field_length = self.conv.receptive_field_length
......@@ -215,12 +215,14 @@ class CRNNEncoder(nn.Layer):
class DeepSpeech2ModelOnline(nn.Layer):
"""The DeepSpeech2 network structure for online.
:param audio_data: Audio spectrogram data layer.
:type audio_data: Variable
:param text_data: Transcription text data layer.
:type text_data: Variable
:param audio: Audio spectrogram data layer.
:type audio: Variable
:param text: Transcription text data layer.
:type text: Variable
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:param feat_size: feature size for audio.
:type feat_size: int
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
......
......@@ -146,7 +146,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False:
if use_gru is False:
self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
......@@ -177,7 +177,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False:
if use_gru is False:
self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册