diff --git a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc index a01dd512a378217df6f528665a46d50f319e16f7..66be0cd2e320452157f145411416a59a9ddbd7d1 100644 --- a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc @@ -32,49 +32,58 @@ using mkldnn::softmax_forward; using mkldnn::stream; using platform::to_void_cast; +template class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { public: - SoftmaxMKLDNNHandler(const platform::MKLDNNDeviceContext& dev_ctx, + SoftmaxMKLDNNHandler(const std::vector& dims, + const mkldnn::memory::format fmt, + const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, const std::string& base_key) - : platform::MKLDNNHandler(dev_ctx, engine, base_key) {} + : platform::MKLDNNHandler(dev_ctx, engine, base_key), + dims_(dims), + fmt_(fmt) {} - SoftmaxMKLDNNHandler( - std::shared_ptr softmax_pd, - std::shared_ptr softmax_bwd_pd, - const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, - const std::string& base_key) + SoftmaxMKLDNNHandler(const std::vector& dims, + const mkldnn::memory::format fmt, + const mkldnn::memory::format diff_fmt, + const platform::MKLDNNDeviceContext& dev_ctx, + mkldnn::engine engine, const std::string& base_key) : platform::MKLDNNHandler(dev_ctx, engine, base_key), - softmax_pd_(softmax_pd), - softmax_bwd_pd_(softmax_bwd_pd) { + dims_(dims), + fmt_(fmt), + diff_fmt_(diff_fmt) { // If we are in Grad operatgor then update a key with BWD suffix to // distinguish from FWD memory primitives + // Key_common will allow to access FWD_PD from cache key_ += "-BWD"; } - std::shared_ptr - AcquireSoftmaxPrimitiveDescriptor(const softmax_forward::desc& softmax_desc, - const mkldnn::engine& engine) { - // Softmax PD has to be passed to Grad op that - // may be executed by diffrent thread, hence - // for that one we use key that does not contain TID - const std::string key_softmax_pd = key_common_ + "@softmax_pd"; + // TODO(jczaja): Once fwd_pd_ are moved to MKLDNNHandler then this function + // should be moved as well eg. SoftmaxMKLDNNHandler -> MKLDNNHandler + std::shared_ptr AcquireSrcMemory(void* ptr) { + return this->AcquireMemory(dims_, platform::MKLDNNGetDataType(), fmt_, + ptr, "@user_src_mem_p"); + } - softmax_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_softmax_pd)); - if (softmax_pd_ == nullptr) { - static std::mutex acquire_barrier; - std::lock_guard block_threads_until_finish_this_job( - acquire_barrier); - softmax_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_softmax_pd)); - if (softmax_pd_ == nullptr) { - softmax_pd_.reset( - new softmax_forward::primitive_desc(softmax_desc, engine)); - dev_ctx_.SetBlob(key_softmax_pd, softmax_pd_); - } - } + std::shared_ptr AcquireDstMemory(void* ptr) { + return this->AcquireMemory(dims_, platform::MKLDNNGetDataType(), fmt_, + ptr, "@user_dst_mem_p"); + } + + std::shared_ptr AcquireDiffDstMemory(void* ptr) { + return this->AcquireMemory(dims_, platform::MKLDNNGetDataType(), + diff_fmt_, ptr, "@user_diff_dst_mem_p"); + } + + std::shared_ptr AcquireDiffSrcMemory(void* ptr) { + return this->AcquireMemory(dims_, platform::MKLDNNGetDataType(), + diff_fmt_, ptr, "@user_diff_src_mem_p"); + } - return softmax_pd_; + std::shared_ptr AcquireDstMemoryFromPrimitive(void* ptr) { + this->AcquireSoftmaxPrimitiveDescriptor(); + return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr, + "@dst_mem_p"); } std::shared_ptr AcquireSoftmax( @@ -86,8 +95,9 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { auto softmax_p = std::static_pointer_cast( dev_ctx_.GetBlob(prim_key)); if (softmax_p == nullptr) { + this->AcquireSoftmaxPrimitiveDescriptor(); softmax_p = std::make_shared( - *softmax_pd_, *(static_cast(src_memory_p.get())), + *fwd_pd_, *(static_cast(src_memory_p.get())), *(static_cast(dst_memory_p.get()))); dev_ctx_.SetBlob(prim_key, softmax_p); } @@ -103,8 +113,19 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { auto softmax_bwd_p = std::static_pointer_cast( dev_ctx_.GetBlob(prim_key)); if (softmax_bwd_p == nullptr) { + auto data_softmax_md = + mkldnn::memory::desc(dims_, platform::MKLDNNGetDataType(), fmt_); + auto diff_softmax_md = mkldnn::memory::desc( + dims_, platform::MKLDNNGetDataType(), diff_fmt_); + // TODO(jczaja): Add support for other axes + auto softmax_bwd_desc = softmax_backward::desc( + diff_softmax_md, data_softmax_md, 1 /* dim: C*/); + this->AcquireSoftmaxPrimitiveDescriptor(); + auto softmax_bwd_pd = mkldnn::softmax_backward::primitive_desc( + softmax_bwd_desc, engine_, *fwd_pd_); + softmax_bwd_p = std::make_shared( - *softmax_bwd_pd_, *dst_memory_p, *diff_dst_memory_p, + softmax_bwd_pd, *dst_memory_p, *diff_dst_memory_p, *diff_src_memory_p); dev_ctx_.SetBlob(prim_key, softmax_bwd_p); } @@ -112,9 +133,41 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { return softmax_bwd_p; } + protected: + void AcquireSoftmaxPrimitiveDescriptor(void) { + // Softmax PD has to be passed to Grad op that + // may be executed by diffrent thread, hence + // for that one we use key that does not contain TID + const std::string key_softmax_pd = key_common_ + "@softmax_pd"; + + fwd_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_softmax_pd)); + if (fwd_pd_ == nullptr) { + static std::mutex acquire_barrier; + std::lock_guard block_threads_until_finish_this_job( + acquire_barrier); + fwd_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_softmax_pd)); + if (fwd_pd_ == nullptr) { + // TODO(jczaja): Make it working along chosen axis and for + // forward_training + // Normalization is made after innermost dimension eg. C out of NC + auto md = + mkldnn::memory::desc(dims_, platform::MKLDNNGetDataType(), fmt_); + auto softmax_desc = + softmax_forward::desc(prop_kind::forward_scoring, md, 1 /*dim: C*/); + fwd_pd_.reset( + new softmax_forward::primitive_desc(softmax_desc, engine_)); + dev_ctx_.SetBlob(key_softmax_pd, fwd_pd_); + } + } + } + private: - std::shared_ptr softmax_pd_; - std::shared_ptr softmax_bwd_pd_; + std::vector dims_; + mkldnn::memory::format fmt_; + mkldnn::memory::format diff_fmt_; + std::shared_ptr fwd_pd_; }; template @@ -154,21 +207,14 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel { const std::string key = platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Output("Out")); - SoftmaxMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); - // Currently only NC data format is supported - auto softmax_md = MKLDNNMemDesc( - {softmax_tz}, platform::MKLDNNGetDataType(), memory::format::nc); - // Normalization is made after innermost dimension eg. C out of NC - auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring, - softmax_md, 1 /*dim: C*/); - - auto softmax_pd = - handler.AcquireSoftmaxPrimitiveDescriptor(softmax_desc, mkldnn_engine); + SoftmaxMKLDNNHandler handler(softmax_tz, mkldnn::memory::format::nc, + dev_ctx, mkldnn_engine, key); + // Currently only NC data format is supported auto softmax_src_memory_p = - handler.AcquireSrcMemory(softmax_md, to_void_cast(input_data)); + handler.AcquireSrcMemory(to_void_cast(input_data)); auto softmax_dst_memory_p = - handler.AcquireDstMemory(softmax_md, to_void_cast(output_data)); + handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); auto softmax_p = handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p); @@ -241,25 +287,16 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel { // TODO(jczaja): Add layouts support when there is a need to do so // Two dimensional softmax does support NC format - auto data_softmax_md = MKLDNNMemDesc( - {softmax_tz}, platform::MKLDNNGetDataType(), memory::format::nc); - auto diff_softmax_md = MKLDNNMemDesc( - {softmax_tz}, platform::MKLDNNGetDataType(), memory::format::nc); // Normalization is made after innermost dimension eg. C out of NC - auto softmax_bwd_desc = - softmax_backward::desc(diff_softmax_md, data_softmax_md, 1 /* dim: C*/); - auto softmax_bwd_pd = - std::make_shared( - softmax_bwd_desc, mkldnn_engine, *softmax_pd); - - SoftmaxMKLDNNHandler handler(softmax_pd, softmax_bwd_pd, dev_ctx, - mkldnn_engine, key); - auto dst_memory_p = - handler.AcquireDstMemory(data_softmax_md, to_void_cast(dst_data)); - auto diff_dst_memory_p = handler.AcquireDiffDstMemory( - diff_softmax_md, to_void_cast(diff_dst_ptr)); - auto diff_src_memory_p = handler.AcquireDiffSrcMemory( - diff_softmax_md, to_void_cast(diff_src_ptr)); + SoftmaxMKLDNNHandler handler(softmax_tz, mkldnn::memory::format::nc, + mkldnn::memory::format::nc, dev_ctx, + mkldnn_engine, key); + + auto dst_memory_p = handler.AcquireDstMemory(to_void_cast(dst_data)); + auto diff_dst_memory_p = + handler.AcquireDiffDstMemory(to_void_cast(diff_dst_ptr)); + auto diff_src_memory_p = + handler.AcquireDiffSrcMemory(to_void_cast(diff_src_ptr)); // Get primitve from device context auto softmax_bwd_p = handler.AcquireSoftmaxBackward( diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index cec6bb792efb7d048665f63d68b554282a89e3ba..8285e61a069faccd9f76f9b75a965a85bd30d114 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -119,6 +119,25 @@ class MKLDNNHandler { return mem_p; } + std::shared_ptr AcquireMemory( + const std::vector& dims, const mkldnn::memory::data_type dtype, + const mkldnn::memory::format& fmt, void* ptr, const std::string& suffix) { + /*Generate key*/ + auto local_key = key_ + suffix; + auto mem_p = + std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); + if (mem_p == nullptr) { + auto md = mkldnn::memory::desc(dims, dtype, fmt); + + mem_p = std::make_shared( + mkldnn::memory::primitive_desc{md, engine_}, ptr); + dev_ctx_.SetBlob(local_key, mem_p); + } else { + mem_p->set_data_handle(ptr); + } + return mem_p; + } + std::shared_ptr AcquireMemory( const mkldnn::memory::primitive_desc& mpd, const std::string& suffix) { auto local_key = key_ + suffix; @@ -949,18 +968,7 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { std::shared_ptr AcquireSrcMemory( const mkldnn::memory::format& fmt, void* ptr) { - auto local_key = key_ + "@user_src_mem_p"; - auto mem_p = - std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); - if (mem_p == nullptr) { - auto src_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt); - mem_p = std::make_shared( - mkldnn::memory::primitive_desc{src_md, engine_}, ptr); - dev_ctx_.SetBlob(local_key, mem_p); - } else { - mem_p->set_data_handle(ptr); - } - return mem_p; + return this->AcquireMemory(dims_, dtype_, fmt, ptr, "@user_src_mem_p"); } std::shared_ptr AcquireDstMemory(