提交 70394792 编写于 作者: T tensor-tang

refine comment and code

上级 88452186
...@@ -109,19 +109,10 @@ void MKLDNNBatchNormLayer::convertWeightsFromPaddle() { ...@@ -109,19 +109,10 @@ void MKLDNNBatchNormLayer::convertWeightsFromPaddle() {
void MKLDNNBatchNormLayer::calMovingMeanAndVar() { void MKLDNNBatchNormLayer::calMovingMeanAndVar() {
// calculating and saving moving mean and variance // calculating and saving moving mean and variance
CHECK_EQ(useGlobalStats_, false); CHECK_EQ(useGlobalStats_, false);
MatrixPtr movingMean = movingMean_->getW(); movingMean_->getW()->add(
MatrixPtr movingVar = movingVar_->getW(); *mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
if (FLAGS_trainer_count > 1) { // here var is v^2
auto mvMean = std::dynamic_pointer_cast<SharedCpuMatrix>(movingMean); movingVar_->getW()->add(*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
auto mvVar = std::dynamic_pointer_cast<SharedCpuMatrix>(movingVar);
CHECK(mvMean && mvVar);
mvMean->add(*mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
mvVar->add(*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
} else {
movingMean->add(*mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
// here var is v^2
movingVar->add(*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
}
} }
void MKLDNNBatchNormLayer::reshape( void MKLDNNBatchNormLayer::reshape(
...@@ -142,8 +133,9 @@ void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline, ...@@ -142,8 +133,9 @@ void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
// in training always calculate mean and var, so useGlobalStats must be false // In training phase, it will always calculate mean and var,
// in test depends on useGlobalStats // so useGlobalStats must be false.
// In scoring phase, it depends on useGlobalStats choice.
if (passType_ != PASS_TEST && useGlobalStats_ == true) { if (passType_ != PASS_TEST && useGlobalStats_ == true) {
LOG(WARNING) << "use_global_stats is invalid setting in training phase"; LOG(WARNING) << "use_global_stats is invalid setting in training phase";
useGlobalStats_ = false; useGlobalStats_ = false;
...@@ -173,7 +165,7 @@ void MKLDNNBatchNormLayer::resetBwd(std::vector<primitive>& pipeline, ...@@ -173,7 +165,7 @@ void MKLDNNBatchNormLayer::resetBwd(std::vector<primitive>& pipeline,
void MKLDNNBatchNormLayer::forward(PassType passType) { void MKLDNNBatchNormLayer::forward(PassType passType) {
MKLDNNLayer::forward(passType); MKLDNNLayer::forward(passType);
// calculating and saving moving mean and variance // calculate and save moving mean and variance
if (passType_ != PASS_TEST) { if (passType_ != PASS_TEST) {
calMovingMeanAndVar(); calMovingMeanAndVar();
} }
......
...@@ -56,8 +56,10 @@ protected: ...@@ -56,8 +56,10 @@ protected:
bool hasInitedWgt_; bool hasInitedWgt_;
// local mean and variance // local mean and variance
MKLDNNMatrixPtr mean_; // output of mkldnn: m // when useGlobalStats_ they are loaded from moving mean and variance
MKLDNNMatrixPtr var_; // output of mkldnn: v^2 // when do not useGlobalStats_ they are calculated from this mini-batch
MKLDNNMatrixPtr mean_;
MKLDNNMatrixPtr var_;
public: public:
explicit MKLDNNBatchNormLayer(const LayerConfig& config) explicit MKLDNNBatchNormLayer(const LayerConfig& config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册