reader.py 4.1 KB
Newer Older
X
xixiaoyao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13

from . import gpu_dev_count, cpu_dev_count
import Queue
from threading import Thread

dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count

def yield_pieces(data, distribute_strategy, batch_size):
    """
    Args:
        distribute_strategy: support s=split, c=copy, u=unstack,
        """
    assert batch_size % dev_count == 0, "batch_size need to be integer times larger than dev_count."
W
wangxiao1021 已提交
14 15
    # print('data in yield pieces')
    # print(len(data))
X
xixiaoyao 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

    assert type(data) == type(distribute_strategy), [type(data), type(distribute_strategy)]
    assert len(data) == len(distribute_strategy), [len(data), len(distribute_strategy)]
    if isinstance(data, dict):
        keys = list(data.keys())
        data_list = [data[i] for i in keys]
        ds_list = [distribute_strategy[i] for i in keys]
    else:
        assert isinstance(data, list), "the input data must be a list or dict, and contained with multiple tensors."
        data_list = data
        ds_list = distribute_strategy
    stride = batch_size // dev_count
    p = stride
    # while p < len(data_list) + stride:
    while p <= batch_size:
        temp = []
        for d, s in zip(data_list, ds_list):
            s = s.strip().lower()
            if s == 's' or s == 'split':
                if p - stride >= len(d):
W
wangxiao1021 已提交
36
                    # print('WARNING: no more examples to feed empty devices')
X
xixiaoyao 已提交
37 38 39 40 41 42
                    temp = []
                    return
                temp.append(d[p-stride:p])
            elif s == 'u' or s == 'unstack':
                assert len(d) <= dev_count, 'Tensor size on dim 0 must be less equal to dev_count when unstack is applied.'
                if p//stride > len(d):
W
wangxiao1021 已提交
43
                    # print('WARNING: no more examples to feed empty devices')
X
xixiaoyao 已提交
44 45 46 47 48 49 50 51 52 53 54
                    return
                temp.append(d[p//stride-1])
            elif s == 'c' or s == 'copy':
                temp.append(d)
            else:
                raise NotImplementedError()
            
        p += stride
        if type(data) == dict:
            yield dict(zip(*[keys, temp]))
        else:
W
wangxiao1021 已提交
55 56
            # print('yielded pieces')
            # print(len(temp))
X
xixiaoyao 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
            yield temp

def data_feeder(reader, postprocess_fn=None, prefetch_steps=2):
    if postprocess_fn is None:
        def postprocess_fn(batch):
            return batch

    def worker(reader, dev_count, queue):
        dev_batches = []
        for index, data in enumerate(reader()):
            if len(dev_batches) < dev_count:
                dev_batches.append(data)
            if len(dev_batches) == dev_count:
                queue.put((dev_batches, 0))
                dev_batches = []
        # For the prediction of the remained batches, pad more batches to 
        # the number of devices and the padded samples would be removed in
        # prediction outputs. 
        if len(dev_batches) > 0:
            num_pad = dev_count - len(dev_batches)
            for i in range(len(dev_batches), dev_count):
                dev_batches.append(dev_batches[-1])
            queue.put((dev_batches, num_pad))
        queue.put(None)

    queue = Queue.Queue(dev_count*prefetch_steps)
    p = Thread(
        target=worker, args=(reader, dev_count, queue))
    p.daemon = True
    p.start()
    while True:
        ret = queue.get()
        queue.task_done()
        if ret is not None:
            batches, num_pad = ret
            batch_buf = []
            flag_buf = []
            for idx, batch in enumerate(batches):
                # flag = num_pad == 0
                flag = idx-len(batches) < -num_pad
                # if num_pad > 0:
                #     num_pad -= 1
W
wangxiao1021 已提交
99
                # batch = postprocess_fn(batch, id)
X
xixiaoyao 已提交
100 101 102 103
                batch = postprocess_fn(batch)
                batch_buf.append(batch)
                flag_buf.append(flag)
            yield batch_buf, flag_buf
W
wangxiao1021 已提交
104
        else: 
X
xixiaoyao 已提交
105 106 107 108
            break
    queue.join()


W
wangxiao1021 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122


def decode_fake(nums, mask, bs):
    n_t = 0
    for flag in mask: 
        if not flag:
            break
        n_t = n_t + 1
    
    n_f = len(mask) - n_t
    p1 = nums - (n_t-1) * bs
    each_f = p1 / (n_f+1)
    return each_f * n_f