dataset.py 9.4 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
from . import core
D
dongdaxiang 已提交
18
__all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset']
D
dongdaxiang 已提交
19 20 21


class DatasetFactory(object):
22 23 24 25 26 27 28 29
    """
    DatasetFactory is a factory which create dataset by its name,
    you can create "QueueDataset" or "InMemoryDataset",
    the default is "QueueDataset".

    Example:
        dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
    """
D
dongdaxiang 已提交
30

D
dongdaxiang 已提交
31
    def __init__(self):
32 33 34
        """
        Init
        """
D
dongdaxiang 已提交
35 36
        pass

37
    def create_dataset(self, datafeed_class="QueueDataset"):
38 39 40
        """
        Create "QueueDataset" or "InMemoryDataset",
        the default is "QueueDataset".
D
dongdaxiang 已提交
41 42 43 44

        Examples:
            import paddle.fluid as fluid
            dataset = fluid.DatasetFactory().create_dataset()
45
        """
D
dongdaxiang 已提交
46 47
        try:
            dataset = globals()[datafeed_class]()
48
            return dataset
D
dongdaxiang 已提交
49 50 51 52 53 54
        except:
            raise ValueError("datafeed class %s does not exist" %
                             datafeed_class)


class DatasetBase(object):
55 56 57
    """
    Base dataset class
    """
D
dongdaxiang 已提交
58

D
dongdaxiang 已提交
59
    def __init__(self):
60 61 62
        """
        Init
        """
D
dongdaxiang 已提交
63 64 65 66
        # define class name here
        # to decide whether we need create in memory instance
        self.proto_desc = data_feed_pb2.DataFeedDesc()
        self.proto_desc.pipe_command = "cat"
X
xujiaqi01 已提交
67
        self.dataset = core.Dataset("MultiSlotDataset")
68
        self.thread_num = 0
D
dongdaxiang 已提交
69 70 71 72 73 74

    def set_pipe_command(self, pipe_command):
        """
        Set pipe command of current dataset
        A pipe command is a UNIX pipeline command that can be used only

75 76 77 78 79 80
        Example:
            >>> dataset.set_pipe_command("python my_script.py")

        Args:
            pipe_command: pipe command

D
dongdaxiang 已提交
81 82 83 84 85 86 87 88
        """
        self.proto_desc.pipe_command = pipe_command

    def set_batch_size(self, batch_size):
        """
        Set batch size. Will be effective during training

        Example:
89
            >>> dataset.set_batch_size(128)
D
dongdaxiang 已提交
90 91 92 93 94 95 96

        Args:
            batch_size: batch size

        """
        self.proto_desc.batch_size = batch_size

97
    def set_thread(self, thread_num):
98 99 100 101 102 103 104 105 106
        """
        Set thread num, it is the num of readers.

        Example:
            >>> dataset.set_thread(12)

        Args:
            thread_num: thread num
        """
107
        self.dataset.set_thread_num(thread_num)
108
        self.thread_num = thread_num
109 110

    def set_filelist(self, filelist):
111 112 113 114 115 116 117 118 119
        """
        Set file list in current worker.

        Example:
            >>> dataset.set_filelist(['a.txt', 'b.txt'])

        Args:
            filelist: file list
        """
120 121
        self.dataset.set_filelist(filelist)

D
dongdaxiang 已提交
122
    def set_use_var(self, var_list):
123 124 125 126 127 128 129 130 131
        """
        Set Variables which you will use.

        Example:
            >>> dataset.set_use_var([data, label])

        Args:
            var_list: variable list
        """
132
        multi_slot = self.proto_desc.multi_slot_desc
D
dongdaxiang 已提交
133
        for var in var_list:
134
            slot_var = multi_slot.slots.add()
D
dongdaxiang 已提交
135 136 137 138
            slot_var.is_used = True
            slot_var.name = var.name
            if var.lod_level == 0:
                slot_var.is_dense = True
139
                slot_var.shape.extend(var.shape)
140
            if var.dtype == core.VarDesc.VarType.FP32:
D
dongdaxiang 已提交
141
                slot_var.type = "float"
142
            elif var.dtype == core.VarDesc.VarType.INT64:
D
dongdaxiang 已提交
143 144 145 146 147 148
                slot_var.type = "uint64"
            else:
                raise ValueError(
                    "Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
                )

149
    def set_hdfs_config(self, fs_name, fs_ugi):
150 151 152 153 154 155 156 157 158 159
        """
        Set hdfs config: fs name ad ugi

        Example:
            >>> dataset.set_hdfs_config("my_fs_name", "my_fs_ugi")

        Args:
            fs_name: fs name
            fs_ugi: fs ugi
        """
160 161
        self.dataset.set_hdfs_config(fs_name, fs_ugi)

162
    def _prepare_to_run(self):
163 164 165 166
        """
        Set data_feed_desc before load or shuffle,
        user no need to call this function.
        """
167 168
        self.dataset.set_data_feed_desc(self.desc())

D
dongdaxiang 已提交
169 170 171 172 173
    def desc(self):
        """
        Returns a protobuf message for this DataFeedDesc

        Example:
174
            >>> print(dataset.desc())
D
dongdaxiang 已提交
175 176 177 178 179 180 181 182

        Returns:
            A string message
        """
        return text_format.MessageToString(self.proto_desc)


class InMemoryDataset(DatasetBase):
183 184
    """
    InMemoryDataset, it will load data into memory
D
dongdaxiang 已提交
185 186
    and shuffle data before training.
    This class should be created by DatasetFactory
187 188 189 190

    Example:
        dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
    """
