From b2bd67133aa609225ea46d12d1f091340ab000e4 Mon Sep 17 00:00:00 2001
From: tensor-tang <jian.j.tang@intel.com>
Date: Wed, 9 Aug 2017 22:52:47 +0800
Subject: [PATCH] rename and refine functions

---
 paddle/gserver/layers/MkldnnBase.h      |  16 +-
 paddle/gserver/layers/MkldnnFcLayer.cpp | 167 ++++++++++++++----
 paddle/gserver/layers/MkldnnFcLayer.h   |  21 ++-
 paddle/gserver/layers/MkldnnLayer.cpp   | 222 ------------------------
 paddle/gserver/layers/MkldnnLayer.h     |  78 ++++-----
 paddle/gserver/tests/MkldnnTester.cpp   |  22 ++-
 paddle/gserver/tests/MkldnnTester.h     |   4 +-
 paddle/gserver/tests/test_Mkldnn.cpp    |  13 +-
 python/paddle/trainer/config_parser.py  |   7 +-
 9 files changed, 217 insertions(+), 333 deletions(-)
 delete mode 100644 paddle/gserver/layers/MkldnnLayer.cpp

diff --git a/paddle/gserver/layers/MkldnnBase.h b/paddle/gserver/layers/MkldnnBase.h
index 260dbe45e44..63fd67a8508 100644
--- a/paddle/gserver/layers/MkldnnBase.h
+++ b/paddle/gserver/layers/MkldnnBase.h
@@ -19,12 +19,12 @@ limitations under the License. */
 namespace paddle {
 
 typedef enum {
-  DNN_BASE = 1,
-  DNN_TESTS = 1,
-  DNN_SIZES,
-  DNN_FMTS,
-  DNN_ALL,
-} DNN_LOG_LEVEL;
+  MKLDNN_BASE = 1,   // basical info of MKLDNN
+  MKLDNN_TESTS = 1,  // gtest info of MKLDNN
+  MKLDNN_SIZES = 2,  // size info of MKLDNN
+  MKLDNN_FMTS = 3,   // format info of MKLDNN
+  MKLDNN_ALL = 4,    // show all info of MKLDNN
+} MKLDNN_LOG_LEVEL;
 
 /**
  * @brief MKLDNN CPU engine.
@@ -68,7 +68,7 @@ public:
   /**
    * @brief Submit stream
    * @param prims The primitives vector
-   *        block Waiting for the stream to complete
+   * @param block Waiting for the stream to complete
    */
   void submit(std::vector<mkldnn::primitive>& prims, bool block = true) {
     resetState();
@@ -84,8 +84,8 @@ public:
       return;
     }
     // TODO(TJ): change me when mkldnn have method to reset this state
-    stream_.reset(new mkldnn::stream(mkldnn::stream::kind::eager));
     // stream_.reset(new mkldnn::stream(mkldnn::stream::kind::lazy));
+    stream_.reset(new mkldnn::stream(mkldnn::stream::kind::eager));
     ready_ = true;
   }
 
diff --git a/paddle/gserver/layers/MkldnnFcLayer.cpp b/paddle/gserver/layers/MkldnnFcLayer.cpp
index e4c4d4675d1..f89db169efa 100644
--- a/paddle/gserver/layers/MkldnnFcLayer.cpp
+++ b/paddle/gserver/layers/MkldnnFcLayer.cpp
@@ -16,6 +16,12 @@ limitations under the License. */
 #include "paddle/utils/Logging.h"
 #include "paddle/utils/Stat.h"
 
