// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/phi/kernels/funcs/aligned_vector.h" namespace phi { namespace fusion { template __global__ void VectorizedFusedRopeKernel(phi::Array ins_data, phi::Array sin_cos_data, bool flag_sin_cos, int sign, int64_t batch_size, int64_t seq_len, int64_t num_heads, int64_t head_dim, phi::Array outs_data, int num_inputs, MPType div_c) { int64_t index = (static_cast(blockIdx.x) * static_cast(blockDim.x) + threadIdx.x) * VecSize; int64_t stride = static_cast(gridDim.x) * static_cast(blockDim.x) * VecSize; int64_t size = batch_size * seq_len * num_heads * head_dim; MPType sin_value[VecSize]; MPType cos_value[VecSize]; MPType result[VecSize]; T store[VecSize]; using VecType = phi::AlignedVector; constexpr int kVectorsPerThread = VecSize / 2; for (; index < size; index += stride) { if (flag_sin_cos) { #pragma unroll for (int64_t nx = 0; nx < VecSize; ++nx) { int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); int64_t pos_seq = index_wc / (num_heads * head_dim); int64_t pos_head = index_wc % head_dim; int64_t index_sc = pos_seq * head_dim + pos_head; const T* sin_input = sin_cos_data[0] + index_sc; const T* cos_input = sin_cos_data[1] + index_sc; sin_value[nx] = static_cast(sin_input[0]); cos_value[nx] = static_cast(cos_input[0]); } } else { #pragma unroll for (int nx = 0; nx < VecSize; ++nx) { // get sin_index and cos_index int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); int64_t pos_seq = index_wc / (num_heads * head_dim); MPType idx = static_cast((index_wc % head_dim) / 2 * 2.0); MPType indicses = static_cast(1) / pow(static_cast(10000), idx * static_cast(div_c)); MPType value = pos_seq * indicses; sin_value[nx] = sin(value); cos_value[nx] = cos(value); } } #pragma unroll for (int iter = 0; iter < 3; iter++) { if (iter > num_inputs) break; const T* input = ins_data[iter] + index; VecType* out = reinterpret_cast(outs_data[iter] + index); #pragma unroll for (int nx = 0; nx < kVectorsPerThread; ++nx) { int pr_index = nx * 2; int ls_index = pr_index + 1; MPType p0 = static_cast(input[pr_index]); MPType p1 = static_cast(input[ls_index]); result[pr_index] = cos_value[pr_index] * p0 - sign * sin_value[ls_index] * p1; result[ls_index] = cos_value[ls_index] * p1 + sign * sin_value[pr_index] * p0; store[pr_index] = static_cast(result[pr_index]); store[ls_index] = static_cast(result[ls_index]); } out[0] = *(reinterpret_cast(store)); } } } } // namespace fusion } // namespace phi