提交 c087d295 编写于 作者: W wangxiao

fix bugs

上级 1a762548
...@@ -55,7 +55,7 @@ class Model(backbone): ...@@ -55,7 +55,7 @@ class Model(backbone):
return {"token_ids": [[-1, -1], 'int64'], return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32']} "input_mask": [[-1, -1, 1], 'float32']}
@property @property
def outputs_attr(self): def outputs_attr(self):
......
...@@ -65,7 +65,7 @@ class Model(backbone): ...@@ -65,7 +65,7 @@ class Model(backbone):
return {"token_ids": [[-1, -1], 'int64'], return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32'], "input_mask": [[-1, -1, 1], 'float32'],
"task_ids": [[-1,-1], 'int64']} "task_ids": [[-1,-1], 'int64']}
@property @property
......
...@@ -65,8 +65,8 @@ class Reader(reader): ...@@ -65,8 +65,8 @@ class Reader(reader):
return {"token_ids": [[-1, -1], 'int64'], return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32'], "input_mask": [[-1, -1, 1], 'float32'],
"label_ids": [[-1], 'int64'], "label_ids": [[-1, 1], 'int64'],
"task_ids": [[-1, -1], 'int64'] "task_ids": [[-1, -1], 'int64']
} }
else: else:
...@@ -74,7 +74,7 @@ class Reader(reader): ...@@ -74,7 +74,7 @@ class Reader(reader):
"position_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"task_ids": [[-1, -1], 'int64'], "task_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32'] "input_mask": [[-1, -1, 1], 'float32']
} }
......
...@@ -63,8 +63,8 @@ class Reader(reader): ...@@ -63,8 +63,8 @@ class Reader(reader):
return {"token_ids": [[-1, -1], 'int64'], return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32'], "input_mask": [[-1, -1, 1], 'float32'],
"label_ids": [[-1], 'int64'], "label_ids": [[-1, 1], 'int64'],
"task_ids": [[-1, -1], 'int64'] "task_ids": [[-1, -1], 'int64']
} }
else: else:
...@@ -72,7 +72,7 @@ class Reader(reader): ...@@ -72,7 +72,7 @@ class Reader(reader):
"position_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"task_ids": [[-1, -1], 'int64'], "task_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32'] "input_mask": [[-1, -1, 1], 'float32']
} }
......
...@@ -63,10 +63,10 @@ class Reader(reader): ...@@ -63,10 +63,10 @@ class Reader(reader):
return {"token_ids": [[-1, -1], 'int64'], return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32'], "input_mask": [[-1, -1, 1], 'float32'],
"task_ids": [[-1, -1], 'int64'], "task_ids": [[-1, -1], 'int64'],
"mask_label": [[-1], 'int64'], "mask_label": [[-1, 1], 'int64'],
"mask_pos": [[-1], 'int64'], "mask_pos": [[-1, 1], 'int64'],
} }
......
...@@ -71,7 +71,7 @@ class Reader(reader): ...@@ -71,7 +71,7 @@ class Reader(reader):
return {"token_ids": [[-1, -1], 'int64'], return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32'], "input_mask": [[-1, -1, 1], 'float32'],
"start_positions": [[-1], 'int64'], "start_positions": [[-1], 'int64'],
"end_positions": [[-1], 'int64'], "end_positions": [[-1], 'int64'],
"task_ids": [[-1, -1], 'int64'] "task_ids": [[-1, -1], 'int64']
...@@ -81,7 +81,7 @@ class Reader(reader): ...@@ -81,7 +81,7 @@ class Reader(reader):
"position_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"task_ids": [[-1, -1], 'int64'], "task_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32'], "input_mask": [[-1, -1, 1], 'float32'],
"unique_ids": [[-1], 'int64'] "unique_ids": [[-1], 'int64']
} }
......
...@@ -43,7 +43,7 @@ class TaskParadigm(task_paradigm): ...@@ -43,7 +43,7 @@ class TaskParadigm(task_paradigm):
@property @property
def inputs_attrs(self): def inputs_attrs(self):
if self._is_training: if self._is_training:
reader = {"label_ids": [[-1], 'int64']} reader = {"label_ids": [[-1, 1], 'int64']}
else: else:
reader = {} reader = {}
bb = {"sentence_embedding": [[-1, self._hidden_size], 'float32']} bb = {"sentence_embedding": [[-1, self._hidden_size], 'float32']}
......
...@@ -44,7 +44,7 @@ class TaskParadigm(task_paradigm): ...@@ -44,7 +44,7 @@ class TaskParadigm(task_paradigm):
@property @property
def inputs_attrs(self): def inputs_attrs(self):
if self._is_training: if self._is_training:
reader = {"label_ids": [[-1], 'int64']} reader = {"label_ids": [[-1, 1], 'int64']}
else: else:
reader = {} reader = {}
bb = {"sentence_pair_embedding": [[-1, self._hidden_size], 'float32']} bb = {"sentence_pair_embedding": [[-1, self._hidden_size], 'float32']}
......
...@@ -33,8 +33,8 @@ class TaskParadigm(task_paradigm): ...@@ -33,8 +33,8 @@ class TaskParadigm(task_paradigm):
@property @property
def inputs_attrs(self): def inputs_attrs(self):
reader = { reader = {
"mask_label": [[-1], 'int64'], "mask_label": [[-1, 1], 'int64'],
"mask_pos": [[-1], 'int64']} "mask_pos": [[-1, 1], 'int64']}
if not self._is_training: if not self._is_training:
del reader['mask_label'] del reader['mask_label']
del reader['batchsize_x_seqlen'] del reader['batchsize_x_seqlen']
......
...@@ -68,8 +68,8 @@ class TaskParadigm(task_paradigm): ...@@ -68,8 +68,8 @@ class TaskParadigm(task_paradigm):
if self._is_training: if self._is_training:
return {'loss': [[1], 'float32']} return {'loss': [[1], 'float32']}
else: else:
return {'start_logits': [[-1, -1], 'float32'], return {'start_logits': [[-1, -1, 1], 'float32'],
'end_logits': [[-1, -1], 'float32'], 'end_logits': [[-1, -1, 1], 'float32'],
'unique_ids': [[-1], 'int64']} 'unique_ids': [[-1], 'int64']}
......
...@@ -22,6 +22,8 @@ from paddle import fluid ...@@ -22,6 +22,8 @@ from paddle import fluid
def _check_and_adapt_shape_dtype(rt_val, attr, message=""): def _check_and_adapt_shape_dtype(rt_val, attr, message=""):
print(rt_val)
print(attr)
if not isinstance(rt_val, np.ndarray): if not isinstance(rt_val, np.ndarray):
rt_val = np.array(rt_val) rt_val = np.array(rt_val)
assert rt_val.dtype != np.dtype('O'), "yielded data is not a valid tensor(number of elements on some dimension may differ)." assert rt_val.dtype != np.dtype('O'), "yielded data is not a valid tensor(number of elements on some dimension may differ)."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册