提交 123c1a4f 编写于 作者: H Hui Zhang

format

上级 17092cbb
...@@ -187,7 +187,8 @@ class DeepSpeech2Trainer(Trainer): ...@@ -187,7 +187,8 @@ class DeepSpeech2Trainer(Trainer):
sortagrad=config.data.sortagrad, sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method) shuffle_method=config.data.shuffle_method)
collate_fn = SpeechCollator(keep_transcription_text=False, return_utts=False) collate_fn = SpeechCollator(
keep_transcription_text=False, return_utts=False)
self.train_loader = DataLoader( self.train_loader = DataLoader(
train_dataset, train_dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
...@@ -220,7 +221,8 @@ class DeepSpeech2Trainer(Trainer): ...@@ -220,7 +221,8 @@ class DeepSpeech2Trainer(Trainer):
batch_size=config.decoding.batch_size, batch_size=config.decoding.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True, return_utts=True)) collate_fn=SpeechCollator(
keep_transcription_text=True, return_utts=True))
logger.info("Setup test Dataloader!") logger.info("Setup test Dataloader!")
......
...@@ -222,7 +222,8 @@ class U2Trainer(Trainer): ...@@ -222,7 +222,8 @@ class U2Trainer(Trainer):
config.data.augmentation_config = "" config.data.augmentation_config = ""
dev_dataset = ManifestDataset.from_config(config) dev_dataset = ManifestDataset.from_config(config)
collate_fn = SpeechCollator(keep_transcription_text=False, return_utts=False) collate_fn = SpeechCollator(
keep_transcription_text=False, return_utts=False)
if self.parallel: if self.parallel:
batch_sampler = SortagradDistributedBatchSampler( batch_sampler = SortagradDistributedBatchSampler(
train_dataset, train_dataset,
...@@ -272,7 +273,8 @@ class U2Trainer(Trainer): ...@@ -272,7 +273,8 @@ class U2Trainer(Trainer):
batch_size=config.decoding.batch_size, batch_size=config.decoding.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True, return_utts=True)) collate_fn=SpeechCollator(
keep_transcription_text=True, return_utts=True))
logger.info("Setup train/valid/test Dataloader!") logger.info("Setup train/valid/test Dataloader!")
def setup_model(self): def setup_model(self):
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""Contains the data augmentation pipeline.""" """Contains the data augmentation pipeline."""
import json import json
from pprint import pformat
from collections.abc import Sequence from collections.abc import Sequence
from inspect import signature from inspect import signature
from pprint import pformat
import numpy as np import numpy as np
...@@ -112,7 +112,8 @@ class AugmentationPipeline(): ...@@ -112,7 +112,8 @@ class AugmentationPipeline():
'audio') 'audio')
self._spec_augmentors, self._spec_rates = self._parse_pipeline_from( self._spec_augmentors, self._spec_rates = self._parse_pipeline_from(
'feature') 'feature')
logger.info(f"Augmentation: {pformat(list(zip(self._augmentors, self._rates)))}") logger.info(
f"Augmentation: {pformat(list(zip(self._augmentors, self._rates)))}")
def __call__(self, xs, uttid_list=None, **kwargs): def __call__(self, xs, uttid_list=None, **kwargs):
if not isinstance(xs, Sequence): if not isinstance(xs, Sequence):
......
...@@ -12,17 +12,17 @@ ...@@ -12,17 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains the text featurizer class.""" """Contains the text featurizer class."""
import sentencepiece as spm
from pprint import pformat from pprint import pformat
import sentencepiece as spm
from ..utility import BLANK
from ..utility import EOS from ..utility import EOS
from ..utility import load_dict
from ..utility import MASKCTC
from ..utility import SOS
from ..utility import SPACE from ..utility import SPACE
from ..utility import UNK from ..utility import UNK
from ..utility import SOS
from ..utility import BLANK
from ..utility import MASKCTC
from ..utility import load_dict
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
......
...@@ -23,7 +23,11 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"] ...@@ -23,7 +23,11 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"]
class CTCLoss(nn.Layer): class CTCLoss(nn.Layer):
def __init__(self, blank=0, reduction='sum', batch_average=False, grad_norm_type=None): def __init__(self,
blank=0,
reduction='sum',
batch_average=False,
grad_norm_type=None):
super().__init__() super().__init__()
# last token id as blank id # last token id as blank id
self.loss = nn.CTCLoss(blank=blank, reduction=reduction) self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册