From 9dadf7df2467475690422a7f2eb12da820a2f0fc Mon Sep 17 00:00:00 2001 From: WangXi Date: Tue, 26 Apr 2022 19:43:56 +0800 Subject: [PATCH] Add fused_multi_transformer op to optimize transformer generation performance (#41814) --- paddle/fluid/operators/fused/CMakeLists.txt | 2 + .../fused/fused_multi_transformer_op.cc | 259 ++++ .../fused/fused_multi_transformer_op.cu | 1338 +++++++++++++++++ paddle/fluid/pybind/op_function_generator.h | 6 + .../contrib/mixed_precision/fp16_lists.py | 1 + .../contrib/mixed_precision/fp16_utils.py | 2 + .../fluid/tests/unittests/CMakeLists.txt | 3 + ..._model_parallel_fused_multi_transformer.py | 193 +++ .../test_fused_multi_transformer_op.py | 542 +++++++ ..._model_parallel_fused_multi_transformer.py | 45 + python/paddle/incubate/nn/__init__.py | 3 +- .../paddle/incubate/nn/functional/__init__.py | 7 +- .../nn/functional/fused_transformer.py | 235 +++ .../incubate/nn/layer/fused_transformer.py | 401 +++++ 14 files changed, 3035 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/fused/fused_multi_transformer_op.cc create mode 100644 paddle/fluid/operators/fused/fused_multi_transformer_op.cu create mode 100644 python/paddle/fluid/tests/unittests/static_model_parallel_fused_multi_transformer.py create mode 100644 python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_static_model_parallel_fused_multi_transformer.py diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 80e7f5c001..68b9051d85 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -19,6 +19,7 @@ register_operators(EXCLUDES fused_attention_op fused_transformer_op fused_feedforward_op + fused_multi_transformer_op resnet_unit_op fused_gemm_epilogue_op) @@ -73,6 +74,7 @@ if (WITH_GPU OR WITH_ROCM) op_library(fused_feedforward_op) # fused_attention_op op_library(fused_attention_op) + op_library(fused_multi_transformer_op) endif() # resnet_unit needs cudnn 8.0 above if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000)) diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc new file mode 100644 index 0000000000..c95ca6fe0c --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc @@ -0,0 +1,259 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class FusedMultiTransformerOp : public framework::OperatorWithKernel { + private: + static constexpr const char *OpName = "FusedMultiTransformerOp"; + + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { +#define CHECK_INPUT(name) \ + OP_INOUT_CHECK(ctx->HasInput(#name), "Input", #name, OpName) +#define CHECK_INPUTS(name) \ + OP_INOUT_CHECK(ctx->HasInputs(#name), "Input", #name, OpName) +#define CHECK_OUTPUT(name) \ + OP_INOUT_CHECK(ctx->HasOutput(#name), "Output", #name, OpName) +#define CHECK_OUTPUTS(name) \ + OP_INOUT_CHECK(ctx->HasOutputs(#name), "Output", #name, OpName) + + CHECK_INPUT(X); + + // attention + CHECK_INPUTS(QKVW); + CHECK_INPUTS(OutLinearW); + + if (ctx->HasInput("TimeStep")) { + CHECK_INPUTS(CacheKV); + } + + if (ctx->HasInputs("CacheKV")) { + CHECK_OUTPUTS(CacheKVOut); + } + + // ffn + CHECK_INPUTS(FFN1Weight); + CHECK_INPUTS(FFN2Weight); + + CHECK_OUTPUT(Out); + + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto x_dim = ctx->GetInputDim("X"); + auto y_dim = ctx->GetInputsDim("QKVW")[0]; + PADDLE_ENFORCE_EQ(x_dim.size(), 3, platform::errors::InvalidArgument( + "The dimensions of x must be 3" + "(batch_size, seq_len, dim_embed)," + "but received dimensions of" + "Input is [%d]", + x_dim.size())); + PADDLE_ENFORCE_EQ(y_dim.size(), 4, + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "but received dimensions of" + "Input is [%d]", + y_dim.size())); + PADDLE_ENFORCE_EQ(x_dim[2], y_dim[3], + platform::errors::InvalidArgument( + "ShapeError: the dimension of x_dim[2] and y_dim[3]" + "must be equal. But received: the shape " + "of input x = [%s], and the shape of " + "input qkv_weight = [%s]", + x_dim, y_dim)); + + if (ctx->Attrs().Get("ring_id") == -1) { + PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], y_dim[3], + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "and must satisfy the limitations: " + "(num_head * dim_head == dim_embed)")); + } + + if (ctx->HasInputs("CacheKV")) { + // [2, batch_size, num_head, max_seq_len, head_size] + const auto &c_dims = ctx->GetInputsDim("CacheKV"); + const auto &c_dim = c_dims[0]; + + PADDLE_ENFORCE_EQ( + c_dim.size(), 5, + paddle::platform::errors::InvalidArgument( + "The CacheKV must be 5 dims, but got %d", c_dim.size())); + PADDLE_ENFORCE_EQ(c_dim[0], 2, + paddle::platform::errors::InvalidArgument( + "The first dim of CacheKV must be 2, but got %d", + c_dim[0])); // 2 + PADDLE_ENFORCE_EQ(c_dim[1], x_dim[0], + paddle::platform::errors::InvalidArgument( + "The second dim of CacheKV must be equal with " + "batch size %d, but got %d", + x_dim[0], c_dim[1])); // batch_size + PADDLE_ENFORCE_EQ(c_dim[2], y_dim[1], + paddle::platform::errors::InvalidArgument( + "The third dim of CacheKV must be equal with num " + "head %d, but got %d", + y_dim[1], c_dim[2])); // num_head + PADDLE_ENFORCE_GT( + c_dim[3], 0, + paddle::platform::errors::InvalidArgument( + "The forth dim of CacheKV must be greater than 0, but got %d", + c_dim[3])); // cache_seq_len + PADDLE_ENFORCE_EQ(c_dim[4], y_dim[2], + paddle::platform::errors::InvalidArgument( + "The fifth dim of CacheKV must be equal with head " + "size %d, but got %d", + y_dim[2], c_dim[4])); // head_size + } + + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "TimeStep") { + VLOG(10) << "var_name:" << var_name << " need not to transform"; + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +}; + +class FusedMultiTransformerOpOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor."); + AddInput("LnScale", + "Scale is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDuplicable(); + AddInput("LnBias", + "Bias is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDuplicable(); + AddInput("QKVW", "The qkv weight tensor.").AsDuplicable(); + AddInput("QKVBias", "The qkv bias tensor.").AsDispensable().AsDuplicable(); + AddInput("CacheKV", "(optional) The cached KV for generation inference.") + .AsDispensable() + .AsDuplicable(); + AddInput("TimeStep", + "(optional, int) The time step for generation inference.") + .AsDispensable(); + AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") + .AsDispensable(); + AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable(); + AddInput("OutLinearBias", "The out_linear bias tensor.") + .AsDispensable() + .AsDuplicable(); + + AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op") + .AsDuplicable(); + AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op") + .AsDuplicable(); + AddInput("FFN1Weight", "The linear1 weight of FusedFeedForward op") + .AsDuplicable(); + AddInput("FFN1Bias", "The linear1 bias of FusedFeedForward op") + .AsDispensable() + .AsDuplicable(); + AddInput("FFN2Weight", "The linear2 weight of FusedFeedForward op") + .AsDuplicable(); + AddInput("FFN2Bias", "The linear2 bias input of FusedFeedForward op") + .AsDispensable() + .AsDuplicable(); + + AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV") + .AsDispensable() + .AsDuplicable(); + AddOutput("Out", "Result after multi ."); + + AddAttr("pre_layer_norm", + "if true, the attention op uses pre_layer_norm architecure, " + "else, uses post_layer_norm architecuture. " + "[default true].") + .SetDefault(true); + AddAttr("epsilon", + "Constant for numerical stability [default 1e-5].") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true, + platform::errors::InvalidArgument( + "'epsilon' in Op(LayerNorm) should be between" + "0.0 and 0.001, But received [%s].", + epsilon)); + }); + + AddAttr("dropout_rate", "Probability of setting units to zero.") + .SetDefault(.5f) + .AddCustomChecker([](const float &drop_p) { + PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, true, + platform::errors::InvalidArgument( + "'dropout_rate' must be between 0.0 and 1.0.")); + }); + + AddAttr("dropout_is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr( + "dropout_implementation", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "The meaning is the same as 'attn_dropout_implementation'.") + .SetDefault("downgrade_in_infer") + .AddCustomChecker([](const std::string &type) { + PADDLE_ENFORCE_EQ( + type == "downgrade_in_infer" || type == "upscale_in_train", true, + platform::errors::InvalidArgument( + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train")); + }); + AddAttr("act_method", "act_method").SetDefault("gelu"); + + AddAttr( + "ring_id", + "ring id for tensor model parallel. distributed training and inference") + .SetDefault(-1); + + AddComment(R"DOC(fused multi transformer layers op)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + fused_multi_transformer, ops::FusedMultiTransformerOp, + ops::FusedMultiTransformerOpOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu new file mode 100644 index 0000000000..f4a5319a68 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -0,0 +1,1338 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +// This file has been adapted from FasterTransformer file: +// https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu +// We add License in the head. + +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" + +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +#include "paddle/fluid/operators/fused/attention_layer_norm.h" +#include "paddle/fluid/operators/fused/attn_gemm.h" +#include "paddle/fluid/operators/fused/fmha_ref.h" +#include "paddle/fluid/operators/fused/fused_dropout_helper.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +// for debug +// #define _DEBUG_FUSED_MULTI_TRANSFORMER + +template +static void AllReduce(framework::Tensor &tensor, // NOLINT + const int ring_id, + const platform::CUDADeviceContext &ctx) { + if (ring_id == -1) return; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto dtype = + platform::ToNCCLDataType(framework::TransToProtoVarType(tensor.dtype())); + int64_t numel = tensor.numel(); + const void *sendbuff = tensor.data(); + auto place = ctx.GetPlace(); + void *recvbuff = tensor.mutable_data(place); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + auto stream = ctx.stream(); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream)); +#else + PADDLE_THROW(platform::errors::Unimplemented( + "PaddlePaddle should compile with NCCL or RCCL when used tensor model " + "parallel op.")); +#endif +} + +namespace { + +namespace plat = paddle::platform; +using float16 = plat::float16; + +#define MMHA_USE_FP32_ACUM_FOR_LOGITS +#define MMHA_USE_FP32_ACUM_FOR_OUT + +template +struct Masked_multihead_attention_params { + // output buffer, [B, 1(seq_len), num_head * dim_head] + T *out; + // qkv_out, [B, 1(seq_len), 3, num_head * dim_head] + const T *qkv; + // bias, [3, num_head, dim_head] + const T *qkv_bias; + // TODO(wangxi): optimize with input_lengths and max_input_len? + // [bsz, 1, 1, time_step(cache_seq_length)+1] + const T *attn_mask; + + // [2, B, num_head, max_seq_len(valid cache_seq_len), dim_head] + // k [B, num_head, dim_head/x, max_seq_len, x], that is `seq_len` first + // v [B, num_head, max_seq_len, dim_head] + T *cache_kv; + + int batch_size; + int num_head; + int timestep; // cache_seq_length + int max_seq_length; + + // 1.f / sqrt(Dh) + float inv_sqrt_dh; +}; + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +// clang-format off + +template struct Qk_vec_ {}; +template <> struct Qk_vec_ { using Type = float; }; +template <> struct Qk_vec_ { using Type = float2; }; +template <> struct Qk_vec_ { using Type = float4; }; +template <> struct Qk_vec_ { using Type = uint32_t; }; +template <> struct Qk_vec_ { using Type = uint32_t; }; +template <> struct Qk_vec_ { using Type = uint2; }; + +template struct K_vec_ {}; +template <> struct K_vec_ { using Type = float; }; +template <> struct K_vec_ { using Type = float2; }; +template <> struct K_vec_ { using Type = float4; }; +template <> struct K_vec_ { using Type = uint32_t; }; +template <> struct K_vec_ { using Type = uint2; }; +template <> struct K_vec_ { using Type = uint4; }; + +template struct V_vec_ {}; +template <> struct V_vec_ { using Type = float; }; +template <> struct V_vec_ { using Type = float2; }; +template <> struct V_vec_ { using Type = float4; }; +template <> struct V_vec_ { using Type = uint32_t; }; +template <> struct V_vec_ { using Type = uint2; }; +template <> struct V_vec_ { using Type = uint4; }; + +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT +template struct V_vec_acum_fp32_ {}; +// template <> struct V_vec_acum_fp32_ { using Type = float; }; +// template <> struct V_vec_acum_fp32_ { using Type = float2; }; +template <> struct V_vec_acum_fp32_ { using Type = float4; }; +// template <> struct V_vec_acum_fp32_ { using Type = float2; }; +// template <> struct V_vec_acum_fp32_ { using Type = Float4_; }; +template <> struct V_vec_acum_fp32_ { using Type = Float8_; }; +#endif + +// clang-format on + +inline __device__ float half_to_float(uint16_t h) { + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); + return f; +} + +inline __device__ float2 half2_to_float2(uint32_t v) { + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +} + +inline __device__ uint32_t float2_to_half2(float2 f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" + : "=r"(tmp.u32) + : "f"(f.y), "f"(f.x)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); +#endif + return tmp.u32; +} + +inline __device__ float add(float a, float b) { return a + b; } + +inline __device__ float2 add(float2 a, float2 b) { + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ float4 add(float4 a, float4 b) { + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ uint16_t add(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +inline __device__ uint32_t add(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +inline __device__ uint2 add(uint2 a, uint2 b) { + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ uint4 add(uint4 a, uint4 b) { + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(uint32_t a, float2 fb) { + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +inline __device__ Float8_ add(uint4 a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +template +inline __device__ Acc mul(A a, B b); + +template <> +inline __device__ float mul(float a, float b) { + return a * b; +} + +template <> +inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> +inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +template <> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +template <> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +template <> +inline __device__ uint2 mul(uint2 a, uint2 b) { + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> +inline __device__ uint4 mul(uint4 a, uint4 b) { + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +inline __device__ float sum(float v) { return v; } +inline __device__ float sum(float2 v) { return v.x + v.y; } +inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } +inline __device__ float sum(uint16_t v) { return half_to_float(v); } +inline __device__ float sum(uint32_t v) { + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +inline __device__ float sum(uint2 v) { + uint32_t c = add(v.x, v.y); + return sum(c); +} + +inline __device__ float sum(uint4 v) { + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); + return sum(c); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +inline __device__ constexpr uint32_t shfl_mask(int threads) { + return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; +} + +template +inline __device__ __host__ T div_up(T m, T n) { + return (m + n - 1) / n; +} + +inline __device__ float fma(float a, float b, float c) { return a * b + c; } + +inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ float4 fma(float4 a, float4 b, float4 c) { + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); + return d; +} + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ float2 fma(float a, float2 b, float2 c) { + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ float4 fma(float a, float4 b, float4 c) { + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ uint32_t h0_h0(uint16_t a) { + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); + return b; +} + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { + return fma(h0_h0(a), b, c); +} + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float cast_to_float(float u) { return u; } + +inline __device__ float2 cast_to_float(float2 u) { return u; } + +inline __device__ float4 cast_to_float(float4 u) { return u; } + +inline __device__ Float8_ cast_to_float(uint4 u) { + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +template +inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) { + K_vec qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +template +struct Qk_dot { + template + static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) { + return qk_dot_(q, k); + } +}; + +template +inline __device__ float block_sum(float *red_smem, float sum) { + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + if (lane == 0) { + red_smem[warp] = sum; + } + __syncthreads(); + + if (lane < WARPS_PER_BLOCK) { + sum = red_smem[lane]; + } + +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + return __shfl_sync(uint32_t(-1), sum, 0); +} + +inline __device__ void convert_from_float(float &dst, float src) { // NOLINT + dst = src; +} + +inline __device__ void convert_from_float(float4 &dst, float4 src) { // NOLINT + dst = src; +} + +inline __device__ void convert_from_float(plat::float16 &dst, // NOLINT + float src) { + dst = static_cast(src); +} + +inline __device__ void convert_from_float(uint4 &dst, Float8_ src) { // NOLINT + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +inline __device__ void zero(uint16_t &dst) { dst = uint16_t(0); } // NOLINT + +template +inline __device__ void zero(T &dst) { // NOLINT + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +template +__global__ void masked_multihead_attention_kernel( + Masked_multihead_attention_params params) { + static_assert(Dh % THREADS_PER_KEY == 0, ""); + static_assert(Dh % THREADS_PER_VALUE == 0, ""); + + constexpr int WARP_SIZE = 32; + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + extern __shared__ char smem_[]; + + float *qk_smem = reinterpret_cast(smem_); + + char *logits_smem_ = smem_; + // fp32 accum for logits + float *logits_smem = reinterpret_cast(logits_smem_); + + T *out_smem = reinterpret_cast(smem_); + + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + __shared__ T q_smem[Dh]; + + const int bi = blockIdx.y; + const int hi = blockIdx.x; + const int bhi = bi * params.num_head + hi; + const int tid = threadIdx.x; + + float qk_max = -FLT_MAX; + + // qkv [B, S=1, 3, num_head, head_dim] + int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh; + + using Qk_vec = typename Qk_vec_::Type; + constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); + static_assert(Dh % QK_VEC_SIZE == 0 && Dh / QK_VEC_SIZE <= WARP_SIZE, ""); + constexpr int QK_VECS_PER_WARP = Dh / QK_VEC_SIZE; + + // cache_k, [B, num_head, head_dim / x, max_seq_len, x] + // x == 4/8 for FP32/FP16, 128bit, 16Byte + constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); + constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); + + const T *q_base = params.qkv; + const T *k_base = params.qkv + params.num_head * Dh; + const T *q_bias_base = params.qkv_bias; + const T *k_bias_base = params.qkv_bias + params.num_head * Dh; + + if (tid < QK_VECS_PER_WARP) { + int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE; + int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE; + + Qk_vec q = *reinterpret_cast(&q_base[qk_offset]); + Qk_vec k = *reinterpret_cast(&k_base[qk_offset]); + + Qk_vec q_bias = + *reinterpret_cast(&q_bias_base[qk_bias_offset]); + Qk_vec k_bias = + *reinterpret_cast(&k_bias_base[qk_bias_offset]); + + q = add(q, q_bias); + // TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510 + // we may not require k_bias. + k = add(k, k_bias); + + *reinterpret_cast(&q_smem[tid * QK_VEC_SIZE]) = q; + + int co = tid / QK_VECS_IN_16B; + int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE; + int offset = bhi * params.max_seq_length * Dh + + co * params.max_seq_length * QK_ELTS_IN_16B + + params.timestep * QK_ELTS_IN_16B + ci; + *reinterpret_cast(¶ms.cache_kv[offset]) = k; + + float qk = dot(q, k); +#pragma unroll + for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); + } + + qk *= params.inv_sqrt_dh; + if (tid == 0) { + // NOTE(wangxi): mask must be 0.0 + // T mask = params.attn_mask[ + // bi * (params.timestep + 1) + params.timestep]; + // qk += static_cast(mask); + qk_max = qk; + qk_smem[params.timestep] = qk; + } + } + __syncthreads(); + +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + if (bi == 0 && hi == 0 && tid == 0) { + printf("=======q_out=======\n"); + for (int i = 0; i < Dh; ++i) printf("%f ", static_cast(q_smem[i])); + printf("\n"); + } + __syncthreads(); +#endif + + using K_vec = typename K_vec_::Type; + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); + static_assert(Dh % K_VEC_SIZE == 0, ""); + constexpr int K_ELTS_PER_THREAD = Dh / THREADS_PER_KEY; + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + + int ko = tid / THREADS_PER_KEY; + int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE; + + K_vec q[K_VECS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < K_VECS_PER_THREAD; ++i) { + q[i] = *reinterpret_cast( + &q_smem[ki + i * THREADS_PER_KEY * K_VEC_SIZE]); + } + + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + T *k_cache = ¶ms.cache_kv[bhi * params.max_seq_length * Dh + ki]; + int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; + + for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { + K_vec k[K_VECS_PER_THREAD]; +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_seq_length + ti; + if (ti < params.timestep) { + k[ii] = *reinterpret_cast(&k_cache[jj * QK_ELTS_IN_16B]); + } + } + + float qk = Qk_dot::dot(q, k) * params.inv_sqrt_dh; + + // bool is_mask = false; + if (ti < params.timestep && tid % THREADS_PER_KEY == 0) { + // qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); + T mask = params.attn_mask[bi * (params.timestep + 1) + ti]; + qk += static_cast(mask); + qk_max = fmaxf(qk_max, qk); + + qk_smem[ti] = qk; + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + const int warp = tid / WARP_SIZE; + const int lane = tid % WARP_SIZE; + + if (lane == 0) { + red_smem[warp] = qk_max; + } + + __syncthreads(); + + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + if (bi == 0 && hi == 0 && tid == 0) { + printf("=======qk_out=======\n"); + for (int i = 0; i <= params.timestep; ++i) printf("%f ", qk_smem[i]); + printf("qk_max=%f\n", qk_max); + } + __syncthreads(); +#endif + + float sum = 0.f; + for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) { + // bool is_mask = false; + // float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max); + float logit = __expf(qk_smem[ti] - qk_max); + sum += logit; + qk_smem[ti] = logit; + } + + sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); + + // FIXME(wangxi): need add 1.e-6f? + float inv_sum = __fdividef(1.f, sum + 1.e-6f); + for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) { + convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum); + } + __syncthreads(); + + constexpr int V_VEC_SIZE = Dh / THREADS_PER_VALUE; + using V_vec = typename V_vec_::Type; + + int vo = tid / THREADS_PER_VALUE; + int vi = (tid % THREADS_PER_VALUE) * V_VEC_SIZE; + + T *v_cache = ¶ms.cache_kv[params.batch_size * params.num_head * + params.max_seq_length * Dh + + bhi * params.max_seq_length * Dh + vi]; + +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + using V_vec_acum = typename V_vec_acum_fp32_::Type; +#else + using V_vec_acum = V_vec; +#endif + + V_vec_acum out; + zero(out); + + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) { + V_vec v = *reinterpret_cast(&v_cache[ti * Dh]); +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + float logit = logits_smem[ti]; + out = fma(logit, cast_to_float(v), out); +#else + T logit = logits_smem[ti]; + // Update the partial sums. + out = fma(logit, v, out); +#endif + } + +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + if (bi == 0 && hi == 0 && tid == 0) { + printf("======logits_out=====\n"); + for (int i = 0; i <= params.timestep; ++i) printf("%f ", logits_smem[i]); + printf("\n"); + } + __syncthreads(); +#endif + + if (vo == (params.timestep % V_PER_ITER)) { + V_vec v = *reinterpret_cast( + ¶ms.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]); + V_vec v_bias = *reinterpret_cast( + ¶ms.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]); + v = add(v, v_bias); + *reinterpret_cast(&v_cache[params.timestep * Dh]) = v; + +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + out = fma(logits_smem[params.timestep], cast_to_float(v), out); +#else + out = fma(logits_smem[params.timestep], v, out); +#endif + } + + __syncthreads(); + +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { + int midpoint = active_groups / 2; + + if (vo >= midpoint && vo < active_groups) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float( + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), + out); +#else + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; +#endif + } + __syncthreads(); + if (vo < midpoint) { + out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); + } + __syncthreads(); + } + + if (vo == 0) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), + out); +#else + *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; +#endif + } + +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + __syncthreads(); + if (bi == 0 && hi == 0 && tid == 0) { + printf("======fmha_out=====\n"); + for (int i = 0; i < Dh; ++i) + printf("%f ", static_cast(params.out[i])); + printf("\n"); + } +#endif +} + +template +inline size_t smem_size_in_bytes( + const Masked_multihead_attention_params ¶ms, int dim_head, + int threads_per_value, int threads_per_block) { + size_t qk_sz = div_up(params.timestep + 1, 4) * 16; + size_t logits_sz = 0; + +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if (sizeof(T) != 4) { + logits_sz = div_up(params.max_seq_length, 4) * 4 * sizeof(T); + } +#endif + size_t softmax_sz = qk_sz + logits_sz; + + int rows_per_red = threads_per_block / threads_per_value; + size_t red_sz = rows_per_red * dim_head * sizeof(T) / 2; + + return max(softmax_sz, red_sz); +} + +#define MMHA_LAUNCH_KERNEL(T, Dh, THDS_PER_KEY, THDS_PER_VALUE, \ + THDS_PER_BLOCK, stream) \ + size_t smem_sz = \ + smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_head, params.batch_size); \ + masked_multihead_attention_kernel< \ + T, Dh, THDS_PER_KEY, THDS_PER_VALUE, \ + THDS_PER_BLOCK><<>>(params) + +template +void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, + const cudaStream_t &stream) { + constexpr int THREADS_PER_VALUE = Dh * sizeof(T) / 16; + if (params.timestep < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, 4, THREADS_PER_VALUE, 64, stream); + } else if (params.timestep < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, 2, THREADS_PER_VALUE, 128, stream); + } else { + MMHA_LAUNCH_KERNEL(T, Dh, 1, THREADS_PER_VALUE, 256, stream); + } +} + +template +void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor, + const Tensor &qkv_bias_tensor, const Tensor &src_mask_tensor, + Tensor *cache_kv_tensor, Tensor *out_tensor, int batch_size, + int max_seq_length, int num_head, int dim_head, int timestep, + float inv_sqrt_dh) { + Masked_multihead_attention_params params; + params.out = out_tensor->data(); + params.qkv = qkv_tensor.data(); + params.qkv_bias = qkv_bias_tensor.data(); + params.attn_mask = src_mask_tensor.data(); + params.cache_kv = cache_kv_tensor->data(); + + params.batch_size = batch_size; + params.num_head = num_head; + params.timestep = timestep; + params.max_seq_length = max_seq_length; + params.inv_sqrt_dh = inv_sqrt_dh; + + switch (dim_head) { + case 32: + fmha_launch_kernel(params, dev_ctx.stream()); + break; + case 64: + fmha_launch_kernel(params, dev_ctx.stream()); + break; + case 128: + fmha_launch_kernel(params, dev_ctx.stream()); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "dim_head = %d is unsupport, only support " + "dim_head = 32, 64 or 128 for now.", + dim_head)); + } +} + +// NOTE: simd with 16Bytes(128bit), float is 4, float16 is 8 +constexpr int VEC_16B = 16; + +template +__global__ void write_cache_k_kernel(T *cache_k, const T *k, const int num_head, + const int dim_head, const int seq_len, + const int max_seq_len) { + const int bi = blockIdx.y; + const int hi = blockIdx.z; + constexpr int X_ELEMS = VEC_16B / sizeof(T); + + // [bsz, num_head, seq_len, dim_head/x, x] + auto k_src = reinterpret_cast( + k + bi * num_head * seq_len * dim_head + hi * seq_len * dim_head); + // [bsz, num_head, dim_head/x, max_seq_len, x] + auto k_dst = reinterpret_cast( + cache_k + bi * num_head * max_seq_len * dim_head + + hi * max_seq_len * dim_head); + + const int out_idx = blockIdx.x * blockDim.x + threadIdx.x; + // vec size + int dim_head_div_x = dim_head / X_ELEMS; + + // FIXME(wangxi): num_head is not need? + // if (out_idx >= num_head * dim_head_div_x * max_seq_len) return; + if (out_idx >= dim_head_div_x * max_seq_len) return; + + int idx = out_idx; + const int k_seq_len_id = idx % max_seq_len; + // idx = (idx - k_seq_len_id) / max_seq_len; + idx = idx / max_seq_len; + const int k_vec_id = idx % dim_head_div_x; + + if (k_seq_len_id < seq_len) { + k_dst[out_idx] = k_src[k_seq_len_id * dim_head_div_x + k_vec_id]; + } +} + +template +__global__ void write_cache_v_kernel(T *cache_v, const T *v, const int num_head, + const int dim_head, const int seq_len, + const int max_seq_len) { + const int bi = blockIdx.y; + const int hi = blockIdx.z; + + // [bsz, num_head, seq_len, dim_head/x, x] + auto v_src = reinterpret_cast( + v + bi * num_head * seq_len * dim_head + hi * seq_len * dim_head); + // [bsz, num_head, max_seq_len, dim_head/x, x] + auto v_dst = reinterpret_cast( + cache_v + bi * num_head * max_seq_len * dim_head + + hi * max_seq_len * dim_head); + + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + constexpr int X_ELEMS = VEC_16B / sizeof(T); + const int dim_head_div_x = dim_head / X_ELEMS; + + if (idx >= dim_head_div_x * seq_len) return; + + v_dst[idx] = v_src[idx]; +} + +template +void write_cache_kv(const platform::CUDADeviceContext &dev_ctx, T *cache_k, + T *cache_v, const T *k, const T *v, const int bsz, + const int num_head, const int seq_len, + const int max_seq_len, const int dim_head) { + constexpr int block_sz = 128; + constexpr int x = VEC_16B / sizeof(T); + + assert(dim_head % x == 0); + PADDLE_ENFORCE_EQ( + dim_head % x, 0, + platform::errors::PreconditionNotMet( + "dim_head=%d must be divisible by vec_size=%d", dim_head, x)); + + int max_size = max_seq_len * dim_head / x; + int size = seq_len * dim_head / x; + dim3 grid(div_up(max_size, block_sz), bsz, num_head); + dim3 grid_v(div_up(size, block_sz), bsz, num_head); + + // transpose [bsz, num_head, seq_len, dim_head/x, x]-> + // [bsz, num_head, dim_head/x, max_seq_len, x] + write_cache_k_kernel<<>>( + cache_k, k, num_head, dim_head, seq_len, max_seq_len); + + // copy [bsz, num_head, seq_len, dim_head/x, x]-> + // [bsz, num_head, max_seq_len, dim_head/x, x] + write_cache_v_kernel<<>>( + cache_v, v, num_head, dim_head, seq_len, max_seq_len); +} + +} // namespace + +template +class FusedMultiTransformerOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; + auto place = ctx.GetPlace(); + auto &dev_ctx = ctx.cuda_device_context(); + + auto *time_step = ctx.Input("TimeStep"); + // 0. input + auto *input_x = ctx.Input("X"); + const auto input_x_dims = input_x->dims(); + int bsz = input_x_dims[0]; + int seq_len = input_x_dims[1]; + int dim_embed = input_x_dims[2]; + int bsz_seq = bsz * seq_len; + + // 1. layer norm + const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); + const float epsilon = ctx.Attr("epsilon"); + auto ln_scales = ctx.MultiInput("LnScale"); + auto ln_biases = ctx.MultiInput("LnBias"); + + auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, bsz_seq, dim_embed); + Tensor ln_mean, ln_var; + auto *ln_mean_data = ln_mean.mutable_data({bsz_seq}, place); + auto *ln_var_data = ln_var.mutable_data({bsz_seq}, place); + + // 2. qkv + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto qkv_weights = ctx.MultiInput("QKVW"); + auto qkv_biases = ctx.MultiInput("QKVBias"); + const auto qkv_w_dims = qkv_weights[0]->dims(); + int num_head = qkv_w_dims[1]; + int dim_head = qkv_w_dims[2]; + int hidden_size = num_head * dim_head; + int output_size = 3 * hidden_size; + int input_size = dim_embed; + + bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; + // (transA, transB, compute_bias) = (false, true, false) + auto qkv_compute = AttnMatMul(dev_ctx, false, true, bsz_seq, output_size, + input_size, compute_bias); + Tensor qkv_out; + auto *qkv_out_data = + qkv_out.mutable_data({bsz, seq_len, 3, num_head, dim_head}, place); + + // 3. fmha + AttnDropoutParam attn_param(true, "upscale_in_train", 0.0, true, true, 0, + nullptr); + auto fmha_compute = + FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); + auto *src_mask = ctx.Input("SrcMask"); + auto cache_kvs = ctx.MultiInput("CacheKV"); + auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); + // auto *time_step = ctx.Input("TimeStep"); + + auto out_seq_len = seq_len; + if (time_step) { + PADDLE_ENFORCE_EQ(time_step->place(), platform::CPUPlace(), + platform::errors::PreconditionNotMet( + "The place of input(TimeStep) must be CPUPlace.")); + // cache_seq_len + int time_step_value = time_step->data()[0]; + PADDLE_ENFORCE_GT(time_step_value, 0, + platform::errors::PreconditionNotMet( + "The value of time_step must > 0, but now is %d", + time_step_value)); + PADDLE_ENFORCE_EQ( + seq_len, 1, + platform::errors::PreconditionNotMet( + "In decode stage, the seq_len of input must be 1, but now is %d", + seq_len)); + out_seq_len += time_step_value; + } + + Tensor transpose_out_2, qk_out; + auto *transpose_out_2_data = transpose_out_2.mutable_data( + {3, bsz, num_head, seq_len, dim_head}, place); + auto *qk_out_data = + qk_out.mutable_data({bsz, num_head, seq_len, out_seq_len}, place); + + Tensor src_mask_out, softmax_out; + Tensor attn_dropout_mask_out, attn_dropout_out; + Tensor qktv_out, fmha_out; + auto *src_mask_out_data = src_mask_out.mutable_data( + {bsz, num_head, seq_len, out_seq_len}, place); + auto *softmax_out_data = softmax_out.mutable_data( + {bsz, num_head, seq_len, out_seq_len}, place); + + auto *attn_dropout_mask_out_data = attn_dropout_mask_out.mutable_data( + {bsz, num_head, seq_len, out_seq_len}, place); + auto *attn_dropout_data_data = attn_dropout_out.mutable_data( + {bsz, num_head, seq_len, out_seq_len}, place); + + auto *qktv_out_data = + qktv_out.mutable_data({bsz, num_head, seq_len, dim_head}, place); + auto *fmha_out_data = + fmha_out.mutable_data({bsz, seq_len, num_head, dim_head}, place); + + // 4. out_linear + auto out_linear_weights = ctx.MultiInput("OutLinearW"); + auto out_linear_biases = ctx.MultiInput("OutLinearBias"); + int ring_id = ctx.Attr("ring_id"); + // (transA, transB, compute_bias) = (false, false, false) + auto out_linear_compute = AttnMatMul(dev_ctx, false, false, bsz_seq, + dim_embed, hidden_size, false); + + // 5. ln(residual + bias) + DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( + dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon); + auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); + auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); + Tensor bias_dropout_residual_out, dropout_mask_out; + auto *bias_dropout_residual_out_data = + bias_dropout_residual_out.mutable_data({bsz, seq_len, dim_embed}, + place); + auto *dropout_mask_out_data = dropout_mask_out.mutable_data( + {bsz, seq_len, dim_embed}, place); + + // 6. ffn matmul1 + auto ffn1_weights = ctx.MultiInput("FFN1Weight"); + auto ffn1_biases = ctx.MultiInput("FFN1Bias"); + auto ffn1_weight_dim = ffn1_weights[0]->dims(); + + int dim_ffn = ffn1_weight_dim[1]; + auto ffn1_linear_compute = AttnMatMul(dev_ctx, false, false, bsz_seq, + dim_ffn, dim_embed, false); + Tensor ffn1_out; + auto *ffn1_out_data = ffn1_out.mutable_data({bsz_seq, dim_ffn}, place); + + // 7. ffn act + bias + DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutHelper fused_act_dropout_helper( + dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param); + Tensor ffn1_dropout_out, ffn1_dropout_mask; + auto *ffn1_dropout_out_data = + ffn1_dropout_out.mutable_data({bsz_seq, dim_ffn}, place); + auto *ffn1_dropout_mask_data = + ffn1_dropout_mask.mutable_data({bsz_seq, dim_ffn}, place); + + // 8. ffn2 matmul + auto ffn2_weights = ctx.MultiInput("FFN2Weight"); + auto ffn2_biases = ctx.MultiInput("FFN2Bias"); + auto ffn2_linear_compute = AttnMatMul(dev_ctx, false, false, bsz_seq, + dim_embed, dim_ffn, false); + + // 9. ffn2 residual bias + DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutLayerNormHelper ffn2_fused_dropout_helper( + dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); + + // calc + auto *out = ctx.Output("Out"); + auto *from_data = out->mutable_data(place); + Tensor *from_tensor = out; + Tensor tmp_out; + auto *tmp_out_data = + tmp_out.mutable_data({bsz, seq_len, dim_embed}, place); + + auto *x_data = input_x->data(); + Tensor *buf0 = nullptr; + Tensor *buf1 = nullptr; + + // step0: x --> buf1 + // step1: buf1 --> buf0 + // step2: buf0 --> buf1 + int layers = qkv_weights.size(); + if (layers & 1) { + // odd, set buf1 as out + buf0 = &tmp_out; + buf1 = out; + } else { + // even, set buf0 as out + buf0 = out; + buf1 = &tmp_out; + } + + for (int i = 0; i < layers; ++i) { + // step1. layer_norm + if (i == 0 && pre_layer_norm) { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + // TODO(wangxi): can remove mean var in inference + ln_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data, + buf1->data(), ln_mean_data, ln_var_data); + } else if (!pre_layer_norm) { + PADDLE_THROW(platform::errors::Unimplemented( + "Unimplemented post_layer_norm for now.")); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step1"; +#endif + + // step2. qkv + const Tensor *qkv_bias = qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; + // NOTE: in decoder stage, bias is fused in fmha + const Tensor *bias = time_step ? nullptr : qkv_bias; + qkv_compute.ComputeForward(qkv_weights[i], buf1, bias, &qkv_out, + &qkv_out); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step2"; +#endif + + // step3. fmha + const Tensor *cache_kv = cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; + Tensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; + + if (time_step) { // generation decoder stage + // [2, batch_size, num_head, max_seq_len, head_size] + int max_seq_len = cache_kv->dims()[3]; + fmha(dev_ctx, qkv_out, *qkv_bias, *src_mask, cache_kv_out, &fmha_out, + bsz, max_seq_len, num_head, dim_head, time_step->data()[0], + 1. / sqrt(dim_head)); + } else if (cache_kv_out) { // generation context stage + // TODO(wangxi): can remove dropout in inference + fmha_compute.ComputeForward( + qkv_out, nullptr, src_mask, &transpose_out_2, nullptr, &qk_out, + &src_mask_out, &softmax_out, &attn_dropout_mask_out, + &attn_dropout_out, &qktv_out, &fmha_out); + // [3, bsz, num_head, seq_len, head_dim] + T *qkv_data = transpose_out_2_data; + int64_t q_size = bsz * seq_len * num_head * dim_head; + int64_t k_size = q_size; + const T *q_ptr = qkv_data; + const T *k_ptr = q_ptr + q_size; + const T *v_ptr = k_ptr + k_size; + + // [2, bsz, num_head, max_seq_len, head_dim] + int max_seq_len = cache_kv_out->dims()[3]; + T *cache_kv_data = cache_kv_out->data(); + int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; + + T *cache_k_ptr = cache_kv_data; + T *cache_v_ptr = cache_kv_data + cache_k_size; + + write_cache_kv(dev_ctx, cache_k_ptr, cache_v_ptr, k_ptr, v_ptr, bsz, + num_head, seq_len, max_seq_len, dim_head); + } else { // not generation + // TODO(wangxi): can remove dropout in inference + fmha_compute.ComputeForward( + qkv_out, cache_kv, src_mask, &transpose_out_2, cache_kv_out, + &qk_out, &src_mask_out, &softmax_out, &attn_dropout_mask_out, + &attn_dropout_out, &qktv_out, &fmha_out); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step3"; +#endif + + // step4. out_linear + out_linear_compute.ComputeForward(out_linear_weights[i], &fmha_out, + nullptr, buf1, nullptr); + AllReduce(*buf1, ring_id, dev_ctx); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step4"; +#endif + + // step5. ln(residual + dropout(input + bias)) + if (pre_layer_norm) { + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases[i]->data(); + + // inplace + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, buf1->data(), x_data, out_linear_bias_data, + ln_scale_data, ln_bias_data, bias_dropout_residual_out_data, + dropout_mask_out_data, buf1->data(), ln_mean_data, ln_var_data); + } else { + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step5"; +#endif + + // step6. ffn matmul1 + ffn1_linear_compute.ComputeForward(ffn1_weights[i], buf1, nullptr, + &ffn1_out, nullptr); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step6"; +#endif + + // step7. act bias + // TODO(wangxi): remove dropout mask in inference + fused_act_dropout_helper.DropoutActBias( + dev_ctx, ffn1_out_data, ffn1_biases[i]->data(), "gelu", + ffn1_dropout_out_data, ffn1_dropout_mask_data); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step7"; +#endif + + // step8. ffn matmul2 + ffn2_linear_compute.ComputeForward(ffn2_weights[i], &ffn1_dropout_out, + nullptr, buf1, nullptr); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step8.0"; +#endif + + AllReduce(*buf1, ring_id, dev_ctx); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step8.1"; +#endif + + // step9. residual bias + if (pre_layer_norm) { + // TODO(wangxi): remove dropout mask in inference + if (i < layers - 1) { + auto *ln_scale_data = ln_scales[i + 1]->data(); + auto *ln_bias_data = ln_biases[i + 1]->data(); + ffn2_fused_dropout_helper.LayernormResidualDropoutBias( + dev_ctx, buf1->data(), bias_dropout_residual_out_data, + ffn2_biases[i]->data(), ln_scale_data, ln_bias_data, + buf1->data(), dropout_mask_out_data, buf0->data(), + ln_mean_data, ln_var_data); + } else { + ffn2_fused_dropout_helper.ResidualDropoutBias( + dev_ctx, buf1->data(), bias_dropout_residual_out_data, + ffn2_biases[i]->data(), buf1->data(), + dropout_mask_out_data); + } + } else { + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step9"; +#endif + x_data = buf1->data(); + std::swap(buf0, buf1); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(fused_multi_transformer, + ops::FusedMultiTransformerOpKernel, + ops::FusedMultiTransformerOpKernel); diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index 7b128bd3b0..2b849968c7 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -32,6 +32,10 @@ std::map> op_ins_map = { {"fused_attention", {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "CacheKV", "SrcMask", "OutLinearW", "OutLinearBias", "Ln2Scale", "Ln2Bias"}}, + {"fused_multi_transformer", + {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "CacheKV", "TimeStep", + "SrcMask", "OutLinearW", "OutLinearBias", "FFNLnScale", "FFNLnBias", + "FFN1Weight", "FFN1Bias", "FFN2Weight", "FFN2Bias"}}, {"instance_norm", {"X", "Scale", "Bias"}}, {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}}, {"label_smooth", {"X", "PriorDist"}}, @@ -176,6 +180,7 @@ std::map> op_outs_map = { {"lamb", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, + {"fused_multi_transformer", {"CacheKVOut", "Out"}}, }; // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are @@ -253,6 +258,7 @@ std::map> op_passing_outs_map = { {"assign_value", {"Out"}}, {"split", {"Out"}}, {"concat", {"Out"}}, + {"fused_multi_transformer", {"CacheKVOut"}}, }; // NOTE(pangyoki): Tensor View Strategy. diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index 9dba5d658d..7b2546f70a 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -162,6 +162,7 @@ gray_list = { 'split', 'fused_feedforward', 'fused_attention', + 'fused_multi_transformer', } # The set of ops that don't support fp16 calculation diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 760e9ceb9e..0100866806 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -109,6 +109,8 @@ def _keep_fp32_input(op, in_name): return in_name in { 'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias" } + if op_type == 'fused_multi_transformer': + return in_name in {'LnScale', 'LnBias', 'FFNLnScale', 'FFNLnBias'} return False diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index c6111391b7..12ed7b975a 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -25,6 +25,7 @@ list(APPEND DIST_TEST_OPS test_ir_pass_pipeline) list(APPEND DIST_TEST_OPS test_static_model_parallel) list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_feedforward) list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_attention) +list(APPEND DIST_TEST_OPS test_static_model_parallel_fused_multi_transformer) list(APPEND DIST_TEST_OPS test_parallel_dygraph_se_resnext) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height) @@ -128,6 +129,7 @@ if(NOT WITH_GPU) LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api) + LIST(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_op) LIST(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer) endif() @@ -1187,6 +1189,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) set_tests_properties(test_static_model_parallel PROPERTIES TIMEOUT 240) set_tests_properties(test_static_model_parallel_fused_feedforward PROPERTIES TIMEOUT 120) set_tests_properties(test_static_model_parallel_fused_attention PROPERTIES TIMEOUT 120) + set_tests_properties(test_static_model_parallel_fused_multi_transformer PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_split_embedding test_collective_split_embedding_none_divisible test_collective_split_row_linear diff --git a/python/paddle/fluid/tests/unittests/static_model_parallel_fused_multi_transformer.py b/python/paddle/fluid/tests/unittests/static_model_parallel_fused_multi_transformer.py new file mode 100644 index 0000000000..f9c5d4d78c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/static_model_parallel_fused_multi_transformer.py @@ -0,0 +1,193 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np + +import paddle +import paddle.fluid as fluid +from test_dist_base import TestDistRunnerBase, runtime_main +from paddle.incubate.nn import FusedMultiTransformer +import paddle.distributed.fleet as fleet + +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from paddle.fluid.dygraph.layers import Layer +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid import core +from paddle.nn.initializer import Constant + +paddle.enable_static() + + +def get_param_attr(weight, bias): + weight_attr = paddle.ParamAttr( + initializer=fluid.initializer.NumpyArrayInitializer(weight)) + bias_attr = paddle.ParamAttr( + initializer=fluid.initializer.NumpyArrayInitializer(bias)) + return weight_attr, bias_attr + + +DTYPE = "float32" +MODEL_PARALLEL_SIZE = 2 +num_head = 2 * MODEL_PARALLEL_SIZE +dim_head = 4 +hidden = num_head * dim_head +dim_ffn = 4 * hidden + + +def create_model(data, rank): + np.random.seed(2021) + ln_w = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE) + ln_b = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE) + qkv_w = np.random.uniform( + -1, 1, size=(3, num_head, dim_head, hidden)).astype(DTYPE) + qkv_b = np.random.uniform(-1, 1, size=(3, num_head, dim_head)).astype(DTYPE) + linear_w = np.random.uniform( + -1, 1, size=(num_head * dim_head, hidden)).astype(DTYPE) + linear_b = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE) + + ffn_ln_w = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE) + ffn_ln_b = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE) + ffn1_w = np.random.uniform(-1, 1, size=(hidden, dim_ffn)).astype(DTYPE) + ffn1_b = np.random.uniform(-1, 1, size=(dim_ffn, )).astype(DTYPE) + ffn2_w = np.random.uniform(-1, 1, size=(dim_ffn, hidden)).astype(DTYPE) + ffn2_b = np.random.uniform(-1, 1, size=(hidden, )).astype(DTYPE) + + if rank is not None: + start = 0 if rank == 0 else (num_head // MODEL_PARALLEL_SIZE) + end = start + (num_head // MODEL_PARALLEL_SIZE) + col_qkv_w = qkv_w[:, start:end, :, :] + col_qkv_b = qkv_b[:, start:end, :] + row_linear_w = linear_w[(start * dim_head):(end * dim_head), :] + + ln_w_attr, ln_b_attr = get_param_attr(ln_w, ln_b) + qkv_w_attr, qkv_b_attr = get_param_attr(col_qkv_w, col_qkv_b) + linear_w_attr, linear_b_attr = get_param_attr(row_linear_w, linear_b) + + start = 0 if rank == 0 else (dim_ffn // MODEL_PARALLEL_SIZE) + end = start + (dim_ffn // MODEL_PARALLEL_SIZE) + col_ffn1_w = ffn1_w[:, start:end] + col_ffn1_b = ffn1_b[start:end] + row_ffn2_w = ffn2_w[start:end, :] + + ffn_ln_w_attr, ffn_ln_b_attr = get_param_attr(ffn_ln_w, ffn_ln_b) + ffn1_w_attr, ffn1_b_attr = get_param_attr(col_ffn1_w, col_ffn1_b) + ffn2_w_attr, ffn2_b_attr = get_param_attr(row_ffn2_w, ffn2_b) + + multi_transformer = FusedMultiTransformer( + hidden, + num_head, + dim_ffn, + dropout_rate=0.0, + activation="gelu", + normalize_before=True, + ln_scale_attrs=[ln_w_attr], + ln_bias_attrs=[ln_b_attr], + qkv_weight_attrs=[qkv_w_attr], + qkv_bias_attrs=[qkv_b_attr], + linear_weight_attrs=[linear_w_attr], + linear_bias_attrs=[linear_b_attr], + ffn_ln_scale_attrs=[ffn_ln_w_attr], + ffn_ln_bias_attrs=[ffn_ln_b_attr], + ffn1_weight_attrs=[ffn1_w_attr], + ffn1_bias_attrs=[ffn1_b_attr], + ffn2_weight_attrs=[ffn2_w_attr], + ffn2_bias_attrs=[ffn2_b_attr], + nranks=MODEL_PARALLEL_SIZE, + ring_id=0) + result = multi_transformer(data) + else: + ln_w_attr, ln_b_attr = get_param_attr(ln_w, ln_b) + qkv_w_attr, qkv_b_attr = get_param_attr(qkv_w, qkv_b) + linear_w_attr, linear_b_attr = get_param_attr(linear_w, linear_b) + + ffn_ln_w_attr, ffn_ln_b_attr = get_param_attr(ffn_ln_w, ffn_ln_b) + ffn1_w_attr, ffn1_b_attr = get_param_attr(ffn1_w, ffn1_b) + ffn2_w_attr, ffn2_b_attr = get_param_attr(ffn2_w, ffn2_b) + + multi_transformer = FusedMultiTransformer( + hidden, + num_head, + dim_ffn, + dropout_rate=0.0, + activation="gelu", + normalize_before=True, + ln_scale_attrs=[ln_w_attr], + ln_bias_attrs=[ln_b_attr], + qkv_weight_attrs=[qkv_w_attr], + qkv_bias_attrs=[qkv_b_attr], + linear_weight_attrs=[linear_w_attr], + linear_bias_attrs=[linear_b_attr], + ffn_ln_scale_attrs=[ffn_ln_w_attr], + ffn_ln_bias_attrs=[ffn_ln_b_attr], + ffn1_weight_attrs=[ffn1_w_attr], + ffn1_bias_attrs=[ffn1_b_attr], + ffn2_weight_attrs=[ffn2_w_attr], + ffn2_bias_attrs=[ffn2_b_attr]) + result = multi_transformer(data) + + # fused_multi_transformer have no backward + result.stop_gradient = True + predict = paddle.mean(result) + return predict + + +class TestModelParallel(TestDistRunnerBase): + def get_model(self, batch_size=2, use_dgc=False, dist_strategy=None): + # Input data + seq_len = 2 + data_in = fluid.data( + name='data_in', shape=[batch_size, seq_len, hidden], dtype=DTYPE) + + if dist_strategy: + data_loader = fluid.io.DataLoader.from_generator( + feed_list=[data_in], + capacity=64, + use_double_buffer=False, + iterable=False) + + if dist_strategy: + fleet.init(is_collective=True) + strategy = fleet.DistributedStrategy() + strategy.tensor_parallel = True + strategy.tensor_parallel_configs = {'tensor_parallel_degree': 2} + + rank = fleet.worker_index() if dist_strategy else None + avg_cost = create_model(data_in, rank) + opt = fluid.optimizer.SGD(0.1) + + if dist_strategy: + dist_opt = fleet.distributed_optimizer( + optimizer=opt, strategy=strategy) + dist_opt.minimize(avg_cost) + else: + opt.minimize(avg_cost) + + def gen_data(): + np.random.seed(2021) + while True: + data = [np.random.random([seq_len, hidden]).astype(DTYPE)] + yield data + + train_reader = paddle.batch(gen_data, batch_size=batch_size) + + if dist_strategy: + return None, avg_cost, train_reader, None, None, None, data_loader + else: + return None, avg_cost, train_reader, None, None, None + + +if __name__ == "__main__": + runtime_main(TestModelParallel) diff --git a/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py new file mode 100644 index 0000000000..8f77972de8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py @@ -0,0 +1,542 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.fluid.core as core +import paddle.nn.functional as F +import paddle.incubate.nn.functional as incubate_f +from paddle.nn.layer.norm import LayerNorm +from paddle.nn.layer.common import Linear, Dropout +from paddle.nn.layer.transformer import _convert_attention_mask +from paddle import tensor +from paddle.fluid import layers +import unittest +from op_test import OpTest +from paddle.fluid.framework import default_main_program +from paddle.fluid.dygraph.layers import Layer +from paddle.fluid.layer_helper import LayerHelper +from paddle.nn.initializer import Constant +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from paddle.fluid.framework import _non_static_mode, default_main_program +from paddle import _C_ops +from paddle.incubate.nn.functional import fused_multi_transformer + +default_main_program().random_seed = 42 + + +class TestFusedMultiTransformerOp(OpTest): + def setUp(self): + self.config() + self.generate_input_data() + + self.rtol = 1e-5 + # FIXME(wangxi): Because there is a problem with the test precision + # on A100, atol is temporarily set to 1e-2, and it will be + # changed back after the precision problem is solved. + self.atol = 1e-2 + # make sure local development precision + if "V100" in paddle.device.cuda.get_device_name(): + self.atol = 1e-4 + if self.x_type is np.float16: + self.atol = 1e-1 + + paddle.set_default_dtype(self.x_type) + self.__class__.op_type = "fused_multi_transformer" + # use autograd to check grad in this unittest. + self.__class__.no_need_check_grad = False + + bias_attr = paddle.fluid.ParamAttr( + initializer=paddle.fluid.initializer.Constant(value=0.0005)) + self.q_proj = Linear( + self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=bias_attr) + #bias_attr=self.bias_attr) + + self.k_proj = Linear( + self.kdim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.v_proj = Linear( + self.vdim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.out_proj = Linear( + self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + + self.ffn1_proj = Linear( + self.embed_dim, + 4 * self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.ffn2_proj = Linear( + 4 * self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + + paddle.set_default_dtype(np.float32) + self.norm = LayerNorm(self.embed_dim) + self.ffn_norm = LayerNorm(self.embed_dim) + + paddle.set_default_dtype(self.x_type) + self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train") + self.activation = getattr(F, self.act_method) + + def config(self): + # for debug + self.debug = False + + self.x_type = np.float32 + self.attn_mask_type = np.float64 + self.pre_layer_norm = True + self.has_attn_mask = True + + # has_cache_kv, gen_cache_kv, stage + # False, False, not generation + # True, True, generation context stage + # True, False, generation decoder stage + self.has_cache_kv = False + self.gen_cache_kv = False + + self.training = False + + self.layers = 4 + self.batch_size = 8 + self.query_length = 128 + self.cache_length = 128 + self.head_dim = 64 + self.num_heads = 16 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.act_method = 'gelu' + self.weight_attr = None + self.bias_attr = None + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length + + def generate_input_data(self): + self.query = np.random.rand(self.batch_size, self.query_length, + self.embed_dim).astype(self.x_type) + out_seq_len = self.key_length + if self.has_cache_kv: + assert self.training is False, ValueError( + 'cache_kv can only used in inference') + self.cache_kv = np.random.rand(2, self.batch_size, self.num_heads, + self.cache_length, + self.head_dim).astype(self.x_type) + if self.gen_cache_kv: + self.cache_kv[:] = 0 + else: + out_seq_len += self.cache_length + else: + self.cache_kv = None + + if self.has_attn_mask: + # [B, n_head, seq_len, out_seq_len] + self.attn_mask = np.ones( + (self.batch_size, 1, self.query_length, out_seq_len), + dtype=self.attn_mask_type) + if self.attn_mask_type == np.int64: + self.attn_mask = np.tril(self.attn_mask) + elif self.attn_mask_type == np.float64: + if self.has_cache_kv and not self.gen_cache_kv: + # NOTE: decoder stage, -1(out_seq_len) should no mask + self.attn_mask[:, :, :, -2] = 0.0 + self.attn_mask = (self.attn_mask - 1.0) * 1e4 + else: + self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e4 + else: + raise ValueError( + "'attn_mask_type' should be 'int64' or 'float64'.") + else: + self.attn_mask = None + self.key, self.value = self.query, self.query + + self.dout = np.random.random((self.batch_size, self.query_length, + self.embed_dim)).astype(self.x_type) + + def GetBaselineOut(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + tensor_query = paddle.to_tensor(self.query, stop_gradient=False) + + cache_kvs = [] + cache_kv = None + if self.has_cache_kv: + cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False) + + if self.has_attn_mask: + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + else: + attn_mask = None + + for i in range(self.layers): + residual = tensor_query + ln1_out = tensor_query + if self.pre_layer_norm: + ln1_out = self.norm(tensor_query) + + q = self.q_proj(ln1_out) + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + k = self.k_proj(ln1_out) + v = self.v_proj(ln1_out) + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + + if self.has_cache_kv: + # [1, B, n_head, cache_seq_len, head_dim] + cache_k, cache_v = paddle.split(cache_kv, 2) + cache_k = paddle.squeeze(cache_k, axis=0) + cache_v = paddle.squeeze(cache_v, axis=0) + # [B, n_head, cache_seq_len + seq_len, head_dim] + # out_seq_len = cache_seq_len + seq_len + if self.debug: + print('q out is') + print(q_out[0, 0, :, :]) + print('cache k out seq=128') + print(k_out[0, 0, :, :]) + if self.gen_cache_kv: + cache_kvs.append((k_out, v_out)) + else: + k_out = paddle.concat([cache_k, k_out], axis=-2) + v_out = paddle.concat([cache_v, v_out], axis=-2) + + # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim] + # --> [B, n_head, seq_len, out_seq_len] + qk_out = layers.matmul( + x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5) + + if self.debug: + print('qk out is') + print(qk_out[0][0][0]) + + if attn_mask is not None: + attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype) + attn_mask_out = qk_out + attn_mask + if self.debug: + print('attn mask out is') + print(attn_mask_out[0][0][0]) + softmax_out = F.softmax(attn_mask_out) + else: + softmax_out = F.softmax(qk_out) + + if self.debug: + print('softmax out is') + print(softmax_out[0][0][0]) + if self.dropout_prob: + dropout_out = F.dropout( + softmax_out, + self.dropout_prob, + training=self.training, + mode="upscale_in_train") + # [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim] + # --> [B, n_head, seq_len, head_dim] + qktv_out = tensor.matmul(dropout_out, v_out) + else: + qktv_out = tensor.matmul(softmax_out, v_out) + + fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) + if self.debug: + print('fmha out is') + print(fmha_out[0][0][0]) + out_linear_in = tensor.reshape( + x=fmha_out, + shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) + out = self.out_proj(out_linear_in) + + residual_out = residual + self.dropout(out) + if not self.pre_layer_norm: + attn_out = self.norm(residual_out) + else: + attn_out = residual_out + + ffn_ln_out = attn_out + if self.pre_layer_norm: + ffn_ln_out = self.ffn_norm(attn_out) + + ffn1_out = self.ffn1_proj(ffn_ln_out) + ffn1_out = self.dropout(self.activation(ffn1_out)) + ffn2_out = self.ffn2_proj(ffn1_out) + + residual_out = attn_out + self.dropout(ffn2_out) + final_out = residual_out + if not self.pre_layer_norm: + final_out = self.ffn_norm(residual_out) + + tensor_query = final_out + + if self.has_cache_kv and self.gen_cache_kv: + return final_out, cache_kvs + return final_out + + def GetFusedMultiTransformerOut(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + q_proj_weight = paddle.to_tensor( + self.q_proj.weight, stop_gradient=False) + k_proj_weight = paddle.to_tensor( + self.k_proj.weight, stop_gradient=False) + v_proj_weight = paddle.to_tensor( + self.v_proj.weight, stop_gradient=False) + out_linear_weight = paddle.to_tensor( + self.out_proj.weight, stop_gradient=False) + ffn1_weight = paddle.to_tensor( + self.ffn1_proj.weight, stop_gradient=False) + ffn2_weight = paddle.to_tensor( + self.ffn2_proj.weight, stop_gradient=False) + + if self.bias_attr is False: + qkv_bias_tensor = None + out_linear_bias = None + else: + q_proj_bias = paddle.to_tensor( + self.q_proj.bias, stop_gradient=False) + k_proj_bias = paddle.to_tensor( + self.k_proj.bias, stop_gradient=False) + v_proj_bias = paddle.to_tensor( + self.v_proj.bias, stop_gradient=False) + qkv_bias = np.concatenate( + (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())) + qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) + qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) + out_linear_bias = paddle.to_tensor( + self.out_proj.bias, stop_gradient=False) + ffn1_bias = paddle.to_tensor( + self.ffn1_proj.bias, stop_gradient=False) + ffn2_bias = paddle.to_tensor( + self.ffn2_proj.bias, stop_gradient=False) + + ln_scale = paddle.to_tensor(self.norm.weight, stop_gradient=False) + ln_bias = paddle.to_tensor(self.norm.bias, stop_gradient=False) + ffn_ln_scale = paddle.to_tensor( + self.ffn_norm.weight, stop_gradient=False) + ffn_ln_bias = paddle.to_tensor(self.ffn_norm.bias, stop_gradient=False) + + q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) + k_proj_weight = k_proj_weight.numpy().transpose((1, 0)) + v_proj_weight = v_proj_weight.numpy().transpose((1, 0)) + qkv_weight = np.concatenate( + (q_proj_weight, k_proj_weight, v_proj_weight)) + qkv_weight = qkv_weight.reshape( + (3, self.num_heads, self.head_dim, self.embed_dim)) + + x = paddle.to_tensor(self.query, stop_gradient=False) + cache_kvs, cache_kv = None, None + time_step = None + if self.has_cache_kv: + cache_kvs = [] + + max_seq_length = (self.cache_length + 128) // 128 * 128 + cache_kv = np.zeros( + [ + 2, self.batch_size, self.num_heads, max_seq_length, + self.head_dim + ], + dtype=self.x_type) + + elems = 4 + if self.x_type is np.float16: + elems = 8 + + assert self.head_dim % elems == 0 + v_elems = self.head_dim // elems + + # [B, num_head, 128, head_dim] + # cache_k_tmp = self.cache_kv[0, :] + # [B, num_head, 128, head_dim / 4, 4] + cache_k_tmp = self.cache_kv[0].reshape([ + self.batch_size, self.num_heads, self.cache_length, v_elems, + elems + ]) + # [B, num_head, head_dim / 4, 128, 4] + cache_k_tmp = cache_k_tmp.transpose([0, 1, 3, 2, 4]) + + cache_kv[0, :].reshape([ + self.batch_size, self.num_heads, v_elems, max_seq_length, elems + ])[:, :, :, :self.cache_length, :] = cache_k_tmp + + cache_kv[1, :, :, :self.cache_length, :] = self.cache_kv[1] + if self.gen_cache_kv: + assert self.query_length == self.cache_length + cache_kv[:] = 0 + else: + time_step = paddle.to_tensor( + [self.cache_length], dtype='int32', place=paddle.CPUPlace()) + if self.has_attn_mask: + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + else: + attn_mask = None + qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False) + epsilon = 1e-05 + ln2_epsilon = 1e-05 + + if attn_mask is not None: + attn_mask = _convert_attention_mask(attn_mask, x.dtype) + + qkv_weights, qkv_biases = [], [] + out_weights, out_biases = [], [] + ln_scales, ln_biases = [], [] + ffn1_weights, ffn1_biases = [], [] + ffn2_weights, ffn2_biases = [], [] + ffn_ln_scales, ffn_ln_biases = [], [] + for i in range(self.layers): + qkv_weights.append(qkv_weight_tensor) + qkv_biases.append(qkv_bias_tensor) + out_weights.append(out_linear_weight) + out_biases.append(out_linear_bias) + ln_scales.append(ln_scale) + ln_biases.append(ln_bias) + ffn1_weights.append(ffn1_weight) + ffn1_biases.append(ffn1_bias) + ffn2_weights.append(ffn2_weight) + ffn2_biases.append(ffn2_bias) + ffn_ln_scales.append(ffn_ln_scale) + ffn_ln_biases.append(ffn_ln_bias) + if self.has_cache_kv: + cache_kvs.append( + paddle.to_tensor( + cache_kv, stop_gradient=False)) + + final_out = fused_multi_transformer( + x, + ln_scales, + ln_biases, + qkv_weights, + qkv_biases, + out_weights, + out_biases, + ffn_ln_scales, + ffn_ln_biases, + ffn1_weights, + ffn1_biases, + ffn2_weights, + ffn2_biases, + pre_layer_norm=self.pre_layer_norm, + epsilon=epsilon, + cache_kvs=cache_kvs, + time_step=time_step, + attn_mask=attn_mask, + dropout_rate=self.dropout_prob, + training=self.training) + + if self.has_cache_kv: + return final_out[0], final_out[1] + + return final_out + + def test_fused_multi_transformer_op(self): + final_out_ref = self.GetBaselineOut() + final_out = self.GetFusedMultiTransformerOut() + if self.has_cache_kv: + final_out, cache_kv_out = final_out + s = cache_kv_out[0].shape + bsz = s[1] + num_head = s[2] + max_seq_len = s[3] + head_dim = s[4] + elems = 8 if self.x_type is np.float16 else 4 + v_elems = head_dim // elems + + if self.debug: + print("cache_k out timestep=128") + print(cache_kv_out[0].reshape([ + 2, bsz, num_head, v_elems, max_seq_len, elems + ])[0, 0, 0, :, self.cache_length, :]) + + print("cache_v out timestep=128") + print(cache_kv_out[0][1, 0, 0, self.cache_length, :]) + + if self.gen_cache_kv: + final_out_ref, cache_kvs = final_out_ref + for i in range(self.layers): + cache_k_ref = cache_kvs[i][0] + cache_v_ref = cache_kvs[i][1] + + cache_k = cache_kv_out[i][0, :] + cache_k = cache_k.reshape( + [bsz, num_head, v_elems, max_seq_len, elems]) + cache_k = cache_k[:, :, :, :self.cache_length, :] + cache_k = cache_k.transpose([0, 1, 3, 2, 4]) + cache_k = cache_k.reshape( + [bsz, num_head, self.cache_length, head_dim]) + + cache_v = cache_kv_out[i][1, :, :, :self.cache_length, :] + + np.testing.assert_allclose( + cache_k_ref, cache_k, rtol=self.rtol, atol=self.atol) + np.testing.assert_allclose( + cache_v_ref, cache_v, rtol=self.rtol, atol=self.atol) + if i == 0: + break + + np.testing.assert_allclose( + final_out_ref, final_out, rtol=self.rtol, atol=self.atol) + + +class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp): + def config(self): + super().config() + self.x_type = np.float16 + self.layers = 3 # odd layers + + +class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp): + def config(self): + super().config() + self.has_cache_kv = True + self.query_length = 1 + self.key_length, self.value_length = 1, 1 + self.layers = 3 # odd layers + + +class TestFusedMultiTransformerOpCacheKVFp16(TestFusedMultiTransformerOp): + def config(self): + super().config() + self.has_cache_kv = True + self.query_length = 1 + self.key_length, self.value_length = 1, 1 + self.x_type = np.float16 + + +class TestFusedMultiTransformerOpGenCacheKV(TestFusedMultiTransformerOp): + def config(self): + super().config() + self.has_cache_kv = True + self.gen_cache_kv = True + + +class TestFusedMultiTransformerOpGenCacheKVFp16(TestFusedMultiTransformerOp): + def config(self): + super().config() + self.has_cache_kv = True + self.gen_cache_kv = True + self.x_type = np.float16 + self.layers = 3 # odd layers + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_static_model_parallel_fused_multi_transformer.py b/python/paddle/fluid/tests/unittests/test_static_model_parallel_fused_multi_transformer.py new file mode 100644 index 0000000000..5475fd4a10 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_static_model_parallel_fused_multi_transformer.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from test_dist_base import TestDistBase + +import os +import paddle + +paddle.enable_static() +flag_name = os.path.splitext(__file__)[0] + + +class TestStaticModelParallel(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_reduce = False + self._use_reader_alloc = False + self._nccl_comm_num = 1 + self._pipeline_mode = True + + def test_dist_static_model_parallel_fused_multi_transformer(self): + import paddle.fluid as fluid + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "static_model_parallel_fused_multi_transformer.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/incubate/nn/__init__.py b/python/paddle/incubate/nn/__init__.py index f359ec1e0d..43fcabf973 100644 --- a/python/paddle/incubate/nn/__init__.py +++ b/python/paddle/incubate/nn/__init__.py @@ -15,10 +15,11 @@ from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401 from .layer.fused_transformer import FusedFeedForward # noqa: F401 from .layer.fused_transformer import FusedTransformerEncoderLayer # noqa: F401 +from .layer.fused_transformer import FusedMultiTransformer # noqa: F401 __all__ = [ #noqa 'FusedMultiHeadAttention', 'FusedFeedForward', 'FusedTransformerEncoderLayer', - + 'FusedMultiTransformer', ] diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index 4d1c3eee02..4da0904877 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -14,5 +14,10 @@ from .fused_transformer import fused_multi_head_attention from .fused_transformer import fused_feedforward +from .fused_transformer import fused_multi_transformer -__all__ = ['fused_multi_head_attention', 'fused_feedforward'] +__all__ = [ + 'fused_multi_head_attention', + 'fused_feedforward', + 'fused_multi_transformer', +] diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 800d5e832f..3e263f1c6d 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -488,3 +488,238 @@ def fused_multi_head_attention(x, attrs=attrs) return (final_out, cache_kv_out) if cache_kv else final_out + + +def fused_multi_transformer(x, + ln_scales, + ln_biases, + qkv_weights, + qkv_biases, + linear_weights, + linear_biases, + ffn_ln_scales, + ffn_ln_biases, + ffn1_weights, + ffn1_biases, + ffn2_weights, + ffn2_biases, + pre_layer_norm=True, + epsilon=1e-05, + cache_kvs=None, + time_step=None, + attn_mask=None, + dropout_rate=0.0, + activation="gelu", + training=False, + mode='upscale_in_train', + ring_id=-1, + name=None): + r""" + This is a fusion operator to compute multi transformer layers in transformer model architecture. + This operator only supports running on GPU. The function of the transformer layer is consistent + with the following pseudo code: + + .. code-block:: python + + if pre_layer_norm: + out = layer_norm(x) + out = qkv_linear(out) + qkv_bias + else: + out = qkv_linear(x) + qkv_bias + out = transpose(out, perm=[2, 0, 3, 1, 4]) + # extract q, k and v from out. + q = out[0:1, ::] + k = out[1:2, ::] + v = out[2:3, ::] + out = q * k^t + out = attn_mask + out + out = softmax(out) + out = dropout(out) + out = out * v + out = transpose(out, perm=[0, 2, 1, 3]) + out = linear(out) + if pre_layer_norm: + out = x + dropout(out + bias) + else: + out = layer_norm(x + dropout(out + bias)) + + residual = out; + if pre_layer_norm: + out = ffn_layer_norm(out) + out = ffn1_linear(out) + out = dropout(activation(out + ffn1_bias)) + out = ffn2_linear(out) + out = residual + dropout(out + ffn2_bias) + if not pre_layer_norm: + out = ffn_layer_norm(out) + + Args: + x (Tensor): the input tensor could be 3-D tensor, the input data type could be float16 or float32, the shape is `[batch\_size, sequence\_length, d\_model]`. + ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of attention layer_norm, the shape is `[d\_model]`. + ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of attention layer_norm. the shape is `[d\_model]`. + qkv_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention qkv computation. The shape is `[3, num\_head, dim\_head, d\_model]`. + qkv_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention qkv computation. The shape is `[3, num\_head, dim\_head]`. + linear_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention linear. The shape is `[num\_head * dim\_head, d\_model]`. + linear_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention linear. The shape is `[d\_model]`. + ffn_ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward layer_norm, the shape is `[d\_model]` + ffn_ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of feedforward layer_norm, the shape is `[d\_model]` + ffn1_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward first linear, the shape is `[d\_model, dim\_feedforward]`. + ffn1_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward first linear, the shape is `[dim\_feedforward]`. + ffn2_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward second linear, the shape is `[dim\_feedforward, d\_model]`. + ffn2_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward second linear, the shape is `[d_model]`. + pre_layer_norm (bool, optional): whether it is pre_layer_norm(True) or post_layer_norm(False). Default True. + epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5. + cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None. + time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None. + attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to + some unwanted positions, usually the paddings or the subsequent positions. It is a tensor + with shape `[batch_size, 1, sequence_length, sequence_length]`. Default None. + dropout_rate (float, optional): The dropout probability of setting units to zero. Default 0.0. + activation (str, optional): The activation. Default "gelu". + training (bool, optional): A flag indicating whether it is in train phrase or not. Default False. + mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer'] + + 1. upscale_in_train(default), upscale the output at training time + + - train: out = input * mask / ( 1.0 - p ) + - inference: out = input + + 2. downscale_in_infer, downscale the output at inference + + - train: out = input * mask + - inference: out = input * (1.0 - p) + ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using mp. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor|tuple: If `cache_kvs` is None, return a tensor that has + the same shape and data type with `x`, representing the output + of Transformer layers. If `cache_kvs` is not None, return the + tuple (output, cache_kvs), which output is the output of + Transformer layers, cache_kvs is inplace with input `cache_kvs`. + + Examples: + .. code-block:: python + + # required: gpu + import paddle + import paddle.incubate.nn.functional as F + import numpy as np + + # input: [batch_size, seq_len, embed_dim] + x = paddle.rand(shape=(2, 4, 128), dtype="float32") + + # ln_scale: [embed_dim], ln_bias: [embed_dim] + ln_scale = paddle.rand(shape=(128,), dtype="float32") + ln_bias = paddle.rand(shape=(128,), dtype="float32") + + # qkv_weight: [3, num_head, head_dim, embed_dim], qkv_bias: [3, num_head, head_dim] + qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32") + qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32") + + # linear_weight: [embed_dim, embed_dim], linear_bias: [embed_dim] + linear_weight = paddle.rand(shape=(128, 128), dtype="float32") + linear_bias = paddle.rand(shape=(128,), dtype="float32") + + # ffn_ln_scale: [embed_dim], ffn_ln_bias: [embed_dim] + ffn_ln_scale = paddle.rand(shape=(128,), dtype="float32") + ffn_ln_bias = paddle.rand(shape=(128,), dtype="float32") + + # ffn1_weight: [embed_dim, 4*embed_dim], ffn1_bias: [4*embed_dim] + ffn1_weight = paddle.rand(shape=(128, 4*128), dtype="float32") + ffn1_bias = paddle.rand(shape=(4*128,), dtype="float32") + + # ffn2_weight: [4*embed_dim, embed_dim], ffn2_bias: [embed_dim] + ffn2_weight = paddle.rand(shape=(4*128, 128), dtype="float32") + ffn2_bias = paddle.rand(shape=(128,), dtype="float32") + + # self attention mask: [batch_size, 1, seq_len, seq_len] + attn_mask = paddle.rand(shape=(2, 1, 4, 4), dtype="float32") + + # output: [batch_size, seq_len, embed_dim] + output = F.fused_multi_transformer( + x, [ln_scale], [ln_bias], [qkv_weight], [qkv_bias], + [linear_weight], [linear_bias], [ffn_ln_scale], [ffn_ln_bias], + [ffn1_weight], [ffn1_bias], [ffn2_weight], [ffn2_bias], + attn_mask=attn_mask) + # [2, 4, 128] + print(output.shape) + """ + if mode not in ('downscale_in_infer', 'upscale_in_train'): + raise ValueError( + "mode argument should be 'downscale_in_infer' or 'upscale_in_train'") + mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer + + if _non_static_mode(): + cache_kv_out, final_out = _C_ops.fused_multi_transformer( + x, ln_scales, ln_biases, qkv_weights, qkv_biases, cache_kvs, + time_step, attn_mask, linear_weights, linear_biases, ffn_ln_scales, + ffn_ln_biases, ffn1_weights, ffn1_biases, ffn2_weights, ffn2_biases, + cache_kvs, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon, + 'dropout_rate', dropout_rate, 'dropout_is_test', not training, + 'dropout_implementation', mode, 'act_method', activation, 'ring_id', + ring_id) + if cache_kvs is not None: + return final_out, cache_kv_out + return final_out + else: + helper = LayerHelper('fused_multi_transformer', **locals()) + dtype = x.dtype + # check dtypes + check_variable_and_dtype(x, 'x', ['float16', 'float32'], + 'fused_multi_transformer') + check_dtype(dtype, 'dtype', ['float16', 'float32'], + 'fused_multi_transformer') + + # set inputs + inputs = dict() + inputs['X'] = [x] + inputs['LnScale'] = ln_scales + inputs['LnBias'] = ln_biases + inputs['QKVW'] = qkv_weights + if qkv_biases is not None: + inputs['QKVBias'] = qkv_biases + if cache_kvs is not None: + assert len(cache_kvs) == len(qkv_weights) + inputs['CacheKV'] = cache_kvs + if time_step is not None: + inputs['TimeStep'] = time_step + inputs['SrcMask'] = attn_mask + inputs['OutLinearW'] = linear_weights + if linear_biases is not None: + inputs['OutLinearBias'] = linear_biases + + inputs['FFNLnScale'] = ffn_ln_scales + inputs['FFNLnBias'] = ffn_ln_biases + inputs['FFN1Weight'] = ffn1_weights + if ffn1_biases is not None: + inputs['FFN1Bias'] = ffn1_biases + inputs['FFN2Weight'] = ffn2_weights + if ffn2_biases is not None: + inputs['FFN2Bias'] = ffn2_biases + + # set attrs + attrs = { + 'pre_layer_norm': pre_layer_norm, + 'epsilon': epsilon, + 'dropout_rate': dropout_rate, + 'dropout_is_test': not training, + 'dropout_implementation': mode, + 'act_method': activation, + 'ring_id': ring_id + } + + outputs = dict() + final_out = helper.create_variable_for_type_inference(dtype=dtype) + outputs['Out'] = final_out + if cache_kvs: + # NOTE: inplace + outputs['CacheKVOut'] = cache_kvs + + helper.append_op( + type='fused_multi_transformer', + inputs=inputs, + outputs=outputs, + attrs=attrs) + + return (final_out, cache_kvs) if cache_kvs else final_out diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index d38e8d1193..d76b990958 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -22,6 +22,20 @@ from paddle.nn.initializer import Constant import collections +# for distributed tensor model parallel +def _set_var_distributed(var): + if var is None: + return + + var.is_distributed = True + + # NOTE: use current_block and find_var_recursive to support while_loop + startup_block = paddle.static.default_startup_program().current_block() + main_block = paddle.static.default_main_program().current_block() + startup_block._find_var_recursive(var.name).is_distributed = True + main_block._find_var_recursive(var.name).is_distributed = True + + class FusedMultiHeadAttention(Layer): """ Attention mapps queries and a set of key-value pairs to outputs, and @@ -608,3 +622,390 @@ class FusedTransformer(Layer): def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None): raise NotImplementedError() + + +class FusedMultiTransformer(Layer): + """ + FusedMultiTransformer is composed of multi transformer layers which contains two + sub-layers which are self (multi-head) attention and feedforward network. The + function of one transformer layer is consistent with the following pseudo code: + + .. code-block:: python + + if pre_layer_norm: + out = layer_norm(x) + out = qkv_linear(out) + qkv_bias + else: + out = qkv_linear(x) + qkv_bias + out = transpose(out, perm=[2, 0, 3, 1, 4]) + # extract q, k and v from out. + q = out[0:1, ::] + k = out[1:2, ::] + v = out[2:3, ::] + out = q * k^t + out = attn_mask + out + out = softmax(out) + out = dropout(out) + out = out * v + out = transpose(out, perm=[0, 2, 1, 3]) + out = linear(out) + if pre_layer_norm: + out = x + dropout(out + bias) + else: + out = layer_norm(x + dropout(out + bias)) + + residual = out; + if pre_layer_norm: + out = ffn_layer_norm(out) + out = ffn1_linear(out) + out = dropout(activation(out + ffn1_bias)) + out = ffn2_linear(out) + out = residual + dropout(out + ffn2_bias) + if not pre_layer_norm: + out = ffn_layer_norm(out) + + Parameters: + embed_dim (int): The expected feature size in the input and output. + num_heads (int): The number of heads in multi-head attention(MHA). + dim_feedforward (int): The hidden layer size in the feedforward network(FFN). + dropout_rate (float, optional): The dropout probability used in pre-process + and post-precess of MHA and FFN sub-layer. Default 0.0 + activation (str, optional): The activation function in the feedforward + network. Default "gelu". + normalize_before (bool, optional): Indicate whether to put layer normalization + into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer + normalization and post-precess includes dropout, residual connection. + Otherwise, no pre-process and post-precess includes dropout, residual + connection, layer normalization. Default True + ln_scale_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property + for Attention layer_norm. For Attention layer_norm weight, if it is a list/tuple, `attrs[0]` + would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as + `attr` for transformer layer 1,etc. Otherwise, all layers both use it as + `attr` to create parameters. Default: None, which means the default weight + parameter property is used. See usage for details in :code:`ParamAttr`. + ln_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property + for Attention layer_norm. For Attention layer_norm bias, if it is a list/tuple, `attrs[0]` + would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as + `attr` for transformer layer 1,etc. Otherwise, all layers both use it as + `attr` to create parameters. The `False` value means the corresponding layer would + not have trainable bias parameter. Default: None, which means the default bias + parameter property is used. See usage for details in :code:`ParamAttr`. + qkv_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property + for Attention qkv computation. For Attention qkv weight, if it is a list/tuple, `attrs[0]` + would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as + `attr` for transformer layer 1,etc. Otherwise, all layers both use it as + `attr` to create parameters. Default: None, which means the default weight + parameter property is used. See usage for details in :code:`ParamAttr`. + qkv_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property + for Attention qkv computation. For Attention qkv bias, if it is a list/tuple, `attrs[0]` + would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as + `attr` for transformer layer 1,etc. Otherwise, all layers both use it as + `attr` to create parameters. The `False` value means the corresponding layer would + not have trainable bias parameter. Default: None, which means the default bias + parameter property is used. See usage for details in :code:`ParamAttr`. + linear_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property + for Attention linear. For Attention linear weight, if it is a list/tuple, `attrs[0]` + would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as + `attr` for transformer layer 1,etc. Otherwise, all layers both use it as + `attr` to create parameters. Default: None, which means the default weight + parameter property is used. See usage for details in :code:`ParamAttr`. + linear_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property + for Attention linear computation. For Attention linear bias, if it is a list/tuple, `attrs[0]` + would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as + `attr` for transformer layer 1,etc. Otherwise, all layers both use it as + `attr` to create parameters. The `False` value means the corresponding layer would + not have trainable bias parameter. Default: None, which means the default bias + parameter property is used. See usage for details in :code:`ParamAttr`. + ffn_ln_scale_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property + for FFN layer_norm. For FFN layer_norm weight, if it is a list/tuple, `attrs[0]` + would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as + `attr` for transformer layer 1,etc. Otherwise, all layers both use it as + `attr` to create parameters. Default: None, which means the default weight + parameter property is used. See usage for details in :code:`ParamAttr`. + ffn_ln_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property + for FFN layer_norm. For FFN layer_norm bias, if it is a list/tuple, `attrs[0]` + would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as + `attr` for transformer layer 1,etc. Otherwise, all layers both use it as + `attr` to create parameters. The `False` value means the corresponding layer would + not have trainable bias parameter. Default: None, which means the default bias + parameter property is used. See usage for details in :code:`ParamAttr`. + ffn1_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property + for FFN first linear. For FFN first linear weight, if it is a list/tuple, `attrs[0]` + would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as + `attr` for transformer layer 1,etc. Otherwise, all layers both use it as + `attr` to create parameters. Default: None, which means the default weight + parameter property is used. See usage for details in :code:`ParamAttr`. + ffn1_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property + for FFN first linear. For FFN first linear bias, if it is a list/tuple, `attrs[0]` + would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as + `attr` for transformer layer 1,etc. Otherwise, all layers both use it as + `attr` to create parameters. The `False` value means the corresponding layer would + not have trainable bias parameter. Default: None, which means the default bias + parameter property is used. See usage for details in :code:`ParamAttr`. + ffn2_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property + for FFN second linear. For FFN second linear weight, if it is a list/tuple, `attrs[0]` + would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as + `attr` for transformer layer 1,etc. Otherwise, all layers both use it as + `attr` to create parameters. Default: None, which means the default weight + parameter property is used. See usage for details in :code:`ParamAttr`. + ffn2_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property + for FFN second linear. For FFN second linear bias, if it is a list/tuple, `attrs[0]` + would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as + `attr` for transformer layer 1,etc. Otherwise, all layers both use it as + `attr` to create parameters. The `False` value means the corresponding layer would + not have trainable bias parameter. Default: None, which means the default bias + parameter property is used. See usage for details in :code:`ParamAttr`. + epsilon (float, optional): Small float value added to denominator of the layer_norm to + avoid dividing by zero. Default: 1e-05. + num_layers (int, optional): The number of layers of the transformer. If `qkv_weight_attrs` + is a list or tuple, the number of layers is obtained from `qkv_weight_attrs`. num_layers + only takes effect when `qkv_weight_attrs` is not a list or tuple. Default: -1. + nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using mp. + ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using mp. + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Examples: + + .. code-block:: python + + # required: gpu + import paddle + from paddle.incubate.nn import FusedMultiTransformer + + # encoder input: [batch_size, src_len, d_model] + enc_input = paddle.rand((2, 4, 128)) + # self attention mask: [batch_size, 1, src_len, src_len] + attn_mask = paddle.rand((2, 1, 4, 4)) + encoder_layers = FusedMultiTransformer(128, 2, 512, num_layers=1) + enc_output = encoder_layers(enc_input, attn_mask) # [2, 4, 128] + """ + + def __init__(self, + embed_dim, + num_heads, + dim_feedforward, + dropout_rate=0.0, + activation="gelu", + normalize_before=True, + ln_scale_attrs=None, + ln_bias_attrs=None, + qkv_weight_attrs=None, + qkv_bias_attrs=None, + linear_weight_attrs=None, + linear_bias_attrs=None, + ffn_ln_scale_attrs=None, + ffn_ln_bias_attrs=None, + ffn1_weight_attrs=None, + ffn1_bias_attrs=None, + ffn2_weight_attrs=None, + ffn2_bias_attrs=None, + epsilon=1e-5, + num_layers=-1, + nranks=1, + ring_id=-1, + name=None): + super(FusedMultiTransformer, self).__init__() + + assert embed_dim > 0, ("Expected embed_dim to be greater than 0, " + "but recieved {}".format(embed_dim)) + assert num_heads > 0, ("Expected nhead to be greater than 0, " + "but recieved {}".format(num_heads)) + assert dim_feedforward > 0, ( + "Expected dim_feedforward to be greater than 0, but recieved {}". + format(dim_feedforward)) + + self.normalize_before = normalize_before + self._dtype = self._helper.get_default_dtype() + self._epsilon = epsilon + self._ring_id = ring_id + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" + + # tensor model parallel + if nranks > 1: + assert ring_id != -1 + assert num_heads % nranks == 0 + assert dim_feedforward % nranks == 0 + num_heads = num_heads // nranks + dim_feedforward = dim_feedforward // nranks + self._dim_feedforward = dim_feedforward + + if isinstance(qkv_weight_attrs, (list, tuple)): + num_layers = len(qkv_weight_attrs) + assert num_layers > 0 + + self.ln_scales, self.ln_biases = [], [] + self.qkv_weights, self.qkv_biases = [], [] + self.linear_weights, self.linear_biases = [], [] + self.ffn_ln_scales, self.ffn_ln_biases = [], [] + self.ffn1_weights, self.ffn1_biases = [], [] + self.ffn2_weights, self.ffn2_biases = [], [] + + def get_attr(attrs, idx): + if isinstance(attrs, (list, tuple)): + assert len(attrs) == num_layers + return attrs[idx] + return attrs + + for i in range(num_layers): + ln_scale_attr = get_attr(ln_scale_attrs, i) + ln_bias_attr = get_attr(ln_bias_attrs, i) + qkv_weight_attr = get_attr(qkv_weight_attrs, i) + qkv_bias_attr = get_attr(qkv_bias_attrs, i) + linear_weight_attr = get_attr(linear_weight_attrs, i) + linear_bias_attr = get_attr(linear_bias_attrs, i) + + ffn_ln_scale_attr = get_attr(ffn_ln_scale_attrs, i) + ffn_ln_bias_attr = get_attr(ffn_ln_bias_attrs, i) + ffn1_weight_attr = get_attr(ffn1_weight_attrs, i) + ffn1_bias_attr = get_attr(ffn1_bias_attrs, i) + ffn2_weight_attr = get_attr(ffn2_weight_attrs, i) + ffn2_bias_attr = get_attr(ffn2_bias_attrs, i) + + ln_scale = self.create_parameter( + attr=ln_scale_attr, + shape=[embed_dim], + default_initializer=Constant(value=1.0)) + ln_bias = self.create_parameter( + attr=ln_bias_attr, shape=[embed_dim], is_bias=True) + qkv_weight = self.create_parameter( + shape=[3, num_heads, self.head_dim, embed_dim], + attr=qkv_weight_attr, + dtype=self._dtype, + is_bias=False) + qkv_bias = self.create_parameter( + shape=[3, num_heads, self.head_dim], + attr=qkv_bias_attr, + dtype=self._dtype, + is_bias=True) + linear_weight = self.create_parameter( + shape=[num_heads * self.head_dim, embed_dim], + attr=linear_weight_attr, + dtype=self._dtype, + is_bias=False) + linear_bias = self.create_parameter( + shape=[embed_dim], + attr=linear_bias_attr, + dtype=self._dtype, + is_bias=True) + + ffn_ln_scale = self.create_parameter( + shape=[embed_dim], + attr=ffn_ln_scale_attr, + is_bias=False, + default_initializer=Constant(1.0)) + ffn_ln_bias = self.create_parameter( + shape=[embed_dim], attr=ffn_ln_bias_attr, is_bias=True) + ffn1_weight = self.create_parameter( + shape=[embed_dim, dim_feedforward], + attr=ffn1_weight_attr, + dtype=self._dtype, + is_bias=False) + ffn1_bias = self.create_parameter( + shape=[dim_feedforward], + attr=ffn1_bias_attr, + dtype=self._dtype, + is_bias=True) + ffn2_weight = self.create_parameter( + shape=[dim_feedforward, embed_dim], + attr=ffn2_weight_attr, + dtype=self._dtype, + is_bias=False) + ffn2_bias = self.create_parameter( + shape=[embed_dim], + attr=ffn2_bias_attr, + dtype=self._dtype, + is_bias=True) + + # tensor model parallel + if nranks > 1: + # column parallel + _set_var_distributed(qkv_weight) + _set_var_distributed(qkv_bias) + _set_var_distributed(ffn1_weight) + _set_var_distributed(ffn1_bias) + # row parallel + _set_var_distributed(linear_weight) + _set_var_distributed(ffn2_weight) + + self.ln_scales.append(ln_scale) + self.ln_biases.append(ln_bias) + self.qkv_weights.append(qkv_weight) + self.qkv_biases.append(qkv_bias) + self.linear_weights.append(linear_weight) + self.linear_biases.append(linear_bias) + + self.ffn_ln_scales.append(ffn_ln_scale) + self.ffn_ln_biases.append(ffn_ln_bias) + self.ffn1_weights.append(ffn1_weight) + self.ffn1_biases.append(ffn1_bias) + self.ffn2_weights.append(ffn2_weight) + self.ffn2_biases.append(ffn2_bias) + + self.dropout_rate = dropout_rate + self.activation = activation + self.name = name + + def forward(self, src, attn_mask=None, caches=None, time_step=None): + """ + Applies multi transformer layers on the input. + + Parameters: + src (Tensor): The input of Transformer layers. It is + a tensor with shape `[batch_size, sequence_length, d_model]`. + The data type should be float16 or float32. + attn_mask (Tensor, optional): A tensor used in multi-head attention + to prevents attention to some unwanted positions, usually the + paddings or the subsequent positions. It is a tensor with shape + `[batch_size, 1, sequence_length, sequence_length]`. It can be + None when nothing wanted or needed to be prevented attention to. + Default None. + caches (list(Tensor)|tuple(Tensor), optional): The cache structure + tensors for the inference generation model. It is only used for + inference and should be None for training. The shape is + `[2, batch_size, num_head, max_seq_len, head_dim]`. Default None. + time_step (Tensor, optional): The time step tensor for the generation + model. Which used in decode stage, to represent the time step, + that is, the real seq_len of CacheKV. The shape is `[1]`, must be + in CPUPlace. Default None. + + Returns: + Tensor|tuple: If `caches` is None, return a tensor that has + the same shape and data type with `src`, representing the output + of Transformer layers. If `caches` is not None, return the + tuple (output, caches), which output is the output of + Transformer layers, caches is inplace with input `caches`. + """ + + if caches is not None: + assert len(caches) == len(self.qkv_weights) + out = incubate_f.fused_multi_transformer( + src, + self.ln_scales, + self.ln_biases, + self.qkv_weights, + self.qkv_biases, + self.linear_weights, + self.linear_biases, + self.ffn_ln_scales, + self.ffn_ln_biases, + self.ffn1_weights, + self.ffn1_biases, + self.ffn2_weights, + self.ffn2_biases, + pre_layer_norm=self.normalize_before, + epsilon=self._epsilon, + cache_kvs=caches, + time_step=time_step, + attn_mask=attn_mask, + dropout_rate=self.dropout_rate, + activation=self.activation, + training=self.training, + mode='upscale_in_train', + ring_id=self._ring_id, + name=self.name) + return out -- GitLab