From c087d295bf9ff3cccc3181b678fed0b95a28789b Mon Sep 17 00:00:00 2001 From: wangxiao Date: Wed, 4 Dec 2019 15:44:51 +0800 Subject: [PATCH] fix bugs --- paddlepalm/backbone/bert.py | 2 +- paddlepalm/backbone/ernie.py | 2 +- paddlepalm/reader/cls.py | 6 +++--- paddlepalm/reader/match.py | 6 +++--- paddlepalm/reader/mlm.py | 6 +++--- paddlepalm/reader/mrc.py | 4 ++-- paddlepalm/task_paradigm/cls.py | 2 +- paddlepalm/task_paradigm/match.py | 2 +- paddlepalm/task_paradigm/mlm.py | 4 ++-- paddlepalm/task_paradigm/mrc.py | 4 ++-- paddlepalm/utils/reader_helper.py | 2 ++ 11 files changed, 21 insertions(+), 19 deletions(-) diff --git a/paddlepalm/backbone/bert.py b/paddlepalm/backbone/bert.py index addae98..d3592a5 100644 --- a/paddlepalm/backbone/bert.py +++ b/paddlepalm/backbone/bert.py @@ -55,7 +55,7 @@ class Model(backbone): return {"token_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32']} + "input_mask": [[-1, -1, 1], 'float32']} @property def outputs_attr(self): diff --git a/paddlepalm/backbone/ernie.py b/paddlepalm/backbone/ernie.py index cc841af..ded1963 100644 --- a/paddlepalm/backbone/ernie.py +++ b/paddlepalm/backbone/ernie.py @@ -65,7 +65,7 @@ class Model(backbone): return {"token_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32'], + "input_mask": [[-1, -1, 1], 'float32'], "task_ids": [[-1,-1], 'int64']} @property diff --git a/paddlepalm/reader/cls.py b/paddlepalm/reader/cls.py index 5c28c84..c13cbfe 100644 --- a/paddlepalm/reader/cls.py +++ b/paddlepalm/reader/cls.py @@ -65,8 +65,8 @@ class Reader(reader): return {"token_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32'], - "label_ids": [[-1], 'int64'], + "input_mask": [[-1, -1, 1], 'float32'], + "label_ids": [[-1, 1], 'int64'], "task_ids": [[-1, -1], 'int64'] } else: @@ -74,7 +74,7 @@ class Reader(reader): "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], "task_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32'] + "input_mask": [[-1, -1, 1], 'float32'] } diff --git a/paddlepalm/reader/match.py b/paddlepalm/reader/match.py index 4520c35..2dcdc12 100644 --- a/paddlepalm/reader/match.py +++ b/paddlepalm/reader/match.py @@ -63,8 +63,8 @@ class Reader(reader): return {"token_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32'], - "label_ids": [[-1], 'int64'], + "input_mask": [[-1, -1, 1], 'float32'], + "label_ids": [[-1, 1], 'int64'], "task_ids": [[-1, -1], 'int64'] } else: @@ -72,7 +72,7 @@ class Reader(reader): "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], "task_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32'] + "input_mask": [[-1, -1, 1], 'float32'] } diff --git a/paddlepalm/reader/mlm.py b/paddlepalm/reader/mlm.py index c1db351..5230844 100644 --- a/paddlepalm/reader/mlm.py +++ b/paddlepalm/reader/mlm.py @@ -63,10 +63,10 @@ class Reader(reader): return {"token_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32'], + "input_mask": [[-1, -1, 1], 'float32'], "task_ids": [[-1, -1], 'int64'], - "mask_label": [[-1], 'int64'], - "mask_pos": [[-1], 'int64'], + "mask_label": [[-1, 1], 'int64'], + "mask_pos": [[-1, 1], 'int64'], } diff --git a/paddlepalm/reader/mrc.py b/paddlepalm/reader/mrc.py index ea4e726..2906b97 100644 --- a/paddlepalm/reader/mrc.py +++ b/paddlepalm/reader/mrc.py @@ -71,7 +71,7 @@ class Reader(reader): return {"token_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32'], + "input_mask": [[-1, -1, 1], 'float32'], "start_positions": [[-1], 'int64'], "end_positions": [[-1], 'int64'], "task_ids": [[-1, -1], 'int64'] @@ -81,7 +81,7 @@ class Reader(reader): "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], "task_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32'], + "input_mask": [[-1, -1, 1], 'float32'], "unique_ids": [[-1], 'int64'] } diff --git a/paddlepalm/task_paradigm/cls.py b/paddlepalm/task_paradigm/cls.py index 69fb184..6cbacf7 100644 --- a/paddlepalm/task_paradigm/cls.py +++ b/paddlepalm/task_paradigm/cls.py @@ -43,7 +43,7 @@ class TaskParadigm(task_paradigm): @property def inputs_attrs(self): if self._is_training: - reader = {"label_ids": [[-1], 'int64']} + reader = {"label_ids": [[-1, 1], 'int64']} else: reader = {} bb = {"sentence_embedding": [[-1, self._hidden_size], 'float32']} diff --git a/paddlepalm/task_paradigm/match.py b/paddlepalm/task_paradigm/match.py index 72e303d..ee0d175 100644 --- a/paddlepalm/task_paradigm/match.py +++ b/paddlepalm/task_paradigm/match.py @@ -44,7 +44,7 @@ class TaskParadigm(task_paradigm): @property def inputs_attrs(self): if self._is_training: - reader = {"label_ids": [[-1], 'int64']} + reader = {"label_ids": [[-1, 1], 'int64']} else: reader = {} bb = {"sentence_pair_embedding": [[-1, self._hidden_size], 'float32']} diff --git a/paddlepalm/task_paradigm/mlm.py b/paddlepalm/task_paradigm/mlm.py index 211002f..ec86dd1 100644 --- a/paddlepalm/task_paradigm/mlm.py +++ b/paddlepalm/task_paradigm/mlm.py @@ -33,8 +33,8 @@ class TaskParadigm(task_paradigm): @property def inputs_attrs(self): reader = { - "mask_label": [[-1], 'int64'], - "mask_pos": [[-1], 'int64']} + "mask_label": [[-1, 1], 'int64'], + "mask_pos": [[-1, 1], 'int64']} if not self._is_training: del reader['mask_label'] del reader['batchsize_x_seqlen'] diff --git a/paddlepalm/task_paradigm/mrc.py b/paddlepalm/task_paradigm/mrc.py index edb5bff..218885a 100644 --- a/paddlepalm/task_paradigm/mrc.py +++ b/paddlepalm/task_paradigm/mrc.py @@ -68,8 +68,8 @@ class TaskParadigm(task_paradigm): if self._is_training: return {'loss': [[1], 'float32']} else: - return {'start_logits': [[-1, -1], 'float32'], - 'end_logits': [[-1, -1], 'float32'], + return {'start_logits': [[-1, -1, 1], 'float32'], + 'end_logits': [[-1, -1, 1], 'float32'], 'unique_ids': [[-1], 'int64']} diff --git a/paddlepalm/utils/reader_helper.py b/paddlepalm/utils/reader_helper.py index 09d5c0c..6e9c9d6 100644 --- a/paddlepalm/utils/reader_helper.py +++ b/paddlepalm/utils/reader_helper.py @@ -22,6 +22,8 @@ from paddle import fluid def _check_and_adapt_shape_dtype(rt_val, attr, message=""): + print(rt_val) + print(attr) if not isinstance(rt_val, np.ndarray): rt_val = np.array(rt_val) assert rt_val.dtype != np.dtype('O'), "yielded data is not a valid tensor(number of elements on some dimension may differ)." -- GitLab