diff --git a/paddle/pten/kernels/functions/math/matmul_func.h b/paddle/pten/kernels/functions/math/matmul_func.h index b5ddd26a95576f0323e9cd62e178d295e5a3ced0..8aa8750aba4180b401fdbf639e956572dc25de17 100644 --- a/paddle/pten/kernels/functions/math/matmul_func.h +++ b/paddle/pten/kernels/functions/math/matmul_func.h @@ -17,9 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/complex_functors.h" -#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/pten/core/dense_tensor.h" -#include "paddle/pten/kernels/functions/eigen/common.h" namespace pten { namespace math { @@ -105,34 +103,34 @@ void MatMulFunction(const DeviceContext& dev_ctx, const T* x_data = X.data(); const T* y_data = Y.data(); + auto blas = paddle::operators::math::GetBlas(dev_ctx); + if (x_ndim == 1 && y_ndim == 1) { + const int M = X.numel(); + const int N = Y.numel(); PADDLE_ENFORCE_EQ( - X.numel(), - Y.numel(), + M, + N, paddle::platform::errors::InvalidArgument( "X's numbers must be equal to Y's numbers," "when X/Y's dims =1. But received X has [%d] elements," "received Y has [%d] elements", - X.numel(), - Y.numel())); + M, + N)); VLOG(3) << "MatMul's case 1"; - Out->Resize({1}); - Out->mutable_data(); - auto out_eigen = EigenScalar::From(*Out); - auto x_eigen = EigenVector::Flatten(X); - auto y_eigen = EigenVector::Flatten(Y); - - auto& dev = *dev_ctx.eigen_device(); - if (flag) { - out_eigen.device(dev) = (x_eigen * y_eigen).sum() + out_eigen; - } else { - out_eigen.device(dev) = (x_eigen * y_eigen).sum(); - } + blas.GEMM(CblasNoTrans, + CblasTrans, + 1, + 1, + M, + static_cast(1), + y_data, + x_data, + static_cast(flag), + Out->mutable_data()); return; } - auto blas = paddle::operators::math::GetBlas(dev_ctx); - if (x_ndim == 1) { const int N = X.numel(); if (trans_y) {