dataset_generator.py 11.5 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

__all__ = ['MultiSlotDataset']


class DatasetGenerator(object):
    def __init__(self):
        self._proto_info = None
        self._hadoop_host = None
        self._batch_size = 32
        self._hadoop_ugi = None
        self._hadoop_path = None

    def _set_proto_filename(self, proto_filename):
        if not isinstance(proto_filename, str):
            raise ValueError("proto_filename%s must be in str type" %
                             type(proto_filename))
        if not proto_filename:
            raise ValueError("proto_filename can not be empty")
        self._proto_filename = proto_filename

    def generate_sample(self, line):
        '''
        This function needs to be overridden by the user to process the
        original data row into a list or tuple

        Args:
            line(str): the original data row

        Returns:
            Returns the data processed by the user.
              The data format is list or tuple:
            [(name, [feasign, ...]), ...]
              or ((name, [feasign, ...]), ...)
            
            For example:
            [("words", [1926, 08, 17])], ("label", [1])]
              or (("words", [1926, 08, 17]), ("label", [1]))

        Note:
            The type of feasigns must be in int or float. Once the float
            element appears in the feasign, the type of that slot will be
            processed into a float.
        '''
        raise NotImplementedError(
            "please rewrite this function to return a list" +
            "[(name, [int, int ...]), ...]")

    def set_batch(self, batch):
        self.batch = batch

    def generate_batch(self, samples):
        '''
        This function can be overridden by the user to process batch
        data, a user can define how to generate batch with this function
        
        Args:
            samples(list of results from generate_samples)
        
        Returns:
            Returns the processed batch by the user
            [[(name, [int, ...]), ...],
             [(name, [int, ...]), ...],
             [(name, [int, ...])]]

        Default:
            Do nothing about current batch
        '''

        def batch_iter():
            for sample in samples:
                yield sample

        return batch_iter

    def _gen_str(self, line):
        raise NotImplementedError(
            "Please inherit this class and implement _gen_str")

    def _upload_proto_file(self):
        if self.proto_output_path == None:
            raise ValueError("If you are running data generation on hadoop, "
                             "please set proto output path first")

        if self._hadoop_host == None or self._hadoop_ugi == None or \
           self._hadoop_path == None:
            raise ValueError(
                "If you are running data generation on hadoop, "
                "please set hadoop_host, hadoop_path, hadoop_ugi first")
        cmd = "$HADOOP_HOME/bin/hadoop fs" \
              + " -Dhadoop.job.ugi=" + self.hadoop_ugi \
              + " -Dfs.default.name=" + self.hadoop_host \
              + " -put " + self._proto_filename + " " + self._proto_output_path
        os.system(cmd)

    def set_hadoop_config(self,
                          hadoop_host=None,
                          hadoop_ugi=None,
                          proto_path=None):
        '''
        This function set hadoop configuration for map-reduce based data
        generation. 
        
        Args:
            hadoop_host(str): The host name of the hadoop. It should be
                              in this format: "hdfs://${HOST}:${PORT}".
            hadoop_ugi(str): The ugi of the hadoop. It should be in this
                             format: "${USERNAME},${PASSWORD}".
            proto_path(str): The hadoop path you want to upload the
                             protofile to.
        '''
        self.hadoop_host = hadoop_host
        self.hadoop_ugi = hadoop_ugi
        self.proto_output_path = proto_path

    def run_from_memory(self, is_local=True, proto_filename='data_feed.proto'):
        '''
        This function generates data from memory, user needs to
        define how to generate samples by define generate_sample
        and generate_batch
        '''
        self._set_proto_filename(proto_filename)
        batch_data = []
        line_iter = self.generate_sample(None)
        for user_parsed_line in line_iter():
            if user_parsed_line == None:
                continue
            batch_data.append(user_parsed_line)
            if len(batch_data) == self._batch_size:
                batched_iter = self.generate_batch(batch_data)
                for batched_line in batched_iter():
                    sys.stdout.write(self._gen_str(batched_line))
                batch_data = []
        if len(batch_data) > 0:
            batched_iter = self.generate_batch(batch_data)
            for batched_line in batched_iter():
                sys.stdout.write(self._gen_str(batched_line))
        if self.proto_info is not None:
            with open(self._proto_filename, "w") as f:
                f.write(self._get_proto_desc(self._proto_info))
            if is_local == False:
                self._upload_proto_file()

    def run_from_stdin(self, is_local=True, proto_filename='data_feed.proto'):
        '''
        This function reads the data row from stdin, parses it with the
        process function, and further parses the return value of the
        process function with the _gen_str function. The parsed data will
        be wrote to stdout and the corresponding protofile will be
        generated. If local is set to False, the protofile will be
        uploaded to hadoop.
        
        Args:
            is_local(bool): Whether user wants to run this function from local
            proto_filename(str): The name of protofile. The default value
                                 is "data_feed.proto". It is not
                                 recommended to modify it.
        '''
        self._set_proto_filename(proto_filename)
        batch_data = []
        for line in sys.stdin:
            line_iter = self.generate_sample(line)
            for user_parsed_line in line_iter():
                if user_parsed_line == None:
                    continue
                batch_data.append(user_parsed_line)
                if len(batch_data) == self._batch_size:
                    batched_iter = self.generate_batch(batch_data)
                    for batched_line in batched_iter():
                        sys.stdout.write(self._gen_str(batched_line))
                    batch_data = []
        if len(batch_data) > 0:
            batched_iter = self.generate_batch(batch_data)
            for batched_line in batched_iter():
                sys.stdout.write(self._gen_str(batched_line))

        if self._proto_info is not None:
            with open(self._proto_filename, "w") as f:
                f.write(self._get_proto_desc(self._proto_info))
            if is_local == False:
                self._upload_proto_file()


