未验证 提交 c81a3f0f 编写于 作者: H Hui Zhang 提交者: GitHub

[s2t] DataLoader with BatchSampler or DistributeBatchSampler (#1242)

* batchsampler or distributebatchsampler

* format
上级 6d93f3e5
...@@ -292,7 +292,8 @@ class U2STTrainer(Trainer): ...@@ -292,7 +292,8 @@ class U2STTrainer(Trainer):
n_iter_processes=config.collator.num_workers, n_iter_processes=config.collator.num_workers,
subsampling_factor=1, subsampling_factor=1,
load_aux_output=load_transcript, load_aux_output=load_transcript,
num_encs=1) num_encs=1,
dist_sampler=True)
self.valid_loader = BatchDataLoader( self.valid_loader = BatchDataLoader(
json_file=config.data.dev_manifest, json_file=config.data.dev_manifest,
...@@ -313,7 +314,8 @@ class U2STTrainer(Trainer): ...@@ -313,7 +314,8 @@ class U2STTrainer(Trainer):
n_iter_processes=config.collator.num_workers, n_iter_processes=config.collator.num_workers,
subsampling_factor=1, subsampling_factor=1,
load_aux_output=load_transcript, load_aux_output=load_transcript,
num_encs=1) num_encs=1,
dist_sampler=True)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
else: else:
# test dataset, return raw text # test dataset, return raw text
...@@ -335,7 +337,8 @@ class U2STTrainer(Trainer): ...@@ -335,7 +337,8 @@ class U2STTrainer(Trainer):
augmentation_config, # aug will be off when train_mode=False augmentation_config, # aug will be off when train_mode=False
n_iter_processes=config.collator.num_workers, n_iter_processes=config.collator.num_workers,
subsampling_factor=1, subsampling_factor=1,
num_encs=1) num_encs=1,
dist_sampler=False)
logger.info("Setup test Dataloader!") logger.info("Setup test Dataloader!")
...@@ -542,7 +545,8 @@ class U2STTester(U2STTrainer): ...@@ -542,7 +545,8 @@ class U2STTester(U2STTrainer):
len_refs += metrics['len_refs'] len_refs += metrics['len_refs']
num_ins += metrics['num_ins'] num_ins += metrics['num_ins']
rtf = num_time / (num_frames * stride_ms) rtf = num_time / (num_frames * stride_ms)
logger.info("RTF: %f, instance (%d), batch BELU = %f" % (rtf, num_ins, bleu)) logger.info("RTF: %f, instance (%d), batch BELU = %f" %
(rtf, num_ins, bleu))
rtf = num_time / (num_frames * stride_ms) rtf = num_time / (num_frames * stride_ms)
msg = "Test: " msg = "Test: "
......
...@@ -65,8 +65,9 @@ class CustomConverter(): ...@@ -65,8 +65,9 @@ class CustomConverter():
# text data (output): (text_len, ) # text data (output): (text_len, )
ys_data.append(ud) ys_data.append(ud)
assert xs_data[0][0] is not None, "please check Reader and Augmentation impl." assert xs_data[0][
0] is not None, "please check Reader and Augmentation impl."
xs_pad, ilens = [], [] xs_pad, ilens = [], []
for xs in xs_data: for xs in xs_data:
# perform subsampling # perform subsampling
...@@ -79,22 +80,26 @@ class CustomConverter(): ...@@ -79,22 +80,26 @@ class CustomConverter():
# perform padding and convert to tensor # perform padding and convert to tensor
# currently only support real number # currently only support real number
xs_pad.append(pad_list(xs, 0).astype(self.dtype)) xs_pad.append(pad_list(xs, 0).astype(self.dtype))
if not self.load_aux_input: if not self.load_aux_input:
xs_pad, ilens = xs_pad[0], ilens[0] xs_pad, ilens = xs_pad[0], ilens[0]
break break
# NOTE: this is for multi-output (e.g., speech translation) # NOTE: this is for multi-output (e.g., speech translation)
ys_pad, olens = [], [] ys_pad, olens = [], []
for ys in ys_data: for ys in ys_data:
ys_pad.append(pad_list( ys_pad.append(
[np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys], pad_list([
self.ignore_id)) np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys
], self.ignore_id))
olens.append(
np.array([
y[0].shape[0] if isinstance(y, tuple) else y.shape[0]
for y in ys
]))
olens.append(np.array(
[y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys]))
if not self.load_aux_output: if not self.load_aux_output:
ys_pad, olens = ys_pad[0], olens[0] ys_pad, olens = ys_pad[0], olens[0]
break break
......
...@@ -18,6 +18,7 @@ from typing import Text ...@@ -18,6 +18,7 @@ from typing import Text
import jsonlines import jsonlines
import numpy as np import numpy as np
from paddle.io import BatchSampler
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
...@@ -76,7 +77,8 @@ class BatchDataLoader(): ...@@ -76,7 +77,8 @@ class BatchDataLoader():
subsampling_factor: int=1, subsampling_factor: int=1,
load_aux_input: bool=False, load_aux_input: bool=False,
load_aux_output: bool=False, load_aux_output: bool=False,
num_encs: int=1): num_encs: int=1,
dist_sampler: bool=False):
self.json_file = json_file self.json_file = json_file
self.train_mode = train_mode self.train_mode = train_mode
self.use_sortagrad = sortagrad == -1 or sortagrad > 0 self.use_sortagrad = sortagrad == -1 or sortagrad > 0
...@@ -94,6 +96,7 @@ class BatchDataLoader(): ...@@ -94,6 +96,7 @@ class BatchDataLoader():
self.n_iter_processes = n_iter_processes self.n_iter_processes = n_iter_processes
self.load_aux_input = load_aux_input self.load_aux_input = load_aux_input
self.load_aux_output = load_aux_output self.load_aux_output = load_aux_output
self.dist_sampler = dist_sampler
# read json data # read json data
with jsonlines.open(json_file, 'r') as reader: with jsonlines.open(json_file, 'r') as reader:
...@@ -145,11 +148,18 @@ class BatchDataLoader(): ...@@ -145,11 +148,18 @@ class BatchDataLoader():
self.dataset = TransformDataset(self.minibaches, self.converter, self.dataset = TransformDataset(self.minibaches, self.converter,
self.reader) self.reader)
self.sampler = DistributedBatchSampler( if self.dist_sampler:
dataset=self.dataset, self.sampler = DistributedBatchSampler(
batch_size=1, dataset=self.dataset,
shuffle=not self.use_sortagrad if self.train_mode else False, batch_size=1,
) shuffle=not self.use_sortagrad if self.train_mode else False,
drop_last=False, )
else:
self.sampler = BatchSampler(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False,
drop_last=False, )
self.dataloader = DataLoader( self.dataloader = DataLoader(
dataset=self.dataset, dataset=self.dataset,
...@@ -181,5 +191,8 @@ class BatchDataLoader(): ...@@ -181,5 +191,8 @@ class BatchDataLoader():
echo += f"subsampling_factor: {self.subsampling_factor}, " echo += f"subsampling_factor: {self.subsampling_factor}, "
echo += f"num_encs: {self.num_encs}, " echo += f"num_encs: {self.num_encs}, "
echo += f"num_workers: {self.n_iter_processes}, " echo += f"num_workers: {self.n_iter_processes}, "
echo += f"load_aux_input: {self.load_aux_input}, "
echo += f"load_aux_output: {self.load_aux_output}, "
echo += f"dist_sampler: {self.dist_sampler}, "
echo += f"file: {self.json_file}" echo += f"file: {self.json_file}"
return echo return echo
...@@ -203,12 +203,15 @@ def evaluate(args): ...@@ -203,12 +203,15 @@ def evaluate(args):
get_tone_ids = True get_tone_ids = True
if args.lang == 'zh': if args.lang == 'zh':
input_ids = frontend.get_input_ids( input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
if get_tone_ids: if get_tone_ids:
tone_ids = input_ids["tone_ids"] tone_ids = input_ids["tone_ids"]
elif args.lang == 'en': elif args.lang == 'en':
input_ids = frontend.get_input_ids(sentence, merge_sentences=merge_sentences) input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
else: else:
print("lang should in {'zh', 'en'}!") print("lang should in {'zh', 'en'}!")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册