diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index cc893ef0f5748906225570a06da0d8e8bef63460..6b95a88042a13a280bcb80f753b3887fcef37296 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -122,13 +122,14 @@ def main(): test_creator = paddle.dataset.mnist.test() test_data = [] for item in test_creator(): - test_data.append(item[0]) + test_data.append((item[0], )) if len(test_data) == 100: break # output is a softmax layer. It returns probabilities. # Shape should be (100, 10) - probs = paddle.infer(output=predict, parameters=parameters, input=test_data) + probs = paddle.infer( + output_layer=predict, parameters=parameters, input=test_data) print probs.shape diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py index 35949622abb7a704b0b23d4f9457738a1177a795..ec3c67d89548f68d705a9b5de80e28597e9829da 100644 --- a/python/paddle/v2/inference.py +++ b/python/paddle/v2/inference.py @@ -9,8 +9,8 @@ __all__ = ['infer'] class Inference(object): - def __init__(self, output, parameters): - topo = topology.Topology(output) + def __init__(self, output_layer, parameters): + topo = topology.Topology(output_layer) gm = api.GradientMachine.createFromConfigProto( topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE]) for param in gm.getParameters(): @@ -21,33 +21,16 @@ class Inference(object): self.__gradient_machine__ = gm self.__data_types__ = topo.data_type() - def iter_infer(self, input=None, batch_size=None, reader=None, - feeding=None): + def iter_infer(self, input, feeding=None): feeder = DataFeeder(self.__data_types__, feeding) - if reader is None: - assert input is not None and isinstance(input, collections.Iterable) - if not isinstance(input, collections.Iterable): - raise TypeError("When reader is None, input should be whole " - "inference data and should be iterable") - - if batch_size is None: - if not hasattr(input, '__len__'): - raise ValueError("Should set batch size when input data " - "don't contain length.") - batch_size = len(input) - - def __reader_impl__(): - for each_sample in input: - if len(feeder) == 1: - yield [each_sample] - else: - yield each_sample - - reader = minibatch.batch(__reader_impl__, batch_size=batch_size) - else: - if input is not None: - raise ValueError("User should set either input or reader, " - "should not set them both.") + batch_size = len(input) + + def __reader_impl__(): + for each_sample in input: + yield each_sample + + reader = minibatch.batch(__reader_impl__, batch_size=batch_size) + self.__gradient_machine__.start() for data_batch in reader(): yield self.__gradient_machine__.forwardTest(feeder(data_batch)) @@ -71,13 +54,7 @@ class Inference(object): return retv -def infer(output, - parameters, - input=None, - batch_size=None, - reader=None, - feeding=None, - field='value'): +def infer(output_layer, parameters, input, feeding=None, field='value'): """ Infer a neural network by given neural network output and parameters. The user should pass either a batch of input data or reader method. @@ -90,19 +67,13 @@ def infer(output, batch_size=32) print result - :param output: output of the neural network that would be inferred - :type output: paddle.v2.config_base.Layer + :param output_layer: output of the neural network that would be inferred + :type output_layer: paddle.v2.config_base.Layer :param parameters: parameters of the neural network. :type parameters: paddle.v2.parameters.Parameters :param input: input data batch. Should be a python iterable object, and each element is the data batch. :type input: collections.Iterable - :param batch_size: the batch size when perform inference. Default is the - length of input. - :type batch_size: int - :param reader: input data reader creator in batch. If this field is set, the - `input` and `batch_size` will be ignored. - :type reader: callable :param feeding: Reader dictionary. Default could generate from input value. :param field: The prediction field. It should in [`value`, `ids`]. `value` @@ -113,10 +84,5 @@ def infer(output, :rtype: numpy.ndarray """ - inferer = Inference(output=output, parameters=parameters) - return inferer.infer( - field=field, - input=input, - batch_size=batch_size, - reader=reader, - feeding=feeding) + inferer = Inference(output_layer=output_layer, parameters=parameters) + return inferer.infer(field=field, input=input, feeding=feeding)