+using namespace mkldnn;  // NOLINT
+typedef memory::format format;
+typedef inner_product_forward fc_fwd;
+typedef inner_product_backward_weights fc_bwdWgt;
+typedef inner_product_backward_data fc_bwdData;
+
 namespace paddle {
 
 REGISTER_LAYER(mkldnn_fc, MkldnnFcLayer);
@@ -26,7 +32,7 @@ bool MkldnnFcLayer::init(const LayerMap& layerMap,
     return false;
   }
 
-  CHECK_EQ(inputLayers_.size(), 1) << "Only support one input layer yet!";
+  CHECK_EQ(inputLayers_.size(), 1) << "Only support one input layer yet";
   CHECK_EQ(inputLayers_.size(), parameters_.size());
   CHECK(!parameters_[0]->isSparse()) << "Do not support sparse yet";
 
@@ -63,14 +69,14 @@ void MkldnnFcLayer::convertWeightsFromPaddle() {
   MatrixPtr paddleWgt = Matrix::create(
       weight_->getW()->getData(), iLayerSize_, oc_, false, false);
 
+  // TODO(TJ): remove this print when do not need differ weights
   std::ostringstream ostr;
   paddleWgt->print(ostr);
-  VLOG(DNN_ALL) << "Initial Weight from paddle: " << std::endl << ostr.str();
+  VLOG(MKLDNN_ALL) << "Initial Weight from paddle: " << std::endl << ostr.str();
 
   // The mkldnn weight is transposed from initial paddle matrix
   MatrixPtr paddleWgtT;
   paddleWgt->transpose(paddleWgtT, true);
-
   weight_->getW()->copyFrom(*paddleWgtT);
   hasInitedWgt_ = true;
 }
@@ -101,6 +107,10 @@ void MkldnnFcLayer::reshape() {
   if (iw_ == 0) {
     iw_ = 1;
   }
+  hasSpatial_ = true;
+  if (ih_ == 1 && iw_ == 1) {
+    hasSpatial_ = false;
+  }
   CHECK_EQ(iLayerSize_, inputLayers_[0]->getSize());
   ic_ = iLayerSize_ / (ih_ * iw_);
   CHECK_EQ(size_t(ic_ * ih_ * iw_), iLayerSize_) << "not divisible";
@@ -111,6 +121,114 @@ void MkldnnFcLayer::reshape() {
   output_.setFrameHeight(oh_);
   output_.setFrameWidth(ow_);
   resetOutput(bs_, oc_);
+
+  // reset mkldnn forward
+  resetFwd();
+  needResetBwd_ = true;
+
+  convertWeightsFromPaddle();
+}
+
+void MkldnnFcLayer::resetFwd() {
+  bool hasBias = biases_ && biases_->getW();
+  real* iData = getInputValue(0)->getData();
+  real* oData = getOutputValue()->getData();
+  real* wData = weight_->getW()->getData();
+  real* bData = hasBias ? biases_->getW()->getData() : NULL;
+
+  // TODO(TJ): below create should be covered in MkldnnMatrix
+  // create memory desc
+  memory::desc iMD = hasSpatial_ ? createMD({bs_, ic_, ih_, iw_}, format::nchw)
+                                 : createMD({bs_, ic_}, format::nc);
+  memory::desc wMD = hasSpatial_ ? createMD({oc_, ic_, ih_, iw_}, format::oihw)
+                                 : createMD({oc_, ic_}, format::oi);
+  memory::desc bMD = bData != NULL ? createMD({oc_}, format::x)
+                                   : createMD({}, format::format_undef);
+  memory::desc oMD = createMD({bs_, oc_}, format::nc);
+
+  // create memory primitive desc and memory self
+  inVal_.reset(new memory(memory::primitive_desc(iMD, engine_), iData));
+  wgtVal_.reset(new memory(memory::primitive_desc(wMD, engine_), wData));
+  outVal_.reset(new memory(memory::primitive_desc(oMD, engine_), oData));
+
+  prop_kind pk = prop_kind::forward;
+  fc_fwd::desc fwdDesc = bData != NULL ? fc_fwd::desc(pk, iMD, wMD, bMD, oMD)
+                                       : fc_fwd::desc(pk, iMD, wMD, oMD);
+  fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
+
+  if (bData != NULL) {
+    biasVal_.reset(new memory(memory::primitive_desc(bMD, engine_), bData));
+    fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *biasVal_, *outVal_));
+  } else {
+    fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *outVal_));
+  }
+  pipelineFwd_.clear();
+  pipelineFwd_.push_back(*fwd_);
+}
+
+void MkldnnFcLayer::resetBwd() {
+  if (!needResetBwd_) {
+    return;
+  }
+  needResetBwd_ = false;
+
+  bool hasBias = biases_ && biases_->getWGrad();
+  real* iData = getInputValue(0)->getData();
+  real* iDiff = getInputGrad(0) != nullptr ? getInputGrad(0)->getData() : NULL;
+  real* oDiff = getOutputGrad()->getData();
+  real* wDiff = weight_->getWGrad()->getData();
+  real* bDiff = hasBias ? biases_->getWGrad()->getData() : NULL;
+
+  /// backward weight
+  // create memory desc for backward memory
+  memory::desc iMD = hasSpatial_ ? createMD({bs_, ic_, ih_, iw_}, format::nchw)
+                                 : createMD({bs_, ic_}, format::nc);
+  memory::desc wMD = hasSpatial_ ? createMD({oc_, ic_, ih_, iw_}, format::oihw)
+                                 : createMD({oc_, ic_}, format::oi);
+  memory::desc oMD = createMD({bs_, oc_}, format::nc);
+  memory::desc bMD = bDiff != NULL ? createMD({oc_}, format::x)
+                                   : createMD({}, format::format_undef);
+
+  if (inVal_) {
+    // update data
+    inVal_->set_data_handle(iData);
+  } else {
+    inVal_.reset(new memory(memory::primitive_desc(iMD, engine_), iData));
+  }
+
+  // create memory primitive desc and memory self
+  wgtGrad_.reset(new memory(memory::primitive_desc(wMD, engine_), wDiff));
+  outGrad_.reset(new memory(memory::primitive_desc(oMD, engine_), oDiff));
+
+  fc_fwd::desc fwdDesc = fc_fwd::desc(prop_kind::forward, iMD, wMD, oMD);
+  fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
+  fc_bwdWgt::desc bwdWgtDesc = bDiff != NULL
+                                   ? fc_bwdWgt::desc(iMD, wMD, bMD, oMD)
+                                   : fc_bwdWgt::desc(iMD, wMD, oMD);
+  fc_bwdWgt::primitive_desc bwdWgtPD =
+      fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD);
+
+  if (bDiff != NULL) {
+    biasGrad_.reset(new memory(memory::primitive_desc(bMD, engine_), bDiff));
+    bwdWgt_.reset(
+        new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_, *biasGrad_));
+  } else {
+    bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_));
+  }
+  pipelineBwd_.clear();
+  pipelineBwd_.push_back(*bwdWgt_);
+
+  /// backward data
+  if (iDiff == NULL) {
+    return;
+  }
+  fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(iMD, wMD, oMD);
+  fc_bwdData::primitive_desc bwdDataPD =
+      fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD);
+  inGrad_.reset(new memory(memory::primitive_desc(iMD, engine_), iDiff));
+  CHECK(wgtVal_) << "Should have weight memory";
+  bwdData_.reset(new fc_bwdData(bwdDataPD, *outGrad_, *wgtVal_, *inGrad_));
+  pipelineBwd_.push_back(*bwdData_);
 }
 
 void MkldnnFcLayer::forward(PassType passType) {
@@ -119,12 +237,14 @@ void MkldnnFcLayer::forward(PassType passType) {
 
   {
     REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str());
-    real* input = getInputValue(0)->getData();
-    real* output = getOutputValue()->getData();
-    real* wgt = weight_->getW()->getData();
-    bool hasBias = biases_ && biases_->getW();
-    real* bias = hasBias ? biases_->getW()->getData() : NULL;
-    mkldnnForwardFC(bs_, ic_, ih_, iw_, input, oc_, output, wgt, bias);
+
+    // update input data
+    // since it might be changed if this is after data layer
+    real* iData = getInputValue(0)->getData();
+    inVal_->set_data_handle(iData);
+
+    // just submit forward pipeline
+    stream_->submit(pipelineFwd_);
   }
 
   /* activation */ {
@@ -139,33 +259,22 @@ void MkldnnFcLayer::backward(const UpdateCallback& callback) {
     backwardActivation();
   }
 
-  bool hasBias = biases_ && biases_->getWGrad();
   {
     REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str());
-    real* inVal = getInputValue(0)->getData();
-    real* inGrad =
-        getInputGrad(0) != nullptr ? getInputGrad(0)->getData() : NULL;
-    real* outGrad = getOutputGrad()->getData();
-    real* wgtGrad = weight_->getWGrad()->getData();
-    real* wgtVal = weight_->getW()->getData();
-    real* biasGrad = hasBias ? biases_->getWGrad()->getData() : NULL;
-    mkldnnBackwardFC(bs_,
-                     ic_,
-                     ih_,
-                     iw_,
-                     inGrad,
-                     inVal,
-                     oc_,
-                     outGrad,
-                     wgtGrad,
-                     wgtVal,
-                     biasGrad);
+    resetBwd();
+
+    // update diff
+    real* oDiff = getOutputGrad()->getData();
+    outGrad_->set_data_handle(oDiff);
+
+    // just sumbmit backward pipeline
+    stream_->submit(pipelineBwd_);
   }
 
   {
     REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
     weight_->getParameterPtr()->incUpdate(callback);
-    if (hasBias) {
+    if (biases_ && biases_->getWGrad()) {
       biases_->getParameterPtr()->incUpdate(callback);
     }
   }
diff --git a/paddle/gserver/layers/MkldnnFcLayer.h b/paddle/gserver/layers/MkldnnFcLayer.h
index f8910522849..c4c0fa1c41f 100644
--- a/paddle/gserver/layers/MkldnnFcLayer.h
+++ b/paddle/gserver/layers/MkldnnFcLayer.h
@@ -30,6 +30,7 @@ protected:
   size_t iLayerSize_;  // == ic * ih * iw
 
   bool hasInitedWgt_;
+  bool hasSpatial_;
 
   // fc weight and bias
   std::unique_ptr<Weight> weight_;
@@ -37,7 +38,7 @@ protected:
 
 public:
   explicit MkldnnFcLayer(const LayerConfig& config)
-      : MkldnnLayer(config), hasInitedWgt_(false) {}
+      : MkldnnLayer(config), hasInitedWgt_(false), hasSpatial_(true) {}
 
   ~MkldnnFcLayer() {}
 
@@ -52,7 +53,25 @@ public:
 
   void backward(const UpdateCallback& callback) override;
 
+protected:
+  /**
+   * reshape the input image sizes
+   * and reset output buffer size
+   * and reset mkldnn forward
+   */
   void reshape();
+
+  /**
+   * reset the forward primitve and memory
+   * only would be called when input size changes
+   */
+  void resetFwd();
+
+  /**
+   * reset the backward primitve and memory for mkldnn fc
+   * only would be called when needed
+   */
+  void resetBwd();
 };
 
 }  // namespace paddle
