diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index 88cda1cd66868933de99b3b864ea98627df1e304..49cfe0a0ab0d57fc149f5cb66dbeca0da34bc989 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/framework/data_layout_transform.h" -#include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" namespace paddle { namespace operators { @@ -65,21 +65,27 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel { (src_x_tz.size() == 5 && x->format() != (format = memory::format::ncdhw))) { _x.Resize(x_dims); - auto user_x_memory_pd = memory::primitive_desc( - {{src_x_tz}, memory::data_type::f32, x->format()}, mkldnn_engine); - auto x_memory_pd = memory::primitive_desc( - {{src_x_tz}, memory::data_type::f32, format}, mkldnn_engine); - auto size = x_memory_pd.get_size(); - _x.mutable_data(ctx.GetPlace(), size); - auto user_x_memory = - memory(user_x_memory_pd, paddle::platform::to_void_cast(x_data)); - auto x_memory = memory(x_memory_pd, - paddle::platform::to_void_cast(_x.data())); - - auto x_reorder = reorder(user_x_memory, x_memory); + + mkldnn::memory::data_type in_type = platform::MKLDNNGetDataType(); + auto out_format = platform::MKLDNNFormatForSize( + x_dims.size(), mkldnn::memory::format::nchw); + + const std::string key = platform::ReorderMKLDNNHandler::GetHash( + src_x_tz, x->format(), out_format, std::to_string(in_type)); + + platform::ReorderMKLDNNHandler handler(src_x_tz, x->type(), in_type, + dev_ctx, mkldnn_engine, key); + + auto user_x_memory_p = handler.AcquireSrcMemory( + x->format(), paddle::platform::to_void_cast(x_data)); + + auto x_memory_p = + handler.AcquireDstMemory(&_x, out_format, ctx.GetPlace()); + + auto x_reorder = handler.AcquireReorder(x_memory_p, user_x_memory_p); std::vector pipeline; - pipeline.push_back(x_reorder); + pipeline.push_back(*x_reorder); stream(stream::kind::eager).submit(pipeline).wait(); } else { format = x->format(); @@ -125,46 +131,41 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel { std::vector dst_tz = framework::vectorize2int(z_dims); std::vector srcs_pd; - std::vector srcs; std::vector scales = {1.0f, 1.0f}; - auto src_x_pd = memory::primitive_desc( - {{src_x_tz}, memory::data_type::f32, x->format()}, mkldnn_engine); - auto src_y_pd = memory::primitive_desc( - {{src_y_tz}, memory::data_type::f32, y->format()}, mkldnn_engine); - auto src_x_memory = - memory(src_x_pd, paddle::platform::to_void_cast(x_data)); - auto src_y_memory = - memory(src_y_pd, paddle::platform::to_void_cast(y_data)); + const std::string key = platform::MKLDNNHandler::GetHash( + src_x_tz, ctx.op().Output("Out") + std::to_string(x->format()) + + std::to_string(y->format())); + + platform::SumMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); + + auto src_x_memory = handler.AcquireSrcMemory( + {{src_x_tz}, platform::MKLDNNGetDataType(), x->format()}, + paddle::platform::to_void_cast(x_data)); - srcs_pd.push_back(src_x_pd); - srcs_pd.push_back(src_y_pd); - srcs.push_back(src_x_memory); - srcs.push_back(src_y_memory); + auto src_y_memory = handler.AcquireSecondSrcMemory( + {{src_y_tz}, platform::MKLDNNGetDataType(), y->format()}, + paddle::platform::to_void_cast(y_data)); - auto dst_md = - memory::desc({dst_tz}, memory::data_type::f32, memory::format::any); + auto dst_md = memory::desc({dst_tz}, platform::MKLDNNGetDataType(), + memory::format::any); - // create primitive descriptor for sum - auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_pd); + auto sum_pd = handler.AcquireSumPrimitiveDescriptor( + {src_x_memory, src_y_memory}, scales, dst_md); - // create mkldnn memory for dst - memory dst_memory = memory(sum_pd.dst_primitive_desc(), z_data); + auto dst_memory = handler.AcquireDstMemoryFromPrimitive(z_data); - std::vector inputs; - inputs.push_back(srcs[0]); - inputs.push_back(srcs[1]); + std::vector inputs({*src_x_memory, *src_y_memory}); - // create sum primitive - auto sum_prim = sum(sum_pd, inputs, dst_memory); + auto sum_prim = handler.AcquireSum(dst_memory, &inputs); std::vector pipeline; - pipeline.push_back(sum_prim); + pipeline.push_back(*sum_prim); stream(stream::kind::eager).submit(pipeline).wait(); z->set_layout(DataLayout::kMKLDNN); z->set_format( - (memory::format)dst_memory.get_primitive_desc().desc().data.format); + (memory::format)dst_memory->get_primitive_desc().desc().data.format); } } }; diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index f1fb6b156aedcbf4d834d53ebe4d443fd5f780d3..ad4ddc5a3b6ae66595a8810737bda7dfe0bad37e 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -45,6 +45,11 @@ class MKLDNNHandler { return this->AcquireMemory(md, ptr, "@user_src_mem_p"); } + std::shared_ptr AcquireSecondSrcMemory( + const mkldnn::memory::desc& md, void* ptr) { + return this->AcquireMemory(md, ptr, "@user_src2_mem_p"); + } + std::shared_ptr AcquireWeightsMemory( const mkldnn::memory::desc& md, void* ptr, user_function custom_func = {}) { @@ -265,6 +270,55 @@ class MKLDNNHandler { static constexpr int MaxKeyLength = 256; }; +class SumMKLDNNHandler : public MKLDNNHandler { + public: + SumMKLDNNHandler(const platform::MKLDNNDeviceContext& dev_ctx, + mkldnn::engine engine, const std::string& base_key) + : platform::MKLDNNHandler(dev_ctx, engine, base_key) {} + + std::shared_ptr AcquireSumPrimitiveDescriptor( + const std::vector>& src_mems, + const std::vector& scales, const mkldnn::memory::desc& dst_md) { + const std::string key_sum_pd = key_ + "@sum_pd"; + + sum_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_sum_pd)); + if (sum_pd_ == nullptr) { + // Get vector of inputs primitive descriptors + std::vector src_pds; + for (auto& input_mem : src_mems) { + src_pds.push_back(input_mem->get_primitive_desc()); + } + + sum_pd_.reset(new mkldnn::sum::primitive_desc(dst_md, scales, src_pds)); + dev_ctx_.SetBlob(key_sum_pd, sum_pd_); + } + + return sum_pd_; + } + + std::shared_ptr AcquireDstMemoryFromPrimitive(void* ptr) { + return this->AcquireMemoryFromPrimitive(sum_pd_->dst_primitive_desc(), ptr, + "@dst_mem_p"); + } + + std::shared_ptr AcquireSum( + std::shared_ptr dst_memory, + std::vector* inputs) { + auto prim_key = key_ + "@sum_p"; + auto sum_p = + std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); + if (sum_p == nullptr) { + sum_p = std::make_shared(*(sum_pd_), *inputs, *(dst_memory)); + dev_ctx_.SetBlob(prim_key, sum_p); + } + return sum_p; + } + + private: + std::shared_ptr sum_pd_; +}; + class TransposeMKLDNNHandler : public MKLDNNHandler { public: TransposeMKLDNNHandler(std::vector& dims, // NOLINT