dataset.py 55.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This is definition of dataset class, which is high performance IO."""

from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
import paddle.fluid.core as core

20 21
__all__ = []

22

23
class DatasetBase:
24
    """Base dataset class."""
25 26

    def __init__(self):
27
        """Init."""
28 29 30 31 32 33 34
        # define class name here
        # to decide whether we need create in memory instance
        self.proto_desc = data_feed_pb2.DataFeedDesc()
        self.proto_desc.pipe_command = "cat"
        self.dataset = core.Dataset("MultiSlotDataset")
        self.thread_num = 1
        self.filelist = []
35
        self.use_ps_gpu = False
36
        self.psgpu = None
37

38 39 40 41 42 43 44 45 46 47 48
    def init(
        self,
        batch_size=1,
        thread_num=1,
        use_var=[],
        pipe_command="cat",
        input_type=0,
        fs_name="",
        fs_ugi="",
        download_cmd="cat",
    ):
49
        """
50
        should be called only once in user's python scripts to initialize setings of dataset instance.
51
        Normally, it is called by InMemoryDataset or QueueDataset.
52 53

        Args:
54 55 56 57
            batch_size(int): batch size. It will be effective during training. default is 1.
            thread_num(int): thread num, it is the num of readers. default is 1.
            use_var(list): list of variables. Variables which you will use. default is [].
            pipe_command(str): pipe command of current dataset. A pipe command is a UNIX pipeline command that can be used only. default is "cat"
58
            input_type(int): the input type of generated input. 0 is for one sample, 1 is for one batch. default is 0.
59 60 61
            fs_name(str): fs name. default is "".
            fs_ugi(str): fs ugi. default is "".
            download_cmd(str): customized download command. default is "cat"
62 63 64


        """
65 66 67 68 69 70 71
        self._set_batch_size(batch_size)
        self._set_thread(thread_num)
        self._set_use_var(use_var)
        self._set_pipe_command(pipe_command)
        self._set_input_type(input_type)
        self._set_hdfs_config(fs_name, fs_ugi)
        self._set_download_cmd(download_cmd)
72

73
    def _set_pipe_command(self, pipe_command):
74
        """
75 76
        Set pipe command of current dataset
        A pipe command is a UNIX pipeline command that can be used only
77 78 79 80

        Examples:
            .. code-block:: python

81 82 83
              import paddle
              dataset = paddle.distributed.fleet.dataset.DatasetBase()
              dataset._set_pipe_command("python my_script.py")
84 85

        Args:
86
            pipe_command(str): pipe command
87 88

        """
89
        self.proto_desc.pipe_command = pipe_command
90

91
    def _set_batch_size(self, batch_size):
92 93 94 95 96 97
        """
        Set batch size. Will be effective during training

        Examples:
            .. code-block:: python

98 99 100
              import paddle
              dataset = paddle.distributed.fleet.DatasetBase()
              dataset._set_batch_size(128)
101 102 103 104 105 106 107

        Args:
            batch_size(int): batch size

        """
        self.proto_desc.batch_size = batch_size

108
    def _set_thread(self, thread_num):
109 110 111 112 113 114
        """
        Set thread num, it is the num of readers.

        Examples:
            .. code-block:: python

115 116 117
              import paddle
              dataset = paddle.distributed.fleet.DatasetBase()
              dataset._set_thread(12)
118 119 120 121 122 123 124 125 126

        Args:
            thread_num(int): thread num
        """
        self.dataset.set_thread_num(thread_num)
        self.thread_num = thread_num

    def set_filelist(self, filelist):
        """
127
        Set file list in current worker. The filelist is indicated by a list of file names (string).
128 129 130 131

        Examples:
            .. code-block:: python

132 133
              import paddle
              dataset = paddle.distributed.fleet.DatasetBase()
134 135 136
              dataset.set_filelist(['a.txt', 'b.txt'])

        Args:
137
            filelist(list[str]): list of file names of inputs.
138 139 140 141
        """
        self.dataset.set_filelist(filelist)
        self.filelist = filelist

142
    def _set_input_type(self, input_type):
143 144
        self.proto_desc.input_type = input_type

145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
    def _set_uid_slot(self, uid_slot):
        """
        Set user slot name.

        Examples:
            .. code-block:: python

              import paddle
              dataset = paddle.distributed.fleet.DatasetBase()
              dataset._set_uid_slot('6048')

        Args:
            set_uid_slot(string): user slot name
        """
        multi_slot = self.proto_desc.multi_slot_desc
        multi_slot.uid_slot = uid_slot

162
    def _set_use_var(self, var_list):
163 164 165 166 167 168
        """
        Set Variables which you will use.

        Examples:
            .. code-block:: python

169 170 171
              import paddle
              dataset = paddle.distributed.fleet.DatasetBase()
              dataset._set_use_var([data, label])
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189

        Args:
            var_list(list): variable list
        """
        multi_slot = self.proto_desc.multi_slot_desc
        for var in var_list:
            slot_var = multi_slot.slots.add()
            slot_var.is_used = True
            slot_var.name = var.name
            if var.lod_level == 0:
                slot_var.is_dense = True
                slot_var.shape.extend(var.shape)
            if var.dtype == core.VarDesc.VarType.FP32:
                slot_var.type = "float"
            elif var.dtype == core.VarDesc.VarType.INT64:
                slot_var.type = "uint64"
            else:
                raise ValueError(
190
                    "Currently, paddle.distributed.fleet.dataset only supports dtype=float32 and dtype=int64"
191 192
                )

193
    def _set_hdfs_config(self, fs_name, fs_ugi):
194 195 196 197 198 199
        """
        Set hdfs config: fs name ad ugi

        Examples:
            .. code-block:: python

200 201 202
              import paddle
              dataset = paddle.distributed.fleet.DatasetBase()
              dataset._set_hdfs_config("my_fs_name", "my_fs_ugi")
203 204 205 206 207 208 209

        Args:
            fs_name(str): fs name
            fs_ugi(str): fs ugi
        """
        self.dataset.set_hdfs_config(fs_name, fs_ugi)

210
    def _set_download_cmd(self, download_cmd):
211 212 213 214 215 216
        """
        Set customized download cmd: download_cmd

        Examples:
            .. code-block:: python

217 218 219
              import paddle
              dataset = paddle.distributed.fleet.DatasetBase()
              dataset._set_download_cmd("./read_from_afs")
