提交 698d7a9b 编写于 作者: H Haoxin Ma

move batch_size, work_nums, shuffle_method, sortagrad to collator

上级 89a00eab
...@@ -28,20 +28,6 @@ _C.data = CN( ...@@ -28,20 +28,6 @@ _C.data = CN(
augmentation_config="", augmentation_config="",
max_duration=float('inf'), max_duration=float('inf'),
min_duration=0.0, min_duration=0.0,
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delat_delta=False, # 'mfcc', 'fbank'
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
batch_size=32, # batch size
num_workers=0, # data loader workers
sortagrad=False, # sorted in first epoch when True
shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle'
)) ))
_C.model = CN( _C.model = CN(
...@@ -72,7 +58,11 @@ _C.collator =CN( ...@@ -72,7 +58,11 @@ _C.collator =CN(
use_dB_normalization=True, use_dB_normalization=True,
target_dB=-20, target_dB=-20,
dither=1.0, # feature dither dither=1.0, # feature dither
keep_transcription_text=False keep_transcription_text=False,
batch_size=32, # batch size
num_workers=0, # data loader workers
sortagrad=False, # sorted in first epoch when True
shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle'
)) ))
DeepSpeech2Model.params(_C.model) DeepSpeech2Model.params(_C.model)
......
...@@ -55,7 +55,7 @@ class DeepSpeech2Trainer(Trainer): ...@@ -55,7 +55,7 @@ class DeepSpeech2Trainer(Trainer):
'train_loss': float(loss), 'train_loss': float(loss),
} }
msg += "train time: {:>.3f}s, ".format(iteration_time) msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.data.batch_size) msg += "batch size: {}, ".format(self.config.collator.batch_size)
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items()) for k, v in losses_np.items())
logger.info(msg) logger.info(msg)
...@@ -149,31 +149,31 @@ class DeepSpeech2Trainer(Trainer): ...@@ -149,31 +149,31 @@ class DeepSpeech2Trainer(Trainer):
if self.parallel: if self.parallel:
batch_sampler = SortagradDistributedBatchSampler( batch_sampler = SortagradDistributedBatchSampler(
train_dataset, train_dataset,
batch_size=config.data.batch_size, batch_size=config.collator.batch_size,
num_replicas=None, num_replicas=None,
rank=None, rank=None,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
sortagrad=config.data.sortagrad, sortagrad=config.collator.sortagrad,
shuffle_method=config.data.shuffle_method) shuffle_method=config.collator.shuffle_method)
else: else:
batch_sampler = SortagradBatchSampler( batch_sampler = SortagradBatchSampler(
train_dataset, train_dataset,
shuffle=True, shuffle=True,
batch_size=config.data.batch_size, batch_size=config.collator.batch_size,
drop_last=True, drop_last=True,
sortagrad=config.data.sortagrad, sortagrad=config.collator.sortagrad,
shuffle_method=config.data.shuffle_method) shuffle_method=config.collator.shuffle_method)
collate_fn = SpeechCollator.from_config(config) collate_fn = SpeechCollator.from_config(config)
self.train_loader = DataLoader( self.train_loader = DataLoader(
train_dataset, train_dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=collate_fn, collate_fn=collate_fn,
num_workers=config.data.num_workers) num_workers=config.collator.num_workers)
self.valid_loader = DataLoader( self.valid_loader = DataLoader(
dev_dataset, dev_dataset,
batch_size=config.data.batch_size, batch_size=config.collator.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=collate_fn) collate_fn=collate_fn)
......
...@@ -26,7 +26,11 @@ _C.collator =CfgNode( ...@@ -26,7 +26,11 @@ _C.collator =CfgNode(
dict( dict(
augmentation_config="", augmentation_config="",
unit_type="char", unit_type="char",
keep_transcription_text=False keep_transcription_text=False,
batch_size=32, # batch size
num_workers=0, # data loader workers
sortagrad=False, # sorted in first epoch when True
shuffle_method="batch_shuffle" # 'batch_shuffle', 'instance_shuffle'
)) ))
_C.model = U2Model.params() _C.model = U2Model.params()
......
...@@ -151,13 +151,3 @@ class SpeechFeaturizer(object): ...@@ -151,13 +151,3 @@ class SpeechFeaturizer(object):
TextFeaturizer: object. TextFeaturizer: object.
""" """
return self._text_featurizer return self._text_featurizer
# @property
# def text_feature(self):
# """Return the text feature object.
# Returns:
# TextFeaturizer: object.
# """
# return self._text_featurizer
...@@ -203,34 +203,22 @@ class SpeechCollator(): ...@@ -203,34 +203,22 @@ class SpeechCollator():
where transcription part could be token ids or text. where transcription part could be token ids or text.
:rtype: tuple of (2darray, list) :rtype: tuple of (2darray, list)
""" """
start_time = time.time()
if isinstance(audio_file, str) and audio_file.startswith('tar:'): if isinstance(audio_file, str) and audio_file.startswith('tar:'):
speech_segment = SpeechSegment.from_file( speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(audio_file), transcript) self._subfile_from_tar(audio_file), transcript)
else: else:
speech_segment = SpeechSegment.from_file(audio_file, transcript) speech_segment = SpeechSegment.from_file(audio_file, transcript)
load_wav_time = time.time() - start_time
#logger.debug(f"load wav time: {load_wav_time}")
# audio augment # audio augment
start_time = time.time()
self._augmentation_pipeline.transform_audio(speech_segment) self._augmentation_pipeline.transform_audio(speech_segment)
audio_aug_time = time.time() - start_time
#logger.debug(f"audio augmentation time: {audio_aug_time}")
start_time = time.time()
specgram, transcript_part = self._speech_featurizer.featurize( specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text) speech_segment, self._keep_transcription_text)
if self._normalizer: if self._normalizer:
specgram = self._normalizer.apply(specgram) specgram = self._normalizer.apply(specgram)
feature_time = time.time() - start_time
#logger.debug(f"audio & test feature time: {feature_time}")
# specgram augment # specgram augment
start_time = time.time()
specgram = self._augmentation_pipeline.transform_feature(specgram) specgram = self._augmentation_pipeline.transform_feature(specgram)
feature_aug_time = time.time() - start_time
#logger.debug(f"audio feature augmentation time: {feature_aug_time}")
return specgram, transcript_part return specgram, transcript_part
def __call__(self, batch): def __call__(self, batch):
...@@ -283,16 +271,6 @@ class SpeechCollator(): ...@@ -283,16 +271,6 @@ class SpeechCollator():
return utts, padded_audios, audio_lens, padded_texts, text_lens return utts, padded_audios, audio_lens, padded_texts, text_lens
# @property
# def text_feature(self):
# return self._speech_featurizer.text_feature
# @property
# def stride_ms(self):
# return self._speech_featurizer.stride_ms
###########
@property @property
def manifest(self): def manifest(self):
......
...@@ -5,16 +5,13 @@ data: ...@@ -5,16 +5,13 @@ data:
test_manifest: data/manifest.test test_manifest: data/manifest.test
mean_std_filepath: data/mean_std.json mean_std_filepath: data/mean_std.json
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
batch_size: 64 # one gpu
min_input_len: 0.0 min_input_len: 0.0
max_input_len: 27.0 # second max_input_len: 27.0 # second
min_output_len: 0.0 min_output_len: 0.0
max_output_len: .inf max_output_len: .inf
min_output_input_ratio: 0.00 min_output_input_ratio: 0.00
max_output_input_ratio: .inf max_output_input_ratio: .inf
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
collator: collator:
augmentation_config: conf/augmentation.json augmentation_config: conf/augmentation.json
...@@ -32,6 +29,10 @@ collator: ...@@ -32,6 +29,10 @@ collator:
target_dB: -20 target_dB: -20
dither: 1.0 dither: 1.0
keep_transcription_text: False keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
batch_size: 64 # one gpu
model: model:
num_conv_layers: 2 num_conv_layers: 2
......
...@@ -6,16 +6,13 @@ data: ...@@ -6,16 +6,13 @@ data:
mean_std_filepath: data/mean_std.json mean_std_filepath: data/mean_std.json
unit_type: char unit_type: char
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
batch_size: 4
min_input_len: 0.0 min_input_len: 0.0
max_input_len: 27.0 max_input_len: 27.0
min_output_len: 0.0 min_output_len: 0.0
max_output_len: 400.0 max_output_len: 400.0
min_output_input_ratio: 0.05 min_output_input_ratio: 0.05
max_output_input_ratio: 10.0 max_output_input_ratio: 10.0
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
collator: collator:
augmentation_config: conf/augmentation.json augmentation_config: conf/augmentation.json
...@@ -33,6 +30,10 @@ collator: ...@@ -33,6 +30,10 @@ collator:
target_dB: -20 target_dB: -20
dither: 1.0 dither: 1.0
keep_transcription_text: False keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
batch_size: 4
model: model:
num_conv_layers: 2 num_conv_layers: 2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册