提交 56f73f88 编写于 作者: W wangxiao

fix label_ids

上级 58e7a556
...@@ -66,7 +66,7 @@ class Reader(reader): ...@@ -66,7 +66,7 @@ class Reader(reader):
"position_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'], "input_mask": [[-1, -1, 1], 'float32'],
"label_ids": [[-1, 1], 'int64'], "label_ids": [[-1], 'int64'],
"task_ids": [[-1, -1], 'int64'] "task_ids": [[-1, -1], 'int64']
} }
else: else:
......
...@@ -64,7 +64,7 @@ class Reader(reader): ...@@ -64,7 +64,7 @@ class Reader(reader):
"position_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'], "input_mask": [[-1, -1, 1], 'float32'],
"label_ids": [[-1, 1], 'int64'], "label_ids": [[-1], 'int64'],
"task_ids": [[-1, -1], 'int64'] "task_ids": [[-1, -1], 'int64']
} }
else: else:
......
...@@ -65,8 +65,8 @@ class Reader(reader): ...@@ -65,8 +65,8 @@ class Reader(reader):
"segment_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'], "input_mask": [[-1, -1, 1], 'float32'],
"task_ids": [[-1, -1], 'int64'], "task_ids": [[-1, -1], 'int64'],
"mask_label": [[-1, 1], 'int64'], "mask_label": [[-1], 'int64'],
"mask_pos": [[-1, 1], 'int64'], "mask_pos": [[-1], 'int64'],
} }
......
...@@ -43,7 +43,7 @@ class TaskParadigm(task_paradigm): ...@@ -43,7 +43,7 @@ class TaskParadigm(task_paradigm):
@property @property
def inputs_attrs(self): def inputs_attrs(self):
if self._is_training: if self._is_training:
reader = {"label_ids": [[-1, 1], 'int64']} reader = {"label_ids": [[-1], 'int64']}
else: else:
reader = {} reader = {}
bb = {"sentence_embedding": [[-1, self._hidden_size], 'float32']} bb = {"sentence_embedding": [[-1, self._hidden_size], 'float32']}
......
...@@ -44,7 +44,7 @@ class TaskParadigm(task_paradigm): ...@@ -44,7 +44,7 @@ class TaskParadigm(task_paradigm):
@property @property
def inputs_attrs(self): def inputs_attrs(self):
if self._is_training: if self._is_training:
reader = {"label_ids": [[-1, 1], 'int64']} reader = {"label_ids": [[-1], 'int64']}
else: else:
reader = {} reader = {}
bb = {"sentence_pair_embedding": [[-1, self._hidden_size], 'float32']} bb = {"sentence_pair_embedding": [[-1, self._hidden_size], 'float32']}
......
...@@ -33,8 +33,8 @@ class TaskParadigm(task_paradigm): ...@@ -33,8 +33,8 @@ class TaskParadigm(task_paradigm):
@property @property
def inputs_attrs(self): def inputs_attrs(self):
reader = { reader = {
"mask_label": [[-1, 1], 'int64'], "mask_label": [[-1], 'int64'],
"mask_pos": [[-1, 1], 'int64']} "mask_pos": [[-1], 'int64']}
if not self._is_training: if not self._is_training:
del reader['mask_label'] del reader['mask_label']
del reader['batchsize_x_seqlen'] del reader['batchsize_x_seqlen']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册