__init__.py 14.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

__all__ = ['MultiSlotDataGenerator', 'MultiSlotStringDataGenerator']


21
class DataGenerator:
22 23 24 25 26 27 28 29 30 31 32 33
    """
    DataGenerator is a general Base class for user to inherit
    A user who wants to define his/her own python processing logic
    with paddle.fluid.dataset should inherit this class
    """

    def __init__(self):
        self._proto_info = None
        self.batch_size_ = 32

    def _set_line_limit(self, line_limit):
        if not isinstance(line_limit, int):
34 35 36
            raise ValueError(
                "line_limit%s must be in int type" % type(line_limit)
            )
37 38 39 40 41 42 43 44
        if line_limit < 1:
            raise ValueError("line_limit can not less than 1")
        self._line_limit = line_limit

    def set_batch(self, batch_size):
        '''
        Set batch size of current DataGenerator
        This is necessary only if a user wants to define generator_batch
45

46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
        Example:
            .. code-block:: python
                import paddle.fluid.incubate.data_generator as dg
                class MyData(dg.DataGenerator):
                    def generate_sample(self, line):
                        def local_iter():
                            int_words = [int(x) for x in line.split()]
                            yield ("words", int_words)
                        return local_iter
                    def generate_batch(self, samples):
                        def local_iter():
                            for s in samples:
                                yield ("words", s[1].extend([s[1][0]]))
                mydata = MyData()
                mydata.set_batch(128)
61

62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
        '''
        self.batch_size_ = batch_size

    def run_from_memory(self):
        '''
        This function generator data from memory, it is usually used for
        debug and benchmarking
        Example:
            .. code-block:: python
                import paddle.fluid.incubate.data_generator as dg
                class MyData(dg.DataGenerator):
                    def generate_sample(self, line):
                        def local_iter():
                            yield ("words", [1, 2, 3, 4])
                        return local_iter
                mydata = MyData()
                mydata.run_from_memory()
        '''
        batch_samples = []
        line_iter = self.generate_sample(None)
        for user_parsed_line in line_iter():
83
            if user_parsed_line is None:
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
                continue
            batch_samples.append(user_parsed_line)
            if len(batch_samples) == self.batch_size_:
                batch_iter = self.generate_batch(batch_samples)
                for sample in batch_iter():
                    sys.stdout.write(self._gen_str(sample))
                batch_samples = []
        if len(batch_samples) > 0:
            batch_iter = self.generate_batch(batch_samples)
            for sample in batch_iter():
                sys.stdout.write(self._gen_str(sample))

    def run_from_stdin(self):
        '''
        This function reads the data row from stdin, parses it with the
99
        process function, and further parses the return value of the
100 101 102 103
        process function with the _gen_str function. The parsed data will
        be wrote to stdout and the corresponding protofile will be
        generated.
        Example:
104

105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
            .. code-block:: python
                import paddle.fluid.incubate.data_generator as dg
                class MyData(dg.DataGenerator):
                    def generate_sample(self, line):
                        def local_iter():
                            int_words = [int(x) for x in line.split()]
                            yield ("words", [int_words])
                        return local_iter
                mydata = MyData()
                mydata.run_from_stdin()
        '''
        batch_samples = []
        for line in sys.stdin:
            line_iter = self.generate_sample(line)
            for user_parsed_line in line_iter():
120
                if user_parsed_line is None:
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
                    continue
                batch_samples.append(user_parsed_line)
                if len(batch_samples) == self.batch_size_:
                    batch_iter = self.generate_batch(batch_samples)
                    for sample in batch_iter():
                        sys.stdout.write(self._gen_str(sample))
                    batch_samples = []
        if len(batch_samples) > 0:
            batch_iter = self.generate_batch(batch_samples)
            for sample in batch_iter():
                sys.stdout.write(self._gen_str(sample))

    def _gen_str(self, line):
        '''
        Further processing the output of the process() function rewritten by
        user, outputting data that can be directly read by the datafeed,and
        updating proto_info information.
        Args:
            line(str): the output of the process() function rewritten by user.
        Returns:
            Return a string data that can be read directly by the datafeed.
        '''
        raise NotImplementedError(
144 145
            "pls use MultiSlotDataGenerator or PairWiseDataGenerator"
        )
