From af5d954bdf70b553186539621baf8badcb9940c8 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 19 Dec 2016 15:03:13 +0800 Subject: [PATCH] Clean BatchNorm Code. --- .../layers/BatchNormalizationLayer.cpp | 26 ++++++------------- paddle/math/Matrix.h | 4 +-- 2 files changed, 10 insertions(+), 20 deletions(-) diff --git a/paddle/gserver/layers/BatchNormalizationLayer.cpp b/paddle/gserver/layers/BatchNormalizationLayer.cpp index e6a06246363..412762d3847 100644 --- a/paddle/gserver/layers/BatchNormalizationLayer.cpp +++ b/paddle/gserver/layers/BatchNormalizationLayer.cpp @@ -59,24 +59,14 @@ void BatchNormalizationLayer::calMeanAndStd(const MatrixPtr& mat) { void BatchNormalizationLayer::calMovingMeanAndVar() { // calculating and saving moving mean and variance - MatrixPtr movingMean = movingMean_->getW(); - MatrixPtr movingVar = movingVar_->getW(); - - if (!useGpu_ && FLAGS_trainer_count > 1) { - auto mvMean = std::dynamic_pointer_cast(movingMean); - auto mvVar = std::dynamic_pointer_cast(movingVar); - CHECK(mvMean && mvVar); - - mvMean->add(*savedMean_, movingAvgFraction_, 1.0 - movingAvgFraction_); - mvVar->add(*savedInvVar_, movingAvgFraction_, 1.0 - movingAvgFraction_); - } else { - // movingMean = movingMean * movingAvgFraction_ - // + savedMean_ * (1 - movingAvgFraction_) - movingMean->add(*savedMean_, movingAvgFraction_, 1.0 - movingAvgFraction_); - // movingVar = movingVar * movingAvgFraction_ - // + savedInvVar_ * (1 - movingAvgFraction_) - movingVar->add(*savedInvVar_, movingAvgFraction_, 1.0 - movingAvgFraction_); - } + auto& movingMean = movingMean_->getW(); + auto& movingVar = movingVar_->getW(); + // movingMean = movingMean * movingAvgFraction_ + // + savedMean_ * (1 - movingAvgFraction_) + movingMean->add(*savedMean_, movingAvgFraction_, 1.0 - movingAvgFraction_); + // movingVar = movingVar * movingAvgFraction_ + // + savedInvVar_ * (1 - movingAvgFraction_) + movingVar->add(*savedInvVar_, movingAvgFraction_, 1.0 - movingAvgFraction_); } void BatchNormalizationLayer::setMeanAndStd() { diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 5685cb7bcbb..1cfb90a9dbf 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -1973,8 +1973,8 @@ public: public: virtual void mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB, real scaleT); - void add(Matrix& b, real p1, real p2); - void add(real p1, real p2); + virtual void add(Matrix& b, real p1, real p2); + virtual void add(real p1, real p2); private: using Matrix::mul; -- GitLab