From 2b4f44d5c8c1848da7b7448068d880d76d41202c Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 9 Sep 2022 10:34:37 +0800 Subject: [PATCH] [Phi] Add fusion kernel dir and migrate fused_softmax_mask op (#45802) * add fusion dir and fuse_softmax_mask kernel * remove fusion kernel dir * migrate infershape * fix code errror --- paddle/fluid/framework/operator.cc | 2 +- .../fluid/operators/fused_softmax_mask_op.cc | 57 +- .../fluid/operators/fused_softmax_mask_op.cu | 595 ------------------ .../fluid/operators/fused_softmax_mask_op.h | 33 - paddle/phi/infermeta/binary.cc | 22 + paddle/phi/infermeta/binary.h | 4 + paddle/phi/kernels/CMakeLists.txt | 14 +- paddle/phi/kernels/fusion/README.md | 13 + .../fusion/fused_softmax_mask_grad_kernel.h | 27 + .../fusion/fused_softmax_mask_kernel.h | 27 + .../gpu/fused_softmax_mask_grad_kernel.cu | 201 ++++++ .../fusion/gpu/fused_softmax_mask_kernel.cu | 280 +++++++++ .../fusion/gpu/fused_softmax_mask_utils.h | 95 +++ .../phi/ops/compat/fused_softmax_mask_sig.cc | 28 + 14 files changed, 723 insertions(+), 675 deletions(-) delete mode 100644 paddle/fluid/operators/fused_softmax_mask_op.cu delete mode 100644 paddle/fluid/operators/fused_softmax_mask_op.h create mode 100644 paddle/phi/kernels/fusion/README.md create mode 100644 paddle/phi/kernels/fusion/fused_softmax_mask_grad_kernel.h create mode 100644 paddle/phi/kernels/fusion/fused_softmax_mask_kernel.h create mode 100644 paddle/phi/kernels/fusion/gpu/fused_softmax_mask_grad_kernel.cu create mode 100644 paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu create mode 100644 paddle/phi/kernels/fusion/gpu/fused_softmax_mask_utils.h create mode 100644 paddle/phi/ops/compat/fused_softmax_mask_sig.cc diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index fe64f81ddf0..4c28a9b5953 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1890,7 +1890,7 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const { PADDLE_ENFORCE_NE( kernels_iter, all_op_kernels.end(), - platform::errors::Unavailable( + platform::errors::Unimplemented( "There are no kernels which are registered in the %s operator.", type_)); diff --git a/paddle/fluid/operators/fused_softmax_mask_op.cc b/paddle/fluid/operators/fused_softmax_mask_op.cc index 11c1fa4af85..604eaaaf3fc 100644 --- a/paddle/fluid/operators/fused_softmax_mask_op.cc +++ b/paddle/fluid/operators/fused_softmax_mask_op.cc @@ -11,10 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fused_softmax_mask_op.h" #include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/binary.h" + namespace paddle { namespace operators { @@ -23,30 +27,6 @@ using framework::Tensor; class SoftmaxMaskFuseOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SoftmaxMaskFuse"); - OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "SoftmaxMaskFuse"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SoftmaxMaskFuse"); - auto x_dims = ctx->GetInputDim("X"); - auto mask_dims = ctx->GetInputDim("Mask"); - - PADDLE_ENFORCE_EQ( - x_dims.size(), - 4, - platform::errors::InvalidArgument("Input x must be in 4D dimension but " - "received the dimension of X is %d", - x_dims.size())); - PADDLE_ENFORCE_EQ(mask_dims.size(), - 4, - platform::errors::InvalidArgument( - "Input mask must be in 4D dimension but " - "received the dimension of mask is %d", - mask_dims.size())); - - ctx->SetOutputDim("Out", x_dims); - ctx->ShareLoD("X", "Out"); - } }; class SoftmaxMaskFuseOpMaker : public framework::OpProtoAndCheckerMaker { @@ -80,17 +60,6 @@ By doing this fusion, we can optimize the training by class SoftmaxMaskFuseOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - framework::GradVarName("Out"), - "SoftmaxMaskFuseGrad"); - - auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); - ctx->SetOutputDim(framework::GradVarName("X"), out_dims); - ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X")); - } }; template @@ -111,12 +80,18 @@ class SoftmaxMaskFuseGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(fused_softmax_mask, + SoftmaxMaskFuseInferShapeFunctor, + PD_INFER_META(phi::SoftmaxMaskFuseInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(fused_softmax_mask_grad, + SoftmaxMaskFuseGradInferShapeFunctor, + PD_INFER_META(phi::GeneralUnaryGradInferMeta)); REGISTER_OPERATOR(fused_softmax_mask, ops::SoftmaxMaskFuseOp, ops::SoftmaxMaskFuseOpMaker, ops::SoftmaxMaskFuseGradOpMaker, - ops::SoftmaxMaskFuseGradOpMaker); -REGISTER_OPERATOR(fused_softmax_mask_grad, ops::SoftmaxMaskFuseOpGrad); -REGISTER_OP_CPU_KERNEL(fused_softmax_mask, - ops::SoftmaxMaskFuseCPUKernel, - ops::SoftmaxMaskFuseCPUKernel); + ops::SoftmaxMaskFuseGradOpMaker, + SoftmaxMaskFuseInferShapeFunctor); +REGISTER_OPERATOR(fused_softmax_mask_grad, + ops::SoftmaxMaskFuseOpGrad, + SoftmaxMaskFuseGradInferShapeFunctor); diff --git a/paddle/fluid/operators/fused_softmax_mask_op.cu b/paddle/fluid/operators/fused_softmax_mask_op.cu deleted file mode 100644 index c259d0efb49..00000000000 --- a/paddle/fluid/operators/fused_softmax_mask_op.cu +++ /dev/null @@ -1,595 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -// this file is inspired by: -// https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/fused_kernels/scaled_masked_softmax.h -/* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifdef PADDLE_WITH_CUDA -#include -#include -#endif -#ifdef PADDLE_WITH_HIP -#include -#include -#endif -#include -#include -#include -#include - -#include -#include - -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/operators/fused_softmax_mask_op.h" -#include "paddle/fluid/platform/float16.h" - -namespace paddle { -namespace operators { - -using framework::Tensor; - -#ifdef PADDLE_WITH_HIP -#define WARP_SIZE 64 -#else -#define WARP_SIZE 32 -#endif - -#define MASK 0xffffffff - -namespace plat = paddle::platform; - -__device__ __inline__ void load_data(plat::float16* dst, - const plat::float16* src) { - *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); -} - -__device__ __inline__ void load_data(float* dst, const float* src) { - *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); -} - -int get_pow2(int value) { - // get next pow2 index - int pow2_index = 0; - while ((1 << pow2_index) < value) { - ++pow2_index; - } - return pow2_index; -} - -template -struct AddOP { - __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } -}; - -template -struct MaxOP { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T -warp_shfl_xor(T value, int laneMask, int width, unsigned int mask = MASK) { -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(T* sum) { - ReduceOp r; -#pragma unroll - for (int offset = width / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < batch; ++i) { - T b = warp_shfl_xor(sum[i], offset, width); - sum[i] = r(sum[i], b); - } - } -} - -// T == fp16 -template -__global__ void SoftmaxMaskFuseGPUKernel(const T* x_data, - const T* mask_data, - T* y_data, - int batch_count, - int key_seq_len) { - // the forward gpu kernel - constexpr int next_pow2 = 1 << pow2_index; - constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; - constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4); - constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1; - constexpr int kOneLoadingCounts = 4; - - int data_first_idx = - (blockDim.y * - (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + - threadIdx.y) * - kLocalBatchSize; - - int mask_fist_idx = - (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * - kLocalBatchSize; - - // batch_count might not be a multiple of kLocalBatchSize. Check how - // many batches have to computed within this WARP. - int local_batches = batch_count - data_first_idx; - if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize; - - // might be many batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - int x_offset = data_first_idx * key_seq_len + kOneLoadingCounts * local_idx; - int mask_offset = mask_fist_idx * key_seq_len + kOneLoadingCounts * local_idx; - x_data += x_offset; - mask_data += mask_offset; - y_data += x_offset; - - // using float for all inter compute - float data[kLocalBatchSize][kLocalIterations]; - T temp_data[kOneLoadingCounts]; - T temp_mask[kOneLoadingCounts]; - -#pragma unroll - for (int i = 0; i < kLocalBatchSize; ++i) { - int batch_data = (i >= local_batches) ? 0 : key_seq_len; - -#pragma unroll - for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { - int data_index = kOneLoadingCounts * local_idx + ii * warp_size; - - if (data_index < batch_data) { - int itr_idx = i * key_seq_len + ii * warp_size; - - // efficiently load data from global memory - load_data(temp_data, x_data + itr_idx); - load_data(temp_mask, mask_data + itr_idx); - -#pragma unroll - for (int counter = 0; counter < kOneLoadingCounts; ++counter) { - data[i][ii + counter] = static_cast(temp_data[counter]) + - static_cast(temp_mask[counter]); - } - } else { -#pragma unroll - for (int counter = 0; counter < kOneLoadingCounts; ++counter) { - data[i][ii + counter] = -std::numeric_limits::infinity(); - } - } - } - } - - // compute max_value - // max value for each batch for current warp - float samples_max_value[kLocalBatchSize]; -#pragma unroll - for (int i = 0; i < kLocalBatchSize; ++i) { - samples_max_value[i] = data[i][0]; -#pragma unroll - for (int ii = 1; ii < kLocalIterations; ++ii) { - samples_max_value[i] = (samples_max_value[i] > data[i][ii]) - ? samples_max_value[i] - : data[i][ii]; - } - } - // max value for each batch for all warp - warp_reduce(samples_max_value); - - // compute the sum for each batch for current warp - float samples_sum[kLocalBatchSize]{0.0f}; -#pragma unroll - for (int i = 0; i < kLocalBatchSize; ++i) { -#pragma unroll - for (int ii = 0; ii < kLocalIterations; ++ii) { - data[i][ii] = std::exp((data[i][ii] - samples_max_value[i])); - samples_sum[i] += data[i][ii]; - } - } - // samples_sum for each batch for all warp - warp_reduce(samples_sum); - - // load the result from device back to host - T samples_out[kOneLoadingCounts]; -#pragma unroll - for (int i = 0; i < kLocalBatchSize; ++i) { - if (i >= local_batches) break; -#pragma unroll - for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { - int idx = kOneLoadingCounts * local_idx + ii * warp_size; - if (idx < key_seq_len) { -#pragma unroll - for (int counter = 0; counter < kOneLoadingCounts; ++counter) { - samples_out[counter] = data[i][ii + counter] / samples_sum[i]; - } - load_data(y_data + i * key_seq_len + ii * warp_size, samples_out); - } else { - break; - } - } - } -} - -template -__global__ void SoftmaxMaskFuseGradGPUKernel(const T* grad_input, - T* grad_output, - const T* softmax_rst, - int batch_count, - int key_seq_len) { - constexpr int next_pow2 = 1 << pow2_index; - constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; - constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4); - constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1; - constexpr int kOneLoadingCounts = 4; - - int data_first_idx = - (blockDim.y * blockIdx.x + threadIdx.y) * kLocalBatchSize; - - // batch_count might not be a multiple of kLocalBatchSize. Check how - // many batches have to computed within this WARP. - int local_batches = batch_count - data_first_idx; - if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize; - - // might be many batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int offset = data_first_idx * key_seq_len + kOneLoadingCounts * local_idx; - grad_input += offset; - grad_output += offset; - softmax_rst += offset; - - // using float for all inter compute - float grad_input_reg[kLocalBatchSize][kLocalIterations]{0.0f}; - float softmax_rst_reg[kLocalBatchSize][kLocalIterations]{0.0f}; - T temp_grad_input[kOneLoadingCounts]; - T temp_softmax_rst[kOneLoadingCounts]; - -#pragma unroll - for (int i = 0; i < kLocalBatchSize; ++i) { - int batch_data = (i >= local_batches) ? 0 : key_seq_len; - -#pragma unroll - for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { - int data_index = kOneLoadingCounts * local_idx + ii * WARP_SIZE; - if (data_index < batch_data) { - load_data(temp_grad_input, - grad_input + i * key_seq_len + ii * warp_size); - load_data(temp_softmax_rst, - softmax_rst + i * key_seq_len + ii * warp_size); - -#pragma unroll - for (int counter = 0; counter < kOneLoadingCounts; ++counter) { - softmax_rst_reg[i][ii + counter] = - static_cast(temp_softmax_rst[counter]); - } -#pragma unroll - for (int counter = 0; counter < kOneLoadingCounts; ++counter) { - grad_input_reg[i][ii + counter] = - static_cast(temp_grad_input[counter]) * - softmax_rst_reg[i][ii + counter]; - } - } - } - } - - float samples_sum[kLocalBatchSize]; -#pragma unroll - for (int i = 0; i < kLocalBatchSize; ++i) { - samples_sum[i] = grad_input_reg[i][0]; -#pragma unroll - for (int ii = 1; ii < kLocalIterations; ++ii) { - samples_sum[i] += grad_input_reg[i][ii]; - } - } - warp_reduce(samples_sum); - -#pragma unroll - for (int i = 0; i < kLocalBatchSize; ++i) { - if (i >= local_batches) break; -#pragma unroll - for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { - int data_index = kOneLoadingCounts * local_idx + ii * warp_size; - if (data_index < key_seq_len) { - // compute gradients - T samples_out[kOneLoadingCounts]; -#pragma unroll - for (int counter = 0; counter < kOneLoadingCounts; ++counter) { - samples_out[counter] = - grad_input_reg[i][ii + counter] - - softmax_rst_reg[i][ii + counter] * samples_sum[i]; - } - load_data(grad_output + i * key_seq_len + ii * warp_size, samples_out); - } - } - } -} - -// T only supports fp16 -// leave as template only for future update -template -class SoftmaxMaskFuseKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* mask = context.Input("Mask"); - auto* y = context.Output("Out"); - - auto* x_data = x->data(); - auto* mask_data = mask->data(); - auto* y_data = y->mutable_data(context.GetPlace()); - - auto x_dim = x->dims(); - auto mask_dim = mask->dims(); - auto batches = x_dim[0]; - auto attn_heads = x_dim[1]; - auto query_seq_len = x_dim[2]; - auto key_seq_len = x_dim[3]; - - PADDLE_ENFORCE_GT(query_seq_len, - 1, - platform::errors::InvalidArgument( - "Input x's second last dim must be large than 1 but " - "received the second last dimension of x is %d", - query_seq_len)); - - PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len < 8192, - true, - platform::errors::InvalidArgument( - "Input x's last dim must be between [32, 8192) " - "received the last dimension of x is %d", - key_seq_len)); - - PADDLE_ENFORCE_EQ(mask_dim[1], - 1, - platform::errors::InvalidArgument( - "Input mask's second dim must be 1 " - "received the second dimension of mask is %d", - mask_dim[1])); - - // dim of x and mask must be equal - for (size_t idx = 0; idx < 4; ++idx) { - if (idx == 1) continue; - PADDLE_ENFORCE_EQ( - x_dim[idx], - mask_dim[idx], - platform::errors::InvalidArgument( - "Input x's %dth dim should be equal with input mask's %dth dim " - "but " - "received the %dth dimension of x and mask are not equal " - "the %dth dim of x is %d, while the %dth dim of mask is %d.", - idx, - idx, - idx, - idx, - x_dim[idx], - idx, - mask_dim[idx])); - } - - auto& place = *context.template device_context().eigen_device(); - auto stream = context.cuda_device_context().stream(); - - int pow2_index = get_pow2(key_seq_len); - const int next_pow2 = 1 << pow2_index; - int batch_count = batches * attn_heads * query_seq_len; - int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; - int batches_per_warp = (next_pow2 <= 128) ? 2 : 1; - // use 128 threads per block to maximum gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - PADDLE_ENFORCE_EQ( - query_seq_len % batches_per_block, - 0, - platform::errors::InvalidArgument( - "The query seq len (third dim of input X) must can divide the " - "number of batches per block. The query seq len is %d, while " - "the number of batches per block is %d.", - query_seq_len, - batches_per_block)); - dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); - dim3 threads(warp_size, warps_per_block, 1); - - // launch the kernel based on the pow2_index - switch (pow2_index) { - case 5: // 32 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 6: // 64 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 7: // 128 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 8: // 256 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 9: // 512 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 10: // 1024 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 11: // 2048 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 12: // 4096 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - case 13: // 8192 - SoftmaxMaskFuseGPUKernel<<>>( - x_data, mask_data, y_data, batch_count, key_seq_len); - break; - default: - break; - } - } -}; - -template -class SoftmaxMaskFuseGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* grad_x = context.Output(framework::GradVarName("X")); - auto* grad_y = context.Input(framework::GradVarName("Out")); - auto* softmax_rst = context.Input("Softmax"); - - auto* grad_x_data = grad_x->mutable_data(context.GetPlace()); - auto* grad_y_data = grad_y->data(); - auto* softmax_rst_data = softmax_rst->data(); - - auto y_dim = grad_y->dims(); - auto batches = y_dim[0]; - auto attn_heads = y_dim[1]; - auto query_seq_len = y_dim[2]; - auto key_seq_len = y_dim[3]; - - auto& place = *context.template device_context().eigen_device(); - auto stream = context.cuda_device_context().stream(); - - int pow2_index = get_pow2(key_seq_len); - const int next_pow2 = 1 << pow2_index; - int batch_count = batches * attn_heads * query_seq_len; - int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; - int batches_per_warp = (next_pow2 <= 128) ? 2 : 1; - // use 128 threads per block to maximum gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = batch_count / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - - // launch the kernel based on the pow2_index - switch (pow2_index) { - case 5: // 32 - SoftmaxMaskFuseGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - case 6: // 64 - SoftmaxMaskFuseGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - case 7: // 128 - SoftmaxMaskFuseGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - case 8: // 256 - SoftmaxMaskFuseGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - case 9: // 512 - SoftmaxMaskFuseGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - case 10: // 1024 - SoftmaxMaskFuseGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - case 11: // 2048 - SoftmaxMaskFuseGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - case 12: // 4096 - SoftmaxMaskFuseGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - case 13: // 8192 - SoftmaxMaskFuseGradGPUKernel - <<>>(grad_y_data, - grad_x_data, - softmax_rst_data, - batch_count, - key_seq_len); - break; - default: - break; - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - fused_softmax_mask, - ops::SoftmaxMaskFuseKernel, - ops::SoftmaxMaskFuseKernel); -REGISTER_OP_CUDA_KERNEL( - fused_softmax_mask_grad, - ops::SoftmaxMaskFuseGradKernel, - ops::SoftmaxMaskFuseGradKernel); diff --git a/paddle/fluid/operators/fused_softmax_mask_op.h b/paddle/fluid/operators/fused_softmax_mask_op.h deleted file mode 100644 index 137dfb830de..00000000000 --- a/paddle/fluid/operators/fused_softmax_mask_op.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -template -class SoftmaxMaskFuseCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), - true, - platform::errors::Unimplemented( - "Softmax mask fuse op only supports GPU now.")); - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 7f3c91181aa..957d942afaa 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2342,6 +2342,28 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence, } } +void SoftmaxMaskFuseInferMeta(const MetaTensor& x, + const MetaTensor& mask, + MetaTensor* out) { + auto x_dims = x.dims(); + auto mask_dims = mask.dims(); + + PADDLE_ENFORCE_EQ( + x_dims.size(), + 4, + phi::errors::InvalidArgument("Input x must be in 4D dimension but " + "received the dimension of X is %d", + x_dims.size())); + PADDLE_ENFORCE_EQ( + mask_dims.size(), + 4, + phi::errors::InvalidArgument("Input mask must be in 4D dimension but " + "received the dimension of mask is %d", + mask_dims.size())); + + out->share_meta(x); +} + void SegmentPoolInferMeta(const MetaTensor& x, const MetaTensor& segment_ids, const std::string& pooltype, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index e91470d32b6..59fedfe2550 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -358,6 +358,10 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence, bool right, MetaTensor* out); +void SoftmaxMaskFuseInferMeta(const MetaTensor& x, + const MetaTensor& mask, + MetaTensor* out); + void SegmentPoolInferMeta(const MetaTensor& x, const MetaTensor& segment_ids, const std::string& pooltype, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 275b9ef031b..d60584f77dc 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -95,8 +95,8 @@ file( "kps/*.cu" "selected_rows/gpu/*.cu" "sparse/gpu/*.cu" - "strings/*.cu" - "strings/gpu/*.cu") + "strings/gpu/*.cu" + "fusion/gpu/*.cu") if(WITH_MKLDNN) file( @@ -110,7 +110,9 @@ if(WITH_MKLDNN) "sparse/cpu/*.cc" "strings/*.cc" "strings/cpu/*.cc" - "onednn/*.cc") + "onednn/*.cc" + "fusion/*.cc" + "fusion/cpu/*.cc") else() file( GLOB @@ -122,10 +124,12 @@ else() "sparse/*.cc" "sparse/cpu/*.cc" "strings/*.cc" - "strings/cpu/*.cc") + "strings/cpu/*.cc" + "fusion/*.cc" + "fusion/cpu/*.cc") endif() -file(GLOB kernel_xpu "xpu/*.cc" "selected_rows/xpu/*.cc") +file(GLOB kernel_xpu "xpu/*.cc" "selected_rows/xpu/*.cc" "fusion/xpu/*.cc") add_library(phi_cpu ${kernel_cc}) kernel_declare("${kernel_cc}") diff --git a/paddle/phi/kernels/fusion/README.md b/paddle/phi/kernels/fusion/README.md new file mode 100644 index 00000000000..2080a37dd0f --- /dev/null +++ b/paddle/phi/kernels/fusion/README.md @@ -0,0 +1,13 @@ +# What's difference for fusion kernel? + +1. We don't recommend to implement Python API for fusion kernel + + - We don't recommend to implement Python API for fusion kernel, because it contains many inputs or outputs arguments generally, it is difficult to use and understand as an Python API, we recommend to call fusion kernel by pass optimization in dy2static mode or static mode. + - We also don't recommend to reuse fusion kernel in other kernel implementation, but recommended that the fusion kernel be implemented by reusing other kernels. + +2. We don't require fusion kernel to have implementations for all devices + + - Fusion Kernel is generally used to accelerate the combined operation on a certain device. If all devices need to be implemented, the cost is relatively high. + - We don't recommend implementing a pseudo kernel that just throws exception, if not required, it can be not implemented. + +3. Fusion Kernel needs to be in the `phi/fusion` namespace diff --git a/paddle/phi/kernels/fusion/fused_softmax_mask_grad_kernel.h b/paddle/phi/kernels/fusion/fused_softmax_mask_grad_kernel.h new file mode 100644 index 00000000000..391c614801f --- /dev/null +++ b/paddle/phi/kernels/fusion/fused_softmax_mask_grad_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SoftmaxMaskFuseGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/fusion/fused_softmax_mask_kernel.h b/paddle/phi/kernels/fusion/fused_softmax_mask_kernel.h new file mode 100644 index 00000000000..dd08373f428 --- /dev/null +++ b/paddle/phi/kernels/fusion/fused_softmax_mask_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SoftmaxMaskFuseKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& mask, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_grad_kernel.cu new file mode 100644 index 00000000000..ab731f8f239 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_grad_kernel.cu @@ -0,0 +1,201 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/kernels/fusion/fused_softmax_mask_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/fusion/gpu/fused_softmax_mask_utils.h" + +namespace phi { +namespace fusion { + +template +__global__ void SoftmaxMaskFuseGradGPUKernel(const T* grad_input, + T* grad_output, + const T* softmax_rst, + int batch_count, + int key_seq_len) { + constexpr int next_pow2 = 1 << pow2_index; + constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; + constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4); + constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1; + constexpr int kOneLoadingCounts = 4; + + int data_first_idx = + (blockDim.y * blockIdx.x + threadIdx.y) * kLocalBatchSize; + + // batch_count might not be a multiple of kLocalBatchSize. Check how + // many batches have to computed within this WARP. + int local_batches = batch_count - data_first_idx; + if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize; + + // might be many batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int offset = data_first_idx * key_seq_len + kOneLoadingCounts * local_idx; + grad_input += offset; + grad_output += offset; + softmax_rst += offset; + + // using float for all inter compute + float grad_input_reg[kLocalBatchSize][kLocalIterations]{0.0f}; + float softmax_rst_reg[kLocalBatchSize][kLocalIterations]{0.0f}; + T temp_grad_input[kOneLoadingCounts]; + T temp_softmax_rst[kOneLoadingCounts]; + +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + int batch_data = (i >= local_batches) ? 0 : key_seq_len; + +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { + int data_index = kOneLoadingCounts * local_idx + ii * WARP_SIZE; + if (data_index < batch_data) { + load_data(temp_grad_input, + grad_input + i * key_seq_len + ii * warp_size); + load_data(temp_softmax_rst, + softmax_rst + i * key_seq_len + ii * warp_size); + +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + softmax_rst_reg[i][ii + counter] = + static_cast(temp_softmax_rst[counter]); + } +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + grad_input_reg[i][ii + counter] = + static_cast(temp_grad_input[counter]) * + softmax_rst_reg[i][ii + counter]; + } + } + } + } + + float samples_sum[kLocalBatchSize]; +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + samples_sum[i] = grad_input_reg[i][0]; +#pragma unroll + for (int ii = 1; ii < kLocalIterations; ++ii) { + samples_sum[i] += grad_input_reg[i][ii]; + } + } + warp_reduce(samples_sum); + +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { + int data_index = kOneLoadingCounts * local_idx + ii * warp_size; + if (data_index < key_seq_len) { + // compute gradients + T samples_out[kOneLoadingCounts]; +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + samples_out[counter] = + grad_input_reg[i][ii + counter] - + softmax_rst_reg[i][ii + counter] * samples_sum[i]; + } + load_data(grad_output + i * key_seq_len + ii * warp_size, samples_out); + } + } + } +} + +template +void SoftmaxMaskFuseGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + auto* grad_x_data = dev_ctx.template Alloc(x_grad); + auto* grad_y_data = out_grad.data(); + auto* softmax_rst_data = out.data(); + + auto y_dim = out_grad.dims(); + auto batches = y_dim[0]; + auto attn_heads = y_dim[1]; + auto query_seq_len = y_dim[2]; + auto key_seq_len = y_dim[3]; + + auto stream = dev_ctx.stream(); + + int pow2_index = get_pow2(key_seq_len); + const int next_pow2 = 1 << pow2_index; + int batch_count = batches * attn_heads * query_seq_len; + int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; + int batches_per_warp = (next_pow2 <= 128) ? 2 : 1; + // use 128 threads per block to maximum gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + + // launch the kernel based on the pow2_index + switch (pow2_index) { + case 5: // 32 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len); + break; + case 6: // 64 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len); + break; + case 7: // 128 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len); + break; + case 8: // 256 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len); + break; + case 9: // 512 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len); + break; + case 10: // 1024 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len); + break; + case 11: // 2048 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len); + break; + case 12: // 4096 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len); + break; + case 13: // 8192 + SoftmaxMaskFuseGradGPUKernel<<>>( + grad_y_data, grad_x_data, softmax_rst_data, batch_count, key_seq_len); + break; + default: + break; + } +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_softmax_mask_grad, + GPU, + ALL_LAYOUT, + phi::fusion::SoftmaxMaskFuseGradKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu new file mode 100644 index 00000000000..e86b4841e92 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu @@ -0,0 +1,280 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/fusion/gpu/fused_softmax_mask_utils.h" + +namespace phi { +namespace fusion { + +// T == fp16 +template +__global__ void SoftmaxMaskFuseGPUKernel(const T* x_data, + const T* mask_data, + T* y_data, + int batch_count, + int key_seq_len) { + // the forward gpu kernel + constexpr int next_pow2 = 1 << pow2_index; + constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; + constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4); + constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1; + constexpr int kOneLoadingCounts = 4; + + int data_first_idx = + (blockDim.y * + (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + + threadIdx.y) * + kLocalBatchSize; + + int mask_fist_idx = + (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * + kLocalBatchSize; + + // batch_count might not be a multiple of kLocalBatchSize. Check how + // many batches have to computed within this WARP. + int local_batches = batch_count - data_first_idx; + if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize; + + // might be many batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + int x_offset = data_first_idx * key_seq_len + kOneLoadingCounts * local_idx; + int mask_offset = mask_fist_idx * key_seq_len + kOneLoadingCounts * local_idx; + x_data += x_offset; + mask_data += mask_offset; + y_data += x_offset; + + // using float for all inter compute + float data[kLocalBatchSize][kLocalIterations]; + T temp_data[kOneLoadingCounts]; + T temp_mask[kOneLoadingCounts]; + +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + int batch_data = (i >= local_batches) ? 0 : key_seq_len; + +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { + int data_index = kOneLoadingCounts * local_idx + ii * warp_size; + + if (data_index < batch_data) { + int itr_idx = i * key_seq_len + ii * warp_size; + + // efficiently load data from global memory + load_data(temp_data, x_data + itr_idx); + load_data(temp_mask, mask_data + itr_idx); + +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + data[i][ii + counter] = static_cast(temp_data[counter]) + + static_cast(temp_mask[counter]); + } + } else { +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + data[i][ii + counter] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + // max value for each batch for current warp + float samples_max_value[kLocalBatchSize]; +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + samples_max_value[i] = data[i][0]; +#pragma unroll + for (int ii = 1; ii < kLocalIterations; ++ii) { + samples_max_value[i] = (samples_max_value[i] > data[i][ii]) + ? samples_max_value[i] + : data[i][ii]; + } + } + // max value for each batch for all warp + warp_reduce(samples_max_value); + + // compute the sum for each batch for current warp + float samples_sum[kLocalBatchSize]{0.0f}; +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ++ii) { + data[i][ii] = std::exp((data[i][ii] - samples_max_value[i])); + samples_sum[i] += data[i][ii]; + } + } + // samples_sum for each batch for all warp + warp_reduce(samples_sum); + + // load the result from device back to host + T samples_out[kOneLoadingCounts]; +#pragma unroll + for (int i = 0; i < kLocalBatchSize; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { + int idx = kOneLoadingCounts * local_idx + ii * warp_size; + if (idx < key_seq_len) { +#pragma unroll + for (int counter = 0; counter < kOneLoadingCounts; ++counter) { + samples_out[counter] = data[i][ii + counter] / samples_sum[i]; + } + load_data(y_data + i * key_seq_len + ii * warp_size, samples_out); + } else { + break; + } + } + } +} + +// T only supports fp16 +// leave as template only for future update +template +void SoftmaxMaskFuseKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& mask, + DenseTensor* out) { + auto* x_data = x.data(); + auto* mask_data = mask.data(); + auto* y_data = dev_ctx.template Alloc(out); + + auto x_dim = x.dims(); + auto mask_dim = mask.dims(); + auto batches = x_dim[0]; + auto attn_heads = x_dim[1]; + auto query_seq_len = x_dim[2]; + auto key_seq_len = x_dim[3]; + + PADDLE_ENFORCE_GT(query_seq_len, + 1, + phi::errors::InvalidArgument( + "Input x's second last dim must be large than 1 but " + "received the second last dimension of x is %d", + query_seq_len)); + + PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len < 8192, + true, + phi::errors::InvalidArgument( + "Input x's last dim must be between [32, 8192) " + "received the last dimension of x is %d", + key_seq_len)); + + PADDLE_ENFORCE_EQ(mask_dim[1], + 1, + phi::errors::InvalidArgument( + "Input mask's second dim must be 1 " + "received the second dimension of mask is %d", + mask_dim[1])); + + // dim of x and mask must be equal + for (size_t idx = 0; idx < 4; ++idx) { + if (idx == 1) continue; + PADDLE_ENFORCE_EQ( + x_dim[idx], + mask_dim[idx], + phi::errors::InvalidArgument( + "Input x's %dth dim should be equal with input mask's %dth dim " + "but " + "received the %dth dimension of x and mask are not equal " + "the %dth dim of x is %d, while the %dth dim of mask is %d.", + idx, + idx, + idx, + idx, + x_dim[idx], + idx, + mask_dim[idx])); + } + + auto stream = dev_ctx.stream(); + + int pow2_index = get_pow2(key_seq_len); + const int next_pow2 = 1 << pow2_index; + int batch_count = batches * attn_heads * query_seq_len; + int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; + int batches_per_warp = (next_pow2 <= 128) ? 2 : 1; + // use 128 threads per block to maximum gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + PADDLE_ENFORCE_EQ( + query_seq_len % batches_per_block, + 0, + phi::errors::InvalidArgument( + "The query seq len (third dim of input X) must can divide the " + "number of batches per block. The query seq len is %d, while " + "the number of batches per block is %d.", + query_seq_len, + batches_per_block)); + dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + + // launch the kernel based on the pow2_index + switch (pow2_index) { + case 5: // 32 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 6: // 64 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 7: // 128 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 8: // 256 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 9: // 512 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 10: // 1024 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 11: // 2048 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 12: // 4096 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + case 13: // 8192 + SoftmaxMaskFuseGPUKernel<<>>( + x_data, mask_data, y_data, batch_count, key_seq_len); + break; + default: + break; + } +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_softmax_mask, + GPU, + ALL_LAYOUT, + phi::fusion::SoftmaxMaskFuseKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_utils.h b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_utils.h new file mode 100644 index 00000000000..2847a4df839 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_softmax_mask_utils.h @@ -0,0 +1,95 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_CUDA +#include +#include +#endif +#ifdef PADDLE_WITH_HIP +#include +#include +#endif + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + +#ifdef PADDLE_WITH_HIP +#define WARP_SIZE 64 +#else +#define WARP_SIZE 32 +#endif + +#define MASK 0xffffffff + +namespace phi { +namespace fusion { + +__device__ __inline__ void load_data(dtype::float16* dst, + const dtype::float16* src) { + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); +} + +__device__ __inline__ void load_data(float* dst, const float* src) { + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); +} + +inline int get_pow2(int value) { + // get next pow2 index + int pow2_index = 0; + while ((1 << pow2_index) < value) { + ++pow2_index; + } + return pow2_index; +} + +template +struct AddOP { + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } +}; + +template +struct MaxOP { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T +warp_shfl_xor(T value, int laneMask, int width, unsigned int mask = MASK) { +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(T* sum) { + ReduceOp r; +#pragma unroll + for (int offset = width / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < batch; ++i) { + T b = warp_shfl_xor(sum[i], offset, width); + sum[i] = r(sum[i], b); + } + } +} + +} // namespace fusion +} // namespace phi + +#endif diff --git a/paddle/phi/ops/compat/fused_softmax_mask_sig.cc b/paddle/phi/ops/compat/fused_softmax_mask_sig.cc new file mode 100644 index 00000000000..415df81763a --- /dev/null +++ b/paddle/phi/ops/compat/fused_softmax_mask_sig.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SoftmaxMaskFuseGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "fused_softmax_mask_grad", {"Softmax", "Out@GRAD"}, {}, {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(fused_softmax_mask_grad, + phi::SoftmaxMaskFuseGradOpArgumentMapping); -- GitLab