diff --git a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc index bdff665f0f62604be2b21b0150d6c06efc41406e..3d3738d922f77b067341b8f68e3d70a040832d3a 100644 --- a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc @@ -25,7 +25,7 @@ limitations under the License. */ #include "paddle/fluid/operators/sum_op.h" -#include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" namespace paddle { namespace framework { @@ -51,6 +51,95 @@ using paddle::platform::CPUDeviceContext; using paddle::platform::MKLDNNDeviceContext; using platform::to_void_cast; +template +class SumMKLDNNHandler : public platform::MKLDNNHandlerT { + public: + SumMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, + platform::Place cpu_place, + const std::vector& in_vars, + framework::LoDTensor* z, const std::string& uniq_name) + + : platform::MKLDNNHandlerT( + dev_ctx, dev_ctx.GetEngine(), cpu_place, + platform::CreateKey(framework::vectorize(z->dims()), uniq_name)), + num_inputs_(0) { + for (size_t i = 0; i < in_vars.size(); i++) { + srcs_suffix_.push_back(std::string("-") + std::to_string(i)); + } + + if (!this->isCached()) { + auto dst_tz = framework::vectorize(z->dims()); + auto src_tz = dst_tz; + + std::vector srcs_md; + for (size_t i = 0; i < in_vars.size(); i++) { + auto& input_it = in_vars[i]->Get(); + if (input_it.numel() == 0) { + continue; + } + MKLDNNMemoryFormat input_format = input_it.format(); + srcs_md.push_back(memory::desc(src_tz, platform::MKLDNNGetDataType(), + input_format)); + ++num_inputs_; + } + std::vector scales(num_inputs_, 1.0); + + auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType(), + MKLDNNMemoryFormat::any); + + this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md); + } + } + + // (jczaja) sum oneDNN prim is not having .desc attribute so + // we cannot use base AcquireForwardPrimitiveDescriptor + void AcquireForwardPrimitiveDescriptor( + const memory::desc& dst_md, const std::vector& scales, + const std::vector& srcs_md) { + // Sum op does not have backward so no passing from FWD to BWD is needed + const std::string key_pd = this->key_ + "@fwd_pd"; + this->fwd_pd_ = std::static_pointer_cast( + this->dev_ctx_.GetBlob(key_pd)); + if (this->fwd_pd_ == nullptr) { + this->fwd_pd_.reset(new mkldnn::sum::primitive_desc( + dst_md, scales, srcs_md, this->engine_)); + this->dev_ctx_.SetBlob(key_pd, this->fwd_pd_); + } + } + + std::shared_ptr AcquireSrcMemory( + const framework::Tensor& input, int i) { + const T* input_data = input.data(); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i), + to_void_cast(input_data), + "@src_mem_p" + srcs_suffix_[i]); + } + + using platform::MKLDNNHandlerT::AcquireDstMemory; + + std::shared_ptr AcquireDstMemory(void) { + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), + "@dst_mem_p"); + } + + inline int GetNumInputs(void) { return num_inputs_; } + + protected: + // isCached need to be overloaded as base one works on key_common + bool isCached() { + const std::string key_pd = this->key_ + "@fwd_pd"; + this->fwd_pd_ = std::static_pointer_cast( + this->dev_ctx_.GetBlob(key_pd)); + + const std::string key_p = this->key_ + "@fwd_p"; + return (this->dev_ctx_.GetBlob(key_p) != nullptr); + } + + private: + int num_inputs_; + std::vector srcs_suffix_; +}; + template class SumMKLDNNOpKernel : public paddle::framework::OpKernel { public: @@ -59,85 +148,67 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel { paddle::platform::errors::PreconditionNotMet( "Operator DNNL Sum must use CPUPlace")); auto& dev_ctx = ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); auto in_vars = ctx.MultiInputVar("X"); - auto out_var = ctx.OutputVar("Out"); PADDLE_ENFORCE_NE(in_vars.empty(), true, platform::errors::InvalidArgument( "Input variable is empty.")); - bool in_place = out_var == in_vars[0]; - + auto& input0 = in_vars[0]->Get(); LoDTensor* output = ctx.Output("Out"); - T* output_data = output->mutable_data(ctx.GetPlace()); - auto dst_tz = framework::vectorize(output->dims()); - auto src_tz = dst_tz; - MKLDNNMemoryFormat output_format{MKLDNNMemoryFormat::undef}; - std::vector scales; - std::vector srcs_md; - std::vector srcs_mem; + bool in_place = (input0.numel() > 0) && input0.IsSharedBufferWith(*output); - auto& input0 = in_vars[0]->Get(); - in_place = (input0.numel() > 0) && (input0.data() == output_data); + SumMKLDNNHandler handler(dev_ctx, ctx.GetPlace(), in_vars, output, + ctx.OutputName("Out")); + // Create list of SRC MEMs + std::vector> srcs_mem; + srcs_mem.reserve(handler.GetNumInputs()); + int input_index = 0; for (size_t i = 0; i < in_vars.size(); i++) { - auto& input_it = in_vars[i]->Get(); + auto& input_it = in_vars[i]->Get(); if (input_it.numel() == 0) { continue; } - - const T* input_data = input_it.data(); - MKLDNNMemoryFormat input_format = input_it.format(); - - auto src_md = memory::desc(src_tz, memory::data_type::f32, input_format); - auto src_mem = memory(src_md, mkldnn_engine, to_void_cast(input_data)); - srcs_md.push_back(src_md); - srcs_mem.push_back(src_mem); - scales.push_back(1.0); - } - - auto dst_md = - memory::desc(dst_tz, memory::data_type::f32, MKLDNNMemoryFormat::any); - - auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_md, mkldnn_engine); - - std::shared_ptr dst_mem; - if (in_place) { - dst_mem.reset(new memory(sum_pd.dst_desc(), mkldnn_engine)); - } else { - dst_mem.reset(new memory(sum_pd.dst_desc(), mkldnn_engine, output_data)); + srcs_mem.push_back(handler.AcquireSrcMemory(input_it, input_index)); + ++input_index; } - auto sum_prim = mkldnn::sum(sum_pd); - output_format = platform::GetMKLDNNFormat(sum_pd.dst_desc()); + auto dst_mem = in_place ? handler.AcquireDstMemory() + : handler.AcquireDstMemory(output); - std::shared_ptr reorder_p; - std::shared_ptr target_mem; - if (in_place) { - output_format = input0.format(); - target_mem.reset( - new memory({{src_tz}, memory::data_type::f32, output_format}, - mkldnn_engine, output_data)); - reorder_p = std::make_shared(*dst_mem, *target_mem); - } + auto sum_p = handler.AcquireForwardPrimitive(); - mkldnn::stream astream(mkldnn_engine); std::unordered_map args; for (size_t i = 0; i < srcs_mem.size(); ++i) { - args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, srcs_mem.at(i)}); + args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])}); } args.insert({MKLDNN_ARG_DST, *dst_mem}); - sum_prim.execute(astream, args); + mkldnn::stream astream(dev_ctx.GetEngine()); + sum_p->execute(astream, args); astream.wait(); + // For in-place execution which sum does not have we need to fake it + // so from oneDNN dst memory we reorder data into input if (in_place) { + const std::string reorder_key = platform::CreateKey( + framework::vectorize(output->dims()), ctx.OutputName("Out") + "-I"); + + auto& in_out = in_vars[0]->Get(); + auto output_tz = framework::vectorize(output->dims()); + platform::ReorderMKLDNNHandler reorder_handler( + output_tz, output->type(), framework::ToMKLDNNDataType(in_out.type()), + dev_ctx, dev_ctx.GetEngine(), reorder_key); + + auto target_mem = reorder_handler.AcquireDstMemory( + output, in_out.format(), ctx.GetPlace()); + + auto reorder_p = reorder_handler.AcquireReorder(target_mem, dst_mem); reorder_p->execute(astream, *dst_mem, *target_mem); astream.wait(); } - - output->set_layout(DataLayout::kMKLDNN); - output->set_format(output_format); + output->set_layout(framework::DataLayout::kMKLDNN); + output->set_format(platform::GetMKLDNNFormat(*dst_mem)); } }; diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 2d9e4333ac95e85b791e7f32001d9ddc859bf2c3..54f8cb1dc88428e95d7d87b9caeebcef8fd1de23 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -591,59 +591,6 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT { } }; -class SumMKLDNNHandler : public MKLDNNHandler { - public: - SumMKLDNNHandler(const platform::MKLDNNDeviceContext& dev_ctx, - mkldnn::engine engine, const std::string& base_key) - : platform::MKLDNNHandler(dev_ctx, engine, base_key) {} - - std::shared_ptr AcquireSumPrimitiveDescriptor( - const std::vector>& src_mems, - const std::vector& scales, const mkldnn::memory::desc& dst_md) { - const std::string key_sum_pd = key_ + "@sum_pd"; - - sum_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_sum_pd)); - if (sum_pd_ == nullptr) { - // Get vector of inputs primitive descriptors - std::vector src_ds; - for (auto& input_mem : src_mems) { - src_ds.push_back(input_mem->get_desc()); - } - - sum_pd_.reset( - new mkldnn::sum::primitive_desc(dst_md, scales, src_ds, engine_)); - dev_ctx_.SetBlob(key_sum_pd, sum_pd_); - } - - return sum_pd_; - } - - std::shared_ptr AcquireDstMemoryFromPrimitive(void* ptr) { - return this->AcquireMemoryFromPrimitive(sum_pd_->dst_desc(), ptr, - "@dst_mem_p"); - } - - std::shared_ptr AcquireSecondSrcMemory( - const mkldnn::memory::desc& md, void* ptr) { - return this->AcquireMemory(md, ptr, "@user_src2_mem_p"); - } - - std::shared_ptr AcquireSum() { - auto prim_key = key_ + "@sum_p"; - auto sum_p = - std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); - if (sum_p == nullptr) { - sum_p = std::make_shared(*sum_pd_); - dev_ctx_.SetBlob(prim_key, sum_p); - } - return sum_p; - } - - private: - std::shared_ptr sum_pd_; -}; - template class ActivationMKLDNNHandler : public MKLDNNHandlerT