提交 74ac7d56 编写于 作者: W wangxiao

change tensorshape

上级 acfcb1c5
...@@ -76,8 +76,8 @@ def _download(item, scope, path, silent=False): ...@@ -76,8 +76,8 @@ def _download(item, scope, path, silent=False):
report_hook(bytes_so_far, total_size) report_hook(bytes_so_far, total_size)
return bytes_so_far return bytes_so_far
response = urlopen(data_url) # response = urlopen(data_url)
_chunk_read(response, data_url, report_hook=_chunk_report) # _chunk_read(response, data_url, report_hook=_chunk_report)
if not silent: if not silent:
print(' done!') print(' done!')
......
...@@ -62,19 +62,19 @@ class Reader(reader): ...@@ -62,19 +62,19 @@ class Reader(reader):
@property @property
def outputs_attr(self): def outputs_attr(self):
if self._is_training: if self._is_training:
return {"token_ids": [[-1, -1, 1], 'int64'], return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'], "input_mask": [[-1, -1], 'float32'],
"label_ids": [[-1,1], 'int64'], "label_ids": [[-1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64'] "task_ids": [[-1, -1], 'int64']
} }
else: else:
return {"token_ids": [[-1, -1, 1], 'int64'], return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64'], "task_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'] "input_mask": [[-1, -1], 'float32']
} }
......
...@@ -60,19 +60,19 @@ class Reader(reader): ...@@ -60,19 +60,19 @@ class Reader(reader):
@property @property
def outputs_attr(self): def outputs_attr(self):
if self._is_training: if self._is_training:
return {"token_ids": [[-1, -1, 1], 'int64'], return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'], "input_mask": [[-1, -1], 'float32'],
"label_ids": [[-1,1], 'int64'], "label_ids": [[-1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64'] "task_ids": [[-1, -1], 'int64']
} }
else: else:
return {"token_ids": [[-1, -1, 1], 'int64'], return {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64'], "task_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'] "input_mask": [[-1, -1], 'float32']
} }
......
...@@ -920,7 +920,7 @@ class MRCReader(BaseReader): ...@@ -920,7 +920,7 @@ class MRCReader(BaseReader):
batch_unique_ids = [record.unique_id for record in batch_records] batch_unique_ids = [record.unique_id for record in batch_records]
batch_unique_ids = np.array(batch_unique_ids).astype("int64").reshape( batch_unique_ids = np.array(batch_unique_ids).astype("int64").reshape(
[-1, 1]) [-1])
# padding # padding
padded_token_ids, input_mask = pad_batch_data( padded_token_ids, input_mask = pad_batch_data(
......
...@@ -218,7 +218,7 @@ def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batc ...@@ -218,7 +218,7 @@ def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batc
names = [] names = []
start = 0 start = 0
if insert_taskid: if insert_taskid:
ret.append(([1,1], 'int64')) ret.append(([1], 'int64'))
names.append('__task_id') names.append('__task_id')
start += 1 start += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册