functional.py 13.3 KB
Newer Older
C
chenxuyi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
C
chenxuyi 已提交
14
"""Basic Dataset API"""
C
chenxuyi 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals

import sys
import logging
import os
import itertools
import random
import inspect
import multiprocessing
from contextlib import contextmanager
import gzip
import struct
import functools

import six
from six.moves import zip, map, filter
import numpy as np

from propeller.util import map_structure

log = logging.getLogger(__name__)

__all__ = ['Dataset']


@contextmanager
C
chenxuyi 已提交
43
def _open_file(filename, format=None):
C
chenxuyi 已提交
44 45 46 47 48 49 50 51 52 53
    if format is None:
        fd = open(filename, 'rb')
    elif format == 'GZIP':
        fd = gzip.open(filename, 'rb')
    else:
        raise ValueError('unkwon file format %s' % format)
    yield fd
    fd.close()


C
chenxuyi 已提交
54 55 56
def _open_record(filename):
    def _gen():
        with _open_file(filename, format='GZIP') as f:
C
chenxuyi 已提交
57 58 59 60 61 62 63 64
            while True:
                data = f.read(struct.calcsize('i'))
                if not len(data):
                    raise StopIteration
                l, = struct.unpack('i', data)
                data = f.read(l)
                yield data

C
chenxuyi 已提交
65
    return _gen
C
chenxuyi 已提交
66 67


C
chenxuyi 已提交
68 69
def _shuffle_func(dataset, buffer_size):
    def _gen():
C
chenxuyi 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
        buf = []
        iterable = dataset()
        try:
            while len(buf) < buffer_size:
                buf.append(next(iterable))
            while 1:
                i = random.randint(0, buffer_size - 1)
                n = next(iterable)
                yield buf[i]
                buf[i] = n
        except StopIteration:
            if len(buf):
                random.shuffle(buf)
                for i in buf:
                    yield i

C
chenxuyi 已提交
86
    return _gen
C
chenxuyi 已提交
87 88


C
chenxuyi 已提交
89 90
def _interleave_func(iterable, map_fn, cycle_length, block_length):
    def _gen():
C
chenxuyi 已提交
91 92 93 94 95 96 97 98 99 100 101 102
        ls = itertools.tee(iterable(), cycle_length)
        buf = []
        for i, j in enumerate(ls):
            j = itertools.islice(j, i, None, cycle_length)
            j = map(map_fn, j)
            j = (jjj for jj in j for jjj in jj)  #flatten
            buf.append(j)

        for tup in six.moves.zip_longest(*buf):
            for ii in (i for i in tup if i is not None):
                yield ii

C
chenxuyi 已提交
103
    return _gen
C
chenxuyi 已提交
104 105


C
chenxuyi 已提交
106 107
def _repeat_func(dataset, n):
    def _gen():
C
chenxuyi 已提交
108 109 110 111 112 113 114 115 116
        iterable = dataset()
        if n >= 0:
            ret = itertools.chain(*itertools.tee(iterable, n))
        else:
            ret = itertools.cycle(iterable)

        for i in ret:
            yield i

C
chenxuyi 已提交
117
    return _gen
C
chenxuyi 已提交
118 119


C
chenxuyi 已提交
120 121
def _filter_func(dataset, fn):
    def _gen():
C
chenxuyi 已提交
122 123 124 125 126 127 128 129
        for i in dataset():
            if isinstance(i, tuple) or isinstance(i, list):
                if fn(*i) is True:
                    yield i
            else:
                if fn(i) is True:
                    yield i

C
chenxuyi 已提交
130
    return _gen
C
chenxuyi 已提交
131 132


C
chenxuyi 已提交
133 134
def _map_func(dataset, fn):
    def _gen():
C
chenxuyi 已提交
135 136 137 138 139 140
        for i in dataset():
            if isinstance(i, tuple) or isinstance(i, list):
                yield fn(*i)
            else:
                yield fn(i)

C
chenxuyi 已提交
141
    return _gen
C
chenxuyi 已提交
142 143


C
chenxuyi 已提交
144 145
def _shard_func(dataset, num_shards, index):
    def _gen():
C
chenxuyi 已提交
146 147 148 149 150
        iterable = dataset()
        ret = itertools.islice(iterable, index, None, num_shards)
        for i in ret:
            yield i

