提交 af5d954b 编写于 作者: Y Yu Yang

Clean BatchNorm Code.

上级 fefb3c13
...@@ -59,24 +59,14 @@ void BatchNormalizationLayer::calMeanAndStd(const MatrixPtr& mat) { ...@@ -59,24 +59,14 @@ void BatchNormalizationLayer::calMeanAndStd(const MatrixPtr& mat) {
void BatchNormalizationLayer::calMovingMeanAndVar() { void BatchNormalizationLayer::calMovingMeanAndVar() {
// calculating and saving moving mean and variance // calculating and saving moving mean and variance
MatrixPtr movingMean = movingMean_->getW(); auto& movingMean = movingMean_->getW();
MatrixPtr movingVar = movingVar_->getW(); auto& movingVar = movingVar_->getW();
// movingMean = movingMean * movingAvgFraction_
if (!useGpu_ && FLAGS_trainer_count > 1) { // + savedMean_ * (1 - movingAvgFraction_)
auto mvMean = std::dynamic_pointer_cast<SharedCpuMatrix>(movingMean); movingMean->add(*savedMean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
auto mvVar = std::dynamic_pointer_cast<SharedCpuMatrix>(movingVar); // movingVar = movingVar * movingAvgFraction_
CHECK(mvMean && mvVar); // + savedInvVar_ * (1 - movingAvgFraction_)
movingVar->add(*savedInvVar_, movingAvgFraction_, 1.0 - movingAvgFraction_);
mvMean->add(*savedMean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
mvVar->add(*savedInvVar_, movingAvgFraction_, 1.0 - movingAvgFraction_);
} else {
// movingMean = movingMean * movingAvgFraction_
// + savedMean_ * (1 - movingAvgFraction_)
movingMean->add(*savedMean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
// movingVar = movingVar * movingAvgFraction_
// + savedInvVar_ * (1 - movingAvgFraction_)
movingVar->add(*savedInvVar_, movingAvgFraction_, 1.0 - movingAvgFraction_);
}
} }
void BatchNormalizationLayer::setMeanAndStd() { void BatchNormalizationLayer::setMeanAndStd() {
......
...@@ -1973,8 +1973,8 @@ public: ...@@ -1973,8 +1973,8 @@ public:
public: public:
virtual void mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB, real scaleT); virtual void mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB, real scaleT);
void add(Matrix& b, real p1, real p2); virtual void add(Matrix& b, real p1, real p2);
void add(real p1, real p2); virtual void add(real p1, real p2);
private: private:
using Matrix::mul; using Matrix::mul;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册