From 9e4c958552e840619a27111cf613caba86daf362 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Thu, 12 Sep 2019 05:02:23 +0200 Subject: [PATCH] Refactoring activation mkldnn op (#19748) test=develop - fix to BWD test=develop --- .../operators/mkldnn/activation_mkldnn_op.cc | 61 +---- paddle/fluid/platform/mkldnn_reuse.h | 210 ++++++++++++------ 2 files changed, 155 insertions(+), 116 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc index dec64ba08e..264ca81def 100644 --- a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc @@ -83,13 +83,10 @@ void eltwise_forward(const framework::ExecutionContext &ctx, PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); auto &dev_ctx = ctx.template device_context(); - const auto &mkldnn_engine = dev_ctx.GetEngine(); const auto *x = ctx.Input("X"); auto *y = ctx.Output("Out"); - const T *x_data = x->data(); - const T alpha = ctx.op().HasAttr("alpha") ? ctx.Attr("alpha") : 0; const T beta = ctx.op().HasAttr("beta") ? ctx.Attr("beta") : 0; @@ -103,23 +100,12 @@ void eltwise_forward(const framework::ExecutionContext &ctx, bool is_test = ctx.Attr("is_test"); - std::string key = platform::ActivationMKLDNNHandler::GetHash( - src_tz, algorithm, src_format, alpha, beta, ctx.op().Input("X")); - - platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); - - auto md = platform::MKLDNNMemDesc(src_tz, platform::MKLDNNGetDataType(), - src_format); - - auto activation_pd = handler.AcquireActivationPrimitiveDescriptor( - is_test ? mkldnn::prop_kind::forward_inference - : mkldnn::prop_kind::forward_training, - algorithm, md, alpha, beta); + platform::ActivationMKLDNNHandler handler( + src_tz, algorithm, alpha, beta, src_format, is_test, dev_ctx, + ctx.GetPlace(), ctx.op().Input("X")); - auto src_memory_p = handler.AcquireSrcMemory(md, to_void_cast(x_data)); - - auto dst_memory_p = - handler.AcquireDstMemoryFromPrimitive(y, ctx.GetPlace()); + auto src_memory_p = handler.AcquireSrcMemory(x); + auto dst_memory_p = handler.AcquireDstMemory(y); auto activation_p = handler.AcquireActivation(dst_memory_p, src_memory_p); // push primitive to stream and wait until it's executed @@ -135,17 +121,11 @@ template void eltwise_grad(const framework::ExecutionContext &ctx, mkldnn::algorithm algorithm) { auto &dev_ctx = ctx.template device_context(); - const auto &mkldnn_engine = dev_ctx.GetEngine(); const auto *x = ctx.Input("X"); - const T *x_data = x->data(); - const auto *diff_y = ctx.Input(framework::GradVarName("Out")); auto *diff_x = ctx.Output(framework::GradVarName("X")); - const T *diff_y_data = diff_y->data(); - T *diff_x_data = diff_x->mutable_data(ctx.GetPlace()); - const T alpha = ctx.op().HasAttr("alpha") ? ctx.Attr("alpha") : 0; const T beta = ctx.op().HasAttr("beta") ? ctx.Attr("beta") : 0; @@ -158,32 +138,13 @@ void eltwise_grad(const framework::ExecutionContext &ctx, auto diff_y_format = diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : diff_y->format(); - auto diff_dst_md = platform::MKLDNNMemDesc( - diff_dst_tz, platform::MKLDNNGetDataType(), diff_y_format); - - std::string key = platform::ActivationMKLDNNHandler::GetHash( - diff_dst_tz, algorithm, src_format, alpha, beta, ctx.op().Input("X")); - - const std::string key_src_data = key + "@eltwise_fwd_src_data"; - - auto src_md = platform::MKLDNNMemDesc( - diff_dst_tz, platform::MKLDNNGetDataType(), src_format); - - platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); - - auto src_memory_p = handler.AcquireSrcMemory(src_md, to_void_cast(x_data)); - - auto diff_dst_memory_p = - handler.AcquireDiffDstMemory(diff_dst_md, to_void_cast(diff_y_data)); - - auto activation_backward_pd = - handler.AcquireActivationBackwardPrimitiveDescriptor( - algorithm, diff_dst_md, src_memory_p->get_primitive_desc().desc(), - alpha, beta); - - auto diff_src_memory_p = - handler.AcquireDiffSrcMemoryFromPrimitive(diff_x_data); + platform::ActivationMKLDNNHandler handler( + diff_dst_tz, algorithm, alpha, beta, src_format, diff_y_format, dev_ctx, + ctx.GetPlace(), ctx.op().Input("X")); + auto src_memory_p = handler.AcquireBackwardSrcMemory(x); + auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y); + auto diff_src_memory_p = handler.AcquireDiffSrcMemory(diff_x); auto activation_backward_p = handler.AcquireActivationBackward( diff_src_memory_p, diff_dst_memory_p, src_memory_p); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 7da73adf13..d474b234a9 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -257,65 +257,94 @@ class SumMKLDNNHandler : public MKLDNNHandler { std::shared_ptr sum_pd_; }; +template class ActivationMKLDNNHandler : public MKLDNNHandler { public: - ActivationMKLDNNHandler(const platform::MKLDNNDeviceContext& dev_ctx, - mkldnn::engine engine, const std::string& base_key) - : platform::MKLDNNHandler(dev_ctx, engine, base_key) {} + ActivationMKLDNNHandler(const std::vector& dims, + mkldnn::algorithm algorithm, float alpha, float beta, + const MKLDNNMemoryFormat fmt, bool is_test, + const platform::MKLDNNDeviceContext& dev_ctx, + platform::Place cpu_place, + const std::string& unique_name) + + : platform::MKLDNNHandler( + dev_ctx, dev_ctx.GetEngine(), + platform::ActivationMKLDNNHandler::GetHash( + dims, algorithm, fmt, alpha, beta, unique_name)), + place_(cpu_place), + fwd_pd_(nullptr), + bwd_pd_(nullptr) { + AcquireActivationPrimitiveDescriptor( + is_test ? mkldnn::prop_kind::forward_inference + : mkldnn::prop_kind::forward_training, + algorithm, dims, fmt, alpha, beta); + } + + ActivationMKLDNNHandler(const std::vector& dims, + mkldnn::algorithm algorithm, float alpha, float beta, + const MKLDNNMemoryFormat fmt, + const MKLDNNMemoryFormat diff_fmt, + const platform::MKLDNNDeviceContext& dev_ctx, + platform::Place cpu_place, + const std::string& unique_name) + + : platform::MKLDNNHandler( + dev_ctx, dev_ctx.GetEngine(), + platform::ActivationMKLDNNHandler::GetHash( + dims, algorithm, fmt, alpha, beta, unique_name)), + place_(cpu_place), + fwd_pd_(nullptr), + bwd_pd_(nullptr) { + AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind::forward_training, + algorithm, dims, fmt, alpha, beta); + AcquireActivationBackwardPrimitiveDescriptor(algorithm, dims, fmt, diff_fmt, + alpha, beta); + } + + // TODO(jczaja): Once fwd_pd_ are moved to MKLDNNHandler then this + // function + // should be moved as well eg. ActivationMKLDNNHandler -> + // MKLDNNHandler + std::shared_ptr AcquireSrcMemory( + const framework::Tensor* input) { + const T* input_data = input->data(); + return this->AcquireMemoryFromPrimitive(fwd_pd_->src_primitive_desc(), + to_void_cast(input_data), + "@src_mem_p"); + } - std::shared_ptr - AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind prop_kind, - mkldnn::algorithm algorithm, - const mkldnn::memory::desc& md, - float alpha, float beta) { - // Activation PD has to be passed to Grad op that - // may be executed by diffrent thread, hence - // for that one we use key that does not contain TID - const std::string key_activation_pd = key_common_ + "@activation_pd"; - fwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_activation_pd)); - if (fwd_pd_ == nullptr) { - static std::mutex acquire_barrier; - std::lock_guard block_threads_until_finish_this_job( - acquire_barrier); + std::shared_ptr AcquireBackwardSrcMemory( + const framework::Tensor* input) { + const T* input_data = input->data(); + return this->AcquireMemoryFromPrimitive(bwd_pd_->src_primitive_desc(), + to_void_cast(input_data), + "@bwd-src_mem_p"); + } - fwd_pd_ = - std::static_pointer_cast( - dev_ctx_.GetBlob(key_activation_pd)); - if (fwd_pd_ == nullptr) { - auto activation_desc = mkldnn::eltwise_forward::desc( - prop_kind, algorithm, md, alpha, beta); + // TODO(jczaja): Move to MKLDNNHandler as common code + std::shared_ptr AcquireDstMemory(framework::Tensor* output) { + T* ptr = output->mutable_data(place_, + fwd_pd_->dst_primitive_desc().get_size()); + return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr, + "@dst_mem_p"); + } - fwd_pd_.reset(new mkldnn::eltwise_forward::primitive_desc( - activation_desc, engine_)); - dev_ctx_.SetBlob(key_activation_pd, fwd_pd_); - } - } - return fwd_pd_; + // TODO(jczaja): Move to MKLDNNHandler as common code + std::shared_ptr AcquireDiffDstMemory( + const framework::Tensor* diffdst) { + const T* ptr = diffdst->data(); + return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_primitive_desc(), + to_void_cast(ptr), + "@diff_dst_mem_p"); } - std::shared_ptr - AcquireActivationBackwardPrimitiveDescriptor( - mkldnn::algorithm algorithm, const mkldnn::memory::desc& diff_dst_md, - const mkldnn::memory::desc& src_md, float alpha, float beta) { - const std::string key_activation_pd = key_common_ + "@activation_pd"; - const std::string key_activation_bwd_pd = key_ + "@activation_bwd_pd"; - bwd_pd_ = - std::static_pointer_cast( - dev_ctx_.GetBlob(key_activation_bwd_pd)); - if (bwd_pd_ == nullptr) { - fwd_pd_ = - std::static_pointer_cast( - dev_ctx_.GetBlob(key_activation_pd)); - // PD from FWD op has to exist. - PADDLE_ENFORCE_NOT_NULL(fwd_pd_, "Eltwise MKL-DNN not found in cache!"); - auto backward_desc = mkldnn::eltwise_backward::desc( - algorithm, diff_dst_md, src_md, alpha, beta); - bwd_pd_.reset(new mkldnn::eltwise_backward::primitive_desc( - backward_desc, engine_, *fwd_pd_)); - dev_ctx_.SetBlob(key_activation_bwd_pd, bwd_pd_); - } - return bwd_pd_; + // TODO(jczaja): Move to MKLDNNHandler as common code + std::shared_ptr AcquireDiffSrcMemory( + framework::Tensor* diffsrc) { + T* ptr = diffsrc->mutable_data( + place_, bwd_pd_->diff_src_primitive_desc().get_size()); + return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(), + ptr, "@diff_src_mem_p"); } std::shared_ptr AcquireActivation( @@ -335,20 +364,6 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { return eltwise_p; } - template - std::shared_ptr AcquireDstMemoryFromPrimitive( - framework::Tensor* output, platform::Place place) { - T* ptr = output->mutable_data(place, - fwd_pd_->dst_primitive_desc().get_size()); - return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr, - "@dst_mem_p"); - } - - std::shared_ptr AcquireDiffSrcMemoryFromPrimitive(void* ptr) { - return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(), - ptr, "@diff_src_mem_p"); - } - std::shared_ptr AcquireActivationBackward( std::shared_ptr diff_src_memory_p, std::shared_ptr diff_dst_memory_p, @@ -383,7 +398,70 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { return key; } + protected: + void AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind prop_kind, + mkldnn::algorithm algorithm, + const std::vector& dims, + const MKLDNNMemoryFormat fmt, + float alpha, float beta) { + // Activation PD has to be passed to Grad op that + // may be executed by diffrent thread, hence + // for that one we use key that does not contain TID + const std::string key_activation_pd = key_common_ + "@activation_pd"; + fwd_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_activation_pd)); + if (fwd_pd_ == nullptr) { + static std::mutex acquire_barrier; + std::lock_guard block_threads_until_finish_this_job( + acquire_barrier); + + fwd_pd_ = + std::static_pointer_cast( + dev_ctx_.GetBlob(key_activation_pd)); + if (fwd_pd_ == nullptr) { + auto md = platform::MKLDNNMemDesc( + dims, platform::MKLDNNGetDataType(), fmt); + auto activation_desc = mkldnn::eltwise_forward::desc( + prop_kind, algorithm, md, alpha, beta); + + fwd_pd_.reset(new mkldnn::eltwise_forward::primitive_desc( + activation_desc, engine_)); + dev_ctx_.SetBlob(key_activation_pd, fwd_pd_); + } + } + } + + void AcquireActivationBackwardPrimitiveDescriptor( + mkldnn::algorithm algorithm, const std::vector& dims, + const MKLDNNMemoryFormat fmt, const MKLDNNMemoryFormat diff_fmt, + float alpha, float beta) { + const std::string key_activation_pd = key_common_ + "@activation_pd"; + const std::string key_activation_bwd_pd = key_ + "@activation_bwd_pd"; + bwd_pd_ = + std::static_pointer_cast( + dev_ctx_.GetBlob(key_activation_bwd_pd)); + if (bwd_pd_ == nullptr) { + fwd_pd_ = + std::static_pointer_cast( + dev_ctx_.GetBlob(key_activation_pd)); + // PD from FWD op has to exist. + PADDLE_ENFORCE_NOT_NULL(fwd_pd_, "Eltwise MKL-DNN not found in cache!"); + + auto diff_dst_md = platform::MKLDNNMemDesc( + dims, platform::MKLDNNGetDataType(), diff_fmt); + auto src_md = + platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType(), fmt); + + auto backward_desc = mkldnn::eltwise_backward::desc( + algorithm, diff_dst_md, src_md, alpha, beta); + bwd_pd_.reset(new mkldnn::eltwise_backward::primitive_desc( + backward_desc, engine_, *fwd_pd_)); + dev_ctx_.SetBlob(key_activation_bwd_pd, bwd_pd_); + } + } + private: + platform::Place place_; std::shared_ptr fwd_pd_; std::shared_ptr bwd_pd_; }; -- GitLab