From cdd8c8abbecfe2ba449382710767c28d0d8da10f Mon Sep 17 00:00:00 2001 From: Sylwester Fraczek Date: Thu, 10 Nov 2022 13:18:16 +0100 Subject: [PATCH] [phi] migrate prelu (#47422) * migrate prelu * remove cache * review fixes --- .../fluid/operators/mkldnn/prelu_mkldnn_op.cc | 208 ------------------ paddle/phi/backends/onednn/onednn_reuse.h | 61 +++++ .../phi/kernels/onednn/prelu_grad_kernel.cc | 69 ++++++ paddle/phi/kernels/onednn/prelu_kernel.cc | 59 +++++ 4 files changed, 189 insertions(+), 208 deletions(-) delete mode 100644 paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc create mode 100644 paddle/phi/kernels/onednn/prelu_grad_kernel.cc create mode 100644 paddle/phi/kernels/onednn/prelu_kernel.cc diff --git a/paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc deleted file mode 100644 index 4c517a6228..0000000000 --- a/paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc +++ /dev/null @@ -1,208 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/platform/mkldnn_reuse.h" -#include "paddle/phi/core/expect.h" - -namespace paddle { -namespace operators { - -using dnnl::memory; - -using platform::MKLDNNDeviceContext; -using platform::MKLDNNGetDataType; -using platform::to_void_cast; - -namespace { -template -class PReluMKLDNNHandler - : public platform:: - MKLDNNHandlerT { - public: - PReluMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, - const dnnl::engine engine, - platform::Place cpu_place, - const phi::DenseTensor* x, - const phi::DenseTensor* weights, - const std::string& uniq_name, - const std::string& mode, - const std::string& data_format, - bool is_test = false) - : platform::MKLDNNHandlerT( - dev_ctx, - engine, - cpu_place, - platform::CreateKey( - dev_ctx, phi::vectorize(x->dims()), uniq_name)) { - if (unlikely(!this->isCached())) { - auto weights_dims = phi::vectorize(weights->dims()); - - // weights must have same size as X only for "element" case - if (weights->dims().size() != x->dims().size()) { - auto new_weights_dims = std::vector(x->dims().size(), 1); - if (mode == "channel") { - new_weights_dims[1] = - *std::max_element(weights_dims.begin(), weights_dims.end()); - } - weights_dims = std::move(new_weights_dims); - } - auto weights_md = memory::desc( - weights_dims, MKLDNNGetDataType(), memory::format_tag::any); - - this->AcquireForwardPrimitiveDescriptor( - dnnl::prop_kind::forward_training, x->mem_desc(), weights_md); - if (!is_test) - this->AcquireBackwardPrimitiveDescriptor( - x->mem_desc(), weights_md, x->mem_desc(), weights_md); - } - } - - std::shared_ptr AcquireWeightsMemoryPossiblyWithReorder( - const phi::DenseTensor* weights, const bool is_test) { - const T* weights_data = weights->data(); - - // if weights are 1D, every format tag is correct, so we accept - // format_tag::any's output and no reorder is needed - if (weights->dims().size() == 1) { - return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(), - to_void_cast(weights_data), - "@alpha_mem_p"); - } - - return this->AcquireMemoryWithReorder(weights->mem_desc(), - this->fwd_pd_->weights_desc(), - to_void_cast(weights_data), - "@alpha_mem_p", - is_test); - } - - std::shared_ptr AcquireDiffWeightsMemory(phi::DenseTensor* output) { - T* output_data = output->mutable_data( - this->place_, this->bwd_pd_->diff_weights_desc().get_size()); - return this->AcquireMemoryFromPrimitive( - this->bwd_pd_->diff_weights_desc(), output_data, "@diff_weights_mem_p"); - } -}; -} // anonymous namespace - -template -class PReluMKLDNNKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - this->RunKernel(ctx); - } - - void RunKernel(const framework::ExecutionContext& ctx) const { - const auto& dev_ctx = ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - const auto* x = ctx.Input("X"); - const auto* alpha = ctx.Input("Alpha"); - auto* out = ctx.Output("Out"); - const bool is_test = ctx.Attr("is_test"); - const auto mode = ctx.Attr("mode"); - const auto data_format = ctx.Attr("data_format"); - - PReluMKLDNNHandler handler(dev_ctx, - onednn_engine, - ctx.GetPlace(), - x, - alpha, - ctx.InputName("X"), - mode, - data_format, - is_test); - - auto src_memory_p = handler.AcquireSrcMemory(x); - auto weights_memory_p = - handler.AcquireWeightsMemoryPossiblyWithReorder(alpha, is_test); - auto dst_memory_p = handler.AcquireDstMemory(out); - auto prelu_p = handler.AcquireForwardPrimitive(); - - auto& astream = MKLDNNDeviceContext::tls().get_stream(); - prelu_p->execute(astream, - {{DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_WEIGHTS, *weights_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}}); - astream.wait(); - - out->set_mem_desc(dst_memory_p->get_desc()); - } -}; - -template -class PReluGradMKLDNNKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - this->RunKernel(ctx); - } - - void RunKernel(const framework::ExecutionContext& ctx) const { - const auto& dev_ctx = ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - auto* x = ctx.Input("X"); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dalpha = - ctx.Output(framework::GradVarName("Alpha")); - auto* alpha = ctx.Input("Alpha"); - const bool is_test = ctx.Attr("is_test"); - const auto mode = ctx.Attr("mode"); - const auto data_format = ctx.Attr("data_format"); - - PReluMKLDNNHandler handler(dev_ctx, - onednn_engine, - ctx.GetPlace(), - x, - alpha, - framework::GradVarName("X"), - mode, - data_format); - - auto src_memory_p = handler.AcquireSrcMemory(x); - auto weights_memory_p = - handler.AcquireWeightsMemoryPossiblyWithReorder(alpha, is_test); - auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx); - auto diff_weights_memory_p = handler.AcquireDiffWeightsMemory(dalpha); - auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout); - auto prelu_p = handler.AcquireBackwardPrimitive(); - - auto& astream = MKLDNNDeviceContext::tls().get_stream(); - prelu_p->execute(astream, - {{DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_WEIGHTS, *weights_memory_p}, - {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, - {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}, - {DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}}); - astream.wait(); - - dx->set_mem_desc(diff_src_memory_p->get_desc()); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_KERNEL(prelu, - MKLDNN, - paddle::platform::CPUPlace, - ops::PReluMKLDNNKernel, - ops::PReluMKLDNNKernel); - -REGISTER_OP_KERNEL(prelu_grad, - MKLDNN, - paddle::platform::CPUPlace, - ops::PReluGradMKLDNNKernel, - ops::PReluGradMKLDNNKernel); diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index 6ed83e1dec..4e9d9dfc0b 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -1084,6 +1084,67 @@ class BroadcastDataOneDNNHandler } }; +template +class PReluOneDNNHandler + : public OneDNNHandlerNoCachingT { + public: + PReluOneDNNHandler(const dnnl::engine engine, + Place cpu_place, + const DenseTensor& x, + const DenseTensor& weights, + const std::string& mode, + const std::string& data_format, + const bool is_test) + : OneDNNHandlerNoCachingT( + engine, cpu_place) { + auto weights_dims = phi::vectorize(weights.dims()); + // weights must have same size as X only for "element" case + if (weights.dims().size() != x.dims().size()) { + auto new_weights_dims = std::vector(x.dims().size(), 1); + if (mode == "channel") { + new_weights_dims[1] = + *std::max_element(weights_dims.begin(), weights_dims.end()); + } + weights_dims = std::move(new_weights_dims); + } + auto weights_md = memory::desc( + weights_dims, OneDNNGetDataType(), memory::format_tag::any); + + this->AcquireForwardPrimitiveDescriptor( + dnnl::prop_kind::forward_training, x.mem_desc(), weights_md); + if (!is_test) { + this->AcquireBackwardPrimitiveDescriptor( + x.mem_desc(), weights_md, x.mem_desc(), weights_md); + } + } + + std::shared_ptr AcquireWeightsMemoryPossiblyWithReorder( + const DenseTensor* weights, const bool is_test) { + const T* weights_data = weights->data(); + + // if weights are 1D, every format tag is correct, so we accept + // format_tag::any's output and no reorder is needed + if (weights->dims().size() == 1) { + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(), + to_void_cast(weights_data)); + } + + return this->AcquireMemoryWithReorder(weights->mem_desc(), + this->fwd_pd_->weights_desc(), + to_void_cast(weights_data), + is_test); + } + + std::shared_ptr AcquireDiffWeightsMemory(DenseTensor* output) { + T* output_data = output->mutable_data( + this->place_, this->bwd_pd_->diff_weights_desc().get_size()); + return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(), + output_data); + } +}; + template class ReductionOneDNNHandler : public OneDNNHandlerNoCachingT { diff --git a/paddle/phi/kernels/onednn/prelu_grad_kernel.cc b/paddle/phi/kernels/onednn/prelu_grad_kernel.cc new file mode 100644 index 0000000000..1b67a4a002 --- /dev/null +++ b/paddle/phi/kernels/onednn/prelu_grad_kernel.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/prelu_grad_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void PReluGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& alpha, + const DenseTensor& out_grad, + const std::string& data_format, + const std::string& mode, + DenseTensor* x_grad, + DenseTensor* alpha_grad) { + bool is_test = dev_ctx.HasDnnAttr("is_test") + ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("is_test")) + : false; + funcs::PReluOneDNNHandler handler(dev_ctx.GetEngine(), + dev_ctx.GetPlace(), + x, + alpha, + mode, + data_format, + is_test); + + auto src_memory_p = handler.AcquireSrcMemory(&x); + auto weights_memory_p = + handler.AcquireWeightsMemoryPossiblyWithReorder(&alpha, is_test); + auto diff_src_memory_p = handler.AcquireDiffSrcMemory(x_grad); + auto diff_weights_memory_p = handler.AcquireDiffWeightsMemory(alpha_grad); + auto diff_dst_memory_p = handler.AcquireDiffDstMemory(&out_grad); + auto prelu_p = handler.AcquireBackwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + prelu_p->execute(astream, + {{DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, + {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}, + {DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}}); + astream.wait(); + + x_grad->set_mem_desc(diff_src_memory_p->get_desc()); +} + +} // namespace phi + +PD_REGISTER_KERNEL(prelu_grad, + OneDNN, + ONEDNN, + phi::PReluGradKernel, + float, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/onednn/prelu_kernel.cc b/paddle/phi/kernels/onednn/prelu_kernel.cc new file mode 100644 index 0000000000..efe814130b --- /dev/null +++ b/paddle/phi/kernels/onednn/prelu_kernel.cc @@ -0,0 +1,59 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/prelu_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void PReluKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& alpha, + const std::string& data_format, + const std::string& mode, + DenseTensor* out) { + PADDLE_ENFORCE_EQ(dev_ctx.GetPlace().GetType(), + AllocationType::CPU, + phi::errors::PreconditionNotMet( + "Operator oneDNN PReLU must use CPUPlace")); + + bool is_test = dev_ctx.HasDnnAttr("is_test") + ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("is_test")) + : false; + funcs::PReluOneDNNHandler handler(dev_ctx.GetEngine(), + dev_ctx.GetPlace(), + x, + alpha, + mode, + data_format, + is_test); + + auto src_memory_p = handler.AcquireSrcMemory(&x); + auto weights_memory_p = + handler.AcquireWeightsMemoryPossiblyWithReorder(&alpha, is_test); + auto dst_memory_p = handler.AcquireDstMemory(out); + auto prelu_p = handler.AcquireForwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + prelu_p->execute(astream, + {{DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}); + astream.wait(); + + out->set_mem_desc(dst_memory_p->get_desc()); +} +} // namespace phi + +PD_REGISTER_KERNEL( + prelu, OneDNN, ONEDNN, phi::PReluKernel, float, phi::dtype::bfloat16) {} -- GitLab