functional.py 17.7 KB
Newer Older
C
chenxuyi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
C
chenxuyi 已提交
14
"""Basic Dataset API"""
C
chenxuyi 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals

import sys
import logging
import os
import itertools
import random
import inspect
import multiprocessing
from contextlib import contextmanager
import gzip
import struct
import functools

import six
from six.moves import zip, map, filter
import numpy as np

from propeller.util import map_structure

log = logging.getLogger(__name__)

__all__ = ['Dataset']


@contextmanager
C
chenxuyi 已提交
43
def _open_file(filename, format=None):
C
chenxuyi 已提交
44 45 46 47 48 49 50 51 52 53
    if format is None:
        fd = open(filename, 'rb')
    elif format == 'GZIP':
        fd = gzip.open(filename, 'rb')
    else:
        raise ValueError('unkwon file format %s' % format)
    yield fd
    fd.close()


C
chenxuyi 已提交
54 55 56
def _open_record(filename):
    def _gen():
        with _open_file(filename, format='GZIP') as f:
C
chenxuyi 已提交
57 58 59 60 61 62 63 64
            while True:
                data = f.read(struct.calcsize('i'))
                if not len(data):
                    raise StopIteration
                l, = struct.unpack('i', data)
                data = f.read(l)
                yield data

C
chenxuyi 已提交
65
    return _gen
C
chenxuyi 已提交
66 67


C
chenxuyi 已提交
68 69
def _shuffle_func(dataset, buffer_size):
    def _gen():
C
chenxuyi 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
        buf = []
        iterable = dataset()
        try:
            while len(buf) < buffer_size:
                buf.append(next(iterable))
            while 1:
                i = random.randint(0, buffer_size - 1)
                n = next(iterable)
                yield buf[i]
                buf[i] = n
        except StopIteration:
            if len(buf):
                random.shuffle(buf)
                for i in buf:
                    yield i

C
chenxuyi 已提交
86
    return _gen
C
chenxuyi 已提交
87 88


M
Meiyim 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
def _cache_shuffle_shard_func(dataset, num_shards, index, seed, drop_last,
                              repeat):
    def _gen():
        iterable = dataset()
        data_list = list(iterable)
        len_per_shard = len(data_list) // num_shards
        rng = np.random.RandomState(seed)
        cnt = 0
        while cnt != repeat:
            cnt += 1
            random.shuffle(data_list, rng.uniform)

            iter_data_list = [
                data_list[i] for i in range(index, len(data_list), num_shards)
            ]

            if drop_last:
                iter_data_list = iter_data_list[:len_per_shard]
            else:
                fill_start_idx = len(data_list) % num_shards
                if 0 < fill_start_idx <= index:
                    iter_data_list.append(random.choice(data_list))

            for data in iter_data_list:
                yield data

    return _gen


C
chenxuyi 已提交
118 119
def _interleave_func(iterable, map_fn, cycle_length, block_length):
    def _gen():
C
chenxuyi 已提交
120 121 122 123 124
        ls = itertools.tee(iterable(), cycle_length)
        buf = []
        for i, j in enumerate(ls):
            j = itertools.islice(j, i, None, cycle_length)
            j = map(map_fn, j)
M
Meiyim 已提交
125

C
chenxuyi 已提交
126
            j = (jjj for jj in j for jjj in jj)  #flatten
M
Meiyim 已提交
127

C
chenxuyi 已提交
128 129 130 131 132 133
            buf.append(j)

        for tup in six.moves.zip_longest(*buf):
            for ii in (i for i in tup if i is not None):
                yield ii

C
chenxuyi 已提交
134
    return _gen
C
chenxuyi 已提交
135 136


C
chenxuyi 已提交
137 138
def _repeat_func(dataset, n):
    def _gen():
M
Meiyim 已提交
139
        # iterable = dataset()
C
chenxuyi 已提交
140
        if n >= 0:
M
Meiyim 已提交
141 142 143 144
            iters = []
            for i in range(n):
                iters.append(dataset())
            ret = itertools.chain(*iters)
C
chenxuyi 已提交
145
        else:
M
Meiyim 已提交
146
            ret = itertools.cycle(dataset())
C
chenxuyi 已提交
147 148 149 150

        for i in ret:
            yield i

C
chenxuyi 已提交
151
    return _gen
C
chenxuyi 已提交
152 153


C
chenxuyi 已提交
154 155
def _filter_func(dataset, fn):
    def _gen():
