From 4c24ac1a9952349c0965d5c50a07878ec2632e17 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 1 Mar 2017 14:50:55 +0800 Subject: [PATCH] Init inferencer. --- python/paddle/v2/__init__.py | 6 +++- python/paddle/v2/inferencer.py | 54 ++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 python/paddle/v2/inferencer.py diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index b31efe170d..cc8f33f980 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -24,12 +24,13 @@ from . import dataset from . import reader import attr import pooling +import inferencer import py_paddle.swig_paddle as api __all__ = [ 'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer', 'event', 'data_type', 'attr', 'pooling', 'data_feeder', 'dataset', 'reader', - 'topology' + 'topology', 'inferencer', 'infer' ] @@ -39,3 +40,6 @@ def init(**kwargs): args.append('--%s=%s' % (key, str(kwargs[key]))) api.initPaddle(*args) + + +infer = inferencer.infer diff --git a/python/paddle/v2/inferencer.py b/python/paddle/v2/inferencer.py new file mode 100644 index 0000000000..36a8ee3711 --- /dev/null +++ b/python/paddle/v2/inferencer.py @@ -0,0 +1,54 @@ +import py_paddle.swig_paddle as api + +import topology +from data_feeder import DataFeeder +import itertools +import numpy + +__all__ = ['InferenceEngine', 'infer'] + + +class InferenceEngine(object): + def __init__(self, output, parameters): + topo = topology.Topology(output) + gm = api.GradientMachine.createFromConfigProto( + topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE]) + for param in gm.getParameters(): + val = param.getBuf(api.PARAMETER_VALUE) + name = param.getName() + assert isinstance(val, api.Matrix) + val.copyFromNumpyMat(parameters.get(name)) + self.__gradient_machine__ = gm + self.__data_types__ = topo.data_type() + + def iter_infer(self, reader, reader_dict=None): + feeder = DataFeeder(self.__data_types__, reader_dict) + out_args = api.Arguments.createArguments(0) + self.__gradient_machine__.start() + for data_batch in reader(): + yield self.__gradient_machine__.forwardTest( + feeder(data_batch), out_args, api.PASS_TEST) + self.__gradient_machine__.finish() + + def iter_infer_field(self, field, **kwargs): + for result in self.iter_infer(**kwargs): + yield [each_result[field] for each_result in result] + + def infer(self, field='value', **kwargs): + retv = [] + for result in itertools.izip( + self.iter_infer_field( + field=field, **kwargs)): + retv.append(numpy.concatenate(result)) + return retv + + def default_reader_dict(self): + reader_dict = dict() + for i, tp in enumerate(self.__data_types__): + reader_dict[tp[0]] = i + return reader_dict + + +def infer(output, parameters, reader, reader_dict=None, field='value'): + inferer = InferenceEngine(output=output, parameters=parameters) + return inferer.infer(field=field, reader=reader, reader_dict=reader_dict) -- GitLab