未验证 提交 cf9eae4c 编写于 作者: F feng_shuai 提交者: GitHub

broadcast qkv_op (#35780)

* broadcast qkv_op

* use PADDLE_ENFORCE_GT to replace assert
上级 7975dfcf
......@@ -233,6 +233,24 @@ __global__ void apply_scale(T *data, T scale, int n) {
#endif
}
inline int round_up(int seq_len, int multiple = 32) {
PADDLE_ENFORCE_GT(
multiple, 0,
platform::errors::InvalidArgument(
"multiple should be a positive number,but it's (%d)", multiple));
return ((seq_len + multiple - 1) / multiple) * multiple;
}
template <typename T>
__global__ void broadcast(const T *src, T *dst, const int seq_len,
const int head_num) {
int batch_id = blockIdx.x / (head_num * seq_len);
int dst_offset = blockIdx.x * seq_len;
if (threadIdx.x < seq_len) {
dst[threadIdx.x + dst_offset] = src[threadIdx.x + batch_id * seq_len];
}
}
int QkvToContextPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
......@@ -258,7 +276,21 @@ int QkvToContextPluginDynamic::enqueue(
auto *tptr = multihead_temp_data + scratch_size;
const float *input0_data = static_cast<const float *>(inputs[0]);
const float *input1_data = static_cast<const float *>(inputs[1]);
// fit to [batch, head_num, length, length] + [batch, 1, 1, length]
framework::Tensor temp_qk_bias_tensor;
float *qk_bias = const_cast<float *>(static_cast<const float *>(inputs[1]));
if (ProductDim(input_desc[1].dims) == (batch * seq_len)) {
temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len});
auto *temp_qk_bias = temp_qk_bias_tensor.mutable_data<float>(
platform::CUDAPlace(device_id));
int grid = batch * head_number_ * seq_len;
int block = round_up(seq_len);
broadcast<<<grid, block, 0, stream>>>(
static_cast<const float *>(inputs[1]), temp_qk_bias, seq_len,
head_number_);
qk_bias = temp_qk_bias;
}
const float *input1_data = static_cast<const float *>(qk_bias);
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransposeQKV(batch, seq_len, head_size_, head_number_, input0_data, tptr,
stream);
......@@ -290,7 +322,22 @@ int QkvToContextPluginDynamic::enqueue(
half *tptr = qkptr + scratch_size;
const half *input0_data = static_cast<const half *>(inputs[0]);
const half *input1_data = static_cast<const half *>(inputs[1]);
// fit to [batch, head_num, length, length] + [batch, 1, 1, length]
framework::Tensor temp_qk_bias_tensor;
half *qk_bias = const_cast<half *>(static_cast<const half *>(inputs[1]));
if (ProductDim(input_desc[1].dims) == (batch * seq_len)) {
temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len});
auto *temp_qk_bias =
reinterpret_cast<half *>(temp_qk_bias_tensor.mutable_data<int16_t>(
platform::CUDAPlace(device_id)));
int grid = batch * head_number_ * seq_len;
int block = round_up(seq_len);
broadcast<<<grid, block, 0, stream>>>(
static_cast<const half *>(inputs[1]), temp_qk_bias, seq_len,
head_number_);
qk_bias = temp_qk_bias;
}
const half *input1_data = static_cast<const half *>(qk_bias);
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransposeQKV(batch, seq_len, head_size_, head_number_, input0_data, tptr,
stream);
......
......@@ -132,6 +132,24 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
}
}
inline int round_up(int seq_len, int multiple = 32) {
PADDLE_ENFORCE_GT(
multiple, 0,
platform::errors::InvalidArgument(
"multiple should be a positive number,but it's (%d)", multiple));
return ((seq_len + multiple - 1) / multiple) * multiple;
}
template <typename T>
__global__ void broadcast(const T *src, T *dst, const int seq_len,
const int head_num) {
int batch_id = blockIdx.x / (head_num * seq_len);
int dst_offset = blockIdx.x * seq_len;
if (threadIdx.x < seq_len) {
dst[threadIdx.x + dst_offset] = src[threadIdx.x + batch_id * seq_len];
}
}
template <typename DeviceContext, typename T>
class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
public:
......@@ -152,6 +170,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
int head_number = context.Attr<int>("head_number");
// compute q*k with eltadd
auto &device_ctx = context.template device_context<DeviceContext>();
auto stream = device_ctx.stream();
// should be (B * S * hidden)
auto input_dims = input->dims();
// shouble be (hidden * 3 * all_head_size)
......@@ -159,7 +178,17 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
int batch = input_dims[0];
int seq_len = input_dims[1];
int hidden = input_dims[2];
Tensor temp_bias_tensor;
// if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
if (bias_qk.numel() == (batch * seq_len)) {
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
auto *temp_qk_bias = temp_bias_tensor.mutable_data<T>(context.GetPlace());
int grid = batch * head_number * seq_len;
int block = round_up(seq_len);
broadcast<<<grid, block, 0, stream>>>(bias_qk_d, temp_qk_bias, seq_len,
head_number);
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
int all_head_size = w_dims[2];
int head_size = all_head_size / head_number;
......@@ -196,7 +225,6 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
auto *qkptr = multihead_temp_data;
auto *tptr = multihead_temp_data + scratch_size;
auto stream = device_ctx.stream();
// Do the transpose with bias.
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransQKVWithBias(batch, seq_len, head_size, head_number, temp_out_data,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册