diff --git a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc index 4862b029d6e33b3bfefc524827cb2851ab2fae06..83e9cfd90a8515c8fd15842c114e6b2c59f45d18 100644 --- a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc @@ -37,7 +37,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { "It must use CPUPlace."); auto& dev_ctx = ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); const Tensor* input = ctx.Input("X"); Tensor* output = ctx.Output("Out"); @@ -66,52 +65,37 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { PADDLE_ENFORCE(input->dims().size() == 4, "Input dim must be with 4, i.e. NCHW"); - const T* input_data = input->data(); - T* output_data = output->mutable_data(ctx.GetPlace()); - auto src_tz = paddle::framework::vectorize(input->dims()); auto dst_tz = paddle::framework::vectorize(output->dims()); - auto input_format = input->format(); - MKLDNNMemoryFormat output_format{MKLDNNMemoryFormat::format_undef}; - - mkldnn::memory::data_type dt = - paddle::framework::ToMKLDNNDataType(input->type()); - auto fmt = input->format(); - - const std::string key = - platform::CreateKey(src_tz, pooling_type, ksize, strides, paddings, dt, - fmt, ctx.op().Output("Out")); - - platform::PoolingMKLDNNHandler handler(pooling_type, dt, - ctx.Attr("is_test"), dev_ctx, - mkldnn_engine, key); - - auto src_md = platform::MKLDNNMemDesc(src_tz, dt, input_format); - - auto src_memory = - handler.AcquireSrcMemory(src_md, to_void_cast(input_data)); - - /* create memory descriptor for pooling without specified format - * ('any') which lets a primitive (pooling in this case) choose - * the memory format preferred for best performance - */ - auto dst_md = platform::MKLDNNMemDesc(dst_tz, dt, MKLDNNMemoryFormat::any); - - auto pooling_pd = handler.AcquirePoolingPrimitiveDescriptor( - src_tz, dst_tz, src_md, dst_md, ksize, strides, paddings, - ctx.Attr("ceil_mode")); - - auto dst_memory = - handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); - - auto pool_p = handler.AcquirePooling(dst_memory, src_memory); + auto is_test = ctx.Attr("is_test"); + + platform::PoolingMKLDNNHandler handler( + src_tz, dst_tz, ksize, strides, paddings, pooling_type, + ctx.Attr("ceil_mode"), input->format(), + paddle::framework::ToMKLDNNDataType(input->type()), is_test, dev_ctx, + ctx.GetPlace(), ctx.op().Output("Out")); + + auto src_memory = handler.AcquireSrcMemory(input); + auto dst_memory = handler.AcquireDstMemory(output); + + std::shared_ptr pool_p; + std::shared_ptr workspace_memory; + if ((is_test == false) && (pooling_type == "max")) { + // Training + workspace_memory = handler.AcquireWorkspaceMemory(); + pool_p = handler.AcquireForwardPrimitive(*src_memory, *dst_memory, + *workspace_memory); + } else { + // Inference + pool_p = handler.AcquireForwardPrimitive(*src_memory, *dst_memory); + } // push primitive to stream and wait until it's executed std::vector pipeline{*pool_p}; stream(stream::kind::eager).submit(pipeline).wait(); - output_format = + auto output_format = (MKLDNNMemoryFormat)dst_memory->get_primitive_desc().desc().data.format; output->set_layout(DataLayout::kMKLDNN); @@ -158,14 +142,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto& dev_ctx = ctx.template device_context(); - const mkldnn::engine& mkldnn_engine = dev_ctx.GetEngine(); std::vector pipeline; - const T* out_grad_data = out_grad->data(); - T* in_x_grad_data = in_x_grad->mutable_data(ctx.GetPlace()); - MKLDNNMemoryFormat in_x_grad_format{MKLDNNMemoryFormat::format_undef}; - auto diff_src_tz = paddle::framework::vectorize(in_x_grad->dims()); auto diff_dst_tz = paddle::framework::vectorize(out_grad->dims()); @@ -175,36 +154,35 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { diff_src_tz, pooling_type, ksize, strides, paddings, memory::data_type::f32, in_x->format(), ctx.op().Input("Out")); - platform::PoolingMKLDNNHandler handler( - pooling_type, paddle::framework::ToMKLDNNDataType(in_x_grad->type()), - false, dev_ctx, mkldnn_engine, key); - - auto workspace = handler.AcquireWorkspaceMemory(); - - auto diff_dst_md = platform::MKLDNNMemDesc( - {diff_dst_tz}, platform::MKLDNNGetDataType(), out_grad->format()); - - auto diff_dst_memory = handler.AcquireDiffDstMemory( - diff_dst_md, to_void_cast(out_grad_data)); - - auto diff_src_md = platform::MKLDNNMemDesc( - diff_src_tz, platform::MKLDNNGetDataType(), MKLDNNMemoryFormat::any); - - auto bwd_pd = handler.AcquirePoolingBackwardPrimitiveDescriptor( - diff_dst_md, diff_src_md, ksize, strides, paddings); - - auto diff_src_memory = handler.AcquireDiffSrcMemoryFromPrimitive( - reinterpret_cast(in_x_grad_data)); - - auto pool_bwd_p = handler.AcquirePoolingBackward(diff_dst_memory, workspace, - diff_src_memory); + platform::PoolingMKLDNNHandler handler( + diff_dst_tz, diff_src_tz, ksize, strides, paddings, pooling_type, + ctx.Attr("ceil_mode"), in_x->format(), out_grad->format(), + paddle::framework::ToMKLDNNDataType(out_grad->type()), dev_ctx, + ctx.GetPlace(), ctx.op().Input("Out")); + + auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad); + auto diff_src_memory = handler.AcquireDiffSrcMemory(in_x_grad); + + std::shared_ptr pool_bwd_p; + std::shared_ptr workspace_memory; + if (pooling_type == "max") { + // Max - pooling needs Workspace + workspace_memory = handler.AcquireWorkspaceMemory(); + pool_bwd_p = handler.AcquireBackwardPrimitive( + *diff_dst_memory, *workspace_memory, *diff_src_memory); + } else { + // Average Pooling + pool_bwd_p = + handler.AcquireBackwardPrimitive(*diff_dst_memory, *diff_src_memory); + } pipeline.push_back(*pool_bwd_p); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); - in_x_grad_format = (MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc() - .desc() - .data.format; + auto in_x_grad_format = + (MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc() + .desc() + .data.format; in_x_grad->set_layout(DataLayout::kMKLDNN); in_x_grad->set_format(in_x_grad_format); } // Compute() diff --git a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc index 2a73458bfeb0e2808cc1a910318679e6a6cde231..690f9271fb7cb17032ef56d4904855a0ec115e6a 100644 --- a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc @@ -66,8 +66,6 @@ class SoftmaxMKLDNNHandler auto diff_softmax_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), diff_fmt); - this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, - data_softmax_md, axis); this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md, axis); } diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 53697f587e5d04b97b66eda56ba20683ca652342..7396b90ea3d0728a6c63069e8cb3089cc3c47f98 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -140,6 +140,9 @@ class MKLDNNHandlerT { template void AcquireBackwardPrimitiveDescriptor(Args&&... args) { + const std::string key_fwd_pd = key_common_ + "@forward_pd"; + fwd_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_fwd_pd)); PADDLE_ENFORCE_NOT_NULL(fwd_pd_); const std::string key_pd = key_ + "@backward_pd"; bwd_pd_ = std::static_pointer_cast( @@ -445,8 +448,6 @@ class ActivationMKLDNNHandler auto src_md = platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType(), fmt); - this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training, - algorithm, src_md, alpha, beta); this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md, alpha, beta); } @@ -496,9 +497,6 @@ class LRNMKLDNNHandler auto diff_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), diff_fmt); - this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training, - mkldnn::lrn_across_channels, src_md, - n, alpha, beta, k); this->AcquireBackwardPrimitiveDescriptor( mkldnn::lrn_across_channels, src_md, diff_md, n, alpha, beta, k); } @@ -520,177 +518,97 @@ class LRNMKLDNNHandler } }; -class PoolingMKLDNNHandler : public MKLDNNHandler { +template +class PoolingMKLDNNHandler : public MKLDNNHandlerT { public: - PoolingMKLDNNHandler(const std::string& pooling_type, - mkldnn::memory::data_type dt, bool is_test, - const platform::MKLDNNDeviceContext& dev_ctx, - mkldnn::engine engine, const std::string& base_key) - : platform::MKLDNNHandler(dev_ctx, engine, base_key), - dt_(dt), - pooling_type_(pooling_type), - is_test_(is_test) {} - - std::shared_ptr - AcquirePoolingPrimitiveDescriptor( - const std::vector& src_tz, const std::vector& dst_tz, - const mkldnn::memory::desc& src_md, const mkldnn::memory::desc& dst_md, + PoolingMKLDNNHandler( + const std::vector& src_dims, const std::vector& dst_dims, const std::vector& ksize, const std::vector& strides, - const std::vector& paddings, bool ceil_mode) { - // Pooling PD has to be passed to Grad op that - // may be executed by diffrent thread, hence - // for that one we use key that does not contain TID - const std::string key_pooling_pd = key_common_ + "@pooling_pd"; - fwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_pooling_pd)); - if (fwd_pd_ == nullptr) { - static std::mutex acquire_barrier; - std::lock_guard block_threads_until_finish_this_job( - acquire_barrier); - fwd_pd_ = - std::static_pointer_cast( - dev_ctx_.GetBlob(key_pooling_pd)); - if (fwd_pd_ == nullptr) { - std::vector padding_left_top(paddings); - std::vector padding_right_bottom(paddings); - if (ceil_mode) { - CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides, - padding_right_bottom); - } - auto mkldnn_forward_prop_kind = - is_test_ ? mkldnn::prop_kind::forward_inference - : mkldnn::prop_kind::forward_training; - auto pooling_desc = mkldnn::pooling_forward::desc( - mkldnn_forward_prop_kind, - pooling_type_ == "max" ? mkldnn::algorithm::pooling_max - : mkldnn::algorithm::pooling_avg, - src_md, dst_md, strides, ksize, padding_left_top, - padding_right_bottom, mkldnn::padding_kind::zero); - - fwd_pd_.reset( - new mkldnn::pooling_forward::primitive_desc(pooling_desc, engine_)); - dev_ctx_.SetBlob(key_pooling_pd, fwd_pd_); - } + const std::vector& paddings, const std::string& pooling_type, + bool ceil_mode, const MKLDNNMemoryFormat fmt, + mkldnn::memory::data_type dt, bool is_test, + const platform::MKLDNNDeviceContext& dev_ctx, platform::Place cpu_place, + const std::string& unique_name) + : platform::MKLDNNHandlerT( + dev_ctx, dev_ctx.GetEngine(), cpu_place, + platform::CreateKey(src_dims, pooling_type, ksize, strides, + paddings, dt, fmt, unique_name)) { + auto src_md = mkldnn::memory::desc(src_dims, dt, fmt); + /* create memory descriptor for pooling without specified format + * ('any') which lets a primitive (pooling in this case) choose + * the memory format preferred for best performance + */ + auto dst_md = + platform::MKLDNNMemDesc(dst_dims, dt, MKLDNNMemoryFormat::any); + + std::vector padding_left_top(paddings); + std::vector padding_right_bottom(paddings); + if (ceil_mode) { + CorrectOutputSize(src_dims, dst_dims, ksize, paddings, strides, + padding_right_bottom); } - return fwd_pd_; - } - std::shared_ptr AcquireDstMemoryFromPrimitive(void* ptr) { - return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr, - "@dst_mem_p"); + this->AcquireForwardPrimitiveDescriptor( + is_test ? mkldnn::prop_kind::forward_inference + : mkldnn::prop_kind::forward_training, + pooling_type == "max" ? mkldnn::algorithm::pooling_max + : mkldnn::algorithm::pooling_avg, + src_md, dst_md, strides, ksize, padding_left_top, padding_right_bottom, + mkldnn::padding_kind::zero); + } + + PoolingMKLDNNHandler( + const std::vector& diff_dst_dims, + const std::vector& diff_src_dims, const std::vector& ksize, + const std::vector& strides, const std::vector& paddings, + const std::string& pooling_type, bool ceil_mode, + const MKLDNNMemoryFormat fmt, const MKLDNNMemoryFormat diff_dst_fmt, + mkldnn::memory::data_type dt, + const platform::MKLDNNDeviceContext& dev_ctx, platform::Place cpu_place, + const std::string& unique_name) + : platform::MKLDNNHandlerT( + dev_ctx, dev_ctx.GetEngine(), cpu_place, + platform::CreateKey(diff_src_dims, pooling_type, ksize, strides, + paddings, dt, fmt, unique_name)) { + auto diff_dst_md = mkldnn::memory::desc( + diff_dst_dims, platform::MKLDNNGetDataType(), diff_dst_fmt); + auto diff_src_md = + mkldnn::memory::desc(diff_src_dims, platform::MKLDNNGetDataType(), + MKLDNNMemoryFormat::any); + + this->AcquireBackwardPrimitiveDescriptor( + pooling_type == "max" ? mkldnn::algorithm::pooling_max + : mkldnn::algorithm::pooling_avg, + diff_src_md, diff_dst_md, strides, ksize, paddings, paddings, + mkldnn::padding_kind::zero); } std::shared_ptr AcquireWorkspaceMemory(void) { mkldnn::memory::primitive_desc workspace_mpd = - pooling_type_ == "max" - ? fwd_pd_->workspace_primitive_desc() - : mkldnn::memory::primitive_desc( - {{}, dt_, MKLDNNMemoryFormat::nchw}, engine_); + this->fwd_pd_->workspace_primitive_desc(); // Pooling PD has to be passed to Grad op that // may be executed by diffrent thread, hence // for that one we use key that does not contain TID - auto local_key = key_common_ + "@workspace"; - auto mem_p = - std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); + auto local_key = this->key_common_ + "@workspace"; + auto mem_p = std::static_pointer_cast( + this->dev_ctx_.GetBlob(local_key)); if (mem_p == nullptr) { static std::mutex acquire_barrier; std::lock_guard block_threads_until_finish_this_job( acquire_barrier); - mem_p = - std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); + mem_p = std::static_pointer_cast( + this->dev_ctx_.GetBlob(local_key)); if (mem_p == nullptr) { mem_p = std::make_shared(workspace_mpd); - dev_ctx_.SetBlob(local_key, mem_p); + this->dev_ctx_.SetBlob(local_key, mem_p); } } return mem_p; } - std::shared_ptr AcquirePooling( - std::shared_ptr dst_memory, - std::shared_ptr src_memory) { - auto prim_key = key_ + "@pooling_p"; - - auto pooling_p = std::static_pointer_cast( - dev_ctx_.GetBlob(prim_key)); - if (pooling_p == nullptr) { - if (is_test_) { - pooling_p = std::make_shared( - *fwd_pd_, *(src_memory), *(dst_memory)); - } else { - // For training we need to create workspace - // to store indices from backward - auto workspace_memory = this->AcquireWorkspaceMemory(); - - pooling_p = std::make_shared( - *fwd_pd_, *src_memory, *dst_memory, *workspace_memory); - } - dev_ctx_.SetBlob(prim_key, pooling_p); - } - return pooling_p; - } - - std::shared_ptr - AcquirePoolingBackwardPrimitiveDescriptor( - const mkldnn::memory::desc& diff_dst_md, - const mkldnn::memory::desc& diff_src_md, const std::vector& ksize, - const std::vector& strides, const std::vector& paddings) { - const std::string key_pooling_pd = key_common_ + "@pooling_pd"; - const std::string key_pooling_bwd_pd = key_ + "@pooling_bwd_pd"; - bwd_pd_ = - std::static_pointer_cast( - dev_ctx_.GetBlob(key_pooling_bwd_pd)); - if (bwd_pd_ == nullptr) { - fwd_pd_ = - std::static_pointer_cast( - dev_ctx_.GetBlob(key_pooling_pd)); - // PD from FWD op has to exist. - PADDLE_ENFORCE(fwd_pd_ != nullptr, "Pooling MKL-DNN not found in cache!"); - - auto backward_desc = mkldnn::pooling_backward::desc( - pooling_type_ == "max" ? mkldnn::algorithm::pooling_max - : mkldnn::algorithm::pooling_avg, - diff_src_md, diff_dst_md, strides, ksize, paddings, paddings, - mkldnn::padding_kind::zero); - bwd_pd_.reset(new mkldnn::pooling_backward::primitive_desc( - backward_desc, engine_, *fwd_pd_)); - - dev_ctx_.SetBlob(key_pooling_bwd_pd, bwd_pd_); - } - return bwd_pd_; - } - - std::shared_ptr AcquireDiffDstMemoryFromDataPrimitive( - const std::shared_ptr user_memory_p, - std::vector& pipeline) { // NOLINT - auto diff_dst_pd = bwd_pd_->diff_dst_primitive_desc(); - auto user_pd = user_memory_p->get_primitive_desc(); - return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p, - "@diff_dst_mem_p", pipeline); - } - - std::shared_ptr AcquireDiffSrcMemoryFromPrimitive(void* ptr) { - return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(), - ptr, "@diff_src_mem_p"); - } - - std::shared_ptr AcquirePoolingBackward( - std::shared_ptr diff_dst_memory, - std::shared_ptr workspace, - std::shared_ptr diff_src_memory) { - auto prim_key = key_ + "@pooling_bwd_p"; - - auto pooling_bwd_p = std::static_pointer_cast( - dev_ctx_.GetBlob(prim_key)); - if (pooling_bwd_p == nullptr) { - pooling_bwd_p = std::make_shared( - *bwd_pd_, *diff_dst_memory, *workspace, *diff_src_memory); - dev_ctx_.SetBlob(prim_key, pooling_bwd_p); - } - - return pooling_bwd_p; - } - private: static inline int ComputeCeiledOutput(int input_size, int kernel_size, int padding, int stride) { @@ -710,13 +628,6 @@ class PoolingMKLDNNHandler : public MKLDNNHandler { } } } - - private: - mkldnn::memory::data_type dt_; - std::string pooling_type_; - bool is_test_; - std::shared_ptr fwd_pd_; - std::shared_ptr bwd_pd_; }; class TransposeMKLDNNHandler : public MKLDNNHandler {