functional.py 12.5 KB
Newer Older
C
chenxuyi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
C
chenxuyi 已提交
14
"""Basic Dataset API"""
C
chenxuyi 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals

import sys
import logging
import os
import itertools
import random
import inspect
import multiprocessing
from contextlib import contextmanager
import gzip
import struct
import functools

import six
from six.moves import zip, map, filter
import numpy as np

from propeller.util import map_structure

log = logging.getLogger(__name__)

__all__ = ['Dataset']


@contextmanager
C
chenxuyi 已提交
43
def _open_file(filename, format=None):
C
chenxuyi 已提交
44 45 46 47 48 49 50 51 52 53
    if format is None:
        fd = open(filename, 'rb')
    elif format == 'GZIP':
        fd = gzip.open(filename, 'rb')
    else:
        raise ValueError('unkwon file format %s' % format)
    yield fd
    fd.close()


C
chenxuyi 已提交
54 55 56
def _open_record(filename):
    def _gen():
        with _open_file(filename, format='GZIP') as f:
C
chenxuyi 已提交
57 58 59 60 61 62 63 64
            while True:
                data = f.read(struct.calcsize('i'))
                if not len(data):
                    raise StopIteration
                l, = struct.unpack('i', data)
                data = f.read(l)
                yield data

C
chenxuyi 已提交
65
    return _gen
C
chenxuyi 已提交
66 67


C
chenxuyi 已提交
68 69
def _shuffle_func(dataset, buffer_size):
    def _gen():
C
chenxuyi 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
        buf = []
        iterable = dataset()
        try:
            while len(buf) < buffer_size:
                buf.append(next(iterable))
            while 1:
                i = random.randint(0, buffer_size - 1)
                n = next(iterable)
                yield buf[i]
                buf[i] = n
        except StopIteration:
            if len(buf):
                random.shuffle(buf)
                for i in buf:
                    yield i

C
chenxuyi 已提交
86
    return _gen
C
chenxuyi 已提交
87 88


C
chenxuyi 已提交
89 90
def _interleave_func(iterable, map_fn, cycle_length, block_length):
    def _gen():
C
chenxuyi 已提交
91 92 93 94 95 96 97 98 99 100 101 102
        ls = itertools.tee(iterable(), cycle_length)
        buf = []
        for i, j in enumerate(ls):
            j = itertools.islice(j, i, None, cycle_length)
            j = map(map_fn, j)
            j = (jjj for jj in j for jjj in jj)  #flatten
            buf.append(j)

        for tup in six.moves.zip_longest(*buf):
            for ii in (i for i in tup if i is not None):
                yield ii

C
chenxuyi 已提交
103
    return _gen
C
chenxuyi 已提交
104 105


C
chenxuyi 已提交
106 107
def _repeat_func(dataset, n):
    def _gen():
C
chenxuyi 已提交
108 109 110 111 112 113 114 115 116
        iterable = dataset()
        if n >= 0:
            ret = itertools.chain(*itertools.tee(iterable, n))
        else:
            ret = itertools.cycle(iterable)

        for i in ret:
            yield i

C
chenxuyi 已提交
117
    return _gen
C
chenxuyi 已提交
118 119


C
chenxuyi 已提交
120 121
def _filter_func(dataset, fn):
    def _gen():
C
chenxuyi 已提交
122 123 124 125 126 127 128 129
        for i in dataset():
            if isinstance(i, tuple) or isinstance(i, list):
                if fn(*i) is True:
                    yield i
            else:
                if fn(i) is True:
                    yield i

C
chenxuyi 已提交
130
    return _gen
C
chenxuyi 已提交
131 132


C
chenxuyi 已提交
133 134
def _map_func(dataset, fn):
    def _gen():
C
chenxuyi 已提交
135 136 137 138 139 140
        for i in dataset():
            if isinstance(i, tuple) or isinstance(i, list):
                yield fn(*i)
            else:
                yield fn(i)

C
chenxuyi 已提交
141
    return _gen
C
chenxuyi 已提交
142 143


