diff --git a/paddle/math/Vector.h b/paddle/math/Vector.h index 7dbf3cfb0d5433c1b44947fe7e24c7ab1f9ec183..80b9775fccf10c57bb48145ef56165ec7c86d8b8 100644 --- a/paddle/math/Vector.h +++ b/paddle/math/Vector.h @@ -92,28 +92,6 @@ public: const T* getData() const { return this->data_; } T* getData() { return this->data_; } -#ifdef PADDLE_USE_MKLDNN - /** - * sgd update with openmp to speedup - */ - void sgdUpdateWithOMP(VectorT& gradVec, - VectorT& momVec, - T learningRate, - T momentum, - T decayRate) { - size_t size = this->getSize(); - T* val = this->getData(); - T* grd = gradVec.getData(); - T* mom = momVec.getData(); - decayRate *= learningRate; -#pragma omp parallel for - for (size_t i = 0; i < size; ++i) { - mom[i] = momentum * mom[i] - learningRate * grd[i] - decayRate * val[i]; - val[i] += mom[i]; - } - } -#endif - virtual void zeroMem() = 0; // set all elements to value virtual void reset(const T& value) = 0; diff --git a/paddle/parameter/FirstOrderOptimizer.h b/paddle/parameter/FirstOrderOptimizer.h index 73e09aee2366bed095be532ab11f3c0d40f6d01f..895e8d6a63d1fad0ee7a6f5647402435d418b2f1 100644 --- a/paddle/parameter/FirstOrderOptimizer.h +++ b/paddle/parameter/FirstOrderOptimizer.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "ParameterOptimizer.h" +#include "ParameterUpdateFunctions.h" #include "Regularizer.h" namespace paddle { @@ -38,13 +39,13 @@ public: ? 1.0 - paraConfig.momentum() : 1.0; #ifdef PADDLE_USE_MKLDNN - vecs[PARAMETER_VALUE]->sgdUpdateWithOMP( - *vecs[PARAMETER_GRADIENT], - *vecs[PARAMETER_MOMENTUM], - learningRate_ * paraConfig.learning_rate() * - (firstTime_ ? 1.0 : torch_learningRate), - paraConfig.momentum(), - applyDecay_ ? paraConfig.decay_rate() : 0); + sgdUpdate(learningRate_ * paraConfig.learning_rate() * + (firstTime_ ? 1.0 : torch_learningRate), + paraConfig.momentum(), + applyDecay_ ? paraConfig.decay_rate() : 0, + vecs[PARAMETER_VALUE].get(), + vecs[PARAMETER_GRADIENT].get(), + vecs[PARAMETER_MOMENTUM].get()); #else vecs[PARAMETER_VALUE]->sgdUpdate( *vecs[PARAMETER_GRADIENT],