data_utils.py 3.2 KB
Newer Older
R
Rosun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
"""
This code is based on https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py
"""

import time
import numpy as np
import threading
import multiprocessing
try:
    import queue
except ImportError:
    import Queue as queue


class GeneratorEnqueuer(object):
    """
    Multiple generators 

    Args:
        generators: 
        wait_time (float): time to sleep in-between calls to `put()`.
    """

    def __init__(self, generators, wait_time=0.05):
        self.wait_time = wait_time
        self._generators = generators
        self._threads = []
        self._stop_events = []
        self.queue = None
        self._manager = None
        self.workers = 1

    def start(self, workers=1, max_queue_size=16):
        """
        Start worker threads which add data from the generator into the queue.

        Args:
            workers (int): number of worker threads
            max_queue_size (int): queue size
                (when full, threads could block on `put()`)
        """

        self.workers = workers

        def data_generator_task(pid):
            """
            Data generator task.
            """

            def task(pid):
                if (self.queue is not None
                        and self.queue.qsize() < max_queue_size):
                    generator_output = next(self._generators[pid])
                    self.queue.put((generator_output))
                else:
                    time.sleep(self.wait_time)

            while not self._stop_events[pid].is_set():
                try:
                    task(pid)
                except Exception:
                    self._stop_events[pid].set()
                    break

        try:
            self._manager = multiprocessing.Manager()
            self.queue = self._manager.Queue(maxsize=max_queue_size)
            for pid in range(self.workers):
                self._stop_events.append(multiprocessing.Event())
                thread = multiprocessing.Process(
                    target=data_generator_task, args=(pid, ))
                thread.daemon = True
                self._threads.append(thread)
                thread.start()
        except:
            self.stop()
            raise

    def is_running(self):
        """
        Returns:
            bool: Whether the worker theads are running.
        """

        # If queue is not empty then still in runing state wait for consumer
        if not self.queue.empty():
            return True

        for pid in range(self.workers):
            if not self._stop_events[pid].is_set():
                return True

        return False

    def stop(self, timeout=None):
        """
        Stops running threads and wait for them to exit, if necessary.
        Should be called by the same thread which called `start()`.

        Args:
            timeout(int|None): maximum time to wait on `thread.join()`.
        """
        if self.is_running():
            for pid in range(self.workers):
                self._stop_events[pid].set()

        for thread in self._threads:
            if thread.is_alive():
                thread.join(timeout)
        if self._manager:
            self._manager.shutdown()

        self._threads = []
        self._stop_events = []
        self.queue = None