220 221 222 223 224 225 226 227 228 229 230 231 232 233

        Args:
            download_cmd(str): customized download command
        """
        self.dataset.set_download_cmd(download_cmd)

    def _prepare_to_run(self):
        """
        Set data_feed_desc before load or shuffle,
        user no need to call this function.
        """
        if self.thread_num > len(self.filelist):
            self.thread_num = len(self.filelist)
        self.dataset.set_thread_num(self.thread_num)
234
        self.dataset.set_data_feed_desc(self._desc())
235 236
        self.dataset.create_readers()

237 238 239 240 241 242 243 244
    def _set_use_ps_gpu(self, use_ps_gpu):
        """
        set use_ps_gpu flag

        Args:
            use_ps_gpu: bool
        """
        self.use_ps_gpu = use_ps_gpu
245 246 247 248 249
        # if not defined heterps with paddle, users will not use psgpu
        if not core._is_compiled_with_heterps():
            self.use_ps_gpu = 0
        elif self.use_ps_gpu:
            self.psgpu = core.PSGPU()
250

251 252 253
    def _finish_to_run(self):
        self.dataset.destroy_readers()

254
    def _desc(self):
255 256 257 258 259 260
        """
        Returns a protobuf message for this DataFeedDesc

        Examples:
            .. code-block:: python

261 262 263
              import paddle
              dataset = paddle.distributed.fleet.DatasetBase()
              print(dataset._desc())
264 265 266 267 268 269 270 271 272 273 274 275

        Returns:
            A string message
        """
        return text_format.MessageToString(self.proto_desc)

    def _dynamic_adjust_before_train(self, thread_num):
        pass

    def _dynamic_adjust_after_train(self):
        pass

276 277 278
    def _check_use_var_with_data_generator(
        self, var_list, data_generator_class, test_file
    ):
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
        """
         Var consistency insepection of use_var_list and data_generator data.

        Examples:
            .. code-block:: python

              # required: skiptest
              import paddle
              from dataset_generator import CTRDataset
              dataset = paddle.distributed.fleet.DatasetBase()
              generator_class = CTRDataset()
              dataset._check_use_var_with_data_generator([data, label], generator_class, "data/part-00000")

        Args:
            var_list(list): variable list
            data_generator_class(class): data_generator class
            test_file(str): local test file path
        """

        f = open(test_file, "r")
        var_len = len(var_list)

        while True:
            line = f.readline()
            if line:
                line_iter = data_generator_class.generate_sample(line)
                for user_parsed_line in line_iter():
                    data_gen_len = len(user_parsed_line)
                    if var_len != data_gen_len:
                        raise ValueError(
                            "var length mismatch error: var_list = %s vs data_generator = %s"
310 311
                            % (var_len, data_gen_len)
                        )
312 313 314 315 316

                    for i, ele in enumerate(user_parsed_line):
                        if len(ele[1]) == 0:
                            raise ValueError(
                                "var length error: var %s's length in data_generator is 0"
317 318
                                % ele[0]
                            )
319 320

                        if var_list[
321 322 323 324
                            i
                        ].dtype == core.VarDesc.VarType.FP32 and not all(
                            isinstance(ele, float) for ele in ele[1]
                        ):
325 326 327 328
                            raise TypeError(
                                "var dtype mismatch error: var name = %s, var type in var_list = %s, while var in data_generator contains non-float value, which is %s \n"
                                "Please check if order of var_list and data_generator are aligned. \n"
                                "Please check if var's type in data_generator is correct."
329 330
                                % (ele[0], "float", ele[1])
                            )
331

332 333 334 335
                        if (
                            var_list[i].dtype == core.VarDesc.VarType.INT64
                            or var_list[i].dtype == core.VarDesc.VarType.INT32
                        ) and not all(isinstance(ele, int) for ele in ele[1]):
336 337 338 339
                            raise TypeError(
                                "var dtype mismatch error: var name = %s, var type in var_list = %s, while var in data_generator contains non-int value, which is %s \n"
                                "Please check if order of var_list and data_generator are aligned. \n"
                                "Please check if var's type in data_generator is correct."
340 341
                                % (ele[0], "int", ele[1])
                            )
342 343 344 345 346 347

            else:
                break

        f.close()

348 349 350

class InMemoryDataset(DatasetBase):
    """
351
    :api_attr: Static Graph
352

S
ShenLiang 已提交
353
    It will load data into memory and shuffle data before training.
354

S
ShenLiang 已提交
355 356 357 358 359 360
    Examples:
        .. code-block:: python

            import paddle
            paddle.enable_static()
            dataset = paddle.distributed.InMemoryDataset()
361 362 363 364

    """

    def __init__(self):
365
        """Init."""
366
        super().__init__()
367 368 369 370 371 372 373 374 375 376 377 378
        self.proto_desc.name = "MultiSlotInMemoryDataFeed"
        self.fleet_send_batch_size = None
        self.is_user_set_queue_num = False
        self.queue_num = None
        self.parse_ins_id = False
        self.parse_content = False
        self.parse_logkey = False
        self.merge_by_sid = True
        self.enable_pv_merge = False
        self.merge_by_lineid = False
        self.fleet_send_sleep_seconds = None

379 380
    def _init_distributed_settings(self, **kwargs):
        """
381 382
        :api_attr: Static Graph

383 384 385 386
        should be called only once in user's python scripts to initialize distributed-related setings of dataset instance
        Args:
            kwargs: Keyword arguments. Currently, we support following keys in **kwargs:

387 388
            merge_size(int): ins size to merge, if merge_size > 0, set merge by line id,
                             instances of same line id will be merged after shuffle,
389 390 391 392 393 394 395 396 397 398 399 400 401
                             you should parse line id in data generator. default is -1.
            parse_ins_id(bool): Set if Dataset need to parse ins_id. default is False.
            parse_content(bool): Set if Dataset need to parse content. default is False.
            fleet_send_batch_size(int): Set fleet send batch size in one rpc, default is 1024
            fleet_send_sleep_seconds(int): Set fleet send sleep time, default is 0
            fea_eval(bool): Set if Dataset need to do feature importance evaluation using slots shuffle.
                            default is False.
            candidate_size(int): if fea_eval is set True, set the candidate size used in slots shuffle.

        Examples:
            .. code-block:: python

              import paddle
S
ShenLiang 已提交
402
              paddle.enable_static()
403 404 405 406 407 408 409 410 411 412 413 414
              dataset = paddle.distributed.InMemoryDataset()
              dataset.init(
                    batch_size=1,
                    thread_num=2,
                    input_type=1,
                    pipe_command="cat",
                    use_var=[])
              dataset._init_distributed_settings(
                    parse_ins_id=True,
                    parse_content=True,
                    fea_eval=True,
                    candidate_size=10000)
415

416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441
        """
        merge_size = kwargs.get("merge_size", -1)
        if merge_size > 0:
            self._set_merge_by_lineid(merge_size)

        parse_ins_id = kwargs.get("parse_ins_id", False)
        self._set_parse_ins_id(parse_ins_id)

        parse_content = kwargs.get("parse_content", False)
        self._set_parse_content(parse_content)

        fleet_send_batch_size = kwargs.get("fleet_send_batch_size", None)
        if fleet_send_batch_size:
            self._set_fleet_send_batch_size(fleet_send_batch_size)

        fleet_send_sleep_seconds = kwargs.get("fleet_send_sleep_seconds", None)
        if fleet_send_sleep_seconds:
            self._set_fleet_send_sleep_seconds(fleet_send_sleep_seconds)

        fea_eval = kwargs.get("fea_eval", False)
        if fea_eval:
            candidate_size = kwargs.get("candidate_size", 10000)
            self._set_fea_eval(candidate_size, True)

    def update_settings(self, **kwargs):
        """
