未验证 提交 11b9d85f 编写于 作者: W Wang Bojun 提交者: GitHub

fix: multihead matmul biasqk broadcast support for [1,1,seq,seq] shape (#47975)

* add trt support
上级 57e22f58
...@@ -1744,13 +1744,19 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -1744,13 +1744,19 @@ struct SimpleOpTypeSetTeller : public Teller {
input_shape[1] == biasqk_shape[3]; input_shape[1] == biasqk_shape[3];
bool is_broadcastable = biasqk_shape[1] == 1 && biasqk_shape[2] == 1 && bool is_broadcastable = biasqk_shape[1] == 1 && biasqk_shape[2] == 1 &&
input_shape[1] == biasqk_shape[3]; input_shape[1] == biasqk_shape[3];
is_broadcastable =
is_broadcastable || (biasqk_shape[0] == 1 && biasqk_shape[1] == 1 &&
input_shape[1] == biasqk_shape[2] &&
input_shape[1] == biasqk_shape[3]);
if (!(has_same_shape || is_broadcastable)) { if (!(has_same_shape || is_broadcastable)) {
VLOG(3) << "The BiasQK's shape is invalid, expect [" << input_shape[0] VLOG(3) << "The BiasQK's shape is invalid, expect [" << input_shape[0]
<< ", 1, 1, " << input_shape[1] << "] or [" << input_shape[0] << ", 1, 1, " << input_shape[1] << "] "
<< ", " << head_number << ", " << input_shape[1] << ", " << "or [" << input_shape[0] << ", " << head_number << ", "
<< input_shape[1] << "] but [" << biasqk_shape[0] << ", " << input_shape[1] << ", " << input_shape[1] << "] "
<< biasqk_shape[1] << ", " << biasqk_shape[2] << ", " << "or [" << input_shape[0] << "/1, " << 1 << ", "
<< biasqk_shape[3] << "]."; << input_shape[1] << ", " << input_shape[1] << "] "
<< "but got [" << biasqk_shape[0] << ", " << biasqk_shape[1]
<< ", " << biasqk_shape[2] << ", " << biasqk_shape[3] << "].";
return false; return false;
} }
} else { } else {
......
...@@ -309,6 +309,19 @@ __global__ void broadcast(const T *src, ...@@ -309,6 +309,19 @@ __global__ void broadcast(const T *src,
} }
} }
template <typename T>
__global__ void broadcast_batch_head_number(const T *src,
T *dst,
const int batch_size,
const int seq_len,
const int head_num) {
int batch_id = blockIdx.x % seq_len;
int dst_offset = blockIdx.x * seq_len;
if (threadIdx.x < seq_len) {
dst[threadIdx.x + dst_offset] = src[threadIdx.x + batch_id * seq_len];
}
}
int QkvToContextPluginDynamic::enqueue( int QkvToContextPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc, const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const nvinfer1::PluginTensorDesc *output_desc,
...@@ -353,6 +366,22 @@ int QkvToContextPluginDynamic::enqueue( ...@@ -353,6 +366,22 @@ int QkvToContextPluginDynamic::enqueue(
head_number_); head_number_);
qk_bias = temp_qk_bias; qk_bias = temp_qk_bias;
} }
// fit to [batch, head_num, length, length] + [1, 1, length, length]
if (ProductDim(input_desc[1].dims) == (seq_len * seq_len)) {
temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len});
auto *temp_qk_bias =
reinterpret_cast<float *>(temp_qk_bias_tensor.mutable_data<float>(
platform::CUDAPlace(device_id)));
int grid = batch * head_number_ * seq_len;
int block = round_up(seq_len);
broadcast_batch_head_number<<<grid, block, 0, stream>>>(
static_cast<const float *>(inputs[1]),
temp_qk_bias,
batch,
seq_len,
head_number_);
qk_bias = temp_qk_bias;
}
// fake qk_bias // fake qk_bias
if (ProductDim(input_desc[1].dims) == ProductDim(input_desc[0].dims)) { if (ProductDim(input_desc[1].dims) == ProductDim(input_desc[0].dims)) {
qk_bias = fake_qk_bias_; qk_bias = fake_qk_bias_;
...@@ -424,6 +453,22 @@ int QkvToContextPluginDynamic::enqueue( ...@@ -424,6 +453,22 @@ int QkvToContextPluginDynamic::enqueue(
head_number_); head_number_);
qk_bias = temp_qk_bias; qk_bias = temp_qk_bias;
} }
// fit to [batch, head_num, length, length] + [1, 1, length, length]
if (ProductDim(input_desc[1].dims) == (seq_len * seq_len)) {
temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len});
auto *temp_qk_bias =
reinterpret_cast<half *>(temp_qk_bias_tensor.mutable_data<int16_t>(
platform::CUDAPlace(device_id)));
int grid = batch * head_number_ * seq_len;
int block = round_up(seq_len);
broadcast_batch_head_number<<<grid, block, 0, stream>>>(
static_cast<const half *>(inputs[1]),
temp_qk_bias,
batch,
seq_len,
head_number_);
qk_bias = temp_qk_bias;
}
// padding: mask_half_ = [1.0,....1.0...1.0....,0.0f] // padding: mask_half_ = [1.0,....1.0...1.0....,0.0f]
// no_padding: mask_half_ = [1.0,....1.0,.........,1.0f] // no_padding: mask_half_ = [1.0,....1.0,.........,1.0f]
bool bias_is_mask = false; bool bias_is_mask = false;
......
...@@ -256,6 +256,19 @@ __global__ void broadcast(const T *src, ...@@ -256,6 +256,19 @@ __global__ void broadcast(const T *src,
} }
} }
template <typename T>
__global__ void broadcast_batch_head_number(const T *src,
T *dst,
const int batch_size,
const int seq_len,
const int head_num) {
int src_seq_id = blockIdx.x % seq_len;
int dst_offset = blockIdx.x * seq_len;
if (threadIdx.x < seq_len) {
dst[threadIdx.x + dst_offset] = src[threadIdx.x + src_seq_id * seq_len];
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
public: public:
...@@ -286,6 +299,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { ...@@ -286,6 +299,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
Tensor temp_bias_tensor; Tensor temp_bias_tensor;
// if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted // if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
if (bias_qk && bias_qk->numel() == (batch * seq_len)) { if (bias_qk && bias_qk->numel() == (batch * seq_len)) {
VLOG(4) << "Do broadcasted bias_qk from [batch, 1, 1, seq_len]";
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len}); temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
auto *temp_qk_bias = device_ctx.template Alloc<T>( auto *temp_qk_bias = device_ctx.template Alloc<T>(
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T)); &temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
...@@ -295,6 +309,19 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { ...@@ -295,6 +309,19 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
bias_qk_d, temp_qk_bias, seq_len, head_number); bias_qk_d, temp_qk_bias, seq_len, head_number);
bias_qk_d = static_cast<const T *>(temp_qk_bias); bias_qk_d = static_cast<const T *>(temp_qk_bias);
} }
// if bias_qk is[1, 1, seq_len, seq_len], the bias_qk_d need to be
// broadcasted
if (bias_qk && bias_qk->numel() == (1 * seq_len * seq_len)) {
VLOG(4) << "do broadcasted bias_qk from [1, 1, seq_len, seq_len]";
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
auto *temp_qk_bias = device_ctx.template Alloc<T>(
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
int grid = batch * head_number * seq_len;
int block = round_up(seq_len);
broadcast_batch_head_number<<<grid, block, 0, stream>>>(
bias_qk_d, temp_qk_bias, batch, seq_len, head_number);
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
if (!bias_qk) { if (!bias_qk) {
int size = batch * head_number * seq_len * seq_len; int size = batch * head_number * seq_len * seq_len;
temp_bias_tensor.Resize({size}); temp_bias_tensor.Resize({size});
...@@ -333,7 +360,8 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { ...@@ -333,7 +360,8 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
// (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H) // (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(device_ctx); auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(device_ctx);
blas.MatMul(input_matrix, w_matrix, &temp_out_tensor); blas.MatMul(input_matrix, w_matrix, &temp_out_tensor);
VLOG(2) << "(B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)";
VLOG(2) << temp_out_tensor;
// temp_out_tensor.Resize(temp_out_dims); // temp_out_tensor.Resize(temp_out_dims);
Tensor multihead_temp_tensor; Tensor multihead_temp_tensor;
......
...@@ -29,6 +29,113 @@ def stable_softmax(x): ...@@ -29,6 +29,113 @@ def stable_softmax(x):
return exps / np.sum(exps) return exps / np.sum(exps)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "Paddle core is not compiled with CUDA"
)
class TestFusedMultiHeadMatmulOp_biasqk2(OpTest):
def config(self):
self.seq_len = 128
self.size_per_head = 64
self.head_number = 12
self.batch_size = 8
self.scale = 0.125
def setUp(self):
self.op_type = "multihead_matmul"
self.config()
h = self.seq_len
w = self.head_number * self.size_per_head
self.Input = (
np.random.random((self.batch_size, h, w)).astype("float32") - 0.5
)
self.WQ = np.random.random((w, w)).astype("float32")
self.KQ = np.random.random((w, w)).astype("float32")
self.VQ = np.random.random((w, w)).astype("float32")
self.CombinedW = np.hstack((self.WQ, self.KQ, self.VQ)).reshape(
(w, 3, w)
)
self.Q = np.dot(self.Input, self.WQ)
self.K = np.dot(self.Input, self.KQ)
self.V = np.dot(self.Input, self.VQ)
self.BiasQ = np.random.random((1, w)).astype("float32")
self.BiasK = np.random.random((1, w)).astype("float32")
self.BiasV = np.random.random((1, w)).astype("float32")
self.CombinedB = np.vstack((self.BiasQ, self.BiasK, self.BiasV))
self.BiasQK = np.random.random(
(1, 1, self.seq_len, self.seq_len)
).astype("float32")
# Compute Q path
fc_q = self.Q + self.BiasQ
reshape_q = np.reshape(
fc_q,
(
self.batch_size,
self.seq_len,
self.head_number,
self.size_per_head,
),
)
transpose_q = np.transpose(reshape_q, (0, 2, 1, 3))
scale_q = self.scale * transpose_q
# Compute K path
fc_k = self.K + self.BiasK
reshape_k = np.reshape(
fc_k,
(
self.batch_size,
self.seq_len,
self.head_number,
self.size_per_head,
),
)
transpose_k = np.transpose(reshape_k, (0, 2, 3, 1))
# Compute Q*K
q_k = np.matmul(scale_q, transpose_k)
eltadd_qk = q_k + np.tile(
self.BiasQK, [self.batch_size, self.head_number, 1, 1]
)
softmax_qk = np.apply_along_axis(stable_softmax, 3, eltadd_qk)
# Compute V path
fc_v = self.V + self.BiasV
reshape_v = np.reshape(
fc_v,
(
self.batch_size,
self.seq_len,
self.head_number,
self.size_per_head,
),
)
transpose_v = np.transpose(reshape_v, (0, 2, 1, 3))
# Compute QK*V
qkv = np.matmul(softmax_qk, transpose_v)
transpose_qkv = np.transpose(qkv, (0, 2, 1, 3))
reshape_qkv = np.reshape(transpose_qkv, (self.batch_size, h, w))
print("biasqk shape")
print(self.BiasQK.shape)
self.inputs = {
"Input": self.Input,
"W": self.CombinedW,
"Bias": self.CombinedB,
"BiasQK": self.BiasQK,
}
self.attrs = {
"transpose_Q": False,
"transpose_K": True,
"transpose_V": False,
"head_number": self.head_number,
"alpha": self.scale,
}
self.outputs = {"Out": reshape_qkv}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=2e-3)
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "Paddle core is not compiled with CUDA" not core.is_compiled_with_cuda(), "Paddle core is not compiled with CUDA"
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册