提交 e823c956 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #947 from reyoung/feature/clean_bn_code

Clean BatchNorm Code.
...@@ -59,24 +59,14 @@ void BatchNormalizationLayer::calMeanAndStd(const MatrixPtr& mat) { ...@@ -59,24 +59,14 @@ void BatchNormalizationLayer::calMeanAndStd(const MatrixPtr& mat) {
void BatchNormalizationLayer::calMovingMeanAndVar() { void BatchNormalizationLayer::calMovingMeanAndVar() {
// calculating and saving moving mean and variance // calculating and saving moving mean and variance
MatrixPtr movingMean = movingMean_->getW(); auto& movingMean = movingMean_->getW();
MatrixPtr movingVar = movingVar_->getW(); auto& movingVar = movingVar_->getW();
if (!useGpu_ && FLAGS_trainer_count > 1) {
auto mvMean = std::dynamic_pointer_cast<SharedCpuMatrix>(movingMean);
auto mvVar = std::dynamic_pointer_cast<SharedCpuMatrix>(movingVar);
CHECK(mvMean && mvVar);
mvMean->add(*savedMean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
mvVar->add(*savedInvVar_, movingAvgFraction_, 1.0 - movingAvgFraction_);
} else {
// movingMean = movingMean * movingAvgFraction_ // movingMean = movingMean * movingAvgFraction_
// + savedMean_ * (1 - movingAvgFraction_) // + savedMean_ * (1 - movingAvgFraction_)
movingMean->add(*savedMean_, movingAvgFraction_, 1.0 - movingAvgFraction_); movingMean->add(*savedMean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
// movingVar = movingVar * movingAvgFraction_ // movingVar = movingVar * movingAvgFraction_
// + savedInvVar_ * (1 - movingAvgFraction_) // + savedInvVar_ * (1 - movingAvgFraction_)
movingVar->add(*savedInvVar_, movingAvgFraction_, 1.0 - movingAvgFraction_); movingVar->add(*savedInvVar_, movingAvgFraction_, 1.0 - movingAvgFraction_);
}
} }
void BatchNormalizationLayer::setMeanAndStd() { void BatchNormalizationLayer::setMeanAndStd() {
......
...@@ -1973,8 +1973,8 @@ public: ...@@ -1973,8 +1973,8 @@ public:
public: public:
virtual void mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB, real scaleT); virtual void mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB, real scaleT);
void add(Matrix& b, real p1, real p2); virtual void add(Matrix& b, real p1, real p2);
void add(real p1, real p2); virtual void add(real p1, real p2);
private: private:
using Matrix::mul; using Matrix::mul;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册