diff --git a/BERT/reader/cls.py b/BERT/reader/cls.py index cebd583549eb8d138984e2a0d92b05e7dfbda8c5..2c75479dd76467c76674f34e93044752bf4aeebf 100644 --- a/BERT/reader/cls.py +++ b/BERT/reader/cls.py @@ -196,10 +196,11 @@ class DataProcessor(object): return_num_token=False) if len(all_dev_batches) < dev_count: all_dev_batches.append(batch_data) - else: + + if len(all_dev_batches) == dev_count: for batch in all_dev_batches: yield batch - all_dev_batches = [batch_data] + all_dev_batches = [] return wrapper diff --git a/BERT/reader/squad.py b/BERT/reader/squad.py index 4c69763fdb7ec3cc72fb1d13ca89956965ee3c68..90c3496cca1993c4aa0680e214ba49736b3f6da4 100644 --- a/BERT/reader/squad.py +++ b/BERT/reader/squad.py @@ -566,10 +566,11 @@ class DataProcessor(object): return_num_token=False) if len(all_dev_batches) < dev_count: all_dev_batches.append(batch_data) - else: + + if len(all_dev_batches) == dev_count: for batch in all_dev_batches: yield batch - all_dev_batches = [batch_data] + all_dev_batches = [] return wrapper