提交 9ba231d3 编写于 作者: Y Yu Yang

Complete inferencer.

上级 4c24ac1a
...@@ -44,6 +44,19 @@ def main(): ...@@ -44,6 +44,19 @@ def main():
batch_size=32), batch_size=32),
event_handler=event_handler) event_handler=event_handler)
# output is a softmax layer. It returns probabilities.
# Shape should be (100, 10)
probs = paddle.infer(
output=inference,
parameters=parameters,
reader=paddle.reader.batched(
paddle.reader.limited(
paddle.reader.map_readers(lambda item: (item[0], ),
paddle.dataset.mnist.test()),
limit=100),
batch_size=32))
print probs.shape
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -35,24 +35,25 @@ def reader_creator(image_filename, label_filename, buffer_size): ...@@ -35,24 +35,25 @@ def reader_creator(image_filename, label_filename, buffer_size):
l = subprocess.Popen([zcat_cmd, label_filename], stdout=subprocess.PIPE) l = subprocess.Popen([zcat_cmd, label_filename], stdout=subprocess.PIPE)
l.stdout.read(8) # skip some magic bytes l.stdout.read(8) # skip some magic bytes
while True: try: # reader could be break.
labels = numpy.fromfile( while True:
l.stdout, 'ubyte', count=buffer_size).astype("int") labels = numpy.fromfile(
l.stdout, 'ubyte', count=buffer_size).astype("int")
if labels.size != buffer_size: if labels.size != buffer_size:
break # numpy.fromfile returns empty slice after EOF. break # numpy.fromfile returns empty slice after EOF.
images = numpy.fromfile( images = numpy.fromfile(
m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape( m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape(
(buffer_size, 28 * 28)).astype('float32') (buffer_size, 28 * 28)).astype('float32')
images = images / 255.0 * 2.0 - 1.0 images = images / 255.0 * 2.0 - 1.0
for i in xrange(buffer_size): for i in xrange(buffer_size):
yield images[i, :], int(labels[i]) yield images[i, :], int(labels[i])
finally:
m.terminate() m.terminate()
l.terminate() l.terminate()
return reader return reader
......
...@@ -16,18 +16,18 @@ class InferenceEngine(object): ...@@ -16,18 +16,18 @@ class InferenceEngine(object):
for param in gm.getParameters(): for param in gm.getParameters():
val = param.getBuf(api.PARAMETER_VALUE) val = param.getBuf(api.PARAMETER_VALUE)
name = param.getName() name = param.getName()
assert isinstance(val, api.Matrix) assert isinstance(val, api.Vector)
val.copyFromNumpyMat(parameters.get(name)) val.copyFromNumpyArray(parameters.get(name).flatten())
self.__gradient_machine__ = gm self.__gradient_machine__ = gm
self.__data_types__ = topo.data_type() self.__data_types__ = topo.data_type()
def iter_infer(self, reader, reader_dict=None): def iter_infer(self, reader, reader_dict=None):
if reader_dict is None:
reader_dict = self.default_reader_dict()
feeder = DataFeeder(self.__data_types__, reader_dict) feeder = DataFeeder(self.__data_types__, reader_dict)
out_args = api.Arguments.createArguments(0)
self.__gradient_machine__.start() self.__gradient_machine__.start()
for data_batch in reader(): for data_batch in reader():
yield self.__gradient_machine__.forwardTest( yield self.__gradient_machine__.forwardTest(feeder(data_batch))
feeder(data_batch), out_args, api.PASS_TEST)
self.__gradient_machine__.finish() self.__gradient_machine__.finish()
def iter_infer_field(self, field, **kwargs): def iter_infer_field(self, field, **kwargs):
...@@ -35,12 +35,17 @@ class InferenceEngine(object): ...@@ -35,12 +35,17 @@ class InferenceEngine(object):
yield [each_result[field] for each_result in result] yield [each_result[field] for each_result in result]
def infer(self, field='value', **kwargs): def infer(self, field='value', **kwargs):
retv = [] retv = None
for result in itertools.izip( for result in self.iter_infer_field(field=field, **kwargs):
self.iter_infer_field( if retv is None:
field=field, **kwargs)): retv = [[]] * len(result)
retv.append(numpy.concatenate(result)) for i, item in enumerate(result):
return retv retv[i].append(item)
retv = [numpy.concatenate(out) for out in retv]
if len(retv) == 1:
return retv[0]
else:
return retv
def default_reader_dict(self): def default_reader_dict(self):
reader_dict = dict() reader_dict = dict()
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
__all__ = [ __all__ = [
'map_readers', 'buffered', 'compose', 'chain', 'shuffle', 'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
'ComposeNotAligned', 'batched' 'ComposeNotAligned', 'batched', 'limited'
] ]
from Queue import Queue
from threading import Thread
import itertools import itertools
import random import random
from Queue import Queue
from threading import Thread
def map_readers(func, *readers): def map_readers(func, *readers):
...@@ -213,3 +213,17 @@ def batched(reader, batch_size): ...@@ -213,3 +213,17 @@ def batched(reader, batch_size):
yield batch yield batch
return batched_reader return batched_reader
def limited(reader, limit):
"""
Limit the max number of samples that reader could return.
"""
def limited_reader():
for i, item in enumerate(reader()):
if i == limit:
break
yield item
return limited_reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册