From 653885a56b567cb83b03113e48d606a94c12ee18 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Fri, 29 Jul 2022 09:30:20 +0200 Subject: [PATCH] [WIP] Matmul v1 & v2 unification -- part 1 (#44640) * - Unit tests to be debugged - fix - refactor - diagnostic - more diagnostic - fix - Fix number two - fix - fix - fix - alpha added - more fixes - compilation fix - removed diagnostic code - cosmetic fixes * lint --- .../operators/mkldnn/matmul_mkldnn_op.cc | 734 --------------- .../operators/mkldnn/matmul_v2_mkldnn_op.cc | 857 ++++++++++++++++-- .../fluid/operators/mkldnn/mul_mkldnn_op.cc | 3 +- paddle/fluid/platform/mkldnn_reuse.h | 130 +-- 4 files changed, 863 insertions(+), 861 deletions(-) delete mode 100644 paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc deleted file mode 100644 index 912b1be813a..00000000000 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ /dev/null @@ -1,734 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h" - -#include - -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" - -using dnnl::memory; -using dnnl::primitive; -using paddle::framework::DataLayout; -using paddle::framework::ExecutionContext; -using paddle::platform::GetMKLDNNFormat; -using paddle::platform::MKLDNNDeviceContext; -using paddle::platform::MKLDNNFormatForSize; -using paddle::platform::MKLDNNGetDataType; -using paddle::platform::to_void_cast; -using phi::vectorize; -using Tensor = paddle::framework::Tensor; - -namespace { - -// Reshape a rank-3 tensor from P x M x N to (P * M) x N. -// Identity op if the tensor is not of rank 3. -static Tensor FoldOuterDims(const Tensor& input) { - auto output = input; - auto in_dims = input.dims(); - if (in_dims.size() == 3) { - output.Resize({in_dims[0] * in_dims[1], in_dims[2]}); - } - return output; -} - -// Reshape a rank-3 tensor from P x M x N to M x (P * N). -// (Warning: This requires transposing data and writes into new memory.) -// Identity op if the tensor is not of rank 3. -template -static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext& dev_ctx, - const Tensor* input) { - auto input_dims = vectorize(input->dims()); - if (input_dims.size() != 3) { - return *input; - } - - Tensor output; - output.Resize({input_dims[1], input_dims[0], input_dims[2]}); - - auto output_dims = vectorize(output.dims()); - - memory::data_type input_type = paddle::framework::ToMKLDNNDataType( - paddle::framework::TransToProtoVarType(input->dtype())); - paddle::platform::ReorderMKLDNNHandler reorder_handler( - output_dims, - paddle::framework::TransToProtoVarType(input->dtype()), - input_type, - dev_ctx.GetEngine()); - - auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( - memory::format_tag::abc, - paddle::platform::to_void_cast(input->data())); - auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( - &output, memory::format_tag::bac, dev_ctx.GetPlace()); - auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p, - reorder_dst_memory_p); - - auto& astream = MKLDNNDeviceContext::tls().get_stream(); - reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); - astream.wait(); - - output.Resize({input_dims[1], input_dims[0] * input_dims[2]}); - return output; -} - -template -constexpr bool IsInt8() { - return std::is_same::value || std::is_same::value; -} - -template -constexpr bool IsBfloat16() { - return std::is_same::value; -} - -// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the -// original x_dim is returned. -static paddle::framework::DDim RowMatrixDimsFromVector( - const paddle::framework::DDim& x_dim) { - return x_dim.size() > 1 ? x_dim : phi::make_ddim({1, x_dim[0]}); -} - -// Get column matrix shape from a vector shape. If the ran of y_dim > 1, the -// original y_dim is returned. -static paddle::framework::DDim ColumnMatrixDimsFromVector( - const paddle::framework::DDim& y_dim) { - return y_dim.size() > 1 ? y_dim : phi::make_ddim({y_dim[0], 1}); -} - -template -class MatMulMKLDNNHandler - : public paddle::platform::MKLDNNHandlerNoCachingT { - public: - MatMulMKLDNNHandler(const dnnl::engine engine, - paddle::platform::Place cpu_place, - Tensor* x, - bool trans_x, - Tensor* y, - bool trans_y, - Tensor* out, - float scale) - : paddle::platform::MKLDNNHandlerNoCachingT(engine, - cpu_place) { - auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x->dims(), 0, trans_x); - auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y->dims(), 0, trans_y); - - memory::dim x_bs = mat_dim_x.batch_size_; - memory::dim y_bs = mat_dim_y.batch_size_; - - memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1; - const memory::dim M = mat_dim_x.height_; - const memory::dim N = mat_dim_y.width_; - const memory::dim K = mat_dim_x.width_; - - memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K}; - memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N}; - memory::dims out_dims = {out_bs, M, N}; - - memory::dims x_strides = - !trans_x ? memory::dims{M * K, K, 1} : memory::dims{M * K, 1, M}; - - memory::dims y_strides = - !trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K}; - memory::dims out_strides = memory::dims{M * N, N, 1}; - - auto x_md = memory::desc(x_dims, MKLDNNGetDataType(), x_strides); - auto y_md = memory::desc(y_dims, MKLDNNGetDataType(), y_strides); - auto out_md = memory::desc(out_dims, MKLDNNGetDataType(), out_strides); - - dnnl::primitive_attr attrs; - if (scale != 1.0f) attrs.set_output_scales(0, {scale}); - - this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md); - } - // Constructor for FWD MatMul - MatMulMKLDNNHandler(const dnnl::engine engine, const ExecutionContext& ctx) - : paddle::platform::MKLDNNHandlerNoCachingT( - engine, ctx.GetPlace()) { - const dnnl::primitive_attr matmul_attrs = CreateMatmulAttrs(ctx); - - auto matmul_dims_ = GetMatmulDims(ctx); - auto x_md = memory::desc( - matmul_dims_.x_dims, MKLDNNGetDataType(), matmul_dims_.x_strides); - auto y_md = memory::desc( - matmul_dims_.y_dims, MKLDNNGetDataType(), matmul_dims_.y_strides); - auto out_md = memory::desc(matmul_dims_.out_dims, - MKLDNNGetDataType(), - matmul_dims_.out_strides); - this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md); - } - - std::shared_ptr AcquireWeightsMemory(const Tensor* input) { - const YT* input_data = input->data(); - return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(), - to_void_cast(input_data)); - } - - public: - void Execute(const paddle::framework::Tensor* x, - const paddle::framework::Tensor* y, - paddle::framework::Tensor* out) { - const auto src_memory_p = this->AcquireSrcMemory(x); - const auto weights_memory_p = this->AcquireWeightsMemory(y); - const auto dst_memory_p = this->AcquireDstMemory(out); - - auto matmul_p = this->AcquireForwardPrimitive(); - - std::unordered_map matmul_args = { - {DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_WEIGHTS, *weights_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}}; - - auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); - - // Simulate batch matmul by processing in loop - void* x_ptr = src_memory_p->get_data_handle(); - void* y_ptr = weights_memory_p->get_data_handle(); - void* out_ptr = dst_memory_p->get_data_handle(); - auto offsets = this->GetOffsets(); - for (uint16_t i = 0; i < this->GetBatchSize(); ++i) { - src_memory_p->set_data_handle(x_ptr); - weights_memory_p->set_data_handle(y_ptr); - dst_memory_p->set_data_handle(out_ptr); - matmul_p->execute(astream, - { - {DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_WEIGHTS, *weights_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}, - }); - x_ptr = static_cast(x_ptr) + std::get<0>(offsets); - y_ptr = static_cast(y_ptr) + std::get<1>(offsets); - out_ptr = static_cast(out_ptr) + std::get<2>(offsets); - } - astream.wait(); - - auto format = - MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw); - out->set_format(format); - out->set_layout(DataLayout::kMKLDNN); - } - - std::shared_ptr AcquireDstMemory( - paddle::framework::Tensor* output) { - // We cannot use base AcquireDstMemory as it makes an allocation request - // base on DST memory primitive size. This is fine in general, but in MatMul - // we have primitive that covers only one batch of Data and then shift - // pointer for every new batch. Hence Tensor size is bigger that dst memory - // primitive size. So would we request less memory that is there and it - // triggers an - // assertion. So as there is no 'any' format here we can leave default size - // of Tensor as computed in ComputeInferShape - OT* ptr = output->mutable_data(this->place_); - return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr); - } - - private: - struct MatMulDims { - const memory::dims x_dims, y_dims, out_dims, x_strides, y_strides, - out_strides; - }; - - phi::DDim GetDimForInput(const ExecutionContext& ctx, - std::string input_name) { - auto shape = ctx.Attr>("fused_reshape_" + input_name); - auto axis = ctx.Attr>("fused_transpose_" + input_name); - auto input_dims = ctx.Input(input_name)->dims(); - if (!shape.empty() && !axis.empty()) { - auto it_zero = std::find(shape.begin(), shape.end(), 0); - if (it_zero != shape.end()) { - for (uint64_t i = 0; i < shape.size(); i++) { - if (shape[i] == 0) { - PADDLE_ENFORCE_LT( - i, - input_dims.size(), - paddle::platform::errors::InvalidArgument( - "The index of 0 in fused_reshape_%s ", - "should be less than output dim size, ", - "but the index is %d and output dim size is %d", - input_name, - i, - input_dims.size())); - shape[i] = input_dims.at(i); - } - } - } - - return input_dims.reshape(shape).transpose(axis); - } - return input_dims; - } - - std::pair GetInputDimsAndStrides( - const ExecutionContext& ctx, std::string input_name) { - auto shape = ctx.Attr>("fused_reshape_" + input_name); - auto axis = ctx.Attr>("fused_transpose_" + input_name); - auto input_dims = ctx.Input(input_name)->dims(); - auto new_dims = input_dims; - if (!shape.empty() && !axis.empty()) { - auto it_zero = std::find(shape.begin(), shape.end(), 0); - if (it_zero != shape.end()) { - for (uint64_t i = 0; i < shape.size(); i++) { - if (shape[i] == 0) { - PADDLE_ENFORCE_LT( - i, - input_dims.size(), - paddle::platform::errors::InvalidArgument( - "The index of 0 in fused_reshape_%s ", - "should be less than output dim size, ", - "but the index is %d and output dim size is %d", - input_name, - i, - input_dims.size())); - shape[i] = input_dims.at(i); - } - } - } - - new_dims = input_dims.reshape(shape).transpose(axis); - } - - auto& MatrixDimsFromVector = input_name == "X" ? RowMatrixDimsFromVector - : ColumnMatrixDimsFromVector; - phi::funcs::MatDescriptor mat_dim = phi::funcs::CreateMatrixDescriptor( - MatrixDimsFromVector(new_dims), - 0, - ctx.Attr("transpose_" + input_name)); - - memory::dims strides; - if (!shape.empty()) { - auto shape2 = input_dims.reshape(shape); - strides.push_back(1); - for (auto i = shape2.size() - 1; i > 0; --i) { - strides.insert(strides.begin(), strides.front() * shape2[i]); - } - strides = Transpose(strides, axis); - if (shape.size() == 4) - strides.erase(strides.begin()); - else if (shape.size() == 2) - strides.insert(strides.begin(), shape[0] * shape[1]); - mat_dim.stride_ = strides[0]; - if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin())); - } - return std::make_pair(mat_dim, strides); - } - - float ComputeOutputScale(const ExecutionContext& ctx) { - float scale_x = ctx.Attr("Scale_x"); - float scale_y = ctx.Attr("Scale_y"); - bool force_fp32_out = ctx.Attr("force_fp32_output"); - float scale_out = force_fp32_out ? 1.f : ctx.Attr("Scale_out"); - float alpha = ctx.Attr("alpha"); - return alpha * scale_out / (scale_x * scale_y); - } - - bool IsInputFused(const ExecutionContext& ctx) const { - return !(ctx.Attr>("fused_reshape_X").empty() && - ctx.Attr>("fused_reshape_Y").empty()); - } - - bool IsOutputFused(const ExecutionContext& ctx) const { - auto& fused_reshape_Out = ctx.Attr>("fused_reshape_Out"); - auto& fused_transpose_Out = - ctx.Attr>("fused_transpose_Out"); - return !fused_reshape_Out.empty() && !fused_transpose_Out.empty(); - } - - MatMulDims GetMatmulDims(const ExecutionContext& ctx) { - phi::funcs::MatDescriptor mat_dim_x; - memory::dims strides_x; - std::tie(mat_dim_x, strides_x) = GetInputDimsAndStrides(ctx, "X"); - phi::funcs::MatDescriptor mat_dim_y; - memory::dims strides_y; - std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y"); - - auto x_bs = mat_dim_x.batch_size_; - auto y_bs = mat_dim_y.batch_size_; - PADDLE_ENFORCE_EQ(x_bs > 0 && y_bs > 0 && x_bs != y_bs, - false, - paddle::platform::errors::InvalidArgument( - "If batch sizes of X and Y are positive," - "they have to be equal.")); - - memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1; - const memory::dim M = mat_dim_x.height_; - const memory::dim N = mat_dim_y.width_; - const memory::dim K = mat_dim_x.width_; - - batch_size_ = 1; - if (out_bs > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) { - auto x_dims = GetDimForInput(ctx, "X"); - auto y_dims = GetDimForInput(ctx, "Y"); - batch_size_ = x_bs > y_bs ? x_dims[0] : y_dims[0]; - x_bs /= batch_size_; - y_bs /= batch_size_; - out_bs /= batch_size_; - } - memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K}; - memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N}; - memory::dims out_dims = {out_bs, M, N}; - - x_offset_ = x_bs * M * K * sizeof(XT); - y_offset_ = y_bs * K * N * sizeof(YT); - out_offset_ = out_bs * M * N * sizeof(OT); - - // Translate transA and transB - if (strides_x.empty()) - strides_x = !ctx.Attr("transpose_X") ? memory::dims{M * K, K, 1} - : memory::dims{M * K, 1, M}; - if (strides_y.empty()) - strides_y = !ctx.Attr("transpose_Y") ? memory::dims{N * K, N, 1} - : memory::dims{N * K, 1, K}; - memory::dims out_strides = memory::dims{M * N, N, 1}; - - CorrectStridesWhenFloatOutputFused(ctx, N, out_bs, &out_strides); - - return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides}; - } - - std::vector Transpose(const std::vector& x, - const std::vector& axis) { - size_t in_rank = x.size(); - size_t axis_size = axis.size(); - - auto axis_set = std::set(axis.begin(), axis.end()); - PADDLE_ENFORCE_EQ(axis_set.size(), - axis_size, - paddle::platform::errors::InvalidArgument( - "In an axis array, elements must be unique.")); - - PADDLE_ENFORCE_EQ(in_rank, - axis_size, - paddle::platform::errors::InvalidArgument( - "The input dimension's size " - "should be equal to the axis's size. " - "But received dimension is %d, " - "axis's size is %d", - in_rank, - axis_size)); - - PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), - axis_size, - paddle::platform::errors::InvalidArgument( - "Axis values must be ranging from 0 to (dims - 1).")); - - std::vector new_x(x.size()); - for (size_t i = 0; i < x.size(); i++) { - new_x[i] = x[axis[i]]; - } - return new_x; - } - - void CorrectStridesWhenFloatOutputFused(const ExecutionContext& ctx, - const memory::dim N, - memory::dim b, - memory::dims* out_strides) const { - if (!IsInt8() && !IsBfloat16() && IsOutputFused(ctx)) { - *out_strides = {N, b * N, 1}; - } - } - - uint16_t GetBatchSize(void) const { return batch_size_; } - - std::tuple GetOffsets() const { - return std::make_tuple(x_offset_, y_offset_, out_offset_); - } - - dnnl::primitive_attr CreateMatmulAttrs(const ExecutionContext& ctx) { - dnnl::primitive_attr matmul_attrs; - dnnl::post_ops post_operations; - - float scale_out = ComputeOutputScale(ctx); - if (scale_out != 1.0f) { - matmul_attrs.set_output_scales(0, {scale_out}); - } - - paddle::platform::AppendActivation(ctx, post_operations); - - matmul_attrs.set_post_ops(post_operations); - return matmul_attrs; - } - - private: - uint32_t x_offset_; - uint32_t y_offset_; - uint32_t out_offset_; - uint16_t batch_size_; -}; - -/** - * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor. - * - * The shape would be [BatchSize, H, W] or [H, W]. - * If transposed, `H,W` will be swapped. - */ -static void ReshapeTensorToMatrixSequence( - Tensor* x, const phi::funcs::MatDescriptor& descriptor) { - int64_t h, w; - h = descriptor.height_; - w = descriptor.width_; - if (descriptor.trans_) { - std::swap(w, h); - } - if (descriptor.batch_size_) { - x->Resize({descriptor.batch_size_, h, w}); - } else { - x->Resize({h, w}); - } -} - -/** - * Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor - * Out = matmul(x, y) - * - * This method will first calculate X,Y matrix sequence, and then calculate - * the out shape. - * - * Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2] - * The out = [BatchSize, H1, W2] - * - * If there is no batch size in `X` and `Y`, the out will be [H1, W2] - * If any of `X` and `Y` has batch size BatchSize, the out will have the - * BatchSize. - */ -static void ReshapeXYOutToMatrixSequence( - Tensor* x, Tensor* y, Tensor* out, bool trans_x, bool trans_y) { - auto x_dim = RowMatrixDimsFromVector(x->dims()); - auto y_dim = ColumnMatrixDimsFromVector(y->dims()); - auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x); - auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y); - if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { - out->Resize({mat_dim_x.height_, mat_dim_y.width_}); - } else { - out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_), - mat_dim_x.height_, - mat_dim_y.width_}); - } - - ReshapeTensorToMatrixSequence(x, mat_dim_x); - ReshapeTensorToMatrixSequence(y, mat_dim_y); -} - -// Choose appropriate Handler instances based on inferred -// output type (uint8, int8 or float). -template -static void ExecuteMatMul(const ExecutionContext& ctx) { - constexpr bool is_int8 = IsInt8(); - constexpr bool is_bfloat16 = IsBfloat16(); - const bool force_fp32_output = ctx.Attr("force_fp32_output"); - constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* out = ctx.Output("Out"); - const auto& dev_ctx = - ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) { - MatMulMKLDNNHandler(onednn_engine, ctx).Execute(x, y, out); - } else if (is_bfloat16) { - MatMulMKLDNNHandler(onednn_engine, ctx) - .Execute(x, y, out); - } else if (fuse_relu) { - MatMulMKLDNNHandler(onednn_engine, ctx).Execute(x, y, out); - } else { - MatMulMKLDNNHandler(onednn_engine, ctx).Execute(x, y, out); - } -} - -template -class MatMulMKLDNNKernel : public paddle::framework::OpKernel { - public: - void Compute(const ExecutionContext& ctx) const override { - if (ctx.HasAttr("head_number")) { - PADDLE_ENFORCE_EQ( - ctx.Attr("head_number"), - 1, - paddle::platform::errors::Unimplemented( - "oneDNN matmul doesn't support multiple heads. Expected " - "head_number=1. But received `head_number` is %d", - ctx.Attr("head_number"))); - } - ExecuteMatMul(ctx); - } -}; - -} // anonymous namespace - -namespace paddle { -namespace operators { - -template -void MatMulGradMKLDNNKernel::Compute(const ExecutionContext& ctx) const { - if (ctx.HasAttr("head_number")) { - PADDLE_ENFORCE_EQ( - ctx.Attr("head_number"), - 1, - platform::errors::Unimplemented( - "oneDNN matmul doesn't support multiple heads. Expected " - "head_number=1. But received `head_number` is %d", - ctx.Attr("head_number"))); - } - RunKernel(ctx); -} - -template -void MatMulGradMKLDNNKernel::ExecuteMatMulGrad( - const ExecutionContext& ctx, - const MKLDNNDeviceContext& dev_ctx, - const dnnl::engine& engine, - Tensor* x, - bool trans_x, - bool is_fold_init_dims_x, - Tensor* y, - bool trans_y, - bool is_fold_init_dims_y, - Tensor* out) const { - // gradient is calculated in a different way when broadcasting is used - bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) && - out->dims().size() == 2; - - Tensor x_combined, y_combined; - if (!need_combine) { - x_combined = *x; - y_combined = *y; - } else { - x_combined = is_fold_init_dims_x ? FoldOuterDims(*x) - : FoldFirstAndLastDims(dev_ctx, x); - y_combined = is_fold_init_dims_y ? FoldOuterDims(*y) - : FoldFirstAndLastDims(dev_ctx, y); - } - - float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 1.0f; - - MatMulMKLDNNHandler handler(engine, - ctx.GetPlace(), - &x_combined, - trans_x, - &y_combined, - trans_y, - out, - alpha); - - const auto src_memory_p = handler.AcquireSrcMemory(&x_combined); - const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined); - const auto dst_memory_p = handler.AcquireDstMemory(out); - - auto matmul_p = handler.AcquireForwardPrimitive(); - - std::unordered_map matmul_args = { - {DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_WEIGHTS, *weights_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}}; - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - matmul_p->execute(astream, matmul_args); - astream.wait(); - - out->set_layout(framework::DataLayout::kMKLDNN); - out->set_format(platform::GetMKLDNNFormat( - dst_memory_p->get_desc().reshape(vectorize(out->dims())))); -} - -template -void MatMulGradMKLDNNKernel::RunKernel(const ExecutionContext& ctx) const { - const auto& dev_ctx = - ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - auto x = *ctx.Input("X"); - auto y = *ctx.Input("Y"); - auto dout = *ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - - bool transpose_x = ctx.HasAttr("transpose_X") ? ctx.Attr("transpose_X") - : ctx.Attr("trans_x"); - bool transpose_y = ctx.HasAttr("transpose_Y") ? ctx.Attr("transpose_Y") - : ctx.Attr("trans_y"); - - ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); - - framework::DDim dx_dims; - if (dx) { - dx_dims = dx->dims(); - if (dx_dims != x.dims()) { - dx->Resize(x.dims()); - } - } - - framework::DDim dy_dims; - if (dy) { - dy_dims = dy->dims(); - if (dy_dims != y.dims()) { - dy->Resize(y.dims()); - } - } - - if (transpose_x && transpose_y) { - this->ExecuteMatMulGrad( - ctx, dev_ctx, onednn_engine, &y, true, true, &dout, true, false, dx); - this->ExecuteMatMulGrad( - ctx, dev_ctx, onednn_engine, &dout, true, true, &x, true, false, dy); - } else if (transpose_x) { - this->ExecuteMatMulGrad( - ctx, dev_ctx, onednn_engine, &y, false, false, &dout, true, false, dx); - this->ExecuteMatMulGrad( - ctx, dev_ctx, onednn_engine, &x, false, false, &dout, false, true, dy); - } else if (transpose_y) { - this->ExecuteMatMulGrad( - ctx, dev_ctx, onednn_engine, &dout, false, false, &y, false, true, dx); - this->ExecuteMatMulGrad( - ctx, dev_ctx, onednn_engine, &dout, true, true, &x, false, true, dy); - } else { - this->ExecuteMatMulGrad( - ctx, dev_ctx, onednn_engine, &dout, false, false, &y, true, false, dx); - this->ExecuteMatMulGrad( - ctx, dev_ctx, onednn_engine, &x, true, true, &dout, false, true, dy); - } - - if (dx) { - if (dx_dims != x.dims()) { - dx->Resize(dx_dims); - dx->set_format(x.format()); - } - } - if (dy) { - if (dy_dims != y.dims()) { - dy->Resize(dy_dims); - dy->set_format(y.format()); - } - } -} - -template class MatMulGradMKLDNNKernel; -template class MatMulGradMKLDNNKernel; - -} // namespace operators -} // namespace paddle -namespace ops = paddle::operators; - -REGISTER_OP_KERNEL(matmul, - MKLDNN, - ::paddle::platform::CPUPlace, - MatMulMKLDNNKernel, - MatMulMKLDNNKernel, - MatMulMKLDNNKernel, - MatMulMKLDNNKernel); - -REGISTER_OP_KERNEL(matmul_grad, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::MatMulGradMKLDNNKernel, - ops::MatMulGradMKLDNNKernel); diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index 24e3a5aa69d..b4c18c42a76 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -14,36 +14,553 @@ limitations under the License. */ #include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h" namespace { - using dnnl::memory; -using dnnl::primitive; using paddle::framework::DataLayout; using paddle::framework::ExecutionContext; using paddle::platform::GetMKLDNNFormat; using paddle::platform::MatMulV2MKLDNNHandler; using paddle::platform::MKLDNNDeviceContext; +using paddle::platform::MKLDNNFormatForSize; using paddle::platform::MKLDNNGetDataType; using paddle::platform::to_void_cast; +using phi::vectorize; using Tensor = paddle::framework::Tensor; -using paddle::framework::DDim; using paddle::framework::GradVarName; using phi::make_ddim; -using phi::vectorize; + +// Reshape a rank-3 tensor from P x M x N to (P * M) x N. +// Identity op if the tensor is not of rank 3. +static Tensor FoldOuterDims(const Tensor &input) { + auto output = input; + auto in_dims = input.dims(); + if (in_dims.size() == 3) { + output.Resize({in_dims[0] * in_dims[1], in_dims[2]}); + } + return output; +} + +// Reshape a rank-3 tensor from P x M x N to M x (P * N). +// (Warning: This requires transposing data and writes into new memory.) +// Identity op if the tensor is not of rank 3. +template +static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext &dev_ctx, + const Tensor *input) { + auto input_dims = vectorize(input->dims()); + if (input_dims.size() != 3) { + return *input; + } + + Tensor output; + output.Resize({input_dims[1], input_dims[0], input_dims[2]}); + + auto output_dims = vectorize(output.dims()); + + memory::data_type input_type = paddle::framework::ToMKLDNNDataType( + paddle::framework::TransToProtoVarType(input->dtype())); + paddle::platform::ReorderMKLDNNHandler reorder_handler( + output_dims, + paddle::framework::TransToProtoVarType(input->dtype()), + input_type, + dev_ctx.GetEngine()); + + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + memory::format_tag::abc, + paddle::platform::to_void_cast(input->data())); + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + &output, memory::format_tag::bac, dev_ctx.GetPlace()); + auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p, + reorder_dst_memory_p); + + auto &astream = MKLDNNDeviceContext::tls().get_stream(); + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); + astream.wait(); + + output.Resize({input_dims[1], input_dims[0] * input_dims[2]}); + return output; +} + +template +constexpr bool IsInt8() { + return std::is_same::value || std::is_same::value; +} + +template +constexpr bool IsBfloat16() { + return std::is_same::value; +} // Get row matrix shape from a vector shape. If the rank of x_dim > 1, the // original x_dim is returned. -static DDim RowMatrixDimsFromVector(const DDim& x_dim) { +static paddle::framework::DDim RowMatrixDimsFromVector( + const paddle::framework::DDim &x_dim) { return x_dim.size() > 1 ? x_dim : phi::make_ddim({1, x_dim[0]}); } // Get column matrix shape from a vector shape. If the ran of y_dim > 1, the // original y_dim is returned. -static DDim ColumnMatrixDimsFromVector(const DDim& y_dim) { +static paddle::framework::DDim ColumnMatrixDimsFromVector( + const paddle::framework::DDim &y_dim) { return y_dim.size() > 1 ? y_dim : phi::make_ddim({y_dim[0], 1}); } -static std::vector Transpose(const std::vector& x, - const std::vector& axis) { +template +class MatMulMKLDNNHandler + : public paddle::platform::MKLDNNHandlerNoCachingT { + public: + MatMulMKLDNNHandler(const dnnl::engine engine, + paddle::platform::Place cpu_place, + Tensor *x, + bool trans_x, + Tensor *y, + bool trans_y, + Tensor *out, + float scale) + : paddle::platform::MKLDNNHandlerNoCachingT(engine, + cpu_place) { + auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x->dims(), 0, trans_x); + auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y->dims(), 0, trans_y); + + memory::dim x_bs = mat_dim_x.batch_size_; + memory::dim y_bs = mat_dim_y.batch_size_; + + memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1; + const memory::dim M = mat_dim_x.height_; + const memory::dim N = mat_dim_y.width_; + const memory::dim K = mat_dim_x.width_; + + memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K}; + memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N}; + memory::dims out_dims = {out_bs, M, N}; + + memory::dims x_strides = + !trans_x ? memory::dims{M * K, K, 1} : memory::dims{M * K, 1, M}; + + memory::dims y_strides = + !trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K}; + memory::dims out_strides = memory::dims{M * N, N, 1}; + + auto x_md = memory::desc(x_dims, MKLDNNGetDataType(), x_strides); + auto y_md = memory::desc(y_dims, MKLDNNGetDataType(), y_strides); + auto out_md = memory::desc(out_dims, MKLDNNGetDataType(), out_strides); + + dnnl::primitive_attr attrs; + if (scale != 1.0f) attrs.set_output_scales(0, {scale}); + + this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md); + } + // Constructor for FWD MatMul + MatMulMKLDNNHandler(const dnnl::engine engine, const ExecutionContext &ctx) + : paddle::platform::MKLDNNHandlerNoCachingT( + engine, ctx.GetPlace()) { + const dnnl::primitive_attr matmul_attrs = CreateMatmulAttrs(ctx); + + auto matmul_dims_ = GetMatmulDims(ctx); + auto x_md = memory::desc( + matmul_dims_.x_dims, MKLDNNGetDataType(), matmul_dims_.x_strides); + auto y_md = memory::desc( + matmul_dims_.y_dims, MKLDNNGetDataType(), matmul_dims_.y_strides); + auto out_md = memory::desc(matmul_dims_.out_dims, + MKLDNNGetDataType(), + matmul_dims_.out_strides); + this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md); + } + + std::shared_ptr AcquireWeightsMemory(const Tensor *input) { + const YT *input_data = input->data(); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(), + to_void_cast(input_data)); + } + + public: + void Execute(const paddle::framework::Tensor *x, + const paddle::framework::Tensor *y, + paddle::framework::Tensor *out) { + const auto src_memory_p = this->AcquireSrcMemory(x); + const auto weights_memory_p = this->AcquireWeightsMemory(y); + const auto dst_memory_p = this->AcquireDstMemory(out); + + auto matmul_p = this->AcquireForwardPrimitive(); + + std::unordered_map matmul_args = { + {DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); + + // Simulate batch matmul by processing in loop + void *x_ptr = src_memory_p->get_data_handle(); + void *y_ptr = weights_memory_p->get_data_handle(); + void *out_ptr = dst_memory_p->get_data_handle(); + auto offsets = this->GetOffsets(); + for (uint16_t i = 0; i < this->GetBatchSize(); ++i) { + src_memory_p->set_data_handle(x_ptr); + weights_memory_p->set_data_handle(y_ptr); + dst_memory_p->set_data_handle(out_ptr); + matmul_p->execute(astream, + { + {DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}, + }); + x_ptr = static_cast(x_ptr) + std::get<0>(offsets); + y_ptr = static_cast(y_ptr) + std::get<1>(offsets); + out_ptr = static_cast(out_ptr) + std::get<2>(offsets); + } + astream.wait(); + + auto format = + MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw); + out->set_format(format); + out->set_layout(DataLayout::kMKLDNN); + } + + std::shared_ptr AcquireDstMemory( + paddle::framework::Tensor *output) { + // We cannot use base AcquireDstMemory as it makes an allocation request + // base on DST memory primitive size. This is fine in general, but in MatMul + // we have primitive that covers only one batch of Data and then shift + // pointer for every new batch. Hence Tensor size is bigger that dst memory + // primitive size. So would we request less memory that is there and it + // triggers an + // assertion. So as there is no 'any' format here we can leave default size + // of Tensor as computed in ComputeInferShape + OT *ptr = output->mutable_data(this->place_); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr); + } + + private: + struct MatMulDims { + const memory::dims x_dims, y_dims, out_dims, x_strides, y_strides, + out_strides; + }; + + phi::DDim GetDimForInput(const ExecutionContext &ctx, + std::string input_name) { + auto shape = ctx.Attr>("fused_reshape_" + input_name); + auto axis = ctx.Attr>("fused_transpose_" + input_name); + auto input_dims = ctx.Input(input_name)->dims(); + if (!shape.empty() && !axis.empty()) { + auto it_zero = std::find(shape.begin(), shape.end(), 0); + if (it_zero != shape.end()) { + for (uint64_t i = 0; i < shape.size(); i++) { + if (shape[i] == 0) { + PADDLE_ENFORCE_LT( + i, + input_dims.size(), + paddle::platform::errors::InvalidArgument( + "The index of 0 in fused_reshape_%s ", + "should be less than output dim size, ", + "but the index is %d and output dim size is %d", + input_name, + i, + input_dims.size())); + shape[i] = input_dims.at(i); + } + } + } + + return input_dims.reshape(shape).transpose(axis); + } + return input_dims; + } + + std::pair GetInputDimsAndStrides( + const ExecutionContext &ctx, std::string input_name) { + auto shape = ctx.Attr>("fused_reshape_" + input_name); + auto axis = ctx.Attr>("fused_transpose_" + input_name); + auto input_dims = ctx.Input(input_name)->dims(); + auto new_dims = input_dims; + if (!shape.empty() && !axis.empty()) { + auto it_zero = std::find(shape.begin(), shape.end(), 0); + if (it_zero != shape.end()) { + for (uint64_t i = 0; i < shape.size(); i++) { + if (shape[i] == 0) { + PADDLE_ENFORCE_LT( + i, + input_dims.size(), + paddle::platform::errors::InvalidArgument( + "The index of 0 in fused_reshape_%s ", + "should be less than output dim size, ", + "but the index is %d and output dim size is %d", + input_name, + i, + input_dims.size())); + shape[i] = input_dims.at(i); + } + } + } + + new_dims = input_dims.reshape(shape).transpose(axis); + } + + auto &MatrixDimsFromVector = input_name == "X" ? RowMatrixDimsFromVector + : ColumnMatrixDimsFromVector; + phi::funcs::MatDescriptor mat_dim = phi::funcs::CreateMatrixDescriptor( + MatrixDimsFromVector(new_dims), + 0, + ctx.Attr("transpose_" + input_name)); + + memory::dims strides; + if (!shape.empty()) { + auto shape2 = input_dims.reshape(shape); + strides.push_back(1); + for (auto i = shape2.size() - 1; i > 0; --i) { + strides.insert(strides.begin(), strides.front() * shape2[i]); + } + strides = Transpose(strides, axis); + if (shape.size() == 4) + strides.erase(strides.begin()); + else if (shape.size() == 2) + strides.insert(strides.begin(), shape[0] * shape[1]); + mat_dim.stride_ = strides[0]; + if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin())); + } + return std::make_pair(mat_dim, strides); + } + + float ComputeOutputScale(const ExecutionContext &ctx) { + float scale_x = ctx.Attr("Scale_x"); + float scale_y = ctx.Attr("Scale_y"); + bool force_fp32_out = ctx.Attr("force_fp32_output"); + float scale_out = force_fp32_out ? 1.f : ctx.Attr("Scale_out"); + float alpha = ctx.Attr("alpha"); + return alpha * scale_out / (scale_x * scale_y); + } + + bool IsInputFused(const ExecutionContext &ctx) const { + return !(ctx.Attr>("fused_reshape_X").empty() && + ctx.Attr>("fused_reshape_Y").empty()); + } + + bool IsOutputFused(const ExecutionContext &ctx) const { + auto &fused_reshape_Out = ctx.Attr>("fused_reshape_Out"); + auto &fused_transpose_Out = + ctx.Attr>("fused_transpose_Out"); + return !fused_reshape_Out.empty() && !fused_transpose_Out.empty(); + } + + MatMulDims GetMatmulDims(const ExecutionContext &ctx) { + phi::funcs::MatDescriptor mat_dim_x; + memory::dims strides_x; + std::tie(mat_dim_x, strides_x) = GetInputDimsAndStrides(ctx, "X"); + phi::funcs::MatDescriptor mat_dim_y; + memory::dims strides_y; + std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y"); + + auto x_bs = mat_dim_x.batch_size_; + auto y_bs = mat_dim_y.batch_size_; + PADDLE_ENFORCE_EQ(x_bs > 0 && y_bs > 0 && x_bs != y_bs, + false, + paddle::platform::errors::InvalidArgument( + "If batch sizes of X and Y are positive," + "they have to be equal.")); + + memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1; + const memory::dim M = mat_dim_x.height_; + const memory::dim N = mat_dim_y.width_; + const memory::dim K = mat_dim_x.width_; + + batch_size_ = 1; + if (out_bs > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) { + auto x_dims = GetDimForInput(ctx, "X"); + auto y_dims = GetDimForInput(ctx, "Y"); + batch_size_ = x_bs > y_bs ? x_dims[0] : y_dims[0]; + x_bs /= batch_size_; + y_bs /= batch_size_; + out_bs /= batch_size_; + } + memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K}; + memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N}; + memory::dims out_dims = {out_bs, M, N}; + + x_offset_ = x_bs * M * K * sizeof(XT); + y_offset_ = y_bs * K * N * sizeof(YT); + out_offset_ = out_bs * M * N * sizeof(OT); + + // Translate transA and transB + if (strides_x.empty()) + strides_x = !ctx.Attr("transpose_X") ? memory::dims{M * K, K, 1} + : memory::dims{M * K, 1, M}; + if (strides_y.empty()) + strides_y = !ctx.Attr("transpose_Y") ? memory::dims{N * K, N, 1} + : memory::dims{N * K, 1, K}; + memory::dims out_strides = memory::dims{M * N, N, 1}; + + CorrectStridesWhenFloatOutputFused(ctx, N, out_bs, &out_strides); + + return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides}; + } + + std::vector Transpose(const std::vector &x, + const std::vector &axis) { + size_t in_rank = x.size(); + size_t axis_size = axis.size(); + + auto axis_set = std::set(axis.begin(), axis.end()); + PADDLE_ENFORCE_EQ(axis_set.size(), + axis_size, + paddle::platform::errors::InvalidArgument( + "In an axis array, elements must be unique.")); + + PADDLE_ENFORCE_EQ(in_rank, + axis_size, + paddle::platform::errors::InvalidArgument( + "The input dimension's size " + "should be equal to the axis's size. " + "But received dimension is %d, " + "axis's size is %d", + in_rank, + axis_size)); + + PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), + axis_size, + paddle::platform::errors::InvalidArgument( + "Axis values must be ranging from 0 to (dims - 1).")); + + std::vector new_x(x.size()); + for (size_t i = 0; i < x.size(); i++) { + new_x[i] = x[axis[i]]; + } + return new_x; + } + + void CorrectStridesWhenFloatOutputFused(const ExecutionContext &ctx, + const memory::dim N, + memory::dim b, + memory::dims *out_strides) const { + if (!IsInt8() && !IsBfloat16() && IsOutputFused(ctx)) { + *out_strides = {N, b * N, 1}; + } + } + + uint16_t GetBatchSize(void) const { return batch_size_; } + + std::tuple GetOffsets() const { + return std::make_tuple(x_offset_, y_offset_, out_offset_); + } + + dnnl::primitive_attr CreateMatmulAttrs(const ExecutionContext &ctx) { + dnnl::primitive_attr matmul_attrs; + dnnl::post_ops post_operations; + + float scale_out = ComputeOutputScale(ctx); + if (scale_out != 1.0f) { + matmul_attrs.set_output_scales(0, {scale_out}); + } + + paddle::platform::AppendActivation(ctx, post_operations); + + matmul_attrs.set_post_ops(post_operations); + return matmul_attrs; + } + + private: + uint32_t x_offset_; + uint32_t y_offset_; + uint32_t out_offset_; + uint16_t batch_size_; +}; + +/** + * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor. + * + * The shape would be [BatchSize, H, W] or [H, W]. + * If transposed, `H,W` will be swapped. + */ +static void ReshapeTensorToMatrixSequence( + Tensor *x, const phi::funcs::MatDescriptor &descriptor) { + int64_t h, w; + h = descriptor.height_; + w = descriptor.width_; + if (descriptor.trans_) { + std::swap(w, h); + } + if (descriptor.batch_size_) { + x->Resize({descriptor.batch_size_, h, w}); + } else { + x->Resize({h, w}); + } +} + +/** + * Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor + * Out = matmul(x, y) + * + * This method will first calculate X,Y matrix sequence, and then calculate + * the out shape. + * + * Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2] + * The out = [BatchSize, H1, W2] + * + * If there is no batch size in `X` and `Y`, the out will be [H1, W2] + * If any of `X` and `Y` has batch size BatchSize, the out will have the + * BatchSize. + */ +static void ReshapeXYOutToMatrixSequence( + Tensor *x, Tensor *y, Tensor *out, bool trans_x, bool trans_y) { + auto x_dim = RowMatrixDimsFromVector(x->dims()); + auto y_dim = ColumnMatrixDimsFromVector(y->dims()); + auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x); + auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y); + if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { + out->Resize({mat_dim_x.height_, mat_dim_y.width_}); + } else { + out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_), + mat_dim_x.height_, + mat_dim_y.width_}); + } + + ReshapeTensorToMatrixSequence(x, mat_dim_x); + ReshapeTensorToMatrixSequence(y, mat_dim_y); +} + +// Choose appropriate Handler instances based on inferred +// output type (uint8, int8 or float). +template +static void ExecuteMatMul(const ExecutionContext &ctx) { + constexpr bool is_int8 = IsInt8(); + constexpr bool is_bfloat16 = IsBfloat16(); + const bool force_fp32_output = ctx.Attr("force_fp32_output"); + constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses + auto *x = ctx.Input("X"); + auto *y = ctx.Input("Y"); + auto *out = ctx.Output("Out"); + const auto &dev_ctx = + ctx.template device_context(); + const auto &onednn_engine = dev_ctx.GetEngine(); + + if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) { + MatMulMKLDNNHandler(onednn_engine, ctx).Execute(x, y, out); + } else if (is_bfloat16) { + MatMulMKLDNNHandler(onednn_engine, ctx) + .Execute(x, y, out); + } else if (fuse_relu) { + MatMulMKLDNNHandler(onednn_engine, ctx).Execute(x, y, out); + } else { + MatMulMKLDNNHandler(onednn_engine, ctx).Execute(x, y, out); + } +} + +template +class MatMulMKLDNNKernel : public paddle::framework::OpKernel { + public: + void Compute(const ExecutionContext &ctx) const override { + if (ctx.HasAttr("head_number")) { + PADDLE_ENFORCE_EQ( + ctx.Attr("head_number"), + 1, + paddle::platform::errors::Unimplemented( + "oneDNN matmul doesn't support multiple heads. Expected " + "head_number=1. But received `head_number` is %d", + ctx.Attr("head_number"))); + } + ExecuteMatMul(ctx); + } +}; + +static std::vector Transpose(const std::vector &x, + const std::vector &axis) { size_t in_rank = x.size(); size_t axis_size = axis.size(); @@ -75,7 +592,7 @@ static std::vector Transpose(const std::vector& x, return new_x; } -std::vector GetInputStrides(const ExecutionContext& ctx, +std::vector GetInputStrides(const ExecutionContext &ctx, const std::string input_name) { auto shape = ctx.Attr>("fused_reshape_" + input_name); auto axis = ctx.Attr>("fused_transpose_" + input_name); @@ -85,13 +602,15 @@ std::vector GetInputStrides(const ExecutionContext& ctx, new_dims = input_dims.reshape(shape).transpose(axis); } - auto& MatrixDimsFromVector = + auto &MatrixDimsFromVector = input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector; phi::funcs::MatDescriptor mat_dim = phi::funcs::CreateMatrixDescriptor( MatrixDimsFromVector(new_dims), 0, - ctx.Attr(std::string("trans_") + - static_cast(std::tolower(input_name[0])))); + ctx.HasAttr("trans_x") + ? ctx.Attr(std::string("trans_") + + static_cast(std::tolower(input_name[0]))) + : ctx.Attr(std::string("transpose_") + input_name[0])); std::vector strides; if (!shape.empty()) { @@ -111,37 +630,39 @@ std::vector GetInputStrides(const ExecutionContext& ctx, return strides; } -bool IsOutputFused(const ExecutionContext& ctx) { - auto& fused_reshape_Out = ctx.Attr>("fused_reshape_Out"); - auto& fused_transpose_Out = ctx.Attr>("fused_transpose_Out"); +bool IsOutputFused(const ExecutionContext &ctx) { + auto &fused_reshape_Out = ctx.Attr>("fused_reshape_Out"); + auto &fused_transpose_Out = ctx.Attr>("fused_transpose_Out"); return !fused_reshape_Out.empty() && !fused_transpose_Out.empty(); } -float ComputeOutputScale(const ExecutionContext& ctx) { +float ComputeOutputScale(const ExecutionContext &ctx) { float scale_x = ctx.Attr("Scale_x"); float scale_y = ctx.Attr("Scale_y"); bool force_fp32_out = ctx.Attr("force_fp32_output"); float scale_out = force_fp32_out ? 1.f : ctx.Attr("Scale_out"); - return scale_out / (scale_x * scale_y); + float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 1.0f; + return alpha * scale_out / (scale_x * scale_y); } template -void ExecuteMatMulV2(const ExecutionContext& ctx, - const MKLDNNDeviceContext& dev_ctx, +void ExecuteMatMulV2(const ExecutionContext &ctx, + const MKLDNNDeviceContext &dev_ctx, const dnnl::engine onednn_engine, paddle::platform::Place cpu_place, - const Tensor* x, - const std::vector& x_dims, + const Tensor *x, + const std::vector &x_dims, bool trans_x, - const Tensor* y, - const std::vector& y_dims, + const Tensor *y, + const std::vector &y_dims, bool trans_y, - Tensor* out, - const std::vector& out_dims, + Tensor *out, + const std::vector &out_dims, int execution_number = 0) { std::vector x_strides_override = GetInputStrides(ctx, "X"); std::vector y_strides_override = GetInputStrides(ctx, "Y"); - MatMulV2MKLDNNHandler handler(onednn_engine, + MatMulV2MKLDNNHandler handler(ctx, + onednn_engine, ctx.GetPlace(), x_dims, trans_x, @@ -162,7 +683,7 @@ void ExecuteMatMulV2(const ExecutionContext& ctx, {DNNL_ARG_WEIGHTS, *weights_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; - auto& astream = MKLDNNDeviceContext::tls().get_stream(); + auto &astream = MKLDNNDeviceContext::tls().get_stream(); matmul_p->execute(astream, matmul_args); astream.wait(); @@ -172,8 +693,9 @@ void ExecuteMatMulV2(const ExecutionContext& ctx, out->set_format(format); } -DDim GetDimForInput(const paddle::framework::ExecutionContext& ctx, - const std::string& input_name) { +paddle::framework::DDim GetDimForInput( + const paddle::framework::ExecutionContext &ctx, + const std::string &input_name) { auto shape = ctx.Attr>("fused_reshape_" + input_name); auto axis = ctx.Attr>("fused_transpose_" + input_name); auto dim = ctx.Input(input_name)->dims(); @@ -186,16 +708,16 @@ DDim GetDimForInput(const paddle::framework::ExecutionContext& ctx, template class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel { public: - void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); } + void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); } private: - void CalculateMatrixDims(const ExecutionContext& ctx, - const std::vector& x_dims, - const std::vector& y_dims, - std::vector* x_bd_dims, - std::vector* y_bd_dims, - std::vector* out_dims, - Tensor* out) const { + void CalculateMatrixDims(const ExecutionContext &ctx, + const std::vector &x_dims, + const std::vector &y_dims, + std::vector *x_bd_dims, + std::vector *y_bd_dims, + std::vector *out_dims, + Tensor *out) const { if (x_dims.size() == 1) { (*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0]; } else if (x_dims.size() == 2) { @@ -237,15 +759,17 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel { } } - void RunKernel(const ExecutionContext& ctx) const { - const auto& dev_ctx = ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); + void RunKernel(const ExecutionContext &ctx) const { + const auto &dev_ctx = ctx.template device_context(); + const auto &onednn_engine = dev_ctx.GetEngine(); - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* out = ctx.Output("Out"); - bool trans_x = ctx.Attr("trans_x"); - bool trans_y = ctx.Attr("trans_y"); + auto *x = ctx.Input("X"); + auto *y = ctx.Input("Y"); + auto *out = ctx.Output("Out"); + bool trans_x = ctx.HasAttr("trans_x") ? ctx.Attr("trans_x") + : ctx.Attr("transpose_X"); + bool trans_y = ctx.HasAttr("trans_y") ? ctx.Attr("trans_y") + : ctx.Attr("transpose_Y"); auto x_dims = vectorize(GetDimForInput(ctx, "X")); auto y_dims = vectorize(GetDimForInput(ctx, "Y")); @@ -278,16 +802,16 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel { template class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel { public: - void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); } + void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); } private: - void CalculateGradMatrixDims(const ExecutionContext& ctx, - Tensor* dx_tmp, - Tensor* dy_tmp, - const std::vector& dx_dims, - const std::vector& dy_dims, - std::vector* dx_bd_dims, - std::vector* dy_bd_dims) const { + void CalculateGradMatrixDims(const ExecutionContext &ctx, + Tensor *dx_tmp, + Tensor *dy_tmp, + const std::vector &dx_dims, + const std::vector &dy_dims, + std::vector *dx_bd_dims, + std::vector *dy_bd_dims) const { for (size_t i = 0; i < dx_dims.size() - 2; ++i) { if (dx_dims[i] != dy_dims[i]) { if (dx_dims[i] == 1) { @@ -305,13 +829,13 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel { } void ReduceSumForMatmulGradOutput( - const ExecutionContext& ctx, - const MKLDNNDeviceContext& dev_ctx, + const ExecutionContext &ctx, + const MKLDNNDeviceContext &dev_ctx, const dnnl::engine onednn_engine, - const Tensor* dx_tmp, - Tensor* dx, - const std::vector& dx_dims, - const std::vector& squeezed_dims) const { + const Tensor *dx_tmp, + Tensor *dx, + const std::vector &dx_dims, + const std::vector &squeezed_dims) const { paddle::platform::ReductionMKLDNNHandler handler( dnnl::algorithm::reduction_sum, 0.0f, @@ -328,7 +852,7 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel { std::unordered_map reduction_args = { {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; - auto& astream = MKLDNNDeviceContext::tls().get_stream(); + auto &astream = MKLDNNDeviceContext::tls().get_stream(); auto reduction_p = handler.AcquireForwardPrimitive(); reduction_p->execute(astream, reduction_args); @@ -338,7 +862,7 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel { dst_memory_p->get_desc().reshape(squeezed_dims))); } - std::vector ExtendDimsWithOnes(const std::vector& dims, + std::vector ExtendDimsWithOnes(const std::vector &dims, int new_size) const { std::vector new_dims(new_size, 1); for (size_t i = 0; i < dims.size(); ++i) { @@ -348,12 +872,12 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel { return new_dims; } - void RunKernel(const ExecutionContext& ctx) const { - const auto& dev_ctx = ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); + void RunKernel(const ExecutionContext &ctx) const { + const auto &dev_ctx = ctx.template device_context(); + const auto &onednn_engine = dev_ctx.GetEngine(); - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); + auto *x = ctx.Input("X"); + auto *y = ctx.Input("Y"); auto x_dims = vectorize(x->dims()); auto y_dims = vectorize(y->dims()); @@ -376,12 +900,14 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel { return; } - auto* dout = ctx.Input(GradVarName("Out")); - auto* dx = ctx.Output(GradVarName("X")); - auto* dy = ctx.Output(GradVarName("Y")); + auto *dout = ctx.Input(GradVarName("Out")); + auto *dx = ctx.Output(GradVarName("X")); + auto *dy = ctx.Output(GradVarName("Y")); - bool trans_x = ctx.Attr("trans_x"); - bool trans_y = ctx.Attr("trans_y"); + bool trans_x = ctx.HasAttr("trans_x") ? ctx.Attr("trans_x") + : ctx.Attr("transpose_X"); + bool trans_y = ctx.HasAttr("trans_y") ? ctx.Attr("trans_y") + : ctx.Attr("transpose_Y"); auto dout_dims = vectorize(dout->dims()); size_t ndims = std::max(x->dims().size(), y->dims().size()); @@ -545,6 +1071,195 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel { }; } // anonymous namespace +namespace paddle { +namespace operators { + +template +void MatMulGradMKLDNNKernel::Compute(const ExecutionContext &ctx) const { + if (ctx.HasAttr("head_number")) { + PADDLE_ENFORCE_EQ( + ctx.Attr("head_number"), + 1, + platform::errors::Unimplemented( + "oneDNN matmul doesn't support multiple heads. Expected " + "head_number=1. But received `head_number` is %d", + ctx.Attr("head_number"))); + } + RunKernel(ctx); +} + +template +void MatMulGradMKLDNNKernel::ExecuteMatMulGrad( + const ExecutionContext &ctx, + const MKLDNNDeviceContext &dev_ctx, + const dnnl::engine &engine, + Tensor *x, + bool trans_x, + bool is_fold_init_dims_x, + Tensor *y, + bool trans_y, + bool is_fold_init_dims_y, + Tensor *out) const { + // gradient is calculated in a different way when broadcasting is used + bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) && + out->dims().size() == 2; + + Tensor x_combined, y_combined; + if (!need_combine) { + x_combined = *x; + y_combined = *y; + } else { + x_combined = is_fold_init_dims_x ? FoldOuterDims(*x) + : FoldFirstAndLastDims(dev_ctx, x); + y_combined = is_fold_init_dims_y ? FoldOuterDims(*y) + : FoldFirstAndLastDims(dev_ctx, y); + } + + float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 1.0f; + + MatMulMKLDNNHandler handler(engine, + ctx.GetPlace(), + &x_combined, + trans_x, + &y_combined, + trans_y, + out, + alpha); + + const auto src_memory_p = handler.AcquireSrcMemory(&x_combined); + const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined); + const auto dst_memory_p = handler.AcquireDstMemory(out); + + auto matmul_p = handler.AcquireForwardPrimitive(); + + std::unordered_map matmul_args = { + {DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + auto &astream = platform::MKLDNNDeviceContext::tls().get_stream(); + matmul_p->execute(astream, matmul_args); + astream.wait(); + + out->set_layout(framework::DataLayout::kMKLDNN); + out->set_format(platform::GetMKLDNNFormat( + dst_memory_p->get_desc().reshape(vectorize(out->dims())))); +} + +template +void MatMulGradMKLDNNKernel::RunKernel(const ExecutionContext &ctx) const { + const auto &dev_ctx = + ctx.template device_context(); + const auto &onednn_engine = dev_ctx.GetEngine(); + + auto x = *ctx.Input("X"); + auto y = *ctx.Input("Y"); + auto dout = *ctx.Input(framework::GradVarName("Out")); + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *dy = ctx.Output(framework::GradVarName("Y")); + + bool transpose_x = ctx.HasAttr("transpose_X") ? ctx.Attr("transpose_X") + : ctx.Attr("trans_x"); + bool transpose_y = ctx.HasAttr("transpose_Y") ? ctx.Attr("transpose_Y") + : ctx.Attr("trans_y"); + + ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); + + framework::DDim dx_dims; + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x.dims()) { + dx->Resize(x.dims()); + } + } + + framework::DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y.dims()) { + dy->Resize(y.dims()); + } + } + + if (transpose_x && transpose_y) { + this->ExecuteMatMulGrad( + ctx, dev_ctx, onednn_engine, &y, true, true, &dout, true, false, dx); + this->ExecuteMatMulGrad( + ctx, dev_ctx, onednn_engine, &dout, true, true, &x, true, false, dy); + } else if (transpose_x) { + this->ExecuteMatMulGrad( + ctx, dev_ctx, onednn_engine, &y, false, false, &dout, true, false, dx); + this->ExecuteMatMulGrad( + ctx, dev_ctx, onednn_engine, &x, false, false, &dout, false, true, dy); + } else if (transpose_y) { + this->ExecuteMatMulGrad( + ctx, dev_ctx, onednn_engine, &dout, false, false, &y, false, true, dx); + this->ExecuteMatMulGrad( + ctx, dev_ctx, onednn_engine, &dout, true, true, &x, false, true, dy); + } else { + this->ExecuteMatMulGrad( + ctx, dev_ctx, onednn_engine, &dout, false, false, &y, true, false, dx); + this->ExecuteMatMulGrad( + ctx, dev_ctx, onednn_engine, &x, true, true, &dout, false, true, dy); + } + + if (dx) { + if (dx_dims != x.dims()) { + dx->Resize(dx_dims); + dx->set_format(x.format()); + } + } + if (dy) { + if (dy_dims != y.dims()) { + dy->Resize(dy_dims); + dy->set_format(y.format()); + } + } +} + +template class MatMulGradMKLDNNKernel; +template class MatMulGradMKLDNNKernel; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(matmul, + MKLDNN, + ::paddle::platform::CPUPlace, + S8, + 0, + MatMulMKLDNNKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(matmul, + MKLDNN, + ::paddle::platform::CPUPlace, + U8, + 0, + MatMulMKLDNNKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(matmul, + MKLDNN, + ::paddle::platform::CPUPlace, + FP32, + 0, + MatMulV2MKLDNNKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( + matmul, + MKLDNN, + ::paddle::platform::CPUPlace, + BF16, + 0, + MatMulV2MKLDNNKernel); + +REGISTER_OP_KERNEL(matmul_grad, + MKLDNN, + ::paddle::platform::CPUPlace, + ops::MatMulGradMKLDNNKernel, + ops::MatMulGradMKLDNNKernel); + REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace, diff --git a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc index ec341c30773..956dbc810fa 100644 --- a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc @@ -416,7 +416,8 @@ class MulMKLDNNKernel : public framework::OpKernel { bool trans_y, Tensor *out) const { static const std::vector vec_placeholder; - MatMulV2MKLDNNHandler handler(onednn_engine, + MatMulV2MKLDNNHandler handler(ctx, + onednn_engine, ctx.GetPlace(), x_dims, trans_x, diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index f1963a75b17..85b0775c751 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -778,6 +778,59 @@ class BroadcastDataMKLDNNHandler } }; +static void AppendActivation(const framework::ExecutionContext& ctx, + dnnl::post_ops& post_ops, + float activation_scale = 1.0f) { + const auto invalid_attribute = + ctx.HasAttr("fuse_activation") + ? ctx.Attr("fuse_activation").empty() + : true; + if (invalid_attribute) return; + + const auto fuse_activation = ctx.Attr("fuse_activation"); + const auto fuse_alpha = + ctx.HasAttr("fuse_alpha") ? ctx.Attr("fuse_alpha") : 0.0f; + const auto fuse_beta = + ctx.HasAttr("fuse_beta") ? ctx.Attr("fuse_beta") : 0.0f; + + if (fuse_activation == "hard_sigmoid") { + post_ops.append_eltwise(activation_scale, + dnnl::algorithm::eltwise_linear, + fuse_alpha, + fuse_beta); + post_ops.append_eltwise( + activation_scale, dnnl::algorithm::eltwise_clip, 0.0f, 1.0f); + } else { + const std::unordered_map activation_map = { + {"abs", dnnl::algorithm::eltwise_abs}, + {"clip", dnnl::algorithm::eltwise_clip}, + {"gelu", dnnl::algorithm::eltwise_gelu_erf}, + {"gelu_erf", dnnl::algorithm::eltwise_gelu_erf}, + {"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh}, + {"hard_swish", dnnl::algorithm::eltwise_hardswish}, + {"leaky_relu", dnnl::algorithm::eltwise_relu}, + {"mish", dnnl::algorithm::eltwise_mish}, + {"relu", dnnl::algorithm::eltwise_relu}, + {"relu6", dnnl::algorithm::eltwise_bounded_relu}, + {"sigmoid", dnnl::algorithm::eltwise_logistic}, + {"sqrt", dnnl::algorithm::eltwise_sqrt}, + {"swish", dnnl::algorithm::eltwise_swish}, + {"tanh", dnnl::algorithm::eltwise_tanh}}; + + const auto& activation_type = activation_map.find(fuse_activation); + + PADDLE_ENFORCE_NE( + activation_type, + activation_map.end(), + platform::errors::InvalidArgument( + "Activation '%s' not found in oneDNN algorithms mapper", + fuse_activation)); + + post_ops.append_eltwise( + activation_scale, activation_type->second, fuse_alpha, fuse_beta); + } +} + template class ReductionMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT { @@ -810,7 +863,8 @@ template class MatMulV2MKLDNNHandler : public paddle::platform::MKLDNNHandlerNoCachingT { public: - MatMulV2MKLDNNHandler(const dnnl::engine engine, + MatMulV2MKLDNNHandler(const framework::ExecutionContext& ctx, + const dnnl::engine engine, paddle::platform::Place cpu_place, const std::vector& x_org_dims, bool trans_x, @@ -888,7 +942,26 @@ class MatMulV2MKLDNNHandler auto y_md = memory::desc(y_dims, MKLDNNGetDataType(), y_strides); auto out_md = memory::desc(out_ddims, MKLDNNGetDataType(), out_strides); - this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md); + const dnnl::primitive_attr matmul_attrs = CreateMatmulAttrs(ctx); + + this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md); + } + + // TODO(jczaja) : Adapt to int8 + dnnl::primitive_attr CreateMatmulAttrs( + const framework::ExecutionContext& ctx) { + dnnl::primitive_attr matmul_attrs; + dnnl::post_ops post_operations; + + float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 1.0f; + if (alpha != 1.0f) { + matmul_attrs.set_output_scales(0, {alpha}); + } + + AppendActivation(ctx, post_operations); + + matmul_attrs.set_post_ops(post_operations); + return matmul_attrs; } std::vector FakeTransposeStrides( @@ -1013,59 +1086,6 @@ class ActivationMKLDNNHandler } }; -static void AppendActivation(const framework::ExecutionContext& ctx, - dnnl::post_ops& post_ops, - float activation_scale = 1.0f) { - const auto invalid_attribute = - ctx.HasAttr("fuse_activation") - ? ctx.Attr("fuse_activation").empty() - : true; - if (invalid_attribute) return; - - const auto fuse_activation = ctx.Attr("fuse_activation"); - const auto fuse_alpha = - ctx.HasAttr("fuse_alpha") ? ctx.Attr("fuse_alpha") : 0.0f; - const auto fuse_beta = - ctx.HasAttr("fuse_beta") ? ctx.Attr("fuse_beta") : 0.0f; - - if (fuse_activation == "hard_sigmoid") { - post_ops.append_eltwise(activation_scale, - dnnl::algorithm::eltwise_linear, - fuse_alpha, - fuse_beta); - post_ops.append_eltwise( - activation_scale, dnnl::algorithm::eltwise_clip, 0.0f, 1.0f); - } else { - const std::unordered_map activation_map = { - {"abs", dnnl::algorithm::eltwise_abs}, - {"clip", dnnl::algorithm::eltwise_clip}, - {"gelu", dnnl::algorithm::eltwise_gelu_erf}, - {"gelu_erf", dnnl::algorithm::eltwise_gelu_erf}, - {"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh}, - {"hard_swish", dnnl::algorithm::eltwise_hardswish}, - {"leaky_relu", dnnl::algorithm::eltwise_relu}, - {"mish", dnnl::algorithm::eltwise_mish}, - {"relu", dnnl::algorithm::eltwise_relu}, - {"relu6", dnnl::algorithm::eltwise_bounded_relu}, - {"sigmoid", dnnl::algorithm::eltwise_logistic}, - {"sqrt", dnnl::algorithm::eltwise_sqrt}, - {"swish", dnnl::algorithm::eltwise_swish}, - {"tanh", dnnl::algorithm::eltwise_tanh}}; - - const auto& activation_type = activation_map.find(fuse_activation); - - PADDLE_ENFORCE_NE( - activation_type, - activation_map.end(), - platform::errors::InvalidArgument( - "Activation '%s' not found in oneDNN algorithms mapper", - fuse_activation)); - - post_ops.append_eltwise( - activation_scale, activation_type->second, fuse_alpha, fuse_beta); - } -} - static std::unordered_map GetAttributeMap( std::string act_type) { std::unordered_map attr_map; -- GitLab