提交 e65ba415 编写于 作者: Y Yibing Liu

assert batch_size >= max_seq_len

上级 b12e0d90
......@@ -62,6 +62,9 @@ class DataReader(object):
self.mask_id = self.vocab["[MASK]"]
self.is_test = is_test
self.generate_neg_sample = generate_neg_sample
if self.in_tokens:
assert self.batch_size >= self.max_seq_len, "The number of " \
"tokens in batch should not be smaller than max seq length."
if self.is_test:
self.epoch = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册