C
chenxuyi 已提交
156 157 158 159 160 161 162 163
        for i in dataset():
            if isinstance(i, tuple) or isinstance(i, list):
                if fn(*i) is True:
                    yield i
            else:
                if fn(i) is True:
                    yield i

C
chenxuyi 已提交
164
    return _gen
C
chenxuyi 已提交
165 166


C
chenxuyi 已提交
167 168
def _map_func(dataset, fn):
    def _gen():
C
chenxuyi 已提交
169 170 171 172 173 174
        for i in dataset():
            if isinstance(i, tuple) or isinstance(i, list):
                yield fn(*i)
            else:
                yield fn(i)

C
chenxuyi 已提交
175
    return _gen
C
chenxuyi 已提交
176 177


C
chenxuyi 已提交
178 179
def _shard_func(dataset, num_shards, index):
    def _gen():
C
chenxuyi 已提交
180 181 182 183 184
        iterable = dataset()
        ret = itertools.islice(iterable, index, None, num_shards)
        for i in ret:
            yield i

C
chenxuyi 已提交
185
    return _gen
C
chenxuyi 已提交
186 187


M
Meiyim 已提交
188 189 190 191 192 193 194 195 196 197 198 199 200 201
def _chunk_func(dataset, num_shards):
    def _gen():
        iterable = dataset()
        while True:
            ret = list(itertools.islice(iterable, num_shards))
            if len(ret) == num_shards:
                for r in ret:
                    yield r
            else:
                raise StopIteration

    return _gen


C
chenxuyi 已提交
202 203
def _take_func(dataset, count):
    def _gen():
C
chenxuyi 已提交
204 205 206 207 208
        iterable = dataset()
        ret = itertools.islice(iterable, count)
        for i in ret:
            yield i

C
chenxuyi 已提交
209
    return _gen
C
chenxuyi 已提交
210 211


M
Meiyim 已提交
212 213 214 215 216 217 218 219 220 221 222
def _chain_func(dataset, dataset2):
    def _gen():
        iterable = dataset()
        iterable2 = dataset2()
        ret = itertools.chain(iterable, iterable2)
        for i in ret:
            yield i

    return _gen


C
chenxuyi 已提交
223
def _buffered_func(dataset, size):
C
chenxuyi 已提交
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
    """
    Creates a buffered data reader.

    The buffered data reader will read and save data entries into a
    buffer. Reading from the buffered data reader will proceed as long
    as the buffer is not empty.

    :param reader: the data reader to read from.
    :type reader: callable
    :param size: max buffer size.
    :type size: int

    :returns: the buffered data reader.
    """

C
chenxuyi 已提交
239
    class _EndSignal(object):
C
chenxuyi 已提交
240 241
        pass

C
chenxuyi 已提交
242
    end = _EndSignal()
C
chenxuyi 已提交
243

C
chenxuyi 已提交
244
    def _read_worker(r, q):
C
chenxuyi 已提交
245 246 247 248
        for d in r:
            q.put(d)
        q.put(end)

C
chenxuyi 已提交
249
    def _data_reader():
C
chenxuyi 已提交
250 251 252
        r = dataset()
        q = multiprocessing.Queue(maxsize=size)
        t = multiprocessing.Process(
C
chenxuyi 已提交
253
            target=_read_worker, args=(
C
chenxuyi 已提交
254 255 256 257 258 259 260 261 262
                r,
                q, ))
        t.daemon = True
        t.start()
        e = q.get()
        while e != end:
            yield e
            e = q.get()

C
chenxuyi 已提交
263
    return _data_reader
C
chenxuyi 已提交
264 265


M
Meiyim 已提交
266 267 268 269 270 271 272 273 274 275 276 277 278 279
def _batch_func(dataset, batch_size):
    def _gen():
        iterable = dataset()
        while True:
            buf = list(itertools.islice(iterable, batch_size))
            if not len(buf):
                raise StopIteration
            buf = list(zip(*buf))  # transpose
            buf = [np.stack(b) for b in buf]
            yield buf

    return _gen


M
Meiyim 已提交
280 281 282 283 284
def _padded_batch_func(dataset,
                       batch_size,
                       pad_value=0,
                       max_seqlen=None,
                       droplast=False):
C
chenxuyi 已提交
285 286 287
    if not isinstance(batch_size, int):
        raise ValueError('unknown batch_size: %s' % repr(batch_size))

C
chenxuyi 已提交
288
    def _gen():
C
chenxuyi 已提交
289 290 291 292
        iterable = dataset()
        pad_value_t = pad_value
        while True:
            buf = list(itertools.islice(iterable, batch_size))
