diff --git a/paddlepalm/backbone/bert.py b/paddlepalm/backbone/bert.py index addae9853ab771bee227212823d1387b082f7570..d3592a5526447694e8a14d01dee2b9987740b2ed 100644 --- a/paddlepalm/backbone/bert.py +++ b/paddlepalm/backbone/bert.py @@ -55,7 +55,7 @@ class Model(backbone): return {"token_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32']} + "input_mask": [[-1, -1, 1], 'float32']} @property def outputs_attr(self): diff --git a/paddlepalm/backbone/ernie.py b/paddlepalm/backbone/ernie.py index cc841affcb17f22bf87169ebc2859e8ea53a08d6..ded196385112513d001c6db4505cdc3883592984 100644 --- a/paddlepalm/backbone/ernie.py +++ b/paddlepalm/backbone/ernie.py @@ -65,7 +65,7 @@ class Model(backbone): return {"token_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32'], + "input_mask": [[-1, -1, 1], 'float32'], "task_ids": [[-1,-1], 'int64']} @property diff --git a/paddlepalm/reader/cls.py b/paddlepalm/reader/cls.py index 5c28c846ff2daf17cc7760aeae07ea66b7225348..c13cbfe9806e1d9988ba80055570ce0cb6916434 100644 --- a/paddlepalm/reader/cls.py +++ b/paddlepalm/reader/cls.py @@ -65,8 +65,8 @@ class Reader(reader): return {"token_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32'], - "label_ids": [[-1], 'int64'], + "input_mask": [[-1, -1, 1], 'float32'], + "label_ids": [[-1, 1], 'int64'], "task_ids": [[-1, -1], 'int64'] } else: @@ -74,7 +74,7 @@ class Reader(reader): "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], "task_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32'] + "input_mask": [[-1, -1, 1], 'float32'] } diff --git a/paddlepalm/reader/match.py b/paddlepalm/reader/match.py index 4520c35c535f8d570a24d726e99e94ab88ed84f9..2dcdc12c4afac4e004777ff378a76bae027963b4 100644 --- a/paddlepalm/reader/match.py +++ b/paddlepalm/reader/match.py @@ -63,8 +63,8 @@ class Reader(reader): return {"token_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32'], - "label_ids": [[-1], 'int64'], + "input_mask": [[-1, -1, 1], 'float32'], + "label_ids": [[-1, 1], 'int64'], "task_ids": [[-1, -1], 'int64'] } else: @@ -72,7 +72,7 @@ class Reader(reader): "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], "task_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32'] + "input_mask": [[-1, -1, 1], 'float32'] } diff --git a/paddlepalm/reader/mlm.py b/paddlepalm/reader/mlm.py index c1db35160251f2aeb850989a2b03d39ac08a40a8..5230844d052c6725cf747f448c723bd22a801dd9 100644 --- a/paddlepalm/reader/mlm.py +++ b/paddlepalm/reader/mlm.py @@ -63,10 +63,10 @@ class Reader(reader): return {"token_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32'], + "input_mask": [[-1, -1, 1], 'float32'], "task_ids": [[-1, -1], 'int64'], - "mask_label": [[-1], 'int64'], - "mask_pos": [[-1], 'int64'], + "mask_label": [[-1, 1], 'int64'], + "mask_pos": [[-1, 1], 'int64'], } diff --git a/paddlepalm/reader/mrc.py b/paddlepalm/reader/mrc.py index ea4e72678bf02b7d0d4fd625c7986f63c2d3f177..2906b97ecb591fd6cc65f3a246c6d88e87dfccb8 100644 --- a/paddlepalm/reader/mrc.py +++ b/paddlepalm/reader/mrc.py @@ -71,7 +71,7 @@ class Reader(reader): return {"token_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32'], + "input_mask": [[-1, -1, 1], 'float32'], "start_positions": [[-1], 'int64'], "end_positions": [[-1], 'int64'], "task_ids": [[-1, -1], 'int64'] @@ -81,7 +81,7 @@ class Reader(reader): "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], "task_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1], 'float32'], + "input_mask": [[-1, -1, 1], 'float32'], "unique_ids": [[-1], 'int64'] } diff --git a/paddlepalm/task_paradigm/cls.py b/paddlepalm/task_paradigm/cls.py index 69fb1845cf8fc628cf0021b05e90191ef15f0450..6cbacf79dd12622c4d952c29040c0c42768e2d11 100644 --- a/paddlepalm/task_paradigm/cls.py +++ b/paddlepalm/task_paradigm/cls.py @@ -43,7 +43,7 @@ class TaskParadigm(task_paradigm): @property def inputs_attrs(self): if self._is_training: - reader = {"label_ids": [[-1], 'int64']} + reader = {"label_ids": [[-1, 1], 'int64']} else: reader = {} bb = {"sentence_embedding": [[-1, self._hidden_size], 'float32']} diff --git a/paddlepalm/task_paradigm/match.py b/paddlepalm/task_paradigm/match.py index 72e303d8e6c168adcd591ee4719aa94287f349e5..ee0d175b01e09ede242aa7fe404366dc48804580 100644 --- a/paddlepalm/task_paradigm/match.py +++ b/paddlepalm/task_paradigm/match.py @@ -44,7 +44,7 @@ class TaskParadigm(task_paradigm): @property def inputs_attrs(self): if self._is_training: - reader = {"label_ids": [[-1], 'int64']} + reader = {"label_ids": [[-1, 1], 'int64']} else: reader = {} bb = {"sentence_pair_embedding": [[-1, self._hidden_size], 'float32']} diff --git a/paddlepalm/task_paradigm/mlm.py b/paddlepalm/task_paradigm/mlm.py index 211002fef31603319dffa99bb88675c74e78e6eb..ec86dd151e8b0f86c345120f4a5907f0afb91d5c 100644 --- a/paddlepalm/task_paradigm/mlm.py +++ b/paddlepalm/task_paradigm/mlm.py @@ -33,8 +33,8 @@ class TaskParadigm(task_paradigm): @property def inputs_attrs(self): reader = { - "mask_label": [[-1], 'int64'], - "mask_pos": [[-1], 'int64']} + "mask_label": [[-1, 1], 'int64'], + "mask_pos": [[-1, 1], 'int64']} if not self._is_training: del reader['mask_label'] del reader['batchsize_x_seqlen'] diff --git a/paddlepalm/task_paradigm/mrc.py b/paddlepalm/task_paradigm/mrc.py index edb5bffbadfcce5e33dd1e7be15e3e9b29d0f57d..218885ab21b20522d2b4f74c041f8cc09d4b3cdc 100644 --- a/paddlepalm/task_paradigm/mrc.py +++ b/paddlepalm/task_paradigm/mrc.py @@ -68,8 +68,8 @@ class TaskParadigm(task_paradigm): if self._is_training: return {'loss': [[1], 'float32']} else: - return {'start_logits': [[-1, -1], 'float32'], - 'end_logits': [[-1, -1], 'float32'], + return {'start_logits': [[-1, -1, 1], 'float32'], + 'end_logits': [[-1, -1, 1], 'float32'], 'unique_ids': [[-1], 'int64']} diff --git a/paddlepalm/utils/reader_helper.py b/paddlepalm/utils/reader_helper.py index 09d5c0c6632efd8e29fd2d14331a4b62b151abf1..6e9c9d6574be708c844a70ac3b921c1af658815d 100644 --- a/paddlepalm/utils/reader_helper.py +++ b/paddlepalm/utils/reader_helper.py @@ -22,6 +22,8 @@ from paddle import fluid def _check_and_adapt_shape_dtype(rt_val, attr, message=""): + print(rt_val) + print(attr) if not isinstance(rt_val, np.ndarray): rt_val = np.array(rt_val) assert rt_val.dtype != np.dtype('O'), "yielded data is not a valid tensor(number of elements on some dimension may differ)."