// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/platform/mkldnn_reuse.h" namespace paddle { namespace operators { using framework::DataLayout; using framework::Tensor; using mkldnn::memory; using mkldnn::primitive; using mkldnn::stream; template class EltwiseMKLDNNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { const auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto* x = ctx.Input("X"); const auto* y = ctx.Input("Y"); auto* z = ctx.Output("Out"); float scale_x = ctx.Attr("Scale_x"); float scale_y = ctx.Attr("Scale_y"); float scale_o = ctx.Attr("Scale_out"); int axis = ctx.Attr("axis"); platform::BinaryMKLDNNHandler handler( BINARY_OP, axis, dev_ctx, mkldnn_engine, ctx.GetPlace(), x, y, z, scale_x, scale_y, scale_o, ctx.OutputName("Out")); const auto src_x_memory = handler.AcquireSrcMemory(x); const auto src_y_memory = handler.AcquireSecondSrcMemory(y); const auto dst_memory = handler.AcquireDstMemory(z); const auto binary_prim = handler.AcquireForwardPrimitive(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); const std::unordered_map args = { {DNNL_ARG_SRC_0, *src_x_memory}, {DNNL_ARG_SRC_1, *src_y_memory}, {DNNL_ARG_DST, *dst_memory}}; binary_prim->execute(astream, args); astream.wait(); z->set_layout(DataLayout::kMKLDNN); z->set_format(platform::GetMKLDNNFormat(*dst_memory)); } }; inline std::vector CalculateBroadcastedDims(const Tensor* x, const Tensor* y) { const auto src_tz = framework::vectorize(x->dims()); const auto dst_tz = framework::vectorize(y->dims()); size_t j = 0; std::vector dst_tz_ex(src_tz.size(), 1); for (size_t i = 0; i < src_tz.size(); ++i) { dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++]; if (j == dst_tz.size()) break; } return dst_tz_ex; } } // namespace operators } // namespace paddle