diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 637596e7bc1c0a5fd231b9e96010f0e946a1155b..06beb7024d1fd07dc327cb4c09d74e1b89a7b8ff 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -50,10 +50,10 @@ def main(): output=inference, parameters=parameters, reader=paddle.reader.batched( - paddle.reader.limited( + paddle.reader.firstn( paddle.reader.map_readers(lambda item: (item[0], ), paddle.dataset.mnist.test()), - limit=100), + n=100), batch_size=32)) print probs.shape diff --git a/python/paddle/v2/inferencer.py b/python/paddle/v2/inferencer.py index 33f5ad1c07c6e2c562d0a549e42c62adc870998d..ac03b016c9b8bfbc586072855402ed3a373e9b54 100644 --- a/python/paddle/v2/inferencer.py +++ b/python/paddle/v2/inferencer.py @@ -5,10 +5,10 @@ from data_feeder import DataFeeder import itertools import numpy -__all__ = ['InferenceEngine', 'infer'] +__all__ = ['Inference', 'infer'] -class InferenceEngine(object): +class Inference(object): def __init__(self, output, parameters): topo = topology.Topology(output) gm = api.GradientMachine.createFromConfigProto( @@ -55,5 +55,5 @@ class InferenceEngine(object): def infer(output, parameters, reader, reader_dict=None, field='value'): - inferer = InferenceEngine(output=output, parameters=parameters) + inferer = Inference(output=output, parameters=parameters) return inferer.infer(field=field, reader=reader, reader_dict=reader_dict) diff --git a/python/paddle/v2/reader/decorator.py b/python/paddle/v2/reader/decorator.py index fe5acbdff594471480c3f372b569e1c36e068525..b7657e27764f099334ba3030c493a7607f323abe 100644 --- a/python/paddle/v2/reader/decorator.py +++ b/python/paddle/v2/reader/decorator.py @@ -14,7 +14,7 @@ __all__ = [ 'map_readers', 'buffered', 'compose', 'chain', 'shuffle', - 'ComposeNotAligned', 'batched', 'limited' + 'ComposeNotAligned', 'batched', 'firstn' ] import itertools @@ -215,15 +215,18 @@ def batched(reader, batch_size): return batched_reader -def limited(reader, limit): +def firstn(reader, n): """ Limit the max number of samples that reader could return. """ - def limited_reader(): + # TODO(yuyang18): Check if just drop the reader, could clean the opened + # resource or not? + + def firstn_reader(): for i, item in enumerate(reader()): - if i == limit: + if i == n: break yield item - return limited_reader + return firstn_reader