mix.py 2.3 KB
Newer Older
H
huangyuxin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#

"""Classes for mixing samples from multiple sources."""

import itertools, os, random, time, sys
from functools import reduce, wraps

import numpy as np

from . import autodecode, utils
from .paddle_utils import PaddleTensor, IterableDataset
from .utils import PipelineStage


def round_robin_shortest(*sources):
    i = 0
    while True:
        try:
            sample = next(sources[i % len(sources)])
            yield sample
        except StopIteration:
            break
        i += 1


def round_robin_longest(*sources):
    i = 0
    while len(sources) > 0:
        try:
            sample = next(sources[i])
            i += 1
            yield sample
        except StopIteration:
            del sources[i]


class RoundRobin(IterableDataset):
    def __init__(self, datasets, longest=False):
        self.datasets = datasets
        self.longest = longest

    def __iter__(self):
        """Return an iterator over the sources."""
        sources = [iter(d) for d in self.datasets]
        if self.longest:
            return round_robin_longest(*sources)
        else:
            return round_robin_shortest(*sources)


def random_samples(sources, probs=None, longest=False):
    if probs is None:
        probs = [1] * len(sources)
    else:
        probs = list(probs)
    while len(sources) > 0:
        cum = (np.array(probs) / np.sum(probs)).cumsum()
        r = random.random()
        i = np.searchsorted(cum, r)
        try:
            yield next(sources[i])
        except StopIteration:
            if longest:
                del sources[i]
                del probs[i]
            else:
                break


class RandomMix(IterableDataset):
    def __init__(self, datasets, probs=None, longest=False):
        self.datasets = datasets
        self.probs = probs
        self.longest = longest

    def __iter__(self):
        """Return an iterator over the sources."""
        sources = [iter(d) for d in self.datasets]
        return random_samples(sources, self.probs, longest=self.longest)