442 443
        :api_attr: Static Graph

S
ShenLiang 已提交
444 445
        should be called in user's python scripts to update setings of dataset instance.

446 447 448 449 450 451
        Args:
            kwargs: Keyword arguments. Currently, we support following keys in **kwargs,
                    including single node settings and advanced distributed related settings:
            batch_size(int): batch size. It will be effective during training. default is 1.
            thread_num(int): thread num, it is the num of readers. default is 1.
            use_var(list): list of variables. Variables which you will use. default is [].
452
            input_type(int): the input type of generated input. 0 is for one sample, 1 is for one batch. default is 0.
453 454 455 456 457 458 459
            fs_name(str): fs name. default is "".
            fs_ugi(str): fs ugi. default is "".
            pipe_command(str): pipe command of current dataset. A pipe command is a UNIX pipeline command that can be used only. default is "cat"
            download_cmd(str): customized download command. default is "cat"
            data_feed_type(str): data feed type used in c++ code. default is "MultiSlotInMemoryDataFeed".
            queue_num(int): Dataset output queue num, training threads get data from queues. default is-1, which is set same as thread number in c++.

460 461
            merge_size(int): ins size to merge, if merge_size > 0, set merge by line id,
                             instances of same line id will be merged after shuffle,
462 463 464 465 466 467 468 469 470 471 472 473
                             you should parse line id in data generator. default is -1.
            parse_ins_id(bool): Set if Dataset need to parse ins_id. default is False.
            parse_content(bool): Set if Dataset need to parse content. default is False.
            fleet_send_batch_size(int): Set fleet send batch size in one rpc, default is 1024
            fleet_send_sleep_seconds(int): Set fleet send sleep time, default is 0
            fea_eval(bool): Set if Dataset need to do feature importance evaluation using slots shuffle.
                            default is False.
            candidate_size(int): if fea_eval is set True, set the candidate size used in slots shuffle.

        Examples:
            .. code-block:: python

474
                import paddle
S
ShenLiang 已提交
475 476 477 478
                paddle.enable_static()

                dataset = paddle.distributed.InMemoryDataset()
                dataset.init(
479 480 481 482 483
                    batch_size=1,
                    thread_num=2,
                    input_type=1,
                    pipe_command="cat",
                    use_var=[])
S
ShenLiang 已提交
484
                dataset._init_distributed_settings(
485 486 487 488
                    parse_ins_id=True,
                    parse_content=True,
                    fea_eval=True,
                    candidate_size=10000)
S
ShenLiang 已提交
489
                dataset.update_settings(batch_size=2)
490

491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516
        """
        for key in kwargs:
            if key == "pipe_command":
                self._set_pipe_command(kwargs[key])
            elif key == "batch_size":
                self._set_batch_size(kwargs[key])
            elif key == "thread_num":
                self._set_thread(kwargs[key])
            elif key == "use_var":
                self._set_use_var(kwargs[key])
            elif key == "input_type":
                self._set_input_type(kwargs[key])
            elif key == "fs_name" and "fs_ugi" in kwargs:
                self._set_hdfs_config(kwargs[key], kwargs["fs_ugi"])
            elif key == "download_cmd":
                self._set_download_cmd(kwargs[key])
            elif key == "merge_size" and kwargs.get("merge_size", -1) > 0:
                self._set_merge_by_lineid(kwargs[key])
            elif key == "parse_ins_id":
                self._set_parse_ins_id(kwargs[key])
            elif key == "parse_content":
                self._set_parse_content(kwargs[key])
            elif key == "fleet_send_batch_size":
                self._set_fleet_send_batch_size(kwargs[key])
            elif key == "fleet_send_sleep_seconds":
                self._set_fleet_send_sleep_seconds(kwargs[key])
517
            elif key == "fea_eval" and kwargs[key]:
518 519 520 521 522
                candidate_size = kwargs.get("candidate_size", 10000)
                self._set_fea_eval(candidate_size, True)

    def init(self, **kwargs):
        """
523 524
        :api_attr: Static Graph

525
        should be called only once in user's python scripts to initialize setings of dataset instance
526

527 528
        Args:
            kwargs: Keyword arguments. Currently, we support following keys in **kwargs:
529

530 531 532
            batch_size(int): batch size. It will be effective during training. default is 1.
            thread_num(int): thread num, it is the num of readers. default is 1.
            use_var(list): list of variables. Variables which you will use. default is [].
533
            input_type(int): the input type of generated input. 0 is for one sample, 1 is for one batch. default is 0.
534 535 536 537 538 539 540 541 542 543 544
            fs_name(str): fs name. default is "".
            fs_ugi(str): fs ugi. default is "".
            pipe_command(str): pipe command of current dataset. A pipe command is a UNIX pipeline command that can be used only. default is "cat"
            download_cmd(str): customized download command. default is "cat"
            data_feed_type(str): data feed type used in c++ code. default is "MultiSlotInMemoryDataFeed".
            queue_num(int): Dataset output queue num, training threads get data from queues. default is -1, which is set same as thread number in c++.

        Examples:
            .. code-block:: python

                import paddle
S
ShenLiang 已提交
545 546 547
                import os
                paddle.enable_static()

548
                with open("test_queue_dataset_run_a.txt", "w") as f:
S
ShenLiang 已提交
549
                    data = "2 1 2 2 5 4 2 2 7 2 1 3"
550 551
                    f.write(data)
                with open("test_queue_dataset_run_b.txt", "w") as f:
S
ShenLiang 已提交
552
                    data = "2 1 2 2 5 4 2 2 7 2 1 3"
553 554 555 556 557
                    f.write(data)

                slots = ["slot1", "slot2", "slot3", "slot4"]
                slots_vars = []
                for slot in slots:
S
ShenLiang 已提交
558
                    var = paddle.static.data(
559 560 561 562 563 564 565 566 567 568 569 570 571
                        name=slot, shape=[None, 1], dtype="int64", lod_level=1)
                    slots_vars.append(var)

                dataset = paddle.distributed.InMemoryDataset()
                dataset.init(
                    batch_size=1,
                    thread_num=2,
                    input_type=1,
                    pipe_command="cat",
                    use_var=slots_vars)
                dataset.set_filelist(
                    ["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"])
                dataset.load_into_memory()
572

S
ShenLiang 已提交
573
                place = paddle.CPUPlace()
