From 77ed98d1a86f408db1f4043bd9d6fc14c48295e9 Mon Sep 17 00:00:00 2001 From: Zrachel Date: Sat, 24 Sep 2016 00:29:52 +0800 Subject: [PATCH] fix bugs under kSgdSparseCpuTraining mode (#100) Local training with "sparse_update = True" parameter triggers kSgdSparseCpuTraining mode, fix bugs under it. --- paddle/trainer/ThreadParameterUpdater.cpp | 2 +- paddle/trainer/TrainerInternal.cpp | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/trainer/ThreadParameterUpdater.cpp b/paddle/trainer/ThreadParameterUpdater.cpp index 65d827787e..91f7f4d29d 100644 --- a/paddle/trainer/ThreadParameterUpdater.cpp +++ b/paddle/trainer/ThreadParameterUpdater.cpp @@ -141,7 +141,7 @@ void SgdThreadUpdater::traverse(GetTraverseCallback getTraverseCallback) { } else if (hasCpuPara) { getGlobalSyncThreadPool()->exec(cpuTraverse); } else if (hasGpuPara) { - cpuTraverse(0, 0); + gpuTraverse(0, 0); } } diff --git a/paddle/trainer/TrainerInternal.cpp b/paddle/trainer/TrainerInternal.cpp index 76b6b9bc3e..6029a4b2c1 100644 --- a/paddle/trainer/TrainerInternal.cpp +++ b/paddle/trainer/TrainerInternal.cpp @@ -101,6 +101,7 @@ void TrainerInternal::trainOneBatch(int64_t batchId, // it //! to ParameterHook. auto& grad = para->getBuf(PARAMETER_GRADIENT); + SetDevice device(para->getDeviceId()); paraStats[para->getID()].avgAbsGrad = grad->getAbsSum() / para->getSize(); paraStats[para->getID()].maxAbsGrad = grad->getAbsMax(); } -- GitLab