diff --git a/paddle/gserver/layers/MkldnnLayer.cpp b/paddle/gserver/layers/MkldnnLayer.cpp
deleted file mode 100644
index 6bd2b15a171..00000000000
--- a/paddle/gserver/layers/MkldnnLayer.cpp
+++ /dev/null
@@ -1,222 +0,0 @@
-/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#include "MkldnnLayer.h"
-
-using mem = mkldnn::memory;  // NOLINT
-typedef mem::format format;
-typedef mkldnn::inner_product_forward fc_fwd;
-typedef mkldnn::inner_product_backward_weights fc_bwdWgt;
-typedef mkldnn::inner_product_backward_data fc_bwdData;
-
-namespace paddle {
-
-bool MkldnnLayer::init(const LayerMap& layerMap,
-                       const ParameterMap& parameterMap) {
-  if (!Layer::init(layerMap, parameterMap)) {
-    return false;
-  }
-
-  CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn."
-                          << "Please set WITH_MKLDNN=ON "
-                          << "and set use_mkldnn=True";
-  stream_.reset(new MkldnnStream());
-  engine_ = CpuEngine::Instance().getEngine();
-
-  // TODO(TJ): deivecId
-  return true;
-}
-
-void MkldnnLayer::resetForwardFC(int bs,
-                                 int ic,
-                                 int ih,
-                                 int iw,
-                                 real* botData,
-                                 int oc,
-                                 real* topData,
-                                 real* wgtData,
-                                 real* biasData) {
-  bool hasSpatial = ih == 1 && iw == 1 ? false : true;
-  mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw)
-                               : createMD({bs, ic}, format::nc);
-  mem::desc wgtMD = hasSpatial ? createMD({oc, ic, ih, iw}, format::oihw)
-                               : createMD({oc, ic}, format::oi);
-  mem::desc biasMD = biasData != NULL ? createMD({oc}, format::x)
-                                      : createMD({}, format::format_undef);
-  mem::desc topMD = createMD({bs, oc}, format::nc);
-
-  mem::primitive_desc botPD = mem::primitive_desc(botMD, engine_);
-  if (inVal_ && inVal_->get_primitive_desc() == botPD) {
-    return;
-  }
-
-  inVal_.reset(new mem(botPD, botData));
-  wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
-  outVal_.reset(new mem(mem::primitive_desc(topMD, engine_), topData));
-
-  mkldnn::prop_kind pk = mkldnn::prop_kind::forward;
-  fc_fwd::desc fwdDesc = biasData != NULL
-                             ? fc_fwd::desc(pk, botMD, wgtMD, biasMD, topMD)
-                             : fc_fwd::desc(pk, botMD, wgtMD, topMD);
-  fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
-
-  if (biasData != NULL) {
-    biasVal_.reset(new mem(mem::primitive_desc(biasMD, engine_), biasData));
-    fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *biasVal_, *outVal_));
-  } else {
-    fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *outVal_));
-  }
-  pipelineFwd_.clear();
-  pipelineFwd_.push_back(*fwd_);
-}
-
-void MkldnnLayer::mkldnnForwardFC(int bs,
-                                  int ic,
-                                  int ih,
-                                  int iw,
-                                  real* botData,
-                                  int oc,
-                                  real* topData,
-                                  real* wgtData,
-                                  real* biasData) {
-  // if input size changed, reset it
-  resetForwardFC(bs, ic, ih, iw, botData, oc, topData, wgtData, biasData);
-
-  this->convertWeightsFromPaddle();
-
-  // update input, since the data might be changed if this is after data layer
-  inVal_->set_data_handle(botData);
-
-  // just forward
-  stream_->submit(pipelineFwd_);
-}
-
-void MkldnnLayer::resetBackwardFC(int bs,
-                                  int ic,
-                                  int ih,
-                                  int iw,
-                                  real* botDiff,
-                                  real* botData,
-                                  int oc,
-                                  real* topDiff,
-                                  real* wgtDiff,
-                                  real* wgtData,
-                                  real* biasDiff) {
-  bool hasSpatial = ih == 1 && iw == 1 ? false : true;
-
-  // backward weight
-  mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw)
-                               : createMD({bs, ic}, format::nc);
-  mem::desc wgtMD = hasSpatial ? createMD({oc, ic, ih, iw}, format::oihw)
-                               : createMD({oc, ic}, format::oi);
-  mem::desc topMD = createMD({bs, oc}, format::nc);
-  mem::desc biasMD = biasDiff != NULL ? createMD({oc}, format::x)
-                                      : createMD({}, format::format_undef);
-
-  mem::primitive_desc topPD = mem::primitive_desc(botMD, engine_);
-  if (outGrad_ && outGrad_->get_primitive_desc() == topPD) {
-    return;
-  }
-
-  if (inVal_) {
-    // update data
-    inVal_->set_data_handle(botData);
-  } else {
-    inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData));
-  }
-  wgtGrad_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtDiff));
-  outGrad_.reset(new mem(topPD, topDiff));
-
-  fc_fwd::desc fwdDesc =
-      fc_fwd::desc(mkldnn::prop_kind::forward, botMD, wgtMD, topMD);
-  fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
-  fc_bwdWgt::desc bwdWgtDesc =
-      biasDiff != NULL ? fc_bwdWgt::desc(botMD, wgtMD, biasMD, topMD)
-                       : fc_bwdWgt::desc(botMD, wgtMD, topMD);
-  fc_bwdWgt::primitive_desc bwdWgtPD =
-      fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD);
-
-  if (biasDiff != NULL) {
-    biasGrad_.reset(new mem(mem::primitive_desc(biasMD, engine_), biasDiff));
-    bwdWgt_.reset(
-        new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_, *biasGrad_));
-  } else {
-    bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_));
-  }
-  pipelineBwd_.clear();
-  pipelineBwd_.push_back(*bwdWgt_);
-
-  // backward data
-  if (botDiff == NULL) {
-    return;
-  }
-
-  fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(botMD, wgtMD, topMD);
-  fc_bwdData::primitive_desc bwdDataPD =
-      fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD);
-  inGrad_.reset(new mem(mem::primitive_desc(botMD, engine_), botDiff));
-  if (wgtVal_) {
-    // update data
-    wgtVal_->set_data_handle(wgtData);
-  } else {
-    wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
-  }
-  bwdData_.reset(new fc_bwdData(bwdDataPD, *outGrad_, *wgtVal_, *inGrad_));
-  pipelineBwd_.push_back(*bwdData_);
-}
-
-void MkldnnLayer::mkldnnBackwardFC(int bs,
-                                   int ic,
-                                   int ih,
-                                   int iw,
-                                   real* botDiff,
-                                   real* botData,
-                                   int oc,
-                                   real* topDiff,
-                                   real* wgtDiff,
-                                   real* wgtData,
-                                   real* biasDiff) {
-  // if input size changed, reset it
-  resetBackwardFC(bs,
-                  ic,
-                  ih,
-                  iw,
-                  botDiff,
-                  botData,
-                  oc,
-                  topDiff,
-                  wgtDiff,
-                  wgtData,
-                  biasDiff);
-
-  // update data
-  outGrad_->set_data_handle(topDiff);
-
-  stream_->submit(pipelineBwd_);
-}
-
-void MkldnnLayer::printSizeInfo() {
-  VLOG(DNN_SIZES) << getName() << ": bs: " << bs_ << ", ic: " << ic_
-                  << ", ih: " << ih_ << ", iw: " << iw_ << ", oc: " << oc_
-                  << ", oh: " << oh_ << ", ow: " << ow_;
-}
-
-mem::desc MkldnnLayer::createMD(mem::dims dims,
-                                mem::format fmt,
-                                mem::data_type type) {
-  // TODO(TJ): isFmtSuppoted(fmt)
-  return mem::desc(dims, type, fmt);
-}
-
-}  // namespace paddle
diff --git a/paddle/gserver/layers/MkldnnLayer.h b/paddle/gserver/layers/MkldnnLayer.h
index e5c93500c75..620bdfc9848 100644
--- a/paddle/gserver/layers/MkldnnLayer.h
+++ b/paddle/gserver/layers/MkldnnLayer.h
@@ -40,6 +40,9 @@ protected:
   // output image channel, height and width
   int oc_, oh_, ow_;
 
