diff --git a/paddle/py_paddle/dataprovider_converter.py b/paddle/py_paddle/dataprovider_converter.py index 60a0f2c04448636a39b6fc550c51faf33ea93b09..57e58e077d7fbe6f8a3700047625470652117c45 100644 --- a/paddle/py_paddle/dataprovider_converter.py +++ b/paddle/py_paddle/dataprovider_converter.py @@ -26,7 +26,14 @@ class IScanner(object): if not isinstance(self.input_type, dp2.InputType): raise ValueError("input type should be dataprovider2.InputType") self.pos = pos - self.use_gpu = True if swig_paddle.isUsingGpu() and ( + # data_in_gpu is used to indicate whether to create argument on GPU + # or not in GPU mode. Now if using one thread (trainer_count=1), + # trainer uses NeuralNetwork which needs to create argument on GPU + # before calling forward function. So, set data_in_gpu to True. + # Otherwise, trainer uses MultiGradientMachine which will transfer + # data from CPU to GPU in the forward function, set data_in_gpu to + # False in this case. + self.data_in_gpu = True if swig_paddle.isUsingGpu() and ( swig_paddle.getTrainerCount() == 1) else False def scan(self, dat):