From 51d7ddd810802eaa93e9e235e0a5edcf35b5f39b Mon Sep 17 00:00:00 2001 From: xixiaoyao Date: Thu, 7 Nov 2019 13:19:57 +0800 Subject: [PATCH] fix bugs --- paddlepalm/mtl_controller.py | 5 ++++- paddlepalm/reader/utils/reader4ernie.py | 8 +++++++- paddlepalm/utils/reader_helper.py | 3 +++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/paddlepalm/mtl_controller.py b/paddlepalm/mtl_controller.py index 1f9b48b..aa94c34 100755 --- a/paddlepalm/mtl_controller.py +++ b/paddlepalm/mtl_controller.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import os import sys import importlib @@ -397,8 +399,9 @@ class Controller(object): # load data for inst in instances: - print(inst.name+": preparing data...") + print(inst.name+": preparing data...", end='') inst.reader['train'].load_data() + print('ok!') # merge dataset iterators and create net input vars iterators = [] diff --git a/paddlepalm/reader/utils/reader4ernie.py b/paddlepalm/reader/utils/reader4ernie.py index fcf25e7..ea844db 100644 --- a/paddlepalm/reader/utils/reader4ernie.py +++ b/paddlepalm/reader/utils/reader4ernie.py @@ -222,6 +222,8 @@ class BaseReader(object): def _prepare_batch_data(self, examples, batch_size, phase=None): """generate batch records""" batch_records, max_len = [], 0 + if len(examples) < batch_size: + raise Exception('CLS dataset contains too few samples. Expect more than '+str(batch_size)) for index, example in enumerate(examples): if phase == "train": self.current_example = index @@ -308,7 +310,6 @@ class MaskLMReader(BaseReader): tokens_a = tokenizer.tokenize(text_a) tokens_b = None - has_text_b = False if isinstance(example, dict): has_text_b = "text_b" in example.keys() @@ -379,6 +380,8 @@ class MaskLMReader(BaseReader): def batch_reader(self, examples, batch_size, in_tokens, phase): batch = [] total_token_num = 0 + if len(examples) < batch_size: + raise Exception('MaskLM dataset contains too few samples. Expect more than '+str(batch_size)) for e in examples: parsed_line = self._convert_example_to_record(e, self.max_seq_len, self.tokenizer) to_append = len(batch) < batch_size @@ -866,6 +869,9 @@ class MRCReader(BaseReader): """generate batch records""" batch_records, max_len = [], 0 + if len(records) < batch_size: + raise Exception('mrc dataset contains too few samples. Expect more than '+str(batch_size)) + for index, record in enumerate(records): if phase == "train": self.current_example = index diff --git a/paddlepalm/utils/reader_helper.py b/paddlepalm/utils/reader_helper.py index 63362b3..d38fc9e 100644 --- a/paddlepalm/utils/reader_helper.py +++ b/paddlepalm/utils/reader_helper.py @@ -118,10 +118,12 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype results = _zero_batch(joint_shape_and_dtypes) outbuf = {} for id in task_ids: + print(id) outputs = next(iterators[id]) # dict type outbuf[id] = outputs prefix = iterator_prefixes[id] for outname, val in outputs.items(): + print(outname) task_outname = prefix + '/' + outname if outname in outname_to_pos: @@ -133,6 +135,7 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype idx = outname_to_pos[task_outname] val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=task_outname+': ') results[idx] = val + print('ok') fake_batch = results dev_count_bak = dev_count -- GitLab