提交 77ed98d1 编写于 作者: Z Zrachel 提交者: emailweixu

fix bugs under kSgdSparseCpuTraining mode (#100)

Local training with "sparse_update = True" parameter triggers kSgdSparseCpuTraining mode, fix bugs under it.
上级 341486d5
...@@ -141,7 +141,7 @@ void SgdThreadUpdater::traverse(GetTraverseCallback getTraverseCallback) { ...@@ -141,7 +141,7 @@ void SgdThreadUpdater::traverse(GetTraverseCallback getTraverseCallback) {
} else if (hasCpuPara) { } else if (hasCpuPara) {
getGlobalSyncThreadPool()->exec(cpuTraverse); getGlobalSyncThreadPool()->exec(cpuTraverse);
} else if (hasGpuPara) { } else if (hasGpuPara) {
cpuTraverse(0, 0); gpuTraverse(0, 0);
} }
} }
......
...@@ -101,6 +101,7 @@ void TrainerInternal::trainOneBatch(int64_t batchId, ...@@ -101,6 +101,7 @@ void TrainerInternal::trainOneBatch(int64_t batchId,
// it // it
//! to ParameterHook. //! to ParameterHook.
auto& grad = para->getBuf(PARAMETER_GRADIENT); auto& grad = para->getBuf(PARAMETER_GRADIENT);
SetDevice device(para->getDeviceId());
paraStats[para->getID()].avgAbsGrad = grad->getAbsSum() / para->getSize(); paraStats[para->getID()].avgAbsGrad = grad->getAbsSum() / para->getSize();
paraStats[para->getID()].maxAbsGrad = grad->getAbsMax(); paraStats[para->getID()].maxAbsGrad = grad->getAbsMax();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册