+  // backward also need reset after reset forward handle
+  bool needResetBwd_;
+
   // mkldnn engine, stream and primivtives
   mkldnn::engine engine_;
   std::shared_ptr<MkldnnStream> stream_;
@@ -50,8 +53,6 @@ protected:
   std::vector<mkldnn::primitive> pipelineBwd_;
 
   // TODO(TJ): change below memory as MkldnnMatrixPtr type
-  // input == bottom, output == top
-  // value == data, grad == diff
   std::shared_ptr<mkldnn::memory> inVal_;
   std::shared_ptr<mkldnn::memory> inGrad_;
   std::shared_ptr<mkldnn::memory> outVal_;
@@ -71,6 +72,7 @@ public:
         oc_(0),
         oh_(0),
         ow_(0),
+        needResetBwd_(true),
         engine_(mkldnn::engine::cpu, 0),
         stream_(nullptr),
         fwd_(nullptr),
@@ -79,9 +81,21 @@ public:
 
   ~MkldnnLayer() {}
 
-  virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
+  virtual bool init(const LayerMap& layerMap,
+                    const ParameterMap& parameterMap) {
+    if (!Layer::init(layerMap, parameterMap)) {
+      return false;
+    }
+
+    CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn."
+                            << "Please set WITH_MKLDNN=ON "
+                            << "and set use_mkldnn=True";
+    stream_.reset(new MkldnnStream());
+    engine_ = CpuEngine::Instance().getEngine();
 
-  virtual void printSizeInfo();
+    // TODO(TJ): deivecId
+    return true;
+  }
 
   /**
    * convert weight from paddle format to mkldnn format
@@ -95,56 +109,24 @@ public:
    */
   virtual void convertWeightsToPaddle() {}
 