574 575 576 577 578 579
                exe = paddle.static.Executor(place)
                startup_program = paddle.static.Program()
                main_program = paddle.static.Program()
                exe.run(startup_program)

                exe.train_from_dataset(main_program, dataset)
580

581 582
                os.remove("./test_queue_dataset_run_a.txt")
                os.remove("./test_queue_dataset_run_b.txt")
S
ShenLiang 已提交
583

584 585 586 587 588 589 590 591 592 593
        """
        batch_size = kwargs.get("batch_size", 1)
        thread_num = kwargs.get("thread_num", 1)
        use_var = kwargs.get("use_var", [])
        input_type = kwargs.get("input_type", 0)
        fs_name = kwargs.get("fs_name", "")
        fs_ugi = kwargs.get("fs_ugi", "")
        pipe_command = kwargs.get("pipe_command", "cat")
        download_cmd = kwargs.get("download_cmd", "cat")

594 595 596 597 598 599
        if self.use_ps_gpu:
            data_feed_type = "SlotRecordInMemoryDataFeed"
        else:
            data_feed_type = "MultiSlotInMemoryDataFeed"
        self._set_feed_type(data_feed_type)

600
        super().init(
601 602 603 604 605 606 607 608 609
            batch_size=batch_size,
            thread_num=thread_num,
            use_var=use_var,
            pipe_command=pipe_command,
            input_type=input_type,
            fs_name=fs_name,
            fs_ugi=fs_ugi,
            download_cmd=download_cmd,
        )
610 611 612 613 614 615

        if kwargs.get("queue_num", -1) > 0:
            queue_num = kwargs.get("queue_num", -1)
            self._set_queue_num(queue_num)

    def _set_feed_type(self, data_feed_type):
616 617 618 619
        """
        Set data_feed_desc
        """
        self.proto_desc.name = data_feed_type
620
        if self.proto_desc.name == "SlotRecordInMemoryDataFeed":
621
            self.dataset = core.Dataset("SlotRecordDataset")
622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638

    def _prepare_to_run(self):
        """
        Set data_feed_desc before load or shuffle,
        user no need to call this function.
        """
        if self.thread_num <= 0:
            self.thread_num = 1
        self.dataset.set_thread_num(self.thread_num)
        if self.queue_num is None:
            self.queue_num = self.thread_num
        self.dataset.set_queue_num(self.queue_num)
        self.dataset.set_parse_ins_id(self.parse_ins_id)
        self.dataset.set_parse_content(self.parse_content)
        self.dataset.set_parse_logkey(self.parse_logkey)
        self.dataset.set_merge_by_sid(self.merge_by_sid)
        self.dataset.set_enable_pv_merge(self.enable_pv_merge)
639
        self.dataset.set_data_feed_desc(self._desc())
640 641 642 643 644
        self.dataset.create_channel()
        self.dataset.create_readers()

    def _dynamic_adjust_before_train(self, thread_num):
        if not self.is_user_set_queue_num:
645 646 647 648
            if self.use_ps_gpu:
                self.dataset.dynamic_adjust_channel_num(thread_num, True)
            else:
                self.dataset.dynamic_adjust_channel_num(thread_num, False)
649 650 651 652
        self.dataset.dynamic_adjust_readers_num(thread_num)

    def _dynamic_adjust_after_train(self):
        if not self.is_user_set_queue_num:
653 654 655 656
            if self.use_ps_gpu:
                self.dataset.dynamic_adjust_channel_num(self.thread_num, True)
            else:
                self.dataset.dynamic_adjust_channel_num(self.thread_num, False)
657 658
        self.dataset.dynamic_adjust_readers_num(self.thread_num)

659
    def _set_queue_num(self, queue_num):
660 661 662 663 664 665 666 667 668
        """
        Set Dataset output queue num, training threads get data from queues

        Args:
            queue_num(int): dataset output queue num

        Examples:
            .. code-block:: python

669
              import paddle
S
ShenLiang 已提交
670
              paddle.enable_static()
671 672
              dataset = paddle.distributed.InMemoryDataset()
              dataset._set_queue_num(12)
673 674 675 676 677

        """
        self.is_user_set_queue_num = True
        self.queue_num = queue_num

678
    def _set_parse_ins_id(self, parse_ins_id):
679
        """
680
        Set if Dataset need to parse insid
681 682 683 684 685 686 687

        Args:
            parse_ins_id(bool): if parse ins_id or not

        Examples:
            .. code-block:: python

688
              import paddle
S
ShenLiang 已提交
689
              paddle.enable_static()
690 691
              dataset = paddle.distributed.InMemoryDataset()
              dataset._set_parse_ins_id(True)
692 693 694 695

        """
        self.parse_ins_id = parse_ins_id

696
    def _set_parse_content(self, parse_content):
697 698 699 700 701 702 703 704 705
        """
        Set if Dataset need to parse content

        Args:
            parse_content(bool): if parse content or not

        Examples:
            .. code-block:: python

706
              import paddle
S
ShenLiang 已提交
707
              paddle.enable_static()
708 709
              dataset = paddle.distributed.InMemoryDataset()
              dataset._set_parse_content(True)
710 711 712 713

        """
        self.parse_content = parse_content

714
    def _set_fleet_send_batch_size(self, fleet_send_batch_size=1024):
715 716 717 718 719 720 721 722 723
        """
        Set fleet send batch size, default is 1024

        Args:
            fleet_send_batch_size(int): fleet send batch size

        Examples:
            .. code-block:: python

724
              import paddle
S
ShenLiang 已提交
725
              paddle.enable_static()
726 727
              dataset = paddle.distributed.InMemoryDataset()
              dataset._set_fleet_send_batch_size(800)
728 729 730 731

        """
        self.fleet_send_batch_size = fleet_send_batch_size

732
    def _set_fleet_send_sleep_seconds(self, fleet_send_sleep_seconds=0):
733 734 735 736 737 738 739 740 741
        """
        Set fleet send sleep time, default is 0

        Args:
            fleet_send_sleep_seconds(int): fleet send sleep time

        Examples:
            .. code-block:: python

742
              import paddle
S
ShenLiang 已提交
743
              paddle.enable_static()
744 745
              dataset = paddle.distributed.InMemoryDataset()
              dataset._set_fleet_send_sleep_seconds(2)
746 747 748 749

        """
        self.fleet_send_sleep_seconds = fleet_send_sleep_seconds

750
    def _set_merge_by_lineid(self, merge_size=2):
751 752 753 754 755 756 757 758 759 760
        """
        Set merge by line id, instances of same line id will be merged after
        shuffle, you should parse line id in data generator.

        Args:
            merge_size(int): ins size to merge. default is 2.

        Examples:
            .. code-block:: python

761
              import paddle
S
ShenLiang 已提交
762
              paddle.enable_static()
763 764
              dataset = paddle.distributed.InMemoryDataset()
              dataset._set_merge_by_lineid()
