提交 70394792 编写于 作者: T tensor-tang

refine comment and code

上级 88452186
......@@ -109,19 +109,10 @@ void MKLDNNBatchNormLayer::convertWeightsFromPaddle() {
void MKLDNNBatchNormLayer::calMovingMeanAndVar() {
// calculating and saving moving mean and variance
CHECK_EQ(useGlobalStats_, false);
MatrixPtr movingMean = movingMean_->getW();
MatrixPtr movingVar = movingVar_->getW();
if (FLAGS_trainer_count > 1) {
auto mvMean = std::dynamic_pointer_cast<SharedCpuMatrix>(movingMean);
auto mvVar = std::dynamic_pointer_cast<SharedCpuMatrix>(movingVar);
CHECK(mvMean && mvVar);
mvMean->add(*mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
mvVar->add(*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
} else {
movingMean->add(*mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
// here var is v^2
movingVar->add(*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
}
movingMean_->getW()->add(
*mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
// here var is v^2
movingVar_->getW()->add(*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
}
void MKLDNNBatchNormLayer::reshape(
......@@ -142,8 +133,9 @@ void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) {
// in training always calculate mean and var, so useGlobalStats must be false
// in test depends on useGlobalStats
// In training phase, it will always calculate mean and var,
// so useGlobalStats must be false.
// In scoring phase, it depends on useGlobalStats choice.
if (passType_ != PASS_TEST && useGlobalStats_ == true) {
LOG(WARNING) << "use_global_stats is invalid setting in training phase";
useGlobalStats_ = false;
......@@ -173,7 +165,7 @@ void MKLDNNBatchNormLayer::resetBwd(std::vector<primitive>& pipeline,
void MKLDNNBatchNormLayer::forward(PassType passType) {
MKLDNNLayer::forward(passType);
// calculating and saving moving mean and variance
// calculate and save moving mean and variance
if (passType_ != PASS_TEST) {
calMovingMeanAndVar();
}
......
......@@ -56,8 +56,10 @@ protected:
bool hasInitedWgt_;
// local mean and variance
MKLDNNMatrixPtr mean_; // output of mkldnn: m
MKLDNNMatrixPtr var_; // output of mkldnn: v^2
// when useGlobalStats_ they are loaded from moving mean and variance
// when do not useGlobalStats_ they are calculated from this mini-batch
MKLDNNMatrixPtr mean_;
MKLDNNMatrixPtr var_;
public:
explicit MKLDNNBatchNormLayer(const LayerConfig& config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册