diff --git a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc deleted file mode 100644 index 072016d729cdb6aae6e6ea10155d29adfa2afb2f..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/*Licensed under the Apache License, Version 2.0(the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/fluid/framework/lod_tensor_array.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" - -namespace phi { -class DenseTensor; -} // namespace phi - -namespace paddle { -namespace operators { - -using paddle::platform::MKLDNNDeviceContext; -using phi::CPUContext; -using platform::to_void_cast; -using Tensor = framework::Tensor; -using SelectedRows = phi::SelectedRows; -using LoDTensor = framework::LoDTensor; - -template -class SumMKLDNNHandler - : public platform::MKLDNNHandlerNoCachingT { - public: - SumMKLDNNHandler(dnnl::engine engine, - platform::Place cpu_place, - const std::vector& in_vars, - framework::LoDTensor* z) - - : platform::MKLDNNHandlerNoCachingT(engine, cpu_place), - num_inputs_(0) { - auto dst_tz = phi::vectorize(z->dims()); - auto src_tz = dst_tz; - - std::vector srcs_md; - srcs_md.reserve(in_vars.size()); - for (size_t i = 0; i < in_vars.size(); i++) { - auto& input_it = in_vars[i]->Get(); - if (input_it.numel() == 0) { - continue; - } - srcs_md.push_back(input_it.mem_desc()); - ++num_inputs_; - } - std::vector scales(num_inputs_, 1.0f); - - auto dst_md = dnnl::memory::desc( - dst_tz, platform::MKLDNNGetDataType(), MKLDNNMemoryFormat::any); - - this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md); - } - - // (jczaja) sum oneDNN prim is not having .desc attribute so - // we cannot use base AcquireForwardPrimitiveDescriptor - void AcquireForwardPrimitiveDescriptor( - const dnnl::memory::desc& dst_md, - const std::vector& scales, - const std::vector& srcs_md) { - this->fwd_pd_.reset( - new dnnl::sum::primitive_desc(dst_md, scales, srcs_md, this->engine_)); - } - - std::shared_ptr AcquireSrcMemory(const framework::Tensor& input, - int i) { - const T* input_data = input.data(); - return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i), - to_void_cast(input_data)); - } - - using platform::MKLDNNHandlerNoCachingT::AcquireDstMemory; - - std::shared_ptr AcquireDstMemory(void) { - return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc()); - } - - inline int GetNumInputs(void) { return num_inputs_; } - - private: - int num_inputs_; -}; - -template -class SumMKLDNNOpKernel : public paddle::framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), - true, - paddle::platform::errors::PreconditionNotMet( - "Operator DNNL Sum must use CPUPlace")); - auto& dev_ctx = ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - auto in_vars = ctx.MultiInputVar("X"); - - PADDLE_ENFORCE_NE( - in_vars.empty(), - true, - platform::errors::InvalidArgument("Input variable is empty.")); - auto& input0 = in_vars[0]->Get(); - LoDTensor* output = ctx.Output("Out"); - - bool in_place = (input0.numel() > 0) && input0.IsSharedBufferWith(*output); - - SumMKLDNNHandler handler(mkldnn_engine, ctx.GetPlace(), in_vars, output); - - // Create list of SRC MEMs - std::vector> srcs_mem; - srcs_mem.reserve(handler.GetNumInputs()); - int input_index = 0; - for (size_t i = 0; i < in_vars.size(); i++) { - auto& input_it = in_vars[i]->Get(); - if (input_it.numel() == 0) { - continue; - } - srcs_mem.push_back(handler.AcquireSrcMemory(input_it, input_index)); - ++input_index; - } - - std::unordered_map args; - std::shared_ptr dst_mem; - - for (size_t i = 0; i < srcs_mem.size(); ++i) { - args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])}); - } - - if (in_place) { - dst_mem = srcs_mem[0]; - } else { - dst_mem = handler.AcquireDstMemory(output); - } - args.insert({DNNL_ARG_DST, *dst_mem}); - - auto sum_p = handler.AcquireForwardPrimitive(); - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - sum_p->execute(astream, args); - astream.wait(); - - output->set_mem_desc(dst_mem->get_desc()); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OP_KERNEL( - sum, - MKLDNN, - ::paddle::platform::CPUPlace, - paddle::operators::SumMKLDNNOpKernel, - paddle::operators::SumMKLDNNOpKernel); diff --git a/paddle/phi/kernels/onednn/add_n_kernel.cc b/paddle/phi/kernels/onednn/add_n_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..421e60504be950f618e5f663c7328a190a35f353 --- /dev/null +++ b/paddle/phi/kernels/onednn/add_n_kernel.cc @@ -0,0 +1,135 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/add_n_kernel.h" +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace funcs { +template +class SumOneDNNHandler : public OneDNNHandlerNoCachingT { + public: + SumOneDNNHandler(dnnl::engine engine, + const Place& cpu_place, + const std::vector& x, + DenseTensor* out) + + : OneDNNHandlerNoCachingT(engine, cpu_place), + num_inputs_(0) { + auto dst_tz = vectorize(out->dims()); + auto src_tz = dst_tz; + + std::vector srcs_md; + srcs_md.reserve(x.size()); + for (size_t i = 0; i < x.size(); i++) { + auto* input_it = (static_cast(x[i])); + if (input_it->numel() == 0) { + continue; + } + srcs_md.push_back(input_it->mem_desc()); + ++num_inputs_; + } + std::vector scales(num_inputs_, 1.0f); + + auto dst_md = dnnl::memory::desc( + dst_tz, OneDNNGetDataType(), OneDNNMemoryFormat::any); + + this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md); + } + + // (jczaja) sum oneDNN prim is not having .desc attribute so + // we cannot use base AcquireForwardPrimitiveDescriptor + void AcquireForwardPrimitiveDescriptor( + const dnnl::memory::desc& dst_md, + const std::vector& scales, + const std::vector& srcs_md) { + this->fwd_pd_.reset( + new dnnl::sum::primitive_desc(dst_md, scales, srcs_md, this->engine_)); + } + + std::shared_ptr AcquireSrcMemory(const DenseTensor* input, + int i) { + const T* input_data = input->data(); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i), + to_void_cast(input_data)); + } + + using OneDNNHandlerNoCachingT::AcquireDstMemory; + + std::shared_ptr AcquireDstMemory(void) { + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc()); + } + + inline int GetNumInputs(void) { return num_inputs_; } + + private: + int num_inputs_; +}; +} // namespace funcs + +template +void AddNKernel(const Context& dev_ctx, + const std::vector& x, + DenseTensor* out) { + PADDLE_ENFORCE_EQ( + dev_ctx.GetPlace().GetType() == AllocationType::CPU, + true, + errors::PreconditionNotMet("oneDNN AddN kernel must use CPUPlace")); + + const auto& onednn_engine = dev_ctx.GetEngine(); + + PADDLE_ENFORCE_NE( + x.empty(), true, errors::InvalidArgument("Input variable is empty.")); + auto* input0 = (static_cast(x[0])); + + bool in_place = (input0->numel() > 0) && input0->IsSharedBufferWith(*out); + + funcs::SumOneDNNHandler handler(onednn_engine, dev_ctx.GetPlace(), x, out); + + // Create list of SRC MEMs + std::vector> srcs_mem; + srcs_mem.reserve(handler.GetNumInputs()); + int input_index = 0; + for (size_t i = 0; i < x.size(); i++) { + auto* input_it = (static_cast(x[i])); + if (input_it->numel() == 0) { + continue; + } + srcs_mem.push_back(handler.AcquireSrcMemory(input_it, input_index)); + ++input_index; + } + + std::unordered_map args; + + for (size_t i = 0; i < srcs_mem.size(); ++i) { + args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])}); + } + + auto dst_mem = in_place ? srcs_mem[0] : handler.AcquireDstMemory(out); + + args.insert({DNNL_ARG_DST, *dst_mem}); + + auto sum_p = handler.AcquireForwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + sum_p->execute(astream, args); + astream.wait(); + + out->set_mem_desc(dst_mem->get_desc()); +} +} // namespace phi + +PD_REGISTER_KERNEL( + add_n, OneDNN, ALL_LAYOUT, phi::AddNKernel, float, phi::dtype::bfloat16) {}