diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index 68813fbb5482eaf3491bd00d4ba780c1e8f35a49..810c0eaff186125bf6a41d14d7c1028797f81362 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -75,20 +75,6 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext &dev_ctx, return output; } -// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the -// original x_dim is returned. -static paddle::framework::DDim RowMatrixDimsFromVector( - const paddle::framework::DDim &x_dim) { - return x_dim.size() > 1 ? x_dim : phi::make_ddim({1, x_dim[0]}); -} - -// Get column matrix shape from a vector shape. If the ran of y_dim > 1, the -// original y_dim is returned. -static paddle::framework::DDim ColumnMatrixDimsFromVector( - const paddle::framework::DDim &y_dim) { - return y_dim.size() > 1 ? y_dim : phi::make_ddim({y_dim[0], 1}); -} - phi::DDim GetDimForInput(const ExecutionContext &ctx, std::string input_name) { auto shape = ctx.Attr>("fused_reshape_" + input_name); auto axis = ctx.Attr>("fused_transpose_" + input_name); @@ -245,8 +231,8 @@ static void ReshapeTensorToMatrixSequence( */ static void ReshapeXYOutToMatrixSequence( Tensor *x, Tensor *y, Tensor *out, bool trans_x, bool trans_y) { - auto x_dim = RowMatrixDimsFromVector(x->dims()); - auto y_dim = ColumnMatrixDimsFromVector(y->dims()); + auto x_dim = phi::funcs::RowMatrixDimsFromVector(x->dims()); + auto y_dim = phi::funcs::ColumnMatrixDimsFromVector(y->dims()); auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x); auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y); if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { @@ -304,8 +290,9 @@ std::vector GetInputStrides(const ExecutionContext &ctx, new_dims = input_dims.reshape(shape).transpose(axis); } - auto &MatrixDimsFromVector = - input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector; + auto &MatrixDimsFromVector = input_name == "X" + ? phi::funcs::RowMatrixDimsFromVector + : phi::funcs::ColumnMatrixDimsFromVector; phi::funcs::MatDescriptor mat_dim = phi::funcs::CreateMatrixDescriptor( MatrixDimsFromVector(new_dims), 0, @@ -707,199 +694,6 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel { } }; -template -class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel { - public: - void Compute(const ExecutionContext &ctx) const override { - const auto &dev_ctx = ctx.template device_context(); - const auto &onednn_engine = dev_ctx.GetEngine(); - - auto *x = ctx.Input("X"); - auto *y = ctx.Input("Y"); - - auto x_dims = vectorize(x->dims()); - auto y_dims = vectorize(y->dims()); - - bool is_broadcast = true; - if (x_dims.size() <= 2 || y_dims.size() <= 2) { - is_broadcast = false; - } else if (x_dims.size() != y_dims.size()) { - is_broadcast = true; - } else { - is_broadcast = !std::equal(x_dims.cbegin(), - x_dims.cbegin() + x_dims.size() - 2, - y_dims.cbegin()); - } - - // if no broadcasting is needed, we can simply use matmul's grad and avoid - // using reduce_sum - if (!is_broadcast) { - matmul_v1_grad_mkldnn_kernel.Compute(ctx); - return; - } - - auto *dout = ctx.Input(GradVarName("Out")); - auto *dx = ctx.Output(GradVarName("X")); - auto *dy = ctx.Output(GradVarName("Y")); - - bool trans_x = ctx.HasAttr("trans_x") ? ctx.Attr("trans_x") - : ctx.Attr("transpose_X"); - bool trans_y = ctx.HasAttr("trans_y") ? ctx.Attr("trans_y") - : ctx.Attr("transpose_Y"); - auto dout_dims = vectorize(dout->dims()); - - size_t ndims = std::max(x->dims().size(), y->dims().size()); - ndims = std::max(ndims, 3); - - if (x_dims.size() != ndims) { - x_dims = ExtendDimsWithOnes(x_dims, ndims); - } else if (y_dims.size() != ndims) { - y_dims = ExtendDimsWithOnes(y_dims, ndims); - } - - // in broadcasting scenario new memory is required because - // reduce sum must be calculated upon broadcasted dims - Tensor dx_tmp, dy_tmp; - - std::vector dx_bd_dims(x_dims); - std::vector dy_bd_dims(y_dims); - - CalculateGradMatrixDims( - ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, &dx_bd_dims, &dy_bd_dims); - - if (trans_x && trans_y) { - ExecuteMatMulV2( - ctx, onednn_engine, y, y_dims, true, dout, dout_dims, true, &dx_tmp); - ExecuteMatMulV2( - ctx, onednn_engine, dout, dout_dims, true, x, x_dims, true, &dy_tmp); - } else if (trans_x) { - ExecuteMatMulV2( - ctx, onednn_engine, y, y_dims, false, dout, dout_dims, true, &dx_tmp); - ExecuteMatMulV2(ctx, - onednn_engine, - x, - x_dims, - false, - dout, - dout_dims, - false, - &dy_tmp); - } else if (trans_y) { - ExecuteMatMulV2(ctx, - onednn_engine, - dout, - dout_dims, - false, - y, - y_dims, - false, - &dx_tmp); - ExecuteMatMulV2( - ctx, onednn_engine, dout, dout_dims, true, x, x_dims, false, &dy_tmp); - } else { - ExecuteMatMulV2( - ctx, onednn_engine, dout, dout_dims, false, y, y_dims, true, &dx_tmp); - ExecuteMatMulV2( - ctx, onednn_engine, x, x_dims, true, dout, dout_dims, false, &dy_tmp); - } - - if (x_dims != dx_bd_dims) { - ReduceSumForMatmulGradOutput(ctx, - dev_ctx, - onednn_engine, - &dx_tmp, - dx, - x_dims, - vectorize(x->dims())); - } else { - *dx = std::move(dx_tmp); - } - if (y_dims != dy_bd_dims) { - ReduceSumForMatmulGradOutput(ctx, - dev_ctx, - onednn_engine, - &dy_tmp, - dy, - y_dims, - vectorize(y->dims())); - } else { - *dy = std::move(dy_tmp); - } - - dx->Resize(x->dims()); - dy->Resize(y->dims()); - } - - private: - void CalculateGradMatrixDims(const ExecutionContext &ctx, - Tensor *dx_tmp, - Tensor *dy_tmp, - const std::vector &dx_dims, - const std::vector &dy_dims, - std::vector *dx_bd_dims, - std::vector *dy_bd_dims) const { - for (size_t i = 0; i < dx_dims.size() - 2; ++i) { - if (dx_dims[i] != dy_dims[i]) { - if (dx_dims[i] == 1) { - (*dx_bd_dims)[i] = dy_dims[i]; - } else { - (*dy_bd_dims)[i] = dx_dims[i]; - } - } - } - - dx_tmp->Resize(phi::make_ddim((*dx_bd_dims))); - dx_tmp->mutable_data(ctx.GetPlace()); - dy_tmp->Resize(phi::make_ddim((*dy_bd_dims))); - dy_tmp->mutable_data(ctx.GetPlace()); - } - - void ReduceSumForMatmulGradOutput( - const ExecutionContext &ctx, - const MKLDNNDeviceContext &dev_ctx, - const dnnl::engine onednn_engine, - const Tensor *dx_tmp, - Tensor *dx, - const std::vector &dx_dims, - const std::vector &squeezed_dims) const { - phi::funcs::ReductionOneDNNHandler handler( - dnnl::algorithm::reduction_sum, - 0.0f, - 0.0f, - onednn_engine, - ctx.GetPlace(), - dx_tmp, - dx, - dx_dims); - - auto src_memory_p = handler.AcquireSrcMemory(dx_tmp); - auto dst_memory_p = handler.AcquireDstMemory(dx); - - std::unordered_map reduction_args = { - {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; - - auto &astream = MKLDNNDeviceContext::tls().get_stream(); - auto reduction_p = handler.AcquireForwardPrimitive(); - - reduction_p->execute(astream, reduction_args); - astream.wait(); - - dx->set_mem_desc(dst_memory_p->get_desc().reshape(squeezed_dims)); - } - - std::vector ExtendDimsWithOnes(const std::vector &dims, - int new_size) const { - std::vector new_dims(new_size, 1); - for (size_t i = 0; i < dims.size(); ++i) { - new_dims[new_size - dims.size() + i] = dims[i]; - } - - return new_dims; - } - - private: - MatMulGradMKLDNNKernel matmul_v1_grad_mkldnn_kernel; -}; } // anonymous namespace REGISTER_OP_KERNEL(matmul, @@ -923,9 +717,3 @@ REGISTER_OP_KERNEL(matmul_v2, MatMulV2MKLDNNKernel, MatMulV2MKLDNNKernel, MatMulV2MKLDNNKernel); - -REGISTER_OP_KERNEL(matmul_v2_grad, - MKLDNN, - ::paddle::platform::CPUPlace, - MatMulV2GradMKLDNNKernel, - MatMulV2GradMKLDNNKernel); diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index 7395138bfd63b30a54e985646ae6f9f236e30e93..bd3d3f30f7a447090a74ee1fddef77bca65e3029 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -29,6 +29,7 @@ limitations under the License. */ #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/data_layout_transform.h" #include "paddle/phi/kernels/funcs/pooling.h" @@ -1331,14 +1332,13 @@ class BatchNormOneDNNHandler diff_scaleshift_data); } - std::shared_ptr AcquireMeanMemory( - const phi::DenseTensor* mean) { + std::shared_ptr AcquireMeanMemory(const DenseTensor* mean) { const T* mean_data = mean->data(); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), to_void_cast(mean_data)); } - std::shared_ptr AcquireMeanMemory(phi::DenseTensor* mean) { + std::shared_ptr AcquireMeanMemory(DenseTensor* mean) { T* mean_data = mean->mutable_data(this->place_, this->fwd_pd_->mean_desc().get_size()); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), @@ -1346,14 +1346,13 @@ class BatchNormOneDNNHandler } std::shared_ptr AcquireVarianceMemory( - const phi::DenseTensor* variance) { + const DenseTensor* variance) { const T* variance_data = variance->data(); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), to_void_cast(variance_data)); } - std::shared_ptr AcquireVarianceMemory( - phi::DenseTensor* variance) { + std::shared_ptr AcquireVarianceMemory(DenseTensor* variance) { T* variance_data = variance->mutable_data( this->place_, this->fwd_pd_->variance_desc().get_size()); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), @@ -1630,5 +1629,346 @@ class PoolingOneDNNHandler } }; +static DDim RowMatrixDimsFromVector(const DDim& x_dim) { + return x_dim.size() > 1 ? x_dim : make_ddim({1, x_dim[0]}); +} + +static DDim ColumnMatrixDimsFromVector(const DDim& y_dim) { + return y_dim.size() > 1 ? y_dim : make_ddim({y_dim[0], 1}); +} + +static std::vector TransposeAxis(const std::vector& x, + const std::vector& axis) { + size_t in_rank = x.size(); + size_t axis_size = axis.size(); + + auto axis_set = std::set(axis.begin(), axis.end()); + PADDLE_ENFORCE_EQ(axis_set.size(), + axis_size, + paddle::platform::errors::InvalidArgument( + "In an axis array, elements must be unique.")); + + PADDLE_ENFORCE_EQ(in_rank, + axis_size, + paddle::platform::errors::InvalidArgument( + "The input dimension's size " + "should be equal to the axis's size. " + "But received dimension is %d, " + "axis's size is %d", + in_rank, + axis_size)); + + PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), + axis_size, + paddle::platform::errors::InvalidArgument( + "Axis values must be ranging from 0 to (dims - 1).")); + + std::vector new_x(x.size()); + for (size_t i = 0; i < x.size(); i++) { + new_x[i] = x[axis[i]]; + } + return new_x; +} + +static std::vector GetInputStrides(const OneDNNContext& dev_ctx, + const DDim& input_dims, + const std::string input_name, + const bool transpose_input) { + auto new_dims = input_dims; + auto shape = + dev_ctx.HasDnnAttr("fused_reshape_" + input_name) + ? PADDLE_GET_CONST(std::vector, + dev_ctx.GetDnnAttr("fused_reshape_" + input_name)) + : std::vector(); + auto axis = dev_ctx.HasDnnAttr("fused_transpose_" + input_name) + ? PADDLE_GET_CONST( + std::vector, + dev_ctx.GetDnnAttr("fused_transpose_" + input_name)) + : std::vector(); + + if (!shape.empty() && !axis.empty()) { + new_dims = input_dims.reshape(shape).transpose(axis); + } + + auto& MatrixDimsFromVector = + input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector; + phi::funcs::MatDescriptor mat_dim = phi::funcs::CreateMatrixDescriptor( + MatrixDimsFromVector(new_dims), 0, transpose_input); + + std::vector strides; + if (!shape.empty()) { + auto shape2 = input_dims.reshape(shape); + strides.push_back(1); + for (auto i = shape2.size() - 1; i > 0; --i) { + strides.insert(strides.begin(), + strides.front() * static_cast(shape2[i])); + } + strides = TransposeAxis(strides, axis); + if (shape.size() == 2) + strides.insert(strides.begin(), + static_cast(shape[0] * shape[1])); + mat_dim.stride_ = strides[0]; + if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin())); + } + return strides; +} + +static bool IsOutputFused(const OneDNNContext& dev_ctx) { + const auto shape = + dev_ctx.HasDnnAttr("fused_reshape_Out") + ? PADDLE_GET_CONST(std::vector, + dev_ctx.GetDnnAttr("fused_reshape_Out")) + : std::vector(); + const auto axis = + dev_ctx.HasDnnAttr("fused_transpose_Out") + ? PADDLE_GET_CONST(std::vector, + dev_ctx.GetDnnAttr("fused_transpose_Out")) + : std::vector(); + return !shape.empty() && !axis.empty(); +} + +template +class MatmulOneDNNHandler + : public phi::funcs::OneDNNHandlerNoCachingT { + public: + MatmulOneDNNHandler(const OneDNNContext& dev_ctx, + const std::vector& x_org_dims, + const std::vector& y_org_dims, + bool trans_x, + bool trans_y, + const std::vector& x_strides_override, + const std::vector& y_strides_override, + bool is_output_fused) + : phi::funcs::OneDNNHandlerNoCachingT( + dev_ctx.GetEngine(), dev_ctx.GetPlace()) { + // M X K * K X N + std::vector x_dims(x_org_dims); + std::vector y_dims(y_org_dims); + + const int MB_idx = x_dims.size() - 3; + const int H_idx = x_dims.size() - 2; + const int W_idx = x_dims.size() - 1; + + if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]); + if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]); + + const memory::dim M = x_dims[H_idx]; + const memory::dim K = x_dims[W_idx]; + const memory::dim N = y_dims[W_idx]; + + std::vector x_strides(x_dims.size() - 3, 1); + std::vector y_strides(x_dims.size() - 3, 1); + std::vector out_strides(x_dims.size() - 3, 1); + std::vector out_ddims(x_dims.size() - 3, 1); + + x_strides.reserve(x_dims.size()); + y_strides.reserve(x_dims.size()); + out_strides.reserve(x_dims.size()); + + if (!x_strides_override.empty()) { + x_strides = x_strides_override; + } else { + if (!trans_x) { + x_strides.insert(x_strides.end(), {M * K, K, 1}); + } else { + x_strides.insert(x_strides.end(), {M * K, 1, M}); + } + } + + if (!y_strides_override.empty()) { + y_strides = y_strides_override; + } else { + if (!trans_y) { + y_strides.insert(y_strides.end(), {N * K, N, 1}); + } else { + y_strides.insert(y_strides.end(), {N * K, 1, K}); + } + } + + out_strides.insert(out_strides.end(), {M * N, N, 1}); + out_ddims.insert(out_ddims.end(), + {std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N}); + + for (int i = x_dims.size() - 4; i >= 0; --i) { + out_ddims[i] = std::max(x_dims[i], y_dims[i]); + if (x_strides_override.empty()) { + x_strides[i] = x_dims[i + 1] * x_strides[i + 1]; + } + if (y_strides_override.empty()) { + y_strides[i] = y_dims[i + 1] * y_strides[i + 1]; + } + out_strides[i] = out_ddims[i + 1] * out_strides[i + 1]; + } + + // TODO(jczaja): Why not for int8?? + if (!is_int8() && is_output_fused) { + out_strides = FakeTransposeStrides(out_ddims); + } + + auto x_md = memory::desc(x_dims, OneDNNGetDataType(), x_strides); + auto y_md = memory::desc(y_dims, OneDNNGetDataType(), y_strides); + auto out_md = memory::desc(out_ddims, OneDNNGetDataType(), out_strides); + + const auto matmul_attrs = CreateMatmulAttrs(dev_ctx); + + this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md); + } + + float ComputeOutputScale(const OneDNNContext& dev_ctx) { + float alpha = dev_ctx.HasDnnAttr("alpha") + ? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("alpha")) + : 1.0f; + + if (dev_ctx.HasDnnAttr("Scale_x") && dev_ctx.HasDnnAttr("Scale_y") && + dev_ctx.HasDnnAttr("Scale_out")) { + float scale_x = PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("Scale_x")); + float scale_y = PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("Scale_y")); + bool force_fp32_out = + dev_ctx.HasDnnAttr("force_fp32_output") + ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output")) + : false; + float scale_out = + force_fp32_out + ? 1.f + : PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("Scale_out")); + alpha *= scale_out / (scale_x * scale_y); + } + return alpha; + } + + dnnl::primitive_attr CreateMatmulAttrs(const OneDNNContext& dev_ctx) { + dnnl::primitive_attr matmul_attrs; + dnnl::post_ops post_operations; + + float scale_out = ComputeOutputScale(dev_ctx); + if (scale_out != 1.0f) { + matmul_attrs.set_output_scales(0, {scale_out}); + } + + if (dev_ctx.HasDnnInput("ResidualData")) { + auto* residual_data = dev_ctx.GetDnnInput("ResidualData"); + auto residual_data_tz = vectorize(residual_data->dims()); + auto residual_data_md = memory::desc(residual_data_tz, + OneDNNGetDataType(), + dnnl::memory::format_tag::any); + post_operations.append_binary(dnnl::algorithm::binary_add, + residual_data_md); + if (dev_ctx.HasDnnAttr("Scale_in_eltwise")) { + float scale_in_eltwise = + PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("Scale_in_eltwise")); + float sum_scale = scale_out / scale_in_eltwise; + post_operations.append_sum(sum_scale); + } + } + + AppendActivation(dev_ctx, post_operations); + + if (dev_ctx.HasDnnAttr("fused_output_scale")) { + float scale_alpha = + PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fused_output_scale")); + post_operations.append_eltwise( + 1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f); + } + + matmul_attrs.set_post_ops(post_operations); + return matmul_attrs; + } + + std::vector FakeTransposeStrides( + const std::vector& matmul_out_dims) const { + // fuse matmul_v2 + transpose + reshape guarantees that output is 4D and + // transpose axis are: {0, 2, 1, 3} + std::vector transpose_axis = {0, 2, 1, 3}; + std::vector fake_strides(transpose_axis.size()); + int ndims = static_cast(transpose_axis.size()); + + int total_stride = 1; + + for (int i = ndims - 1; i >= 0; --i) { + fake_strides[transpose_axis[i]] = total_stride; + total_stride *= matmul_out_dims[transpose_axis[i]]; + } + + return fake_strides; + } + + std::shared_ptr AcquireWeightsMemory(const DenseTensor* input) { + const YT* input_data = input->data(); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(), + to_void_cast(input_data)); + } + + std::shared_ptr AcquireDstMemory(const OneDNNContext& dev_ctx, + DenseTensor* output) { + // We cannot use base AcquireDstMemory as it makes an allocation request + // base on DST memory primitive size. This is fine in general, but in MatMul + // we have primitive that covers only one batch of Data and then shift + // pointer for every new batch. Hence DenseTensor size is bigger that + // dst memory primitive size. So would we request less memory that is there + // and it triggers an assertion. So as there is no 'any' format here we can + // leave default size of DenseTensor as computed in ComputeInferShape + OT* ptr = dev_ctx.template Alloc(output); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr); + } +}; + +template +void ExecuteMatmul(const OneDNNContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const std::vector& x_dims, + const std::vector& y_dims, + bool trans_x, + bool trans_y, + DenseTensor* out) { + auto x_strides_override = GetInputStrides(dev_ctx, x.dims(), "X", trans_x); + auto y_strides_override = GetInputStrides(dev_ctx, y.dims(), "Y", trans_y); + MatmulOneDNNHandler handler(dev_ctx, + x_dims, + y_dims, + trans_x, + trans_y, + x_strides_override, + y_strides_override, + IsOutputFused(dev_ctx)); + + const auto src_memory_p = handler.AcquireSrcMemory(&x); + const auto weights_memory_p = handler.AcquireWeightsMemory(&y); + const auto dst_memory_p = handler.AcquireDstMemory(dev_ctx, out); + + auto matmul_p = handler.AcquireForwardPrimitive(); + + std::unordered_map matmul_args = { + {DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + if (dev_ctx.HasDnnInput("ResidualData")) { + auto* residual_data = dev_ctx.GetDnnInput("ResidualData"); + const auto residual_data_memory_p = handler.AcquireSrcMemory(residual_data); + matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, + *residual_data_memory_p}); + } + + auto& astream = OneDNNContext::tls().get_stream(); + matmul_p->execute(astream, matmul_args); + astream.wait(); + + // TODO(jczaja): Explain why int8 format of dst is ABCD and do not need + // permute + if (IsOutputFused(dev_ctx) && !is_int8()) { + const auto axis = + dev_ctx.HasDnnAttr("fused_transpose_Out") + ? PADDLE_GET_CONST(std::vector, + dev_ctx.GetDnnAttr("fused_transpose_Out")) + : std::vector(); + auto permuted_md = dst_memory_p->get_desc().permute_axes(axis); + out->set_mem_desc(permuted_md.reshape(vectorize(out->dims()))); + } else { + out->set_mem_desc( + dst_memory_p->get_desc().reshape(vectorize(out->dims()))); + } +} + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/onednn/matmul_grad_kernel.cc b/paddle/phi/kernels/onednn/matmul_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..47807f156b18f5dc6db66cb3543c5983e84697c2 --- /dev/null +++ b/paddle/phi/kernels/onednn/matmul_grad_kernel.cc @@ -0,0 +1,163 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/matmul_grad_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +std::vector ExtendDimsWithOnes(const std::vector &dims, + int new_size) { + std::vector new_dims(new_size, 1); + for (size_t i = 0; i < dims.size(); ++i) { + new_dims[new_size - dims.size() + i] = dims[i]; + } + + return new_dims; +} + +template +void CalculateGradMatrixDims(const OneDNNContext &dev_ctx, + DenseTensor *dx_tmp, + DenseTensor *dy_tmp, + const std::vector &dx_dims, + const std::vector &dy_dims, + std::vector *dx_bd_dims, + std::vector *dy_bd_dims) { + for (size_t i = 0; i < dx_dims.size() - 2; ++i) { + if (dx_dims[i] != dy_dims[i]) { + if (dx_dims[i] == 1) { + (*dx_bd_dims)[i] = dy_dims[i]; + } else { + (*dy_bd_dims)[i] = dx_dims[i]; + } + } + } + + dx_tmp->Resize(make_ddim((*dx_bd_dims))); + dev_ctx.template Alloc(dx_tmp); + dy_tmp->Resize(make_ddim((*dy_bd_dims))); + dev_ctx.template Alloc(dy_tmp); +} + +template +void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx, + const DenseTensor *dx_tmp, + DenseTensor *dx, + const std::vector &dx_dims, + const std::vector &squeezed_dims) { + funcs::ReductionOneDNNHandler handler(dnnl::algorithm::reduction_sum, + 0.0f, + 0.0f, + dev_ctx.GetEngine(), + dev_ctx.GetPlace(), + dx_tmp, + dx, + dx_dims); + + auto src_memory_p = handler.AcquireSrcMemory(dx_tmp); + auto dst_memory_p = handler.AcquireDstMemory(dx); + + std::unordered_map reduction_args = { + {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; + + auto &astream = OneDNNContext::tls().get_stream(); + auto reduction_p = handler.AcquireForwardPrimitive(); + + reduction_p->execute(astream, reduction_args); + astream.wait(); + + dx->set_mem_desc(dst_memory_p->get_desc().reshape(squeezed_dims)); +} + +template +void MatmulGradKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &dout, + bool transpose_x, + bool transpose_y, + DenseTensor *dx, + DenseTensor *dy) { + auto x_dims = vectorize(x.dims()); + auto y_dims = vectorize(y.dims()); + auto dout_dims = vectorize(dout.dims()); + + size_t ndims = std::max(x_dims.size(), y_dims.size()); + ndims = std::max(ndims, 3); + + if (x_dims.size() != ndims) { + x_dims = ExtendDimsWithOnes(x_dims, ndims); + } else if (y_dims.size() != ndims) { + y_dims = ExtendDimsWithOnes(y_dims, ndims); + } + + // in broadcasting scenario new memory is required because + // reduce sum must be calculated upon broadcasted dims + DenseTensor dx_tmp, dy_tmp; + std::vector dx_bd_dims(x_dims); + std::vector dy_bd_dims(y_dims); + + CalculateGradMatrixDims( + dev_ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, &dx_bd_dims, &dy_bd_dims); + + if (transpose_x && transpose_y) { + funcs::ExecuteMatmul( + dev_ctx, y, dout, y_dims, dout_dims, true, true, &dx_tmp); + funcs::ExecuteMatmul( + dev_ctx, dout, x, dout_dims, x_dims, true, true, &dy_tmp); + } else if (transpose_x) { + funcs::ExecuteMatmul( + dev_ctx, y, dout, y_dims, dout_dims, false, true, &dx_tmp); + funcs::ExecuteMatmul( + dev_ctx, x, dout, x_dims, dout_dims, false, false, &dy_tmp); + } else if (transpose_y) { + funcs::ExecuteMatmul( + dev_ctx, dout, y, dout_dims, y_dims, false, false, &dx_tmp); + funcs::ExecuteMatmul( + dev_ctx, dout, x, dout_dims, x_dims, true, false, &dy_tmp); + } else { + funcs::ExecuteMatmul( + dev_ctx, dout, y, dout_dims, y_dims, false, true, &dx_tmp); + funcs::ExecuteMatmul( + dev_ctx, x, dout, x_dims, dout_dims, true, false, &dy_tmp); + } + + if (x_dims != dx_bd_dims) { + ReduceSumForMatmulGradOutput( + dev_ctx, &dx_tmp, dx, x_dims, vectorize(x.dims())); + } else { + *dx = std::move(dx_tmp); + } + if (y_dims != dy_bd_dims) { + ReduceSumForMatmulGradOutput( + dev_ctx, &dy_tmp, dy, y_dims, vectorize(y.dims())); + } else { + *dy = std::move(dy_tmp); + } + + dx->Resize(x.dims()); + dy->Resize(y.dims()); +} + +} // namespace phi + +PD_REGISTER_KERNEL(matmul_grad, + OneDNN, + ONEDNN, + phi::MatmulGradKernel, + float, + phi::dtype::bfloat16) {}