/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include #include #include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/phi/backends/onednn/onednn_context.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/onednn/mkldnn_helper.h" namespace phi { namespace funcs { using user_function = std::function(const float*)>; using memory = dnnl::memory; using Place = phi::Place; using MKLDNNMemoryFormat = dnnl::memory::format_tag; template class MKLDNNHandlerNoCachingT { public: MKLDNNHandlerNoCachingT(dnnl::engine engine, Place cpu_place) : engine_(engine), place_(cpu_place), fwd_pd_(nullptr), bwd_pd_(nullptr) { phi::OneDNNContext::tls().log_lib_version(); } std::shared_ptr AcquireForwardPrimitive() { return std::make_shared(*fwd_pd_); } std::shared_ptr AcquireBackwardPrimitive() { return std::make_shared(*bwd_pd_); } std::shared_ptr AcquireBackwardWeightsPrimitive() { PADDLE_ENFORCE_NOT_NULL( bwd_w_pd_, phi::errors::Unavailable("BWD_PD should be set when " "getting BWD prim .")); return std::make_shared(*bwd_w_pd_); } std::shared_ptr AcquireSrcMemory(const DenseTensor* input) { const T* input_data = input->data(); return this->AcquireMemoryFromPrimitive(fwd_pd_->src_desc(), to_void_cast(input_data)); } template std::shared_ptr AcquireDstMemory(DenseTensor* output) { T_out* ptr = output->mutable_data(place_, fwd_pd_->dst_desc().get_size()); return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc(), ptr); } template std::shared_ptr AcquireDstMemory(void) { return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc()); } template std::shared_ptr AcquireDstMemory(const DenseTensor* output) { const T_out* output_data = output->data(); return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_desc(), to_void_cast(output_data)); } std::shared_ptr AcquireDiffDstMemory( const DenseTensor* diffdst) { const T* ptr = diffdst->data(); return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_desc(), to_void_cast(ptr)); } std::shared_ptr AcquireDiffSrcMemory(DenseTensor* diffsrc) { T* ptr = diffsrc->mutable_data(place_, bwd_pd_->diff_src_desc().get_size()); return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_desc(), ptr); } // Buffer of given Tensor is used for oneDNN computation std::shared_ptr AcquireDiffWeightsMemory( DenseTensor* diff_weights) { PADDLE_ENFORCE_NOT_NULL( bwd_w_pd_, phi::errors::Unavailable( "BWD_W_PD should be set when getting BWD grad of weights.")); T* ptr = diff_weights->mutable_data( place_, bwd_w_pd_->diff_weights_desc().get_size()); return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc(), ptr); } // Buffer is allocated by oneDNN to store computation results std::shared_ptr AcquireDiffWeightsMemory(void) { PADDLE_ENFORCE_NOT_NULL( bwd_w_pd_, phi::errors::Unavailable( "BWD_W_PD should be set when getting BWD grad of weights.")); return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc()); } protected: // If your primitive descriptor requires attributes, pass them as a // first argument and paramters to descriptor constructor in the following // arguments. Otherwise, all arguments will be forwarded to descriptor // constructor, including the first one. template void AcquireForwardPrimitiveDescriptor(Arg&& first_arg, Args&&... args) { CreateForwardPrimitiveDescriptor(first_arg, std::forward(args)...); } // Using sfinae to specialise variadic function. Workaround for not having // if constexpr in C++ 11. template typename std::enable_if::type, dnnl::primitive_attr>::value>::type CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) { auto fwd_desc = typename TForward::desc(std::forward(args)...); fwd_pd_ = std::make_shared( fwd_desc, first, engine_); } template typename std::enable_if::type, dnnl::primitive_attr>::value>::type CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) { auto fwd_desc = typename TForward::desc(std::forward(first), std::forward(args)...); fwd_pd_ = std::make_shared(fwd_desc, engine_); } template void AcquireBackwardPrimitiveDescriptor(Args&&... args) { // fwd_pd_ is set during grad by calling // AcquireForwardPrimitiveDescriptor PADDLE_ENFORCE_NOT_NULL( fwd_pd_, phi::errors::Unavailable("Get MKLDNN Forward primitive %s failed.")); auto bwd_desc = typename TBackward::desc(std::forward(args)...); bwd_pd_ = std::make_shared( bwd_desc, engine_, *fwd_pd_); } template void AcquireBackwardWeightsPrimitiveDescriptor(Args&&... args) { // fwd_pd_ is set during grad by calling // AcquireForwardPrimitiveDescriptor PADDLE_ENFORCE_NOT_NULL( fwd_pd_, phi::errors::Unavailable("Get MKLDNN Forward primitive %s failed.")); auto bwd_desc = typename TBackward_params::desc(std::forward(args)...); bwd_w_pd_ = std::make_shared( bwd_desc, engine_, *fwd_pd_); } std::shared_ptr AcquireMemoryFromPrimitive( dnnl::memory::desc md, void* ptr) { return std::make_shared(md, engine_, ptr); } std::shared_ptr AcquireMemoryFromPrimitive( dnnl::memory::desc md) { return std::make_shared(md, engine_); } void AcquireReorder(const std::shared_ptr& user_memory_p, const std::shared_ptr& target_memory_p) { auto reorder_p = std::make_shared(*user_memory_p, *target_memory_p); auto& astream = phi::OneDNNContext::tls().get_stream(); paddle::platform::RecordEvent record_reorder( "int_reorder", paddle::platform::TracerEventType::UserDefined, 2, paddle::platform::EventRole::kUniqueOp); reorder_p->execute( astream, {{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}}); astream.wait(); } template std::shared_ptr AcquireMemoryWithReorder( const dnnl::memory::desc& user_md, const dnnl::memory::desc& target_md, void* ptr, bool is_persistent = false, std::function(const F*)> custom_reorder_func = {}) { std::shared_ptr target_memory_p; if (custom_reorder_func) { auto reordered_data = custom_reorder_func(reinterpret_cast(ptr)); ptr = reinterpret_cast(reordered_data.get()); } auto user_memory_p = std::make_shared(user_md, engine_, ptr); if (user_md != target_md) { target_memory_p = std::make_shared(target_md, engine_); auto reorder_p = std::make_shared(*user_memory_p, *target_memory_p); auto& astream = phi::OneDNNContext::tls().get_stream(); paddle::platform::RecordEvent record_reorder( "int_reorder", paddle::platform::TracerEventType::UserDefined, 2, paddle::platform::EventRole::kUniqueOp); reorder_p->execute( astream, {{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}}); astream.wait(); } else { target_memory_p = user_memory_p; } return target_memory_p; } dnnl::engine engine_; Place place_; std::shared_ptr fwd_pd_; std::shared_ptr bwd_pd_; std::shared_ptr bwd_w_pd_; }; template class ActivationMKLDNNHandler : public MKLDNNHandlerNoCachingT { public: ActivationMKLDNNHandler(dnnl::algorithm algorithm, float alpha, float beta, const dnnl::engine engine, Place cpu_place, const DenseTensor* x) : MKLDNNHandlerNoCachingT(engine, cpu_place) { this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, algorithm, x->mem_desc(), alpha, beta); } ActivationMKLDNNHandler(dnnl::algorithm algorithm, float alpha, float beta, const dnnl::engine engine, Place cpu_place, const DenseTensor* x, const DenseTensor* dout) : MKLDNNHandlerNoCachingT(engine, cpu_place) { this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, algorithm, x->mem_desc(), alpha, beta); this->AcquireBackwardPrimitiveDescriptor( algorithm, dout->mem_desc(), x->mem_desc(), alpha, beta); } std::shared_ptr AcquireBackwardSrcMemory( const DenseTensor* input) { const T* input_data = input->data(); return this->AcquireMemoryFromPrimitive(this->bwd_pd_->src_desc(), to_void_cast(input_data)); } }; class ReorderMKLDNNHandler { public: ReorderMKLDNNHandler(std::vector& dims, // NOLINT DataType ptype, dnnl::memory::data_type dtype, dnnl::engine engine) : dims_(dims), ptype_(ptype), ptype_dst_(ptype), dtype_(dtype), dtype_dst_(dtype), engine_(engine) {} ReorderMKLDNNHandler(std::vector& dims, // NOLINT DataType ptype, dnnl::memory::data_type dtype, DataType ptype_dst, dnnl::memory::data_type dtype_dst, dnnl::engine engine) : dims_(dims), ptype_(ptype), ptype_dst_(ptype_dst), dtype_(dtype), dtype_dst_(dtype_dst), engine_(engine) {} std::shared_ptr AcquireSrcMemory(const dnnl::memory::desc& md, void* ptr) { return std::make_shared(md, engine_, ptr); } std::shared_ptr AcquireSrcMemory(const MKLDNNMemoryFormat& fmt, void* ptr) { auto md = dnnl::memory::desc(dims_, dtype_, fmt); return std::make_shared(md, engine_, ptr); } std::shared_ptr AcquireSubmemory( const std::vector& dims, const std::vector& offset, const std::shared_ptr& mem_p) { auto sub_md = mem_p->get_desc().submemory_desc(dims, {offset}); auto sub_mem_p = std::make_shared( sub_md, engine_, mem_p->get_data_handle()); return sub_mem_p; } std::shared_ptr AcquireDstMemory(DenseTensor* output, const MKLDNNMemoryFormat& fmt, Place place) { auto dst_md = MKLDNNMemDesc(dims_, dtype_dst_, fmt); auto dst_data = output->mutable_data(place, ptype_dst_, dst_md.get_size()); return std::make_shared(dst_md, engine_, dst_data); } std::shared_ptr AcquireDstMemory( DenseTensor* output, const dnnl::memory::desc& src_md, Place place) { if (ptype_dst_ == ptype_) { auto dst_data = output->mutable_data(place, ptype_dst_, src_md.get_size()); return std::make_shared(src_md, engine_, dst_data); } else { auto dst_md = src_md; dst_md.data.data_type = static_cast(dtype_dst_); auto dst_data = output->mutable_data(place, ptype_dst_, dst_md.get_size()); return std::make_shared(dst_md, engine_, dst_data); } } std::shared_ptr AcquireDstMemory( DenseTensor* output, const std::vector& dims, const MKLDNNMemoryFormat& fmt, Place place) { auto dst_md = MKLDNNMemDesc(dims, dtype_dst_, fmt); auto dst_data = output->mutable_data(place, ptype_dst_, dst_md.get_size()); return std::make_shared(dst_md, engine_, dst_data); } std::shared_ptr AcquireReorder( std::shared_ptr dst_memory_p, std::shared_ptr src_memory_p) { return std::make_shared(*(src_memory_p), *(dst_memory_p)); } std::shared_ptr AcquireReorder( std::shared_ptr dst_memory_p, std::shared_ptr src_memory_p, const dnnl::primitive_attr& attrs) { return std::make_shared( *(src_memory_p), *(dst_memory_p), attrs); } private: std::vector dims_; DataType ptype_, ptype_dst_; dnnl::memory::data_type dtype_, dtype_dst_; dnnl::engine engine_; }; } // namespace funcs } // namespace phi