diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 3c9e16785eac814ffba34455d635d798042cdf43..54e4cbdc1624921e6946210a6a192d10fcbdb7dd 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/transpose_op.cu.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" namespace paddle { @@ -69,20 +70,21 @@ class FMHARef { ~FMHARef() {} void ComputeForward(const Tensor& qkv_input_tensor, + const Tensor* cache_kv_tensor, const Tensor* src_mask_tensor, - Tensor* transpose_2_out_tensor, Tensor* qk_out_tensor, + Tensor* transpose_2_out_tensor, + Tensor* cache_kv_out_tensor, Tensor* qk_out_tensor, Tensor* src_mask_out_tensor, Tensor* softmax_out_tensor, Tensor* dropout_mask_out_tensor, Tensor* dropout_out_tensor, Tensor* qktv_out_tensor, Tensor* fmha_out_tensor) { // input shape: [bs, seq_len, 3, num_head, head_dim] - // transpose with perm [2, 0, 1, 3, 4], + // transpose with perm [2, 0, 3, 1, 4], // output_shape: [3, bs, num_head, seq_len, head_dim] int ndims = 5; std::vector perm_1 = {2, 0, 3, 1, 4}; TransposeGPUKernelDriver(dev_ctx_, ndims, qkv_input_tensor, perm_1, transpose_2_out_tensor); - T* qkv_data = transpose_2_out_tensor->data(); T* qk_out_data = qk_out_tensor->data(); T* qktv_out_data = qktv_out_tensor->data(); @@ -90,11 +92,30 @@ class FMHARef { T* dropout_out_data = dropout_out_tensor->data(); T* fmha_out_data = fmha_out_tensor->data(); - int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; - int k_size = q_size; + auto out_seq_len = seq_len_; + if (cache_kv_tensor) { + // kv [2, bs, num_head, seq_len, head_dim] + auto kv_tensor = transpose_2_out_tensor->Slice(1, 3); + phi::funcs::ConcatFunctor concat; + // out [2, bs, num_head, cache_seq_len + seq_len, head_dim] + concat(dev_ctx_, {*cache_kv_tensor, kv_tensor}, 3, cache_kv_out_tensor); + out_seq_len = cache_kv_out_tensor->dims()[3]; + } + + int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; T* q_ptr = qkv_data; - T* k_ptr = q_ptr + q_size; - T* v_ptr = k_ptr + k_size; + T* k_ptr = nullptr; + T* v_ptr = nullptr; + + if (cache_kv_tensor) { + int64_t k_size = cache_kv_out_tensor->numel() / 2; + k_ptr = cache_kv_out_tensor->data(); + v_ptr = k_ptr + k_size; + } else { + int64_t k_size = q_size; + k_ptr = q_ptr + q_size; + v_ptr = k_ptr + k_size; + } // q*k^t, batched_gemm CBLAS_TRANSPOSE transA = CblasNoTrans; @@ -102,7 +123,7 @@ class FMHARef { auto blas = phi::funcs::GetBlas(dev_ctx_); int gemm_batch_size = batch_size_ * num_head_; int gemm_m = seq_len_; - int gemm_n = seq_len_; + int gemm_n = out_seq_len; int gemm_k = head_dim_; T alpha = static_cast(1.0 / sqrt(head_dim_)); T beta = static_cast(0.0); @@ -133,7 +154,7 @@ class FMHARef { transB = CblasNoTrans; gemm_m = seq_len_; gemm_n = head_dim_; - gemm_k = seq_len_; + gemm_k = out_seq_len; alpha = static_cast(1.0); stride_a = gemm_m * gemm_k; stride_b = gemm_k * gemm_n; diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index d141800d61c0ec0b73fe2cc3c8d00dbf1de44cf2..e473f8ff0662cfc3fd7bdc5010bfa1dc08fba85f 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -61,6 +61,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasOutput("QKTVOut"), "Output", "QKTVOut", "FusedAttentionOp"); + if (ctx->HasInput("CacheKV")) { + OP_INOUT_CHECK(ctx->HasOutput("CacheKVOut"), "Output", "CacheKVOut", + "FusedAttentionOp"); + } if (ctx->HasInput("SrcMask")) { OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut", "FusedAttentionOp"); @@ -105,12 +109,14 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "input qkv_weight = [%s]", x_dim, y_dim)); - PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], y_dim[3], - platform::errors::InvalidArgument( - "The dimensions of qkv_weight must be 4" - "(3, num_head, dim_head, dim_embed)," - "and must satisfy the limitations: " - "(num_head * dim_head == dim_embed)")); + if (ctx->Attrs().Get("ring_id") == -1) { + PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], y_dim[3], + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "and must satisfy the limitations: " + "(num_head * dim_head == dim_embed)")); + } if (ctx->Attrs().Get("pre_layer_norm") == true) { ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); @@ -132,20 +138,64 @@ class FusedAttentionOp : public framework::OperatorWithKernel { // [3, batch_size, num_head, seq_len, head_size] ctx->SetOutputDim("TransposeOut2", {y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); - // [batch, num_head, seq_len, seq_len] - ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + + // cache_seq_len + seq_len if cache else seq_len + auto out_seq_len = x_dim[1]; + if (ctx->HasInput("CacheKV")) { + // [2, batch_size, num_head, cache_seq_len, head_size] + auto c_dim = ctx->GetInputDim("CacheKV"); + + PADDLE_ENFORCE_EQ( + c_dim.size(), 5, + paddle::platform::errors::InvalidArgument( + "The CacheKV must be 5 dims, but got %d", c_dim.size())); + PADDLE_ENFORCE_EQ(c_dim[0], 2, + paddle::platform::errors::InvalidArgument( + "The first dim of CacheKV must be 2, but got %d", + c_dim[0])); // 2 + PADDLE_ENFORCE_EQ(c_dim[1], x_dim[0], + paddle::platform::errors::InvalidArgument( + "The second dim of CacheKV must be equal with " + "batch size %d, but got %d", + x_dim[0], c_dim[1])); // batch_size + PADDLE_ENFORCE_EQ(c_dim[2], y_dim[1], + paddle::platform::errors::InvalidArgument( + "The third dim of CacheKV must be equal with num " + "head %d, but got %d", + y_dim[1], c_dim[2])); // num_head + PADDLE_ENFORCE_GE( + c_dim[3], 0, + paddle::platform::errors::InvalidArgument( + "The forth dim of CacheKV must be greater than 0, but got %d", + c_dim[3])); // cache_seq_len + PADDLE_ENFORCE_EQ(c_dim[4], y_dim[2], + paddle::platform::errors::InvalidArgument( + "The fifth dim of CacheKV must be equal with head " + "size %d, but got %d", + y_dim[2], c_dim[4])); // head_size + + out_seq_len += c_dim[3]; + // [3, batch_size, num_head, cache_seq_len + seq_len, head_size] + ctx->SetOutputDim("CacheKVOut", + {c_dim[0], c_dim[1], c_dim[2], out_seq_len, c_dim[4]}); + } + + // [batch, num_head, seq_len, out_seq_len] + ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], out_seq_len}); if (ctx->HasInput("SrcMask")) { - ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + ctx->SetOutputDim("SrcMaskOut", + {x_dim[0], y_dim[1], x_dim[1], out_seq_len}); } // the same as QKOut's shape. ctx->SetOutputDim("AttnDropoutOut", - {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + {x_dim[0], y_dim[1], x_dim[1], out_seq_len}); if (ctx->Attrs().Get("attn_dropout_is_test") == false) { ctx->SetOutputDim("AttnDropoutMaskOut", - {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + {x_dim[0], y_dim[1], x_dim[1], out_seq_len}); } - ctx->SetOutputDim("SoftmaxOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + ctx->SetOutputDim("SoftmaxOut", + {x_dim[0], y_dim[1], x_dim[1], out_seq_len}); // [batch_size, num_heads, seq_len, head_dim] ctx->SetOutputDim("QKTVOut", {x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); // [batch_size, seq_len, number of heads*head size] @@ -182,6 +232,8 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { .AsDispensable(); AddInput("QKVW", "The qkv weight tensor."); AddInput("QKVBias", "The qkv bias tensor.").AsDispensable(); + AddInput("CacheKV", "(optional) The cached KV for generation inference.") + .AsDispensable(); AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") .AsDispensable(); AddInput("OutLinearW", "The out_linear weight tensor."); @@ -217,6 +269,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("BiasDropoutResidualOut", "Result of residual + dropout(src + bias).") .AsIntermediate(); + AddOutput("CacheKVOut", "The udpated cache KV."); AddOutput("Y", "Result after attention."); AddAttr("pre_layer_norm", @@ -324,6 +377,10 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "0.0 and 0.001, But received [%s].", ln_epsilon)); }); + AddAttr( + "ring_id", + "ring id for tensor model parallel. distributed training and inference") + .SetDefault(-1); AddComment(R"DOC( Add fused attention op whose logic is as follows: diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 03f51fc5857985902c21ad12fefbdc9cdec6ef04..d26577f06fe683fb1528c61b4401b9e578c90c9f 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -27,11 +27,39 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fmha_ref.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h" +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#endif + namespace paddle { namespace operators { using Tensor = framework::Tensor; +template +static void AllReduce(framework::Tensor &tensor, // NOLINT + const int ring_id, + const platform::CUDADeviceContext &ctx) { + if (ring_id == -1) return; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto dtype = + platform::ToNCCLDataType(framework::TransToProtoVarType(tensor.dtype())); + int64_t numel = tensor.numel(); + const void *sendbuff = tensor.data(); + auto place = ctx.GetPlace(); + void *recvbuff = tensor.mutable_data(place); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + auto stream = ctx.stream(); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream)); +#else + PADDLE_THROW(platform::errors::Unimplemented( + "PaddlePaddle should compile with NCCL or RCCL when used tensor model " + "parallel op.")); +#endif +} + template class FusedAttentionOpKernel : public framework::OpKernel { public: @@ -56,6 +84,8 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *src_mask = ctx.Input("SrcMask"); auto *transpose_out_2 = ctx.Output("TransposeOut2"); + auto *cache_kv = ctx.Input("CacheKV"); + auto *cache_kv_out = ctx.Output("CacheKVOut"); auto *qk_out = ctx.Output("QKOut"); auto *qktv_out = ctx.Output("QKTVOut"); auto *softmax_out = ctx.Output("SoftmaxOut"); @@ -86,6 +116,7 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; bool is_fix_seed_1 = ctx.Attr("attn_dropout_fix_seed"); int seed_val_1 = ctx.Attr("attn_dropout_seed"); + int ring_id = ctx.Attr("ring_id"); // final output. auto *out = ctx.Output("Y"); @@ -105,6 +136,10 @@ class FusedAttentionOpKernel : public framework::OpKernel { // get data ptr for FMHA. auto *transpose_out_2_data = transpose_out_2->mutable_data(ctx.GetPlace()); + auto *cache_kv_out_data = + (cache_kv_out == nullptr) + ? nullptr + : cache_kv_out->mutable_data(ctx.GetPlace()); auto *qk_out_data = qk_out->mutable_data(ctx.GetPlace()); auto *qktv_out_data = qktv_out->mutable_data(ctx.GetPlace()); auto *src_mask_out_data = @@ -161,9 +196,14 @@ class FusedAttentionOpKernel : public framework::OpKernel { output_size = hidden_size; // (transA, transB, compute_bias) = (false, false, false) + // NOTE(Yuang Liu): For general input size == output size, change the + // position won't have effects. For mp, the output size is mp_head * dkey + // which is actually the input size. While the input size is hidden size, + // which is actually the output size. So for out linear, switch the + // input size and output size. auto out_linear_compute = AttnMatMul(ctx.cuda_device_context(), false, false, bsz_seq, - output_size, input_size, false); + input_size, output_size, false); DropoutParam dropout_param2(ctx, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, @@ -186,15 +226,15 @@ class FusedAttentionOpKernel : public framework::OpKernel { qkv_bias_out); } if (qkv_bias == nullptr) { - fmha_ref_compute.ComputeForward(*qkv_out, src_mask, transpose_out_2, - qk_out, src_mask_out, softmax_out, - attn_dropout_mask_out, attn_dropout_out, - qktv_out, fmha_out); + fmha_ref_compute.ComputeForward( + *qkv_out, cache_kv, src_mask, transpose_out_2, cache_kv_out, qk_out, + src_mask_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, + qktv_out, fmha_out); } else { - fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2, - qk_out, src_mask_out, softmax_out, - attn_dropout_mask_out, attn_dropout_out, - qktv_out, fmha_out); + fmha_ref_compute.ComputeForward( + *qkv_bias_out, cache_kv, src_mask, transpose_out_2, cache_kv_out, + qk_out, src_mask_out, softmax_out, attn_dropout_mask_out, + attn_dropout_out, qktv_out, fmha_out); } // fmha_out: [batch_size, seq_len, num_head, head_dim] @@ -202,6 +242,9 @@ class FusedAttentionOpKernel : public framework::OpKernel { // out_linear_out: [batch_size, seq_len, embed_dim] out_linear_compute.ComputeForward(out_linear_weight, fmha_out, nullptr, out_linear_out, nullptr); + // tensor model parallel + AllReduce(*out_linear_out, ring_id, ctx.cuda_device_context()); + if (pre_layer_norm) { // output = (residual + dropout(input + bias)) fused_dropout_layernorm_helper.ResidualDropoutBias( @@ -244,6 +287,7 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; bool is_fix_seed_1 = ctx.Attr("attn_dropout_fix_seed"); int seed_val_1 = ctx.Attr("attn_dropout_seed"); + int ring_id = ctx.Attr("ring_id"); // get inputs. auto *d_y = ctx.Input(framework::GradVarName("Y")); @@ -399,9 +443,10 @@ class FusedAttentionGradKernel : public framework::OpKernel { transA = false; transB = false; bool compute_bias = false; + // (b*s, num_head * dim_head) * (num_head * dim_head, dim_embed) auto out_linear_compute = AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, - output_size, input_size, compute_bias); + input_size, output_size, compute_bias); DropoutParam dropout_param2(ctx, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, @@ -475,6 +520,8 @@ class FusedAttentionGradKernel : public framework::OpKernel { qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_out, d_ln_out, d_qkv_weight, d_qkv_bias); } + // tensor model parallel + AllReduce(*d_ln_out, ring_id, ctx.cuda_device_context()); layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data, ln_mean_data, ln_var_data, d_x_data, d_ln_scale_data, d_ln_bias_data); @@ -486,6 +533,8 @@ class FusedAttentionGradKernel : public framework::OpKernel { qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_out, d_x, d_qkv_weight, d_qkv_bias); } + // tensor model parallel + AllReduce(*d_x, ring_id, ctx.cuda_device_context()); } // gradient accumulation std::vector ins; diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index d23b3dd64ab05cf10d8096a84e317645972211d1..9e86e3df8a6884ec1b75b8525ad858ff8f2e233c 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -30,8 +30,8 @@ std::map> op_ins_map = { {"layer_norm", {"X", "Scale", "Bias"}}, {"bincount", {"X", "Weights"}}, {"fused_attention", - {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "SrcMask", "OutLinearW", - "OutLinearBias", "Ln2Scale", "Ln2Bias"}}, + {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "CacheKV", "SrcMask", + "OutLinearW", "OutLinearBias", "Ln2Scale", "Ln2Bias"}}, {"instance_norm", {"X", "Scale", "Bias"}}, {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}}, {"label_smooth", {"X", "PriorDist"}}, @@ -104,11 +104,16 @@ std::map> op_outs_map = { {"batch_norm", {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", "ReserveSpace"}}, - {"fused_attention", - {"LnMean", "LnVariance", "LnOut", "QKVOut", "QKVBiasOut", "TransposeOut2", - "QKOut", "QKTVOut", "SoftmaxOut", "AttnDropoutMaskOut", "AttnDropoutOut", - "SrcMaskOut", "FMHAOut", "OutLinearOut", "DropoutMaskOut", "Ln2Mean", - "Ln2Variance", "BiasDropoutResidualOut", "Y"}}, + {"fused_attention", {"LnMean", "LnVariance", + "LnOut", "QKVOut", + "QKVBiasOut", "TransposeOut2", + "QKOut", "QKTVOut", + "SoftmaxOut", "AttnDropoutMaskOut", + "AttnDropoutOut", "SrcMaskOut", + "FMHAOut", "OutLinearOut", + "DropoutMaskOut", "Ln2Mean", + "Ln2Variance", "BiasDropoutResidualOut", + "CacheKVOut", "Y"}}, {"sync_batch_norm", {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", "ReserveSpace"}}, diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 9b0c857576b8acc0d33cabc2525b56545cd3169e..e75b8d1f60bf7dbbfb500a464a3b591a0d1f7ed3 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -24,6 +24,7 @@ list(APPEND DIST_TEST_OPS test_pipeline) list(APPEND DIST_TEST_OPS test_ir_pass_pipeline) list(APPEND DIST_TEST_OPS test_static_model_parallel) list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_feedforward) +list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_attention) list(APPEND DIST_TEST_OPS test_parallel_dygraph_se_resnext) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height) @@ -1155,6 +1156,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) set_tests_properties(test_ir_pass_pipeline PROPERTIES TIMEOUT 120) set_tests_properties(test_static_model_parallel PROPERTIES TIMEOUT 240) set_tests_properties(test_static_model_parallel_fused_feedforward PROPERTIES TIMEOUT 120) + set_tests_properties(test_static_model_parallel_fused_attention PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_split_embedding test_collective_split_embedding_none_divisible test_collective_split_row_linear diff --git a/python/paddle/fluid/tests/unittests/static_model_parallel_fused_attention.py b/python/paddle/fluid/tests/unittests/static_model_parallel_fused_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..b57f26776234eb65a57cc65df2ccd5a6a38a2144 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/static_model_parallel_fused_attention.py @@ -0,0 +1,297 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np + +import paddle +import paddle.fluid as fluid +from test_dist_base import TestDistRunnerBase, runtime_main +import paddle.distributed.fleet as fleet +import paddle.incubate.nn.functional as incubate_f + +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from paddle.fluid.dygraph.layers import Layer +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid import core +from paddle.nn.initializer import Constant + +paddle.enable_static() + + +def _set_var_distributed(var): + if var is None: + return + + var.is_distributed = True + + # NOTE: use current_block and find_var_recursive to support while_loop + startup_block = paddle.static.default_startup_program().current_block() + main_block = paddle.static.default_main_program().current_block() + startup_block._find_var_recursive(var.name).is_distributed = True + main_block._find_var_recursive(var.name).is_distributed = True + + +class ParallelFusedMultiHeadAttention(Layer): + def __init__(self, + embed_dim, + num_heads, + dropout_rate=0.5, + attn_dropout_rate=0.5, + kdim=None, + vdim=None, + normalize_before=False, + need_weights=False, + qkv_weight_attr=None, + qkv_bias_attr=None, + linear_weight_attr=None, + linear_bias_attr=None, + pre_ln_scale_attr=None, + pre_ln_bias_attr=None, + ln_scale_attr=None, + ln_bias_attr=None, + epsilon=1e-5, + nranks=1, + ring_id=-1, + name=None): + super(ParallelFusedMultiHeadAttention, self).__init__() + + assert embed_dim > 0, ("Expected embed_dim to be greater than 0, " + "but recieved {}".format(embed_dim)) + assert num_heads > 0, ("Expected nhead to be greater than 0, " + "but recieved {}".format(num_heads)) + + self.normalize_before = normalize_before + self._dtype = self._helper.get_default_dtype() + self._epsilon = epsilon + self._ring_id = ring_id + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.kdim = kdim + self.vdim = vdim + self.need_weights = need_weights + assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" + assert need_weights == False, "Only support need_weight is False now." + + # tensor model parallel + assert num_heads % nranks == 0 + num_heads = num_heads // nranks + + self.qkv_weight = self.create_parameter( + shape=[3, num_heads, self.head_dim, embed_dim], + attr=qkv_weight_attr, + dtype=self._dtype, + is_bias=False) + self.qkv_bias = self.create_parameter( + shape=[3, num_heads, self.head_dim], + attr=qkv_bias_attr, + dtype=self._dtype, + is_bias=True) + self.linear_weight = self.create_parameter( + shape=[num_heads * self.head_dim, embed_dim], + attr=linear_weight_attr, + dtype=self._dtype, + is_bias=False) + self.linear_bias = self.create_parameter( + shape=[embed_dim], + attr=linear_bias_attr, + dtype=self._dtype, + is_bias=True) + + # tensor model parallel + if nranks > 1: + assert ring_id != -1 + # column parallel + _set_var_distributed(self.qkv_weight) + _set_var_distributed(self.qkv_bias) + # row parallel + _set_var_distributed(self.linear_weight) + + if normalize_before: + self.pre_ln_scale = self.create_parameter( + attr=pre_ln_scale_attr, + shape=[embed_dim], + default_initializer=Constant(value=1.0)) + self.pre_ln_bias = self.create_parameter( + attr=pre_ln_bias_attr, shape=[embed_dim], is_bias=True) + self.ln_scale = None + self.ln_bias = None + else: + self.pre_ln_scale = None + self.pre_ln_bias = None + self.ln_scale = self.create_parameter( + attr=ln_scale_attr, + shape=[embed_dim], + default_initializer=Constant(value=1.0)) + self.ln_bias = self.create_parameter( + attr=ln_bias_attr, shape=[embed_dim], is_bias=True) + + self.dropout_rate = dropout_rate + self.attn_dropout_rate = attn_dropout_rate + + self.name = name + + def forward(self, query, key=None, value=None, attn_mask=None, cache=None): + out = incubate_f.fused_multi_head_attention( + x=query, + qkv_weight=self.qkv_weight, + linear_weight=self.linear_weight, + pre_layer_norm=self.normalize_before, + pre_ln_scale=self.pre_ln_scale, + pre_ln_bias=self.pre_ln_bias, + ln_scale=self.ln_scale, + ln_bias=self.ln_bias, + pre_ln_epsilon=self._epsilon, + qkv_bias=self.qkv_bias, + linear_bias=self.linear_bias, + attn_mask=attn_mask, + dropout_rate=self.dropout_rate, + attn_dropout_rate=self.attn_dropout_rate, + ln_epsilon=self._epsilon, + training=self.training, + ring_id=self._ring_id, + name=self.name) + return out + + +def get_param_attr(weight, bias): + weight_attr = paddle.ParamAttr( + initializer=fluid.initializer.NumpyArrayInitializer(weight)) + bias_attr = paddle.ParamAttr( + initializer=fluid.initializer.NumpyArrayInitializer(bias)) + return weight_attr, bias_attr + + +DTYPE = "float32" +MODEL_PARALLEL_SIZE = 2 +n_head = 2 * MODEL_PARALLEL_SIZE +d_key = 4 +hidden = n_head * d_key + + +def create_model(data, rank): + np.random.seed(2021) + pre_ln_w = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE) + pre_ln_b = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE) + qkv_w = np.random.uniform( + -1, 1, size=(3, n_head, d_key, hidden)).astype(DTYPE) + qkv_b = np.random.uniform(-1, 1, size=(3, n_head, d_key)).astype(DTYPE) + linear_w = np.random.uniform( + -1, 1, size=(n_head * d_key, hidden)).astype(DTYPE) + linear_b = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE) + + data.stop_gradient = False + if rank is not None: + start = 0 if rank == 0 else n_head // MODEL_PARALLEL_SIZE + end = start + n_head // MODEL_PARALLEL_SIZE + col_qkv_w = qkv_w[:, start:end, :, :] + col_qkv_b = qkv_b[:, start:end, :] + row_linear_w = linear_w[(start * d_key):(end * d_key), :] + + pre_ln_w_attr, pre_ln_b_attr = get_param_attr(pre_ln_w, pre_ln_b) + qkv_w_attr, qkv_b_attr = get_param_attr(col_qkv_w, col_qkv_b) + linear_w_attr, linear_b_attr = get_param_attr(row_linear_w, linear_b) + + attn = ParallelFusedMultiHeadAttention( + hidden, + n_head, + dropout_rate=0.0, + attn_dropout_rate=0.0, + normalize_before=False, + qkv_weight_attr=qkv_w_attr, + qkv_bias_attr=qkv_b_attr, + linear_weight_attr=linear_w_attr, + linear_bias_attr=linear_b_attr, + pre_ln_scale_attr=pre_ln_w_attr, + pre_ln_bias_attr=pre_ln_b_attr, + ln_scale_attr=pre_ln_w_attr, + ln_bias_attr=pre_ln_b_attr, + nranks=MODEL_PARALLEL_SIZE, + ring_id=0) + result = attn(data) + else: + pre_ln_w_attr, pre_ln_b_attr = get_param_attr(pre_ln_w, pre_ln_b) + qkv_w_attr, qkv_b_attr = get_param_attr(qkv_w, qkv_b) + linear_w_attr, linear_b_attr = get_param_attr(linear_w, linear_b) + + attn = ParallelFusedMultiHeadAttention( + hidden, + n_head, + dropout_rate=0.0, + attn_dropout_rate=0.0, + normalize_before=False, + qkv_weight_attr=qkv_w_attr, + qkv_bias_attr=qkv_b_attr, + linear_weight_attr=linear_w_attr, + linear_bias_attr=linear_b_attr, + pre_ln_scale_attr=pre_ln_w_attr, + pre_ln_bias_attr=pre_ln_b_attr, + ln_scale_attr=pre_ln_w_attr, + ln_bias_attr=pre_ln_b_attr) + result = attn(data) + + predict = paddle.sum(result) + return predict + + +class TestModelParallel(TestDistRunnerBase): + def get_model(self, batch_size=2, use_dgc=False, dist_strategy=None): + # Input data + seq_len = 2 + data_in = fluid.data( + name='data_in', shape=[batch_size, seq_len, hidden], dtype=DTYPE) + + if dist_strategy: + data_loader = fluid.io.DataLoader.from_generator( + feed_list=[data_in], + capacity=64, + use_double_buffer=False, + iterable=False) + + if dist_strategy: + fleet.init(is_collective=True) + strategy = fleet.DistributedStrategy() + strategy.tensor_parallel = True + strategy.tensor_parallel_configs = {'tensor_parallel_degree': 2} + + rank = fleet.worker_index() if dist_strategy else None + avg_cost = create_model(data_in, rank) + opt = fluid.optimizer.SGD(0.1) + + if dist_strategy: + dist_opt = fleet.distributed_optimizer( + optimizer=opt, strategy=strategy) + dist_opt.minimize(avg_cost) + else: + opt.minimize(avg_cost) + + def gen_data(): + np.random.seed(2021) + while True: + data = [np.random.random([seq_len, hidden]).astype(DTYPE)] + yield data + + train_reader = paddle.batch(gen_data, batch_size=batch_size) + + if dist_strategy: + return None, avg_cost, train_reader, None, None, None, data_loader + else: + return None, avg_cost, train_reader, None, None, None + + +if __name__ == "__main__": + runtime_main(TestModelParallel) diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index 443703aa937d8aead8307b892961e7054ede6ed4..a3ae2a20dba23ef39510e962b148d40364f85e72 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -70,10 +70,12 @@ class TestFusedAttentionOp(OpTest): self.attn_mask_type = np.float64 self.pre_layer_norm = False self.has_attn_mask = True + self.has_cache_kv = False self.training = True self.batch_size = 8 self.query_length = 128 + self.cache_length = 128 self.head_dim = 64 self.num_heads = 16 self.embed_dim = self.head_dim * self.num_heads @@ -88,10 +90,22 @@ class TestFusedAttentionOp(OpTest): def generate_input_data(self): self.query = np.random.rand(self.batch_size, self.query_length, self.embed_dim).astype(self.x_type) + out_seq_len = self.key_length + if self.has_cache_kv: + assert self.training is False, ValueError( + 'cache_kv can only used in inference') + self.cache_kv = np.random.rand(2, self.batch_size, self.num_heads, + self.cache_length, + self.head_dim).astype(self.x_type) + out_seq_len += self.cache_length + else: + self.cache_kv = None + if self.has_attn_mask: + # [B, n_head, seq_len, out_seq_len] self.attn_mask = np.ones( (self.batch_size, self.num_heads, self.query_length, - self.key_length), + out_seq_len), dtype=self.attn_mask_type) if self.attn_mask_type == np.int64: self.attn_mask = np.tril(self.attn_mask) @@ -110,6 +124,11 @@ class TestFusedAttentionOp(OpTest): def GetBaselineOut(self): paddle.disable_static(place=paddle.CUDAPlace(0)) tensor_query = paddle.to_tensor(self.query, stop_gradient=False) + + cache_kv = None + if self.has_cache_kv: + cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False) + if self.has_attn_mask: attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) else: @@ -130,6 +149,18 @@ class TestFusedAttentionOp(OpTest): v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + if self.has_cache_kv: + # [1, B, n_head, cache_seq_len, head_dim] + cache_k, cache_v = paddle.split(cache_kv, 2) + cache_k = paddle.squeeze(cache_k, axis=0) + cache_v = paddle.squeeze(cache_v, axis=0) + # [B, n_head, cache_seq_len + seq_len, head_dim] + # out_seq_len = cache_seq_len + seq_len + k_out = paddle.concat([cache_k, k_out], axis=-2) + v_out = paddle.concat([cache_v, v_out], axis=-2) + + # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim] + # --> [B, n_head, seq_len, out_seq_len] qk_out = layers.matmul( x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5) @@ -146,6 +177,8 @@ class TestFusedAttentionOp(OpTest): self.dropout_prob, training=self.training, mode="upscale_in_train") + # [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim] + # --> [B, n_head, seq_len, head_dim] qktv_out = tensor.matmul(dropout_out, v_out) else: qktv_out = tensor.matmul(softmax_out, v_out) @@ -160,6 +193,10 @@ class TestFusedAttentionOp(OpTest): final_out = self.norm1(residual_out) else: final_out = residual_out + + if self.has_cache_kv: + return final_out + paddle.autograd.backward( [final_out], [paddle.to_tensor(self.dout)], retain_graph=True) return final_out, tensor_query.grad @@ -206,6 +243,9 @@ class TestFusedAttentionOp(OpTest): (3, self.num_heads, self.head_dim, self.embed_dim)) x = paddle.to_tensor(self.query, stop_gradient=False) + cache_kv = None + if self.has_cache_kv: + cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False) if self.has_attn_mask: attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) else: @@ -219,8 +259,12 @@ class TestFusedAttentionOp(OpTest): final_out = incubate_f.fused_multi_head_attention( x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm, ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor, - out_linear_bias, attn_mask, self.dropout_prob, + out_linear_bias, cache_kv, attn_mask, self.dropout_prob, self.attn_dropout_prob, ln2_epsilon) + + if self.has_cache_kv: + return final_out[0], final_out[1] + paddle.autograd.backward( [final_out], [paddle.to_tensor(self.dout)], retain_graph=True) return final_out, x.grad @@ -236,114 +280,27 @@ class TestFusedAttentionOp(OpTest): class TestFusedAttentionOpBiasIsNone(TestFusedAttentionOp): def config(self): - self.x_type = np.float32 - self.attn_mask_type = np.float64 - self.pre_layer_norm = False - self.has_attn_mask = True - self.training = True - - self.batch_size = 8 - self.query_length = 128 - self.head_dim = 64 - self.num_heads = 16 - self.embed_dim = self.head_dim * self.num_heads - - self.dropout_prob = 0.0 - self.attn_dropout_prob = 0.0 - self.weight_attr = None + super().config() self.bias_attr = False - self.kdim, self.vdim = self.embed_dim, self.embed_dim - self.key_length, self.value_length = self.query_length, self.query_length - - def test_fused_attention_op(self): - final_out_ref, x_grad_ref = self.GetBaselineOut() - final_out, x_grad = self.GetFusedAttentionOut() - np.testing.assert_allclose( - final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4) - np.testing.assert_allclose( - x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4) class TestFusedAttentionOpPreLn(TestFusedAttentionOp): def config(self): - self.x_type = np.float32 - self.attn_mask_type = np.float64 + super().config() self.pre_layer_norm = True - self.has_attn_mask = True - self.training = True - - self.batch_size = 8 - self.query_length = 128 - self.head_dim = 64 - self.num_heads = 16 - self.embed_dim = self.head_dim * self.num_heads - - self.dropout_prob = 0.0 - self.attn_dropout_prob = 0.0 - self.weight_attr = None - self.bias_attr = None - self.kdim, self.vdim = self.embed_dim, self.embed_dim - self.key_length, self.value_length = self.query_length, self.query_length - - def test_fused_attention_op(self): - final_out_ref, x_grad_ref = self.GetBaselineOut() - final_out, x_grad = self.GetFusedAttentionOut() - np.testing.assert_allclose( - final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4) - np.testing.assert_allclose( - x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4) class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp): def config(self): - self.x_type = np.float32 - self.attn_mask_type = np.float64 + super().config() self.pre_layer_norm = True self.has_attn_mask = False - self.training = True - - self.batch_size = 8 - self.query_length = 128 - self.head_dim = 64 - self.num_heads = 16 - self.embed_dim = self.head_dim * self.num_heads - - self.dropout_prob = 0.0 - self.attn_dropout_prob = 0.0 - self.weight_attr = None - self.bias_attr = None - self.kdim, self.vdim = self.embed_dim, self.embed_dim - self.key_length, self.value_length = self.query_length, self.query_length - - def test_fused_attention_op(self): - final_out_ref, x_grad_ref = self.GetBaselineOut() - final_out, x_grad = self.GetFusedAttentionOut() - np.testing.assert_allclose( - final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4) - np.testing.assert_allclose( - x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4) class TestFusedAttentionOpFp16(TestFusedAttentionOp): def config(self): + super().config() self.x_type = np.float16 - self.attn_mask_type = np.float64 - self.pre_layer_norm = False - self.has_attn_mask = True - self.training = True - - self.batch_size = 8 - self.query_length = 128 - self.head_dim = 64 - self.num_heads = 16 - self.embed_dim = self.head_dim * self.num_heads - - self.dropout_prob = 0.0 - self.attn_dropout_prob = 0.0 - self.weight_attr = None - self.bias_attr = None - self.kdim, self.vdim = self.embed_dim, self.embed_dim - self.key_length, self.value_length = self.query_length, self.query_length def test_fused_attention_op(self): final_out_ref, x_grad_ref = self.GetBaselineOut() @@ -354,5 +311,21 @@ class TestFusedAttentionOpFp16(TestFusedAttentionOp): x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1) +class TestFusedAttentionOpCacheKV(TestFusedAttentionOp): + def config(self): + super().config() + self.has_cache_kv = True + self.training = False + self.query_length = 1 + self.key_length, self.value_length = 1, 1 + + def test_fused_attention_op(self): + with paddle.no_grad(): + final_out_ref = self.GetBaselineOut() + final_out, cache_kv_out = self.GetFusedAttentionOut() + np.testing.assert_allclose( + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_static_model_parallel_fused_attention.py b/python/paddle/fluid/tests/unittests/test_static_model_parallel_fused_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ce8e8170fa187b223fe48aac6124fa0b736e17 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_static_model_parallel_fused_attention.py @@ -0,0 +1,45 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from test_dist_base import TestDistBase + +import os +import paddle + +paddle.enable_static() +flag_name = os.path.splitext(__file__)[0] + + +class TestStaticModelParallel(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_reduce = False + self._use_reader_alloc = False + self._nccl_comm_num = 1 + self._pipeline_mode = True + + def test_dist_static_model_parallel_fused_feedforward(self): + import paddle.fluid as fluid + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "static_model_parallel_fused_attention.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index d600cda8454cc696579df7fa7f6e6f4d6ae12600..457422ae3a4d602a6782d0949030fb3b3fb797c2 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -223,12 +223,14 @@ def fused_multi_head_attention(x, pre_ln_epsilon=1e-05, qkv_bias=None, linear_bias=None, + cache_kv=None, attn_mask=None, dropout_rate=0.5, attn_dropout_rate=0.5, ln_epsilon=1e-05, training=True, mode='upscale_in_train', + ring_id=-1, name=None): r""" Attention mapps queries and a set of key-value pairs to outputs, and @@ -242,8 +244,8 @@ def fused_multi_head_attention(x, out = layer_norm(x) out = linear(out) + qkv) + bias else: - out = linear(x) + bias - out = transpose(out, perm=[2, 0, 3, 1, 4]) + out = linear(x) + bias + out = transpose(out, perm=[2, 0, 3, 1, 4]) # extract q, k and v from out. q = out[0:1,::] k = out[1:2,::] @@ -257,8 +259,8 @@ def fused_multi_head_attention(x, out = out_linear(out) if pre_layer_norm: out = x + dropout(linear_bias + out) - else: - out = layer_norm(x + dropout(linear_bias + out)) + else: + out = layer_norm(x + dropout(linear_bias + out)) Parameters: x (Tensor): The input tensor of fused_multi_head_attention. The shape is @@ -276,6 +278,7 @@ def fused_multi_head_attention(x, qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`. Default None. linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None. + cache_kv (Tensor, optional): For generation model, cache structure. The shape is `[2, bsz, num_head, seq_len, head_dim]`. Default None. attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the @@ -303,6 +306,7 @@ def fused_multi_head_attention(x, - train: out = input * mask - inference: out = input * (1.0 - p) + ring_id (int, optional): For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -333,7 +337,7 @@ def fused_multi_head_attention(x, output = F.fused_multi_head_attention( x, qkv_weight, linear_weight, False, None, None, None, None, 1e-5, qkv_bias, - linear_bias, attn_mask) + linear_bias, None, attn_mask) # [2, 4, 128] print(output.shape) """ @@ -359,17 +363,20 @@ def fused_multi_head_attention(x, assert qkv_weight.shape[1] * qkv_weight.shape[2] == qkv_weight.shape[ 3], "embed_dim must be divisible by num_heads." - _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention( - x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask, - linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm', - pre_layer_norm, 'epsilon', pre_ln_epsilon, 'dropout_rate', - dropout_rate, 'attn_dropout_rate', attn_dropout_rate, 'ln_epsilon', - ln_epsilon, 'attn_dropout_is_test', not training, 'dropout_is_test', - not training, 'attn_dropout_fix_seed', seed is not None, - 'dropout_fix_seed', seed is not None, 'attn_dropout_seed', seed + _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, cache_kv_out, final_out = _C_ops.fused_attention( + x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, cache_kv, + attn_mask, linear_weight, linear_bias, ln_scale, ln_bias, + 'pre_layer_norm', pre_layer_norm, 'epsilon', pre_ln_epsilon, + 'dropout_rate', dropout_rate, 'attn_dropout_rate', + attn_dropout_rate, 'ln_epsilon', ln_epsilon, 'attn_dropout_is_test', + not training, 'dropout_is_test', not training, + 'attn_dropout_fix_seed', seed is not None, 'dropout_fix_seed', + seed is not None, 'attn_dropout_seed', seed if seed is not None else 0, 'dropout_seed', seed if seed is not None else 0, 'attn_dropout_implementation', mode, - 'dropout_implementation', mode) + 'dropout_implementation', mode, 'ring_id', ring_id) + if cache_kv is not None: + return final_out, cache_kv_out return final_out else: helper = LayerHelper('fused_multi_head_attention', **locals()) @@ -398,6 +405,7 @@ def fused_multi_head_attention(x, inputs['Ln2Scale'] = [ln_scale] if ln_bias: inputs['Ln2Bias'] = [ln_bias] + if cache_kv: inputs['CacheKV'] = [cache_kv] if (seed is None or seed == 0) and helper.main_program.random_seed != 0: seed = helper.main_program.random_seed @@ -417,6 +425,7 @@ def fused_multi_head_attention(x, 'dropout_seed': seed if seed is not None else 0, 'attn_dropout_implementation': mode, 'dropout_implementation': mode, + 'ring_id': ring_id } # set outputs @@ -449,6 +458,7 @@ def fused_multi_head_attention(x, bias_dropout_residual_out = helper.create_variable_for_type_inference( dtype=dtype) final_out = helper.create_variable_for_type_inference(dtype=dtype) + cache_kv_out = helper.create_variable_for_type_inference(dtype=dtype) helper.append_op( type='fused_attention', @@ -472,7 +482,9 @@ def fused_multi_head_attention(x, "Ln2Mean": ln_mean_out, "Ln2Variance": ln_variance_out, "BiasDropoutResidualOut": bias_dropout_residual_out, - 'Y': final_out + 'Y': final_out, + 'CacheKVOut': cache_kv_out }, attrs=attrs) - return final_out + + return (final_out, cache_kv_out) if cache_kv else final_out