inference.py 4.8 KB
Newer Older
Y
Yu Yang 已提交
1
import numpy
Y
Yu Yang 已提交
2
import py_paddle.swig_paddle as api
Y
Yu Yang 已提交
3
import collections
Y
Yu Yang 已提交
4
import topology
Y
Yu Yang 已提交
5
import minibatch
Y
Yu Yang 已提交
6 7
from data_feeder import DataFeeder

8
__all__ = ['infer']
Y
Yu Yang 已提交
9 10


Y
Yu Yang 已提交
11
class Inference(object):
Y
Yu Yang 已提交
12 13 14 15 16 17 18
    def __init__(self, output, parameters):
        topo = topology.Topology(output)
        gm = api.GradientMachine.createFromConfigProto(
            topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE])
        for param in gm.getParameters():
            val = param.getBuf(api.PARAMETER_VALUE)
            name = param.getName()
Y
Yu Yang 已提交
19 20
            assert isinstance(val, api.Vector)
            val.copyFromNumpyArray(parameters.get(name).flatten())
Y
Yu Yang 已提交
21 22 23
        self.__gradient_machine__ = gm
        self.__data_types__ = topo.data_type()

Y
Yu Yang 已提交
24 25 26 27 28
    def iter_infer(self,
                   input=None,
                   batch_size=None,
                   reader=None,
                   reader_dict=None):
Y
Yu Yang 已提交
29 30
        if reader_dict is None:
            reader_dict = self.default_reader_dict()
Y
Yu Yang 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56

        if reader is None:
            assert input is not None and isinstance(input, collections.Iterable)
            if not isinstance(input, collections.Iterable):
                raise TypeError("When reader is None, input should be whole "
                                "inference data and should be iterable")

            if batch_size is None:
                if not hasattr(input, '__len__'):
                    raise ValueError("Should set batch size when input data "
                                     "don't contain length.")
                batch_size = len(input)

            def __reader_impl__():
                for each_sample in input:
                    if len(reader_dict) == 1:
                        yield [each_sample]
                    else:
                        yield each_sample

            reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
        else:
            if input is not None:
                raise ValueError("User should set either input or reader, "
                                 "should not set them both.")

Y
Yu Yang 已提交
57 58 59
        feeder = DataFeeder(self.__data_types__, reader_dict)
        self.__gradient_machine__.start()
        for data_batch in reader():
Y
Yu Yang 已提交
60
            yield self.__gradient_machine__.forwardTest(feeder(data_batch))
Y
Yu Yang 已提交
61 62 63 64 65 66 67
        self.__gradient_machine__.finish()

    def iter_infer_field(self, field, **kwargs):
        for result in self.iter_infer(**kwargs):
            yield [each_result[field] for each_result in result]

    def infer(self, field='value', **kwargs):
Y
Yu Yang 已提交
68 69 70 71 72 73 74 75 76 77 78
        retv = None
        for result in self.iter_infer_field(field=field, **kwargs):
            if retv is None:
                retv = [[]] * len(result)
            for i, item in enumerate(result):
                retv[i].append(item)
        retv = [numpy.concatenate(out) for out in retv]
        if len(retv) == 1:
            return retv[0]
        else:
            return retv
Y
Yu Yang 已提交
79 80 81 82 83 84 85 86

    def default_reader_dict(self):
        reader_dict = dict()
        for i, tp in enumerate(self.__data_types__):
            reader_dict[tp[0]] = i
        return reader_dict


Y
Yu Yang 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
def infer(output,
          parameters,
          input=None,
          batch_size=None,
          reader=None,
          reader_dict=None,
          field='value'):
    """
    Infer a neural network by given neural network output and parameters.  The
    user should pass either a batch of input data or reader method.

    Example usages:

    ..  code-block:: python

        result = paddle.infer(prediction, parameters, input=SomeData,
                              batch_size=32)
        print result

    :param output: output of the neural network that would be inferred
    :type output: paddle.v2.config_base.Layer
    :param parameters: parameters of the neural network.
    :type parameters: paddle.v2.parameters.Parameters
    :param input: input data batch. Should be a python iterable object, and each
                  element is the data batch.
    :type input: collections.Iterable
    :param batch_size: the batch size when perform inference. Default is the
                       length of input.
    :type batch_size: int
    :param reader: input data reader creator in batch. If this field is set, the
                   `input` and `batch_size` will be ignored.
    :type reader: callable
    :param reader_dict: Reader dictionary. Default could generate from input
                        value.
    :param field: The prediction field. It should in [`value`, `ids`]. `value`
                  means return the prediction probabilities, `ids` means return
                  the prediction labels. Default is `value`
    :type field: str
    :return: a numpy array
    :rtype: numpy.ndarray
    """

Y
Yu Yang 已提交
129
    inferer = Inference(output=output, parameters=parameters)
Y
Yu Yang 已提交
130 131 132 133 134 135
    return inferer.infer(
        field=field,
        input=input,
        batch_size=batch_size,
        reader=reader,
        reader_dict=reader_dict)