C
chenxuyi 已提交
144 145
def _shard_func(dataset, num_shards, index):
    def _gen():
C
chenxuyi 已提交
146 147 148 149 150
        iterable = dataset()
        ret = itertools.islice(iterable, index, None, num_shards)
        for i in ret:
            yield i

C
chenxuyi 已提交
151
    return _gen
C
chenxuyi 已提交
152 153


C
chenxuyi 已提交
154 155
def _take_func(dataset, count):
    def _gen():
C
chenxuyi 已提交
156 157 158 159 160
        iterable = dataset()
        ret = itertools.islice(iterable, count)
        for i in ret:
            yield i

C
chenxuyi 已提交
161
    return _gen
C
chenxuyi 已提交
162 163


C
chenxuyi 已提交
164
def _buffered_func(dataset, size):
C
chenxuyi 已提交
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
    """
    Creates a buffered data reader.

    The buffered data reader will read and save data entries into a
    buffer. Reading from the buffered data reader will proceed as long
    as the buffer is not empty.

    :param reader: the data reader to read from.
    :type reader: callable
    :param size: max buffer size.
    :type size: int

    :returns: the buffered data reader.
    """

C
chenxuyi 已提交
180
    class _EndSignal(object):
C
chenxuyi 已提交
181 182
        pass

C
chenxuyi 已提交
183
    end = _EndSignal()
C
chenxuyi 已提交
184

C
chenxuyi 已提交
185
    def _read_worker(r, q):
C
chenxuyi 已提交
186 187 188 189
        for d in r:
            q.put(d)
        q.put(end)

C
chenxuyi 已提交
190
    def _data_reader():
C
chenxuyi 已提交
191 192 193
        r = dataset()
        q = multiprocessing.Queue(maxsize=size)
        t = multiprocessing.Process(
C
chenxuyi 已提交
194
            target=_read_worker, args=(
C
chenxuyi 已提交
195 196 197 198 199 200 201 202 203
                r,
                q, ))
        t.daemon = True
        t.start()
        e = q.get()
        while e != end:
            yield e
            e = q.get()

C
chenxuyi 已提交
204
    return _data_reader
C
chenxuyi 已提交
205 206


C
chenxuyi 已提交
207
def _padded_batch_func(dataset, batch_size, pad_value=0, max_seqlen=None):
C
chenxuyi 已提交
208 209 210
    if not isinstance(batch_size, int):
        raise ValueError('unknown batch_size: %s' % repr(batch_size))

C
chenxuyi 已提交
211
    def _gen():
C
chenxuyi 已提交
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
        iterable = dataset()
        pad_value_t = pad_value
        while True:
            buf = list(itertools.islice(iterable, batch_size))
            if not len(buf):
                raise StopIteration
            buf = list(zip(*buf))  # transpose
            if type(pad_value_t) not in [list, tuple]:
                pad_value_t = [pad_value_t] * len(buf)
            padded = []
            assert len(buf) == len(
                pad_value_t), 'pad_value [%d] != element size[%d]' % (
                    len(pad_value_t), len(buf))
            for e, pv in zip(buf, pad_value_t):
                elem = e[0]
                if (not np.isscalar(elem)) and elem.shape != ():
                    max_len = max(map(len,
                                      e)) if max_seqlen is None else max_seqlen
C
chenxuyi 已提交
230 231 232 233 234 235 236 237 238 239

                    def _fn(i):
                        if max_len >= len(i):
                            return np.pad(i, [0, max_len - len(i)],
                                          'constant',
                                          constant_values=pv)
                        else:
                            return i[:max_len]

                    e = map(_fn, e)
C
chenxuyi 已提交
240 241 242
                padded.append(np.stack(list(e)))
            yield padded

C
chenxuyi 已提交
243
    return _gen
C
chenxuyi 已提交
244 245 246


class Dataset(object):
C
chenxuyi 已提交
247 248
    """Python Wrapper for PyReader"""

C
chenxuyi 已提交
249
    @classmethod
