未验证 提交 520adc0e 编写于 作者: F feng_shuai 提交者: GitHub

optimize: vit 384 (#47432)

* optimize: vit 384

* fix:bug

* fix:bug

* fix:supoort rocm complie

* refactor:name

* fix:support rocm

* fix:__HIP_NO_HALF_CONVERSIONS__

* optimize: delete scalar

* fix:rocm can't support

* fix:ernie error
上级 b03b4a3c
......@@ -63,7 +63,7 @@ template <typename T>
__global__ void reset_qk_bias(T *input, int real_seq_len, int seq_len) {
if (threadIdx.x < seq_len) {
int id = threadIdx.x + blockIdx.x * seq_len;
input[id] = threadIdx.x >= real_seq_len ? (T)-1e20f : (T)0.0f;
input[id] = threadIdx.x >= real_seq_len ? (T)0.0f : (T)1.0f;
}
}
......@@ -292,8 +292,9 @@ void QkvToContextPluginDynamic::configurePlugin(
const phi::GPUContext &dev_ctx = *device_ctx;
auto stream = dev_ctx.stream();
tensor_.Resize({batch, seq_len, seq_len, head_number_});
int blocks = batch * head_number_ * seq_len;
if (in[0].desc.type == nvinfer1::DataType::kHALF) {
tensor_.Resize({batch, seq_len, seq_len, 1});
int blocks = batch * 1 * seq_len;
mask_half_ = reinterpret_cast<half *>(
tensor_.mutable_data<int16_t>(platform::CUDAPlace(device_id)));
reset_qk_bias<<<blocks, 1024, 0, stream>>>(
......@@ -462,6 +463,7 @@ int QkvToContextPluginDynamic::enqueue(
head_size_,
qkptr,
input1_data,
false,
tptr,
scale_,
static_cast<float>(0.0));
......@@ -510,10 +512,12 @@ int QkvToContextPluginDynamic::enqueue(
head_number_);
qk_bias = temp_qk_bias;
}
// padding: mask_half_ = [0,0,...-1e20f,-1e20f]
// no_padding: mask_half_ = [0,.....0,.........,0]
// padding: mask_half_ = [1.0,....1.0...1.0....,0.0f]
// no_padding: mask_half_ = [1.0,....1.0,.........,1.0f]
bool bias_is_mask = false;
if (ProductDim(input_desc[1].dims) == ProductDim(input_desc[0].dims)) {
qk_bias = mask_half_;
bias_is_mask = true;
}
const half *input1_data = static_cast<const half *>(qk_bias);
// BxSx3xNxH => tptr: 3xBxNxSxH.
......@@ -552,6 +556,7 @@ int QkvToContextPluginDynamic::enqueue(
head_size_,
qkptr,
input1_data,
bias_is_mask,
tptr,
half(1.),
half(0.0));
......
......@@ -365,6 +365,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
head_size,
reinterpret_cast<half *>(qkptr),
reinterpret_cast<const half *>(bias_qk_d),
false,
reinterpret_cast<half *>(tptr),
__float2half(static_cast<float>(scale)),
__float2half(0.0));
......@@ -377,6 +378,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
head_size,
qkptr,
bias_qk_d,
false,
tptr,
scale,
T(0.0));
......
......@@ -532,6 +532,257 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
#endif
}
template <typename T>
inline __device__ T ldg(const T *val) {
return __ldg(val);
}
template <typename T>
inline __device__ T hexp2(T a) {
return h2exp(a);
}
template <typename T_IN, typename T_OUT>
inline __device__ T_OUT type2type2(T_IN a);
template <>
inline __device__ half2 type2type2(half a) {
return __half2half2(a);
}
template <typename T>
inline __device__ T float2type2(float a);
template <>
inline __device__ half2 float2type2(float a) {
return __float2half2_rn(a);
}
template <typename T>
inline __device__ T hmul2(T a, T b) {
return __hmul2(a, b);
}
template <typename T>
inline __device__ T hsub2(T a, T b) {
return __hsub2(a, b);
}
template <typename T>
inline __device__ T hadd2(T a, T b) {
return __hadd2(a, b);
}
template <typename T, int NUM>
__inline__ __device__ T warpReduceSumV2(T *val) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
}
return (T)(0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceSumV2(T *val) {
static __shared__ T shared[NUM][33];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduceSumV2<T, NUM>(val);
if (lane == 0) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[i][wid] = val[i];
}
}
__syncthreads();
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
}
warpReduceSumV2<T, NUM>(val);
return (T)0.0f;
}
template <typename T, int NUM>
__inline__ __device__ T warpReduceMaxV2(T *val) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32));
}
return (T)(0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceMaxV2(T *val) {
static __shared__ T shared[32][NUM];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
warpReduceMaxV2<T, NUM>(val); // get maxx in each warp
if (lane == 0) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[wid][i] = val[i];
}
}
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[lane][i] : (T)-1e20f;
}
warpReduceMaxV2<T, NUM>(val);
return (T)0.0f;
}
template <typename T, int ITEMS_PER_THREAD, int NUM>
__global__ void softmax_kernel_with_mask(T *qk_buf_,
const T *attr_mask,
const int batch_size,
const int head_num,
const int seq_len) {
using T2 = half2;
T2 *qk_buf_half2 = reinterpret_cast<T2 *>(qk_buf_);
const T2 *attr_mask_half2 = (const T2 *)attr_mask;
for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x * NUM) {
T2 data[NUM][ITEMS_PER_THREAD];
int qk_offset[NUM];
__shared__ float s_sum[NUM], s_max[NUM];
float local_max[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
local_max[j] = -1e20f;
}
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD;
i++) {
int mask_offset[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len +
seq_id + j * gridDim.x) *
(seq_len / 2) +
blockDim.x * i + threadIdx.x;
mask_offset[j] =
(blockIdx.y * seq_len + seq_id + j * gridDim.x) * (seq_len / 2) +
blockDim.x * i + threadIdx.x;
}
T2 mask_val[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
mask_val[j] = ldg(&attr_mask_half2[mask_offset[j]]);
}
T2 qk[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
qk[j] = qk_buf_half2[qk_offset[j]];
}
#pragma unroll
for (int j = 0; j < NUM; j++) {
mask_val[j] = hmul2<T2>(hsub2<T2>(float2type2<T2>(1.0f), mask_val[j]),
float2type2<T2>(-10000.0f));
}
#pragma unroll
for (int j = 0; j < NUM; j++) {
data[j][i] = hadd2<T2>(qk[j], mask_val[j]);
local_max[j] = fmax(local_max[j],
fmax(static_cast<float>(data[j][i].x),
static_cast<float>(data[j][i].y)));
}
}
if (blockDim.x <= 32) {
warpReduceMaxV2<float, NUM>(local_max);
} else {
blockReduceMaxV2<float, NUM>(local_max);
}
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < NUM; j++) {
s_max[j] = local_max[j];
}
}
__syncthreads();
float local_sum[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
local_sum[j] = {0.f};
}
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD;
i++) {
#pragma unroll
for (int j = 0; j < NUM; j++) {
data[j][i] =
hexp2<T2>(hsub2<T2>(data[j][i], float2type2<T2>(s_max[j])));
}
#pragma unroll
for (int j = 0; j < NUM; j++) {
local_sum[j] += static_cast<float>(data[j][i].x + data[j][i].y);
}
}
if (blockDim.x <= 32) {
warpReduceSumV2<float, NUM>(local_sum);
} else {
blockReduceSumV2<float, NUM>(local_sum);
}
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < NUM; j++) {
s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f);
}
}
__syncthreads();
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD;
i++) {
#pragma unroll
for (int j = 0; j < NUM; j++) {
qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len +
seq_id + j * gridDim.x) *
(seq_len / 2) +
blockDim.x * i + threadIdx.x;
}
#pragma unroll
for (int j = 0; j < NUM; j++) {
qk_buf_half2[qk_offset[j]] =
hmul2<T2>(data[j][i], float2type2<T2>(s_sum[j]));
}
}
}
}
template <typename T>
inline void MatMulWithHeadQK(const phi::GPUContext &context,
int head_num,
......@@ -544,6 +795,7 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
T *k_buf_,
T *qk_buf_,
const T *bias_qk,
bool bias_is_mask,
T alpha,
T beta) {
CBLAS_TRANSPOSE transA = !q_trans ? CblasNoTrans : CblasTrans;
......@@ -583,13 +835,39 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
seq_len / 2,
FINAL_MASK);
} else {
SoftmaxKernelWithEltadd2<__half2><<<grid, block, 0, stream>>>(
reinterpret_cast<__half2 *>(qk_buf_),
reinterpret_cast<const __half2 *>(bias_qk),
batch_size,
head_num,
seq_len / 2,
FINAL_MASK);
if (bias_is_mask) {
#ifndef __HIPCC__
constexpr int ITEMS_PER_THREAD = 1;
bool is_half2 = true;
dim3 grid(seq_len, batch_size, head_num);
dim3 block((seq_len / 2 + 31) / 32 * 32);
block.x /= ITEMS_PER_THREAD;
assert(block.x <= 1024);
assert(grid.x % 4 == 0);
grid.x /= 4;
constexpr int NUM = 4;
softmax_kernel_with_mask<half, ITEMS_PER_THREAD, NUM>
<<<grid, block, 0, stream>>>(reinterpret_cast<half *>(qk_buf_),
(const half *)bias_qk,
batch_size,
head_num,
seq_len);
#else
PADDLE_ENFORCE_EQ(bias_is_mask,
false,
platform::errors::InvalidArgument(
"rocm can't support that QK_bias is mask"));
#endif
} else {
SoftmaxKernelWithEltadd2<__half2><<<grid, block, 0, stream>>>(
reinterpret_cast<__half2 *>(qk_buf_),
reinterpret_cast<const __half2 *>(bias_qk),
batch_size,
head_num,
seq_len / 2,
FINAL_MASK);
}
}
} else {
block = (seq_len <= 32) ? 32 : ((seq_len + 31) / 32) * 32;
......@@ -669,6 +947,7 @@ void MultiHeadGPUComputeFunctor<T>::operator()(const phi::GPUContext &dev_ctx,
int head_size,
T *qkptr,
const T *bias_qk_ptr,
bool bias_is_mask,
T *tptr,
T alpha,
T beta) {
......@@ -690,6 +969,7 @@ void MultiHeadGPUComputeFunctor<T>::operator()(const phi::GPUContext &dev_ctx,
kptr,
qkptr,
bias_qk_ptr,
bias_is_mask,
alpha,
beta);
// batch gemm stride, transpose.
......
......@@ -100,6 +100,7 @@ class MultiHeadGPUComputeFunctor {
int head_size,
T *qkptr,
const T *bias_qk_ptr,
bool bias_is_mask,
T *tptr,
T alpha,
T beta);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册