diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc
index 5bfa1aaa696d5cbe8bdcb94d708746259952740f..909cd5895b2260cacb6e4ed56077b65ea6a8d62d 100644
--- a/paddle/fluid/operators/conv_mkldnn_op.cc
+++ b/paddle/fluid/operators/conv_mkldnn_op.cc
@@ -18,9 +18,6 @@
 namespace paddle {
 namespace operators {
 
-using conv_bwd_data = mkldnn::convolution_backward_data;
-using conv_bwd_weights = mkldnn::convolution_backward_weights;
-using conv_fwd = mkldnn::convolution_forward;
 using framework::DataLayout;
 using mkldnn::memory;
 using mkldnn::primitive;
@@ -39,6 +36,72 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
     conv_pd_ = conv_pd;
   }
 
+  ConvMKLDNNHandler(
+      std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd,
+      std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc>
+          conv_bwd_data_pd,
+      std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
+          conv_bwd_weights_pd,
+      const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
+      const std::string& base_key)
+      : platform::MKLDNNHandler(dev_ctx, engine, base_key),
+        conv_pd_(conv_pd),
+        conv_bwd_weights_pd_(conv_bwd_weights_pd),
+        conv_bwd_data_pd_(conv_bwd_data_pd) {
+    // If we are in Grad operatgor then update a key with BWD suffix to
+    // distinguish from FWD memory primitives
+    key_ += "-BWD";
+  }
+
+  std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromWeightsPrimitive(
+      const std::shared_ptr<mkldnn::memory> user_memory_p,
+      std::vector<mkldnn::primitive>& pipeline) {
+    auto src_pd = conv_bwd_weights_pd_->src_primitive_desc();
+    auto user_pd = user_memory_p->get_primitive_desc();
+    return this->AcquireMemory(src_pd, user_pd, user_memory_p,
+                               "@weights-src_mem_p", pipeline);
+  }
+
+  std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromWeightsPrimitive(
+      const std::shared_ptr<mkldnn::memory> user_memory_p,
+      std::vector<mkldnn::primitive>& pipeline) {
+    auto diff_dst_pd = conv_bwd_weights_pd_->diff_dst_primitive_desc();
+    auto user_pd = user_memory_p->get_primitive_desc();
+    return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p,
+                               "@weights-diff_dst_mem_p", pipeline);
+  }
+
+  std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemoryFromWeightsPrimitive(
+      void* ptr) {
+    return this->AcquireMemoryFromPrimitive(
+        conv_bwd_weights_pd_->diff_weights_primitive_desc(), ptr,
+        "@diff_weights_mem_p");
+  }
+
+  std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromDataPrimitive(
+      const std::shared_ptr<mkldnn::memory> user_memory_p,
+      std::vector<mkldnn::primitive>& pipeline) {
+    auto diff_dst_pd = conv_bwd_data_pd_->diff_dst_primitive_desc();
+    auto user_pd = user_memory_p->get_primitive_desc();
+    return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p,
+                               "@data-diff_dst_mem_p", pipeline);
+  }
+
+  std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromDataPrimitive(
+      const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
+      std::vector<mkldnn::primitive>& pipeline) {
+    auto weights_pd = conv_bwd_data_pd_->weights_primitive_desc();
+    auto user_pd = user_weights_memory_p->get_primitive_desc();
+    return this->AcquireMemory(weights_pd, user_pd, user_weights_memory_p,
+                               "@data-weights_mem_p", pipeline);
+  }
+
+  std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemoryFromDataPrimitive(
+      void* ptr) {
+    return this->AcquireMemoryFromPrimitive(
+        conv_bwd_data_pd_->diff_src_primitive_desc(), ptr, "@diff_src_mem_p");
+  }
+
   std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) {
     return this->AcquireMemoryFromPrimitive(conv_pd_->dst_primitive_desc(), ptr,
                                             "@dst_mem_p");
@@ -68,7 +131,6 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
       std::shared_ptr<mkldnn::memory> weights_memory_p,
       std::shared_ptr<mkldnn::memory> dst_memory_p) {
     auto prim_key = key_ + "@conv_p";
-    auto prim_desc_key = key_ + "@conv_pd";
     auto conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
         dev_ctx_.GetBlob(prim_key));
     PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
