提交 b12e0d90 编写于 作者: Y Yibing Liu

Enable batching not in tokens in pretraining

上级 8a0753a5
......@@ -36,6 +36,7 @@ class DataReader(object):
data_dir,
vocab_path,
batch_size=4096,
in_tokens=True,
max_seq_len=512,
shuffle_files=True,
epoch=100,
......@@ -46,6 +47,7 @@ class DataReader(object):
self.vocab = self.load_vocab(vocab_path)
self.data_dir = data_dir
self.batch_size = batch_size
self.in_tokens = in_tokens
self.shuffle_files = shuffle_files
self.epoch = epoch
self.current_epoch = 0
......@@ -60,8 +62,6 @@ class DataReader(object):
self.mask_id = self.vocab["[MASK]"]
self.is_test = is_test
self.generate_neg_sample = generate_neg_sample
assert self.batch_size > 100, "Current batch size means total token's number, \
it should not be set to too small number."
if self.is_test:
self.epoch = 1
......@@ -245,12 +245,16 @@ class DataReader(object):
continue
yield sample
def batch_reader(reader, batch_size):
def batch_reader(reader, batch_size, in_tokens):
batch, total_token_num, max_len = [], 0, 0
for parsed_line in reader():
token_ids, sent_ids, pos_ids, label = parsed_line
max_len = max(max_len, len(token_ids))
if (len(batch) + 1) * max_len <= batch_size:
if in_tokens:
to_append = (len(batch) + 1) * max_len <= batch_size
else:
to_append = len(batch) < batch_size
if to_append:
batch.append(parsed_line)
total_token_num += len(token_ids)
else:
......@@ -261,8 +265,8 @@ class DataReader(object):
if len(batch) > 0:
yield batch, total_token_num
for batch_data, total_token_num in batch_reader(reader,
self.batch_size):
for batch_data, total_token_num in batch_reader(
reader, self.batch_size, self.in_tokens):
yield prepare_batch_data(
batch_data,
total_token_num,
......
......@@ -61,12 +61,13 @@ log_g.add_arg("verbose", bool, False, "Whether to output verbose l
data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("data_dir", str, "./data/train/", "Path to training data.")
data_g.add_arg("validation_set_dir", str, "./data/validation/", "Path to training data.")
data_g.add_arg("test_set_dir", str, None, "Path to training data.")
data_g.add_arg("validation_set_dir", str, "./data/validation/", "Path to validation data.")
data_g.add_arg("test_set_dir", str, None, "Path to test data.")
data_g.add_arg("vocab_path", str, "./config/vocab.txt", "Vocabulary path.")
data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.")
data_g.add_arg("batch_size", int, 16, "Total examples' number in batch for training. see also --in_tokens.")
data_g.add_arg("in_tokens", bool, False,
data_g.add_arg("max_seq_len", int, 512, "Tokens' number of the longest seqence allowed.")
data_g.add_arg("batch_size", int, 8192,
"The total number of examples in one batch for training, see also --in_tokens.")
data_g.add_arg("in_tokens", bool, True,
"If set, the batch size will be the maximum number of tokens in one batch. "
"Otherwise, it will be the maximum number of examples in one batch.")
......@@ -128,6 +129,7 @@ def predict_wrapper(args,
data_path,
vocab_path=args.vocab_path,
batch_size=args.batch_size,
in_tokens=args.in_tokens,
voc_size=bert_config['vocab_size'],
shuffle_files=False,
epoch=1,
......@@ -250,9 +252,16 @@ def train(args):
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
print("Device count %d" % dev_count)
print("theoretical memory usage: ")
print(fluid.contrib.memory_usage(
program=train_program, batch_size=args.batch_size // args.max_seq_len))
if args.verbose:
if args.in_tokens:
lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
program=train_program,
batch_size=args.batch_size // args.max_seq_len)
else:
lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
program=train_program, batch_size=args.batch_size)
print("Theoretical memory usage in training: %.3f - %.3f %s" %
(lower_mem, upper_mem, unit))
nccl2_num_trainers = 1
nccl2_trainer_id = 0
......@@ -293,6 +302,7 @@ def train(args):
data_reader = DataReader(
data_dir=args.data_dir,
batch_size=args.batch_size,
in_tokens=args.in_tokens,
vocab_path=args.vocab_path,
voc_size=bert_config['vocab_size'],
epoch=args.epoch,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册