提交 123c1a4f 编写于 作者: H Hui Zhang

format

上级 17092cbb
......@@ -187,7 +187,8 @@ class DeepSpeech2Trainer(Trainer):
sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method)
collate_fn = SpeechCollator(keep_transcription_text=False, return_utts=False)
collate_fn = SpeechCollator(
keep_transcription_text=False, return_utts=False)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
......@@ -220,7 +221,8 @@ class DeepSpeech2Trainer(Trainer):
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True, return_utts=True))
collate_fn=SpeechCollator(
keep_transcription_text=True, return_utts=True))
logger.info("Setup test Dataloader!")
......
......@@ -222,7 +222,8 @@ class U2Trainer(Trainer):
config.data.augmentation_config = ""
dev_dataset = ManifestDataset.from_config(config)
collate_fn = SpeechCollator(keep_transcription_text=False, return_utts=False)
collate_fn = SpeechCollator(
keep_transcription_text=False, return_utts=False)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
......@@ -272,7 +273,8 @@ class U2Trainer(Trainer):
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True, return_utts=True))
collate_fn=SpeechCollator(
keep_transcription_text=True, return_utts=True))
logger.info("Setup train/valid/test Dataloader!")
def setup_model(self):
......
......@@ -13,9 +13,9 @@
# limitations under the License.
"""Contains the data augmentation pipeline."""
import json
from pprint import pformat
from collections.abc import Sequence
from inspect import signature
from pprint import pformat
import numpy as np
......@@ -112,7 +112,8 @@ class AugmentationPipeline():
'audio')
self._spec_augmentors, self._spec_rates = self._parse_pipeline_from(
'feature')
logger.info(f"Augmentation: {pformat(list(zip(self._augmentors, self._rates)))}")
logger.info(
f"Augmentation: {pformat(list(zip(self._augmentors, self._rates)))}")
def __call__(self, xs, uttid_list=None, **kwargs):
if not isinstance(xs, Sequence):
......@@ -203,7 +204,7 @@ class AugmentationPipeline():
aug_confs = all_confs
else:
raise ValueError(f"Not support: {aug_type}")
augmentors = [
self._get_augmentor(config["type"], config["params"])
for config in aug_confs
......
......@@ -12,17 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the text featurizer class."""
import sentencepiece as spm
from pprint import pformat
import sentencepiece as spm
from ..utility import BLANK
from ..utility import EOS
from ..utility import load_dict
from ..utility import MASKCTC
from ..utility import SOS
from ..utility import SPACE
from ..utility import UNK
from ..utility import SOS
from ..utility import BLANK
from ..utility import MASKCTC
from ..utility import load_dict
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
......
......@@ -81,4 +81,4 @@ class SpeechCollator():
if self.return_utts:
return padded_audios, audio_lens, padded_texts, text_lens, utts
else:
return padded_audios, audio_lens, padded_texts, text_lens
\ No newline at end of file
return padded_audios, audio_lens, padded_texts, text_lens
......@@ -23,7 +23,11 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"]
class CTCLoss(nn.Layer):
def __init__(self, blank=0, reduction='sum', batch_average=False, grad_norm_type=None):
def __init__(self,
blank=0,
reduction='sum',
batch_average=False,
grad_norm_type=None):
super().__init__()
# last token id as blank id
self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
......
......@@ -148,8 +148,8 @@ class Trainer():
"lr": self.optimizer.get_lr()
})
Checkpoint().save_parameters(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model,
self.optimizer, infos)
if tag is None else tag, self.model,
self.optimizer, infos)
def resume_or_scratch(self):
"""Resume from latest checkpoint at checkpoints in the output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册