提交 51d7ddd8 编写于 作者: X xixiaoyao

fix bugs

上级 c6e33be8
......@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import sys
import importlib
......@@ -397,8 +399,9 @@ class Controller(object):
# load data
for inst in instances:
print(inst.name+": preparing data...")
print(inst.name+": preparing data...", end='')
inst.reader['train'].load_data()
print('ok!')
# merge dataset iterators and create net input vars
iterators = []
......
......@@ -222,6 +222,8 @@ class BaseReader(object):
def _prepare_batch_data(self, examples, batch_size, phase=None):
"""generate batch records"""
batch_records, max_len = [], 0
if len(examples) < batch_size:
raise Exception('CLS dataset contains too few samples. Expect more than '+str(batch_size))
for index, example in enumerate(examples):
if phase == "train":
self.current_example = index
......@@ -308,7 +310,6 @@ class MaskLMReader(BaseReader):
tokens_a = tokenizer.tokenize(text_a)
tokens_b = None
has_text_b = False
if isinstance(example, dict):
has_text_b = "text_b" in example.keys()
......@@ -379,6 +380,8 @@ class MaskLMReader(BaseReader):
def batch_reader(self, examples, batch_size, in_tokens, phase):
batch = []
total_token_num = 0
if len(examples) < batch_size:
raise Exception('MaskLM dataset contains too few samples. Expect more than '+str(batch_size))
for e in examples:
parsed_line = self._convert_example_to_record(e, self.max_seq_len, self.tokenizer)
to_append = len(batch) < batch_size
......@@ -866,6 +869,9 @@ class MRCReader(BaseReader):
"""generate batch records"""
batch_records, max_len = [], 0
if len(records) < batch_size:
raise Exception('mrc dataset contains too few samples. Expect more than '+str(batch_size))
for index, record in enumerate(records):
if phase == "train":
self.current_example = index
......
......@@ -118,10 +118,12 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
results = _zero_batch(joint_shape_and_dtypes)
outbuf = {}
for id in task_ids:
print(id)
outputs = next(iterators[id]) # dict type
outbuf[id] = outputs
prefix = iterator_prefixes[id]
for outname, val in outputs.items():
print(outname)
task_outname = prefix + '/' + outname
if outname in outname_to_pos:
......@@ -133,6 +135,7 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
idx = outname_to_pos[task_outname]
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=task_outname+': ')
results[idx] = val
print('ok')
fake_batch = results
dev_count_bak = dev_count
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册