-  void resetForwardFC(int bs,
-                      int ic,
-                      int ih,
-                      int iw,
-                      real* botData,
-                      int oc,
-                      real* topData,
-                      real* wgtData,
-                      real* biasData);
-
-  void mkldnnForwardFC(int bs,
-                       int ic,
-                       int ih,
-                       int iw,
-                       real* botData,
-                       int oc,
-                       real* topData,
-                       real* wgtData,
-                       real* biasData);
-
-  void resetBackwardFC(int bs,
-                       int ic,
-                       int ih,
-                       int iw,
-                       real* botDiff,
-                       real* botData,
-                       int oc,
-                       real* topDiff,
-                       real* wgtDiff,
-                       real* wgtData,
-                       real* biasDiff);
-
-  void mkldnnBackwardFC(int bs,
-                        int ic,
-                        int ih,
-                        int iw,
-                        real* botDiff,
-                        real* botData,
-                        int oc,
-                        real* topDiff,
-                        real* wgtDiff,
-                        real* wgtData,
-                        real* biasDiff);
+  /**
+   * print info about sizes
+   */
+  virtual void printSizeInfo() {
+    VLOG(MKLDNN_SIZES) << getName() << ": bs: " << bs_ << ", ic: " << ic_
+                       << ", ih: " << ih_ << ", iw: " << iw_ << ", oc: " << oc_
+                       << ", oh: " << oh_ << ", ow: " << ow_;
+  }
 
   // TODO(TJ): move to MkldnnMatrix
   // create memory desc
   inline mkldnn::memory::desc createMD(
       mkldnn::memory::dims dims,
       mkldnn::memory::format fmt,
-      mkldnn::memory::data_type type = mkldnn::memory::data_type::f32);
+      mkldnn::memory::data_type type = mkldnn::memory::data_type::f32) {
+    // TODO(TJ): isFmtSuppoted(fmt)
+    return mkldnn::memory::desc(dims, type, fmt);
+  }
 };
 
 }  // namespace paddle
