From 6e40fc1da002db85debae403741458ba5968cc15 Mon Sep 17 00:00:00 2001 From: Sonder <55493212+AndSonder@users.noreply.github.com> Date: Mon, 14 Aug 2023 10:58:58 +0800 Subject: [PATCH] [Fluid] Move fused_softmax_mask_upper_triangle to phi (#55769) --- .../fused_softmax_mask_upper_triangle_op.cc | 8 - .../fused_softmax_mask_upper_triangle_op.cu | 621 ------------------ .../fused_softmax_mask_upper_triangle_op.h | 31 - ...fused_softmax_mask_upper_triangle_kernel.h | 32 + ...used_softmax_mask_upper_triangle_kernel.cc | 41 ++ ...softmax_mask_upper_triangle_grad_kernel.cu | 266 ++++++++ ...used_softmax_mask_upper_triangle_kernel.cu | 261 ++++++++ .../fused_softmax_mask_upper_triangle_utils.h | 117 ++++ .../fused_softmax_mask_upper_triangle_sig.cc | 40 ++ ...est_softmax_mask_fuse_upper_triangle_op.py | 2 + 10 files changed, 759 insertions(+), 660 deletions(-) delete mode 100644 paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu delete mode 100644 paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.h create mode 100644 paddle/phi/kernels/fused_softmax_mask_upper_triangle_kernel.h create mode 100644 paddle/phi/kernels/fusion/cpu/fused_softmax_mask_upper_triangle_kernel.cc create mode 100644 paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_grad_kernel.cu create mode 100644 paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_kernel.cu create mode 100644 paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_utils.h create mode 100644 paddle/phi/ops/compat/fused_softmax_mask_upper_triangle_sig.cc diff --git a/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cc b/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cc index 0ed8ac2a54b..12c8ec9b81d 100644 --- a/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cc +++ b/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cc @@ -10,7 +10,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/core/generator.h" @@ -102,10 +101,3 @@ REGISTER_OPERATOR( ops::SoftmaxMaskFuseUpperTriangleGradOpMaker); REGISTER_OPERATOR(fused_softmax_mask_upper_triangle_grad, ops::SoftmaxMaskFuseUpperTriangleOpGrad); - -PD_REGISTER_STRUCT_KERNEL(fused_softmax_mask_upper_triangle, - CPU, - ALL_LAYOUT, - ops::SoftmaxMaskFuseUpperTriangleCPUKernel, - float, - double) {} diff --git a/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu b/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu deleted file mode 100644 index 779ee234071..00000000000 --- a/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu +++ /dev/null @@ -1,621 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -// this file is inspired by: -// https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h -/* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifdef PADDLE_WITH_CUDA -#include -#include -#endif -#ifdef PADDLE_WITH_HIP -#include -#include -#endif -#include -#include -#include -#include - -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/phi/core/generator.h" - -namespace paddle { -namespace operators { - -#ifdef PADDLE_WITH_HIP -#define WARP_SIZE 64 -#else -#define WARP_SIZE 32 -#endif - -#define MASK 0xffffffff - -namespace plat = paddle::platform; - -__device__ __inline__ void load_data_upper_tri(plat::float16* dst, - const plat::float16* src) { - *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); -} - -__device__ __inline__ void load_data_upper_tri(plat::bfloat16* dst, - const plat::bfloat16* src) { - *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); -} - -__device__ __inline__ void load_data_upper_tri(float* dst, const float* src) { - *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); -} - -__device__ __inline__ void load_zero_vector_upper_tri(plat::float16* dst) { - *(reinterpret_cast(dst)) = make_float2(0.0f, 0.0f); -} - -__device__ __inline__ void load_zero_vector_upper_tri(plat::bfloat16* dst) { - *(reinterpret_cast(dst)) = make_float2(0.0f, 0.0f); -} - -__device__ __inline__ void load_zero_vector_upper_tri(float* dst) { - *(reinterpret_cast(dst)) = make_float4(0.0f, 0.0f, 0.0f, 0.0f); -} - -int get_pow2_index_value(int value) { - int pow2_index = 0; - while ((1 << pow2_index) < value) { - ++pow2_index; - } - return pow2_index; -} - -template -struct AddOP_upper_tri { - __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } -}; - -template -struct MaxOP_upper_tri { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T warp_shfl_xor_upper_tri(T value, - int laneMask, - int width, - unsigned int mask = MASK) { -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce_upper_tri(T* sum) { - ReduceOp r; -#pragma unroll - for (int offset = width / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < batch; ++i) { - T b = warp_shfl_xor_upper_tri(sum[i], offset, width); - sum[i] = r(sum[i], b); - } - } -} - -template -__global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src, - T* dst, - int64_t batch_count, - int64_t key_seq_len) { - constexpr int next_pow2 = 1 << pow2_index; - constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; - constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4); - constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1; - constexpr int kOneLoadingCounts = 4; - int64_t key_seq_len_pow_2 = key_seq_len * key_seq_len; - - int64_t first_idx = - (static_cast(blockDim.y) * blockIdx.y + threadIdx.y) * - gridDim.x * kLocalBatchSize + - blockIdx.x; - int64_t local_block_idx = blockIdx.x + 1; - int64_t warp_iter_upper_bound = - (local_block_idx + kOneLoadingCounts * warp_size - 1) / warp_size; - - int64_t local_batches = batch_count - first_idx; - if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize; - - int64_t local_idx = threadIdx.x; - - src += first_idx * key_seq_len + kOneLoadingCounts * local_idx; - dst += first_idx * key_seq_len + kOneLoadingCounts * local_idx; - - float data[kLocalBatchSize][kLocalIterations]; - T temp_in[kOneLoadingCounts]; - -#pragma unroll - for (int i = 0; i < kLocalBatchSize; ++i) { - auto batch_total_number = (i >= local_batches) ? 0 : local_block_idx; - -#pragma unroll - for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { - auto element_index = kOneLoadingCounts * local_idx + ii * warp_size; - - if (element_index < batch_total_number) { - load_data_upper_tri(temp_in, - src + i * key_seq_len_pow_2 + ii * warp_size); - -#pragma unroll - for (int counter = 0; counter < kOneLoadingCounts; ++counter) { - if ((element_index + counter) < batch_total_number) { - data[i][ii + counter] = static_cast(temp_in[counter]); - } else { - data[i][ii + counter] = -std::numeric_limits::infinity(); - } - } - } else { -#pragma unroll - for (int counter = 0; counter < kOneLoadingCounts; ++counter) { - data[i][ii + counter] = -std::numeric_limits::infinity(); - } - } - } - } - - float max_value[kLocalBatchSize]; -#pragma unroll - for (int i = 0; i < kLocalBatchSize; ++i) { - max_value[i] = data[i][0]; -#pragma unroll - for (int ii = 1; ii < kLocalIterations; ++ii) { - max_value[i] = (max_value[i] > data[i][ii]) ? max_value[i] : data[i][ii]; - } - } - warp_reduce_upper_tri( - max_value); - - float sum[kLocalBatchSize]{0.0f}; -#pragma unroll - for (int i = 0; i < kLocalBatchSize; ++i) { -#pragma unroll - for (int ii = 0; ii < kLocalIterations; ++ii) { - if (ii < warp_iter_upper_bound) { - data[i][ii] = std::exp((data[i][ii] - max_value[i])); - sum[i] += data[i][ii]; - } - } - } - warp_reduce_upper_tri( - sum); - - T out[kOneLoadingCounts]; -#pragma unroll - for (int i = 0; i < kLocalBatchSize; ++i) { - if (i >= local_batches) break; -#pragma unroll - for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { - auto element_index = kOneLoadingCounts * local_idx + ii * warp_size; - - if (element_index < local_block_idx) { -#pragma unroll - for (int counter = 0; counter < kOneLoadingCounts; ++counter) { - if (element_index + counter < local_block_idx) { - out[counter] = data[i][ii + counter] / sum[i]; - } else { - out[counter] = 0; - } - } - load_data_upper_tri(dst + i * key_seq_len_pow_2 + ii * warp_size, out); - } else if (element_index < key_seq_len) { - load_zero_vector_upper_tri(dst + i * key_seq_len_pow_2 + - ii * warp_size); - } else { - break; - } - } - } -} - -template -__global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input, - T* grad_output, - const T* softmax_rst, - int64_t batch_count, - int64_t key_seq_len) { - constexpr int next_pow2 = 1 << pow2_index; - constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; - constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4); - constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1; - constexpr int kOneLoadingCounts = 4; - int64_t key_seq_len_pow_2 = key_seq_len * key_seq_len; - - int64_t first_idx = - (static_cast(blockDim.y) * blockIdx.y + threadIdx.y) * - gridDim.x * kLocalBatchSize + - blockIdx.x; - int64_t local_block_idx = blockIdx.x + 1; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int64_t local_batches = batch_count - first_idx; - if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize; - - // there might be multiple batches per warp. compute the index within the - // batch - int64_t local_idx = threadIdx.x; - - // the first element to process by the current thread - int64_t offset = first_idx * key_seq_len + kOneLoadingCounts * local_idx; - grad_input += offset; - grad_output += offset; - softmax_rst += offset; - - // load data from global memory - float grad_input_reg[kLocalBatchSize][kLocalIterations]{0.0f}; - float softmax_rst_reg[kLocalBatchSize][kLocalIterations]{0.0f}; - T temp_grad_input[kOneLoadingCounts]; - T temp_softmax_rst[kOneLoadingCounts]; - -#pragma unroll - for (int i = 0; i < kLocalBatchSize; ++i) { - auto batch_total_number = (i >= local_batches) ? 0 : local_block_idx; - -#pragma unroll - for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { - auto element_index = kOneLoadingCounts * local_idx + ii * warp_size; - if (element_index < batch_total_number) { - load_data_upper_tri( - temp_grad_input, - grad_input + i * key_seq_len_pow_2 + ii * warp_size); - load_data_upper_tri( - temp_softmax_rst, - softmax_rst + i * key_seq_len_pow_2 + ii * warp_size); - -#pragma unroll - for (int counter = 0; counter < kOneLoadingCounts; ++counter) { - if (element_index + counter < batch_total_number) { - softmax_rst_reg[i][ii + counter] = - static_cast(temp_softmax_rst[counter]); - } - } -#pragma unroll - for (int counter = 0; counter < kOneLoadingCounts; ++counter) { - if (element_index + counter < batch_total_number) { - grad_input_reg[i][ii + counter] = - static_cast(temp_grad_input[counter]) * - softmax_rst_reg[i][ii + counter]; - } - } - } - } - } - - float sum[kLocalBatchSize]; -#pragma unroll - for (int i = 0; i < kLocalBatchSize; ++i) { - sum[i] = grad_input_reg[i][0]; -#pragma unroll - for (int ii = 1; ii < kLocalIterations; ++ii) { - sum[i] += grad_input_reg[i][ii]; - } - } - warp_reduce_upper_tri( - sum); - -#pragma unroll - for (int i = 0; i < kLocalBatchSize; ++i) { - if (i >= local_batches) break; -#pragma unroll - for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { - auto element_index = kOneLoadingCounts * local_idx + ii * warp_size; - if (element_index < key_seq_len) { - // compute gradients - T samples_out[kOneLoadingCounts]; -#pragma unroll - for (int counter = 0; counter < kOneLoadingCounts; ++counter) { - samples_out[counter] = grad_input_reg[i][ii + counter] - - softmax_rst_reg[i][ii + counter] * sum[i]; - } - load_data_upper_tri( - grad_output + i * key_seq_len_pow_2 + ii * warp_size, samples_out); - } - } - } -} - -template -class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* y = context.Output("Out"); - - auto* x_data = x->data(); - auto* y_data = y->mutable_data(context.GetPlace()); - - auto x_dim = x->dims(); - auto batches = x_dim[0]; - auto attn_heads = x_dim[1]; - auto attn_mul_batch = batches * attn_heads; - auto query_seq_len = x_dim[2]; - auto key_seq_len = x_dim[3]; - - PADDLE_ENFORCE_EQ(key_seq_len, - query_seq_len, - platform::errors::InvalidArgument( - "Key seq len must be equal with query seq len " - "received key len: %d, query len: %d", - key_seq_len, - query_seq_len)); - - PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len <= 16384, - true, - platform::errors::InvalidArgument( - "Input x's last dim must be between [32, 16384] " - "received the last dimension of x is %d", - key_seq_len)); - - auto& place = - *context.template device_context().eigen_device(); - auto stream = context.cuda_device_context().stream(); - - int pow2_index = get_pow2_index_value(key_seq_len); - const int next_pow2 = 1 << pow2_index; - int64_t batch_count = attn_mul_batch * query_seq_len; - int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; - int batches_per_warp = (next_pow2 <= 128) ? 2 : 1; - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - PADDLE_ENFORCE_EQ( - query_seq_len % batches_per_block, - 0, - platform::errors::InvalidArgument( - "The query seq len (third dim of input X) must can divide the " - "number of batches per block. The query seq len is %d, while " - "the number of batches per block is %d.", - query_seq_len, - batches_per_block)); - dim3 blocks(query_seq_len, - (attn_mul_batch + batches_per_block) / batches_per_block, - 1); - dim3 threads(warp_size, warps_per_block, 1); - - switch (pow2_index) { - case 5: // 32 - SoftmaxMaskFuseUpperTriangleGPUKernel - <<>>( - x_data, y_data, batch_count, key_seq_len); - break; - case 6: // 64 - SoftmaxMaskFuseUpperTriangleGPUKernel - <<>>( - x_data, y_data, batch_count, key_seq_len); - break; - case 7: // 128 - SoftmaxMaskFuseUpperTriangleGPUKernel - <<>>( - x_data, y_data, batch_count, key_seq_len); - break; - case 8: // 256 - SoftmaxMaskFuseUpperTriangleGPUKernel - <<>>( - x_data, y_data, batch_count, key_seq_len); - break; - case 9: // 512 - SoftmaxMaskFuseUpperTriangleGPUKernel - <<>>( - x_data, y_data, batch_count, key_seq_len); - break; - case 10: // 1024 - SoftmaxMaskFuseUpperTriangleGPUKernel - <<>>( - x_data, y_data, batch_count, key_seq_len); - break; - case 11: // 2048 - SoftmaxMaskFuseUpperTriangleGPUKernel - <<>>( - x_data, y_data, batch_count, key_seq_len); - break; - case 12: // 4096 - SoftmaxMaskFuseUpperTriangleGPUKernel - <<>>( - x_data, y_data, batch_count, key_seq_len); - break; - case 13: // 8192 - SoftmaxMaskFuseUpperTriangleGPUKernel - <<>>( - x_data, y_data, batch_count, key_seq_len); - break; - case 14: // 16384 - SoftmaxMaskFuseUpperTriangleGPUKernel - <<>>( - x_data, y_data, batch_count, key_seq_len); - break; - default: - PADDLE_THROW(phi::errors::Unimplemented("Too large sequence length.")); - break; - } - } -}; - -template -class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* grad_x = - context.Output(framework::GradVarName("X")); - auto* grad_y = - context.Input(framework::GradVarName("Out")); - auto* softmax_rst = context.Input("Out"); - - auto* grad_x_data = grad_x->mutable_data(context.GetPlace()); - auto* grad_y_data = grad_y->data(); - auto* softmax_rst_data = softmax_rst->data(); - - auto y_dim = grad_y->dims(); - auto batches = y_dim[0]; - auto attn_heads = y_dim[1]; - auto attn_mul_batch = batches * attn_heads; - auto query_seq_len = y_dim[2]; - auto key_seq_len = y_dim[3]; - - auto& place = - *context.template device_context().eigen_device(); - auto stream = context.cuda_device_context().stream(); - - int pow2_index = get_pow2_index_value(key_seq_len); - const int next_pow2 = 1 << pow2_index; - int64_t batch_count = attn_mul_batch * query_seq_len; - int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; - int batches_per_warp = (next_pow2 <= 128) ? 2 : 1; - // use 128 threads per block to maximum gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - dim3 blocks(query_seq_len, - (attn_mul_batch + batches_per_block) / batches_per_block, - 1); - dim3 threads(warp_size, warps_per_block, 1); - - switch (pow2_index) { - case 5: // 32 - SoftmaxMaskFuseUpperTriangleGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - case 6: // 64 - SoftmaxMaskFuseUpperTriangleGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - case 7: // 128 - SoftmaxMaskFuseUpperTriangleGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - case 8: // 256 - SoftmaxMaskFuseUpperTriangleGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - case 9: // 512 - SoftmaxMaskFuseUpperTriangleGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - case 10: // 1024 - SoftmaxMaskFuseUpperTriangleGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - case 11: // 2048 - SoftmaxMaskFuseUpperTriangleGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - case 12: // 4096 - SoftmaxMaskFuseUpperTriangleGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - case 13: // 8192 - SoftmaxMaskFuseUpperTriangleGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - case 14: - SoftmaxMaskFuseUpperTriangleGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - default: - PADDLE_THROW(phi::errors::Unimplemented("Too large sequence length.")); - break; - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -PD_REGISTER_STRUCT_KERNEL(fused_softmax_mask_upper_triangle, - GPU, - ALL_LAYOUT, - ops::SoftmaxMaskFuseUpperTriangleKernel, - float, - plat::float16, - plat::bfloat16) {} -PD_REGISTER_STRUCT_KERNEL(fused_softmax_mask_upper_triangle_grad, - GPU, - ALL_LAYOUT, - ops::SoftmaxMaskFuseUpperTriangleGradKernel, - float, - plat::float16, - plat::bfloat16) {} diff --git a/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.h b/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.h deleted file mode 100644 index c0495b7006b..00000000000 --- a/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.h +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { -template -class SoftmaxMaskFuseUpperTriangleCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), - true, - platform::errors::Unimplemented( - "Softmax mask fuse op only supports GPU now.")); - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/fused_softmax_mask_upper_triangle_kernel.h b/paddle/phi/kernels/fused_softmax_mask_upper_triangle_kernel.h new file mode 100644 index 00000000000..921c21b0aaa --- /dev/null +++ b/paddle/phi/kernels/fused_softmax_mask_upper_triangle_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void FusedSoftmaxMaskFuseUpperTriangleKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out); + +template +void FusedSoftmaxMaskFuseUpperTriangleGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cpu/fused_softmax_mask_upper_triangle_kernel.cc b/paddle/phi/kernels/fusion/cpu/fused_softmax_mask_upper_triangle_kernel.cc new file mode 100644 index 00000000000..b9ded16d1b0 --- /dev/null +++ b/paddle/phi/kernels/fusion/cpu/fused_softmax_mask_upper_triangle_kernel.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fused_softmax_mask_upper_triangle_kernel.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +template +void FusedSoftmaxMaskFuseUpperTriangleKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU; + PADDLE_ENFORCE_EQ(is_gpu_place, + true, + phi::errors::Unimplemented( + "Softmax mask fuse op only supports GPU now.")); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_softmax_mask_upper_triangle, + CPU, + ALL_LAYOUT, + phi::fusion::FusedSoftmaxMaskFuseUpperTriangleKernel, + float, + double) {} diff --git a/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_grad_kernel.cu new file mode 100644 index 00000000000..6c7fe36d364 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_grad_kernel.cu @@ -0,0 +1,266 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/generator.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/fused_softmax_mask_upper_triangle_kernel.h" +#include "paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_utils.h" + +namespace phi { +namespace fusion { + +template +__global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input, + T* grad_output, + const T* softmax_rst, + int64_t batch_count, + int64_t key_seq_len) { + constexpr int next_pow2 = 1 << pow2_index; + constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; + constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4); + constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1; + constexpr int kOneLoadingCounts = 4; + int64_t key_seq_len_pow_2 = key_seq_len * key_seq_len; + + int64_t first_idx = + (static_cast(blockDim.y) * blockIdx.y + threadIdx.y) * + gridDim.x * kLocalBatchSize + + blockIdx.x; + int64_t local_block_idx = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int64_t local_batches = batch_count - first_idx; + if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize; + + // there might be multiple batches per warp. compute the index within the + // batch + int64_t local_idx = threadIdx.x; + + // the first element to process by the current thread + int64_t offset = first_idx * key_seq_len + kOneLoadingCounts * local_idx; + grad_input += offset; + grad_output += offset; + softmax_rst += offset; + + // load data from global memory + float grad_input_reg[kLocalBatchSize][kLocalIterations]{0.0f}; + float softmax_rst_reg[kLocalBatchSize][kLocalIterations]{0.0f}; + T temp_grad_input[kOneLoadingCounts]; + T temp_softmax_rst[kOneLoadingCounts]; + +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + auto batch_total_number = (i >= local_batches) ? 0 : local_block_idx; + +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { + auto element_index = kOneLoadingCounts * local_idx + ii * warp_size; + if (element_index < batch_total_number) { + load_data_upper_tri( + temp_grad_input, + grad_input + i * key_seq_len_pow_2 + ii * warp_size); + load_data_upper_tri( + temp_softmax_rst, + softmax_rst + i * key_seq_len_pow_2 + ii * warp_size); + +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + if (element_index + counter < batch_total_number) { + softmax_rst_reg[i][ii + counter] = + static_cast(temp_softmax_rst[counter]); + } + } +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + if (element_index + counter < batch_total_number) { + grad_input_reg[i][ii + counter] = + static_cast(temp_grad_input[counter]) * + softmax_rst_reg[i][ii + counter]; + } + } + } + } + } + + float sum[kLocalBatchSize]; +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + sum[i] = grad_input_reg[i][0]; +#pragma unroll + for (int ii = 1; ii < kLocalIterations; ++ii) { + sum[i] += grad_input_reg[i][ii]; + } + } + warp_reduce_upper_tri( + sum); + +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { + auto element_index = kOneLoadingCounts * local_idx + ii * warp_size; + if (element_index < key_seq_len) { + // compute gradients + T samples_out[kOneLoadingCounts]; +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + samples_out[counter] = grad_input_reg[i][ii + counter] - + softmax_rst_reg[i][ii + counter] * sum[i]; + } + load_data_upper_tri( + grad_output + i * key_seq_len_pow_2 + ii * warp_size, samples_out); + } + } + } +} + +template +void FusedSoftmaxMaskFuseUpperTriangleGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + auto* grad_y = &out_grad; + auto* softmax_rst = &out; + + auto* x_grad_data = dev_ctx.template Alloc(x_grad); + + auto* grad_y_data = grad_y->data(); + auto* softmax_rst_data = softmax_rst->data(); + + auto y_dim = grad_y->dims(); + auto batches = y_dim[0]; + auto attn_heads = y_dim[1]; + auto attn_mul_batch = batches * attn_heads; + auto query_seq_len = y_dim[2]; + auto key_seq_len = y_dim[3]; + + auto stream = dev_ctx.stream(); + + int pow2_index = get_pow2_index_value(key_seq_len); + const int next_pow2 = 1 << pow2_index; + int64_t batch_count = attn_mul_batch * query_seq_len; + int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; + int batches_per_warp = (next_pow2 <= 128) ? 2 : 1; + // use 128 threads per block to maximum gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + dim3 blocks(query_seq_len, + (attn_mul_batch + batches_per_block) / batches_per_block, + 1); + dim3 threads(warp_size, warps_per_block, 1); + + switch (pow2_index) { + case 5: // 32 + SoftmaxMaskFuseUpperTriangleGradGPUKernel + <<>>(grad_y_data, + x_grad_data, + softmax_rst_data, + batch_count, + key_seq_len); + break; + case 6: // 64 + SoftmaxMaskFuseUpperTriangleGradGPUKernel + <<>>(grad_y_data, + x_grad_data, + softmax_rst_data, + batch_count, + key_seq_len); + break; + case 7: // 128 + SoftmaxMaskFuseUpperTriangleGradGPUKernel + <<>>(grad_y_data, + x_grad_data, + softmax_rst_data, + batch_count, + key_seq_len); + break; + case 8: // 256 + SoftmaxMaskFuseUpperTriangleGradGPUKernel + <<>>(grad_y_data, + x_grad_data, + softmax_rst_data, + batch_count, + key_seq_len); + break; + case 9: // 512 + SoftmaxMaskFuseUpperTriangleGradGPUKernel + <<>>(grad_y_data, + x_grad_data, + softmax_rst_data, + batch_count, + key_seq_len); + break; + case 10: // 1024 + SoftmaxMaskFuseUpperTriangleGradGPUKernel + <<>>(grad_y_data, + x_grad_data, + softmax_rst_data, + batch_count, + key_seq_len); + break; + case 11: // 2048 + SoftmaxMaskFuseUpperTriangleGradGPUKernel + <<>>(grad_y_data, + x_grad_data, + softmax_rst_data, + batch_count, + key_seq_len); + break; + case 12: // 4096 + SoftmaxMaskFuseUpperTriangleGradGPUKernel + <<>>(grad_y_data, + x_grad_data, + softmax_rst_data, + batch_count, + key_seq_len); + break; + case 13: // 8192 + SoftmaxMaskFuseUpperTriangleGradGPUKernel + <<>>(grad_y_data, + x_grad_data, + softmax_rst_data, + batch_count, + key_seq_len); + break; + case 14: + SoftmaxMaskFuseUpperTriangleGradGPUKernel + <<>>(grad_y_data, + x_grad_data, + softmax_rst_data, + batch_count, + key_seq_len); + break; + default: + PADDLE_THROW(phi::errors::Unimplemented("Too large sequence length.")); + break; + } +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_softmax_mask_upper_triangle_grad, + GPU, + ALL_LAYOUT, + phi::fusion::FusedSoftmaxMaskFuseUpperTriangleGradKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_kernel.cu new file mode 100644 index 00000000000..30e5599aac2 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_kernel.cu @@ -0,0 +1,261 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/generator.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/fused_softmax_mask_upper_triangle_kernel.h" +#include "paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_utils.h" + +namespace phi { +namespace fusion { + +template +__global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src, + T* dst, + int64_t batch_count, + int64_t key_seq_len) { + constexpr int next_pow2 = 1 << pow2_index; + constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; + constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4); + constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1; + constexpr int kOneLoadingCounts = 4; + int64_t key_seq_len_pow_2 = key_seq_len * key_seq_len; + + int64_t first_idx = + (static_cast(blockDim.y) * blockIdx.y + threadIdx.y) * + gridDim.x * kLocalBatchSize + + blockIdx.x; + int64_t local_block_idx = blockIdx.x + 1; + int64_t warp_iter_upper_bound = + (local_block_idx + kOneLoadingCounts * warp_size - 1) / warp_size; + + int64_t local_batches = batch_count - first_idx; + if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize; + + int64_t local_idx = threadIdx.x; + + src += first_idx * key_seq_len + kOneLoadingCounts * local_idx; + dst += first_idx * key_seq_len + kOneLoadingCounts * local_idx; + + float data[kLocalBatchSize][kLocalIterations]; + T temp_in[kOneLoadingCounts]; + +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + auto batch_total_number = (i >= local_batches) ? 0 : local_block_idx; + +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { + auto element_index = kOneLoadingCounts * local_idx + ii * warp_size; + + if (element_index < batch_total_number) { + load_data_upper_tri(temp_in, + src + i * key_seq_len_pow_2 + ii * warp_size); + +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + if ((element_index + counter) < batch_total_number) { + data[i][ii + counter] = static_cast(temp_in[counter]); + } else { + data[i][ii + counter] = -std::numeric_limits::infinity(); + } + } + } else { +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + data[i][ii + counter] = -std::numeric_limits::infinity(); + } + } + } + } + + float max_value[kLocalBatchSize]; +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + max_value[i] = data[i][0]; +#pragma unroll + for (int ii = 1; ii < kLocalIterations; ++ii) { + max_value[i] = (max_value[i] > data[i][ii]) ? max_value[i] : data[i][ii]; + } + } + warp_reduce_upper_tri( + max_value); + + float sum[kLocalBatchSize]{0.0f}; +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ++ii) { + if (ii < warp_iter_upper_bound) { + data[i][ii] = std::exp((data[i][ii] - max_value[i])); + sum[i] += data[i][ii]; + } + } + } + warp_reduce_upper_tri( + sum); + + T out[kOneLoadingCounts]; +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { + auto element_index = kOneLoadingCounts * local_idx + ii * warp_size; + + if (element_index < local_block_idx) { +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + if (element_index + counter < local_block_idx) { + out[counter] = data[i][ii + counter] / sum[i]; + } else { + out[counter] = 0; + } + } + load_data_upper_tri(dst + i * key_seq_len_pow_2 + ii * warp_size, out); + } else if (element_index < key_seq_len) { + load_zero_vector_upper_tri(dst + i * key_seq_len_pow_2 + + ii * warp_size); + } else { + break; + } + } + } +} + +template +void FusedSoftmaxMaskFuseUpperTriangleKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + auto* x_ptr = &x; + + auto* x_data = x_ptr->data(); + auto* y_data = dev_ctx.template Alloc(out); + + auto x_dim = x_ptr->dims(); + auto batches = x_dim[0]; + auto attn_heads = x_dim[1]; + auto attn_mul_batch = batches * attn_heads; + auto query_seq_len = x_dim[2]; + auto key_seq_len = x_dim[3]; + + PADDLE_ENFORCE_EQ(key_seq_len, + query_seq_len, + phi::errors::InvalidArgument( + "Key seq len must be equal with query seq len " + "received key len: %d, query len: %d", + key_seq_len, + query_seq_len)); + + PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len <= 16384, + true, + phi::errors::InvalidArgument( + "Input x's last dim must be between [32, 16384] " + "received the last dimension of x is %d", + key_seq_len)); + + auto stream = dev_ctx.stream(); + + int pow2_index = get_pow2_index_value(key_seq_len); + const int next_pow2 = 1 << pow2_index; + int64_t batch_count = attn_mul_batch * query_seq_len; + int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; + int batches_per_warp = (next_pow2 <= 128) ? 2 : 1; + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + PADDLE_ENFORCE_EQ( + query_seq_len % batches_per_block, + 0, + phi::errors::InvalidArgument( + "The query seq len (third dim of input X) must can divide the " + "number of batches per block. The query seq len is %d, while " + "the number of batches per block is %d.", + query_seq_len, + batches_per_block)); + dim3 blocks(query_seq_len, + (attn_mul_batch + batches_per_block) / batches_per_block, + 1); + dim3 threads(warp_size, warps_per_block, 1); + + switch (pow2_index) { + case 5: // 32 + SoftmaxMaskFuseUpperTriangleGPUKernel + <<>>( + x_data, y_data, batch_count, key_seq_len); + break; + case 6: // 64 + SoftmaxMaskFuseUpperTriangleGPUKernel + <<>>( + x_data, y_data, batch_count, key_seq_len); + break; + case 7: // 128 + SoftmaxMaskFuseUpperTriangleGPUKernel + <<>>( + x_data, y_data, batch_count, key_seq_len); + break; + case 8: // 256 + SoftmaxMaskFuseUpperTriangleGPUKernel + <<>>( + x_data, y_data, batch_count, key_seq_len); + break; + case 9: // 512 + SoftmaxMaskFuseUpperTriangleGPUKernel + <<>>( + x_data, y_data, batch_count, key_seq_len); + break; + case 10: // 1024 + SoftmaxMaskFuseUpperTriangleGPUKernel + <<>>( + x_data, y_data, batch_count, key_seq_len); + break; + case 11: // 2048 + SoftmaxMaskFuseUpperTriangleGPUKernel + <<>>( + x_data, y_data, batch_count, key_seq_len); + break; + case 12: // 4096 + SoftmaxMaskFuseUpperTriangleGPUKernel + <<>>( + x_data, y_data, batch_count, key_seq_len); + break; + case 13: // 8192 + SoftmaxMaskFuseUpperTriangleGPUKernel + <<>>( + x_data, y_data, batch_count, key_seq_len); + break; + case 14: // 16384 + SoftmaxMaskFuseUpperTriangleGPUKernel + <<>>( + x_data, y_data, batch_count, key_seq_len); + break; + default: + PADDLE_THROW(phi::errors::Unimplemented("Too large sequence length.")); + break; + } +} +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_softmax_mask_upper_triangle, + GPU, + ALL_LAYOUT, + phi::fusion::FusedSoftmaxMaskFuseUpperTriangleKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_utils.h b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_utils.h new file mode 100644 index 00000000000..32dc8aa07de --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_utils.h @@ -0,0 +1,117 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_CUDA +#include +#include +#endif +#ifdef PADDLE_WITH_HIP +#include +#include +#endif +#include +#include +#include +#include + +#include +#include + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +namespace fusion { + +#ifdef PADDLE_WITH_HIP +#define WARP_SIZE 64 +#else +#define WARP_SIZE 32 +#endif + +#define MASK 0xffffffff + +__device__ __inline__ void load_data_upper_tri(phi::float16* dst, + const phi::float16* src) { + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); +} + +__device__ __inline__ void load_data_upper_tri(phi::bfloat16* dst, + const phi::bfloat16* src) { + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); +} + +__device__ __inline__ void load_data_upper_tri(float* dst, const float* src) { + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); +} + +__device__ __inline__ void load_zero_vector_upper_tri(phi::float16* dst) { + *(reinterpret_cast(dst)) = make_float2(0.0f, 0.0f); +} + +__device__ __inline__ void load_zero_vector_upper_tri(phi::bfloat16* dst) { + *(reinterpret_cast(dst)) = make_float2(0.0f, 0.0f); +} + +__device__ __inline__ void load_zero_vector_upper_tri(float* dst) { + *(reinterpret_cast(dst)) = make_float4(0.0f, 0.0f, 0.0f, 0.0f); +} + +__inline__ int get_pow2_index_value(int value) { + int pow2_index = 0; + while ((1 << pow2_index) < value) { + ++pow2_index; + } + return pow2_index; +} + +template +struct AddOP_upper_tri { + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } +}; + +template +struct MaxOP_upper_tri { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T warp_shfl_xor_upper_tri(T value, + int laneMask, + int width, + unsigned int mask = MASK) { +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce_upper_tri(T* sum) { + ReduceOp r; +#pragma unroll + for (int offset = width / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < batch; ++i) { + T b = warp_shfl_xor_upper_tri(sum[i], offset, width); + sum[i] = r(sum[i], b); + } + } +} +} // namespace fusion +} // namespace phi diff --git a/paddle/phi/ops/compat/fused_softmax_mask_upper_triangle_sig.cc b/paddle/phi/ops/compat/fused_softmax_mask_upper_triangle_sig.cc new file mode 100644 index 00000000000..d3d9857c887 --- /dev/null +++ b/paddle/phi/ops/compat/fused_softmax_mask_upper_triangle_sig.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature FusedSoftmaxMaskUpperTriangleOpArgumentMapping( + const ArgumentMappingContext& ctx UNUSED) { + return KernelSignature( + "fused_softmax_mask_upper_triangle", {"X"}, {}, {"Out"}); +} + +KernelSignature FusedSoftmaxMaskUpperTriangleGradOpArgumentMapping( + const ArgumentMappingContext& ctx UNUSED) { + return KernelSignature("fused_softmax_mask_upper_triangle_grad", + {"Out", "Out@GRAD"}, + {}, + {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(fused_softmax_mask_upper_triangle, + phi::FusedSoftmaxMaskUpperTriangleOpArgumentMapping); + +PD_REGISTER_ARG_MAPPING_FN( + fused_softmax_mask_upper_triangle_grad, + phi::FusedSoftmaxMaskUpperTriangleGradOpArgumentMapping); diff --git a/test/legacy_test/test_softmax_mask_fuse_upper_triangle_op.py b/test/legacy_test/test_softmax_mask_fuse_upper_triangle_op.py index 82dbaaf0e78..cf1efa779dc 100644 --- a/test/legacy_test/test_softmax_mask_fuse_upper_triangle_op.py +++ b/test/legacy_test/test_softmax_mask_fuse_upper_triangle_op.py @@ -43,6 +43,7 @@ def _get_softmax_upper(x, fp16=True): class TestSoftmaxMaskFuseOp(OpTest): def setUp(self): self.op_type = "fused_softmax_mask_upper_triangle" + self.python_api = paddle.incubate.softmax_mask_fuse_upper_triangle x = np.random.random((1, 4, 32, 32)).astype("float16") self.inputs = {'X': x} rst = _get_softmax_upper(x) @@ -61,6 +62,7 @@ class TestSoftmaxMaskFuseOp(OpTest): class TestSoftmaxMaskFuseOp1(OpTest): def setUp(self): self.op_type = "fused_softmax_mask_upper_triangle" + self.python_api = paddle.incubate.softmax_mask_fuse_upper_triangle x = np.random.random((1, 4, 32, 32)) self.inputs = {'X': x} rst = _get_softmax_upper(x) -- GitLab