提交 6bf6034a 编写于 作者: D dangqingqing

follow comments

上级 9c6b1a46
...@@ -47,7 +47,7 @@ void setUseGpu(bool useGpu); ...@@ -47,7 +47,7 @@ void setUseGpu(bool useGpu);
/// Return true if this py_paddle is compiled in GPU Version /// Return true if this py_paddle is compiled in GPU Version
bool isGpuVersion(); bool isGpuVersion();
/// Return FLAGS_use_gpu /// Return FLAGS_trainer_count
int getTrainerCount(); int getTrainerCount();
/// The Error of IO Operation. Such as file not found, etc. /// The Error of IO Operation. Such as file not found, etc.
......
...@@ -26,6 +26,8 @@ class IScanner(object): ...@@ -26,6 +26,8 @@ class IScanner(object):
if not isinstance(self.input_type, dp2.InputType): if not isinstance(self.input_type, dp2.InputType):
raise ValueError("input type should be dataprovider2.InputType") raise ValueError("input type should be dataprovider2.InputType")
self.pos = pos self.pos = pos
self.use_gpu = True if swig_paddle.isUsingGpu() and (
swig_paddle.getTrainerCount() == 1) else False
def scan(self, dat): def scan(self, dat):
pass pass
...@@ -33,11 +35,6 @@ class IScanner(object): ...@@ -33,11 +35,6 @@ class IScanner(object):
def finish_scan(self, argument): def finish_scan(self, argument):
pass pass
def use_gpu(self):
gpu = True if swig_paddle.isUsingGpu() and (
swig_paddle.getTrainerCount() == 1) else False
return gpu
class DenseScanner(IScanner): class DenseScanner(IScanner):
""" """
...@@ -59,7 +56,7 @@ class DenseScanner(IScanner): ...@@ -59,7 +56,7 @@ class DenseScanner(IScanner):
if self.__mat__.dtype != numpy.float32: if self.__mat__.dtype != numpy.float32:
self.__mat__ = self.__mat__.astype(numpy.float32) self.__mat__ = self.__mat__.astype(numpy.float32)
m = swig_paddle.Matrix.createDenseFromNumpy(self.__mat__, True, m = swig_paddle.Matrix.createDenseFromNumpy(self.__mat__, True,
self.use_gpu()) self.use_gpu)
argument.setSlotValue(self.pos, m) argument.setSlotValue(self.pos, m)
...@@ -111,7 +108,7 @@ class IndexScanner(IScanner): ...@@ -111,7 +108,7 @@ class IndexScanner(IScanner):
self.__ids__.append(dat) self.__ids__.append(dat)
def finish_scan(self, argument): def finish_scan(self, argument):
ids = swig_paddle.IVector.create(self.__ids__, self.use_gpu()) ids = swig_paddle.IVector.create(self.__ids__, self.use_gpu)
assert isinstance(argument, swig_paddle.Arguments) assert isinstance(argument, swig_paddle.Arguments)
argument.setSlotIds(self.pos, ids) argument.setSlotIds(self.pos, ids)
......
...@@ -240,4 +240,3 @@ if __name__ == '__main__': ...@@ -240,4 +240,3 @@ if __name__ == '__main__':
if api.isGpuVersion(): if api.isGpuVersion():
api.setUseGpu(True) api.setUseGpu(True)
unittest.main() unittest.main()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册