diff --git a/paddle/gserver/tests/MkldnnTester.cpp b/paddle/gserver/tests/MkldnnTester.cpp
index 59b3861df81..9232e2fdcd8 100644
--- a/paddle/gserver/tests/MkldnnTester.cpp
+++ b/paddle/gserver/tests/MkldnnTester.cpp
@@ -118,7 +118,7 @@ void MkldnnTester::checkForward() {
   printTopDatas();
   double delta = compareMatrix(testLayers_[DNN]->getOutputValue(),
                                testLayers_[REF]->getOutputValue());
-  VLOG(DNN_ALL) << "Check Forward";
+  VLOG(MKLDNN_ALL) << "Check Forward";
   EXPECT_LE(fabs(delta), eps_);
 }
 
@@ -162,7 +162,7 @@ void MkldnnTester::checkBackwardWgts() {
     EXPECT_LE(fabs(delta), eps_);
   }
 
-  VLOG(DNN_ALL) << "Restore dnn weights before comapre";
+  VLOG(MKLDNN_ALL) << "Restore dnn weights before comapre";
   restoreWgt(dnnWgts, parameters_[DNN]);
 }
 
@@ -275,8 +275,8 @@ double MkldnnTester::getDelta(const real* d1,
   EXPECT_TRUE(std::isnormal(sum));
   EXPECT_FALSE(std::isinf(sum));
   EXPECT_FALSE(std::isnan(delta));
-  VLOG(DNN_ALL) << "reference avg data: " << sum / len
-                << ", delta: " << delta / sum << ", failCnt:" << failCnt;
+  VLOG(MKLDNN_ALL) << "reference avg data: " << sum / len
+                   << ", delta: " << delta / sum << ", failCnt:" << failCnt;
   return (failCnt / (float)len) > failRate ? maxOut : delta / sum;
 }
 