M
Meiyim 已提交
293 294
            if droplast and len(buf) != batch_size:
                raise StopIteration
C
chenxuyi 已提交
295 296 297 298 299 300 301 302 303 304 305 306 307 308
            if not len(buf):
                raise StopIteration
            buf = list(zip(*buf))  # transpose
            if type(pad_value_t) not in [list, tuple]:
                pad_value_t = [pad_value_t] * len(buf)
            padded = []
            assert len(buf) == len(
                pad_value_t), 'pad_value [%d] != element size[%d]' % (
                    len(pad_value_t), len(buf))
            for e, pv in zip(buf, pad_value_t):
                elem = e[0]
                if (not np.isscalar(elem)) and elem.shape != ():
                    max_len = max(map(len,
                                      e)) if max_seqlen is None else max_seqlen
C
chenxuyi 已提交
309 310 311 312 313 314 315 316 317 318

                    def _fn(i):
                        if max_len >= len(i):
                            return np.pad(i, [0, max_len - len(i)],
                                          'constant',
                                          constant_values=pv)
                        else:
                            return i[:max_len]

                    e = map(_fn, e)
C
chenxuyi 已提交
319 320 321
                padded.append(np.stack(list(e)))
            yield padded

C
chenxuyi 已提交
322
    return _gen
C
chenxuyi 已提交
323 324


M
Meiyim 已提交
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
def flatten(structure):
    flt = []

    def map_structure(s):
        if isinstance(s, np.ndarray):
            flt.append(s)
            return len(flt) - 1
        elif isinstance(s, list):
            return [map_structure(item) for item in s]
        elif isinstance(s, tuple):
            return tuple([map_structure(item) for item in s])
        elif isinstance(s, dict):
            return {key: map_structure(s[key]) for key in sorted(s.keys())}
        else:
            raise TypeError

    return flt, map_structure(structure)


def unflatten(flt, schema):
    def map_structure(s):
        if isinstance(s, int):
            return flt[s]
        elif isinstance(s, list):
            return [map_structure(item) for item in s]
        elif isinstance(s, tuple):
            return tuple([map_structure(item) for item in s])
        elif isinstance(s, dict):
            return {key: map_structure(s[key]) for key in sorted(s.keys())}
        else:
            raise TypeError

    return map_structure(schema)


C
chenxuyi 已提交
360
class Dataset(object):
C
chenxuyi 已提交
361 362
    """Python Wrapper for PyReader"""

C
chenxuyi 已提交
363
    @classmethod
C
chenxuyi 已提交
364 365
    def from_generator_func(cls, _gen, data_shapes=None, data_types=None):
        """doc"""
M
Meiyim 已提交
366 367 368

        #if not inspect.isgeneratorfunction(_gen):
        #raise ValueError('expect generator function, got %s' % repr(_gen))
C
chenxuyi 已提交
369

C
chenxuyi 已提交
370
        def _wrapper():  #compat to py3.7
C
chenxuyi 已提交
371
            try:
C
chenxuyi 已提交
372
                for item in _gen():
C
chenxuyi 已提交
373 374 375 376 377 378
                    yield item
            except RuntimeError as e:
                if str(e) != 'generator raised StopIteration':
                    raise e

        ret = cls()
C
chenxuyi 已提交
379
        ret.generator = _wrapper
C
chenxuyi 已提交
380 381 382 383 384 385
        ret.data_shapes = data_shapes
        ret.data_types = data_types
        return ret

    @classmethod
    def from_file(cls, filename, format=None):
C
chenxuyi 已提交
386
        """doc"""
C
chenxuyi 已提交
387 388 389
        if os.path.getsize(filename) == 0:
            raise RuntimeError('%s is empty' % filename)

C
chenxuyi 已提交
390 391
        def _gen():
            with _open_file(filename, format) as f:
C
chenxuyi 已提交
392 393 394 395
                for line in f:
                    yield line

        ret = cls()
C
chenxuyi 已提交
396
        ret.generator = _gen
C
chenxuyi 已提交
397 398 399 400 401 402
        ret.data_shapes = []
        ret.data_types = str
        return ret

    @classmethod
    def from_record_file(cls, filename):
C
chenxuyi 已提交
403
        """doc"""
C
chenxuyi 已提交
404 405
        if os.path.getsize(filename) == 0:
            raise RuntimeError('%s is empty' % filename)
C
chenxuyi 已提交
406
        _gen = _open_record(filename)
