提交 88952fba 编写于 作者: T tensor-tang

use existed sgd updater function

上级 d6a27ade
......@@ -92,28 +92,6 @@ public:
const T* getData() const { return this->data_; }
T* getData() { return this->data_; }
#ifdef PADDLE_USE_MKLDNN
/**
* sgd update with openmp to speedup
*/
void sgdUpdateWithOMP(VectorT& gradVec,
VectorT& momVec,
T learningRate,
T momentum,
T decayRate) {
size_t size = this->getSize();
T* val = this->getData();
T* grd = gradVec.getData();
T* mom = momVec.getData();
decayRate *= learningRate;
#pragma omp parallel for
for (size_t i = 0; i < size; ++i) {
mom[i] = momentum * mom[i] - learningRate * grd[i] - decayRate * val[i];
val[i] += mom[i];
}
}
#endif
virtual void zeroMem() = 0;
// set all elements to value
virtual void reset(const T& value) = 0;
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "ParameterOptimizer.h"
#include "ParameterUpdateFunctions.h"
#include "Regularizer.h"
namespace paddle {
......@@ -38,13 +39,13 @@ public:
? 1.0 - paraConfig.momentum()
: 1.0;
#ifdef PADDLE_USE_MKLDNN
vecs[PARAMETER_VALUE]->sgdUpdateWithOMP(
*vecs[PARAMETER_GRADIENT],
*vecs[PARAMETER_MOMENTUM],
learningRate_ * paraConfig.learning_rate() *
(firstTime_ ? 1.0 : torch_learningRate),
paraConfig.momentum(),
applyDecay_ ? paraConfig.decay_rate() : 0);
sgdUpdate(learningRate_ * paraConfig.learning_rate() *
(firstTime_ ? 1.0 : torch_learningRate),
paraConfig.momentum(),
applyDecay_ ? paraConfig.decay_rate() : 0,
vecs[PARAMETER_VALUE].get(),
vecs[PARAMETER_GRADIENT].get(),
vecs[PARAMETER_MOMENTUM].get());
#else
vecs[PARAMETER_VALUE]->sgdUpdate(
*vecs[PARAMETER_GRADIENT],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册