diff --git a/examples/aishell/asr1/conf/conformer.yaml b/examples/aishell/asr1/conf/conformer.yaml index d5d883a031fafd05c99e7aa3f0b11474f95862d5..a150a04d55671edf25e5871b1695bcad14710367 100644 --- a/examples/aishell/asr1/conf/conformer.yaml +++ b/examples/aishell/asr1/conf/conformer.yaml @@ -85,7 +85,7 @@ scheduler: warmuplr scheduler_conf: warmup_steps: 25000 lr_decay: 1.0 -log_interval: 1 +log_interval: 100 checkpoint: kbest_n: 50 latest_n: 5 diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index bcbc15d64ed2308eb197f5be13fff5567519d7a1..efcc9629fdbf63981cfdc4cc5b91693e5f3a85ee 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -239,7 +239,7 @@ class U2Trainer(Trainer): n_iter_processes=config.num_workers, subsampling_factor=1, num_encs=1, - dist_sampler=True, + dist_sampler=config.get('dist_sampler', False), shortest_first=False) self.valid_loader = BatchDataLoader( @@ -260,7 +260,7 @@ class U2Trainer(Trainer): n_iter_processes=config.num_workers, subsampling_factor=1, num_encs=1, - dist_sampler=True, + dist_sampler=config.get('dist_sampler', False), shortest_first=False) logger.info("Setup train/valid Dataloader!") else: diff --git a/paddlespeech/s2t/modules/initializer.py b/paddlespeech/s2t/modules/initializer.py index 98466ebdb0a9e28749f107bcac52a2ecb29463b5..30a04e44fb2965d03be8c6346ef16448ed257bbc 100644 --- a/paddlespeech/s2t/modules/initializer.py +++ b/paddlespeech/s2t/modules/initializer.py @@ -160,9 +160,12 @@ class DefaultInitializerContext(object): self.init_type = init_type def __enter__(self): - from paddlespeech.s2t.modules import align - align.global_init_type = self.init_type - return self + if self.init_type is None: + return + else: + from paddlespeech.s2t.modules import align + align.global_init_type = self.init_type + return def __exit__(self, exc_type, exc_val, exc_tb): from paddlespeech.s2t.modules import align