diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 75c2f08132dcae291ea4d5d70edfe804a702dd18..3aa2199bcb7126325e573f3c84442b52f4a3f21c 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -132,7 +132,8 @@ def main(): # output is a softmax layer. It returns probabilities. # Shape should be (100, 10) - probs = paddle.infer(output=predict, parameters=parameters, input=test_data) + probs = paddle.infer( + output_layer=predict, parameters=parameters, input=test_data) print probs.shape diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py index 2ad4d9d1ab037a210b81b4ee63bce267f296bd83..53510d80c9d92b42ebbe120cc6f4166b09198ae5 100644 --- a/python/paddle/v2/inference.py +++ b/python/paddle/v2/inference.py @@ -9,8 +9,8 @@ __all__ = ['infer'] class Inference(object): - def __init__(self, output, parameters): - topo = topology.Topology(output) + def __init__(self, output_layer, parameters): + topo = topology.Topology(output_layer) gm = api.GradientMachine.createFromConfigProto( topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE]) for param in gm.getParameters(): @@ -70,13 +70,7 @@ class Inference(object): return retv -def infer(output, - parameters, - input=None, - batch_size=None, - reader=None, - feeding=None, - field='value'): +def infer(output_layer, parameters, input=None, feeding=None, field='value'): """ Infer a neural network by given neural network output and parameters. The user should pass either a batch of input data or reader method. @@ -89,19 +83,13 @@ def infer(output, batch_size=32) print result - :param output: output of the neural network that would be inferred - :type output: paddle.v2.config_base.Layer + :param output_layer: output of the neural network that would be inferred + :type output_layer: paddle.v2.config_base.Layer :param parameters: parameters of the neural network. :type parameters: paddle.v2.parameters.Parameters :param input: input data batch. Should be a python iterable object, and each element is the data batch. :type input: collections.Iterable - :param batch_size: the batch size when perform inference. Default is the - length of input. - :type batch_size: int - :param reader: input data reader creator in batch. If this field is set, the - `input` and `batch_size` will be ignored. - :type reader: callable :param feeding: Reader dictionary. Default could generate from input value. :param field: The prediction field. It should in [`value`, `ids`]. `value` @@ -112,10 +100,5 @@ def infer(output, :rtype: numpy.ndarray """ - inferer = Inference(output=output, parameters=parameters) - return inferer.infer( - field=field, - input=input, - batch_size=batch_size, - reader=reader, - feeding=feeding) + inferer = Inference(output_layer=output_layer, parameters=parameters) + return inferer.infer(field=field, input=input, feeding=feeding)