From 5cf2d385948acbdf8dee91068313d33a36a8abe7 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Thu, 1 Aug 2019 05:51:15 +0200 Subject: [PATCH] - Removed passing X from FWD to GRAD via device context (#18911) test=develop - Extracted key generation from FWD and GRAD into separate function test=develop - Compilation fix test=develop - another compilation test=develop --- .../operators/mkldnn/activation_mkldnn_op.cc | 60 ++++++------------- paddle/fluid/platform/mkldnn_reuse.h | 16 +++++ 2 files changed, 34 insertions(+), 42 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc index e9ffe8ecfd8..35334186704 100644 --- a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc @@ -103,24 +103,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx, bool is_test = ctx.Attr("is_test"); - std::string key = platform::MKLDNNHandler::GetHash( - src_tz, std::to_string(algorithm) + std::to_string(alpha) + - std::to_string(beta) + ctx.op().Input("X")); - - // TODO(jczaja): Make it Thread safe - // save input data and layout to be referred in backward path - const std::string key_src_data = key + "@eltwise_fwd_src_data"; - const std::string key_src_layout = key + "@eltwise_fwd_src_layout"; - // Just in case some int8 models are run interchangebly - // with float models then format maybe diffrent - key += std::to_string(src_format); - const std::string key_src_mem = key + "@eltwise_fwd_src_mem"; - auto p_src_data = std::make_shared(x_data); - auto p_src_layout = std::make_shared(src_format); - if (!is_test) { - dev_ctx.SetBlob(key_src_data, p_src_data); - dev_ctx.SetBlob(key_src_layout, p_src_layout); - } + std::string key = platform::ActivationMKLDNNHandler::GetHash( + src_tz, algorithm, src_format, alpha, beta, ctx.op().Input("X")); platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); @@ -133,11 +117,6 @@ void eltwise_forward(const framework::ExecutionContext &ctx, algorithm, md, alpha, beta); auto src_memory_p = handler.AcquireSrcMemory(md, to_void_cast(x_data)); - // jczaja: Workaround, src_memory_p is needed in BWD so it has - // to be accessible under key not dependant on TID - if (!is_test) { - dev_ctx.SetBlob(key_src_mem, src_memory_p); - } auto dst_memory_p = handler.AcquireDstMemoryFromPrimitive(to_void_cast(y_data)); @@ -158,6 +137,9 @@ void eltwise_grad(const framework::ExecutionContext &ctx, auto &dev_ctx = ctx.template device_context(); const auto &mkldnn_engine = dev_ctx.GetEngine(); + const auto *x = ctx.Input("X"); + const T *x_data = x->data(); + const auto *diff_y = ctx.Input(framework::GradVarName("Out")); auto *diff_x = ctx.Output(framework::GradVarName("X")); @@ -169,47 +151,41 @@ void eltwise_grad(const framework::ExecutionContext &ctx, std::vector diff_dst_tz = framework::vectorize2int(diff_y->dims()); + // diff_dst and src dims should be the same + auto src_format = + diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : x->format(); + auto diff_y_format = diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : diff_y->format(); auto diff_dst_md = platform::MKLDNNMemDesc( diff_dst_tz, platform::MKLDNNGetDataType(), diff_y_format); - std::string key = platform::MKLDNNHandler::GetHash( - diff_dst_tz, std::to_string(algorithm) + std::to_string(alpha) + - std::to_string(beta) + ctx.op().Input("X")); + std::string key = platform::ActivationMKLDNNHandler::GetHash( + diff_dst_tz, algorithm, src_format, alpha, beta, ctx.op().Input("X")); const std::string key_src_data = key + "@eltwise_fwd_src_data"; - const std::string key_src_layout = key + "@eltwise_fwd_src_layout"; - - // Get Data from FWD op - const auto p_src_layout = - std::static_pointer_cast(dev_ctx.GetBlob(key_src_layout)); - const auto p_src_data = - std::static_pointer_cast(dev_ctx.GetBlob(key_src_data)); - key += std::to_string(*p_src_layout); - const std::string key_src_mem = key + "@eltwise_fwd_src_mem"; - auto src_memory = - std::static_pointer_cast(dev_ctx.GetBlob(key_src_mem)); - PADDLE_ENFORCE(src_memory != nullptr, - "Fail to find src_memory in device context"); - src_memory->set_data_handle(*p_src_data); + + auto src_md = platform::MKLDNNMemDesc( + diff_dst_tz, platform::MKLDNNGetDataType(), src_format); platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); + auto src_memory_p = handler.AcquireSrcMemory(src_md, to_void_cast(x_data)); + auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_dst_md, to_void_cast(diff_y_data)); auto activation_backward_pd = handler.AcquireActivationBackwardPrimitiveDescriptor( - algorithm, diff_dst_md, src_memory->get_primitive_desc().desc(), + algorithm, diff_dst_md, src_memory_p->get_primitive_desc().desc(), alpha, beta); auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromPrimitive(diff_x_data); auto activation_backward_p = handler.AcquireActivationBackward( - diff_src_memory_p, diff_dst_memory_p, src_memory); + diff_src_memory_p, diff_dst_memory_p, src_memory_p); // push primitive to stream and wait until it's executed std::vector pipeline; diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 998509ea050..935c4f734f4 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -433,6 +433,22 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { return eltwise_bwd_p; } + static std::string GetHash(const memory::dims& input_dims, + const mkldnn::algorithm algorithm, + const mkldnn::memory::format fmt, + const float alpha, const float beta, + const std::string& suffix) { + std::string key; + key.reserve(platform::MKLDNNHandler::MaxKeyLength); + platform::MKLDNNHandler::AppendKeyDims(&key, input_dims); + platform::MKLDNNHandler::AppendKey(&key, std::to_string(algorithm)); + platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt)); + platform::MKLDNNHandler::AppendKey(&key, std::to_string(alpha)); + platform::MKLDNNHandler::AppendKey(&key, std::to_string(beta)); + platform::MKLDNNHandler::AppendKey(&key, suffix); + return key; + } + private: std::shared_ptr activation_pd_; std::shared_ptr activation_bwd_pd_; -- GitLab