提交 98b15eda 编写于 作者: H Hui Zhang

batch WaveDataset

上级 ecb5d4f8
......@@ -199,11 +199,11 @@ class U2Trainer(Trainer):
report("Rank", dist.get_rank())
report("epoch", self.epoch)
report('step', self.iteration)
report('iter', batch_index + 1)
report('total',len(self.train_loader))
report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg)
self.after_train_batch()
report('iter', batch_index + 1)
report('total', len(self.train_loader))
report('reader_cost', dataload_time)
observation['batch_cost'] = observation[
'reader_cost'] + observation['step_cost']
......
......@@ -292,10 +292,6 @@ class SpeechCollator():
olens = np.array(text_lens).astype(np.int64)
return utts, xs_pad, ilens, ys_pad, olens
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
......
......@@ -147,3 +147,131 @@ class TransformDataset(Dataset):
def __getitem__(self, idx):
"""[] operator."""
return self.converter([self.reader(self.data[idx], return_uttid=True)])
class AudioDataset(Dataset):
def __init__(self,
data_file,
max_length=10240,
min_length=0,
token_max_length=200,
token_min_length=1,
batch_type='static',
batch_size=1,
max_frames_in_batch=0,
sort=True,
raw_wav=True,
stride_ms=10):
"""Dataset for loading audio data.
Attributes::
data_file: input data file
Plain text data file, each line contains following 7 fields,
which is split by '\t':
utt:utt1
feat:tmp/data/file1.wav or feat:tmp/data/fbank.ark:30
feat_shape: 4.95(in seconds) or feat_shape:495,80(495 is in frames)
text:i love you
token: i <space> l o v e <space> y o u
tokenid: int id of this token
token_shape: M,N # M is the number of token, N is vocab size
max_length: drop utterance which is greater than max_length(10ms), unit 10ms.
min_length: drop utterance which is less than min_length(10ms), unit 10ms.
token_max_length: drop utterance which is greater than token_max_length,
especially when use char unit for english modeling
token_min_length: drop utterance which is less than token_max_length
batch_type: static or dynamic, see max_frames_in_batch(dynamic)
batch_size: number of utterances in a batch,
it's for static batch size.
max_frames_in_batch: max feature frames in a batch,
when batch_type is dynamic, it's for dynamic batch size.
Then batch_size is ignored, we will keep filling the
batch until the total frames in batch up to max_frames_in_batch.
sort: whether to sort all data, so the utterance with the same
length could be filled in a same batch.
raw_wav: use raw wave or extracted featute.
if raw wave is used, dynamic waveform-level augmentation could be used
and the feature is extracted by torchaudio.
if extracted featute(e.g. by kaldi) is used, only feature-level
augmentation such as specaug could be used.
"""
assert batch_type in ['static', 'dynamic']
# read manifest
data = read_manifest(data_file)
if sort:
data = sorted(data, key=lambda x: x["feat_shape"][0])
if raw_wav:
assert data[0]['feat'].split(':')[0].splitext()[-1] not in ('.ark',
'.scp')
data = map(lambda x: (float(x['feat_shape'][0]) * 1000 / stride_ms))
self.input_dim = data[0]['feat_shape'][1]
self.output_dim = data[0]['token_shape'][1]
# with open(data_file, 'r') as f:
# for line in f:
# arr = line.strip().split('\t')
# if len(arr) != 7:
# continue
# key = arr[0].split(':')[1]
# tokenid = arr[5].split(':')[1]
# output_dim = int(arr[6].split(':')[1].split(',')[1])
# if raw_wav:
# wav_path = ':'.join(arr[1].split(':')[1:])
# duration = int(float(arr[2].split(':')[1]) * 1000 / 10)
# data.append((key, wav_path, duration, tokenid))
# else:
# feat_ark = ':'.join(arr[1].split(':')[1:])
# feat_info = arr[2].split(':')[1].split(',')
# feat_dim = int(feat_info[1].strip())
# num_frames = int(feat_info[0].strip())
# data.append((key, feat_ark, num_frames, tokenid))
# self.input_dim = feat_dim
# self.output_dim = output_dim
valid_data = []
for i in range(len(data)):
length = data[i]['feat_shape'][0]
token_length = data[i]['token_shape'][0]
# remove too lang or too short utt for both input and output
# to prevent from out of memory
if length > max_length or length < min_length:
# logging.warn('ignore utterance {} feature {}'.format(
# data[i][0], length))
pass
elif token_length > token_max_length or token_length < token_min_length:
pass
else:
valid_data.append(data[i])
data = valid_data
self.minibatch = []
num_data = len(data)
# Dynamic batch size
if batch_type == 'dynamic':
assert (max_frames_in_batch > 0)
self.minibatch.append([])
num_frames_in_batch = 0
for i in range(num_data):
length = data[i]['feat_shape'][0]
num_frames_in_batch += length
if num_frames_in_batch > max_frames_in_batch:
self.minibatch.append([])
num_frames_in_batch = length
self.minibatch[-1].append(data[i])
# Static batch size
else:
cur = 0
while cur < num_data:
end = min(cur + batch_size, num_data)
item = []
for i in range(cur, end):
item.append(data[i])
self.minibatch.append(item)
cur = end
def __len__(self):
return len(self.minibatch)
def __getitem__(self, idx):
instance = self.minibatch[idx]
return instance["utt"], instance["feat"], instance["text"]
......@@ -247,11 +247,11 @@ class Trainer():
report("Rank", dist.get_rank())
report("epoch", self.epoch)
report('step', self.iteration)
report('iter', batch_index + 1)
report('total',len(self.train_loader))
report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg)
self.after_train_batch()
report('iter', batch_index + 1)
report('total', len(self.train_loader))
report('reader_cost', dataload_time)
observation['batch_cost'] = observation[
'reader_cost'] + observation['step_cost']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册