未验证 提交 a7ec8958 编写于 作者: S Sonder 提交者: GitHub

Move fused_attention op to phi [迁移前向 GPU OpKernel] (#51743)

* add kernel functions

* update kernel functions

* update func parameters' name

* create codes for gpu device

* 调整文件位置

* fix include error

* remove dependent files to phi/

* restore fused_attention_op.cu

* fix dependence errors

* fix dependence errors

* fix include error

* fix all depandence errors[build success]

* remove useless include

* recover useless include

* use phi::ToNCCLDataType

* fix namespace

* update new register code

* fix error in fused_gemm_epilogue_utils

* fix error in FusedAttentionKernel parm

* finish fused_attention registe code[build success]

* add paddle::optional

* add sig file

* fix build error

* fix a include error

* update CMkaeList

* fix parameter sequence

* add include file

* update #if before include

* fix grammly error

* update codes for DropoutParam

* remove const cast

* trans some fluid api to phi api

* add #if

* update test code

* update test codes

* recover test codes

* trans fused_attention to fluid

* move #endif to end

* move #endif

* delete useless files

* use fused attention utils and recover random seed

* remove fluid include in phi
上级 6df4a667
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
#include "paddle/phi/core/errors.h"
namespace phi {
namespace fusion {
template <typename T>
static void AllReduce(phi::DenseTensor &tensor, // NOLINT
const int ring_id,
const phi::GPUContext &dev_ctx) {
if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(ring_id)) {
paddle::distributed::ProcessGroup *pg = map->get(ring_id);
auto pg_nccl = static_cast<paddle::distributed::ProcessGroupNCCL *>(pg);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = paddle::distributed::ReduceOp::SUM;
auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true);
task->Wait();
} else {
auto dtype = phi::ToNCCLDataType(tensor.dtype());
int64_t numel = tensor.numel();
const void *sendbuff = tensor.data<T>();
auto place = dev_ctx.GetPlace();
void *recvbuff =
dev_ctx.template Alloc<T>(&tensor, tensor.numel() * sizeof(T));
auto comm =
paddle::platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = dev_ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream));
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."));
#endif
}
} // namespace fusion
} // namespace phi
......@@ -19,6 +19,8 @@ limitations under the License. */
#include <cuda_runtime_api.h> // NOLINT
#include "cuda.h" // NOLINT
#include "paddle/phi/backends/dynload/cublasLt.h"
#include "paddle/phi/backends/gpu/cuda/cuda_helper.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/kernels/autotune/gpu_timer.h"
......
......@@ -24,6 +24,7 @@ namespace cub = hipcub;
#include <iostream>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/common/memory_utils.h"
......
......@@ -1137,14 +1137,32 @@ void ReduceKernel(const KPDevice& dev_ctx,
is_mean);
}
template <typename Tx,
typename Ty,
template <typename>
class ReduceOp,
typename TransformOp>
void TensorReduceImpl(const phi::GPUContext& dev_ctx,
const phi::DenseTensor& x,
phi::DenseTensor* y,
const TransformOp& transform,
const std::vector<int>& origin_reduce_dims,
gpuStream_t stream,
bool is_mean = false) {
dev_ctx.template Alloc<Ty>(y);
ReduceKernel<Tx, Ty, ReduceOp, TransformOp>(
static_cast<const phi::GPUContext&>(dev_ctx),
x,
y,
transform,
origin_reduce_dims,
is_mean);
}
#endif
template <typename DeviceContext,
typename T,
size_t D,
size_t R_D,
typename Functor>
void ReduceFunctor(const DeviceContext& context,
template <typename Context, typename T, size_t D, size_t R_D, typename Functor>
void ReduceFunctor(const Context& context,
const phi::DenseTensor& input,
phi::DenseTensor* output,
const std::vector<int64_t>& dims,
......@@ -1181,10 +1199,10 @@ void ReduceFunctor(const DeviceContext& context,
}
}
#define HANDLE_REDUCE_DIM(NDIM, RDIM) \
if (ndim == NDIM && rdim == RDIM) { \
ReduceFunctor<DeviceContext, OutT, NDIM, RDIM, Functor>( \
dev_ctx, input, output, dims, keep_dim); \
#define HANDLE_REDUCE_DIM(NDIM, RDIM) \
if (ndim == NDIM && rdim == RDIM) { \
ReduceFunctor<Context, OutT, NDIM, RDIM, Functor>( \
dev_ctx, input, output, dims, keep_dim); \
}
//////////////// HandleLargeDim
......@@ -1220,8 +1238,8 @@ inline void GetShuffledDim(const DDim& src_dims,
}
}
template <typename DeviceContext, typename OutT>
void GetShuffledInput(const DeviceContext& dev_ctx,
template <typename Context, typename OutT>
void GetShuffledInput(const Context& dev_ctx,
const phi::DenseTensor& input,
phi::DenseTensor* shuffled_input,
const std::vector<int64_t>& dims) {
......@@ -1232,19 +1250,19 @@ void GetShuffledInput(const DeviceContext& dev_ctx,
shuffled_input->Resize(shuffled_dims);
dev_ctx.template Alloc<OutT>(shuffled_input);
phi::funcs::TransposeNormal<DeviceContext, OutT> trans;
phi::funcs::TransposeNormal<Context, OutT> trans;
trans(dev_ctx, input, shuffled_input, perm_axis);
}
template <typename DeviceContext, typename OutT, typename Functor>
void HandleLargeDim(const DeviceContext& dev_ctx,
template <typename Context, typename OutT, typename Functor>
void HandleLargeDim(const Context& dev_ctx,
const phi::DenseTensor& input,
phi::DenseTensor* output,
const std::vector<int64_t>& dims,
bool keep_dim) {
// shuffle the reduced dim to the end
phi::DenseTensor shuffled_input;
GetShuffledInput<DeviceContext, OutT>(dev_ctx, input, &shuffled_input, dims);
GetShuffledInput<Context, OutT>(dev_ctx, input, &shuffled_input, dims);
// transpose to 2D tensor whose shape is {unreduced, reduced}.
const int64_t unreduced = output->numel();
......@@ -1266,15 +1284,15 @@ void HandleLargeDim(const DeviceContext& dev_ctx,
DDim output_dim = output->dims();
output->ResizeAndAllocate({unreduced});
ReduceFunctor<DeviceContext, OutT, 2, 1, Functor>(
ReduceFunctor<Context, OutT, 2, 1, Functor>(
dev_ctx, shuffled_input, output, {1}, keep_dim);
output->ResizeAndAllocate(output_dim);
}
////////////// ReduceKernel
template <typename DeviceContext, typename T, typename OutT, typename Functor>
void ReduceKernelImpl(const DeviceContext& dev_ctx,
template <typename Context, typename T, typename OutT, typename Functor>
void ReduceKernelImpl(const Context& dev_ctx,
const phi::DenseTensor& input,
phi::DenseTensor* output,
const std::vector<int64_t>& dims,
......@@ -1295,7 +1313,7 @@ void ReduceKernelImpl(const DeviceContext& dev_ctx,
int ndim = input.dims().size();
int rdim = dims.size();
if (ndim > 6) {
HandleLargeDim<DeviceContext, OutT, Functor>(
HandleLargeDim<Context, OutT, Functor>(
dev_ctx, input, output, dims, keep_dim);
} else {
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
namespace phi {
namespace fusion {
// NOTE: T must be the same as OutType in ComputeBackward
template <typename T, typename InType = T, typename OutType = T>
class AttnLayerNorm {
public:
AttnLayerNorm(const phi::GPUContext& dev_ctx,
float epsilon,
int64_t batch_size,
int64_t feature_size)
: dev_ctx_(dev_ctx),
epsilon_(epsilon),
batch_size_(batch_size),
feature_size_(feature_size) {}
~AttnLayerNorm() {}
void ComputeForward(const InType* x_data,
const phi::funcs::LayerNormParamType<T>* scale_data,
const phi::funcs::LayerNormParamType<T>* bias_data,
OutType* y_data,
phi::funcs::LayerNormParamType<T>* mean_data,
phi::funcs::LayerNormParamType<T>* var_data,
const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
auto stream = dev_ctx_.stream();
switch (phi::funcs::GetDesiredBlockDim(feature_size_)) {
FIXED_BLOCK_DIM_CASE(
phi::funcs::LayerNormForward<T,
phi::funcs::LayerNormParamType<T>,
kBlockDim,
false,
InType,
OutType>
<<<batch_size_, kBlockDim, 0, stream>>>(x_data,
scale_data,
bias_data,
y_data,
mean_data,
var_data,
epsilon_,
feature_size_,
dequant_out_scale_data,
quant_out_scale_offset,
quant_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound));
default:
PADDLE_THROW(
phi::errors::InvalidArgument("Feature_size must be larger than 1"));
break;
}
}
void ComputeBackward(const T* x_data,
const T* d_y_data,
const phi::funcs::LayerNormParamType<T>* scale_data,
const phi::funcs::LayerNormParamType<T>* mean_data,
const phi::funcs::LayerNormParamType<T>* var_data,
T* d_x_data,
phi::funcs::LayerNormParamType<T>* d_scale_data,
phi::funcs::LayerNormParamType<T>* d_bias_data) {
phi::funcs::LayerNormBackward<T, phi::funcs::LayerNormParamType<T>>(
x_data,
d_y_data,
scale_data,
mean_data,
var_data,
d_x_data,
d_scale_data,
d_bias_data,
epsilon_,
batch_size_,
feature_size_,
dev_ctx_);
}
private:
const phi::GPUContext& dev_ctx_;
int64_t batch_size_;
int64_t feature_size_;
float epsilon_;
};
} // namespace fusion
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined(PADDLE_WITH_CUDA)
#include "paddle/phi/backends/dynload/cublasLt.h"
#endif
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
namespace phi {
namespace fusion {
// support gemm-nt and gemm-nn, which is used in fused_attention_op.
template <typename T>
class AttnMatMul {
public:
// (m, n, k) = bsz_seq, output_size, input_size
AttnMatMul(const phi::GPUContext& dev_ctx,
bool transA,
bool transB,
int bsz_seq,
int output_size,
int input_size,
bool compute_bias)
: dev_ctx_(dev_ctx),
transA_(transA),
transB_(transB),
bsz_seq_(bsz_seq),
output_size_(output_size),
input_size_(input_size),
compute_bias_(compute_bias) {}
void ComputeForward(const phi::DenseTensor* weight,
const phi::DenseTensor* input,
const phi::DenseTensor* bias,
phi::DenseTensor* output,
phi::DenseTensor* bias_out,
bool fused = false) {
VLOG(6) << "input.shape={" << input->dims() << "}, weight.shape={"
<< weight->dims() << "}, output.shape={" << output->dims()
<< "}, batch_size=" << bsz_seq_ << ", output_size=" << output_size_
<< ", input_size=" << input_size_ << ", transA=" << transA_
<< ", transB=" << transB_ << ", compute_bias=" << compute_bias_
<< ", fused=" << fused;
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
if (compute_bias_ && fused) {
PADDLE_ENFORCE_EQ(
!output || output == bias_out,
true,
phi::errors::InvalidArgument(
"The output (= input * weight) is expected to be nullptr or the "
"same as bias_out when fused is true."));
phi::funcs::ComputeFusedGemmEpilogueForward<T>(dev_ctx_,
input,
weight,
bias,
bsz_seq_, // M
output_size_, // N
input_size_, // K
transA_,
transB_,
"none",
bias_out,
nullptr);
return;
}
#endif
// Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
// here: (transa, transb): nt, input * weight.
CBLAS_TRANSPOSE transA = transA_ ? CblasTrans : CblasNoTrans;
CBLAS_TRANSPOSE transB = transB_ ? CblasTrans : CblasNoTrans;
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
// (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx_);
blas.GEMM(transA,
transB,
bsz_seq_,
output_size_,
input_size_,
alpha,
input->data<T>(),
weight->data<T>(),
beta,
output->data<T>());
if (compute_bias_) {
// bias_out = output + bias
std::vector<const phi::DenseTensor*> ins = {output, bias};
std::vector<phi::DenseTensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
}
}
void ComputeBackward(const phi::DenseTensor* input,
const phi::DenseTensor* weight,
const phi::DenseTensor* d_output,
phi::DenseTensor* d_input,
phi::DenseTensor* d_weight,
phi::DenseTensor* d_bias,
bool use_addto = false,
bool fused = false) {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
if (compute_bias_ && fused) {
phi::funcs::ComputeFusedGemmEpilogueBackward<T>(dev_ctx_,
d_output,
input,
weight,
nullptr,
bsz_seq_, // M
output_size_, // N
input_size_, // K
transA_,
transB_,
"none",
d_input,
d_weight,
d_bias,
use_addto);
return;
}
#endif
T alpha = static_cast<T>(1.0);
T beta_dA = use_addto ? static_cast<T>(1.0) : static_cast<T>(0.0);
T beta_dB = static_cast<T>(0.0);
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx_);
if (!transA_) {
// forward: gemm-nt
if (transB_) {
// backward: gemm-tn, dB = (dC)^T * A
if (d_weight) {
int dB_m = output_size_;
int dB_n = input_size_;
int dB_k = bsz_seq_;
T* dB_output_ptr = d_weight->data<T>();
blas.GEMM(CblasTrans,
CblasNoTrans,
dB_m,
dB_n,
dB_k,
alpha,
d_output->data<T>(),
input->data<T>(),
beta_dB,
dB_output_ptr);
}
// backward: gemm-nn, dA = dC * B
if (d_input) {
int dA_m = bsz_seq_;
int dA_n = input_size_;
int dA_k = output_size_;
T* dA_output_ptr = d_input->data<T>();
blas.GEMM(CblasNoTrans,
CblasNoTrans,
dA_m,
dA_n,
dA_k,
alpha,
d_output->data<T>(),
weight->data<T>(),
beta_dA,
dA_output_ptr);
}
} else { // fw: gemm-nn
// backward: gemm-tn, dB = A^T * dC
if (d_weight) {
int dB_m = input_size_;
int dB_n = output_size_;
int dB_k = bsz_seq_;
T* dB_output_ptr = d_weight->data<T>();
blas.GEMM(CblasTrans,
CblasNoTrans,
dB_m,
dB_n,
dB_k,
alpha,
input->data<T>(),
d_output->data<T>(),
beta_dB,
dB_output_ptr);
}
// backward: gemm-nt, dA = dC * B^T
if (d_input) {
int dA_m = bsz_seq_;
int dA_n = input_size_;
int dA_k = output_size_;
T* dA_output_ptr = d_input->data<T>();
blas.GEMM(CblasNoTrans,
CblasTrans,
dA_m,
dA_n,
dA_k,
alpha,
d_output->data<T>(),
weight->data<T>(),
beta_dA,
dA_output_ptr);
}
}
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"AttnMatMul wrapper do not support (transA=T, transB=T/N)"
"parameters."));
}
if (compute_bias_ && d_bias) {
// reduce: {0, 1, 2, 3, 4} -> {2, 3, 4} or {0, 1, 2} -> {2} or {0,1,2,3}
// -> {3} or {0,1,2,3,4} -> {3,4}
const auto input_dims = d_output->dims();
const auto output_dims = d_bias->dims();
bool support_case_1 =
(input_dims.size() == 5 && output_dims.size() == 3 &&
(input_dims[2] == output_dims[0]) &&
(input_dims[3] == output_dims[1]) &&
(input_dims[4] == output_dims[2]));
bool support_case_2 =
(input_dims.size() == 3 && output_dims.size() == 1 &&
(input_dims[2] == output_dims[0]));
bool support_case_3 =
(input_dims.size() == 4 && output_dims.size() == 1 &&
input_dims[3] == output_dims[0]);
bool support_case_4 =
(input_dims.size() == 5 && output_dims.size() == 2 &&
input_dims[3] == output_dims[0] && input_dims[4] == output_dims[1]);
gpuStream_t stream = dev_ctx_.stream();
if (support_case_1 || support_case_2) {
phi::funcs::
TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx_,
*d_output,
d_bias,
kps::IdentityFunctor<T>(),
{0, 1},
stream);
} else if (support_case_3 || support_case_4) {
phi::funcs::
TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx_,
*d_output,
d_bias,
kps::IdentityFunctor<T>(),
{0, 1, 2},
stream);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Only support reduce when the input dims are [0,1,2,3,4] and "
"output is [2,3,4]"
"or input is [0,1,2] and output is [2]."));
}
}
}
private:
const phi::GPUContext& dev_ctx_;
bool transA_;
bool transB_;
int bsz_seq_;
int output_size_;
int input_size_;
int compute_bias_;
};
} // namespace fusion
} // namespace phi
此差异已折叠。
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_common.h"
#include "paddle/phi/kernels/fusion/gpu/fused_residual_dropout_bias.h"
#include "paddle/phi/kernels/gpu/gelu_funcs.h"
namespace phi {
namespace fusion {
template <typename T>
struct GeluFunctor {
inline __host__ __device__ T operator()(const T x) const {
using U = phi::funcs::LayerNormParamType<T>;
const U casted_x = static_cast<U>(x);
const U temp = erf(casted_x * static_cast<U>(M_SQRT1_2));
const U out = (casted_x * static_cast<U>(0.5) * (static_cast<U>(1) + temp));
return static_cast<T>(out);
}
};
template <typename T>
struct FastGeluFunctor {
inline __device__ T operator()(const T x) const {
return phi::GeluFwd<T, true>(x);
}
};
/**
*@brief the gelu grad functor
*/
template <typename T>
struct GeluGradFunctor {
inline __host__ __device__ T UseOut(const T x) const {
using U = phi::funcs::LayerNormParamType<T>;
auto casted_x = static_cast<U>(x);
auto first =
static_cast<U>(0.5) *
(static_cast<U>(1) + erf(casted_x * static_cast<U>(M_SQRT1_2)));
auto second = static_cast<U>(0.5 * M_2_SQRTPI * M_SQRT1_2) * casted_x *
exp(-static_cast<U>(0.5) * casted_x * casted_x);
return static_cast<T>((first + second));
}
};
/**
* @brief dst = dropout(activation(src + bias));
* the src, mask and dst shape is (rows, cols)
* the bias shape is (1, cols)
*/
template <typename T,
typename MaskType,
int VecSize,
typename Functor,
typename InType = T,
typename OutType = T>
__global__ void FusedDropoutActBias(
Functor act,
const uint64_t seed,
const uint64_t rows,
const uint64_t cols,
const int increment,
const float dropout_prob,
const bool is_upscale_in_train,
const bool is_test,
const InType *__restrict__ src,
const T *__restrict__ bias,
OutType *dst,
MaskType *mask,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
int col_id = blockDim.x * blockIdx.x + threadIdx.x;
int row_id = blockIdx.y;
int idx = row_id * cols + col_id;
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);
const T factor =
phi::fusion::GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
for (int i = col_id * VecSize; i < cols;
i += blockDim.x * gridDim.x * VecSize) {
phi::fusion::FusedResidualDropoutBiasOneThread<T,
MaskType,
VecSize,
false,
true,
Functor,
InType,
OutType>(
r,
i,
cols,
&state,
dropout_prob,
factor,
src,
nullptr,
bias,
dst,
mask,
is_test,
nullptr,
nullptr,
act,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
}
}
template <typename T,
int VecSize,
typename Functor,
typename InType = T,
typename OutType = T>
__global__ void FusedActBias(Functor act,
const uint64_t elem_cnt,
const uint64_t cols,
const InType *__restrict__ src,
const T *__restrict__ bias,
OutType *dst,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
const int32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
using LoadT = phi::AlignedVector<T, VecSize>;
using LoadInType = phi::AlignedVector<InType, VecSize>;
using LoadFloat = phi::AlignedVector<float, VecSize>;
using StoreOutType = phi::AlignedVector<OutType, VecSize>;
LoadInType src_vec;
LoadT bias_vec;
StoreOutType out_vec;
LoadFloat dequant_out_scale_vec;
for (int32_t idx = global_thread_idx * VecSize,
step = blockDim.x * gridDim.x * VecSize;
idx < elem_cnt;
idx += step) {
const int32_t col_idx = idx % cols;
phi::Load<InType, VecSize>(&src[idx], &src_vec);
phi::Load<float, VecSize>(&dequant_out_scale_data[col_idx],
&dequant_out_scale_vec);
if (bias) {
phi::Load<T, VecSize>(&bias[col_idx], &bias_vec);
}
#pragma unroll
for (int32_t unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) {
T tmp;
if (std::is_same<InType, int32_t>::value) {
tmp = static_cast<T>(static_cast<float>(src_vec[unroll_idx]) *
dequant_out_scale_vec[unroll_idx]);
if (bias) {
tmp = static_cast<T>(act(tmp + bias_vec[unroll_idx]));
} else {
tmp = static_cast<T>(act(tmp));
}
out_vec[unroll_idx] = phi::funcs::quant_helper(tmp,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
if (bias) {
out_vec[unroll_idx] = static_cast<OutType>(
act(static_cast<T>(src_vec[unroll_idx]) + bias_vec[unroll_idx]));
} else {
out_vec[unroll_idx] =
static_cast<OutType>(act(static_cast<T>(src_vec[unroll_idx])));
}
}
}
phi::Store<OutType, VecSize>(out_vec, &dst[idx]);
}
}
/**
* @brief dst = dropout(activation(src + bias));
*/
template <typename T,
typename MaskType,
typename Functor,
typename InType = T,
typename OutType = T>
void LaunchDropoutActBias(Functor act_functor,
const uint64_t seed,
const uint32_t rows,
const uint32_t cols,
const int increment,
const float dropout_prob,
const bool is_upscale_in_train,
const bool is_test,
const InType *src,
const T *bias,
OutType *dst,
MaskType *mask_data,
const phi::GPUContext &ctx,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
// dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
phi::fusion::SetZero<T>(ctx, reinterpret_cast<T *>(dst), rows * cols);
phi::fusion::SetZero<MaskType>(ctx, mask_data, rows * cols);
return;
}
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
const auto config =
phi::fusion::Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
if (cols % VecSize == 0) {
if (is_test) {
const int32_t elem_cnt = rows * cols;
const int32_t pack_num = elem_cnt / VecSize;
const int32_t tmp_cols = cols / VecSize;
int block_size =
std::max(static_cast<int32_t>(32), std::min(tmp_cols, 128));
const int grid_size = std::max(static_cast<int32_t>(1),
(pack_num + block_size - 1) / block_size);
FusedActBias<T, VecSize, Functor, InType, OutType>
<<<grid_size, block_size, 0, ctx.stream()>>>(act_functor,
elem_cnt,
cols,
src,
bias,
dst,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
} else {
FusedDropoutActBias<T, MaskType, VecSize, Functor, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor,
seed,
rows,
cols,
increment,
dropout_prob,
is_upscale_in_train,
is_test,
src,
bias,
dst,
mask_data,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
}
} else {
FusedDropoutActBias<T, MaskType, 1, Functor, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor,
seed,
rows,
cols,
increment,
dropout_prob,
is_upscale_in_train,
is_test,
src,
bias,
dst,
mask_data,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
}
}
/*
* @brief calculate the grad of no bias
*/
template <typename T, typename MaskType, int VecSize, typename Functor>
__global__ void FusedDropoutActGrad(Functor act_grad,
const T *dout,
const MaskType *mask,
const T *src,
const T factor,
const int64_t size,
T *dx) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
LoadT dout_vec;
LoadT src_vec;
MaskLoadT mask_vec;
phi::Load<T, VecSize>(&dout[i], &dout_vec);
phi::Load<MaskType, VecSize>(&mask[i], &mask_vec);
phi::Load<T, VecSize>(&src[i], &src_vec);
StoreT dx_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
T tmp = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
dx_vec[ii] = tmp * act_grad.UseOut(src_vec[ii]);
}
phi::Store<T, VecSize>(dx_vec, &dx[i]);
}
}
/**
* blocks(128 * 8)
* 1. calculate the dx and reduce total rows to 128 rows
* 2. save 128*8 temporary sum in 8*128 shared memory
* 3. reduce the sum of 128 cols data by 8*VecSize warps
*/
template <typename T,
typename MaskType,
int BlockSizeX,
int BlockSizeY,
int VecSize,
typename Functor,
int THREADS_PER_CTA = BlockSizeX *BlockSizeY>
__global__ __launch_bounds__(THREADS_PER_CTA) void FusedDropoutActBiasGrad(
Functor act_grad,
const T *dout,
const MaskType *mask,
const T *src,
const T *bias,
const T factor,
const int64_t rows,
const int64_t cols,
T *dx,
T *dbias) {
int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
T tmp_sum[VecSize] = {static_cast<T>(0)};
// calculate the dx and temporary sum
if (col_id * VecSize < cols) {
for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) {
int index = row_id * cols + col_id * VecSize;
LoadT dout_vec;
LoadT src_vec;
LoadT bias_vec;
MaskLoadT mask_vec;
phi::Load<T, VecSize>(&dout[index], &dout_vec);
phi::Load<T, VecSize>(&src[index], &src_vec);
phi::Load<MaskType, VecSize>(&mask[index], &mask_vec);
phi::Load<T, VecSize>(&bias[col_id * VecSize], &bias_vec);
StoreT dx_vec;
#pragma unroll
for (int i = 0; i < VecSize; i++) {
T val;
T tmp = dout_vec[i] * static_cast<T>(mask_vec[i]) * factor;
val = tmp * act_grad.UseOut(src_vec[i] + bias_vec[i]);
dx_vec[i] = val;
tmp_sum[i] += val;
}
phi::Store<T, VecSize>(dx_vec, &dx[index]);
}
}
phi::fusion::CalculateDBias<T, VecSize, BlockSizeX, BlockSizeY>(
tmp_sum, dbias, cols);
}
/**
* @brief to launch kernel FusedResidualDropoutBiasGradVec
*/
template <typename T, typename MaskType, typename Functor>
void LaunchDropoutActBiasGrad(Functor act_functor,
const T *dout,
const MaskType *mask,
const T *src,
const T *bias,
const float dropout_prob,
const bool is_upscale_in_train,
const uint32_t rows,
const uint32_t cols,
T *dx,
T *dbias,
const phi::GPUContext &ctx) {
const T zero = static_cast<T>(0.0);
auto factor = dropout_prob == static_cast<float>(1.0f)
? zero
: static_cast<T>(1.0 / (1.0 - dropout_prob));
if (!is_upscale_in_train) {
factor = static_cast<T>(1.0f);
}
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
if (dbias != nullptr) {
const auto threads = 8;
const auto blocks =
std::max(static_cast<uint32_t>(1),
(cols / real_vec_size + threads - 1) / threads);
dim3 block_dim(threads, 128, 1);
dim3 grid_dim(blocks, 1, 1);
if (cols % VecSize == 0) {
FusedDropoutActBiasGrad<T, MaskType, 8, 128, VecSize, Functor>
<<<grid_dim, block_dim, 0, ctx.stream()>>>(act_functor,
dout,
mask,
src,
bias,
factor,
rows,
cols,
dx,
dbias);
} else {
FusedDropoutActBiasGrad<T, MaskType, 8, 128, 1, Functor>
<<<grid_dim, block_dim, 0, ctx.stream()>>>(act_functor,
dout,
mask,
src,
bias,
factor,
rows,
cols,
dx,
dbias);
}
} else {
const uint64_t n = rows * cols;
phi::backends::gpu::GpuLaunchConfig config =
phi::backends::gpu::GetGpuLaunchConfig1D(ctx, n / real_vec_size);
if (n % VecSize == 0) {
FusedDropoutActGrad<T, MaskType, VecSize, Functor>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor, dout, mask, src, factor, n, dx);
} else {
FusedDropoutActGrad<T, MaskType, 1, Functor>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor, dout, mask, src, factor, n, dx);
}
}
}
} // namespace fusion
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if defined(PADDLE_WITH_CUDA)
#include <cooperative_groups.h>
#include <cuda.h>
#include <curand_kernel.h>
#endif
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
namespace phi {
namespace fusion {
#define CACHE_LINE 128
#define MAX_CACHE_BYTES (CACHE_LINE / CHAR_BIT)
/**
* get the threads for fused_residual_dropout_bias:
* 1D blocks: blockDim.x = cols
* 2D grids: gridDim.y = rows
*/
inline phi::backends::gpu::GpuLaunchConfig Get1DBlocksAnd2DGrids(
const phi::GPUContext &ctx,
const uint32_t rows,
const uint32_t cols,
const int vec_size) {
const uint32_t tmp_cols = cols / vec_size;
// NOTE(wangxi): We set max_block_size to 512, for `FusedResidualDropoutBias`
// needs too many register resources. If data_type is float16, CUDA
// error(701) will occur when block_size is 1024. Which error is
// 'cudaErrorLaunchOutOfResources', this indicates that a launch did not
// occur because it did not have appropriate resources.
// Of course, this kernel can be optimized later to reduce the use
// of registers.
int threads = std::max(static_cast<uint32_t>(32),
std::min(tmp_cols,
static_cast<uint32_t>(std::min(
ctx.GetMaxThreadsPerBlock(), 512))));
const auto blocks_x =
std::max(static_cast<uint32_t>(1), (tmp_cols + threads - 1) / threads);
const auto blocks_y = std::max(static_cast<uint32_t>(1), rows);
phi::backends::gpu::GpuLaunchConfig config;
config.block_per_grid.x = blocks_x;
config.block_per_grid.y = blocks_y;
config.thread_per_block.x = threads;
return config;
}
template <int VecSize>
__forceinline__ __device__ void RandVec(curandStatePhilox4_32_10_t *state,
float *data);
template <>
__forceinline__ __device__ void RandVec<1>(curandStatePhilox4_32_10_t *state,
float *data) {
data[0] = curand_uniform(state);
}
template <>
__forceinline__ __device__ void RandVec<2>(curandStatePhilox4_32_10_t *state,
float *data) {
data[0] = curand_uniform(state);
data[1] = curand_uniform(state);
}
template <>
__forceinline__ __device__ void RandVec<4>(curandStatePhilox4_32_10_t *state,
float *data) {
float4 rand4 = curand_uniform4(state);
data[0] = rand4.x;
data[1] = rand4.y;
data[2] = rand4.w;
data[3] = rand4.z;
}
template <>
__forceinline__ __device__ void RandVec<8>(curandStatePhilox4_32_10_t *state,
float *data) {
RandVec<4>(state, data);
RandVec<4>(state, data + 4);
}
template <typename T>
inline void SetZero(const phi::GPUContext &ctx, T *ptr, const size_t size) {
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(ptr, 0, size * sizeof(T), ctx.stream()));
}
/**
* reduce the sum of 128 cols data by 8*VecSize warps
**/
template <typename T, int VecSize, int BlockSizeX, int BlockSizeY>
inline __device__ void CalculateDBias(const T *tmp_sum,
T *dbias,
const int cols) {
// save temporary sum to cache and do transpose
__shared__ T cache[BlockSizeX * VecSize][BlockSizeY];
for (int i = 0; i < VecSize; i++) {
cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i];
}
__syncthreads();
// reduce sum
T sum[2] = {static_cast<T>(0)};
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int x = tid >> 5; // warp id
int y = tid & 31; // thread id on warp 0~31
// need BlockSizeX * VecSize warps
for (int j = x; j < BlockSizeX * VecSize; j += 32) {
// reduce 128 to 32
#pragma unroll
for (int i = 0; i < (BlockSizeY >> 5); i++) {
sum[(j >> 5)] += cache[j][y + i * 32];
}
}
int reduce_num_pre_thread = (BlockSizeX * VecSize + 31) / 32;
// reduce 32 to 1
for (int i = 0; i < reduce_num_pre_thread; i++) {
sum[i] = phi::funcs::WarpReduceSum(sum[i]);
}
// save sum to dbias
if (y == 0 && x < BlockSizeX * VecSize) {
for (int i = 0; i < reduce_num_pre_thread; i++) {
int bias_id = blockIdx.x * BlockSizeX * VecSize + x + i * 32;
if (bias_id < cols) {
dbias[bias_id] = sum[i];
}
}
}
}
template <typename T>
inline __device__ T GetFactor(const float dropout_prob,
const bool is_upscale_in_train,
const bool is_test) {
T factor = is_upscale_in_train ? static_cast<T>(1.0f / (1.0f - dropout_prob))
: static_cast<T>(1.0f);
if (is_test) {
factor = is_upscale_in_train ? static_cast<T>(1.0f)
: static_cast<T>(1.0f - dropout_prob);
}
return factor;
}
} // namespace fusion
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined(PADDLE_WITH_CUDA)
#include "paddle/phi/backends/dynload/cublasLt.h"
#endif
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/dropout_impl_util.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_act_bias.h"
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_common.h"
#include "paddle/phi/kernels/fusion/gpu/fused_layernorm_residual_dropout_bias.h"
#include "paddle/phi/kernels/fusion/gpu/fused_residual_dropout_bias.h"
#include "paddle/phi/kernels/layer_norm_kernel.h"
namespace phi {
namespace fusion {
struct DropoutParam {
uint64_t seed;
float dropout_prob;
bool is_upscale_in_train;
bool is_test;
bool fix_seed;
int increment{};
const phi::DenseTensor* tensor_seed;
int seed_val;
DropoutParam() {
fix_seed = false;
seed = 0;
is_test = false;
is_upscale_in_train = false;
dropout_prob = 0.5;
tensor_seed = nullptr;
seed_val = 0;
}
DropoutParam(bool fix_seed_,
uint64_t seed_,
bool is_test_,
bool is_upscale_in_train_,
float dropout_prob_,
const phi::DenseTensor* tensor_seed_,
int seed_val_) {
fix_seed = fix_seed_;
seed = seed_;
is_test = is_test_;
is_upscale_in_train = is_upscale_in_train_;
dropout_prob = dropout_prob_;
tensor_seed = tensor_seed_;
seed_val = seed_val_;
}
int UpdateSeedAndIncrement(const phi::GPUContext& dev_ctx, const int offset) {
uint64_t tmp_increment;
phi::funcs::GetSeedDataAndIncrement(dev_ctx,
tensor_seed,
fix_seed,
seed_val,
offset,
&seed,
&tmp_increment);
increment = static_cast<int>(tmp_increment);
return increment;
}
};
template <typename T>
struct DataTypeTraits {
using DataType = T;
};
template <>
struct DataTypeTraits<phi::dtype::float16> {
// Since LayerNormDirectCUDAFunctor register half type, we need to convert
// phi::float16 to half.
using DataType = half;
};
template <typename T,
typename MaskType,
typename InType = T,
typename OutType = T>
class FusedDropoutHelper {
private:
int GetIncrement(const phi::GPUContext& ctx) {
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
const int real_vec_size = cols_ % VecSize == 0 ? VecSize : 1;
auto config = Get1DBlocksAnd2DGrids(ctx,
static_cast<uint64_t>(rows_),
static_cast<uint64_t>(cols_),
real_vec_size);
int increment = ((cols_ - 1) / (config.thread_per_block.x *
config.block_per_grid.x * real_vec_size) +
1) *
real_vec_size;
increment = dropout_param_.UpdateSeedAndIncrement(ctx, increment);
return increment;
}
public:
FusedDropoutHelper() {}
FusedDropoutHelper(const phi::GPUContext& ctx,
const int rows,
const int cols,
const DropoutParam& dropout_param) {
rows_ = rows;
cols_ = cols;
dropout_param_ = dropout_param;
}
// out = residual + dropout( src + bias )
void ResidualDropoutBias(const phi::GPUContext& ctx,
const InType* src,
const T* residual,
const T* bias,
OutType* out,
MaskType* mask,
const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0) {
auto increment = GetIncrement(ctx);
LaunchResidualDropoutBias<T, MaskType, InType, OutType>(
rows_,
cols_,
increment,
dropout_param_.seed,
dropout_param_.dropout_prob,
dropout_param_.is_test,
dropout_param_.is_upscale_in_train,
src,
residual,
bias,
mask,
out,
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
}
void ResidualDropoutBiasGrad(const phi::GPUContext& ctx,
const T* d_out,
const MaskType* mask,
T* d_src,
T* d_residual,
T* d_bias) {
LaunchResidualDropoutBiasGrad<T, uint8_t>(
d_out,
mask,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
rows_,
cols_,
d_src,
d_bias,
ctx);
if (d_residual) {
phi::memory_utils::Copy(ctx.GetPlace(),
d_residual,
ctx.GetPlace(),
d_out,
rows_ * cols_ * sizeof(T),
ctx.stream());
}
}
// out = dropout(activation(src + bias))
void DropoutActBias(const phi::GPUContext& ctx,
const InType* src,
const T* bias,
const std::string& act_method,
OutType* out,
MaskType* mask,
const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
auto increment = GetIncrement(ctx);
if (act_method == "gelu") {
if (FLAGS_use_fast_math) {
phi::fusion::FastGeluFunctor<T> fast_gelu;
phi::fusion::LaunchDropoutActBias<T,
MaskType,
phi::fusion::FastGeluFunctor<T>,
InType,
OutType>(
fast_gelu,
dropout_param_.seed,
rows_,
cols_,
dropout_param_.increment,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
dropout_param_.is_test,
src,
bias,
out,
mask,
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
phi::fusion::GeluFunctor<T> gelu;
phi::fusion::LaunchDropoutActBias<T,
MaskType,
phi::fusion::GeluFunctor<T>,
InType,
OutType>(
gelu,
dropout_param_.seed,
rows_,
cols_,
dropout_param_.increment,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
dropout_param_.is_test,
src,
bias,
out,
mask,
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
} else if (act_method == "relu") {
phi::funcs::ReluFunctor<T> relu;
phi::fusion::LaunchDropoutActBias<T,
MaskType,
phi::funcs::ReluFunctor<T>,
InType,
OutType>(
relu,
dropout_param_.seed,
rows_,
cols_,
increment,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
dropout_param_.is_test,
src,
bias,
out,
mask,
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
PADDLE_THROW(errors::InvalidArgument(
"Currently only supports gelu or relu activation functions!"));
}
}
void DropoutActBiasGrad(const phi::GPUContext& ctx,
const T* dout,
const T* src,
const T* bias,
const MaskType* mask,
T* d_src,
T* d_bias,
const std::string& act_method) {
if (act_method == "gelu") {
phi::funcs::GeluGradFunctor<T> gelu_grad;
phi::fusion::
LaunchDropoutActBiasGrad<T, MaskType, phi::funcs::GeluGradFunctor<T>>(
gelu_grad,
dout,
mask,
src,
bias,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
rows_,
cols_,
d_src,
d_bias,
ctx);
} else if (act_method == "relu") {
phi::funcs::ReluGradFunctor<T> relu_grad;
phi::fusion::
LaunchDropoutActBiasGrad<T, MaskType, phi::funcs::ReluGradFunctor<T>>(
relu_grad,
dout,
mask,
src,
bias,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
rows_,
cols_,
d_src,
d_bias,
ctx);
} else {
PADDLE_THROW(errors::InvalidArgument(
"Currently only supports gelu or relu activation functions!"));
}
}
protected:
int rows_;
int cols_;
DropoutParam dropout_param_;
};
template <typename T,
typename MaskType,
typename InType = T,
typename OutType = T>
class FusedDropoutLayerNormHelper
: public FusedDropoutHelper<T, MaskType, InType, OutType> {
public:
FusedDropoutLayerNormHelper() {}
FusedDropoutLayerNormHelper(const int rows,
const int cols,
const float epsilon) {
using U = phi::funcs::LayerNormParamType<T>;
this->rows_ = rows;
this->cols_ = cols;
epsilon_ = epsilon;
}
FusedDropoutLayerNormHelper(const phi::GPUContext& ctx,
const int rows,
const int cols,
const DropoutParam& dropout_param,
const float epsilon)
: FusedDropoutHelper<T, MaskType, InType, OutType>(
ctx, rows, cols, dropout_param) {
using U = phi::funcs::LayerNormParamType<T>;
epsilon_ = epsilon;
}
// call layer_norm
void LayerNorm(const phi::GPUContext& ctx,
const InType* src,
const phi::funcs::LayerNormParamType<T>* gamma,
const phi::funcs::LayerNormParamType<T>* beta,
OutType* out,
phi::funcs::LayerNormParamType<T>* mean,
phi::funcs::LayerNormParamType<T>* variance) {
using InDataType = typename DataTypeTraits<InType>::DataType;
using OutDataType = typename DataTypeTraits<OutType>::DataType;
phi::LayerNormDirectCUDAFunctor<InDataType,
phi::funcs::LayerNormParamType<T>>
layer_norm;
std::vector<int> src_shape{this->rows_, this->cols_};
layer_norm(ctx.stream(),
reinterpret_cast<const InDataType*>(src),
src_shape,
beta,
gamma,
reinterpret_cast<OutDataType*>(out),
mean,
variance,
1,
epsilon_);
}
void LayerNormGrad(const phi::GPUContext& ctx,
const T* dout,
const T* src,
const phi::funcs::LayerNormParamType<T>* gamma,
const phi::funcs::LayerNormParamType<T>* mean,
const phi::funcs::LayerNormParamType<T>* variance,
T* d_src,
phi::funcs::LayerNormParamType<T>* d_scale,
phi::funcs::LayerNormParamType<T>* d_bias) {
using U = phi::funcs::LayerNormParamType<T>;
phi::funcs::LayerNormBackward<T, U>(src,
dout,
gamma,
mean,
variance,
d_src,
d_scale,
d_bias,
epsilon_,
this->rows_,
this->cols_,
ctx);
}
// out = layernorm(residual + dropout(src + bias))
template <typename P = phi::funcs::LayerNormParamType<T>,
bool is_same_type = false>
void LayernormResidualDropoutBias(
const phi::GPUContext& ctx,
const InType* src,
const T* residual,
const T* bias,
const P* gamma,
const P* beta,
T* dropout_out,
MaskType* mask,
OutType* out,
phi::funcs::LayerNormParamType<T>* mean,
phi::funcs::LayerNormParamType<T>* variance,
const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
using U = phi::funcs::LayerNormParamType<T>;
int vec_size = MAX_CACHE_BYTES / sizeof(T);
if (this->cols_ % vec_size != 0) {
vec_size = 1;
}
int threads = phi::funcs::GetDesiredBlockDim(this->cols_ / vec_size);
int increment = ((this->cols_ - 1) / (threads * vec_size) + 1) * vec_size;
increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment);
LaunchLayernormResidualDropoutBias<T,
MaskType,
U,
is_same_type,
InType,
OutType>(
this->rows_,
this->cols_,
increment,
this->dropout_param_.seed,
this->dropout_param_.dropout_prob,
epsilon_,
this->dropout_param_.is_upscale_in_train,
this->dropout_param_.is_test,
src,
residual,
bias,
gamma,
beta,
mask,
dropout_out,
out,
mean,
variance,
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
template <typename P = phi::funcs::LayerNormParamType<T>,
bool is_same_type = false>
void LayernormResidualDropoutBiasGrad(
const phi::GPUContext& ctx,
const T* d_out,
const T* layernorm_src,
const MaskType* mask,
const P* gamma,
const phi::funcs::LayerNormParamType<T>* mean,
const phi::funcs::LayerNormParamType<T>* variance,
T* d_layernorm_src,
P* d_scale,
P* d_layernorm_bias,
T* d_dropout_src,
T* d_bias,
T* d_residual) {
using U = phi::funcs::LayerNormParamType<T>;
bool can_call_1024_kernel = false;
// Fast impl for cases when cols is 1024 and linear_bias is nullptr.
// In fact, linear_bias is not nullptr is also feasible for impl.
// Here, we do not support it.
if (this->cols_ == 1024 && d_bias == nullptr && d_scale != nullptr &&
d_layernorm_bias != nullptr && sizeof(T) <= 4) {
can_call_1024_kernel = true;
}
VLOG(6) << "LaunchLayernormResidualDropoutGrad = " << can_call_1024_kernel;
if (can_call_1024_kernel) {
LaunchLayernormResidualDropoutGrad<T, U, MaskType, is_same_type>(
ctx,
this->rows_,
this->cols_,
epsilon_,
this->dropout_param_.dropout_prob,
this->dropout_param_.is_upscale_in_train,
d_out,
layernorm_src,
gamma,
mean,
variance,
mask,
d_scale,
d_layernorm_bias,
d_residual,
d_dropout_src);
} else {
phi::funcs::LayerNormBackward<T, U, is_same_type>(layernorm_src,
d_out,
gamma,
mean,
variance,
d_layernorm_src,
d_scale,
d_layernorm_bias,
epsilon_,
this->rows_,
this->cols_,
ctx);
this->ResidualDropoutBiasGrad(
ctx, d_layernorm_src, mask, d_dropout_src, d_residual, d_bias);
}
}
protected:
float epsilon_;
};
} // namespace fusion
} // namespace phi
此差异已折叠。
......@@ -28,8 +28,12 @@ from paddle.nn.layer.common import Dropout, Linear
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.transformer import _convert_attention_mask
random.seed(42)
default_main_program().random_seed = 42
seed = 42
random.seed(seed)
default_main_program().random_seed = seed
np.random.seed(seed)
paddle.seed(seed)
class TestFusedMultiTransformerOp(OpTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册