未验证 提交 1882c496 编写于 作者: Y Yuang Liu 提交者: GitHub

[hybrid] Support tensor parallel and cache structure for fused attention op. (#40101)

上级 e24ca55e
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/transpose_op.cu.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace paddle {
......@@ -69,20 +70,21 @@ class FMHARef {
~FMHARef() {}
void ComputeForward(const Tensor& qkv_input_tensor,
const Tensor* cache_kv_tensor,
const Tensor* src_mask_tensor,
Tensor* transpose_2_out_tensor, Tensor* qk_out_tensor,
Tensor* transpose_2_out_tensor,
Tensor* cache_kv_out_tensor, Tensor* qk_out_tensor,
Tensor* src_mask_out_tensor, Tensor* softmax_out_tensor,
Tensor* dropout_mask_out_tensor,
Tensor* dropout_out_tensor, Tensor* qktv_out_tensor,
Tensor* fmha_out_tensor) {
// input shape: [bs, seq_len, 3, num_head, head_dim]
// transpose with perm [2, 0, 1, 3, 4],
// transpose with perm [2, 0, 3, 1, 4],
// output_shape: [3, bs, num_head, seq_len, head_dim]
int ndims = 5;
std::vector<int> perm_1 = {2, 0, 3, 1, 4};
TransposeGPUKernelDriver<T>(dev_ctx_, ndims, qkv_input_tensor, perm_1,
transpose_2_out_tensor);
T* qkv_data = transpose_2_out_tensor->data<T>();
T* qk_out_data = qk_out_tensor->data<T>();
T* qktv_out_data = qktv_out_tensor->data<T>();
......@@ -90,11 +92,30 @@ class FMHARef {
T* dropout_out_data = dropout_out_tensor->data<T>();
T* fmha_out_data = fmha_out_tensor->data<T>();
int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_;
int k_size = q_size;
auto out_seq_len = seq_len_;
if (cache_kv_tensor) {
// kv [2, bs, num_head, seq_len, head_dim]
auto kv_tensor = transpose_2_out_tensor->Slice(1, 3);
phi::funcs::ConcatFunctor<phi::GPUContext, T> concat;
// out [2, bs, num_head, cache_seq_len + seq_len, head_dim]
concat(dev_ctx_, {*cache_kv_tensor, kv_tensor}, 3, cache_kv_out_tensor);
out_seq_len = cache_kv_out_tensor->dims()[3];
}
int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_;
T* q_ptr = qkv_data;
T* k_ptr = q_ptr + q_size;
T* v_ptr = k_ptr + k_size;
T* k_ptr = nullptr;
T* v_ptr = nullptr;
if (cache_kv_tensor) {
int64_t k_size = cache_kv_out_tensor->numel() / 2;
k_ptr = cache_kv_out_tensor->data<T>();
v_ptr = k_ptr + k_size;
} else {
int64_t k_size = q_size;
k_ptr = q_ptr + q_size;
v_ptr = k_ptr + k_size;
}
// q*k^t, batched_gemm
CBLAS_TRANSPOSE transA = CblasNoTrans;
......@@ -102,7 +123,7 @@ class FMHARef {
auto blas = phi::funcs::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
int gemm_batch_size = batch_size_ * num_head_;
int gemm_m = seq_len_;
int gemm_n = seq_len_;
int gemm_n = out_seq_len;
int gemm_k = head_dim_;
T alpha = static_cast<T>(1.0 / sqrt(head_dim_));
T beta = static_cast<T>(0.0);
......@@ -133,7 +154,7 @@ class FMHARef {
transB = CblasNoTrans;
gemm_m = seq_len_;
gemm_n = head_dim_;
gemm_k = seq_len_;
gemm_k = out_seq_len;
alpha = static_cast<T>(1.0);
stride_a = gemm_m * gemm_k;
stride_b = gemm_k * gemm_n;
......
......@@ -61,6 +61,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasOutput("QKTVOut"), "Output", "QKTVOut",
"FusedAttentionOp");
if (ctx->HasInput("CacheKV")) {
OP_INOUT_CHECK(ctx->HasOutput("CacheKVOut"), "Output", "CacheKVOut",
"FusedAttentionOp");
}
if (ctx->HasInput("SrcMask")) {
OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut",
"FusedAttentionOp");
......@@ -105,12 +109,14 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"input qkv_weight = [%s]",
x_dim, y_dim));
if (ctx->Attrs().Get<int>("ring_id") == -1) {
PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], y_dim[3],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
}
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
......@@ -132,20 +138,64 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
// [3, batch_size, num_head, seq_len, head_size]
ctx->SetOutputDim("TransposeOut2",
{y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
// [batch, num_head, seq_len, seq_len]
ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
// cache_seq_len + seq_len if cache else seq_len
auto out_seq_len = x_dim[1];
if (ctx->HasInput("CacheKV")) {
// [2, batch_size, num_head, cache_seq_len, head_size]
auto c_dim = ctx->GetInputDim("CacheKV");
PADDLE_ENFORCE_EQ(
c_dim.size(), 5,
paddle::platform::errors::InvalidArgument(
"The CacheKV must be 5 dims, but got %d", c_dim.size()));
PADDLE_ENFORCE_EQ(c_dim[0], 2,
paddle::platform::errors::InvalidArgument(
"The first dim of CacheKV must be 2, but got %d",
c_dim[0])); // 2
PADDLE_ENFORCE_EQ(c_dim[1], x_dim[0],
paddle::platform::errors::InvalidArgument(
"The second dim of CacheKV must be equal with "
"batch size %d, but got %d",
x_dim[0], c_dim[1])); // batch_size
PADDLE_ENFORCE_EQ(c_dim[2], y_dim[1],
paddle::platform::errors::InvalidArgument(
"The third dim of CacheKV must be equal with num "
"head %d, but got %d",
y_dim[1], c_dim[2])); // num_head
PADDLE_ENFORCE_GE(
c_dim[3], 0,
paddle::platform::errors::InvalidArgument(
"The forth dim of CacheKV must be greater than 0, but got %d",
c_dim[3])); // cache_seq_len
PADDLE_ENFORCE_EQ(c_dim[4], y_dim[2],
paddle::platform::errors::InvalidArgument(
"The fifth dim of CacheKV must be equal with head "
"size %d, but got %d",
y_dim[2], c_dim[4])); // head_size
out_seq_len += c_dim[3];
// [3, batch_size, num_head, cache_seq_len + seq_len, head_size]
ctx->SetOutputDim("CacheKVOut",
{c_dim[0], c_dim[1], c_dim[2], out_seq_len, c_dim[4]});
}
// [batch, num_head, seq_len, out_seq_len]
ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], out_seq_len});
if (ctx->HasInput("SrcMask")) {
ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
ctx->SetOutputDim("SrcMaskOut",
{x_dim[0], y_dim[1], x_dim[1], out_seq_len});
}
// the same as QKOut's shape.
ctx->SetOutputDim("AttnDropoutOut",
{x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
{x_dim[0], y_dim[1], x_dim[1], out_seq_len});
if (ctx->Attrs().Get<bool>("attn_dropout_is_test") == false) {
ctx->SetOutputDim("AttnDropoutMaskOut",
{x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
{x_dim[0], y_dim[1], x_dim[1], out_seq_len});
}
ctx->SetOutputDim("SoftmaxOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
ctx->SetOutputDim("SoftmaxOut",
{x_dim[0], y_dim[1], x_dim[1], out_seq_len});
// [batch_size, num_heads, seq_len, head_dim]
ctx->SetOutputDim("QKTVOut", {x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
// [batch_size, seq_len, number of heads*head size]
......@@ -182,6 +232,8 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDispensable();
AddInput("QKVW", "The qkv weight tensor.");
AddInput("QKVBias", "The qkv bias tensor.").AsDispensable();
AddInput("CacheKV", "(optional) The cached KV for generation inference.")
.AsDispensable();
AddInput("SrcMask", "(optional) The attention mask tensor in fmha.")
.AsDispensable();
AddInput("OutLinearW", "The out_linear weight tensor.");
......@@ -217,6 +269,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("BiasDropoutResidualOut",
"Result of residual + dropout(src + bias).")
.AsIntermediate();
AddOutput("CacheKVOut", "The udpated cache KV.");
AddOutput("Y", "Result after attention.");
AddAttr<bool>("pre_layer_norm",
......@@ -324,6 +377,10 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"0.0 and 0.001, But received [%s].",
ln_epsilon));
});
AddAttr<int>(
"ring_id",
"ring id for tensor model parallel. distributed training and inference")
.SetDefault(-1);
AddComment(R"DOC(
Add fused attention op whose logic is as follows:
......
......@@ -27,11 +27,39 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
static void AllReduce(framework::Tensor &tensor, // NOLINT
const int ring_id,
const platform::CUDADeviceContext &ctx) {
if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto dtype =
platform::ToNCCLDataType(framework::TransToProtoVarType(tensor.dtype()));
int64_t numel = tensor.numel();
const void *sendbuff = tensor.data<T>();
auto place = ctx.GetPlace();
void *recvbuff = tensor.mutable_data<T>(place);
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream));
#else
PADDLE_THROW(platform::errors::Unimplemented(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."));
#endif
}
template <typename T>
class FusedAttentionOpKernel : public framework::OpKernel<T> {
public:
......@@ -56,6 +84,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto *src_mask = ctx.Input<Tensor>("SrcMask");
auto *transpose_out_2 = ctx.Output<Tensor>("TransposeOut2");
auto *cache_kv = ctx.Input<Tensor>("CacheKV");
auto *cache_kv_out = ctx.Output<Tensor>("CacheKVOut");
auto *qk_out = ctx.Output<Tensor>("QKOut");
auto *qktv_out = ctx.Output<Tensor>("QKTVOut");
auto *softmax_out = ctx.Output<Tensor>("SoftmaxOut");
......@@ -86,6 +116,7 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input<Tensor>("Seed1") : nullptr;
bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
int ring_id = ctx.Attr<int>("ring_id");
// final output.
auto *out = ctx.Output<Tensor>("Y");
......@@ -105,6 +136,10 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// get data ptr for FMHA.
auto *transpose_out_2_data =
transpose_out_2->mutable_data<T>(ctx.GetPlace());
auto *cache_kv_out_data =
(cache_kv_out == nullptr)
? nullptr
: cache_kv_out->mutable_data<T>(ctx.GetPlace());
auto *qk_out_data = qk_out->mutable_data<T>(ctx.GetPlace());
auto *qktv_out_data = qktv_out->mutable_data<T>(ctx.GetPlace());
auto *src_mask_out_data =
......@@ -161,9 +196,14 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
output_size = hidden_size;
// (transA, transB, compute_bias) = (false, false, false)
// NOTE(Yuang Liu): For general input size == output size, change the
// position won't have effects. For mp, the output size is mp_head * dkey
// which is actually the input size. While the input size is hidden size,
// which is actually the output size. So for out linear, switch the
// input size and output size.
auto out_linear_compute =
AttnMatMul<T>(ctx.cuda_device_context(), false, false, bsz_seq,
output_size, input_size, false);
input_size, output_size, false);
DropoutParam dropout_param2(ctx, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2,
......@@ -186,15 +226,15 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
qkv_bias_out);
}
if (qkv_bias == nullptr) {
fmha_ref_compute.ComputeForward(*qkv_out, src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out,
fmha_ref_compute.ComputeForward(
*qkv_out, cache_kv, src_mask, transpose_out_2, cache_kv_out, qk_out,
src_mask_out, softmax_out, attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out);
} else {
fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out);
fmha_ref_compute.ComputeForward(
*qkv_bias_out, cache_kv, src_mask, transpose_out_2, cache_kv_out,
qk_out, src_mask_out, softmax_out, attn_dropout_mask_out,
attn_dropout_out, qktv_out, fmha_out);
}
// fmha_out: [batch_size, seq_len, num_head, head_dim]
......@@ -202,6 +242,9 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// out_linear_out: [batch_size, seq_len, embed_dim]
out_linear_compute.ComputeForward(out_linear_weight, fmha_out, nullptr,
out_linear_out, nullptr);
// tensor model parallel
AllReduce<T>(*out_linear_out, ring_id, ctx.cuda_device_context());
if (pre_layer_norm) {
// output = (residual + dropout(input + bias))
fused_dropout_layernorm_helper.ResidualDropoutBias(
......@@ -244,6 +287,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input<Tensor>("Seed1") : nullptr;
bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
int ring_id = ctx.Attr<int>("ring_id");
// get inputs.
auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
......@@ -399,9 +443,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
transA = false;
transB = false;
bool compute_bias = false;
// (b*s, num_head * dim_head) * (num_head * dim_head, dim_embed)
auto out_linear_compute =
AttnMatMul<T>(ctx.cuda_device_context(), transA, transB, bsz_seq,
output_size, input_size, compute_bias);
input_size, output_size, compute_bias);
DropoutParam dropout_param2(ctx, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2,
......@@ -475,6 +520,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_out, d_ln_out,
d_qkv_weight, d_qkv_bias);
}
// tensor model parallel
AllReduce<T>(*d_ln_out, ring_id, ctx.cuda_device_context());
layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data,
ln_mean_data, ln_var_data, d_x_data,
d_ln_scale_data, d_ln_bias_data);
......@@ -486,6 +533,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_out, d_x,
d_qkv_weight, d_qkv_bias);
}
// tensor model parallel
AllReduce<T>(*d_x, ring_id, ctx.cuda_device_context());
}
// gradient accumulation
std::vector<const Tensor *> ins;
......
......@@ -30,8 +30,8 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"layer_norm", {"X", "Scale", "Bias"}},
{"bincount", {"X", "Weights"}},
{"fused_attention",
{"X", "LnScale", "LnBias", "QKVW", "QKVBias", "SrcMask", "OutLinearW",
"OutLinearBias", "Ln2Scale", "Ln2Bias"}},
{"X", "LnScale", "LnBias", "QKVW", "QKVBias", "CacheKV", "SrcMask",
"OutLinearW", "OutLinearBias", "Ln2Scale", "Ln2Bias"}},
{"instance_norm", {"X", "Scale", "Bias"}},
{"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}},
{"label_smooth", {"X", "PriorDist"}},
......@@ -104,11 +104,16 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"batch_norm",
{"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance",
"ReserveSpace"}},
{"fused_attention",
{"LnMean", "LnVariance", "LnOut", "QKVOut", "QKVBiasOut", "TransposeOut2",
"QKOut", "QKTVOut", "SoftmaxOut", "AttnDropoutMaskOut", "AttnDropoutOut",
"SrcMaskOut", "FMHAOut", "OutLinearOut", "DropoutMaskOut", "Ln2Mean",
"Ln2Variance", "BiasDropoutResidualOut", "Y"}},
{"fused_attention", {"LnMean", "LnVariance",
"LnOut", "QKVOut",
"QKVBiasOut", "TransposeOut2",
"QKOut", "QKTVOut",
"SoftmaxOut", "AttnDropoutMaskOut",
"AttnDropoutOut", "SrcMaskOut",
"FMHAOut", "OutLinearOut",
"DropoutMaskOut", "Ln2Mean",
"Ln2Variance", "BiasDropoutResidualOut",
"CacheKVOut", "Y"}},
{"sync_batch_norm",
{"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance",
"ReserveSpace"}},
......
......@@ -24,6 +24,7 @@ list(APPEND DIST_TEST_OPS test_pipeline)
list(APPEND DIST_TEST_OPS test_ir_pass_pipeline)
list(APPEND DIST_TEST_OPS test_static_model_parallel)
list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_feedforward)
list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_attention)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_se_resnext)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height)
......@@ -1155,6 +1156,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
set_tests_properties(test_ir_pass_pipeline PROPERTIES TIMEOUT 120)
set_tests_properties(test_static_model_parallel PROPERTIES TIMEOUT 240)
set_tests_properties(test_static_model_parallel_fused_feedforward PROPERTIES TIMEOUT 120)
set_tests_properties(test_static_model_parallel_fused_attention PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_split_embedding
test_collective_split_embedding_none_divisible
test_collective_split_row_linear
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle
import paddle.fluid as fluid
from test_dist_base import TestDistRunnerBase, runtime_main
import paddle.distributed.fleet as fleet
import paddle.incubate.nn.functional as incubate_f
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid import core
from paddle.nn.initializer import Constant
paddle.enable_static()
def _set_var_distributed(var):
if var is None:
return
var.is_distributed = True
# NOTE: use current_block and find_var_recursive to support while_loop
startup_block = paddle.static.default_startup_program().current_block()
main_block = paddle.static.default_main_program().current_block()
startup_block._find_var_recursive(var.name).is_distributed = True
main_block._find_var_recursive(var.name).is_distributed = True
class ParallelFusedMultiHeadAttention(Layer):
def __init__(self,
embed_dim,
num_heads,
dropout_rate=0.5,
attn_dropout_rate=0.5,
kdim=None,
vdim=None,
normalize_before=False,
need_weights=False,
qkv_weight_attr=None,
qkv_bias_attr=None,
linear_weight_attr=None,
linear_bias_attr=None,
pre_ln_scale_attr=None,
pre_ln_bias_attr=None,
ln_scale_attr=None,
ln_bias_attr=None,
epsilon=1e-5,
nranks=1,
ring_id=-1,
name=None):
super(ParallelFusedMultiHeadAttention, self).__init__()
assert embed_dim > 0, ("Expected embed_dim to be greater than 0, "
"but recieved {}".format(embed_dim))
assert num_heads > 0, ("Expected nhead to be greater than 0, "
"but recieved {}".format(num_heads))
self.normalize_before = normalize_before
self._dtype = self._helper.get_default_dtype()
self._epsilon = epsilon
self._ring_id = ring_id
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.kdim = kdim
self.vdim = vdim
self.need_weights = need_weights
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
assert need_weights == False, "Only support need_weight is False now."
# tensor model parallel
assert num_heads % nranks == 0
num_heads = num_heads // nranks
self.qkv_weight = self.create_parameter(
shape=[3, num_heads, self.head_dim, embed_dim],
attr=qkv_weight_attr,
dtype=self._dtype,
is_bias=False)
self.qkv_bias = self.create_parameter(
shape=[3, num_heads, self.head_dim],
attr=qkv_bias_attr,
dtype=self._dtype,
is_bias=True)
self.linear_weight = self.create_parameter(
shape=[num_heads * self.head_dim, embed_dim],
attr=linear_weight_attr,
dtype=self._dtype,
is_bias=False)
self.linear_bias = self.create_parameter(
shape=[embed_dim],
attr=linear_bias_attr,
dtype=self._dtype,
is_bias=True)
# tensor model parallel
if nranks > 1:
assert ring_id != -1
# column parallel
_set_var_distributed(self.qkv_weight)
_set_var_distributed(self.qkv_bias)
# row parallel
_set_var_distributed(self.linear_weight)
if normalize_before:
self.pre_ln_scale = self.create_parameter(
attr=pre_ln_scale_attr,
shape=[embed_dim],
default_initializer=Constant(value=1.0))
self.pre_ln_bias = self.create_parameter(
attr=pre_ln_bias_attr, shape=[embed_dim], is_bias=True)
self.ln_scale = None
self.ln_bias = None
else:
self.pre_ln_scale = None
self.pre_ln_bias = None
self.ln_scale = self.create_parameter(
attr=ln_scale_attr,
shape=[embed_dim],
default_initializer=Constant(value=1.0))
self.ln_bias = self.create_parameter(
attr=ln_bias_attr, shape=[embed_dim], is_bias=True)
self.dropout_rate = dropout_rate
self.attn_dropout_rate = attn_dropout_rate
self.name = name
def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
out = incubate_f.fused_multi_head_attention(
x=query,
qkv_weight=self.qkv_weight,
linear_weight=self.linear_weight,
pre_layer_norm=self.normalize_before,
pre_ln_scale=self.pre_ln_scale,
pre_ln_bias=self.pre_ln_bias,
ln_scale=self.ln_scale,
ln_bias=self.ln_bias,
pre_ln_epsilon=self._epsilon,
qkv_bias=self.qkv_bias,
linear_bias=self.linear_bias,
attn_mask=attn_mask,
dropout_rate=self.dropout_rate,
attn_dropout_rate=self.attn_dropout_rate,
ln_epsilon=self._epsilon,
training=self.training,
ring_id=self._ring_id,
name=self.name)
return out
def get_param_attr(weight, bias):
weight_attr = paddle.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(weight))
bias_attr = paddle.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(bias))
return weight_attr, bias_attr
DTYPE = "float32"
MODEL_PARALLEL_SIZE = 2
n_head = 2 * MODEL_PARALLEL_SIZE
d_key = 4
hidden = n_head * d_key
def create_model(data, rank):
np.random.seed(2021)
pre_ln_w = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE)
pre_ln_b = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE)
qkv_w = np.random.uniform(
-1, 1, size=(3, n_head, d_key, hidden)).astype(DTYPE)
qkv_b = np.random.uniform(-1, 1, size=(3, n_head, d_key)).astype(DTYPE)
linear_w = np.random.uniform(
-1, 1, size=(n_head * d_key, hidden)).astype(DTYPE)
linear_b = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE)
data.stop_gradient = False
if rank is not None:
start = 0 if rank == 0 else n_head // MODEL_PARALLEL_SIZE
end = start + n_head // MODEL_PARALLEL_SIZE
col_qkv_w = qkv_w[:, start:end, :, :]
col_qkv_b = qkv_b[:, start:end, :]
row_linear_w = linear_w[(start * d_key):(end * d_key), :]
pre_ln_w_attr, pre_ln_b_attr = get_param_attr(pre_ln_w, pre_ln_b)
qkv_w_attr, qkv_b_attr = get_param_attr(col_qkv_w, col_qkv_b)
linear_w_attr, linear_b_attr = get_param_attr(row_linear_w, linear_b)
attn = ParallelFusedMultiHeadAttention(
hidden,
n_head,
dropout_rate=0.0,
attn_dropout_rate=0.0,
normalize_before=False,
qkv_weight_attr=qkv_w_attr,
qkv_bias_attr=qkv_b_attr,
linear_weight_attr=linear_w_attr,
linear_bias_attr=linear_b_attr,
pre_ln_scale_attr=pre_ln_w_attr,
pre_ln_bias_attr=pre_ln_b_attr,
ln_scale_attr=pre_ln_w_attr,
ln_bias_attr=pre_ln_b_attr,
nranks=MODEL_PARALLEL_SIZE,
ring_id=0)
result = attn(data)
else:
pre_ln_w_attr, pre_ln_b_attr = get_param_attr(pre_ln_w, pre_ln_b)
qkv_w_attr, qkv_b_attr = get_param_attr(qkv_w, qkv_b)
linear_w_attr, linear_b_attr = get_param_attr(linear_w, linear_b)
attn = ParallelFusedMultiHeadAttention(
hidden,
n_head,
dropout_rate=0.0,
attn_dropout_rate=0.0,
normalize_before=False,
qkv_weight_attr=qkv_w_attr,
qkv_bias_attr=qkv_b_attr,
linear_weight_attr=linear_w_attr,
linear_bias_attr=linear_b_attr,
pre_ln_scale_attr=pre_ln_w_attr,
pre_ln_bias_attr=pre_ln_b_attr,
ln_scale_attr=pre_ln_w_attr,
ln_bias_attr=pre_ln_b_attr)
result = attn(data)
predict = paddle.sum(result)
return predict
class TestModelParallel(TestDistRunnerBase):
def get_model(self, batch_size=2, use_dgc=False, dist_strategy=None):
# Input data
seq_len = 2
data_in = fluid.data(
name='data_in', shape=[batch_size, seq_len, hidden], dtype=DTYPE)
if dist_strategy:
data_loader = fluid.io.DataLoader.from_generator(
feed_list=[data_in],
capacity=64,
use_double_buffer=False,
iterable=False)
if dist_strategy:
fleet.init(is_collective=True)
strategy = fleet.DistributedStrategy()
strategy.tensor_parallel = True
strategy.tensor_parallel_configs = {'tensor_parallel_degree': 2}
rank = fleet.worker_index() if dist_strategy else None
avg_cost = create_model(data_in, rank)
opt = fluid.optimizer.SGD(0.1)
if dist_strategy:
dist_opt = fleet.distributed_optimizer(
optimizer=opt, strategy=strategy)
dist_opt.minimize(avg_cost)
else:
opt.minimize(avg_cost)
def gen_data():
np.random.seed(2021)
while True:
data = [np.random.random([seq_len, hidden]).astype(DTYPE)]
yield data
train_reader = paddle.batch(gen_data, batch_size=batch_size)
if dist_strategy:
return None, avg_cost, train_reader, None, None, None, data_loader
else:
return None, avg_cost, train_reader, None, None, None
if __name__ == "__main__":
runtime_main(TestModelParallel)
......@@ -70,10 +70,12 @@ class TestFusedAttentionOp(OpTest):
self.attn_mask_type = np.float64
self.pre_layer_norm = False
self.has_attn_mask = True
self.has_cache_kv = False
self.training = True
self.batch_size = 8
self.query_length = 128
self.cache_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
......@@ -88,10 +90,22 @@ class TestFusedAttentionOp(OpTest):
def generate_input_data(self):
self.query = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type)
out_seq_len = self.key_length
if self.has_cache_kv:
assert self.training is False, ValueError(
'cache_kv can only used in inference')
self.cache_kv = np.random.rand(2, self.batch_size, self.num_heads,
self.cache_length,
self.head_dim).astype(self.x_type)
out_seq_len += self.cache_length
else:
self.cache_kv = None
if self.has_attn_mask:
# [B, n_head, seq_len, out_seq_len]
self.attn_mask = np.ones(
(self.batch_size, self.num_heads, self.query_length,
self.key_length),
out_seq_len),
dtype=self.attn_mask_type)
if self.attn_mask_type == np.int64:
self.attn_mask = np.tril(self.attn_mask)
......@@ -110,6 +124,11 @@ class TestFusedAttentionOp(OpTest):
def GetBaselineOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
cache_kv = None
if self.has_cache_kv:
cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
......@@ -130,6 +149,18 @@ class TestFusedAttentionOp(OpTest):
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])
if self.has_cache_kv:
# [1, B, n_head, cache_seq_len, head_dim]
cache_k, cache_v = paddle.split(cache_kv, 2)
cache_k = paddle.squeeze(cache_k, axis=0)
cache_v = paddle.squeeze(cache_v, axis=0)
# [B, n_head, cache_seq_len + seq_len, head_dim]
# out_seq_len = cache_seq_len + seq_len
k_out = paddle.concat([cache_k, k_out], axis=-2)
v_out = paddle.concat([cache_v, v_out], axis=-2)
# [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, out_seq_len]
qk_out = layers.matmul(
x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5)
......@@ -146,6 +177,8 @@ class TestFusedAttentionOp(OpTest):
self.dropout_prob,
training=self.training,
mode="upscale_in_train")
# [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, head_dim]
qktv_out = tensor.matmul(dropout_out, v_out)
else:
qktv_out = tensor.matmul(softmax_out, v_out)
......@@ -160,6 +193,10 @@ class TestFusedAttentionOp(OpTest):
final_out = self.norm1(residual_out)
else:
final_out = residual_out
if self.has_cache_kv:
return final_out
paddle.autograd.backward(
[final_out], [paddle.to_tensor(self.dout)], retain_graph=True)
return final_out, tensor_query.grad
......@@ -206,6 +243,9 @@ class TestFusedAttentionOp(OpTest):
(3, self.num_heads, self.head_dim, self.embed_dim))
x = paddle.to_tensor(self.query, stop_gradient=False)
cache_kv = None
if self.has_cache_kv:
cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
......@@ -219,8 +259,12 @@ class TestFusedAttentionOp(OpTest):
final_out = incubate_f.fused_multi_head_attention(
x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm,
ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor,
out_linear_bias, attn_mask, self.dropout_prob,
out_linear_bias, cache_kv, attn_mask, self.dropout_prob,
self.attn_dropout_prob, ln2_epsilon)
if self.has_cache_kv:
return final_out[0], final_out[1]
paddle.autograd.backward(
[final_out], [paddle.to_tensor(self.dout)], retain_graph=True)
return final_out, x.grad
......@@ -236,122 +280,51 @@ class TestFusedAttentionOp(OpTest):
class TestFusedAttentionOpBiasIsNone(TestFusedAttentionOp):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = False
self.has_attn_mask = True
self.training = True
self.batch_size = 8
self.query_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
super().config()
self.bias_attr = False
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
def test_fused_attention_op(self):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4)
class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
super().config()
self.pre_layer_norm = True
self.has_attn_mask = True
self.training = True
self.batch_size = 8
self.query_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
def test_fused_attention_op(self):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4)
class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
super().config()
self.pre_layer_norm = True
self.has_attn_mask = False
self.training = True
self.batch_size = 8
self.query_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
class TestFusedAttentionOpFp16(TestFusedAttentionOp):
def config(self):
super().config()
self.x_type = np.float16
def test_fused_attention_op(self):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4)
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4)
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1)
class TestFusedAttentionOpFp16(TestFusedAttentionOp):
class TestFusedAttentionOpCacheKV(TestFusedAttentionOp):
def config(self):
self.x_type = np.float16
self.attn_mask_type = np.float64
self.pre_layer_norm = False
self.has_attn_mask = True
self.training = True
self.batch_size = 8
self.query_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
super().config()
self.has_cache_kv = True
self.training = False
self.query_length = 1
self.key_length, self.value_length = 1, 1
def test_fused_attention_op(self):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1)
with paddle.no_grad():
final_out_ref = self.GetBaselineOut()
final_out, cache_kv_out = self.GetFusedAttentionOut()
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1)
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4)
if __name__ == "__main__":
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_dist_base import TestDistBase
import os
import paddle
paddle.enable_static()
flag_name = os.path.splitext(__file__)[0]
class TestStaticModelParallel(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
self._use_reader_alloc = False
self._nccl_comm_num = 1
self._pipeline_mode = True
def test_dist_static_model_parallel_fused_feedforward(self):
import paddle.fluid as fluid
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"static_model_parallel_fused_attention.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == '__main__':
unittest.main()
......@@ -223,12 +223,14 @@ def fused_multi_head_attention(x,
pre_ln_epsilon=1e-05,
qkv_bias=None,
linear_bias=None,
cache_kv=None,
attn_mask=None,
dropout_rate=0.5,
attn_dropout_rate=0.5,
ln_epsilon=1e-05,
training=True,
mode='upscale_in_train',
ring_id=-1,
name=None):
r"""
Attention mapps queries and a set of key-value pairs to outputs, and
......@@ -276,6 +278,7 @@ def fused_multi_head_attention(x,
qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`.
Default None.
linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
cache_kv (Tensor, optional): For generation model, cache structure. The shape is `[2, bsz, num_head, seq_len, head_dim]`. Default None.
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the
......@@ -303,6 +306,7 @@ def fused_multi_head_attention(x,
- train: out = input * mask
- inference: out = input * (1.0 - p)
ring_id (int, optional): For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
......@@ -333,7 +337,7 @@ def fused_multi_head_attention(x,
output = F.fused_multi_head_attention(
x, qkv_weight, linear_weight, False,
None, None, None, None, 1e-5, qkv_bias,
linear_bias, attn_mask)
linear_bias, None, attn_mask)
# [2, 4, 128]
print(output.shape)
"""
......@@ -359,17 +363,20 @@ def fused_multi_head_attention(x,
assert qkv_weight.shape[1] * qkv_weight.shape[2] == qkv_weight.shape[
3], "embed_dim must be divisible by num_heads."
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention(
x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask,
linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm',
pre_layer_norm, 'epsilon', pre_ln_epsilon, 'dropout_rate',
dropout_rate, 'attn_dropout_rate', attn_dropout_rate, 'ln_epsilon',
ln_epsilon, 'attn_dropout_is_test', not training, 'dropout_is_test',
not training, 'attn_dropout_fix_seed', seed is not None,
'dropout_fix_seed', seed is not None, 'attn_dropout_seed', seed
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, cache_kv_out, final_out = _C_ops.fused_attention(
x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, cache_kv,
attn_mask, linear_weight, linear_bias, ln_scale, ln_bias,
'pre_layer_norm', pre_layer_norm, 'epsilon', pre_ln_epsilon,
'dropout_rate', dropout_rate, 'attn_dropout_rate',
attn_dropout_rate, 'ln_epsilon', ln_epsilon, 'attn_dropout_is_test',
not training, 'dropout_is_test', not training,
'attn_dropout_fix_seed', seed is not None, 'dropout_fix_seed',
seed is not None, 'attn_dropout_seed', seed
if seed is not None else 0, 'dropout_seed', seed
if seed is not None else 0, 'attn_dropout_implementation', mode,
'dropout_implementation', mode)
'dropout_implementation', mode, 'ring_id', ring_id)
if cache_kv is not None:
return final_out, cache_kv_out
return final_out
else:
helper = LayerHelper('fused_multi_head_attention', **locals())
......@@ -398,6 +405,7 @@ def fused_multi_head_attention(x,
inputs['Ln2Scale'] = [ln_scale]
if ln_bias:
inputs['Ln2Bias'] = [ln_bias]
if cache_kv: inputs['CacheKV'] = [cache_kv]
if (seed is None or seed == 0) and helper.main_program.random_seed != 0:
seed = helper.main_program.random_seed
......@@ -417,6 +425,7 @@ def fused_multi_head_attention(x,
'dropout_seed': seed if seed is not None else 0,
'attn_dropout_implementation': mode,
'dropout_implementation': mode,
'ring_id': ring_id
}
# set outputs
......@@ -449,6 +458,7 @@ def fused_multi_head_attention(x,
bias_dropout_residual_out = helper.create_variable_for_type_inference(
dtype=dtype)
final_out = helper.create_variable_for_type_inference(dtype=dtype)
cache_kv_out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(
type='fused_attention',
......@@ -472,7 +482,9 @@ def fused_multi_head_attention(x,
"Ln2Mean": ln_mean_out,
"Ln2Variance": ln_variance_out,
"BiasDropoutResidualOut": bias_dropout_residual_out,
'Y': final_out
'Y': final_out,
'CacheKVOut': cache_kv_out
},
attrs=attrs)
return final_out
return (final_out, cache_kv_out) if cache_kv else final_out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册