From 500d8836d03287621e983f815a919d9a29749c6d Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 1 Mar 2017 18:49:13 +0800 Subject: [PATCH] Follow comments --- demo/mnist/api_train_v2.py | 4 ++-- python/paddle/v2/inferencer.py | 6 +++--- python/paddle/v2/reader/decorator.py | 13 ++++++++----- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 637596e7bc..06beb7024d 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -50,10 +50,10 @@ def main(): output=inference, parameters=parameters, reader=paddle.reader.batched( - paddle.reader.limited( + paddle.reader.firstn( paddle.reader.map_readers(lambda item: (item[0], ), paddle.dataset.mnist.test()), - limit=100), + n=100), batch_size=32)) print probs.shape diff --git a/python/paddle/v2/inferencer.py b/python/paddle/v2/inferencer.py index 33f5ad1c07..ac03b016c9 100644 --- a/python/paddle/v2/inferencer.py +++ b/python/paddle/v2/inferencer.py @@ -5,10 +5,10 @@ from data_feeder import DataFeeder import itertools import numpy -__all__ = ['InferenceEngine', 'infer'] +__all__ = ['Inference', 'infer'] -class InferenceEngine(object): +class Inference(object): def __init__(self, output, parameters): topo = topology.Topology(output) gm = api.GradientMachine.createFromConfigProto( @@ -55,5 +55,5 @@ class InferenceEngine(object): def infer(output, parameters, reader, reader_dict=None, field='value'): - inferer = InferenceEngine(output=output, parameters=parameters) + inferer = Inference(output=output, parameters=parameters) return inferer.infer(field=field, reader=reader, reader_dict=reader_dict) diff --git a/python/paddle/v2/reader/decorator.py b/python/paddle/v2/reader/decorator.py index fe5acbdff5..b7657e2776 100644 --- a/python/paddle/v2/reader/decorator.py +++ b/python/paddle/v2/reader/decorator.py @@ -14,7 +14,7 @@ __all__ = [ 'map_readers', 'buffered', 'compose', 'chain', 'shuffle', - 'ComposeNotAligned', 'batched', 'limited' + 'ComposeNotAligned', 'batched', 'firstn' ] import itertools @@ -215,15 +215,18 @@ def batched(reader, batch_size): return batched_reader -def limited(reader, limit): +def firstn(reader, n): """ Limit the max number of samples that reader could return. """ - def limited_reader(): + # TODO(yuyang18): Check if just drop the reader, could clean the opened + # resource or not? + + def firstn_reader(): for i, item in enumerate(reader()): - if i == limit: + if i == n: break yield item - return limited_reader + return firstn_reader -- GitLab