765 766 767 768 769 770

        """
        self.dataset.set_merge_by_lineid(merge_size)
        self.merge_by_lineid = True
        self.parse_ins_id = True

771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787
    def _set_shuffle_by_uid(self, enable_shuffle_uid):
        """
        Set if Dataset need to shuffle by uid.

        Args:
            set_shuffle_by_uid(bool): if shuffle according to uid or not

        Examples:
            .. code-block:: python

              import paddle
              paddle.enable_static()
              dataset = paddle.distributed.InMemoryDataset()
              dataset._set_shuffle_by_uid(True)
        """
        self.dataset.set_shuffle_by_uid(enable_shuffle_uid)

788
    def _set_generate_unique_feasigns(self, generate_uni_feasigns, shard_num):
789 790 791 792
        self.dataset.set_generate_unique_feasigns(generate_uni_feasigns)
        self.gen_uni_feasigns = generate_uni_feasigns
        self.local_shard_num = shard_num

793 794 795 796 797 798
    def _generate_local_tables_unlock(
        self, table_id, fea_dim, read_thread_num, consume_thread_num, shard_num
    ):
        self.dataset.generate_local_tables_unlock(
            table_id, fea_dim, read_thread_num, consume_thread_num, shard_num
        )
799

800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835
    def set_date(self, date):
        """
        :api_attr: Static Graph

        Set training date for pull sparse parameters, saving and loading model. Only used in psgpu

        Args:
            date(str): training date(format : YYMMDD). eg.20211111

        Examples:
            .. code-block:: python

                import paddle
                paddle.enable_static()

                dataset = paddle.distributed.InMemoryDataset()
                slots = ["slot1", "slot2", "slot3", "slot4"]
                slots_vars = []
                for slot in slots:
                    var = paddle.static.data(
                        name=slot, shape=[None, 1], dtype="int64", lod_level=1)
                    slots_vars.append(var)
                dataset.init(
                    batch_size=1,
                    thread_num=2,
                    input_type=1,
                    pipe_command="cat",
                    use_var=slots_vars)
                dataset.set_date("20211111")
        """
        year = int(date[:4])
        month = int(date[4:6])
        day = int(date[6:])
        if self.use_ps_gpu and core._is_compiled_with_heterps():
            self.psgpu.set_date(year, month, day)

836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854
    def tdm_sample(
        self,
        tree_name,
        tree_path,
        tdm_layer_counts,
        start_sample_layer,
        with_hierachy,
        seed,
        id_slot,
    ):
        self.dataset.tdm_sample(
            tree_name,
            tree_path,
            tdm_layer_counts,
            start_sample_layer,
            with_hierachy,
            seed,
            id_slot,
        )
W
wangzhen38 已提交
855

856
    def load_into_memory(self, is_shuffle=False):
857
        """
858
        :api_attr: Static Graph
859

860 861
        Load data into memory

862 863 864
        Args:
            is_shuffle(bool): whether to use local shuffle, default is False

865 866 867
        Examples:
            .. code-block:: python

S
ShenLiang 已提交
868 869
                import paddle
                paddle.enable_static()
870

S
ShenLiang 已提交
871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886
                dataset = paddle.distributed.InMemoryDataset()
                slots = ["slot1", "slot2", "slot3", "slot4"]
                slots_vars = []
                for slot in slots:
                    var = paddle.static.data(
                        name=slot, shape=[None, 1], dtype="int64", lod_level=1)
                    slots_vars.append(var)
                dataset.init(
                    batch_size=1,
                    thread_num=2,
                    input_type=1,
                    pipe_command="cat",
                    use_var=slots_vars)
                filelist = ["a.txt", "b.txt"]
                dataset.set_filelist(filelist)
                dataset.load_into_memory()
887 888
        """
        self._prepare_to_run()
889 890 891 892 893
        if not self.use_ps_gpu:
            self.dataset.load_into_memory()
        elif core._is_compiled_with_heterps():
            self.psgpu.set_dataset(self.dataset)
            self.psgpu.load_into_memory(is_shuffle)
894 895 896

    def preload_into_memory(self, thread_num=None):
        """
897 898
        :api_attr: Static Graph

899 900 901 902 903 904 905 906
        Load data into memory in async mode

        Args:
            thread_num(int): preload thread num

        Examples:
            .. code-block:: python

S
ShenLiang 已提交
907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926
                import paddle
                paddle.enable_static()

                dataset = paddle.distributed.InMemoryDataset()
                slots = ["slot1", "slot2", "slot3", "slot4"]
                slots_vars = []
                for slot in slots:
                    var = paddle.static.data(
                        name=slot, shape=[None, 1], dtype="int64", lod_level=1)
                    slots_vars.append(var)
                dataset.init(
                    batch_size=1,
                    thread_num=2,
                    input_type=1,
                    pipe_command="cat",
                    use_var=slots_vars)
                filelist = ["a.txt", "b.txt"]
                dataset.set_filelist(filelist)
                dataset.preload_into_memory()
                dataset.wait_preload_done()
927 928 929 930 931 932 933 934 935 936
        """
        self._prepare_to_run()
        if thread_num is None:
            thread_num = self.thread_num
        self.dataset.set_preload_thread_num(thread_num)
        self.dataset.create_preload_readers()
        self.dataset.preload_into_memory()

    def wait_preload_done(self):
        """
937 938
        :api_attr: Static Graph

939 940 941 942 943
        Wait preload_into_memory done

        Examples:
            .. code-block:: python

S
ShenLiang 已提交
944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963
                import paddle
                paddle.enable_static()

                dataset = paddle.distributed.InMemoryDataset()
                slots = ["slot1", "slot2", "slot3", "slot4"]
                slots_vars = []
                for slot in slots:
                    var = paddle.static.data(
                        name=slot, shape=[None, 1], dtype="int64", lod_level=1)
                    slots_vars.append(var)
                dataset.init(
                    batch_size=1,
                    thread_num=2,
                    input_type=1,
                    pipe_command="cat",
                    use_var=slots_vars)
                filelist = ["a.txt", "b.txt"]
                dataset.set_filelist(filelist)
                dataset.preload_into_memory()
                dataset.wait_preload_done()
964 965 966 967 968 969
        """
        self.dataset.wait_preload_done()
        self.dataset.destroy_preload_readers()

    def local_shuffle(self):
        """
970 971
        :api_attr: Static Graph

972 973 974 975 976
        Local shuffle

        Examples:
            .. code-block:: python

S
ShenLiang 已提交
977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996
                import paddle
                paddle.enable_static()

                dataset = paddle.distributed.InMemoryDataset()
                slots = ["slot1", "slot2", "slot3", "slot4"]
                slots_vars = []
                for slot in slots:
                    var = paddle.static.data(
                        name=slot, shape=[None, 1], dtype="int64", lod_level=1)
                    slots_vars.append(var)
                dataset.init(
                    batch_size=1,
                    thread_num=2,
                    input_type=1,
                    pipe_command="cat",
                    use_var=slots_vars)
                filelist = ["a.txt", "b.txt"]
                dataset.set_filelist(filelist)
                dataset.load_into_memory()
                dataset.local_shuffle()
