From 74ac7d56dd083d08932b8981d268aa6f278d62d3 Mon Sep 17 00:00:00 2001 From: wangxiao Date: Wed, 4 Dec 2019 13:54:37 +0800 Subject: [PATCH] change tensorshape --- paddlepalm/_downloader.py | 4 ++-- paddlepalm/reader/cls.py | 22 +++++++++++----------- paddlepalm/reader/match.py | 22 +++++++++++----------- paddlepalm/reader/utils/reader4ernie.py | 2 +- paddlepalm/utils/reader_helper.py | 2 +- 5 files changed, 26 insertions(+), 26 deletions(-) diff --git a/paddlepalm/_downloader.py b/paddlepalm/_downloader.py index 52521e8..2031e6f 100644 --- a/paddlepalm/_downloader.py +++ b/paddlepalm/_downloader.py @@ -76,8 +76,8 @@ def _download(item, scope, path, silent=False): report_hook(bytes_so_far, total_size) return bytes_so_far - response = urlopen(data_url) - _chunk_read(response, data_url, report_hook=_chunk_report) + # response = urlopen(data_url) + # _chunk_read(response, data_url, report_hook=_chunk_report) if not silent: print(' done!') diff --git a/paddlepalm/reader/cls.py b/paddlepalm/reader/cls.py index 1ecf6cb..5c28c84 100644 --- a/paddlepalm/reader/cls.py +++ b/paddlepalm/reader/cls.py @@ -62,19 +62,19 @@ class Reader(reader): @property def outputs_attr(self): if self._is_training: - return {"token_ids": [[-1, -1, 1], 'int64'], - "position_ids": [[-1, -1, 1], 'int64'], - "segment_ids": [[-1, -1, 1], 'int64'], - "input_mask": [[-1, -1, 1], 'float32'], - "label_ids": [[-1,1], 'int64'], - "task_ids": [[-1, -1, 1], 'int64'] + return {"token_ids": [[-1, -1], 'int64'], + "position_ids": [[-1, -1], 'int64'], + "segment_ids": [[-1, -1], 'int64'], + "input_mask": [[-1, -1], 'float32'], + "label_ids": [[-1], 'int64'], + "task_ids": [[-1, -1], 'int64'] } else: - return {"token_ids": [[-1, -1, 1], 'int64'], - "position_ids": [[-1, -1, 1], 'int64'], - "segment_ids": [[-1, -1, 1], 'int64'], - "task_ids": [[-1, -1, 1], 'int64'], - "input_mask": [[-1, -1, 1], 'float32'] + return {"token_ids": [[-1, -1], 'int64'], + "position_ids": [[-1, -1], 'int64'], + "segment_ids": [[-1, -1], 'int64'], + "task_ids": [[-1, -1], 'int64'], + "input_mask": [[-1, -1], 'float32'] } diff --git a/paddlepalm/reader/match.py b/paddlepalm/reader/match.py index 965b830..4520c35 100644 --- a/paddlepalm/reader/match.py +++ b/paddlepalm/reader/match.py @@ -60,19 +60,19 @@ class Reader(reader): @property def outputs_attr(self): if self._is_training: - return {"token_ids": [[-1, -1, 1], 'int64'], - "position_ids": [[-1, -1, 1], 'int64'], - "segment_ids": [[-1, -1, 1], 'int64'], - "input_mask": [[-1, -1, 1], 'float32'], - "label_ids": [[-1,1], 'int64'], - "task_ids": [[-1, -1, 1], 'int64'] + return {"token_ids": [[-1, -1], 'int64'], + "position_ids": [[-1, -1], 'int64'], + "segment_ids": [[-1, -1], 'int64'], + "input_mask": [[-1, -1], 'float32'], + "label_ids": [[-1], 'int64'], + "task_ids": [[-1, -1], 'int64'] } else: - return {"token_ids": [[-1, -1, 1], 'int64'], - "position_ids": [[-1, -1, 1], 'int64'], - "segment_ids": [[-1, -1, 1], 'int64'], - "task_ids": [[-1, -1, 1], 'int64'], - "input_mask": [[-1, -1, 1], 'float32'] + return {"token_ids": [[-1, -1], 'int64'], + "position_ids": [[-1, -1], 'int64'], + "segment_ids": [[-1, -1], 'int64'], + "task_ids": [[-1, -1], 'int64'], + "input_mask": [[-1, -1], 'float32'] } diff --git a/paddlepalm/reader/utils/reader4ernie.py b/paddlepalm/reader/utils/reader4ernie.py index 7048bf6..a57a747 100644 --- a/paddlepalm/reader/utils/reader4ernie.py +++ b/paddlepalm/reader/utils/reader4ernie.py @@ -920,7 +920,7 @@ class MRCReader(BaseReader): batch_unique_ids = [record.unique_id for record in batch_records] batch_unique_ids = np.array(batch_unique_ids).astype("int64").reshape( - [-1, 1]) + [-1]) # padding padded_token_ids, input_mask = pad_batch_data( diff --git a/paddlepalm/utils/reader_helper.py b/paddlepalm/utils/reader_helper.py index 7ba4ba7..09d5c0c 100644 --- a/paddlepalm/utils/reader_helper.py +++ b/paddlepalm/utils/reader_helper.py @@ -218,7 +218,7 @@ def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batc names = [] start = 0 if insert_taskid: - ret.append(([1,1], 'int64')) + ret.append(([1], 'int64')) names.append('__task_id') start += 1 -- GitLab