diff --git a/paddle/trainer/ThreadParameterUpdater.cpp b/paddle/trainer/ThreadParameterUpdater.cpp index 65d827787ee78fe7a572869d7115c7abe27304a6..91f7f4d29df938e88a0e8c54b7046194c7adfb35 100644 --- a/paddle/trainer/ThreadParameterUpdater.cpp +++ b/paddle/trainer/ThreadParameterUpdater.cpp @@ -141,7 +141,7 @@ void SgdThreadUpdater::traverse(GetTraverseCallback getTraverseCallback) { } else if (hasCpuPara) { getGlobalSyncThreadPool()->exec(cpuTraverse); } else if (hasGpuPara) { - cpuTraverse(0, 0); + gpuTraverse(0, 0); } } diff --git a/paddle/trainer/TrainerInternal.cpp b/paddle/trainer/TrainerInternal.cpp index 76b6b9bc3ee38341a67d7ec111196e28a28a0e9b..6029a4b2c1d0a0c04058bbd979523f26b72b5a5e 100644 --- a/paddle/trainer/TrainerInternal.cpp +++ b/paddle/trainer/TrainerInternal.cpp @@ -101,6 +101,7 @@ void TrainerInternal::trainOneBatch(int64_t batchId, // it //! to ParameterHook. auto& grad = para->getBuf(PARAMETER_GRADIENT); + SetDevice device(para->getDeviceId()); paraStats[para->getID()].avgAbsGrad = grad->getAbsSum() / para->getSize(); paraStats[para->getID()].maxAbsGrad = grad->getAbsMax(); }