diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc index 1762f638e7de83c149be66850c594046ae351b71..7911b125b12154e4a832d0b72e718a06565b339a 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc @@ -36,7 +36,7 @@ PD_DECLARE_KERNEL(batch_norm, OneDNN, ONEDNN); USE_OP_ITSELF(conv2d_transpose); PD_DECLARE_KERNEL(conv2d_transpose, OneDNN, ONEDNN); USE_OP_ITSELF(elementwise_add); -USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); +PD_DECLARE_KERNEL(add_raw, OneDNN, ONEDNN); USE_OP_ITSELF(gelu); PD_DECLARE_KERNEL(gelu, OneDNN, ONEDNN); PD_DECLARE_ARG_MAPPING_FN(gelu); diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc index 673f7cd88d6caeac5b9f55833feee3d609ee4d35..e8116fc47c8a6d1de3a8b19de93a88adaefec2df 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc @@ -25,7 +25,7 @@ USE_OP_ITSELF(softmax); PD_DECLARE_KERNEL(softmax, OneDNN, ONEDNN); USE_OP_ITSELF(elementwise_add); -USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); +PD_DECLARE_KERNEL(add_raw, OneDNN, ONEDNN); USE_OP_ITSELF(leaky_relu); PD_DECLARE_KERNEL(leaky_relu, OneDNN, ONEDNN); USE_OP_ITSELF(gelu); diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc deleted file mode 100644 index 57996477e38a972b2e2d7472c6ac60173b35a7b5..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" - -namespace ops = paddle::operators; - -REGISTER_OP_KERNEL( - elementwise_add, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::EltwiseMKLDNNKernel, - ops::EltwiseMKLDNNKernel, - ops::EltwiseMKLDNNKernel, - ops::EltwiseMKLDNNKernel) diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h deleted file mode 100644 index 6c7a8a7a66cf5194bc54b9b1d91dbb3898be08a1..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h +++ /dev/null @@ -1,415 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include - -#include "paddle/fluid/framework/data_layout_transform.h" -#include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" - -namespace paddle { -namespace operators { - -using dnnl::memory; -using dnnl::primitive; -using dnnl::stream; -using phi::DataLayout; -using phi::OneDNNContext; -using phi::funcs::BinaryOneDNNHandler; - -inline std::vector CalculateBroadcastedDims( - const phi::DenseTensor* x, const phi::DenseTensor* y) { - const auto src_tz = phi::vectorize(x->dims()); - const auto dst_tz = phi::vectorize(y->dims()); - - std::vector dst_tz_ex(src_tz.size(), 1); - - if (src_tz.size() == dst_tz.size()) { - for (size_t i = 0; i < src_tz.size(); i++) { - dst_tz_ex[i] = (src_tz[i] == dst_tz[i]) ? dst_tz[i] : 1; - } - } else { - size_t j = 0; - for (size_t i = 0; i < src_tz.size(); i++) { - dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++]; - if (j == dst_tz.size()) break; - } - } - - return dst_tz_ex; -} - -inline void AddSubNonBroadcast( - phi::funcs::ReorderOneDNNHandler* reorder_handler, - phi::DenseTensor* grad_tensor, - const std::shared_ptr& src_memory, - const std::shared_ptr& dst_memory, - const std::vector& scales) { - dnnl::primitive_attr reorder_attr; - reorder_attr.set_output_scales(0, scales); - auto reorder_p = - reorder_handler->AcquireReorder(dst_memory, src_memory, reorder_attr); - - reorder_p->execute( - OneDNNContext::tls().get_stream(), *src_memory, *dst_memory); -} - -template -inline void BroadcastReduction(const framework::ExecutionContext& ctx, - const dnnl::engine& onednn_engine, - phi::DenseTensor* grad_tensor, - const phi::DenseTensor* dout, - const std::shared_ptr& src_memory, - std::shared_ptr dst_memory, - const std::vector& scales, - const bool is_sub) { - dnnl::primitive_attr broadcast_reduction_attr; - - // Broadcasting - if (is_sub) { - dnnl::post_ops po; - po.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, scales[0], 0); - broadcast_reduction_attr.set_post_ops(po); - } - - phi::funcs::ReductionOneDNNHandler reduction_handler( - dnnl::algorithm::reduction_sum, - 0.0f, - 0.0f, - onednn_engine, - ctx.GetPlace(), - dout, - grad_tensor, - CalculateBroadcastedDims(dout, grad_tensor), - broadcast_reduction_attr); - dst_memory = reduction_handler.AcquireDstMemory(grad_tensor); - - auto reduction_p = reduction_handler.AcquireForwardPrimitive(); - auto astream = OneDNNContext::tls().get_stream(); - reduction_p->execute(astream, - { - {DNNL_ARG_SRC, *src_memory}, - {DNNL_ARG_DST, *dst_memory}, - }); - astream.wait(); - grad_tensor->set_mem_desc(dst_memory->get_desc().reshape( - phi::vectorize(grad_tensor->dims()))); -} - -template -class EltwiseMKLDNNKernel : public framework::OpKernel { - private: - dnnl::post_ops get_post_ops(const framework::ExecutionContext& ctx) const { - dnnl::post_ops post_operations; - platform::AppendActivation(ctx, post_operations); - if (ctx.HasAttr("fused_output_scale")) { - float scale_alpha = ctx.Attr("fused_output_scale"); - post_operations.append_eltwise( - 1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f); - } - return post_operations; - } - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto& dev_ctx = ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - - float scale_x = ctx.Attr("Scale_x"); - float scale_y = ctx.Attr("Scale_y"); - float scale_o = ctx.Attr("Scale_out"); - int axis = ctx.Attr("axis"); - - BinaryOneDNNHandler handler(BINARY_OP, - axis, - mkldnn_engine, - ctx.GetPlace(), - x, - y, - z, - scale_x, - scale_y, - scale_o, - true, - get_post_ops(ctx)); - - // oneDNN's binary is optimized for broadcasting y into x, so in other case - // we have to swap tensors to achieve optimal performance - if (x->numel() < y->numel()) { - std::swap(x, y); - } - - const auto src_x_memory = handler.AcquireSrcMemory(x); - const auto src_y_memory = handler.AcquireSecondSrcMemory(y); - // (jczaja) For Inplace src and dst should be the same memory object. - // So x should share buffer with z. But UT mechanics is testing inplace - // execution for this op not checking that x can be bradcasted to match in - // shape y tensor. - // This is wrong as when x is to be broadcasted then z(out) will match the - // shape of y which is bigger than x. Hence if x is smaller in shape than z - // and they share a buffer (of - // shape x) then this buffer is not big enough to hold result of elementwise - // operation. - const bool reuse_x_memopry = - x->numel() == z->numel() && x->IsSharedBufferWith(*z); - std::shared_ptr dst_memory; - if (reuse_x_memopry) { - dst_memory = src_x_memory; - // NOTE(chenfeiyu): when the output reuses memory from other tensor rather - // than allocate its own, it's still need to take care of its data type. - // Unfortunately, paddle's operator only infers the output' shape, but not - // the data type. mutable_data takes care of allocation and data type - // normally, but if the memory is already allocated and there is no need - // to re-allocate, it just set the data type. So this it added there to - // get the right data type. - z->mutable_data(ctx.GetPlace()); - } else { - dst_memory = handler.AcquireDstMemory(z); - } - - const auto binary_prim = handler.AcquireForwardPrimitive(); - - auto& astream = OneDNNContext::tls().get_stream(); - - const std::unordered_map args = { - {DNNL_ARG_SRC_0, *src_x_memory}, - {DNNL_ARG_SRC_1, *src_y_memory}, - {DNNL_ARG_DST, *dst_memory}}; - - binary_prim->execute(astream, args); - astream.wait(); - - if (handler.use_broadcasting_hack == false) { - platform::SetOutMemDescWithLogicalLayoutFusesSupport( - ctx, z, dst_memory->get_desc()); - } else { - auto dims = dst_memory->get_desc().dims(); - dims.insert(dims.begin(), x->dims()[0]); - dims[1] /= dims[0]; - platform::SetOutMemDescWithLogicalLayoutFusesSupport( - ctx, z, dst_memory->get_desc().reshape(dims)); - } - } -}; - -template -class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - ElemwiseGradKernel::Compute(ctx); - - auto& dev_ctx = ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* out = ctx.Input("Out"); - - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - auto* dout = ctx.Input(framework::GradVarName("Out")); - - // oneDNN's binary is optimized for broadcasting y into x, so in other case - // we have to swap tensors to achieve optimal performance - bool swap_x_y = false; - if (x->numel() < y->numel()) { - std::swap(x, y); - std::swap(dx, dy); - swap_x_y = true; - } - - std::vector scales{1.0}; - if (swap_x_y) { - scales[0] = (BINARY_OP == dnnl::algorithm::binary_add) ? 1 : -1; - } - - int axis = ctx.Attr("axis"); - - auto tz = phi::vectorize(dout->dims()); - auto dout_type = phi::funcs::ToOneDNNDataType(dout->dtype()); - - phi::funcs::ReorderOneDNNHandler reorder_handler( - tz, dout->dtype(), dout_type, onednn_engine); - - auto reorder_src_memory = reorder_handler.AcquireSrcMemory( - dout->mem_desc(), phi::funcs::to_void_cast(dout->data())); - - std::shared_ptr dst_memory; - std::shared_ptr broadcast_src_memory = reorder_src_memory; - - auto& astream = OneDNNContext::tls().get_stream(); - if (dx) { - // elementwise_add & elementwise_sub - if (BINARY_OP == dnnl::algorithm::binary_add || - BINARY_OP == dnnl::algorithm::binary_sub) { - if (dout->dims() == dx->dims()) { - dst_memory = reorder_handler.AcquireDstMemory( - dx, dout->mem_desc(), ctx.GetPlace()); - AddSubNonBroadcast( - &reorder_handler, dx, reorder_src_memory, dst_memory, scales); - } - } else { // elementwise_mul & elementwise_div - BinaryOneDNNHandler binary_handler(BINARY_OP, - axis, - onednn_engine, - ctx.GetPlace(), - dout, - y, - dx, - 1.0f, - 1.0f, - 1.0f, - false); - - const auto src_dout_memory = binary_handler.AcquireSrcMemory(dout); - const auto src_y_memory = binary_handler.AcquireSecondSrcMemory(y); - dst_memory = binary_handler.AcquireDstMemory(dx); - - const auto binary_prim = binary_handler.AcquireForwardPrimitive(); - - const std::unordered_map args = { - {DNNL_ARG_SRC_0, *src_dout_memory}, - {DNNL_ARG_SRC_1, *src_y_memory}, - {DNNL_ARG_DST, *dst_memory}}; - - binary_prim->execute(astream, args); - } - astream.wait(); - - if (dout->dims() != dx->dims()) { - BroadcastReduction(ctx, - onednn_engine, - dx, - dout, - broadcast_src_memory, - dst_memory, - scales, - BINARY_OP == dnnl::algorithm::binary_sub); - } else { - dx->set_mem_desc(dst_memory->get_desc()); - } - } - - if (dy) { - // elementwise_add & elementwise_sub - if (BINARY_OP == dnnl::algorithm::binary_add || - BINARY_OP == dnnl::algorithm::binary_sub) { - if (dout->dims() == dy->dims()) { - dst_memory = reorder_handler.AcquireDstMemory( - dy, dout->mem_desc(), ctx.GetPlace()); - AddSubNonBroadcast( - &reorder_handler, dy, reorder_src_memory, dst_memory, scales); - } - } else { // elementwise_mul & elementwise_div - std::unordered_map args; - std::shared_ptr binary_prim; - std::shared_ptr post_op_memory; - std::shared_ptr src_0_memory; - std::shared_ptr src_1_memory; - - BinaryOneDNNHandler binary_handler(dnnl::algorithm::binary_mul, - axis, - onednn_engine, - ctx.GetPlace(), - dout, - x, - nullptr, - 1.0f, - 1.0f, - 1.0f, - false); - - src_1_memory = binary_handler.AcquireSecondSrcMemory(x); - - if (BINARY_OP == dnnl::algorithm::binary_div) { - BinaryOneDNNHandler post_op_binary_handler( - dnnl::algorithm::binary_div, - axis, - onednn_engine, - ctx.GetPlace(), - y, - y, - nullptr, - 1.0f, - 1.0f, - 1.0f, - false); - - post_op_memory = post_op_binary_handler.AcquireSrcMemory(y); - - dnnl::post_ops po; - po.append_binary(dnnl::algorithm::binary_div, - post_op_memory->get_desc()); - - binary_handler = BinaryOneDNNHandler(dnnl::algorithm::binary_mul, - axis, - onednn_engine, - ctx.GetPlace(), - dout, - out, - nullptr, - -1.0f, - 1.0f, - 1.0f, - false, - po); - - src_1_memory = binary_handler.AcquireSecondSrcMemory(out); - } - - src_0_memory = binary_handler.AcquireSrcMemory(dout); - - const auto dst_dy_memory = (dout->dims() == dy->dims()) - ? binary_handler.AcquireDstMemory(dy) - : binary_handler.AcquireDstMemory(); - - binary_prim = binary_handler.AcquireForwardPrimitive(); - args = {{DNNL_ARG_SRC_0, *src_0_memory}, - {DNNL_ARG_SRC_1, *src_1_memory}, - {DNNL_ARG_DST, *dst_dy_memory}}; - - if (BINARY_OP == dnnl::algorithm::binary_div) - args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, - *post_op_memory}); - - binary_prim->execute(astream, args); - broadcast_src_memory = dst_dy_memory; - dst_memory = dst_dy_memory; - } - astream.wait(); - - if (dout->dims() != dy->dims()) { - BroadcastReduction(ctx, - onednn_engine, - dy, - dout, - broadcast_src_memory, - dst_memory, - scales, - BINARY_OP == dnnl::algorithm::binary_sub); - } else { - dy->set_mem_desc(dst_memory->get_desc()); - } - } - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc deleted file mode 100644 index ba3a0d87f6cf74156be6491550707d46688f8956..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" - -namespace ops = paddle::operators; - -REGISTER_OP_KERNEL( - elementwise_mul, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::EltwiseMKLDNNKernel, - ops::EltwiseMKLDNNKernel, - ops::EltwiseMKLDNNKernel, - ops::EltwiseMKLDNNKernel) diff --git a/paddle/fluid/operators/elementwise/unity_build_rule.cmake b/paddle/fluid/operators/elementwise/unity_build_rule.cmake index 060c990ea8712517c6eeac8ab754c18b07b94d19..057e8d700ecbcb7e684ee80f2aa05238060777ba 100644 --- a/paddle/fluid/operators/elementwise/unity_build_rule.cmake +++ b/paddle/fluid/operators/elementwise/unity_build_rule.cmake @@ -7,14 +7,12 @@ register_unity_group( cc elementwise_add_op.cc - mkldnn/elementwise_add_mkldnn_op.cc elementwise_div_op.cc elementwise_floordiv_op.cc elementwise_max_op.cc elementwise_min_op.cc elementwise_mod_op.cc elementwise_mul_op.cc - mkldnn/elementwise_mul_mkldnn_op.cc elementwise_pow_op.cc elementwise_sub_op.cc) register_unity_group( diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc index e005683e242288868d864fa7958138369db1d232..cbf0b918e6d727706a403d290c5def54e5dc8704 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc @@ -28,9 +28,9 @@ #include "paddle/phi/core/kernel_registry.h" USE_OP_ITSELF(elementwise_add); -USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); +PD_DECLARE_KERNEL(add_raw, OneDNN, ONEDNN); USE_OP_ITSELF(elementwise_mul); -USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN); +PD_DECLARE_KERNEL(multiply_raw, OneDNN, ONEDNN); USE_OP_ITSELF(relu); PD_DECLARE_KERNEL(relu, OneDNN, ONEDNN); USE_OP_ITSELF(softmax); diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc index 8aa299570443b2bff1804883110898e88169ffbb..2c8ef7f0981dd1d107a402e497d8edd7a7b787d7 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc @@ -28,7 +28,7 @@ #include "paddle/phi/core/kernel_registry.h" USE_OP_ITSELF(elementwise_add); -USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); +PD_DECLARE_KERNEL(add_raw, OneDNN, ONEDNN); USE_OP_ITSELF(relu); PD_DECLARE_KERNEL(relu, OneDNN, ONEDNN); USE_OP_ITSELF(softmax); diff --git a/paddle/phi/kernels/elementwise_kernel.cc b/paddle/phi/kernels/elementwise_kernel.cc index d4ffd49b5fc48b7b67d070a4afaef8ffa12dee60..c6031b34af249c6e054a27d22d0f726ea0ea91cb 100644 --- a/paddle/phi/kernels/elementwise_kernel.cc +++ b/paddle/phi/kernels/elementwise_kernel.cc @@ -414,17 +414,3 @@ PD_REGISTER_KERNEL(elementwise_pow, float, phi::dtype::float16) {} #endif - -#if defined PADDLE_WITH_MKLDNN -PD_REGISTER_KERNEL(subtract, - OneDNN, - ONEDNN, - phi::SubtractKernel, - float, - phi::dtype::bfloat16, - int8_t, - uint8_t) {} - -PD_REGISTER_KERNEL( - divide, OneDNN, ONEDNN, phi::DivideKernel, float, phi::dtype::bfloat16) {} -#endif diff --git a/paddle/phi/kernels/onednn/elementwise_kernel.cc b/paddle/phi/kernels/onednn/elementwise_kernel.cc index 29d527a523fbfe4df6d2c78d5de24f42629d733a..b786da7a319156d7a119b4994f296df91650dae4 100644 --- a/paddle/phi/kernels/onednn/elementwise_kernel.cc +++ b/paddle/phi/kernels/onednn/elementwise_kernel.cc @@ -32,14 +32,14 @@ void ElementwiseKernel(const OneDNNContext& dev_ctx, float scale_x = dev_ctx.HasDnnAttr("Scale_x") ? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("Scale_x")) - : 1; + : 1.0f; float scale_y = dev_ctx.HasDnnAttr("Scale_y") ? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("Scale_y")) - : 1; + : 1.0f; float scale_out = dev_ctx.HasDnnAttr("Scale_out") ? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("Scale_out")) - : 1; + : 1.0f; dnnl::post_ops post_operations; funcs::AppendActivation(dev_ctx, post_operations); @@ -114,12 +114,14 @@ void ElementwiseKernel(const OneDNNContext& dev_ctx, astream.wait(); if (handler.use_broadcasting_hack == false) { - out->set_mem_desc(dst_memory->get_desc()); + funcs::SetOutMemDescWithLogicalLayoutFusesSupport( + dev_ctx, out, dst_memory->get_desc()); } else { auto dims = dst_memory->get_desc().dims(); dims.insert(dims.begin(), non_const_x->dims()[0]); dims[1] /= dims[0]; - out->set_mem_desc(dst_memory->get_desc().reshape(dims)); + funcs::SetOutMemDescWithLogicalLayoutFusesSupport( + dev_ctx, out, dst_memory->get_desc().reshape(dims)); } } @@ -131,13 +133,40 @@ void ElementwiseKernel(const OneDNNContext& dev_ctx, int axis, \ DenseTensor* out) { \ ElementwiseKernel(dev_ctx, x, y, axis, out); \ + } \ + template \ + void name##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + DenseTensor* out) { \ + ElementwiseKernel(dev_ctx, x, y, -1, out); \ } +DEFINE_ONEDNN_ELEMENTWISE_KERNEL(Add, dnnl::algorithm::binary_add) DEFINE_ONEDNN_ELEMENTWISE_KERNEL(Subtract, dnnl::algorithm::binary_sub) +DEFINE_ONEDNN_ELEMENTWISE_KERNEL(Multiply, dnnl::algorithm::binary_mul) DEFINE_ONEDNN_ELEMENTWISE_KERNEL(Divide, dnnl::algorithm::binary_div) } // namespace phi +PD_REGISTER_KERNEL(add_raw, + OneDNN, + ONEDNN, + phi::AddRawKernel, + float, + phi::dtype::bfloat16, + int8_t, + uint8_t) {} + +PD_REGISTER_KERNEL(add, + OneDNN, + ONEDNN, + phi::AddKernel, + float, + phi::dtype::bfloat16, + int8_t, + uint8_t) {} + PD_REGISTER_KERNEL(subtract_raw, OneDNN, ONEDNN, @@ -147,9 +176,39 @@ PD_REGISTER_KERNEL(subtract_raw, int8_t, uint8_t) {} +PD_REGISTER_KERNEL(subtract, + OneDNN, + ONEDNN, + phi::SubtractKernel, + float, + phi::dtype::bfloat16, + int8_t, + uint8_t) {} + +PD_REGISTER_KERNEL(multiply_raw, + OneDNN, + ONEDNN, + phi::MultiplyRawKernel, + float, + phi::dtype::bfloat16, + int8_t, + uint8_t) {} + +PD_REGISTER_KERNEL(multiply, + OneDNN, + ONEDNN, + phi::MultiplyKernel, + float, + phi::dtype::bfloat16, + int8_t, + uint8_t) {} + PD_REGISTER_KERNEL(divide_raw, OneDNN, ONEDNN, phi::DivideRawKernel, float, phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL( + divide, OneDNN, ONEDNN, phi::DivideKernel, float, phi::dtype::bfloat16) {}