From 88952fbab9a11fb4f939c9046699ba9ecfa64314 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 25 Sep 2017 20:03:01 +0800 Subject: [PATCH] use existed sgd updater function --- paddle/math/Vector.h | 22 ---------------------- paddle/parameter/FirstOrderOptimizer.h | 15 ++++++++------- 2 files changed, 8 insertions(+), 29 deletions(-) diff --git a/paddle/math/Vector.h b/paddle/math/Vector.h index 7dbf3cfb0d5..80b9775fccf 100644 --- a/paddle/math/Vector.h +++ b/paddle/math/Vector.h @@ -92,28 +92,6 @@ public: const T* getData() const { return this->data_; } T* getData() { return this->data_; } -#ifdef PADDLE_USE_MKLDNN - /** - * sgd update with openmp to speedup - */ - void sgdUpdateWithOMP(VectorT& gradVec, - VectorT& momVec, - T learningRate, - T momentum, - T decayRate) { - size_t size = this->getSize(); - T* val = this->getData(); - T* grd = gradVec.getData(); - T* mom = momVec.getData(); - decayRate *= learningRate; -#pragma omp parallel for - for (size_t i = 0; i < size; ++i) { - mom[i] = momentum * mom[i] - learningRate * grd[i] - decayRate * val[i]; - val[i] += mom[i]; - } - } -#endif - virtual void zeroMem() = 0; // set all elements to value virtual void reset(const T& value) = 0; diff --git a/paddle/parameter/FirstOrderOptimizer.h b/paddle/parameter/FirstOrderOptimizer.h index 73e09aee236..895e8d6a63d 100644 --- a/paddle/parameter/FirstOrderOptimizer.h +++ b/paddle/parameter/FirstOrderOptimizer.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "ParameterOptimizer.h" +#include "ParameterUpdateFunctions.h" #include "Regularizer.h" namespace paddle { @@ -38,13 +39,13 @@ public: ? 1.0 - paraConfig.momentum() : 1.0; #ifdef PADDLE_USE_MKLDNN - vecs[PARAMETER_VALUE]->sgdUpdateWithOMP( - *vecs[PARAMETER_GRADIENT], - *vecs[PARAMETER_MOMENTUM], - learningRate_ * paraConfig.learning_rate() * - (firstTime_ ? 1.0 : torch_learningRate), - paraConfig.momentum(), - applyDecay_ ? paraConfig.decay_rate() : 0); + sgdUpdate(learningRate_ * paraConfig.learning_rate() * + (firstTime_ ? 1.0 : torch_learningRate), + paraConfig.momentum(), + applyDecay_ ? paraConfig.decay_rate() : 0, + vecs[PARAMETER_VALUE].get(), + vecs[PARAMETER_GRADIENT].get(), + vecs[PARAMETER_MOMENTUM].get()); #else vecs[PARAMETER_VALUE]->sgdUpdate( *vecs[PARAMETER_GRADIENT], -- GitLab