C
chenxuyi 已提交
407
        ret = cls()
C
chenxuyi 已提交
408
        ret.generator = _gen
C
chenxuyi 已提交
409 410 411 412 413 414
        ret.data_shapes = []
        ret.data_types = str
        return ret

    @classmethod
    def from_list(cls, ls):
C
chenxuyi 已提交
415
        """doc"""
C
chenxuyi 已提交
416 417 418
        if not isinstance(ls, list):
            raise ValueError('expect list, got %s' % repr(ls))

C
chenxuyi 已提交
419
        def _gen():
C
chenxuyi 已提交
420 421 422 423
            for i in ls:
                yield i

        ret = cls()
C
chenxuyi 已提交
424
        ret.generator = _gen
C
chenxuyi 已提交
425 426 427 428 429 430 431 432
        ret.data_shapes = []
        ret.data_types = str
        return ret

    def __init__(self):
        self.name = None
        self._data_shapes = None
        self._data_types = None
M
Meiyim 已提交
433
        self._data_schema = None
C
chenxuyi 已提交
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
        self.generator = None
        self.pyreader = None

    def __repr__(self):
        return 'Dataset: name: %s, data_shapes %s, data_types %s' % (
            self.name, self._data_shapes, self._data_types)

    def __eq__(self, other):
        return self.name == other.name and \
               self._data_shapes == other._data_shapes and \
               self._data_types == other._data_types

    def __iter__(self):
        return self.generator()

    #def __call__(self):
    #    return self.generator()

M
Meiyim 已提交
452
    def _infer_shapes_and_types_and_schema(self):
C
chenxuyi 已提交
453 454
        if self.generator is not None and self.name is not None:
            log.info('Try to infer data shapes & types from generator')
M
Meiyim 已提交
455 456 457
            first_gen = self.generator()
            first_value = next(first_gen)
            first_value, self._data_schema = flatten(first_value)
C
chenxuyi 已提交
458 459 460 461 462 463
            shapes, types = [], []
            for v in first_value:
                if not isinstance(v, np.ndarray):
                    raise ValueError(
                        'dataset generator should use numpy elements, got %s' %
                        first_value)
M
Meiyim 已提交
464 465
                # use black magic to keep the same dataset shape.
                shapes.append([(i > 1) + 1 for i in v.shape])
C
chenxuyi 已提交
466 467 468 469 470
                types.append(v.dtype.name)
            self._data_shapes = shapes
            self._data_types = types
            log.info('Dataset `%s` has data_shapes: %s data_types: %s' %
                     (self.name, repr(shapes), repr(types)))
M
Meiyim 已提交
471 472 473 474 475 476 477 478 479 480 481 482
            original_generator = self.generator
            self.is_first_call = True

            def _gen():
                if self.is_first_call:
                    self.is_first_call = False
                    generator = itertools.chain([first_value], first_gen)
                else:
                    generator = original_generator()
                yield from generator

            self.generator = _gen
C
chenxuyi 已提交
483 484 485 486 487 488
        else:
            raise ValueError(
                'Try to infer data shapes or types from incomplete Dataset')

    @property
    def data_shapes(self):
C
chenxuyi 已提交
489
        """doc"""
C
chenxuyi 已提交
490
        if self._data_shapes is None:
M
Meiyim 已提交
491
            self._infer_shapes_and_types_and_schema()
C
chenxuyi 已提交
492 493 494 495 496 497
            return self._data_shapes
        else:
            return self._data_shapes

    @data_shapes.setter
    def data_shapes(self, val):
C
chenxuyi 已提交
498
        """doc"""
C
chenxuyi 已提交
499 500
        self._data_shapes = val

M
Meiyim 已提交
501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
    @property
    def data_schema(self):
        """doc"""
        if self._data_schema is None:
            if self._data_shapes is not None and self._data_types is not None:
                self._data_schema = [i for i in range(len(self._data_shapes))]
            else:
                self._infer_shapes_and_types_and_schema()
            return self._data_schema
        else:
            return self._data_schema

    @data_schema.setter
    def data_schema(self, val):
        """doc"""
        self._data_schema = val

C
chenxuyi 已提交
518 519
    @property
    def data_types(self):
C
chenxuyi 已提交
520
        """doc"""
C
chenxuyi 已提交
521
        if self._data_types is None:
M
Meiyim 已提交
522
            self._infer_shapes_and_types_and_schema()
C
chenxuyi 已提交
523 524 525 526 527 528
            return self._data_types
        else:
            return self._data_types

    @data_types.setter
    def data_types(self, val):
