未验证 提交 79bfb184 编写于 作者: Y Yuanle Liu 提交者: GitHub

multihead_matmul op support codegen and kernel remove to phi (#56846)

上级 7fd6ffb8
......@@ -22,9 +22,9 @@
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/multihead_matmul_functor.h"
namespace paddle {
namespace inference {
......@@ -254,7 +254,7 @@ int MultiheadMatmulRoformerPlugin::enqueue(
platform::CUDAPlace(device_id)));
const phi::GPUContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<float> multihead_compute_func;
phi::funcs::MultiheadGPUComputeFunctor<float> multihead_compute_func;
multihead_compute_func(dev_ctx,
batch,
seq_len,
......@@ -341,7 +341,7 @@ int MultiheadMatmulRoformerPlugin::enqueue(
tptr, static_cast<half>(scale_), n_q);
const phi::GPUContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<half> multihead_compute_func;
phi::funcs::MultiheadGPUComputeFunctor<half> multihead_compute_func;
multihead_compute_func(dev_ctx,
batch,
seq_len,
......
......@@ -24,9 +24,9 @@
#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh"
#include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/multihead_matmul_functor.h"
namespace paddle {
namespace inference {
......@@ -396,7 +396,7 @@ int QkvToContextPluginDynamic::enqueue(
platform::CUDAPlace(device_id)));
const phi::GPUContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<float> multihead_compute_func;
phi::funcs::MultiheadGPUComputeFunctor<float> multihead_compute_func;
multihead_compute_func(dev_ctx,
batch,
seq_len,
......@@ -506,7 +506,7 @@ int QkvToContextPluginDynamic::enqueue(
tptr, static_cast<half>(scale_), n_q);
const phi::GPUContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<half> multihead_compute_func;
phi::funcs::MultiheadGPUComputeFunctor<half> multihead_compute_func;
multihead_compute_func(dev_ctx,
batch,
seq_len,
......
......@@ -10,7 +10,6 @@ register_operators(
fusion_transpose_flatten_concat_op
fusion_conv_inception_op
fused_fc_elementwise_layernorm_op
multihead_matmul_op
self_dp_attention_op
skip_layernorm_op
yolo_box_head_op
......@@ -74,8 +73,6 @@ if(WITH_GPU OR WITH_ROCM)
endif()
# fused_fc_elementwise_layernorm_op
op_library(fused_fc_elementwise_layernorm_op)
# multihead_matmul_op
op_library(multihead_matmul_op)
op_library(skip_layernorm_op)
op_library(yolo_box_head_op)
op_library(yolo_box_post_op)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace operators {
class MultiHeadMatMulV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(
context->HasInput("Input"),
true,
platform::errors::InvalidArgument(
"Input(Input) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ(context->HasInput("W"),
true,
platform::errors::InvalidArgument(
"Input(W) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ(
context->HasInput("Bias"),
true,
platform::errors::InvalidArgument(
"Input(Bias) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ(
context->HasOutput("Out"),
true,
platform::errors::InvalidArgument(
"Output(Out) of MultiHeadMatMul should not be null."));
auto dim_w = context->GetInputDim("W");
PADDLE_ENFORCE_GT(
dim_w.size(),
2,
platform::errors::InvalidArgument(
"Multihead input is expected at least a 3-D tensor, but "
"it's %d-D tensor now.",
dim_w.size()));
auto dim_bias_q = context->GetInputDim("Bias");
PADDLE_ENFORCE_GT(
dim_bias_q.size(),
1,
platform::errors::InvalidArgument(
"Multihead input should be at least 2-D tensor, but it's "
"%d-D tensor now.",
dim_bias_q.size()));
auto dim_input = context->GetInputDim("Input");
context->SetOutputDim("Out", dim_input);
context->ShareLoD("Input", /*->*/ "Out");
}
};
class MultiHeadMatMulV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "The input of MultiHeadMatMul op");
AddInput("W", "The weight input of MultiHeadMatMul op");
AddInput("Bias", "The bias input of MultiHeadMatMul op");
AddInput("BiasQK", "The QK bias input of MultiHeadMatMul op")
.AsDispensable();
AddOutput("Out", "The output of MultiHeadMatMul op");
AddAttr<bool>("transpose_Q",
R"DOC(If true, use the transpose of `Q`.
)DOC")
.SetDefault(false);
AddAttr<bool>("transpose_K",
R"DOC(If true, use the transpose of `K`.
)DOC")
.SetDefault(true);
AddAttr<bool>("transpose_V",
R"DOC(If true, use the transpose of `V`.
)DOC")
.SetDefault(false);
AddAttr<float>("alpha", "The scale of Out").SetDefault(1.0f);
AddAttr<int>("head_number", "The number of heads of the matrix")
.SetDefault(1);
AddComment(R"DOC(
MultiHeadMatMul Operator.
This op is used for optimize multi head calculation in ernie model.
Not suggest to use in other case except has same structure as ernie.
Example of matrix multiplication with head_number of B
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(multihead_matmul,
ops::MultiHeadMatMulV2Op,
ops::MultiHeadMatMulV2OpMaker);
......@@ -255,704 +255,6 @@ template class EmbEltwiseLayerNormFunctor<float>;
template class EmbEltwiseLayerNormFunctor<half>;
#endif
template <typename T>
__global__ void SoftmaxKernelWithEltadd(T *qk_buf_,
const T *bias_qk_,
const int batch_size,
const int head_num,
const int seq_len,
const phi::funcs::warp_mask_t mask) {
int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % WARP_SIZE == 0);
float tmp = threadIdx.x < seq_len
? static_cast<float>(qk_buf_[threadIdx.x + qk_offset] +
bias_qk_[threadIdx.x + qk_offset])
: -1e20f;
float max_val = phi::funcs::BlockReduceMax<float>(tmp, mask);
float qk_tmp = threadIdx.x < seq_len ? __expf(tmp - max_val) : 0.0f;
float sum_val = phi::funcs::BlockReduceSum<float>(qk_tmp, mask);
if (threadIdx.x < seq_len)
qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / sum_val);
}
// HIP defined __HIP_NO_HALF_CONVERSIONS__
#ifndef __HIPCC__ // @{ Half kernel: SoftmaxKernelWithEltadd
template <>
__global__ void SoftmaxKernelWithEltadd<half>(
half *qk_buf_,
const half *bias_qk_,
const int batch_size,
const int head_num,
const int seq_len,
const phi::funcs::warp_mask_t mask) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % WARP_SIZE == 0);
float tmp = threadIdx.x < seq_len
? static_cast<float>(qk_buf_[threadIdx.x + qk_offset] +
bias_qk_[threadIdx.x + qk_offset])
: -1e20f;
float max_val = phi::funcs::BlockReduceMax<float>(tmp, mask);
float qk_tmp = threadIdx.x < seq_len ? __expf(tmp - max_val) : 0.0f;
float sum_val = phi::funcs::BlockReduceSum<float>(qk_tmp, mask);
if (threadIdx.x < seq_len)
qk_buf_[threadIdx.x + qk_offset] = (half)(qk_tmp / sum_val);
#endif
}
#endif // @} End Half kernel: SoftmaxKernelWithEltadd
template <typename T>
__global__ void SoftmaxKernelWithEltadd2(T *qk_buf_,
const T *bias_qk_,
const int batch_size,
const int head_num,
const int seq_len,
const phi::funcs::warp_mask_t mask) {
int qk_offset = blockIdx.x * seq_len;
int idx = threadIdx.x;
assert(blockDim.x % WARP_SIZE == 0);
float2 tmp = idx < seq_len
? phi::funcs::ToFloat2<T>(qk_buf_[idx + qk_offset] +
bias_qk_[idx + qk_offset])
: make_float2(-1e20f, -1e20f);
float max_val = phi::funcs::BlockReduceMax<float>(max(tmp.x, tmp.y), mask);
float2 qk_tmp = idx < seq_len ? make_float2(__expf(tmp.x - max_val),
__expf(tmp.y - max_val))
: make_float2(0.f, 0.f);
float sum_val =
phi::funcs::BlockReduceSum<float>(qk_tmp.x + qk_tmp.y, mask) + 1e-6f;
if (idx < seq_len) {
qk_buf_[idx + qk_offset] =
phi::funcs::FloatsToPair<T>(qk_tmp.x / sum_val, qk_tmp.y / sum_val);
}
}
template <>
__global__ void SoftmaxKernelWithEltadd2<half2>(
half2 *qk_buf_,
const half2 *bias_qk_,
const int batch_size,
const int head_num,
const int seq_len,
const phi::funcs::warp_mask_t mask) {
// operator "+" of half only suppotted after cuda version 10.0
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) && \
(CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000)
int qk_offset = blockIdx.x * seq_len;
int idx = threadIdx.x;
assert(blockDim.x % WARP_SIZE == 0);
float2 tmp = idx < seq_len
? phi::funcs::ToFloat2<half2>(qk_buf_[idx + qk_offset] +
bias_qk_[idx + qk_offset])
: make_float2(-1e20f, -1e20f);
float max_val = phi::funcs::BlockReduceMax<float>(max(tmp.x, tmp.y), mask);
float2 qk_tmp = idx < seq_len ? make_float2(__expf(tmp.x - max_val),
__expf(tmp.y - max_val))
: make_float2(0.f, 0.f);
float sum_val =
phi::funcs::BlockReduceSum<float>(qk_tmp.x + qk_tmp.y, mask) + 1e-6f;
if (idx < seq_len) {
qk_buf_[idx + qk_offset] =
phi::funcs::FloatsToPair<half2>(qk_tmp.x / sum_val, qk_tmp.y / sum_val);
}
#endif
}
template <typename T>
__global__ void SoftmaxKernelWithEltaddForLarge(
T *qk_buf,
const T *bias_qk,
const int batch_size,
const int head_num,
const int seq_len,
const phi::funcs::warp_mask_t mask) {
int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % WARP_SIZE == 0);
T stride_max = -1e20f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
stride_max = qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset] >
stride_max
? qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset]
: stride_max;
}
T max_val = phi::funcs::BlockReduceMax<T>(stride_max, mask);
T stride_sum = 0.f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
stride_sum += __expf(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset] - max_val);
}
T sum_val = phi::funcs::BlockReduceSum<T>(stride_sum, mask);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
qk_buf[threadIdx.x + i + qk_offset] =
(T)(__expf(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset] - max_val) /
sum_val);
}
}
// HIP defined __HIP_NO_HALF_CONVERSIONS__
#ifndef __HIPCC__ // @{ Half kernel: SoftmaxKernelWithEltadd
template <>
__global__ void SoftmaxKernelWithEltaddForLarge(
half *qk_buf,
const half *bias_qk,
const int batch_size,
const int head_num,
const int seq_len,
const phi::funcs::warp_mask_t mask) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % WARP_SIZE == 0);
float stride_max = -1e20f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float tmp = static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset]);
stride_max = tmp > stride_max ? tmp : stride_max;
}
float max_val = phi::funcs::BlockReduceMax<float>(stride_max, mask);
float stride_sum = 0.f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float tmp = static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset]);
stride_sum += __expf(tmp - max_val);
}
float sum_val = phi::funcs::BlockReduceSum<float>(stride_sum, mask);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float tmp =
__expf(static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset]) -
max_val);
qk_buf[threadIdx.x + i + qk_offset] = (half)(tmp / sum_val);
}
#endif
}
#endif // @} End Half kernel: SoftmaxKernelWithEltadd
template <typename T>
__global__ void SoftmaxKernelWithEltaddForLarge2(
T *qk_buf_,
const T *bias_qk_,
const int batch_size,
const int head_num,
const int seq_len,
const phi::funcs::warp_mask_t mask) {
int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % WARP_SIZE == 0);
float2 stride_max = make_float2(-1e20f, -1e20f);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
stride_max.x = max(stride_max.x, cur.x);
stride_max.y = max(stride_max.y, cur.y);
}
float max_val =
phi::funcs::BlockReduceMax<float>(max(stride_max.x, stride_max.y), mask);
float2 stride_sum = make_float2(0.f, 0.f);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
stride_sum.x += __expf(cur.x - max_val);
stride_sum.y += __expf(cur.y - max_val);
}
float sum_val =
phi::funcs::BlockReduceSum<float>(stride_sum.x + stride_sum.y, mask) +
1e-6f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
qk_buf_[threadIdx.x + i + qk_offset] = phi::funcs::FloatsToPair<T>(
__expf(cur.x - max_val) / sum_val, __expf(cur.y - max_val) / sum_val);
}
}
template <>
__global__ void SoftmaxKernelWithEltaddForLarge2(
half2 *qk_buf_,
const half2 *bias_qk_,
const int batch_size,
const int head_num,
const int seq_len,
const phi::funcs::warp_mask_t mask) {
// operator "+" of half only suppotted after cuda version 10.0
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) && \
(CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000)
int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % WARP_SIZE == 0);
float2 stride_max = make_float2(-1e20f, -1e20f);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur =
phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
stride_max.x = max(stride_max.x, cur.x);
stride_max.y = max(stride_max.y, cur.y);
}
float max_val =
phi::funcs::BlockReduceMax<float>(max(stride_max.x, stride_max.y), mask);
float2 stride_sum = make_float2(0.f, 0.f);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur =
phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
stride_sum.x += __expf(cur.x - max_val);
stride_sum.y += __expf(cur.y - max_val);
}
float sum_val =
phi::funcs::BlockReduceSum<float>(stride_sum.x + stride_sum.y, mask) +
1e-6f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur =
phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
qk_buf_[threadIdx.x + i + qk_offset] = phi::funcs::FloatsToPair<half2>(
__expf(cur.x - max_val) / sum_val, __expf(cur.y - max_val) / sum_val);
}
#endif
}
template <typename T>
inline __device__ T ldg(const T *val) {
return __ldg(val);
}
template <typename T>
inline __device__ T hexp2(T a) {
return h2exp(a);
}
template <typename T_IN, typename T_OUT>
inline __device__ T_OUT type2type2(T_IN a);
template <>
inline __device__ half2 type2type2(half a) {
return __half2half2(a);
}
template <typename T>
inline __device__ T float2type2(float a);
template <>
inline __device__ half2 float2type2(float a) {
return __float2half2_rn(a);
}
template <typename T>
inline __device__ T hmul2(T a, T b) {
return __hmul2(a, b);
}
template <typename T>
inline __device__ T hsub2(T a, T b) {
return __hsub2(a, b);
}
template <typename T>
inline __device__ T hadd2(T a, T b) {
return __hadd2(a, b);
}
template <typename T, int ITEMS_PER_THREAD, int NUM>
__global__ void softmax_kernel_with_mask(T *qk_buf_,
const T *attr_mask,
const int batch_size,
const int head_num,
const int seq_len) {
using T2 = half2;
T2 *qk_buf_half2 = reinterpret_cast<T2 *>(qk_buf_);
const T2 *attr_mask_half2 = (const T2 *)attr_mask;
for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x * NUM) {
T2 data[NUM][ITEMS_PER_THREAD];
int qk_offset[NUM];
__shared__ float s_sum[NUM], s_max[NUM];
float local_max[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
local_max[j] = -1e20f;
}
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD;
i++) {
int mask_offset[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len +
seq_id + j * gridDim.x) *
(seq_len / 2) +
blockDim.x * i + threadIdx.x;
mask_offset[j] =
(blockIdx.y * seq_len + seq_id + j * gridDim.x) * (seq_len / 2) +
blockDim.x * i + threadIdx.x;
}
T2 mask_val[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
mask_val[j] = ldg(&attr_mask_half2[mask_offset[j]]);
}
T2 qk[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
qk[j] = qk_buf_half2[qk_offset[j]];
}
#pragma unroll
for (int j = 0; j < NUM; j++) {
mask_val[j] = hmul2<T2>(hsub2<T2>(float2type2<T2>(1.0f), mask_val[j]),
float2type2<T2>(-10000.0f));
}
#pragma unroll
for (int j = 0; j < NUM; j++) {
data[j][i] = hadd2<T2>(qk[j], mask_val[j]);
local_max[j] = fmax(local_max[j],
fmax(static_cast<float>(data[j][i].x),
static_cast<float>(data[j][i].y)));
}
}
if (blockDim.x <= WARP_SIZE) {
phi::funcs::WarpReduceMaxV2<float, NUM>(local_max);
} else {
phi::funcs::BlockReduceMaxV2<float, NUM>(local_max);
}
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < NUM; j++) {
s_max[j] = local_max[j];
}
}
__syncthreads();
float local_sum[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
local_sum[j] = {0.f};
}
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD;
i++) {
#pragma unroll
for (int j = 0; j < NUM; j++) {
data[j][i] =
hexp2<T2>(hsub2<T2>(data[j][i], float2type2<T2>(s_max[j])));
}
#pragma unroll
for (int j = 0; j < NUM; j++) {
local_sum[j] += static_cast<float>(data[j][i].x + data[j][i].y);
}
}
if (blockDim.x <= WARP_SIZE) {
phi::funcs::WarpReduceSumV2<float, NUM>(local_sum);
} else {
phi::funcs::BlockReduceSumV2<float, NUM>(local_sum);
}
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < NUM; j++) {
s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f);
}
}
__syncthreads();
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD;
i++) {
#pragma unroll
for (int j = 0; j < NUM; j++) {
qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len +
seq_id + j * gridDim.x) *
(seq_len / 2) +
blockDim.x * i + threadIdx.x;
}
#pragma unroll
for (int j = 0; j < NUM; j++) {
qk_buf_half2[qk_offset[j]] =
hmul2<T2>(data[j][i], float2type2<T2>(s_sum[j]));
}
}
}
}
#define SOFTMAX_KERNEL_WITH_MASK(REPEAT_THREAD) \
do { \
block.x /= REPEAT_THREAD; \
grid.x /= 4; \
constexpr int NUM = 4; \
softmax_kernel_with_mask<half, REPEAT_THREAD, NUM> \
<<<grid, block, 0, stream>>>(reinterpret_cast<half *>(qk_buf_), \
(const half *)bias_qk, \
batch_size, \
head_num, \
seq_len); \
} while (0)
template <typename T>
inline void MatMulWithHeadQK(const phi::GPUContext &context,
int head_num,
int seq_len,
int size_per_head,
int batch_size,
bool q_trans,
bool k_trans,
T *q_buf_,
T *k_buf_,
T *qk_buf_,
const T *bias_qk,
bool bias_is_mask,
T alpha,
T beta) {
CBLAS_TRANSPOSE transA = !q_trans ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !k_trans ? CblasNoTrans : CblasTrans;
typedef typename CUDATypeTraits<T>::TYPE run_type;
auto blas = phi::funcs::GetBlas<phi::GPUContext, run_type>(context);
auto stream = context.stream();
blas.BatchedGEMM(transA,
transB,
seq_len,
seq_len,
size_per_head,
static_cast<run_type>(alpha),
reinterpret_cast<run_type *>(q_buf_),
reinterpret_cast<run_type *>(k_buf_),
static_cast<run_type>(beta),
reinterpret_cast<run_type *>(qk_buf_),
batch_size * head_num,
seq_len * size_per_head,
seq_len * size_per_head);
if (seq_len <= 1024) {
int grid = batch_size * head_num * seq_len;
int block = seq_len;
// Align block to 32, also limit seq_len to max block size.
if (seq_len % 2 == 0) {
block =
(seq_len <= (2 * WARP_SIZE))
? WARP_SIZE
: ((seq_len + (2 * WARP_SIZE - 1)) / (2 * WARP_SIZE)) * WARP_SIZE;
if (std::is_same<T, float>::value) {
SoftmaxKernelWithEltadd2<float2><<<grid, block, 0, stream>>>(
reinterpret_cast<float2 *>(qk_buf_),
reinterpret_cast<const float2 *>(bias_qk),
batch_size,
head_num,
seq_len / 2,
FINAL_MASK);
} else {
if (bias_is_mask) {
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700)
PADDLE_ENFORCE_EQ(bias_is_mask,
false,
platform::errors::InvalidArgument(
"QK_bias is mask can't be supported on rocm or "
"cuda_arch<700"));
#else
dim3 grid(seq_len, batch_size, head_num);
dim3 block((seq_len / 2 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
SOFTMAX_KERNEL_WITH_MASK(1);
#endif
} else {
SoftmaxKernelWithEltadd2<__half2><<<grid, block, 0, stream>>>(
reinterpret_cast<__half2 *>(qk_buf_),
reinterpret_cast<const __half2 *>(bias_qk),
batch_size,
head_num,
seq_len / 2,
FINAL_MASK);
}
}
} else {
block = (seq_len <= WARP_SIZE)
? WARP_SIZE
: ((seq_len + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
SoftmaxKernelWithEltadd<T><<<grid, block, 0, stream>>>(
qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
}
} else {
int grid = batch_size * head_num * seq_len;
int block = 512;
if (seq_len % 2 == 0) {
if (std::is_same<T, float>::value) {
SoftmaxKernelWithEltaddForLarge2<float2><<<grid, block, 0, stream>>>(
reinterpret_cast<float2 *>(qk_buf_),
reinterpret_cast<const float2 *>(bias_qk),
batch_size,
head_num,
seq_len / 2,
FINAL_MASK);
} else {
if (bias_is_mask) {
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700)
PADDLE_ENFORCE_EQ(bias_is_mask,
false,
platform::errors::InvalidArgument(
"QK_bias is mask can't be supported on rocm or "
"cuda_arch<700"));
#else
dim3 grid(seq_len, batch_size, head_num);
dim3 block((seq_len / 2 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
if (block.x > 0 && block.x <= 1024) {
SOFTMAX_KERNEL_WITH_MASK(1);
} else if (block.x <= 2048) {
SOFTMAX_KERNEL_WITH_MASK(2);
} else if (block.x <= 4096) {
SOFTMAX_KERNEL_WITH_MASK(4);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot support the length of attention > 8192."));
}
#endif
} else {
SoftmaxKernelWithEltaddForLarge2<__half2><<<grid, block, 0, stream>>>(
reinterpret_cast<__half2 *>(qk_buf_),
reinterpret_cast<const __half2 *>(bias_qk),
batch_size,
head_num,
seq_len / 2,
FINAL_MASK);
}
}
} else {
SoftmaxKernelWithEltaddForLarge<T><<<grid, block, 0, stream>>>(
qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
}
}
}
template <typename T>
inline void MatMulWithHeadQKV(const phi::GPUContext &context,
int head_num,
int seq_len,
int size_per_head,
int batch_size,
bool qk_trans,
bool v_trans,
T *v_buf_,
const T *qk_buf_,
T *dst,
T alpha,
T beta) {
int m = batch_size * seq_len;
int k = head_num * size_per_head;
typedef typename CUDATypeTraits<T>::TYPE run_type;
auto blas = phi::funcs::GetBlas<phi::GPUContext, run_type>(context);
auto stream = context.stream();
CBLAS_TRANSPOSE transA = !qk_trans ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !v_trans ? CblasNoTrans : CblasTrans;
blas.BatchedGEMM(transA,
transB,
seq_len,
size_per_head,
seq_len,
static_cast<run_type>(alpha),
reinterpret_cast<const run_type *>(qk_buf_),
reinterpret_cast<run_type *>(v_buf_),
static_cast<run_type>(beta),
reinterpret_cast<run_type *>(dst),
batch_size * head_num,
seq_len * seq_len,
seq_len * size_per_head);
}
template <typename T>
void MultiHeadGPUComputeFunctor<T>::operator()(const phi::GPUContext &dev_ctx,
int batch,
int seq_len,
int head_num,
int head_size,
T *qkptr,
const T *bias_qk_ptr,
bool bias_is_mask,
T *tptr,
T alpha,
T beta) {
auto stream = dev_ctx.stream();
const int tsize = batch * head_num * seq_len * head_size;
T *qptr = tptr;
T *kptr = qptr + tsize;
T *vptr = kptr + tsize;
// batch gemm stride, softmaxwithscale.
MatMulWithHeadQK<T>(dev_ctx,
head_num,
seq_len,
head_size,
batch,
false,
true,
qptr,
kptr,
qkptr,
bias_qk_ptr,
bias_is_mask,
alpha,
beta);
// batch gemm stride, transpose.
MatMulWithHeadQKV<T>(dev_ctx,
head_num,
seq_len,
head_size,
batch,
false,
false,
vptr,
qkptr,
tptr,
T(1.0),
beta);
}
template class MultiHeadGPUComputeFunctor<float>;
// device function 'operator()' is not supportted until cuda 10.0
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
template class MultiHeadGPUComputeFunctor<half>;
#endif
template <typename T, unsigned TPB>
__global__ void SkipLayerNormSmallKernel(int num,
int hidden,
......
......@@ -77,35 +77,6 @@ class EmbEltwiseLayerNormFunctor {
gpuStream_t stream);
};
// This functor involves a fusion calculation in Ernie or Bert.
// The fusion mode is as follows:
//
// | |
// matmul
// |
// eltwise_add
// |
// softmax /
// \ /
// matmul
// |
template <typename T>
class MultiHeadGPUComputeFunctor {
public:
void operator()(const phi::GPUContext &dev_ctx,
int batch,
int seq_len,
int head_num,
int head_size,
T *qkptr,
const T *bias_qk_ptr,
bool bias_is_mask,
T *tptr,
T alpha,
T beta);
};
// This functor involves a fusion calculation in Ernie or Bert.
// The fusion mode is as follows:
//
......
......@@ -189,6 +189,16 @@
data_type : x
optional : mask, seq_lod, max_seq_len, x_fp16, out_fp16
- op : multihead_matmul
args : (Tensor input, Tensor w, Tensor bias, Tensor bias_qk, bool transpose_q = false, bool transpose_k = true, bool transpose_v = false, float alpha = 1.0f, int head_number = 1)
output : Tensor(out)
infer_meta :
func : MultiheadMatmulInferMeta
kernel :
func : multihead_matmul
data_type : input
optional : bias_qk
- op : yolo_box_xpu
args : (Tensor x, Tensor x_max, Tensor grid, Tensor stride, Tensor anchor_grid, float offset)
output : Tensor(out), Tensor(out_max)
......
......@@ -1935,6 +1935,14 @@
outputs :
{out : Out, index : Index, nms_rois_num : NmsRoisNum}
- op : multihead_matmul
inputs :
{input : Input, w : W, bias : Bias, bias_qk : BiasQK}
outputs :
out : Out
attrs :
{transpose_q : transpose_Q, transpose_k : transpose_K, transpose_v : transpose_V}
- op : multinomial
inputs :
{x : X}
......
......@@ -4126,6 +4126,39 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
out_count->set_dtype(DataType::INT32);
}
void MultiheadMatmulInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
const MetaTensor& bias_qk,
const bool transpose_q,
const bool transpose_k,
const bool transpose_v,
const float alpha,
const int head_number,
MetaTensor* out) {
auto w_dims = w.dims();
PADDLE_ENFORCE_GT(
w_dims.size(),
2,
errors::InvalidArgument(
"MultiheadMatmul's w is expected at least a 3-D tensor, but "
"it's %d-D tensor now.",
w_dims.size()));
auto bias_dims = bias.dims();
PADDLE_ENFORCE_GT(
bias_dims.size(),
1,
errors::InvalidArgument(
"MultiheadMatmul's bias should be at least 2-D tensor, but it's "
"%d-D tensor now.",
bias_dims.size()));
out->set_dims(input.dims());
out->set_dtype(input.dtype());
out->share_lod(input);
}
void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& cache_kv,
const MetaTensor& bias,
......
......@@ -811,6 +811,17 @@ void FusedRopeInferMeta(const MetaTensor& q,
MetaTensor* out_k,
MetaTensor* out_v);
void MultiheadMatmulInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
const MetaTensor& bias_qk,
const bool transpose_q,
const bool transpose_k,
const bool transpose_v,
const float alpha,
const int head_number,
MetaTensor* out);
void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& cache_kv,
const MetaTensor& bias,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#include <cub/cub.cuh> // NOLINT
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/phi/kernels/funcs/multihead_matmul_functor.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
namespace phi {
namespace funcs {
template <typename T>
struct CUDATypeTraits;
template <>
struct CUDATypeTraits<half> {
typedef phi::dtype::float16 TYPE;
};
template <>
struct CUDATypeTraits<float> {
typedef float TYPE;
};
using phi::funcs::operator+;
template <typename T>
__global__ void SoftmaxKernelWithEltadd(T *qk_buf_,
const T *bias_qk_,
const int batch_size,
const int head_num,
const int seq_len,
const phi::funcs::warp_mask_t mask) {
int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % WARP_SIZE == 0);
float tmp = threadIdx.x < seq_len
? static_cast<float>(qk_buf_[threadIdx.x + qk_offset] +
bias_qk_[threadIdx.x + qk_offset])
: -1e20f;
float max_val = phi::funcs::BlockReduceMax<float>(tmp, mask);
float qk_tmp = threadIdx.x < seq_len ? __expf(tmp - max_val) : 0.0f;
float sum_val = phi::funcs::BlockReduceSum<float>(qk_tmp, mask);
if (threadIdx.x < seq_len)
qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / sum_val);
}
template <>
__global__ void SoftmaxKernelWithEltadd<half>(
half *qk_buf_,
const half *bias_qk_,
const int batch_size,
const int head_num,
const int seq_len,
const phi::funcs::warp_mask_t mask) {
#if defined(PADDLE_WITH_CUDA) && CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % WARP_SIZE == 0);
float tmp = threadIdx.x < seq_len
? static_cast<float>(qk_buf_[threadIdx.x + qk_offset] +
bias_qk_[threadIdx.x + qk_offset])
: -1e20f;
float max_val = phi::funcs::BlockReduceMax<float>(tmp, mask);
float qk_tmp = threadIdx.x < seq_len ? __expf(tmp - max_val) : 0.0f;
float sum_val = phi::funcs::BlockReduceSum<float>(qk_tmp, mask);
if (threadIdx.x < seq_len)
qk_buf_[threadIdx.x + qk_offset] = (half)(qk_tmp / sum_val);
#endif
}
template <typename T>
__global__ void SoftmaxKernelWithEltadd2(T *qk_buf_,
const T *bias_qk_,
const int batch_size,
const int head_num,
const int seq_len,
const phi::funcs::warp_mask_t mask) {
int qk_offset = blockIdx.x * seq_len;
int idx = threadIdx.x;
assert(blockDim.x % WARP_SIZE == 0);
float2 tmp = idx < seq_len
? phi::funcs::ToFloat2<T>(qk_buf_[idx + qk_offset] +
bias_qk_[idx + qk_offset])
: make_float2(-1e20f, -1e20f);
float max_val = phi::funcs::BlockReduceMax<float>(max(tmp.x, tmp.y), mask);
float2 qk_tmp = idx < seq_len ? make_float2(__expf(tmp.x - max_val),
__expf(tmp.y - max_val))
: make_float2(0.f, 0.f);
float sum_val =
phi::funcs::BlockReduceSum<float>(qk_tmp.x + qk_tmp.y, mask) + 1e-6f;
if (idx < seq_len) {
qk_buf_[idx + qk_offset] =
phi::funcs::FloatsToPair<T>(qk_tmp.x / sum_val, qk_tmp.y / sum_val);
}
}
template <>
__global__ void SoftmaxKernelWithEltadd2<half2>(
half2 *qk_buf_,
const half2 *bias_qk_,
const int batch_size,
const int head_num,
const int seq_len,
const phi::funcs::warp_mask_t mask) {
// operator "+" of half only suppotted after cuda version 10.0
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) && CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int qk_offset = blockIdx.x * seq_len;
int idx = threadIdx.x;
assert(blockDim.x % WARP_SIZE == 0);
float2 tmp = idx < seq_len
? phi::funcs::ToFloat2<half2>(qk_buf_[idx + qk_offset] +
bias_qk_[idx + qk_offset])
: make_float2(-1e20f, -1e20f);
float max_val = phi::funcs::BlockReduceMax<float>(max(tmp.x, tmp.y), mask);
float2 qk_tmp = idx < seq_len ? make_float2(__expf(tmp.x - max_val),
__expf(tmp.y - max_val))
: make_float2(0.f, 0.f);
float sum_val =
phi::funcs::BlockReduceSum<float>(qk_tmp.x + qk_tmp.y, mask) + 1e-6f;
if (idx < seq_len) {
qk_buf_[idx + qk_offset] =
phi::funcs::FloatsToPair<half2>(qk_tmp.x / sum_val, qk_tmp.y / sum_val);
}
#endif
}
template <typename T>
__global__ void SoftmaxKernelWithEltaddForLarge(
T *qk_buf,
const T *bias_qk,
const int batch_size,
const int head_num,
const int seq_len,
const phi::funcs::warp_mask_t mask) {
int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % WARP_SIZE == 0);
T stride_max = -1e20f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
stride_max = qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset] >
stride_max
? qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset]
: stride_max;
}
T max_val = phi::funcs::BlockReduceMax<T>(stride_max, mask);
T stride_sum = 0.f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
stride_sum += __expf(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset] - max_val);
}
T sum_val = phi::funcs::BlockReduceSum<T>(stride_sum, mask);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
qk_buf[threadIdx.x + i + qk_offset] =
(T)(__expf(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset] - max_val) /
sum_val);
}
}
template <>
__global__ void SoftmaxKernelWithEltaddForLarge(
half *qk_buf,
const half *bias_qk,
const int batch_size,
const int head_num,
const int seq_len,
const phi::funcs::warp_mask_t mask) {
#if defined(PADDLE_WITH_CUDA) && \
(CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000)
int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % WARP_SIZE == 0);
float stride_max = -1e20f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float tmp = static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset]);
stride_max = tmp > stride_max ? tmp : stride_max;
}
float max_val = phi::funcs::BlockReduceMax<float>(stride_max, mask);
float stride_sum = 0.f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float tmp = static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset]);
stride_sum += __expf(tmp - max_val);
}
float sum_val = phi::funcs::BlockReduceSum<float>(stride_sum, mask);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float tmp =
__expf(static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset]) -
max_val);
qk_buf[threadIdx.x + i + qk_offset] = (half)(tmp / sum_val);
}
#endif
}
template <typename T>
__global__ void SoftmaxKernelWithEltaddForLarge2(
T *qk_buf_,
const T *bias_qk_,
const int batch_size,
const int head_num,
const int seq_len,
const phi::funcs::warp_mask_t mask) {
int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % WARP_SIZE == 0);
float2 stride_max = make_float2(-1e20f, -1e20f);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
stride_max.x = max(stride_max.x, cur.x);
stride_max.y = max(stride_max.y, cur.y);
}
float max_val =
phi::funcs::BlockReduceMax<float>(max(stride_max.x, stride_max.y), mask);
float2 stride_sum = make_float2(0.f, 0.f);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
stride_sum.x += __expf(cur.x - max_val);
stride_sum.y += __expf(cur.y - max_val);
}
float sum_val =
phi::funcs::BlockReduceSum<float>(stride_sum.x + stride_sum.y, mask) +
1e-6f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
qk_buf_[threadIdx.x + i + qk_offset] = phi::funcs::FloatsToPair<T>(
__expf(cur.x - max_val) / sum_val, __expf(cur.y - max_val) / sum_val);
}
}
template <>
__global__ void SoftmaxKernelWithEltaddForLarge2(
half2 *qk_buf_,
const half2 *bias_qk_,
const int batch_size,
const int head_num,
const int seq_len,
const phi::funcs::warp_mask_t mask) {
// operator "+" of half only suppotted after cuda version 10.0
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) && \
(CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000)
int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % WARP_SIZE == 0);
float2 stride_max = make_float2(-1e20f, -1e20f);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur =
phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
stride_max.x = max(stride_max.x, cur.x);
stride_max.y = max(stride_max.y, cur.y);
}
float max_val =
phi::funcs::BlockReduceMax<float>(max(stride_max.x, stride_max.y), mask);
float2 stride_sum = make_float2(0.f, 0.f);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur =
phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
stride_sum.x += __expf(cur.x - max_val);
stride_sum.y += __expf(cur.y - max_val);
}
float sum_val =
phi::funcs::BlockReduceSum<float>(stride_sum.x + stride_sum.y, mask) +
1e-6f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur =
phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
qk_buf_[threadIdx.x + i + qk_offset] = phi::funcs::FloatsToPair<half2>(
__expf(cur.x - max_val) / sum_val, __expf(cur.y - max_val) / sum_val);
}
#endif
}
template <typename T>
inline __device__ T ldg(const T *val) {
return __ldg(val);
}
template <typename T>
inline __device__ T hexp2(T a) {
return h2exp(a);
}
template <typename T_IN, typename T_OUT>
inline __device__ T_OUT type2type2(T_IN a);
template <>
inline __device__ half2 type2type2(half a) {
return __half2half2(a);
}
template <typename T>
inline __device__ T float2type2(float a);
template <>
inline __device__ half2 float2type2(float a) {
return __float2half2_rn(a);
}
template <typename T>
inline __device__ T hmul2(T a, T b) {
return __hmul2(a, b);
}
template <typename T>
inline __device__ T hsub2(T a, T b) {
return __hsub2(a, b);
}
template <typename T>
inline __device__ T hadd2(T a, T b) {
return __hadd2(a, b);
}
template <typename T, int ITEMS_PER_THREAD, int NUM>
__global__ void softmax_kernel_with_mask(T *qk_buf_,
const T *attr_mask,
const int batch_size,
const int head_num,
const int seq_len) {
using T2 = half2;
T2 *qk_buf_half2 = reinterpret_cast<T2 *>(qk_buf_);
const T2 *attr_mask_half2 = (const T2 *)attr_mask;
for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x * NUM) {
T2 data[NUM][ITEMS_PER_THREAD];
int qk_offset[NUM];
__shared__ float s_sum[NUM], s_max[NUM];
float local_max[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
local_max[j] = -1e20f;
}
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD;
i++) {
int mask_offset[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len +
seq_id + j * gridDim.x) *
(seq_len / 2) +
blockDim.x * i + threadIdx.x;
mask_offset[j] =
(blockIdx.y * seq_len + seq_id + j * gridDim.x) * (seq_len / 2) +
blockDim.x * i + threadIdx.x;
}
T2 mask_val[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
mask_val[j] = ldg(&attr_mask_half2[mask_offset[j]]);
}
T2 qk[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
qk[j] = qk_buf_half2[qk_offset[j]];
}
#pragma unroll
for (int j = 0; j < NUM; j++) {
mask_val[j] = hmul2<T2>(hsub2<T2>(float2type2<T2>(1.0f), mask_val[j]),
float2type2<T2>(-10000.0f));
}
#pragma unroll
for (int j = 0; j < NUM; j++) {
data[j][i] = hadd2<T2>(qk[j], mask_val[j]);
local_max[j] = fmax(local_max[j],
fmax(static_cast<float>(data[j][i].x),
static_cast<float>(data[j][i].y)));
}
}
if (blockDim.x <= WARP_SIZE) {
phi::funcs::WarpReduceMaxV2<float, NUM>(local_max);
} else {
phi::funcs::BlockReduceMaxV2<float, NUM>(local_max);
}
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < NUM; j++) {
s_max[j] = local_max[j];
}
}
__syncthreads();
float local_sum[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
local_sum[j] = {0.f};
}
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD;
i++) {
#pragma unroll
for (int j = 0; j < NUM; j++) {
data[j][i] =
hexp2<T2>(hsub2<T2>(data[j][i], float2type2<T2>(s_max[j])));
}
#pragma unroll
for (int j = 0; j < NUM; j++) {
local_sum[j] += static_cast<float>(data[j][i].x + data[j][i].y);
}
}
if (blockDim.x <= WARP_SIZE) {
phi::funcs::WarpReduceSumV2<float, NUM>(local_sum);
} else {
phi::funcs::BlockReduceSumV2<float, NUM>(local_sum);
}
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < NUM; j++) {
s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f);
}
}
__syncthreads();
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD;
i++) {
#pragma unroll
for (int j = 0; j < NUM; j++) {
qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len +
seq_id + j * gridDim.x) *
(seq_len / 2) +
blockDim.x * i + threadIdx.x;
}
#pragma unroll
for (int j = 0; j < NUM; j++) {
qk_buf_half2[qk_offset[j]] =
hmul2<T2>(data[j][i], float2type2<T2>(s_sum[j]));
}
}
}
}
#define SOFTMAX_KERNEL_WITH_MASK(REPEAT_THREAD) \
do { \
block.x /= REPEAT_THREAD; \
grid.x /= 4; \
constexpr int NUM = 4; \
softmax_kernel_with_mask<half, REPEAT_THREAD, NUM> \
<<<grid, block, 0, stream>>>(reinterpret_cast<half *>(qk_buf_), \
(const half *)bias_qk, \
batch_size, \
head_num, \
seq_len); \
} while (0)
template <typename T>
inline void MatmulWithHeadQK(const phi::GPUContext &context,
int head_num,
int seq_len,
int size_per_head,
int batch_size,
bool q_trans,
bool k_trans,
T *q_buf_,
T *k_buf_,
T *qk_buf_,
const T *bias_qk,
bool bias_is_mask,
T alpha,
T beta) {
CBLAS_TRANSPOSE transA = !q_trans ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !k_trans ? CblasNoTrans : CblasTrans;
typedef typename CUDATypeTraits<T>::TYPE run_type;
auto blas = phi::funcs::GetBlas<phi::GPUContext, run_type>(context);
auto stream = context.stream();
blas.BatchedGEMM(transA,
transB,
seq_len,
seq_len,
size_per_head,
static_cast<run_type>(alpha),
reinterpret_cast<run_type *>(q_buf_),
reinterpret_cast<run_type *>(k_buf_),
static_cast<run_type>(beta),
reinterpret_cast<run_type *>(qk_buf_),
batch_size * head_num,
seq_len * size_per_head,
seq_len * size_per_head);
if (seq_len <= 1024) {
int grid = batch_size * head_num * seq_len;
int block = seq_len;
// Align block to 32, also limit seq_len to max block size.
if (seq_len % 2 == 0) {
block =
(seq_len <= (2 * WARP_SIZE))
? WARP_SIZE
: ((seq_len + (2 * WARP_SIZE - 1)) / (2 * WARP_SIZE)) * WARP_SIZE;
if (std::is_same<T, float>::value) {
SoftmaxKernelWithEltadd2<float2><<<grid, block, 0, stream>>>(
reinterpret_cast<float2 *>(qk_buf_),
reinterpret_cast<const float2 *>(bias_qk),
batch_size,
head_num,
seq_len / 2,
FINAL_MASK);
} else {
if (bias_is_mask) {
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700)
PADDLE_ENFORCE_EQ(bias_is_mask,
false,
phi::errors::InvalidArgument(
"QK_bias is mask can't be supported on rocm or "
"cuda_arch<700"));
#else
dim3 grid(seq_len, batch_size, head_num);
dim3 block((seq_len / 2 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
SOFTMAX_KERNEL_WITH_MASK(1);
#endif
} else {
SoftmaxKernelWithEltadd2<__half2><<<grid, block, 0, stream>>>(
reinterpret_cast<__half2 *>(qk_buf_),
reinterpret_cast<const __half2 *>(bias_qk),
batch_size,
head_num,
seq_len / 2,
FINAL_MASK);
}
}
} else {
block = (seq_len <= WARP_SIZE)
? WARP_SIZE
: ((seq_len + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
SoftmaxKernelWithEltadd<T><<<grid, block, 0, stream>>>(
qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
}
} else {
int grid = batch_size * head_num * seq_len;
int block = 512;
if (seq_len % 2 == 0) {
if (std::is_same<T, float>::value) {
SoftmaxKernelWithEltaddForLarge2<float2><<<grid, block, 0, stream>>>(
reinterpret_cast<float2 *>(qk_buf_),
reinterpret_cast<const float2 *>(bias_qk),
batch_size,
head_num,
seq_len / 2,
FINAL_MASK);
} else {
if (bias_is_mask) {
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700)
PADDLE_ENFORCE_EQ(bias_is_mask,
false,
phi::errors::InvalidArgument(
"QK_bias is mask can't be supported on rocm or "
"cuda_arch<700"));
#else
dim3 grid(seq_len, batch_size, head_num);
dim3 block((seq_len / 2 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
if (block.x > 0 && block.x <= 1024) {
SOFTMAX_KERNEL_WITH_MASK(1);
} else if (block.x <= 2048) {
SOFTMAX_KERNEL_WITH_MASK(2);
} else if (block.x <= 4096) {
SOFTMAX_KERNEL_WITH_MASK(4);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Cannot support the length of attention > 8192."));
}
#endif
} else {
SoftmaxKernelWithEltaddForLarge2<__half2><<<grid, block, 0, stream>>>(
reinterpret_cast<__half2 *>(qk_buf_),
reinterpret_cast<const __half2 *>(bias_qk),
batch_size,
head_num,
seq_len / 2,
FINAL_MASK);
}
}
} else {
SoftmaxKernelWithEltaddForLarge<T><<<grid, block, 0, stream>>>(
qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
}
}
}
template <typename T>
inline void MatmulWithHeadQKV(const phi::GPUContext &context,
int head_num,
int seq_len,
int size_per_head,
int batch_size,
bool qk_trans,
bool v_trans,
T *v_buf_,
const T *qk_buf_,
T *dst,
T alpha,
T beta) {
int m = batch_size * seq_len;
int k = head_num * size_per_head;
typedef typename CUDATypeTraits<T>::TYPE run_type;
auto blas = phi::funcs::GetBlas<phi::GPUContext, run_type>(context);
auto stream = context.stream();
CBLAS_TRANSPOSE transA = !qk_trans ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !v_trans ? CblasNoTrans : CblasTrans;
blas.BatchedGEMM(transA,
transB,
seq_len,
size_per_head,
seq_len,
static_cast<run_type>(alpha),
reinterpret_cast<const run_type *>(qk_buf_),
reinterpret_cast<run_type *>(v_buf_),
static_cast<run_type>(beta),
reinterpret_cast<run_type *>(dst),
batch_size * head_num,
seq_len * seq_len,
seq_len * size_per_head);
}
template <typename T>
void MultiheadGPUComputeFunctor<T>::operator()(const phi::GPUContext &dev_ctx,
int batch,
int seq_len,
int head_num,
int head_size,
T *qkptr,
const T *bias_qk_ptr,
bool bias_is_mask,
T *tptr,
T alpha,
T beta) {
auto stream = dev_ctx.stream();
const int tsize = batch * head_num * seq_len * head_size;
T *qptr = tptr;
T *kptr = qptr + tsize;
T *vptr = kptr + tsize;
// batch gemm stride, softmaxwithscale.
MatmulWithHeadQK<T>(dev_ctx,
head_num,
seq_len,
head_size,
batch,
false,
true,
qptr,
kptr,
qkptr,
bias_qk_ptr,
bias_is_mask,
alpha,
beta);
// batch gemm stride, transpose.
MatmulWithHeadQKV<T>(dev_ctx,
head_num,
seq_len,
head_size,
batch,
false,
false,
vptr,
qkptr,
tptr,
T(1.0),
beta);
}
template class MultiheadGPUComputeFunctor<float>;
// device function 'operator()' is not supportted until cuda 10.0
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
template class MultiheadGPUComputeFunctor<half>;
#endif
} // namespace funcs
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace phi {
namespace funcs {
// This functor involves a fusion calculation in Ernie or Bert.
// The fusion mode is as follows:
//
// | |
// matmul
// |
// eltwise_add
// |
// softmax /
// \ /
// matmul
// |
template <typename T>
class MultiheadGPUComputeFunctor {
public:
void operator()(const phi::GPUContext &dev_ctx,
int batch,
int seq_len,
int head_num,
int head_size,
T *qkptr,
const T *bias_qk_ptr,
bool bias_is_mask,
T *tptr,
T alpha,
T beta);
};
} // namespace funcs
} // namespace phi
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,20 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/fluid/platform/device_context.h>
#include <algorithm>
#include <type_traits>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/multihead_matmul_functor.h"
namespace paddle {
namespace operators {
namespace phi {
namespace fusion {
template <typename T>
__global__ void transpose(T *src,
......@@ -149,7 +148,7 @@ void TransQKVWithBias(const int batch,
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num,
1024,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
......@@ -165,7 +164,7 @@ void TransQKVWithBias(const int batch,
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num,
1024,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
......@@ -177,7 +176,7 @@ void TransQKVWithBias(const int batch,
// limit head_size * head_num to max block size(1024).
PADDLE_ENFORCE_LE(head_size * head_num,
1024,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
......@@ -193,9 +192,9 @@ void TransQKVWithBias(const int batch,
const int seq_len,
const int head_size,
const int head_num,
const platform::float16 *input,
const platform::float16 *bias,
platform::float16 *output,
const phi::dtype::float16 *input,
const phi::dtype::float16 *bias,
phi::dtype::float16 *output,
gpuStream_t stream) {
// BxSx3xNxH + 3xNxH -> 3xBxNxSxH
int scratch_size = batch * head_num * seq_len * seq_len;
......@@ -209,7 +208,7 @@ void TransQKVWithBias(const int batch,
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num,
1024,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
......@@ -225,7 +224,7 @@ void TransQKVWithBias(const int batch,
// limit head_size * head_num to max block size(1024).
PADDLE_ENFORCE_LE(head_size * head_num,
1024,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
......@@ -240,7 +239,7 @@ inline int round_up(int seq_len, int multiple = 32) {
PADDLE_ENFORCE_GT(
multiple,
0,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"multiple should be a positive number, but it's (%d)", multiple));
return ((seq_len + multiple - 1) / multiple) * multiple;
}
......@@ -270,168 +269,166 @@ __global__ void broadcast_batch_head_number(const T *src,
}
}
template <typename T, typename DeviceContext>
class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *input = context.Input<phi::DenseTensor>("Input");
auto *w = context.Input<phi::DenseTensor>("W");
auto *bias = context.Input<phi::DenseTensor>("Bias");
auto *bias_qk = context.Input<phi::DenseTensor>("BiasQK");
auto *input_d = input->data<T>();
auto *w_d = w->data<T>();
auto *bias_d = bias->data<T>();
auto *bias_qk_d = bias_qk ? bias_qk->data<T>() : nullptr;
T scale = static_cast<T>(context.Attr<float>("alpha"));
int head_number = context.Attr<int>("head_number");
// compute q*k with eltadd
auto &device_ctx = context.template device_context<DeviceContext>();
auto stream = device_ctx.stream();
// should be (B * S * hidden)
auto input_dims = input->dims();
// shouble be (hidden * 3 * all_head_size)
auto w_dims = w->dims();
int batch = input_dims[0];
int seq_len = input_dims[1];
int hidden = input_dims[2];
phi::DenseTensor temp_bias_tensor;
// if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
if (bias_qk && bias_qk->numel() == (batch * seq_len)) {
VLOG(4) << "Do broadcasted bias_qk from [batch, 1, 1, seq_len]";
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
auto *temp_qk_bias = device_ctx.template Alloc<T>(
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
int grid = batch * head_number * seq_len;
int block = round_up(seq_len);
broadcast<<<grid, block, 0, stream>>>(
bias_qk_d, temp_qk_bias, seq_len, head_number);
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
// if bias_qk is[1, 1, seq_len, seq_len], the bias_qk_d need to be
// broadcasted
if (bias_qk && bias_qk->numel() == (1 * seq_len * seq_len)) {
VLOG(4) << "do broadcasted bias_qk from [1, 1, seq_len, seq_len]";
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
auto *temp_qk_bias = device_ctx.template Alloc<T>(
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
int grid = batch * head_number * seq_len;
int block = round_up(seq_len);
broadcast_batch_head_number<<<grid, block, 0, stream>>>(
bias_qk_d, temp_qk_bias, batch, seq_len, head_number);
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
if (!bias_qk) {
int size = batch * head_number * seq_len * seq_len;
temp_bias_tensor.Resize({size});
auto *temp_qk_bias = device_ctx.template Alloc<T>(
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
template <typename T, typename Context>
void MultiheadMatmulKernel(const Context &dev_ctx,
const DenseTensor &input,
const DenseTensor &w,
const DenseTensor &bias,
const paddle::optional<DenseTensor> &bias_qk,
const bool transpose_q,
const bool transpose_k,
const bool transpose_v,
const float alpha,
const int head_number,
DenseTensor *out) {
auto *input_d = input.data<T>();
auto *w_d = w.data<T>();
auto *bias_d = bias.data<T>();
auto *bias_qk_d = bias_qk ? bias_qk->data<T>() : nullptr;
T scale = static_cast<T>(alpha);
// compute q*k with eltadd
auto stream = dev_ctx.stream();
// should be (B * S * hidden)
auto input_dims = input.dims();
// shouble be (hidden * 3 * all_head_size)
auto w_dims = w.dims();
int batch = input_dims[0];
int seq_len = input_dims[1];
int hidden = input_dims[2];
phi::DenseTensor temp_bias_tensor;
// if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
if (bias_qk && bias_qk->numel() == (batch * seq_len)) {
VLOG(4) << "Do broadcasted bias_qk from [batch, 1, 1, seq_len]";
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
auto *temp_qk_bias = dev_ctx.template Alloc<T>(
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
int grid = batch * head_number * seq_len;
int block = round_up(seq_len);
broadcast<<<grid, block, 0, stream>>>(
bias_qk_d, temp_qk_bias, seq_len, head_number);
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
// if bias_qk is[1, 1, seq_len, seq_len], the bias_qk_d need to be
// broadcasted
if (bias_qk && bias_qk->numel() == (1 * seq_len * seq_len)) {
VLOG(4) << "do broadcasted bias_qk from [1, 1, seq_len, seq_len]";
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
auto *temp_qk_bias = dev_ctx.template Alloc<T>(
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
int grid = batch * head_number * seq_len;
int block = round_up(seq_len);
broadcast_batch_head_number<<<grid, block, 0, stream>>>(
bias_qk_d, temp_qk_bias, batch, seq_len, head_number);
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
if (!bias_qk) {
int size = batch * head_number * seq_len * seq_len;
temp_bias_tensor.Resize({size});
auto *temp_qk_bias = dev_ctx.template Alloc<T>(
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
#ifdef PADDLE_WITH_HIP
hipMemset(temp_qk_bias, 0, sizeof(float) * size);
hipMemset(temp_qk_bias, 0, sizeof(float) * size);
#else
cudaMemset(temp_qk_bias, 0, sizeof(float) * size);
cudaMemset(temp_qk_bias, 0, sizeof(float) * size);
#endif
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
int all_head_size = w_dims[2];
int head_size = all_head_size / head_number;
auto *out = context.Output<phi::DenseTensor>("Out");
out->Resize({batch, seq_len, all_head_size});
auto *output_d =
device_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
// (B*S, hidden)
const phi::DenseTensor input_matrix =
phi::ReshapeToMatrix(*input, 2 /*x_num_col_dims */);
// (hidden, 3 * all_head_size)
const phi::DenseTensor w_matrix =
phi::ReshapeToMatrix(*w, 1 /*y_num_col_dims*/);
phi::DenseTensor temp_out_tensor;
auto temp_out_dims =
phi::make_ddim({batch, seq_len, 3, head_number, head_size});
temp_out_tensor.Resize(
{batch * seq_len, phi::product(temp_out_dims) / (batch * seq_len)});
auto *temp_out_data = device_ctx.template Alloc<T>(
&temp_out_tensor, temp_out_tensor.numel() * sizeof(T));
// (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(device_ctx);
blas.MatMul(input_matrix, w_matrix, &temp_out_tensor);
VLOG(2) << "(B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)";
VLOG(2) << temp_out_tensor;
// temp_out_tensor.Resize(temp_out_dims);
phi::DenseTensor multihead_temp_tensor;
// B * head_number * S * S * 1 + B * S * 3 * N * H
int scratch_size = batch * head_number * seq_len * seq_len * 1;
multihead_temp_tensor.Resize({scratch_size + temp_out_tensor.numel()});
auto *multihead_temp_data = device_ctx.template Alloc<T>(
&multihead_temp_tensor, multihead_temp_tensor.numel() * sizeof(T));
auto *qkptr = multihead_temp_data;
auto *tptr = multihead_temp_data + scratch_size;
// Do the transpose with bias.
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransQKVWithBias(batch,
seq_len,
head_size,
head_number,
temp_out_data,
bias_d,
tptr,
stream);
if (std::is_same<T, platform::float16>::value) {
math::MultiHeadGPUComputeFunctor<half> multihead_compute_func;
multihead_compute_func(device_ctx,
batch,
seq_len,
head_number,
head_size,
reinterpret_cast<half *>(qkptr),
reinterpret_cast<const half *>(bias_qk_d),
false,
reinterpret_cast<half *>(tptr),
__float2half(static_cast<float>(scale)),
__float2half(0.0));
} else {
math::MultiHeadGPUComputeFunctor<T> multihead_compute_func;
multihead_compute_func(device_ctx,
batch,
seq_len,
head_number,
head_size,
qkptr,
bias_qk_d,
false,
tptr,
scale,
T(0.0));
}
int grid = batch * head_number * seq_len;
int block = head_size;
transpose<T><<<grid, block, 0, stream>>>(
tptr, output_d, batch, seq_len, head_number, head_size);
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
int all_head_size = w_dims[2];
int head_size = all_head_size / head_number;
out->Resize({batch, seq_len, all_head_size});
auto *output_d = dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
// (B*S, hidden)
const phi::DenseTensor input_matrix =
phi::ReshapeToMatrix(input, 2 /*x_num_col_dims */);
// (hidden, 3 * all_head_size)
const phi::DenseTensor w_matrix =
phi::ReshapeToMatrix(w, 1 /*y_num_col_dims*/);
phi::DenseTensor temp_out_tensor;
auto temp_out_dims =
phi::make_ddim({batch, seq_len, 3, head_number, head_size});
temp_out_tensor.Resize(
{batch * seq_len, phi::product(temp_out_dims) / (batch * seq_len)});
auto *temp_out_data = dev_ctx.template Alloc<T>(
&temp_out_tensor, temp_out_tensor.numel() * sizeof(T));
// (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx);
blas.MatMul(input_matrix, w_matrix, &temp_out_tensor);
VLOG(2) << "(B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)";
// temp_out_tensor.Resize(temp_out_dims);
phi::DenseTensor multihead_temp_tensor;
// B * head_number * S * S * 1 + B * S * 3 * N * H
int scratch_size = batch * head_number * seq_len * seq_len * 1;
multihead_temp_tensor.Resize({scratch_size + temp_out_tensor.numel()});
auto *multihead_temp_data = dev_ctx.template Alloc<T>(
&multihead_temp_tensor, multihead_temp_tensor.numel() * sizeof(T));
auto *qkptr = multihead_temp_data;
auto *tptr = multihead_temp_data + scratch_size;
// Do the transpose with bias.
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransQKVWithBias(batch,
seq_len,
head_size,
head_number,
temp_out_data,
bias_d,
tptr,
stream);
if (std::is_same<T, phi::dtype::float16>::value) {
phi::funcs::MultiheadGPUComputeFunctor<half> multihead_compute_func;
multihead_compute_func(dev_ctx,
batch,
seq_len,
head_number,
head_size,
reinterpret_cast<half *>(qkptr),
reinterpret_cast<const half *>(bias_qk_d),
false,
reinterpret_cast<half *>(tptr),
__float2half(static_cast<float>(scale)),
__float2half(0.0));
} else {
phi::funcs::MultiheadGPUComputeFunctor<T> multihead_compute_func;
multihead_compute_func(dev_ctx,
batch,
seq_len,
head_number,
head_size,
qkptr,
bias_qk_d,
false,
tptr,
scale,
T(0.0));
}
};
} // namespace operators
} // namespace paddle
int grid = batch * head_number * seq_len;
int block = head_size;
transpose<T><<<grid, block, 0, stream>>>(
tptr, output_d, batch, seq_len, head_number, head_size);
}
} // namespace fusion
} // namespace phi
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
PD_REGISTER_STRUCT_KERNEL(multihead_matmul,
GPU,
ALL_LAYOUT,
ops::MultiHeadMatMulV2Kernel,
float,
plat::float16) {}
PD_REGISTER_KERNEL(multihead_matmul,
GPU,
ALL_LAYOUT,
phi::fusion::MultiheadMatmulKernel,
float,
phi::dtype::float16) {}
#else
PD_REGISTER_STRUCT_KERNEL(
multihead_matmul, GPU, ALL_LAYOUT, ops::MultiHeadMatMulV2Kernel, float) {}
PD_REGISTER_KERNEL(multihead_matmul,
GPU,
ALL_LAYOUT,
phi::fusion::MultiheadMatmulKernel,
float) {}
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册