提交 500d8836 编写于 作者: Y Yu Yang

Follow comments

上级 39b85d9c
...@@ -50,10 +50,10 @@ def main(): ...@@ -50,10 +50,10 @@ def main():
output=inference, output=inference,
parameters=parameters, parameters=parameters,
reader=paddle.reader.batched( reader=paddle.reader.batched(
paddle.reader.limited( paddle.reader.firstn(
paddle.reader.map_readers(lambda item: (item[0], ), paddle.reader.map_readers(lambda item: (item[0], ),
paddle.dataset.mnist.test()), paddle.dataset.mnist.test()),
limit=100), n=100),
batch_size=32)) batch_size=32))
print probs.shape print probs.shape
......
...@@ -5,10 +5,10 @@ from data_feeder import DataFeeder ...@@ -5,10 +5,10 @@ from data_feeder import DataFeeder
import itertools import itertools
import numpy import numpy
__all__ = ['InferenceEngine', 'infer'] __all__ = ['Inference', 'infer']
class InferenceEngine(object): class Inference(object):
def __init__(self, output, parameters): def __init__(self, output, parameters):
topo = topology.Topology(output) topo = topology.Topology(output)
gm = api.GradientMachine.createFromConfigProto( gm = api.GradientMachine.createFromConfigProto(
...@@ -55,5 +55,5 @@ class InferenceEngine(object): ...@@ -55,5 +55,5 @@ class InferenceEngine(object):
def infer(output, parameters, reader, reader_dict=None, field='value'): def infer(output, parameters, reader, reader_dict=None, field='value'):
inferer = InferenceEngine(output=output, parameters=parameters) inferer = Inference(output=output, parameters=parameters)
return inferer.infer(field=field, reader=reader, reader_dict=reader_dict) return inferer.infer(field=field, reader=reader, reader_dict=reader_dict)
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
__all__ = [ __all__ = [
'map_readers', 'buffered', 'compose', 'chain', 'shuffle', 'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
'ComposeNotAligned', 'batched', 'limited' 'ComposeNotAligned', 'batched', 'firstn'
] ]
import itertools import itertools
...@@ -215,15 +215,18 @@ def batched(reader, batch_size): ...@@ -215,15 +215,18 @@ def batched(reader, batch_size):
return batched_reader return batched_reader
def limited(reader, limit): def firstn(reader, n):
""" """
Limit the max number of samples that reader could return. Limit the max number of samples that reader could return.
""" """
def limited_reader(): # TODO(yuyang18): Check if just drop the reader, could clean the opened
# resource or not?
def firstn_reader():
for i, item in enumerate(reader()): for i, item in enumerate(reader()):
if i == limit: if i == n:
break break
yield item yield item
return limited_reader return firstn_reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册