提交 00761972 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #2067 from reyoung/feature/expose_inference_in_py_api

Expose Inference in Python V2 API.
...@@ -5,15 +5,22 @@ import topology ...@@ -5,15 +5,22 @@ import topology
import minibatch import minibatch
from data_feeder import DataFeeder from data_feeder import DataFeeder
__all__ = ['infer'] __all__ = ['infer', 'Inference']
class Inference(object): class Inference(object):
""" """
Inference combines neural network output and parameters together Inference combines neural network output and parameters together
to do inference. to do inference.
.. code-block:: python
inferer = Inference(output_layer=prediction, parameters=parameters)
for data_batch in batches:
print inferer.infer(data_batch)
:param outptut_layer: The neural network that should be inferenced. :param output_layer: The neural network that should be inferenced.
:type output_layer: paddle.v2.config_base.Layer or the sequence :type output_layer: paddle.v2.config_base.Layer or the sequence
of paddle.v2.config_base.Layer of paddle.v2.config_base.Layer
:param parameters: The parameters dictionary. :param parameters: The parameters dictionary.
...@@ -56,8 +63,14 @@ class Inference(object): ...@@ -56,8 +63,14 @@ class Inference(object):
item = [each_result[each_field] for each_field in field] item = [each_result[each_field] for each_field in field]
yield item yield item
def infer(self, field='value', **kwargs): def infer(self, input, field='value', **kwargs):
"""
Infer a data by model.
:param input: input data batch. Should be python iterable object.
:param field: output field.
"""
retv = None retv = None
kwargs['input'] = input
for result in self.iter_infer_field(field=field, **kwargs): for result in self.iter_infer_field(field=field, **kwargs):
if retv is None: if retv is None:
retv = [[] for i in xrange(len(result))] retv = [[] for i in xrange(len(result))]
...@@ -79,7 +92,7 @@ def infer(output_layer, parameters, input, feeding=None, field='value'): ...@@ -79,7 +92,7 @@ def infer(output_layer, parameters, input, feeding=None, field='value'):
.. code-block:: python .. code-block:: python
result = paddle.infer(outptut_layer=prediction, result = paddle.infer(output_layer=prediction,
parameters=parameters, parameters=parameters,
input=SomeData) input=SomeData)
print result print result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册