From f9d5ae4e275a6a055f931a6fa1f2db50ccb8ef95 Mon Sep 17 00:00:00 2001 From: WangXi Date: Mon, 16 May 2022 11:18:23 +0800 Subject: [PATCH] fused_multi_transformer add fused softmax mask (#42636) --- paddle/fluid/operators/fused/fmha_ref.h | 29 ++- .../fused/fused_multi_transformer_op.cu | 14 +- .../operators/fused/fused_softmax_mask.cu.h | 204 ++++++++++++++++++ .../test_fused_multi_transformer_op.py | 8 +- 4 files changed, 235 insertions(+), 20 deletions(-) create mode 100644 paddle/fluid/operators/fused/fused_softmax_mask.cu.h diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 6eb5881112..3d75d127ab 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/dropout_impl.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" +#include "paddle/fluid/operators/fused/fused_softmax_mask.cu.h" #include "paddle/fluid/operators/transpose_op.cu.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" @@ -148,18 +149,24 @@ class FMHARef { stride_b); int softmax_axis = -1; if (src_mask_tensor != nullptr) { - std::vector ins; - std::vector outs; - ins.emplace_back(qk_out_tensor); - ins.emplace_back(src_mask_tensor); - outs.emplace_back(src_mask_out_tensor); - int elewise_add_axis = -1; - paddle::operators::LaunchElementwiseCudaKernel( - dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor()); + if (src_mask_out_tensor == nullptr && seq_len_ == out_seq_len) { + LaunchFusedSoftmaxMaskKernel(qk_out_data, src_mask_tensor->data(), + softmax_out_data, batch_size_, + num_head_, seq_len_, dev_ctx_.stream()); + } else { + std::vector ins; + std::vector outs; + ins.emplace_back(qk_out_tensor); + ins.emplace_back(src_mask_tensor); + outs.emplace_back(src_mask_out_tensor); + int elewise_add_axis = -1; + paddle::operators::LaunchElementwiseCudaKernel( + dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor()); - phi::SoftmaxForwardCUDAKernelDriver(dev_ctx_, *src_mask_out_tensor, - softmax_axis, softmax_out_tensor); + phi::SoftmaxForwardCUDAKernelDriver( + dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); + } } else { phi::SoftmaxForwardCUDAKernelDriver(dev_ctx_, *qk_out_tensor, softmax_axis, softmax_out_tensor); diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index e38ac9a0ad..fdd0208c3d 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -1084,11 +1084,9 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto *qk_out_data = qk_out.mutable_data({bsz, num_head, seq_len, out_seq_len}, place); - Tensor src_mask_out, softmax_out; + Tensor softmax_out; Tensor attn_dropout_mask_out, attn_dropout_out; Tensor qktv_out, fmha_out; - auto *src_mask_out_data = src_mask_out.mutable_data( - {bsz, num_head, seq_len, out_seq_len}, place); auto *softmax_out_data = softmax_out.mutable_data( {bsz, num_head, seq_len, out_seq_len}, place); @@ -1219,10 +1217,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { 1. / sqrt(dim_head)); } else if (cache_kv_out) { // generation context stage // TODO(wangxi): can remove dropout in inference - fmha_compute.ComputeForward( - qkv_out, nullptr, src_mask, &transpose_out_2, nullptr, &qk_out, - &src_mask_out, &softmax_out, &attn_dropout_mask_out, - &attn_dropout_out, &qktv_out, &fmha_out); + fmha_compute.ComputeForward(qkv_out, nullptr, src_mask, + &transpose_out_2, nullptr, &qk_out, nullptr, + &softmax_out, &attn_dropout_mask_out, + &attn_dropout_out, &qktv_out, &fmha_out); // [3, bsz, num_head, seq_len, head_dim] T *qkv_data = transpose_out_2_data; int64_t q_size = bsz * seq_len * num_head * dim_head; @@ -1245,7 +1243,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { // TODO(wangxi): can remove dropout in inference fmha_compute.ComputeForward( qkv_out, cache_kv, src_mask, &transpose_out_2, cache_kv_out, - &qk_out, &src_mask_out, &softmax_out, &attn_dropout_mask_out, + &qk_out, nullptr, &softmax_out, &attn_dropout_mask_out, &attn_dropout_out, &qktv_out, &fmha_out); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER diff --git a/paddle/fluid/operators/fused/fused_softmax_mask.cu.h b/paddle/fluid/operators/fused/fused_softmax_mask.cu.h new file mode 100644 index 0000000000..11f1011dec --- /dev/null +++ b/paddle/fluid/operators/fused/fused_softmax_mask.cu.h @@ -0,0 +1,204 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +namespace plat = paddle::platform; + +#define FINAL_MASK 0xffffffff +#define DIV_UP(x, y) (((x) + (y)-1) / (y)) + +template +__inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); + return val; +} + +template +__inline__ __device__ T warpReduceMax(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); + return val; +} + +inline int ElementsCeil(int seq_len) { + int elements = 1; + while (elements * 32 < seq_len) elements *= 2; + return elements; +} + +template +__global__ void FusedSoftmaxMaskVecKernel(T* dst, const T* src, const T* mask, + int seq_len) { + constexpr int block_size = 128; + constexpr int warp_size = 32; + constexpr int warps_per_block = block_size / warp_size; + + // blockDim/threadIdx = (warp_size, warps_per_block) + // gridDim/blockIdx = (DIV_UP(seq_len, warps_per_block), batch_size, head_num) + // every block processes 4(warps_per_block) sequences + // seq_id = seq_id * 4 + warp_id, eg.seq_len=128, 127=31*4+3 + int seq_id = blockIdx.x * warps_per_block + threadIdx.y; + if (seq_id >= seq_len) return; + + // ((bid*head_num + hid)*seq_len + seq_id) * seq_len + int offset = + ((blockIdx.y * gridDim.z + blockIdx.z) * seq_len + seq_id) * seq_len; + // (bid * seq_len + seq_id) * seq_len + int mask_offset = (blockIdx.y * seq_len + seq_id) * seq_len; + src += offset; + dst += offset; + mask += mask_offset; + + static_assert(ELEMENTS_PER_THREADS % VEC_SIZE == 0, ""); + constexpr int VEC_NUMS = ELEMENTS_PER_THREADS / VEC_SIZE; + using VecT = phi::AlignedVector; + + VecT elements[VEC_NUMS]; + VecT tmp_mask; + float max_val = -std::numeric_limits::infinity(); + + for (int i = 0; (i * warp_size + threadIdx.x) * VEC_SIZE < seq_len; ++i) { + phi::Load(src + (i * warp_size + threadIdx.x) * VEC_SIZE, &elements[i]); + phi::Load(mask + (i * warp_size + threadIdx.x) * VEC_SIZE, &tmp_mask); +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + // TODO(wangxi): vec add + elements[i][j] += tmp_mask[j]; + max_val = max(max_val, static_cast(elements[i][j])); + } + } + max_val = warpReduceMax(max_val); + + float sum_val = 0; + for (int i = 0; (i * warp_size + threadIdx.x) * VEC_SIZE < seq_len; ++i) { +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + float tmp = __expf(static_cast(elements[i][j]) - max_val); + sum_val += tmp; + elements[i][j] = static_cast(tmp); + } + } + sum_val = warpReduceSum(sum_val); + float mean_val = __fdividef(1.0f, sum_val + 1e-6f); + + for (int i = 0; (i * warp_size + threadIdx.x) * VEC_SIZE < seq_len; ++i) { +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + float tmp = static_cast(elements[i][j]) * mean_val; + elements[i][j] = static_cast(tmp); + } + phi::Store(elements[i], dst + (i * warp_size + threadIdx.x) * VEC_SIZE); + } +} + +#define SOFTMAX_MASK_KERNEL(VEC_SIZE, ELEMENTS) \ + FusedSoftmaxMaskVecKernel<<>>( \ + dst, src, mask, seq_len) + +// FIXME(wangxi): It is found that the performance of VEC_SIZE=2 is better +// than that of =4 and =8. Further analysis of the kernel is needed later. +// #define SELECT_SOFTMAX_MASK_KERNEL(ELEMENTS) \ +// do { \ +// if (sizeof(T) == 2 && seq_len % 8 == 0) { \ +// FusedSoftmaxMaskVecKernel \ +// <<>>( \ +// (plat::float16*)dst, (const plat::float16*)src, mask, seq_len); \ +// } \ +// else if (seq_len % 4 == 0) SOFTMAX_MASK_KERNEL(4, ELEMENTS); \ +// else if (seq_len % 2 == 0) SOFTMAX_MASK_KERNEL(2, ELEMENTS); \ +// else SOFTMAX_MASK_KERNEL(1, ELEMENTS); \ +// } while(0) + +#define SELECT_SOFTMAX_MASK_KERNEL(ELEMENTS) \ + do { \ + if (seq_len % 2 == 0) { \ + SOFTMAX_MASK_KERNEL(2, ELEMENTS); \ + } else { \ + SOFTMAX_MASK_KERNEL(1, ELEMENTS); \ + } \ + } while (0) + +#define CASE_SOFTMAX_MASK_KERNEL(ELEMENTS) \ + case ELEMENTS: { \ + SELECT_SOFTMAX_MASK_KERNEL(ELEMENTS); \ + break; \ + } + +// template +template +void LaunchFusedSoftmaxMaskKernel(const T* src, const T* mask, T* dst, + const int batch_size, const int head_num, + const int seq_len, cudaStream_t stream) { + PADDLE_ENFORCE_EQ( + seq_len > 0 && seq_len <= 4096, true, + platform::errors::InvalidArgument("seq_len must be between (0, 4096] " + "received the seq_len is %d", + seq_len)); + + constexpr int block_size = 128; + constexpr int warp_size = 32; + constexpr int warps_per_block = block_size / warp_size; + + // put head_num to the outside for mask + dim3 block(warp_size, warps_per_block); + dim3 grid(DIV_UP(seq_len, warps_per_block), batch_size, head_num); + + // clang-format off + int elements = ElementsCeil(seq_len); + switch (elements) { + case 1: { // <=32 + SOFTMAX_MASK_KERNEL(1, 1); + break; + } + case 2: { // <=64 + // if (seq_len % 2 == 0) SOFTMAX_MASK_KERNEL(2, 2); + // else SOFTMAX_MASK_KERNEL(1, 2); + SELECT_SOFTMAX_MASK_KERNEL(2); + break; + } + case 4: { // <=128 + // if (seq_len % 4 == 0) SOFTMAX_MASK_KERNEL(4, 4); + // else if (seq_len % 2 == 0) SOFTMAX_MASK_KERNEL(2, 4); + // else SOFTMAX_MASK_KERNEL(1, 4); + SELECT_SOFTMAX_MASK_KERNEL(4); + break; + } + CASE_SOFTMAX_MASK_KERNEL(8); // <=256 + CASE_SOFTMAX_MASK_KERNEL(16); // <=512 + CASE_SOFTMAX_MASK_KERNEL(32); // <=1024 + CASE_SOFTMAX_MASK_KERNEL(64); // <=2048 + CASE_SOFTMAX_MASK_KERNEL(128); // <=4096 + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "seq_len must be between (0, 4096], received the seq_len is %d", + seq_len)); + } + // clang-format on +} + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py index 8f77972de8..67f382a439 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py @@ -109,6 +109,7 @@ class TestFusedMultiTransformerOp(OpTest): self.x_type = np.float32 self.attn_mask_type = np.float64 + #self.attn_mask_type = np.bool self.pre_layer_norm = True self.has_attn_mask = True @@ -168,6 +169,11 @@ class TestFusedMultiTransformerOp(OpTest): self.attn_mask = (self.attn_mask - 1.0) * 1e4 else: self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e4 + elif self.attn_mask_type == np.bool: + if self.has_cache_kv and not self.gen_cache_kv: + self.attn_mask[:, :, :, -2] = 0 + else: + self.attn_mask = np.tril(self.attn_mask) else: raise ValueError( "'attn_mask_type' should be 'int64' or 'float64'.") @@ -394,7 +400,7 @@ class TestFusedMultiTransformerOp(OpTest): epsilon = 1e-05 ln2_epsilon = 1e-05 - if attn_mask is not None: + if attn_mask is not None and self.attn_mask_type != np.bool: attn_mask = _convert_attention_mask(attn_mask, x.dtype) qkv_weights, qkv_biases = [], [] -- GitLab