@@ -306,10 +306,8 @@ void MkldnnTester::runOnce() {
 
   // clear buffers
   // ref code will addto the diff, dnn code will writeto it
+  // and clearTopDatas() and clearWgtDiffs() should be coverd by test layers
   clearBotDiffs(REF);
-  // below two should be coverd by test layers
-  // clearTopDatas();
-  // clearWgtDiffs();
 }
 
 void MkldnnTester::run(const TestConfig& dnn,
@@ -321,8 +319,8 @@ void MkldnnTester::run(const TestConfig& dnn,
                        float epsilon,
                        bool log,
                        int level) {
-  VLOG(DNN_TESTS) << "Test MKLDNN functionality: " << dnn.layerConfig.type()
-                  << " vs " << ref.layerConfig.type();
+  VLOG(MKLDNN_TESTS) << "Test MKLDNN functionality: " << dnn.layerConfig.type()
+                     << " vs " << ref.layerConfig.type();
   ih_ = inputImgH;
   iw_ = inputImgW;
   iter_ = iter;
@@ -338,14 +336,14 @@ void MkldnnTester::run(const TestConfig& dnn,
   clearWgtDiffs();
   clearBotDiffs();
   for (size_t i = 0; i < iter_; ++i) {
-    VLOG(DNN_TESTS) << "Check Iteration " << i;
+    VLOG(MKLDNN_TESTS) << "Check Iteration " << i;
     runOnce();
   }
 
   // Then test FLAGS_use_mkldnn_wgt = true
   FLAGS_use_mkldnn_wgt = true;
   // after run once the mkldnn weight has been stored in dnnlayer
-  // then save the weigths and restart again
+  // then save the weights and restart again
   vector<VectorPtr> dnnWgts, refWgts;
   CHECK_EQ(parameters_[DNN].size(), parameters_[REF].size());
   saveWgt(parameters_[DNN], dnnWgts);
@@ -361,7 +359,7 @@ void MkldnnTester::run(const TestConfig& dnn,
   clearBotDiffs();
 
   for (size_t i = 0; i < iter_; ++i) {
-    VLOG(DNN_TESTS) << "Check Iteration " << i;
+    VLOG(MKLDNN_TESTS) << "Check Iteration " << i;
     runOnce();
   }
 }
diff --git a/paddle/gserver/tests/MkldnnTester.h b/paddle/gserver/tests/MkldnnTester.h
index 8b3049b5c26..7d1db870d12 100644
--- a/paddle/gserver/tests/MkldnnTester.h
+++ b/paddle/gserver/tests/MkldnnTester.h
@@ -58,7 +58,7 @@ public:
     iter_ = iter;
     eps_ = epsilon;
     log_ = false;
-    lvl_ = DNN_ALL;
+    lvl_ = MKLDNN_ALL;
   }
 
   ~MkldnnTester() {}
@@ -72,7 +72,7 @@ public:
            size_t iter = 3,
            float epsilon = 1e-4,
            bool log = false,
-           int level = DNN_ALL);
+           int level = MKLDNN_ALL);
   void setLogLevel(int lvl) { lvl_ = lvl; }
 
 private:
