diff --git a/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc index 097ba01d401dbc7969e30f576cac2567c874ed99..97ffb385a0e87f82d04d1e3b8e27b38959476d12 100644 --- a/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/lrn_op.h" -#include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" namespace paddle { namespace operators { @@ -22,30 +22,6 @@ namespace operators { using paddle::framework::Tensor; using paddle::platform::MKLDNNDeviceContext; -namespace { -template -std::shared_ptr insert_to_context(const std::string& key, - const MKLDNNDeviceContext& dev_ctx, - Args&&... args) { - auto p = std::static_pointer_cast(dev_ctx.GetBlob(key)); - - if (!p) { - p = std::make_shared(args...); - dev_ctx.SetBlob(key, std::static_pointer_cast(p)); - } - - return p; -} - -template -void run_primitive(Args&&... args) { - auto forward_op = mkldnn::lrn_forward{args...}; - - std::vector pipeline = {forward_op}; - mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); -} -} // namespace - template class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { public: @@ -76,66 +52,42 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { const float alpha = ctx.Attr("alpha") * static_cast(n); const float beta = ctx.Attr("beta"); const float k = ctx.Attr("k"); - const bool is_test = ctx.Attr("is_test"); auto e_mid = framework::EigenTensor::From(*mid); e_mid = e_mid.constant(k); auto dims = paddle::framework::vectorize2int(x->dims()); - auto src_md = paddle::platform::MKLDNNMemDesc( - dims, mkldnn::memory::data_type::f32, x->format()); - - auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward, - mkldnn::lrn_across_channels, - src_md, - n, - alpha, - beta, - k}; - - auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine}; - - if (!is_test) { - const std::string key = ctx.op().Output("Out"); - const std::string key_src_memory = key + "@lrn_src_memory"; - const std::string key_pd = key + "@lrn_pd"; - const std::string key_workspace_memory = key + "@lrn_workspace_memory"; - - auto forward_pd = insert_to_context( - key_pd, dev_ctx, forward_desc, mkldnn_engine); - - auto src_memory = insert_to_context( - key_src_memory, dev_ctx, src_memory_pd); - - src_memory->set_data_handle( - static_cast(const_cast(input_data))); - - auto dst_memory = mkldnn::memory(forward_pd->dst_primitive_desc(), - static_cast(output_data)); - auto workspace_memory = insert_to_context( - key_workspace_memory, dev_ctx, - forward_pd->workspace_primitive_desc()); - - run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory); - - out->set_layout(framework::DataLayout::kMKLDNN); - out->set_format(platform::GetMKLDNNFormat(dst_memory)); - } else { - auto forward_pd = - mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine}; - auto src_memory = mkldnn::memory{ - src_memory_pd, static_cast(const_cast(input_data))}; - auto workspace_memory = - mkldnn::memory{forward_pd.workspace_primitive_desc()}; - auto dst_memory = mkldnn::memory(forward_pd.dst_primitive_desc(), - static_cast(output_data)); - - run_primitive(forward_pd, src_memory, workspace_memory, dst_memory); - - out->set_layout(framework::DataLayout::kMKLDNN); - out->set_format(platform::GetMKLDNNFormat(dst_memory)); - } + // Format and dims are assumed to be the same for dst and src + auto md = paddle::platform::MKLDNNMemDesc( + dims, platform::MKLDNNGetDataType(), x->format()); + + const std::string key = platform::LRNMKLDNNHandler::GetHash( + dims, n, alpha, beta, k, x->format(), ctx.op().Output("Out")); + + platform::LRNMKLDNNHandler handler(ctx.Attr("is_test"), dev_ctx, + mkldnn_engine, key); + auto src_memory = + handler.AcquireSrcMemory(md, platform::to_void_cast(input_data)); + + // TODO(jczaja): Hide getting PD inside of handler for all Acquire API + handler.AcquireLRNPrimitiveDescriptor(md, n, alpha, beta, k); + + auto dst_memory = + handler.AcquireDstMemory(md, platform::to_void_cast(output_data)); + + auto lrn_p = handler.AcquireLRN(dst_memory, src_memory); + + std::vector pipeline = {*lrn_p}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); + + auto output_format = + (mkldnn::memory::format)dst_memory->get_primitive_desc() + .desc() + .data.format; + + out->set_layout(framework::DataLayout::kMKLDNN); + out->set_format(output_format); } }; @@ -156,11 +108,6 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto out_grad = ctx.Input(framework::GradVarName("Out")); auto x_grad = ctx.Output(framework::GradVarName("X")); - const std::string key = ctx.op().Input("Out"); - const std::string key_src_memory = key + "@lrn_src_memory"; - const std::string key_pd = key + "@lrn_pd"; - const std::string key_workspace_memory = key + "@lrn_workspace_memory"; - const int n = ctx.Attr("n"); const float alpha = ctx.Attr("alpha") * static_cast(n); const float beta = ctx.Attr("beta"); @@ -174,42 +121,46 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto dims = paddle::framework::vectorize2int(x->dims()); - auto src_md = paddle::platform::MKLDNNMemDesc( - dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); + const std::string key = platform::LRNMKLDNNHandler::GetHash( + dims, n, alpha, beta, k, x->format(), ctx.op().Input("Out")); - auto diff_src_md = paddle::platform::MKLDNNMemDesc( - dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); + platform::LRNMKLDNNHandler handler(false, dev_ctx, mkldnn_engine, key); - auto diff_dst_md = paddle::platform::MKLDNNMemDesc( - dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); + auto src_md = paddle::platform::MKLDNNMemDesc( + dims, platform::MKLDNNGetDataType(), x->format()); - auto diff_dst_memory = - mkldnn::memory{{diff_dst_md, mkldnn_engine}, - static_cast(const_cast(out_grad_data))}; + // diff_dst and diff_src layouts are assumed to be the same + auto diff_md = paddle::platform::MKLDNNMemDesc( + dims, platform::MKLDNNGetDataType(), out_grad->format()); - auto diff_src_memory = mkldnn::memory{{diff_src_md, mkldnn_engine}, - static_cast(x_grad_data)}; + auto workspace = handler.AcquireWorkspaceMemory(); - auto backward_desc = mkldnn::lrn_backward::desc{ - mkldnn::lrn_across_channels, src_md, diff_src_md, n, alpha, beta, k}; + auto diff_dst_memory = handler.AcquireDiffDstMemory( + diff_md, platform::to_void_cast(out_grad_data)); - auto forward_pd = dev_ctx.GetBlob(key_pd); + auto diff_src_memory = handler.AcquireDiffSrcMemory( + diff_md, platform::to_void_cast(x_grad_data)); - auto backward_pd = mkldnn::lrn_backward::primitive_desc{ - backward_desc, mkldnn_engine, - *static_cast(forward_pd.get())}; + auto src_memory = handler.AcquireSrcMemory( + src_md, platform::to_void_cast(x->data())); - std::shared_ptr workspace_memory = - dev_ctx.GetBlob(key_workspace_memory); + // TODO(jczaja): Hide this call inside Handler + handler.AcquireLRNBackwardPrimitiveDescriptor(src_md, diff_md, n, alpha, + beta, k); - auto src_memory = dev_ctx.GetBlob(key_src_memory); - auto backward_op = mkldnn::lrn_backward{ - backward_pd, *static_cast(src_memory.get()), - diff_dst_memory, *static_cast(workspace_memory.get()), - diff_src_memory}; + auto lrn_bwd = handler.AcquireLRNBackward(src_memory, diff_dst_memory, + workspace, diff_src_memory); - std::vector pipeline = {backward_op}; + std::vector pipeline = {*lrn_bwd}; mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); + + auto output_format = + (mkldnn::memory::format)diff_src_memory->get_primitive_desc() + .desc() + .data.format; + + x_grad->set_layout(framework::DataLayout::kMKLDNN); + x_grad->set_format(output_format); } }; } // namespace operators diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 4571a7ea13ea31af3405005c45bf9861c19bd921..9f277d682b4ad767b67b1f83298a75688aadc74b 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -436,6 +436,159 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { std::shared_ptr activation_bwd_pd_; }; +class LRNMKLDNNHandler : public MKLDNNHandler { + public: + LRNMKLDNNHandler(bool is_test, const platform::MKLDNNDeviceContext& dev_ctx, + mkldnn::engine engine, const std::string& base_key) + : platform::MKLDNNHandler(dev_ctx, engine, base_key), is_test_(is_test) {} + + std::shared_ptr + AcquireLRNPrimitiveDescriptor(const mkldnn::memory::desc& src_md, const int n, + const float alpha, const float beta, + const float k) { + // LRN PD has to be passed to Grad op that + // may be executed by diffrent thread, hence + // for that one we use key that does not contain TID + const std::string key_lrn_pd = key_common_ + "@lrn_pd"; + fwd_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_lrn_pd)); + if (fwd_pd_ == nullptr) { + static std::mutex acquire_barrier; + std::lock_guard block_threads_until_finish_this_job( + acquire_barrier); + fwd_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_lrn_pd)); + if (fwd_pd_ == nullptr) { + auto forward_desc = mkldnn::lrn_forward::desc{ + is_test_ ? mkldnn::prop_kind::forward_inference + : mkldnn::prop_kind::forward_training, + mkldnn::lrn_across_channels, src_md, n, alpha, beta, k}; + fwd_pd_.reset( + new mkldnn::lrn_forward::primitive_desc(forward_desc, engine_)); + dev_ctx_.SetBlob(key_lrn_pd, fwd_pd_); + } + } + return fwd_pd_; + } + + std::shared_ptr AcquireWorkspaceMemory(void) { + // workspace has to be passed to Grad op that + // may be executed by diffrent thread, hence + // for that one we use key that does not contain TID + auto local_key = key_common_ + "@workspace"; + auto mem_p = + std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); + if (mem_p == nullptr) { + static std::mutex acquire_barrier; + std::lock_guard block_threads_until_finish_this_job( + acquire_barrier); + mem_p = + std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); + if (mem_p == nullptr) { + const std::string key_lrn_pd = key_common_ + "@lrn_pd"; + fwd_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_lrn_pd)); + // PD from FWD op has to exist. + PADDLE_ENFORCE(fwd_pd_ != nullptr, + "LRN PD MKL-DNN not found in cache!"); + mkldnn::memory::primitive_desc workspace_mpd = + fwd_pd_->workspace_primitive_desc(); + mem_p = std::make_shared(workspace_mpd); + dev_ctx_.SetBlob(local_key, mem_p); + } + } + return mem_p; + } + + std::shared_ptr AcquireLRN( + std::shared_ptr dst_memory, + std::shared_ptr src_memory) { + auto prim_key = key_ + "@lrn_p"; + + auto lrn_p = std::static_pointer_cast( + dev_ctx_.GetBlob(prim_key)); + if (lrn_p == nullptr) { + if (is_test_) { + lrn_p = std::make_shared(*fwd_pd_, *(src_memory), + *(dst_memory)); + } else { + // For training we need to create workspace + // to store indices from backward + auto workspace_memory = this->AcquireWorkspaceMemory(); + + lrn_p = std::make_shared( + *fwd_pd_, *src_memory, *workspace_memory, *dst_memory); + } + dev_ctx_.SetBlob(prim_key, lrn_p); + } + return lrn_p; + } + + std::shared_ptr + AcquireLRNBackwardPrimitiveDescriptor(const mkldnn::memory::desc& src_md, + const mkldnn::memory::desc& diff_md, + const int n, const float alpha, + const float beta, const float k) { + const std::string key_lrn_pd = key_common_ + "@lrn_pd"; + const std::string key_lrn_bwd_pd = key_ + "@lrn_bwd_pd"; + bwd_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_lrn_bwd_pd)); + if (bwd_pd_ == nullptr) { + fwd_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_lrn_pd)); + // PD from FWD op has to exist. + PADDLE_ENFORCE(fwd_pd_ != nullptr, "LRN MKL-DNN not found in cache!"); + + auto backward_desc = mkldnn::lrn_backward::desc{ + mkldnn::lrn_across_channels, src_md, diff_md, n, alpha, beta, k}; + bwd_pd_.reset(new mkldnn::lrn_backward::primitive_desc( + backward_desc, engine_, *fwd_pd_)); + dev_ctx_.SetBlob(key_lrn_bwd_pd, bwd_pd_); + } + return bwd_pd_; + } + + std::shared_ptr AcquireLRNBackward( + std::shared_ptr src_memory, + std::shared_ptr diff_dst_memory, + std::shared_ptr workspace, + std::shared_ptr diff_src_memory) { + auto prim_key = key_ + "@lrn_bwd_p"; + + auto lrn_bwd_p = std::static_pointer_cast( + dev_ctx_.GetBlob(prim_key)); + if (lrn_bwd_p == nullptr) { + lrn_bwd_p = std::make_shared( + *bwd_pd_, *src_memory, *diff_dst_memory, *workspace, + *diff_src_memory); + dev_ctx_.SetBlob(prim_key, lrn_bwd_p); + } + + return lrn_bwd_p; + } + + static std::string GetHash(const memory::dims& input_dims, const int n, + const float alpha, const float beta, const float k, + const memory::format& fmt, + const std::string& suffix) { + std::string key; + key.reserve(platform::MKLDNNHandler::MaxKeyLength); + platform::MKLDNNHandler::AppendKeyDims(&key, input_dims); + platform::MKLDNNHandler::AppendKey(&key, std::to_string(n)); + platform::MKLDNNHandler::AppendKey(&key, std::to_string(alpha)); + platform::MKLDNNHandler::AppendKey(&key, std::to_string(beta)); + platform::MKLDNNHandler::AppendKey(&key, std::to_string(k)); + platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt)); + platform::MKLDNNHandler::AppendKey(&key, suffix); + return key; + } + + private: + bool is_test_; + std::shared_ptr fwd_pd_; + std::shared_ptr bwd_pd_; +}; + class PoolingMKLDNNHandler : public MKLDNNHandler { public: PoolingMKLDNNHandler(const std::string& pooling_type,