C
chenxuyi 已提交
250 251 252 253
    def from_generator_func(cls, _gen, data_shapes=None, data_types=None):
        """doc"""
        if not inspect.isgeneratorfunction(_gen):
            raise ValueError('expect generator function, got %s' % repr(_gen))
C
chenxuyi 已提交
254

C
chenxuyi 已提交
255
        def _wrapper():  #compat to py3.7
C
chenxuyi 已提交
256
            try:
C
chenxuyi 已提交
257
                for item in _gen():
C
chenxuyi 已提交
258 259 260 261 262 263
                    yield item
            except RuntimeError as e:
                if str(e) != 'generator raised StopIteration':
                    raise e

        ret = cls()
C
chenxuyi 已提交
264
        ret.generator = _wrapper
C
chenxuyi 已提交
265 266 267 268 269 270
        ret.data_shapes = data_shapes
        ret.data_types = data_types
        return ret

    @classmethod
    def from_file(cls, filename, format=None):
C
chenxuyi 已提交
271
        """doc"""
C
chenxuyi 已提交
272 273 274
        if os.path.getsize(filename) == 0:
            raise RuntimeError('%s is empty' % filename)

C
chenxuyi 已提交
275 276
        def _gen():
            with _open_file(filename, format) as f:
C
chenxuyi 已提交
277 278 279 280
                for line in f:
                    yield line

        ret = cls()
C
chenxuyi 已提交
281
        ret.generator = _gen
C
chenxuyi 已提交
282 283 284 285 286 287
        ret.data_shapes = []
        ret.data_types = str
        return ret

    @classmethod
    def from_record_file(cls, filename):
C
chenxuyi 已提交
288
        """doc"""
C
chenxuyi 已提交
289 290
        if os.path.getsize(filename) == 0:
            raise RuntimeError('%s is empty' % filename)
C
chenxuyi 已提交
291
        _gen = _open_record(filename)
C
chenxuyi 已提交
292
        ret = cls()
C
chenxuyi 已提交
293
        ret.generator = _gen
C
chenxuyi 已提交
294 295 296 297 298 299
        ret.data_shapes = []
        ret.data_types = str
        return ret

    @classmethod
    def from_list(cls, ls):
C
chenxuyi 已提交
300
        """doc"""
C
chenxuyi 已提交
301 302 303
        if not isinstance(ls, list):
            raise ValueError('expect list, got %s' % repr(ls))

C
chenxuyi 已提交
304
        def _gen():
C
chenxuyi 已提交
305 306 307 308
            for i in ls:
                yield i

        ret = cls()
C
chenxuyi 已提交
309
        ret.generator = _gen
C
chenxuyi 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
        ret.data_shapes = []
        ret.data_types = str
        return ret

    def __init__(self):
        self.name = None
        self._data_shapes = None
        self._data_types = None
        self.generator = None
        self.pyreader = None

    def __repr__(self):
        return 'Dataset: name: %s, data_shapes %s, data_types %s' % (
            self.name, self._data_shapes, self._data_types)

    def __eq__(self, other):
        return self.name == other.name and \
               self._data_shapes == other._data_shapes and \
               self._data_types == other._data_types

    def __iter__(self):
        return self.generator()

    #def __call__(self):
    #    return self.generator()

    def _infer_shapes_and_types(self):
        if self.generator is not None and self.name is not None:
            log.info('Try to infer data shapes & types from generator')
            first_value = next(self.generator())
            shapes, types = [], []
            for v in first_value:
                if not isinstance(v, np.ndarray):
                    raise ValueError(
                        'dataset generator should use numpy elements, got %s' %
                        first_value)
                shapes.append(v.shape)
                types.append(v.dtype.name)
            self._data_shapes = shapes
            self._data_types = types
            log.info('Dataset `%s` has data_shapes: %s data_types: %s' %
                     (self.name, repr(shapes), repr(types)))
        else:
            raise ValueError(
                'Try to infer data shapes or types from incomplete Dataset')

    @property
    def data_shapes(self):
C
chenxuyi 已提交
358
        """doc"""
