/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/kernels/funcs/blas/blas.h" namespace { using dnnl::memory; using paddle::framework::ExecutionContext; using paddle::framework::GradVarName; using phi::DenseTensor; using phi::OneDNNContext; using phi::vectorize; using phi::funcs::OneDNNGetDataType; // Reshape a rank-3 tensor from P x M x N to (P * M) x N. // Identity op if the tensor is not of rank 3. static DenseTensor FoldOuterDims(const DenseTensor &input) { auto output = input; auto in_dims = input.dims(); if (in_dims.size() == 3) { output.Resize({in_dims[0] * in_dims[1], in_dims[2]}); } return output; } // Reshape a rank-3 tensor from P x M x N to M x (P * N). // (Warning: This requires transposing data and writes into new memory.) // Identity op if the tensor is not of rank 3. template static DenseTensor FoldFirstAndLastDims(const OneDNNContext &dev_ctx, const DenseTensor *input) { auto input_dims = vectorize(input->dims()); if (input_dims.size() != 3) { return *input; } DenseTensor output; output.Resize({input_dims[1], input_dims[0], input_dims[2]}); auto output_dims = vectorize(output.dims()); memory::data_type input_type = phi::funcs::ToOneDNNDataType(input->dtype()); phi::funcs::ReorderOneDNNHandler reorder_handler( output_dims, input->dtype(), input_type, dev_ctx.GetEngine()); auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( memory::format_tag::abc, phi::funcs::to_void_cast(input->data())); auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( &output, memory::format_tag::bac, dev_ctx.GetPlace()); auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p, reorder_dst_memory_p); auto &astream = OneDNNContext::tls().get_stream(); reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); astream.wait(); output.Resize({input_dims[1], input_dims[0] * input_dims[2]}); return output; } template class MatMulV1OneDNNHandler : public phi::funcs::OneDNNHandlerNoCachingT { public: MatMulV1OneDNNHandler(const ExecutionContext &ctx, const dnnl::engine engine, phi::Place cpu_place, const std::vector &x_org_dims, const std::vector &y_org_dims) : phi::funcs::OneDNNHandlerNoCachingT(engine, cpu_place) { // M X K * K X N std::vector x_dims(x_org_dims); std::vector y_dims(y_org_dims); const int MB_idx = x_dims.size() - 3; const int H_idx = x_dims.size() - 2; const int W_idx = x_dims.size() - 1; auto trans_x = ctx.Attr("transpose_X"); auto trans_y = ctx.Attr("transpose_Y"); if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]); if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]); const memory::dim M = x_dims[H_idx]; const memory::dim K = x_dims[W_idx]; const memory::dim N = y_dims[W_idx]; std::vector x_strides(x_dims.size() - 3, 1); std::vector y_strides(x_dims.size() - 3, 1); std::vector out_strides(x_dims.size() - 3, 1); std::vector out_ddims(x_dims.size() - 3, 1); x_strides.reserve(x_dims.size()); y_strides.reserve(x_dims.size()); out_strides.reserve(x_dims.size()); if (trans_x) { x_strides.insert(x_strides.end(), {M * K, 1, M}); } else { x_strides.insert(x_strides.end(), {M * K, K, 1}); } if (trans_y) { y_strides.insert(y_strides.end(), {N * K, 1, K}); } else { y_strides.insert(y_strides.end(), {N * K, N, 1}); } out_strides.insert(out_strides.end(), {M * N, N, 1}); out_ddims.insert(out_ddims.end(), {std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N}); for (int i = x_dims.size() - 4; i >= 0; --i) { out_ddims[i] = std::max(x_dims[i], y_dims[i]); x_strides[i] = x_dims[i + 1] * x_strides[i + 1]; y_strides[i] = y_dims[i + 1] * y_strides[i + 1]; out_strides[i] = out_ddims[i + 1] * out_strides[i + 1]; } auto x_md = memory::desc(x_dims, phi::funcs::OneDNNGetDataType(), x_strides); auto y_md = memory::desc(y_dims, phi::funcs::OneDNNGetDataType(), y_strides); auto out_md = memory::desc( out_ddims, phi::funcs::OneDNNGetDataType(), out_strides); dnnl::primitive_attr matmul_attrs; dnnl::post_ops post_operations; float scale_out = ComputeOutputScale(ctx); if (scale_out != 1.0f) { matmul_attrs.set_output_scales(0, {scale_out}); } matmul_attrs.set_post_ops(post_operations); this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md); } MatMulV1OneDNNHandler(const dnnl::engine engine, phi::Place cpu_place, DenseTensor *x, bool trans_x, DenseTensor *y, bool trans_y, DenseTensor *out, float scale) : phi::funcs::OneDNNHandlerNoCachingT(engine, cpu_place) { auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x->dims(), 0, trans_x); auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y->dims(), 0, trans_y); memory::dim x_bs = mat_dim_x.batch_size_; memory::dim y_bs = mat_dim_y.batch_size_; memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1; const memory::dim M = mat_dim_x.height_; const memory::dim N = mat_dim_y.width_; const memory::dim K = mat_dim_x.width_; memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K}; memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N}; memory::dims out_dims = {out_bs, M, N}; memory::dims x_strides = trans_x ? memory::dims{M * K, 1, M} : memory::dims{M * K, K, 1}; memory::dims y_strides = trans_y ? memory::dims{N * K, 1, K} : memory::dims{N * K, N, 1}; memory::dims out_strides = memory::dims{M * N, N, 1}; auto x_md = memory::desc(x_dims, OneDNNGetDataType(), x_strides); auto y_md = memory::desc(y_dims, OneDNNGetDataType(), y_strides); auto out_md = memory::desc(out_dims, OneDNNGetDataType(), out_strides); dnnl::primitive_attr attrs; if (scale != 1.0f) attrs.set_output_scales(0, {scale}); this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md); } float ComputeOutputScale(const ExecutionContext &ctx) { float alpha = ctx.Attr("alpha"); if (ctx.HasAttr("Scale_x") && ctx.HasAttr("Scale_y") && ctx.HasAttr("Scale_out")) { float scale_x = ctx.Attr("Scale_x"); float scale_y = ctx.Attr("Scale_y"); bool force_fp32_out = ctx.HasAttr("force_fp32_output") ? ctx.Attr("force_fp32_output") : false; float scale_out = force_fp32_out ? 1.f : ctx.Attr("Scale_out"); alpha *= scale_out / (scale_x * scale_y); } return alpha; } std::shared_ptr AcquireWeightsMemory(const DenseTensor *input) { const YT *input_data = input->data(); return this->AcquireMemoryFromPrimitive( this->fwd_pd_->weights_desc(), phi::funcs::to_void_cast(input_data)); } std::shared_ptr AcquireDstMemory(DenseTensor *output) { // We cannot use base AcquireDstMemory as it makes an allocation request // base on DST memory primitive size. This is fine in general, but in MatMul // we have primitive that covers only one batch of Data and then shift // pointer for every new batch. Hence DenseTensor size is bigger that // dst memory primitive size. So would we request less memory that is there // and it triggers an assertion. So as there is no 'any' format here we can // leave default size of DenseTensor as computed in ComputeInferShape OT *ptr = output->mutable_data(this->place_); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr); } private: uint16_t batch_size_; }; /** * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor. * * The shape would be [BatchSize, H, W] or [H, W]. * If transposed, `H,W` will be swapped. */ static void ReshapeTensorToMatrixSequence( DenseTensor *x, const phi::funcs::MatDescriptor &descriptor) { int64_t h, w; h = descriptor.height_; w = descriptor.width_; if (descriptor.trans_) { std::swap(w, h); } if (descriptor.batch_size_) { x->Resize({descriptor.batch_size_, h, w}); } else { x->Resize({h, w}); } } /** * Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor * Out = matmul(x, y) * * This method will first calculate X,Y matrix sequence, and then calculate * the out shape. * * Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2] * The out = [BatchSize, H1, W2] * * If there is no batch size in `X` and `Y`, the out will be [H1, W2] * If any of `X` and `Y` has batch size BatchSize, the out will have the * BatchSize. */ static void ReshapeXYOutToMatrixSequence(DenseTensor *x, DenseTensor *y, DenseTensor *out, bool trans_x, bool trans_y) { auto x_dim = phi::funcs::RowMatrixDimsFromVector(x->dims()); auto y_dim = phi::funcs::ColumnMatrixDimsFromVector(y->dims()); auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x); auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y); if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { out->Resize({mat_dim_x.height_, mat_dim_y.width_}); } else { out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_), mat_dim_x.height_, mat_dim_y.width_}); } ReshapeTensorToMatrixSequence(x, mat_dim_x); ReshapeTensorToMatrixSequence(y, mat_dim_y); } std::vector Transpose(const std::vector &x, const std::vector &axis) { size_t in_rank = x.size(); size_t axis_size = axis.size(); auto axis_set = std::set(axis.begin(), axis.end()); PADDLE_ENFORCE_EQ(axis_set.size(), axis_size, phi::errors::InvalidArgument( "In an axis array, elements must be unique.")); PADDLE_ENFORCE_EQ( in_rank, axis_size, phi::errors::InvalidArgument("The input dimension's size " "should be equal to the axis's size. " "But received dimension is %d, " "axis's size is %d", in_rank, axis_size)); PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size, phi::errors::InvalidArgument( "Axis values must be ranging from 0 to (dims - 1).")); std::vector new_x(x.size()); for (size_t i = 0; i < x.size(); i++) { new_x[i] = x[axis[i]]; } return new_x; } template void ExecuteMatMul(const ExecutionContext &ctx, const DenseTensor *x, const std::vector &x_dims, const DenseTensor *y, const std::vector &y_dims, DenseTensor *out) { const auto &dev_ctx = ctx.template device_context(); MatMulV1OneDNNHandler handler( ctx, dev_ctx.GetEngine(), ctx.GetPlace(), x_dims, y_dims); const auto src_memory_p = handler.AcquireSrcMemory(x); const auto weights_memory_p = handler.AcquireWeightsMemory(y); const auto dst_memory_p = handler.AcquireDstMemory(out); auto matmul_p = handler.AcquireForwardPrimitive(); std::unordered_map matmul_args = { {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_WEIGHTS, *weights_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; auto &astream = OneDNNContext::tls().get_stream(); matmul_p->execute(astream, matmul_args); astream.wait(); out->set_mem_desc( dst_memory_p->get_desc().reshape(vectorize(out->dims()))); } template class MatMulV1OneDNNKernel : public paddle::framework::OpKernel { public: void Compute(const ExecutionContext &ctx) const override { if (ctx.HasAttr("head_number")) { PADDLE_ENFORCE_EQ( ctx.Attr("head_number"), 1, phi::errors::Unimplemented( "oneDNN matmul doesn't support multiple heads. Expected " "head_number=1. But received `head_number` is %d", ctx.Attr("head_number"))); } constexpr bool is_int8 = phi::funcs::is_int8(); constexpr bool is_bfloat16 = phi::funcs::is_bfloat16(); const bool force_fp32_output = ctx.HasAttr("force_fp32_output") ? ctx.Attr("force_fp32_output") : false; constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses auto *x = ctx.Input("X"); auto *y = ctx.Input("Y"); auto *out = ctx.Output("Out"); auto x_dims = vectorize(x->dims()); auto y_dims = vectorize(y->dims()); int ndims = std::max(x_dims.size(), y_dims.size()); ndims = std::max(ndims, 3); std::vector x_bd_dims(ndims, 1); std::vector y_bd_dims(ndims, 1); CalculateMatrixDims(x_dims, y_dims, &x_bd_dims, &y_bd_dims, out); if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) { ExecuteMatMul(ctx, x, x_bd_dims, y, y_bd_dims, out); } else if (is_bfloat16) { ExecuteMatMul( ctx, x, x_bd_dims, y, y_bd_dims, out); } else if (fuse_relu) { ExecuteMatMul(ctx, x, x_bd_dims, y, y_bd_dims, out); } else { ExecuteMatMul(ctx, x, x_bd_dims, y, y_bd_dims, out); } } private: void CalculateMatrixDims(const std::vector &x_dims, const std::vector &y_dims, std::vector *x_bd_dims, std::vector *y_bd_dims, DenseTensor *out) const { if (x_dims.size() == 1) { (*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0]; } else if (x_dims.size() == 2) { (*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[1]; (*x_bd_dims)[(*x_bd_dims).size() - 2] = x_dims[0]; } else { for (size_t i = 0; i < x_dims.size(); ++i) { (*x_bd_dims)[(*x_bd_dims).size() - x_dims.size() + i] = x_dims[i]; } } if (y_dims.size() == 1) { (*y_bd_dims)[(*x_bd_dims).size() - 2] = y_dims[0]; } else if (y_dims.size() == 2) { (*y_bd_dims)[(*y_bd_dims).size() - 1] = y_dims[1]; (*y_bd_dims)[(*y_bd_dims).size() - 2] = y_dims[0]; } else { for (size_t i = 0; i < y_dims.size(); ++i) { (*y_bd_dims)[(*y_bd_dims).size() - y_dims.size() + i] = y_dims[i]; } } if (x_dims.size() > 2 && y_dims.size() > 2) { auto out_dims = vectorize(out->dims()); for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) { PADDLE_ENFORCE_EQ( (*x_bd_dims)[i] == (*y_bd_dims)[i] || (*x_bd_dims)[i] == 1 || (*y_bd_dims)[i] == 1, true, phi::errors::InvalidArgument( "DenseTensor dimensions are incorrect for broadcasting." "Dimensions in X and Y must be same or equal to 1, but " "received x_dim[%d]=%d and y_dims[%d]= %d", i, (*x_bd_dims)[i], i, (*y_bd_dims)[i])); (out_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]); } out->Resize(phi::make_ddim((out_dims))); } } }; template class MatMulV1GradOneDNNKernel : public paddle::framework::OpKernel { public: void Compute(const ExecutionContext &ctx) const override { if (ctx.HasAttr("head_number")) { PADDLE_ENFORCE_EQ( ctx.Attr("head_number"), 1, phi::errors::Unimplemented( "oneDNN matmul doesn't support multiple heads. Expected " "head_number=1. But received `head_number` is %d", ctx.Attr("head_number"))); } const auto &dev_ctx = ctx.template device_context(); const auto &onednn_engine = dev_ctx.GetEngine(); auto x = *ctx.Input("X"); auto y = *ctx.Input("Y"); auto dout = *ctx.Input(paddle::framework::GradVarName("Out")); auto *dx = ctx.Output(paddle::framework::GradVarName("X")); auto *dy = ctx.Output(paddle::framework::GradVarName("Y")); bool transpose_x = ctx.Attr("transpose_X"); bool transpose_y = ctx.Attr("transpose_Y"); ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); phi::DDim dx_dims; if (dx) { dx_dims = dx->dims(); if (dx_dims != x.dims()) { dx->Resize(x.dims()); } } phi::DDim dy_dims; if (dy) { dy_dims = dy->dims(); if (dy_dims != y.dims()) { dy->Resize(y.dims()); } } if (transpose_x && transpose_y) { this->ExecuteMatMulGrad( ctx, dev_ctx, onednn_engine, &y, true, true, &dout, true, false, dx); this->ExecuteMatMulGrad( ctx, dev_ctx, onednn_engine, &dout, true, true, &x, true, false, dy); } else if (transpose_x) { this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, false, false, &dout, true, false, dx); this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, false, false, &dout, false, true, dy); } else if (transpose_y) { this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false, &y, false, true, dx); this->ExecuteMatMulGrad( ctx, dev_ctx, onednn_engine, &dout, true, true, &x, false, true, dy); } else { this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false, &y, true, false, dx); this->ExecuteMatMulGrad( ctx, dev_ctx, onednn_engine, &x, true, true, &dout, false, true, dy); } if (dx) { if (dx_dims != x.dims()) { dx->Resize(dx_dims); dx->set_mem_desc(x.mem_desc()); } } if (dy) { if (dy_dims != y.dims()) { dy->Resize(dy_dims); dy->set_mem_desc(y.mem_desc()); } } } private: void ExecuteMatMulGrad(const ExecutionContext &ctx, const OneDNNContext &dev_ctx, const dnnl::engine &engine, DenseTensor *x, bool trans_x, bool is_fold_init_dims_x, DenseTensor *y, bool trans_y, bool is_fold_init_dims_y, DenseTensor *out) const { // gradient is calculated in a different way when broadcasting is used bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) && out->dims().size() == 2; DenseTensor x_combined, y_combined; if (need_combine) { x_combined = is_fold_init_dims_x ? FoldOuterDims(*x) : FoldFirstAndLastDims(dev_ctx, x); y_combined = is_fold_init_dims_y ? FoldOuterDims(*y) : FoldFirstAndLastDims(dev_ctx, y); } else { x_combined = *x; y_combined = *y; } float alpha = ctx.Attr("alpha"); MatMulV1OneDNNHandler handler(engine, ctx.GetPlace(), &x_combined, trans_x, &y_combined, trans_y, out, alpha); const auto src_memory_p = handler.AcquireSrcMemory(&x_combined); const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined); const auto dst_memory_p = handler.AcquireDstMemory(out); auto matmul_p = handler.AcquireForwardPrimitive(); std::unordered_map matmul_args = { {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_WEIGHTS, *weights_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; auto &astream = OneDNNContext::tls().get_stream(); matmul_p->execute(astream, matmul_args); astream.wait(); out->set_mem_desc( dst_memory_p->get_desc().reshape(vectorize(out->dims()))); } }; } // anonymous namespace REGISTER_OP_KERNEL(matmul, MKLDNN, ::phi::CPUPlace, MatMulV1OneDNNKernel, MatMulV1OneDNNKernel, MatMulV1OneDNNKernel, MatMulV1OneDNNKernel); REGISTER_OP_KERNEL(matmul_grad, MKLDNN, ::phi::CPUPlace, MatMulV1GradOneDNNKernel, MatMulV1GradOneDNNKernel);