提交 ee442428 编写于 作者: G guosheng

Fix distribute BatchSampler

上级 6431daed
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
import glob import glob
import six import six
import os import os
import tarfile import io
import itertools import itertools
from functools import partial
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -24,16 +25,67 @@ from paddle.fluid.dygraph.parallel import ParallelEnv ...@@ -24,16 +25,67 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import BatchSampler, DataLoader, Dataset from paddle.fluid.io import BatchSampler, DataLoader, Dataset
def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head): def create_data_loader(args, device):
data_loaders = [None, None]
data_files = [args.training_file, args.validation_file
] if args.validation_file else [args.training_file]
for i, data_file in enumerate(data_files):
dataset = Seq2SeqDataset(
fpattern=data_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
byte_data=True)
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler(
dataset=dataset,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
max_length=args.max_length,
distribute_mode=True
if i == 0 else False) # every device eval all data
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
collate_fn=partial(
prepare_train_input,
bos_idx=args.bos_idx,
eos_idx=args.eos_idx,
src_pad_idx=args.eos_idx,
trg_pad_idx=args.eos_idx,
n_head=args.n_head),
num_workers=0, # TODO: use multi-process
return_list=True)
data_loaders[i] = data_loader
return data_loaders
def prepare_train_input(insts, bos_idx, eos_idx, src_pad_idx, trg_pad_idx,
n_head):
""" """
Put all padded data needed by training into a list. Put all padded data needed by training into a list.
""" """
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) [inst[0] + [eos_idx] for inst in insts],
src_pad_idx,
n_head,
is_target=False)
src_word = src_word.reshape(-1, src_max_len) src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len) src_pos = src_pos.reshape(-1, src_max_len)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data( trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) [[bos_idx] + inst[1] for inst in insts],
trg_pad_idx,
n_head,
is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len) trg_word = trg_word.reshape(-1, trg_max_len)
trg_pos = trg_pos.reshape(-1, trg_max_len) trg_pos = trg_pos.reshape(-1, trg_max_len)
...@@ -41,7 +93,7 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head): ...@@ -41,7 +93,7 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
[1, 1, trg_max_len, 1]).astype("float32") [1, 1, trg_max_len, 1]).astype("float32")
lbl_word, lbl_weight, num_token = pad_batch_data( lbl_word, lbl_weight, num_token = pad_batch_data(
[inst[2] for inst in insts], [inst[1] + [eos_idx] for inst in insts],
trg_pad_idx, trg_pad_idx,
n_head, n_head,
is_target=False, is_target=False,
...@@ -71,9 +123,7 @@ def prepare_infer_input(insts, src_pad_idx, n_head): ...@@ -71,9 +123,7 @@ def prepare_infer_input(insts, src_pad_idx, n_head):
src_word = src_word.reshape(-1, src_max_len) src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len) src_pos = src_pos.reshape(-1, src_max_len)
data_inputs = [ data_inputs = [src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias]
src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias
]
return data_inputs return data_inputs
...@@ -142,29 +192,30 @@ class SortType(object): ...@@ -142,29 +192,30 @@ class SortType(object):
class Converter(object): class Converter(object):
def __init__(self, vocab, beg, end, unk, delimiter, add_beg): def __init__(self, vocab, beg, end, unk, delimiter, add_beg, add_end):
self._vocab = vocab self._vocab = vocab
self._beg = beg self._beg = beg
self._end = end self._end = end
self._unk = unk self._unk = unk
self._delimiter = delimiter self._delimiter = delimiter
self._add_beg = add_beg self._add_beg = add_beg
self._add_end = add_end
def __call__(self, sentence): def __call__(self, sentence):
return ([self._beg] if self._add_beg else []) + [ return ([self._beg] if self._add_beg else []) + [
self._vocab.get(w, self._unk) self._vocab.get(w, self._unk)
for w in sentence.split(self._delimiter) for w in sentence.split(self._delimiter)
] + [self._end] ] + ([self._end] if self._add_end else [])
class ComposedConverter(object): class ComposedConverter(object):
def __init__(self, converters): def __init__(self, converters):
self._converters = converters self._converters = converters
def __call__(self, parallel_sentence): def __call__(self, fields):
return [ return [
self._converters[i](parallel_sentence[i]) converter(field)
for i in range(len(self._converters)) for field, converter in zip(fields, self._converters)
] ]
...@@ -201,10 +252,11 @@ class TokenBatchCreator(object): ...@@ -201,10 +252,11 @@ class TokenBatchCreator(object):
class SampleInfo(object): class SampleInfo(object):
def __init__(self, i, max_len, min_len): def __init__(self, i, lens):
self.i = i self.i = i
self.min_len = min_len # take bos and eos into account
self.max_len = max_len self.min_len = min(lens[0] + 1, lens[1] + 2)
self.max_len = max(lens[0] + 1, lens[1] + 2)
class MinMaxFilter(object): class MinMaxFilter(object):
...@@ -229,98 +281,109 @@ class Seq2SeqDataset(Dataset): ...@@ -229,98 +281,109 @@ class Seq2SeqDataset(Dataset):
src_vocab_fpath, src_vocab_fpath,
trg_vocab_fpath, trg_vocab_fpath,
fpattern, fpattern,
tar_fname=None,
field_delimiter="\t", field_delimiter="\t",
token_delimiter=" ", token_delimiter=" ",
start_mark="<s>", start_mark="<s>",
end_mark="<e>", end_mark="<e>",
unk_mark="<unk>", unk_mark="<unk>",
only_src=False): only_src=False,
# convert str to bytes, and use byte data trg_fpattern=None,
byte_data=False):
if byte_data:
# The WMT16 bpe data used here seems including bytes can not be
# decoded by utf8. Thus convert str to bytes, and use byte data
field_delimiter = field_delimiter.encode("utf8") field_delimiter = field_delimiter.encode("utf8")
token_delimiter = token_delimiter.encode("utf8") token_delimiter = token_delimiter.encode("utf8")
start_mark = start_mark.encode("utf8") start_mark = start_mark.encode("utf8")
end_mark = end_mark.encode("utf8") end_mark = end_mark.encode("utf8")
unk_mark = unk_mark.encode("utf8") unk_mark = unk_mark.encode("utf8")
self._src_vocab = self.load_dict(src_vocab_fpath) self._byte_data = byte_data
self._trg_vocab = self.load_dict(trg_vocab_fpath) self._src_vocab = self.load_dict(src_vocab_fpath, byte_data=byte_data)
self._trg_vocab = self.load_dict(trg_vocab_fpath, byte_data=byte_data)
self._bos_idx = self._src_vocab[start_mark] self._bos_idx = self._src_vocab[start_mark]
self._eos_idx = self._src_vocab[end_mark] self._eos_idx = self._src_vocab[end_mark]
self._unk_idx = self._src_vocab[unk_mark] self._unk_idx = self._src_vocab[unk_mark]
self._only_src = only_src
self._field_delimiter = field_delimiter self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter self._token_delimiter = token_delimiter
self.load_src_trg_ids(fpattern, tar_fname) self.load_src_trg_ids(fpattern, trg_fpattern)
def load_src_trg_ids(self, fpattern, tar_fname): def load_src_trg_ids(self, fpattern, trg_fpattern=None):
converters = [ src_converter = Converter(
Converter(vocab=self._src_vocab, vocab=self._src_vocab,
beg=self._bos_idx, beg=self._bos_idx,
end=self._eos_idx, end=self._eos_idx,
unk=self._unk_idx, unk=self._unk_idx,
delimiter=self._token_delimiter, delimiter=self._token_delimiter,
add_beg=False) add_beg=False,
] add_end=False)
if not self._only_src:
converters.append( trg_converter = Converter(
Converter(vocab=self._trg_vocab, vocab=self._trg_vocab,
beg=self._bos_idx, beg=self._bos_idx,
end=self._eos_idx, end=self._eos_idx,
unk=self._unk_idx, unk=self._unk_idx,
delimiter=self._token_delimiter, delimiter=self._token_delimiter,
add_beg=True)) add_beg=False,
add_end=False)
converters = ComposedConverter(converters) converters = ComposedConverter([src_converter, trg_converter])
self._src_seq_ids = [] self._src_seq_ids = []
self._trg_seq_ids = None if self._only_src else [] self._trg_seq_ids = []
self._sample_infos = [] self._sample_infos = []
for i, line in enumerate(self._load_lines(fpattern, tar_fname)): slots = [self._src_seq_ids, self._trg_seq_ids]
src_trg_ids = converters(line) for i, line in enumerate(self._load_lines(fpattern, trg_fpattern)):
self._src_seq_ids.append(src_trg_ids[0]) lens = []
lens = [len(src_trg_ids[0])] for field, slot in zip(converters(line), slots):
if not self._only_src: slot.append(field)
self._trg_seq_ids.append(src_trg_ids[1]) lens.append(len(field))
lens.append(len(src_trg_ids[1])) self._sample_infos.append(SampleInfo(i, lens))
self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
def _load_lines(self, fpattern, tar_fname): def _load_lines(self, fpattern, trg_fpattern=None):
fpaths = glob.glob(fpattern) fpaths = glob.glob(fpattern)
fpaths = sorted(fpaths) # TODO: Add custum sort
assert len(fpaths) > 0, "no matching file to the provided data path" assert len(fpaths) > 0, "no matching file to the provided data path"
if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]): (f_mode, f_encoding,
if tar_fname is None: endl) = ("rb", None, b"\n") if self._byte_data else ("r", "utf8",
raise Exception("If tar file provided, please set tar_fname.") "\n")
if trg_fpattern is None:
f = tarfile.open(fpaths[0], "rb")
for line in f.extractfile(tar_fname):
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src
and len(fields) == 2) or (self._only_src
and len(fields) == 1):
yield fields
else:
for fpath in fpaths: for fpath in fpaths:
if not os.path.isfile(fpath): with io.open(fpath, f_mode, encoding=f_encoding) as f:
raise IOError("Invalid file: %s" % fpath)
with open(fpath, "rb") as f:
for line in f: for line in f:
fields = line.strip(b"\n").split(self._field_delimiter) fields = line.strip(endl).split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or ( yield fields
self._only_src and len(fields) == 1): else:
# separated source and target language data files
# assume we can get aligned data by sort the two language files
# TODO: Need more rigorous check
trg_fpaths = glob.glob(trg_fpattern)
trg_fpaths = sorted(trg_fpaths)
assert len(fpaths) == len(
trg_fpaths
), "the number of source language data files must equal \
with that of source language"
for fpath, trg_fpath in zip(fpaths, trg_fpaths):
with io.open(fpath, f_mode, encoding=f_encoding) as f:
with io.open(
trg_fpath, f_mode, encoding=f_encoding) as trg_f:
for line in zip(f, trg_f):
fields = [field.strip(endl) for field in line]
yield fields yield fields
@staticmethod @staticmethod
def load_dict(dict_path, reverse=False): def load_dict(dict_path, reverse=False, byte_data=False):
word_dict = {} word_dict = {}
with open(dict_path, "rb") as fdict: (f_mode, f_encoding,
endl) = ("rb", None, b"\n") if byte_data else ("r", "utf8", "\n")
with io.open(dict_path, f_mode, encoding=f_encoding) as fdict:
for idx, line in enumerate(fdict): for idx, line in enumerate(fdict):
if reverse: if reverse:
word_dict[idx] = line.strip(b"\n") word_dict[idx] = line.strip(endl)
else: else:
word_dict[line.strip(b"\n")] = idx word_dict[line.strip(endl)] = idx
return word_dict return word_dict
def get_vocab_summary(self): def get_vocab_summary(self):
...@@ -328,9 +391,8 @@ class Seq2SeqDataset(Dataset): ...@@ -328,9 +391,8 @@ class Seq2SeqDataset(Dataset):
self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx
def __getitem__(self, idx): def __getitem__(self, idx):
return (self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1], return (self._src_seq_ids[idx], self._trg_seq_ids[idx]
self._trg_seq_ids[idx][1:] ) if self._trg_seq_ids else self._src_seq_ids[idx]
) if not self._only_src else self._src_seq_ids[idx]
def __len__(self): def __len__(self):
return len(self._sample_infos) return len(self._sample_infos)
...@@ -348,6 +410,7 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -348,6 +410,7 @@ class Seq2SeqBatchSampler(BatchSampler):
shuffle_batch=False, shuffle_batch=False,
use_token_batch=False, use_token_batch=False,
clip_last_batch=False, clip_last_batch=False,
distribute_mode=True,
seed=0): seed=0):
for arg, value in locals().items(): for arg, value in locals().items():
if arg != "self": if arg != "self":
...@@ -355,6 +418,7 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -355,6 +418,7 @@ class Seq2SeqBatchSampler(BatchSampler):
self._random = np.random self._random = np.random
self._random.seed(seed) self._random.seed(seed)
# for multi-devices # for multi-devices
self._distribute_mode = distribute_mode
self._nranks = ParallelEnv().nranks self._nranks = ParallelEnv().nranks
self._local_rank = ParallelEnv().local_rank self._local_rank = ParallelEnv().local_rank
self._device_id = ParallelEnv().dev_id self._device_id = ParallelEnv().dev_id
...@@ -362,8 +426,8 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -362,8 +426,8 @@ class Seq2SeqBatchSampler(BatchSampler):
def __iter__(self): def __iter__(self):
# global sort or global shuffle # global sort or global shuffle
if self._sort_type == SortType.GLOBAL: if self._sort_type == SortType.GLOBAL:
infos = sorted(self._dataset._sample_infos, infos = sorted(
key=lambda x: x.max_len) self._dataset._sample_infos, key=lambda x: x.max_len)
else: else:
if self._shuffle: if self._shuffle:
infos = self._dataset._sample_infos infos = self._dataset._sample_infos
...@@ -383,9 +447,9 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -383,9 +447,9 @@ class Seq2SeqBatchSampler(BatchSampler):
batches = [] batches = []
batch_creator = TokenBatchCreator( batch_creator = TokenBatchCreator(
self._batch_size self.
) if self._use_token_batch else SentenceBatchCreator(self._batch_size * _batch_size) if self._use_token_batch else SentenceBatchCreator(
self._nranks) self._batch_size * self._nranks)
batch_creator = MinMaxFilter(self._max_length, self._min_length, batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator) batch_creator)
...@@ -413,11 +477,15 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -413,11 +477,15 @@ class Seq2SeqBatchSampler(BatchSampler):
# for multi-device # for multi-device
for batch_id, batch in enumerate(batches): for batch_id, batch in enumerate(batches):
if batch_id % self._nranks == self._local_rank: if not self._distribute_mode or (
batch_id % self._nranks == self._local_rank):
batch_indices = [info.i for info in batch] batch_indices = [info.i for info in batch]
yield batch_indices yield batch_indices
if self._local_rank > len(batches) % self._nranks: if self._distribute_mode and len(batches) % self._nranks != 0:
if self._local_rank >= len(batches) % self._nranks:
# use previous data to pad
yield batch_indices yield batch_indices
def __len__(self): def __len__(self):
return 100 # TODO(guosheng): fix the uncertain length
return 0
...@@ -17,7 +17,6 @@ import os ...@@ -17,7 +17,6 @@ import os
import six import six
import sys import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from functools import partial
import numpy as np import numpy as np
import paddle import paddle
...@@ -29,14 +28,18 @@ from utils.check import check_gpu, check_version ...@@ -29,14 +28,18 @@ from utils.check import check_gpu, check_version
from model import Input, set_device from model import Input, set_device
from callbacks import ProgBarLogger from callbacks import ProgBarLogger
from reader import prepare_train_input, Seq2SeqDataset, Seq2SeqBatchSampler from reader import create_data_loader
from transformer import Transformer, CrossEntropyCriterion from transformer import Transformer, CrossEntropyCriterion
class TrainCallback(ProgBarLogger): class TrainCallback(ProgBarLogger):
def __init__(self, log_freq=1, verbose=2, loss_normalizer=0.): def __init__(self, args, verbose=2):
super(TrainCallback, self).__init__(log_freq, verbose) super(TrainCallback, self).__init__(args.print_step, verbose)
# TODO: wrap these override function to simplify # the best cross-entropy value with label smoothing
loss_normalizer = -(
(1. - args.label_smooth_eps) * np.log(
(1. - args.label_smooth_eps)) + args.label_smooth_eps *
np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
self.loss_normalizer = loss_normalizer self.loss_normalizer = loss_normalizer
def on_train_begin(self, logs=None): def on_train_begin(self, logs=None):
...@@ -100,42 +103,7 @@ def do_train(args): ...@@ -100,42 +103,7 @@ def do_train(args):
] ]
# def dataloader # def dataloader
data_loaders = [None, None] train_loader, eval_loader = create_data_loader(args, device)
data_files = [args.training_file, args.validation_file
] if args.validation_file else [args.training_file]
for i, data_file in enumerate(data_files):
dataset = Seq2SeqDataset(
fpattern=data_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2])
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler(
dataset=dataset,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
max_length=args.max_length)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
collate_fn=partial(
prepare_train_input,
src_pad_idx=args.eos_idx,
trg_pad_idx=args.eos_idx,
n_head=args.n_head),
num_workers=0, # TODO: use multi-process
return_list=True)
data_loaders[i] = data_loader
train_loader, eval_loader = data_loaders
# define model # define model
transformer = Transformer( transformer = Transformer(
...@@ -166,12 +134,6 @@ def do_train(args): ...@@ -166,12 +134,6 @@ def do_train(args):
if args.init_from_pretrain_model: if args.init_from_pretrain_model:
transformer.load(args.init_from_pretrain_model, reset_optimizer=True) transformer.load(args.init_from_pretrain_model, reset_optimizer=True)
# the best cross-entropy value with label smoothing
loss_normalizer = -(
(1. - args.label_smooth_eps) * np.log(
(1. - args.label_smooth_eps)) + args.label_smooth_eps *
np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
# model train # model train
transformer.fit(train_data=train_loader, transformer.fit(train_data=train_loader,
eval_data=eval_loader, eval_data=eval_loader,
...@@ -180,11 +142,7 @@ def do_train(args): ...@@ -180,11 +142,7 @@ def do_train(args):
save_freq=1, save_freq=1,
save_dir=args.save_model, save_dir=args.save_model,
verbose=2, verbose=2,
callbacks=[ callbacks=[TrainCallback(args)])
TrainCallback(
log_freq=args.print_step,
loss_normalizer=loss_normalizer)
])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册