146 147 148

    def generate_sample(self, line):
        '''
149
        This function needs to be overridden by the user to process the
150 151 152 153 154
        original data row into a list or tuple.
        Args:
            line(str): the original data row
        Returns:
            Returns the data processed by the user.
155 156
              The data format is list or tuple:
            [(name, [feasign, ...]), ...]
157
              or ((name, [feasign, ...]), ...)
158

159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
            For example:
            [("words", [1926, 08, 17]), ("label", [1])]
              or (("words", [1926, 08, 17]), ("label", [1]))
        Note:
            The type of feasigns must be in int or float. Once the float
            element appears in the feasign, the type of that slot will be
            processed into a float.
        Example:
            .. code-block:: python
                import paddle.fluid.incubate.data_generator as dg
                class MyData(dg.DataGenerator):
                    def generate_sample(self, line):
                        def local_iter():
                            int_words = [int(x) for x in line.split()]
                            yield ("words", [int_words])
                        return local_iter
        '''
        raise NotImplementedError(
177 178 179
            "Please rewrite this function to return a list or tuple: "
            + "[(name, [feasign, ...]), ...] or ((name, [feasign, ...]), ...)"
        )
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242

    def generate_batch(self, samples):
        '''
        This function needs to be overridden by the user to process the
        generated samples from generate_sample(self, str) function
        It is usually used as batch processing when a user wants to
        do preprocessing on a batch of samples, e.g. padding according to
        the max length of a sample in the batch
        Args:
            samples(list tuple): generated sample from generate_sample
        Returns:
            a python generator, the same format as return value of generate_sample
        Example:
            .. code-block:: python
                import paddle.fluid.incubate.data_generator as dg
                class MyData(dg.DataGenerator):
                    def generate_sample(self, line):
                        def local_iter():
                            int_words = [int(x) for x in line.split()]
                            yield ("words", int_words)
                        return local_iter
                    def generate_batch(self, samples):
                        def local_iter():
                            for s in samples:
                                yield ("words", s[1].extend([s[1][0]]))
                mydata = MyData()
                mydata.set_batch(128)
        '''

        def local_iter():
            for sample in samples:
                yield sample

        return local_iter


# TODO: guru4elephant
# add more generalized DataGenerator that can adapt user-defined slot
# for example, [(name, float_list), (name, str_list), (name, int_list)]
class MultiSlotStringDataGenerator(DataGenerator):
    def _gen_str(self, line):
        '''
        Further processing the output of the process() function rewritten by
        user, outputting data that can be directly read by the MultiSlotDataFeed,
        and updating proto_info information.
        The input line will be in this format:
            >>> [(name, [str(feasign), ...]), ...]
            >>> or ((name, [str(feasign), ...]), ...)
        The output will be in this format:
            >>> [ids_num id1 id2 ...] ...
        For example, if the input is like this:
            >>> [("words", ["1926", "08", "17"]), ("label", ["1"])]
            >>> or (("words", ["1926", "08", "17"]), ("label", ["1"]))
        the output will be:
            >>> 3 1234 2345 3456 1 1
        Args:
            line(str): the output of the process() function rewritten by user.
        Returns:
            Return a string data that can be read directly by the MultiSlotDataFeed.
        '''
        if not isinstance(line, list) and not isinstance(line, tuple):
            raise ValueError(
                "the output of process() must be in list or tuple type"
243 244
                "Examples: [('words', ['1926', '08', '17']), ('label', ['1'])]"
            )
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
        output = ""
        for index, item in enumerate(line):
            name, elements = item
            if output:
                output += " "
            out_str = []
            out_str.append(str(len(elements)))
            out_str.extend(elements)
            output += " ".join(out_str)
        return output + "\n"


