cls.py 3.5 KB
Newer Older
X
xixiaoyao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# -*- coding: UTF-8 -*-
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

X
xixiaoyao 已提交
16 17 18 19 20
from paddlepalm.reader.base_reader import BaseReader
from paddlepalm.reader.utils.reader4ernie import ClassifyReader as CLSReader


class ClassifyReader(BaseReader):
X
xixiaoyao 已提交
21
    
X
xixiaoyao 已提交
22 23
    def __init__(self, vocab_path, max_len, tokenizer='wordpiece', \
             lang='en', seed=None, do_lower_case=False, phase='train'):
X
xixiaoyao 已提交
24 25 26 27 28 29 30
        """xxxxxx.

        Argument:
          - vocab_path: xxxx
          -

        """
X
xixiaoyao 已提交
31

X
xixiaoyao 已提交
32
        BaseReader.__init__(self, phase)
X
xixiaoyao 已提交
33

X
xixiaoyao 已提交
34
        assert lang.lower() in ['en', 'cn', 'english', 'chinese'], "supported language: en (English), cn (Chinese)."
X
xixiaoyao 已提交
35
        assert phase in ['train', 'predict'], "supported phase: train, predict."
X
xixiaoyao 已提交
36

X
xixiaoyao 已提交
37
        for_cn = lang.lower() == 'cn' or lang.lower() == 'chinese'
X
xixiaoyao 已提交
38

X
xixiaoyao 已提交
39
        self._register.add('token_ids')
X
xixiaoyao 已提交
40
        if phase == 'train':
X
xixiaoyao 已提交
41 42 43 44 45 46 47 48 49 50
            self._register.add('label_ids')

        self._is_training = phase == 'train'

        cls_reader = CLSReader(vocab_path,
                                max_seq_len=max_len,
                                do_lower_case=do_lower_case,
                                for_cn=for_cn,
                                random_seed=seed)
        self._reader = cls_reader
X
xixiaoyao 已提交
51 52 53

        self._phase = phase
        # self._batch_size = 
X
xixiaoyao 已提交
54
        # self._print_first_n = config.get('print_first_n', 0)
X
xixiaoyao 已提交
55 56 57 58


    @property
    def outputs_attr(self):
X
xixiaoyao 已提交
59 60 61 62 63 64 65 66 67 68
        attrs = {"token_ids": [[-1, -1], 'int64'],
                "position_ids": [[-1, -1], 'int64'],
                "segment_ids": [[-1, -1], 'int64'],
                "input_mask": [[-1, -1, 1], 'float32'],
                "label_ids": [[-1], 'int64'],
                "task_ids": [[-1, -1], 'int64']
                }
        return self._get_registed_attrs(attrs)


X
xixiaoyao 已提交
69
    def load_data(self, input_file, batch_size, num_epochs=None, \
X
xixiaoyao 已提交
70
                  file_format='csv', shuffle_train=True):
X
xixiaoyao 已提交
71 72 73
        self._batch_size = batch_size
        self._num_epochs = num_epochs
        self._data_generator = self._reader.data_generator( \
X
fix cls  
xixiaoyao 已提交
74
            input_file, batch_size, num_epochs if self._phase == 'train' else 1, \
X
xixiaoyao 已提交
75
            shuffle=shuffle_train if self._phase == 'train' else False, \
X
xixiaoyao 已提交
76 77 78 79 80 81
            phase=self._phase)

    def _iterator(self): 

        names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask', 
            'label_ids', 'unique_ids']
X
xixiaoyao 已提交
82
        for batch in self._data_generator():
X
xixiaoyao 已提交
83
            outputs = {n: i for n,i in zip(names, batch)}
X
xixiaoyao 已提交
84
            ret = {}
X
xixiaoyao 已提交
85 86 87 88
            # TODO: move runtime shape check here
            for attr in self.outputs_attr.keys():
                ret[attr] = outputs[attr]
            yield ret
X
xixiaoyao 已提交
89 90 91 92 93 94 95 96 97

    def get_epoch_outputs(self):
        return {'examples': self._reader.get_examples(self._phase),
                'features': self._reader.get_features(self._phase)}

    @property
    def num_examples(self):
        return self._reader.get_num_examples(phase=self._phase)

X
xixiaoyao 已提交
98 99 100 101
    @property
    def num_epochs(self):
        return self._num_epochs

X
xixiaoyao 已提交
102