From 6bf6034ac63d3dc9336c89d5d38aa9f0d6262238 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 28 Feb 2017 18:43:32 +0800 Subject: [PATCH] follow comments --- paddle/api/PaddleAPI.h | 2 +- paddle/py_paddle/dataprovider_converter.py | 11 ++++------- python/paddle/v2/tests/test_data_feeder.py | 1 - 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 0223dec4f42..d99e9a4ad48 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -47,7 +47,7 @@ void setUseGpu(bool useGpu); /// Return true if this py_paddle is compiled in GPU Version bool isGpuVersion(); -/// Return FLAGS_use_gpu +/// Return FLAGS_trainer_count int getTrainerCount(); /// The Error of IO Operation. Such as file not found, etc. diff --git a/paddle/py_paddle/dataprovider_converter.py b/paddle/py_paddle/dataprovider_converter.py index 8d5c57e10fd..60a0f2c0444 100644 --- a/paddle/py_paddle/dataprovider_converter.py +++ b/paddle/py_paddle/dataprovider_converter.py @@ -26,6 +26,8 @@ class IScanner(object): if not isinstance(self.input_type, dp2.InputType): raise ValueError("input type should be dataprovider2.InputType") self.pos = pos + self.use_gpu = True if swig_paddle.isUsingGpu() and ( + swig_paddle.getTrainerCount() == 1) else False def scan(self, dat): pass @@ -33,11 +35,6 @@ class IScanner(object): def finish_scan(self, argument): pass - def use_gpu(self): - gpu = True if swig_paddle.isUsingGpu() and ( - swig_paddle.getTrainerCount() == 1) else False - return gpu - class DenseScanner(IScanner): """ @@ -59,7 +56,7 @@ class DenseScanner(IScanner): if self.__mat__.dtype != numpy.float32: self.__mat__ = self.__mat__.astype(numpy.float32) m = swig_paddle.Matrix.createDenseFromNumpy(self.__mat__, True, - self.use_gpu()) + self.use_gpu) argument.setSlotValue(self.pos, m) @@ -111,7 +108,7 @@ class IndexScanner(IScanner): self.__ids__.append(dat) def finish_scan(self, argument): - ids = swig_paddle.IVector.create(self.__ids__, self.use_gpu()) + ids = swig_paddle.IVector.create(self.__ids__, self.use_gpu) assert isinstance(argument, swig_paddle.Arguments) argument.setSlotIds(self.pos, ids) diff --git a/python/paddle/v2/tests/test_data_feeder.py b/python/paddle/v2/tests/test_data_feeder.py index 916d3aa4cdc..ab2bc5df76c 100644 --- a/python/paddle/v2/tests/test_data_feeder.py +++ b/python/paddle/v2/tests/test_data_feeder.py @@ -240,4 +240,3 @@ if __name__ == '__main__': if api.isGpuVersion(): api.setUseGpu(True) unittest.main() - unittest.main() -- GitLab