multihead_matmul_op.cu 14.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <paddle/fluid/platform/device_context.h>
16

17
#include <algorithm>
W
Wilber 已提交
18
#include <type_traits>
19

20 21
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
22
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
W
Wilber 已提交
23
#include "paddle/fluid/platform/float16.h"
24
#include "paddle/phi/kernels/funcs/blas/blas.h"
25 26 27 28 29

namespace paddle {
namespace operators {

template <typename T>
30 31 32 33 34
__global__ void transpose(T *src,
                          T *dst,
                          const int batch_size,
                          const int seq_len,
                          const int head_num,
35 36 37 38 39 40 41 42 43
                          const int size_per_head) {
  int batch_id = blockIdx.x / (head_num * seq_len);
  int seq_id = blockIdx.x % seq_len;
  int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len;
  dst[batch_id * (head_num * seq_len * size_per_head) +
      seq_id * head_num * size_per_head + head_id * size_per_head +
      threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x];
}

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
template <typename T>
inline __device__ T add_func(T a, T b);

template <>
__device__ float add_func<float>(float a, float b) {
  return a + b;
}

template <>
__device__ float2 add_func<float2>(float2 a, float2 b) {
  float2 c;
  c.x = a.x + b.x;
  c.y = a.y + b.y;
  return c;
}

template <>
__device__ float4 add_func<float4>(float4 a, float4 b) {
  float4 c;
  c.x = a.x + b.x;
  c.y = a.y + b.y;
  c.z = a.z + b.z;
  c.w = a.w + b.w;
  return c;
68
}
W
Wilber 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
#if defined(PADDLE_WITH_CUDA)
template <>
__device__ half2 add_func<half2>(half2 a, half2 b) {
#if __CUDA_ARCH__ >= 530
  return __hadd2(a, b);
#else
  return half2(__float2half(__half2float(a.x) + __half2float(b.x)),
               __float2half(__half2float(b.x) + __half2float(b.y)));
#endif
}

template <>
__device__ half add_func<half>(half a, half b) {
#if __CUDA_ARCH__ >= 530
  return __hadd(a, b);
#else
  return __float2half(__half2float(a) + __half2float(b));
#endif
}
#endif
89 90

template <typename T>
91 92 93
__global__ void TransposeQkvKernel(const int H,
                                   const T *input,
                                   const T *bias,
94
                                   T *output) {
95
  // Input: BxSx3xNxH
W
Wilber 已提交
96
  // Bias: 3xNxH
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
  // Output: 3xBxNxSxH
  int n = threadIdx.y;
  int s = blockIdx.x;
  int b = blockIdx.y;
  int m = blockIdx.z;

  const int N = blockDim.y;
  const int S = gridDim.x;
  const int B = gridDim.y;

  const int NH = N * H;
  const int NHS = NH * S;
  const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3;
  const int bias_offset = m * NH + n * H;
  const int out_offset = s * H + n * S * H + b * NHS + m * NHS * B;

  const int i = threadIdx.x;
  output[out_offset + i] =
      add_func(input[in_offset + i], bias[bias_offset + i]);
}
117

W
Wilber 已提交
118 119 120 121 122 123 124 125 126 127 128
template <typename T>
void TransQKVWithBias(const int batch,
                      const int seq_len,
                      const int head_size,
                      const int head_num,
                      const T *input,
                      const T *bias,
                      T *output,
                      gpuStream_t stream);

template <>
129 130 131 132 133 134 135 136
void TransQKVWithBias(const int batch,
                      const int seq_len,
                      const int head_size,
                      const int head_num,
                      const float *input,
                      const float *bias,
                      float *output,
                      gpuStream_t stream) {
137
  // BxSx3xNxH + 3xNxH -> 3xBxNxSxH
Z
Zhaolong Xing 已提交
138
  int scratch_size = batch * head_num * seq_len * seq_len;
139
  const dim3 grid(seq_len, batch, 3);
Z
Zhaolong Xing 已提交
140 141
  // scratch % 4 == 0 to ensure the alignment
  if (head_size % 4 == 0 && scratch_size % 4 == 0) {
142 143 144 145 146 147 148
    const int h = head_size / 4;
    const float4 *input4 = reinterpret_cast<const float4 *>(input);
    const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
    float4 *output4 = reinterpret_cast<float4 *>(output);
    const dim3 block(h, head_num, 1);

    // limit h * head_num to max block size(1024).
149 150
    PADDLE_ENFORCE_LE(h * head_num,
                      1024,
151 152
                      platform::errors::InvalidArgument(
                          "head_num (%d) * head_size (%d) should <= %d",
153 154 155
                          head_num,
                          head_size,
                          1024 * 4));
156 157
    TransposeQkvKernel<float4>
        <<<grid, block, 0, stream>>>(h, input4, bias4, output4);
Z
Zhaolong Xing 已提交
158
  } else if (head_size % 2 == 0 && scratch_size % 2 == 0) {
159 160 161 162 163 164
    const int h = head_size / 2;
    const float2 *input2 = reinterpret_cast<const float2 *>(input);
    const float2 *bias2 = reinterpret_cast<const float2 *>(bias);
    float2 *output2 = reinterpret_cast<float2 *>(output);
    const dim3 block(h, head_num, 1);
    // limit h * head_num to max block size(1024).
165 166
    PADDLE_ENFORCE_LE(h * head_num,
                      1024,
167 168
                      platform::errors::InvalidArgument(
                          "head_num (%d) * head_size (%d) should <= %d",
169 170 171
                          head_num,
                          head_size,
                          1024 * 2));
172 173
    TransposeQkvKernel<float2>
        <<<grid, block, 0, stream>>>(h, input2, bias2, output2);
174 175 176
  } else {
    const dim3 block(head_size, head_num, 1);
    // limit head_size * head_num to max block size(1024).
177 178
    PADDLE_ENFORCE_LE(head_size * head_num,
                      1024,
179 180
                      platform::errors::InvalidArgument(
                          "head_num (%d) * head_size (%d) should <= %d",
181 182 183
                          head_num,
                          head_size,
                          1024));
184 185
    TransposeQkvKernel<float>
        <<<grid, block, 0, stream>>>(head_size, input, bias, output);
186 187
  }
}
188

W
Wilber 已提交
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
#if defined(PADDLE_WITH_CUDA)
template <>
void TransQKVWithBias(const int batch,
                      const int seq_len,
                      const int head_size,
                      const int head_num,
                      const platform::float16 *input,
                      const platform::float16 *bias,
                      platform::float16 *output,
                      gpuStream_t stream) {
  // BxSx3xNxH + 3xNxH -> 3xBxNxSxH
  int scratch_size = batch * head_num * seq_len * seq_len;
  const dim3 grid(seq_len, batch, 3);
  if (head_size % 2 == 0 && scratch_size % 2 == 0) {
    const int h = head_size / 2;
    const half2 *input2 = reinterpret_cast<const half2 *>(input);
    const half2 *bias2 = reinterpret_cast<const half2 *>(bias);
    half2 *output2 = reinterpret_cast<half2 *>(output);
    const dim3 block(h, head_num, 1);
    // limit h * head_num to max block size(1024).
    PADDLE_ENFORCE_LE(h * head_num,
                      1024,
                      platform::errors::InvalidArgument(
                          "head_num (%d) * head_size (%d) should <= %d",
                          head_num,
                          head_size,
                          1024 * 2));
    TransposeQkvKernel<half2>
        <<<grid, block, 0, stream>>>(h, input2, bias2, output2);
  } else {
    const dim3 block(head_size, head_num, 1);
    const half *input_half = reinterpret_cast<const half *>(input);
    const half *bias_half = reinterpret_cast<const half *>(bias);
    half *output_half = reinterpret_cast<half *>(output);

    // limit head_size * head_num to max block size(1024).
    PADDLE_ENFORCE_LE(head_size * head_num,
                      1024,
                      platform::errors::InvalidArgument(
                          "head_num (%d) * head_size (%d) should <= %d",
                          head_num,
                          head_size,
                          1024));
    TransposeQkvKernel<half><<<grid, block, 0, stream>>>(
        head_size, input_half, bias_half, output_half);
  }
}
#endif

F
feng_shuai 已提交
238 239
inline int round_up(int seq_len, int multiple = 32) {
  PADDLE_ENFORCE_GT(
240 241
      multiple,
      0,
F
feng_shuai 已提交
242 243 244 245 246 247
      platform::errors::InvalidArgument(
          "multiple should be a positive number,but it's (%d)", multiple));
  return ((seq_len + multiple - 1) / multiple) * multiple;
}

template <typename T>
248 249 250
__global__ void broadcast(const T *src,
                          T *dst,
                          const int seq_len,
F
feng_shuai 已提交
251 252 253 254 255 256 257 258
                          const int head_num) {
  int batch_id = blockIdx.x / (head_num * seq_len);
  int dst_offset = blockIdx.x * seq_len;
  if (threadIdx.x < seq_len) {
    dst[threadIdx.x + dst_offset] = src[threadIdx.x + batch_id * seq_len];
  }
}

259
template <typename DeviceContext, typename T>
260
class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
261 262
 public:
  void Compute(const framework::ExecutionContext &context) const override {
263 264 265 266
    using Tensor = framework::Tensor;
    auto *input = context.Input<framework::Tensor>("Input");
    auto *w = context.Input<framework::Tensor>("W");
    auto *bias = context.Input<framework::Tensor>("Bias");
267
    auto &bias_qk = GET_DATA_SAFELY(context.Input<framework::Tensor>("BiasQK"),
268 269 270
                                    "Input",
                                    "BiasQK",
                                    "MultiHeadMatMulV2");
271

272 273 274
    auto *input_d = input->data<T>();
    auto *w_d = w->data<T>();
    auto *bias_d = bias->data<T>();
275
    auto *bias_qk_d = bias_qk.template data<T>();
276 277 278 279 280
    T scale = static_cast<T>(context.Attr<float>("alpha"));

    int head_number = context.Attr<int>("head_number");
    // compute q*k with eltadd
    auto &device_ctx = context.template device_context<DeviceContext>();
F
feng_shuai 已提交
281
    auto stream = device_ctx.stream();
282 283 284 285 286 287 288
    // should be (B * S * hidden)
    auto input_dims = input->dims();
    // shouble be (hidden * 3 * all_head_size)
    auto w_dims = w->dims();
    int batch = input_dims[0];
    int seq_len = input_dims[1];
    int hidden = input_dims[2];
F
feng_shuai 已提交
289 290 291 292 293 294 295
    Tensor temp_bias_tensor;
    // if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
    if (bias_qk.numel() == (batch * seq_len)) {
      temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
      auto *temp_qk_bias = temp_bias_tensor.mutable_data<T>(context.GetPlace());
      int grid = batch * head_number * seq_len;
      int block = round_up(seq_len);
296 297
      broadcast<<<grid, block, 0, stream>>>(
          bias_qk_d, temp_qk_bias, seq_len, head_number);
F
feng_shuai 已提交
298 299
      bias_qk_d = static_cast<const T *>(temp_qk_bias);
    }
300 301 302
    int all_head_size = w_dims[2];
    int head_size = all_head_size / head_number;

303 304 305 306
    auto *out = context.Output<framework::Tensor>("Out");
    out->Resize({batch, seq_len, all_head_size});
    auto *output_d = out->mutable_data<T>(context.GetPlace());

307 308 309 310 311 312 313 314 315
    // (B*S, hidden)
    const Tensor input_matrix =
        framework::ReshapeToMatrix(*input, 2 /*x_num_col_dims */);
    // (hidden, 3 * all_head_size)
    const Tensor w_matrix =
        framework::ReshapeToMatrix(*w, 1 /*y_num_col_dims*/);

    Tensor temp_out_tensor;
    auto temp_out_dims =
316
        phi::make_ddim({batch, seq_len, 3, head_number, head_size});
317
    temp_out_tensor.Resize(
318
        {batch * seq_len, phi::product(temp_out_dims) / (batch * seq_len)});
319 320 321
    auto *temp_out_data = temp_out_tensor.mutable_data<T>(context.GetPlace());

    // (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)
L
Leo Chen 已提交
322
    auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(device_ctx);
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
    blas.MatMul(input_matrix, w_matrix, &temp_out_tensor);

    // temp_out_tensor.Resize(temp_out_dims);

    Tensor multihead_temp_tensor;
    // B * head_number * S * S * 1 + B * S * 3 * N * H
    int scratch_size = batch * head_number * seq_len * seq_len * 1;
    multihead_temp_tensor.Resize({scratch_size + temp_out_tensor.numel()});
    auto *multihead_temp_data =
        multihead_temp_tensor.mutable_data<T>(context.GetPlace());
    auto *qkptr = multihead_temp_data;
    auto *tptr = multihead_temp_data + scratch_size;

    // Do the transpose with bias.
    // BxSx3xNxH => tptr: 3xBxNxSxH.
338 339 340 341 342 343 344 345
    TransQKVWithBias(batch,
                     seq_len,
                     head_size,
                     head_number,
                     temp_out_data,
                     bias_d,
                     tptr,
                     stream);
W
Wilber 已提交
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370
    if (std::is_same<T, platform::float16>::value) {
      math::MultiHeadGPUComputeFunctor<half> multihead_compute_func;
      multihead_compute_func(device_ctx,
                             batch,
                             seq_len,
                             head_number,
                             head_size,
                             reinterpret_cast<half *>(qkptr),
                             reinterpret_cast<const half *>(bias_qk_d),
                             reinterpret_cast<half *>(tptr),
                             __float2half(static_cast<float>(scale)),
                             __float2half(0.0));
    } else {
      math::MultiHeadGPUComputeFunctor<T> multihead_compute_func;
      multihead_compute_func(device_ctx,
                             batch,
                             seq_len,
                             head_number,
                             head_size,
                             qkptr,
                             bias_qk_d,
                             tptr,
                             scale,
                             T(0.0));
    }
371 372 373

    int grid = batch * head_number * seq_len;
    int block = head_size;
374 375
    transpose<T><<<grid, block, 0, stream>>>(
        tptr, output_d, batch, seq_len, head_number, head_size);
376 377 378 379 380 381 382
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
W
Wilber 已提交
383 384 385 386 387 388
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
REGISTER_OP_CUDA_KERNEL(
    multihead_matmul,
    ops::MultiHeadMatMulV2Kernel<phi::GPUContext, paddle::platform::float16>,
    ops::MultiHeadMatMulV2Kernel<phi::GPUContext, float>);
#else
L
Leo Chen 已提交
389 390
REGISTER_OP_CUDA_KERNEL(multihead_matmul,
                        ops::MultiHeadMatMulV2Kernel<phi::GPUContext, float>);
W
Wilber 已提交
391
#endif