D
dongdaxiang 已提交
191

D
dongdaxiang 已提交
192
    def __init__(self):
193 194 195
        """
        Init
        """
196 197 198 199
        super(InMemoryDataset, self).__init__()
        self.proto_desc.name = "MultiSlotInMemoryDataFeed"

    def load_into_memory(self):
200 201 202 203
        """
        Load data into memory

        Example:
D
dongdaxiang 已提交
204 205 206 207
            >>> import paddle.fluid as fluid
            >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
            >>> filelist = ["a.txt", "b.txt"]
            >>> dataset.set_filelist(filelist)
208 209
            >>> dataset.load_into_memory()
        """
210
        self._prepare_to_run()
211
        self.dataset.load_into_memory()
D
dongdaxiang 已提交
212 213

    def local_shuffle(self):
214 215 216 217
        """
        Local shuffle

        Example:
D
dongdaxiang 已提交
218 219 220 221
            >>> import paddle.fluid as fluid
            >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
            >>> filelist = ["a.txt", "b.txt"]
            >>> dataset.set_filelist(filelist)
222
            >>> dataset.load_into_memory()
223 224
            >>> dataset.local_shuffle()
        """
225
        self.dataset.local_shuffle()
D
dongdaxiang 已提交
226

227
    def global_shuffle(self, fleet=None):
228 229
        """
        Global shuffle.
230 231 232
        Global shuffle can be used only in distributed mode. i.e. multiple
        processes on single machine or multiple machines training together.
        If you run in distributed mode, you should pass fleet instead of None.
233

234
        Examples:
D
dongdaxiang 已提交
235
            >>> import paddle.fluid as fluid
236
            >>> from paddle.fluid.incubate.fleet.pslib import fleet
D
dongdaxiang 已提交
237 238 239
            >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
            >>> filelist = ["a.txt", "b.txt"]
            >>> dataset.set_filelist(filelist)
240
            >>> dataset.load_into_memory()
241 242 243 244 245
            >>> dataset.global_shuffle(fleet)

        Args:
            fleet: fleet singleton. Default None.
        """
246
        trainer_num = 1
X
xjqbest 已提交
247
        fleet_send_batch_size = 80000
248
        if fleet is not None:
X
xjqbest 已提交
249
            fleet.fleet_instance.role_maker_._barrier_worker()
250
            trainer_num = fleet.worker_num()
251
        self.dataset.register_client2client_msg_handler()
252
        self.dataset.set_trainer_num(trainer_num)
X
xjqbest 已提交
253
        self.dataset.set_fleet_send_batch_size(fleet_send_batch_size)
254
        if fleet is not None:
X
xjqbest 已提交
255
            fleet.fleet_instance.role_maker_._barrier_worker()
X
xujiaqi01 已提交
256
        self.dataset.global_shuffle()
257
        if fleet is not None:
X
xjqbest 已提交
258
            fleet.fleet_instance.role_maker_._barrier_worker()
D
dongdaxiang 已提交
259

260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
    def release_memory(self):
        """
        Release InMemoryDataset memory data, when data will not be used again.

        Example:
            >>> import paddle.fluid as fluid
            >>> import paddle.fluid.incubate.fleet.parameter_server as fleet
            >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
            >>> filelist = ["a.txt", "b.txt"]
            >>> dataset.set_filelist(filelist)
            >>> dataset.load_into_memory()
            >>> dataset.global_shuffle(fleet)
            >>> exe = fluid.Executor(fluid.CPUPlace())
            >>> exe.run(fluid.default_startup_program())
            >>> exe.train_from_dataset(fluid.default_main_program(), dataset)
            >>> dataset.release_memory()
        """
        self.dataset.release_memory()
D
dongdaxiang 已提交
278

X
xjqbest 已提交
279

D
dongdaxiang 已提交
280
class QueueDataset(DatasetBase):
281 282 283 284
    """
    QueueDataset, it will process data streamly.

    Example:
D
dongdaxiang 已提交
285 286
        import paddle.fluid as fluid
        dataset = fluid.DatasetFactory.create_dataset("QueueDataset")
287
    """
D
dongdaxiang 已提交
288

D
dongdaxiang 已提交
289
    def __init__(self):
290
        """
D
dongdaxiang 已提交
291 292
        Initialize QueueDataset
        This class should be created by DatasetFactory
293
        """
294
        super(QueueDataset, self).__init__()
D
dongdaxiang 已提交
295
        self.proto_desc.name = "MultiSlotDataFeed"
X
xujiaqi01 已提交
296 297

    def local_shuffle(self):
298 299
        """
        Local shuffle
D
dongdaxiang 已提交
300

D
dongdaxiang 已提交
301 302
        Local shuffle is not supported in QueueDataset
        NotImplementedError will be raised
303
        """
D
dongdaxiang 已提交
304 305 306
        raise NotImplementedError(
            "QueueDataset does not support local shuffle, "
            "please use InMemoryDataset for local_shuffle")
X
xujiaqi01 已提交
307

308
    def global_shuffle(self, fleet=None):
309
        """
D
dongdaxiang 已提交
310 311
        Global shuffle is not supported in QueueDataset
        NotImplementedError will be raised
312
        """
D
dongdaxiang 已提交
313 314 315
        raise NotImplementedError(
            "QueueDataset does not support global shuffle, "
            "please use InMemoryDataset for global_shuffle")