提交 26ff946f 编写于 作者: G guosheng

Update Transformer dataloader, fit, parallel.

上级 0e47f4c4
...@@ -257,23 +257,21 @@ class Seq2SeqDataset(Dataset): ...@@ -257,23 +257,21 @@ class Seq2SeqDataset(Dataset):
def load_src_trg_ids(self, fpattern, tar_fname): def load_src_trg_ids(self, fpattern, tar_fname):
converters = [ converters = [
Converter( Converter(vocab=self._src_vocab,
vocab=self._src_vocab, beg=self._bos_idx,
beg=self._bos_idx, end=self._eos_idx,
end=self._eos_idx, unk=self._unk_idx,
unk=self._unk_idx, delimiter=self._token_delimiter,
delimiter=self._token_delimiter, add_beg=False)
add_beg=False)
] ]
if not self._only_src: if not self._only_src:
converters.append( converters.append(
Converter( Converter(vocab=self._trg_vocab,
vocab=self._trg_vocab, beg=self._bos_idx,
beg=self._bos_idx, end=self._eos_idx,
end=self._eos_idx, unk=self._unk_idx,
unk=self._unk_idx, delimiter=self._token_delimiter,
delimiter=self._token_delimiter, add_beg=True))
add_beg=True))
converters = ComposedConverter(converters) converters = ComposedConverter(converters)
...@@ -301,8 +299,9 @@ class Seq2SeqDataset(Dataset): ...@@ -301,8 +299,9 @@ class Seq2SeqDataset(Dataset):
f = tarfile.open(fpaths[0], "rb") f = tarfile.open(fpaths[0], "rb")
for line in f.extractfile(tar_fname): for line in f.extractfile(tar_fname):
fields = line.strip(b"\n").split(self._field_delimiter) fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or ( if (not self._only_src
self._only_src and len(fields) == 1): and len(fields) == 2) or (self._only_src
and len(fields) == 1):
yield fields yield fields
else: else:
for fpath in fpaths: for fpath in fpaths:
...@@ -332,7 +331,8 @@ class Seq2SeqDataset(Dataset): ...@@ -332,7 +331,8 @@ class Seq2SeqDataset(Dataset):
self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx
def __getitem__(self, idx): def __getitem__(self, idx):
return (self._src_seq_ids[idx], self._trg_seq_ids[idx] return (self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
self._trg_seq_ids[idx][1:]
) if not self._only_src else self._src_seq_ids[idx] ) if not self._only_src else self._src_seq_ids[idx]
def __len__(self): def __len__(self):
...@@ -365,13 +365,14 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -365,13 +365,14 @@ class Seq2SeqBatchSampler(BatchSampler):
def __iter__(self): def __iter__(self):
# global sort or global shuffle # global sort or global shuffle
if self._sort_type == SortType.GLOBAL: if self._sort_type == SortType.GLOBAL:
infos = sorted(self.dataset._sample_infos, key=lambda x: x.max_len) infos = sorted(self._dataset._sample_infos,
key=lambda x: x.max_len)
else: else:
if self._shuffle: if self._shuffle:
infos = self.dataset._sample_infos infos = self._dataset._sample_infos
self._random.shuffle(infos) self._random.shuffle(infos)
else: else:
infos = self.dataset._sample_infos infos = self._dataset._sample_infos
if self._sort_type == SortType.POOL: if self._sort_type == SortType.POOL:
reverse = True reverse = True
...@@ -385,9 +386,9 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -385,9 +386,9 @@ class Seq2SeqBatchSampler(BatchSampler):
batches = [] batches = []
batch_creator = TokenBatchCreator( batch_creator = TokenBatchCreator(
self. self._batch_size
_batch_size) if self._use_token_batch else SentenceBatchCreator( ) if self._use_token_batch else SentenceBatchCreator(self._batch_size *
self._batch_size * self._nranks) self._nranks)
batch_creator = MinMaxFilter(self._max_length, self._min_length, batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator) batch_creator)
...@@ -422,8 +423,4 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -422,8 +423,4 @@ class Seq2SeqBatchSampler(BatchSampler):
yield batch_indices yield batch_indices
def __len__(self): def __len__(self):
pass return 100
@property
def dev_id(self):
return self._dev_id
...@@ -123,6 +123,7 @@ def do_train(args): ...@@ -123,6 +123,7 @@ def do_train(args):
num_workers=0, num_workers=0,
return_list=True) return_list=True)
transformer = Transformer( transformer = Transformer(
args.src_vocab_size, args.trg_vocab_size, args.max_length + 1, args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model, args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册