inference.py 4.4 KB
Newer Older
Y
Yu Yang 已提交
1
import numpy
Y
Yu Yang 已提交
2
import py_paddle.swig_paddle as api
Y
Yu Yang 已提交
3
import collections
Y
Yu Yang 已提交
4
import topology
Y
Yu Yang 已提交
5
import minibatch
Y
Yu Yang 已提交
6 7
from data_feeder import DataFeeder

8
__all__ = ['infer']
Y
Yu Yang 已提交
9 10


Y
Yu Yang 已提交
11
class Inference(object):
Y
Yu Yang 已提交
12 13 14 15 16 17 18
    def __init__(self, output, parameters):
        topo = topology.Topology(output)
        gm = api.GradientMachine.createFromConfigProto(
            topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE])
        for param in gm.getParameters():
            val = param.getBuf(api.PARAMETER_VALUE)
            name = param.getName()
Y
Yu Yang 已提交
19 20
            assert isinstance(val, api.Vector)
            val.copyFromNumpyArray(parameters.get(name).flatten())
Y
Yu Yang 已提交
21 22 23
        self.__gradient_machine__ = gm
        self.__data_types__ = topo.data_type()

24 25
    def iter_infer(self, input=None, batch_size=None, reader=None,
                   feeding=None):
Y
Yu Yang 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40

        if reader is None:
            assert input is not None and isinstance(input, collections.Iterable)
            if not isinstance(input, collections.Iterable):
                raise TypeError("When reader is None, input should be whole "
                                "inference data and should be iterable")

            if batch_size is None:
                if not hasattr(input, '__len__'):
                    raise ValueError("Should set batch size when input data "
                                     "don't contain length.")
                batch_size = len(input)

            def __reader_impl__():
                for each_sample in input:
Y
Yu Yang 已提交
41
                    yield each_sample
Y
Yu Yang 已提交
42 43 44 45 46 47 48

            reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
        else:
            if input is not None:
                raise ValueError("User should set either input or reader, "
                                 "should not set them both.")

Y
Yu Yang 已提交
49
        feeder = DataFeeder(self.__data_types__, feeding)
Y
Yu Yang 已提交
50 51
        self.__gradient_machine__.start()
        for data_batch in reader():
Y
Yu Yang 已提交
52
            yield self.__gradient_machine__.forwardTest(feeder(data_batch))
Y
Yu Yang 已提交
53 54 55 56 57 58 59
        self.__gradient_machine__.finish()

    def iter_infer_field(self, field, **kwargs):
        for result in self.iter_infer(**kwargs):
            yield [each_result[field] for each_result in result]

    def infer(self, field='value', **kwargs):
Y
Yu Yang 已提交
60 61 62 63 64 65 66 67 68 69 70
        retv = None
        for result in self.iter_infer_field(field=field, **kwargs):
            if retv is None:
                retv = [[]] * len(result)
            for i, item in enumerate(result):
                retv[i].append(item)
        retv = [numpy.concatenate(out) for out in retv]
        if len(retv) == 1:
            return retv[0]
        else:
            return retv
Y
Yu Yang 已提交
71 72


Y
Yu Yang 已提交
73 74 75 76 77
def infer(output,
          parameters,
          input=None,
          batch_size=None,
          reader=None,
78
          feeding=None,
Y
Yu Yang 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
          field='value'):
    """
    Infer a neural network by given neural network output and parameters.  The
    user should pass either a batch of input data or reader method.

    Example usages:

    ..  code-block:: python

        result = paddle.infer(prediction, parameters, input=SomeData,
                              batch_size=32)
        print result

    :param output: output of the neural network that would be inferred
    :type output: paddle.v2.config_base.Layer
    :param parameters: parameters of the neural network.
    :type parameters: paddle.v2.parameters.Parameters
    :param input: input data batch. Should be a python iterable object, and each
                  element is the data batch.
    :type input: collections.Iterable
    :param batch_size: the batch size when perform inference. Default is the
                       length of input.
    :type batch_size: int
    :param reader: input data reader creator in batch. If this field is set, the
                   `input` and `batch_size` will be ignored.
    :type reader: callable
105
    :param feeding: Reader dictionary. Default could generate from input
Y
Yu Yang 已提交
106 107 108 109 110 111 112 113 114
                        value.
    :param field: The prediction field. It should in [`value`, `ids`]. `value`
                  means return the prediction probabilities, `ids` means return
                  the prediction labels. Default is `value`
    :type field: str
    :return: a numpy array
    :rtype: numpy.ndarray
    """

Y
Yu Yang 已提交
115
    inferer = Inference(output=output, parameters=parameters)
Y
Yu Yang 已提交
116 117 118 119 120
    return inferer.infer(
        field=field,
        input=input,
        batch_size=batch_size,
        reader=reader,
121
        feeding=feeding)