cls.py 3.2 KB
Newer Older
X
xixiaoyao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# -*- coding: UTF-8 -*-
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

X
xixiaoyao 已提交
16 17 18 19 20
from paddlepalm.reader.base_reader import BaseReader
from paddlepalm.reader.utils.reader4ernie import ClassifyReader as CLSReader


class ClassifyReader(BaseReader):
X
xixiaoyao 已提交
21
    
X
xixiaoyao 已提交
22 23
    def __init__(self, vocab_path, max_len, tokenizer='wordpiece', \
             lang='en', seed=None, do_lower_case=False, phase='train'):
X
xixiaoyao 已提交
24

X
xixiaoyao 已提交
25
        BaseReader.__init__(self, phase)
X
xixiaoyao 已提交
26

X
xixiaoyao 已提交
27 28
        assert lang.lower() in ['en', 'cn', 'english', 'chinese'], "supported language: en (English), cn (Chinese)."
        assert phase in ['train', 'pred'], "supported phase: train, pred."
X
xixiaoyao 已提交
29

X
xixiaoyao 已提交
30
        for_cn = lang.lower() == 'cn' or lang.lower() == 'chinese'
X
xixiaoyao 已提交
31

X
xixiaoyao 已提交
32
        self._register.add('token_ids')
X
xixiaoyao 已提交
33
        if phase == 'train':
X
xixiaoyao 已提交
34 35 36 37 38 39 40 41 42 43
            self._register.add('label_ids')

        self._is_training = phase == 'train'

        cls_reader = CLSReader(vocab_path,
                                max_seq_len=max_len,
                                do_lower_case=do_lower_case,
                                for_cn=for_cn,
                                random_seed=seed)
        self._reader = cls_reader
X
xixiaoyao 已提交
44 45 46

        self._phase = phase
        # self._batch_size = 
X
xixiaoyao 已提交
47
        # self._print_first_n = config.get('print_first_n', 0)
X
xixiaoyao 已提交
48 49 50 51


    @property
    def outputs_attr(self):
X
xixiaoyao 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
        attrs = {"token_ids": [[-1, -1], 'int64'],
                "position_ids": [[-1, -1], 'int64'],
                "segment_ids": [[-1, -1], 'int64'],
                "input_mask": [[-1, -1, 1], 'float32'],
                "label_ids": [[-1], 'int64'],
                "task_ids": [[-1, -1], 'int64']
                }
        return self._get_registed_attrs(attrs)


    def _load_data(self, input_file, batch_size, num_epochs=None, \
                  file_format='csv', shuffle_train=True):
        self._data_generator = self._reader.data_generator(input_file, batch_size, \
            num_epochs, shuffle=shuffle_train if self._phase == 'train' else False, \
            phase=self._phase)

    def _iterator(self): 

        names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask', 
            'label_ids', 'unique_ids']
X
xixiaoyao 已提交
72
        for batch in self._data_generator():
X
xixiaoyao 已提交
73 74 75 76 77 78
            outputs = {n: i for n,i in zip(names, batch)}
            ret = []
            # TODO: move runtime shape check here
            for attr in self.outputs_attr.keys():
                ret[attr] = outputs[attr]
            yield ret
X
xixiaoyao 已提交
79 80 81 82 83 84 85 86 87

    def get_epoch_outputs(self):
        return {'examples': self._reader.get_examples(self._phase),
                'features': self._reader.get_features(self._phase)}

    @property
    def num_examples(self):
        return self._reader.get_num_examples(phase=self._phase)

X
xixiaoyao 已提交
88