提交 26ff946f 编写于 作者: G guosheng

Update Transformer dataloader, fit, parallel.

上级 0e47f4c4
......@@ -257,23 +257,21 @@ class Seq2SeqDataset(Dataset):
def load_src_trg_ids(self, fpattern, tar_fname):
converters = [
Converter(
vocab=self._src_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False)
Converter(vocab=self._src_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False)
]
if not self._only_src:
converters.append(
Converter(
vocab=self._trg_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=True))
Converter(vocab=self._trg_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=True))
converters = ComposedConverter(converters)
......@@ -301,8 +299,9 @@ class Seq2SeqDataset(Dataset):
f = tarfile.open(fpaths[0], "rb")
for line in f.extractfile(tar_fname):
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
if (not self._only_src
and len(fields) == 2) or (self._only_src
and len(fields) == 1):
yield fields
else:
for fpath in fpaths:
......@@ -332,7 +331,8 @@ class Seq2SeqDataset(Dataset):
self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx
def __getitem__(self, idx):
return (self._src_seq_ids[idx], self._trg_seq_ids[idx]
return (self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
self._trg_seq_ids[idx][1:]
) if not self._only_src else self._src_seq_ids[idx]
def __len__(self):
......@@ -365,13 +365,14 @@ class Seq2SeqBatchSampler(BatchSampler):
def __iter__(self):
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
infos = sorted(self.dataset._sample_infos, key=lambda x: x.max_len)
infos = sorted(self._dataset._sample_infos,
key=lambda x: x.max_len)
else:
if self._shuffle:
infos = self.dataset._sample_infos
infos = self._dataset._sample_infos
self._random.shuffle(infos)
else:
infos = self.dataset._sample_infos
infos = self._dataset._sample_infos
if self._sort_type == SortType.POOL:
reverse = True
......@@ -385,9 +386,9 @@ class Seq2SeqBatchSampler(BatchSampler):
batches = []
batch_creator = TokenBatchCreator(
self.
_batch_size) if self._use_token_batch else SentenceBatchCreator(
self._batch_size * self._nranks)
self._batch_size
) if self._use_token_batch else SentenceBatchCreator(self._batch_size *
self._nranks)
batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator)
......@@ -422,8 +423,4 @@ class Seq2SeqBatchSampler(BatchSampler):
yield batch_indices
def __len__(self):
pass
@property
def dev_id(self):
return self._dev_id
return 100
......@@ -123,6 +123,7 @@ def do_train(args):
num_workers=0,
return_list=True)
transformer = Transformer(
args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册