@@ -85,6 +147,54 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
     return conv_p;
   }
 
+  std::shared_ptr<mkldnn::convolution_backward_weights>
+  AcquireConvolutionBackwardWeights(
+      std::shared_ptr<mkldnn::memory> src_memory_p,
+      std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
+      std::shared_ptr<mkldnn::memory> diff_weights_memory_p) {
+    auto prim_key = key_ + "@conv_bwd_weights_p";
+    auto conv_bwd_weights_p =
+        std::static_pointer_cast<mkldnn::convolution_backward_weights>(
+            dev_ctx_.GetBlob(prim_key));
+    PADDLE_ENFORCE(
+        (conv_bwd_weights_p != nullptr) || (is_reusing_ == false),
+        "Fail to find convolution bwd weights primitive in device context");
+    if (conv_bwd_weights_p == nullptr) {
+      // create backward conv primitive for weights
+      conv_bwd_weights_p =
+          std::make_shared<mkldnn::convolution_backward_weights>(
+              *conv_bwd_weights_pd_, *src_memory_p, *diff_dst_memory_p,
+              *diff_weights_memory_p);
+      dev_ctx_.SetBlob(prim_key, conv_bwd_weights_p);
+    } else {
+      is_reusing_ = true;
+    }
+    return conv_bwd_weights_p;
+  }
+
+  std::shared_ptr<mkldnn::convolution_backward_data>
+  AcquireConvolutionBackwardData(
+      std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
+      std::shared_ptr<mkldnn::memory> weights_memory_p,
+      std::shared_ptr<mkldnn::memory> diff_src_memory_p) {
+    auto prim_key = key_ + "@conv_bwd_data_p";
+    auto conv_bwd_data_p =
+        std::static_pointer_cast<mkldnn::convolution_backward_data>(
+            dev_ctx_.GetBlob(prim_key));
+    PADDLE_ENFORCE(
+        (conv_bwd_data_p != nullptr) || (is_reusing_ == false),
+        "Fail to find convolution bwd data primitive in device context");
+    if (conv_bwd_data_p == nullptr) {
+      conv_bwd_data_p = std::make_shared<mkldnn::convolution_backward_data>(
+          *conv_bwd_data_pd_, *diff_dst_memory_p, *weights_memory_p,
+          *diff_src_memory_p);
+      dev_ctx_.SetBlob(prim_key, conv_bwd_data_p);
+    } else {
+      is_reusing_ = true;
+    }
+    return conv_bwd_data_p;
+  }
+
   // Generate keys for storing/retriving primitives for this operator
   // TODO(jczaja): Make hashing function more optimial
   static std::string GetHash(memory::dims& input_dims,
@@ -100,6 +210,10 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
 
  private:
   std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd_;
+  std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
+      conv_bwd_weights_pd_;
+  std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc>
+      conv_bwd_data_pd_;
 };
 
 template <typename T>
@@ -174,8 +288,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
         dst_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
 
     // create a conv primitive descriptor and save it for usage in backward
-    std::shared_ptr<conv_fwd::primitive_desc> conv_pd = ConvFwdPrimitiveDesc(
-        src_md, weights_md, dst_md, strides, paddings, mkldnn_engine);
+    std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
+        ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
+                             mkldnn_engine);
     // Save conv_pd/src_memory/weights_memory for backward pass
     dev_ctx.SetBlob(key_conv_pd, conv_pd);
 
@@ -208,21 +323,24 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
   }
 
  private:
