未验证 提交 212b51ef 编写于 作者: C cambriconhsq 提交者: GitHub

[MLU] optimize matmul_grad_v2 dy (B,M,K)*(K,N) for better performance (#45336)

上级 ac0a2e50
...@@ -68,6 +68,37 @@ static void MatMul2D(const framework::ExecutionContext& ctx, ...@@ -68,6 +68,37 @@ static void MatMul2D(const framework::ExecutionContext& ctx,
GetBasePtr(Out)); GetBasePtr(Out));
} }
template <typename T>
static void MatMul2DwithReduceBatch(const framework::ExecutionContext& ctx,
const Tensor& X,
const Tensor& Y,
Tensor* Out,
const bool trans_x,
const bool trans_y) {
if (!Out->initialized()) {
Out->mutable_data<T>(ctx.GetPlace());
}
// reshape to 2D matmul
std::vector<int64_t> x_dims = phi::vectorize(X.dims());
std::vector<int64_t> y_dims = phi::vectorize(Y.dims());
std::vector<int> realx_dims(
{static_cast<int>(x_dims[0] * x_dims[1]), static_cast<int>(x_dims[2])});
std::vector<int> realy_dims(
{static_cast<int>(y_dims[0] * y_dims[1]), static_cast<int>(y_dims[2])});
MLUCnnlTensorDesc x_desc(2, realx_dims.data(), ToCnnlDataType<T>());
MLUCnnlTensorDesc y_desc(2, realy_dims.data(), ToCnnlDataType<T>());
MLUCnnlTensorDesc out_desc(*Out, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
MLUCnnl::Matmul(ctx,
trans_x,
trans_y,
x_desc.get(),
GetBasePtr(&X),
y_desc.get(),
GetBasePtr(&Y),
out_desc.get(),
GetBasePtr(Out));
}
template <typename T> template <typename T>
static void MatMulND(const framework::ExecutionContext& ctx, static void MatMulND(const framework::ExecutionContext& ctx,
const Tensor& X, const Tensor& X,
...@@ -333,22 +364,32 @@ class MatMulGradV2MLUKernel : public framework::OpKernel<T> { ...@@ -333,22 +364,32 @@ class MatMulGradV2MLUKernel : public framework::OpKernel<T> {
} }
if (dY) { if (dY) {
Tensor dy_temp(Y->type()); // Case 3: [B, M, K] x [K, N] = [B, M, N] better performance
if (y_dims != y_bcast_dims) { // otherwise, tensor dy_temp in else branch might encounter
dy_temp.Resize(phi::make_ddim(y_bcast_dims)); // numel overflow due to cnnlTensorDescriptor limitation
} else { if (x_dims.size() == 3 && phi::vectorize(Y->dims()).size() == 2) {
dY->mutable_data<T>(ctx.GetPlace()); if (trans_y) {
dy_temp.ShareDataWith(*dY); MatMul2DwithReduceBatch<T>(ctx, dout_temp, x_temp, dY, true, trans_x);
} } else {
MatMul2DwithReduceBatch<T>(
if (trans_y) { ctx, x_temp, dout_temp, dY, !trans_x, false);
MatMulND<T>(ctx, dout_temp, x_temp, &dy_temp, true, trans_x); }
} else { } else {
MatMulND<T>(ctx, x_temp, dout_temp, &dy_temp, !trans_x, false); Tensor dy_temp(Y->type());
} if (y_dims != y_bcast_dims) {
dy_temp.Resize(phi::make_ddim(y_bcast_dims));
if (y_dims != y_bcast_dims) { } else {
ReduceDims<T>(ctx, y_dims, y_bcast_dims, dy_temp, dY); dY->mutable_data<T>(ctx.GetPlace());
dy_temp.ShareDataWith(*dY);
}
if (trans_y) {
MatMulND<T>(ctx, dout_temp, x_temp, &dy_temp, true, trans_x);
} else {
MatMulND<T>(ctx, x_temp, dout_temp, &dy_temp, !trans_x, false);
}
if (y_dims != y_bcast_dims) {
ReduceDims<T>(ctx, y_dims, y_bcast_dims, dy_temp, dY);
}
} }
} }
} }
......
...@@ -264,6 +264,18 @@ class TestMatMuklOp17(TestMatMulV2Op): ...@@ -264,6 +264,18 @@ class TestMatMuklOp17(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp18(TestMatMulV2Op):
"""
case 18 : to check the gradient for special case
"""
def config(self):
self.x_shape = (2, 32, 100)
self.y_shape = (100, 10)
self.trans_x = False
self.trans_y = False
class TestMatMuklOpBroadcast1(TestMatMulV2Op): class TestMatMuklOpBroadcast1(TestMatMulV2Op):
""" """
case 14_3 case 14_3
...@@ -328,6 +340,7 @@ create_test_fp16_class(TestMatMuklOp14) ...@@ -328,6 +340,7 @@ create_test_fp16_class(TestMatMuklOp14)
create_test_fp16_class(TestMatMuklOp15) create_test_fp16_class(TestMatMuklOp15)
create_test_fp16_class(TestMatMuklOp16) create_test_fp16_class(TestMatMuklOp16)
create_test_fp16_class(TestMatMuklOp17) create_test_fp16_class(TestMatMuklOp17)
create_test_fp16_class(TestMatMuklOp18)
class TestMatMulV2API(unittest.TestCase): class TestMatMulV2API(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册