997 998 999 1000 1001
        """
        self.dataset.local_shuffle()

    def global_shuffle(self, fleet=None, thread_num=12):
        """
1002 1003
        :api_attr: Static Graph

1004 1005 1006 1007 1008 1009 1010 1011
        Global shuffle.
        Global shuffle can be used only in distributed mode. i.e. multiple
        processes on single machine or multiple machines training together.
        If you run in distributed mode, you should pass fleet instead of None.

        Examples:
            .. code-block:: python

S
ShenLiang 已提交
1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031
                import paddle
                paddle.enable_static()

                dataset = paddle.distributed.InMemoryDataset()
                slots = ["slot1", "slot2", "slot3", "slot4"]
                slots_vars = []
                for slot in slots:
                    var = paddle.static.data(
                        name=slot, shape=[None, 1], dtype="int64", lod_level=1)
                    slots_vars.append(var)
                dataset.init(
                    batch_size=1,
                    thread_num=2,
                    input_type=1,
                    pipe_command="cat",
                    use_var=slots_vars)
                filelist = ["a.txt", "b.txt"]
                dataset.set_filelist(filelist)
                dataset.load_into_memory()
                dataset.global_shuffle()
1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062

        Args:
            fleet(Fleet): fleet singleton. Default None.
            thread_num(int): shuffle thread num. Default is 12.

        """
        trainer_num = 1
        if fleet is not None:
            fleet._role_maker.barrier_worker()
            trainer_num = fleet.worker_num()
        if self.fleet_send_batch_size is None:
            self.fleet_send_batch_size = 1024
        if self.fleet_send_sleep_seconds is None:
            self.fleet_send_sleep_seconds = 0
        self.dataset.register_client2client_msg_handler()
        self.dataset.set_trainer_num(trainer_num)
        self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size)
        self.dataset.set_fleet_send_sleep_seconds(self.fleet_send_sleep_seconds)
        if fleet is not None:
            fleet._role_maker.barrier_worker()
        self.dataset.global_shuffle(thread_num)
        if fleet is not None:
            fleet._role_maker.barrier_worker()
        if self.merge_by_lineid:
            self.dataset.merge_by_lineid()
        if fleet is not None:
            fleet._role_maker.barrier_worker()

    def release_memory(self):
        """
        :api_attr: Static Graph
1063

1064 1065 1066 1067 1068
        Release InMemoryDataset memory data, when data will not be used again.

        Examples:
            .. code-block:: python

S
ShenLiang 已提交
1069 1070
                import paddle
                paddle.enable_static()
1071

S
ShenLiang 已提交
1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094
                dataset = paddle.distributed.InMemoryDataset()
                slots = ["slot1", "slot2", "slot3", "slot4"]
                slots_vars = []
                for slot in slots:
                    var = paddle.static.data(
                        name=slot, shape=[None, 1], dtype="int64", lod_level=1)
                    slots_vars.append(var)
                dataset.init(
                    batch_size=1,
                    thread_num=2,
                    input_type=1,
                    pipe_command="cat",
                    use_var=slots_vars)
                filelist = ["a.txt", "b.txt"]
                dataset.set_filelist(filelist)
                dataset.load_into_memory()
                dataset.global_shuffle()
                exe = paddle.static.Executor(paddle.CPUPlace())
                startup_program = paddle.static.Program()
                main_program = paddle.static.Program()
                exe.run(startup_program)
                exe.train_from_dataset(main_program, dataset)
                dataset.release_memory()
1095 1096 1097 1098 1099 1100

        """
        self.dataset.release_memory()

    def get_memory_data_size(self, fleet=None):
        """
1101 1102
        :api_attr: Static Graph

1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117
        Get memory data size, user can call this function to know the num
        of ins in all workers after load into memory.

        Note:
            This function may cause bad performance, because it has barrier

        Args:
            fleet(Fleet): Fleet Object.

        Returns:
            The size of memory data.

        Examples:
            .. code-block:: python

S
ShenLiang 已提交
1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137
                import paddle
                paddle.enable_static()

                dataset = paddle.distributed.InMemoryDataset()
                slots = ["slot1", "slot2", "slot3", "slot4"]
                slots_vars = []
                for slot in slots:
                    var = paddle.static.data(
                        name=slot, shape=[None, 1], dtype="int64", lod_level=1)
                    slots_vars.append(var)
                dataset.init(
                    batch_size=1,
                    thread_num=2,
                    input_type=1,
                    pipe_command="cat",
                    use_var=slots_vars)
                filelist = ["a.txt", "b.txt"]
                dataset.set_filelist(filelist)
                dataset.load_into_memory()
                print dataset.get_memory_data_size()
1138 1139 1140

        """
        import numpy as np
1141

1142 1143 1144 1145
        local_data_size = self.dataset.get_memory_data_size()
        local_data_size = np.array([local_data_size])
        if fleet is not None:
            global_data_size = local_data_size * 0
1146 1147 1148
            fleet._role_maker.all_reduce_worker(
                local_data_size, global_data_size
            )
1149 1150 1151 1152 1153
            return global_data_size[0]
        return local_data_size[0]

    def get_shuffle_data_size(self, fleet=None):
        """
1154 1155
        :api_attr: Static Graph

1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171
        Get shuffle data size, user can call this function to know the num
        of ins in all workers after local/global shuffle.

        Note:
            This function may cause bad performance to local shuffle,
            because it has barrier. It does not affect global shuffle.

        Args:
            fleet(Fleet): Fleet Object.

        Returns:
            The size of shuffle data.

        Examples:
            .. code-block:: python

S
ShenLiang 已提交
1172 1173
                import paddle
                paddle.enable_static()
1174

S
ShenLiang 已提交
1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193
                dataset = paddle.distributed.InMemoryDataset()
                dataset = paddle.distributed.InMemoryDataset()
                slots = ["slot1", "slot2", "slot3", "slot4"]
                slots_vars = []
                for slot in slots:
                    var = paddle.static.data(
                        name=slot, shape=[None, 1], dtype="int64", lod_level=1)
                    slots_vars.append(var)
                dataset.init(
                    batch_size=1,
                    thread_num=2,
                    input_type=1,
                    pipe_command="cat",
                    use_var=slots_vars)
                filelist = ["a.txt", "b.txt"]
                dataset.set_filelist(filelist)
                dataset.load_into_memory()
                dataset.global_shuffle()
                print dataset.get_shuffle_data_size()
