dataset.py 8.4 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
from . import core
D
dongdaxiang 已提交
18
__all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset']
D
dongdaxiang 已提交
19 20 21


class DatasetFactory(object):
22 23 24 25 26 27 28 29
    """
    DatasetFactory is a factory which create dataset by its name,
    you can create "QueueDataset" or "InMemoryDataset",
    the default is "QueueDataset".

    Example:
        dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
    """
D
dongdaxiang 已提交
30

D
dongdaxiang 已提交
31
    def __init__(self):
32 33 34
        """
        Init
        """
D
dongdaxiang 已提交
35 36
        pass

37
    def create_dataset(self, datafeed_class="QueueDataset"):
38 39 40
        """
        Create "QueueDataset" or "InMemoryDataset",
        the default is "QueueDataset".
D
dongdaxiang 已提交
41 42 43 44

        Examples:
            import paddle.fluid as fluid
            dataset = fluid.DatasetFactory().create_dataset()
45
        """
D
dongdaxiang 已提交
46 47
        try:
            dataset = globals()[datafeed_class]()
48
            return dataset
D
dongdaxiang 已提交
49 50 51 52 53 54
        except:
            raise ValueError("datafeed class %s does not exist" %
                             datafeed_class)


class DatasetBase(object):
55 56 57
    """
    Base dataset class
    """
D
dongdaxiang 已提交
58

D
dongdaxiang 已提交
59
    def __init__(self):
60 61 62
        """
        Init
        """
D
dongdaxiang 已提交
63 64 65 66
        # define class name here
        # to decide whether we need create in memory instance
        self.proto_desc = data_feed_pb2.DataFeedDesc()
        self.proto_desc.pipe_command = "cat"
X
xujiaqi01 已提交
67
        self.dataset = core.Dataset("MultiSlotDataset")
68
        self.thread_num = 0
D
dongdaxiang 已提交
69 70 71 72 73 74

    def set_pipe_command(self, pipe_command):
        """
        Set pipe command of current dataset
        A pipe command is a UNIX pipeline command that can be used only

75 76 77 78 79 80
        Example:
            >>> dataset.set_pipe_command("python my_script.py")

        Args:
            pipe_command: pipe command

D
dongdaxiang 已提交
81 82 83 84 85 86 87 88
        """
        self.proto_desc.pipe_command = pipe_command

    def set_batch_size(self, batch_size):
        """
        Set batch size. Will be effective during training

        Example:
89
            >>> dataset.set_batch_size(128)
D
dongdaxiang 已提交
90 91 92 93 94 95 96

        Args:
            batch_size: batch size

        """
        self.proto_desc.batch_size = batch_size

97
    def set_thread(self, thread_num):
98 99 100 101 102 103 104 105 106
        """
        Set thread num, it is the num of readers.

        Example:
            >>> dataset.set_thread(12)

        Args:
            thread_num: thread num
        """
107
        self.dataset.set_thread_num(thread_num)
108
        self.thread_num = thread_num
109 110

    def set_filelist(self, filelist):
111 112 113 114 115 116 117 118 119
        """
        Set file list in current worker.

        Example:
            >>> dataset.set_filelist(['a.txt', 'b.txt'])

        Args:
            filelist: file list
        """
120 121
        self.dataset.set_filelist(filelist)

D
dongdaxiang 已提交
122
    def set_use_var(self, var_list):
123 124 125 126 127 128 129 130 131
        """
        Set Variables which you will use.

        Example:
            >>> dataset.set_use_var([data, label])

        Args:
            var_list: variable list
        """
132
        multi_slot = self.proto_desc.multi_slot_desc
D
dongdaxiang 已提交
133
        for var in var_list:
134
            slot_var = multi_slot.slots.add()
D
dongdaxiang 已提交
135 136 137 138
            slot_var.is_used = True
            slot_var.name = var.name
            if var.lod_level == 0:
                slot_var.is_dense = True
139
            if var.dtype == core.VarDesc.VarType.FP32:
D
dongdaxiang 已提交
140
                slot_var.type = "float"
141
            elif var.dtype == core.VarDesc.VarType.INT64:
D
dongdaxiang 已提交
142 143 144 145 146 147
                slot_var.type = "uint64"
            else:
                raise ValueError(
                    "Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
                )

148
    def set_hdfs_config(self, fs_name, fs_ugi):
149 150 151 152 153 154 155 156 157 158
        """
        Set hdfs config: fs name ad ugi

        Example:
            >>> dataset.set_hdfs_config("my_fs_name", "my_fs_ugi")

        Args:
            fs_name: fs name
            fs_ugi: fs ugi
        """
159 160
        self.dataset.set_hdfs_config(fs_name, fs_ugi)

161
    def _prepare_to_run(self):
162 163 164 165
        """
        Set data_feed_desc before load or shuffle,
        user no need to call this function.
        """
166 167
        self.dataset.set_data_feed_desc(self.desc())

