未验证 提交 50bfe420 编写于 作者: W WangXi 提交者: GitHub

[cherry-pick 2.3] Add fused_multi_transformer op to optimize transformer...

[cherry-pick 2.3] Add fused_multi_transformer op to optimize transformer generation performance (#42311)

* Add fused_multi_transformer op to optimize transformer generation performance (#41814)

* fix fused_multi_transformer compile failed in cuda arch < sm53 (#42315)

* fix ci timeout
上级 765fbb59
...@@ -19,6 +19,7 @@ register_operators(EXCLUDES ...@@ -19,6 +19,7 @@ register_operators(EXCLUDES
fused_attention_op fused_attention_op
fused_transformer_op fused_transformer_op
fused_feedforward_op fused_feedforward_op
fused_multi_transformer_op
resnet_unit_op resnet_unit_op
fused_gemm_epilogue_op) fused_gemm_epilogue_op)
...@@ -73,6 +74,7 @@ if (WITH_GPU OR WITH_ROCM) ...@@ -73,6 +74,7 @@ if (WITH_GPU OR WITH_ROCM)
op_library(fused_feedforward_op) op_library(fused_feedforward_op)
# fused_attention_op # fused_attention_op
op_library(fused_attention_op) op_library(fused_attention_op)
op_library(fused_multi_transformer_op)
endif() endif()
# resnet_unit needs cudnn 8.0 above # resnet_unit needs cudnn 8.0 above
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000)) if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000))
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class FusedMultiTransformerOp : public framework::OperatorWithKernel {
private:
static constexpr const char *OpName = "FusedMultiTransformerOp";
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
#define CHECK_INPUT(name) \
OP_INOUT_CHECK(ctx->HasInput(#name), "Input", #name, OpName)
#define CHECK_INPUTS(name) \
OP_INOUT_CHECK(ctx->HasInputs(#name), "Input", #name, OpName)
#define CHECK_OUTPUT(name) \
OP_INOUT_CHECK(ctx->HasOutput(#name), "Output", #name, OpName)
#define CHECK_OUTPUTS(name) \
OP_INOUT_CHECK(ctx->HasOutputs(#name), "Output", #name, OpName)
CHECK_INPUT(X);
// attention
CHECK_INPUTS(QKVW);
CHECK_INPUTS(OutLinearW);
if (ctx->HasInput("TimeStep")) {
CHECK_INPUTS(CacheKV);
}
if (ctx->HasInputs("CacheKV")) {
CHECK_OUTPUTS(CacheKVOut);
}
// ffn
CHECK_INPUTS(FFN1Weight);
CHECK_INPUTS(FFN2Weight);
CHECK_OUTPUT(Out);
// x: qkv's input [batch_size, seq_len, dim_embed]
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputsDim("QKVW")[0];
PADDLE_ENFORCE_EQ(x_dim.size(), 3, platform::errors::InvalidArgument(
"The dimensions of x must be 3"
"(batch_size, seq_len, dim_embed),"
"but received dimensions of"
"Input is [%d]",
x_dim.size()));
PADDLE_ENFORCE_EQ(y_dim.size(), 4,
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"but received dimensions of"
"Input is [%d]",
y_dim.size()));
PADDLE_ENFORCE_EQ(x_dim[2], y_dim[3],
platform::errors::InvalidArgument(
"ShapeError: the dimension of x_dim[2] and y_dim[3]"
"must be equal. But received: the shape "
"of input x = [%s], and the shape of "
"input qkv_weight = [%s]",
x_dim, y_dim));
if (ctx->Attrs().Get<int>("ring_id") == -1) {
PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], y_dim[3],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
}
if (ctx->HasInputs("CacheKV")) {
// [2, batch_size, num_head, max_seq_len, head_size]
const auto &c_dims = ctx->GetInputsDim("CacheKV");
const auto &c_dim = c_dims[0];
PADDLE_ENFORCE_EQ(
c_dim.size(), 5,
paddle::platform::errors::InvalidArgument(
"The CacheKV must be 5 dims, but got %d", c_dim.size()));
PADDLE_ENFORCE_EQ(c_dim[0], 2,
paddle::platform::errors::InvalidArgument(
"The first dim of CacheKV must be 2, but got %d",
c_dim[0])); // 2
PADDLE_ENFORCE_EQ(c_dim[1], x_dim[0],
paddle::platform::errors::InvalidArgument(
"The second dim of CacheKV must be equal with "
"batch size %d, but got %d",
x_dim[0], c_dim[1])); // batch_size
PADDLE_ENFORCE_EQ(c_dim[2], y_dim[1],
paddle::platform::errors::InvalidArgument(
"The third dim of CacheKV must be equal with num "
"head %d, but got %d",
y_dim[1], c_dim[2])); // num_head
PADDLE_ENFORCE_GT(
c_dim[3], 0,
paddle::platform::errors::InvalidArgument(
"The forth dim of CacheKV must be greater than 0, but got %d",
c_dim[3])); // cache_seq_len
PADDLE_ENFORCE_EQ(c_dim[4], y_dim[2],
paddle::platform::errors::InvalidArgument(
"The fifth dim of CacheKV must be equal with head "
"size %d, but got %d",
y_dim[2], c_dim[4])); // head_size
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "TimeStep") {
VLOG(10) << "var_name:" << var_name << " need not to transform";
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class FusedMultiTransformerOpOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor.");
AddInput("LnScale",
"Scale is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDuplicable();
AddInput("LnBias",
"Bias is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDuplicable();
AddInput("QKVW", "The qkv weight tensor.").AsDuplicable();
AddInput("QKVBias", "The qkv bias tensor.").AsDispensable().AsDuplicable();
AddInput("CacheKV", "(optional) The cached KV for generation inference.")
.AsDispensable()
.AsDuplicable();
AddInput("TimeStep",
"(optional, int) The time step for generation inference.")
.AsDispensable();
AddInput("SrcMask", "(optional) The attention mask tensor in fmha.")
.AsDispensable();
AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable();
AddInput("OutLinearBias", "The out_linear bias tensor.")
.AsDispensable()
.AsDuplicable();
AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op")
.AsDuplicable();
AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op")
.AsDuplicable();
AddInput("FFN1Weight", "The linear1 weight of FusedFeedForward op")
.AsDuplicable();
AddInput("FFN1Bias", "The linear1 bias of FusedFeedForward op")
.AsDispensable()
.AsDuplicable();
AddInput("FFN2Weight", "The linear2 weight of FusedFeedForward op")
.AsDuplicable();
AddInput("FFN2Bias", "The linear2 bias input of FusedFeedForward op")
.AsDispensable()
.AsDuplicable();
AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV")
.AsDispensable()
.AsDuplicable();
AddOutput("Out", "Result after multi .");
AddAttr<bool>("pre_layer_norm",
"if true, the attention op uses pre_layer_norm architecure, "
"else, uses post_layer_norm architecuture. "
"[default true].")
.SetDefault(true);
AddAttr<float>("epsilon",
"Constant for numerical stability [default 1e-5].")
.SetDefault(1e-5)
.AddCustomChecker([](const float &epsilon) {
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true,
platform::errors::InvalidArgument(
"'epsilon' in Op(LayerNorm) should be between"
"0.0 and 0.001, But received [%s].",
epsilon));
});
AddAttr<float>("dropout_rate", "Probability of setting units to zero.")
.SetDefault(.5f)
.AddCustomChecker([](const float &drop_p) {
PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, true,
platform::errors::InvalidArgument(
"'dropout_rate' must be between 0.0 and 1.0."));
});
AddAttr<bool>("dropout_is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddAttr<std::string>(
"dropout_implementation",
"[\"downgrade_in_infer\"|\"upscale_in_train\"]"
"The meaning is the same as 'attn_dropout_implementation'.")
.SetDefault("downgrade_in_infer")
.AddCustomChecker([](const std::string &type) {
PADDLE_ENFORCE_EQ(
type == "downgrade_in_infer" || type == "upscale_in_train", true,
platform::errors::InvalidArgument(
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train"));
});
AddAttr<std::string>("act_method", "act_method").SetDefault("gelu");
AddAttr<int>(
"ring_id",
"ring id for tensor model parallel. distributed training and inference")
.SetDefault(-1);
AddComment(R"DOC(fused multi transformer layers op)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
fused_multi_transformer, ops::FusedMultiTransformerOp,
ops::FusedMultiTransformerOpOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// This file has been adapted from FasterTransformer file:
// https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu
// We add License in the head.
#include <cuda_fp16.h>
#include <float.h>
#include <cub/cub.cuh>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/operators/fused/attention_layer_norm.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
// for debug
// #define _DEBUG_FUSED_MULTI_TRANSFORMER
template <typename T>
static void AllReduce(framework::Tensor &tensor, // NOLINT
const int ring_id,
const platform::CUDADeviceContext &ctx) {
if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto dtype =
platform::ToNCCLDataType(framework::TransToProtoVarType(tensor.dtype()));
int64_t numel = tensor.numel();
const void *sendbuff = tensor.data<T>();
auto place = ctx.GetPlace();
void *recvbuff = tensor.mutable_data<T>(place);
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream));
#else
PADDLE_THROW(platform::errors::Unimplemented(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."));
#endif
}
namespace {
namespace plat = paddle::platform;
using float16 = plat::float16;
#define MMHA_USE_FP32_ACUM_FOR_LOGITS
#define MMHA_USE_FP32_ACUM_FOR_OUT
template <typename T>
struct Masked_multihead_attention_params {
// output buffer, [B, 1(seq_len), num_head * dim_head]
T *out;
// qkv_out, [B, 1(seq_len), 3, num_head * dim_head]
const T *qkv;
// bias, [3, num_head, dim_head]
const T *qkv_bias;
// TODO(wangxi): optimize with input_lengths and max_input_len?
// [bsz, 1, 1, time_step(cache_seq_length)+1]
const T *attn_mask;
// [2, B, num_head, max_seq_len(valid cache_seq_len), dim_head]
// k [B, num_head, dim_head/x, max_seq_len, x], that is `seq_len` first
// v [B, num_head, max_seq_len, dim_head]
T *cache_kv;
int batch_size;
int num_head;
int timestep; // cache_seq_length
int max_seq_length;
// 1.f / sqrt(Dh)
float inv_sqrt_dh;
};
struct Float8_ {
float2 x;
float2 y;
float2 z;
float2 w;
};
// clang-format off
template <typename T, int Dh> struct Qk_vec_ {};
template <> struct Qk_vec_<float, 32> { using Type = float; };
template <> struct Qk_vec_<float, 64> { using Type = float2; };
template <> struct Qk_vec_<float, 128> { using Type = float4; };
template <> struct Qk_vec_<float16, 32> { using Type = uint32_t; };
template <> struct Qk_vec_<float16, 64> { using Type = uint32_t; };
template <> struct Qk_vec_<float16, 128> { using Type = uint2; };
template <typename T, int THREADS_PER_KEY> struct K_vec_ {};
template <> struct K_vec_<float, 4> { using Type = float; };
template <> struct K_vec_<float, 2> { using Type = float2; };
template <> struct K_vec_<float, 1> { using Type = float4; };
template <> struct K_vec_<float16, 4> { using Type = uint32_t; };
template <> struct K_vec_<float16, 2> { using Type = uint2; };
template <> struct K_vec_<float16, 1> { using Type = uint4; };
template <typename T, int V_VEC_SIZE> struct V_vec_ {};
template <> struct V_vec_<float, 1> { using Type = float; };
template <> struct V_vec_<float, 2> { using Type = float2; };
template <> struct V_vec_<float, 4> { using Type = float4; };
template <> struct V_vec_<float16, 2> { using Type = uint32_t; };
template <> struct V_vec_<float16, 4> { using Type = uint2; };
template <> struct V_vec_<float16, 8> { using Type = uint4; };
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
template <typename T> struct V_vec_acum_fp32_ {};
// template <> struct V_vec_acum_fp32_<float> { using Type = float; };
// template <> struct V_vec_acum_fp32_<float2> { using Type = float2; };
template <> struct V_vec_acum_fp32_<float4> { using Type = float4; };
// template <> struct V_vec_acum_fp32_<uint32_t> { using Type = float2; };
// template <> struct V_vec_acum_fp32_<uint2 > { using Type = Float4_; };
template <> struct V_vec_acum_fp32_<uint4> { using Type = Float8_; };
#endif
// clang-format on
inline __device__ float half_to_float(uint16_t h) {
float f;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
return f;
}
inline __device__ float2 half2_to_float2(uint32_t v) {
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
return make_float2(half_to_float(lo), half_to_float(hi));
}
inline __device__ uint32_t float2_to_half2(float2 f) {
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
: "=r"(tmp.u32)
: "f"(f.y), "f"(f.x));
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif
return tmp.u32;
}
inline __device__ float add(float a, float b) { return a + b; }
inline __device__ float2 add(float2 a, float2 b) {
float2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ float4 add(float4 a, float4 b) {
float4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
uint16_t c;
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
}
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
inline __device__ uint2 add(uint2 a, uint2 b) {
uint2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ uint4 add(uint4 a, uint4 b) {
uint4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
inline __device__ float2 add(uint32_t a, float2 fb) {
float2 fa = half2_to_float2(a);
return add(fa, fb);
}
inline __device__ Float8_ add(uint4 a, Float8_ fb) {
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
template <typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b);
template <>
inline __device__ float mul<float, float>(float a, float b) {
return a * b;
}
template <>
inline __device__ float2 mul(float2 a, float2 b) {
float2 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
return c;
}
template <>
inline __device__ float4 mul(float4 a, float4 b) {
float4 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
c.z = a.z * b.z;
c.w = a.w * b.w;
return c;
}
template <>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
uint16_t c;
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
}
template <>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
template <>
inline __device__ uint2 mul(uint2 a, uint2 b) {
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
return c;
}
template <>
inline __device__ uint4 mul(uint4 a, uint4 b) {
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
return c;
}
inline __device__ float sum(float v) { return v; }
inline __device__ float sum(float2 v) { return v.x + v.y; }
inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; }
inline __device__ float sum(uint16_t v) { return half_to_float(v); }
inline __device__ float sum(uint32_t v) {
float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y;
}
inline __device__ float sum(uint2 v) {
uint32_t c = add(v.x, v.y);
return sum(c);
}
inline __device__ float sum(uint4 v) {
uint32_t c = add(v.x, v.y);
c = add(c, v.z);
c = add(c, v.w);
return sum(c);
}
template <typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<T, T, T>(a, b));
}
template <typename A, typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<A, T, T>(a, b));
}
inline __device__ constexpr uint32_t shfl_mask(int threads) {
return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u;
}
template <typename T>
inline __device__ __host__ T div_up(T m, T n) {
return (m + n - 1) / n;
}
inline __device__ float fma(float a, float b, float c) { return a * b + c; }
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
float2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ float4 fma(float4 a, float4 b, float4 c) {
float4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(d)
: "r"(a), "r"(b), "r"(c));
return d;
}
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
uint2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
uint4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ float2 fma(float a, float2 b, float2 c) {
float2 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
inline __device__ float4 fma(float a, float4 b, float4 c) {
float4 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
Float8_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
inline __device__ uint32_t h0_h0(uint16_t a) {
uint32_t b;
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
return b;
}
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
return fma(h0_h0(a), b, c);
}
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
uint32_t s = h0_h0(a);
uint2 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
uint32_t s = h0_h0(a);
uint4 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
inline __device__ float cast_to_float(float u) { return u; }
inline __device__ float2 cast_to_float(float2 u) { return u; }
inline __device__ float4 cast_to_float(float4 u) { return u; }
inline __device__ Float8_ cast_to_float(uint4 u) {
Float8_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
tmp.z = half2_to_float2(u.z);
tmp.w = half2_to_float2(u.w);
return tmp;
}
template <int THREADS_PER_KEY, typename K_vec, int N>
inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) {
K_vec qk_vec = mul<K_vec, K_vec, K_vec>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
}
return qk;
}
template <typename T, int THREADS_PER_KEY>
struct Qk_dot {
template <typename K_vec, int N>
static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) {
return qk_dot_<THREADS_PER_KEY>(q, k);
}
};
template <int WARPS_PER_BLOCK, int WARP_SIZE = 32>
inline __device__ float block_sum(float *red_smem, float sum) {
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
if (lane == 0) {
red_smem[warp] = sum;
}
__syncthreads();
if (lane < WARPS_PER_BLOCK) {
sum = red_smem[lane];
}
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
return __shfl_sync(uint32_t(-1), sum, 0);
}
inline __device__ void convert_from_float(float &dst, float src) { // NOLINT
dst = src;
}
inline __device__ void convert_from_float(float4 &dst, float4 src) { // NOLINT
dst = src;
}
inline __device__ void convert_from_float(plat::float16 &dst, // NOLINT
float src) {
dst = static_cast<plat::float16>(src);
}
inline __device__ void convert_from_float(uint4 &dst, Float8_ src) { // NOLINT
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
dst.z = float2_to_half2(src.z);
dst.w = float2_to_half2(src.w);
}
inline __device__ void zero(uint16_t &dst) { dst = uint16_t(0); } // NOLINT
template <typename T>
inline __device__ void zero(T &dst) { // NOLINT
constexpr int WORDS = sizeof(T) / 4;
union {
T raw;
uint32_t words[WORDS];
} tmp;
#pragma unroll
for (int ii = 0; ii < WORDS; ++ii) {
tmp.words[ii] = 0u;
}
dst = tmp.raw;
}
template <typename T, int Dh, int THREADS_PER_KEY, int THREADS_PER_VALUE,
int THREADS_PER_BLOCK>
__global__ void masked_multihead_attention_kernel(
Masked_multihead_attention_params<T> params) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
static_assert(Dh % THREADS_PER_KEY == 0, "");
static_assert(Dh % THREADS_PER_VALUE == 0, "");
constexpr int WARP_SIZE = 32;
constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE;
extern __shared__ char smem_[];
float *qk_smem = reinterpret_cast<float *>(smem_);
char *logits_smem_ = smem_;
// fp32 accum for logits
float *logits_smem = reinterpret_cast<float *>(logits_smem_);
T *out_smem = reinterpret_cast<T *>(smem_);
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
__shared__ T q_smem[Dh];
const int bi = blockIdx.y;
const int hi = blockIdx.x;
const int bhi = bi * params.num_head + hi;
const int tid = threadIdx.x;
float qk_max = -FLT_MAX;
// qkv [B, S=1, 3, num_head, head_dim]
int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh;
using Qk_vec = typename Qk_vec_<T, Dh>::Type;
constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
static_assert(Dh % QK_VEC_SIZE == 0 && Dh / QK_VEC_SIZE <= WARP_SIZE, "");
constexpr int QK_VECS_PER_WARP = Dh / QK_VEC_SIZE;
// cache_k, [B, num_head, head_dim / x, max_seq_len, x]
// x == 4/8 for FP32/FP16, 128bit, 16Byte
constexpr int QK_ELTS_IN_16B = 16 / sizeof(T);
constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec);
const T *q_base = params.qkv;
const T *k_base = params.qkv + params.num_head * Dh;
const T *q_bias_base = params.qkv_bias;
const T *k_bias_base = params.qkv_bias + params.num_head * Dh;
if (tid < QK_VECS_PER_WARP) {
int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE;
int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE;
Qk_vec q = *reinterpret_cast<const Qk_vec *>(&q_base[qk_offset]);
Qk_vec k = *reinterpret_cast<const Qk_vec *>(&k_base[qk_offset]);
Qk_vec q_bias =
*reinterpret_cast<const Qk_vec *>(&q_bias_base[qk_bias_offset]);
Qk_vec k_bias =
*reinterpret_cast<const Qk_vec *>(&k_bias_base[qk_bias_offset]);
q = add(q, q_bias);
// TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510
// we may not require k_bias.
k = add(k, k_bias);
*reinterpret_cast<Qk_vec *>(&q_smem[tid * QK_VEC_SIZE]) = q;
int co = tid / QK_VECS_IN_16B;
int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE;
int offset = bhi * params.max_seq_length * Dh +
co * params.max_seq_length * QK_ELTS_IN_16B +
params.timestep * QK_ELTS_IN_16B + ci;
*reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k;
float qk = dot<Qk_vec, Qk_vec>(q, k);
#pragma unroll
for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
}
qk *= params.inv_sqrt_dh;
if (tid == 0) {
// NOTE(wangxi): mask must be 0.0
// T mask = params.attn_mask[
// bi * (params.timestep + 1) + params.timestep];
// qk += static_cast<float>(mask);
qk_max = qk;
qk_smem[params.timestep] = qk;
}
}
__syncthreads();
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if (bi == 0 && hi == 0 && tid == 0) {
printf("=======q_out=======\n");
for (int i = 0; i < Dh; ++i) printf("%f ", static_cast<float>(q_smem[i]));
printf("\n");
}
__syncthreads();
#endif
using K_vec = typename K_vec_<T, THREADS_PER_KEY>::Type;
constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T);
static_assert(Dh % K_VEC_SIZE == 0, "");
constexpr int K_ELTS_PER_THREAD = Dh / THREADS_PER_KEY;
constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE;
int ko = tid / THREADS_PER_KEY;
int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE;
K_vec q[K_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < K_VECS_PER_THREAD; ++i) {
q[i] = *reinterpret_cast<const K_vec *>(
&q_smem[ki + i * THREADS_PER_KEY * K_VEC_SIZE]);
}
constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY;
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
T *k_cache = &params.cache_kv[bhi * params.max_seq_length * Dh + ki];
int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
K_vec k[K_VECS_PER_THREAD];
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_seq_length + ti;
if (ti < params.timestep) {
k[ii] = *reinterpret_cast<const K_vec *>(&k_cache[jj * QK_ELTS_IN_16B]);
}
}
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q, k) * params.inv_sqrt_dh;
// bool is_mask = false;
if (ti < params.timestep && tid % THREADS_PER_KEY == 0) {
// qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
T mask = params.attn_mask[bi * (params.timestep + 1) + ti];
qk += static_cast<float>(mask);
qk_max = fmaxf(qk_max, qk);
qk_smem[ti] = qk;
}
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
const int warp = tid / WARP_SIZE;
const int lane = tid % WARP_SIZE;
if (lane == 0) {
red_smem[warp] = qk_max;
}
__syncthreads();
qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if (bi == 0 && hi == 0 && tid == 0) {
printf("=======qk_out=======\n");
for (int i = 0; i <= params.timestep; ++i) printf("%f ", qk_smem[i]);
printf("qk_max=%f\n", qk_max);
}
__syncthreads();
#endif
float sum = 0.f;
for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) {
// bool is_mask = false;
// float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max);
float logit = __expf(qk_smem[ti] - qk_max);
sum += logit;
qk_smem[ti] = logit;
}
sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
// FIXME(wangxi): need add 1.e-6f?
float inv_sum = __fdividef(1.f, sum + 1.e-6f);
for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) {
convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum);
}
__syncthreads();
constexpr int V_VEC_SIZE = Dh / THREADS_PER_VALUE;
using V_vec = typename V_vec_<T, V_VEC_SIZE>::Type;
int vo = tid / THREADS_PER_VALUE;
int vi = (tid % THREADS_PER_VALUE) * V_VEC_SIZE;
T *v_cache = &params.cache_kv[params.batch_size * params.num_head *
params.max_seq_length * Dh +
bhi * params.max_seq_length * Dh + vi];
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
using V_vec_acum = typename V_vec_acum_fp32_<V_vec>::Type;
#else
using V_vec_acum = V_vec;
#endif
V_vec_acum out;
zero(out);
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) {
V_vec v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]);
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float logit = logits_smem[ti];
out = fma(logit, cast_to_float(v), out);
#else
T logit = logits_smem[ti];
// Update the partial sums.
out = fma(logit, v, out);
#endif
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if (bi == 0 && hi == 0 && tid == 0) {
printf("======logits_out=====\n");
for (int i = 0; i <= params.timestep; ++i) printf("%f ", logits_smem[i]);
printf("\n");
}
__syncthreads();
#endif
if (vo == (params.timestep % V_PER_ITER)) {
V_vec v = *reinterpret_cast<const V_vec *>(
&params.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]);
V_vec v_bias = *reinterpret_cast<const V_vec *>(
&params.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]);
v = add(v, v_bias);
*reinterpret_cast<V_vec *>(&v_cache[params.timestep * Dh]) = v;
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
out = fma(logits_smem[params.timestep], cast_to_float(v), out);
#else
out = fma(logits_smem[params.timestep], v, out);
#endif
}
__syncthreads();
#pragma unroll
for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) {
int midpoint = active_groups / 2;
if (vo >= midpoint && vo < active_groups) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float(
*reinterpret_cast<V_vec *>(&out_smem[(vo - midpoint) * Dh + vi]),
out);
#else
*reinterpret_cast<V_vec *>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
#endif
}
__syncthreads();
if (vo < midpoint) {
out = add(*reinterpret_cast<const V_vec *>(&out_smem[vo * Dh + vi]), out);
}
__syncthreads();
}
if (vo == 0) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float(*reinterpret_cast<V_vec *>(&params.out[bhi * Dh + vi]),
out);
#else
*reinterpret_cast<V_vec *>(&params.out[bhi * Dh + vi]) = out;
#endif
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
__syncthreads();
if (bi == 0 && hi == 0 && tid == 0) {
printf("======fmha_out=====\n");
for (int i = 0; i < Dh; ++i)
printf("%f ", static_cast<float>(params.out[i]));
printf("\n");
}
#endif
#else
assert(false);
#endif
}
template <typename T>
inline size_t smem_size_in_bytes(
const Masked_multihead_attention_params<T> &params, int dim_head,
int threads_per_value, int threads_per_block) {
size_t qk_sz = div_up(params.timestep + 1, 4) * 16;
size_t logits_sz = 0;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
if (sizeof(T) != 4) {
logits_sz = div_up(params.max_seq_length, 4) * 4 * sizeof(T);
}
#endif
size_t softmax_sz = qk_sz + logits_sz;
int rows_per_red = threads_per_block / threads_per_value;
size_t red_sz = rows_per_red * dim_head * sizeof(T) / 2;
return max(softmax_sz, red_sz);
}
#define MMHA_LAUNCH_KERNEL(T, Dh, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, stream) \
size_t smem_sz = \
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_head, params.batch_size); \
masked_multihead_attention_kernel< \
T, Dh, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
template <typename T, int Dh>
void fmha_launch_kernel(const Masked_multihead_attention_params<T> &params,
const cudaStream_t &stream) {
constexpr int THREADS_PER_VALUE = Dh * sizeof(T) / 16;
if (params.timestep < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, 4, THREADS_PER_VALUE, 64, stream);
} else if (params.timestep < 2048) {
MMHA_LAUNCH_KERNEL(T, Dh, 2, THREADS_PER_VALUE, 128, stream);
} else {
MMHA_LAUNCH_KERNEL(T, Dh, 1, THREADS_PER_VALUE, 256, stream);
}
}
template <typename T>
void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor,
const Tensor &qkv_bias_tensor, const Tensor &src_mask_tensor,
Tensor *cache_kv_tensor, Tensor *out_tensor, int batch_size,
int max_seq_length, int num_head, int dim_head, int timestep,
float inv_sqrt_dh) {
Masked_multihead_attention_params<T> params;
params.out = out_tensor->data<T>();
params.qkv = qkv_tensor.data<T>();
params.qkv_bias = qkv_bias_tensor.data<T>();
params.attn_mask = src_mask_tensor.data<T>();
params.cache_kv = cache_kv_tensor->data<T>();
params.batch_size = batch_size;
params.num_head = num_head;
params.timestep = timestep;
params.max_seq_length = max_seq_length;
params.inv_sqrt_dh = inv_sqrt_dh;
switch (dim_head) {
case 32:
fmha_launch_kernel<T, 32>(params, dev_ctx.stream());
break;
case 64:
fmha_launch_kernel<T, 64>(params, dev_ctx.stream());
break;
case 128:
fmha_launch_kernel<T, 128>(params, dev_ctx.stream());
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"dim_head = %d is unsupport, only support "
"dim_head = 32, 64 or 128 for now.",
dim_head));
}
}
// NOTE: simd with 16Bytes(128bit), float is 4, float16 is 8
constexpr int VEC_16B = 16;
template <typename T>
__global__ void write_cache_k_kernel(T *cache_k, const T *k, const int num_head,
const int dim_head, const int seq_len,
const int max_seq_len) {
const int bi = blockIdx.y;
const int hi = blockIdx.z;
constexpr int X_ELEMS = VEC_16B / sizeof(T);
// [bsz, num_head, seq_len, dim_head/x, x]
auto k_src = reinterpret_cast<const uint4 *>(
k + bi * num_head * seq_len * dim_head + hi * seq_len * dim_head);
// [bsz, num_head, dim_head/x, max_seq_len, x]
auto k_dst = reinterpret_cast<uint4 *>(
cache_k + bi * num_head * max_seq_len * dim_head +
hi * max_seq_len * dim_head);
const int out_idx = blockIdx.x * blockDim.x + threadIdx.x;
// vec size
int dim_head_div_x = dim_head / X_ELEMS;
// FIXME(wangxi): num_head is not need?
// if (out_idx >= num_head * dim_head_div_x * max_seq_len) return;
if (out_idx >= dim_head_div_x * max_seq_len) return;
int idx = out_idx;
const int k_seq_len_id = idx % max_seq_len;
// idx = (idx - k_seq_len_id) / max_seq_len;
idx = idx / max_seq_len;
const int k_vec_id = idx % dim_head_div_x;
if (k_seq_len_id < seq_len) {
k_dst[out_idx] = k_src[k_seq_len_id * dim_head_div_x + k_vec_id];
}
}
template <typename T>
__global__ void write_cache_v_kernel(T *cache_v, const T *v, const int num_head,
const int dim_head, const int seq_len,
const int max_seq_len) {
const int bi = blockIdx.y;
const int hi = blockIdx.z;
// [bsz, num_head, seq_len, dim_head/x, x]
auto v_src = reinterpret_cast<const uint4 *>(
v + bi * num_head * seq_len * dim_head + hi * seq_len * dim_head);
// [bsz, num_head, max_seq_len, dim_head/x, x]
auto v_dst = reinterpret_cast<uint4 *>(
cache_v + bi * num_head * max_seq_len * dim_head +
hi * max_seq_len * dim_head);
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
constexpr int X_ELEMS = VEC_16B / sizeof(T);
const int dim_head_div_x = dim_head / X_ELEMS;
if (idx >= dim_head_div_x * seq_len) return;
v_dst[idx] = v_src[idx];
}
template <typename T>
void write_cache_kv(const platform::CUDADeviceContext &dev_ctx, T *cache_k,
T *cache_v, const T *k, const T *v, const int bsz,
const int num_head, const int seq_len,
const int max_seq_len, const int dim_head) {
constexpr int block_sz = 128;
constexpr int x = VEC_16B / sizeof(T);
assert(dim_head % x == 0);
PADDLE_ENFORCE_EQ(
dim_head % x, 0,
platform::errors::PreconditionNotMet(
"dim_head=%d must be divisible by vec_size=%d", dim_head, x));
int max_size = max_seq_len * dim_head / x;
int size = seq_len * dim_head / x;
dim3 grid(div_up(max_size, block_sz), bsz, num_head);
dim3 grid_v(div_up(size, block_sz), bsz, num_head);
// transpose [bsz, num_head, seq_len, dim_head/x, x]->
// [bsz, num_head, dim_head/x, max_seq_len, x]
write_cache_k_kernel<<<grid, block_sz, 0, dev_ctx.stream()>>>(
cache_k, k, num_head, dim_head, seq_len, max_seq_len);
// copy [bsz, num_head, seq_len, dim_head/x, x]->
// [bsz, num_head, max_seq_len, dim_head/x, x]
write_cache_v_kernel<<<grid_v, block_sz, 0, dev_ctx.stream()>>>(
cache_v, v, num_head, dim_head, seq_len, max_seq_len);
}
} // namespace
template <typename T>
class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
auto place = ctx.GetPlace();
auto &dev_ctx = ctx.cuda_device_context();
auto *time_step = ctx.Input<Tensor>("TimeStep");
// 0. input
auto *input_x = ctx.Input<Tensor>("X");
const auto input_x_dims = input_x->dims();
int bsz = input_x_dims[0];
int seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2];
int bsz_seq = bsz * seq_len;
// 1. layer norm
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
const float epsilon = ctx.Attr<float>("epsilon");
auto ln_scales = ctx.MultiInput<Tensor>("LnScale");
auto ln_biases = ctx.MultiInput<Tensor>("LnBias");
auto ln_compute = AttnLayerNorm<T>(dev_ctx, epsilon, bsz_seq, dim_embed);
Tensor ln_mean, ln_var;
auto *ln_mean_data = ln_mean.mutable_data<U>({bsz_seq}, place);
auto *ln_var_data = ln_var.mutable_data<U>({bsz_seq}, place);
// 2. qkv
// x: qkv's input [batch_size, seq_len, dim_embed]
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto qkv_weights = ctx.MultiInput<Tensor>("QKVW");
auto qkv_biases = ctx.MultiInput<Tensor>("QKVBias");
const auto qkv_w_dims = qkv_weights[0]->dims();
int num_head = qkv_w_dims[1];
int dim_head = qkv_w_dims[2];
int hidden_size = num_head * dim_head;
int output_size = 3 * hidden_size;
int input_size = dim_embed;
bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr;
// (transA, transB, compute_bias) = (false, true, false)
auto qkv_compute = AttnMatMul<T>(dev_ctx, false, true, bsz_seq, output_size,
input_size, compute_bias);
Tensor qkv_out;
auto *qkv_out_data =
qkv_out.mutable_data<T>({bsz, seq_len, 3, num_head, dim_head}, place);
// 3. fmha
AttnDropoutParam attn_param(true, "upscale_in_train", 0.0, true, true, 0,
nullptr);
auto fmha_compute =
FMHARef<T>(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param);
auto *src_mask = ctx.Input<Tensor>("SrcMask");
auto cache_kvs = ctx.MultiInput<Tensor>("CacheKV");
auto cache_kv_outs = ctx.MultiOutput<Tensor>("CacheKVOut");
// auto *time_step = ctx.Input<Tensor>("TimeStep");
auto out_seq_len = seq_len;
if (time_step) {
PADDLE_ENFORCE_EQ(time_step->place(), platform::CPUPlace(),
platform::errors::PreconditionNotMet(
"The place of input(TimeStep) must be CPUPlace."));
// cache_seq_len
int time_step_value = time_step->data<int>()[0];
PADDLE_ENFORCE_GT(time_step_value, 0,
platform::errors::PreconditionNotMet(
"The value of time_step must > 0, but now is %d",
time_step_value));
PADDLE_ENFORCE_EQ(
seq_len, 1,
platform::errors::PreconditionNotMet(
"In decode stage, the seq_len of input must be 1, but now is %d",
seq_len));
out_seq_len += time_step_value;
}
Tensor transpose_out_2, qk_out;
auto *transpose_out_2_data = transpose_out_2.mutable_data<T>(
{3, bsz, num_head, seq_len, dim_head}, place);
auto *qk_out_data =
qk_out.mutable_data<T>({bsz, num_head, seq_len, out_seq_len}, place);
Tensor src_mask_out, softmax_out;
Tensor attn_dropout_mask_out, attn_dropout_out;
Tensor qktv_out, fmha_out;
auto *src_mask_out_data = src_mask_out.mutable_data<T>(
{bsz, num_head, seq_len, out_seq_len}, place);
auto *softmax_out_data = softmax_out.mutable_data<T>(
{bsz, num_head, seq_len, out_seq_len}, place);
auto *attn_dropout_mask_out_data = attn_dropout_mask_out.mutable_data<T>(
{bsz, num_head, seq_len, out_seq_len}, place);
auto *attn_dropout_data_data = attn_dropout_out.mutable_data<T>(
{bsz, num_head, seq_len, out_seq_len}, place);
auto *qktv_out_data =
qktv_out.mutable_data<T>({bsz, num_head, seq_len, dim_head}, place);
auto *fmha_out_data =
fmha_out.mutable_data<T>({bsz, seq_len, num_head, dim_head}, place);
// 4. out_linear
auto out_linear_weights = ctx.MultiInput<Tensor>("OutLinearW");
auto out_linear_biases = ctx.MultiInput<Tensor>("OutLinearBias");
int ring_id = ctx.Attr<int>("ring_id");
// (transA, transB, compute_bias) = (false, false, false)
auto out_linear_compute = AttnMatMul<T>(dev_ctx, false, false, bsz_seq,
dim_embed, hidden_size, false);
// 5. ln(residual + bias)
DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon);
auto ffn_ln_scales = ctx.MultiInput<Tensor>("FFNLnScale");
auto ffn_ln_biases = ctx.MultiInput<Tensor>("FFNLnBias");
Tensor bias_dropout_residual_out, dropout_mask_out;
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out.mutable_data<T>({bsz, seq_len, dim_embed},
place);
auto *dropout_mask_out_data = dropout_mask_out.mutable_data<uint8_t>(
{bsz, seq_len, dim_embed}, place);
// 6. ffn matmul1
auto ffn1_weights = ctx.MultiInput<Tensor>("FFN1Weight");
auto ffn1_biases = ctx.MultiInput<Tensor>("FFN1Bias");
auto ffn1_weight_dim = ffn1_weights[0]->dims();
int dim_ffn = ffn1_weight_dim[1];
auto ffn1_linear_compute = AttnMatMul<T>(dev_ctx, false, false, bsz_seq,
dim_ffn, dim_embed, false);
Tensor ffn1_out;
auto *ffn1_out_data = ffn1_out.mutable_data<T>({bsz_seq, dim_ffn}, place);
// 7. ffn act + bias
DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param);
Tensor ffn1_dropout_out, ffn1_dropout_mask;
auto *ffn1_dropout_out_data =
ffn1_dropout_out.mutable_data<T>({bsz_seq, dim_ffn}, place);
auto *ffn1_dropout_mask_data =
ffn1_dropout_mask.mutable_data<uint8_t>({bsz_seq, dim_ffn}, place);
// 8. ffn2 matmul
auto ffn2_weights = ctx.MultiInput<Tensor>("FFN2Weight");
auto ffn2_biases = ctx.MultiInput<Tensor>("FFN2Bias");
auto ffn2_linear_compute = AttnMatMul<T>(dev_ctx, false, false, bsz_seq,
dim_embed, dim_ffn, false);
// 9. ffn2 residual bias
DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t> ffn2_fused_dropout_helper(
dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon);
// calc
auto *out = ctx.Output<Tensor>("Out");
auto *from_data = out->mutable_data<T>(place);
Tensor *from_tensor = out;
Tensor tmp_out;
auto *tmp_out_data =
tmp_out.mutable_data<T>({bsz, seq_len, dim_embed}, place);
auto *x_data = input_x->data<T>();
Tensor *buf0 = nullptr;
Tensor *buf1 = nullptr;
// step0: x --> buf1
// step1: buf1 --> buf0
// step2: buf0 --> buf1
int layers = qkv_weights.size();
if (layers & 1) {
// odd, set buf1 as out
buf0 = &tmp_out;
buf1 = out;
} else {
// even, set buf0 as out
buf0 = out;
buf1 = &tmp_out;
}
for (int i = 0; i < layers; ++i) {
// step1. layer_norm
if (i == 0 && pre_layer_norm) {
auto *ln_scale_data = ln_scales[i]->data<U>();
auto *ln_bias_data = ln_biases[i]->data<U>();
// TODO(wangxi): can remove mean var in inference
ln_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data,
buf1->data<T>(), ln_mean_data, ln_var_data);
} else if (!pre_layer_norm) {
PADDLE_THROW(platform::errors::Unimplemented(
"Unimplemented post_layer_norm for now."));
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step1";
#endif
// step2. qkv
const Tensor *qkv_bias = qkv_biases.size() > 0 ? qkv_biases[i] : nullptr;
// NOTE: in decoder stage, bias is fused in fmha
const Tensor *bias = time_step ? nullptr : qkv_bias;
qkv_compute.ComputeForward(qkv_weights[i], buf1, bias, &qkv_out,
&qkv_out);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step2";
#endif
// step3. fmha
const Tensor *cache_kv = cache_kvs.size() > 0 ? cache_kvs[i] : nullptr;
Tensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr;
if (time_step) { // generation decoder stage
// [2, batch_size, num_head, max_seq_len, head_size]
int max_seq_len = cache_kv->dims()[3];
fmha<T>(dev_ctx, qkv_out, *qkv_bias, *src_mask, cache_kv_out, &fmha_out,
bsz, max_seq_len, num_head, dim_head, time_step->data<int>()[0],
1. / sqrt(dim_head));
} else if (cache_kv_out) { // generation context stage
// TODO(wangxi): can remove dropout in inference
fmha_compute.ComputeForward(
qkv_out, nullptr, src_mask, &transpose_out_2, nullptr, &qk_out,
&src_mask_out, &softmax_out, &attn_dropout_mask_out,
&attn_dropout_out, &qktv_out, &fmha_out);
// [3, bsz, num_head, seq_len, head_dim]
T *qkv_data = transpose_out_2_data;
int64_t q_size = bsz * seq_len * num_head * dim_head;
int64_t k_size = q_size;
const T *q_ptr = qkv_data;
const T *k_ptr = q_ptr + q_size;
const T *v_ptr = k_ptr + k_size;
// [2, bsz, num_head, max_seq_len, head_dim]
int max_seq_len = cache_kv_out->dims()[3];
T *cache_kv_data = cache_kv_out->data<T>();
int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head;
T *cache_k_ptr = cache_kv_data;
T *cache_v_ptr = cache_kv_data + cache_k_size;
write_cache_kv<T>(dev_ctx, cache_k_ptr, cache_v_ptr, k_ptr, v_ptr, bsz,
num_head, seq_len, max_seq_len, dim_head);
} else { // not generation
// TODO(wangxi): can remove dropout in inference
fmha_compute.ComputeForward(
qkv_out, cache_kv, src_mask, &transpose_out_2, cache_kv_out,
&qk_out, &src_mask_out, &softmax_out, &attn_dropout_mask_out,
&attn_dropout_out, &qktv_out, &fmha_out);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step3";
#endif
// step4. out_linear
out_linear_compute.ComputeForward(out_linear_weights[i], &fmha_out,
nullptr, buf1, nullptr);
AllReduce<T>(*buf1, ring_id, dev_ctx);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step4";
#endif
// step5. ln(residual + dropout(input + bias))
if (pre_layer_norm) {
auto *ln_scale_data = ffn_ln_scales[i]->data<U>();
auto *ln_bias_data = ffn_ln_biases[i]->data<U>();
auto *out_linear_bias_data = out_linear_biases[i]->data<T>();
// inplace
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
dev_ctx, buf1->data<T>(), x_data, out_linear_bias_data,
ln_scale_data, ln_bias_data, bias_dropout_residual_out_data,
dropout_mask_out_data, buf1->data<T>(), ln_mean_data, ln_var_data);
} else {
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step5";
#endif
// step6. ffn matmul1
ffn1_linear_compute.ComputeForward(ffn1_weights[i], buf1, nullptr,
&ffn1_out, nullptr);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step6";
#endif
// step7. act bias
// TODO(wangxi): remove dropout mask in inference
fused_act_dropout_helper.DropoutActBias(
dev_ctx, ffn1_out_data, ffn1_biases[i]->data<T>(), "gelu",
ffn1_dropout_out_data, ffn1_dropout_mask_data);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step7";
#endif
// step8. ffn matmul2
ffn2_linear_compute.ComputeForward(ffn2_weights[i], &ffn1_dropout_out,
nullptr, buf1, nullptr);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step8.0";
#endif
AllReduce<T>(*buf1, ring_id, dev_ctx);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step8.1";
#endif
// step9. residual bias
if (pre_layer_norm) {
// TODO(wangxi): remove dropout mask in inference
if (i < layers - 1) {
auto *ln_scale_data = ln_scales[i + 1]->data<U>();
auto *ln_bias_data = ln_biases[i + 1]->data<U>();
ffn2_fused_dropout_helper.LayernormResidualDropoutBias(
dev_ctx, buf1->data<T>(), bias_dropout_residual_out_data,
ffn2_biases[i]->data<T>(), ln_scale_data, ln_bias_data,
buf1->data<T>(), dropout_mask_out_data, buf0->data<T>(),
ln_mean_data, ln_var_data);
} else {
ffn2_fused_dropout_helper.ResidualDropoutBias(
dev_ctx, buf1->data<T>(), bias_dropout_residual_out_data,
ffn2_biases[i]->data<T>(), buf1->data<T>(),
dropout_mask_out_data);
}
} else {
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step9";
#endif
x_data = buf1->data<T>();
std::swap(buf0, buf1);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(fused_multi_transformer,
ops::FusedMultiTransformerOpKernel<plat::float16>,
ops::FusedMultiTransformerOpKernel<float>);
...@@ -32,6 +32,10 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -32,6 +32,10 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"fused_attention", {"fused_attention",
{"X", "LnScale", "LnBias", "QKVW", "QKVBias", "CacheKV", "SrcMask", {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "CacheKV", "SrcMask",
"OutLinearW", "OutLinearBias", "Ln2Scale", "Ln2Bias"}}, "OutLinearW", "OutLinearBias", "Ln2Scale", "Ln2Bias"}},
{"fused_multi_transformer",
{"X", "LnScale", "LnBias", "QKVW", "QKVBias", "CacheKV", "TimeStep",
"SrcMask", "OutLinearW", "OutLinearBias", "FFNLnScale", "FFNLnBias",
"FFN1Weight", "FFN1Bias", "FFN2Weight", "FFN2Bias"}},
{"instance_norm", {"X", "Scale", "Bias"}}, {"instance_norm", {"X", "Scale", "Bias"}},
{"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}}, {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}},
{"label_smooth", {"X", "PriorDist"}}, {"label_smooth", {"X", "PriorDist"}},
...@@ -176,6 +180,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = { ...@@ -176,6 +180,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"lamb", {"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}}, "MasterParamOut"}},
{"fused_multi_transformer", {"CacheKVOut", "Out"}},
}; };
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
...@@ -253,6 +258,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = { ...@@ -253,6 +258,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"assign_value", {"Out"}}, {"assign_value", {"Out"}},
{"split", {"Out"}}, {"split", {"Out"}},
{"concat", {"Out"}}, {"concat", {"Out"}},
{"fused_multi_transformer", {"CacheKVOut"}},
}; };
// NOTE(pangyoki): Tensor View Strategy. // NOTE(pangyoki): Tensor View Strategy.
......
...@@ -162,6 +162,7 @@ gray_list = { ...@@ -162,6 +162,7 @@ gray_list = {
'split', 'split',
'fused_feedforward', 'fused_feedforward',
'fused_attention', 'fused_attention',
'fused_multi_transformer',
} }
# The set of ops that don't support fp16 calculation # The set of ops that don't support fp16 calculation
......
...@@ -109,6 +109,8 @@ def _keep_fp32_input(op, in_name): ...@@ -109,6 +109,8 @@ def _keep_fp32_input(op, in_name):
return in_name in { return in_name in {
'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias" 'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias"
} }
if op_type == 'fused_multi_transformer':
return in_name in {'LnScale', 'LnBias', 'FFNLnScale', 'FFNLnBias'}
return False return False
......
...@@ -25,6 +25,7 @@ list(APPEND DIST_TEST_OPS test_ir_pass_pipeline) ...@@ -25,6 +25,7 @@ list(APPEND DIST_TEST_OPS test_ir_pass_pipeline)
list(APPEND DIST_TEST_OPS test_static_model_parallel) list(APPEND DIST_TEST_OPS test_static_model_parallel)
list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_feedforward) list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_feedforward)
list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_attention) list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_attention)
list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_multi_transformer)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_se_resnext) list(APPEND DIST_TEST_OPS test_parallel_dygraph_se_resnext)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height)
...@@ -128,6 +129,7 @@ if(NOT WITH_GPU) ...@@ -128,6 +129,7 @@ if(NOT WITH_GPU)
LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op) LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api)
LIST(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer) LIST(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer)
LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op) LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op) LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op)
...@@ -1185,6 +1187,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) ...@@ -1185,6 +1187,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
set_tests_properties(test_static_model_parallel PROPERTIES TIMEOUT 240) set_tests_properties(test_static_model_parallel PROPERTIES TIMEOUT 240)
set_tests_properties(test_static_model_parallel_fused_feedforward PROPERTIES TIMEOUT 120) set_tests_properties(test_static_model_parallel_fused_feedforward PROPERTIES TIMEOUT 120)
set_tests_properties(test_static_model_parallel_fused_attention PROPERTIES TIMEOUT 120) set_tests_properties(test_static_model_parallel_fused_attention PROPERTIES TIMEOUT 120)
set_tests_properties(test_static_model_parallel_fused_multi_transformer PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_split_embedding set_tests_properties(test_collective_split_embedding
test_collective_split_embedding_none_divisible test_collective_split_embedding_none_divisible
test_collective_split_row_linear test_collective_split_row_linear
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle
import paddle.fluid as fluid
from test_dist_base import TestDistRunnerBase, runtime_main
from paddle.incubate.nn import FusedMultiTransformer
import paddle.distributed.fleet as fleet
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid import core
from paddle.nn.initializer import Constant
paddle.enable_static()
def get_param_attr(weight, bias):
weight_attr = paddle.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(weight))
bias_attr = paddle.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(bias))
return weight_attr, bias_attr
DTYPE = "float32"
MODEL_PARALLEL_SIZE = 2
num_head = 2 * MODEL_PARALLEL_SIZE
dim_head = 4
hidden = num_head * dim_head
dim_ffn = 4 * hidden
def create_model(data, rank):
np.random.seed(2021)
ln_w = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE)
ln_b = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE)
qkv_w = np.random.uniform(
-1, 1, size=(3, num_head, dim_head, hidden)).astype(DTYPE)
qkv_b = np.random.uniform(-1, 1, size=(3, num_head, dim_head)).astype(DTYPE)
linear_w = np.random.uniform(
-1, 1, size=(num_head * dim_head, hidden)).astype(DTYPE)
linear_b = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE)
ffn_ln_w = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE)
ffn_ln_b = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE)
ffn1_w = np.random.uniform(-1, 1, size=(hidden, dim_ffn)).astype(DTYPE)
ffn1_b = np.random.uniform(-1, 1, size=(dim_ffn, )).astype(DTYPE)
ffn2_w = np.random.uniform(-1, 1, size=(dim_ffn, hidden)).astype(DTYPE)
ffn2_b = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE)
if rank is not None:
start = 0 if rank == 0 else (num_head // MODEL_PARALLEL_SIZE)
end = start + (num_head // MODEL_PARALLEL_SIZE)
col_qkv_w = qkv_w[:, start:end, :, :]
col_qkv_b = qkv_b[:, start:end, :]
row_linear_w = linear_w[(start * dim_head):(end * dim_head), :]
ln_w_attr, ln_b_attr = get_param_attr(ln_w, ln_b)
qkv_w_attr, qkv_b_attr = get_param_attr(col_qkv_w, col_qkv_b)
linear_w_attr, linear_b_attr = get_param_attr(row_linear_w, linear_b)
start = 0 if rank == 0 else (dim_ffn // MODEL_PARALLEL_SIZE)
end = start + (dim_ffn // MODEL_PARALLEL_SIZE)
col_ffn1_w = ffn1_w[:, start:end]
col_ffn1_b = ffn1_b[start:end]
row_ffn2_w = ffn2_w[start:end, :]
ffn_ln_w_attr, ffn_ln_b_attr = get_param_attr(ffn_ln_w, ffn_ln_b)
ffn1_w_attr, ffn1_b_attr = get_param_attr(col_ffn1_w, col_ffn1_b)
ffn2_w_attr, ffn2_b_attr = get_param_attr(row_ffn2_w, ffn2_b)
multi_transformer = FusedMultiTransformer(
hidden,
num_head,
dim_ffn,
dropout_rate=0.0,
activation="gelu",
normalize_before=True,
ln_scale_attrs=[ln_w_attr],
ln_bias_attrs=[ln_b_attr],
qkv_weight_attrs=[qkv_w_attr],
qkv_bias_attrs=[qkv_b_attr],
linear_weight_attrs=[linear_w_attr],
linear_bias_attrs=[linear_b_attr],
ffn_ln_scale_attrs=[ffn_ln_w_attr],
ffn_ln_bias_attrs=[ffn_ln_b_attr],
ffn1_weight_attrs=[ffn1_w_attr],
ffn1_bias_attrs=[ffn1_b_attr],
ffn2_weight_attrs=[ffn2_w_attr],
ffn2_bias_attrs=[ffn2_b_attr],
nranks=MODEL_PARALLEL_SIZE,
ring_id=0)
result = multi_transformer(data)
else:
ln_w_attr, ln_b_attr = get_param_attr(ln_w, ln_b)
qkv_w_attr, qkv_b_attr = get_param_attr(qkv_w, qkv_b)
linear_w_attr, linear_b_attr = get_param_attr(linear_w, linear_b)
ffn_ln_w_attr, ffn_ln_b_attr = get_param_attr(ffn_ln_w, ffn_ln_b)
ffn1_w_attr, ffn1_b_attr = get_param_attr(ffn1_w, ffn1_b)
ffn2_w_attr, ffn2_b_attr = get_param_attr(ffn2_w, ffn2_b)
multi_transformer = FusedMultiTransformer(
hidden,
num_head,
dim_ffn,
dropout_rate=0.0,
activation="gelu",
normalize_before=True,
ln_scale_attrs=[ln_w_attr],
ln_bias_attrs=[ln_b_attr],
qkv_weight_attrs=[qkv_w_attr],
qkv_bias_attrs=[qkv_b_attr],
linear_weight_attrs=[linear_w_attr],
linear_bias_attrs=[linear_b_attr],
ffn_ln_scale_attrs=[ffn_ln_w_attr],
ffn_ln_bias_attrs=[ffn_ln_b_attr],
ffn1_weight_attrs=[ffn1_w_attr],
ffn1_bias_attrs=[ffn1_b_attr],
ffn2_weight_attrs=[ffn2_w_attr],
ffn2_bias_attrs=[ffn2_b_attr])
result = multi_transformer(data)
# fused_multi_transformer have no backward
result.stop_gradient = True
predict = paddle.mean(result)
return predict
class TestModelParallel(TestDistRunnerBase):
def get_model(self, batch_size=2, use_dgc=False, dist_strategy=None):
# Input data
seq_len = 2
data_in = fluid.data(
name='data_in', shape=[batch_size, seq_len, hidden], dtype=DTYPE)
if dist_strategy:
data_loader = fluid.io.DataLoader.from_generator(
feed_list=[data_in],
capacity=64,
use_double_buffer=False,
iterable=False)
if dist_strategy:
fleet.init(is_collective=True)
strategy = fleet.DistributedStrategy()
strategy.tensor_parallel = True
strategy.tensor_parallel_configs = {'tensor_parallel_degree': 2}
rank = fleet.worker_index() if dist_strategy else None
avg_cost = create_model(data_in, rank)
opt = fluid.optimizer.SGD(0.1)
if dist_strategy:
dist_opt = fleet.distributed_optimizer(
optimizer=opt, strategy=strategy)
dist_opt.minimize(avg_cost)
else:
opt.minimize(avg_cost)
def gen_data():
np.random.seed(2021)
while True:
data = [np.random.random([seq_len, hidden]).astype(DTYPE)]
yield data
train_reader = paddle.batch(gen_data, batch_size=batch_size)
if dist_strategy:
return None, avg_cost, train_reader, None, None, None, data_loader
else:
return None, avg_cost, train_reader, None, None, None
if __name__ == "__main__":
runtime_main(TestModelParallel)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
import paddle.fluid.core as core
import paddle.nn.functional as F
import paddle.incubate.nn.functional as incubate_f
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.common import Linear, Dropout
from paddle.nn.layer.transformer import _convert_attention_mask
from paddle import tensor
from paddle.fluid import layers
import unittest
from op_test import OpTest
from paddle.fluid.framework import default_main_program
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.layer_helper import LayerHelper
from paddle.nn.initializer import Constant
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.framework import _non_static_mode, default_main_program
from paddle import _C_ops
from paddle.incubate.nn.functional import fused_multi_transformer
default_main_program().random_seed = 42
class TestFusedMultiTransformerOp(OpTest):
def setUp(self):
self.config()
self.generate_input_data()
self.rtol = 1e-5
# FIXME(wangxi): Because there is a problem with the test precision
# on A100, atol is temporarily set to 1e-2, and it will be
# changed back after the precision problem is solved.
self.atol = 1e-2
# make sure local development precision
if "V100" in paddle.device.cuda.get_device_name():
self.atol = 1e-4
if self.x_type is np.float16:
self.atol = 1e-1
paddle.set_default_dtype(self.x_type)
self.__class__.op_type = "fused_multi_transformer"
# use autograd to check grad in this unittest.
self.__class__.no_need_check_grad = False
bias_attr = paddle.fluid.ParamAttr(
initializer=paddle.fluid.initializer.Constant(value=0.0005))
self.q_proj = Linear(
self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=bias_attr)
#bias_attr=self.bias_attr)
self.k_proj = Linear(
self.kdim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.v_proj = Linear(
self.vdim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.out_proj = Linear(
self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.ffn1_proj = Linear(
self.embed_dim,
4 * self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.ffn2_proj = Linear(
4 * self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
paddle.set_default_dtype(np.float32)
self.norm = LayerNorm(self.embed_dim)
self.ffn_norm = LayerNorm(self.embed_dim)
paddle.set_default_dtype(self.x_type)
self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train")
self.activation = getattr(F, self.act_method)
def config(self):
# for debug
self.debug = False
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = True
# has_cache_kv, gen_cache_kv, stage
# False, False, not generation
# True, True, generation context stage
# True, False, generation decoder stage
self.has_cache_kv = False
self.gen_cache_kv = False
self.training = False
self.layers = 4
self.batch_size = 4
self.query_length = 128
self.cache_length = 128
self.head_dim = 64
self.num_heads = 8
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.act_method = 'gelu'
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
def generate_input_data(self):
self.query = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type)
out_seq_len = self.key_length
if self.has_cache_kv:
assert self.training is False, ValueError(
'cache_kv can only used in inference')
self.cache_kv = np.random.rand(2, self.batch_size, self.num_heads,
self.cache_length,
self.head_dim).astype(self.x_type)
if self.gen_cache_kv:
self.cache_kv[:] = 0
else:
out_seq_len += self.cache_length
else:
self.cache_kv = None
if self.has_attn_mask:
# [B, n_head, seq_len, out_seq_len]
self.attn_mask = np.ones(
(self.batch_size, 1, self.query_length, out_seq_len),
dtype=self.attn_mask_type)
if self.attn_mask_type == np.int64:
self.attn_mask = np.tril(self.attn_mask)
elif self.attn_mask_type == np.float64:
if self.has_cache_kv and not self.gen_cache_kv:
# NOTE: decoder stage, -1(out_seq_len) should no mask
self.attn_mask[:, :, :, -2] = 0.0
self.attn_mask = (self.attn_mask - 1.0) * 1e4
else:
self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e4
else:
raise ValueError(
"'attn_mask_type' should be 'int64' or 'float64'.")
else:
self.attn_mask = None
self.key, self.value = self.query, self.query
self.dout = np.random.random((self.batch_size, self.query_length,
self.embed_dim)).astype(self.x_type)
def GetBaselineOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
cache_kvs = []
cache_kv = None
if self.has_cache_kv:
cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
attn_mask = None
for i in range(self.layers):
residual = tensor_query
ln1_out = tensor_query
if self.pre_layer_norm:
ln1_out = self.norm(tensor_query)
q = self.q_proj(ln1_out)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3])
k = self.k_proj(ln1_out)
v = self.v_proj(ln1_out)
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])
if self.has_cache_kv:
# [1, B, n_head, cache_seq_len, head_dim]
cache_k, cache_v = paddle.split(cache_kv, 2)
cache_k = paddle.squeeze(cache_k, axis=0)
cache_v = paddle.squeeze(cache_v, axis=0)
# [B, n_head, cache_seq_len + seq_len, head_dim]
# out_seq_len = cache_seq_len + seq_len
if self.debug:
print('q out is')
print(q_out[0, 0, :, :])
print('cache k out seq=128')
print(k_out[0, 0, :, :])
if self.gen_cache_kv:
cache_kvs.append((k_out, v_out))
else:
k_out = paddle.concat([cache_k, k_out], axis=-2)
v_out = paddle.concat([cache_v, v_out], axis=-2)
# [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, out_seq_len]
qk_out = layers.matmul(
x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5)
if self.debug:
print('qk out is')
print(qk_out[0][0][0])
if attn_mask is not None:
attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype)
attn_mask_out = qk_out + attn_mask
if self.debug:
print('attn mask out is')
print(attn_mask_out[0][0][0])
softmax_out = F.softmax(attn_mask_out)
else:
softmax_out = F.softmax(qk_out)
if self.debug:
print('softmax out is')
print(softmax_out[0][0][0])
if self.dropout_prob:
dropout_out = F.dropout(
softmax_out,
self.dropout_prob,
training=self.training,
mode="upscale_in_train")
# [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, head_dim]
qktv_out = tensor.matmul(dropout_out, v_out)
else:
qktv_out = tensor.matmul(softmax_out, v_out)
fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3])
if self.debug:
print('fmha out is')
print(fmha_out[0][0][0])
out_linear_in = tensor.reshape(
x=fmha_out,
shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]])
out = self.out_proj(out_linear_in)
residual_out = residual + self.dropout(out)
if not self.pre_layer_norm:
attn_out = self.norm(residual_out)
else:
attn_out = residual_out
ffn_ln_out = attn_out
if self.pre_layer_norm:
ffn_ln_out = self.ffn_norm(attn_out)
ffn1_out = self.ffn1_proj(ffn_ln_out)
ffn1_out = self.dropout(self.activation(ffn1_out))
ffn2_out = self.ffn2_proj(ffn1_out)
residual_out = attn_out + self.dropout(ffn2_out)
final_out = residual_out
if not self.pre_layer_norm:
final_out = self.ffn_norm(residual_out)
tensor_query = final_out
if self.has_cache_kv and self.gen_cache_kv:
return final_out, cache_kvs
return final_out
def GetFusedMultiTransformerOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
q_proj_weight = paddle.to_tensor(
self.q_proj.weight, stop_gradient=False)
k_proj_weight = paddle.to_tensor(
self.k_proj.weight, stop_gradient=False)
v_proj_weight = paddle.to_tensor(
self.v_proj.weight, stop_gradient=False)
out_linear_weight = paddle.to_tensor(
self.out_proj.weight, stop_gradient=False)
ffn1_weight = paddle.to_tensor(
self.ffn1_proj.weight, stop_gradient=False)
ffn2_weight = paddle.to_tensor(
self.ffn2_proj.weight, stop_gradient=False)
if self.bias_attr is False:
qkv_bias_tensor = None
out_linear_bias = None
else:
q_proj_bias = paddle.to_tensor(
self.q_proj.bias, stop_gradient=False)
k_proj_bias = paddle.to_tensor(
self.k_proj.bias, stop_gradient=False)
v_proj_bias = paddle.to_tensor(
self.v_proj.bias, stop_gradient=False)
qkv_bias = np.concatenate(
(q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy()))
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
out_linear_bias = paddle.to_tensor(
self.out_proj.bias, stop_gradient=False)
ffn1_bias = paddle.to_tensor(
self.ffn1_proj.bias, stop_gradient=False)
ffn2_bias = paddle.to_tensor(
self.ffn2_proj.bias, stop_gradient=False)
ln_scale = paddle.to_tensor(self.norm.weight, stop_gradient=False)
ln_bias = paddle.to_tensor(self.norm.bias, stop_gradient=False)
ffn_ln_scale = paddle.to_tensor(
self.ffn_norm.weight, stop_gradient=False)
ffn_ln_bias = paddle.to_tensor(self.ffn_norm.bias, stop_gradient=False)
q_proj_weight = q_proj_weight.numpy().transpose((1, 0))
k_proj_weight = k_proj_weight.numpy().transpose((1, 0))
v_proj_weight = v_proj_weight.numpy().transpose((1, 0))
qkv_weight = np.concatenate(
(q_proj_weight, k_proj_weight, v_proj_weight))
qkv_weight = qkv_weight.reshape(
(3, self.num_heads, self.head_dim, self.embed_dim))
x = paddle.to_tensor(self.query, stop_gradient=False)
cache_kvs, cache_kv = None, None
time_step = None
if self.has_cache_kv:
cache_kvs = []
max_seq_length = (self.cache_length + 128) // 128 * 128
cache_kv = np.zeros(
[
2, self.batch_size, self.num_heads, max_seq_length,
self.head_dim
],
dtype=self.x_type)
elems = 4
if self.x_type is np.float16:
elems = 8
assert self.head_dim % elems == 0
v_elems = self.head_dim // elems
# [B, num_head, 128, head_dim]
# cache_k_tmp = self.cache_kv[0, :]
# [B, num_head, 128, head_dim / 4, 4]
cache_k_tmp = self.cache_kv[0].reshape([
self.batch_size, self.num_heads, self.cache_length, v_elems,
elems
])
# [B, num_head, head_dim / 4, 128, 4]
cache_k_tmp = cache_k_tmp.transpose([0, 1, 3, 2, 4])
cache_kv[0, :].reshape([
self.batch_size, self.num_heads, v_elems, max_seq_length, elems
])[:, :, :, :self.cache_length, :] = cache_k_tmp
cache_kv[1, :, :, :self.cache_length, :] = self.cache_kv[1]
if self.gen_cache_kv:
assert self.query_length == self.cache_length
cache_kv[:] = 0
else:
time_step = paddle.to_tensor(
[self.cache_length], dtype='int32', place=paddle.CPUPlace())
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
attn_mask = None
qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False)
epsilon = 1e-05
ln2_epsilon = 1e-05
if attn_mask is not None:
attn_mask = _convert_attention_mask(attn_mask, x.dtype)
qkv_weights, qkv_biases = [], []
out_weights, out_biases = [], []
ln_scales, ln_biases = [], []
ffn1_weights, ffn1_biases = [], []
ffn2_weights, ffn2_biases = [], []
ffn_ln_scales, ffn_ln_biases = [], []
for i in range(self.layers):
qkv_weights.append(qkv_weight_tensor)
qkv_biases.append(qkv_bias_tensor)
out_weights.append(out_linear_weight)
out_biases.append(out_linear_bias)
ln_scales.append(ln_scale)
ln_biases.append(ln_bias)
ffn1_weights.append(ffn1_weight)
ffn1_biases.append(ffn1_bias)
ffn2_weights.append(ffn2_weight)
ffn2_biases.append(ffn2_bias)
ffn_ln_scales.append(ffn_ln_scale)
ffn_ln_biases.append(ffn_ln_bias)
if self.has_cache_kv:
cache_kvs.append(
paddle.to_tensor(
cache_kv, stop_gradient=False))
final_out = fused_multi_transformer(
x,
ln_scales,
ln_biases,
qkv_weights,
qkv_biases,
out_weights,
out_biases,
ffn_ln_scales,
ffn_ln_biases,
ffn1_weights,
ffn1_biases,
ffn2_weights,
ffn2_biases,
pre_layer_norm=self.pre_layer_norm,
epsilon=epsilon,
cache_kvs=cache_kvs,
time_step=time_step,
attn_mask=attn_mask,
dropout_rate=self.dropout_prob,
training=self.training)
if self.has_cache_kv:
return final_out[0], final_out[1]
return final_out
def test_fused_multi_transformer_op(self):
final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedMultiTransformerOut()
if self.has_cache_kv:
final_out, cache_kv_out = final_out
s = cache_kv_out[0].shape
bsz = s[1]
num_head = s[2]
max_seq_len = s[3]
head_dim = s[4]
elems = 8 if self.x_type is np.float16 else 4
v_elems = head_dim // elems
if self.debug:
print("cache_k out timestep=128")
print(cache_kv_out[0].reshape([
2, bsz, num_head, v_elems, max_seq_len, elems
])[0, 0, 0, :, self.cache_length, :])
print("cache_v out timestep=128")
print(cache_kv_out[0][1, 0, 0, self.cache_length, :])
if self.gen_cache_kv:
final_out_ref, cache_kvs = final_out_ref
for i in range(self.layers):
cache_k_ref = cache_kvs[i][0]
cache_v_ref = cache_kvs[i][1]
cache_k = cache_kv_out[i][0, :]
cache_k = cache_k.reshape(
[bsz, num_head, v_elems, max_seq_len, elems])
cache_k = cache_k[:, :, :, :self.cache_length, :]
cache_k = cache_k.transpose([0, 1, 3, 2, 4])
cache_k = cache_k.reshape(
[bsz, num_head, self.cache_length, head_dim])
cache_v = cache_kv_out[i][1, :, :, :self.cache_length, :]
np.testing.assert_allclose(
cache_k_ref, cache_k, rtol=self.rtol, atol=self.atol)
np.testing.assert_allclose(
cache_v_ref, cache_v, rtol=self.rtol, atol=self.atol)
if i == 0:
break
np.testing.assert_allclose(
final_out_ref, final_out, rtol=self.rtol, atol=self.atol)
class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.x_type = np.float16
self.layers = 3 # odd layers
class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.layers = 3 # odd layers
class TestFusedMultiTransformerOpCacheKVFp16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.x_type = np.float16
class TestFusedMultiTransformerOpGenCacheKV(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
class TestFusedMultiTransformerOpGenCacheKVFp16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.x_type = np.float16
self.layers = 3 # odd layers
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_dist_base import TestDistBase
import os
import paddle
paddle.enable_static()
flag_name = os.path.splitext(__file__)[0]
class TestStaticModelParallel(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
self._use_reader_alloc = False
self._nccl_comm_num = 1
self._pipeline_mode = True
def test_dist_static_model_parallel_fused_multi_transformer(self):
import paddle.fluid as fluid
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"static_model_parallel_fused_multi_transformer.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == '__main__':
unittest.main()
...@@ -15,10 +15,11 @@ ...@@ -15,10 +15,11 @@
from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401 from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401
from .layer.fused_transformer import FusedFeedForward # noqa: F401 from .layer.fused_transformer import FusedFeedForward # noqa: F401
from .layer.fused_transformer import FusedTransformerEncoderLayer # noqa: F401 from .layer.fused_transformer import FusedTransformerEncoderLayer # noqa: F401
from .layer.fused_transformer import FusedMultiTransformer # noqa: F401
__all__ = [ #noqa __all__ = [ #noqa
'FusedMultiHeadAttention', 'FusedMultiHeadAttention',
'FusedFeedForward', 'FusedFeedForward',
'FusedTransformerEncoderLayer', 'FusedTransformerEncoderLayer',
'FusedMultiTransformer',
] ]
...@@ -14,5 +14,10 @@ ...@@ -14,5 +14,10 @@
from .fused_transformer import fused_multi_head_attention from .fused_transformer import fused_multi_head_attention
from .fused_transformer import fused_feedforward from .fused_transformer import fused_feedforward
from .fused_transformer import fused_multi_transformer
__all__ = ['fused_multi_head_attention', 'fused_feedforward'] __all__ = [
'fused_multi_head_attention',
'fused_feedforward',
'fused_multi_transformer',
]
...@@ -488,3 +488,238 @@ def fused_multi_head_attention(x, ...@@ -488,3 +488,238 @@ def fused_multi_head_attention(x,
attrs=attrs) attrs=attrs)
return (final_out, cache_kv_out) if cache_kv else final_out return (final_out, cache_kv_out) if cache_kv else final_out
def fused_multi_transformer(x,
ln_scales,
ln_biases,
qkv_weights,
qkv_biases,
linear_weights,
linear_biases,
ffn_ln_scales,
ffn_ln_biases,
ffn1_weights,
ffn1_biases,
ffn2_weights,
ffn2_biases,
pre_layer_norm=True,
epsilon=1e-05,
cache_kvs=None,
time_step=None,
attn_mask=None,
dropout_rate=0.0,
activation="gelu",
training=False,
mode='upscale_in_train',
ring_id=-1,
name=None):
r"""
This is a fusion operator to compute multi transformer layers in transformer model architecture.
This operator only supports running on GPU. The function of the transformer layer is consistent
with the following pseudo code:
.. code-block:: python
if pre_layer_norm:
out = layer_norm(x)
out = qkv_linear(out) + qkv_bias
else:
out = qkv_linear(x) + qkv_bias
out = transpose(out, perm=[2, 0, 3, 1, 4])
# extract q, k and v from out.
q = out[0:1, ::]
k = out[1:2, ::]
v = out[2:3, ::]
out = q * k^t
out = attn_mask + out
out = softmax(out)
out = dropout(out)
out = out * v
out = transpose(out, perm=[0, 2, 1, 3])
out = linear(out)
if pre_layer_norm:
out = x + dropout(out + bias)
else:
out = layer_norm(x + dropout(out + bias))
residual = out;
if pre_layer_norm:
out = ffn_layer_norm(out)
out = ffn1_linear(out)
out = dropout(activation(out + ffn1_bias))
out = ffn2_linear(out)
out = residual + dropout(out + ffn2_bias)
if not pre_layer_norm:
out = ffn_layer_norm(out)
Args:
x (Tensor): the input tensor could be 3-D tensor, the input data type could be float16 or float32, the shape is `[batch\_size, sequence\_length, d\_model]`.
ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of attention layer_norm, the shape is `[d\_model]`.
ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of attention layer_norm. the shape is `[d\_model]`.
qkv_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention qkv computation. The shape is `[3, num\_head, dim\_head, d\_model]`.
qkv_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention qkv computation. The shape is `[3, num\_head, dim\_head]`.
linear_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention linear. The shape is `[num\_head * dim\_head, d\_model]`.
linear_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention linear. The shape is `[d\_model]`.
ffn_ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward layer_norm, the shape is `[d\_model]`
ffn_ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of feedforward layer_norm, the shape is `[d\_model]`
ffn1_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward first linear, the shape is `[d\_model, dim\_feedforward]`.
ffn1_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward first linear, the shape is `[dim\_feedforward]`.
ffn2_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward second linear, the shape is `[dim\_feedforward, d\_model]`.
ffn2_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward second linear, the shape is `[d_model]`.
pre_layer_norm (bool, optional): whether it is pre_layer_norm(True) or post_layer_norm(False). Default True.
epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5.
cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None.
time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None.
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
with shape `[batch_size, 1, sequence_length, sequence_length]`. Default None.
dropout_rate (float, optional): The dropout probability of setting units to zero. Default 0.0.
activation (str, optional): The activation. Default "gelu".
training (bool, optional): A flag indicating whether it is in train phrase or not. Default False.
mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
1. upscale_in_train(default), upscale the output at training time
- train: out = input * mask / ( 1.0 - p )
- inference: out = input
2. downscale_in_infer, downscale the output at inference
- train: out = input * mask
- inference: out = input * (1.0 - p)
ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using mp.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor|tuple: If `cache_kvs` is None, return a tensor that has
the same shape and data type with `x`, representing the output
of Transformer layers. If `cache_kvs` is not None, return the
tuple (output, cache_kvs), which output is the output of
Transformer layers, cache_kvs is inplace with input `cache_kvs`.
Examples:
.. code-block:: python
# required: gpu
import paddle
import paddle.incubate.nn.functional as F
import numpy as np
# input: [batch_size, seq_len, embed_dim]
x = paddle.rand(shape=(2, 4, 128), dtype="float32")
# ln_scale: [embed_dim], ln_bias: [embed_dim]
ln_scale = paddle.rand(shape=(128,), dtype="float32")
ln_bias = paddle.rand(shape=(128,), dtype="float32")
# qkv_weight: [3, num_head, head_dim, embed_dim], qkv_bias: [3, num_head, head_dim]
qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32")
qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32")
# linear_weight: [embed_dim, embed_dim], linear_bias: [embed_dim]
linear_weight = paddle.rand(shape=(128, 128), dtype="float32")
linear_bias = paddle.rand(shape=(128,), dtype="float32")
# ffn_ln_scale: [embed_dim], ffn_ln_bias: [embed_dim]
ffn_ln_scale = paddle.rand(shape=(128,), dtype="float32")
ffn_ln_bias = paddle.rand(shape=(128,), dtype="float32")
# ffn1_weight: [embed_dim, 4*embed_dim], ffn1_bias: [4*embed_dim]
ffn1_weight = paddle.rand(shape=(128, 4*128), dtype="float32")
ffn1_bias = paddle.rand(shape=(4*128,), dtype="float32")
# ffn2_weight: [4*embed_dim, embed_dim], ffn2_bias: [embed_dim]
ffn2_weight = paddle.rand(shape=(4*128, 128), dtype="float32")
ffn2_bias = paddle.rand(shape=(128,), dtype="float32")
# self attention mask: [batch_size, 1, seq_len, seq_len]
attn_mask = paddle.rand(shape=(2, 1, 4, 4), dtype="float32")
# output: [batch_size, seq_len, embed_dim]
output = F.fused_multi_transformer(
x, [ln_scale], [ln_bias], [qkv_weight], [qkv_bias],
[linear_weight], [linear_bias], [ffn_ln_scale], [ffn_ln_bias],
[ffn1_weight], [ffn1_bias], [ffn2_weight], [ffn2_bias],
attn_mask=attn_mask)
# [2, 4, 128]
print(output.shape)
"""
if mode not in ('downscale_in_infer', 'upscale_in_train'):
raise ValueError(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'")
mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer
if _non_static_mode():
cache_kv_out, final_out = _C_ops.fused_multi_transformer(
x, ln_scales, ln_biases, qkv_weights, qkv_biases, cache_kvs,
time_step, attn_mask, linear_weights, linear_biases, ffn_ln_scales,
ffn_ln_biases, ffn1_weights, ffn1_biases, ffn2_weights, ffn2_biases,
cache_kvs, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon,
'dropout_rate', dropout_rate, 'dropout_is_test', not training,
'dropout_implementation', mode, 'act_method', activation, 'ring_id',
ring_id)
if cache_kvs is not None:
return final_out, cache_kv_out
return final_out
else:
helper = LayerHelper('fused_multi_transformer', **locals())
dtype = x.dtype
# check dtypes
check_variable_and_dtype(x, 'x', ['float16', 'float32'],
'fused_multi_transformer')
check_dtype(dtype, 'dtype', ['float16', 'float32'],
'fused_multi_transformer')
# set inputs
inputs = dict()
inputs['X'] = [x]
inputs['LnScale'] = ln_scales
inputs['LnBias'] = ln_biases
inputs['QKVW'] = qkv_weights
if qkv_biases is not None:
inputs['QKVBias'] = qkv_biases
if cache_kvs is not None:
assert len(cache_kvs) == len(qkv_weights)
inputs['CacheKV'] = cache_kvs
if time_step is not None:
inputs['TimeStep'] = time_step
inputs['SrcMask'] = attn_mask
inputs['OutLinearW'] = linear_weights
if linear_biases is not None:
inputs['OutLinearBias'] = linear_biases
inputs['FFNLnScale'] = ffn_ln_scales
inputs['FFNLnBias'] = ffn_ln_biases
inputs['FFN1Weight'] = ffn1_weights
if ffn1_biases is not None:
inputs['FFN1Bias'] = ffn1_biases
inputs['FFN2Weight'] = ffn2_weights
if ffn2_biases is not None:
inputs['FFN2Bias'] = ffn2_biases
# set attrs
attrs = {
'pre_layer_norm': pre_layer_norm,
'epsilon': epsilon,
'dropout_rate': dropout_rate,
'dropout_is_test': not training,
'dropout_implementation': mode,
'act_method': activation,
'ring_id': ring_id
}
outputs = dict()
final_out = helper.create_variable_for_type_inference(dtype=dtype)
outputs['Out'] = final_out
if cache_kvs:
# NOTE: inplace
outputs['CacheKVOut'] = cache_kvs
helper.append_op(
type='fused_multi_transformer',
inputs=inputs,
outputs=outputs,
attrs=attrs)
return (final_out, cache_kvs) if cache_kvs else final_out
...@@ -22,6 +22,20 @@ from paddle.nn.initializer import Constant ...@@ -22,6 +22,20 @@ from paddle.nn.initializer import Constant
import collections import collections
# for distributed tensor model parallel
def _set_var_distributed(var):
if var is None:
return
var.is_distributed = True
# NOTE: use current_block and find_var_recursive to support while_loop
startup_block = paddle.static.default_startup_program().current_block()
main_block = paddle.static.default_main_program().current_block()
startup_block._find_var_recursive(var.name).is_distributed = True
main_block._find_var_recursive(var.name).is_distributed = True
class FusedMultiHeadAttention(Layer): class FusedMultiHeadAttention(Layer):
""" """
Attention mapps queries and a set of key-value pairs to outputs, and Attention mapps queries and a set of key-value pairs to outputs, and
...@@ -608,3 +622,390 @@ class FusedTransformer(Layer): ...@@ -608,3 +622,390 @@ class FusedTransformer(Layer):
def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None): def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
raise NotImplementedError() raise NotImplementedError()
class FusedMultiTransformer(Layer):
"""
FusedMultiTransformer is composed of multi transformer layers which contains two
sub-layers which are self (multi-head) attention and feedforward network. The
function of one transformer layer is consistent with the following pseudo code:
.. code-block:: python
if pre_layer_norm:
out = layer_norm(x)
out = qkv_linear(out) + qkv_bias
else:
out = qkv_linear(x) + qkv_bias
out = transpose(out, perm=[2, 0, 3, 1, 4])
# extract q, k and v from out.
q = out[0:1, ::]
k = out[1:2, ::]
v = out[2:3, ::]
out = q * k^t
out = attn_mask + out
out = softmax(out)
out = dropout(out)
out = out * v
out = transpose(out, perm=[0, 2, 1, 3])
out = linear(out)
if pre_layer_norm:
out = x + dropout(out + bias)
else:
out = layer_norm(x + dropout(out + bias))
residual = out;
if pre_layer_norm:
out = ffn_layer_norm(out)
out = ffn1_linear(out)
out = dropout(activation(out + ffn1_bias))
out = ffn2_linear(out)
out = residual + dropout(out + ffn2_bias)
if not pre_layer_norm:
out = ffn_layer_norm(out)
Parameters:
embed_dim (int): The expected feature size in the input and output.
num_heads (int): The number of heads in multi-head attention(MHA).
dim_feedforward (int): The hidden layer size in the feedforward network(FFN).
dropout_rate (float, optional): The dropout probability used in pre-process
and post-precess of MHA and FFN sub-layer. Default 0.0
activation (str, optional): The activation function in the feedforward
network. Default "gelu".
normalize_before (bool, optional): Indicate whether to put layer normalization
into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer
normalization and post-precess includes dropout, residual connection.
Otherwise, no pre-process and post-precess includes dropout, residual
connection, layer normalization. Default True
ln_scale_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
for Attention layer_norm. For Attention layer_norm weight, if it is a list/tuple, `attrs[0]`
would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
`attr` for transformer layer 1,etc. Otherwise, all layers both use it as
`attr` to create parameters. Default: None, which means the default weight
parameter property is used. See usage for details in :code:`ParamAttr`.
ln_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
for Attention layer_norm. For Attention layer_norm bias, if it is a list/tuple, `attrs[0]`
would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
`attr` for transformer layer 1,etc. Otherwise, all layers both use it as
`attr` to create parameters. The `False` value means the corresponding layer would
not have trainable bias parameter. Default: None, which means the default bias
parameter property is used. See usage for details in :code:`ParamAttr`.
qkv_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
for Attention qkv computation. For Attention qkv weight, if it is a list/tuple, `attrs[0]`
would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
`attr` for transformer layer 1,etc. Otherwise, all layers both use it as
`attr` to create parameters. Default: None, which means the default weight
parameter property is used. See usage for details in :code:`ParamAttr`.
qkv_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
for Attention qkv computation. For Attention qkv bias, if it is a list/tuple, `attrs[0]`
would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
`attr` for transformer layer 1,etc. Otherwise, all layers both use it as
`attr` to create parameters. The `False` value means the corresponding layer would
not have trainable bias parameter. Default: None, which means the default bias
parameter property is used. See usage for details in :code:`ParamAttr`.
linear_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
for Attention linear. For Attention linear weight, if it is a list/tuple, `attrs[0]`
would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
`attr` for transformer layer 1,etc. Otherwise, all layers both use it as
`attr` to create parameters. Default: None, which means the default weight
parameter property is used. See usage for details in :code:`ParamAttr`.
linear_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
for Attention linear computation. For Attention linear bias, if it is a list/tuple, `attrs[0]`
would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
`attr` for transformer layer 1,etc. Otherwise, all layers both use it as
`attr` to create parameters. The `False` value means the corresponding layer would
not have trainable bias parameter. Default: None, which means the default bias
parameter property is used. See usage for details in :code:`ParamAttr`.
ffn_ln_scale_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
for FFN layer_norm. For FFN layer_norm weight, if it is a list/tuple, `attrs[0]`
would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
`attr` for transformer layer 1,etc. Otherwise, all layers both use it as
`attr` to create parameters. Default: None, which means the default weight
parameter property is used. See usage for details in :code:`ParamAttr`.
ffn_ln_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
for FFN layer_norm. For FFN layer_norm bias, if it is a list/tuple, `attrs[0]`
would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
`attr` for transformer layer 1,etc. Otherwise, all layers both use it as
`attr` to create parameters. The `False` value means the corresponding layer would
not have trainable bias parameter. Default: None, which means the default bias
parameter property is used. See usage for details in :code:`ParamAttr`.
ffn1_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
for FFN first linear. For FFN first linear weight, if it is a list/tuple, `attrs[0]`
would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
`attr` for transformer layer 1,etc. Otherwise, all layers both use it as
`attr` to create parameters. Default: None, which means the default weight
parameter property is used. See usage for details in :code:`ParamAttr`.
ffn1_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
for FFN first linear. For FFN first linear bias, if it is a list/tuple, `attrs[0]`
would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
`attr` for transformer layer 1,etc. Otherwise, all layers both use it as
`attr` to create parameters. The `False` value means the corresponding layer would
not have trainable bias parameter. Default: None, which means the default bias
parameter property is used. See usage for details in :code:`ParamAttr`.
ffn2_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
for FFN second linear. For FFN second linear weight, if it is a list/tuple, `attrs[0]`
would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
`attr` for transformer layer 1,etc. Otherwise, all layers both use it as
`attr` to create parameters. Default: None, which means the default weight
parameter property is used. See usage for details in :code:`ParamAttr`.
ffn2_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
for FFN second linear. For FFN second linear bias, if it is a list/tuple, `attrs[0]`
would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
`attr` for transformer layer 1,etc. Otherwise, all layers both use it as
`attr` to create parameters. The `False` value means the corresponding layer would
not have trainable bias parameter. Default: None, which means the default bias
parameter property is used. See usage for details in :code:`ParamAttr`.
epsilon (float, optional): Small float value added to denominator of the layer_norm to
avoid dividing by zero. Default: 1e-05.
num_layers (int, optional): The number of layers of the transformer. If `qkv_weight_attrs`
is a list or tuple, the number of layers is obtained from `qkv_weight_attrs`. num_layers
only takes effect when `qkv_weight_attrs` is not a list or tuple. Default: -1.
nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using mp.
ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using mp.
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`.
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn import FusedMultiTransformer
# encoder input: [batch_size, src_len, d_model]
enc_input = paddle.rand((2, 4, 128))
# self attention mask: [batch_size, 1, src_len, src_len]
attn_mask = paddle.rand((2, 1, 4, 4))
encoder_layers = FusedMultiTransformer(128, 2, 512, num_layers=1)
enc_output = encoder_layers(enc_input, attn_mask) # [2, 4, 128]
"""
def __init__(self,
embed_dim,
num_heads,
dim_feedforward,
dropout_rate=0.0,
activation="gelu",
normalize_before=True,
ln_scale_attrs=None,
ln_bias_attrs=None,
qkv_weight_attrs=None,
qkv_bias_attrs=None,
linear_weight_attrs=None,
linear_bias_attrs=None,
ffn_ln_scale_attrs=None,
ffn_ln_bias_attrs=None,
ffn1_weight_attrs=None,
ffn1_bias_attrs=None,
ffn2_weight_attrs=None,
ffn2_bias_attrs=None,
epsilon=1e-5,
num_layers=-1,
nranks=1,
ring_id=-1,
name=None):
super(FusedMultiTransformer, self).__init__()
assert embed_dim > 0, ("Expected embed_dim to be greater than 0, "
"but recieved {}".format(embed_dim))
assert num_heads > 0, ("Expected nhead to be greater than 0, "
"but recieved {}".format(num_heads))
assert dim_feedforward > 0, (
"Expected dim_feedforward to be greater than 0, but recieved {}".
format(dim_feedforward))
self.normalize_before = normalize_before
self._dtype = self._helper.get_default_dtype()
self._epsilon = epsilon
self._ring_id = ring_id
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
# tensor model parallel
if nranks > 1:
assert ring_id != -1
assert num_heads % nranks == 0
assert dim_feedforward % nranks == 0
num_heads = num_heads // nranks
dim_feedforward = dim_feedforward // nranks
self._dim_feedforward = dim_feedforward
if isinstance(qkv_weight_attrs, (list, tuple)):
num_layers = len(qkv_weight_attrs)
assert num_layers > 0
self.ln_scales, self.ln_biases = [], []
self.qkv_weights, self.qkv_biases = [], []
self.linear_weights, self.linear_biases = [], []
self.ffn_ln_scales, self.ffn_ln_biases = [], []
self.ffn1_weights, self.ffn1_biases = [], []
self.ffn2_weights, self.ffn2_biases = [], []
def get_attr(attrs, idx):
if isinstance(attrs, (list, tuple)):
assert len(attrs) == num_layers
return attrs[idx]
return attrs
for i in range(num_layers):
ln_scale_attr = get_attr(ln_scale_attrs, i)
ln_bias_attr = get_attr(ln_bias_attrs, i)
qkv_weight_attr = get_attr(qkv_weight_attrs, i)
qkv_bias_attr = get_attr(qkv_bias_attrs, i)
linear_weight_attr = get_attr(linear_weight_attrs, i)
linear_bias_attr = get_attr(linear_bias_attrs, i)
ffn_ln_scale_attr = get_attr(ffn_ln_scale_attrs, i)
ffn_ln_bias_attr = get_attr(ffn_ln_bias_attrs, i)
ffn1_weight_attr = get_attr(ffn1_weight_attrs, i)
ffn1_bias_attr = get_attr(ffn1_bias_attrs, i)
ffn2_weight_attr = get_attr(ffn2_weight_attrs, i)
ffn2_bias_attr = get_attr(ffn2_bias_attrs, i)
ln_scale = self.create_parameter(
attr=ln_scale_attr,
shape=[embed_dim],
default_initializer=Constant(value=1.0))
ln_bias = self.create_parameter(
attr=ln_bias_attr, shape=[embed_dim], is_bias=True)
qkv_weight = self.create_parameter(
shape=[3, num_heads, self.head_dim, embed_dim],
attr=qkv_weight_attr,
dtype=self._dtype,
is_bias=False)
qkv_bias = self.create_parameter(
shape=[3, num_heads, self.head_dim],
attr=qkv_bias_attr,
dtype=self._dtype,
is_bias=True)
linear_weight = self.create_parameter(
shape=[num_heads * self.head_dim, embed_dim],
attr=linear_weight_attr,
dtype=self._dtype,
is_bias=False)
linear_bias = self.create_parameter(
shape=[embed_dim],
attr=linear_bias_attr,
dtype=self._dtype,
is_bias=True)
ffn_ln_scale = self.create_parameter(
shape=[embed_dim],
attr=ffn_ln_scale_attr,
is_bias=False,
default_initializer=Constant(1.0))
ffn_ln_bias = self.create_parameter(
shape=[embed_dim], attr=ffn_ln_bias_attr, is_bias=True)
ffn1_weight = self.create_parameter(
shape=[embed_dim, dim_feedforward],
attr=ffn1_weight_attr,
dtype=self._dtype,
is_bias=False)
ffn1_bias = self.create_parameter(
shape=[dim_feedforward],
attr=ffn1_bias_attr,
dtype=self._dtype,
is_bias=True)
ffn2_weight = self.create_parameter(
shape=[dim_feedforward, embed_dim],
attr=ffn2_weight_attr,
dtype=self._dtype,
is_bias=False)
ffn2_bias = self.create_parameter(
shape=[embed_dim],
attr=ffn2_bias_attr,
dtype=self._dtype,
is_bias=True)
# tensor model parallel
if nranks > 1:
# column parallel
_set_var_distributed(qkv_weight)
_set_var_distributed(qkv_bias)
_set_var_distributed(ffn1_weight)
_set_var_distributed(ffn1_bias)
# row parallel
_set_var_distributed(linear_weight)
_set_var_distributed(ffn2_weight)
self.ln_scales.append(ln_scale)
self.ln_biases.append(ln_bias)
self.qkv_weights.append(qkv_weight)
self.qkv_biases.append(qkv_bias)
self.linear_weights.append(linear_weight)
self.linear_biases.append(linear_bias)
self.ffn_ln_scales.append(ffn_ln_scale)
self.ffn_ln_biases.append(ffn_ln_bias)
self.ffn1_weights.append(ffn1_weight)
self.ffn1_biases.append(ffn1_bias)
self.ffn2_weights.append(ffn2_weight)
self.ffn2_biases.append(ffn2_bias)
self.dropout_rate = dropout_rate
self.activation = activation
self.name = name
def forward(self, src, attn_mask=None, caches=None, time_step=None):
"""
Applies multi transformer layers on the input.
Parameters:
src (Tensor): The input of Transformer layers. It is
a tensor with shape `[batch_size, sequence_length, d_model]`.
The data type should be float16 or float32.
attn_mask (Tensor, optional): A tensor used in multi-head attention
to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape
`[batch_size, 1, sequence_length, sequence_length]`. It can be
None when nothing wanted or needed to be prevented attention to.
Default None.
caches (list(Tensor)|tuple(Tensor), optional): The cache structure
tensors for the inference generation model. It is only used for
inference and should be None for training. The shape is
`[2, batch_size, num_head, max_seq_len, head_dim]`. Default None.
time_step (Tensor, optional): The time step tensor for the generation
model. Which used in decode stage, to represent the time step,
that is, the real seq_len of CacheKV. The shape is `[1]`, must be
in CPUPlace. Default None.
Returns:
Tensor|tuple: If `caches` is None, return a tensor that has
the same shape and data type with `src`, representing the output
of Transformer layers. If `caches` is not None, return the
tuple (output, caches), which output is the output of
Transformer layers, caches is inplace with input `caches`.
"""
if caches is not None:
assert len(caches) == len(self.qkv_weights)
out = incubate_f.fused_multi_transformer(
src,
self.ln_scales,
self.ln_biases,
self.qkv_weights,
self.qkv_biases,
self.linear_weights,
self.linear_biases,
self.ffn_ln_scales,
self.ffn_ln_biases,
self.ffn1_weights,
self.ffn1_biases,
self.ffn2_weights,
self.ffn2_biases,
pre_layer_norm=self.normalize_before,
epsilon=self._epsilon,
cache_kvs=caches,
time_step=time_step,
attn_mask=attn_mask,
dropout_rate=self.dropout_rate,
activation=self.activation,
training=self.training,
mode='upscale_in_train',
ring_id=self._ring_id,
name=self.name)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册