From c7e688921bf4c34036bb15d55e3bfa530a63b2f5 Mon Sep 17 00:00:00 2001 From: Adam <38704900+grygielski@users.noreply.github.com> Date: Thu, 19 Sep 2019 12:15:49 +0200 Subject: [PATCH] Add template functions for Acquire primitive/primitive_desc (#19867) * Add template functions for Acquire primitive/primitive_desc test=develop * Move acquire primitive descriptor to protected section test=develop --- .../operators/mkldnn/activation_mkldnn_op.cc | 7 +- .../operators/mkldnn/softmax_mkldnn_op.cc | 119 ++--------- paddle/fluid/platform/mkldnn_reuse.h | 185 ++++++++---------- 3 files changed, 100 insertions(+), 211 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc index b706982c3b..414576f1a2 100644 --- a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc @@ -92,7 +92,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx, auto src_memory_p = handler.AcquireSrcMemory(x); auto dst_memory_p = handler.AcquireDstMemory(y); - auto activation_p = handler.AcquireActivation(dst_memory_p, src_memory_p); + auto activation_p = + handler.AcquireForwardPrimitive(*src_memory_p, *dst_memory_p); // push primitive to stream and wait until it's executed std::vector pipeline; @@ -131,8 +132,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx, auto src_memory_p = handler.AcquireBackwardSrcMemory(x); auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y); auto diff_src_memory_p = handler.AcquireDiffSrcMemory(diff_x); - auto activation_backward_p = handler.AcquireActivationBackward( - diff_src_memory_p, diff_dst_memory_p, src_memory_p); + auto activation_backward_p = handler.AcquireBackwardPrimitive( + *src_memory_p, *diff_dst_memory_p, *diff_src_memory_p); // push primitive to stream and wait until it's executed std::vector pipeline; diff --git a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc index 8426348a11..cd53a07aca 100644 --- a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc @@ -45,7 +45,10 @@ class SoftmaxMKLDNNHandler mkldnn::softmax_backward>( dev_ctx, dev_ctx.GetEngine(), cpu_place, platform::CreateKey(dims, uniq_name)) { - this->AcquireSoftmaxPrimitiveDescriptor(dims, fmt); + auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); + + this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md, + 1 /*dim: C*/); } SoftmaxMKLDNNHandler(const std::vector& dims, @@ -57,100 +60,15 @@ class SoftmaxMKLDNNHandler mkldnn::softmax_backward>( dev_ctx, dev_ctx.GetEngine(), cpu_place, platform::CreateKey(dims, uniq_name)) { - // If we are in Grad operatgor then update a key with BWD suffix to - // distinguish from FWD memory primitives - // Key_common will allow to access FWD_PD from cache - this->AcquireSoftmaxPrimitiveDescriptor(dims, fmt); - this->AcquireSoftmaxBackwardPrimitiveDescriptor(dims, fmt, diff_fmt); - } - - std::shared_ptr AcquireSoftmax( - std::shared_ptr dst_memory_p, - std::shared_ptr src_memory_p) { - /*Generate key*/ - auto prim_key = this->key_ + "@softmax_p"; - - auto softmax_p = std::static_pointer_cast( - this->dev_ctx_.GetBlob(prim_key)); - if (softmax_p == nullptr) { - softmax_p = std::make_shared( - *this->fwd_pd_, *(static_cast(src_memory_p.get())), - *(static_cast(dst_memory_p.get()))); - this->dev_ctx_.SetBlob(prim_key, softmax_p); - } - - return softmax_p; - } - - std::shared_ptr AcquireSoftmaxBackward( - std::shared_ptr dst_memory_p, - std::shared_ptr diff_dst_memory_p, - std::shared_ptr diff_src_memory_p) { - auto prim_key = this->key_ + "@softmax_bwd_p"; - auto softmax_bwd_p = std::static_pointer_cast( - this->dev_ctx_.GetBlob(prim_key)); - if (softmax_bwd_p == nullptr) { - softmax_bwd_p = std::make_shared( - *this->bwd_pd_, *dst_memory_p, *diff_dst_memory_p, - *diff_src_memory_p); - this->dev_ctx_.SetBlob(prim_key, softmax_bwd_p); - } - - return softmax_bwd_p; - } - - protected: - void AcquireSoftmaxPrimitiveDescriptor(const std::vector& dims, - const mkldnn::memory::format fmt) { - // Softmax PD has to be passed to Grad op that - // may be executed by diffrent thread, hence - // for that one we use key that does not contain TID - const std::string key_softmax_pd = this->key_common_ + "@softmax_pd"; - - this->fwd_pd_ = std::static_pointer_cast( - this->dev_ctx_.GetBlob(key_softmax_pd)); - if (this->fwd_pd_ == nullptr) { - static std::mutex acquire_barrier; - std::lock_guard block_threads_until_finish_this_job( - acquire_barrier); - this->fwd_pd_ = std::static_pointer_cast( - this->dev_ctx_.GetBlob(key_softmax_pd)); - if (this->fwd_pd_ == nullptr) { - // TODO(jczaja): Make it working along chosen axis and for - // forward_training - // Normalization is made after innermost dimension eg. C out of NC - auto md = - mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); - auto softmax_desc = - softmax_forward::desc(prop_kind::forward_scoring, md, 1 /*dim: C*/); - this->fwd_pd_.reset( - new softmax_forward::primitive_desc(softmax_desc, this->engine_)); - this->dev_ctx_.SetBlob(key_softmax_pd, this->fwd_pd_); - } - } - } - - void AcquireSoftmaxBackwardPrimitiveDescriptor( - const std::vector& dims, const mkldnn::memory::format fmt, - const mkldnn::memory::format diff_fmt) { - // Fwd_PD_ has to exists when to create BWD_PD_ - PADDLE_ENFORCE_NOT_NULL(this->fwd_pd_); - const std::string key_bwd_pd = this->key_ + "@softmax_bwd_pd"; - this->bwd_pd_ = - std::static_pointer_cast( - this->dev_ctx_.GetBlob(key_bwd_pd)); - if (this->bwd_pd_ == nullptr) { - auto data_softmax_md = - mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); - auto diff_softmax_md = mkldnn::memory::desc( - dims, platform::MKLDNNGetDataType(), diff_fmt); - // TODO(jczaja): Add support for other axes - auto backward_desc = softmax_backward::desc( - diff_softmax_md, data_softmax_md, 1 /* dim: C*/); - this->bwd_pd_.reset(new mkldnn::softmax_backward::primitive_desc( - backward_desc, this->engine_, *this->fwd_pd_)); - this->dev_ctx_.SetBlob(key_bwd_pd, this->bwd_pd_); - } + auto data_softmax_md = + mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); + auto diff_softmax_md = + mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), diff_fmt); + + this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, + data_softmax_md, 1 /*dim: C*/); + this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md, + 1 /* dim: C*/); } }; @@ -181,11 +99,10 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel { // Currently only NC data format is supported auto softmax_src_memory_p = handler.AcquireSrcMemory(input); auto softmax_dst_memory_p = handler.AcquireDstMemory(output); - auto softmax_p = - handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p); + auto softmax_p = handler.AcquireForwardPrimitive(*softmax_src_memory_p, + *softmax_dst_memory_p); - std::vector pipeline{ - *(static_cast(softmax_p.get()))}; + std::vector pipeline{*softmax_p}; stream(stream::kind::eager).submit(pipeline).wait(); T* output_data = output->mutable_data(ctx.GetPlace()); @@ -242,8 +159,8 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel { auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx); // Get primitve from device context - auto softmax_bwd_p = handler.AcquireSoftmaxBackward( - dst_memory_p, diff_dst_memory_p, diff_src_memory_p); + auto softmax_bwd_p = handler.AcquireBackwardPrimitive( + *dst_memory_p, *diff_dst_memory_p, *diff_src_memory_p); std::vector pipeline{*softmax_bwd_p}; stream(stream::kind::eager).submit(pipeline).wait(); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 1ebc89a8af..c6b58fc093 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include +#include #include #include "boost/optional.hpp" #include "paddle/fluid/framework/data_layout_transform.h" @@ -48,6 +49,32 @@ class MKLDNNHandlerT { } } + template + std::shared_ptr AcquireForwardPrimitive(Args&&... args) { + const std::string key_p = key_ + "@forward_p"; + auto forward_p = + std::static_pointer_cast(dev_ctx_.GetBlob(key_p)); + if (forward_p == nullptr) { + forward_p = + std::make_shared(*fwd_pd_, std::forward(args)...); + dev_ctx_.SetBlob(key_p, forward_p); + } + return forward_p; + } + + template + std::shared_ptr AcquireBackwardPrimitive(Args&&... args) { + const std::string key_p = key_ + "@backward_p"; + auto backward_p = + std::static_pointer_cast(dev_ctx_.GetBlob(key_p)); + if (backward_p == nullptr) { + backward_p = + std::make_shared(*bwd_pd_, std::forward(args)...); + dev_ctx_.SetBlob(key_p, backward_p); + } + return backward_p; + } + std::shared_ptr AcquireSrcMemory( const framework::Tensor* input) { const T* input_data = input->data(); @@ -87,6 +114,44 @@ class MKLDNNHandlerT { ptr, "@diff_src_mem_p"); } + protected: + template + void AcquireForwardPrimitiveDescriptor(Args&&... args) { + // Forward PD has to be passed to Grad op that + // may be executed by diffrent thread, hence + // for that one we use key that does not contain TID + const std::string key_pd = key_common_ + "@forward_pd"; + fwd_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_pd)); + if (fwd_pd_ == nullptr) { + static std::mutex acquire_barrier; + std::lock_guard block_threads_until_finish_this_job( + acquire_barrier); + fwd_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_pd)); + if (fwd_pd_ == nullptr) { + auto fwd_desc = typename TForward::desc(std::forward(args)...); + fwd_pd_ = std::make_shared(fwd_desc, + engine_); + dev_ctx_.SetBlob(key_pd, fwd_pd_); + } + } + } + + template + void AcquireBackwardPrimitiveDescriptor(Args&&... args) { + PADDLE_ENFORCE_NOT_NULL(fwd_pd_); + const std::string key_pd = key_ + "@backward_pd"; + bwd_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_pd)); + if (bwd_pd_ == nullptr) { + auto bwd_desc = typename TBackward::desc(std::forward(args)...); + bwd_pd_ = std::make_shared( + bwd_desc, engine_, *fwd_pd_); + dev_ctx_.SetBlob(key_pd, bwd_pd_); + } + } + std::shared_ptr AcquireMemoryFromPrimitive( mkldnn::memory::primitive_desc mdp, void* ptr, const std::string& suffix) { @@ -102,7 +167,6 @@ class MKLDNNHandlerT { return mem_p; } - protected: const MKLDNNDeviceContext& dev_ctx_; mkldnn::engine engine_; platform::Place place_; @@ -355,10 +419,12 @@ class ActivationMKLDNNHandler dev_ctx, dev_ctx.GetEngine(), cpu_place, platform::CreateKey(dims, algorithm, fmt, alpha, beta, unique_name)) { - AcquireActivationPrimitiveDescriptor( + auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); + + this->AcquireForwardPrimitiveDescriptor( is_test ? mkldnn::prop_kind::forward_inference : mkldnn::prop_kind::forward_training, - algorithm, dims, fmt, alpha, beta); + algorithm, md, alpha, beta); } ActivationMKLDNNHandler(const std::vector& dims, @@ -374,10 +440,15 @@ class ActivationMKLDNNHandler dev_ctx, dev_ctx.GetEngine(), cpu_place, platform::CreateKey(dims, algorithm, fmt, alpha, beta, unique_name)) { - AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind::forward_training, - algorithm, dims, fmt, alpha, beta); - AcquireActivationBackwardPrimitiveDescriptor(algorithm, dims, fmt, diff_fmt, - alpha, beta); + auto diff_dst_md = platform::MKLDNNMemDesc( + dims, platform::MKLDNNGetDataType(), diff_fmt); + auto src_md = + platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType(), fmt); + + this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training, + algorithm, src_md, alpha, beta); + this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md, + alpha, beta); } std::shared_ptr AcquireBackwardSrcMemory( @@ -387,106 +458,6 @@ class ActivationMKLDNNHandler to_void_cast(input_data), "@bwd-src_mem_p"); } - - std::shared_ptr AcquireActivation( - std::shared_ptr dst_memory_p, - std::shared_ptr src_memory_p) { - /*Generate key*/ - auto prim_key = this->key_ + "@eltwise_p"; - - auto eltwise_p = std::static_pointer_cast( - this->dev_ctx_.GetBlob(prim_key)); - if (eltwise_p == nullptr) { - eltwise_p = std::make_shared( - *this->fwd_pd_, *(src_memory_p), *(dst_memory_p)); - this->dev_ctx_.SetBlob(prim_key, eltwise_p); - } - - return eltwise_p; - } - - std::shared_ptr AcquireActivationBackward( - std::shared_ptr diff_src_memory_p, - std::shared_ptr diff_dst_memory_p, - std::shared_ptr src_memory_p) { - /*Generate key*/ - auto prim_key = this->key_ + "@eltwise_bwd_p"; - - auto eltwise_bwd_p = std::static_pointer_cast( - this->dev_ctx_.GetBlob(prim_key)); - if (eltwise_bwd_p == nullptr) { - eltwise_bwd_p = std::make_shared( - *this->bwd_pd_, *(src_memory_p), *(diff_dst_memory_p), - *(diff_src_memory_p)); - this->dev_ctx_.SetBlob(prim_key, eltwise_bwd_p); - } - - return eltwise_bwd_p; - } - - protected: - void AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind prop_kind, - mkldnn::algorithm algorithm, - const std::vector& dims, - const MKLDNNMemoryFormat fmt, - float alpha, float beta) { - // Activation PD has to be passed to Grad op that - // may be executed by diffrent thread, hence - // for that one we use key that does not contain TID - const std::string key_activation_pd = this->key_common_ + "@activation_pd"; - this->fwd_pd_ = - std::static_pointer_cast( - this->dev_ctx_.GetBlob(key_activation_pd)); - if (this->fwd_pd_ == nullptr) { - static std::mutex acquire_barrier; - std::lock_guard block_threads_until_finish_this_job( - acquire_barrier); - - this->fwd_pd_ = - std::static_pointer_cast( - this->dev_ctx_.GetBlob(key_activation_pd)); - if (this->fwd_pd_ == nullptr) { - auto md = platform::MKLDNNMemDesc( - dims, platform::MKLDNNGetDataType(), fmt); - auto activation_desc = mkldnn::eltwise_forward::desc( - prop_kind, algorithm, md, alpha, beta); - - this->fwd_pd_.reset(new mkldnn::eltwise_forward::primitive_desc( - activation_desc, this->engine_)); - this->dev_ctx_.SetBlob(key_activation_pd, this->fwd_pd_); - } - } - } - - void AcquireActivationBackwardPrimitiveDescriptor( - mkldnn::algorithm algorithm, const std::vector& dims, - const MKLDNNMemoryFormat fmt, const MKLDNNMemoryFormat diff_fmt, - float alpha, float beta) { - const std::string key_activation_pd = this->key_common_ + "@activation_pd"; - const std::string key_activation_bwd_pd = this->key_ + "@activation_bwd_pd"; - this->bwd_pd_ = - std::static_pointer_cast( - this->dev_ctx_.GetBlob(key_activation_bwd_pd)); - if (this->bwd_pd_ == nullptr) { - this->fwd_pd_ = - std::static_pointer_cast( - this->dev_ctx_.GetBlob(key_activation_pd)); - // PD from FWD op has to exist. - PADDLE_ENFORCE_NOT_NULL(this->fwd_pd_, - "Eltwise MKL-DNN not found in cache!"); - - auto diff_dst_md = platform::MKLDNNMemDesc( - dims, platform::MKLDNNGetDataType(), diff_fmt); - auto src_md = - platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType(), fmt); - - auto backward_desc = mkldnn::eltwise_backward::desc( - algorithm, diff_dst_md, src_md, alpha, beta); - this->bwd_pd_.reset(new mkldnn::eltwise_backward::primitive_desc( - backward_desc, this->engine_, *this->fwd_pd_)); - this->dev_ctx_.SetBlob(key_activation_bwd_pd, this->bwd_pd_); - } - } }; class LRNMKLDNNHandler : public MKLDNNHandler { -- GitLab