dataset.py 5.4 KB
Newer Older
C
Chengmo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import warnings

import paddle.fluid as fluid
from paddlerec.core.utils import envs
from paddlerec.core.utils import dataloader_instance
from paddlerec.core.reader import SlotReader
from paddlerec.core.trainer import EngineMode

__all__ = ["DatasetBase", "DataLoader", "QueueDataset"]


class DatasetBase(object):
    """R
    """

    def __init__(self, context):
        pass

    def get_dataset(self, context):
        pass


class DataLoader(DatasetBase):
    def __init__(self, context):
        pass

    def get_dataloader(self, context, dataset_name, dataloader):
        name = "dataset." + dataset_name + "."
        sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
        dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
        batch_size = envs.get_global_env(name + "batch_size")

        reader_class = envs.get_global_env(name + "data_converter")
        reader_class_name = envs.get_global_env(name + "reader_class_name",
                                                "Reader")

        if sparse_slots == "" and dense_slots == "":
            reader = dataloader_instance.dataloader_by_name(
                reader_class,
                dataset_name,
                context["config_yaml"],
                context,
                reader_class_name=reader_class_name)

            reader_class = envs.lazy_instance_by_fliename(reader_class,
                                                          reader_class_name)
            reader_ins = reader_class(context["config_yaml"])
        else:
            reader = dataloader_instance.slotdataloader_by_name(
                "", dataset_name, context["config_yaml"], context)
            reader_ins = SlotReader(context["config_yaml"])
        if hasattr(reader_ins, 'generate_batch_from_trainfiles'):
            dataloader.set_sample_list_generator(reader)
        else:
            dataloader.set_sample_generator(reader, batch_size)
        return dataloader


class QueueDataset(DatasetBase):
    def __init__(self, context):
        pass

    def create_dataset(self, dataset_name, context):
        name = "dataset." + dataset_name + "."
        type_name = envs.get_global_env(name + "type")
        if envs.get_platform() != "LINUX":
            print("platform ", envs.get_platform(), "Reader To Dataloader")
            type_name = "DataLoader"

        if type_name == "DataLoader":
            return None
        else:
            return self._get_dataset(dataset_name, context)

    def _get_dataset(self, dataset_name, context):
        name = "dataset." + dataset_name + "."
        reader_class = envs.get_global_env(name + "data_converter")
        reader_class_name = envs.get_global_env(name + "reader_class_name",
                                                "Reader")
        abs_dir = os.path.dirname(os.path.abspath(__file__))
        reader = os.path.join(abs_dir, '../../utils', 'dataset_instance.py')
        sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
        dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
        if sparse_slots == "" and dense_slots == "":
            pipe_cmd = "python {} {} {} {}".format(reader, reader_class,
                                                   reader_class_name,
                                                   context["config_yaml"])
        else:
            if sparse_slots == "":
                sparse_slots = "?"
            if dense_slots == "":
                dense_slots = "?"
            padding = envs.get_global_env(name + "padding", 0)
            pipe_cmd = "python {} {} {} {} {} {} {} {}".format(
                reader, "slot", "slot", context["config_yaml"], "fake",
                sparse_slots.replace(" ", "?"),
                dense_slots.replace(" ", "?"), str(padding))

        batch_size = envs.get_global_env(name + "batch_size")
        dataset = fluid.DatasetFactory().create_dataset()
        dataset.set_batch_size(batch_size)
        dataset.set_pipe_command(pipe_cmd)
        train_data_path = envs.get_global_env(name + "data_path")
        file_list = [
            os.path.join(train_data_path, x)
            for x in os.listdir(train_data_path)
        ]
        if context["engine"] == EngineMode.LOCAL_CLUSTER:
            file_list = context["fleet"].split_files(file_list)

        dataset.set_filelist(file_list)
T
tangwei 已提交
129
        for model_dict in context["phases"]:
C
Chengmo 已提交
130 131 132 133 134 135 136 137 138 139 140
            if model_dict["dataset_name"] == dataset_name:
                model = context["model"][model_dict["name"]]["model"]
                thread_num = int(model_dict["thread_num"])
                dataset.set_thread(thread_num)
                if context["is_infer"]:
                    inputs = model._infer_data_var
                else:
                    inputs = model._data_var
                dataset.set_use_var(inputs)
                break
        return dataset