提交 698d7a9b 编写于 作者: H Haoxin Ma

move batch_size, work_nums, shuffle_method, sortagrad to collator

上级 89a00eab
......@@ -28,20 +28,6 @@ _C.data = CN(
augmentation_config="",
max_duration=float('inf'),
min_duration=0.0,
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delat_delta=False, # 'mfcc', 'fbank'
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
batch_size=32, # batch size
num_workers=0, # data loader workers
sortagrad=False, # sorted in first epoch when True
shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle'
))
_C.model = CN(
......@@ -72,7 +58,11 @@ _C.collator =CN(
use_dB_normalization=True,
target_dB=-20,
dither=1.0, # feature dither
keep_transcription_text=False
keep_transcription_text=False,
batch_size=32, # batch size
num_workers=0, # data loader workers
sortagrad=False, # sorted in first epoch when True
shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle'
))
DeepSpeech2Model.params(_C.model)
......
......@@ -55,7 +55,7 @@ class DeepSpeech2Trainer(Trainer):
'train_loss': float(loss),
}
msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.data.batch_size)
msg += "batch size: {}, ".format(self.config.collator.batch_size)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
logger.info(msg)
......@@ -149,31 +149,31 @@ class DeepSpeech2Trainer(Trainer):
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
batch_size=config.data.batch_size,
batch_size=config.collator.batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=True,
sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method)
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
else:
batch_sampler = SortagradBatchSampler(
train_dataset,
shuffle=True,
batch_size=config.data.batch_size,
batch_size=config.collator.batch_size,
drop_last=True,
sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method)
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
collate_fn = SpeechCollator.from_config(config)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
num_workers=config.data.num_workers)
num_workers=config.collator.num_workers)
self.valid_loader = DataLoader(
dev_dataset,
batch_size=config.data.batch_size,
batch_size=config.collator.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn)
......
......@@ -26,7 +26,11 @@ _C.collator =CfgNode(
dict(
augmentation_config="",
unit_type="char",
keep_transcription_text=False
keep_transcription_text=False,
batch_size=32, # batch size
num_workers=0, # data loader workers
sortagrad=False, # sorted in first epoch when True
shuffle_method="batch_shuffle" # 'batch_shuffle', 'instance_shuffle'
))
_C.model = U2Model.params()
......
......@@ -151,13 +151,3 @@ class SpeechFeaturizer(object):
TextFeaturizer: object.
"""
return self._text_featurizer
# @property
# def text_feature(self):
# """Return the text feature object.
# Returns:
# TextFeaturizer: object.
# """
# return self._text_featurizer
......@@ -203,34 +203,22 @@ class SpeechCollator():
where transcription part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
start_time = time.time()
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(audio_file), transcript)
else:
speech_segment = SpeechSegment.from_file(audio_file, transcript)
load_wav_time = time.time() - start_time
#logger.debug(f"load wav time: {load_wav_time}")
# audio augment
start_time = time.time()
self._augmentation_pipeline.transform_audio(speech_segment)
audio_aug_time = time.time() - start_time
#logger.debug(f"audio augmentation time: {audio_aug_time}")
start_time = time.time()
specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
if self._normalizer:
specgram = self._normalizer.apply(specgram)
feature_time = time.time() - start_time
#logger.debug(f"audio & test feature time: {feature_time}")
# specgram augment
start_time = time.time()
specgram = self._augmentation_pipeline.transform_feature(specgram)
feature_aug_time = time.time() - start_time
#logger.debug(f"audio feature augmentation time: {feature_aug_time}")
return specgram, transcript_part
def __call__(self, batch):
......@@ -283,16 +271,6 @@ class SpeechCollator():
return utts, padded_audios, audio_lens, padded_texts, text_lens
# @property
# def text_feature(self):
# return self._speech_featurizer.text_feature
# @property
# def stride_ms(self):
# return self._speech_featurizer.stride_ms
###########
@property
def manifest(self):
......
......@@ -5,16 +5,13 @@ data:
test_manifest: data/manifest.test
mean_std_filepath: data/mean_std.json
vocab_filepath: data/vocab.txt
batch_size: 64 # one gpu
min_input_len: 0.0
max_input_len: 27.0 # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
collator:
augmentation_config: conf/augmentation.json
......@@ -32,6 +29,10 @@ collator:
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
batch_size: 64 # one gpu
model:
num_conv_layers: 2
......
......@@ -6,16 +6,13 @@ data:
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
batch_size: 4
min_input_len: 0.0
max_input_len: 27.0
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
collator:
augmentation_config: conf/augmentation.json
......@@ -33,6 +30,10 @@ collator:
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
batch_size: 4
model:
num_conv_layers: 2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册