base_reader.py 5.6 KB
Newer Older
X
xixiaoyao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# -*- coding: UTF-8 -*-
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
X
Xiaoyao Xi 已提交
15

X
xixiaoyao 已提交
16
from copy import copy
W
wangxiao1021 已提交
17
class Reader(object):
X
Xiaoyao Xi 已提交
18
    """interface of data reader."""
X
xixiaoyao 已提交
19

X
xixiaoyao 已提交
20
    def __init__(self, phase='train'):
X
Xiaoyao Xi 已提交
21 22 23 24 25 26
        """该函数完成一个Reader的构造,至少需要包含一个phase参数。
        注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。
        Args:
            phase: str类型。用于区分主干网络被调用时所处的运行阶段,目前支持训练阶段train和预测阶段predict
            """
        
X
xixiaoyao 已提交
27
        self._phase = phase
W
wangxiao1021 已提交
28 29
        self._batch_size = None
        self._num_epochs = 1
X
xixiaoyao 已提交
30 31
        self._register = set()
        self._registered_backbone = None
X
xixiaoyao 已提交
32

X
xixiaoyao 已提交
33 34 35 36 37
    @classmethod
    def create_register(self):
        return set()
        
    def clone(self, phase='train'):
X
Xiaoyao Xi 已提交
38
        """拷贝一个新的reader对象。"""
X
xixiaoyao 已提交
39 40 41 42 43 44
        if phase == self._phase:
            return copy(self)
        else:
            ret = copy(self)
            ret._phase = phase
            return ret
X
xixiaoyao 已提交
45 46

    def require_attr(self, attr_name):
X
Xiaoyao Xi 已提交
47 48 49 50 51
        """在注册器中新增一个需要产生的对象。

        Args:
            attr_name: 需要产出的对象的对象名,例如’segment_ids‘。
            """
X
xixiaoyao 已提交
52
        self._register.add(attr_name)
X
xixiaoyao 已提交
53
            
X
xixiaoyao 已提交
54
    def register_with(self, backbone):
X
Xiaoyao Xi 已提交
55 56 57 58 59
        """根据backbone对输入对象的依赖,在注册器中对每个依赖的输入对象进行注册。

        Args:
            backbone: 需要对接的主干网络。
        """
X
xixiaoyao 已提交
60 61 62 63 64
        for attr in backbone.inputs_attr:
            self.require_attr(attr)
        self._registered_backbone = backbone

    def get_registered_backbone(self):
X
Xiaoyao Xi 已提交
65
        """返回该reader所注册的backbone。"""
X
xixiaoyao 已提交
66
        return self._registered_backbone
X
xixiaoyao 已提交
67

X
xixiaoyao 已提交
68 69 70 71 72 73 74
    def _get_registed_attrs(self, attrs):
        ret = {}
        for i in self._register:
            if i not in attrs:
                raise NotImplementedError('output attr {} is not found in this reader.'.format(i))
            ret[i] = attrs[i]
        return ret
X
xixiaoyao 已提交
75

X
Xiaoyao Xi 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    def load_data(self, input_file, batch_size, num_epochs=None, \
                  file_format='tsv', shuffle_train=True):
        """Load data into reader. 

        Noted that it requires the creation of self._batch_size and self._num_epochs when this method implemented.

        Args:
            input_file: the dataset file path. File format should meet the requirement of `file_format` argument.
            batch_size: number of examples for once yield. CAUSIOUS! If your environment exists multiple GPU devices
                (marked as dev_count), the batch_size should be divided by dev_count with no remainder!
            num_epochs: the travelsal times of input examples. Default is None, means once for single-task learning 
                and automatically calculated for multi-task learning. This argument only works on train phase.
            file_format: the file format of input file. Supported format: tsv. Default is tsv.
            shuffle_train: whether to shuffle training dataset. Default is True. This argument only works on training phase.
        """
        raise NotImplementedError()
X
xixiaoyao 已提交
92 93 94

    @property
    def outputs_attr(self):
X
Xiaoyao Xi 已提交
95 96
        """描述reader输出对象(被yield出的对象)的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据
        类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
X
xixiaoyao 已提交
97 98 99 100 101 102 103 104 105 106 107 108
        注意:当使用mini-batch梯度下降学习策略时,,应为常规的输入对象设置batch_size维度(一般为-1)
        Return:
            dict类型。对各个输入对象的属性描述。例如,
            对于文本分类和匹配任务,yield的输出内容可能包含如下的对象(下游backbone和task可按需访问其中的对象)
                {"token_ids": ([-1, max_len], 'int64'),
                 "input_ids": ([-1, max_len], 'int64'),
                 "segment_ids": ([-1, max_len], 'int64'),
                 "input_mask": ([-1, max_len], 'float32'),
                 "label": ([-1], 'int')}
        """
        raise NotImplementedError()
    
X
Xiaoyao Xi 已提交
109
    def _iterator(self):
X
xixiaoyao 已提交
110 111
        """数据集遍历接口,注意,当数据集遍历到尾部时该接口应自动完成指针重置,即重新从数据集头部开始新的遍历。
        Yield:
X
Xiaoyao Xi 已提交
112
            dict类型。符合outputs_attr描述的当前step的输出对象。
X
xixiaoyao 已提交
113 114 115
        """
        raise NotImplementedError()

X
Xiaoyao Xi 已提交
116 117 118 119
    def get_epoch_outputs(self):
        """返回数据集每个epoch遍历后的输出对象。"""
        raise NotImplementedError()

X
xixiaoyao 已提交
120 121
    @property
    def num_examples(self):
X
Xiaoyao Xi 已提交
122 123
        """数据集中的样本数量,即每个epoch中iterator所生成的样本数。注意,使用滑动窗口等可能导致数据集样本数发生变化的策略时
        该接口应返回runtime阶段的实际样本数。"""
X
xixiaoyao 已提交
124 125
        raise NotImplementedError()

W
wangxiao1021 已提交
126 127
    @property
    def num_epochs(self):
X
Xiaoyao Xi 已提交
128 129
        """数据集遍历次数"""
        return self._num_epochs