-  std::unique_ptr<conv_fwd::primitive_desc> ConvFwdPrimitiveDesc(
-      const memory::desc& src, const memory::desc& weights,
-      const memory::desc& dst, const std::vector<int>& strides,
-      const std::vector<int>& paddings, const mkldnn::engine& engine) const {
+  std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
+  ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
+                       const memory::desc& dst, const std::vector<int>& strides,
+                       const std::vector<int>& paddings,
+                       const mkldnn::engine& engine) const {
     memory::dims stride_dims = {strides[0], strides[1]};
     memory::dims padding_dims = {paddings[0], paddings[1]};
 
-    auto conv_desc =
-        conv_fwd::desc(mkldnn::prop_kind::forward, mkldnn::convolution_direct,
-                       src, weights, dst, stride_dims, padding_dims,
-                       padding_dims, mkldnn::padding_kind::zero);
+    auto conv_desc = mkldnn::convolution_forward::desc(
+        mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights,
+        dst, stride_dims, padding_dims, padding_dims,
+        mkldnn::padding_kind::zero);
 
-    auto p_conv_pd = new conv_fwd::primitive_desc(conv_desc, engine);
+    auto p_conv_pd =
+        new mkldnn::convolution_forward::primitive_desc(conv_desc, engine);
 
-    return std::unique_ptr<conv_fwd::primitive_desc>(p_conv_pd);
+    return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
+        p_conv_pd);
   }
 };
 
@@ -290,147 +408,108 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
                                    dilations, groups, ctx.op().Input("Output"));
 
     const std::string key_conv_pd = key + "@conv_pd";
+    std::vector<primitive> pipeline;
 
