未验证 提交 b460dfd0 编写于 作者: L liu zhengxi 提交者: GitHub

Support pad sequence and batch size can be multiplied by a factor (#5175)

* support pad seq and bsz multi

* alter vocab size

* alter md5sum

* root for reader
上级 fea4ea64
# The frequency to save trained models when training.
save_step: 10000
# The frequency to fetch and print output when training.
print_step: 100
# Path of the checkpoint, to resume the previous training
init_from_checkpoint: ""
# Path of the pretrain model, to better solve the current task
init_from_pretrain_model: ""
# Path of trained parameter, to make prediction
init_from_params: "./trained_models/step_final/"
# The directory for saving model
save_model: "trained_models"
# Set seed for CE or debug
random_seed: None
# The file to output the translation results of predict_file to.
output_file: "predict.txt"
# The <bos>, <eos> and <unk> tokens in the dictionary.
special_token: ["<s>", "<e>", "<unk>"]
# The directory to store data.
root: None
# Whether to use cuda
use_gpu: True
# Args for reader, see reader.py for details
pool_size: 200000
sort_type: "global"
batch_size: 4096
infer_batch_size: 64
shuffle_batch: True
# Data shuffle only works when sort_type is pool or none
shuffle: True
# shuffle_seed must be set when shuffle is True and using multi-cards to train.
# Otherwise, the number of batches cannot be guaranteed.
shuffle_seed: 128
# Hyparams for training:
# The number of epoches for training
epoch: 30
# The hyper parameters for Adam optimizer.
# This static learning_rate will be applied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate: 2.0
beta1: 0.9
beta2: 0.997
eps: 1e-9
# The parameters for learning rate scheduling.
warmup_steps: 8000
# The weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps: 0.1
# Hyparams for generation:
# The parameters for beam search.
beam_size: 5
max_out_len: 256
# The number of decoded sentences to output.
n_best: 1
# Hyparams for model:
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# Size of source word dictionary.
src_vocab_size: 10000
# Size of target word dictionay
trg_vocab_size: 10000
# Used to pad vocab size to be multiple of pad_factor.
pad_factor: 8
# Used to pad sequence length to be multiple of pad_seq.
pad_seq: 8
# Used to make batch size to be multiple of bsz_multi.
bsz_multi: 8
# Index for <bos> token
bos_idx: 0
# Index for <eos> token
eos_idx: 1
# Index for <unk> token
unk_idx: 2
# Max length of sequences deciding the size of position encoding table.
max_length: 256
# The dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model: 512
# Size of the hidden layer in position-wise feed-forward networks.
d_inner_hid: 2048
# Number of head used in multi-head attention.
n_head: 8
# Number of sub-layers to be stacked in the encoder and decoder.
n_layer: 6
# Dropout rates.
dropout: 0.1
# The flag indicating whether to share embedding and softmax weights.
# Vocabularies in source and target should be same for weight sharing.
weight_sharing: True
max_iter: None
...@@ -68,6 +68,10 @@ src_vocab_size: 10000 ...@@ -68,6 +68,10 @@ src_vocab_size: 10000
trg_vocab_size: 10000 trg_vocab_size: 10000
# Used to pad vocab size to be multiple of pad_factor. # Used to pad vocab size to be multiple of pad_factor.
pad_factor: 8 pad_factor: 8
# Used to pad sequence length to be multiple of pad_seq.
pad_seq: 8
# Used to make batch size to be multiple of bsz_multi.
bsz_multi: 8
# Index for <bos> token # Index for <bos> token
bos_idx: 0 bos_idx: 0
# Index for <eos> token # Index for <eos> token
......
...@@ -33,8 +33,15 @@ def min_max_filer(data, max_len, min_len=0): ...@@ -33,8 +33,15 @@ def min_max_filer(data, max_len, min_len=0):
return (data_min_len >= min_len) and (data_max_len <= max_len) return (data_min_len >= min_len) and (data_max_len <= max_len)
def create_data_loader(args, places=None): def create_data_loader(args, places=None, use_all_vocab=False):
root = None if args.root == "None" else args.root root = None if args.root == "None" else args.root
if not use_all_vocab:
WMT14ende.VOCAB_INFO = (os.path.join(
"WMT14.en-de", "wmt14_ende_data_bpe",
"vocab_all.bpe.33712"), os.path.join(
"WMT14.en-de", "wmt14_ende_data_bpe", "vocab_all.bpe.33712"),
"de485e3c2e17e23acf4b4b70b54682dd",
"de485e3c2e17e23acf4b4b70b54682dd")
(src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root) (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
padding_vocab = ( padding_vocab = (
lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor
...@@ -44,7 +51,8 @@ def create_data_loader(args, places=None): ...@@ -44,7 +51,8 @@ def create_data_loader(args, places=None):
transform_func = WMT14ende.get_default_transform_func(root=root) transform_func = WMT14ende.get_default_transform_func(root=root)
datasets = [ datasets = [
WMT14ende.get_datasets( WMT14ende.get_datasets(
mode=m, transform_func=transform_func) for m in ["train", "dev"] mode=m, root=root, transform_func=transform_func)
for m in ["train", "dev"]
] ]
data_loaders = [(None)] * 2 data_loaders = [(None)] * 2
...@@ -63,7 +71,9 @@ def create_data_loader(args, places=None): ...@@ -63,7 +71,9 @@ def create_data_loader(args, places=None):
max_length=args.max_length, max_length=args.max_length,
distribute_mode=True if i == 0 else False, distribute_mode=True if i == 0 else False,
world_size=dist.get_world_size(), world_size=dist.get_world_size(),
rank=dist.get_rank()) rank=dist.get_rank(),
pad_seq=args.pad_seq,
bsz_multi=args.bsz_multi)
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
...@@ -73,14 +83,22 @@ def create_data_loader(args, places=None): ...@@ -73,14 +83,22 @@ def create_data_loader(args, places=None):
prepare_train_input, prepare_train_input,
bos_idx=args.bos_idx, bos_idx=args.bos_idx,
eos_idx=args.eos_idx, eos_idx=args.eos_idx,
pad_idx=args.bos_idx), pad_idx=args.bos_idx,
pad_seq=args.pad_seq),
num_workers=0) num_workers=0)
data_loaders[i] = (data_loader) data_loaders[i] = (data_loader)
return data_loaders return data_loaders
def create_infer_loader(args): def create_infer_loader(args, use_all_vocab=False):
root = None if args.root == "None" else args.root root = None if args.root == "None" else args.root
if not use_all_vocab:
WMT14ende.VOCAB_INFO = (os.path.join(
"WMT14.en-de", "wmt14_ende_data_bpe",
"vocab_all.bpe.33712"), os.path.join(
"WMT14.en-de", "wmt14_ende_data_bpe", "vocab_all.bpe.33712"),
"de485e3c2e17e23acf4b4b70b54682dd",
"de485e3c2e17e23acf4b4b70b54682dd")
(src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root) (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
padding_vocab = ( padding_vocab = (
lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor
...@@ -89,7 +107,7 @@ def create_infer_loader(args): ...@@ -89,7 +107,7 @@ def create_infer_loader(args):
args.trg_vocab_size = padding_vocab(len(trg_vocab)) args.trg_vocab_size = padding_vocab(len(trg_vocab))
transform_func = WMT14ende.get_default_transform_func(root=root) transform_func = WMT14ende.get_default_transform_func(root=root)
dataset = WMT14ende.get_datasets( dataset = WMT14ende.get_datasets(
mode="test", transform_func=transform_func).filter( mode="test", root=root, transform_func=transform_func).filter(
partial( partial(
min_max_filer, max_len=args.max_length)) min_max_filer, max_len=args.max_length))
...@@ -103,33 +121,51 @@ def create_infer_loader(args): ...@@ -103,33 +121,51 @@ def create_infer_loader(args):
prepare_infer_input, prepare_infer_input,
bos_idx=args.bos_idx, bos_idx=args.bos_idx,
eos_idx=args.eos_idx, eos_idx=args.eos_idx,
pad_idx=args.bos_idx), pad_idx=args.bos_idx,
pad_seq=args.pad_seq),
num_workers=0, num_workers=0,
return_list=True) return_list=True)
return data_loader, trg_vocab.to_tokens return data_loader, trg_vocab.to_tokens
def prepare_train_input(insts, bos_idx, eos_idx, pad_idx): def prepare_train_input(insts, bos_idx, eos_idx, pad_idx, pad_seq=1):
""" """
Put all padded data needed by training into a list. Put all padded data needed by training into a list.
""" """
word_pad = Pad(pad_idx) word_pad = Pad(pad_idx)
src_word = word_pad([inst[0] + [eos_idx] for inst in insts]) src_max_len = (
trg_word = word_pad([[bos_idx] + inst[1] for inst in insts]) max([len(inst[0]) for inst in insts]) + pad_seq) // pad_seq * pad_seq
trg_max_len = (
max([len(inst[1]) for inst in insts]) + pad_seq) // pad_seq * pad_seq
src_word = word_pad([
inst[0] + [eos_idx] + [pad_idx] * (src_max_len - 1 - len(inst[0]))
for inst in insts
])
trg_word = word_pad([[bos_idx] + inst[1] + [pad_idx] *
(trg_max_len - 1 - len(inst[1])) for inst in insts])
lbl_word = np.expand_dims( lbl_word = np.expand_dims(
word_pad([inst[1] + [eos_idx] for inst in insts]), axis=2) word_pad([
inst[1] + [eos_idx] + [pad_idx] * (trg_max_len - 1 - len(inst[1]))
for inst in insts
]),
axis=2)
data_inputs = [src_word, trg_word, lbl_word] data_inputs = [src_word, trg_word, lbl_word]
return data_inputs return data_inputs
def prepare_infer_input(insts, bos_idx, eos_idx, pad_idx): def prepare_infer_input(insts, bos_idx, eos_idx, pad_idx, pad_seq=1):
""" """
Put all padded data needed by beam search decoder into a list. Put all padded data needed by beam search decoder into a list.
""" """
word_pad = Pad(pad_idx) word_pad = Pad(pad_idx)
src_word = word_pad([inst[0] + [eos_idx] for inst in insts]) src_max_len = (
max([len(inst[0]) for inst in insts]) + pad_seq) // pad_seq * pad_seq
src_word = word_pad([
inst[0] + [eos_idx] + [pad_idx] * (src_max_len - 1 - len(inst[0]))
for inst in insts
])
return [src_word, ] return [src_word, ]
...@@ -154,18 +190,24 @@ class SentenceBatchCreator(object): ...@@ -154,18 +190,24 @@ class SentenceBatchCreator(object):
class TokenBatchCreator(object): class TokenBatchCreator(object):
def __init__(self, batch_size): def __init__(self, batch_size, bsz_multi=1):
self._batch = [] self._batch = []
self.max_len = -1 self.max_len = -1
self._batch_size = batch_size self._batch_size = batch_size
self._bsz_multi = bsz_multi
def append(self, info): def append(self, info):
cur_len = info.max_len cur_len = info.max_len
max_len = max(self.max_len, cur_len) max_len = max(self.max_len, cur_len)
if max_len * (len(self._batch) + 1) > self._batch_size: if max_len * (len(self._batch) + 1) > self._batch_size:
result = self._batch # Make sure the batch size won't be empty.
self._batch = [info] mode_len = max(
self.max_len = cur_len len(self._batch) // self._bsz_multi * self._bsz_multi,
len(self._batch) % self._bsz_multi)
result = self._batch[:mode_len]
self._batch = self._batch[mode_len:]
self._batch.append(info)
self.max_len = max([b.max_len for b in self._batch])
return result return result
else: else:
self.max_len = max_len self.max_len = max_len
...@@ -177,11 +219,12 @@ class TokenBatchCreator(object): ...@@ -177,11 +219,12 @@ class TokenBatchCreator(object):
class SampleInfo(object): class SampleInfo(object):
def __init__(self, i, lens): def __init__(self, i, lens, pad_seq=1):
self.i = i self.i = i
# Take bos and eos into account # Take bos and eos into account
self.min_len = min(lens[0] + 1, lens[1] + 1) self.min_len = min(lens[0], lens[1]) + 1
self.max_len = max(lens[0] + 1, lens[1] + 1) self.max_len = (max(lens[0], lens[1]) + pad_seq) // pad_seq * pad_seq
self.seq_max_len = max(lens[0], lens[1]) + 1
self.src_len = lens[0] + 1 self.src_len = lens[0] + 1
self.trg_len = lens[1] + 1 self.trg_len = lens[1] + 1
...@@ -201,7 +244,9 @@ class TransformerBatchSampler(BatchSampler): ...@@ -201,7 +244,9 @@ class TransformerBatchSampler(BatchSampler):
distribute_mode=True, distribute_mode=True,
seed=0, seed=0,
world_size=1, world_size=1,
rank=0): rank=0,
pad_seq=1,
bsz_multi=8):
for arg, value in locals().items(): for arg, value in locals().items():
if arg != "self": if arg != "self":
setattr(self, "_" + arg, value) setattr(self, "_" + arg, value)
...@@ -214,7 +259,7 @@ class TransformerBatchSampler(BatchSampler): ...@@ -214,7 +259,7 @@ class TransformerBatchSampler(BatchSampler):
self._sample_infos = [] self._sample_infos = []
for i, data in enumerate(self._dataset): for i, data in enumerate(self._dataset):
lens = [len(data[0]), len(data[1])] lens = [len(data[0]), len(data[1])]
self._sample_infos.append(SampleInfo(i, lens)) self._sample_infos.append(SampleInfo(i, lens, self._pad_seq))
def __iter__(self): def __iter__(self):
# global sort or global shuffle # global sort or global shuffle
...@@ -235,13 +280,13 @@ class TransformerBatchSampler(BatchSampler): ...@@ -235,13 +280,13 @@ class TransformerBatchSampler(BatchSampler):
reverse = not reverse reverse = not reverse
infos[i:i + self._pool_size] = sorted( infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size], infos[i:i + self._pool_size],
key=lambda x: x.max_len, key=lambda x: x.seq_max_len,
reverse=reverse) reverse=reverse)
batches = [] batches = []
batch_creator = TokenBatchCreator( batch_creator = TokenBatchCreator(
self. self._batch_size,
_batch_size) if self._use_token_batch else SentenceBatchCreator( self._bsz_multi) if self._use_token_batch else SentenceBatchCreator(
self._batch_size * self._nranks) self._batch_size * self._nranks)
for info in infos: for info in infos:
......
...@@ -44,7 +44,8 @@ def create_data_loader(args): ...@@ -44,7 +44,8 @@ def create_data_loader(args):
transform_func = WMT14ende.get_default_transform_func(root=root) transform_func = WMT14ende.get_default_transform_func(root=root)
datasets = [ datasets = [
WMT14ende.get_datasets( WMT14ende.get_datasets(
mode=m, transform_func=transform_func) for m in ["train", "dev"] mode=m, root=root, transform_func=transform_func)
for m in ["train", "dev"]
] ]
data_loaders = [(None)] * 2 data_loaders = [(None)] * 2
...@@ -89,7 +90,7 @@ def create_infer_loader(args): ...@@ -89,7 +90,7 @@ def create_infer_loader(args):
args.trg_vocab_size = padding_vocab(len(trg_vocab)) args.trg_vocab_size = padding_vocab(len(trg_vocab))
transform_func = WMT14ende.get_default_transform_func(root=root) transform_func = WMT14ende.get_default_transform_func(root=root)
dataset = WMT14ende.get_datasets( dataset = WMT14ende.get_datasets(
mode="test", transform_func=transform_func).filter( mode="test", root=root, transform_func=transform_func).filter(
partial( partial(
min_max_filer, max_len=args.max_length)) min_max_filer, max_len=args.max_length))
......
...@@ -339,7 +339,7 @@ class WMT14ende(TranslationDataset): ...@@ -339,7 +339,7 @@ class WMT14ende(TranslationDataset):
BOS_TOKEN = "<s>" BOS_TOKEN = "<s>"
EOS_TOKEN = "<e>" EOS_TOKEN = "<e>"
MD5 = "5506d213dba4124121c682368257bae4" MD5 = "a2b8410709ff760a3b40b84bd62dfbd8"
def __init__(self, mode="train", root=None, transform_func=None): def __init__(self, mode="train", root=None, transform_func=None):
if mode not in ("train", "dev", "test", "dev-eval", "test-eval"): if mode not in ("train", "dev", "test", "dev-eval", "test-eval"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册