提交 aff521f0 编写于 作者: W wangxiao

change tensorshape

上级 bba10bb6
...@@ -52,10 +52,10 @@ class Model(backbone): ...@@ -52,10 +52,10 @@ class Model(backbone):
@property @property
def inputs_attr(self): def inputs_attr(self):
return {"token_ids": [[-1, -1, 1], 'int64'], return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32']} "input_mask": [[-1, -1], 'float32']}
@property @property
def outputs_attr(self): def outputs_attr(self):
......
...@@ -62,11 +62,11 @@ class Model(backbone): ...@@ -62,11 +62,11 @@ class Model(backbone):
@property @property
def inputs_attr(self): def inputs_attr(self):
return {"token_ids": [[-1, -1, 1], 'int64'], return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'], "input_mask": [[-1, -1], 'float32'],
"task_ids": [[-1,-1, 1], 'int64']} "task_ids": [[-1,-1], 'int64']}
@property @property
def outputs_attr(self): def outputs_attr(self):
......
...@@ -60,13 +60,13 @@ class Reader(reader): ...@@ -60,13 +60,13 @@ class Reader(reader):
@property @property
def outputs_attr(self): def outputs_attr(self):
return {"token_ids": [[-1, -1, 1], 'int64'], return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'], "input_mask": [[-1, -1], 'float32'],
"task_ids": [[-1, -1, 1], 'int64'], "task_ids": [[-1, -1], 'int64'],
"mask_label": [[-1, 1], 'int64'], "mask_label": [[-1], 'int64'],
"mask_pos": [[-1, 1], 'int64'], "mask_pos": [[-1], 'int64'],
} }
......
...@@ -67,8 +67,8 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3): ...@@ -67,8 +67,8 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
sent[token_index] = MASK sent[token_index] = MASK
mask_flag = True mask_flag = True
mask_pos.append(sent_index * max_len + token_index) mask_pos.append(sent_index * max_len + token_index)
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1]) mask_label = np.array(mask_label).astype("int64").reshape([-1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1]) mask_pos = np.array(mask_pos).astype("int64").reshape([-1])
return batch_tokens, mask_label, mask_pos return batch_tokens, mask_label, mask_pos
...@@ -96,7 +96,7 @@ def prepare_batch_data(insts, ...@@ -96,7 +96,7 @@ def prepare_batch_data(insts,
# or unique id # or unique id
for i in range(3, len(insts[0]), 1): for i in range(3, len(insts[0]), 1):
labels = [inst[i] for inst in insts] labels = [inst[i] for inst in insts]
labels = np.array(labels).astype("int64").reshape([-1, 1]) labels = np.array(labels).astype("int64").reshape([-1])
labels_list.append(labels) labels_list.append(labels)
# First step: do mask without padding # First step: do mask without padding
if mask_id >= 0: if mask_id >= 0:
...@@ -154,14 +154,14 @@ def pad_batch_data(insts, ...@@ -154,14 +154,14 @@ def pad_batch_data(insts,
inst_data = np.array([ inst_data = np.array([
list(inst) + list([pad_idx] * (max_len - len(inst))) for inst in insts list(inst) + list([pad_idx] * (max_len - len(inst))) for inst in insts
]) ])
return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])] return_list += [inst_data.astype("int64").reshape([-1, max_len])]
# position data # position data
if return_pos: if return_pos:
inst_pos = np.array([ inst_pos = np.array([
list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst)) list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
for inst in insts for inst in insts
]) ])
return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])] return_list += [inst_pos.astype("int64").reshape([-1, max_len])]
if return_input_mask: if return_input_mask:
# This is used to avoid attention on paddings. # This is used to avoid attention on paddings.
input_mask_data = np.array([[1] * len(inst) + [0] * input_mask_data = np.array([[1] * len(inst) + [0] *
......
...@@ -113,8 +113,8 @@ def mask(batch_tokens, ...@@ -113,8 +113,8 @@ def mask(batch_tokens,
pre_sent_len = len(sent) pre_sent_len = len(sent)
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1]) mask_label = np.array(mask_label).astype("int64").reshape([-1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1]) mask_pos = np.array(mask_pos).astype("int64").reshape([-1])
return batch_tokens, mask_label, mask_pos return batch_tokens, mask_label, mask_pos
...@@ -136,7 +136,7 @@ def pad_batch_data(insts, ...@@ -136,7 +136,7 @@ def pad_batch_data(insts,
inst_data = np.array( inst_data = np.array(
[inst + list([pad_idx] * (max_len - len(inst))) for inst in insts]) [inst + list([pad_idx] * (max_len - len(inst))) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])] return_list += [inst_data.astype("int64").reshape([-1, max_len])]
# position data # position data
if return_pos: if return_pos:
...@@ -145,7 +145,7 @@ def pad_batch_data(insts, ...@@ -145,7 +145,7 @@ def pad_batch_data(insts,
for inst in insts for inst in insts
]) ])
return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])] return_list += [inst_pos.astype("int64").reshape([-1, max_len])]
if return_input_mask: if return_input_mask:
# This is used to avoid attention on paddings. # This is used to avoid attention on paddings.
...@@ -165,7 +165,7 @@ def pad_batch_data(insts, ...@@ -165,7 +165,7 @@ def pad_batch_data(insts,
if return_seq_lens: if return_seq_lens:
seq_lens = np.array([len(inst) for inst in insts]) seq_lens = np.array([len(inst) for inst in insts])
return_list += [seq_lens.astype("int64").reshape([-1, 1])] return_list += [seq_lens.astype("int64").reshape([-1])]
return return_list if len(return_list) > 1 else return_list[0] return return_list if len(return_list) > 1 else return_list[0]
......
...@@ -67,8 +67,8 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3): ...@@ -67,8 +67,8 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
sent[token_index] = MASK sent[token_index] = MASK
mask_flag = True mask_flag = True
mask_pos.append(sent_index * max_len + token_index) mask_pos.append(sent_index * max_len + token_index)
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1]) mask_label = np.array(mask_label).astype("int64").reshape([-1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1]) mask_pos = np.array(mask_pos).astype("int64").reshape([-1])
return batch_tokens, mask_label, mask_pos return batch_tokens, mask_label, mask_pos
...@@ -147,14 +147,14 @@ def pad_batch_data(insts, ...@@ -147,14 +147,14 @@ def pad_batch_data(insts,
inst_data = np.array([ inst_data = np.array([
list(inst) + list([pad_idx] * (max_len - len(inst))) for inst in insts list(inst) + list([pad_idx] * (max_len - len(inst))) for inst in insts
]) ])
return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])] return_list += [inst_data.astype("int64").reshape([-1, max_len])]
# position data # position data
if return_pos: if return_pos:
inst_pos = np.array([ inst_pos = np.array([
list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst)) list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
for inst in insts for inst in insts
]) ])
return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])] return_list += [inst_pos.astype("int64").reshape([-1, max_len])]
if return_input_mask: if return_input_mask:
# This is used to avoid attention on paddings. # This is used to avoid attention on paddings.
input_mask_data = np.array([[1] * len(inst) + [0] * input_mask_data = np.array([[1] * len(inst) + [0] *
......
...@@ -479,17 +479,17 @@ class ClassifyReader(BaseReader): ...@@ -479,17 +479,17 @@ class ClassifyReader(BaseReader):
batch_labels = [record.label_id for record in batch_records] batch_labels = [record.label_id for record in batch_records]
if self.is_classify: if self.is_classify:
batch_labels = np.array(batch_labels).astype("int64").reshape( batch_labels = np.array(batch_labels).astype("int64").reshape(
[-1, 1]) [-1])
elif self.is_regression: elif self.is_regression:
batch_labels = np.array(batch_labels).astype("float32").reshape( batch_labels = np.array(batch_labels).astype("float32").reshape(
[-1, 1]) [-1])
if batch_records[0].qid: if batch_records[0].qid:
batch_qids = [record.qid for record in batch_records] batch_qids = [record.qid for record in batch_records]
batch_qids = np.array(batch_qids).astype("int64").reshape( batch_qids = np.array(batch_qids).astype("int64").reshape(
[-1, 1]) [-1])
else: else:
batch_qids = np.array([]).astype("int64").reshape([-1, 1]) batch_qids = np.array([]).astype("int64").reshape([-1])
# padding # padding
padded_token_ids, input_mask = pad_batch_data( padded_token_ids, input_mask = pad_batch_data(
...@@ -908,15 +908,15 @@ class MRCReader(BaseReader): ...@@ -908,15 +908,15 @@ class MRCReader(BaseReader):
record.end_position for record in batch_records record.end_position for record in batch_records
] ]
batch_start_position = np.array(batch_start_position).astype( batch_start_position = np.array(batch_start_position).astype(
"int64").reshape([-1, 1]) "int64").reshape([-1])
batch_end_position = np.array(batch_end_position).astype( batch_end_position = np.array(batch_end_position).astype(
"int64").reshape([-1, 1]) "int64").reshape([-1])
else: else:
batch_size = len(batch_token_ids) batch_size = len(batch_token_ids)
batch_start_position = np.zeros( batch_start_position = np.zeros(
shape=[batch_size, 1], dtype="int64") shape=[batch_size], dtype="int64")
batch_end_position = np.zeros(shape=[batch_size, 1], dtype="int64") batch_end_position = np.zeros(shape=[batch_size], dtype="int64")
batch_unique_ids = [record.unique_id for record in batch_records] batch_unique_ids = [record.unique_id for record in batch_records]
batch_unique_ids = np.array(batch_unique_ids).astype("int64").reshape( batch_unique_ids = np.array(batch_unique_ids).astype("int64").reshape(
......
...@@ -43,7 +43,7 @@ class TaskParadigm(task_paradigm): ...@@ -43,7 +43,7 @@ class TaskParadigm(task_paradigm):
@property @property
def inputs_attrs(self): def inputs_attrs(self):
if self._is_training: if self._is_training:
reader = {"label_ids": [[-1, 1], 'int64']} reader = {"label_ids": [[-1], 'int64']}
else: else:
reader = {} reader = {}
bb = {"sentence_embedding": [[-1, self._hidden_size], 'float32']} bb = {"sentence_embedding": [[-1, self._hidden_size], 'float32']}
......
...@@ -44,7 +44,7 @@ class TaskParadigm(task_paradigm): ...@@ -44,7 +44,7 @@ class TaskParadigm(task_paradigm):
@property @property
def inputs_attrs(self): def inputs_attrs(self):
if self._is_training: if self._is_training:
reader = {"label_ids": [[-1, 1], 'int64']} reader = {"label_ids": [[-1], 'int64']}
else: else:
reader = {} reader = {}
bb = {"sentence_pair_embedding": [[-1, self._hidden_size], 'float32']} bb = {"sentence_pair_embedding": [[-1, self._hidden_size], 'float32']}
......
...@@ -33,8 +33,8 @@ class TaskParadigm(task_paradigm): ...@@ -33,8 +33,8 @@ class TaskParadigm(task_paradigm):
@property @property
def inputs_attrs(self): def inputs_attrs(self):
reader = { reader = {
"mask_label": [[-1, 1], 'int64'], "mask_label": [[-1], 'int64'],
"mask_pos": [[-1, 1], 'int64']} "mask_pos": [[-1], 'int64']}
if not self._is_training: if not self._is_training:
del reader['mask_label'] del reader['mask_label']
del reader['batchsize_x_seqlen'] del reader['batchsize_x_seqlen']
......
...@@ -49,11 +49,11 @@ class TaskParadigm(task_paradigm): ...@@ -49,11 +49,11 @@ class TaskParadigm(task_paradigm):
@property @property
def inputs_attrs(self): def inputs_attrs(self):
if self._is_training: if self._is_training:
reader = {"start_positions": [[-1, 1], 'int64'], reader = {"start_positions": [[-1], 'int64'],
"end_positions": [[-1, 1], 'int64'], "end_positions": [[-1], 'int64'],
} }
else: else:
reader = {'unique_ids': [[-1, 1], 'int64']} reader = {'unique_ids': [[-1], 'int64']}
bb = {"encoder_outputs": [[-1, -1, self._hidden_size], 'float32']} bb = {"encoder_outputs": [[-1, -1, self._hidden_size], 'float32']}
return {'reader': reader, 'backbone': bb} return {'reader': reader, 'backbone': bb}
...@@ -68,9 +68,9 @@ class TaskParadigm(task_paradigm): ...@@ -68,9 +68,9 @@ class TaskParadigm(task_paradigm):
if self._is_training: if self._is_training:
return {'loss': [[1], 'float32']} return {'loss': [[1], 'float32']}
else: else:
return {'start_logits': [[-1, -1, 1], 'float32'], return {'start_logits': [[-1, -1], 'float32'],
'end_logits': [[-1, -1, 1], 'float32'], 'end_logits': [[-1, -1], 'float32'],
'unique_ids': [[-1, 1], 'int64']} 'unique_ids': [[-1], 'int64']}
def build(self, inputs, scope_name=""): def build(self, inputs, scope_name=""):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册