提交 c087d295 编写于 作者: W wangxiao

fix bugs

上级 1a762548
......@@ -55,7 +55,7 @@ class Model(backbone):
return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32']}
"input_mask": [[-1, -1, 1], 'float32']}
@property
def outputs_attr(self):
......
......@@ -65,7 +65,7 @@ class Model(backbone):
return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32'],
"input_mask": [[-1, -1, 1], 'float32'],
"task_ids": [[-1,-1], 'int64']}
@property
......
......@@ -65,8 +65,8 @@ class Reader(reader):
return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32'],
"label_ids": [[-1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"label_ids": [[-1, 1], 'int64'],
"task_ids": [[-1, -1], 'int64']
}
else:
......@@ -74,7 +74,7 @@ class Reader(reader):
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"task_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32']
"input_mask": [[-1, -1, 1], 'float32']
}
......
......@@ -63,8 +63,8 @@ class Reader(reader):
return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32'],
"label_ids": [[-1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"label_ids": [[-1, 1], 'int64'],
"task_ids": [[-1, -1], 'int64']
}
else:
......@@ -72,7 +72,7 @@ class Reader(reader):
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"task_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32']
"input_mask": [[-1, -1, 1], 'float32']
}
......
......@@ -63,10 +63,10 @@ class Reader(reader):
return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32'],
"input_mask": [[-1, -1, 1], 'float32'],
"task_ids": [[-1, -1], 'int64'],
"mask_label": [[-1], 'int64'],
"mask_pos": [[-1], 'int64'],
"mask_label": [[-1, 1], 'int64'],
"mask_pos": [[-1, 1], 'int64'],
}
......
......@@ -71,7 +71,7 @@ class Reader(reader):
return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32'],
"input_mask": [[-1, -1, 1], 'float32'],
"start_positions": [[-1], 'int64'],
"end_positions": [[-1], 'int64'],
"task_ids": [[-1, -1], 'int64']
......@@ -81,7 +81,7 @@ class Reader(reader):
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"task_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32'],
"input_mask": [[-1, -1, 1], 'float32'],
"unique_ids": [[-1], 'int64']
}
......
......@@ -43,7 +43,7 @@ class TaskParadigm(task_paradigm):
@property
def inputs_attrs(self):
if self._is_training:
reader = {"label_ids": [[-1], 'int64']}
reader = {"label_ids": [[-1, 1], 'int64']}
else:
reader = {}
bb = {"sentence_embedding": [[-1, self._hidden_size], 'float32']}
......
......@@ -44,7 +44,7 @@ class TaskParadigm(task_paradigm):
@property
def inputs_attrs(self):
if self._is_training:
reader = {"label_ids": [[-1], 'int64']}
reader = {"label_ids": [[-1, 1], 'int64']}
else:
reader = {}
bb = {"sentence_pair_embedding": [[-1, self._hidden_size], 'float32']}
......
......@@ -33,8 +33,8 @@ class TaskParadigm(task_paradigm):
@property
def inputs_attrs(self):
reader = {
"mask_label": [[-1], 'int64'],
"mask_pos": [[-1], 'int64']}
"mask_label": [[-1, 1], 'int64'],
"mask_pos": [[-1, 1], 'int64']}
if not self._is_training:
del reader['mask_label']
del reader['batchsize_x_seqlen']
......
......@@ -68,8 +68,8 @@ class TaskParadigm(task_paradigm):
if self._is_training:
return {'loss': [[1], 'float32']}
else:
return {'start_logits': [[-1, -1], 'float32'],
'end_logits': [[-1, -1], 'float32'],
return {'start_logits': [[-1, -1, 1], 'float32'],
'end_logits': [[-1, -1, 1], 'float32'],
'unique_ids': [[-1], 'int64']}
......
......@@ -22,6 +22,8 @@ from paddle import fluid
def _check_and_adapt_shape_dtype(rt_val, attr, message=""):
print(rt_val)
print(attr)
if not isinstance(rt_val, np.ndarray):
rt_val = np.array(rt_val)
assert rt_val.dtype != np.dtype('O'), "yielded data is not a valid tensor(number of elements on some dimension may differ)."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册