未验证 提交 11b9d85f 编写于 作者: W Wang Bojun 提交者: GitHub

fix: multihead matmul biasqk broadcast support for [1,1,seq,seq] shape (#47975)

* add trt support
上级 57e22f58
......@@ -1744,13 +1744,19 @@ struct SimpleOpTypeSetTeller : public Teller {
input_shape[1] == biasqk_shape[3];
bool is_broadcastable = biasqk_shape[1] == 1 && biasqk_shape[2] == 1 &&
input_shape[1] == biasqk_shape[3];
is_broadcastable =
is_broadcastable || (biasqk_shape[0] == 1 && biasqk_shape[1] == 1 &&
input_shape[1] == biasqk_shape[2] &&
input_shape[1] == biasqk_shape[3]);
if (!(has_same_shape || is_broadcastable)) {
VLOG(3) << "The BiasQK's shape is invalid, expect [" << input_shape[0]
<< ", 1, 1, " << input_shape[1] << "] or [" << input_shape[0]
<< ", " << head_number << ", " << input_shape[1] << ", "
<< input_shape[1] << "] but [" << biasqk_shape[0] << ", "
<< biasqk_shape[1] << ", " << biasqk_shape[2] << ", "
<< biasqk_shape[3] << "].";
<< ", 1, 1, " << input_shape[1] << "] "
<< "or [" << input_shape[0] << ", " << head_number << ", "
<< input_shape[1] << ", " << input_shape[1] << "] "
<< "or [" << input_shape[0] << "/1, " << 1 << ", "
<< input_shape[1] << ", " << input_shape[1] << "] "
<< "but got [" << biasqk_shape[0] << ", " << biasqk_shape[1]
<< ", " << biasqk_shape[2] << ", " << biasqk_shape[3] << "].";
return false;
}
} else {
......
......@@ -309,6 +309,19 @@ __global__ void broadcast(const T *src,
}
}
template <typename T>
__global__ void broadcast_batch_head_number(const T *src,
T *dst,
const int batch_size,
const int seq_len,
const int head_num) {
int batch_id = blockIdx.x % seq_len;
int dst_offset = blockIdx.x * seq_len;
if (threadIdx.x < seq_len) {
dst[threadIdx.x + dst_offset] = src[threadIdx.x + batch_id * seq_len];
}
}
int QkvToContextPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc,
......@@ -353,6 +366,22 @@ int QkvToContextPluginDynamic::enqueue(
head_number_);
qk_bias = temp_qk_bias;
}
// fit to [batch, head_num, length, length] + [1, 1, length, length]
if (ProductDim(input_desc[1].dims) == (seq_len * seq_len)) {
temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len});
auto *temp_qk_bias =
reinterpret_cast<float *>(temp_qk_bias_tensor.mutable_data<float>(
platform::CUDAPlace(device_id)));
int grid = batch * head_number_ * seq_len;
int block = round_up(seq_len);
broadcast_batch_head_number<<<grid, block, 0, stream>>>(
static_cast<const float *>(inputs[1]),
temp_qk_bias,
batch,
seq_len,
head_number_);
qk_bias = temp_qk_bias;
}
// fake qk_bias
if (ProductDim(input_desc[1].dims) == ProductDim(input_desc[0].dims)) {
qk_bias = fake_qk_bias_;
......@@ -424,6 +453,22 @@ int QkvToContextPluginDynamic::enqueue(
head_number_);
qk_bias = temp_qk_bias;
}
// fit to [batch, head_num, length, length] + [1, 1, length, length]
if (ProductDim(input_desc[1].dims) == (seq_len * seq_len)) {
temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len});
auto *temp_qk_bias =
reinterpret_cast<half *>(temp_qk_bias_tensor.mutable_data<int16_t>(
platform::CUDAPlace(device_id)));
int grid = batch * head_number_ * seq_len;
int block = round_up(seq_len);
broadcast_batch_head_number<<<grid, block, 0, stream>>>(
static_cast<const half *>(inputs[1]),
temp_qk_bias,
batch,
seq_len,
head_number_);
qk_bias = temp_qk_bias;
}
// padding: mask_half_ = [1.0,....1.0...1.0....,0.0f]
// no_padding: mask_half_ = [1.0,....1.0,.........,1.0f]
bool bias_is_mask = false;
......
......@@ -256,6 +256,19 @@ __global__ void broadcast(const T *src,
}
}
template <typename T>
__global__ void broadcast_batch_head_number(const T *src,
T *dst,
const int batch_size,
const int seq_len,
const int head_num) {
int src_seq_id = blockIdx.x % seq_len;
int dst_offset = blockIdx.x * seq_len;
if (threadIdx.x < seq_len) {
dst[threadIdx.x + dst_offset] = src[threadIdx.x + src_seq_id * seq_len];
}
}
template <typename DeviceContext, typename T>
class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
public:
......@@ -286,6 +299,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
Tensor temp_bias_tensor;
// if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
if (bias_qk && bias_qk->numel() == (batch * seq_len)) {
VLOG(4) << "Do broadcasted bias_qk from [batch, 1, 1, seq_len]";
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
auto *temp_qk_bias = device_ctx.template Alloc<T>(
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
......@@ -295,6 +309,19 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
bias_qk_d, temp_qk_bias, seq_len, head_number);
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
// if bias_qk is[1, 1, seq_len, seq_len], the bias_qk_d need to be
// broadcasted
if (bias_qk && bias_qk->numel() == (1 * seq_len * seq_len)) {
VLOG(4) << "do broadcasted bias_qk from [1, 1, seq_len, seq_len]";
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
auto *temp_qk_bias = device_ctx.template Alloc<T>(
&temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T));
int grid = batch * head_number * seq_len;
int block = round_up(seq_len);
broadcast_batch_head_number<<<grid, block, 0, stream>>>(
bias_qk_d, temp_qk_bias, batch, seq_len, head_number);
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
if (!bias_qk) {
int size = batch * head_number * seq_len * seq_len;
temp_bias_tensor.Resize({size});
......@@ -333,7 +360,8 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
// (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(device_ctx);
blas.MatMul(input_matrix, w_matrix, &temp_out_tensor);
VLOG(2) << "(B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)";
VLOG(2) << temp_out_tensor;
// temp_out_tensor.Resize(temp_out_dims);
Tensor multihead_temp_tensor;
......
......@@ -1081,5 +1081,419 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest):
self.run_test()
class TrtConvertMultiHeadMatmulTest_biasqk_seqseq(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def generate_input1(batch, dim1):
return np.random.random((batch, dim1, 768)).astype(np.float32)
def generate_input2(shape):
return np.random.random(shape).astype(np.float32)
def generate_weight1():
return np.random.random((768, 768)).astype(np.float32)
def generate_weight2():
return np.random.random(768).astype(np.float32)
def generate_weight3():
return np.random.random((768, 768)).astype(np.float32)
for batch in [2]:
self.batch = batch
for reshape_shape in [[0, 0, 12, 64]]:
for dim1 in [128]:
input2_shapes = [
[batch, reshape_shape[2], dim1, dim1],
[batch, 1, 1, dim1],
]
for input2_shape in input2_shapes:
for axis in [0]:
dics = [
{"x_num_col_dims": 2, "y_num_col_dims": 1},
{"axis": 2},
{"shape": reshape_shape},
{"axis": [0, 2, 1, 3]},
{"x_num_col_dims": 2, "y_num_col_dims": 1},
{"axis": 2},
{"shape": reshape_shape},
{"axis": [0, 2, 1, 3]},
{"x_num_col_dims": 2, "y_num_col_dims": 1},
{"axis": 2},
{"shape": reshape_shape},
{"axis": [0, 2, 1, 3]},
{
"scale": 0.125,
"bias": 0.0,
"bias_after_scale": True,
},
{
"alpha": 1.0,
"transpose_X": False,
"transpose_Y": True,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
},
{"axis": axis},
{"axis": -1, "is_test": True},
{
"seed": 0,
"dropout_prob": 0.10000000149011612,
"dropout_implementation": "upscale_in_train",
"fix_seed": False,
"is_test": True,
},
{
"alpha": 1.0,
"transpose_X": False,
"transpose_Y": False,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
},
{"axis": [0, 2, 1, 3]},
{"shape": [0, 0, 768]},
{"x_num_col_dims": 2, "y_num_col_dims": 1},
]
ops_config = [
{
"op_type": "mul",
"op_inputs": {
"X": ["input_data1"],
"Y": ["mul1_weight"],
},
"op_outputs": {"Out": ["mul1_output"]},
"op_attrs": dics[0],
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["mul1_output"],
"Y": ["elementwise_add1_weight"],
},
"op_outputs": {
"Out": ["elementwise_add1_output"]
},
"op_attrs": dics[1],
},
{
"op_type": "reshape2",
"op_inputs": {
"X": ["elementwise_add1_output"],
},
"op_outputs": {
"Out": ["reshape21_output"],
"XShape": ["reshape21_output_xshape"],
},
"op_attrs": dics[2],
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["reshape21_output"]},
"op_outputs": {
"Out": ["transpose21_output"],
"XShape": ["transpose21_output_xshape"],
},
"op_attrs": dics[3],
},
{
"op_type": "mul",
"op_inputs": {
"X": ["input_data1"],
"Y": ["mul2_weight"],
},
"op_outputs": {"Out": ["mul2_output"]},
"op_attrs": dics[4],
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["mul2_output"],
"Y": ["elementwise_add2_weight"],
},
"op_outputs": {
"Out": ["elementwise_add2_output"]
},
"op_attrs": dics[5],
},
{
"op_type": "reshape2",
"op_inputs": {
"X": ["elementwise_add2_output"]
},
"op_outputs": {
"Out": ["reshape22_output"],
"XShape": ["reshape22_output_xshape"],
},
"op_attrs": dics[6],
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["reshape22_output"]},
"op_outputs": {
"Out": ["transpose22_output"],
"XShape": ["transpose22_output_xshape"],
},
"op_attrs": dics[7],
},
{
"op_type": "mul",
"op_inputs": {
"X": ["input_data1"],
"Y": ["mul3_weight"],
},
"op_outputs": {"Out": ["mul3_output"]},
"op_attrs": dics[8],
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["mul3_output"],
"Y": ["elementwise_add3_weight"],
},
"op_outputs": {
"Out": ["elementwise_add3_output"]
},
"op_attrs": dics[9],
},
{
"op_type": "reshape2",
"op_inputs": {
"X": ["elementwise_add3_output"]
},
"op_outputs": {
"Out": ["reshape23_output"],
"XShape": ["reshape23_output_xshape"],
},
"op_attrs": dics[10],
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["reshape23_output"]},
"op_outputs": {
"Out": ["transpose23_output"],
"XShape": ["transpose23_output_xshape"],
},
"op_attrs": dics[11],
},
{
"op_type": "scale",
"op_inputs": {
"X": ["transpose23_output"],
},
"op_outputs": {"Out": ["scale_output"]},
"op_attrs": dics[12],
},
{
"op_type": "matmul",
"op_inputs": {
"X": ["scale_output"],
"Y": ["transpose22_output"],
},
"op_outputs": {"Out": ["matmul1_output"]},
"op_attrs": dics[13],
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["matmul1_output"],
"Y": ["input_data2"],
},
"op_outputs": {
"Out": ["elementwise_add4_output"]
},
"op_attrs": dics[14],
},
{
"op_type": "softmax",
"op_inputs": {
"X": ["elementwise_add4_output"]
},
"op_outputs": {"Out": ["softmax_output"]},
"op_attrs": dics[15],
},
{
"op_type": "dropout",
"op_inputs": {
"X": ["softmax_output"],
},
"op_outputs": {"Out": ["dropout3_output"]},
"op_attrs": dics[16],
},
{
"op_type": "matmul",
"op_inputs": {
"X": ["dropout3_output"],
"Y": ["transpose21_output"],
},
"op_outputs": {"Out": ["matmul2_output"]},
"op_attrs": dics[17],
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["matmul2_output"]},
"op_outputs": {
"Out": ["transpose24_output"],
"XShape": ["transpose24_output_xshape"],
},
"op_attrs": dics[18],
},
{
"op_type": "reshape2",
"op_inputs": {"X": ["transpose24_output"]},
"op_outputs": {
"Out": ["reshape24_output"],
"XShape": ["reshape24_output_xshape"],
},
"op_attrs": dics[19],
},
# In order to fuse ops with
# multihead_matmul_fuse_pass_v2, the last op
# must be mul.
{
"op_type": "mul",
"op_inputs": {
"X": ["reshape24_output"],
"Y": ["mul4_weight"],
},
"op_outputs": {"Out": ["mul4_output"]},
"op_attrs": dics[20],
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"mul1_weight": TensorConfig(
data_gen=partial(generate_weight1)
),
"mul2_weight": TensorConfig(
data_gen=partial(generate_weight1)
),
"mul3_weight": TensorConfig(
data_gen=partial(generate_weight1)
),
"mul4_weight": TensorConfig(
data_gen=partial(generate_weight1)
),
"elementwise_add1_weight": TensorConfig(
data_gen=partial(generate_weight2)
),
"elementwise_add2_weight": TensorConfig(
data_gen=partial(generate_weight3)
),
"elementwise_add3_weight": TensorConfig(
data_gen=partial(generate_weight2)
),
},
inputs={
"input_data1": TensorConfig(
data_gen=partial(
generate_input1, batch, dim1
)
),
"input_data2": TensorConfig(
data_gen=partial(
generate_input2, input2_shape
)
),
},
outputs=["mul4_output"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
# The last dim of input1 and input2 should be static.
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 8, 768],
"input_data2": [1, 1, 1, 128],
"reshape24_output": [1, 128, 768],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [16, 512, 768],
"input_data2": [16, 256, 512, 128],
"reshape24_output": [1, 128, 768],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [8, 128, 768],
"input_data2": [8, 32, 64, 128],
"reshape24_output": [1, 128, 768],
}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
self.trt_param.workspace_size = 2013265920
yield self.create_inference_config(), (1, 3), (1e-5, 1e-4)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 3), (1e-3, 1e-3)
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if self.trt_param.precision == paddle_infer.PrecisionType.Half:
return True
return False
self.add_skip_case(
teller1,
SkipReasons.TRT_NOT_IMPLEMENTED,
"The output has diff between gpu and trt in fp16 mode.",
)
def teller2(program_config, predictor_config):
if (
self.trt_param.precision == paddle_infer.PrecisionType.Float32
and len(self.dynamic_shape.min_input_shape) != 0
and self.batch > 2
):
return True
return False
self.add_skip_case(
teller2,
SkipReasons.TRT_NOT_IMPLEMENTED,
"The output has diff between gpu and trt when dynamic fp32 mode and batch size > 2.",
)
def teller3(program_config, predictor_config):
if self.trt_param.precision == paddle_infer.PrecisionType.Int8:
return True
return False
self.add_skip_case(
teller3,
SkipReasons.TRT_NOT_IMPLEMENTED,
"The output has diff between gpu and trt in int8 mode.",
)
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
......@@ -29,6 +29,113 @@ def stable_softmax(x):
return exps / np.sum(exps)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "Paddle core is not compiled with CUDA"
)
class TestFusedMultiHeadMatmulOp_biasqk2(OpTest):
def config(self):
self.seq_len = 128
self.size_per_head = 64
self.head_number = 12
self.batch_size = 8
self.scale = 0.125
def setUp(self):
self.op_type = "multihead_matmul"
self.config()
h = self.seq_len
w = self.head_number * self.size_per_head
self.Input = (
np.random.random((self.batch_size, h, w)).astype("float32") - 0.5
)
self.WQ = np.random.random((w, w)).astype("float32")
self.KQ = np.random.random((w, w)).astype("float32")
self.VQ = np.random.random((w, w)).astype("float32")
self.CombinedW = np.hstack((self.WQ, self.KQ, self.VQ)).reshape(
(w, 3, w)
)
self.Q = np.dot(self.Input, self.WQ)
self.K = np.dot(self.Input, self.KQ)
self.V = np.dot(self.Input, self.VQ)
self.BiasQ = np.random.random((1, w)).astype("float32")
self.BiasK = np.random.random((1, w)).astype("float32")
self.BiasV = np.random.random((1, w)).astype("float32")
self.CombinedB = np.vstack((self.BiasQ, self.BiasK, self.BiasV))
self.BiasQK = np.random.random(
(1, 1, self.seq_len, self.seq_len)
).astype("float32")
# Compute Q path
fc_q = self.Q + self.BiasQ
reshape_q = np.reshape(
fc_q,
(
self.batch_size,
self.seq_len,
self.head_number,
self.size_per_head,
),
)
transpose_q = np.transpose(reshape_q, (0, 2, 1, 3))
scale_q = self.scale * transpose_q
# Compute K path
fc_k = self.K + self.BiasK
reshape_k = np.reshape(
fc_k,
(
self.batch_size,
self.seq_len,
self.head_number,
self.size_per_head,
),
)
transpose_k = np.transpose(reshape_k, (0, 2, 3, 1))
# Compute Q*K
q_k = np.matmul(scale_q, transpose_k)
eltadd_qk = q_k + np.tile(
self.BiasQK, [self.batch_size, self.head_number, 1, 1]
)
softmax_qk = np.apply_along_axis(stable_softmax, 3, eltadd_qk)
# Compute V path
fc_v = self.V + self.BiasV
reshape_v = np.reshape(
fc_v,
(
self.batch_size,
self.seq_len,
self.head_number,
self.size_per_head,
),
)
transpose_v = np.transpose(reshape_v, (0, 2, 1, 3))
# Compute QK*V
qkv = np.matmul(softmax_qk, transpose_v)
transpose_qkv = np.transpose(qkv, (0, 2, 1, 3))
reshape_qkv = np.reshape(transpose_qkv, (self.batch_size, h, w))
print("biasqk shape")
print(self.BiasQK.shape)
self.inputs = {
"Input": self.Input,
"W": self.CombinedW,
"Bias": self.CombinedB,
"BiasQK": self.BiasQK,
}
self.attrs = {
"transpose_Q": False,
"transpose_K": True,
"transpose_V": False,
"head_number": self.head_number,
"alpha": self.scale,
}
self.outputs = {"Out": reshape_qkv}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=2e-3)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "Paddle core is not compiled with CUDA"
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册