未验证 提交 9ea279a4 编写于 作者: C carryyu 提交者: GitHub

make fused_multi_transformer support dynamically set the cache_kvs' shape and...

make fused_multi_transformer support dynamically set the cache_kvs' shape and support input prefix_caches. (#46777)

* make fused_multi_transformer support dynamically set the cache_kvs' shape and support input prefix_caches.
上级 af6d80fb
...@@ -143,12 +143,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { ...@@ -143,12 +143,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel {
"head %d, but got %d", "head %d, but got %d",
trans_qkvw ? y_dim[1] : y_dim[2], trans_qkvw ? y_dim[1] : y_dim[2],
c_dim[2])); // num_head c_dim[2])); // num_head
PADDLE_ENFORCE_GT(
c_dim[3],
0,
paddle::platform::errors::InvalidArgument(
"The forth dim of CacheKV must be greater than 0, but got %d",
c_dim[3])); // cache_seq_len
PADDLE_ENFORCE_EQ(c_dim[4], PADDLE_ENFORCE_EQ(c_dim[4],
trans_qkvw ? y_dim[2] : y_dim[3], trans_qkvw ? y_dim[2] : y_dim[3],
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
...@@ -199,6 +193,10 @@ class FusedMultiTransformerOpOpMaker ...@@ -199,6 +193,10 @@ class FusedMultiTransformerOpOpMaker
AddInput("CacheKV", "(optional) The cached KV for generation inference.") AddInput("CacheKV", "(optional) The cached KV for generation inference.")
.AsDispensable() .AsDispensable()
.AsDuplicable(); .AsDuplicable();
AddInput("PreCaches",
"(optional) The prefix caches for generation inference.")
.AsDispensable()
.AsDuplicable();
AddInput("TimeStep", AddInput("TimeStep",
"(optional, int) The time step for generation inference.") "(optional, int) The time step for generation inference.")
.AsDispensable(); .AsDispensable();
......
...@@ -80,6 +80,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -80,6 +80,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto cache_kvs = ctx.MultiInput<phi::DenseTensor>("CacheKV"); auto cache_kvs = ctx.MultiInput<phi::DenseTensor>("CacheKV");
auto cache_kv_outs = ctx.MultiOutput<phi::DenseTensor>("CacheKVOut"); auto cache_kv_outs = ctx.MultiOutput<phi::DenseTensor>("CacheKVOut");
// auto *time_step = ctx.Input<phi::DenseTensor>("TimeStep"); // auto *time_step = ctx.Input<phi::DenseTensor>("TimeStep");
auto pre_caches = ctx.MultiInput<phi::DenseTensor>("PreCaches");
int cache_offset = 0;
if (pre_caches.size() > 0) {
cache_offset = pre_caches[0]->dims()[3];
}
auto out_seq_len = seq_len; auto out_seq_len = seq_len;
if (time_step) { if (time_step) {
...@@ -101,6 +106,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -101,6 +106,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
"In decode stage, the seq_len of input must be 1, but now is %d", "In decode stage, the seq_len of input must be 1, but now is %d",
seq_len)); seq_len));
out_seq_len += time_step_value; out_seq_len += time_step_value;
} else {
out_seq_len += cache_offset;
} }
Tensor transpose_out_2, qk_out; Tensor transpose_out_2, qk_out;
...@@ -110,6 +117,22 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -110,6 +117,22 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *qk_out_data = dev_ctx.Alloc<T>(&qk_out, qk_out.numel() * sizeof(T)); auto *qk_out_data = dev_ctx.Alloc<T>(&qk_out, qk_out.numel() * sizeof(T));
Tensor src_mask_out;
if (cache_offset > 0) {
src_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *src_mask_out_data =
dev_ctx.Alloc<T>(&src_mask_out, src_mask_out.numel() * sizeof(T));
}
// [2, bs, num_head, cache_seq_len + seq_len, head_dim]
Tensor pre_cache_kv_out;
if (cache_offset > 0) {
pre_cache_kv_out.Resize(
{{2, bsz, num_head, seq_len + cache_offset, dim_head}});
auto *pre_cache_kv_out_data = dev_ctx.Alloc<T>(
&pre_cache_kv_out, pre_cache_kv_out.numel() * sizeof(T));
}
Tensor softmax_out; Tensor softmax_out;
Tensor attn_dropout_mask_out, attn_dropout_out; Tensor attn_dropout_mask_out, attn_dropout_out;
Tensor qktv_out, fmha_out; Tensor qktv_out, fmha_out;
...@@ -277,26 +300,42 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -277,26 +300,42 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
time_step->data<int>()[0], time_step->data<int>()[0],
1. / sqrt(dim_head)); 1. / sqrt(dim_head));
} else if (cache_kv_out) { // generation context stage } else if (cache_kv_out) { // generation context stage
// TODO(wangxi): can remove dropout in inference const Tensor *pre_cache_kv_tensor =
pre_caches.size() > 0 ? pre_caches[i] : nullptr;
Tensor *pre_cache_kv_out_tmp =
cache_offset > 0 ? &pre_cache_kv_out : nullptr;
Tensor *src_mask_tmp = cache_offset > 0 ? &src_mask_out : nullptr;
fmha_compute.ComputeForward(qkv_out, fmha_compute.ComputeForward(qkv_out,
nullptr, pre_cache_kv_tensor,
src_mask, src_mask,
&transpose_out_2, &transpose_out_2,
nullptr, pre_cache_kv_out_tmp,
&qk_out, &qk_out,
nullptr, src_mask_tmp,
&softmax_out, &softmax_out,
&attn_dropout_mask_out, &attn_dropout_mask_out,
&attn_dropout_out, &attn_dropout_out,
&qktv_out, &qktv_out,
&fmha_out); &fmha_out);
// [3, bsz, num_head, seq_len, head_dim]
T *qkv_data = transpose_out_2_data; const T *k_ptr = nullptr;
int64_t q_size = bsz * seq_len * num_head * dim_head; const T *v_ptr = nullptr;
int64_t k_size = q_size;
const T *q_ptr = qkv_data; if (cache_offset > 0) {
const T *k_ptr = q_ptr + q_size; // [2, bsz, num_head, cache_offset + seq_len, head_dim]
const T *v_ptr = k_ptr + k_size; const T *kv_data = pre_cache_kv_out.data<T>();
k_ptr = kv_data;
int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head;
v_ptr = k_ptr + k_size;
} else {
// [3, bsz, num_head, seq_len, head_dim]
T *qkv_data = transpose_out_2_data;
int64_t q_size = bsz * seq_len * num_head * dim_head;
int64_t k_size = q_size;
const T *q_ptr = qkv_data;
k_ptr = q_ptr + q_size;
v_ptr = k_ptr + k_size;
}
// [2, bsz, num_head, max_seq_len, head_dim] // [2, bsz, num_head, max_seq_len, head_dim]
int max_seq_len = cache_kv_out->dims()[3]; int max_seq_len = cache_kv_out->dims()[3];
...@@ -306,6 +345,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -306,6 +345,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
T *cache_k_ptr = cache_kv_data; T *cache_k_ptr = cache_kv_data;
T *cache_v_ptr = cache_kv_data + cache_k_size; T *cache_v_ptr = cache_kv_data + cache_k_size;
const int seq_len_tmp = seq_len + cache_offset;
write_cache_kv<T>(dev_ctx, write_cache_kv<T>(dev_ctx,
cache_k_ptr, cache_k_ptr,
cache_v_ptr, cache_v_ptr,
...@@ -313,7 +353,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -313,7 +353,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
v_ptr, v_ptr,
bsz, bsz,
num_head, num_head,
seq_len, seq_len_tmp,
max_seq_len, max_seq_len,
dim_head); dim_head);
} else { // not generation } else { // not generation
......
...@@ -61,6 +61,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -61,6 +61,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"QKVW", "QKVW",
"QKVBias", "QKVBias",
"CacheKV", "CacheKV",
"PreCaches",
"TimeStep", "TimeStep",
"SrcMask", "SrcMask",
"OutLinearW", "OutLinearW",
......
...@@ -34,6 +34,7 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype ...@@ -34,6 +34,7 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.framework import _non_static_mode, default_main_program from paddle.fluid.framework import _non_static_mode, default_main_program
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from paddle.incubate.nn.functional import fused_multi_transformer from paddle.incubate.nn.functional import fused_multi_transformer
from paddle.incubate.nn import FusedMultiTransformer
default_main_program().random_seed = 42 default_main_program().random_seed = 42
...@@ -114,6 +115,7 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -114,6 +115,7 @@ class TestFusedMultiTransformerOp(OpTest):
# True, False, generation decoder stage # True, False, generation decoder stage
self.has_cache_kv = False self.has_cache_kv = False
self.gen_cache_kv = False self.gen_cache_kv = False
self.has_pre_cache = False
self.training = False self.training = False
...@@ -121,6 +123,7 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -121,6 +123,7 @@ class TestFusedMultiTransformerOp(OpTest):
self.batch_size = 8 self.batch_size = 8
self.query_length = 128 self.query_length = 128
self.cache_length = 128 self.cache_length = 128
self.pre_cache_num = 64
self.head_dim = 64 self.head_dim = 64
self.num_heads = 16 self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads self.embed_dim = self.head_dim * self.num_heads
...@@ -150,6 +153,12 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -150,6 +153,12 @@ class TestFusedMultiTransformerOp(OpTest):
else: else:
self.cache_kv = None self.cache_kv = None
if self.has_pre_cache:
out_seq_len += self.pre_cache_num
self.pre_cache_kv = np.random.rand(
2, self.batch_size, self.num_heads, self.pre_cache_num,
self.head_dim).astype(self.x_type)
if self.has_attn_mask: if self.has_attn_mask:
# [B, n_head, seq_len, out_seq_len] # [B, n_head, seq_len, out_seq_len]
self.attn_mask = np.ones( self.attn_mask = np.ones(
...@@ -188,6 +197,10 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -188,6 +197,10 @@ class TestFusedMultiTransformerOp(OpTest):
if self.has_cache_kv: if self.has_cache_kv:
cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False) cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)
if self.has_pre_cache:
pre_cache_kv = paddle.to_tensor(self.pre_cache_kv,
stop_gradient=False)
if self.has_attn_mask: if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else: else:
...@@ -227,6 +240,13 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -227,6 +240,13 @@ class TestFusedMultiTransformerOp(OpTest):
k_out = paddle.concat([cache_k, k_out], axis=-2) k_out = paddle.concat([cache_k, k_out], axis=-2)
v_out = paddle.concat([cache_v, v_out], axis=-2) v_out = paddle.concat([cache_v, v_out], axis=-2)
if self.has_pre_cache:
pre_cache_k, pre_cache_v = paddle.split(pre_cache_kv, 2)
pre_cache_k = paddle.squeeze(pre_cache_k, axis=0)
pre_cache_v = paddle.squeeze(pre_cache_v, axis=0)
k_out = paddle.concat([pre_cache_k, k_out], axis=-2)
v_out = paddle.concat([pre_cache_v, v_out], axis=-2)
# [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim] # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, out_seq_len] # --> [B, n_head, seq_len, out_seq_len]
qk_out = layers.matmul(x=q_out, qk_out = layers.matmul(x=q_out,
...@@ -348,6 +368,7 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -348,6 +368,7 @@ class TestFusedMultiTransformerOp(OpTest):
x = paddle.to_tensor(self.query, stop_gradient=False) x = paddle.to_tensor(self.query, stop_gradient=False)
cache_kvs, cache_kv = None, None cache_kvs, cache_kv = None, None
time_step = None time_step = None
pre_caches, pre_cache = None, None
if self.has_cache_kv: if self.has_cache_kv:
cache_kvs = [] cache_kvs = []
...@@ -387,6 +408,18 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -387,6 +408,18 @@ class TestFusedMultiTransformerOp(OpTest):
time_step = paddle.to_tensor([self.cache_length], time_step = paddle.to_tensor([self.cache_length],
dtype='int32', dtype='int32',
place=paddle.CPUPlace()) place=paddle.CPUPlace())
if self.has_pre_cache:
cache_kvs = []
max_seq_length = (self.cache_length +
128) // 128 * 128 + self.pre_cache_num
cache_kv = np.zeros([
2, self.batch_size, self.num_heads, max_seq_length,
self.head_dim
],
dtype=self.x_type)
pre_caches = []
if self.has_attn_mask: if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else: else:
...@@ -420,6 +453,11 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -420,6 +453,11 @@ class TestFusedMultiTransformerOp(OpTest):
if self.has_cache_kv: if self.has_cache_kv:
cache_kvs.append(paddle.to_tensor(cache_kv, cache_kvs.append(paddle.to_tensor(cache_kv,
stop_gradient=False)) stop_gradient=False))
if self.has_pre_cache:
cache_kvs.append(paddle.to_tensor(cache_kv,
stop_gradient=False))
pre_caches.append(
paddle.to_tensor(self.pre_cache_kv, stop_gradient=False))
final_out = fused_multi_transformer(x, final_out = fused_multi_transformer(x,
ln_scales, ln_scales,
...@@ -437,6 +475,7 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -437,6 +475,7 @@ class TestFusedMultiTransformerOp(OpTest):
pre_layer_norm=self.pre_layer_norm, pre_layer_norm=self.pre_layer_norm,
epsilon=epsilon, epsilon=epsilon,
cache_kvs=cache_kvs, cache_kvs=cache_kvs,
pre_caches=pre_caches,
time_step=time_step, time_step=time_step,
attn_mask=attn_mask, attn_mask=attn_mask,
dropout_rate=self.dropout_prob, dropout_rate=self.dropout_prob,
...@@ -445,8 +484,159 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -445,8 +484,159 @@ class TestFusedMultiTransformerOp(OpTest):
if self.has_cache_kv: if self.has_cache_kv:
return final_out[0], final_out[1] return final_out[0], final_out[1]
if self.has_pre_cache:
return final_out[0]
return final_out return final_out
def GetFusedMultiTransformerOutStatic(self):
paddle.enable_static()
x = paddle.fluid.data('x', self.query.shape, self.query.dtype)
cache_kvs, cache_kv = None, None
time_step = None
time_step_feed = None
pre_caches, pre_cache = None, None
if self.has_cache_kv:
cache_kvs = []
max_seq_length = (self.cache_length + 128) // 128 * 128
cache_kv = np.zeros([
2, self.batch_size, self.num_heads, max_seq_length,
self.head_dim
],
dtype=self.x_type)
elems = 4
if self.x_type is np.float16:
elems = 8
assert self.head_dim % elems == 0
v_elems = self.head_dim // elems
cache_k_tmp = self.cache_kv[0].reshape([
self.batch_size, self.num_heads, self.cache_length, v_elems,
elems
])
# [B, num_head, head_dim / 4, 128, 4]
cache_k_tmp = cache_k_tmp.transpose([0, 1, 3, 2, 4])
cache_kv[0, :].reshape([
self.batch_size, self.num_heads, v_elems, max_seq_length, elems
])[:, :, :, :self.cache_length, :] = cache_k_tmp
cache_kv[1, :, :, :self.cache_length, :] = self.cache_kv[1]
if self.gen_cache_kv:
assert self.query_length == self.cache_length
cache_kv[:] = 0
else:
time_step = layers.fill_constant(shape=[1],
dtype="int32",
value=0,
force_cpu=True)
time_step_feed = self.cache_length
if self.has_pre_cache:
cache_kvs = []
max_seq_length = (self.cache_length +
128) // 128 * 128 + self.pre_cache_num
cache_kv = np.zeros([
2, self.batch_size, self.num_heads, max_seq_length,
self.head_dim
],
dtype=self.x_type)
pre_caches = []
attn_mask = None
epsilon = 1e-05
ln2_epsilon = 1e-05
qkv_weights_attr, qkv_biases_attr = [], []
out_weights_attr, out_biases_attr = [], []
ln_scales_attr, ln_biases_attr = [], []
ffn1_weights_attr, ffn1_biases_attr = [], []
ffn2_weights_attr, ffn2_biases_attr = [], []
ffn_ln_scales_attr, ffn_ln_biases_attr = [], []
if self.has_cache_kv:
cache_kvs_feed = []
if self.has_pre_cache:
cache_kvs_feed = []
pre_caches_feed = []
for i in range(self.layers):
qkv_weights_attr.append(self.weight_attr)
qkv_biases_attr.append(self.bias_attr)
out_weights_attr.append(self.weight_attr)
out_biases_attr.append(self.bias_attr)
ln_scales_attr.append(self.ln_w_attr)
ln_biases_attr.append(self.ln_b_attr)
ffn1_weights_attr.append(self.weight_attr)
ffn1_biases_attr.append(self.bias_attr)
ffn2_weights_attr.append(self.weight_attr)
ffn2_biases_attr.append(self.bias_attr)
ffn_ln_scales_attr.append(self.ln_w_attr)
ffn_ln_biases_attr.append(self.ln_b_attr)
transformer = FusedMultiTransformer(
self.embed_dim,
self.num_heads,
4 * self.embed_dim,
self.dropout_prob,
normalize_before=self.pre_layer_norm,
ln_scale_attrs=ln_scales_attr,
ln_bias_attrs=ln_biases_attr,
qkv_weight_attrs=qkv_weights_attr,
qkv_bias_attrs=qkv_biases_attr,
linear_weight_attrs=out_weights_attr,
linear_bias_attrs=out_biases_attr,
ffn_ln_scale_attrs=ffn_ln_scales_attr,
ffn_ln_bias_attrs=ffn_ln_biases_attr,
ffn1_weight_attrs=ffn1_weights_attr,
ffn1_bias_attrs=ffn1_biases_attr,
ffn2_weight_attrs=ffn2_weights_attr,
ffn2_bias_attrs=ffn2_biases_attr)
transformer.eval()
for i in range(self.layers):
if self.has_cache_kv:
cache_kvs.append(
layers.fill_constant(shape=cache_kv.shape,
dtype=cache_kv.dtype,
value=0))
cache_kvs_feed.append(cache_kv)
if self.has_pre_cache:
cache_kvs.append(
layers.fill_constant(shape=cache_kv.shape,
dtype=cache_kv.dtype,
value=0))
cache_kvs_feed.append(cache_kv)
pre_caches.append(
layers.fill_constant(shape=self.pre_cache_kv.shape,
dtype=self.pre_cache_kv.dtype,
value=0))
pre_caches_feed.append(self.pre_cache_kv)
final_out = transformer(x,
attn_mask=attn_mask,
caches=cache_kvs,
pre_caches=pre_caches,
time_step=time_step)[0]
exe = paddle.static.Executor(place=paddle.CUDAPlace(0))
exe.run(paddle.static.default_startup_program())
feed_data = {
'x': self.query,
'cache_kvs': cache_kvs_feed,
'pre_caches': pre_caches_feed,
'time_step': time_step_feed,
'attn_mask': attn_mask
}
out = exe.run(paddle.fluid.default_main_program(),
feed=feed_data,
fetch_list=[final_out])
paddle.disable_static()
return out[0]
def test_fused_multi_transformer_op(self): def test_fused_multi_transformer_op(self):
final_out_ref = self.GetBaselineOut() final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedMultiTransformerOut() final_out = self.GetFusedMultiTransformerOut()
...@@ -603,5 +793,39 @@ class TestFusedMultiTransformerOpGenCacheKVPostLayerNormFp16( ...@@ -603,5 +793,39 @@ class TestFusedMultiTransformerOpGenCacheKVPostLayerNormFp16(
self.pre_layer_norm = False self.pre_layer_norm = False
class TestFusedMultiTransformerOpPreCache(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_pre_cache = True
self.x_type = np.float16
class TestFusedMultiTransformerOpPreCacheStatic(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_pre_cache = True
self.has_attn_mask = False
self.x_type = np.float32
self.weight_attr = paddle.ParamAttr(
initializer=paddle.fluid.initializer.Constant(0.))
self.bias_attr = paddle.ParamAttr(
initializer=paddle.fluid.initializer.Constant(0.0005))
self.ln_w_attr = paddle.ParamAttr(
initializer=paddle.fluid.initializer.Constant(1.))
self.ln_b_attr = paddle.ParamAttr(
initializer=paddle.fluid.initializer.Constant(0.))
def test_fused_multi_transformer_op(self):
final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedMultiTransformerOutStatic()
np.testing.assert_allclose(final_out_ref,
final_out,
rtol=self.rtol,
atol=self.atol)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -674,6 +674,7 @@ def fused_multi_transformer(x, ...@@ -674,6 +674,7 @@ def fused_multi_transformer(x,
pre_layer_norm=True, pre_layer_norm=True,
epsilon=1e-05, epsilon=1e-05,
cache_kvs=None, cache_kvs=None,
pre_caches=None,
time_step=None, time_step=None,
attn_mask=None, attn_mask=None,
dropout_rate=0.0, dropout_rate=0.0,
...@@ -739,6 +740,7 @@ def fused_multi_transformer(x, ...@@ -739,6 +740,7 @@ def fused_multi_transformer(x,
pre_layer_norm (bool, optional): whether it is pre_layer_norm(True) or post_layer_norm(False). Default True. pre_layer_norm (bool, optional): whether it is pre_layer_norm(True) or post_layer_norm(False). Default True.
epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5. epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5.
cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None. cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None.
pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None.
time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None. time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None.
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
some unwanted positions, usually the paddings or the subsequent positions. It is a tensor some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
...@@ -826,12 +828,13 @@ def fused_multi_transformer(x, ...@@ -826,12 +828,13 @@ def fused_multi_transformer(x,
if _non_static_mode(): if _non_static_mode():
cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer( cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer(
x, ln_scales, ln_biases, qkv_weights, qkv_biases, cache_kvs, x, ln_scales, ln_biases, qkv_weights, qkv_biases, cache_kvs,
time_step, attn_mask, linear_weights, linear_biases, ffn_ln_scales, pre_caches, time_step, attn_mask, linear_weights, linear_biases,
ffn_ln_biases, ffn1_weights, ffn1_biases, ffn2_weights, ffn2_biases, ffn_ln_scales, ffn_ln_biases, ffn1_weights, ffn1_biases,
cache_kvs, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, ffn2_weights, ffn2_biases, cache_kvs, 'pre_layer_norm',
'dropout_rate', dropout_rate, 'is_test', not training, pre_layer_norm, 'epsilon', epsilon, 'dropout_rate', dropout_rate,
'dropout_implementation', mode, 'act_method', activation, 'is_test', not training, 'dropout_implementation', mode,
'trans_qkvw', trans_qkvw, 'ring_id', ring_id) 'act_method', activation, 'trans_qkvw', trans_qkvw, 'ring_id',
ring_id)
if cache_kvs is not None: if cache_kvs is not None:
return final_out, cache_kv_out return final_out, cache_kv_out
return final_out return final_out
...@@ -857,6 +860,8 @@ def fused_multi_transformer(x, ...@@ -857,6 +860,8 @@ def fused_multi_transformer(x,
inputs['CacheKV'] = cache_kvs inputs['CacheKV'] = cache_kvs
if time_step is not None: if time_step is not None:
inputs['TimeStep'] = time_step inputs['TimeStep'] = time_step
if pre_caches is not None:
inputs['PreCaches'] = pre_caches
inputs['SrcMask'] = attn_mask inputs['SrcMask'] = attn_mask
inputs['OutLinearW'] = linear_weights inputs['OutLinearW'] = linear_weights
if linear_biases is not None: if linear_biases is not None:
......
...@@ -1242,7 +1242,12 @@ class FusedMultiTransformer(Layer): ...@@ -1242,7 +1242,12 @@ class FusedMultiTransformer(Layer):
self.activation = activation self.activation = activation
self.name = name self.name = name
def forward(self, src, attn_mask=None, caches=None, time_step=None): def forward(self,
src,
attn_mask=None,
caches=None,
pre_caches=None,
time_step=None):
""" """
Applies multi transformer layers on the input. Applies multi transformer layers on the input.
...@@ -1260,6 +1265,8 @@ class FusedMultiTransformer(Layer): ...@@ -1260,6 +1265,8 @@ class FusedMultiTransformer(Layer):
tensors for the inference generation model. It is only used for tensors for the inference generation model. It is only used for
inference and should be None for training. The shape is inference and should be None for training. The shape is
`[2, batch_size, num_head, max_seq_len, head_dim]`. Default None. `[2, batch_size, num_head, max_seq_len, head_dim]`. Default None.
pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches
for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None.
time_step (Tensor, optional): The time step tensor for the generation time_step (Tensor, optional): The time step tensor for the generation
model. Which used in decode stage, to represent the time step, model. Which used in decode stage, to represent the time step,
that is, the real seq_len of CacheKV. The shape is `[1]`, must be that is, the real seq_len of CacheKV. The shape is `[1]`, must be
...@@ -1292,6 +1299,7 @@ class FusedMultiTransformer(Layer): ...@@ -1292,6 +1299,7 @@ class FusedMultiTransformer(Layer):
pre_layer_norm=self.normalize_before, pre_layer_norm=self.normalize_before,
epsilon=self._epsilon, epsilon=self._epsilon,
cache_kvs=caches, cache_kvs=caches,
pre_caches=pre_caches,
time_step=time_step, time_step=time_step,
attn_mask=attn_mask, attn_mask=attn_mask,
dropout_rate=self.dropout_rate, dropout_rate=self.dropout_rate,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册