From 71ab4df36625ea5ae6637afbee2b588a513db608 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 8 Mar 2017 13:26:28 +0800 Subject: [PATCH] Follow comments, remove reader/batch_size in interface. --- demo/mnist/api_train_v2.py | 3 ++- python/paddle/v2/inference.py | 31 +++++++------------------------ 2 files changed, 9 insertions(+), 25 deletions(-) diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 75c2f0813..3aa2199bc 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -132,7 +132,8 @@ def main(): # output is a softmax layer. It returns probabilities. # Shape should be (100, 10) - probs = paddle.infer(output=predict, parameters=parameters, input=test_data) + probs = paddle.infer( + output_layer=predict, parameters=parameters, input=test_data) print probs.shape diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py index 2ad4d9d1a..53510d80c 100644 --- a/python/paddle/v2/inference.py +++ b/python/paddle/v2/inference.py @@ -9,8 +9,8 @@ __all__ = ['infer'] class Inference(object): - def __init__(self, output, parameters): - topo = topology.Topology(output) + def __init__(self, output_layer, parameters): + topo = topology.Topology(output_layer) gm = api.GradientMachine.createFromConfigProto( topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE]) for param in gm.getParameters(): @@ -70,13 +70,7 @@ class Inference(object): return retv -def infer(output, - parameters, - input=None, - batch_size=None, - reader=None, - feeding=None, - field='value'): +def infer(output_layer, parameters, input=None, feeding=None, field='value'): """ Infer a neural network by given neural network output and parameters. The user should pass either a batch of input data or reader method. @@ -89,19 +83,13 @@ def infer(output, batch_size=32) print result - :param output: output of the neural network that would be inferred - :type output: paddle.v2.config_base.Layer + :param output_layer: output of the neural network that would be inferred + :type output_layer: paddle.v2.config_base.Layer :param parameters: parameters of the neural network. :type parameters: paddle.v2.parameters.Parameters :param input: input data batch. Should be a python iterable object, and each element is the data batch. :type input: collections.Iterable - :param batch_size: the batch size when perform inference. Default is the - length of input. - :type batch_size: int - :param reader: input data reader creator in batch. If this field is set, the - `input` and `batch_size` will be ignored. - :type reader: callable :param feeding: Reader dictionary. Default could generate from input value. :param field: The prediction field. It should in [`value`, `ids`]. `value` @@ -112,10 +100,5 @@ def infer(output, :rtype: numpy.ndarray """ - inferer = Inference(output=output, parameters=parameters) - return inferer.infer( - field=field, - input=input, - batch_size=batch_size, - reader=reader, - feeding=feeding) + inferer = Inference(output_layer=output_layer, parameters=parameters) + return inferer.infer(field=field, input=input, feeding=feeding) -- GitLab