From 212b51ef0e041e3b263964ebc4e6bdaf8da43aac Mon Sep 17 00:00:00 2001 From: cambriconhsq <106155938+cambriconhsq@users.noreply.github.com> Date: Mon, 29 Aug 2022 14:48:31 +0800 Subject: [PATCH] [MLU] optimize matmul_grad_v2 dy (B,M,K)*(K,N) for better performance (#45336) --- paddle/fluid/operators/matmul_v2_op_mlu.cc | 71 +++++++++++++++---- .../unittests/mlu/test_matmul_v2_op_mlu.py | 13 ++++ 2 files changed, 69 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/matmul_v2_op_mlu.cc b/paddle/fluid/operators/matmul_v2_op_mlu.cc index 18105fa8998..1ea29500ddc 100644 --- a/paddle/fluid/operators/matmul_v2_op_mlu.cc +++ b/paddle/fluid/operators/matmul_v2_op_mlu.cc @@ -68,6 +68,37 @@ static void MatMul2D(const framework::ExecutionContext& ctx, GetBasePtr(Out)); } +template +static void MatMul2DwithReduceBatch(const framework::ExecutionContext& ctx, + const Tensor& X, + const Tensor& Y, + Tensor* Out, + const bool trans_x, + const bool trans_y) { + if (!Out->initialized()) { + Out->mutable_data(ctx.GetPlace()); + } + // reshape to 2D matmul + std::vector x_dims = phi::vectorize(X.dims()); + std::vector y_dims = phi::vectorize(Y.dims()); + std::vector realx_dims( + {static_cast(x_dims[0] * x_dims[1]), static_cast(x_dims[2])}); + std::vector realy_dims( + {static_cast(y_dims[0] * y_dims[1]), static_cast(y_dims[2])}); + MLUCnnlTensorDesc x_desc(2, realx_dims.data(), ToCnnlDataType()); + MLUCnnlTensorDesc y_desc(2, realy_dims.data(), ToCnnlDataType()); + MLUCnnlTensorDesc out_desc(*Out, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnl::Matmul(ctx, + trans_x, + trans_y, + x_desc.get(), + GetBasePtr(&X), + y_desc.get(), + GetBasePtr(&Y), + out_desc.get(), + GetBasePtr(Out)); +} + template static void MatMulND(const framework::ExecutionContext& ctx, const Tensor& X, @@ -333,22 +364,32 @@ class MatMulGradV2MLUKernel : public framework::OpKernel { } if (dY) { - Tensor dy_temp(Y->type()); - if (y_dims != y_bcast_dims) { - dy_temp.Resize(phi::make_ddim(y_bcast_dims)); - } else { - dY->mutable_data(ctx.GetPlace()); - dy_temp.ShareDataWith(*dY); - } - - if (trans_y) { - MatMulND(ctx, dout_temp, x_temp, &dy_temp, true, trans_x); + // Case 3: [B, M, K] x [K, N] = [B, M, N] better performance + // otherwise, tensor dy_temp in else branch might encounter + // numel overflow due to cnnlTensorDescriptor limitation + if (x_dims.size() == 3 && phi::vectorize(Y->dims()).size() == 2) { + if (trans_y) { + MatMul2DwithReduceBatch(ctx, dout_temp, x_temp, dY, true, trans_x); + } else { + MatMul2DwithReduceBatch( + ctx, x_temp, dout_temp, dY, !trans_x, false); + } } else { - MatMulND(ctx, x_temp, dout_temp, &dy_temp, !trans_x, false); - } - - if (y_dims != y_bcast_dims) { - ReduceDims(ctx, y_dims, y_bcast_dims, dy_temp, dY); + Tensor dy_temp(Y->type()); + if (y_dims != y_bcast_dims) { + dy_temp.Resize(phi::make_ddim(y_bcast_dims)); + } else { + dY->mutable_data(ctx.GetPlace()); + dy_temp.ShareDataWith(*dY); + } + if (trans_y) { + MatMulND(ctx, dout_temp, x_temp, &dy_temp, true, trans_x); + } else { + MatMulND(ctx, x_temp, dout_temp, &dy_temp, !trans_x, false); + } + if (y_dims != y_bcast_dims) { + ReduceDims(ctx, y_dims, y_bcast_dims, dy_temp, dY); + } } } } diff --git a/python/paddle/fluid/tests/unittests/mlu/test_matmul_v2_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_matmul_v2_op_mlu.py index 85c73aa78ce..7c0612cefa0 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_matmul_v2_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_matmul_v2_op_mlu.py @@ -264,6 +264,18 @@ class TestMatMuklOp17(TestMatMulV2Op): self.trans_y = False +class TestMatMuklOp18(TestMatMulV2Op): + """ + case 18 : to check the gradient for special case + """ + + def config(self): + self.x_shape = (2, 32, 100) + self.y_shape = (100, 10) + self.trans_x = False + self.trans_y = False + + class TestMatMuklOpBroadcast1(TestMatMulV2Op): """ case 14_3 @@ -328,6 +340,7 @@ create_test_fp16_class(TestMatMuklOp14) create_test_fp16_class(TestMatMuklOp15) create_test_fp16_class(TestMatMuklOp16) create_test_fp16_class(TestMatMuklOp17) +create_test_fp16_class(TestMatMuklOp18) class TestMatMulV2API(unittest.TestCase): -- GitLab