C
chenxuyi 已提交
151
    return _gen
C
chenxuyi 已提交
152 153


C
chenxuyi 已提交
154 155
def _take_func(dataset, count):
    def _gen():
C
chenxuyi 已提交
156 157 158 159 160
        iterable = dataset()
        ret = itertools.islice(iterable, count)
        for i in ret:
            yield i

C
chenxuyi 已提交
161
    return _gen
C
chenxuyi 已提交
162 163


M
Meiyim 已提交
164 165 166 167 168 169 170 171 172 173 174
def _chain_func(dataset, dataset2):
    def _gen():
        iterable = dataset()
        iterable2 = dataset2()
        ret = itertools.chain(iterable, iterable2)
        for i in ret:
            yield i

    return _gen


C
chenxuyi 已提交
175
def _buffered_func(dataset, size):
C
chenxuyi 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
    """
    Creates a buffered data reader.

    The buffered data reader will read and save data entries into a
    buffer. Reading from the buffered data reader will proceed as long
    as the buffer is not empty.

    :param reader: the data reader to read from.
    :type reader: callable
    :param size: max buffer size.
    :type size: int

    :returns: the buffered data reader.
    """

C
chenxuyi 已提交
191
    class _EndSignal(object):
C
chenxuyi 已提交
192 193
        pass

C
chenxuyi 已提交
194
    end = _EndSignal()
C
chenxuyi 已提交
195

C
chenxuyi 已提交
196
    def _read_worker(r, q):
C
chenxuyi 已提交
197 198 199 200
        for d in r:
            q.put(d)
        q.put(end)

C
chenxuyi 已提交
201
    def _data_reader():
C
chenxuyi 已提交
202 203 204
        r = dataset()
        q = multiprocessing.Queue(maxsize=size)
        t = multiprocessing.Process(
C
chenxuyi 已提交
205
            target=_read_worker, args=(
C
chenxuyi 已提交
206 207 208 209 210 211 212 213 214
                r,
                q, ))
        t.daemon = True
        t.start()
        e = q.get()
        while e != end:
            yield e
            e = q.get()

C
chenxuyi 已提交
215
    return _data_reader
C
chenxuyi 已提交
216 217


M
Meiyim 已提交
218 219 220 221 222 223 224 225 226 227 228 229 230 231
def _batch_func(dataset, batch_size):
    def _gen():
        iterable = dataset()
        while True:
            buf = list(itertools.islice(iterable, batch_size))
            if not len(buf):
                raise StopIteration
            buf = list(zip(*buf))  # transpose
            buf = [np.stack(b) for b in buf]
            yield buf

    return _gen


C
chenxuyi 已提交
232
def _padded_batch_func(dataset, batch_size, pad_value=0, max_seqlen=None):
C
chenxuyi 已提交
233 234 235
    if not isinstance(batch_size, int):
        raise ValueError('unknown batch_size: %s' % repr(batch_size))

C
chenxuyi 已提交
236
    def _gen():
C
chenxuyi 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
        iterable = dataset()
        pad_value_t = pad_value
        while True:
            buf = list(itertools.islice(iterable, batch_size))
            if not len(buf):
                raise StopIteration
            buf = list(zip(*buf))  # transpose
            if type(pad_value_t) not in [list, tuple]:
                pad_value_t = [pad_value_t] * len(buf)
            padded = []
            assert len(buf) == len(
                pad_value_t), 'pad_value [%d] != element size[%d]' % (
                    len(pad_value_t), len(buf))
            for e, pv in zip(buf, pad_value_t):
                elem = e[0]
                if (not np.isscalar(elem)) and elem.shape != ():
                    max_len = max(map(len,
                                      e)) if max_seqlen is None else max_seqlen
C
chenxuyi 已提交
255 256 257 258 259 260 261 262 263 264

                    def _fn(i):
                        if max_len >= len(i):
                            return np.pad(i, [0, max_len - len(i)],
                                          'constant',
                                          constant_values=pv)
                        else:
                            return i[:max_len]

                    e = map(_fn, e)
C
chenxuyi 已提交
265 266 267
                padded.append(np.stack(list(e)))
            yield padded

C
chenxuyi 已提交
268
    return _gen
C
chenxuyi 已提交
269 270 271


