// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/add_n_kernel.h" #include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { namespace funcs { template class SumOneDNNHandler : public OneDNNHandlerNoCachingT { public: SumOneDNNHandler(dnnl::engine engine, const Place& cpu_place, const std::vector& x, DenseTensor* out) : OneDNNHandlerNoCachingT(engine, cpu_place), num_inputs_(0) { auto dst_tz = vectorize(out->dims()); auto src_tz = dst_tz; std::vector srcs_md; srcs_md.reserve(x.size()); for (size_t i = 0; i < x.size(); i++) { auto* input_it = (static_cast(x[i])); if (input_it->numel() == 0) { continue; } srcs_md.push_back(input_it->mem_desc()); ++num_inputs_; } std::vector scales(num_inputs_, 1.0f); auto dst_md = dnnl::memory::desc( dst_tz, OneDNNGetDataType(), OneDNNMemoryFormat::any); this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md); } // (jczaja) sum oneDNN prim is not having .desc attribute so // we cannot use base AcquireForwardPrimitiveDescriptor void AcquireForwardPrimitiveDescriptor( const dnnl::memory::desc& dst_md, const std::vector& scales, const std::vector& srcs_md) { this->fwd_pd_.reset( new dnnl::sum::primitive_desc(dst_md, scales, srcs_md, this->engine_)); } std::shared_ptr AcquireSrcMemory(const DenseTensor* input, int i) { const T* input_data = input->data(); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i), to_void_cast(input_data)); } using OneDNNHandlerNoCachingT::AcquireDstMemory; std::shared_ptr AcquireDstMemory(void) { return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc()); } inline int GetNumInputs(void) { return num_inputs_; } private: int num_inputs_; }; } // namespace funcs template void AddNKernel(const Context& dev_ctx, const std::vector& x, DenseTensor* out) { PADDLE_ENFORCE_EQ( dev_ctx.GetPlace().GetType() == AllocationType::CPU, true, errors::PreconditionNotMet("oneDNN AddN kernel must use CPUPlace")); const auto& onednn_engine = dev_ctx.GetEngine(); PADDLE_ENFORCE_NE( x.empty(), true, errors::InvalidArgument("Input variable is empty.")); auto* input0 = (static_cast(x[0])); bool in_place = (input0->numel() > 0) && input0->IsSharedBufferWith(*out); funcs::SumOneDNNHandler handler(onednn_engine, dev_ctx.GetPlace(), x, out); // Create list of SRC MEMs std::vector> srcs_mem; srcs_mem.reserve(handler.GetNumInputs()); int input_index = 0; for (size_t i = 0; i < x.size(); i++) { auto* input_it = (static_cast(x[i])); if (input_it->numel() == 0) { continue; } srcs_mem.push_back(handler.AcquireSrcMemory(input_it, input_index)); ++input_index; } std::unordered_map args; for (size_t i = 0; i < srcs_mem.size(); ++i) { args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])}); } auto dst_mem = in_place ? srcs_mem[0] : handler.AcquireDstMemory(out); args.insert({DNNL_ARG_DST, *dst_mem}); auto sum_p = handler.AcquireForwardPrimitive(); auto& astream = OneDNNContext::tls().get_stream(); sum_p->execute(astream, args); astream.wait(); out->set_mem_desc(dst_mem->get_desc()); } } // namespace phi PD_REGISTER_KERNEL( add_n, OneDNN, ONEDNN, phi::AddNKernel, float, phi::dtype::bfloat16) {}