提交 9068c0d4 编写于 作者: H huangyuxin

Merge branch 'HEAD_1' into ds2_online

......@@ -55,7 +55,7 @@ if __name__ == "__main__":
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
if config.training.seed != None:
if config.training.seed is not None:
os.environ.setdefault('FLAGS_cudnn_deterministic', 'True')
main(config, args)
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains DeepSpeech2 and DeepSpeech2Online model."""
import os
import random
import time
from collections import defaultdict
......@@ -64,7 +63,7 @@ class DeepSpeech2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
if config.training.seed != None:
if config.training.seed is not None:
self.set_seed(config.training.seed)
def set_seed(self, seed):
......
......@@ -52,10 +52,7 @@ if __name__ == "__main__":
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
if config.training.seed != None:
os.environ.setdefault('FLAGS_cudnn_deterministic', 'True')
main(config, args)
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
......
......@@ -55,7 +55,7 @@ class U2Trainer(Trainer):
log_interval=100, # steps
accum_grad=1, # accum grad by # steps
global_grad_clip=5.0, # the global norm clip
seed=1024, ))
))
default.optim = 'adam'
default.optim_conf = CfgNode(
dict(
......@@ -75,12 +75,6 @@ class U2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
if config.training.seed != None:
self.set_seed(config.training.seed)
def set_seed(self, seed):
np.random.seed(seed)
paddle.seed(seed)
def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training
......
......@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False
......
......@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.subsampling import Conv2dSubsampling4
......
......@@ -26,7 +26,7 @@ from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModeOnline']
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline']
class CRNNEncoder(nn.Layer):
......@@ -68,7 +68,7 @@ class CRNNEncoder(nn.Layer):
rnn_input_size = i_size
else:
rnn_input_size = layernorm_size
if use_gru == True:
if use_gru is True:
self.rnn.append(
nn.GRU(
input_size=rnn_input_size,
......@@ -113,7 +113,7 @@ class CRNNEncoder(nn.Layer):
if init_state_h_box is not None:
init_state_list = None
if self.use_gru == True:
if self.use_gru is True:
init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0)
init_state_list = init_state_h_list
......@@ -139,7 +139,7 @@ class CRNNEncoder(nn.Layer):
x = self.fc_layers_list[i](x)
x = F.relu(x)
if self.use_gru == True:
if self.use_gru is True:
final_chunk_state_h_box = paddle.concat(
final_chunk_state_list, axis=0)
final_chunk_state_c_box = init_state_c_box
......
......@@ -46,7 +46,7 @@ model:
training:
n_epoch: 50
lr: 2e-3
lr_decay: 0.9 # 0.83
lr_decay: 0.91 # 0.83
weight_decay: 1e-06
global_grad_clip: 3.0
log_interval: 100
......
......@@ -143,10 +143,10 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list)
decode_max_len = eouts.shape[1]
eouts_by_chk = eouts_by_chk[:, :decode_max_len, :]
self.assertEqual(paddle.allclose(eouts_by_chk, eouts, atol=1e-5), True)
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False:
if use_gru is False:
self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
......@@ -177,7 +177,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False:
if use_gru is False:
self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册