未验证 提交 636dc2ff 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Add bias input of mmha and simplify mmha. (#56411)

* add_bias_and_simplify_mmha
上级 e99b3cb2
......@@ -1607,14 +1607,14 @@
backward : margin_cross_entropy_grad
- op : masked_multihead_attention_
args : (Tensor x, Tensor cache_kv, Tensor src_mask, Tensor cum_offsets, Tensor sequence_lengths, Tensor rotary_tensor, Tensor beam_cache_offset, Tensor qkv_out_scale, Tensor out_shift, Tensor out_smooth, int seq_len, int rotary_emb_dims, bool use_neox_rotary_style=false, float out_scale=-1, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0)
args : (Tensor x, Tensor cache_kv, Tensor bias, Tensor src_mask, Tensor cum_offsets, Tensor sequence_lengths, Tensor rotary_tensor, Tensor beam_cache_offset, Tensor qkv_out_scale, Tensor out_shift, Tensor out_smooth, int seq_len, int rotary_emb_dims, bool use_neox_rotary_style=false, str compute_dtype = "default", float out_scale=-1, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0)
output : Tensor(out), Tensor(cache_kv_out), Tensor(beam_cache_offset_out)
infer_meta :
func : MaskedMultiheadAttentionInferMeta
kernel :
func : masked_multihead_attention
data_type : cache_kv
optional : src_mask, cum_offsets, sequence_lengths, rotary_tensor, beam_cache_offset, qkv_out_scale, out_shift, out_smooth
data_type : x
optional : bias, src_mask, cum_offsets, sequence_lengths, rotary_tensor, beam_cache_offset, qkv_out_scale, out_shift, out_smooth
inplace : (cache_kv -> cache_kv_out), (beam_cache_offset -> beam_cache_offset_out)
- op : masked_select
......
......@@ -4094,6 +4094,7 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& cache_kv,
const MetaTensor& bias,
const MetaTensor& src_mask,
const MetaTensor& cum_offsets,
const MetaTensor& sequence_lengths,
......@@ -4105,6 +4106,7 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
int seq_len,
int rotary_emb_dims,
const bool use_neox_rotary_style,
const std::string& compute_dtype,
const float out_scale,
const int quant_round_type,
const float quant_max_bound,
......@@ -4113,7 +4115,6 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
MetaTensor* cache_kv_out,
MetaTensor* beam_cache_offset_out) {
int bsz = x.dims()[0];
auto x_dtype = x.dtype();
auto cache_kv_dims = cache_kv.dims();
int num_head = cache_kv.dims()[2];
int dim_head = cache_kv.dims()[4];
......@@ -4141,10 +4142,86 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
out->set_dims({bsz, num_head * dim_head});
auto FBADtypeCheck = [](const MetaTensor& check_tensor,
const std::string& tensor_name,
const std::string& compute_dtype) {
if (compute_dtype == "bf16") {
PADDLE_ENFORCE_EQ(
check_tensor.dtype(),
phi::DataType::BFLOAT16,
phi::errors::InvalidArgument(
"Input(%s) dtype must be the same with Attr(compute_dtype)",
tensor_name));
} else if (compute_dtype == "fp16") {
PADDLE_ENFORCE_EQ(
check_tensor.dtype(),
phi::DataType::FLOAT16,
phi::errors::InvalidArgument(
"Input(%s) dtype must be the same with Attr(compute_dtype)",
tensor_name));
} else if (compute_dtype == "fp32") {
PADDLE_ENFORCE_EQ(
check_tensor.dtype(),
phi::DataType::FLOAT32,
phi::errors::InvalidArgument(
"Input(%s) dtype must be the same with Attr(compute_dtype)",
tensor_name));
}
};
// In the case of quantization enabled, the dtype for computation is
// determined based on compute_dtype.
if (x.dtype() == phi::DataType::INT32) {
PADDLE_ENFORCE_NE(
compute_dtype,
"default",
phi::errors::InvalidArgument(
"If Input(x) dtype is INT32, Attr(compute_dtype) must be set."));
if (bias) {
FBADtypeCheck(bias, "bias", compute_dtype);
}
if (out_scale > 0) {
out->set_dtype(phi::DataType::INT8);
} else {
if (compute_dtype == "bf16") {
out->set_dtype(phi::DataType::BFLOAT16);
} else if (compute_dtype == "fp16") {
out->set_dtype(phi::DataType::FLOAT16);
} else if (compute_dtype == "fp32") {
out->set_dtype(phi::DataType::FLOAT32);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"In the case of quantization enabled with Input(x) INT32, "
"Attr(compute_dtype) must be set in (bf16, fp16, fp32), "
"but get compute_dtype (%s)",
compute_dtype));
}
}
} else {
if (bias) {
if (compute_dtype != "default") {
FBADtypeCheck(bias, "bias", compute_dtype);
FBADtypeCheck(x, "x", compute_dtype);
} else {
PADDLE_ENFORCE_EQ(
x.dtype(),
bias.dtype(),
phi::errors::InvalidArgument("Input(x) and Input(bias) must be the "
"same dtype in this situation"));
}
} else {
// bias not exist
if (compute_dtype != "default") {
FBADtypeCheck(x, "x", compute_dtype);
}
}
if (out_scale > 0) {
out->set_dtype(DataType::INT8);
out->set_dtype(phi::DataType::INT8);
} else {
out->set_dtype(x_dtype);
out->set_dtype(x.dtype());
}
}
cache_kv_out->set_dims(cache_kv_dims);
......
......@@ -801,6 +801,7 @@ void FusedRopeInferMeta(const MetaTensor& q,
void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& cache_kv,
const MetaTensor& bias,
const MetaTensor& src_mask,
const MetaTensor& cum_offsets,
const MetaTensor& sequence_lengths,
......@@ -812,6 +813,7 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
int seq_len,
int rotary_emb_dims,
const bool use_neox_rotary_style,
const std::string& compute_dtype,
const float out_scale,
const int quant_round_type,
const float quant_max_bound,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/fusion/gpu/masked_multihead_attention_utils.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void MMHAKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& cache_kv,
const paddle::optional<DenseTensor>& src_mask,
const paddle::optional<DenseTensor>& cum_offsets,
const paddle::optional<DenseTensor>& sequence_lengths,
const paddle::optional<DenseTensor>& rotary_tensor,
const paddle::optional<DenseTensor>& beam_cache_offset,
const paddle::optional<DenseTensor>& qkv_out_scale,
const paddle::optional<DenseTensor>& out_shift,
const paddle::optional<DenseTensor>& out_smooth,
int seq_len,
int rotary_emb_dims,
const bool use_neox_rotary_style,
const float out_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
DenseTensor* out,
DenseTensor* cache_kv_out,
DenseTensor* beam_cache_offset_out);
} // namespace fusion
} // namespace phi
......@@ -48,7 +48,6 @@
*/
#ifndef PADDLE_WITH_HIP
#pragma once
#if defined(__CUDACC__) && CUDA_VERSION >= 11000
......@@ -66,8 +65,6 @@
namespace phi {
namespace fusion {
namespace { // NOLINT
struct Float8_ {
float2 x;
float2 y;
......@@ -1712,8 +1709,6 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, // NOLINT
}
#endif // ENABLE_BF16
} // namespace
} // namespace fusion
} // namespace phi
......
......@@ -19,6 +19,7 @@ from paddle.framework import LayerHelper, in_dynamic_mode
def masked_multihead_attention(
x,
cache_kv=None,
bias=None,
src_mask=None,
cum_offsets=None,
sequence_lengths=None,
......@@ -30,6 +31,7 @@ def masked_multihead_attention(
seq_len=1,
rotary_emb_dims=0,
use_neox_rotary_style=False,
compute_dtype='default',
out_scale=-1,
quant_round_type=1,
quant_max_bound=127.0,
......@@ -43,6 +45,7 @@ def masked_multihead_attention(
Args:
x (Tensor): The input tensor could be 2-D tensor. Its shape is [batch_size, 3 * num_head * head_dim].
cache_kvs (list(Tensor)|tuple(Tensor)): The cache structure tensors for the generation model. Its shape is [2, batch_size, num_head, max_seq_len, head_dim].
bias (Tensor, optional): The bias tensor. Its shape is [3, num_head, head_dim].
src_mask (Tensor, optional): The src_mask tensor. Its shape is [batch_size, 1, 1, sequence_length].
sequence_lengths (Tensor, optional): The sequence_lengths tensor, used to index input. Its shape is [batch_size, 1].
rotary_tensor (Tensor, optional): The rotary_tensor tensor. The dtype must be float. Its shape is [batch_size, 1, 1, sequence_length, head_dim].
......@@ -53,6 +56,7 @@ def masked_multihead_attention(
seq_len (int, optional): The seq_len, used to get input length. Default 1.
rotary_emb_dims (int, optional): The rotary_emb_dims. Default 1.
use_neox_rotary_style (bool, optional): A flag indicating whether neox_rotary_style is needed or not. Default False.
compute_dtype (string): A compute dtype, used to represent the input data type.
out_scale (float, optional): The out_scale, used in quant.
quant_round_type (int, optional): The quant_round_type, used in quant. Default 1.
quant_max_bound (float, optional): The quant_max_bound, used in quant. Default 127.0.
......@@ -89,6 +93,7 @@ def masked_multihead_attention(
return _C_ops.masked_multihead_attention_(
x,
cache_kv,
bias,
src_mask,
cum_offsets,
sequence_lengths,
......@@ -100,6 +105,7 @@ def masked_multihead_attention(
seq_len,
rotary_emb_dims,
use_neox_rotary_style,
compute_dtype,
out_scale,
quant_round_type,
quant_max_bound,
......@@ -107,11 +113,22 @@ def masked_multihead_attention(
)
helper = LayerHelper('masked_multihead_attention', **locals())
if x.dtype == "int32":
if compute_dtype == "bf16":
dtype = "uint16"
elif compute_dtype == "fp16":
dtype = "float16"
elif compute_dtype == "fp32":
dtype = "float32"
out = helper.create_variable_for_type_inference(dtype=dtype)
else:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
inputs = {}
inputs['x'] = x
inputs['cache_kv'] = cache_kv
if bias is not None:
inputs['bias'] = bias
if src_mask is not None:
inputs['src_mask'] = src_mask
if cum_offsets is not None:
......@@ -148,6 +165,7 @@ def masked_multihead_attention(
'seq_len': seq_len,
'rotary_emb_dims': rotary_emb_dims,
'use_neox_rotary_style': use_neox_rotary_style,
'compute_dtype': compute_dtype,
'out_scale': out_scale,
'quant_round_type': quant_round_type,
'quant_max_bound': quant_max_bound,
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
......@@ -43,6 +42,10 @@ class TestMMHAOp(unittest.TestCase):
2, 10, size=(self.bsz, 3, self.num_head, self.dim_head)
).astype("int")
self.bias = np.random.uniform(
-0.05, 0.05, [3, self.num_head, self.dim_head]
)
self.src_mask = np.zeros([self.bsz, 1, 1, self.sequence_length + 1])
self.cum_offsets = None
......@@ -77,7 +80,7 @@ class TestMMHAOp(unittest.TestCase):
self.seq_len = 1
self.rotary_emb_dims = 0
self.use_neox_rotary_style = False
self.compute_dtype = "default"
self.out_scale = 10
self.quant_round_type = 1
self.quant_max_bound = 126
......@@ -100,6 +103,7 @@ class TestMMHAOp(unittest.TestCase):
self,
x,
cache_kv_out,
bias,
src_mask,
qkv_out_scale,
seq_len,
......@@ -110,9 +114,9 @@ class TestMMHAOp(unittest.TestCase):
bsz,
):
if qkv_out_scale is not None:
x = x.cast(cache_kv_out.dtype) * qkv_out_scale
x = x.cast(cache_kv_out.dtype) * qkv_out_scale + bias
else:
x = x
x = x + bias
x = paddle.transpose(
x, [0, 2, 1, 3]
......@@ -145,6 +149,7 @@ class TestMMHAOp(unittest.TestCase):
x,
cache_kv_out,
cache_kv_mmha_out,
bias,
src_mask,
qkv_out_scale,
out_scale,
......@@ -157,12 +162,14 @@ class TestMMHAOp(unittest.TestCase):
else:
x = paddle.to_tensor(x).cast(dtype)
src_mask = paddle.to_tensor(src_mask).cast(dtype)
bias = paddle.to_tensor(bias).cast(dtype)
cache_kv_out = paddle.to_tensor(cache_kv_out).cast(dtype)
cache_kv_mmha_out = paddle.to_tensor(cache_kv_mmha_out).cast(dtype)
paddle_naive_mmha_out = 0
paddle_naive_mmha_out = self.mmha_naive(
x,
cache_kv_out,
bias,
src_mask,
qkv_out_scale,
self.seq_len,
......@@ -174,9 +181,14 @@ class TestMMHAOp(unittest.TestCase):
)
x = x.reshape([self.bsz, -1])
if x.dtype == paddle.float16:
dtype = self.compute_dtype
else:
dtype = "fp16"
paddle_mmha_out = masked_multihead_attention(
x,
cache_kv_mmha_out,
bias,
src_mask,
None,
None,
......@@ -188,6 +200,7 @@ class TestMMHAOp(unittest.TestCase):
self.seq_len,
self.rotary_emb_dims,
self.use_neox_rotary_style,
dtype,
out_scale,
self.quant_round_type,
self.quant_max_bound,
......@@ -204,6 +217,7 @@ class TestMMHAOp(unittest.TestCase):
self.x,
self.cache_kv_out,
self.cache_kv_mmha_out,
self.bias,
self.src_mask,
None,
-1,
......@@ -224,6 +238,7 @@ class TestMMHAOp(unittest.TestCase):
self.x_int,
self.cache_kv_out,
self.cache_kv_mmha_out,
self.bias,
self.src_mask,
self.qkv_out_scale,
-1,
......@@ -244,6 +259,7 @@ class TestMMHAOp(unittest.TestCase):
self.x,
self.cache_kv_out,
self.cache_kv_mmha_out,
self.bias,
self.src_mask,
None,
self.out_scale,
......@@ -274,6 +290,9 @@ class TestLayerNormStaticInt8Op(unittest.TestCase):
self.x = np.random.uniform(
-0.05, 0.05, [self.bsz, 3, self.num_head, self.dim_head]
)
self.bias = np.random.uniform(
-0.05, 0.05, [3, self.num_head, self.dim_head]
)
self.src_mask = np.zeros([self.bsz, 1, 1, self.sequence_length + 1])
self.cum_offsets = None
......@@ -317,6 +336,7 @@ class TestLayerNormStaticInt8Op(unittest.TestCase):
self,
x,
cache_kv_out,
bias,
src_mask,
qkv_out_scale,
seq_len,
......@@ -327,7 +347,9 @@ class TestLayerNormStaticInt8Op(unittest.TestCase):
bsz,
):
if qkv_out_scale is not None:
x = x.cast(cache_kv_out.dtype) * qkv_out_scale
x = x.cast(cache_kv_out.dtype) * qkv_out_scale + bias
else:
x = x + bias
x = paddle.transpose(
x, [0, 2, 1, 3]
......@@ -351,6 +373,7 @@ class TestLayerNormStaticInt8Op(unittest.TestCase):
def check_main(
self,
x,
bias,
src_mask,
cache_kv_out,
cache_kv_mmha_out,
......@@ -361,11 +384,13 @@ class TestLayerNormStaticInt8Op(unittest.TestCase):
paddle.disable_static()
x_tensor = paddle.to_tensor(x).cast(dtype)
src_mask_tensor = paddle.to_tensor(src_mask).cast(dtype)
bias_tensor = paddle.to_tensor(bias).cast(dtype)
cache_kv_out = paddle.to_tensor(cache_kv_out).cast(dtype)
paddle_naive_mmha_out = self.mmha_naive(
x_tensor,
cache_kv_out,
bias_tensor,
src_mask_tensor,
None,
self.seq_len,
......@@ -383,6 +408,11 @@ class TestLayerNormStaticInt8Op(unittest.TestCase):
shape=[self.bsz, 3 * self.num_head * self.dim_head],
dtype=dtype,
)
bias_static = paddle.static.data(
name="bias_static",
shape=[3, self.num_head, self.dim_head],
dtype=dtype,
)
src_mask_static = paddle.static.data(
name="src_mask_static",
shape=[self.bsz, 1, 1, self.sequence_length + 1],
......@@ -403,6 +433,7 @@ class TestLayerNormStaticInt8Op(unittest.TestCase):
outs = masked_multihead_attention(
x_static,
cache_kv_mmha_out_static,
bias_static,
src_mask_static,
None,
None,
......@@ -414,6 +445,7 @@ class TestLayerNormStaticInt8Op(unittest.TestCase):
32,
0,
False,
"fp16",
-1,
1,
127.0,
......@@ -424,6 +456,7 @@ class TestLayerNormStaticInt8Op(unittest.TestCase):
feed={
"x_static": x.reshape(self.bsz, -1).astype(dtype),
"cache_kv_mmha_out_static": cache_kv_mmha_out.astype(dtype),
"bias_static": bias.astype(dtype),
"src_mask_static": src_mask.astype(dtype),
},
fetch_list=[outs],
......@@ -437,6 +470,7 @@ class TestLayerNormStaticInt8Op(unittest.TestCase):
paddle_naive_mmha_out, paddle_mmha_out = self.check_main(
self.x,
self.bias,
self.src_mask,
self.cache_kv_out,
self.cache_kv_mmha_out,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册