class MultiSlotDataGenerator(DataGenerator):
    def _gen_str(self, line):
        '''
        Further processing the output of the process() function rewritten by
        user, outputting data that can be directly read by the MultiSlotDataFeed,
        and updating proto_info information.
        The input line will be in this format:
264
            >>> [(name, [feasign, ...]), ...]
265 266 267 268 269
            >>> or ((name, [feasign, ...]), ...)
        The output will be in this format:
            >>> [ids_num id1 id2 ...] ...
        The proto_info will be in this format:
            >>> [(name, type), ...]
270

271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
        For example, if the input is like this:
            >>> [("words", [1926, 08, 17]), ("label", [1])]
            >>> or (("words", [1926, 08, 17]), ("label", [1]))
        the output will be:
            >>> 3 1234 2345 3456 1 1
        the proto_info will be:
            >>> [("words", "uint64"), ("label", "uint64")]
        Args:
            line(str): the output of the process() function rewritten by user.
        Returns:
            Return a string data that can be read directly by the MultiSlotDataFeed.
        '''
        if not isinstance(line, list) and not isinstance(line, tuple):
            raise ValueError(
                "the output of process() must be in list or tuple type"
286 287
                "Example: [('words', [1926, 08, 17]), ('label', [1])]"
            )
288 289 290 291 292 293 294 295 296
        output = ""

        if self._proto_info is None:
            self._proto_info = []
            for item in line:
                name, elements = item
                if not isinstance(name, str):
                    raise ValueError("name%s must be in str type" % type(name))
                if not isinstance(elements, list):
297 298 299
                    raise ValueError(
                        "elements%s must be in list type" % type(elements)
                    )
300 301 302 303 304 305 306 307 308 309 310
                if not elements:
                    raise ValueError(
                        "the elements of each field can not be empty, you need padding it in process()."
                    )
                self._proto_info.append((name, "uint64"))
                if output:
                    output += " "
                output += str(len(elements))
                for elem in elements:
                    if isinstance(elem, float):
                        self._proto_info[-1] = (name, "float")
311
                    elif not isinstance(elem, int) and not isinstance(
312 313
                        elem, long
                    ):
314
                        raise ValueError(
315 316 317
                            "the type of element%s must be in int or float"
                            % type(elem)
                        )
318 319 320 321
                    output += " " + str(elem)
        else:
            if len(line) != len(self._proto_info):
                raise ValueError(
322 323
                    "the complete field set of two given line are inconsistent."
                )
324 325 326 327 328
            for index, item in enumerate(line):
                name, elements = item
                if not isinstance(name, str):
                    raise ValueError("name%s must be in str type" % type(name))
                if not isinstance(elements, list):
329 330 331
                    raise ValueError(
                        "elements%s must be in list type" % type(elements)
                    )
332 333 334 335 336 337 338
                if not elements:
                    raise ValueError(
                        "the elements of each field can not be empty, you need padding it in process()."
                    )
                if name != self._proto_info[index][0]:
                    raise ValueError(
                        "the field name of two given line are not match: require<%s>, get<%s>."
339 340
                        % (self._proto_info[index][0], name)
                    )
341 342 343 344 345 346 347
                if output:
                    output += " "
                output += str(len(elements))
                for elem in elements:
                    if self._proto_info[index][1] != "float":
                        if isinstance(elem, float):
                            self._proto_info[index] = (name, "float")
348
                        elif not isinstance(elem, int) and not isinstance(
349 350
                            elem, long
                        ):
351 352
                            raise ValueError(
                                "the type of element%s must be in int or float"
353 354
                                % type(elem)
                            )
355 356
                    output += " " + str(elem)
        return output + "\n"