提交 77ed98d1 编写于 作者: Z Zrachel 提交者: emailweixu

fix bugs under kSgdSparseCpuTraining mode (#100)

Local training with "sparse_update = True" parameter triggers kSgdSparseCpuTraining mode, fix bugs under it.
上级 341486d5
......@@ -141,7 +141,7 @@ void SgdThreadUpdater::traverse(GetTraverseCallback getTraverseCallback) {
} else if (hasCpuPara) {
getGlobalSyncThreadPool()->exec(cpuTraverse);
} else if (hasGpuPara) {
cpuTraverse(0, 0);
gpuTraverse(0, 0);
}
}
......
......@@ -101,6 +101,7 @@ void TrainerInternal::trainOneBatch(int64_t batchId,
// it
//! to ParameterHook.
auto& grad = para->getBuf(PARAMETER_GRADIENT);
SetDevice device(para->getDeviceId());
paraStats[para->getID()].avgAbsGrad = grad->getAbsSum() / para->getSize();
paraStats[para->getID()].maxAbsGrad = grad->getAbsMax();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册