提交 74ac7d56 编写于 作者: W wangxiao

change tensorshape

上级 acfcb1c5
......@@ -76,8 +76,8 @@ def _download(item, scope, path, silent=False):
report_hook(bytes_so_far, total_size)
return bytes_so_far
response = urlopen(data_url)
_chunk_read(response, data_url, report_hook=_chunk_report)
# response = urlopen(data_url)
# _chunk_read(response, data_url, report_hook=_chunk_report)
if not silent:
print(' done!')
......
......@@ -62,19 +62,19 @@ class Reader(reader):
@property
def outputs_attr(self):
if self._is_training:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"label_ids": [[-1,1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64']
return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32'],
"label_ids": [[-1], 'int64'],
"task_ids": [[-1, -1], 'int64']
}
else:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32']
return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"task_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32']
}
......
......@@ -60,19 +60,19 @@ class Reader(reader):
@property
def outputs_attr(self):
if self._is_training:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"label_ids": [[-1,1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64']
return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32'],
"label_ids": [[-1], 'int64'],
"task_ids": [[-1, -1], 'int64']
}
else:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32']
return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"task_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1], 'float32']
}
......
......@@ -920,7 +920,7 @@ class MRCReader(BaseReader):
batch_unique_ids = [record.unique_id for record in batch_records]
batch_unique_ids = np.array(batch_unique_ids).astype("int64").reshape(
[-1, 1])
[-1])
# padding
padded_token_ids, input_mask = pad_batch_data(
......
......@@ -218,7 +218,7 @@ def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batc
names = []
start = 0
if insert_taskid:
ret.append(([1,1], 'int64'))
ret.append(([1], 'int64'))
names.append('__task_id')
start += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册