提交 52dc6a9c 编写于 作者: Y Yu Yang

Merge branch 'feature/better_infer_interface' into feature/recommendation_v2_api

...@@ -92,12 +92,8 @@ def main(): ...@@ -92,12 +92,8 @@ def main():
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 1000 == 0: if event.batch_id % 1000 == 0:
result = trainer.test(reader=paddle.batch( print "Pass %d, Batch %d, Cost %f, %s" % (
paddle.dataset.mnist.test(), batch_size=256)) event.pass_id, event.batch_id, event.cost, event.metrics)
print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics,
result.metrics)
with gzip.open('params.tar.gz', 'w') as f: with gzip.open('params.tar.gz', 'w') as f:
parameters.to_tar(f) parameters.to_tar(f)
...@@ -123,17 +119,16 @@ def main(): ...@@ -123,17 +119,16 @@ def main():
print 'Best pass is %s, testing Avgcost is %s' % (best[0], best[1]) print 'Best pass is %s, testing Avgcost is %s' % (best[0], best[1])
print 'The classification accuracy is %.2f%%' % (100 - float(best[2]) * 100) print 'The classification accuracy is %.2f%%' % (100 - float(best[2]) * 100)
test_creator = paddle.dataset.mnist.test()
test_data = []
for item in test_creator():
test_data.append(item[0])
if len(test_data) == 100:
break
# output is a softmax layer. It returns probabilities. # output is a softmax layer. It returns probabilities.
# Shape should be (100, 10) # Shape should be (100, 10)
probs = paddle.infer( probs = paddle.infer(output=predict, parameters=parameters, input=test_data)
output=predict,
parameters=parameters,
reader=paddle.batch(
paddle.reader.firstn(
paddle.reader.map_readers(lambda item: (item[0], ),
paddle.dataset.mnist.test()),
n=100),
batch_size=32))
print probs.shape print probs.shape
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Trainer API Trainer API
########### ###########
========== ==========
Parameters Parameters
========== ==========
...@@ -24,3 +25,10 @@ Event ...@@ -24,3 +25,10 @@ Event
.. automodule:: paddle.v2.event .. automodule:: paddle.v2.event
:members: :members:
=========
Inference
=========
.. autofunction:: paddle.v2.infer
\ No newline at end of file
...@@ -85,6 +85,9 @@ class DataFeeder(DataProviderConverter): ...@@ -85,6 +85,9 @@ class DataFeeder(DataProviderConverter):
input_types.append(each[1]) input_types.append(each[1])
DataProviderConverter.__init__(self, input_types) DataProviderConverter.__init__(self, input_types)
def __len__(self):
return len(self.input_names)
def convert(self, dat, argument=None): def convert(self, dat, argument=None):
""" """
:param dat: A list of mini-batch data. Each sample is a list or tuple :param dat: A list of mini-batch data. Each sample is a list or tuple
......
import numpy
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
import collections
import topology import topology
import minibatch
from data_feeder import DataFeeder from data_feeder import DataFeeder
import itertools
import numpy
__all__ = ['infer'] __all__ = ['infer']
...@@ -21,8 +21,33 @@ class Inference(object): ...@@ -21,8 +21,33 @@ class Inference(object):
self.__gradient_machine__ = gm self.__gradient_machine__ = gm
self.__data_types__ = topo.data_type() self.__data_types__ = topo.data_type()
def iter_infer(self, reader, feeding=None): def iter_infer(self, input=None, batch_size=None, reader=None,
feeding=None):
feeder = DataFeeder(self.__data_types__, feeding) feeder = DataFeeder(self.__data_types__, feeding)
if reader is None:
assert input is not None and isinstance(input, collections.Iterable)
if not isinstance(input, collections.Iterable):
raise TypeError("When reader is None, input should be whole "
"inference data and should be iterable")
if batch_size is None:
if not hasattr(input, '__len__'):
raise ValueError("Should set batch size when input data "
"don't contain length.")
batch_size = len(input)
def __reader_impl__():
for each_sample in input:
if len(feeder) == 1:
yield [each_sample]
else:
yield each_sample
reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
else:
if input is not None:
raise ValueError("User should set either input or reader, "
"should not set them both.")
self.__gradient_machine__.start() self.__gradient_machine__.start()
for data_batch in reader(): for data_batch in reader():
yield self.__gradient_machine__.forwardTest(feeder(data_batch)) yield self.__gradient_machine__.forwardTest(feeder(data_batch))
...@@ -46,6 +71,52 @@ class Inference(object): ...@@ -46,6 +71,52 @@ class Inference(object):
return retv return retv
def infer(output, parameters, reader, feeding=None, field='value'): def infer(output,
parameters,
input=None,
batch_size=None,
reader=None,
feeding=None,
field='value'):
"""
Infer a neural network by given neural network output and parameters. The
user should pass either a batch of input data or reader method.
Example usages:
.. code-block:: python
result = paddle.infer(prediction, parameters, input=SomeData,
batch_size=32)
print result
:param output: output of the neural network that would be inferred
:type output: paddle.v2.config_base.Layer
:param parameters: parameters of the neural network.
:type parameters: paddle.v2.parameters.Parameters
:param input: input data batch. Should be a python iterable object, and each
element is the data batch.
:type input: collections.Iterable
:param batch_size: the batch size when perform inference. Default is the
length of input.
:type batch_size: int
:param reader: input data reader creator in batch. If this field is set, the
`input` and `batch_size` will be ignored.
:type reader: callable
:param feeding: Reader dictionary. Default could generate from input
value.
:param field: The prediction field. It should in [`value`, `ids`]. `value`
means return the prediction probabilities, `ids` means return
the prediction labels. Default is `value`
:type field: str
:return: a numpy array
:rtype: numpy.ndarray
"""
inferer = Inference(output=output, parameters=parameters) inferer = Inference(output=output, parameters=parameters)
return inferer.infer(field=field, reader=reader, feeding=feeding) return inferer.infer(
field=field,
input=input,
batch_size=batch_size,
reader=reader,
feeding=feeding)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册