inference.py 1.8 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7
import py_paddle.swig_paddle as api

import topology
from data_feeder import DataFeeder
import itertools
import numpy

8
__all__ = ['infer']
Y
Yu Yang 已提交
9 10


Y
Yu Yang 已提交
11
class Inference(object):
Y
Yu Yang 已提交
12 13 14 15 16 17 18
    def __init__(self, output, parameters):
        topo = topology.Topology(output)
        gm = api.GradientMachine.createFromConfigProto(
            topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE])
        for param in gm.getParameters():
            val = param.getBuf(api.PARAMETER_VALUE)
            name = param.getName()
Y
Yu Yang 已提交
19 20
            assert isinstance(val, api.Vector)
            val.copyFromNumpyArray(parameters.get(name).flatten())
Y
Yu Yang 已提交
21 22 23
        self.__gradient_machine__ = gm
        self.__data_types__ = topo.data_type()

Y
Yu Yang 已提交
24 25
    def iter_infer(self, reader, feeding=None):
        feeder = DataFeeder(self.__data_types__, feeding)
Y
Yu Yang 已提交
26 27
        self.__gradient_machine__.start()
        for data_batch in reader():
Y
Yu Yang 已提交
28
            yield self.__gradient_machine__.forwardTest(feeder(data_batch))
Y
Yu Yang 已提交
29 30 31 32 33 34 35
        self.__gradient_machine__.finish()

    def iter_infer_field(self, field, **kwargs):
        for result in self.iter_infer(**kwargs):
            yield [each_result[field] for each_result in result]

    def infer(self, field='value', **kwargs):
Y
Yu Yang 已提交
36 37 38 39 40 41 42 43 44 45 46
        retv = None
        for result in self.iter_infer_field(field=field, **kwargs):
            if retv is None:
                retv = [[]] * len(result)
            for i, item in enumerate(result):
                retv[i].append(item)
        retv = [numpy.concatenate(out) for out in retv]
        if len(retv) == 1:
            return retv[0]
        else:
            return retv
Y
Yu Yang 已提交
47 48


Y
Yu Yang 已提交
49
def infer(output, parameters, reader, feeding=None, field='value'):
Y
Yu Yang 已提交
50
    inferer = Inference(output=output, parameters=parameters)
Y
Yu Yang 已提交
51
    return inferer.infer(field=field, reader=reader, feeding=feeding)