未验证 提交 212b51ef 编写于 作者: C cambriconhsq 提交者: GitHub

[MLU] optimize matmul_grad_v2 dy (B,M,K)*(K,N) for better performance (#45336)

上级 ac0a2e50
......@@ -68,6 +68,37 @@ static void MatMul2D(const framework::ExecutionContext& ctx,
GetBasePtr(Out));
}
template <typename T>
static void MatMul2DwithReduceBatch(const framework::ExecutionContext& ctx,
const Tensor& X,
const Tensor& Y,
Tensor* Out,
const bool trans_x,
const bool trans_y) {
if (!Out->initialized()) {
Out->mutable_data<T>(ctx.GetPlace());
}
// reshape to 2D matmul
std::vector<int64_t> x_dims = phi::vectorize(X.dims());
std::vector<int64_t> y_dims = phi::vectorize(Y.dims());
std::vector<int> realx_dims(
{static_cast<int>(x_dims[0] * x_dims[1]), static_cast<int>(x_dims[2])});
std::vector<int> realy_dims(
{static_cast<int>(y_dims[0] * y_dims[1]), static_cast<int>(y_dims[2])});
MLUCnnlTensorDesc x_desc(2, realx_dims.data(), ToCnnlDataType<T>());
MLUCnnlTensorDesc y_desc(2, realy_dims.data(), ToCnnlDataType<T>());
MLUCnnlTensorDesc out_desc(*Out, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
MLUCnnl::Matmul(ctx,
trans_x,
trans_y,
x_desc.get(),
GetBasePtr(&X),
y_desc.get(),
GetBasePtr(&Y),
out_desc.get(),
GetBasePtr(Out));
}
template <typename T>
static void MatMulND(const framework::ExecutionContext& ctx,
const Tensor& X,
......@@ -333,22 +364,32 @@ class MatMulGradV2MLUKernel : public framework::OpKernel<T> {
}
if (dY) {
Tensor dy_temp(Y->type());
if (y_dims != y_bcast_dims) {
dy_temp.Resize(phi::make_ddim(y_bcast_dims));
} else {
dY->mutable_data<T>(ctx.GetPlace());
dy_temp.ShareDataWith(*dY);
}
if (trans_y) {
MatMulND<T>(ctx, dout_temp, x_temp, &dy_temp, true, trans_x);
// Case 3: [B, M, K] x [K, N] = [B, M, N] better performance
// otherwise, tensor dy_temp in else branch might encounter
// numel overflow due to cnnlTensorDescriptor limitation
if (x_dims.size() == 3 && phi::vectorize(Y->dims()).size() == 2) {
if (trans_y) {
MatMul2DwithReduceBatch<T>(ctx, dout_temp, x_temp, dY, true, trans_x);
} else {
MatMul2DwithReduceBatch<T>(
ctx, x_temp, dout_temp, dY, !trans_x, false);
}
} else {
MatMulND<T>(ctx, x_temp, dout_temp, &dy_temp, !trans_x, false);
}
if (y_dims != y_bcast_dims) {
ReduceDims<T>(ctx, y_dims, y_bcast_dims, dy_temp, dY);
Tensor dy_temp(Y->type());
if (y_dims != y_bcast_dims) {
dy_temp.Resize(phi::make_ddim(y_bcast_dims));
} else {
dY->mutable_data<T>(ctx.GetPlace());
dy_temp.ShareDataWith(*dY);
}
if (trans_y) {
MatMulND<T>(ctx, dout_temp, x_temp, &dy_temp, true, trans_x);
} else {
MatMulND<T>(ctx, x_temp, dout_temp, &dy_temp, !trans_x, false);
}
if (y_dims != y_bcast_dims) {
ReduceDims<T>(ctx, y_dims, y_bcast_dims, dy_temp, dY);
}
}
}
}
......
......@@ -264,6 +264,18 @@ class TestMatMuklOp17(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp18(TestMatMulV2Op):
"""
case 18 : to check the gradient for special case
"""
def config(self):
self.x_shape = (2, 32, 100)
self.y_shape = (100, 10)
self.trans_x = False
self.trans_y = False
class TestMatMuklOpBroadcast1(TestMatMulV2Op):
"""
case 14_3
......@@ -328,6 +340,7 @@ create_test_fp16_class(TestMatMuklOp14)
create_test_fp16_class(TestMatMuklOp15)
create_test_fp16_class(TestMatMuklOp16)
create_test_fp16_class(TestMatMuklOp17)
create_test_fp16_class(TestMatMuklOp18)
class TestMatMulV2API(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册