diff --git a/paddle/gserver/tests/test_Mkldnn.cpp b/paddle/gserver/tests/test_Mkldnn.cpp
index 0516a059de0..8e4a8595d3c 100644
--- a/paddle/gserver/tests/test_Mkldnn.cpp
+++ b/paddle/gserver/tests/test_Mkldnn.cpp
@@ -23,7 +23,6 @@ using namespace paddle;  // NOLINT
 DECLARE_bool(thread_local_rand_use_global_seed);
 DECLARE_bool(use_gpu);
 DECLARE_bool(use_mkldnn);
-DECLARE_bool(use_mkldnn_wgt);
 
 struct testFCDesc {
   int bs;
@@ -56,12 +55,12 @@ void testFcLayer(const testFCDesc& pm) {
 }
 
 TEST(MkldnnLayer, fcLayer) {
-  testFcLayer({2, 2, 3, 1, 1});
-  testFcLayer({3, 7, 19, 1, 1});
-  testFcLayer({8, 16, 32, 13, 13});
-  testFcLayer({4, 12, 18, 13, 11});
-  testFcLayer({2, 64, 32, 16, 16});
-  testFcLayer({15, 3, 6, 16, 16});
+  testFcLayer({/*bs*/ 2, /*ic*/ 2, /*oc*/ 3, /*ih*/ 1, /*iw*/ 1});
+  testFcLayer({/*bs*/ 3, /*ic*/ 7, /*oc*/ 19, /*ih*/ 1, /*iw*/ 1});
+  testFcLayer({/*bs*/ 8, /*ic*/ 16, /*oc*/ 32, /*ih*/ 13, /*iw*/ 13});
+  testFcLayer({/*bs*/ 4, /*ic*/ 12, /*oc*/ 18, /*ih*/ 13, /*iw*/ 11});
+  testFcLayer({/*bs*/ 2, /*ic*/ 64, /*oc*/ 32, /*ih*/ 16, /*iw*/ 16});
+  testFcLayer({/*bs*/ 15, /*ic*/ 3, /*oc*/ 6, /*ih*/ 16, /*iw*/ 16});
 }
 
 // TODO(TJ): add branch test
diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py
index 3213df51860..da99e5bd534 100644
--- a/python/paddle/trainer/config_parser.py
+++ b/python/paddle/trainer/config_parser.py
@@ -1626,15 +1626,14 @@ class FCLayer(LayerBase):
         for input_index in xrange(len(self.inputs)):
             input_layer = self.get_input_layer(input_index)
             psize = self.config.size * input_layer.size
+            dims = [input_layer.size, self.config.size]
             format = self.inputs[input_index].format
             sparse = format == "csr" or format == "csc"
             if use_mkldnn:
                 config_assert(not sparse,
                               "MkldnnFCLayer do not support sparse format yet")
-            if use_mkldnn and use_mkldnn_wgt:
-                dims = [self.config.size, input_layer.size]
-            else:
-                dims = [input_layer.size, self.config.size]
+                if use_mkldnn_wgt:
+                    dims = [self.config.size, input_layer.size]
             if sparse:
                 psize = self.inputs[input_index].nnz
             else:
-- 
GitLab