diff --git a/paddle/gserver/layers/BatchNormalizationLayer.cpp b/paddle/gserver/layers/BatchNormalizationLayer.cpp index e6a0624636380e0e8ed5e6ee5066fbcf0439f507..412762d38475422be98ffeb87ffcfb028c3e035f 100644 --- a/paddle/gserver/layers/BatchNormalizationLayer.cpp +++ b/paddle/gserver/layers/BatchNormalizationLayer.cpp @@ -59,24 +59,14 @@ void BatchNormalizationLayer::calMeanAndStd(const MatrixPtr& mat) { void BatchNormalizationLayer::calMovingMeanAndVar() { // calculating and saving moving mean and variance - MatrixPtr movingMean = movingMean_->getW(); - MatrixPtr movingVar = movingVar_->getW(); - - if (!useGpu_ && FLAGS_trainer_count > 1) { - auto mvMean = std::dynamic_pointer_cast(movingMean); - auto mvVar = std::dynamic_pointer_cast(movingVar); - CHECK(mvMean && mvVar); - - mvMean->add(*savedMean_, movingAvgFraction_, 1.0 - movingAvgFraction_); - mvVar->add(*savedInvVar_, movingAvgFraction_, 1.0 - movingAvgFraction_); - } else { - // movingMean = movingMean * movingAvgFraction_ - // + savedMean_ * (1 - movingAvgFraction_) - movingMean->add(*savedMean_, movingAvgFraction_, 1.0 - movingAvgFraction_); - // movingVar = movingVar * movingAvgFraction_ - // + savedInvVar_ * (1 - movingAvgFraction_) - movingVar->add(*savedInvVar_, movingAvgFraction_, 1.0 - movingAvgFraction_); - } + auto& movingMean = movingMean_->getW(); + auto& movingVar = movingVar_->getW(); + // movingMean = movingMean * movingAvgFraction_ + // + savedMean_ * (1 - movingAvgFraction_) + movingMean->add(*savedMean_, movingAvgFraction_, 1.0 - movingAvgFraction_); + // movingVar = movingVar * movingAvgFraction_ + // + savedInvVar_ * (1 - movingAvgFraction_) + movingVar->add(*savedInvVar_, movingAvgFraction_, 1.0 - movingAvgFraction_); } void BatchNormalizationLayer::setMeanAndStd() { diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 5685cb7bcbbb6b90687790953d676e3792f36f36..1cfb90a9dbf19537f7f2ecd7d8ccea6ffe9929ef 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -1973,8 +1973,8 @@ public: public: virtual void mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB, real scaleT); - void add(Matrix& b, real p1, real p2); - void add(real p1, real p2); + virtual void add(Matrix& b, real p1, real p2); + virtual void add(real p1, real p2); private: using Matrix::mul;