未验证 提交 444a7358 编写于 作者: L Linjie Chen 提交者: GitHub

Optimize Matmul_v2 (#37037)

Optimize dot product of Matmul_v2 
上级 6b0cc2b1
...@@ -17,9 +17,7 @@ limitations under the License. */ ...@@ -17,9 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/complex_functors.h" #include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/functions/eigen/common.h"
namespace pten { namespace pten {
namespace math { namespace math {
...@@ -105,34 +103,34 @@ void MatMulFunction(const DeviceContext& dev_ctx, ...@@ -105,34 +103,34 @@ void MatMulFunction(const DeviceContext& dev_ctx,
const T* x_data = X.data<T>(); const T* x_data = X.data<T>();
const T* y_data = Y.data<T>(); const T* y_data = Y.data<T>();
auto blas = paddle::operators::math::GetBlas<DeviceContext, T>(dev_ctx);
if (x_ndim == 1 && y_ndim == 1) { if (x_ndim == 1 && y_ndim == 1) {
const int M = X.numel();
const int N = Y.numel();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
X.numel(), M,
Y.numel(), N,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"X's numbers must be equal to Y's numbers," "X's numbers must be equal to Y's numbers,"
"when X/Y's dims =1. But received X has [%d] elements," "when X/Y's dims =1. But received X has [%d] elements,"
"received Y has [%d] elements", "received Y has [%d] elements",
X.numel(), M,
Y.numel())); N));
VLOG(3) << "MatMul's case 1"; VLOG(3) << "MatMul's case 1";
Out->Resize({1}); blas.GEMM(CblasNoTrans,
Out->mutable_data<T>(); CblasTrans,
auto out_eigen = EigenScalar<T>::From(*Out); 1,
auto x_eigen = EigenVector<T>::Flatten(X); 1,
auto y_eigen = EigenVector<T>::Flatten(Y); M,
static_cast<T>(1),
auto& dev = *dev_ctx.eigen_device(); y_data,
if (flag) { x_data,
out_eigen.device(dev) = (x_eigen * y_eigen).sum() + out_eigen; static_cast<T>(flag),
} else { Out->mutable_data<T>());
out_eigen.device(dev) = (x_eigen * y_eigen).sum();
}
return; return;
} }
auto blas = paddle::operators::math::GetBlas<DeviceContext, T>(dev_ctx);
if (x_ndim == 1) { if (x_ndim == 1) {
const int N = X.numel(); const int N = X.numel();
if (trans_y) { if (trans_y) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册