diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc index 9b0cb0a60902f2c2625f5cfb6927b54a9528daaa..673f7cd88d6caeac5b9f55833feee3d609ee4d35 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc @@ -23,7 +23,7 @@ #include "paddle/phi/core/kernel_registry.h" USE_OP_ITSELF(softmax); -USE_OP_DEVICE_KERNEL(softmax, MKLDNN); +PD_DECLARE_KERNEL(softmax, OneDNN, ONEDNN); USE_OP_ITSELF(elementwise_add); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP_ITSELF(leaky_relu); diff --git a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc deleted file mode 100644 index 57935a1a1c1aaaa23bbebcd632dcfa60289e75aa..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc +++ /dev/null @@ -1,111 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" -#include "paddle/phi/kernels/funcs/axis_utils.h" - -namespace paddle { -namespace operators { - -using paddle::platform::MKLDNNDeviceContext; -using paddle::platform::MKLDNNMemDesc; - -using dnnl::memory; // Note: paddle has also "memory" namespace -using dnnl::primitive; -using dnnl::prop_kind; -using dnnl::softmax_backward; -using dnnl::softmax_forward; -using dnnl::stream; -using platform::to_void_cast; - -template -class SoftmaxMKLDNNHandler - : public platform::MKLDNNHandlerNoCachingT { - public: - SoftmaxMKLDNNHandler(const dnnl::engine mkldnn_engine, - platform::Place cpu_place, - const phi::DenseTensor* input, - phi::DenseTensor* output, - const int axis) - : platform::MKLDNNHandlerNoCachingT(mkldnn_engine, - cpu_place) { - PADDLE_ENFORCE_EQ( - input->dims(), - output->dims(), - platform::errors::InvalidArgument( - "The shape of input and output tensor must be identical.")); - - this->AcquireForwardPrimitiveDescriptor( - prop_kind::forward_scoring, input->mem_desc(), axis); - } -}; - -template -class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - auto& dev_ctx = ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - - const phi::DenseTensor* input = ctx.Input("X"); - phi::DenseTensor* output = ctx.Output("Out"); - bool is_inplaced = input->IsSharedBufferWith(*output); - - const int axis = - phi::funcs::CanonicalAxis(ctx.Attr("axis"), input->dims().size()); - - SoftmaxMKLDNNHandler handler( - mkldnn_engine, ctx.GetPlace(), input, output, axis); - - auto softmax_src_memory_p = handler.AcquireSrcMemory(input); - // For Inplace src and and dst are the same memory object - std::shared_ptr softmax_dst_memory_p = nullptr; - if (is_inplaced) { - softmax_dst_memory_p = softmax_src_memory_p; - output->mutable_data(ctx.GetPlace()); - } else { - softmax_dst_memory_p = handler.AcquireDstMemory(output); - } - auto softmax_p = handler.AcquireForwardPrimitive(); - - auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); - softmax_p->execute(astream, - {{DNNL_ARG_SRC, *softmax_src_memory_p}, - {DNNL_ARG_DST, *softmax_dst_memory_p}}); - astream.wait(); - - const bool is_test = ctx.Attr("is_test"); - if (!is_test) { - T* output_data = output->mutable_data(ctx.GetPlace()); - std::for_each(output_data, &output_data[output->numel()], [](T& val) { - val = std::max(val, static_cast(exp(-64))); - }); - } - - output->set_mem_desc(softmax_dst_memory_p->get_desc()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_KERNEL(softmax, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::SoftmaxMKLDNNKernel, - ops::SoftmaxMKLDNNKernel); diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc index 9a4237230e86f5ad062340d0b56c4a1a3ea2f444..63b7cfd51fb55aedd7fcc1bc840eee81d643ded5 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc @@ -34,7 +34,7 @@ USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN); USE_OP_ITSELF(relu); PD_DECLARE_KERNEL(relu, OneDNN, ONEDNN); USE_OP_ITSELF(softmax); -USE_OP_DEVICE_KERNEL(softmax, MKLDNN); +PD_DECLARE_KERNEL(softmax, OneDNN, ONEDNN); USE_OP_ITSELF(conv2d); PD_DECLARE_KERNEL(conv2d, OneDNN, ONEDNN); diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc index ba6d5e4d3f2c14346cc66a7188d5ba2d0225689a..8aa299570443b2bff1804883110898e88169ffbb 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc @@ -32,8 +32,7 @@ USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP_ITSELF(relu); PD_DECLARE_KERNEL(relu, OneDNN, ONEDNN); USE_OP_ITSELF(softmax); -USE_OP_DEVICE_KERNEL(softmax, MKLDNN); - +PD_DECLARE_KERNEL(softmax, OneDNN, ONEDNN); PD_DECLARE_KERNEL(softmax, CPU, ALL_LAYOUT); namespace paddle { diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index 4a28c4262f32d761eb9fdfa8e6c6e74582d01070..536ac86496d3ff6d78081d182257a871f15d431c 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -753,12 +753,19 @@ class SoftmaxOneDNNHandler public: SoftmaxOneDNNHandler(const dnnl::engine onednn_engine, Place cpu_place, + int axis, const DenseTensor* x, - int axis) + DenseTensor* out) : OneDNNHandlerNoCachingT(onednn_engine, cpu_place) { + PADDLE_ENFORCE_EQ( + x->dims(), + out->dims(), + phi::errors::InvalidArgument( + "The shape of input and output tensor must be identical.")); + const int canonical_axis = funcs::CanonicalAxis(axis, x->dims().size()); this->AcquireForwardPrimitiveDescriptor( dnnl::prop_kind::forward_scoring, x->mem_desc(), canonical_axis); diff --git a/paddle/phi/kernels/onednn/softmax_kernel.cc b/paddle/phi/kernels/onednn/softmax_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..06709aa0fd1582874d4f666458cd1ee55ceb4d1b --- /dev/null +++ b/paddle/phi/kernels/onednn/softmax_kernel.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/softmax_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void SoftmaxKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + DenseTensor* out) { + funcs::SoftmaxOneDNNHandler handler( + dev_ctx.GetEngine(), dev_ctx.GetPlace(), axis, &x, out); + + auto src_memory_p = handler.AcquireSrcMemory(&x); + std::shared_ptr dst_memory_p = nullptr; + if (x.IsSharedBufferWith(*out)) { + dst_memory_p = src_memory_p; + dev_ctx.template Alloc(out); + } else { + dst_memory_p = handler.AcquireDstMemory(out); + } + auto softmax_p = handler.AcquireForwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + softmax_p->execute( + astream, {{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}); + astream.wait(); + + bool is_test = dev_ctx.HasDnnAttr("is_test") + ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("is_test")) + : false; + if (!is_test) { + T* out_data = dev_ctx.template Alloc(out); + std::for_each(out_data, &out_data[out->numel()], [](T& val) { + val = std::max(val, static_cast(exp(-64))); + }); + } + + out->set_mem_desc(dst_memory_p->get_desc()); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + softmax, OneDNN, ONEDNN, phi::SoftmaxKernel, float, phi::dtype::bfloat16) {}