diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc index cb5d5b17dfeb6f5aee1b51ec5f49bdb394325309..09c3dfe24c13eba6649a4563d2ebe266ae9550fe 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc @@ -143,12 +143,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { "head %d, but got %d", trans_qkvw ? y_dim[1] : y_dim[2], c_dim[2])); // num_head - PADDLE_ENFORCE_GT( - c_dim[3], - 0, - paddle::platform::errors::InvalidArgument( - "The forth dim of CacheKV must be greater than 0, but got %d", - c_dim[3])); // cache_seq_len PADDLE_ENFORCE_EQ(c_dim[4], trans_qkvw ? y_dim[2] : y_dim[3], paddle::platform::errors::InvalidArgument( @@ -199,6 +193,10 @@ class FusedMultiTransformerOpOpMaker AddInput("CacheKV", "(optional) The cached KV for generation inference.") .AsDispensable() .AsDuplicable(); + AddInput("PreCaches", + "(optional) The prefix caches for generation inference.") + .AsDispensable() + .AsDuplicable(); AddInput("TimeStep", "(optional, int) The time step for generation inference.") .AsDispensable(); diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index 01464b7241655f7710947785265e760a7c8d0f5a..1274e247e696b35bd7d754eeed8ef3d783fc8e66 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -80,6 +80,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto cache_kvs = ctx.MultiInput("CacheKV"); auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); // auto *time_step = ctx.Input("TimeStep"); + auto pre_caches = ctx.MultiInput("PreCaches"); + int cache_offset = 0; + if (pre_caches.size() > 0) { + cache_offset = pre_caches[0]->dims()[3]; + } auto out_seq_len = seq_len; if (time_step) { @@ -101,6 +106,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { "In decode stage, the seq_len of input must be 1, but now is %d", seq_len)); out_seq_len += time_step_value; + } else { + out_seq_len += cache_offset; } Tensor transpose_out_2, qk_out; @@ -110,6 +117,22 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); + Tensor src_mask_out; + if (cache_offset > 0) { + src_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *src_mask_out_data = + dev_ctx.Alloc(&src_mask_out, src_mask_out.numel() * sizeof(T)); + } + + // [2, bs, num_head, cache_seq_len + seq_len, head_dim] + Tensor pre_cache_kv_out; + if (cache_offset > 0) { + pre_cache_kv_out.Resize( + {{2, bsz, num_head, seq_len + cache_offset, dim_head}}); + auto *pre_cache_kv_out_data = dev_ctx.Alloc( + &pre_cache_kv_out, pre_cache_kv_out.numel() * sizeof(T)); + } + Tensor softmax_out; Tensor attn_dropout_mask_out, attn_dropout_out; Tensor qktv_out, fmha_out; @@ -277,26 +300,42 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { time_step->data()[0], 1. / sqrt(dim_head)); } else if (cache_kv_out) { // generation context stage - // TODO(wangxi): can remove dropout in inference + const Tensor *pre_cache_kv_tensor = + pre_caches.size() > 0 ? pre_caches[i] : nullptr; + Tensor *pre_cache_kv_out_tmp = + cache_offset > 0 ? &pre_cache_kv_out : nullptr; + Tensor *src_mask_tmp = cache_offset > 0 ? &src_mask_out : nullptr; fmha_compute.ComputeForward(qkv_out, - nullptr, + pre_cache_kv_tensor, src_mask, &transpose_out_2, - nullptr, + pre_cache_kv_out_tmp, &qk_out, - nullptr, + src_mask_tmp, &softmax_out, &attn_dropout_mask_out, &attn_dropout_out, &qktv_out, &fmha_out); - // [3, bsz, num_head, seq_len, head_dim] - T *qkv_data = transpose_out_2_data; - int64_t q_size = bsz * seq_len * num_head * dim_head; - int64_t k_size = q_size; - const T *q_ptr = qkv_data; - const T *k_ptr = q_ptr + q_size; - const T *v_ptr = k_ptr + k_size; + + const T *k_ptr = nullptr; + const T *v_ptr = nullptr; + + if (cache_offset > 0) { + // [2, bsz, num_head, cache_offset + seq_len, head_dim] + const T *kv_data = pre_cache_kv_out.data(); + k_ptr = kv_data; + int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head; + v_ptr = k_ptr + k_size; + } else { + // [3, bsz, num_head, seq_len, head_dim] + T *qkv_data = transpose_out_2_data; + int64_t q_size = bsz * seq_len * num_head * dim_head; + int64_t k_size = q_size; + const T *q_ptr = qkv_data; + k_ptr = q_ptr + q_size; + v_ptr = k_ptr + k_size; + } // [2, bsz, num_head, max_seq_len, head_dim] int max_seq_len = cache_kv_out->dims()[3]; @@ -306,6 +345,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { T *cache_k_ptr = cache_kv_data; T *cache_v_ptr = cache_kv_data + cache_k_size; + const int seq_len_tmp = seq_len + cache_offset; write_cache_kv(dev_ctx, cache_k_ptr, cache_v_ptr, @@ -313,7 +353,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { v_ptr, bsz, num_head, - seq_len, + seq_len_tmp, max_seq_len, dim_head); } else { // not generation diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index af080bd0b3431c36d1ded54508008f4c78cea71f..21f97ecb48ca4dd3d24e33a43d81a83ef2609c36 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -61,6 +61,7 @@ std::map> op_ins_map = { "QKVW", "QKVBias", "CacheKV", + "PreCaches", "TimeStep", "SrcMask", "OutLinearW", diff --git a/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py index 47851c895fa63d661a9457aa3d6e7ad1182b2963..b91081aa89a6f3edc9d32faa83006c1387131b4e 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py @@ -34,6 +34,7 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.framework import _non_static_mode, default_main_program from paddle import _C_ops, _legacy_C_ops from paddle.incubate.nn.functional import fused_multi_transformer +from paddle.incubate.nn import FusedMultiTransformer default_main_program().random_seed = 42 @@ -114,6 +115,7 @@ class TestFusedMultiTransformerOp(OpTest): # True, False, generation decoder stage self.has_cache_kv = False self.gen_cache_kv = False + self.has_pre_cache = False self.training = False @@ -121,6 +123,7 @@ class TestFusedMultiTransformerOp(OpTest): self.batch_size = 8 self.query_length = 128 self.cache_length = 128 + self.pre_cache_num = 64 self.head_dim = 64 self.num_heads = 16 self.embed_dim = self.head_dim * self.num_heads @@ -150,6 +153,12 @@ class TestFusedMultiTransformerOp(OpTest): else: self.cache_kv = None + if self.has_pre_cache: + out_seq_len += self.pre_cache_num + self.pre_cache_kv = np.random.rand( + 2, self.batch_size, self.num_heads, self.pre_cache_num, + self.head_dim).astype(self.x_type) + if self.has_attn_mask: # [B, n_head, seq_len, out_seq_len] self.attn_mask = np.ones( @@ -188,6 +197,10 @@ class TestFusedMultiTransformerOp(OpTest): if self.has_cache_kv: cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False) + if self.has_pre_cache: + pre_cache_kv = paddle.to_tensor(self.pre_cache_kv, + stop_gradient=False) + if self.has_attn_mask: attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) else: @@ -227,6 +240,13 @@ class TestFusedMultiTransformerOp(OpTest): k_out = paddle.concat([cache_k, k_out], axis=-2) v_out = paddle.concat([cache_v, v_out], axis=-2) + if self.has_pre_cache: + pre_cache_k, pre_cache_v = paddle.split(pre_cache_kv, 2) + pre_cache_k = paddle.squeeze(pre_cache_k, axis=0) + pre_cache_v = paddle.squeeze(pre_cache_v, axis=0) + k_out = paddle.concat([pre_cache_k, k_out], axis=-2) + v_out = paddle.concat([pre_cache_v, v_out], axis=-2) + # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim] # --> [B, n_head, seq_len, out_seq_len] qk_out = layers.matmul(x=q_out, @@ -348,6 +368,7 @@ class TestFusedMultiTransformerOp(OpTest): x = paddle.to_tensor(self.query, stop_gradient=False) cache_kvs, cache_kv = None, None time_step = None + pre_caches, pre_cache = None, None if self.has_cache_kv: cache_kvs = [] @@ -387,6 +408,18 @@ class TestFusedMultiTransformerOp(OpTest): time_step = paddle.to_tensor([self.cache_length], dtype='int32', place=paddle.CPUPlace()) + + if self.has_pre_cache: + cache_kvs = [] + max_seq_length = (self.cache_length + + 128) // 128 * 128 + self.pre_cache_num + cache_kv = np.zeros([ + 2, self.batch_size, self.num_heads, max_seq_length, + self.head_dim + ], + dtype=self.x_type) + pre_caches = [] + if self.has_attn_mask: attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) else: @@ -420,6 +453,11 @@ class TestFusedMultiTransformerOp(OpTest): if self.has_cache_kv: cache_kvs.append(paddle.to_tensor(cache_kv, stop_gradient=False)) + if self.has_pre_cache: + cache_kvs.append(paddle.to_tensor(cache_kv, + stop_gradient=False)) + pre_caches.append( + paddle.to_tensor(self.pre_cache_kv, stop_gradient=False)) final_out = fused_multi_transformer(x, ln_scales, @@ -437,6 +475,7 @@ class TestFusedMultiTransformerOp(OpTest): pre_layer_norm=self.pre_layer_norm, epsilon=epsilon, cache_kvs=cache_kvs, + pre_caches=pre_caches, time_step=time_step, attn_mask=attn_mask, dropout_rate=self.dropout_prob, @@ -445,8 +484,159 @@ class TestFusedMultiTransformerOp(OpTest): if self.has_cache_kv: return final_out[0], final_out[1] + if self.has_pre_cache: + return final_out[0] + return final_out + def GetFusedMultiTransformerOutStatic(self): + paddle.enable_static() + x = paddle.fluid.data('x', self.query.shape, self.query.dtype) + cache_kvs, cache_kv = None, None + time_step = None + time_step_feed = None + pre_caches, pre_cache = None, None + if self.has_cache_kv: + cache_kvs = [] + + max_seq_length = (self.cache_length + 128) // 128 * 128 + cache_kv = np.zeros([ + 2, self.batch_size, self.num_heads, max_seq_length, + self.head_dim + ], + dtype=self.x_type) + + elems = 4 + if self.x_type is np.float16: + elems = 8 + + assert self.head_dim % elems == 0 + v_elems = self.head_dim // elems + cache_k_tmp = self.cache_kv[0].reshape([ + self.batch_size, self.num_heads, self.cache_length, v_elems, + elems + ]) + # [B, num_head, head_dim / 4, 128, 4] + cache_k_tmp = cache_k_tmp.transpose([0, 1, 3, 2, 4]) + + cache_kv[0, :].reshape([ + self.batch_size, self.num_heads, v_elems, max_seq_length, elems + ])[:, :, :, :self.cache_length, :] = cache_k_tmp + + cache_kv[1, :, :, :self.cache_length, :] = self.cache_kv[1] + if self.gen_cache_kv: + assert self.query_length == self.cache_length + cache_kv[:] = 0 + else: + time_step = layers.fill_constant(shape=[1], + dtype="int32", + value=0, + force_cpu=True) + time_step_feed = self.cache_length + + if self.has_pre_cache: + cache_kvs = [] + max_seq_length = (self.cache_length + + 128) // 128 * 128 + self.pre_cache_num + cache_kv = np.zeros([ + 2, self.batch_size, self.num_heads, max_seq_length, + self.head_dim + ], + dtype=self.x_type) + pre_caches = [] + + attn_mask = None + epsilon = 1e-05 + ln2_epsilon = 1e-05 + + qkv_weights_attr, qkv_biases_attr = [], [] + out_weights_attr, out_biases_attr = [], [] + ln_scales_attr, ln_biases_attr = [], [] + ffn1_weights_attr, ffn1_biases_attr = [], [] + ffn2_weights_attr, ffn2_biases_attr = [], [] + ffn_ln_scales_attr, ffn_ln_biases_attr = [], [] + + if self.has_cache_kv: + cache_kvs_feed = [] + if self.has_pre_cache: + cache_kvs_feed = [] + pre_caches_feed = [] + + for i in range(self.layers): + qkv_weights_attr.append(self.weight_attr) + qkv_biases_attr.append(self.bias_attr) + out_weights_attr.append(self.weight_attr) + out_biases_attr.append(self.bias_attr) + ln_scales_attr.append(self.ln_w_attr) + ln_biases_attr.append(self.ln_b_attr) + ffn1_weights_attr.append(self.weight_attr) + ffn1_biases_attr.append(self.bias_attr) + ffn2_weights_attr.append(self.weight_attr) + ffn2_biases_attr.append(self.bias_attr) + ffn_ln_scales_attr.append(self.ln_w_attr) + ffn_ln_biases_attr.append(self.ln_b_attr) + + transformer = FusedMultiTransformer( + self.embed_dim, + self.num_heads, + 4 * self.embed_dim, + self.dropout_prob, + normalize_before=self.pre_layer_norm, + ln_scale_attrs=ln_scales_attr, + ln_bias_attrs=ln_biases_attr, + qkv_weight_attrs=qkv_weights_attr, + qkv_bias_attrs=qkv_biases_attr, + linear_weight_attrs=out_weights_attr, + linear_bias_attrs=out_biases_attr, + ffn_ln_scale_attrs=ffn_ln_scales_attr, + ffn_ln_bias_attrs=ffn_ln_biases_attr, + ffn1_weight_attrs=ffn1_weights_attr, + ffn1_bias_attrs=ffn1_biases_attr, + ffn2_weight_attrs=ffn2_weights_attr, + ffn2_bias_attrs=ffn2_biases_attr) + + transformer.eval() + + for i in range(self.layers): + if self.has_cache_kv: + cache_kvs.append( + layers.fill_constant(shape=cache_kv.shape, + dtype=cache_kv.dtype, + value=0)) + cache_kvs_feed.append(cache_kv) + + if self.has_pre_cache: + cache_kvs.append( + layers.fill_constant(shape=cache_kv.shape, + dtype=cache_kv.dtype, + value=0)) + cache_kvs_feed.append(cache_kv) + pre_caches.append( + layers.fill_constant(shape=self.pre_cache_kv.shape, + dtype=self.pre_cache_kv.dtype, + value=0)) + pre_caches_feed.append(self.pre_cache_kv) + + final_out = transformer(x, + attn_mask=attn_mask, + caches=cache_kvs, + pre_caches=pre_caches, + time_step=time_step)[0] + exe = paddle.static.Executor(place=paddle.CUDAPlace(0)) + exe.run(paddle.static.default_startup_program()) + feed_data = { + 'x': self.query, + 'cache_kvs': cache_kvs_feed, + 'pre_caches': pre_caches_feed, + 'time_step': time_step_feed, + 'attn_mask': attn_mask + } + out = exe.run(paddle.fluid.default_main_program(), + feed=feed_data, + fetch_list=[final_out]) + paddle.disable_static() + return out[0] + def test_fused_multi_transformer_op(self): final_out_ref = self.GetBaselineOut() final_out = self.GetFusedMultiTransformerOut() @@ -603,5 +793,39 @@ class TestFusedMultiTransformerOpGenCacheKVPostLayerNormFp16( self.pre_layer_norm = False +class TestFusedMultiTransformerOpPreCache(TestFusedMultiTransformerOp): + + def config(self): + super().config() + self.has_pre_cache = True + self.x_type = np.float16 + + +class TestFusedMultiTransformerOpPreCacheStatic(TestFusedMultiTransformerOp): + + def config(self): + super().config() + self.has_pre_cache = True + self.has_attn_mask = False + self.x_type = np.float32 + self.weight_attr = paddle.ParamAttr( + initializer=paddle.fluid.initializer.Constant(0.)) + self.bias_attr = paddle.ParamAttr( + initializer=paddle.fluid.initializer.Constant(0.0005)) + self.ln_w_attr = paddle.ParamAttr( + initializer=paddle.fluid.initializer.Constant(1.)) + self.ln_b_attr = paddle.ParamAttr( + initializer=paddle.fluid.initializer.Constant(0.)) + + def test_fused_multi_transformer_op(self): + final_out_ref = self.GetBaselineOut() + final_out = self.GetFusedMultiTransformerOutStatic() + + np.testing.assert_allclose(final_out_ref, + final_out, + rtol=self.rtol, + atol=self.atol) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index b1d759b6953c362e8171e624f734b6a2bbacab46..26ac7f349246e58a4e6cdf545e9e34b84d273ded 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -674,6 +674,7 @@ def fused_multi_transformer(x, pre_layer_norm=True, epsilon=1e-05, cache_kvs=None, + pre_caches=None, time_step=None, attn_mask=None, dropout_rate=0.0, @@ -739,6 +740,7 @@ def fused_multi_transformer(x, pre_layer_norm (bool, optional): whether it is pre_layer_norm(True) or post_layer_norm(False). Default True. epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5. cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None. + pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None. attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor @@ -826,12 +828,13 @@ def fused_multi_transformer(x, if _non_static_mode(): cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer( x, ln_scales, ln_biases, qkv_weights, qkv_biases, cache_kvs, - time_step, attn_mask, linear_weights, linear_biases, ffn_ln_scales, - ffn_ln_biases, ffn1_weights, ffn1_biases, ffn2_weights, ffn2_biases, - cache_kvs, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, - 'dropout_rate', dropout_rate, 'is_test', not training, - 'dropout_implementation', mode, 'act_method', activation, - 'trans_qkvw', trans_qkvw, 'ring_id', ring_id) + pre_caches, time_step, attn_mask, linear_weights, linear_biases, + ffn_ln_scales, ffn_ln_biases, ffn1_weights, ffn1_biases, + ffn2_weights, ffn2_biases, cache_kvs, 'pre_layer_norm', + pre_layer_norm, 'epsilon', epsilon, 'dropout_rate', dropout_rate, + 'is_test', not training, 'dropout_implementation', mode, + 'act_method', activation, 'trans_qkvw', trans_qkvw, 'ring_id', + ring_id) if cache_kvs is not None: return final_out, cache_kv_out return final_out @@ -857,6 +860,8 @@ def fused_multi_transformer(x, inputs['CacheKV'] = cache_kvs if time_step is not None: inputs['TimeStep'] = time_step + if pre_caches is not None: + inputs['PreCaches'] = pre_caches inputs['SrcMask'] = attn_mask inputs['OutLinearW'] = linear_weights if linear_biases is not None: diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index 3af26db37a029fa227fdca53e3b62ebf9259a1a2..fefef4f37c7fd0a4e69662dd1c4cbf8d0efc2f04 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -1242,7 +1242,12 @@ class FusedMultiTransformer(Layer): self.activation = activation self.name = name - def forward(self, src, attn_mask=None, caches=None, time_step=None): + def forward(self, + src, + attn_mask=None, + caches=None, + pre_caches=None, + time_step=None): """ Applies multi transformer layers on the input. @@ -1260,6 +1265,8 @@ class FusedMultiTransformer(Layer): tensors for the inference generation model. It is only used for inference and should be None for training. The shape is `[2, batch_size, num_head, max_seq_len, head_dim]`. Default None. + pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches + for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be @@ -1292,6 +1299,7 @@ class FusedMultiTransformer(Layer): pre_layer_norm=self.normalize_before, epsilon=self._epsilon, cache_kvs=caches, + pre_caches=pre_caches, time_step=time_step, attn_mask=attn_mask, dropout_rate=self.dropout_rate,