未验证 提交 f154d586 编写于 作者: W wawltor 提交者: GitHub

Speed up the matmul op, use the gemm replace the batch gemm (#22926)

In the op of gemm, we use the gemm to replace batch gemm, speed up the matmul op 
上级 8f541027
...@@ -72,8 +72,21 @@ class MatMulKernel : public framework::OpKernel<T> { ...@@ -72,8 +72,21 @@ class MatMulKernel : public framework::OpKernel<T> {
ColumnMatrixFromVector(y.dims()), 0, context.Attr<bool>("transpose_Y")); ColumnMatrixFromVector(y.dims()), 0, context.Attr<bool>("transpose_Y"));
auto scale = static_cast<T>(context.Attr<float>("alpha")); auto scale = static_cast<T>(context.Attr<float>("alpha"));
int head_number = 1;
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
head_number = context.Attr<int>("head_number");
#endif
const auto &x_dims = x.dims();
const auto &y_dims = y.dims();
if (head_number <= 1 && x_dims.size() == 3 && y_dims.size() <= 2) {
// the transpose_X must be false, if is true, the transpose cost much time
if (!context.Attr<bool>("transpose_X")) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
}
}
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
int head_number = context.Attr<int>("head_number");
bool split_vertical_y = (mat_dim_a.width_ != mat_dim_b.height_); bool split_vertical_y = (mat_dim_a.width_ != mat_dim_b.height_);
if (head_number > 1) { if (head_number > 1) {
...@@ -210,6 +223,19 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -210,6 +223,19 @@ class MatMulGradKernel : public framework::OpKernel<T> {
auto blas = math::GetBlas<DeviceContext, T>(context); auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b);
int head_number = 1;
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
head_number = context.Attr<int>("head_number");
#endif
if (head_number <= 1 && a.dims().size() == 3 && b.dims().size() <= 2) {
// the transpose_X must be false, if is true, the transpose cost much time
if (!trans_a) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
}
}
blas.MatMul(a, mat_dim_a, b, mat_dim_b, blas.MatMul(a, mat_dim_a, b, mat_dim_b,
static_cast<T>(context.Attr<float>("alpha")), out, T(0)); static_cast<T>(context.Attr<float>("alpha")), out, T(0));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册