提交 43493b2a 编写于 作者: Y Yu Yang

Expose Inference in Python V2 API.

上级 f003a63b
...@@ -5,7 +5,7 @@ import topology ...@@ -5,7 +5,7 @@ import topology
import minibatch import minibatch
from data_feeder import DataFeeder from data_feeder import DataFeeder
__all__ = ['infer'] __all__ = ['infer', 'Inference']
class Inference(object): class Inference(object):
...@@ -13,7 +13,14 @@ class Inference(object): ...@@ -13,7 +13,14 @@ class Inference(object):
Inference combines neural network output and parameters together Inference combines neural network output and parameters together
to do inference. to do inference.
:param outptut_layer: The neural network that should be inferenced. .. code-block:: python
inferer = Inference(output_layer=prediction, parameters=parameters)
for data_batch in batches:
print inferer.infer(data_batch)
:param output_layer: The neural network that should be inferenced.
:type output_layer: paddle.v2.config_base.Layer or the sequence :type output_layer: paddle.v2.config_base.Layer or the sequence
of paddle.v2.config_base.Layer of paddle.v2.config_base.Layer
:param parameters: The parameters dictionary. :param parameters: The parameters dictionary.
...@@ -56,8 +63,14 @@ class Inference(object): ...@@ -56,8 +63,14 @@ class Inference(object):
item = [each_result[each_field] for each_field in field] item = [each_result[each_field] for each_field in field]
yield item yield item
def infer(self, field='value', **kwargs): def infer(self, input, field='value', **kwargs):
"""
Infer a data by model.
:param input: input data batch. Should be python iterable object.
:param field: output field.
"""
retv = None retv = None
kwargs['input'] = input
for result in self.iter_infer_field(field=field, **kwargs): for result in self.iter_infer_field(field=field, **kwargs):
if retv is None: if retv is None:
retv = [[] for i in xrange(len(result))] retv = [[] for i in xrange(len(result))]
...@@ -79,7 +92,7 @@ def infer(output_layer, parameters, input, feeding=None, field='value'): ...@@ -79,7 +92,7 @@ def infer(output_layer, parameters, input, feeding=None, field='value'):
.. code-block:: python .. code-block:: python
result = paddle.infer(outptut_layer=prediction, result = paddle.infer(output_layer=prediction,
parameters=parameters, parameters=parameters,
input=SomeData) input=SomeData)
print result print result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册