提交 ee442428 编写于 作者: G guosheng

Fix distribute BatchSampler

上级 6431daed
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,8 +15,9 @@
import glob
import six
import os
import tarfile
import io
import itertools
from functools import partial
import numpy as np
import paddle.fluid as fluid
......@@ -24,16 +25,67 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import BatchSampler, DataLoader, Dataset
def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
def create_data_loader(args, device):
data_loaders = [None, None]
data_files = [args.training_file, args.validation_file
] if args.validation_file else [args.training_file]
for i, data_file in enumerate(data_files):
dataset = Seq2SeqDataset(
fpattern=data_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
byte_data=True)
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler(
dataset=dataset,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
max_length=args.max_length,
distribute_mode=True
if i == 0 else False) # every device eval all data
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
collate_fn=partial(
prepare_train_input,
bos_idx=args.bos_idx,
eos_idx=args.eos_idx,
src_pad_idx=args.eos_idx,
trg_pad_idx=args.eos_idx,
n_head=args.n_head),
num_workers=0, # TODO: use multi-process
return_list=True)
data_loaders[i] = data_loader
return data_loaders
def prepare_train_input(insts, bos_idx, eos_idx, src_pad_idx, trg_pad_idx,
n_head):
"""
Put all padded data needed by training into a list.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
[inst[0] + [eos_idx] for inst in insts],
src_pad_idx,
n_head,
is_target=False)
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
[[bos_idx] + inst[1] for inst in insts],
trg_pad_idx,
n_head,
is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len)
trg_pos = trg_pos.reshape(-1, trg_max_len)
......@@ -41,7 +93,7 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
[1, 1, trg_max_len, 1]).astype("float32")
lbl_word, lbl_weight, num_token = pad_batch_data(
[inst[2] for inst in insts],
[inst[1] + [eos_idx] for inst in insts],
trg_pad_idx,
n_head,
is_target=False,
......@@ -71,9 +123,7 @@ def prepare_infer_input(insts, src_pad_idx, n_head):
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias
]
data_inputs = [src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias]
return data_inputs
......@@ -142,29 +192,30 @@ class SortType(object):
class Converter(object):
def __init__(self, vocab, beg, end, unk, delimiter, add_beg):
def __init__(self, vocab, beg, end, unk, delimiter, add_beg, add_end):
self._vocab = vocab
self._beg = beg
self._end = end
self._unk = unk
self._delimiter = delimiter
self._add_beg = add_beg
self._add_end = add_end
def __call__(self, sentence):
return ([self._beg] if self._add_beg else []) + [
self._vocab.get(w, self._unk)
for w in sentence.split(self._delimiter)
] + [self._end]
] + ([self._end] if self._add_end else [])
class ComposedConverter(object):
def __init__(self, converters):
self._converters = converters
def __call__(self, parallel_sentence):
def __call__(self, fields):
return [
self._converters[i](parallel_sentence[i])
for i in range(len(self._converters))
converter(field)
for field, converter in zip(fields, self._converters)
]
......@@ -201,10 +252,11 @@ class TokenBatchCreator(object):
class SampleInfo(object):
def __init__(self, i, max_len, min_len):
def __init__(self, i, lens):
self.i = i
self.min_len = min_len
self.max_len = max_len
# take bos and eos into account
self.min_len = min(lens[0] + 1, lens[1] + 2)
self.max_len = max(lens[0] + 1, lens[1] + 2)
class MinMaxFilter(object):
......@@ -229,98 +281,109 @@ class Seq2SeqDataset(Dataset):
src_vocab_fpath,
trg_vocab_fpath,
fpattern,
tar_fname=None,
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
only_src=False):
# convert str to bytes, and use byte data
field_delimiter = field_delimiter.encode("utf8")
token_delimiter = token_delimiter.encode("utf8")
start_mark = start_mark.encode("utf8")
end_mark = end_mark.encode("utf8")
unk_mark = unk_mark.encode("utf8")
self._src_vocab = self.load_dict(src_vocab_fpath)
self._trg_vocab = self.load_dict(trg_vocab_fpath)
only_src=False,
trg_fpattern=None,
byte_data=False):
if byte_data:
# The WMT16 bpe data used here seems including bytes can not be
# decoded by utf8. Thus convert str to bytes, and use byte data
field_delimiter = field_delimiter.encode("utf8")
token_delimiter = token_delimiter.encode("utf8")
start_mark = start_mark.encode("utf8")
end_mark = end_mark.encode("utf8")
unk_mark = unk_mark.encode("utf8")
self._byte_data = byte_data
self._src_vocab = self.load_dict(src_vocab_fpath, byte_data=byte_data)
self._trg_vocab = self.load_dict(trg_vocab_fpath, byte_data=byte_data)
self._bos_idx = self._src_vocab[start_mark]
self._eos_idx = self._src_vocab[end_mark]
self._unk_idx = self._src_vocab[unk_mark]
self._only_src = only_src
self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
self.load_src_trg_ids(fpattern, tar_fname)
def load_src_trg_ids(self, fpattern, tar_fname):
converters = [
Converter(vocab=self._src_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False)
]
if not self._only_src:
converters.append(
Converter(vocab=self._trg_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=True))
converters = ComposedConverter(converters)
self.load_src_trg_ids(fpattern, trg_fpattern)
def load_src_trg_ids(self, fpattern, trg_fpattern=None):
src_converter = Converter(
vocab=self._src_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False,
add_end=False)
trg_converter = Converter(
vocab=self._trg_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False,
add_end=False)
converters = ComposedConverter([src_converter, trg_converter])
self._src_seq_ids = []
self._trg_seq_ids = None if self._only_src else []
self._trg_seq_ids = []
self._sample_infos = []
for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
src_trg_ids = converters(line)
self._src_seq_ids.append(src_trg_ids[0])
lens = [len(src_trg_ids[0])]
if not self._only_src:
self._trg_seq_ids.append(src_trg_ids[1])
lens.append(len(src_trg_ids[1]))
self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
slots = [self._src_seq_ids, self._trg_seq_ids]
for i, line in enumerate(self._load_lines(fpattern, trg_fpattern)):
lens = []
for field, slot in zip(converters(line), slots):
slot.append(field)
lens.append(len(field))
self._sample_infos.append(SampleInfo(i, lens))
def _load_lines(self, fpattern, tar_fname):
def _load_lines(self, fpattern, trg_fpattern=None):
fpaths = glob.glob(fpattern)
fpaths = sorted(fpaths) # TODO: Add custum sort
assert len(fpaths) > 0, "no matching file to the provided data path"
if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.")
f = tarfile.open(fpaths[0], "rb")
for line in f.extractfile(tar_fname):
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src
and len(fields) == 2) or (self._only_src
and len(fields) == 1):
yield fields
else:
(f_mode, f_encoding,
endl) = ("rb", None, b"\n") if self._byte_data else ("r", "utf8",
"\n")
if trg_fpattern is None:
for fpath in fpaths:
if not os.path.isfile(fpath):
raise IOError("Invalid file: %s" % fpath)
with open(fpath, "rb") as f:
with io.open(fpath, f_mode, encoding=f_encoding) as f:
for line in f:
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
fields = line.strip(endl).split(self._field_delimiter)
yield fields
else:
# separated source and target language data files
# assume we can get aligned data by sort the two language files
# TODO: Need more rigorous check
trg_fpaths = glob.glob(trg_fpattern)
trg_fpaths = sorted(trg_fpaths)
assert len(fpaths) == len(
trg_fpaths
), "the number of source language data files must equal \
with that of source language"
for fpath, trg_fpath in zip(fpaths, trg_fpaths):
with io.open(fpath, f_mode, encoding=f_encoding) as f:
with io.open(
trg_fpath, f_mode, encoding=f_encoding) as trg_f:
for line in zip(f, trg_f):
fields = [field.strip(endl) for field in line]
yield fields
@staticmethod
def load_dict(dict_path, reverse=False):
def load_dict(dict_path, reverse=False, byte_data=False):
word_dict = {}
with open(dict_path, "rb") as fdict:
(f_mode, f_encoding,
endl) = ("rb", None, b"\n") if byte_data else ("r", "utf8", "\n")
with io.open(dict_path, f_mode, encoding=f_encoding) as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip(b"\n")
word_dict[idx] = line.strip(endl)
else:
word_dict[line.strip(b"\n")] = idx
word_dict[line.strip(endl)] = idx
return word_dict
def get_vocab_summary(self):
......@@ -328,9 +391,8 @@ class Seq2SeqDataset(Dataset):
self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx
def __getitem__(self, idx):
return (self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
self._trg_seq_ids[idx][1:]
) if not self._only_src else self._src_seq_ids[idx]
return (self._src_seq_ids[idx], self._trg_seq_ids[idx]
) if self._trg_seq_ids else self._src_seq_ids[idx]
def __len__(self):
return len(self._sample_infos)
......@@ -348,6 +410,7 @@ class Seq2SeqBatchSampler(BatchSampler):
shuffle_batch=False,
use_token_batch=False,
clip_last_batch=False,
distribute_mode=True,
seed=0):
for arg, value in locals().items():
if arg != "self":
......@@ -355,6 +418,7 @@ class Seq2SeqBatchSampler(BatchSampler):
self._random = np.random
self._random.seed(seed)
# for multi-devices
self._distribute_mode = distribute_mode
self._nranks = ParallelEnv().nranks
self._local_rank = ParallelEnv().local_rank
self._device_id = ParallelEnv().dev_id
......@@ -362,8 +426,8 @@ class Seq2SeqBatchSampler(BatchSampler):
def __iter__(self):
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
infos = sorted(self._dataset._sample_infos,
key=lambda x: x.max_len)
infos = sorted(
self._dataset._sample_infos, key=lambda x: x.max_len)
else:
if self._shuffle:
infos = self._dataset._sample_infos
......@@ -383,9 +447,9 @@ class Seq2SeqBatchSampler(BatchSampler):
batches = []
batch_creator = TokenBatchCreator(
self._batch_size
) if self._use_token_batch else SentenceBatchCreator(self._batch_size *
self._nranks)
self.
_batch_size) if self._use_token_batch else SentenceBatchCreator(
self._batch_size * self._nranks)
batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator)
......@@ -413,11 +477,15 @@ class Seq2SeqBatchSampler(BatchSampler):
# for multi-device
for batch_id, batch in enumerate(batches):
if batch_id % self._nranks == self._local_rank:
if not self._distribute_mode or (
batch_id % self._nranks == self._local_rank):
batch_indices = [info.i for info in batch]
yield batch_indices
if self._local_rank > len(batches) % self._nranks:
yield batch_indices
if self._distribute_mode and len(batches) % self._nranks != 0:
if self._local_rank >= len(batches) % self._nranks:
# use previous data to pad
yield batch_indices
def __len__(self):
return 100
# TODO(guosheng): fix the uncertain length
return 0
......@@ -17,7 +17,6 @@ import os
import six
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from functools import partial
import numpy as np
import paddle
......@@ -29,14 +28,18 @@ from utils.check import check_gpu, check_version
from model import Input, set_device
from callbacks import ProgBarLogger
from reader import prepare_train_input, Seq2SeqDataset, Seq2SeqBatchSampler
from reader import create_data_loader
from transformer import Transformer, CrossEntropyCriterion
class TrainCallback(ProgBarLogger):
def __init__(self, log_freq=1, verbose=2, loss_normalizer=0.):
super(TrainCallback, self).__init__(log_freq, verbose)
# TODO: wrap these override function to simplify
def __init__(self, args, verbose=2):
super(TrainCallback, self).__init__(args.print_step, verbose)
# the best cross-entropy value with label smoothing
loss_normalizer = -(
(1. - args.label_smooth_eps) * np.log(
(1. - args.label_smooth_eps)) + args.label_smooth_eps *
np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
self.loss_normalizer = loss_normalizer
def on_train_begin(self, logs=None):
......@@ -100,42 +103,7 @@ def do_train(args):
]
# def dataloader
data_loaders = [None, None]
data_files = [args.training_file, args.validation_file
] if args.validation_file else [args.training_file]
for i, data_file in enumerate(data_files):
dataset = Seq2SeqDataset(
fpattern=data_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2])
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler(
dataset=dataset,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
max_length=args.max_length)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
collate_fn=partial(
prepare_train_input,
src_pad_idx=args.eos_idx,
trg_pad_idx=args.eos_idx,
n_head=args.n_head),
num_workers=0, # TODO: use multi-process
return_list=True)
data_loaders[i] = data_loader
train_loader, eval_loader = data_loaders
train_loader, eval_loader = create_data_loader(args, device)
# define model
transformer = Transformer(
......@@ -166,12 +134,6 @@ def do_train(args):
if args.init_from_pretrain_model:
transformer.load(args.init_from_pretrain_model, reset_optimizer=True)
# the best cross-entropy value with label smoothing
loss_normalizer = -(
(1. - args.label_smooth_eps) * np.log(
(1. - args.label_smooth_eps)) + args.label_smooth_eps *
np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
# model train
transformer.fit(train_data=train_loader,
eval_data=eval_loader,
......@@ -180,11 +142,7 @@ def do_train(args):
save_freq=1,
save_dir=args.save_model,
verbose=2,
callbacks=[
TrainCallback(
log_freq=args.print_step,
loss_normalizer=loss_normalizer)
])
callbacks=[TrainCallback(args)])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册