提交 56f73f88 编写于 作者: W wangxiao

fix label_ids

上级 58e7a556
......@@ -66,7 +66,7 @@ class Reader(reader):
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"label_ids": [[-1, 1], 'int64'],
"label_ids": [[-1], 'int64'],
"task_ids": [[-1, -1], 'int64']
}
else:
......
......@@ -64,7 +64,7 @@ class Reader(reader):
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"label_ids": [[-1, 1], 'int64'],
"label_ids": [[-1], 'int64'],
"task_ids": [[-1, -1], 'int64']
}
else:
......
......@@ -65,8 +65,8 @@ class Reader(reader):
"segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"task_ids": [[-1, -1], 'int64'],
"mask_label": [[-1, 1], 'int64'],
"mask_pos": [[-1, 1], 'int64'],
"mask_label": [[-1], 'int64'],
"mask_pos": [[-1], 'int64'],
}
......
......@@ -43,7 +43,7 @@ class TaskParadigm(task_paradigm):
@property
def inputs_attrs(self):
if self._is_training:
reader = {"label_ids": [[-1, 1], 'int64']}
reader = {"label_ids": [[-1], 'int64']}
else:
reader = {}
bb = {"sentence_embedding": [[-1, self._hidden_size], 'float32']}
......
......@@ -44,7 +44,7 @@ class TaskParadigm(task_paradigm):
@property
def inputs_attrs(self):
if self._is_training:
reader = {"label_ids": [[-1, 1], 'int64']}
reader = {"label_ids": [[-1], 'int64']}
else:
reader = {}
bb = {"sentence_pair_embedding": [[-1, self._hidden_size], 'float32']}
......
......@@ -33,8 +33,8 @@ class TaskParadigm(task_paradigm):
@property
def inputs_attrs(self):
reader = {
"mask_label": [[-1, 1], 'int64'],
"mask_pos": [[-1, 1], 'int64']}
"mask_label": [[-1], 'int64'],
"mask_pos": [[-1], 'int64']}
if not self._is_training:
del reader['mask_label']
del reader['batchsize_x_seqlen']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册