multihead_matmul_op.cu 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cuda_runtime.h>
#include <paddle/fluid/platform/device_context.h>
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
21
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
#include "paddle/fluid/operators/math/blas.h"

namespace paddle {
namespace operators {

template <typename T>
__global__ void transpose(T *src, T *dst, const int batch_size,
                          const int seq_len, const int head_num,
                          const int size_per_head) {
  int batch_id = blockIdx.x / (head_num * seq_len);
  int seq_id = blockIdx.x % seq_len;
  int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len;
  dst[batch_id * (head_num * seq_len * size_per_head) +
      seq_id * head_num * size_per_head + head_id * size_per_head +
      threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x];
}

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
template <typename T>
inline __device__ T add_func(T a, T b);

template <>
__device__ float add_func<float>(float a, float b) {
  return a + b;
}

template <>
__device__ float2 add_func<float2>(float2 a, float2 b) {
  float2 c;
  c.x = a.x + b.x;
  c.y = a.y + b.y;
  return c;
}

template <>
__device__ float4 add_func<float4>(float4 a, float4 b) {
  float4 c;
  c.x = a.x + b.x;
  c.y = a.y + b.y;
  c.z = a.z + b.z;
  c.w = a.w + b.w;
  return c;
63 64 65
}

template <typename T>
66 67
__global__ void TransposeQkvKernel(const int H, const T *input, const T *bias,
                                   T *output) {
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
  // Input: BxSx3xNxH
  // Bias: 3xSxB
  // Output: 3xBxNxSxH
  int n = threadIdx.y;
  int s = blockIdx.x;
  int b = blockIdx.y;
  int m = blockIdx.z;

  const int N = blockDim.y;
  const int S = gridDim.x;
  const int B = gridDim.y;

  const int NH = N * H;
  const int NHS = NH * S;
  const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3;
  const int bias_offset = m * NH + n * H;
  const int out_offset = s * H + n * S * H + b * NHS + m * NHS * B;

  const int i = threadIdx.x;
  output[out_offset + i] =
      add_func(input[in_offset + i], bias[bias_offset + i]);
}
90

91 92 93 94
void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
                      const int head_num, const float *input, const float *bias,
                      float *output, cudaStream_t stream) {
  // BxSx3xNxH + 3xNxH -> 3xBxNxSxH
Z
Zhaolong Xing 已提交
95
  int scratch_size = batch * head_num * seq_len * seq_len;
96
  const dim3 grid(seq_len, batch, 3);
Z
Zhaolong Xing 已提交
97 98
  // scratch % 4 == 0 to ensure the alignment
  if (head_size % 4 == 0 && scratch_size % 4 == 0) {
99 100 101 102 103 104 105 106 107 108 109
    const int h = head_size / 4;
    const float4 *input4 = reinterpret_cast<const float4 *>(input);
    const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
    float4 *output4 = reinterpret_cast<float4 *>(output);
    const dim3 block(h, head_num, 1);

    // limit h * head_num to max block size(1024).
    PADDLE_ENFORCE_LE(h * head_num, 1024,
                      platform::errors::InvalidArgument(
                          "head_num (%d) * head_size (%d) should <= %d",
                          head_num, head_size, 1024 * 4));
110 111
    TransposeQkvKernel<float4><<<grid, block, 0, stream>>>(h, input4, bias4,
                                                           output4);
Z
Zhaolong Xing 已提交
112
  } else if (head_size % 2 == 0 && scratch_size % 2 == 0) {
113 114 115 116 117 118 119 120 121 122
    const int h = head_size / 2;
    const float2 *input2 = reinterpret_cast<const float2 *>(input);
    const float2 *bias2 = reinterpret_cast<const float2 *>(bias);
    float2 *output2 = reinterpret_cast<float2 *>(output);
    const dim3 block(h, head_num, 1);
    // limit h * head_num to max block size(1024).
    PADDLE_ENFORCE_LE(h * head_num, 1024,
                      platform::errors::InvalidArgument(
                          "head_num (%d) * head_size (%d) should <= %d",
                          head_num, head_size, 1024 * 2));
123 124
    TransposeQkvKernel<float2><<<grid, block, 0, stream>>>(h, input2, bias2,
                                                           output2);
125 126 127 128 129 130 131
  } else {
    const dim3 block(head_size, head_num, 1);
    // limit head_size * head_num to max block size(1024).
    PADDLE_ENFORCE_LE(head_size * head_num, 1024,
                      platform::errors::InvalidArgument(
                          "head_num (%d) * head_size (%d) should <= %d",
                          head_num, head_size, 1024));
132 133
    TransposeQkvKernel<float><<<grid, block, 0, stream>>>(head_size, input,
                                                          bias, output);
134 135
  }
}
136 137

template <typename DeviceContext, typename T>
138
class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
139 140
 public:
  void Compute(const framework::ExecutionContext &context) const override {
141 142 143 144
    using Tensor = framework::Tensor;
    auto *input = context.Input<framework::Tensor>("Input");
    auto *w = context.Input<framework::Tensor>("W");
    auto *bias = context.Input<framework::Tensor>("Bias");
145 146 147 148

    auto &bias_qk = detail::Ref(context.Input<framework::Tensor>("BiasQK"),
                                "Cannot find QK");

149 150 151 152
    auto *input_d = input->data<T>();
    auto *w_d = w->data<T>();
    auto *bias_d = bias->data<T>();
    auto *bias_qk_d = bias_qk.data<T>();
153 154 155 156 157
    T scale = static_cast<T>(context.Attr<float>("alpha"));

    int head_number = context.Attr<int>("head_number");
    // compute q*k with eltadd
    auto &device_ctx = context.template device_context<DeviceContext>();
158 159 160 161 162 163 164 165 166 167 168
    // should be (B * S * hidden)
    auto input_dims = input->dims();
    // shouble be (hidden * 3 * all_head_size)
    auto w_dims = w->dims();
    int batch = input_dims[0];
    int seq_len = input_dims[1];
    int hidden = input_dims[2];

    int all_head_size = w_dims[2];
    int head_size = all_head_size / head_number;

169 170 171 172
    auto *out = context.Output<framework::Tensor>("Out");
    out->Resize({batch, seq_len, all_head_size});
    auto *output_d = out->mutable_data<T>(context.GetPlace());

173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
    // (B*S, hidden)
    const Tensor input_matrix =
        framework::ReshapeToMatrix(*input, 2 /*x_num_col_dims */);
    // (hidden, 3 * all_head_size)
    const Tensor w_matrix =
        framework::ReshapeToMatrix(*w, 1 /*y_num_col_dims*/);

    Tensor temp_out_tensor;
    auto temp_out_dims =
        framework::make_ddim({batch, seq_len, 3, head_number, head_size});
    temp_out_tensor.Resize({batch * seq_len, framework::product(temp_out_dims) /
                                                 (batch * seq_len)});
    auto *temp_out_data = temp_out_tensor.mutable_data<T>(context.GetPlace());

    // (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)
    auto blas = math::GetBlas<platform::CUDADeviceContext, T>(device_ctx);
    blas.MatMul(input_matrix, w_matrix, &temp_out_tensor);

    // temp_out_tensor.Resize(temp_out_dims);

    Tensor multihead_temp_tensor;
    // B * head_number * S * S * 1 + B * S * 3 * N * H
    int scratch_size = batch * head_number * seq_len * seq_len * 1;
    multihead_temp_tensor.Resize({scratch_size + temp_out_tensor.numel()});
    auto *multihead_temp_data =
        multihead_temp_tensor.mutable_data<T>(context.GetPlace());
    auto *qkptr = multihead_temp_data;
    auto *tptr = multihead_temp_data + scratch_size;

    auto stream = device_ctx.stream();
    // Do the transpose with bias.
    // BxSx3xNxH => tptr: 3xBxNxSxH.
    TransQKVWithBias(batch, seq_len, head_size, head_number, temp_out_data,
                     bias_d, tptr, stream);

208 209 210
    math::MultiHeadGPUComputeFunctor<T> multihead_compute_func;
    multihead_compute_func(device_ctx, batch, seq_len, head_number, head_size,
                           qkptr, bias_qk_d, tptr, scale, T(0.0));
211 212 213 214 215

    int grid = batch * head_number * seq_len;
    int block = head_size;
    transpose<T><<<grid, block, 0, stream>>>(tptr, output_d, batch, seq_len,
                                             head_number, head_size);
216 217 218 219 220 221 222 223 224
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    multihead_matmul,
225
    ops::MultiHeadMatMulV2Kernel<paddle::platform::CUDADeviceContext, float>);