From 25a0b46dc6155b058f7a1ef04550bd2dcf65dbd1 Mon Sep 17 00:00:00 2001 From: duanyanhui <45005871+YanhuiDua@users.noreply.github.com> Date: Mon, 4 Sep 2023 10:59:18 +0800 Subject: [PATCH] optimize softmax_mask_fuse (#56877) --- paddle/fluid/eager/amp_auto_cast.h | 4 + .../fusion/gpu/fused_softmax_mask_kernel.cu | 131 ++++++++++++------ test/legacy_test/test_softmax_mask_fuse_op.py | 21 +++ 3 files changed, 112 insertions(+), 44 deletions(-) diff --git a/paddle/fluid/eager/amp_auto_cast.h b/paddle/fluid/eager/amp_auto_cast.h index c9cf3e2ee28..66080ecef6a 100644 --- a/paddle/fluid/eager/amp_auto_cast.h +++ b/paddle/fluid/eager/amp_auto_cast.h @@ -75,6 +75,10 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name, input_name != "X") { return input; } + if (op_name == "fused_softmax_mask" && input_name == "Mask" && + input.dtype() == phi::DataType::FLOAT32) { + return input; + } if (dst_dtype == phi::DataType::FLOAT16) { if (op_name == "run_program") { return input; diff --git a/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu index 0902b9448ec..1adadb6c3f6 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu @@ -22,9 +22,9 @@ namespace phi { namespace fusion { // T == fp16 -template +template __global__ void SoftmaxMaskFuseGPUKernel(const T* x_data, - const T* mask_data, + const MT* mask_data, T* y_data, int batch_count, int key_seq_len) { @@ -62,7 +62,7 @@ __global__ void SoftmaxMaskFuseGPUKernel(const T* x_data, // using float for all inter compute float data[kLocalBatchSize][kLocalIterations]; T temp_data[kOneLoadingCounts]; - T temp_mask[kOneLoadingCounts]; + MT temp_mask[kOneLoadingCounts]; #pragma unroll for (int i = 0; i < kLocalBatchSize; ++i) { @@ -151,7 +151,6 @@ void FusedSoftmaxMaskKernel(const Context& dev_ctx, const DenseTensor& mask, DenseTensor* out) { auto* x_data = x.data(); - auto* mask_data = mask.data(); auto* y_data = dev_ctx.template Alloc(out); auto x_dim = x.dims(); @@ -226,46 +225,90 @@ void FusedSoftmaxMaskKernel(const Context& dev_ctx, dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); dim3 threads(warp_size, warps_per_block, 1); - // launch the kernel based on the pow2_index - switch (pow2_index) { - case 5: // 32 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 6: // 64 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 7: // 128 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 8: // 256 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 9: // 512 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 10: // 1024 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 11: // 2048 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 12: // 4096 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 13: // 8192 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - default: - break; + if (mask.dtype() == x.dtype()) { + auto* mask_data = mask.data(); + switch (pow2_index) { + case 5: // 32 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 6: // 64 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 7: // 128 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 8: // 256 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 9: // 512 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 10: // 1024 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 11: // 2048 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 12: // 4096 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 13: // 8192 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + default: + break; + } + } else if (mask.dtype() == phi::DataType::FLOAT32) { + auto* mask_data = mask.data(); + switch (pow2_index) { + case 5: // 32 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 6: // 64 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 7: // 128 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 8: // 256 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 9: // 512 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 10: // 1024 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 11: // 2048 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 12: // 4096 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 13: // 8192 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + default: + break; + } } } diff --git a/test/legacy_test/test_softmax_mask_fuse_op.py b/test/legacy_test/test_softmax_mask_fuse_op.py index 79c6ad8c935..56a4ba24a68 100644 --- a/test/legacy_test/test_softmax_mask_fuse_op.py +++ b/test/legacy_test/test_softmax_mask_fuse_op.py @@ -78,6 +78,27 @@ class TestSoftmaxMaskFuseOp0(OpTest): self.check_grad_with_place(core.CUDAPlace(0), ["X"], "Out") +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestSoftmaxMaskFuseOp01(OpTest): + def setUp(self): + self.op_type = "fused_softmax_mask" + self.python_api = paddle.incubate.softmax_mask_fuse + x = np.random.random((1, 1, 8, 32)).astype("float16") + mask = np.random.randint(0, 2, (1, 1, 8, 32)).astype("float32") + mask_input = np.where(mask == 1, -10000.0, mask) + self.inputs = {'X': x, 'Mask': mask_input} + rst = _get_softmax(x, mask_input) + self.outputs = {'Out': rst} + + def test_check_output(self): + self.check_output_with_place(core.CUDAPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place(core.CUDAPlace(0), ["X"], "Out") + + @unittest.skipIf( not core.is_compiled_with_cuda(), "core is not compiled with CUDA" ) -- GitLab