class MultiSlotDataset(DatasetGenerator):
    def _get_proto_desc(self, proto_info):
        proto_str = "name: \"MultiSlotDataFeed\"\n" \
                    + "batch_size: 32\nmulti_slot_desc {\n"
        for elem in proto_info:
            proto_str += "  slots {\n" \
                         + "    name: \"%s\"\n" % elem[0]\
                         + "    type: \"%s\"\n" % elem[1]\
                         + "    is_dense: false\n" \
                         + "    is_used: false\n" \
                         + "  }\n"
        proto_str += "}"
        return proto_str

    def generate_batch(self, samples):
        super(MultiSlotDataset, self).generate_batch(samples)

        def batch_iter():
            for sample in samples:
                yield sample

        return batch_iter

    def _gen_str(self, line):
        if not isinstance(line, list) and not isinstance(line, tuple):
            raise ValueError(
                "the output of process() must be in list or tuple type")
        output = ""

        if self._proto_info is None:
            self._proto_info = []
            for item in line:
                name, elements = item
                if not isinstance(name, str):
                    raise ValueError("name%s must be in str type" % type(name))
                if not isinstance(elements, list):
                    raise ValueError("elements%s must be in list type" %
                                     type(elements))
                if not elements:
                    raise ValueError(
                        "the elements of each field can not be empty, you need padding it in process()."
                    )
                self._proto_info.append((name, "uint64"))
                if output:
                    output += " "
                output += str(len(elements))
                for elem in elements:
                    if isinstance(elem, float):
                        self._proto_info[-1] = (name, "float")
                    elif not isinstance(elem, int) and not isinstance(elem,
                                                                      long):
                        raise ValueError(
                            "the type of element%s must be in int or float" %
                            type(elem))
                    output += " " + str(elem)
        else:
            if len(line) != len(self._proto_info):
                raise ValueError(
                    "the complete field set of two given line are inconsistent.")
            for index, item in enumerate(line):
                name, elements = item
                if not isinstance(name, str):
                    raise ValueError("name%s must be in str type" % type(name))
                if not isinstance(elements, list):
                    raise ValueError("elements%s must be in list type" %
                                     type(elements))
                if not elements:
                    raise ValueError(
                        "the elements of each field can not be empty, you need padding it in process()."
                    )
                if name != self._proto_info[index][0]:
                    raise ValueError(
                        "the field name of two given line are not match: require<%s>, get<%d>."
                        % (self._proto_info[index][0], name))
                if output:
                    output += " "
                output += str(len(elements))
                for elem in elements:
                    if self._proto_info[index][1] != "float":
                        if isinstance(elem, float):
                            self._proto_info[index] = (name, "float")
                        elif not isinstance(elem, int) and not isinstance(elem,
                                                                          long):
                            raise ValueError(
                                "the type of element%s must be in int or float"
                                % type(elem))
                    output += " " + str(elem)
        return output + "\n"