提交 123c1a4f 编写于 作者: H Hui Zhang

format

上级 17092cbb
...@@ -187,7 +187,8 @@ class DeepSpeech2Trainer(Trainer): ...@@ -187,7 +187,8 @@ class DeepSpeech2Trainer(Trainer):
sortagrad=config.data.sortagrad, sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method) shuffle_method=config.data.shuffle_method)
collate_fn = SpeechCollator(keep_transcription_text=False, return_utts=False) collate_fn = SpeechCollator(
keep_transcription_text=False, return_utts=False)
self.train_loader = DataLoader( self.train_loader = DataLoader(
train_dataset, train_dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
...@@ -220,7 +221,8 @@ class DeepSpeech2Trainer(Trainer): ...@@ -220,7 +221,8 @@ class DeepSpeech2Trainer(Trainer):
batch_size=config.decoding.batch_size, batch_size=config.decoding.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True, return_utts=True)) collate_fn=SpeechCollator(
keep_transcription_text=True, return_utts=True))
logger.info("Setup test Dataloader!") logger.info("Setup test Dataloader!")
......
...@@ -222,7 +222,8 @@ class U2Trainer(Trainer): ...@@ -222,7 +222,8 @@ class U2Trainer(Trainer):
config.data.augmentation_config = "" config.data.augmentation_config = ""
dev_dataset = ManifestDataset.from_config(config) dev_dataset = ManifestDataset.from_config(config)
collate_fn = SpeechCollator(keep_transcription_text=False, return_utts=False) collate_fn = SpeechCollator(
keep_transcription_text=False, return_utts=False)
if self.parallel: if self.parallel:
batch_sampler = SortagradDistributedBatchSampler( batch_sampler = SortagradDistributedBatchSampler(
train_dataset, train_dataset,
...@@ -272,7 +273,8 @@ class U2Trainer(Trainer): ...@@ -272,7 +273,8 @@ class U2Trainer(Trainer):
batch_size=config.decoding.batch_size, batch_size=config.decoding.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True, return_utts=True)) collate_fn=SpeechCollator(
keep_transcription_text=True, return_utts=True))
logger.info("Setup train/valid/test Dataloader!") logger.info("Setup train/valid/test Dataloader!")
def setup_model(self): def setup_model(self):
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""Contains the data augmentation pipeline.""" """Contains the data augmentation pipeline."""
import json import json
from pprint import pformat
from collections.abc import Sequence from collections.abc import Sequence
from inspect import signature from inspect import signature
from pprint import pformat
import numpy as np import numpy as np
...@@ -112,7 +112,8 @@ class AugmentationPipeline(): ...@@ -112,7 +112,8 @@ class AugmentationPipeline():
'audio') 'audio')
self._spec_augmentors, self._spec_rates = self._parse_pipeline_from( self._spec_augmentors, self._spec_rates = self._parse_pipeline_from(
'feature') 'feature')
logger.info(f"Augmentation: {pformat(list(zip(self._augmentors, self._rates)))}") logger.info(
f"Augmentation: {pformat(list(zip(self._augmentors, self._rates)))}")
def __call__(self, xs, uttid_list=None, **kwargs): def __call__(self, xs, uttid_list=None, **kwargs):
if not isinstance(xs, Sequence): if not isinstance(xs, Sequence):
...@@ -203,7 +204,7 @@ class AugmentationPipeline(): ...@@ -203,7 +204,7 @@ class AugmentationPipeline():
aug_confs = all_confs aug_confs = all_confs
else: else:
raise ValueError(f"Not support: {aug_type}") raise ValueError(f"Not support: {aug_type}")
augmentors = [ augmentors = [
self._get_augmentor(config["type"], config["params"]) self._get_augmentor(config["type"], config["params"])
for config in aug_confs for config in aug_confs
......
...@@ -12,17 +12,17 @@ ...@@ -12,17 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains the text featurizer class.""" """Contains the text featurizer class."""
import sentencepiece as spm
from pprint import pformat from pprint import pformat
import sentencepiece as spm
from ..utility import BLANK
from ..utility import EOS from ..utility import EOS
from ..utility import load_dict
from ..utility import MASKCTC
from ..utility import SOS
from ..utility import SPACE from ..utility import SPACE
from ..utility import UNK from ..utility import UNK
from ..utility import SOS
from ..utility import BLANK
from ..utility import MASKCTC
from ..utility import load_dict
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
......
...@@ -81,4 +81,4 @@ class SpeechCollator(): ...@@ -81,4 +81,4 @@ class SpeechCollator():
if self.return_utts: if self.return_utts:
return padded_audios, audio_lens, padded_texts, text_lens, utts return padded_audios, audio_lens, padded_texts, text_lens, utts
else: else:
return padded_audios, audio_lens, padded_texts, text_lens return padded_audios, audio_lens, padded_texts, text_lens
\ No newline at end of file
...@@ -23,7 +23,11 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"] ...@@ -23,7 +23,11 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"]
class CTCLoss(nn.Layer): class CTCLoss(nn.Layer):
def __init__(self, blank=0, reduction='sum', batch_average=False, grad_norm_type=None): def __init__(self,
blank=0,
reduction='sum',
batch_average=False,
grad_norm_type=None):
super().__init__() super().__init__()
# last token id as blank id # last token id as blank id
self.loss = nn.CTCLoss(blank=blank, reduction=reduction) self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
......
...@@ -148,8 +148,8 @@ class Trainer(): ...@@ -148,8 +148,8 @@ class Trainer():
"lr": self.optimizer.get_lr() "lr": self.optimizer.get_lr()
}) })
Checkpoint().save_parameters(self.checkpoint_dir, self.iteration Checkpoint().save_parameters(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model, if tag is None else tag, self.model,
self.optimizer, infos) self.optimizer, infos)
def resume_or_scratch(self): def resume_or_scratch(self):
"""Resume from latest checkpoint at checkpoints in the output """Resume from latest checkpoint at checkpoints in the output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册