提交 71ab4df3 编写于 作者: Y Yu Yang

Follow comments, remove reader/batch_size in interface.

上级 5905d0e8
......@@ -132,7 +132,8 @@ def main():
# output is a softmax layer. It returns probabilities.
# Shape should be (100, 10)
probs = paddle.infer(output=predict, parameters=parameters, input=test_data)
probs = paddle.infer(
output_layer=predict, parameters=parameters, input=test_data)
print probs.shape
......
......@@ -9,8 +9,8 @@ __all__ = ['infer']
class Inference(object):
def __init__(self, output, parameters):
topo = topology.Topology(output)
def __init__(self, output_layer, parameters):
topo = topology.Topology(output_layer)
gm = api.GradientMachine.createFromConfigProto(
topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE])
for param in gm.getParameters():
......@@ -70,13 +70,7 @@ class Inference(object):
return retv
def infer(output,
parameters,
input=None,
batch_size=None,
reader=None,
feeding=None,
field='value'):
def infer(output_layer, parameters, input=None, feeding=None, field='value'):
"""
Infer a neural network by given neural network output and parameters. The
user should pass either a batch of input data or reader method.
......@@ -89,19 +83,13 @@ def infer(output,
batch_size=32)
print result
:param output: output of the neural network that would be inferred
:type output: paddle.v2.config_base.Layer
:param output_layer: output of the neural network that would be inferred
:type output_layer: paddle.v2.config_base.Layer
:param parameters: parameters of the neural network.
:type parameters: paddle.v2.parameters.Parameters
:param input: input data batch. Should be a python iterable object, and each
element is the data batch.
:type input: collections.Iterable
:param batch_size: the batch size when perform inference. Default is the
length of input.
:type batch_size: int
:param reader: input data reader creator in batch. If this field is set, the
`input` and `batch_size` will be ignored.
:type reader: callable
:param feeding: Reader dictionary. Default could generate from input
value.
:param field: The prediction field. It should in [`value`, `ids`]. `value`
......@@ -112,10 +100,5 @@ def infer(output,
:rtype: numpy.ndarray
"""
inferer = Inference(output=output, parameters=parameters)
return inferer.infer(
field=field,
input=input,
batch_size=batch_size,
reader=reader,
feeding=feeding)
inferer = Inference(output_layer=output_layer, parameters=parameters)
return inferer.infer(field=field, input=input, feeding=feeding)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册