// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/fused_rope_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" namespace phi { template __global__ void VectorizedFusedRopeKernel(phi::Array ins_data, int batch_size, int seq_len, int num_heads, int head_dim, phi::Array outs_data, int num_inputs, MPType div_c) { int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; int stride = gridDim.x * blockDim.x * VecSize; int size = batch_size * seq_len * num_heads * head_dim; MPType sin_value[VecSize]; MPType cos_value[VecSize]; MPType result[VecSize]; T store[VecSize]; using VecType = phi::AlignedVector; constexpr int kVectorsPerThread = VecSize / 2; for (; index < size; index += stride) { #pragma unroll for (int nx = 0; nx < VecSize; ++nx) { // get sin_index and cos_index int index_wc = (index + nx) % (seq_len * num_heads * head_dim); int pos_seq = index_wc / (num_heads * head_dim); MPType idx = static_cast((index_wc % head_dim) / 2 * 2.0); MPType indicses = static_cast(1) / pow(static_cast(10000), idx * static_cast(div_c)); MPType value = pos_seq * indicses; sin_value[nx] = sin(value); cos_value[nx] = cos(value); } #pragma unroll for (int iter = 0; iter < 3; iter++) { if (iter > num_inputs) break; const T* input = ins_data[iter] + index; VecType* out = reinterpret_cast(outs_data[iter] + index); #pragma unroll for (int nx = 0; nx < kVectorsPerThread; ++nx) { int pr_index = nx * 2; int ls_index = pr_index + 1; MPType p0 = static_cast(input[pr_index]); MPType p1 = static_cast(input[ls_index]); result[pr_index] = cos_value[pr_index] * p0; result[pr_index] -= sin_value[pr_index] * p1; result[ls_index] = sin_value[ls_index] * p0; result[ls_index] += cos_value[ls_index] * p1; store[pr_index] = static_cast(result[pr_index]); store[ls_index] = static_cast(result[ls_index]); } out[0] = *(reinterpret_cast(store)); } } } template void FusedRopeKernel(const Context& dev_ctx, const DenseTensor& q, const paddle::optional& k, const paddle::optional& v, DenseTensor* out_q, DenseTensor* out_k, DenseTensor* out_v) { int numel = q.numel(); if (numel <= 0) return; dev_ctx.template Alloc(out_q); out_q->Resize(q.dims()); // small size for broadcast auto batch_size = q.dims()[0]; auto num_heads = q.dims()[2]; auto head_dim = q.dims()[3]; auto seq_len = q.dims()[1]; PADDLE_ENFORCE_NE(head_dim % 2, 1, phi::errors::InvalidArgument( "The head_dim of input must be a multiple of 2.")); constexpr const int vec_size = 2; auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size); int grid = config.block_per_grid.x; int block = config.thread_per_block.x; auto stream = dev_ctx.stream(); phi::Array outs_data; phi::Array ins_data; ins_data[0] = q.data(); outs_data[0] = out_q->data(); int num_inputs = 0; if (k.get_ptr()) { dev_ctx.template Alloc(out_k); out_k->Resize(q.dims()); ins_data[1] = k->data(); outs_data[1] = out_k->data(); num_inputs++; } if (v.get_ptr()) { dev_ctx.template Alloc(out_v); out_v->Resize(q.dims()); ins_data[2] = v->data(); outs_data[2] = out_v->data(); num_inputs++; } using MPType = typename phi::dtype::MPTypeTrait::Type; MPType div_c = static_cast(1.0f / head_dim); VectorizedFusedRopeKernel <<>>(ins_data, batch_size, seq_len, num_heads, head_dim, outs_data, num_inputs, div_c); } } // namespace phi PD_REGISTER_KERNEL(fused_rope, GPU, ALL_LAYOUT, phi::FusedRopeKernel, float, double, phi::dtype::float16, phi::dtype::bfloat16) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); }