1194 1195 1196

        """
        import numpy as np
1197

1198 1199 1200 1201
        local_data_size = self.dataset.get_shuffle_data_size()
        local_data_size = np.array([local_data_size])
        if fleet is not None:
            global_data_size = local_data_size * 0
1202 1203 1204
            fleet._role_maker.all_reduce_worker(
                local_data_size, global_data_size
            )
1205 1206 1207
            return global_data_size[0]
        return local_data_size[0]

1208 1209 1210 1211
    def _set_fea_eval(self, record_candidate_size, fea_eval=True):
        """
        set fea eval mode for slots shuffle to debug the importance level of
        slots(features), fea_eval need to be set True for slots shuffle.
1212

1213
        Args:
1214
            record_candidate_size(int): size of instances candidate to shuffle
1215 1216 1217
                                        one slot
            fea_eval(bool): whether enable fea eval mode to enable slots shuffle.
                            default is True.
1218

1219 1220 1221 1222
        Examples:
            .. code-block:: python

            import paddle
S
ShenLiang 已提交
1223
            paddle.enable_static()
1224 1225 1226 1227 1228 1229 1230 1231 1232 1233
            dataset = paddle.distributed.InMemoryDataset()
            dataset._set_fea_eval(1000000, True)

        """
        if fea_eval:
            self.dataset.set_fea_eval(fea_eval, record_candidate_size)
        self.fea_eval = fea_eval

    def slots_shuffle(self, slots):
        """
1234 1235
        Slots Shuffle
        Slots Shuffle is a shuffle method in slots level, which is usually used
1236
        in sparse feature with large scale of instances. To compare the metric, i.e.
1237
        auc while doing slots shuffle on one or several slots with baseline to
1238
        evaluate the importance level of slots(features).
1239

1240 1241 1242 1243
        Args:
            slots(list[string]): the set of slots(string) to do slots shuffle.

        Examples:
S
ShenLiang 已提交
1244 1245 1246 1247
            .. code-block:: python

                import paddle
                paddle.enable_static()
1248

S
ShenLiang 已提交
1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266
                dataset = paddle.distributed.InMemoryDataset()
                dataset._init_distributed_settings(fea_eval=True)
                slots = ["slot1", "slot2", "slot3", "slot4"]
                slots_vars = []
                for slot in slots:
                    var = paddle.static.data(
                        name=slot, shape=[None, 1], dtype="int64", lod_level=1)
                    slots_vars.append(var)
                dataset.init(
                    batch_size=1,
                    thread_num=2,
                    input_type=1,
                    pipe_command="cat",
                    use_var=slots_vars)
                filelist = ["a.txt", "b.txt"]
                dataset.set_filelist(filelist)
                dataset.load_into_memory()
                dataset.slots_shuffle(['slot1'])
1267 1268 1269 1270 1271
        """
        if self.fea_eval:
            slots_set = set(slots)
            self.dataset.slots_shuffle(slots_set)

1272 1273 1274

class QueueDataset(DatasetBase):
    """
1275 1276
    :api_attr: Static Graph

1277 1278 1279 1280 1281
    QueueDataset, it will process data streamly.

    Examples:
        .. code-block:: python

1282 1283
          import paddle
          dataset = paddle.distributed.QueueDataset()
1284 1285 1286 1287 1288 1289 1290

    """

    def __init__(self):
        """
        Initialize QueueDataset
        """
1291
        super().__init__()
1292 1293
        self.proto_desc.name = "MultiSlotDataFeed"

1294 1295
    def init(self, **kwargs):
        """
1296 1297
        :api_attr: Static Graph

1298 1299
        should be called only once in user's python scripts to initialize setings of dataset instance
        """
1300
        super().init(**kwargs)
1301

1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312
    def _prepare_to_run(self):
        """
        Set data_feed_desc/thread num/filelist before run,
        user no need to call this function.
        """
        if self.thread_num > len(self.filelist):
            self.thread_num = len(self.filelist)
        if self.thread_num == 0:
            self.thread_num = 1
        self.dataset.set_thread_num(self.thread_num)
        self.dataset.set_filelist(self.filelist)
1313
        self.dataset.set_data_feed_desc(self._desc())
1314 1315 1316 1317 1318 1319 1320 1321 1322 1323
        self.dataset.create_readers()


class FileInstantDataset(DatasetBase):
    """
    FileInstantDataset, it will process data streamly.

    Examples:
        .. code-block:: python

1324 1325
          import paddle
          dataset = paddle.distributed.fleet.FileInstantDataset()
1326 1327 1328 1329 1330 1331
    """

    def __init__(self):
        """
        Initialize FileInstantDataset
        """
1332
        super().__init__()
1333 1334
        self.proto_desc.name = "MultiSlotFileInstantDataFeed"

1335
    def init(self, **kwargs):
1336
        """
1337
        should be called only once in user's python scripts to initialize setings of dataset instance
1338
        """
1339
        super().init(**kwargs)
1340 1341 1342 1343 1344 1345 1346 1347 1348


class BoxPSDataset(InMemoryDataset):
    """
    BoxPSDataset: derived from InMemoryDataset.

    Examples:
        .. code-block:: python

1349 1350
          import paddle
          dataset = paddle.distributed.fleet.BoxPSDataset()
1351 1352 1353 1354 1355 1356
    """

    def __init__(self):
        """
        Initialize BoxPSDataset
        """
1357
        super().__init__()
1358 1359 1360
        self.boxps = core.BoxPS(self.dataset)
        self.proto_desc.name = "PaddleBoxDataFeed"

1361 1362 1363 1364
    def init(self, **kwargs):
        """
        should be called only once in user's python scripts to initialize setings of dataset instance
        """
1365
        super().init(**kwargs)
1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461

        rank_offset = kwargs.get("rank_offset", "")
        self._set_rank_offset(rank_offset)
        pv_batch_size = kwargs.get("pv_batch_size", 1)
        self._set_pv_batch_size(pv_batch_size)
        parse_logkey = kwargs.get("parse_logkey", False)
        self._set_parse_logkey(parse_logkey)
        merge_by_sid = kwargs.get("merge_by_sid", False)
        self._set_merge_by_sid(merge_by_sid)
        enable_pv_merge = kwargs.get("enable_pv_merge", False)
        self._set_enable_pv_merge(enable_pv_merge)

    def _set_rank_offset(self, rank_offset):
        """
        Set rank_offset for merge_pv. It set the message of Pv.

        Examples:
            .. code-block:: python

              import paddle
              dataset = paddle.distributed.fleet.BoxPSDataset()
              dataset._set_rank_offset("rank_offset")

        Args:
            rank_offset(str): rank_offset's name

        """
        self.proto_desc.rank_offset = rank_offset

    def _set_pv_batch_size(self, pv_batch_size):
        """
        Set pv batch size. It will be effective during enable_pv_merge

        Examples:
            .. code-block:: python

              import paddle
              dataset = paddle.distributed.fleet.BoxPSDataset()
              dataset._set_pv_batch_size(128)
        Args:
            pv_batch_size(int): pv batch size

        """
        self.proto_desc.pv_batch_size = pv_batch_size

    def _set_parse_logkey(self, parse_logkey):
        """
        Set if Dataset need to parse logkey

        Args:
            parse_content(bool): if parse logkey or not

        Examples:
            .. code-block:: python

              import paddle
              dataset = paddle.distributed.fleet.BoxPSDataset()
              dataset._set_parse_logkey(True)

        """
        self.parse_logkey = parse_logkey

    def _set_merge_by_sid(self, merge_by_sid):
        """
        Set if Dataset need to merge sid. If not, one ins means one Pv.

        Args:
            merge_by_sid(bool): if merge sid or not

        Examples:
            .. code-block:: python

              import paddle
              dataset = paddle.distributed.fleet.BoxPSDataset()
              dataset._set_merge_by_sid(True)

        """
        self.merge_by_sid = merge_by_sid

    def _set_enable_pv_merge(self, enable_pv_merge):
        """
        Set if Dataset need to merge pv.

        Args:
            enable_pv_merge(bool): if enable_pv_merge or not

        Examples:
            .. code-block:: python

              import paddle
              dataset = paddle.distributed.fleet.BoxPSDataset()
              dataset._set_enable_pv_merge(True)

        """
        self.enable_pv_merge = enable_pv_merge

1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473
    def set_date(self, date):
        """
        Workaround for date
        """
        year = int(date[:4])
        month = int(date[4:6])
        day = int(date[6:])
        self.boxps.set_date(year, month, day)

    def begin_pass(self):
        """
        Begin Pass
