未验证 提交 9ea279a4 编写于 作者: C carryyu 提交者: GitHub

make fused_multi_transformer support dynamically set the cache_kvs' shape and...

make fused_multi_transformer support dynamically set the cache_kvs' shape and support input prefix_caches. (#46777)

* make fused_multi_transformer support dynamically set the cache_kvs' shape and support input prefix_caches.
上级 af6d80fb
......@@ -143,12 +143,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel {
"head %d, but got %d",
trans_qkvw ? y_dim[1] : y_dim[2],
c_dim[2])); // num_head
PADDLE_ENFORCE_GT(
c_dim[3],
0,
paddle::platform::errors::InvalidArgument(
"The forth dim of CacheKV must be greater than 0, but got %d",
c_dim[3])); // cache_seq_len
PADDLE_ENFORCE_EQ(c_dim[4],
trans_qkvw ? y_dim[2] : y_dim[3],
paddle::platform::errors::InvalidArgument(
......@@ -199,6 +193,10 @@ class FusedMultiTransformerOpOpMaker
AddInput("CacheKV", "(optional) The cached KV for generation inference.")
.AsDispensable()
.AsDuplicable();
AddInput("PreCaches",
"(optional) The prefix caches for generation inference.")
.AsDispensable()
.AsDuplicable();
AddInput("TimeStep",
"(optional, int) The time step for generation inference.")
.AsDispensable();
......
......@@ -80,6 +80,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto cache_kvs = ctx.MultiInput<phi::DenseTensor>("CacheKV");
auto cache_kv_outs = ctx.MultiOutput<phi::DenseTensor>("CacheKVOut");
// auto *time_step = ctx.Input<phi::DenseTensor>("TimeStep");
auto pre_caches = ctx.MultiInput<phi::DenseTensor>("PreCaches");
int cache_offset = 0;
if (pre_caches.size() > 0) {
cache_offset = pre_caches[0]->dims()[3];
}
auto out_seq_len = seq_len;
if (time_step) {
......@@ -101,6 +106,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
"In decode stage, the seq_len of input must be 1, but now is %d",
seq_len));
out_seq_len += time_step_value;
} else {
out_seq_len += cache_offset;
}
Tensor transpose_out_2, qk_out;
......@@ -110,6 +117,22 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *qk_out_data = dev_ctx.Alloc<T>(&qk_out, qk_out.numel() * sizeof(T));
Tensor src_mask_out;
if (cache_offset > 0) {
src_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *src_mask_out_data =
dev_ctx.Alloc<T>(&src_mask_out, src_mask_out.numel() * sizeof(T));
}
// [2, bs, num_head, cache_seq_len + seq_len, head_dim]
Tensor pre_cache_kv_out;
if (cache_offset > 0) {
pre_cache_kv_out.Resize(
{{2, bsz, num_head, seq_len + cache_offset, dim_head}});
auto *pre_cache_kv_out_data = dev_ctx.Alloc<T>(
&pre_cache_kv_out, pre_cache_kv_out.numel() * sizeof(T));
}
Tensor softmax_out;
Tensor attn_dropout_mask_out, attn_dropout_out;
Tensor qktv_out, fmha_out;
......@@ -277,26 +300,42 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
time_step->data<int>()[0],
1. / sqrt(dim_head));
} else if (cache_kv_out) { // generation context stage
// TODO(wangxi): can remove dropout in inference
const Tensor *pre_cache_kv_tensor =
pre_caches.size() > 0 ? pre_caches[i] : nullptr;
Tensor *pre_cache_kv_out_tmp =
cache_offset > 0 ? &pre_cache_kv_out : nullptr;
Tensor *src_mask_tmp = cache_offset > 0 ? &src_mask_out : nullptr;
fmha_compute.ComputeForward(qkv_out,
nullptr,
pre_cache_kv_tensor,
src_mask,
&transpose_out_2,
nullptr,
pre_cache_kv_out_tmp,
&qk_out,
nullptr,
src_mask_tmp,
&softmax_out,
&attn_dropout_mask_out,
&attn_dropout_out,
&qktv_out,
&fmha_out);
// [3, bsz, num_head, seq_len, head_dim]
T *qkv_data = transpose_out_2_data;
int64_t q_size = bsz * seq_len * num_head * dim_head;
int64_t k_size = q_size;
const T *q_ptr = qkv_data;
const T *k_ptr = q_ptr + q_size;
const T *v_ptr = k_ptr + k_size;
const T *k_ptr = nullptr;
const T *v_ptr = nullptr;
if (cache_offset > 0) {
// [2, bsz, num_head, cache_offset + seq_len, head_dim]
const T *kv_data = pre_cache_kv_out.data<T>();
k_ptr = kv_data;
int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head;
v_ptr = k_ptr + k_size;
} else {
// [3, bsz, num_head, seq_len, head_dim]
T *qkv_data = transpose_out_2_data;
int64_t q_size = bsz * seq_len * num_head * dim_head;
int64_t k_size = q_size;
const T *q_ptr = qkv_data;
k_ptr = q_ptr + q_size;
v_ptr = k_ptr + k_size;
}
// [2, bsz, num_head, max_seq_len, head_dim]
int max_seq_len = cache_kv_out->dims()[3];
......@@ -306,6 +345,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
T *cache_k_ptr = cache_kv_data;
T *cache_v_ptr = cache_kv_data + cache_k_size;
const int seq_len_tmp = seq_len + cache_offset;
write_cache_kv<T>(dev_ctx,
cache_k_ptr,
cache_v_ptr,
......@@ -313,7 +353,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
v_ptr,
bsz,
num_head,
seq_len,
seq_len_tmp,
max_seq_len,
dim_head);
} else { // not generation
......
......@@ -61,6 +61,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"QKVW",
"QKVBias",
"CacheKV",
"PreCaches",
"TimeStep",
"SrcMask",
"OutLinearW",
......
......@@ -34,6 +34,7 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.framework import _non_static_mode, default_main_program
from paddle import _C_ops, _legacy_C_ops
from paddle.incubate.nn.functional import fused_multi_transformer
from paddle.incubate.nn import FusedMultiTransformer
default_main_program().random_seed = 42
......@@ -114,6 +115,7 @@ class TestFusedMultiTransformerOp(OpTest):
# True, False, generation decoder stage
self.has_cache_kv = False
self.gen_cache_kv = False
self.has_pre_cache = False
self.training = False
......@@ -121,6 +123,7 @@ class TestFusedMultiTransformerOp(OpTest):
self.batch_size = 8
self.query_length = 128
self.cache_length = 128
self.pre_cache_num = 64
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
......@@ -150,6 +153,12 @@ class TestFusedMultiTransformerOp(OpTest):
else:
self.cache_kv = None
if self.has_pre_cache:
out_seq_len += self.pre_cache_num
self.pre_cache_kv = np.random.rand(
2, self.batch_size, self.num_heads, self.pre_cache_num,
self.head_dim).astype(self.x_type)
if self.has_attn_mask:
# [B, n_head, seq_len, out_seq_len]
self.attn_mask = np.ones(
......@@ -188,6 +197,10 @@ class TestFusedMultiTransformerOp(OpTest):
if self.has_cache_kv:
cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)
if self.has_pre_cache:
pre_cache_kv = paddle.to_tensor(self.pre_cache_kv,
stop_gradient=False)
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
......@@ -227,6 +240,13 @@ class TestFusedMultiTransformerOp(OpTest):
k_out = paddle.concat([cache_k, k_out], axis=-2)
v_out = paddle.concat([cache_v, v_out], axis=-2)
if self.has_pre_cache:
pre_cache_k, pre_cache_v = paddle.split(pre_cache_kv, 2)
pre_cache_k = paddle.squeeze(pre_cache_k, axis=0)
pre_cache_v = paddle.squeeze(pre_cache_v, axis=0)
k_out = paddle.concat([pre_cache_k, k_out], axis=-2)
v_out = paddle.concat([pre_cache_v, v_out], axis=-2)
# [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, out_seq_len]
qk_out = layers.matmul(x=q_out,
......@@ -348,6 +368,7 @@ class TestFusedMultiTransformerOp(OpTest):
x = paddle.to_tensor(self.query, stop_gradient=False)
cache_kvs, cache_kv = None, None
time_step = None
pre_caches, pre_cache = None, None
if self.has_cache_kv:
cache_kvs = []
......@@ -387,6 +408,18 @@ class TestFusedMultiTransformerOp(OpTest):
time_step = paddle.to_tensor([self.cache_length],
dtype='int32',
place=paddle.CPUPlace())
if self.has_pre_cache:
cache_kvs = []
max_seq_length = (self.cache_length +
128) // 128 * 128 + self.pre_cache_num
cache_kv = np.zeros([
2, self.batch_size, self.num_heads, max_seq_length,
self.head_dim
],
dtype=self.x_type)
pre_caches = []
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
......@@ -420,6 +453,11 @@ class TestFusedMultiTransformerOp(OpTest):
if self.has_cache_kv:
cache_kvs.append(paddle.to_tensor(cache_kv,
stop_gradient=False))
if self.has_pre_cache:
cache_kvs.append(paddle.to_tensor(cache_kv,
stop_gradient=False))
pre_caches.append(
paddle.to_tensor(self.pre_cache_kv, stop_gradient=False))
final_out = fused_multi_transformer(x,
ln_scales,
......@@ -437,6 +475,7 @@ class TestFusedMultiTransformerOp(OpTest):
pre_layer_norm=self.pre_layer_norm,
epsilon=epsilon,
cache_kvs=cache_kvs,
pre_caches=pre_caches,
time_step=time_step,
attn_mask=attn_mask,
dropout_rate=self.dropout_prob,
......@@ -445,8 +484,159 @@ class TestFusedMultiTransformerOp(OpTest):
if self.has_cache_kv:
return final_out[0], final_out[1]
if self.has_pre_cache:
return final_out[0]
return final_out
def GetFusedMultiTransformerOutStatic(self):
paddle.enable_static()
x = paddle.fluid.data('x', self.query.shape, self.query.dtype)
cache_kvs, cache_kv = None, None
time_step = None
time_step_feed = None
pre_caches, pre_cache = None, None
if self.has_cache_kv:
cache_kvs = []
max_seq_length = (self.cache_length + 128) // 128 * 128
cache_kv = np.zeros([
2, self.batch_size, self.num_heads, max_seq_length,
self.head_dim
],
dtype=self.x_type)
elems = 4
if self.x_type is np.float16:
elems = 8
assert self.head_dim % elems == 0
v_elems = self.head_dim // elems
cache_k_tmp = self.cache_kv[0].reshape([
self.batch_size, self.num_heads, self.cache_length, v_elems,
elems
])
# [B, num_head, head_dim / 4, 128, 4]
cache_k_tmp = cache_k_tmp.transpose([0, 1, 3, 2, 4])
cache_kv[0, :].reshape([
self.batch_size, self.num_heads, v_elems, max_seq_length, elems
])[:, :, :, :self.cache_length, :] = cache_k_tmp
cache_kv[1, :, :, :self.cache_length, :] = self.cache_kv[1]
if self.gen_cache_kv:
assert self.query_length == self.cache_length
cache_kv[:] = 0
else:
time_step = layers.fill_constant(shape=[1],
dtype="int32",
value=0,
force_cpu=True)
time_step_feed = self.cache_length
if self.has_pre_cache:
cache_kvs = []
max_seq_length = (self.cache_length +
128) // 128 * 128 + self.pre_cache_num
cache_kv = np.zeros([
2, self.batch_size, self.num_heads, max_seq_length,
self.head_dim
],
dtype=self.x_type)
pre_caches = []
attn_mask = None
epsilon = 1e-05
ln2_epsilon = 1e-05
qkv_weights_attr, qkv_biases_attr = [], []
out_weights_attr, out_biases_attr = [], []
ln_scales_attr, ln_biases_attr = [], []
ffn1_weights_attr, ffn1_biases_attr = [], []
ffn2_weights_attr, ffn2_biases_attr = [], []
ffn_ln_scales_attr, ffn_ln_biases_attr = [], []
if self.has_cache_kv:
cache_kvs_feed = []
if self.has_pre_cache:
cache_kvs_feed = []
pre_caches_feed = []
for i in range(self.layers):
qkv_weights_attr.append(self.weight_attr)
qkv_biases_attr.append(self.bias_attr)
out_weights_attr.append(self.weight_attr)
out_biases_attr.append(self.bias_attr)
ln_scales_attr.append(self.ln_w_attr)
ln_biases_attr.append(self.ln_b_attr)
ffn1_weights_attr.append(self.weight_attr)
ffn1_biases_attr.append(self.bias_attr)
ffn2_weights_attr.append(self.weight_attr)
ffn2_biases_attr.append(self.bias_attr)
ffn_ln_scales_attr.append(self.ln_w_attr)
ffn_ln_biases_attr.append(self.ln_b_attr)
transformer = FusedMultiTransformer(
self.embed_dim,
self.num_heads,
4 * self.embed_dim,
self.dropout_prob,
normalize_before=self.pre_layer_norm,
ln_scale_attrs=ln_scales_attr,
ln_bias_attrs=ln_biases_attr,
qkv_weight_attrs=qkv_weights_attr,
qkv_bias_attrs=qkv_biases_attr,
linear_weight_attrs=out_weights_attr,
linear_bias_attrs=out_biases_attr,
ffn_ln_scale_attrs=ffn_ln_scales_attr,
ffn_ln_bias_attrs=ffn_ln_biases_attr,
ffn1_weight_attrs=ffn1_weights_attr,
ffn1_bias_attrs=ffn1_biases_attr,
ffn2_weight_attrs=ffn2_weights_attr,
ffn2_bias_attrs=ffn2_biases_attr)
transformer.eval()
for i in range(self.layers):
if self.has_cache_kv:
cache_kvs.append(
layers.fill_constant(shape=cache_kv.shape,
dtype=cache_kv.dtype,
value=0))
cache_kvs_feed.append(cache_kv)
if self.has_pre_cache:
cache_kvs.append(
layers.fill_constant(shape=cache_kv.shape,
dtype=cache_kv.dtype,
value=0))
cache_kvs_feed.append(cache_kv)
pre_caches.append(
layers.fill_constant(shape=self.pre_cache_kv.shape,
dtype=self.pre_cache_kv.dtype,
value=0))
pre_caches_feed.append(self.pre_cache_kv)
final_out = transformer(x,
attn_mask=attn_mask,
caches=cache_kvs,
pre_caches=pre_caches,
time_step=time_step)[0]
exe = paddle.static.Executor(place=paddle.CUDAPlace(0))
exe.run(paddle.static.default_startup_program())
feed_data = {
'x': self.query,
'cache_kvs': cache_kvs_feed,
'pre_caches': pre_caches_feed,
'time_step': time_step_feed,
'attn_mask': attn_mask
}
out = exe.run(paddle.fluid.default_main_program(),
feed=feed_data,
fetch_list=[final_out])
paddle.disable_static()
return out[0]
def test_fused_multi_transformer_op(self):
final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedMultiTransformerOut()
......@@ -603,5 +793,39 @@ class TestFusedMultiTransformerOpGenCacheKVPostLayerNormFp16(
self.pre_layer_norm = False
class TestFusedMultiTransformerOpPreCache(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_pre_cache = True
self.x_type = np.float16
class TestFusedMultiTransformerOpPreCacheStatic(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_pre_cache = True
self.has_attn_mask = False
self.x_type = np.float32
self.weight_attr = paddle.ParamAttr(
initializer=paddle.fluid.initializer.Constant(0.))
self.bias_attr = paddle.ParamAttr(
initializer=paddle.fluid.initializer.Constant(0.0005))
self.ln_w_attr = paddle.ParamAttr(
initializer=paddle.fluid.initializer.Constant(1.))
self.ln_b_attr = paddle.ParamAttr(
initializer=paddle.fluid.initializer.Constant(0.))
def test_fused_multi_transformer_op(self):
final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedMultiTransformerOutStatic()
np.testing.assert_allclose(final_out_ref,
final_out,
rtol=self.rtol,
atol=self.atol)
if __name__ == "__main__":
unittest.main()
......@@ -674,6 +674,7 @@ def fused_multi_transformer(x,
pre_layer_norm=True,
epsilon=1e-05,
cache_kvs=None,
pre_caches=None,
time_step=None,
attn_mask=None,
dropout_rate=0.0,
......@@ -739,6 +740,7 @@ def fused_multi_transformer(x,
pre_layer_norm (bool, optional): whether it is pre_layer_norm(True) or post_layer_norm(False). Default True.
epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5.
cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None.
pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None.
time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None.
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
......@@ -826,12 +828,13 @@ def fused_multi_transformer(x,
if _non_static_mode():
cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer(
x, ln_scales, ln_biases, qkv_weights, qkv_biases, cache_kvs,
time_step, attn_mask, linear_weights, linear_biases, ffn_ln_scales,
ffn_ln_biases, ffn1_weights, ffn1_biases, ffn2_weights, ffn2_biases,
cache_kvs, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon,
'dropout_rate', dropout_rate, 'is_test', not training,
'dropout_implementation', mode, 'act_method', activation,
'trans_qkvw', trans_qkvw, 'ring_id', ring_id)
pre_caches, time_step, attn_mask, linear_weights, linear_biases,
ffn_ln_scales, ffn_ln_biases, ffn1_weights, ffn1_biases,
ffn2_weights, ffn2_biases, cache_kvs, 'pre_layer_norm',
pre_layer_norm, 'epsilon', epsilon, 'dropout_rate', dropout_rate,
'is_test', not training, 'dropout_implementation', mode,
'act_method', activation, 'trans_qkvw', trans_qkvw, 'ring_id',
ring_id)
if cache_kvs is not None:
return final_out, cache_kv_out
return final_out
......@@ -857,6 +860,8 @@ def fused_multi_transformer(x,
inputs['CacheKV'] = cache_kvs
if time_step is not None:
inputs['TimeStep'] = time_step
if pre_caches is not None:
inputs['PreCaches'] = pre_caches
inputs['SrcMask'] = attn_mask
inputs['OutLinearW'] = linear_weights
if linear_biases is not None:
......
......@@ -1242,7 +1242,12 @@ class FusedMultiTransformer(Layer):
self.activation = activation
self.name = name
def forward(self, src, attn_mask=None, caches=None, time_step=None):
def forward(self,
src,
attn_mask=None,
caches=None,
pre_caches=None,
time_step=None):
"""
Applies multi transformer layers on the input.
......@@ -1260,6 +1265,8 @@ class FusedMultiTransformer(Layer):
tensors for the inference generation model. It is only used for
inference and should be None for training. The shape is
`[2, batch_size, num_head, max_seq_len, head_dim]`. Default None.
pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches
for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None.
time_step (Tensor, optional): The time step tensor for the generation
model. Which used in decode stage, to represent the time step,
that is, the real seq_len of CacheKV. The shape is `[1]`, must be
......@@ -1292,6 +1299,7 @@ class FusedMultiTransformer(Layer):
pre_layer_norm=self.normalize_before,
epsilon=self._epsilon,
cache_kvs=caches,
pre_caches=pre_caches,
time_step=time_step,
attn_mask=attn_mask,
dropout_rate=self.dropout_rate,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册