C
chenxuyi 已提交
359 360 361 362 363 364 365 366
        if self._data_shapes is None:
            self._infer_shapes_and_types()
            return self._data_shapes
        else:
            return self._data_shapes

    @data_shapes.setter
    def data_shapes(self, val):
C
chenxuyi 已提交
367
        """doc"""
C
chenxuyi 已提交
368 369 370 371
        self._data_shapes = val

    @property
    def data_types(self):
C
chenxuyi 已提交
372
        """doc"""
C
chenxuyi 已提交
373 374 375 376 377 378 379 380
        if self._data_types is None:
            self._infer_shapes_and_types()
            return self._data_types
        else:
            return self._data_types

    @data_types.setter
    def data_types(self, val):
C
chenxuyi 已提交
381
        """doc"""
C
chenxuyi 已提交
382 383 384
        self._data_types = val

    def apply(self, transform_func):
C
chenxuyi 已提交
385
        """apply transform func to datasets"""
C
chenxuyi 已提交
386 387 388 389 390 391 392 393 394 395 396 397 398 399 400
        #input_shapes = transform_func.input_shapes
        #input_types = transform_func.input_types
        #data_shapes = transform_func.data_shapes
        #data_types = transform_func.data_types
        #assert input_shapes == self._data_shapes
        #assert input_types = self._data_types
        ret_gen = transform_func(self.generator)
        ret = type(self).from_generator_func(ret_gen)
        if self.name is not None:
            ret.name = self.name
        #ret.data_shapes = data_shapes
        #ret.data_types = data_types
        return ret

    def shuffle(self, buffer_size):
C
chenxuyi 已提交
401 402
        """doc"""
        func = functools.partial(_shuffle_func, buffer_size=buffer_size)
C
chenxuyi 已提交
403 404 405
        return self.apply(func)

    def repeat(self, n=-1):
C
chenxuyi 已提交
406 407
        """doc"""
        func = functools.partial(_repeat_func, n=n)
C
chenxuyi 已提交
408 409 410
        return self.apply(func)

    def map(self, fn):
C
chenxuyi 已提交
411 412
        """doc"""
        func = functools.partial(_map_func, fn=fn)
C
chenxuyi 已提交
413 414 415
        return self.apply(func)

    def filter(self, fn):
C
chenxuyi 已提交
416 417
        """doc"""
        func = functools.partial(_filter_func, fn=fn)
C
chenxuyi 已提交
418 419 420
        return self.apply(func)

    def shard(self, num_shards, index):
C
chenxuyi 已提交
421
        """doc"""
C
chenxuyi 已提交
422
        func = functools.partial(
C
chenxuyi 已提交
423
            _shard_func, num_shards=num_shards, index=index)
C
chenxuyi 已提交
424 425 426
        return self.apply(func)

    def interleave(self, map_fn, cycle_length, block_length):
C
chenxuyi 已提交
427
        """doc"""
C
chenxuyi 已提交
428
        func = functools.partial(
C
chenxuyi 已提交
429
            _interleave_func,
C
chenxuyi 已提交
430 431 432 433 434 435
            map_fn=map_fn,
            cycle_length=cycle_length,
            block_length=block_length)
        return self.apply(func)

    def padded_batch(self, batch_size, pad_value=0, max_seqlen=None):
C
chenxuyi 已提交
436
        """doc"""
C
chenxuyi 已提交
437
        func = functools.partial(
C
chenxuyi 已提交
438
            _padded_batch_func,
C
chenxuyi 已提交
439 440 441 442 443 444
            batch_size=batch_size,
            pad_value=pad_value,
            max_seqlen=max_seqlen)
        return self.apply(func)

    def take(self, count=1):
C
chenxuyi 已提交
445 446
        """doc"""
        func = functools.partial(_take_func, count=count)
C
chenxuyi 已提交
447 448 449
        return self.apply(func)

    def buffered(self, size=10):
C
chenxuyi 已提交
450 451
        """doc"""
        func = functools.partial(_buffered_func, size=size)
C
chenxuyi 已提交
452
        return self.apply(func)