diff --git a/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc index d073822fc548d88a0878739684af2fca4d20ef48..fe1ead8fed6a7dc075be6afb57b815a49e90fb4e 100644 --- a/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc @@ -32,16 +32,11 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { "MKLDNN LRN must use CPUPlace."); auto& dev_ctx = ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); auto x = ctx.Input("X"); auto out = ctx.Output("Out"); auto mid = ctx.Output("MidOut"); - auto input_data = x->data(); - auto output_data = out->mutable_data(ctx.GetPlace()); - mid->mutable_data(ctx.GetPlace()); - const int n = ctx.Attr("n"); // MKL-DNN implements LRN in a caffe way: // http://caffe.berkeleyvision.org/tutorial/layers/lrn.html @@ -52,31 +47,32 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { const float alpha = ctx.Attr("alpha") * static_cast(n); const float beta = ctx.Attr("beta"); const float k = ctx.Attr("k"); - - auto e_mid = framework::EigenTensor::From(*mid); - e_mid = e_mid.constant(k); + bool is_test = ctx.Attr("is_test"); auto dims = paddle::framework::vectorize(x->dims()); - // Format and dims are assumed to be the same for dst and src - auto md = paddle::platform::MKLDNNMemDesc( - dims, platform::MKLDNNGetDataType(), x->format()); - - const std::string key = platform::CreateKey( - dims, n, alpha, beta, k, x->format(), ctx.op().Output("Out")); - - platform::LRNMKLDNNHandler handler(ctx.Attr("is_test"), dev_ctx, - mkldnn_engine, key); - auto src_memory = - handler.AcquireSrcMemory(md, platform::to_void_cast(input_data)); - - // TODO(jczaja): Hide getting PD inside of handler for all Acquire API - handler.AcquireLRNPrimitiveDescriptor(md, n, alpha, beta, k); - - auto dst_memory = - handler.AcquireDstMemory(md, platform::to_void_cast(output_data)); - - auto lrn_p = handler.AcquireLRN(dst_memory, src_memory); + platform::LRNMKLDNNHandler handler(dims, n, alpha, beta, k, x->format(), + is_test, dev_ctx, ctx.GetPlace(), + ctx.op().Output("Out")); + + auto src_memory = handler.AcquireSrcMemory(x); + auto dst_memory = handler.AcquireDstMemory(out); + + std::shared_ptr workspace_memory; + std::shared_ptr lrn_p; + if (is_test == false) { + workspace_memory = handler.AcquireWorkspaceMemory(mid); + lrn_p = handler.AcquireForwardPrimitive(*src_memory, *workspace_memory, + *dst_memory); + } else { + // mid has to be allocated and filled + // k to pass LRN unit tests + // TODO(jczaja): Disable checking mid in unit tests (Require API change) + mid->mutable_data(ctx.GetPlace()); + auto e_mid = framework::EigenTensor::From(*mid); + e_mid = e_mid.constant(k); + lrn_p = handler.AcquireForwardPrimitive(*src_memory, *dst_memory); + } std::vector pipeline = {*lrn_p}; mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); @@ -104,6 +100,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel { "is_test attribute should be set to False in training phase."); auto x = ctx.Input("X"); + auto mid = ctx.Input("MidOut"); auto out_grad = ctx.Input(framework::GradVarName("Out")); auto x_grad = ctx.Output(framework::GradVarName("X")); @@ -114,42 +111,20 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel { const float k = ctx.Attr("k"); auto& dev_ctx = ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - - auto x_grad_data = x_grad->mutable_data(ctx.GetPlace()); - auto out_grad_data = out_grad->data(); auto dims = paddle::framework::vectorize(x->dims()); - const std::string key = platform::CreateKey( - dims, n, alpha, beta, k, x->format(), ctx.op().Input("Out")); - - platform::LRNMKLDNNHandler handler(false, dev_ctx, mkldnn_engine, key); - - auto src_md = paddle::platform::MKLDNNMemDesc( - dims, platform::MKLDNNGetDataType(), x->format()); - - // diff_dst and diff_src layouts are assumed to be the same - auto diff_md = paddle::platform::MKLDNNMemDesc( - dims, platform::MKLDNNGetDataType(), out_grad->format()); - - auto workspace = handler.AcquireWorkspaceMemory(); - - auto diff_dst_memory = handler.AcquireDiffDstMemory( - diff_md, platform::to_void_cast(out_grad_data)); - - auto diff_src_memory = handler.AcquireDiffSrcMemory( - diff_md, platform::to_void_cast(x_grad_data)); - - auto src_memory = handler.AcquireSrcMemory( - src_md, platform::to_void_cast(x->data())); + platform::LRNMKLDNNHandler handler( + dims, n, alpha, beta, k, x->format(), out_grad->format(), dev_ctx, + ctx.GetPlace(), ctx.op().Input("Out")); - // TODO(jczaja): Hide this call inside Handler - handler.AcquireLRNBackwardPrimitiveDescriptor(src_md, diff_md, n, alpha, - beta, k); + auto src_memory = handler.AcquireSrcMemory(x); + auto workspace = handler.AcquireBackwardWorkspaceMemory(mid); + auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad); + auto diff_src_memory = handler.AcquireDiffSrcMemory(x_grad); - auto lrn_bwd = handler.AcquireLRNBackward(src_memory, diff_dst_memory, - workspace, diff_src_memory); + auto lrn_bwd = handler.AcquireBackwardPrimitive( + *src_memory, *diff_dst_memory, *workspace, *diff_src_memory); std::vector pipeline = {*lrn_bwd}; mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 8f63ba051d07fd9c32d803fcb9aa36faf1105dbf..53697f587e5d04b97b66eda56ba20683ca652342 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -460,141 +460,64 @@ class ActivationMKLDNNHandler } }; -class LRNMKLDNNHandler : public MKLDNNHandler { +template +class LRNMKLDNNHandler + : public MKLDNNHandlerT { public: - LRNMKLDNNHandler(bool is_test, const platform::MKLDNNDeviceContext& dev_ctx, - mkldnn::engine engine, const std::string& base_key) - : platform::MKLDNNHandler(dev_ctx, engine, base_key), is_test_(is_test) {} - - std::shared_ptr - AcquireLRNPrimitiveDescriptor(const mkldnn::memory::desc& src_md, const int n, - const float alpha, const float beta, - const float k) { - // LRN PD has to be passed to Grad op that - // may be executed by diffrent thread, hence - // for that one we use key that does not contain TID - const std::string key_lrn_pd = key_common_ + "@lrn_pd"; - fwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_lrn_pd)); - if (fwd_pd_ == nullptr) { - static std::mutex acquire_barrier; - std::lock_guard block_threads_until_finish_this_job( - acquire_barrier); - fwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_lrn_pd)); - if (fwd_pd_ == nullptr) { - auto forward_desc = mkldnn::lrn_forward::desc{ - is_test_ ? mkldnn::prop_kind::forward_inference - : mkldnn::prop_kind::forward_training, - mkldnn::lrn_across_channels, src_md, n, alpha, beta, k}; - fwd_pd_.reset( - new mkldnn::lrn_forward::primitive_desc(forward_desc, engine_)); - dev_ctx_.SetBlob(key_lrn_pd, fwd_pd_); - } - } - return fwd_pd_; - } + LRNMKLDNNHandler(const std::vector& dims, const int n, const float alpha, + const float beta, const float k, + const MKLDNNMemoryFormat fmt, bool is_test, + const platform::MKLDNNDeviceContext& dev_ctx, + platform::Place cpu_place, const std::string& unique_name) - std::shared_ptr AcquireWorkspaceMemory(void) { - // workspace has to be passed to Grad op that - // may be executed by diffrent thread, hence - // for that one we use key that does not contain TID - auto local_key = key_common_ + "@workspace"; - auto mem_p = - std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); - if (mem_p == nullptr) { - static std::mutex acquire_barrier; - std::lock_guard block_threads_until_finish_this_job( - acquire_barrier); - mem_p = - std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); - if (mem_p == nullptr) { - const std::string key_lrn_pd = key_common_ + "@lrn_pd"; - fwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_lrn_pd)); - // PD from FWD op has to exist. - PADDLE_ENFORCE(fwd_pd_ != nullptr, - "LRN PD MKL-DNN not found in cache!"); - mkldnn::memory::primitive_desc workspace_mpd = - fwd_pd_->workspace_primitive_desc(); - mem_p = std::make_shared(workspace_mpd); - dev_ctx_.SetBlob(local_key, mem_p); - } - } - return mem_p; + : platform::MKLDNNHandlerT( + dev_ctx, dev_ctx.GetEngine(), cpu_place, + platform::CreateKey(dims, n, alpha, beta, k, fmt, unique_name)) { + auto src_md = + mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); + this->AcquireForwardPrimitiveDescriptor( + is_test ? mkldnn::prop_kind::forward_inference + : mkldnn::prop_kind::forward_training, + mkldnn::lrn_across_channels, src_md, n, alpha, beta, k); } - std::shared_ptr AcquireLRN( - std::shared_ptr dst_memory, - std::shared_ptr src_memory) { - auto prim_key = key_ + "@lrn_p"; - - auto lrn_p = std::static_pointer_cast( - dev_ctx_.GetBlob(prim_key)); - if (lrn_p == nullptr) { - if (is_test_) { - lrn_p = std::make_shared(*fwd_pd_, *(src_memory), - *(dst_memory)); - } else { - // For training we need to create workspace - // to store indices from backward - auto workspace_memory = this->AcquireWorkspaceMemory(); + LRNMKLDNNHandler(const std::vector& dims, const int n, const float alpha, + const float beta, const float k, + const MKLDNNMemoryFormat fmt, + const MKLDNNMemoryFormat diff_fmt, + const platform::MKLDNNDeviceContext& dev_ctx, + platform::Place cpu_place, const std::string& unique_name) - lrn_p = std::make_shared( - *fwd_pd_, *src_memory, *workspace_memory, *dst_memory); - } - dev_ctx_.SetBlob(prim_key, lrn_p); - } - return lrn_p; - } - - std::shared_ptr - AcquireLRNBackwardPrimitiveDescriptor(const mkldnn::memory::desc& src_md, - const mkldnn::memory::desc& diff_md, - const int n, const float alpha, - const float beta, const float k) { - const std::string key_lrn_pd = key_common_ + "@lrn_pd"; - const std::string key_lrn_bwd_pd = key_ + "@lrn_bwd_pd"; - bwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_lrn_bwd_pd)); - if (bwd_pd_ == nullptr) { - fwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_lrn_pd)); - // PD from FWD op has to exist. - PADDLE_ENFORCE(fwd_pd_ != nullptr, "LRN MKL-DNN not found in cache!"); + : platform::MKLDNNHandlerT( + dev_ctx, dev_ctx.GetEngine(), cpu_place, + platform::CreateKey(dims, n, alpha, beta, k, fmt, unique_name)) { + auto src_md = + mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); + auto diff_md = + mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), diff_fmt); - auto backward_desc = mkldnn::lrn_backward::desc{ - mkldnn::lrn_across_channels, src_md, diff_md, n, alpha, beta, k}; - bwd_pd_.reset(new mkldnn::lrn_backward::primitive_desc( - backward_desc, engine_, *fwd_pd_)); - dev_ctx_.SetBlob(key_lrn_bwd_pd, bwd_pd_); - } - return bwd_pd_; + this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training, + mkldnn::lrn_across_channels, src_md, + n, alpha, beta, k); + this->AcquireBackwardPrimitiveDescriptor( + mkldnn::lrn_across_channels, src_md, diff_md, n, alpha, beta, k); } - std::shared_ptr AcquireLRNBackward( - std::shared_ptr src_memory, - std::shared_ptr diff_dst_memory, - std::shared_ptr workspace, - std::shared_ptr diff_src_memory) { - auto prim_key = key_ + "@lrn_bwd_p"; - - auto lrn_bwd_p = std::static_pointer_cast( - dev_ctx_.GetBlob(prim_key)); - if (lrn_bwd_p == nullptr) { - lrn_bwd_p = std::make_shared( - *bwd_pd_, *src_memory, *diff_dst_memory, *workspace, - *diff_src_memory); - dev_ctx_.SetBlob(prim_key, lrn_bwd_p); - } - - return lrn_bwd_p; + std::shared_ptr AcquireWorkspaceMemory( + framework::Tensor* workspace) { + T* ptr = workspace->mutable_data( + this->place_, this->fwd_pd_->dst_primitive_desc().get_size()); + return this->AcquireMemoryFromPrimitive( + this->fwd_pd_->workspace_primitive_desc(), ptr, "@wrk_mem_p"); } - private: - bool is_test_; - std::shared_ptr fwd_pd_; - std::shared_ptr bwd_pd_; + std::shared_ptr AcquireBackwardWorkspaceMemory( + const framework::Tensor* workspace) { + const T* workspace_data = workspace->data(); + return this->AcquireMemoryFromPrimitive( + this->fwd_pd_->workspace_primitive_desc(), + to_void_cast(workspace_data), "@bwd-wrk_mem_p"); + } }; class PoolingMKLDNNHandler : public MKLDNNHandler {