From 7fd6ffb84117477893e32a2c0f9fd392a1343b24 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Mon, 4 Sep 2023 16:21:05 +0800 Subject: [PATCH] add num_splist to support deterministic for flash_attn_bwd and FlashAttnUnpaddedGradKernel (#56363) * add num_splist for flash_attn_bwd and FlashAttnUnpaddedGradKernel * Add assertTrue * Update submodule to a specific commit --- .../phi/kernels/gpu/flash_attn_grad_kernel.cu | 15 +- .../test_flash_attention_deterministic.py | 208 ++++++++++++++++++ third_party/flashattn | 2 +- 3 files changed, 219 insertions(+), 6 deletions(-) create mode 100644 test/legacy_test/test_flash_attention_deterministic.py diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 7b76a5f458d..fae308008b4 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -28,6 +28,11 @@ PD_DECLARE_bool(cudnn_deterministic); namespace phi { +int get_num_split() { + // 0 for an internal heuristic, which is optimal + return FLAGS_cudnn_deterministic ? 1 : 0; +} + template void FlashAttnUnpaddedGradImpl(const Context& ctx, const DenseTensor& q, @@ -236,11 +241,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, const int64_t total_k = k.dims()[0]; const int64_t num_heads_k = k.dims()[1]; - // TODO(umiswing): add deterministic in fa2. - // int num_splits = 0; // 0 for an internal heuristic, which is optimal - // if (FLAGS_cudnn_deterministic) { - // num_splits = 1; - // } + int num_splits = get_num_split(); // TODO(umiswing): add shape check PADDLE_ENFORCE_EQ( @@ -294,6 +295,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, params.scale, params.causal, params.is_bf16, + num_splits, stream, params.seed, params.offset); @@ -401,6 +403,8 @@ void FlashAttnGradKernel(const Context& ctx, VLOG(10) << "FlashAttn bwd seed: " << params.seed << ", offset: " << params.offset; + int num_splits = get_num_split(); + bool succ = phi::dynload::flash_attn_bwd(dout.data(), q.data(), k.data(), @@ -426,6 +430,7 @@ void FlashAttnGradKernel(const Context& ctx, params.scale, params.causal, params.is_bf16, + num_splits, stream, params.seed, params.offset); diff --git a/test/legacy_test/test_flash_attention_deterministic.py b/test/legacy_test/test_flash_attention_deterministic.py new file mode 100644 index 00000000000..1d6b37cbf30 --- /dev/null +++ b/test/legacy_test/test_flash_attention_deterministic.py @@ -0,0 +1,208 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import unittest + +import numpy as np + +import paddle +import paddle.nn.functional as F +from paddle.device import core +from paddle.nn.functional.flash_attention import ( + flash_attention, + scaled_dot_product_attention, +) + + +def get_cuda_version(): + result = os.popen("nvcc --version").read() + regex = r'release (\S+),' + match = re.search(regex, result) + if match: + num = str(match.group(1)) + integer, decimal = num.split('.') + return int(integer) * 1000 + int(float(decimal) * 10) + else: + return -1 + + +def attention_naive(q, k, v, causal=False): + qt = paddle.transpose(q, [0, 2, 1, 3]) + kt = paddle.transpose(k, [0, 2, 1, 3]) + vt = paddle.transpose(v, [0, 2, 1, 3]) + scale = 1.0 / np.sqrt(q.shape[-1]) + s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2])) + s = paddle.scale(s, scale) + p = ( + paddle.incubate.softmax_mask_fuse_upper_triangle(s) + if causal + else F.softmax(s) + ) + o = paddle.matmul(p, vt) + return paddle.transpose(o, [0, 2, 1, 3]) + + +is_sm8x = ( + core.is_compiled_with_cuda() + and paddle.device.cuda.get_device_capability()[0] == 8 + and paddle.device.cuda.get_device_capability()[1] >= 0 +) + +is_sm90 = ( + core.is_compiled_with_cuda() + and paddle.device.cuda.get_device_capability()[0] == 9 + and paddle.device.cuda.get_device_capability()[1] == 0 +) + +is_sm_supported = is_sm8x or is_sm90 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11040 + or not is_sm_supported, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" + "and device's compute capability must be 8.x or 90", +) +class TestFlashAttentionAPIFlag(unittest.TestCase): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (2, 128, 8, 16) + self.dtype = 'float16' + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + self.use_sdp_kernel = False + self.use_sdp_api = False + + def flash_attn_compute(self, query, key, value): + # test dynamic + paddle.disable_static() + + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + q_ = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + k_ = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + v_ = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + if self.use_sdp_kernel: + with paddle.nn.functional.sdp_kernel( + enable_math=self.enable_math, + enable_flash=self.enable_flash, + enable_mem_efficient=self.enable_mem_efficient, + ): + if self.use_sdp_api: + out = scaled_dot_product_attention( + q, k, v, None, self.dropout, self.causal + ) + else: + out, _ = flash_attention( + q, k, v, self.dropout, self.causal, self.return_softmax + ) + + else: + out, _ = flash_attention( + q, k, v, self.dropout, self.causal, self.return_softmax + ) + out_ = attention_naive(q_, k_, v_, self.causal) + + out.backward() + out_.backward() + + self.assertEqual(q.grad.shape, q.shape) + self.assertEqual(q_.grad.shape, q.shape) + + np.testing.assert_allclose( + q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03 + ) + + return out, out_, q.grad.numpy(), k.grad.numpy(), v.grad.numpy() + + def test_all_flag(self): + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + query = np.random.random(self.shape) + key = np.random.random(self.shape) + value = np.random.random(self.shape) + + out1, out1_, q_grad1, k_grad1, v_grad1 = self.flash_attn_compute( + query, key, value + ) + + np.testing.assert_allclose(out1.numpy(), out1_, rtol=5e-03, atol=1e-03) + + out2, out2_, q_grad2, k_grad2, v_grad2 = self.flash_attn_compute( + query, key, value + ) + self.assertTrue(np.equal(out1.numpy(), out2.numpy()).all()) + self.assertTrue(np.equal(q_grad1, q_grad2).all()) + self.assertTrue(np.equal(k_grad1, k_grad2).all()) + self.assertTrue(np.equal(v_grad1, v_grad2).all()) + paddle.set_flags({'FLAGS_cudnn_deterministic': 0}) + + +class TestFlashAttentionAPIFlagTest1(TestFlashAttentionAPIFlag): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (2, 128, 8, 16) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + self.use_sdp_kernel = False + + +class TestFlashAttentionAPIFlagTest2(TestFlashAttentionAPIFlag): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (8, 1024, 16, 256) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + self.use_sdp_kernel = False + + +class TestSDPAttentionAPIFlagTest(TestFlashAttentionAPIFlag): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (8, 1024, 16, 128) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + self.use_sdp_kernel = True + self.use_sdp_api = True + self.enable_math = True + self.enable_flash = False + self.enable_mem_efficient = False + + +if __name__ == '__main__': + unittest.main() diff --git a/third_party/flashattn b/third_party/flashattn index b5bdb79d5e1..e6b9d0d48c2 160000 --- a/third_party/flashattn +++ b/third_party/flashattn @@ -1 +1 @@ -Subproject commit b5bdb79d5e1f2f88b1ef62e86899a14f82fa079a +Subproject commit e6b9d0d48c29f8205b440dede6a48ceb8394383f -- GitLab