-    // create mkldnn memory from input tensors (input/weights/output_grad)
-    auto user_src_memory = memory(
-        {{{src_tz}, memory::data_type::f32, input->format()}, mkldnn_engine},
-        to_void_cast(input_data));
-    auto user_weights_memory =
-        memory({{{weights_tz}, memory::data_type::f32, filter->format()},
-                mkldnn_engine},
-               to_void_cast(filter_data));
-    auto user_diff_dst_memory =
-        memory({{{dst_tz}, memory::data_type::f32, output_grad->format()},
-                mkldnn_engine},
-               to_void_cast(output_grad_data));
+    // Create user memory descriptors
+    auto user_src_md = platform::MKLDNNMemDesc(
+        {src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
+    auto user_weights_md = platform::MKLDNNMemDesc(
+        {weights_tz}, platform::MKLDNNGetDataType<T>(), filter->format());
+    auto user_diff_dst_md = platform::MKLDNNMemDesc(
+        {dst_tz}, platform::MKLDNNGetDataType<T>(), output_grad->format());
 
     /* create memory descriptor for conv backward without specified format
      * ('any') which lets a primitive (conv backward in this case) choose
      * the memory format preferred for best performance
      */
-    auto src_md = platform::MKLDNNMemDesc(src_tz, memory::data_type::f32,
-                                          memory::format::any);
-    auto diff_src_md = platform::MKLDNNMemDesc(src_tz, memory::data_type::f32,
-                                               memory::format::any);
+    auto src_md = platform::MKLDNNMemDesc(
+        src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
+    auto diff_src_md = platform::MKLDNNMemDesc(
+        src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
     auto weights_md = platform::MKLDNNMemDesc(
-        weights_tz, memory::data_type::f32, memory::format::any);
+        weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
     auto diff_weights_md = platform::MKLDNNMemDesc(
-        weights_tz, memory::data_type::f32, memory::format::any);
-    auto diff_dst_md = platform::MKLDNNMemDesc(dst_tz, memory::data_type::f32,
-                                               memory::format::any);
+        weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
+    auto diff_dst_md = platform::MKLDNNMemDesc(
+        dst_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
 
     // Retrieve conv_pd from device context
-    auto conv_pd = std::static_pointer_cast<conv_fwd::primitive_desc>(
-        dev_ctx.GetBlob(key_conv_pd));
+    auto conv_pd =
+        std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
+            dev_ctx.GetBlob(key_conv_pd));
     PADDLE_ENFORCE(conv_pd != nullptr,
                    "Fail to find conv_pd in device context");
 
+    // create backward convolution weights primitive descriptor
+    auto conv_bwd_weights_desc = mkldnn::convolution_backward_weights::desc(
+        mkldnn::convolution_direct, src_md, diff_weights_md, diff_dst_md,
+        strides, paddings, paddings, mkldnn::padding_kind::zero);
+    auto conv_bwd_weights_pd =
+        std::make_shared<mkldnn::convolution_backward_weights::primitive_desc>(
+            conv_bwd_weights_desc, mkldnn_engine, *conv_pd);
+
+    // create backward convolution data primitive descriptor
+    auto conv_bwd_data_desc = mkldnn::convolution_backward_data::desc(
+        mkldnn::convolution_direct, diff_src_md, weights_md, diff_dst_md,
+        strides, paddings, paddings, mkldnn::padding_kind::zero);
+    auto conv_bwd_data_pd =
+        std::make_shared<mkldnn::convolution_backward_data::primitive_desc>(
+            conv_bwd_data_desc, mkldnn_engine, *conv_pd);
+
+    ConvMKLDNNHandler handler(conv_pd, conv_bwd_data_pd, conv_bwd_weights_pd,
+                              dev_ctx, mkldnn_engine, key);
+
+    // create mkldnn memory from input tensors (data/weights)
+    auto user_src_memory_p =
+        handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
+    auto user_weights_memory_p = handler.AcquireWeightsMemory(
+        user_weights_md, to_void_cast<T>(filter_data));
+    auto user_diff_dst_memory_p = handler.AcquireDiffDstMemory(
+        user_diff_dst_md, to_void_cast<T>(output_grad_data));
+
     // create backward conv primitive for weights
     if (filter_grad) {
-      // create backward convolution primitive descriptor
-      auto conv_bwd_weights_desc = conv_bwd_weights::desc(
-          mkldnn::convolution_direct, src_md, diff_weights_md, diff_dst_md,
-          strides, paddings, paddings, mkldnn::padding_kind::zero);
-      auto conv_bwd_weights_pd = conv_bwd_weights::primitive_desc(
-          conv_bwd_weights_desc, mkldnn_engine, *conv_pd);
-
-      // create reorder primitive if the input format is not the preferred one
-      auto src_memory = user_src_memory;
-      primitive reorder_src;
-      bool is_src_reordered = false;
-      if (memory::primitive_desc(conv_bwd_weights_pd.src_primitive_desc()) !=
-          user_src_memory.get_primitive_desc()) {
-        src_memory = memory(conv_bwd_weights_pd.src_primitive_desc());
-        reorder_src = reorder(user_src_memory, src_memory);
-        is_src_reordered = true;
-      }
-
-      auto diff_dst_memory_4filter = user_diff_dst_memory;
-      primitive reorder_diff_dst_4filter;
-      bool is_diff_dst_reordered_4filter = false;
-      if (memory::primitive_desc(
-              conv_bwd_weights_pd.diff_dst_primitive_desc()) !=
-          user_diff_dst_memory.get_primitive_desc()) {
-        diff_dst_memory_4filter =
-            memory(conv_bwd_weights_pd.diff_dst_primitive_desc());
-        reorder_diff_dst_4filter =
-            reorder(user_diff_dst_memory, diff_dst_memory_4filter);
-        is_diff_dst_reordered_4filter = true;
-      }
-
-      // create mkldnn memory for output (i.e. diff weights)
-      auto diff_weights_memory =
-          memory(conv_bwd_weights_pd.diff_weights_primitive_desc(),
-                 reinterpret_cast<void*>(filter_grad_data));
+      auto src_memory_p = handler.AcquireSrcMemoryFromWeightsPrimitive(
+          user_src_memory_p, pipeline);
 
-      // create backward conv primitive for weights
-      auto conv_bwd_weights_prim =
-          conv_bwd_weights(conv_bwd_weights_pd, src_memory,
-                           diff_dst_memory_4filter, diff_weights_memory);
-
-      // push primitive and execute it
-      std::vector<primitive> pipeline;
-      if (is_src_reordered) pipeline.push_back(reorder_src);
-      if (is_diff_dst_reordered_4filter)
-        pipeline.push_back(reorder_diff_dst_4filter);
-      pipeline.push_back(conv_bwd_weights_prim);
-      stream(stream::kind::eager).submit(pipeline).wait();
+      auto diff_dst_memory_4filter_p =
+          handler.AcquireDiffDstMemoryFromWeightsPrimitive(
+              user_diff_dst_memory_p, pipeline);
+
+      auto diff_weights_memory_p =
+          handler.AcquireDiffWeightsMemoryFromWeightsPrimitive(
+              reinterpret_cast<void*>(filter_grad_data));
+
+      auto conv_bwd_weights_p = handler.AcquireConvolutionBackwardWeights(
+          src_memory_p, diff_dst_memory_4filter_p, diff_weights_memory_p);
+
+      // push primitive to stream and wait until it's executed
+      pipeline.push_back(*conv_bwd_weights_p);
 
       filter_grad->set_layout(DataLayout::kMKLDNN);
-      filter_grad->set_format(GetMKLDNNFormat(diff_weights_memory));
+      filter_grad->set_format(GetMKLDNNFormat(*diff_weights_memory_p));
     }
 
     if (input_grad) {
-      // create backward convolution primitive descriptor
-      auto conv_bwd_data_desc = conv_bwd_data::desc(
-          mkldnn::convolution_direct, diff_src_md, weights_md, diff_dst_md,
-          strides, paddings, paddings, mkldnn::padding_kind::zero);
-      auto conv_bwd_data_pd = conv_bwd_data::primitive_desc(
-          conv_bwd_data_desc, mkldnn_engine, *conv_pd);
-
-      // create reorder primitive if the input format is not the preferred one
-      auto weights_memory = user_weights_memory;
-      primitive reorder_weights;
-      bool is_weights_reordered = false;
-      if (memory::primitive_desc(conv_bwd_data_pd.weights_primitive_desc()) !=
-          user_weights_memory.get_primitive_desc()) {
-        weights_memory = memory(conv_bwd_data_pd.weights_primitive_desc());
-        reorder_weights = reorder(user_weights_memory, weights_memory);
-        is_weights_reordered = true;
-      }
-
-      auto diff_dst_memory_4data = user_diff_dst_memory;
-      primitive reorder_diff_dst_4data;
-      bool is_diff_dst_reordered_4data = false;
-      if (memory::primitive_desc(conv_bwd_data_pd.diff_dst_primitive_desc()) !=
-          user_diff_dst_memory.get_primitive_desc()) {
-        diff_dst_memory_4data =
-            memory(conv_bwd_data_pd.diff_dst_primitive_desc());
-        reorder_diff_dst_4data =
-            reorder(user_diff_dst_memory, diff_dst_memory_4data);
-        is_diff_dst_reordered_4data = true;
-      }
-
-      // create mkldnn memory for output (i.e. diff src)
-      auto diff_src_memory = memory(conv_bwd_data_pd.diff_src_primitive_desc(),
-                                    reinterpret_cast<void*>(input_grad_data));
-
-      // create backward conv primitive for data
-      auto conv_bwd_data_prim =
-          conv_bwd_data(conv_bwd_data_pd, diff_dst_memory_4data, weights_memory,
-                        diff_src_memory);
-
-      // push primitive and execute it
-      std::vector<primitive> pipeline;
-      if (is_weights_reordered) pipeline.push_back(reorder_weights);
-      if (is_diff_dst_reordered_4data)
-        pipeline.push_back(reorder_diff_dst_4data);
-      pipeline.push_back(conv_bwd_data_prim);
-      stream(stream::kind::eager).submit(pipeline).wait();
+      auto weights_memory_p = handler.AcquireWeightsMemoryFromDataPrimitive(
+          user_weights_memory_p, pipeline);
+
+      auto diff_dst_memory_4data_p =
+          handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p,
+                                                        pipeline);
+
+      auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive(
+          reinterpret_cast<void*>(input_grad_data));
+
+      auto conv_bwd_data_p = handler.AcquireConvolutionBackwardData(
+          diff_dst_memory_4data_p, weights_memory_p, diff_src_memory_p);
+
+      pipeline.push_back(*conv_bwd_data_p);
 
       input_grad->set_layout(DataLayout::kMKLDNN);
-      input_grad->set_format(GetMKLDNNFormat(diff_src_memory));
+      input_grad->set_format(GetMKLDNNFormat(*diff_src_memory_p));
     }
+    stream(stream::kind::eager).submit(pipeline).wait();
   }  // Compute()
 };