1474
        Notify BoxPS to load sparse parameters of next pass to GPU Memory
1475 1476 1477 1478

        Examples:
            .. code-block:: python

1479 1480
              import paddle
              dataset = paddle.distributed.fleet.BoxPSDataset()
1481 1482 1483 1484 1485 1486 1487
              dataset.begin_pass()
        """
        self.boxps.begin_pass()

    def end_pass(self, need_save_delta):
        """
        End Pass
1488
        Notify BoxPS that current pass ended
1489 1490 1491
        Examples:
            .. code-block:: python

1492 1493
              import paddle
              dataset = paddle.distributed.fleet.BoxPSDataset()
1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504
              dataset.end_pass(True)
        """
        self.boxps.end_pass(need_save_delta)

    def wait_preload_done(self):
        """
        Wait async preload done
        Wait Until Feed Pass Done
        Examples:
            .. code-block:: python

1505 1506
              import paddle
              dataset = paddle.distributed.fleet.BoxPSDataset()
1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519
              filelist = ["a.txt", "b.txt"]
              dataset.set_filelist(filelist)
              dataset.preload_into_memory()
              dataset.wait_preload_done()
        """
        self.boxps.wait_feed_pass_done()

    def load_into_memory(self):
        """
        Load next pass into memory and notify boxps to fetch its emb from SSD
        Examples:
            .. code-block:: python

1520 1521
              import paddle
              dataset = paddle.distributed.fleet.BoxPSDataset()
1522 1523 1524
              filelist = ["a.txt", "b.txt"]
              dataset.set_filelist(filelist)
              dataset.load_into_memory()
1525
        """
1526 1527 1528 1529 1530 1531 1532 1533 1534
        self._prepare_to_run()
        self.boxps.load_into_memory()

    def preload_into_memory(self):
        """
        Begin async preload next pass while current pass may be training
        Examples:
            .. code-block:: python

1535 1536
              import paddle
              dataset = paddle.distributed.fleet.BoxPSDataset()
1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553
              filelist = ["a.txt", "b.txt"]
              dataset.set_filelist(filelist)
              dataset.preload_into_memory()
        """
        self._prepare_to_run()
        self.boxps.preload_into_memory()

    def _dynamic_adjust_before_train(self, thread_num):
        if not self.is_user_set_queue_num:
            self.dataset.dynamic_adjust_channel_num(thread_num, True)
        self.dataset.dynamic_adjust_readers_num(thread_num)

    def _dynamic_adjust_after_train(self):
        pass

    def slots_shuffle(self, slots):
        """
1554 1555
        Slots Shuffle
        Slots Shuffle is a shuffle method in slots level, which is usually used
1556
        in sparse feature with large scale of instances. To compare the metric, i.e.
1557
        auc while doing slots shuffle on one or several slots with baseline to
1558
        evaluate the importance level of slots(features).
1559

1560 1561 1562 1563
        Args:
            slots(list[string]): the set of slots(string) to do slots shuffle.

        Examples:
1564 1565
            import paddle
            dataset = paddle.distributed.fleet.BoxPSDataset()
1566 1567 1568 1569 1570 1571
            dataset.set_merge_by_lineid()
            #suppose there is a slot 0
            dataset.slots_shuffle(['0'])
        """
        slots_set = set(slots)
        self.boxps.slots_shuffle(slots_set)
1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616

    def set_current_phase(self, current_phase):
        """
        Set current phase in train. It is useful for untest.
        current_phase : 1 for join, 0 for update.

        Examples:
            .. code-block:: python

              import paddle
              dataset = paddle.distributed.fleet.BoxPSDataset()
              filelist = ["a.txt", "b.txt"]
              dataset.set_filelist(filelist)
              dataset.load_into_memory()
              dataset.set_current_phase(1)

        """
        self.dataset.set_current_phase(current_phase)

    def get_pv_data_size(self):
        """
        Get memory data size of Pv, user can call this function to know the pv num
        of ins in all workers after load into memory.

        Note:
            This function may cause bad performance, because it has barrier

        Returns:
            The size of memory pv data.

        Examples:
            .. code-block:: python

              import paddle
              dataset = paddle.distributed.fleet.BoxPSDataset()
              filelist = ["a.txt", "b.txt"]
              dataset.set_filelist(filelist)
              dataset.load_into_memory()
              print dataset.get_pv_data_size()

        """
        return self.dataset.get_pv_data_size()

    def preprocess_instance(self):
        """
1617
        Merge pv instance and convey it from input_channel to input_pv_channel.
1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650
        It will be effective when enable_pv_merge_ is True.

        Examples:
            .. code-block:: python

              import paddle
              dataset = paddle.distributed.fleet.BoxPSDataset()
              filelist = ["a.txt", "b.txt"]
              dataset.set_filelist(filelist)
              dataset.load_into_memory()
              dataset.preprocess_instance()

        """
        self.dataset.preprocess_instance()

    def postprocess_instance(self):
        """
        Divide pv instance and convey it to input_channel.

        Examples:
            .. code-block:: python

              import paddle
              dataset = paddle.distributed.fleet.BoxPSDataset()
              filelist = ["a.txt", "b.txt"]
              dataset.set_filelist(filelist)
              dataset.load_into_memory()
              dataset.preprocess_instance()
              exe.train_from_dataset(dataset)
              dataset.postprocess_instance()

        """
        self.dataset.postprocess_instance()