mlm.py 3.5 KB
Newer Older
X
xixiaoyao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# -*- coding: UTF-8 -*-
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wangxiao1021 已提交
16 17
from paddlepalm.reader.base_reader import Reader
from paddlepalm.reader.utils.reader4ernie import MaskLMReader as MLMReader
X
xixiaoyao 已提交
18 19
import numpy as np

W
wangxiao1021 已提交
20
class MaskLMReader(Reader):
X
xixiaoyao 已提交
21
    
W
wangxiao1021 已提交
22 23
    def __init__(self, vocab_path, max_len, tokenizer='wordpiece', \
             lang='en', seed=None, do_lower_case=False, phase='train', dev_count=1, print_prefix=''):
X
xixiaoyao 已提交
24 25 26
        """
        Args:
            phase: train, eval, pred
W
wangxiao1021 已提交
27
        """
X
xixiaoyao 已提交
28 29


W
wangxiao1021 已提交
30
        Reader.__init__(self, phase)
X
xixiaoyao 已提交
31

W
wangxiao1021 已提交
32 33 34 35 36 37
        assert lang.lower() in ['en', 'cn', 'english', 'chinese'], "supported language: en (English), cn (Chinese)."
        assert phase in ['train', 'predict'], "supported phase: train, predict."

        for_cn = lang.lower() == 'cn' or lang.lower() == 'chinese'

        self._register.add('mask_pos')
X
xixiaoyao 已提交
38
        if phase == 'train':
W
wangxiao1021 已提交
39 40 41 42 43 44 45 46 47
            self._register.add('mask_label')
        self._is_training = phase == 'train'

        mlm_reader = MLMReader(vocab_path,
                                max_seq_len=max_len,
                                do_lower_case=do_lower_case,
                                for_cn=for_cn,
                                random_seed=seed)
        self._reader = mlm_reader
X
xixiaoyao 已提交
48 49

        self._phase = phase
W
wangxiao1021 已提交
50
        self._dev_count = dev_count
X
xixiaoyao 已提交
51 52 53 54


    @property
    def outputs_attr(self):
W
wangxiao1021 已提交
55
        attrs = {"token_ids": [[-1, -1], 'int64'],
X
xixiaoyao 已提交
56 57 58 59 60
                "position_ids": [[-1, -1], 'int64'],
                "segment_ids": [[-1, -1], 'int64'],
                "input_mask": [[-1, -1, 1], 'float32'],
                "task_ids": [[-1, -1], 'int64'],
                "mask_label": [[-1], 'int64'],
W
wangxiao1021 已提交
61
                "mask_pos": [[-1], 'int64']
X
xixiaoyao 已提交
62 63
                }

W
wangxiao1021 已提交
64
        return self._get_registed_attrs(attrs)
X
xixiaoyao 已提交
65 66


W
wangxiao1021 已提交
67 68 69 70 71 72 73 74
    def load_data(self, input_file, batch_size, num_epochs=None, \
                  file_format='csv', shuffle_train=True):
        self._batch_size = batch_size
        self._num_epochs = num_epochs
        self._data_generator = self._reader.data_generator( \
            input_file, batch_size, num_epochs if self._phase == 'train' else 1, \
            shuffle=shuffle_train if self._phase == 'train' else False, \
            phase=self._phase)
X
xixiaoyao 已提交
75

W
wangxiao1021 已提交
76
    def _iterator(self): 
X
xixiaoyao 已提交
77

W
wangxiao1021 已提交
78 79
        names = ['token_ids', 'position_ids', 'segment_ids', 'input_mask', 
            'task_ids', 'mask_label', 'mask_pos']
X
xixiaoyao 已提交
80
        for batch in self._data_generator():
W
wangxiao1021 已提交
81 82 83 84 85 86 87
            outputs = {n: i for n,i in zip(names, batch)}
            ret = {}
            # TODO: move runtime shape check here
            for attr in self.outputs_attr.keys():
                ret[attr] = outputs[attr]

            yield ret
X
xixiaoyao 已提交
88 89 90 91 92 93 94 95 96

    def get_epoch_outputs(self):
        return {'examples': self._reader.get_examples(self._phase),
                'features': self._reader.get_features(self._phase)}

    @property
    def num_examples(self):
        return self._reader.get_num_examples(phase=self._phase)

W
wangxiao1021 已提交
97 98 99 100
    @property
    def num_epochs(self):
        return self._num_epochs