D
dongdaxiang 已提交
168 169 170 171 172
    def desc(self):
        """
        Returns a protobuf message for this DataFeedDesc

        Example:
173
            >>> print(dataset.desc())
D
dongdaxiang 已提交
174 175 176 177 178 179 180 181

        Returns:
            A string message
        """
        return text_format.MessageToString(self.proto_desc)


class InMemoryDataset(DatasetBase):
182 183
    """
    InMemoryDataset, it will load data into memory
D
dongdaxiang 已提交
184 185
    and shuffle data before training.
    This class should be created by DatasetFactory
186 187 188 189

    Example:
        dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
    """
D
dongdaxiang 已提交
190

D
dongdaxiang 已提交
191
    def __init__(self):
192 193 194
        """
        Init
        """
195 196 197 198
        super(InMemoryDataset, self).__init__()
        self.proto_desc.name = "MultiSlotInMemoryDataFeed"

    def load_into_memory(self):
199 200 201 202
        """
        Load data into memory

        Example:
D
dongdaxiang 已提交
203 204 205 206
            >>> import paddle.fluid as fluid
            >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
            >>> filelist = ["a.txt", "b.txt"]
            >>> dataset.set_filelist(filelist)
207 208
            >>> dataset.load_into_memory()
        """
209
        self._prepare_to_run()
210
        self.dataset.load_into_memory()
D
dongdaxiang 已提交
211 212

    def local_shuffle(self):
213 214 215 216
        """
        Local shuffle

        Example:
D
dongdaxiang 已提交
217 218 219 220
            >>> import paddle.fluid as fluid
            >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
            >>> filelist = ["a.txt", "b.txt"]
            >>> dataset.set_filelist(filelist)
221 222
            >>> dataset.local_shuffle()
        """
223
        self.dataset.local_shuffle()
D
dongdaxiang 已提交
224

225
    def global_shuffle(self, fleet=None):
226 227
        """
        Global shuffle.
228 229 230
        Global shuffle can be used only in distributed mode. i.e. multiple
        processes on single machine or multiple machines training together.
        If you run in distributed mode, you should pass fleet instead of None.
231

232
        Examples:
D
dongdaxiang 已提交
233 234 235 236 237
            >>> import paddle.fluid as fluid
            >>> import paddle.fluid.incubate.fleet.parameter_server as fleet
            >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
            >>> filelist = ["a.txt", "b.txt"]
            >>> dataset.set_filelist(filelist)
238 239 240 241 242
            >>> dataset.global_shuffle(fleet)

        Args:
            fleet: fleet singleton. Default None.
        """
243 244
        trainer_num = 1
        if fleet is not None:
X
xjqbest 已提交
245
            fleet.fleet_instance.role_maker_._barrier_worker()
246
            trainer_num = fleet.worker_num()
247
        self.dataset.register_client2client_msg_handler()
248
        self.dataset.set_trainer_num(trainer_num)
249
        if fleet is not None:
X
xjqbest 已提交
250
            fleet.fleet_instance.role_maker_._barrier_worker()
X
xujiaqi01 已提交
251
        self.dataset.global_shuffle()
252
        if fleet is not None:
X
xjqbest 已提交
253
            fleet.fleet_instance.role_maker_._barrier_worker()
D
dongdaxiang 已提交
254 255 256


class QueueDataset(DatasetBase):
257 258 259 260
    """
    QueueDataset, it will process data streamly.

    Example:
D
dongdaxiang 已提交
261 262
        import paddle.fluid as fluid
        dataset = fluid.DatasetFactory.create_dataset("QueueDataset")
263
    """
D
dongdaxiang 已提交
264

D
dongdaxiang 已提交
265
    def __init__(self):
266
        """
D
dongdaxiang 已提交
267 268
        Initialize QueueDataset
        This class should be created by DatasetFactory
269
        """
270
        super(QueueDataset, self).__init__()
D
dongdaxiang 已提交
271
        self.proto_desc.name = "MultiSlotDataFeed"
X
xujiaqi01 已提交
272 273

    def local_shuffle(self):
274 275
        """
        Local shuffle
D
dongdaxiang 已提交
276

D
dongdaxiang 已提交
277 278
        Local shuffle is not supported in QueueDataset
        NotImplementedError will be raised
279
        """
D
dongdaxiang 已提交
280 281 282
        raise NotImplementedError(
            "QueueDataset does not support local shuffle, "
            "please use InMemoryDataset for local_shuffle")
X
xujiaqi01 已提交
283

284
    def global_shuffle(self, fleet=None):
285
        """
D
dongdaxiang 已提交
286 287
        Global shuffle is not supported in QueueDataset
        NotImplementedError will be raised
288
        """
D
dongdaxiang 已提交
289 290 291
        raise NotImplementedError(
            "QueueDataset does not support global shuffle, "
            "please use InMemoryDataset for global_shuffle")