diff --git a/paddle/fluid/eager/amp_auto_cast.h b/paddle/fluid/eager/amp_auto_cast.h index c9cf3e2ee282398d75cf25b7e70825c2d398195f..66080ecef6a6b2497d37f2825a103dc67a0f752d 100644 --- a/paddle/fluid/eager/amp_auto_cast.h +++ b/paddle/fluid/eager/amp_auto_cast.h @@ -75,6 +75,10 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name, input_name != "X") { return input; } + if (op_name == "fused_softmax_mask" && input_name == "Mask" && + input.dtype() == phi::DataType::FLOAT32) { + return input; + } if (dst_dtype == phi::DataType::FLOAT16) { if (op_name == "run_program") { return input; diff --git a/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu index 0902b9448eca6c377e18a34d2f550462ed22ec10..1adadb6c3f62773b15f36d720e45d599f4fb37d9 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu @@ -22,9 +22,9 @@ namespace phi { namespace fusion { // T == fp16 -template +template __global__ void SoftmaxMaskFuseGPUKernel(const T* x_data, - const T* mask_data, + const MT* mask_data, T* y_data, int batch_count, int key_seq_len) { @@ -62,7 +62,7 @@ __global__ void SoftmaxMaskFuseGPUKernel(const T* x_data, // using float for all inter compute float data[kLocalBatchSize][kLocalIterations]; T temp_data[kOneLoadingCounts]; - T temp_mask[kOneLoadingCounts]; + MT temp_mask[kOneLoadingCounts]; #pragma unroll for (int i = 0; i < kLocalBatchSize; ++i) { @@ -151,7 +151,6 @@ void FusedSoftmaxMaskKernel(const Context& dev_ctx, const DenseTensor& mask, DenseTensor* out) { auto* x_data = x.data(); - auto* mask_data = mask.data(); auto* y_data = dev_ctx.template Alloc(out); auto x_dim = x.dims(); @@ -226,46 +225,90 @@ void FusedSoftmaxMaskKernel(const Context& dev_ctx, dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); dim3 threads(warp_size, warps_per_block, 1); - // launch the kernel based on the pow2_index - switch (pow2_index) { - case 5: // 32 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 6: // 64 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 7: // 128 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 8: // 256 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 9: // 512 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 10: // 1024 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 11: // 2048 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 12: // 4096 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 13: // 8192 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - default: - break; + if (mask.dtype() == x.dtype()) { + auto* mask_data = mask.data(); + switch (pow2_index) { + case 5: // 32 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 6: // 64 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 7: // 128 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 8: // 256 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 9: // 512 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 10: // 1024 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 11: // 2048 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 12: // 4096 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 13: // 8192 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + default: + break; + } + } else if (mask.dtype() == phi::DataType::FLOAT32) { + auto* mask_data = mask.data(); + switch (pow2_index) { + case 5: // 32 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 6: // 64 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 7: // 128 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 8: // 256 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 9: // 512 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 10: // 1024 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 11: // 2048 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 12: // 4096 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 13: // 8192 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + default: + break; + } } } diff --git a/test/legacy_test/test_softmax_mask_fuse_op.py b/test/legacy_test/test_softmax_mask_fuse_op.py index 79c6ad8c9352544475304df4399bf185dbfc61ec..56a4ba24a6862d9971c35fc194b691f2b67f3013 100644 --- a/test/legacy_test/test_softmax_mask_fuse_op.py +++ b/test/legacy_test/test_softmax_mask_fuse_op.py @@ -78,6 +78,27 @@ class TestSoftmaxMaskFuseOp0(OpTest): self.check_grad_with_place(core.CUDAPlace(0), ["X"], "Out") +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestSoftmaxMaskFuseOp01(OpTest): + def setUp(self): + self.op_type = "fused_softmax_mask" + self.python_api = paddle.incubate.softmax_mask_fuse + x = np.random.random((1, 1, 8, 32)).astype("float16") + mask = np.random.randint(0, 2, (1, 1, 8, 32)).astype("float32") + mask_input = np.where(mask == 1, -10000.0, mask) + self.inputs = {'X': x, 'Mask': mask_input} + rst = _get_softmax(x, mask_input) + self.outputs = {'Out': rst} + + def test_check_output(self): + self.check_output_with_place(core.CUDAPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place(core.CUDAPlace(0), ["X"], "Out") + + @unittest.skipIf( not core.is_compiled_with_cuda(), "core is not compiled with CUDA" )