diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index b4d782da78f024880d489670d423dfa6ffa85dab..be965c4abb89564a3060ec855f81cac024f9cce0 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -381,7 +381,7 @@ void ExecuteMatMulV2(const ExecutionContext &ctx, } template -class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel { +class MatMulMKLDNNKernel : public paddle::framework::OpKernel { public: void Compute(const ExecutionContext &ctx) const override { if (ctx.HasAttr("head_number")) { @@ -696,21 +696,13 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel { REGISTER_OP_KERNEL(matmul, MKLDNN, ::paddle::platform::CPUPlace, - MatMulV2MKLDNNKernel, - MatMulV2MKLDNNKernel, - MatMulV2MKLDNNKernel, - MatMulV2MKLDNNKernel); + MatMulMKLDNNKernel, + MatMulMKLDNNKernel, + MatMulMKLDNNKernel, + MatMulMKLDNNKernel); REGISTER_OP_KERNEL(matmul_grad, MKLDNN, ::paddle::platform::CPUPlace, MatMulGradMKLDNNKernel, MatMulGradMKLDNNKernel); - -REGISTER_OP_KERNEL(matmul_v2, - MKLDNN, - ::paddle::platform::CPUPlace, - MatMulV2MKLDNNKernel, - MatMulV2MKLDNNKernel, - MatMulV2MKLDNNKernel, - MatMulV2MKLDNNKernel); diff --git a/paddle/fluid/operators/ops_extra_info.h b/paddle/fluid/operators/ops_extra_info.h index b16e4ed58f3fe24057325b1520cf6c342c52f180..77c0aa7a33fb3ce0fc75167d78b652fe646cd990 100644 --- a/paddle/fluid/operators/ops_extra_info.h +++ b/paddle/fluid/operators/ops_extra_info.h @@ -98,6 +98,7 @@ const std::unordered_map {"fuse_alpha", ExtraAttrProperty::ONEDNN}, {"fuse_beta", ExtraAttrProperty::ONEDNN}, {"fuse_relu", ExtraAttrProperty::ONEDNN}, + {"fused_output_scale", ExtraAttrProperty::ONEDNN}, {"fuse_residual_connection", ExtraAttrProperty::ONEDNN}, {"fuse_with_relu", ExtraAttrProperty::ONEDNN}, {"fused_reshape_Out", ExtraAttrProperty::ONEDNN}, @@ -221,7 +222,8 @@ class ExtraInfoUtils { std::unordered_map> g_extra_input_names_map_ = {{"conv2d", {"Bias", "ResidualData"}}, {"conv2d_transpose", {"Bias"}}, - {"conv2d_grad", {"Bias"}}}; + {"conv2d_grad", {"Bias"}}, + {"matmul_v2", {"ResidualData"}}}; std::vector empty_extra_input_names_; }; diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index f4577dab5aa476e4c1d34a55417b4ae661e35588..7f64f8668c91bdfcdfe2fd47762a8826117dbc4b 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -1874,9 +1874,11 @@ class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT { if (scale_out != 1.0f) { matmul_attrs.set_output_scales(0, {scale_out}); } + const auto* residual_data = dev_ctx.HasDnnInput("ResidualData") + ? dev_ctx.GetDnnInput("ResidualData") + : nullptr; - if (dev_ctx.HasDnnInput("ResidualData")) { - auto* residual_data = dev_ctx.GetDnnInput("ResidualData"); + if (residual_data) { auto residual_data_tz = vectorize(residual_data->dims()); auto residual_data_md = memory::desc(residual_data_tz, OneDNNGetDataType(), @@ -1893,9 +1895,11 @@ class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT { AppendActivation(dev_ctx, post_operations); - if (dev_ctx.HasDnnAttr("fused_output_scale")) { - float scale_alpha = - PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fused_output_scale")); + const float scale_alpha = + dev_ctx.HasDnnAttr("fused_output_scale") + ? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fused_output_scale")) + : 1.0f; + if (scale_alpha != 1.0f) { post_operations.append_eltwise( 1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f); } @@ -2014,8 +2018,11 @@ void ExecuteMatmul(const OneDNNContext& dev_ctx, {DNNL_ARG_WEIGHTS, *weights_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; - if (dev_ctx.HasDnnInput("ResidualData")) { - auto* residual_data = dev_ctx.GetDnnInput("ResidualData"); + const auto* residual_data = dev_ctx.HasDnnInput("ResidualData") + ? dev_ctx.GetDnnInput("ResidualData") + : nullptr; + + if (residual_data) { const auto residual_data_memory_p = handler.AcquireSrcMemory(residual_data); matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, *residual_data_memory_p}); diff --git a/paddle/phi/kernels/onednn/matmul_kernel.cc b/paddle/phi/kernels/onednn/matmul_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..30a1735c5184aadb381e294d748a0aa5711b5541 --- /dev/null +++ b/paddle/phi/kernels/onednn/matmul_kernel.cc @@ -0,0 +1,164 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/matmul_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +DDim GetDimsForInput(const OneDNNContext &dev_ctx, + DDim input_dims, + std::string input_name) { + auto shape = + dev_ctx.HasDnnAttr("fused_reshape_" + input_name) + ? PADDLE_GET_CONST(std::vector, + dev_ctx.GetDnnAttr("fused_reshape_" + input_name)) + : std::vector(); + auto axis = dev_ctx.HasDnnAttr("fused_transpose_" + input_name) + ? PADDLE_GET_CONST( + std::vector, + dev_ctx.GetDnnAttr("fused_transpose_" + input_name)) + : std::vector(); + if (!shape.empty() && !axis.empty()) { + return input_dims.reshape(shape).transpose(axis); + } + return input_dims; +} + +void CalculateMatrixDims(const std::vector &x_dims, + const std::vector &y_dims, + std::vector *x_bd_dims, + std::vector *y_bd_dims, + DenseTensor *out, + const bool is_output_fused) { + if (x_dims.size() == 1) { + (*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0]; + } else if (x_dims.size() == 2) { + (*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[1]; + (*x_bd_dims)[(*x_bd_dims).size() - 2] = x_dims[0]; + } else { + for (size_t i = 0; i < x_dims.size(); ++i) { + (*x_bd_dims)[(*x_bd_dims).size() - x_dims.size() + i] = x_dims[i]; + } + } + if (y_dims.size() == 1) { + (*y_bd_dims)[(*x_bd_dims).size() - 2] = y_dims[0]; + } else if (y_dims.size() == 2) { + (*y_bd_dims)[(*y_bd_dims).size() - 1] = y_dims[1]; + (*y_bd_dims)[(*y_bd_dims).size() - 2] = y_dims[0]; + } else { + for (size_t i = 0; i < y_dims.size(); ++i) { + (*y_bd_dims)[(*y_bd_dims).size() - y_dims.size() + i] = y_dims[i]; + } + } + + if (!is_output_fused && x_dims.size() > 2 && y_dims.size() > 2) { + auto out_dims = vectorize(out->dims()); + for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) { + PADDLE_ENFORCE_EQ( + (*x_bd_dims)[i] == (*y_bd_dims)[i] || (*x_bd_dims)[i] == 1 || + (*y_bd_dims)[i] == 1, + true, + errors::InvalidArgument( + "Tensor dimensions are incorrect for broadcasting." + "Dimensions in X and Y must be same or equal to 1, but " + "received x_dim[%d]=%d and y_dims[%d]= %d", + i, + (*x_bd_dims)[i], + i, + (*y_bd_dims)[i])); + (out_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]); + } + out->Resize(make_ddim((out_dims))); + } +} + +template +void MatmulKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + bool transpose_x, + bool transpose_y, + DenseTensor *out) { + if (dev_ctx.HasDnnAttr("head_number")) { + const auto head_number = + PADDLE_GET_CONST(int, dev_ctx.GetDnnAttr("head_number")); + PADDLE_ENFORCE_EQ( + head_number, + 1, + errors::Unimplemented( + "oneDNN matmul doesn't support multiple heads. Expected " + "head_number=1. But received `head_number` is %d", + head_number)); + } + + constexpr bool is_int8 = funcs::is_int8(); + constexpr bool is_bfloat16 = funcs::is_bfloat16(); + const bool force_fp32_output = + dev_ctx.HasDnnAttr("force_fp32_output") + ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output")) + : false; + + bool fuse_relu = false; + if (dev_ctx.HasDnnAttr("fuse_activation")) { + auto act_type = + PADDLE_GET_CONST(std::string, dev_ctx.GetDnnAttr("fuse_activation")); + if (act_type == "relu" || act_type == "relu6") { + fuse_relu = true; + } + } + + auto x_dims = vectorize(GetDimsForInput(dev_ctx, x.dims(), "X")); + auto y_dims = vectorize(GetDimsForInput(dev_ctx, y.dims(), "Y")); + + int ndims = std::max(x_dims.size(), y_dims.size()); + ndims = std::max(ndims, 3); + + std::vector x_bd_dims(ndims, 1); + std::vector y_bd_dims(ndims, 1); + + CalculateMatrixDims(x_dims, + y_dims, + &x_bd_dims, + &y_bd_dims, + out, + funcs::IsOutputFused(dev_ctx)); + + if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) { + funcs::ExecuteMatmul( + dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out); + } else if (is_bfloat16) { + funcs::ExecuteMatmul( + dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out); + } else if (fuse_relu) { + funcs::ExecuteMatmul( + dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out); + } else { + funcs::ExecuteMatmul( + dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(matmul, + OneDNN, + ONEDNN, + phi::MatmulKernel, + float, + phi::dtype::bfloat16, + int8_t, + uint8_t) {}