未验证 提交 444a7358 编写于 作者: L Linjie Chen 提交者: GitHub

Optimize Matmul_v2 (#37037)

Optimize dot product of Matmul_v2 
上级 6b0cc2b1
......@@ -17,9 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/functions/eigen/common.h"
namespace pten {
namespace math {
......@@ -105,34 +103,34 @@ void MatMulFunction(const DeviceContext& dev_ctx,
const T* x_data = X.data<T>();
const T* y_data = Y.data<T>();
auto blas = paddle::operators::math::GetBlas<DeviceContext, T>(dev_ctx);
if (x_ndim == 1 && y_ndim == 1) {
const int M = X.numel();
const int N = Y.numel();
PADDLE_ENFORCE_EQ(
X.numel(),
Y.numel(),
M,
N,
paddle::platform::errors::InvalidArgument(
"X's numbers must be equal to Y's numbers,"
"when X/Y's dims =1. But received X has [%d] elements,"
"received Y has [%d] elements",
X.numel(),
Y.numel()));
M,
N));
VLOG(3) << "MatMul's case 1";
Out->Resize({1});
Out->mutable_data<T>();
auto out_eigen = EigenScalar<T>::From(*Out);
auto x_eigen = EigenVector<T>::Flatten(X);
auto y_eigen = EigenVector<T>::Flatten(Y);
auto& dev = *dev_ctx.eigen_device();
if (flag) {
out_eigen.device(dev) = (x_eigen * y_eigen).sum() + out_eigen;
} else {
out_eigen.device(dev) = (x_eigen * y_eigen).sum();
}
blas.GEMM(CblasNoTrans,
CblasTrans,
1,
1,
M,
static_cast<T>(1),
y_data,
x_data,
static_cast<T>(flag),
Out->mutable_data<T>());
return;
}
auto blas = paddle::operators::math::GetBlas<DeviceContext, T>(dev_ctx);
if (x_ndim == 1) {
const int N = X.numel();
if (trans_y) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册