diff --git a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc index ea0abf930e7f548b93afca937c27fa8d25a35e94..52554800a30f2c8b666781706a9ad1f6d251b093 100644 --- a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc @@ -29,55 +29,6 @@ using mkldnn::reorder; using mkldnn::stream; using platform::to_void_cast; -// Generate keys for storing/retriving primitives for this operator -std::string CreateKey(const paddle::framework::ExecutionContext& ctx, - const memory::dims& input_dims, - const std::string& pooling_type, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, - const memory::data_type& dt, const memory::format& fmt, - const std::string& suffix) { - std::string key; - key.reserve(platform::MKLDNNHandler::MaxKeyLength); - platform::MKLDNNHandler::AppendKeyDims(&key, input_dims); - platform::MKLDNNHandler::AppendKey(&key, pooling_type); - platform::MKLDNNHandler::AppendKeyVec(&key, ksize); - platform::MKLDNNHandler::AppendKeyVec(&key, strides); - platform::MKLDNNHandler::AppendKeyVec(&key, paddings); - platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt)); - platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt)); - platform::MKLDNNHandler::AppendKey(&key, suffix); - if (platform::get_cur_mkldnn_session_id() == - platform::kMKLDNNSessionID_Default) { - auto tid = std::this_thread::get_id(); - std::stringstream ss; - ss << tid; - platform::MKLDNNHandler::AppendKey(&key, "-t:"); - platform::MKLDNNHandler::AppendKey(&key, ss.str()); - } - return key; -} - -static inline int ComputeCeiledOutput(int input_size, int kernel_size, - int padding, int stride) { - return (input_size - kernel_size + 2 * padding) / stride + 1; -} - -static inline void CorrectOutputSize( - const std::vector& src_tz, const std::vector& dst_tz, - const std::vector& kernel_size, const std::vector& paddings, - const std::vector& strides, - std::vector& right_bot_padding) { // NOLINT - for (size_t i = 0; i < right_bot_padding.size(); i++) { - int desired_size = ComputeCeiledOutput(src_tz[i + 2], kernel_size[i], - paddings[i], strides[i]); - if (desired_size != dst_tz[i + 2]) { - right_bot_padding[i] += strides[i]; - } - } -} - template class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { public: @@ -99,7 +50,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector ksize = ctx.Attr>("ksize"); std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); - bool is_test = ctx.Attr("is_test"); + if (ctx.Attr("global_pooling")) { for (size_t i = 0; i < ksize.size(); ++i) { paddings[i] = 0; @@ -126,139 +77,46 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { mkldnn::memory::data_type dt = paddle::framework::ToMKLDNNDataType(input->type()); auto fmt = input->format(); - const std::string key = - CreateKey(ctx, src_tz, pooling_type, ksize, strides, paddings, dt, fmt, - ctx.op().Output("Out")); - const std::string key_pool_p = key + "@pool_p"; - const std::string key_pool_pd = key + "@pool_pd"; - const std::string key_pool_src_mem_p = key + "@pool_src_mem_p"; - const std::string key_pool_dst_mem_p = key + "@pool_dst_mem_p"; - const std::string key_pool_workspace_memory = - key + "@pool_workspace_memory"; - - std::shared_ptr src_memory, dst_memory; - std::shared_ptr pool_pd; - std::shared_ptr pool_src_memory_p, pool_dst_memory_p; - - auto pool_p = - std::static_pointer_cast(dev_ctx.GetBlob(key_pool_p)); - if (pool_p == nullptr) { - const std::vector& padding_left_top(paddings); - std::vector padding_right_bottom(paddings); - bool ceil_mode = ctx.Attr("ceil_mode"); - if (ceil_mode) { - CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides, - padding_right_bottom); - } - auto src_md = platform::MKLDNNMemDesc(src_tz, dt, input_format); - - /* create memory descriptor for pooling without specified format - * ('any') which lets a primitive (pooling in this case) choose - * the memory format preferred for best performance - */ - auto dst_md = - platform::MKLDNNMemDesc(dst_tz, dt, mkldnn::memory::format::any); - auto propagation = src_md.data.data_type == mkldnn_f32 - ? mkldnn::prop_kind::forward_training - : mkldnn::prop_kind::forward_scoring; - std::shared_ptr pool_pd = - CreatePrimitiveDesc(src_md, dst_md, propagation, strides, - padding_left_top, padding_right_bottom, ksize, - pooling_type, mkldnn_engine, ceil_mode, is_test); - - // save pool_pd into global device context to be referred in backward path - if (!is_test) dev_ctx.SetBlob(key_pool_pd, pool_pd); - - src_memory = std::make_shared(pool_pd->src_primitive_desc(), - to_void_cast(input_data)); - dst_memory = - std::make_shared(pool_pd->dst_primitive_desc(), output_data); - - dev_ctx.SetBlob(key_pool_src_mem_p, src_memory); - dev_ctx.SetBlob(key_pool_dst_mem_p, dst_memory); - - if (is_test) { - pool_p = std::make_shared(*pool_pd, *src_memory, - *dst_memory); - } else { - std::shared_ptr workspace_memory = - CreateWorkspaceMemory(pool_pd, pooling_type, mkldnn_engine); - - // save pool_workspace_memory to be referred in backward path - dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory); - - pool_p = std::make_shared( - *pool_pd, *src_memory, *dst_memory, *workspace_memory); - } + const std::string key = platform::PoolingMKLDNNHandler::GetHash( + src_tz, pooling_type, ksize, strides, paddings, dt, fmt, + ctx.op().Output("Out")); - dev_ctx.SetBlob(key_pool_p, pool_p); - - output_format = - (memory::format)dst_memory->get_primitive_desc().desc().data.format; - } else { - // Primitives already exist - pool_src_memory_p = - std::static_pointer_cast(dev_ctx.GetBlob(key_pool_src_mem_p)); - PADDLE_ENFORCE(pool_src_memory_p != nullptr, - "Fail to find pooling src mem_p in device context"); - pool_dst_memory_p = - std::static_pointer_cast(dev_ctx.GetBlob(key_pool_dst_mem_p)); - PADDLE_ENFORCE(pool_dst_memory_p != nullptr, - "Fail to find pooling dst mem_p in device context"); - pool_src_memory_p->set_data_handle(to_void_cast(input_data)); - pool_dst_memory_p->set_data_handle(output_data); - - output_format = (memory::format)pool_dst_memory_p->get_primitive_desc() - .desc() - .data.format; - } + platform::PoolingMKLDNNHandler handler(pooling_type, dt, + ctx.Attr("is_test"), dev_ctx, + mkldnn_engine, key); + + auto src_md = platform::MKLDNNMemDesc(src_tz, dt, input_format); + + auto src_memory = + handler.AcquireSrcMemory(src_md, to_void_cast(input_data)); + + /* create memory descriptor for pooling without specified format + * ('any') which lets a primitive (pooling in this case) choose + * the memory format preferred for best performance + */ + auto dst_md = + platform::MKLDNNMemDesc(dst_tz, dt, mkldnn::memory::format::any); + + auto pooling_pd = handler.AcquirePoolingPrimitiveDescriptor( + src_tz, dst_tz, src_md, dst_md, ksize, strides, paddings, + ctx.Attr("ceil_mode")); + + auto dst_memory = + handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); + + auto pool_p = handler.AcquirePooling(dst_memory, src_memory); // push primitive to stream and wait until it's executed std::vector pipeline{*pool_p}; stream(stream::kind::eager).submit(pipeline).wait(); + output_format = + (memory::format)dst_memory->get_primitive_desc().desc().data.format; + output->set_layout(DataLayout::kMKLDNN); output->set_format(output_format); } - - private: - std::unique_ptr CreatePrimitiveDesc( - const mkldnn::memory::desc& src, const mkldnn::memory::desc& dst, - const mkldnn::prop_kind& propagation, const std::vector& stride, - const std::vector& padding_left_top, - const std::vector& padding_right_bot, const std::vector& kernel, - const std::string& pooling_type, const mkldnn::engine& engine, - bool ceil_mode, bool is_test) const { - auto mkldnn_forward_prop_kind = is_test - ? mkldnn::prop_kind::forward_inference - : mkldnn::prop_kind::forward_training; - auto pool_desc = mkldnn::pooling_forward::desc( - mkldnn_forward_prop_kind, - pooling_type == "max" ? mkldnn::algorithm::pooling_max - : mkldnn::algorithm::pooling_avg, - src, dst, stride, kernel, padding_left_top, padding_right_bot, - mkldnn::padding_kind::zero); - - auto p_pool_pd = - new mkldnn::pooling_forward::primitive_desc(pool_desc, engine); - return std::unique_ptr(p_pool_pd); - } - - std::unique_ptr CreateWorkspaceMemory( - std::shared_ptr pool_pd, - const std::string& pooling_type, const mkldnn::engine& engine) const { - mkldnn::memory::primitive_desc workspace_md = - pooling_type == "max" - ? pool_pd->workspace_primitive_desc() - : mkldnn::memory::primitive_desc({{}, - platform::MKLDNNGetDataType(), - mkldnn::memory::format::nchw}, - engine); - - auto p_workspace_memory = new mkldnn::memory(workspace_md); - return std::unique_ptr(p_workspace_memory); - } }; template @@ -299,6 +157,8 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { ctx.template device_context(); const mkldnn::engine& mkldnn_engine = dev_ctx.GetEngine(); + std::vector pipeline; + const T* out_grad_data = out_grad->data(); T* in_x_grad_data = in_x_grad->mutable_data(ctx.GetPlace()); memory::format in_x_grad_format{memory::format::format_undef}; @@ -310,119 +170,41 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { // Get an unique name from "argument" name of "Out" variable // This name will be used as key when referring info from device context - const std::string key = CreateKey(ctx, diff_src_tz, pooling_type, ksize, - strides, paddings, memory::data_type::f32, - in_x->format(), ctx.op().Input("Out")); - const std::string key_pool_bwd_p = key + "@pool_bwd_p"; - const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p"; - const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p"; - const std::string key_pool_src_mem_p = key + "@pool_src_mem_p"; - const std::string key_pool_dst_mem_p = key + "@pool_dst_mem_p"; - const std::string key_pool_pd = key + "@pool_pd"; - const std::string key_pool_workspace_memory = - key + "@pool_workspace_memory"; - - auto user_diff_dst_memory = - memory({{{diff_dst_tz}, memory::data_type::f32, out_grad->format()}, - mkldnn_engine}, - to_void_cast(out_grad_data)); - - std::shared_ptr diff_src_memory; - std::shared_ptr diff_dst_memory; - auto dst_memory = - std::static_pointer_cast(dev_ctx.GetBlob(key_pool_dst_mem_p)); - PADDLE_ENFORCE(dst_memory != nullptr, - "Fail to find dst_memory in device context"); - - primitive reorder_diff_dst; - bool is_diff_dst_reordered = false; - auto pool_bwd_p = std::static_pointer_cast( - dev_ctx.GetBlob(key_pool_bwd_p)); - if (pool_bwd_p == nullptr) { - // Retrieve src_memory/dst_memory saved in forward pass - auto src_memory = - std::static_pointer_cast(dev_ctx.GetBlob(key_pool_src_mem_p)); - PADDLE_ENFORCE(src_memory != nullptr, - "Fail to find src_memory in device context"); - // Retrieve pool_pd/pool_workspace_memory from device context - auto pool_pd = - std::static_pointer_cast( - dev_ctx.GetBlob(key_pool_pd)); - PADDLE_ENFORCE(pool_pd != nullptr, - "Fail to find pool_pd in device context"); - auto workspace_memory = std::static_pointer_cast( - dev_ctx.GetBlob(key_pool_workspace_memory)); - PADDLE_ENFORCE(workspace_memory != nullptr, - "Fail to find workspace_memory in device context"); - - // create memory descriptors for pooling - auto diff_src_md = src_memory.get()->get_primitive_desc().desc(); - auto diff_dst_md = dst_memory.get()->get_primitive_desc().desc(); - - auto pool_bwd_desc = mkldnn::pooling_backward::desc( - pooling_type == "max" ? mkldnn::algorithm::pooling_max - : mkldnn::algorithm::pooling_avg, - diff_src_md, diff_dst_md, strides, ksize, paddings, paddings, - mkldnn::padding_kind::zero); - auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc( - pool_bwd_desc, mkldnn_engine, *pool_pd); - - // reorder between user_diff_dst and pool diff_dst if needed - diff_dst_memory = std::make_shared(user_diff_dst_memory); - if (memory::primitive_desc(dst_memory->get_primitive_desc()) != - user_diff_dst_memory.get_primitive_desc()) { - diff_dst_memory = - std::make_shared(dst_memory.get()->get_primitive_desc()); - reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory); - is_diff_dst_reordered = true; - } + const std::string key = platform::PoolingMKLDNNHandler::GetHash( + diff_src_tz, pooling_type, ksize, strides, paddings, + memory::data_type::f32, in_x->format(), ctx.op().Input("Out")); - diff_src_memory = std::make_shared( - pool_bwd_pd.diff_src_primitive_desc(), in_x_grad_data); - - dev_ctx.SetBlob(key_pool_diff_src_mem_p, diff_src_memory); - dev_ctx.SetBlob(key_pool_diff_dst_mem_p, diff_dst_memory); - - pool_bwd_p = std::make_shared( - pool_bwd_pd, *diff_dst_memory, *workspace_memory, *diff_src_memory); - dev_ctx.SetBlob(key_pool_bwd_p, pool_bwd_p); - - } else { - // Primitives already exist - diff_src_memory = std::static_pointer_cast( - dev_ctx.GetBlob(key_pool_diff_src_mem_p)); - PADDLE_ENFORCE(diff_src_memory != nullptr, - "Fail to find pooling src mem_p in device context"); - diff_dst_memory = std::static_pointer_cast( - dev_ctx.GetBlob(key_pool_diff_dst_mem_p)); - PADDLE_ENFORCE(diff_dst_memory != nullptr, - "Fail to find pooling dst mem_p in device context"); - - diff_src_memory->set_data_handle(reinterpret_cast(in_x_grad_data)); - diff_dst_memory->set_data_handle(const_cast(out_grad_data)); - - // reorder between user_diff_dst and pool diff_dst if needed - if (memory::primitive_desc(dst_memory->get_primitive_desc()) != - user_diff_dst_memory.get_primitive_desc()) { - diff_dst_memory = - std::make_shared(dst_memory.get()->get_primitive_desc()); - reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory); - is_diff_dst_reordered = true; - } - } + platform::PoolingMKLDNNHandler handler( + pooling_type, paddle::framework::ToMKLDNNDataType(in_x_grad->type()), + false, dev_ctx, mkldnn_engine, key); - in_x_grad_format = (memory::format)diff_src_memory->get_primitive_desc() - .desc() - .data.format; + auto workspace = handler.AcquireWorkspaceMemory(); + + auto diff_dst_md = platform::MKLDNNMemDesc( + {diff_dst_tz}, platform::MKLDNNGetDataType(), out_grad->format()); + + auto diff_dst_memory = handler.AcquireDiffDstMemory( + diff_dst_md, to_void_cast(out_grad_data)); + + auto diff_src_md = + platform::MKLDNNMemDesc(diff_src_tz, platform::MKLDNNGetDataType(), + mkldnn::memory::format::any); + + auto bwd_pd = handler.AcquirePoolingBackwardPrimitiveDescriptor( + diff_dst_md, diff_src_md, ksize, strides, paddings); + + auto diff_src_memory = handler.AcquireDiffSrcMemoryFromPrimitive( + reinterpret_cast(in_x_grad_data)); + + auto pool_bwd_p = handler.AcquirePoolingBackward(diff_dst_memory, workspace, + diff_src_memory); - // push primitive to stream and wait until it's executed - std::vector pipeline; - if (is_diff_dst_reordered) { - pipeline.push_back(reorder_diff_dst); - } pipeline.push_back(*pool_bwd_p); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); + in_x_grad_format = (memory::format)diff_src_memory->get_primitive_desc() + .desc() + .data.format; in_x_grad->set_layout(DataLayout::kMKLDNN); in_x_grad->set_format(in_x_grad_format); } // Compute() diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index d478d66fc5617bed9d67d53b436fa8c1456537bb..70f8d9dbd8d623aac53b37d0dd5dc980cee8bfb1 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -122,6 +122,18 @@ class MKLDNNHandler { return mem_p; } + std::shared_ptr AcquireMemory( + const mkldnn::memory::primitive_desc& mpd, const std::string& suffix) { + auto local_key = key_ + suffix; + auto mem_p = + std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); + if (mem_p == nullptr) { + mem_p = std::make_shared(mpd); + dev_ctx_.SetBlob(local_key, mem_p); + } + return mem_p; + } + std::shared_ptr AcquireMemory( const std::shared_ptr& user_memory_p, const std::shared_ptr& target_memory_p, @@ -424,6 +436,223 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { std::shared_ptr activation_bwd_pd_; }; +class PoolingMKLDNNHandler : public MKLDNNHandler { + public: + PoolingMKLDNNHandler(const std::string& pooling_type, + mkldnn::memory::data_type dt, bool is_test, + const platform::MKLDNNDeviceContext& dev_ctx, + mkldnn::engine engine, const std::string& base_key) + : platform::MKLDNNHandler(dev_ctx, engine, base_key), + dt_(dt), + pooling_type_(pooling_type), + is_test_(is_test) {} + + std::shared_ptr + AcquirePoolingPrimitiveDescriptor( + const std::vector& src_tz, const std::vector& dst_tz, + const mkldnn::memory::desc& src_md, const mkldnn::memory::desc& dst_md, + const std::vector& ksize, const std::vector& strides, + const std::vector& paddings, bool ceil_mode) { + // Pooling PD has to be passed to Grad op that + // may be executed by diffrent thread, hence + // for that one we use key that does not contain TID + const std::string key_pooling_pd = key_common_ + "@pooling_pd"; + fwd_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_pooling_pd)); + if (fwd_pd_ == nullptr) { + static std::mutex acquire_barrier; + std::lock_guard block_threads_until_finish_this_job( + acquire_barrier); + fwd_pd_ = + std::static_pointer_cast( + dev_ctx_.GetBlob(key_pooling_pd)); + if (fwd_pd_ == nullptr) { + std::vector padding_left_top(paddings); + std::vector padding_right_bottom(paddings); + if (ceil_mode) { + CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides, + padding_right_bottom); + } + auto mkldnn_forward_prop_kind = + is_test_ ? mkldnn::prop_kind::forward_inference + : mkldnn::prop_kind::forward_training; + auto pooling_desc = mkldnn::pooling_forward::desc( + mkldnn_forward_prop_kind, + pooling_type_ == "max" ? mkldnn::algorithm::pooling_max + : mkldnn::algorithm::pooling_avg, + src_md, dst_md, strides, ksize, padding_left_top, + padding_right_bottom, mkldnn::padding_kind::zero); + + fwd_pd_.reset( + new mkldnn::pooling_forward::primitive_desc(pooling_desc, engine_)); + dev_ctx_.SetBlob(key_pooling_pd, fwd_pd_); + } + } + return fwd_pd_; + } + + std::shared_ptr AcquireDstMemoryFromPrimitive(void* ptr) { + return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr, + "@dst_mem_p"); + } + + std::shared_ptr AcquireWorkspaceMemory(void) { + mkldnn::memory::primitive_desc workspace_mpd = + pooling_type_ == "max" + ? fwd_pd_->workspace_primitive_desc() + : mkldnn::memory::primitive_desc( + {{}, dt_, mkldnn::memory::format::nchw}, engine_); + // Pooling PD has to be passed to Grad op that + // may be executed by diffrent thread, hence + // for that one we use key that does not contain TID + auto local_key = key_common_ + "@workspace"; + auto mem_p = + std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); + if (mem_p == nullptr) { + static std::mutex acquire_barrier; + std::lock_guard block_threads_until_finish_this_job( + acquire_barrier); + mem_p = + std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); + if (mem_p == nullptr) { + mem_p = std::make_shared(workspace_mpd); + dev_ctx_.SetBlob(local_key, mem_p); + } + } + return mem_p; + } + + std::shared_ptr AcquirePooling( + std::shared_ptr dst_memory, + std::shared_ptr src_memory) { + auto prim_key = key_ + "@pooling_p"; + + auto pooling_p = std::static_pointer_cast( + dev_ctx_.GetBlob(prim_key)); + if (pooling_p == nullptr) { + if (is_test_) { + pooling_p = std::make_shared( + *fwd_pd_, *(src_memory), *(dst_memory)); + } else { + // For training we need to create workspace + // to store indices from backward + auto workspace_memory = this->AcquireWorkspaceMemory(); + + pooling_p = std::make_shared( + *fwd_pd_, *src_memory, *dst_memory, *workspace_memory); + } + dev_ctx_.SetBlob(prim_key, pooling_p); + } + return pooling_p; + } + + std::shared_ptr + AcquirePoolingBackwardPrimitiveDescriptor( + const mkldnn::memory::desc& diff_dst_md, + const mkldnn::memory::desc& diff_src_md, const std::vector& ksize, + const std::vector& strides, const std::vector& paddings) { + const std::string key_pooling_pd = key_common_ + "@pooling_pd"; + const std::string key_pooling_bwd_pd = key_ + "@pooling_bwd_pd"; + bwd_pd_ = + std::static_pointer_cast( + dev_ctx_.GetBlob(key_pooling_bwd_pd)); + if (bwd_pd_ == nullptr) { + fwd_pd_ = + std::static_pointer_cast( + dev_ctx_.GetBlob(key_pooling_pd)); + // PD from FWD op has to exist. + PADDLE_ENFORCE(fwd_pd_ != nullptr, "Pooling MKL-DNN not found in cache!"); + + auto backward_desc = mkldnn::pooling_backward::desc( + pooling_type_ == "max" ? mkldnn::algorithm::pooling_max + : mkldnn::algorithm::pooling_avg, + diff_src_md, diff_dst_md, strides, ksize, paddings, paddings, + mkldnn::padding_kind::zero); + bwd_pd_.reset(new mkldnn::pooling_backward::primitive_desc( + backward_desc, engine_, *fwd_pd_)); + + dev_ctx_.SetBlob(key_pooling_bwd_pd, bwd_pd_); + } + return bwd_pd_; + } + + std::shared_ptr AcquireDiffDstMemoryFromDataPrimitive( + const std::shared_ptr user_memory_p, + std::vector& pipeline) { // NOLINT + auto diff_dst_pd = bwd_pd_->diff_dst_primitive_desc(); + auto user_pd = user_memory_p->get_primitive_desc(); + return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p, + "@diff_dst_mem_p", pipeline); + } + + std::shared_ptr AcquireDiffSrcMemoryFromPrimitive(void* ptr) { + return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(), + ptr, "@diff_src_mem_p"); + } + + std::shared_ptr AcquirePoolingBackward( + std::shared_ptr diff_dst_memory, + std::shared_ptr workspace, + std::shared_ptr diff_src_memory) { + auto prim_key = key_ + "@pooling_bwd_p"; + + auto pooling_bwd_p = std::static_pointer_cast( + dev_ctx_.GetBlob(prim_key)); + if (pooling_bwd_p == nullptr) { + pooling_bwd_p = std::make_shared( + *bwd_pd_, *diff_dst_memory, *workspace, *diff_src_memory); + dev_ctx_.SetBlob(prim_key, pooling_bwd_p); + } + + return pooling_bwd_p; + } + + static std::string GetHash( + const memory::dims& input_dims, const std::string& pooling_type, + const std::vector& ksize, const std::vector& strides, + const std::vector& paddings, const memory::data_type& dt, + const memory::format& fmt, const std::string& suffix) { + std::string key; + key.reserve(platform::MKLDNNHandler::MaxKeyLength); + platform::MKLDNNHandler::AppendKeyDims(&key, input_dims); + platform::MKLDNNHandler::AppendKey(&key, pooling_type); + platform::MKLDNNHandler::AppendKeyVec(&key, ksize); + platform::MKLDNNHandler::AppendKeyVec(&key, strides); + platform::MKLDNNHandler::AppendKeyVec(&key, paddings); + platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt)); + platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt)); + platform::MKLDNNHandler::AppendKey(&key, suffix); + return key; + } + + private: + static inline int ComputeCeiledOutput(int input_size, int kernel_size, + int padding, int stride) { + return (input_size - kernel_size + 2 * padding) / stride + 1; + } + + static inline void CorrectOutputSize( + const std::vector& src_tz, const std::vector& dst_tz, + const std::vector& kernel_size, const std::vector& paddings, + const std::vector& strides, + std::vector& right_bot_padding) { // NOLINT + for (size_t i = 0; i < right_bot_padding.size(); i++) { + int desired_size = ComputeCeiledOutput(src_tz[i + 2], kernel_size[i], + paddings[i], strides[i]); + if (desired_size != dst_tz[i + 2]) { + right_bot_padding[i] += strides[i]; + } + } + } + + private: + mkldnn::memory::data_type dt_; + std::string pooling_type_; + bool is_test_; + std::shared_ptr fwd_pd_; + std::shared_ptr bwd_pd_; +}; + class TransposeMKLDNNHandler : public MKLDNNHandler { public: TransposeMKLDNNHandler(std::vector& dims, // NOLINT