提交 4feb5013 编写于 作者: Y Yi Wang 提交者: GitHub

Merge pull request #1561 from reyoung/feature/better_infer_interface

Add input data interface for inference
...@@ -122,13 +122,14 @@ def main(): ...@@ -122,13 +122,14 @@ def main():
test_creator = paddle.dataset.mnist.test() test_creator = paddle.dataset.mnist.test()
test_data = [] test_data = []
for item in test_creator(): for item in test_creator():
test_data.append(item[0]) test_data.append((item[0], ))
if len(test_data) == 100: if len(test_data) == 100:
break break
# output is a softmax layer. It returns probabilities. # output is a softmax layer. It returns probabilities.
# Shape should be (100, 10) # Shape should be (100, 10)
probs = paddle.infer(output=predict, parameters=parameters, input=test_data) probs = paddle.infer(
output_layer=predict, parameters=parameters, input=test_data)
print probs.shape print probs.shape
......
...@@ -9,8 +9,8 @@ __all__ = ['infer'] ...@@ -9,8 +9,8 @@ __all__ = ['infer']
class Inference(object): class Inference(object):
def __init__(self, output, parameters): def __init__(self, output_layer, parameters):
topo = topology.Topology(output) topo = topology.Topology(output_layer)
gm = api.GradientMachine.createFromConfigProto( gm = api.GradientMachine.createFromConfigProto(
topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE]) topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE])
for param in gm.getParameters(): for param in gm.getParameters():
...@@ -21,33 +21,16 @@ class Inference(object): ...@@ -21,33 +21,16 @@ class Inference(object):
self.__gradient_machine__ = gm self.__gradient_machine__ = gm
self.__data_types__ = topo.data_type() self.__data_types__ = topo.data_type()
def iter_infer(self, input=None, batch_size=None, reader=None, def iter_infer(self, input, feeding=None):
feeding=None):
feeder = DataFeeder(self.__data_types__, feeding) feeder = DataFeeder(self.__data_types__, feeding)
if reader is None: batch_size = len(input)
assert input is not None and isinstance(input, collections.Iterable)
if not isinstance(input, collections.Iterable): def __reader_impl__():
raise TypeError("When reader is None, input should be whole " for each_sample in input:
"inference data and should be iterable") yield each_sample
if batch_size is None: reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
if not hasattr(input, '__len__'):
raise ValueError("Should set batch size when input data "
"don't contain length.")
batch_size = len(input)
def __reader_impl__():
for each_sample in input:
if len(feeder) == 1:
yield [each_sample]
else:
yield each_sample
reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
else:
if input is not None:
raise ValueError("User should set either input or reader, "
"should not set them both.")
self.__gradient_machine__.start() self.__gradient_machine__.start()
for data_batch in reader(): for data_batch in reader():
yield self.__gradient_machine__.forwardTest(feeder(data_batch)) yield self.__gradient_machine__.forwardTest(feeder(data_batch))
...@@ -71,13 +54,7 @@ class Inference(object): ...@@ -71,13 +54,7 @@ class Inference(object):
return retv return retv
def infer(output, def infer(output_layer, parameters, input, feeding=None, field='value'):
parameters,
input=None,
batch_size=None,
reader=None,
feeding=None,
field='value'):
""" """
Infer a neural network by given neural network output and parameters. The Infer a neural network by given neural network output and parameters. The
user should pass either a batch of input data or reader method. user should pass either a batch of input data or reader method.
...@@ -90,19 +67,13 @@ def infer(output, ...@@ -90,19 +67,13 @@ def infer(output,
batch_size=32) batch_size=32)
print result print result
:param output: output of the neural network that would be inferred :param output_layer: output of the neural network that would be inferred
:type output: paddle.v2.config_base.Layer :type output_layer: paddle.v2.config_base.Layer
:param parameters: parameters of the neural network. :param parameters: parameters of the neural network.
:type parameters: paddle.v2.parameters.Parameters :type parameters: paddle.v2.parameters.Parameters
:param input: input data batch. Should be a python iterable object, and each :param input: input data batch. Should be a python iterable object, and each
element is the data batch. element is the data batch.
:type input: collections.Iterable :type input: collections.Iterable
:param batch_size: the batch size when perform inference. Default is the
length of input.
:type batch_size: int
:param reader: input data reader creator in batch. If this field is set, the
`input` and `batch_size` will be ignored.
:type reader: callable
:param feeding: Reader dictionary. Default could generate from input :param feeding: Reader dictionary. Default could generate from input
value. value.
:param field: The prediction field. It should in [`value`, `ids`]. `value` :param field: The prediction field. It should in [`value`, `ids`]. `value`
...@@ -113,10 +84,5 @@ def infer(output, ...@@ -113,10 +84,5 @@ def infer(output,
:rtype: numpy.ndarray :rtype: numpy.ndarray
""" """
inferer = Inference(output=output, parameters=parameters) inferer = Inference(output_layer=output_layer, parameters=parameters)
return inferer.infer( return inferer.infer(field=field, input=input, feeding=feeding)
field=field,
input=input,
batch_size=batch_size,
reader=reader,
feeding=feeding)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册