提交 b12e0d90 编写于 作者: Y Yibing Liu

Enable batching not in tokens in pretraining

上级 8a0753a5
...@@ -36,6 +36,7 @@ class DataReader(object): ...@@ -36,6 +36,7 @@ class DataReader(object):
data_dir, data_dir,
vocab_path, vocab_path,
batch_size=4096, batch_size=4096,
in_tokens=True,
max_seq_len=512, max_seq_len=512,
shuffle_files=True, shuffle_files=True,
epoch=100, epoch=100,
...@@ -46,6 +47,7 @@ class DataReader(object): ...@@ -46,6 +47,7 @@ class DataReader(object):
self.vocab = self.load_vocab(vocab_path) self.vocab = self.load_vocab(vocab_path)
self.data_dir = data_dir self.data_dir = data_dir
self.batch_size = batch_size self.batch_size = batch_size
self.in_tokens = in_tokens
self.shuffle_files = shuffle_files self.shuffle_files = shuffle_files
self.epoch = epoch self.epoch = epoch
self.current_epoch = 0 self.current_epoch = 0
...@@ -60,8 +62,6 @@ class DataReader(object): ...@@ -60,8 +62,6 @@ class DataReader(object):
self.mask_id = self.vocab["[MASK]"] self.mask_id = self.vocab["[MASK]"]
self.is_test = is_test self.is_test = is_test
self.generate_neg_sample = generate_neg_sample self.generate_neg_sample = generate_neg_sample
assert self.batch_size > 100, "Current batch size means total token's number, \
it should not be set to too small number."
if self.is_test: if self.is_test:
self.epoch = 1 self.epoch = 1
...@@ -245,12 +245,16 @@ class DataReader(object): ...@@ -245,12 +245,16 @@ class DataReader(object):
continue continue
yield sample yield sample
def batch_reader(reader, batch_size): def batch_reader(reader, batch_size, in_tokens):
batch, total_token_num, max_len = [], 0, 0 batch, total_token_num, max_len = [], 0, 0
for parsed_line in reader(): for parsed_line in reader():
token_ids, sent_ids, pos_ids, label = parsed_line token_ids, sent_ids, pos_ids, label = parsed_line
max_len = max(max_len, len(token_ids)) max_len = max(max_len, len(token_ids))
if (len(batch) + 1) * max_len <= batch_size: if in_tokens:
to_append = (len(batch) + 1) * max_len <= batch_size
else:
to_append = len(batch) < batch_size
if to_append:
batch.append(parsed_line) batch.append(parsed_line)
total_token_num += len(token_ids) total_token_num += len(token_ids)
else: else:
...@@ -261,8 +265,8 @@ class DataReader(object): ...@@ -261,8 +265,8 @@ class DataReader(object):
if len(batch) > 0: if len(batch) > 0:
yield batch, total_token_num yield batch, total_token_num
for batch_data, total_token_num in batch_reader(reader, for batch_data, total_token_num in batch_reader(
self.batch_size): reader, self.batch_size, self.in_tokens):
yield prepare_batch_data( yield prepare_batch_data(
batch_data, batch_data,
total_token_num, total_token_num,
......
...@@ -61,14 +61,15 @@ log_g.add_arg("verbose", bool, False, "Whether to output verbose l ...@@ -61,14 +61,15 @@ log_g.add_arg("verbose", bool, False, "Whether to output verbose l
data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options") data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("data_dir", str, "./data/train/", "Path to training data.") data_g.add_arg("data_dir", str, "./data/train/", "Path to training data.")
data_g.add_arg("validation_set_dir", str, "./data/validation/", "Path to training data.") data_g.add_arg("validation_set_dir", str, "./data/validation/", "Path to validation data.")
data_g.add_arg("test_set_dir", str, None, "Path to training data.") data_g.add_arg("test_set_dir", str, None, "Path to test data.")
data_g.add_arg("vocab_path", str, "./config/vocab.txt", "Vocabulary path.") data_g.add_arg("vocab_path", str, "./config/vocab.txt", "Vocabulary path.")
data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.") data_g.add_arg("max_seq_len", int, 512, "Tokens' number of the longest seqence allowed.")
data_g.add_arg("batch_size", int, 16, "Total examples' number in batch for training. see also --in_tokens.") data_g.add_arg("batch_size", int, 8192,
data_g.add_arg("in_tokens", bool, False, "The total number of examples in one batch for training, see also --in_tokens.")
"If set, the batch size will be the maximum number of tokens in one batch. " data_g.add_arg("in_tokens", bool, True,
"Otherwise, it will be the maximum number of examples in one batch.") "If set, the batch size will be the maximum number of tokens in one batch. "
"Otherwise, it will be the maximum number of examples in one batch.")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.") run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("is_distributed", bool, False, "If set, then start distributed training.") run_type_g.add_arg("is_distributed", bool, False, "If set, then start distributed training.")
...@@ -128,6 +129,7 @@ def predict_wrapper(args, ...@@ -128,6 +129,7 @@ def predict_wrapper(args,
data_path, data_path,
vocab_path=args.vocab_path, vocab_path=args.vocab_path,
batch_size=args.batch_size, batch_size=args.batch_size,
in_tokens=args.in_tokens,
voc_size=bert_config['vocab_size'], voc_size=bert_config['vocab_size'],
shuffle_files=False, shuffle_files=False,
epoch=1, epoch=1,
...@@ -250,9 +252,16 @@ def train(args): ...@@ -250,9 +252,16 @@ def train(args):
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
print("Device count %d" % dev_count) print("Device count %d" % dev_count)
print("theoretical memory usage: ") if args.verbose:
print(fluid.contrib.memory_usage( if args.in_tokens:
program=train_program, batch_size=args.batch_size // args.max_seq_len)) lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
program=train_program,
batch_size=args.batch_size // args.max_seq_len)
else:
lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
program=train_program, batch_size=args.batch_size)
print("Theoretical memory usage in training: %.3f - %.3f %s" %
(lower_mem, upper_mem, unit))
nccl2_num_trainers = 1 nccl2_num_trainers = 1
nccl2_trainer_id = 0 nccl2_trainer_id = 0
...@@ -293,6 +302,7 @@ def train(args): ...@@ -293,6 +302,7 @@ def train(args):
data_reader = DataReader( data_reader = DataReader(
data_dir=args.data_dir, data_dir=args.data_dir,
batch_size=args.batch_size, batch_size=args.batch_size,
in_tokens=args.in_tokens,
vocab_path=args.vocab_path, vocab_path=args.vocab_path,
voc_size=bert_config['vocab_size'], voc_size=bert_config['vocab_size'],
epoch=args.epoch, epoch=args.epoch,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册