diff --git a/BERT/reader/pretraining.py b/BERT/reader/pretraining.py index 37026df04f43825b1dae6a5066b936e207a2c099..35a3720eaf8d383e19584ed114322685c517ce3a 100644 --- a/BERT/reader/pretraining.py +++ b/BERT/reader/pretraining.py @@ -62,6 +62,9 @@ class DataReader(object): self.mask_id = self.vocab["[MASK]"] self.is_test = is_test self.generate_neg_sample = generate_neg_sample + if self.in_tokens: + assert self.batch_size >= self.max_seq_len, "The number of " \ + "tokens in batch should not be smaller than max seq length." if self.is_test: self.epoch = 1