提交 718407b7 编写于 作者: H huangyuxin

add seed

上级 08b68e4b
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer for DeepSpeech2 model."""
import os
from paddle import distributed as dist
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
......@@ -53,5 +55,7 @@ if __name__ == "__main__":
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
if config.training.seed != None:
os.environ.setdefault('FLAGS_cudnn_deterministic', 'True')
main(config, args)
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains DeepSpeech2 and DeepSpeech2Online model."""
import os
import random
import time
from collections import defaultdict
from pathlib import Path
......@@ -53,6 +55,7 @@ class DeepSpeech2Trainer(Trainer):
weight_decay=1e-6, # the coeff of weight decay
global_grad_clip=5.0, # the global norm clip
n_epoch=50, # train epochs
seed=1024, #train seed
))
if config is not None:
......@@ -61,6 +64,13 @@ class DeepSpeech2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
if config.training.seed != None:
self.set_seed(config.training.seed)
def set_seed(self, seed):
np.random.seed(seed)
random.seed(seed)
paddle.seed(seed)
def train_batch(self, batch_index, batch_data, msg):
start = time.time()
......
......@@ -52,7 +52,10 @@ if __name__ == "__main__":
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
if config.training.seed != None:
os.environ.setdefault('FLAGS_cudnn_deterministic', 'True')
main(config, args)
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
......
......@@ -55,7 +55,7 @@ class U2Trainer(Trainer):
log_interval=100, # steps
accum_grad=1, # accum grad by # steps
global_grad_clip=5.0, # the global norm clip
))
seed=1024, ))
default.optim = 'adam'
default.optim_conf = CfgNode(
dict(
......@@ -75,6 +75,12 @@ class U2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
if config.training.seed != None:
self.set_seed(config.training.seed)
def set_seed(self, seed):
np.random.seed(seed)
paddle.seed(seed)
def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training
......
......@@ -102,13 +102,13 @@ class CRNNEncoder(nn.Layer):
Args:
x (Tensor): [B, feature_size, D]
x_lens (Tensor): [B]
init_state_h_box(Tensor): init_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
init_state_c_box(Tensor): init_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
Returns:
init_state_h_box(Tensor): init_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
init_state_c_box(Tensor): init_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
Return:
x (Tensor): encoder outputs, [B, size, D]
x_lens (Tensor): encoder length, [B]
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
"""
if init_state_h_box is not None:
init_state_list = None
......@@ -142,7 +142,7 @@ class CRNNEncoder(nn.Layer):
if self.use_gru == True:
final_chunk_state_h_box = paddle.concat(
final_chunk_state_list, axis=0)
final_chunk_state_c_box = init_state_c_box #paddle.zeros_like(final_chunk_state_h_box)
final_chunk_state_c_box = init_state_c_box
else:
final_chunk_state_h_list = [
final_chunk_state_list[i][0] for i in range(self.num_rnn_layers)
......@@ -165,10 +165,10 @@ class CRNNEncoder(nn.Layer):
x_lens (Tensor): [B]
decoder_chunk_size: The chunk size of decoder
Returns:
eouts_list (List of Tensor): The list of encoder outputs in chunk_size, [B, chunk_size, D] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size, [B] * num_chunks
final_state_h_box(Tensor): final_states h for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
final_state_c_box(Tensor): final_states c for RNN layers, num_rnn_layers * num_directions, batch_size, hidden_size
eouts_list (List of Tensor): The list of encoder outputs in chunk_size: [B, chunk_size, D] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size: [B] * num_chunks
final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
"""
subsampling_rate = self.conv.subsampling_rate
receptive_field_length = self.conv.receptive_field_length
......@@ -215,12 +215,14 @@ class CRNNEncoder(nn.Layer):
class DeepSpeech2ModelOnline(nn.Layer):
"""The DeepSpeech2 network structure for online.
:param audio_data: Audio spectrogram data layer.
:type audio_data: Variable
:param text_data: Transcription text data layer.
:type text_data: Variable
:param audio: Audio spectrogram data layer.
:type audio: Variable
:param text: Transcription text data layer.
:type text: Variable
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:param feat_size: feature size for audio.
:type feat_size: int
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
......
......@@ -143,7 +143,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list)
decode_max_len = eouts.shape[1]
eouts_by_chk = eouts_by_chk[:, :decode_max_len, :]
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual(paddle.allclose(eouts_by_chk, eouts, atol=1e-5), True)
self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册