diff --git a/paddlepalm/reader/mrc.py b/paddlepalm/reader/mrc.py index 6cac89adca1b8244d271cab4605b6a834a7faa37..ea4e72678bf02b7d0d4fd625c7986f63c2d3f177 100644 --- a/paddlepalm/reader/mrc.py +++ b/paddlepalm/reader/mrc.py @@ -68,21 +68,21 @@ class Reader(reader): @property def outputs_attr(self): if self._is_training: - return {"token_ids": [[-1, -1, 1], 'int64'], - "position_ids": [[-1, -1, 1], 'int64'], - "segment_ids": [[-1, -1, 1], 'int64'], - "input_mask": [[-1, -1, 1], 'float32'], - "start_positions": [[-1, 1], 'int64'], - "end_positions": [[-1, 1], 'int64'], - "task_ids": [[-1, -1, 1], 'int64'] + return {"token_ids": [[-1, -1], 'int64'], + "position_ids": [[-1, -1], 'int64'], + "segment_ids": [[-1, -1], 'int64'], + "input_mask": [[-1, -1], 'float32'], + "start_positions": [[-1], 'int64'], + "end_positions": [[-1], 'int64'], + "task_ids": [[-1, -1], 'int64'] } else: - return {"token_ids": [[-1, -1, 1], 'int64'], - "position_ids": [[-1, -1, 1], 'int64'], - "segment_ids": [[-1, -1, 1], 'int64'], - "task_ids": [[-1, -1, 1], 'int64'], - "input_mask": [[-1, -1, 1], 'float32'], - "unique_ids": [[-1, 1], 'int64'] + return {"token_ids": [[-1, -1], 'int64'], + "position_ids": [[-1, -1], 'int64'], + "segment_ids": [[-1, -1], 'int64'], + "task_ids": [[-1, -1], 'int64'], + "input_mask": [[-1, -1], 'float32'], + "unique_ids": [[-1], 'int64'] } @property