未验证 提交 8c66b85b 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Merge pull request #21 from xixiaoyao/master

fix bugs
......@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import sys
import importlib
......@@ -397,8 +399,9 @@ class Controller(object):
# load data
for inst in instances:
print(inst.name+": preparing data...")
print(inst.name+": preparing data...", end='')
inst.reader['train'].load_data()
print('ok!')
# merge dataset iterators and create net input vars
iterators = []
......
......@@ -222,6 +222,8 @@ class BaseReader(object):
def _prepare_batch_data(self, examples, batch_size, phase=None):
"""generate batch records"""
batch_records, max_len = [], 0
if len(examples) < batch_size:
raise Exception('CLS dataset contains too few samples. Expect more than '+str(batch_size))
for index, example in enumerate(examples):
if phase == "train":
self.current_example = index
......@@ -308,7 +310,6 @@ class MaskLMReader(BaseReader):
tokens_a = tokenizer.tokenize(text_a)
tokens_b = None
has_text_b = False
if isinstance(example, dict):
has_text_b = "text_b" in example.keys()
......@@ -379,6 +380,8 @@ class MaskLMReader(BaseReader):
def batch_reader(self, examples, batch_size, in_tokens, phase):
batch = []
total_token_num = 0
if len(examples) < batch_size:
raise Exception('MaskLM dataset contains too few samples. Expect more than '+str(batch_size))
for e in examples:
parsed_line = self._convert_example_to_record(e, self.max_seq_len, self.tokenizer)
to_append = len(batch) < batch_size
......@@ -866,6 +869,9 @@ class MRCReader(BaseReader):
"""generate batch records"""
batch_records, max_len = [], 0
if len(records) < batch_size:
raise Exception('mrc dataset contains too few samples. Expect more than '+str(batch_size))
for index, record in enumerate(records):
if phase == "train":
self.current_example = index
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册