From f154d5860f37dc3201dcd3c0d3a59a9771154999 Mon Sep 17 00:00:00 2001 From: wawltor Date: Wed, 11 Mar 2020 15:46:52 +0800 Subject: [PATCH] Speed up the matmul op, use the gemm replace the batch gemm (#22926) In the op of gemm, we use the gemm to replace batch gemm, speed up the matmul op --- paddle/fluid/operators/matmul_op.cc | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 96346e6d1cf..af9e0644c92 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -72,8 +72,21 @@ class MatMulKernel : public framework::OpKernel { ColumnMatrixFromVector(y.dims()), 0, context.Attr("transpose_Y")); auto scale = static_cast(context.Attr("alpha")); + int head_number = 1; +#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) + head_number = context.Attr("head_number"); +#endif + + const auto &x_dims = x.dims(); + const auto &y_dims = y.dims(); + if (head_number <= 1 && x_dims.size() == 3 && y_dims.size() <= 2) { + // the transpose_X must be false, if is true, the transpose cost much time + if (!context.Attr("transpose_X")) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + } + } #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) - int head_number = context.Attr("head_number"); bool split_vertical_y = (mat_dim_a.width_ != mat_dim_b.height_); if (head_number > 1) { @@ -210,6 +223,19 @@ class MatMulGradKernel : public framework::OpKernel { auto blas = math::GetBlas(context); auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + + int head_number = 1; +#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) + head_number = context.Attr("head_number"); +#endif + + if (head_number <= 1 && a.dims().size() == 3 && b.dims().size() <= 2) { + // the transpose_X must be false, if is true, the transpose cost much time + if (!trans_a) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + } + } blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast(context.Attr("alpha")), out, T(0)); } -- GitLab