diff --git a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc index 5b7505f3c4acdef94fead04efd00b47825274117..1767ebaf8c39d4eca40b03d8bdd4f6778f088de4 100644 --- a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc @@ -13,7 +13,7 @@ limitations under the License. */ #include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" namespace paddle { namespace operators { @@ -99,20 +99,21 @@ void eltwise_forward(const framework::ExecutionContext &ctx, auto src_format = src_tz.size() == 2 ? mkldnn::memory::format::nc : x->format(); - const std::string key = gethash(src_tz, algorithm); - const std::string key_src_data = - key + ctx.op().Output("Out") + "@eltwise_fwd_src_data"; - const std::string key_src_layout = - key + ctx.op().Output("Out") + "@eltwise_fwd_src_layout"; - const std::string key_with_layout = key + std::to_string(src_format); - const std::string key_src_mem = key_with_layout + "@eltwise_fwd_src_mem"; - const std::string key_dst_mem = key_with_layout + "@eltwise_fwd_dst_mem"; - const std::string key_fwd = key_with_layout + "@eltwise_fwd"; - const std::string key_fwd_pd = key_with_layout + "@eltwise_fwd_pd"; - bool is_test = ctx.Attr("is_test"); + // TODO(jczaja): When adding leaky-relu , swish , elu make sure to extend key + // with alpha, beta + std::string key = platform::MKLDNNHandler::GetHash( + src_tz, std::to_string(algorithm) + ctx.op().Output("Out")); + + // TODO(jczaja): Make it Thread safe // save input data and layout to be referred in backward path + const std::string key_src_data = key + "@eltwise_fwd_src_data"; + const std::string key_src_layout = key + "@eltwise_fwd_src_layout"; + // Just in case some int8 models are run interchangebly + // with float models then format maybe diffrent + key += std::to_string(src_format); + const std::string key_src_mem = key + "@eltwise_fwd_src_mem"; auto p_src_data = std::make_shared(x_data); auto p_src_layout = std::make_shared(src_format); if (!is_test) { @@ -120,65 +121,34 @@ void eltwise_forward(const framework::ExecutionContext &ctx, dev_ctx.SetBlob(key_src_layout, p_src_layout); } - auto p_fwd = std::static_pointer_cast( - dev_ctx.GetBlob(key_fwd)); - - std::shared_ptr dst_memory; - - if (p_fwd == nullptr) { - // create mkldnn memory for input X - auto src_md = platform::MKLDNNMemDesc( - src_tz, platform::MKLDNNGetDataType(), src_format); - auto src_memory = std::shared_ptr( - new memory({src_md, mkldnn_engine}, to_void_cast(x_data))); - // save src_memory to be referred in backward path - dev_ctx.SetBlob(key_src_mem, src_memory); - - // create primitive descriptor for activation forward and save it - auto mkldnn_forward_prop_kind = is_test - ? mkldnn::prop_kind::forward_inference - : mkldnn::prop_kind::forward_training; - auto forward_desc = mkldnn::eltwise_forward::desc( - mkldnn_forward_prop_kind, algorithm, - src_memory->get_primitive_desc().desc(), alpha, beta); - auto forward_pd = std::make_shared( - forward_desc, mkldnn_engine); - - // save prim desc into global device context to be referred in backward path - if (!is_test) dev_ctx.SetBlob(key_fwd_pd, forward_pd); - - // create mkldnn memory for output y - dst_memory = - std::make_shared(forward_pd->dst_primitive_desc(), y_data); - - dev_ctx.SetBlob(key_dst_mem, dst_memory); - - // create activation primitive - p_fwd = std::make_shared(*forward_pd, *src_memory, - *dst_memory); - dev_ctx.SetBlob(key_fwd, p_fwd); - } else { - // primitives already exist - auto src_memory = - std::static_pointer_cast(dev_ctx.GetBlob(key_src_mem)); - PADDLE_ENFORCE(src_memory != nullptr, - "Fail to find eltwise src_memory in device context."); - dst_memory = - std::static_pointer_cast(dev_ctx.GetBlob(key_dst_mem)); - PADDLE_ENFORCE(dst_memory != nullptr, - "Fail to find eltwise dst_memory in device context."); - - src_memory->set_data_handle(platform::to_void_cast(x_data)); - dst_memory->set_data_handle(y_data); + platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); + + auto md = platform::MKLDNNMemDesc(src_tz, platform::MKLDNNGetDataType(), + src_format); + + auto activation_pd = handler.AcquireActivationPrimitiveDescriptor( + is_test ? mkldnn::prop_kind::forward_inference + : mkldnn::prop_kind::forward_training, + algorithm, md, alpha, beta); + + auto src_memory_p = handler.AcquireSrcMemory(md, to_void_cast(x_data)); + // jczaja: Workaround, src_memory_p is needed in BWD so it has + // to be accessible under key not dependant on TID + if (!is_test) { + dev_ctx.SetBlob(key_src_mem, src_memory_p); } + auto dst_memory_p = + handler.AcquireDstMemoryFromPrimitive(to_void_cast(y_data)); + auto activation_p = handler.AcquireActivation(dst_memory_p, src_memory_p); + // push primitive to stream and wait until it's executed std::vector pipeline; - pipeline.push_back(*p_fwd); + pipeline.push_back(*activation_p); stream(stream::kind::eager).submit(pipeline).wait(); y->set_layout(DataLayout::kMKLDNN); - y->set_format(GetMKLDNNFormat(*dst_memory)); + y->set_format(GetMKLDNNFormat(*dst_memory_p)); } template @@ -199,90 +169,51 @@ void eltwise_grad(const framework::ExecutionContext &ctx, auto diff_y_format = diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : diff_y->format(); - const std::string key = gethash(diff_dst_tz, algorithm); - const std::string key_src_data = - key + ctx.op().Input("Out") + "@eltwise_fwd_src_data"; - const std::string key_src_layout = - key + ctx.op().Input("Out") + "@eltwise_fwd_src_layout"; + auto diff_dst_md = platform::MKLDNNMemDesc( + diff_dst_tz, platform::MKLDNNGetDataType(), diff_y_format); + + std::string key = platform::MKLDNNHandler::GetHash( + diff_dst_tz, std::to_string(algorithm) + ctx.op().Input("Out")); + + const std::string key_src_data = key + "@eltwise_fwd_src_data"; + const std::string key_src_layout = key + "@eltwise_fwd_src_layout"; + + // Get Data from FWD op const auto p_src_layout = std::static_pointer_cast(dev_ctx.GetBlob(key_src_layout)); - const std::string key_src_mem = - key + std::to_string(*p_src_layout) + "@eltwise_fwd_src_mem"; - const std::string key_fwd_pd = - key + std::to_string(*p_src_layout) + "@eltwise_fwd_pd"; - const std::string key_with_layouts = - key + std::to_string(*p_src_layout) + "-" + std::to_string(diff_y_format); - const std::string key_diff_src_mem = - key_with_layouts + "@eltwise_diff_src_mem"; - const std::string key_diff_dst_mem = - key_with_layouts + "@eltwise_diff_dst_mem"; - const std::string key_grad = key_with_layouts + "@eltwise_grad"; - const auto p_src_data = std::static_pointer_cast(dev_ctx.GetBlob(key_src_data)); - + key += std::to_string(*p_src_layout); + const std::string key_src_mem = key + "@eltwise_fwd_src_mem"; auto src_memory = std::static_pointer_cast(dev_ctx.GetBlob(key_src_mem)); PADDLE_ENFORCE(src_memory != nullptr, "Fail to find src_memory in device context"); src_memory->set_data_handle(*p_src_data); - std::shared_ptr diff_src_memory; - - auto p_grad = std::static_pointer_cast( - dev_ctx.GetBlob(key_grad)); - - if (p_grad == nullptr) { - // create mkldnn memory for input diff_y - auto diff_dst_md = platform::MKLDNNMemDesc( - diff_dst_tz, platform::MKLDNNGetDataType(), diff_y_format); - auto diff_dst_memory = std::shared_ptr( - new memory({diff_dst_md, mkldnn_engine}, to_void_cast(diff_y_data))); - dev_ctx.SetBlob(key_diff_dst_mem, diff_dst_memory); - - // retrieve eltwise primitive desc from device context - auto forward_pd = - std::static_pointer_cast( - dev_ctx.GetBlob(key_fwd_pd)); - PADDLE_ENFORCE(forward_pd != nullptr, - "Fail to find eltwise_fwd_pd in device context"); - - // ceate primitive descriptor for activation backward - auto backward_desc = mkldnn::eltwise_backward::desc( - algorithm, diff_dst_memory->get_primitive_desc().desc(), - src_memory->get_primitive_desc().desc(), alpha, beta); - auto backward_pd = mkldnn::eltwise_backward::primitive_desc( - backward_desc, mkldnn_engine, *forward_pd); - - // create mkldnn memory for output diff_src - diff_src_memory = std::make_shared( - backward_pd.diff_src_primitive_desc(), diff_x_data); - dev_ctx.SetBlob(key_diff_src_mem, diff_src_memory); - - // create activation backward primitive - p_grad = std::make_shared( - backward_pd, *src_memory, *diff_dst_memory, *diff_src_memory); - dev_ctx.SetBlob(key_grad, p_grad); - } else { - // primitives already exist - diff_src_memory = std::static_pointer_cast( - dev_ctx.GetBlob(key_diff_src_mem)); - auto diff_dst_memory = std::static_pointer_cast( - dev_ctx.GetBlob(key_diff_dst_mem)); - - diff_src_memory->set_data_handle( - platform::to_void_reinterpret_cast(diff_x_data)); - diff_dst_memory->set_data_handle( - platform::to_void_reinterpret_cast(diff_y_data)); - } + platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); + + auto diff_dst_memory_p = + handler.AcquireDiffDstMemory(diff_dst_md, to_void_cast(diff_y_data)); + + auto activation_backward_pd = + handler.AcquireActivationBackwardPrimitiveDescriptor( + algorithm, diff_dst_md, src_memory->get_primitive_desc().desc(), + alpha, beta); + + auto diff_src_memory_p = + handler.AcquireDiffSrcMemoryFromPrimitive(diff_x_data); + + auto activation_backward_p = handler.AcquireActivationBackward( + diff_src_memory_p, diff_dst_memory_p, src_memory); // push primitive to stream and wait until it's executed std::vector pipeline; - pipeline.push_back(*p_grad); + pipeline.push_back(*activation_backward_p); stream(stream::kind::eager).submit(pipeline).wait(); diff_x->set_layout(DataLayout::kMKLDNN); - diff_x->set_format(GetMKLDNNFormat(*diff_src_memory)); + diff_x->set_format(GetMKLDNNFormat(*diff_src_memory_p)); } template diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 4697f8c916177bb6bc1bb9ccea32cd73269aeb5d..d478d66fc5617bed9d67d53b436fa8c1456537bb 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -309,6 +309,121 @@ class SumMKLDNNHandler : public MKLDNNHandler { std::shared_ptr sum_pd_; }; +class ActivationMKLDNNHandler : public MKLDNNHandler { + public: + ActivationMKLDNNHandler(const platform::MKLDNNDeviceContext& dev_ctx, + mkldnn::engine engine, const std::string& base_key) + : platform::MKLDNNHandler(dev_ctx, engine, base_key) {} + + std::shared_ptr + AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind prop_kind, + mkldnn::algorithm algorithm, + const mkldnn::memory::desc& md, + float alpha, float beta) { + // Activation PD has to be passed to Grad op that + // may be executed by diffrent thread, hence + // for that one we use key that does not contain TID + const std::string key_activation_pd = key_common_ + "@activation_pd"; + activation_pd_ = + std::static_pointer_cast( + dev_ctx_.GetBlob(key_activation_pd)); + if (activation_pd_ == nullptr) { + static std::mutex acquire_barrier; + std::lock_guard block_threads_until_finish_this_job( + acquire_barrier); + + activation_pd_ = + std::static_pointer_cast( + dev_ctx_.GetBlob(key_activation_pd)); + if (activation_pd_ == nullptr) { + auto activation_desc = mkldnn::eltwise_forward::desc( + prop_kind, algorithm, md, alpha, beta); + + activation_pd_.reset(new mkldnn::eltwise_forward::primitive_desc( + activation_desc, engine_)); + dev_ctx_.SetBlob(key_activation_pd, activation_pd_); + } + } + return activation_pd_; + } + + std::shared_ptr + AcquireActivationBackwardPrimitiveDescriptor( + mkldnn::algorithm algorithm, const mkldnn::memory::desc& diff_dst_md, + const mkldnn::memory::desc& src_md, float alpha, float beta) { + const std::string key_activation_pd = key_common_ + "@activation_pd"; + const std::string key_activation_bwd_pd = key_ + "@activation_bwd_pd"; + activation_bwd_pd_ = + std::static_pointer_cast( + dev_ctx_.GetBlob(key_activation_bwd_pd)); + if (activation_bwd_pd_ == nullptr) { + activation_pd_ = + std::static_pointer_cast( + dev_ctx_.GetBlob(key_activation_pd)); + // PD from FWD op has to exist. + PADDLE_ENFORCE(activation_pd_ != nullptr, + "Eltwise MKL-DNN not found in cache!"); + auto backward_desc = mkldnn::eltwise_backward::desc( + algorithm, diff_dst_md, src_md, alpha, beta); + activation_bwd_pd_.reset(new mkldnn::eltwise_backward::primitive_desc( + backward_desc, engine_, *activation_pd_)); + dev_ctx_.SetBlob(key_activation_bwd_pd, activation_bwd_pd_); + } + return activation_bwd_pd_; + } + + std::shared_ptr AcquireActivation( + std::shared_ptr dst_memory_p, + std::shared_ptr src_memory_p) { + /*Generate key*/ + auto prim_key = key_ + "@eltwise_p"; + + auto eltwise_p = std::static_pointer_cast( + dev_ctx_.GetBlob(prim_key)); + if (eltwise_p == nullptr) { + eltwise_p = std::make_shared( + *activation_pd_, *(src_memory_p), *(dst_memory_p)); + dev_ctx_.SetBlob(prim_key, eltwise_p); + } + + return eltwise_p; + } + + // TODO(jczaja): Merge all AcquireDstMemoryFromPrimitive into one + std::shared_ptr AcquireDstMemoryFromPrimitive(void* ptr) { + return this->AcquireMemoryFromPrimitive( + activation_pd_->dst_primitive_desc(), ptr, "@dst_mem_p"); + } + + std::shared_ptr AcquireDiffSrcMemoryFromPrimitive(void* ptr) { + return this->AcquireMemoryFromPrimitive( + activation_bwd_pd_->diff_src_primitive_desc(), ptr, "@diff_src_mem_p"); + } + + std::shared_ptr AcquireActivationBackward( + std::shared_ptr diff_src_memory_p, + std::shared_ptr diff_dst_memory_p, + std::shared_ptr src_memory_p) { + /*Generate key*/ + auto prim_key = key_ + "@eltwise_bwd_p"; + + auto eltwise_bwd_p = std::static_pointer_cast( + dev_ctx_.GetBlob(prim_key)); + if (eltwise_bwd_p == nullptr) { + eltwise_bwd_p = std::make_shared( + *activation_bwd_pd_, *(src_memory_p), *(diff_dst_memory_p), + *(diff_src_memory_p)); + dev_ctx_.SetBlob(prim_key, eltwise_bwd_p); + } + + return eltwise_bwd_p; + } + + private: + std::shared_ptr activation_pd_; + std::shared_ptr activation_bwd_pd_; +}; + class TransposeMKLDNNHandler : public MKLDNNHandler { public: TransposeMKLDNNHandler(std::vector& dims, // NOLINT