diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 96346e6d1cfef5d3afcc626a824f4e14977e55ca..af9e0644c9269a234a6a01153ef1b194ca63cbb0 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -72,8 +72,21 @@ class MatMulKernel : public framework::OpKernel { ColumnMatrixFromVector(y.dims()), 0, context.Attr("transpose_Y")); auto scale = static_cast(context.Attr("alpha")); + int head_number = 1; +#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) + head_number = context.Attr("head_number"); +#endif + + const auto &x_dims = x.dims(); + const auto &y_dims = y.dims(); + if (head_number <= 1 && x_dims.size() == 3 && y_dims.size() <= 2) { + // the transpose_X must be false, if is true, the transpose cost much time + if (!context.Attr("transpose_X")) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + } + } #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) - int head_number = context.Attr("head_number"); bool split_vertical_y = (mat_dim_a.width_ != mat_dim_b.height_); if (head_number > 1) { @@ -210,6 +223,19 @@ class MatMulGradKernel : public framework::OpKernel { auto blas = math::GetBlas(context); auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + + int head_number = 1; +#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) + head_number = context.Attr("head_number"); +#endif + + if (head_number <= 1 && a.dims().size() == 3 && b.dims().size() <= 2) { + // the transpose_X must be false, if is true, the transpose cost much time + if (!trans_a) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + } + } blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast(context.Attr("alpha")), out, T(0)); }