From e65ba4150ee0082def4fb577be3d7943adda7872 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 6 Mar 2019 14:22:33 +0000 Subject: [PATCH] assert batch_size >= max_seq_len --- BERT/reader/pretraining.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/BERT/reader/pretraining.py b/BERT/reader/pretraining.py index 37026df..35a3720 100644 --- a/BERT/reader/pretraining.py +++ b/BERT/reader/pretraining.py @@ -62,6 +62,9 @@ class DataReader(object): self.mask_id = self.vocab["[MASK]"] self.is_test = is_test self.generate_neg_sample = generate_neg_sample + if self.in_tokens: + assert self.batch_size >= self.max_seq_len, "The number of " \ + "tokens in batch should not be smaller than max seq length." if self.is_test: self.epoch = 1 -- GitLab