// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" using dnnl::engine; using dnnl::inner_product_forward; using dnnl::memory; using dnnl::prop_kind; using dnnl::stream; using phi::ReshapeToMatrix; namespace phi { template class FusedMatmulOneDNNHandler : public funcs::OneDNNHandlerNoCachingT { public: FusedMatmulOneDNNHandler(const OneDNNContext &dev_ctx, const DenseTensor *residual_data, const std::vector &x_org_dims, const std::vector &y_org_dims, bool trans_x, bool trans_y, const float matmul_alpha, const std::vector &x_strides_override, const std::vector &y_strides_override, bool is_output_fused, const std::string &fuse_activation, const float fuse_alpha, const float fuse_beta, const float fused_output_scale, const float scale_x, const float scale_y, const float scale_in_eltwise, const float scale_out, const bool force_fp32_output) : funcs::OneDNNHandlerNoCachingT(dev_ctx.GetEngine(), dev_ctx.GetPlace()) { // M X K * K X N std::vector x_dims(x_org_dims); std::vector y_dims(y_org_dims); const int MB_idx = x_dims.size() - 3; const int H_idx = x_dims.size() - 2; const int W_idx = x_dims.size() - 1; if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]); if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]); const memory::dim M = x_dims[H_idx]; const memory::dim K = x_dims[W_idx]; const memory::dim N = y_dims[W_idx]; std::vector x_strides(x_dims.size() - 3, 1); std::vector y_strides(x_dims.size() - 3, 1); std::vector out_strides(x_dims.size() - 3, 1); std::vector out_ddims(x_dims.size() - 3, 1); x_strides.reserve(x_dims.size()); y_strides.reserve(x_dims.size()); out_strides.reserve(x_dims.size()); if (x_strides_override.empty()) { if (trans_x) { x_strides.insert(x_strides.end(), {M * K, 1, M}); } else { x_strides.insert(x_strides.end(), {M * K, K, 1}); } } else { x_strides = x_strides_override; } if (y_strides_override.empty()) { if (trans_y) { y_strides.insert(y_strides.end(), {N * K, 1, K}); } else { y_strides.insert(y_strides.end(), {N * K, N, 1}); } } else { y_strides = y_strides_override; } out_strides.insert(out_strides.end(), {M * N, N, 1}); out_ddims.insert(out_ddims.end(), {std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N}); for (int i = x_dims.size() - 4; i >= 0; --i) { out_ddims[i] = std::max(x_dims[i], y_dims[i]); if (x_strides_override.empty()) { x_strides[i] = x_dims[i + 1] * x_strides[i + 1]; } if (y_strides_override.empty()) { y_strides[i] = y_dims[i + 1] * y_strides[i + 1]; } out_strides[i] = out_ddims[i + 1] * out_strides[i + 1]; } // TODO(jczaja): Why not for int8?? if (!funcs::is_int8() && is_output_fused) { out_strides = FakeTransposeStrides(out_ddims); } auto x_md = memory::desc(x_dims, funcs::OneDNNGetDataType(), x_strides); auto y_md = memory::desc(y_dims, funcs::OneDNNGetDataType(), y_strides); auto out_md = memory::desc(out_ddims, funcs::OneDNNGetDataType(), out_strides); const auto matmul_attrs = CreateMatmulAttrs(dev_ctx, residual_data, matmul_alpha, fuse_activation, fuse_alpha, fuse_beta, fused_output_scale, scale_x, scale_y, scale_in_eltwise, scale_out, force_fp32_output); this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md); } float ComputeOutputScale(float matmul_alpha, const float scale_x, const float scale_y, const float scale_in_eltwise, const float scale_out, const bool force_fp32_output) { float f_scale_out = force_fp32_output ? 1.0f : scale_out; matmul_alpha *= f_scale_out / (scale_x * scale_y); return matmul_alpha; } dnnl::primitive_attr CreateMatmulAttrs(const OneDNNContext &dev_ctx, const DenseTensor *residual_data, const float matmul_alpha, const std::string &fuse_activation, const float fuse_alpha, const float fuse_beta, const float fused_output_scale, const float scale_x, const float scale_y, const float scale_in_eltwise, const float scale_out, const bool force_fp32_output) { dnnl::primitive_attr matmul_attrs; dnnl::post_ops post_operations; float computed_scale_out = ComputeOutputScale(matmul_alpha, scale_x, scale_y, scale_in_eltwise, scale_out, force_fp32_output); if (computed_scale_out != 1.0f) { matmul_attrs.set_output_scales(0, {computed_scale_out}); } if (residual_data) { auto residual_data_tz = vectorize(residual_data->dims()); auto residual_data_md = memory::desc(residual_data_tz, funcs::OneDNNGetDataType(), dnnl::memory::format_tag::any); post_operations.append_binary(dnnl::algorithm::binary_add, residual_data_md); if (scale_in_eltwise != 0.0f) { float sum_scale = scale_out / scale_in_eltwise; post_operations.append_sum(sum_scale); } } funcs::AppendActivation( dev_ctx, post_operations, 1.0f, fuse_activation, fuse_alpha, fuse_beta); if (fused_output_scale != 1.0f) { post_operations.append_eltwise( 1.0, dnnl::algorithm::eltwise_linear, fused_output_scale, 0.0f); } matmul_attrs.set_post_ops(post_operations); return matmul_attrs; } std::vector FakeTransposeStrides( const std::vector &matmul_out_dims) const { // fuse matmul_v2 + transpose + reshape guarantees that output is 4D and // transpose axis are: {0, 2, 1, 3} std::vector transpose_axis = {0, 2, 1, 3}; std::vector fake_strides(transpose_axis.size()); int ndims = static_cast(transpose_axis.size()); int total_stride = 1; for (int i = ndims - 1; i >= 0; --i) { fake_strides[transpose_axis[i]] = total_stride; total_stride *= matmul_out_dims[transpose_axis[i]]; } return fake_strides; } std::shared_ptr AcquireWeightsMemory(const DenseTensor *input) { const YT *input_data = input->data(); return this->AcquireMemoryFromPrimitive( this->fwd_pd_->weights_desc(), funcs::to_void_cast(input_data)); } std::shared_ptr AcquireDstMemory(const OneDNNContext &dev_ctx, DenseTensor *output) { // We cannot use base AcquireDstMemory as it makes an allocation request // base on DST memory primitive size. This is fine in general, but in MatMul // we have primitive that covers only one batch of Data and then shift // pointer for every new batch. Hence DenseTensor size is bigger that // dst memory primitive size. So would we request less memory that is there // and it triggers an assertion. So as there is no 'any' format here we can // leave default size of DenseTensor as computed in ComputeInferShape OT *ptr = dev_ctx.template Alloc(output); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr); } }; static DDim RowMatrixDimsFromVector(const DDim &x_dim) { return x_dim.size() > 1 ? x_dim : make_ddim({1, x_dim[0]}); } static DDim ColumnMatrixDimsFromVector(const DDim &y_dim) { return y_dim.size() > 1 ? y_dim : make_ddim({y_dim[0], 1}); } static std::vector TransposeAxis(const std::vector &x, const std::vector &axis) { size_t in_rank = x.size(); size_t axis_size = axis.size(); auto axis_set = std::set(axis.begin(), axis.end()); PADDLE_ENFORCE_EQ(axis_set.size(), axis_size, phi::errors::InvalidArgument( "In an axis array, elements must be unique.")); PADDLE_ENFORCE_EQ( in_rank, axis_size, phi::errors::InvalidArgument("The input dimension's size " "should be equal to the axis's size. " "But received dimension is %d, " "axis's size is %d", in_rank, axis_size)); PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size, phi::errors::InvalidArgument( "Axis values must be ranging from 0 to (dims - 1).")); std::vector new_x(x.size()); for (size_t i = 0; i < x.size(); i++) { new_x[i] = x[axis[i]]; } return new_x; } static std::vector GetInputStrides(const std::string input_name, const DDim &input_dims, std::vector shape, std::vector axis, const bool transpose_input) { auto new_dims = input_dims; if (!shape.empty() && !axis.empty()) { new_dims = input_dims.reshape(shape).transpose(axis); } auto &MatrixDimsFromVector = input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector; funcs::MatDescriptor mat_dim = funcs::CreateMatrixDescriptor( MatrixDimsFromVector(new_dims), 0, transpose_input); std::vector strides; if (!shape.empty()) { auto shape2 = input_dims.reshape(shape); strides.push_back(1); for (auto i = shape2.size() - 1; i > 0; --i) { strides.insert(strides.begin(), strides.front() * static_cast(shape2[i])); } strides = TransposeAxis(strides, axis); if (shape.size() == 2) strides.insert(strides.begin(), static_cast(shape[0] * shape[1])); mat_dim.stride_ = strides[0]; if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin())); } return strides; } template void ExecuteFusedMatmul(const OneDNNContext &dev_ctx, const DenseTensor &x, const DenseTensor &y, const DenseTensor *residual_data, const std::vector &x_dims, const std::vector &y_dims, bool trans_x, bool trans_y, const float matmul_alpha, const std::vector &x_strides_override, const std::vector &y_strides_override, const bool is_output_fused, const std::vector &fused_transpose_Out, const std::string &fuse_activation, const float fuse_alpha, const float fuse_beta, const float fused_output_scale, const float scale_x, const float scale_y, const float scale_in_eltwise, const float scale_out, const bool force_fp32_output, DenseTensor *out) { FusedMatmulOneDNNHandler handler(dev_ctx, residual_data, x_dims, y_dims, trans_x, trans_y, matmul_alpha, x_strides_override, y_strides_override, is_output_fused, fuse_activation, fuse_alpha, fuse_beta, fused_output_scale, scale_x, scale_y, scale_in_eltwise, scale_out, force_fp32_output); const auto src_memory_p = handler.AcquireSrcMemory(&x); const auto weights_memory_p = handler.AcquireWeightsMemory(&y); const auto dst_memory_p = handler.AcquireDstMemory(dev_ctx, out); auto matmul_p = handler.AcquireForwardPrimitive(); std::unordered_map matmul_args = { {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_WEIGHTS, *weights_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; if (residual_data) { const auto residual_data_memory_p = handler.AcquireSrcMemory(residual_data); matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, *residual_data_memory_p}); } auto &astream = OneDNNContext::tls().get_stream(); matmul_p->execute(astream, matmul_args); astream.wait(); if (is_output_fused && !funcs::is_int8()) { auto permuted_md = dst_memory_p->get_desc().permute_axes(fused_transpose_Out); out->set_mem_desc(permuted_md.reshape(vectorize(out->dims()))); } else { out->set_mem_desc( dst_memory_p->get_desc().reshape(vectorize(out->dims()))); } } std::vector GetInputShape(DDim input_dims, std::vector shape, std::vector axis) { if (!shape.empty() && !axis.empty()) { return vectorize(input_dims.reshape(shape).transpose(axis)); } return vectorize(input_dims); } void CalculateMatrixDims(const std::vector &x_dims, const std::vector &y_dims, std::vector *x_bd_dims, std::vector *y_bd_dims, DenseTensor *out, const bool is_output_fused) { if (x_dims.size() == 1) { (*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0]; } else if (x_dims.size() == 2) { (*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[1]; (*x_bd_dims)[(*x_bd_dims).size() - 2] = x_dims[0]; } else { for (size_t i = 0; i < x_dims.size(); ++i) { (*x_bd_dims)[(*x_bd_dims).size() - x_dims.size() + i] = x_dims[i]; } } if (y_dims.size() == 1) { (*y_bd_dims)[(*x_bd_dims).size() - 2] = y_dims[0]; } else if (y_dims.size() == 2) { (*y_bd_dims)[(*y_bd_dims).size() - 1] = y_dims[1]; (*y_bd_dims)[(*y_bd_dims).size() - 2] = y_dims[0]; } else { for (size_t i = 0; i < y_dims.size(); ++i) { (*y_bd_dims)[(*y_bd_dims).size() - y_dims.size() + i] = y_dims[i]; } } if (!is_output_fused && x_dims.size() > 2 && y_dims.size() > 2) { auto out_dims = vectorize(out->dims()); for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) { PADDLE_ENFORCE_EQ( (*x_bd_dims)[i] == (*y_bd_dims)[i] || (*x_bd_dims)[i] == 1 || (*y_bd_dims)[i] == 1, true, errors::InvalidArgument( "Tensor dimensions are incorrect for broadcasting." "Dimensions in X and Y must be same or equal to 1, but " "received x_dim[%d]=%d and y_dims[%d]= %d", i, (*x_bd_dims)[i], i, (*y_bd_dims)[i])); (out_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]); } out->Resize(make_ddim((out_dims))); } } template void FusedMatmulKernel(const Context &dev_ctx, const DenseTensor &x, const DenseTensor &y, const paddle::optional &residual_data, bool transpose_x, bool transpose_y, const float matmul_alpha, const std::string &fuse_activation, const float fuse_alpha, const float fuse_beta, const float fused_output_scale, const std::vector &fused_reshape_X, const std::vector &fused_transpose_X, const std::vector &fused_reshape_Y, const std::vector &fused_transpose_Y, const std::vector &fused_reshape_Out, const std::vector &fused_transpose_Out, const std::string &mkldnn_data_type, const float scale_x, const float scale_y, const float scale_in_eltwise, const float scale_out, const bool force_fp32_output, DenseTensor *out) { if (dev_ctx.HasDnnAttr("head_number")) { const auto head_number = PADDLE_GET_CONST(int, dev_ctx.GetDnnAttr("head_number")); PADDLE_ENFORCE_EQ( head_number, 1, errors::Unimplemented( "oneDNN matmul doesn't support multiple heads. Expected " "head_number=1. But received `head_number` is %d", head_number)); } constexpr bool is_int8 = funcs::is_int8(); constexpr bool is_bfloat16 = funcs::is_bfloat16(); bool fuse_relu = false; if (fuse_activation == "relu" || fuse_activation == "relu6") { fuse_relu = true; } auto x_dims = GetInputShape(x.dims(), fused_reshape_X, fused_transpose_X); auto y_dims = GetInputShape(y.dims(), fused_reshape_Y, fused_transpose_Y); auto is_output_fused = !fused_reshape_Out.empty() && !fused_transpose_Out.empty(); auto x_strides_override = GetInputStrides( "X", x.dims(), fused_reshape_X, fused_transpose_X, transpose_x); auto y_strides_override = GetInputStrides( "Y", y.dims(), fused_reshape_Y, fused_transpose_Y, transpose_y); int ndims = std::max(x_dims.size(), y_dims.size()); ndims = std::max(ndims, 3); std::vector x_bd_dims(ndims, 1); std::vector y_bd_dims(ndims, 1); CalculateMatrixDims( x_dims, y_dims, &x_bd_dims, &y_bd_dims, out, is_output_fused); if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) { ExecuteFusedMatmul(dev_ctx, x, y, residual_data.get_ptr(), x_bd_dims, y_bd_dims, transpose_x, transpose_y, matmul_alpha, x_strides_override, y_strides_override, is_output_fused, fused_transpose_Out, fuse_activation, fuse_alpha, fuse_beta, fused_output_scale, scale_x, scale_y, scale_in_eltwise, scale_out, force_fp32_output, out); } else if (is_bfloat16) { ExecuteFusedMatmul(dev_ctx, x, y, residual_data.get_ptr(), x_bd_dims, y_bd_dims, transpose_x, transpose_y, matmul_alpha, x_strides_override, y_strides_override, is_output_fused, fused_transpose_Out, fuse_activation, fuse_alpha, fuse_beta, fused_output_scale, scale_x, scale_y, scale_in_eltwise, scale_out, force_fp32_output, out); } else if (fuse_relu) { ExecuteFusedMatmul(dev_ctx, x, y, residual_data.get_ptr(), x_bd_dims, y_bd_dims, transpose_x, transpose_y, matmul_alpha, x_strides_override, y_strides_override, is_output_fused, fused_transpose_Out, fuse_activation, fuse_alpha, fuse_beta, fused_output_scale, scale_x, scale_y, scale_in_eltwise, scale_out, force_fp32_output, out); } else { ExecuteFusedMatmul(dev_ctx, x, y, residual_data.get_ptr(), x_bd_dims, y_bd_dims, transpose_x, transpose_y, matmul_alpha, x_strides_override, y_strides_override, is_output_fused, fused_transpose_Out, fuse_activation, fuse_alpha, fuse_beta, fused_output_scale, scale_x, scale_y, scale_in_eltwise, scale_out, force_fp32_output, out); } } } // namespace phi PD_REGISTER_KERNEL(fused_matmul, OneDNN, ONEDNN, phi::FusedMatmulKernel, float, phi::dtype::bfloat16, int8_t, uint8_t) {}