From 56f73f88b6fb7bdbeed9ec537e3cbba6b6d2d26d Mon Sep 17 00:00:00 2001 From: wangxiao Date: Wed, 4 Dec 2019 16:33:47 +0800 Subject: [PATCH] fix label_ids --- paddlepalm/reader/cls.py | 2 +- paddlepalm/reader/match.py | 2 +- paddlepalm/reader/mlm.py | 4 ++-- paddlepalm/task_paradigm/cls.py | 2 +- paddlepalm/task_paradigm/match.py | 2 +- paddlepalm/task_paradigm/mlm.py | 4 ++-- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/paddlepalm/reader/cls.py b/paddlepalm/reader/cls.py index c13cbfe..dd5e7f3 100644 --- a/paddlepalm/reader/cls.py +++ b/paddlepalm/reader/cls.py @@ -66,7 +66,7 @@ class Reader(reader): "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], "input_mask": [[-1, -1, 1], 'float32'], - "label_ids": [[-1, 1], 'int64'], + "label_ids": [[-1], 'int64'], "task_ids": [[-1, -1], 'int64'] } else: diff --git a/paddlepalm/reader/match.py b/paddlepalm/reader/match.py index 2dcdc12..d6be0f8 100644 --- a/paddlepalm/reader/match.py +++ b/paddlepalm/reader/match.py @@ -64,7 +64,7 @@ class Reader(reader): "position_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'], "input_mask": [[-1, -1, 1], 'float32'], - "label_ids": [[-1, 1], 'int64'], + "label_ids": [[-1], 'int64'], "task_ids": [[-1, -1], 'int64'] } else: diff --git a/paddlepalm/reader/mlm.py b/paddlepalm/reader/mlm.py index 5230844..e4dff34 100644 --- a/paddlepalm/reader/mlm.py +++ b/paddlepalm/reader/mlm.py @@ -65,8 +65,8 @@ class Reader(reader): "segment_ids": [[-1, -1], 'int64'], "input_mask": [[-1, -1, 1], 'float32'], "task_ids": [[-1, -1], 'int64'], - "mask_label": [[-1, 1], 'int64'], - "mask_pos": [[-1, 1], 'int64'], + "mask_label": [[-1], 'int64'], + "mask_pos": [[-1], 'int64'], } diff --git a/paddlepalm/task_paradigm/cls.py b/paddlepalm/task_paradigm/cls.py index ed40709..b590b6f 100644 --- a/paddlepalm/task_paradigm/cls.py +++ b/paddlepalm/task_paradigm/cls.py @@ -43,7 +43,7 @@ class TaskParadigm(task_paradigm): @property def inputs_attrs(self): if self._is_training: - reader = {"label_ids": [[-1, 1], 'int64']} + reader = {"label_ids": [[-1], 'int64']} else: reader = {} bb = {"sentence_embedding": [[-1, self._hidden_size], 'float32']} diff --git a/paddlepalm/task_paradigm/match.py b/paddlepalm/task_paradigm/match.py index a286cbf..d42c64a 100644 --- a/paddlepalm/task_paradigm/match.py +++ b/paddlepalm/task_paradigm/match.py @@ -44,7 +44,7 @@ class TaskParadigm(task_paradigm): @property def inputs_attrs(self): if self._is_training: - reader = {"label_ids": [[-1, 1], 'int64']} + reader = {"label_ids": [[-1], 'int64']} else: reader = {} bb = {"sentence_pair_embedding": [[-1, self._hidden_size], 'float32']} diff --git a/paddlepalm/task_paradigm/mlm.py b/paddlepalm/task_paradigm/mlm.py index b77483a..81dda86 100644 --- a/paddlepalm/task_paradigm/mlm.py +++ b/paddlepalm/task_paradigm/mlm.py @@ -33,8 +33,8 @@ class TaskParadigm(task_paradigm): @property def inputs_attrs(self): reader = { - "mask_label": [[-1, 1], 'int64'], - "mask_pos": [[-1, 1], 'int64']} + "mask_label": [[-1], 'int64'], + "mask_pos": [[-1], 'int64']} if not self._is_training: del reader['mask_label'] del reader['batchsize_x_seqlen'] -- GitLab