class Dataset(object):
C
chenxuyi 已提交
272 273
    """Python Wrapper for PyReader"""

C
chenxuyi 已提交
274
    @classmethod
C
chenxuyi 已提交
275 276 277 278
    def from_generator_func(cls, _gen, data_shapes=None, data_types=None):
        """doc"""
        if not inspect.isgeneratorfunction(_gen):
            raise ValueError('expect generator function, got %s' % repr(_gen))
C
chenxuyi 已提交
279

C
chenxuyi 已提交
280
        def _wrapper():  #compat to py3.7
C
chenxuyi 已提交
281
            try:
C
chenxuyi 已提交
282
                for item in _gen():
C
chenxuyi 已提交
283 284 285 286 287 288
                    yield item
            except RuntimeError as e:
                if str(e) != 'generator raised StopIteration':
                    raise e

        ret = cls()
C
chenxuyi 已提交
289
        ret.generator = _wrapper
C
chenxuyi 已提交
290 291 292 293 294 295
        ret.data_shapes = data_shapes
        ret.data_types = data_types
        return ret

    @classmethod
    def from_file(cls, filename, format=None):
C
chenxuyi 已提交
296
        """doc"""
C
chenxuyi 已提交
297 298 299
        if os.path.getsize(filename) == 0:
            raise RuntimeError('%s is empty' % filename)

C
chenxuyi 已提交
300 301
        def _gen():
            with _open_file(filename, format) as f:
C
chenxuyi 已提交
302 303 304 305
                for line in f:
                    yield line

        ret = cls()
C
chenxuyi 已提交
306
        ret.generator = _gen
C
chenxuyi 已提交
307 308 309 310 311 312
        ret.data_shapes = []
        ret.data_types = str
        return ret

    @classmethod
    def from_record_file(cls, filename):
C
chenxuyi 已提交
313
        """doc"""
C
chenxuyi 已提交
314 315
        if os.path.getsize(filename) == 0:
            raise RuntimeError('%s is empty' % filename)
C
chenxuyi 已提交
316
        _gen = _open_record(filename)
C
chenxuyi 已提交
317
        ret = cls()
C
chenxuyi 已提交
318
        ret.generator = _gen
C
chenxuyi 已提交
319 320 321 322 323 324
        ret.data_shapes = []
        ret.data_types = str
        return ret

    @classmethod
    def from_list(cls, ls):
C
chenxuyi 已提交
325
        """doc"""
C
chenxuyi 已提交
326 327 328
        if not isinstance(ls, list):
            raise ValueError('expect list, got %s' % repr(ls))

C
chenxuyi 已提交
329
        def _gen():
C
chenxuyi 已提交
330 331 332 333
            for i in ls:
                yield i

        ret = cls()
C
chenxuyi 已提交
334
        ret.generator = _gen
C
chenxuyi 已提交
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
        ret.data_shapes = []
        ret.data_types = str
        return ret

    def __init__(self):
        self.name = None
        self._data_shapes = None
        self._data_types = None
        self.generator = None
        self.pyreader = None

    def __repr__(self):
        return 'Dataset: name: %s, data_shapes %s, data_types %s' % (
            self.name, self._data_shapes, self._data_types)

    def __eq__(self, other):
        return self.name == other.name and \
               self._data_shapes == other._data_shapes and \
               self._data_types == other._data_types

    def __iter__(self):
        return self.generator()

    #def __call__(self):
    #    return self.generator()

    def _infer_shapes_and_types(self):
        if self.generator is not None and self.name is not None:
            log.info('Try to infer data shapes & types from generator')
            first_value = next(self.generator())
            shapes, types = [], []
            for v in first_value:
                if not isinstance(v, np.ndarray):
                    raise ValueError(
                        'dataset generator should use numpy elements, got %s' %
                        first_value)
                shapes.append(v.shape)
                types.append(v.dtype.name)
            self._data_shapes = shapes
            self._data_types = types
            log.info('Dataset `%s` has data_shapes: %s data_types: %s' %
                     (self.name, repr(shapes), repr(types)))
        else:
            raise ValueError(
                'Try to infer data shapes or types from incomplete Dataset')

    @property
    def data_shapes(self):
C
chenxuyi 已提交
383
        """doc"""
