// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/platform/mkldnn_reuse.h" namespace paddle { namespace operators { using framework::DataLayout; using framework::Tensor; using mkldnn::memory; using mkldnn::primitive; using mkldnn::stream; template class EltwiseMKLDNNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { const auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto* x = ctx.Input("X"); const auto* y = ctx.Input("Y"); auto* z = ctx.Output("Out"); float scale_x = ctx.Attr("Scale_x"); float scale_y = ctx.Attr("Scale_y"); float scale_o = ctx.Attr("Scale_out"); int axis = ctx.Attr("axis"); platform::BinaryMKLDNNHandler handler( BINARY_OP, axis, dev_ctx, mkldnn_engine, ctx.GetPlace(), x, y, z, scale_x, scale_y, scale_o, ctx.OutputName("Out")); const auto src_x_memory = handler.AcquireSrcMemory(x); const auto src_y_memory = handler.AcquireSecondSrcMemory(y); // For Inplace src and and dst are the same memory object const auto dst_memory = x->IsSharedBufferWith(*z) ? src_x_memory : handler.AcquireDstMemory(z); const auto binary_prim = handler.AcquireForwardPrimitive(); mkldnn::stream astream(mkldnn_engine); const std::unordered_map args = { {DNNL_ARG_SRC_0, *src_x_memory}, {DNNL_ARG_SRC_1, *src_y_memory}, {DNNL_ARG_DST, *dst_memory}}; binary_prim->execute(astream, args); astream.wait(); z->set_layout(DataLayout::kMKLDNN); z->set_format(platform::GetMKLDNNFormat(*dst_memory)); } }; } // namespace operators } // namespace paddle