fused_rope_grad_kernel.cu 4.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
21
#include "paddle/phi/kernels/fusion/gpu/fused_rope_utils.h"
22

23
namespace phi {
24
namespace fusion {
25 26 27

template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx,
28 29
                         const paddle::optional<DenseTensor>& sin,
                         const paddle::optional<DenseTensor>& cos,
30
                         const paddle::optional<DenseTensor>& position_ids,
31 32 33
                         const DenseTensor& dout_q,
                         const paddle::optional<DenseTensor>& dout_k,
                         const paddle::optional<DenseTensor>& dout_v,
34
                         bool use_neox_rotary_style,
35 36 37
                         DenseTensor* dq,
                         DenseTensor* dk,
                         DenseTensor* dv) {
38
  int64_t numel = dout_q.numel();
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
  if (numel <= 0) return;
  dev_ctx.template Alloc<T>(dq);
  // small size for broadcast
  auto batch_size = dout_q.dims()[0];
  auto num_heads = dout_q.dims()[2];
  auto head_dim = dout_q.dims()[3];
  auto seq_len = dout_q.dims()[1];
  PADDLE_ENFORCE_NE(head_dim % 2,
                    1,
                    phi::errors::InvalidArgument(
                        "The head_dim of input must be a multiple of 2."));

  constexpr const int vec_size = 2;

  auto config =
      phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);

56 57
  int64_t grid = config.block_per_grid.x;
  int64_t block = config.thread_per_block.x;
58 59 60 61
  auto stream = dev_ctx.stream();

  phi::Array<T*, 3> outs_data;
  phi::Array<const T*, 3> ins_data;
62
  phi::Array<const T*, 2> sin_cos_data;
63
  const int64_t* position_ids_data = NULL;
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

  ins_data[0] = dout_q.data<T>();
  outs_data[0] = dq->data<T>();
  int num_inputs = 0;

  if (dout_k.get_ptr()) {
    dev_ctx.template Alloc<T>(dk);
    outs_data[1] = dk->data<T>();
    ins_data[1] = dout_k->data<T>();
    num_inputs++;
  }

  if (dout_v.get_ptr()) {
    dev_ctx.template Alloc<T>(dv);
    outs_data[2] = dv->data<T>();
    ins_data[2] = dout_v->data<T>();
    num_inputs++;
  }

  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType div_c = static_cast<MPType>(1.0f / head_dim);

86 87 88 89 90 91
  bool flag_sin_cos = false;
  if (sin.get_ptr() && cos.get_ptr()) {
    sin_cos_data[0] = sin->data<T>();
    sin_cos_data[1] = cos->data<T>();

    flag_sin_cos = true;
92 93 94 95

    if (position_ids.get_ptr()) {
      position_ids_data = position_ids->data<int64_t>();
    }
96 97 98
  }

  int sign = -1;
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
  if (use_neox_rotary_style) {
    VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size>
        <<<grid, block, 0, stream>>>(ins_data,
                                     sin_cos_data,
                                     position_ids_data,
                                     flag_sin_cos,
                                     sign,
                                     batch_size,
                                     seq_len,
                                     num_heads,
                                     head_dim,
                                     outs_data,
                                     num_inputs,
                                     div_c);
  } else {
    VectorizedFusedRopeWithRotateHalfKernel<T, MPType, vec_size>
        <<<grid, block, 0, stream>>>(ins_data,
                                     sin_cos_data,
                                     position_ids_data,
                                     flag_sin_cos,
                                     sign,
                                     batch_size,
                                     seq_len,
                                     num_heads,
                                     head_dim,
                                     outs_data,
                                     num_inputs,
                                     div_c);
  }
128
}
129 130

}  // namespace fusion
131 132
}  // namespace phi

133
PD_REGISTER_KERNEL(fused_rotary_position_embedding_grad,
134 135
                   GPU,
                   ALL_LAYOUT,
136
                   phi::fusion::FusedRopeGradKernel,
137 138 139
                   float,
                   double,
                   phi::dtype::float16,
140
                   phi::dtype::bfloat16){};