diff --git a/BERT/reader/pretraining.py b/BERT/reader/pretraining.py index 182e627ab2b1056031d13c29df2ad0cce413020b..35a3720eaf8d383e19584ed114322685c517ce3a 100644 --- a/BERT/reader/pretraining.py +++ b/BERT/reader/pretraining.py @@ -36,6 +36,7 @@ class DataReader(object): data_dir, vocab_path, batch_size=4096, + in_tokens=True, max_seq_len=512, shuffle_files=True, epoch=100, @@ -46,6 +47,7 @@ class DataReader(object): self.vocab = self.load_vocab(vocab_path) self.data_dir = data_dir self.batch_size = batch_size + self.in_tokens = in_tokens self.shuffle_files = shuffle_files self.epoch = epoch self.current_epoch = 0 @@ -60,8 +62,9 @@ class DataReader(object): self.mask_id = self.vocab["[MASK]"] self.is_test = is_test self.generate_neg_sample = generate_neg_sample - assert self.batch_size > 100, "Current batch size means total token's number, \ - it should not be set to too small number." + if self.in_tokens: + assert self.batch_size >= self.max_seq_len, "The number of " \ + "tokens in batch should not be smaller than max seq length." if self.is_test: self.epoch = 1 @@ -245,12 +248,16 @@ class DataReader(object): continue yield sample - def batch_reader(reader, batch_size): + def batch_reader(reader, batch_size, in_tokens): batch, total_token_num, max_len = [], 0, 0 for parsed_line in reader(): token_ids, sent_ids, pos_ids, label = parsed_line max_len = max(max_len, len(token_ids)) - if (len(batch) + 1) * max_len <= batch_size: + if in_tokens: + to_append = (len(batch) + 1) * max_len <= batch_size + else: + to_append = len(batch) < batch_size + if to_append: batch.append(parsed_line) total_token_num += len(token_ids) else: @@ -261,8 +268,8 @@ class DataReader(object): if len(batch) > 0: yield batch, total_token_num - for batch_data, total_token_num in batch_reader(reader, - self.batch_size): + for batch_data, total_token_num in batch_reader( + reader, self.batch_size, self.in_tokens): yield prepare_batch_data( batch_data, total_token_num, diff --git a/BERT/train.py b/BERT/train.py index 51df8705d9adcd18f1a8ae0f9674158e776bd49e..64c751c44735fdc90e9bd284b356cde3467e13f7 100644 --- a/BERT/train.py +++ b/BERT/train.py @@ -61,14 +61,15 @@ log_g.add_arg("verbose", bool, False, "Whether to output verbose l data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options") data_g.add_arg("data_dir", str, "./data/train/", "Path to training data.") -data_g.add_arg("validation_set_dir", str, "./data/validation/", "Path to training data.") -data_g.add_arg("test_set_dir", str, None, "Path to training data.") +data_g.add_arg("validation_set_dir", str, "./data/validation/", "Path to validation data.") +data_g.add_arg("test_set_dir", str, None, "Path to test data.") data_g.add_arg("vocab_path", str, "./config/vocab.txt", "Vocabulary path.") -data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.") -data_g.add_arg("batch_size", int, 16, "Total examples' number in batch for training. see also --in_tokens.") -data_g.add_arg("in_tokens", bool, False, - "If set, the batch size will be the maximum number of tokens in one batch. " - "Otherwise, it will be the maximum number of examples in one batch.") +data_g.add_arg("max_seq_len", int, 512, "Tokens' number of the longest seqence allowed.") +data_g.add_arg("batch_size", int, 8192, + "The total number of examples in one batch for training, see also --in_tokens.") +data_g.add_arg("in_tokens", bool, True, + "If set, the batch size will be the maximum number of tokens in one batch. " + "Otherwise, it will be the maximum number of examples in one batch.") run_type_g = ArgumentGroup(parser, "run_type", "running type options.") run_type_g.add_arg("is_distributed", bool, False, "If set, then start distributed training.") @@ -128,6 +129,7 @@ def predict_wrapper(args, data_path, vocab_path=args.vocab_path, batch_size=args.batch_size, + in_tokens=args.in_tokens, voc_size=bert_config['vocab_size'], shuffle_files=False, epoch=1, @@ -250,9 +252,16 @@ def train(args): dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) print("Device count %d" % dev_count) - print("theoretical memory usage: ") - print(fluid.contrib.memory_usage( - program=train_program, batch_size=args.batch_size // args.max_seq_len)) + if args.verbose: + if args.in_tokens: + lower_mem, upper_mem, unit = fluid.contrib.memory_usage( + program=train_program, + batch_size=args.batch_size // args.max_seq_len) + else: + lower_mem, upper_mem, unit = fluid.contrib.memory_usage( + program=train_program, batch_size=args.batch_size) + print("Theoretical memory usage in training: %.3f - %.3f %s" % + (lower_mem, upper_mem, unit)) nccl2_num_trainers = 1 nccl2_trainer_id = 0 @@ -293,6 +302,7 @@ def train(args): data_reader = DataReader( data_dir=args.data_dir, batch_size=args.batch_size, + in_tokens=args.in_tokens, vocab_path=args.vocab_path, voc_size=bert_config['vocab_size'], epoch=args.epoch,