diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index d0dab34bcebb34dc016416df5145b27dc671c4c2..52c25fb5e827f1d2b8b2ab02b7ac4684e859d0de 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,10 +18,8 @@ using dnnl::memory; using paddle::framework::ExecutionContext; using paddle::platform::MatMulV2MKLDNNHandler; using paddle::platform::MKLDNNDeviceContext; -using paddle::platform::MKLDNNFormatForSize; using paddle::platform::MKLDNNGetDataType; using paddle::platform::to_void_cast; -using phi::DataLayout; using phi::vectorize; using Tensor = phi::DenseTensor; using paddle::framework::GradVarName; @@ -157,22 +155,6 @@ class MatMulMKLDNNHandler this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md); } - // Constructor for FWD MatMul - MatMulMKLDNNHandler(const dnnl::engine engine, const ExecutionContext &ctx) - : paddle::platform::MKLDNNHandlerNoCachingT( - engine, ctx.GetPlace()) { - const dnnl::primitive_attr matmul_attrs = CreateMatmulAttrs(ctx); - - auto matmul_dims_ = GetMatmulDims(ctx); - auto x_md = memory::desc( - matmul_dims_.x_dims, MKLDNNGetDataType(), matmul_dims_.x_strides); - auto y_md = memory::desc( - matmul_dims_.y_dims, MKLDNNGetDataType(), matmul_dims_.y_strides); - auto out_md = memory::desc(matmul_dims_.out_dims, - MKLDNNGetDataType(), - matmul_dims_.out_strides); - this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md); - } std::shared_ptr AcquireWeightsMemory(const Tensor *input) { const YT *input_data = input->data(); @@ -201,8 +183,8 @@ class MatMulMKLDNNHandler void *x_ptr = src_memory_p->get_data_handle(); void *y_ptr = weights_memory_p->get_data_handle(); void *out_ptr = dst_memory_p->get_data_handle(); - auto offsets = this->GetOffsets(); - for (uint16_t i = 0; i < this->GetBatchSize(); ++i) { + auto offsets = std::make_tuple(x_offset_, y_offset_, out_offset_); + for (uint16_t i = 0; i < batch_size_; ++i) { src_memory_p->set_data_handle(x_ptr); weights_memory_p->set_data_handle(y_ptr); dst_memory_p->set_data_handle(out_ptr); @@ -229,182 +211,6 @@ class MatMulMKLDNNHandler return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr); } - private: - struct MatMulDims { - const memory::dims x_dims, y_dims, out_dims, x_strides, y_strides, - out_strides; - }; - - std::pair GetInputDimsAndStrides( - const ExecutionContext &ctx, std::string input_name) { - auto shape = ctx.Attr>("fused_reshape_" + input_name); - auto axis = ctx.Attr>("fused_transpose_" + input_name); - auto input_dims = ctx.Input(input_name)->dims(); - auto new_dims = input_dims; - if (!shape.empty() && !axis.empty()) { - new_dims = input_dims.reshape(shape).transpose(axis); - } - - auto &MatrixDimsFromVector = input_name == "X" ? RowMatrixDimsFromVector - : ColumnMatrixDimsFromVector; - phi::funcs::MatDescriptor mat_dim = phi::funcs::CreateMatrixDescriptor( - MatrixDimsFromVector(new_dims), - 0, - ctx.Attr("transpose_" + input_name)); - - memory::dims strides; - if (!shape.empty()) { - auto shape2 = input_dims.reshape(shape); - strides.push_back(1); - for (auto i = shape2.size() - 1; i > 0; --i) { - strides.insert(strides.begin(), strides.front() * shape2[i]); - } - strides = Transpose(strides, axis); - if (shape.size() == 4) - strides.erase(strides.begin()); - else if (shape.size() == 2) - strides.insert(strides.begin(), shape[0] * shape[1]); - mat_dim.stride_ = strides[0]; - if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin())); - } - return std::make_pair(mat_dim, strides); - } - - float ComputeOutputScale(const ExecutionContext &ctx) { - float scale_x = ctx.Attr("Scale_x"); - float scale_y = ctx.Attr("Scale_y"); - bool force_fp32_out = ctx.Attr("force_fp32_output"); - float scale_out = force_fp32_out ? 1.f : ctx.Attr("Scale_out"); - float alpha = ctx.Attr("alpha"); - return alpha * scale_out / (scale_x * scale_y); - } - - bool IsInputFused(const ExecutionContext &ctx) const { - return !(ctx.Attr>("fused_reshape_X").empty() && - ctx.Attr>("fused_reshape_Y").empty()); - } - - bool IsOutputFused(const ExecutionContext &ctx) const { - auto &fused_reshape_Out = ctx.Attr>("fused_reshape_Out"); - auto &fused_transpose_Out = - ctx.Attr>("fused_transpose_Out"); - return !fused_reshape_Out.empty() && !fused_transpose_Out.empty(); - } - - MatMulDims GetMatmulDims(const ExecutionContext &ctx) { - phi::funcs::MatDescriptor mat_dim_x; - memory::dims strides_x; - std::tie(mat_dim_x, strides_x) = GetInputDimsAndStrides(ctx, "X"); - phi::funcs::MatDescriptor mat_dim_y; - memory::dims strides_y; - std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y"); - - auto x_bs = mat_dim_x.batch_size_; - auto y_bs = mat_dim_y.batch_size_; - PADDLE_ENFORCE_EQ(x_bs > 0 && y_bs > 0 && x_bs != y_bs, - false, - paddle::platform::errors::InvalidArgument( - "If batch sizes of X and Y are positive," - "they have to be equal.")); - - memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1; - const memory::dim M = mat_dim_x.height_; - const memory::dim N = mat_dim_y.width_; - const memory::dim K = mat_dim_x.width_; - - batch_size_ = 1; - if (out_bs > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) { - auto x_dims = GetDimForInput(ctx, "X"); - auto y_dims = GetDimForInput(ctx, "Y"); - batch_size_ = x_bs > y_bs ? x_dims[0] : y_dims[0]; - x_bs /= batch_size_; - y_bs /= batch_size_; - out_bs /= batch_size_; - } - memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K}; - memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N}; - memory::dims out_dims = {out_bs, M, N}; - - x_offset_ = x_bs * M * K * sizeof(XT); - y_offset_ = y_bs * K * N * sizeof(YT); - out_offset_ = out_bs * M * N * sizeof(OT); - - // Translate transA and transB - if (strides_x.empty()) - strides_x = !ctx.Attr("transpose_X") ? memory::dims{M * K, K, 1} - : memory::dims{M * K, 1, M}; - if (strides_y.empty()) - strides_y = !ctx.Attr("transpose_Y") ? memory::dims{N * K, N, 1} - : memory::dims{N * K, 1, K}; - memory::dims out_strides = memory::dims{M * N, N, 1}; - - CorrectStridesWhenFloatOutputFused(ctx, N, out_bs, &out_strides); - - return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides}; - } - - std::vector Transpose(const std::vector &x, - const std::vector &axis) { - size_t in_rank = x.size(); - size_t axis_size = axis.size(); - - auto axis_set = std::set(axis.begin(), axis.end()); - PADDLE_ENFORCE_EQ(axis_set.size(), - axis_size, - paddle::platform::errors::InvalidArgument( - "In an axis array, elements must be unique.")); - - PADDLE_ENFORCE_EQ(in_rank, - axis_size, - paddle::platform::errors::InvalidArgument( - "The input dimension's size " - "should be equal to the axis's size. " - "But received dimension is %d, " - "axis's size is %d", - in_rank, - axis_size)); - - PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), - axis_size, - paddle::platform::errors::InvalidArgument( - "Axis values must be ranging from 0 to (dims - 1).")); - - std::vector new_x(x.size()); - for (size_t i = 0; i < x.size(); i++) { - new_x[i] = x[axis[i]]; - } - return new_x; - } - - void CorrectStridesWhenFloatOutputFused(const ExecutionContext &ctx, - const memory::dim N, - memory::dim b, - memory::dims *out_strides) const { - if (!IsInt8() && !IsBfloat16() && IsOutputFused(ctx)) { - *out_strides = {N, b * N, 1}; - } - } - - uint16_t GetBatchSize(void) const { return batch_size_; } - - std::tuple GetOffsets() const { - return std::make_tuple(x_offset_, y_offset_, out_offset_); - } - - dnnl::primitive_attr CreateMatmulAttrs(const ExecutionContext &ctx) { - dnnl::primitive_attr matmul_attrs; - dnnl::post_ops post_operations; - - float scale_out = ComputeOutputScale(ctx); - if (scale_out != 1.0f) { - matmul_attrs.set_output_scales(0, {scale_out}); - } - paddle::platform::AppendActivation(ctx, post_operations); - - matmul_attrs.set_post_ops(post_operations); - return matmul_attrs; - } - private: uint32_t x_offset_; uint32_t y_offset_; @@ -465,55 +271,8 @@ static void ReshapeXYOutToMatrixSequence( ReshapeTensorToMatrixSequence(y, mat_dim_y); } -// Choose appropriate Handler instances based on inferred -// output type (uint8, int8 or float). -template -static void ExecuteMatMul(const ExecutionContext &ctx) { - constexpr bool is_int8 = IsInt8(); - constexpr bool is_bfloat16 = IsBfloat16(); - const bool force_fp32_output = ctx.Attr("force_fp32_output"); - const bool fuse_relu = - ctx.HasAttr("fuse_activation") - ? ctx.Attr("fuse_activation") == "relu" - : false; - auto *x = ctx.Input("X"); - auto *y = ctx.Input("Y"); - auto *out = ctx.Output("Out"); - const auto &dev_ctx = - ctx.template device_context(); - const auto &onednn_engine = dev_ctx.GetEngine(); - - if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) { - MatMulMKLDNNHandler(onednn_engine, ctx).Execute(x, y, out); - } else if (is_bfloat16) { - MatMulMKLDNNHandler(onednn_engine, ctx) - .Execute(x, y, out); - } else if (fuse_relu) { - MatMulMKLDNNHandler(onednn_engine, ctx).Execute(x, y, out); - } else { - MatMulMKLDNNHandler(onednn_engine, ctx).Execute(x, y, out); - } -} - -template -class MatMulMKLDNNKernel : public paddle::framework::OpKernel { - public: - void Compute(const ExecutionContext &ctx) const override { - if (ctx.HasAttr("head_number")) { - PADDLE_ENFORCE_EQ( - ctx.Attr("head_number"), - 1, - paddle::platform::errors::Unimplemented( - "oneDNN matmul doesn't support multiple heads. Expected " - "head_number=1. But received `head_number` is %d", - ctx.Attr("head_number"))); - } - ExecuteMatMul(ctx); - } -}; - -static std::vector Transpose(const std::vector &x, - const std::vector &axis) { +std::vector Transpose(const std::vector &x, + const std::vector &axis) { size_t in_rank = x.size(); size_t axis_size = axis.size(); @@ -589,15 +348,6 @@ bool IsOutputFused(const ExecutionContext &ctx) { return !fused_reshape_Out.empty() && !fused_transpose_Out.empty(); } -float ComputeOutputScale(const ExecutionContext &ctx) { - float scale_x = ctx.Attr("Scale_x"); - float scale_y = ctx.Attr("Scale_y"); - bool force_fp32_out = ctx.Attr("force_fp32_output"); - float scale_out = force_fp32_out ? 1.f : ctx.Attr("Scale_out"); - float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 1.0f; - return alpha * scale_out / (scale_x * scale_y); -} - template void ExecuteMatMulV2(const ExecutionContext &ctx, const MKLDNNDeviceContext &dev_ctx,