mrc.py 4.6 KB
Newer Older
X
xixiaoyao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# -*- coding: UTF-8 -*-
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import MRCReader
X
xixiaoyao 已提交
18
import numpy as np
X
xixiaoyao 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33

class Reader(reader):
    
    def __init__(self, config, phase='train', dev_count=1, print_prefix=''):
        """
        Args:
            phase: train, eval, pred
            """

        self._is_training = phase == 'train'

        reader = MRCReader(config['vocab_path'],
            max_seq_len=config['max_seq_len'],
            do_lower_case=config.get('do_lower_case', False),
            tokenizer='FullTokenizer',
X
xixiaoyao 已提交
34
            for_cn=config.get('for_cn', False),
X
xixiaoyao 已提交
35
            doc_stride=config['doc_stride'],
X
Xiaoyao Xi 已提交
36
            remove_noanswer=config.get('remove_noanswer', True),
X
xixiaoyao 已提交
37 38 39 40 41 42 43 44 45 46 47
            max_query_length=config['max_query_len'],
            random_seed=config.get('seed', None))
        self._reader = reader
        self._dev_count = dev_count

        self._batch_size = config['batch_size']
        self._max_seq_len = config['max_seq_len']
        if phase == 'train':
            self._input_file = config['train_file']
            # self._num_epochs = config['num_epochs']
            self._num_epochs = None # 防止iteartor终止
X
xixiaoyao 已提交
48
            self._shuffle = config.get('shuffle', True)
X
xixiaoyao 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
            self._shuffle_buffer = config.get('shuffle_buffer', 5000)
        if phase == 'eval':
            self._input_file = config['dev_file']
            self._num_epochs = 1
            self._shuffle = False
            self._batch_size = config.get('pred_batch_size', self._batch_size)
        elif phase == 'pred':
            self._input_file = config['pred_file']
            self._num_epochs = 1
            self._shuffle = False
            self._batch_size = config.get('pred_batch_size', self._batch_size)

        self._phase = phase
        # self._batch_size = 
        self._print_first_n = config.get('print_first_n', 1)

        # TODO: without slide window version
        self._with_slide_window = config.get('with_slide_window', False)


    @property
    def outputs_attr(self):
        if self._is_training:
W
wangxiao 已提交
72 73 74
            return {"token_ids": [[-1, -1], 'int64'],
                    "position_ids": [[-1, -1], 'int64'],
                    "segment_ids": [[-1, -1], 'int64'],
X
xixiaoyao 已提交
75
                    "input_mask": [[-1, -1, 1], 'float32'],
W
wangxiao 已提交
76 77 78
                    "start_positions": [[-1], 'int64'],
                    "end_positions": [[-1], 'int64'],
                    "task_ids": [[-1, -1], 'int64']
X
xixiaoyao 已提交
79 80
                    }
        else:
W
wangxiao 已提交
81 82 83 84
            return {"token_ids": [[-1, -1], 'int64'],
                    "position_ids": [[-1, -1], 'int64'],
                    "segment_ids": [[-1, -1], 'int64'],
                    "task_ids": [[-1, -1], 'int64'],
X
xixiaoyao 已提交
85
                    "input_mask": [[-1, -1, 1], 'float32'],
W
wangxiao 已提交
86
                    "unique_ids": [[-1], 'int64']
X
xixiaoyao 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
                    }

    @property
    def epoch_outputs_attr(self):
        if not self._is_training:
            return {"examples": None,
                    "features": None}

    def load_data(self):
        self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase)

    def iterator(self): 

        def list_to_dict(x):
            names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask', 
                'start_positions', 'end_positions', 'unique_ids']
            outputs = {n: i for n,i in zip(names, x)}
            if self._is_training:
                del outputs['unique_ids']
            else:
                del outputs['start_positions']
                del outputs['end_positions']
            return outputs

        for batch in self._data_generator():
X
xixiaoyao 已提交
112
            print(len(list_to_dict(batch)))
X
xixiaoyao 已提交
113 114 115 116 117 118 119 120 121 122
            yield list_to_dict(batch)

    def get_epoch_outputs(self):
        return {'examples': self._reader.get_examples(self._phase),
                'features': self._reader.get_features(self._phase)}

    @property
    def num_examples(self):
        return self._reader.get_num_examples(phase=self._phase)