data_feeder.py 3.4 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5
from __future__ import print_function
import core
import numpy
import six.moves as six

F
fengjiayi 已提交
6
from framework import Variable, default_main_program
Y
Yu Yang 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

__all__ = ['DataFeeder']


class DataToLoDTensorConverter(object):
    def __init__(self, place, lod_level, shape, dtype):
        self.place = place
        self.lod_level = lod_level
        self.shape = shape
        if dtype == core.DataType.FP32:
            self.dtype = 'float32'
        elif dtype == core.DataType.INT64:
            self.dtype = 'int64'
        elif dtype == core.DataType.FP64:
            self.dtype = 'float64'
        elif dtype == core.DataType.INT32:
            self.dtype = 'int32'
        else:
            raise ValueError("dtype must be any of [int32, float32, int64, "
                             "float64]")

        self.data = []
        self.lod = []

        for i in six.range(lod_level):
            self.lod.append([0])

    def feed(self, data):
        self._feed_impl_(data, self.lod, self.lod_level)

    def _feed_impl_(self, data, lod, lod_level):
        if lod_level == 0:
            self.data.append(data)
        else:
            cur_lod_len = len(data)
            lod[-1].append(lod[-1][-1] + cur_lod_len)
            for each_data in data:
                self._feed_impl_(each_data, lod[:-1], lod_level - 1)

    def done(self):
        arr = numpy.array(self.data, dtype=self.dtype).reshape(self.shape)
        t = core.LoDTensor()
        t.set(arr, self.place)
        if self.lod_level > 0:
            t.set_lod(self.lod)
        return t


class DataFeeder(object):
F
fengjiayi 已提交
56
    def __init__(self, feed_list, place, program=None):
Y
Yu Yang 已提交
57 58 59 60
        self.feed_dtypes = []
        self.feed_names = []
        self.feed_shapes = []
        self.feed_lod_level = []
F
fengjiayi 已提交
61 62
        if program is None:
            program = default_main_program()
Y
Yu Yang 已提交
63
        for each_var in feed_list:
F
fengjiayi 已提交
64 65
            if isinstance(each_var, basestring):
                each_var = program.block(0).var(each_var)
Y
Yu Yang 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
            if not isinstance(each_var, Variable):
                raise TypeError("Feed list should contain a list of variable")
            self.feed_dtypes.append(each_var.dtype)
            self.feed_names.append(each_var.name)
            shape = each_var.shape
            batch_size_dim = -1
            for i, s in enumerate(shape):
                if s < 0:
                    batch_size_dim = i
                    break
            if batch_size_dim == -1:
                raise ValueError("Variable {0} must has a batch size dimension",
                                 each_var.name)
            self.feed_lod_level.append(each_var.lod_level)
            self.feed_shapes.append(shape)

        self.place = place

    def feed(self, iterable):
        converter = []
        for lod_level, shape, dtype in six.zip(
                self.feed_lod_level, self.feed_shapes, self.feed_dtypes):
            converter.append(
                DataToLoDTensorConverter(
                    place=self.place,
                    lod_level=lod_level,
                    shape=shape,
                    dtype=dtype))

        for each_sample in iterable:
            for each_converter, each_slot in six.zip(converter, each_sample):
                each_converter.feed(each_slot)
        ret_dict = {}
        for each_name, each_converter in six.zip(self.feed_names, converter):
            ret_dict[each_name] = each_converter.done()
        return ret_dict