diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 1831b8e170087c909f77948f2d9077c946c72507..0223dec4f42efe066b3f3c69f21328c2af8cb0fe 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -47,6 +47,9 @@ void setUseGpu(bool useGpu); /// Return true if this py_paddle is compiled in GPU Version bool isGpuVersion(); +/// Return FLAGS_use_gpu +int getTrainerCount(); + /// The Error of IO Operation. Such as file not found, etc. class IOError {}; diff --git a/paddle/api/Util.cpp b/paddle/api/Util.cpp index 54d67aa62f4d87ad03282962c722019698dc621a..d369df5d4e04b4a8d822db0e72a8051150868ce6 100644 --- a/paddle/api/Util.cpp +++ b/paddle/api/Util.cpp @@ -54,5 +54,7 @@ bool isGpuVersion() { #endif } +int getTrainerCount() { return FLAGS_trainer_count; } + static_assert(NUM_PARAMETER_TYPES == paddle::NUM_PARAMETER_TYPES, "The Parameter Type should be same in core/api and core/common"); diff --git a/paddle/py_paddle/dataprovider_converter.py b/paddle/py_paddle/dataprovider_converter.py index 2690cafe1d8d32bf52cd9e5fa4dc69fbacb2d66c..8d5c57e10fd3584398c696e29f8ee11b858f7c53 100644 --- a/paddle/py_paddle/dataprovider_converter.py +++ b/paddle/py_paddle/dataprovider_converter.py @@ -33,6 +33,11 @@ class IScanner(object): def finish_scan(self, argument): pass + def use_gpu(self): + gpu = True if swig_paddle.isUsingGpu() and ( + swig_paddle.getTrainerCount() == 1) else False + return gpu + class DenseScanner(IScanner): """ @@ -53,7 +58,8 @@ class DenseScanner(IScanner): assert isinstance(argument, swig_paddle.Arguments) if self.__mat__.dtype != numpy.float32: self.__mat__ = self.__mat__.astype(numpy.float32) - m = swig_paddle.Matrix.createDenseFromNumpy(self.__mat__, True, False) + m = swig_paddle.Matrix.createDenseFromNumpy(self.__mat__, True, + self.use_gpu()) argument.setSlotValue(self.pos, m) @@ -75,10 +81,13 @@ class SparseBinaryScanner(IScanner): def finish_scan(self, argument): assert isinstance(argument, swig_paddle.Arguments) - m = swig_paddle.Matrix.createSparse(self.__height__, - self.input_type.dim, - len(self.__cols__), - len(self.__value__) == 0) + m = swig_paddle.Matrix.createSparse( + self.__height__, + self.input_type.dim, + len(self.__cols__), + len(self.__value__) == 0, + False, # trans + False) # TODO supoort GPU assert isinstance(m, swig_paddle.Matrix) m.sparseCopyFrom(self.__rows__, self.__cols__, self.__value__) argument.setSlotValue(self.pos, m) @@ -102,7 +111,7 @@ class IndexScanner(IScanner): self.__ids__.append(dat) def finish_scan(self, argument): - ids = swig_paddle.IVector.create(self.__ids__) + ids = swig_paddle.IVector.create(self.__ids__, self.use_gpu()) assert isinstance(argument, swig_paddle.Arguments) argument.setSlotIds(self.pos, ids) diff --git a/python/paddle/v2/tests/test_data_feeder.py b/python/paddle/v2/tests/test_data_feeder.py index 5f67da6a5b32d74228d727d94ec79b9f7a06dab7..916d3aa4cdcbca689844e6ebefa2c3b861c92a05 100644 --- a/python/paddle/v2/tests/test_data_feeder.py +++ b/python/paddle/v2/tests/test_data_feeder.py @@ -235,4 +235,9 @@ class DataFeederTest(unittest.TestCase): if __name__ == '__main__': api.initPaddle("--use_gpu=0") + suite = unittest.TestLoader().loadTestsFromTestCase(DataFeederTest) + unittest.TextTestRunner().run(suite) + if api.isGpuVersion(): + api.setUseGpu(True) + unittest.main() unittest.main()