utility.py 6.0 KB
Newer Older
1
"""Contains data helper functions."""
2 3 4 5 6
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
7
import codecs
Y
yangyaming 已提交
8 9
import os
import tarfile
10 11
import time
from Queue import Queue
12
from threading import Thread
13
from multiprocessing import Process, Manager
Y
yangyaming 已提交
14
from paddle.v2.dataset.common import md5file
15 16 17


def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
18
    """Load and parse manifest file.
19

20 21 22
    Instances with durations outside [min_duration, max_duration] will be
    filtered out.

23
    :param manifest_path: Manifest file to load and parse.
24 25 26 27 28 29 30 31 32
    :type manifest_path: basestring
    :param max_duration: Maximal duration in seconds for instance filter.
    :type max_duration: float
    :param min_duration: Minimal duration in seconds for instance filter.
    :type min_duration: float
    :return: Manifest parsing results. List of dict.
    :rtype: list
    :raises IOError: If failed to parse the manifest.
    """
33
    manifest = []
34
    for json_line in codecs.open(manifest_path, 'r', 'utf-8'):
35 36 37 38 39 40 41 42
        try:
            json_data = json.loads(json_line)
        except Exception as e:
            raise IOError("Error reading manifest: %s" % str(e))
        if (json_data["duration"] <= max_duration and
                json_data["duration"] >= min_duration):
            manifest.append(json_data)
    return manifest
Y
yangyaming 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67


def download(url, md5sum, target_dir):
    """Download file from url to target_dir, and check md5sum."""
    if not os.path.exists(target_dir): os.makedirs(target_dir)
    filepath = os.path.join(target_dir, url.split("/")[-1])
    if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
        print("Downloading %s ..." % url)
        os.system("wget -c " + url + " -P " + target_dir)
        print("\nMD5 Chesksum %s ..." % filepath)
        if not md5file(filepath) == md5sum:
            raise RuntimeError("MD5 checksum failed.")
    else:
        print("File exists, skip downloading. (%s)" % filepath)
    return filepath


def unpack(filepath, target_dir, rm_tar=False):
    """Unpack the file to the target_dir."""
    print("Unpacking %s ..." % filepath)
    tar = tarfile.open(filepath)
    tar.extractall(target_dir)
    tar.close()
    if rm_tar == True:
        os.remove(filepath)
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103


class XmapEndSignal():
    pass


def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
    """A multiprocessing pipeline wrapper for the data reader.

    :param mapper:  Function to map sample.
    :type mapper: callable
    :param reader: Given data reader.
    :type reader: callable
    :param process_num: Number of processes in the pipeline
    :type process_num: int
    :param buffer_size: Maximal buffer size.
    :type buffer_size: int
    :param order: Reserve the order of samples from the given reader.
    :type order: bool
    :return: The wrappered reader
    :rtype: callable
    """
    end_flag = XmapEndSignal()

    # define a worker to read samples from reader to in_queue
    def read_worker(reader, in_queue):
        for sample in reader():
            in_queue.put(sample)
        in_queue.put(end_flag)

    # define a worker to read samples from reader to in_queue with order flag
    def order_read_worker(reader, in_queue):
        for order_id, sample in enumerate(reader()):
            in_queue.put((order_id, sample))
        in_queue.put(end_flag)

104 105
    # define a worker to handle samples from in_queue by mapper and put results
    # to out_queue
106 107 108 109 110 111 112 113
    def handle_worker(in_queue, out_queue, mapper):
        sample = in_queue.get()
        while not isinstance(sample, XmapEndSignal):
            out_queue.put(mapper(sample))
            sample = in_queue.get()
        in_queue.put(end_flag)
        out_queue.put(end_flag)

114 115
    # define a worker to handle samples from in_queue by mapper and put results
    # to out_queue with order
116 117 118 119 120 121 122 123 124 125 126 127 128
    def order_handle_worker(in_queue, out_queue, mapper, out_order):
        ins = in_queue.get()
        while not isinstance(ins, XmapEndSignal):
            order_id, sample = ins
            result = mapper(sample)
            while order_id != out_order[0]:
                time.sleep(0.001)
            out_queue.put(result)
            out_order[0] += 1
            ins = in_queue.get()
        in_queue.put(end_flag)
        out_queue.put(end_flag)

129 130 131 132 133 134 135 136 137 138 139 140
    # define a thread worker to flush samples from Manager.Queue to Queue
    # for acceleration
    def flush_worker(in_queue, out_queue):
        finish = 0
        while finish < process_num:
            sample = in_queue.get()
            if isinstance(sample, XmapEndSignal):
                finish += 1
            else:
                out_queue.put(sample)
        out_queue.put(end_flag)

Y
yangyaming 已提交
141 142 143 144
    def cleanup():
        # kill all sub process and threads
        os._exit(0)

145 146 147 148 149 150 151 152 153 154
    def xreader():
        # prepare shared memory
        manager = Manager()
        in_queue = manager.Queue(buffer_size)
        out_queue = manager.Queue(buffer_size)
        out_order = manager.list([0])

        # start a read worker in a process
        target = order_read_worker if order else read_worker
        p = Process(target=target, args=(reader, in_queue))
155
        p.daemon = True
156 157 158 159 160 161 162 163 164 165
        p.start()

        # start handle_workers with multiple processes
        target = order_handle_worker if order else handle_worker
        args = (in_queue, out_queue, mapper, out_order) if order else (
            in_queue, out_queue, mapper)
        workers = [
            Process(target=target, args=args) for _ in xrange(process_num)
        ]
        for w in workers:
166
            w.daemon = True
167 168
            w.start()

169 170 171 172 173 174
        # start a thread to read data from slow Manager.Queue
        flush_queue = Queue(buffer_size)
        t = Thread(target=flush_worker, args=(out_queue, flush_queue))
        t.daemon = True
        t.start()

175
        # get results
176 177 178 179
        sample = flush_queue.get()
        while not isinstance(sample, XmapEndSignal):
            yield sample
            sample = flush_queue.get()
180

Y
yangyaming 已提交
181
    return xreader, cleanup