未验证 提交 989c5e87 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Add masked multihead attention kernel and export API. (#55344)

* support_mmha
* add_python_api
* add_api_doc
* fix_doc_error
* fix_infermeta
* add_infermeta
* add_bf16_cuda_check
* add_bf16_check
* fix_ci_windows
* fix_ci_windows_kernel_register
* fix_test_mmha
* add_cumoffsets
* remove_bias
* delete_mmha_reshape_input_output
* rename_delete_hfile
* remove_fluid

---------
Co-authored-by: Nyangjianfengo1 <yangjianfeng01@baidu.com>
上级 2a5d1d54
...@@ -931,9 +931,8 @@ __global__ void masked_multihead_attention_kernel( ...@@ -931,9 +931,8 @@ __global__ void masked_multihead_attention_kernel(
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if (bi == 0 && hi == 0 && tid == 0) { if (bi == 0 && hi == 0 && tid == 0) {
printf("=======q_out=======\n"); VLOG(0) << "=======q_out=======\n";
for (int i = 0; i < Dh; ++i) printf("%f ", static_cast<float>(q_smem[i])); for (int i = 0; i < Dh; ++i) VLOG(0) << static_cast<float>(q_smem[i]);
printf("\n");
} }
__syncthreads(); __syncthreads();
#endif #endif
......
...@@ -1616,6 +1616,17 @@ ...@@ -1616,6 +1616,17 @@
data_type : logits data_type : logits
backward : margin_cross_entropy_grad backward : margin_cross_entropy_grad
- op : masked_multihead_attention_
args : (Tensor x, Tensor cache_kv, Tensor src_mask, Tensor cum_offsets, Tensor sequence_lengths, Tensor rotary_tensor, Tensor beam_cache_offset, Tensor qkv_out_scale, Tensor out_shift, Tensor out_smooth, int seq_len, int rotary_emb_dims, bool use_neox_rotary_style=false, float out_scale=-1, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0)
output : Tensor(out), Tensor(cache_kv_out), Tensor(beam_cache_offset_out)
infer_meta :
func : MaskedMultiheadAttentionInferMeta
kernel :
func : masked_multihead_attention
data_type : cache_kv
optional : src_mask, cum_offsets, sequence_lengths, rotary_tensor, beam_cache_offset, qkv_out_scale, out_shift, out_smooth
inplace : (cache_kv -> cache_kv_out), (beam_cache_offset -> beam_cache_offset_out)
- op : masked_select - op : masked_select
args : (Tensor x, Tensor mask) args : (Tensor x, Tensor mask)
output : Tensor (out) output : Tensor (out)
......
...@@ -3983,5 +3983,69 @@ void WeightOnlyMatmulInferMeta(const MetaTensor& x, ...@@ -3983,5 +3983,69 @@ void WeightOnlyMatmulInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& cache_kv,
const MetaTensor& src_mask,
const MetaTensor& cum_offsets,
const MetaTensor& sequence_lengths,
const MetaTensor& rotary_tensor,
const MetaTensor& beam_cache_offset,
const MetaTensor& qkv_out_scale,
const MetaTensor& out_shift,
const MetaTensor& out_smooth,
int seq_len,
int rotary_emb_dims,
const bool use_neox_rotary_style,
const float out_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
MetaTensor* out,
MetaTensor* cache_kv_out,
MetaTensor* beam_cache_offset_out) {
int bsz = x.dims()[0];
auto x_dtype = x.dtype();
auto cache_kv_dims = cache_kv.dims();
int num_head = cache_kv.dims()[2];
int dim_head = cache_kv.dims()[4];
PADDLE_ENFORCE_EQ(
cache_kv_dims.size(),
5,
errors::InvalidArgument("The cache_kv must be 5 dims, but got %d",
cache_kv_dims.size()));
PADDLE_ENFORCE_EQ(
cache_kv_dims[0],
2,
errors::InvalidArgument("The first dim of cache_kv must be 2, but got %d",
cache_kv_dims[0]));
if (rotary_tensor) {
PADDLE_ENFORCE_EQ(
rotary_tensor.dtype(),
DataType::FLOAT32,
errors::InvalidArgument(
"The dtype of rotary_tensor must be float32, but got %d",
rotary_tensor.dtype()));
}
out->set_dims({bsz, num_head * dim_head});
if (out_scale > 0) {
out->set_dtype(DataType::INT8);
} else {
out->set_dtype(x_dtype);
}
cache_kv_out->set_dims(cache_kv_dims);
cache_kv_out->set_dtype(cache_kv.dtype());
if (beam_cache_offset) {
beam_cache_offset_out->set_dims(beam_cache_offset.dims());
beam_cache_offset_out->set_dtype(beam_cache_offset.dtype());
}
}
} // namespace phi } // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
...@@ -773,4 +773,25 @@ void FusedRopeInferMeta(const MetaTensor& q, ...@@ -773,4 +773,25 @@ void FusedRopeInferMeta(const MetaTensor& q,
MetaTensor* out_k, MetaTensor* out_k,
MetaTensor* out_v); MetaTensor* out_v);
void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& cache_kv,
const MetaTensor& src_mask,
const MetaTensor& cum_offsets,
const MetaTensor& sequence_lengths,
const MetaTensor& rotary_tensor,
const MetaTensor& beam_cache_offset,
const MetaTensor& qkv_out_scale,
const MetaTensor& out_shift,
const MetaTensor& out_smooth,
int seq_len,
int rotary_emb_dims,
const bool use_neox_rotary_style,
const float out_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
MetaTensor* out,
MetaTensor* cache_kv_out,
MetaTensor* beam_cache_offset_out);
} // namespace phi } // namespace phi
此差异已折叠。
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fusion/gpu/masked_multihead_attention.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void MMHAKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& cache_kv,
const paddle::optional<DenseTensor>& src_mask,
const paddle::optional<DenseTensor>& cum_offsets,
const paddle::optional<DenseTensor>& sequence_lengths,
const paddle::optional<DenseTensor>& rotary_tensor,
const paddle::optional<DenseTensor>& beam_cache_offset,
const paddle::optional<DenseTensor>& qkv_out_scale,
const paddle::optional<DenseTensor>& out_shift,
const paddle::optional<DenseTensor>& out_smooth,
int seq_len,
int rotary_emb_dims,
const bool use_neox_rotary_style,
const float out_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
DenseTensor* out,
DenseTensor* cache_kv_out,
DenseTensor* beam_cache_offset_out) {
#ifndef PADDLE_WITH_HIP
const auto& x_dims = x.dims();
int bsz = x_dims[0];
int cache_bsz = cache_kv.dims()[1];
int num_head = cache_kv.dims()[2];
int max_seq_len = cache_kv.dims()[3];
int dim_head = cache_kv.dims()[4];
int timestep = max_seq_len;
float inv_sqrt_dh = 1. / sqrt(dim_head);
Masked_multihead_attention_params<T> params;
bool mask_broadcast_num_heads = true;
if (src_mask) {
if (src_mask->dims()[1] == 1) {
mask_broadcast_num_heads = true;
} else if (src_mask->dims()[1] == num_head) {
mask_broadcast_num_heads = false;
} else {
PADDLE_THROW(errors::InvalidArgument(
"Unknow dimension for attn_mask, the num_head(2nd) "
"dimension is invalid, it should be 1 or num_head(%d), "
"but got %d",
num_head,
src_mask->dims()[1]));
}
params.attn_mask = src_mask->data<T>();
params.mask_length = src_mask->dims()[3];
timestep = src_mask->dims()[3] - 1;
}
if (out_scale > 0) {
dev_ctx.template Alloc<int8_t>(out);
} else {
dev_ctx.template Alloc<T>(out);
}
if (sequence_lengths) {
params.sequence_lengths = sequence_lengths->data<int>();
}
if (cum_offsets) {
params.cum_offsets = cum_offsets->data<int>();
} else {
params.cum_offsets = nullptr;
}
if (rotary_emb_dims > 0) {
params.rotary_emb = rotary_tensor->data<float>();
} else {
params.rotary_emb = nullptr;
}
if (beam_cache_offset) {
params.beam_cache_offset = beam_cache_offset->data<int>();
params.beam_width = beam_cache_offset->dims()[1];
}
params.mask_broadcast_num_heads = mask_broadcast_num_heads;
params.cache_kv = const_cast<T*>(cache_kv_out->data<T>());
params.neox_rotary_style = use_neox_rotary_style;
params.add_qkv_bias = false;
params.batch_size = bsz;
params.cache_batch_size = cache_bsz;
params.num_head = num_head;
params.timestep = timestep;
params.seq_len = seq_len;
params.max_seq_length = max_seq_len;
params.inv_sqrt_dh = inv_sqrt_dh;
params.rotary_emb_dims = rotary_emb_dims;
if (out_shift) {
DispatchFMHA<T>(dev_ctx,
x,
*(out_shift.get_ptr()),
*(out_smooth.get_ptr()),
params,
num_head,
dim_head,
out,
qkv_out_scale.get_ptr(),
out_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
DispatchFMHA<T>(dev_ctx,
x,
params,
num_head,
dim_head,
out,
qkv_out_scale.get_ptr(),
out_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
#endif // PADDLE_WITH_HIP
}
} // namespace fusion
} // namespace phi
#if CUDA_VERSION >= 11000
PD_REGISTER_KERNEL(masked_multihead_attention,
GPU,
ALL_LAYOUT,
phi::fusion::MMHAKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(masked_multihead_attention,
GPU,
ALL_LAYOUT,
phi::fusion::MMHAKernel,
float,
phi::dtype::float16) {}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/fusion/gpu/masked_multihead_attention_utils.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void MMHAKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& cache_kv,
const paddle::optional<DenseTensor>& src_mask,
const paddle::optional<DenseTensor>& cum_offsets,
const paddle::optional<DenseTensor>& sequence_lengths,
const paddle::optional<DenseTensor>& rotary_tensor,
const paddle::optional<DenseTensor>& beam_cache_offset,
const paddle::optional<DenseTensor>& qkv_out_scale,
const paddle::optional<DenseTensor>& out_shift,
const paddle::optional<DenseTensor>& out_smooth,
int seq_len,
int rotary_emb_dims,
const bool use_neox_rotary_style,
const float out_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
DenseTensor* out,
DenseTensor* cache_kv_out,
DenseTensor* beam_cache_offset_out);
} // namespace fusion
} // namespace phi
...@@ -30,6 +30,7 @@ from .variable_length_memory_efficient_attention import ( ...@@ -30,6 +30,7 @@ from .variable_length_memory_efficient_attention import (
) )
from .fused_rms_norm import fused_rms_norm from .fused_rms_norm import fused_rms_norm
from .fused_layer_norm import fused_layer_norm from .fused_layer_norm import fused_layer_norm
from .masked_multihead_attention import masked_multihead_attention
__all__ = [ __all__ = [
'fused_multi_head_attention', 'fused_multi_head_attention',
...@@ -45,4 +46,5 @@ __all__ = [ ...@@ -45,4 +46,5 @@ __all__ = [
'variable_length_memory_efficient_attention', 'variable_length_memory_efficient_attention',
"fused_rms_norm", "fused_rms_norm",
"fused_layer_norm", "fused_layer_norm",
"masked_multihead_attention",
] ]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import _C_ops
from paddle.framework import LayerHelper, in_dynamic_mode
def masked_multihead_attention(
x,
cache_kv=None,
src_mask=None,
cum_offsets=None,
sequence_lengths=None,
rotary_tensor=None,
beam_cache_offset=None,
qkv_out_scale=None,
out_shift=None,
out_smooth=None,
seq_len=1,
rotary_emb_dims=0,
use_neox_rotary_style=False,
out_scale=-1,
quant_round_type=1,
quant_max_bound=127.0,
quant_min_bound=-127.0,
):
r"""
Masked Multi-head attention for text summarization.
This is a fusion operator to compute masked multihead attention in transformer model architecture.
This operator only supports running on GPU.
Args:
x (Tensor): The input tensor could be 2-D tensor. Its shape is [batch_size, 3 * num_head * head_dim].
cache_kvs (list(Tensor)|tuple(Tensor)): The cache structure tensors for the generation model. Its shape is [2, batch_size, num_head, max_seq_len, head_dim].
src_mask (Tensor, optional): The src_mask tensor. Its shape is [batch_size, 1, 1, sequence_length].
sequence_lengths (Tensor, optional): The sequence_lengths tensor, used to index input. Its shape is [batch_size, 1].
rotary_tensor (Tensor, optional): The rotary_tensor tensor. The dtype must be float. Its shape is [batch_size, 1, 1, sequence_length, head_dim].
beam_cache_offset (Tensor, optional): The beam_cache_offset tensor. Its shape is [batch_size, beam_size, max_seq_len + max_dec_len].
qkv_out_scale (Tensor, optional): The qkv_out_scale tensor, used in quant. Its shape is [3, num_head, head_dim].
out_shift (Tensor, optional): The out_shift tensor, used in quant.
out_smooth (Tensor, optional): The out_smooth tensor, used in quant.
seq_len (int, optional): The seq_len, used to get input length. Default 1.
rotary_emb_dims (int, optional): The rotary_emb_dims. Default 1.
use_neox_rotary_style (bool, optional): A flag indicating whether neox_rotary_style is needed or not. Default False.
out_scale (float, optional): The out_scale, used in quant.
quant_round_type (int, optional): The quant_round_type, used in quant. Default 1.
quant_max_bound (float, optional): The quant_max_bound, used in quant. Default 127.0.
quant_min_bound (float, optional): The quant_min_bound, used in quant. Default -127.0.
Returns:
Tensor|tuple: If "beam_cache_offset_out" is not none, return the
tuple (output, cache_kvs_out, beam_cache_offset_out), which output is the output of
masked_multihead_attention layers, cache_kvs_out is inplace with input `cache_kvs`.
If "beam_cache_offset_out" is none, return the tuple (output, cache_kvs_out).
Examples:
.. code-block:: python
# required: gpu
import paddle
import paddle.incubate.nn.functional as F
# input: [batch_size, 3 * num_head * dim_head]
x = paddle.rand(shape=(2, 3 * 32 * 128), dtype="float32")
# src_mask: [batch_size, 1, 1, sequence_length]
src_mask = paddle.rand(shape=(2, 1, 1, 10), dtype="float32")
# cache_kv: [2, batch_size, num_head, max_seq_len, dim_head]
cache_kv = paddle.rand(shape=(2, 2, 32, 64, 128), dtype="float32")
output = F.masked_multihead_attention(
x, src_mask=src_mask, cache_kv=cache_kv)
"""
if in_dynamic_mode():
return _C_ops.masked_multihead_attention_(
x,
cache_kv,
src_mask,
cum_offsets,
sequence_lengths,
rotary_tensor,
beam_cache_offset,
qkv_out_scale,
out_shift,
out_smooth,
seq_len,
rotary_emb_dims,
use_neox_rotary_style,
out_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
)
helper = LayerHelper('masked_multihead_attention', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
inputs = {}
inputs['x'] = x
inputs['cache_kv'] = cache_kv
if src_mask is not None:
inputs['src_mask'] = src_mask
if cum_offsets is not None:
inputs['cum_offsets'] = cum_offsets
if sequence_lengths is not None:
inputs['sequence_lengths'] = sequence_lengths
if rotary_tensor is not None:
inputs['rotary_tensor'] = rotary_tensor
beam_cache_offset_flag = False
if beam_cache_offset is not None:
inputs['beam_cache_offset'] = beam_cache_offset
beam_cache_offset_flag = True
else:
beam_cache_offset = helper.create_variable_for_type_inference(
dtype="int"
)
if qkv_out_scale is not None:
inputs['qkv_out_scale'] = qkv_out_scale
if out_shift is not None:
inputs['out_shift'] = out_shift
if out_smooth is not None:
inputs['out_smooth'] = out_smooth
outputs = {
'out': out,
'cache_kv_out': cache_kv,
'beam_cache_offset_out': beam_cache_offset,
}
helper.append_op(
type='masked_multihead_attention',
inputs=inputs,
outputs=outputs,
attrs={
'seq_len': seq_len,
'rotary_emb_dims': rotary_emb_dims,
'use_neox_rotary_style': use_neox_rotary_style,
'out_scale': out_scale,
'quant_round_type': quant_round_type,
'quant_max_bound': quant_max_bound,
'quant_min_bound': quant_min_bound,
},
)
return (
(out, cache_kv, beam_cache_offset)
if beam_cache_offset_flag is not None
else (out, cache_kv)
)
...@@ -155,6 +155,7 @@ if(WIN32) ...@@ -155,6 +155,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_ops_nms) list(REMOVE_ITEM TEST_OPS test_ops_nms)
list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias) list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias)
list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op) list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op)
list(REMOVE_ITEM TEST_OPS test_masked_multihead_attention_op)
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
list(REMOVE_ITEM TEST_OPS test_rms_norm_op) list(REMOVE_ITEM TEST_OPS test_rms_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_layernorm_op) list(REMOVE_ITEM TEST_OPS test_fused_layernorm_op)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.framework import core
from paddle.incubate.nn.functional import masked_multihead_attention
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestMMHAOp(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.bsz = 2
self.cache_bsz = 2
self.num_head = 32
self.dim_head = 128
self.beam_size = 1
self.max_seq_len = 33
self.sequence_length = 32
self.x = np.random.uniform(
-0.05, 0.05, [self.bsz, 3, self.num_head, self.dim_head]
)
self.x_int = np.random.randint(
2, 10, size=(self.bsz, 3, self.num_head, self.dim_head)
).astype("int")
self.src_mask = np.zeros([self.bsz, 1, 1, self.sequence_length + 1])
self.cum_offsets = None
self.sequence_lengths = None
self.rotary_tensor = None
self.beam_cache_offset = None
self.cache_kv_out = np.random.uniform(
-0.05,
0.05,
[
2,
self.cache_bsz,
self.num_head,
self.sequence_length,
self.dim_head,
],
)
numpy_ones = np.zeros(
[2, self.cache_bsz, self.num_head, 1, self.dim_head]
)
self.cache_kv_mmha_out = np.concatenate(
(self.cache_kv_out, numpy_ones), axis=3
)
self.qkv_out_scale = np.random.uniform(
-0.05, 0.05, [3, self.num_head, self.dim_head]
)
self.out_shift = None
self.out_smooth = None
self.seq_len = 1
self.rotary_emb_dims = 0
self.use_neox_rotary_style = False
self.out_scale = 10
self.quant_round_type = 1
self.quant_max_bound = 126
self.quant_min_bound = -126
def quant_helper(
self, x, quant_scale, quant_round_type, quant_max_bound, quant_min_bound
):
quant_value = quant_max_bound * quant_scale * x
if quant_round_type == 0:
quant_value = paddle.to_tensor(np.rint(quant_value.numpy()))
else:
quant_value = paddle.round(quant_value)
return paddle.cast(
paddle.clip(quant_value, quant_min_bound, quant_max_bound),
paddle.int8,
)
def mmha_naive(
self,
x,
cache_kv_out,
src_mask,
qkv_out_scale,
seq_len,
out_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
bsz,
):
if qkv_out_scale is not None:
x = x.cast(cache_kv_out.dtype) * qkv_out_scale
else:
x = x
x = paddle.transpose(
x, [0, 2, 1, 3]
) # [bz, seqlen, nhead, head_dim] --> [bz, nhead, seqlen, head_dim]
q, k, v = paddle.split(x, 3, axis=2)
cache_k, cache_v = paddle.split(cache_kv_out, 2, axis=0)
k = paddle.concat([cache_k.squeeze(0), k], axis=2)
v = paddle.concat([cache_v.squeeze(0), v], axis=2)
product = paddle.matmul(
x=q * (x.shape[3] ** -0.5), y=k, transpose_y=True
)
product = product + src_mask
product = paddle.nn.functional.softmax(product)
out = (
paddle.matmul(product, v).transpose([0, 2, 1, 3]).reshape([bsz, -1])
)
normalized_out = self.quant_helper(
out,
out_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
).reshape([bsz, -1])
return out, normalized_out
def check_main(
self,
x,
cache_kv_out,
cache_kv_mmha_out,
src_mask,
qkv_out_scale,
out_scale,
dtype,
):
paddle.disable_static()
if qkv_out_scale is not None:
x = paddle.to_tensor(x).cast("int32")
qkv_out_scale = paddle.to_tensor(qkv_out_scale).cast("float32")
else:
x = paddle.to_tensor(x).cast(dtype)
src_mask = paddle.to_tensor(src_mask).cast(dtype)
cache_kv_out = paddle.to_tensor(cache_kv_out).cast(dtype)
cache_kv_mmha_out = paddle.to_tensor(cache_kv_mmha_out).cast(dtype)
paddle_naive_mmha_out = 0
paddle_naive_mmha_out = self.mmha_naive(
x,
cache_kv_out,
src_mask,
qkv_out_scale,
self.seq_len,
out_scale,
self.quant_round_type,
self.quant_max_bound,
self.quant_min_bound,
self.bsz,
)
x = x.reshape([self.bsz, -1])
paddle_mmha_out = masked_multihead_attention(
x,
cache_kv_mmha_out,
src_mask,
None,
None,
None,
None,
qkv_out_scale,
None,
None,
self.seq_len,
self.rotary_emb_dims,
self.use_neox_rotary_style,
out_scale,
self.quant_round_type,
self.quant_max_bound,
self.quant_min_bound,
)
paddle.enable_static()
return paddle_naive_mmha_out, paddle_mmha_out
def test_mmha_fp16(self):
if not paddle.is_compiled_with_cuda():
return
paddle_naive_mmha, paddle_mmha_out = self.check_main(
self.x,
self.cache_kv_out,
self.cache_kv_mmha_out,
self.src_mask,
None,
-1,
'float16',
)
np.testing.assert_allclose(
paddle_mmha_out[0].numpy(),
paddle_naive_mmha[0].numpy(),
rtol=1e-3,
atol=1e-3,
)
def test_mmha_qkv_out_scale(self):
if not paddle.is_compiled_with_cuda():
return
paddle_naive_mmha, paddle_mmha_out = self.check_main(
self.x_int,
self.cache_kv_out,
self.cache_kv_mmha_out,
self.src_mask,
self.qkv_out_scale,
-1,
'float16',
)
np.testing.assert_allclose(
paddle_mmha_out[0].numpy(),
paddle_naive_mmha[0].numpy(),
rtol=1e-3,
atol=1e-3,
)
def test_mmha_outlinear_in_scale(self):
if not paddle.is_compiled_with_cuda():
return
paddle_naive_mmha, paddle_mmha_out = self.check_main(
self.x,
self.cache_kv_out,
self.cache_kv_mmha_out,
self.src_mask,
None,
self.out_scale,
'float16',
)
np.testing.assert_allclose(
paddle_mmha_out[0].numpy(),
paddle_naive_mmha[1].numpy(),
rtol=1,
atol=1,
)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestLayerNormStaticInt8Op(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.bsz = 2
self.cache_bsz = 2
self.num_head = 32
self.dim_head = 128
self.beam_size = 1
self.max_seq_len = 33
self.sequence_length = 32
self.x = np.random.uniform(
-0.05, 0.05, [self.bsz, 3, self.num_head, self.dim_head]
)
self.src_mask = np.zeros([self.bsz, 1, 1, self.sequence_length + 1])
self.cum_offsets = None
self.sequence_lengths = None
self.rotary_tensor = None
self.beam_cache_offset = None
self.cache_kv_out = np.random.uniform(
-0.05,
0.05,
[
2,
self.cache_bsz,
self.num_head,
self.sequence_length,
self.dim_head,
],
)
numpy_ones = np.zeros(
[2, self.cache_bsz, self.num_head, 1, self.dim_head]
)
self.cache_kv_mmha_out = np.concatenate(
(self.cache_kv_out, numpy_ones), axis=3
)
self.qkv_out_scale = None
self.out_shift = None
self.out_smooth = None
self.seq_len = 1
self.rotary_emb_dims = 0
self.use_neox_rotary_style = False
self.out_scale = -1
self.quant_round_type = 1
self.quant_max_bound = 127
self.quant_min_bound = -127
self.place = paddle.CUDAPlace(0)
def mmha_naive(
self,
x,
cache_kv_out,
src_mask,
qkv_out_scale,
seq_len,
out_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
bsz,
):
if qkv_out_scale is not None:
x = x.cast(cache_kv_out.dtype) * qkv_out_scale
x = paddle.transpose(
x, [0, 2, 1, 3]
) # [bz, seqlen, nhead, head_dim] --> [bz, nhead, seqlen, head_dim]
q, k, v = paddle.split(x, 3, axis=2)
cache_k, cache_v = paddle.split(cache_kv_out, 2, axis=0)
k = paddle.concat([cache_k.squeeze(0), k], axis=2)
v = paddle.concat([cache_v.squeeze(0), v], axis=2)
product = paddle.matmul(
x=q * (x.shape[3] ** -0.5), y=k, transpose_y=True
)
product = product + src_mask
product = paddle.nn.functional.softmax(product)
out = (
paddle.matmul(product, v).transpose([0, 2, 1, 3]).reshape([bsz, -1])
)
return out
def check_main(
self,
x,
src_mask,
cache_kv_out,
cache_kv_mmha_out,
qkv_out_scale,
out_scale,
dtype,
):
paddle.disable_static()
x_tensor = paddle.to_tensor(x).cast(dtype)
src_mask_tensor = paddle.to_tensor(src_mask).cast(dtype)
cache_kv_out = paddle.to_tensor(cache_kv_out).cast(dtype)
paddle_naive_mmha_out = self.mmha_naive(
x_tensor,
cache_kv_out,
src_mask_tensor,
None,
self.seq_len,
out_scale,
self.quant_round_type,
self.quant_max_bound,
self.quant_min_bound,
self.bsz,
)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x_static = paddle.static.data(
name="x_static",
shape=[self.bsz, 3 * self.num_head * self.dim_head],
dtype=dtype,
)
src_mask_static = paddle.static.data(
name="src_mask_static",
shape=[self.bsz, 1, 1, self.sequence_length + 1],
dtype=dtype,
)
cache_kv_mmha_out_static = paddle.static.data(
name="cache_kv_mmha_out_static",
shape=[
2,
self.cache_bsz,
self.num_head,
self.sequence_length + 1,
self.dim_head,
],
dtype=dtype,
)
outs = masked_multihead_attention(
x_static,
cache_kv_mmha_out_static,
src_mask_static,
None,
None,
None,
None,
None,
None,
None,
32,
0,
False,
-1,
1,
127.0,
-127.0,
)
exe = paddle.static.Executor(self.place)
out_s = exe.run(
feed={
"x_static": x.reshape(self.bsz, -1).astype(dtype),
"cache_kv_mmha_out_static": cache_kv_mmha_out.astype(dtype),
"src_mask_static": src_mask.astype(dtype),
},
fetch_list=[outs],
)
return paddle_naive_mmha_out, out_s
def test_mmha_fp16(self):
if not paddle.is_compiled_with_cuda():
return
paddle_naive_mmha_out, paddle_mmha_out = self.check_main(
self.x,
self.src_mask,
self.cache_kv_out,
self.cache_kv_mmha_out,
self.qkv_out_scale,
self.out_scale,
'float16',
)
np.testing.assert_allclose(
paddle_mmha_out[0],
paddle_naive_mmha_out.numpy(),
rtol=1e-3,
atol=1e-3,
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册