C
chenxuyi 已提交
529
        """doc"""
C
chenxuyi 已提交
530 531 532
        self._data_types = val

    def apply(self, transform_func):
C
chenxuyi 已提交
533
        """apply transform func to datasets"""
C
chenxuyi 已提交
534 535 536 537 538 539 540 541 542 543 544 545 546 547 548
        #input_shapes = transform_func.input_shapes
        #input_types = transform_func.input_types
        #data_shapes = transform_func.data_shapes
        #data_types = transform_func.data_types
        #assert input_shapes == self._data_shapes
        #assert input_types = self._data_types
        ret_gen = transform_func(self.generator)
        ret = type(self).from_generator_func(ret_gen)
        if self.name is not None:
            ret.name = self.name
        #ret.data_shapes = data_shapes
        #ret.data_types = data_types
        return ret

    def shuffle(self, buffer_size):
C
chenxuyi 已提交
549 550
        """doc"""
        func = functools.partial(_shuffle_func, buffer_size=buffer_size)
C
chenxuyi 已提交
551 552 553
        return self.apply(func)

    def repeat(self, n=-1):
C
chenxuyi 已提交
554 555
        """doc"""
        func = functools.partial(_repeat_func, n=n)
C
chenxuyi 已提交
556 557 558
        return self.apply(func)

    def map(self, fn):
C
chenxuyi 已提交
559 560
        """doc"""
        func = functools.partial(_map_func, fn=fn)
C
chenxuyi 已提交
561 562 563
        return self.apply(func)

    def filter(self, fn):
C
chenxuyi 已提交
564 565
        """doc"""
        func = functools.partial(_filter_func, fn=fn)
C
chenxuyi 已提交
566 567 568
        return self.apply(func)

    def shard(self, num_shards, index):
C
chenxuyi 已提交
569
        """doc"""
C
chenxuyi 已提交
570
        func = functools.partial(
C
chenxuyi 已提交
571
            _shard_func, num_shards=num_shards, index=index)
C
chenxuyi 已提交
572 573
        return self.apply(func)

M
Meiyim 已提交
574 575 576 577
    def chunk(self, num_shards):
        func = functools.partial(_chunk_func, num_shards=num_shards)
        return self.apply(func)

C
chenxuyi 已提交
578
    def interleave(self, map_fn, cycle_length, block_length):
C
chenxuyi 已提交
579
        """doc"""
C
chenxuyi 已提交
580
        func = functools.partial(
C
chenxuyi 已提交
581
            _interleave_func,
C
chenxuyi 已提交
582 583 584 585 586
            map_fn=map_fn,
            cycle_length=cycle_length,
            block_length=block_length)
        return self.apply(func)

M
Meiyim 已提交
587 588 589 590
    def batch(self, batch_size):
        func = functools.partial(_batch_func, batch_size=batch_size)
        return self.apply(func)

M
Meiyim 已提交
591 592 593 594 595
    def padded_batch(self,
                     batch_size,
                     pad_value=0,
                     max_seqlen=None,
                     droplast=False):
C
chenxuyi 已提交
596
        """doc"""
C
chenxuyi 已提交
597
        func = functools.partial(
C
chenxuyi 已提交
598
            _padded_batch_func,
C
chenxuyi 已提交
599 600
            batch_size=batch_size,
            pad_value=pad_value,
M
Meiyim 已提交
601 602
            max_seqlen=max_seqlen,
            droplast=droplast)
C
chenxuyi 已提交
603 604 605
        return self.apply(func)

    def take(self, count=1):
C
chenxuyi 已提交
606 607
        """doc"""
        func = functools.partial(_take_func, count=count)
C
chenxuyi 已提交
608 609 610
        return self.apply(func)

    def buffered(self, size=10):
C
chenxuyi 已提交
611 612
        """doc"""
        func = functools.partial(_buffered_func, size=size)
C
chenxuyi 已提交
613
        return self.apply(func)
M
Meiyim 已提交
614 615 616 617

    def chain(self, other):
        func = functools.partial(_chain_func, dataset2=other.generator)
        return self.apply(func)
M
Meiyim 已提交
618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633

    def cache_shuffle_shard(self,
                            num_shards,
                            index,
                            seed=0,
                            drop_last=True,
                            repeat=-1):
        func = functools.partial(
            _cache_shuffle_shard_func,
            num_shards=num_shards,
            index=index,
            seed=seed,
            repeat=repeat,
            drop_last=drop_last, )

        return self.apply(func)