C
chenxuyi 已提交
384 385 386 387 388 389 390 391
        if self._data_shapes is None:
            self._infer_shapes_and_types()
            return self._data_shapes
        else:
            return self._data_shapes

    @data_shapes.setter
    def data_shapes(self, val):
C
chenxuyi 已提交
392
        """doc"""
C
chenxuyi 已提交
393 394 395 396
        self._data_shapes = val

    @property
    def data_types(self):
C
chenxuyi 已提交
397
        """doc"""
C
chenxuyi 已提交
398 399 400 401 402 403 404 405
        if self._data_types is None:
            self._infer_shapes_and_types()
            return self._data_types
        else:
            return self._data_types

    @data_types.setter
    def data_types(self, val):
C
chenxuyi 已提交
406
        """doc"""
C
chenxuyi 已提交
407 408 409
        self._data_types = val

    def apply(self, transform_func):
C
chenxuyi 已提交
410
        """apply transform func to datasets"""
C
chenxuyi 已提交
411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
        #input_shapes = transform_func.input_shapes
        #input_types = transform_func.input_types
        #data_shapes = transform_func.data_shapes
        #data_types = transform_func.data_types
        #assert input_shapes == self._data_shapes
        #assert input_types = self._data_types
        ret_gen = transform_func(self.generator)
        ret = type(self).from_generator_func(ret_gen)
        if self.name is not None:
            ret.name = self.name
        #ret.data_shapes = data_shapes
        #ret.data_types = data_types
        return ret

    def shuffle(self, buffer_size):
C
chenxuyi 已提交
426 427
        """doc"""
        func = functools.partial(_shuffle_func, buffer_size=buffer_size)
C
chenxuyi 已提交
428 429 430
        return self.apply(func)

    def repeat(self, n=-1):
C
chenxuyi 已提交
431 432
        """doc"""
        func = functools.partial(_repeat_func, n=n)
C
chenxuyi 已提交
433 434 435
        return self.apply(func)

    def map(self, fn):
C
chenxuyi 已提交
436 437
        """doc"""
        func = functools.partial(_map_func, fn=fn)
C
chenxuyi 已提交
438 439 440
        return self.apply(func)

    def filter(self, fn):
C
chenxuyi 已提交
441 442
        """doc"""
        func = functools.partial(_filter_func, fn=fn)
C
chenxuyi 已提交
443 444 445
        return self.apply(func)

    def shard(self, num_shards, index):
C
chenxuyi 已提交
446
        """doc"""
C
chenxuyi 已提交
447
        func = functools.partial(
C
chenxuyi 已提交
448
            _shard_func, num_shards=num_shards, index=index)
C
chenxuyi 已提交
449 450 451
        return self.apply(func)

    def interleave(self, map_fn, cycle_length, block_length):
C
chenxuyi 已提交
452
        """doc"""
C
chenxuyi 已提交
453
        func = functools.partial(
C
chenxuyi 已提交
454
            _interleave_func,
C
chenxuyi 已提交
455 456 457 458 459
            map_fn=map_fn,
            cycle_length=cycle_length,
            block_length=block_length)
        return self.apply(func)

M
Meiyim 已提交
460 461 462 463
    def batch(self, batch_size):
        func = functools.partial(_batch_func, batch_size=batch_size)
        return self.apply(func)

C
chenxuyi 已提交
464
    def padded_batch(self, batch_size, pad_value=0, max_seqlen=None):
C
chenxuyi 已提交
465
        """doc"""
C
chenxuyi 已提交
466
        func = functools.partial(
C
chenxuyi 已提交
467
            _padded_batch_func,
C
chenxuyi 已提交
468 469 470 471 472 473
            batch_size=batch_size,
            pad_value=pad_value,
            max_seqlen=max_seqlen)
        return self.apply(func)

    def take(self, count=1):
C
chenxuyi 已提交
474 475
        """doc"""
        func = functools.partial(_take_func, count=count)
C
chenxuyi 已提交
476 477 478
        return self.apply(func)

    def buffered(self, size=10):
C
chenxuyi 已提交
479 480
        """doc"""
        func = functools.partial(_buffered_func, size=size)
C
chenxuyi 已提交
481
        return self.apply(func)
M
Meiyim 已提交
482 483 484 485

    def chain(self, other):
        func = functools.partial(_chain_func, dataset2=other.generator)
        return self.apply(func)