fused_rope_utils.h 4.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/kernels/funcs/aligned_vector.h"

namespace phi {
namespace fusion {

template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
                                          phi::Array<const T*, 2> sin_cos_data,
                                          bool flag_sin_cos,
                                          int sign,
27 28 29 30
                                          int64_t batch_size,
                                          int64_t seq_len,
                                          int64_t num_heads,
                                          int64_t head_dim,
31 32 33
                                          phi::Array<T*, 3> outs_data,
                                          int num_inputs,
                                          MPType div_c) {
34 35 36 37 38 39 40
  int64_t index =
      (static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
       threadIdx.x) *
      VecSize;
  int64_t stride = static_cast<int64_t>(gridDim.x) *
                   static_cast<int64_t>(blockDim.x) * VecSize;
  int64_t size = batch_size * seq_len * num_heads * head_dim;
41 42 43 44 45 46 47 48 49 50
  MPType sin_value[VecSize];
  MPType cos_value[VecSize];
  MPType result[VecSize];
  T store[VecSize];
  using VecType = phi::AlignedVector<T, VecSize>;
  constexpr int kVectorsPerThread = VecSize / 2;

  for (; index < size; index += stride) {
    if (flag_sin_cos) {
#pragma unroll
51 52 53 54 55
      for (int64_t nx = 0; nx < VecSize; ++nx) {
        int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim);
        int64_t pos_seq = index_wc / (num_heads * head_dim);
        int64_t pos_head = index_wc % head_dim;
        int64_t index_sc = pos_seq * head_dim + pos_head;
56 57 58 59 60 61 62 63 64 65
        const T* sin_input = sin_cos_data[0] + index_sc;
        const T* cos_input = sin_cos_data[1] + index_sc;

        sin_value[nx] = static_cast<MPType>(sin_input[0]);
        cos_value[nx] = static_cast<MPType>(cos_input[0]);
      }
    } else {
#pragma unroll
      for (int nx = 0; nx < VecSize; ++nx) {
        // get sin_index and cos_index
66 67
        int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim);
        int64_t pos_seq = index_wc / (num_heads * head_dim);
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
        MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0);
        MPType indicses =
            static_cast<MPType>(1) /
            pow(static_cast<MPType>(10000), idx * static_cast<MPType>(div_c));
        MPType value = pos_seq * indicses;
        sin_value[nx] = sin(value);
        cos_value[nx] = cos(value);
      }
    }

#pragma unroll
    for (int iter = 0; iter < 3; iter++) {
      if (iter > num_inputs) break;
      const T* input = ins_data[iter] + index;
      VecType* out = reinterpret_cast<VecType*>(outs_data[iter] + index);

#pragma unroll
      for (int nx = 0; nx < kVectorsPerThread; ++nx) {
        int pr_index = nx * 2;
        int ls_index = pr_index + 1;

        MPType p0 = static_cast<MPType>(input[pr_index]);
        MPType p1 = static_cast<MPType>(input[ls_index]);

        result[pr_index] =
            cos_value[pr_index] * p0 - sign * sin_value[ls_index] * p1;
        result[ls_index] =
            cos_value[ls_index] * p1 + sign * sin_value[pr_index] * p0;

        store[pr_index] = static_cast<T>(result[pr_index]);
        store[ls_index] = static_cast<T>(result[ls_index]);
      }
      out[0] = *(reinterpret_cast<VecType*>(store));
    }
  }
}

}  // namespace fusion
}  // namespace phi