未验证 提交 428fb804 编写于 作者: S sneaxiy 提交者: GitHub

Save fused_attention op memory when dropout_rate = 0.0 (#48902)

* save fused_attention memory when dropout_rate = 0.0

* add ut

* fix ut bug

* fix fused_layernorm_residual_dropout_bias_test.cu
上级 889e5834
......@@ -572,15 +572,17 @@ fused_attention_dygraph_function(
egr::EagerUtils::CheckAndRetainGrad(SoftmaxOut);
grad_node->SetGradOutMeta(SoftmaxOut, 19);
auto AttnDropoutOut_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(
p_autograd_AttnDropoutOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_AttnDropoutOut, 0);
egr::EagerUtils::SetHistory(p_autograd_AttnDropoutOut,
AttnDropoutOut_accumulation_node);
AttnDropoutOut_accumulation_node->SetGradInMeta(AttnDropoutOut, 0);
egr::EagerUtils::CheckAndRetainGrad(AttnDropoutOut);
grad_node->SetGradOutMeta(AttnDropoutOut, 20);
if (AttnDropoutOut.initialized()) {
auto AttnDropoutOut_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(
p_autograd_AttnDropoutOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_AttnDropoutOut, 0);
egr::EagerUtils::SetHistory(p_autograd_AttnDropoutOut,
AttnDropoutOut_accumulation_node);
AttnDropoutOut_accumulation_node->SetGradInMeta(AttnDropoutOut, 0);
egr::EagerUtils::CheckAndRetainGrad(AttnDropoutOut);
grad_node->SetGradOutMeta(AttnDropoutOut, 20);
}
auto FMHAOut_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(p_autograd_FMHAOut);
......
......@@ -476,7 +476,7 @@ class fused_attentionGradNodeCompat : public egr::GradNodeBase {
SoftmaxOut_ = egr::TensorWrapper(SoftmaxOut, false);
}
void SetTensorWrapperSrcMask(const paddle::experimental::Tensor& SrcMask) {
SrcMask_ = egr::TensorWrapper(SrcMask, false);
SrcMask_ = egr::TensorWrapper(SrcMask, true);
}
void SetTensorWrapperSrcMaskOut(
const paddle::experimental::Tensor& SrcMaskOut) {
......
......@@ -102,7 +102,6 @@ class FMHARef {
T* qk_out_data = qk_out_tensor->data<T>();
T* qktv_out_data = qktv_out_tensor->data<T>();
T* softmax_out_data = softmax_out_tensor->data<T>();
T* dropout_out_data = dropout_out_tensor->data<T>();
T* fmha_out_data = fmha_out_tensor->data<T>();
auto out_seq_len = seq_len_;
......@@ -219,6 +218,7 @@ class FMHARef {
dropout_mask_out_tensor,
dropout_out_tensor,
false);
T* dropout_out_data = dropout_out_tensor->data<T>();
blas.BatchedGEMM(transA,
transB,
gemm_m,
......@@ -462,8 +462,6 @@ class FMHARef {
const T* softmax_out_data = softmax_out_tensor.data<T>();
T* softmax_out_grad_data = softmax_out_grad_tensor->data<T>();
const T* dropout_out_data = dropout_out_tensor.data<T>();
T* dropout_out_grad_data = dropout_out_grad_tensor->data<T>();
T* qktv_out_grad_data = qktv_out_grad_tensor->data<T>();
// transpose bw
......@@ -485,6 +483,7 @@ class FMHARef {
int64_t stride_b = gemm_k * gemm_n;
// bw: dy = x^t * dout
if (dropout_param_.dropout_prob_) {
const T* dropout_out_data = dropout_out_tensor.data<T>();
blas.BatchedGEMM(transA,
transB,
gemm_m,
......@@ -522,6 +521,7 @@ class FMHARef {
stride_a = gemm_m * gemm_k;
stride_b = gemm_k * gemm_n;
if (dropout_param_.dropout_prob_) {
T* dropout_out_grad_data = dropout_out_grad_tensor->data<T>();
blas.BatchedGEMM(transA,
transB,
gemm_m,
......
......@@ -545,8 +545,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("QKOut"));
ctx->SetOutputDim(framework::GradVarName("SoftmaxOut"),
ctx->GetInputDim("SoftmaxOut"));
ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"),
ctx->GetInputDim("AttnDropoutOut"));
if (ctx->HasOutput(framework::GradVarName("AttnDropoutOut"))) {
ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"),
ctx->GetInputDim("AttnDropoutOut"));
}
if (ctx->HasOutput(framework::GradVarName("SrcMaskOut"))) {
ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"),
......@@ -707,7 +709,8 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(FusedAttentionGradNoNeedBufferInferer,
"QKVOut",
"QKOut",
"QKTVOut",
"OutLinearOut");
"OutLinearOut",
"SrcMask");
} // namespace operators
} // namespace paddle
......
......@@ -121,6 +121,10 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
const float ln_epsilon = ctx.Attr<float>("ln_epsilon");
float attn_dropout_rate = ctx.Attr<float>("attn_dropout_rate");
const bool has_attn_dropout = (attn_dropout_rate != 0.0f);
DropoutParam dropout_param2(ctx, 0);
const bool has_dropout = (dropout_param2.dropout_prob != 0.0f);
bool is_test_1 = ctx.Attr<bool>("is_test");
auto &dropout_implementation_1 =
ctx.Attr<std::string>("attn_dropout_implementation");
......@@ -169,11 +173,16 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
src_mask_out->numel() * sizeof(T));
auto *softmax_out_data = dev_ctx.template Alloc<T>(
softmax_out, softmax_out->numel() * sizeof(T));
auto *attn_dropout_mask_out_data = dev_ctx.template Alloc<uint8_t>(
attn_dropout_mask_out,
attn_dropout_mask_out->numel() * sizeof(uint8_t));
auto *attn_dropout_out_data = dev_ctx.template Alloc<T>(
attn_dropout_out, attn_dropout_out->numel() * sizeof(T));
auto *attn_dropout_mask_out_data =
has_attn_dropout ? dev_ctx.template Alloc<uint8_t>(
attn_dropout_mask_out,
attn_dropout_mask_out->numel() * sizeof(uint8_t))
: nullptr;
auto *attn_dropout_out_data =
has_attn_dropout
? dev_ctx.template Alloc<T>(attn_dropout_out,
attn_dropout_out->numel() * sizeof(T))
: nullptr;
auto *fmha_out_data =
dev_ctx.template Alloc<T>(fmha_out, fmha_out->numel() * sizeof(T));
......@@ -185,8 +194,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
out_linear_out, out_linear_out->numel() * sizeof(T));
// get data ptr for bias+dropout+residual+layernorm
auto *dropout_mask_out_data = dev_ctx.template Alloc<uint8_t>(
dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t));
auto *dropout_mask_out_data =
has_dropout
? dev_ctx.template Alloc<uint8_t>(
dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t))
: nullptr;
auto *final_out_data =
dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
......@@ -246,7 +258,6 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
input_size,
output_size,
false);
DropoutParam dropout_param2(ctx, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx.cuda_device_context(),
bsz_seq,
......@@ -367,7 +378,11 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
const float epsilon = ctx.Attr<float>("epsilon");
const float ln2epsilon = ctx.Attr<float>("ln_epsilon");
float attn_dropout_prob = ctx.Attr<float>("attn_dropout_rate");
const float attn_dropout_prob = ctx.Attr<float>("attn_dropout_rate");
const bool has_attn_dropout = (attn_dropout_prob != 0.0f);
DropoutParam dropout_param2(ctx, 0);
const bool has_dropout = (dropout_param2.dropout_prob != 0.0f);
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
bool is_test_1 = ctx.Attr<bool>("is_test");
auto &dropout_implementation_1 =
......@@ -398,7 +413,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *qkv_bias = ctx.Input<phi::DenseTensor>("QKVBias");
auto *out_linear_weight = ctx.Input<phi::DenseTensor>("OutLinearW");
auto *out_linear_bias = ctx.Input<phi::DenseTensor>("OutLinearBias");
auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data<T>());
auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
auto *out_linear_weight_data = out_linear_weight->data<T>();
......@@ -424,7 +438,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *softmax_out_data = softmax_out->data<T>();
auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>();
auto *dropout_mask_out_data =
has_dropout ? dropout_mask_out->data<uint8_t>() : nullptr;
// output's grad
auto *d_x = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
......@@ -470,8 +485,11 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
dev_ctx.template Alloc<T>(d_qk_out, d_qk_out->numel() * sizeof(T));
auto *d_softmax_out_data = dev_ctx.template Alloc<T>(
d_softmax_out, d_softmax_out->numel() * sizeof(T));
auto *d_attn_dropout_out_data = dev_ctx.template Alloc<T>(
d_attn_dropout_out, d_attn_dropout_out->numel() * sizeof(T));
auto *d_attn_dropout_out_data =
has_attn_dropout
? dev_ctx.template Alloc<T>(d_attn_dropout_out,
d_attn_dropout_out->numel() * sizeof(T))
: nullptr;
auto *d_src_mask_out_data =
(src_mask == nullptr)
? nullptr
......@@ -571,7 +589,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
input_size,
output_size,
compute_bias);
DropoutParam dropout_param2(ctx, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx.cuda_device_context(),
bsz_seq,
......@@ -631,7 +648,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
if (qkv_bias != nullptr) {
fmha_ref_compute.ComputeBackward(*transpose_out_2,
src_mask,
has_attn_dropout ? src_mask : nullptr,
*softmax_out,
*attn_dropout_mask_out,
*attn_dropout_out,
......@@ -648,7 +665,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_qkv_bias_out);
} else {
fmha_ref_compute.ComputeBackward(*transpose_out_2,
src_mask,
has_attn_dropout ? src_mask : nullptr,
*softmax_out,
*attn_dropout_mask_out,
*attn_dropout_out,
......
......@@ -290,7 +290,7 @@ struct TestFusedLayernormResidualDropoutBias {
framework::TensorToVector(layernorm_out, *ctx, &_layernorm_out);
framework::TensorToVector(means, *ctx, &_means);
framework::TensorToVector(vars, *ctx, &_vars);
if (!is_test) {
if (!is_test && dropout_prob != 0.0f) {
framework::TensorToVector(mask, *ctx, &_mask);
}
ctx->Wait();
......@@ -298,7 +298,9 @@ struct TestFusedLayernormResidualDropoutBias {
for (int i = 0; i < n; i++) {
EXPECT_LT(std::abs(_out[i] - correct_out[i]), diff);
EXPECT_LT(std::abs(_layernorm_out[i] - correct_layernorm_out[i]), diff);
if (!is_test) EXPECT_EQ(_mask[i], correct_mask[i]);
if (!is_test && dropout_prob != 0.0f) {
EXPECT_EQ(_mask[i], correct_mask[i]);
}
}
for (int i = 0; i < rows; i++) {
EXPECT_LT(std::abs(_means[i] - correct_means[i]), static_cast<U>(diff));
......
......@@ -30,7 +30,8 @@ template <typename T,
bool Activation,
typename Functor,
typename InType = T,
typename OutType = T>
typename OutType = T,
bool HasDropout = true>
__forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
const int row_id,
const int col_id,
......@@ -84,7 +85,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
}
MaskStoreT mask_vec;
if (!is_test) {
if (!is_test && HasDropout) {
float rand[VecSize];
RandVec<VecSize>(state, rand);
#pragma unroll
......@@ -114,8 +115,12 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
if (Activation) {
tmp = act_func(tmp);
}
dest_vec[ii] =
tmp * static_cast<T>(mask_vec[ii]) * factor + residual_vec[ii];
if (HasDropout) {
dest_vec[ii] =
tmp * static_cast<T>(mask_vec[ii]) * factor + residual_vec[ii];
} else {
dest_vec[ii] = tmp * factor + residual_vec[ii];
}
if (ComputeLayerNorm) {
U tmp = static_cast<U>(dest_vec[ii]);
*mean_val += tmp;
......@@ -138,7 +143,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
phi::Store<T, VecSize>(dest_vec,
reinterpret_cast<T *>(&dst[row_id * cols + col_id]));
}
if (!is_test) {
if (!is_test && HasDropout) {
phi::Store<MaskType, VecSize>(mask_vec, &mask[row_id * cols + col_id]);
}
}
......@@ -154,7 +159,8 @@ template <typename T,
typename MaskType,
int VecSize,
typename InType = T,
typename OutType = T>
typename OutType = T,
bool HasDropout = true>
__global__ void FusedResidualDropoutBias(
const size_t rows,
const size_t cols,
......@@ -175,8 +181,15 @@ __global__ void FusedResidualDropoutBias(
int row_id = blockIdx.y;
int idx = row_id * cols + col_id;
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);
const T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
if (HasDropout) {
curand_init(seed, idx, increment, &state);
}
T factor;
if (HasDropout) {
factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
} else {
factor = static_cast<T>(1);
}
phi::funcs::ReluFunctor<T> relu;
for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
for (int i = col_id * VecSize; i < cols;
......@@ -188,24 +201,25 @@ __global__ void FusedResidualDropoutBias(
false,
phi::funcs::ReluFunctor<T>,
InType,
OutType>(r,
i,
cols,
&state,
dropout_prob,
factor,
src,
residual,
bias,
dst,
mask,
is_test,
nullptr,
nullptr,
relu,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
OutType,
HasDropout>(r,
i,
cols,
&state,
dropout_prob,
factor,
src,
residual,
bias,
dst,
mask,
is_test,
nullptr,
nullptr,
relu,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
}
}
}
......@@ -256,43 +270,64 @@ void LaunchResidualDropoutBias(const uint32_t rows,
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
if (cols % VecSize == 0) {
FusedResidualDropoutBias<T, uint8_t, VecSize, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
rows,
cols,
seed,
dropout_prob,
is_upscale_in_train,
src,
residual,
bias,
mask_data,
dst,
increment,
is_test,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
#define PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_KERNEL(__has_dropout) \
do { \
if (cols % VecSize == 0) { \
FusedResidualDropoutBias<T, \
uint8_t, \
VecSize, \
InType, \
OutType, \
__has_dropout> \
<<<config.block_per_grid, \
config.thread_per_block, \
0, \
ctx.stream()>>>(rows, \
cols, \
seed, \
dropout_prob, \
is_upscale_in_train, \
src, \
residual, \
bias, \
mask_data, \
dst, \
increment, \
is_test, \
quant_last_in_scale, \
dequant_out_scale_data, \
quant_next_in_scale); \
} else { \
FusedResidualDropoutBias<T, uint8_t, 1, InType, OutType, __has_dropout> \
<<<config.block_per_grid, \
config.thread_per_block, \
0, \
ctx.stream()>>>(rows, \
cols, \
seed, \
dropout_prob, \
is_upscale_in_train, \
src, \
residual, \
bias, \
mask_data, \
dst, \
increment, \
is_test, \
quant_last_in_scale, \
dequant_out_scale_data, \
quant_next_in_scale); \
} \
} while (0)
if (dropout_prob != 0.0f) {
PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_KERNEL(true);
} else {
FusedResidualDropoutBias<T, uint8_t, 1, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
rows,
cols,
seed,
dropout_prob,
is_upscale_in_train,
src,
residual,
bias,
mask_data,
dst,
increment,
is_test,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_KERNEL(false);
}
#undef PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_KERNEL
}
/*
......@@ -334,7 +369,8 @@ template <typename T,
typename MaskType,
int BlockSizeX,
int BlockSizeY,
int VecSize>
int VecSize,
bool HasDropout>
__global__ void FusedResidualDropoutBiasGrad(const T *dout,
const MaskType *mask,
const T factor,
......@@ -350,6 +386,9 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout,
T tmp_sum[VecSize] = {static_cast<T>(0)};
// calculate the dx and temporary sum
const bool not_need_dx = (dx == nullptr) || (dx == dout && !HasDropout &&
factor == static_cast<T>(1.0));
if (col_id * VecSize < cols) {
for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) {
int index = row_id * cols + col_id * VecSize;
......@@ -357,15 +396,27 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout,
MaskLoadT mask_vec;
StoreT dx_vec;
phi::Load<T, VecSize>(&dout[index], &out_vec);
phi::Load<MaskType, VecSize>(&mask[index], &mask_vec);
if (HasDropout) {
phi::Load<MaskType, VecSize>(&mask[index], &mask_vec);
}
if (not_need_dx) {
#pragma unroll
for (int i = 0; i < VecSize; i++) {
tmp_sum[i] += out_vec[i];
}
} else {
#pragma unroll
for (int i = 0; i < VecSize; i++) {
dx_vec[i] = out_vec[i] * static_cast<T>(mask_vec[i]) * factor;
tmp_sum[i] += out_vec[i];
for (int i = 0; i < VecSize; i++) {
if (HasDropout) {
dx_vec[i] = out_vec[i] * static_cast<T>(mask_vec[i]) * factor;
} else {
dx_vec[i] = out_vec[i] * factor;
}
tmp_sum[i] += out_vec[i];
}
phi::Store<T, VecSize>(dx_vec, &dx[index]);
}
phi::Store<T, VecSize>(dx_vec, &dx[index]);
}
}
......@@ -395,35 +446,68 @@ void LaunchResidualDropoutBiasGrad(const T *dout,
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
if (dbias != nullptr) {
const auto threads = 8;
auto blocks = std::max(static_cast<uint32_t>(1),
(cols / real_vec_size + threads - 1) / threads);
dim3 block_dim(threads, 128, 1);
dim3 grid_dim(blocks, 1, 1);
if (cols % VecSize == 0) {
FusedResidualDropoutBiasGrad<T, MaskType, 8, 128, VecSize>
<<<grid_dim, block_dim, 0, ctx.stream()>>>(
dout, mask, factor, rows, cols, dx, dbias);
} else {
FusedResidualDropoutBiasGrad<T, MaskType, 8, 128, 1>
<<<grid_dim, block_dim, 0, ctx.stream()>>>(
dout, mask, factor, rows, cols, dx, dbias);
}
#define PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_GRAD_KERNEL(__has_dropout) \
do { \
if (dbias != nullptr) { \
const auto threads = 8; \
auto blocks = std::max(static_cast<uint32_t>(1), \
(cols / real_vec_size + threads - 1) / threads); \
dim3 block_dim(threads, 128, 1); \
dim3 grid_dim(blocks, 1, 1); \
if (cols % VecSize == 0) { \
FusedResidualDropoutBiasGrad<T, \
MaskType, \
8, \
128, \
VecSize, \
__has_dropout> \
<<<grid_dim, block_dim, 0, ctx.stream()>>>( \
dout, mask, factor, rows, cols, dx, dbias); \
} else { \
FusedResidualDropoutBiasGrad<T, MaskType, 8, 128, 1, __has_dropout> \
<<<grid_dim, block_dim, 0, ctx.stream()>>>( \
dout, mask, factor, rows, cols, dx, dbias); \
} \
} else { \
if (dropout_prob == 0.0f) { \
if (dx == nullptr || dx == dout) { \
return; \
} \
memory::Copy(ctx.GetPlace(), \
dx, \
ctx.GetPlace(), \
dout, \
rows *cols * sizeof(T), \
ctx.stream()); \
} else { \
const uint64_t n = rows * cols; \
platform::GpuLaunchConfig config = \
platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size); \
if (n % VecSize == 0) { \
FusedResidualDropoutGrad<T, MaskType, VecSize> \
<<<config.block_per_grid, \
config.thread_per_block, \
0, \
ctx.stream()>>>(dout, mask, factor, n, dx); \
} else { \
FusedResidualDropoutGrad<T, MaskType, 1> \
<<<config.block_per_grid, \
config.thread_per_block, \
0, \
ctx.stream()>>>(dout, mask, factor, n, dx); \
} \
} \
} \
} while (0)
if (dropout_prob != 0.0f) {
PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_GRAD_KERNEL(true);
} else {
const uint64_t n = rows * cols;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size);
if (n % VecSize == 0) {
FusedResidualDropoutGrad<T, MaskType, VecSize>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
dout, mask, factor, n, dx);
} else {
FusedResidualDropoutGrad<T, MaskType, 1>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
dout, mask, factor, n, dx);
}
PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_GRAD_KERNEL(false);
}
#undef PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_GRAD_KERNEL
}
} // namespace operators
......
......@@ -258,14 +258,14 @@ struct FusedResidualDropoutBiasTester {
std::vector<T> fused_out(n);
std::vector<uint8_t> fused_mask(n);
framework::TensorToVector(out, *ctx, &fused_out);
if (!is_test) {
if (!is_test && dropout_prob != 0.0f) {
framework::TensorToVector<uint8_t>(mask, *ctx, &fused_mask);
}
ctx->Wait();
for (int i = 0; i < n; i++) {
EXPECT_LT(std::abs(fused_out[i] - correct_out[i]), diff);
if (!is_test) {
if (!is_test && dropout_prob != 0.0f) {
EXPECT_EQ(fused_mask[i], correct_mask[i]);
}
}
......
......@@ -501,7 +501,8 @@ __inline__ __device__ void cuLoadAddStridedInputs(const int64_t i1_block,
}
#ifdef PADDLE_WITH_CUDA
template <bool isFusedDropoutResidualLn,
template <bool IsFusedDropoutResidualLn,
bool NeedDDropoutSrcPtr,
typename T,
typename U,
typename ScaleT = U,
......@@ -531,6 +532,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel(
const MaskType *mask_ptr = nullptr,
T factor = static_cast<T>(0),
T *d_dropout_src_ptr = nullptr) {
static_assert(
!IsFusedDropoutResidualLn || NeedDDropoutSrcPtr,
"When IsFusedDropoutResidualLn = true, NeedDDropoutSrcPtr must be true.");
using Vec = phi::AlignedVector<T, VecSize>;
using Vec_scale = phi::AlignedVector<ScaleT, VecSize>;
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
......@@ -585,7 +590,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel(
phi::Load<T, VecSize>(dout_ptr + row * ELTS_PER_ROW + col * VecSize,
&dout[it]);
phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]);
if (isFusedDropoutResidualLn) {
if (IsFusedDropoutResidualLn) {
phi::Load<MaskType, VecSize>(
mask_ptr + row * ELTS_PER_ROW + col * VecSize, &mask_vec[it]);
}
......@@ -671,7 +676,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel(
U dx_tmp = var_cur_row * (dy_tmp - sum_loss2 * y_tmp - sum_loss1);
// Note: reuse x and dout vec register to store dx and d_dropout_src.
x[it][jt] = static_cast<T>(dx_tmp);
if (isFusedDropoutResidualLn) {
if (IsFusedDropoutResidualLn) {
dout[it][jt] = x[it][jt] * static_cast<T>(mask_vec[it][jt]) * factor;
}
}
......@@ -683,9 +688,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel(
for (int it = 0; it < LDGS; it++) {
phi::Store<T, VecSize>(x[it],
dx_ptr + row * ELTS_PER_ROW + col * VecSize);
if (isFusedDropoutResidualLn) {
if (IsFusedDropoutResidualLn) {
phi::Store<T, VecSize>(
dout[it], d_dropout_src_ptr + row * ELTS_PER_ROW + col * VecSize);
} else if (NeedDDropoutSrcPtr) {
phi::Store<T, VecSize>(
x[it], d_dropout_src_ptr + row * ELTS_PER_ROW + col * VecSize);
}
col += THREADS_PER_ROW;
}
......@@ -955,6 +963,7 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx,
}
#define LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \
fused_ln_bwd_fast_kernel<true, \
true, \
T, \
U, \
ScaleT, \
......@@ -993,8 +1002,10 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx,
#undef LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL
} else {
#define LAUNCH_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \
#define LAUNCH_FUSED_LN_BWD_FAST_KERNEL_BASE( \
vec_size, ele_per_row, need_d_dropout_src_ptr) \
fused_ln_bwd_fast_kernel<false, \
need_d_dropout_src_ptr, \
T, \
U, \
ScaleT, \
......@@ -1013,7 +1024,19 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx,
dout_ptr, \
dscale_temp_ptr, \
dbias_temp_ptr, \
dx_ptr);
dx_ptr, \
nullptr, \
factor, \
d_dropout_src_ptr);
#define LAUNCH_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \
do { \
if (d_dropout_src_ptr != nullptr) { \
LAUNCH_FUSED_LN_BWD_FAST_KERNEL_BASE(vec_size, ele_per_row, true); \
} else { \
LAUNCH_FUSED_LN_BWD_FAST_KERNEL_BASE(vec_size, ele_per_row, false); \
} \
} while (0)
if (cols == 1024) {
LAUNCH_FUSED_LN_BWD_FAST_KERNEL(VecSize, 1024);
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.incubate.nn import FusedMultiHeadAttention
def random_init_model(model, seed):
paddle.seed(seed)
for p in model.parameters():
shape = p.shape
dtype = p.dtype
value = paddle.randn(shape=shape, dtype=dtype)
p.set_value(value.numpy())
class FusedAttentionTestLayer(FusedMultiHeadAttention):
def __init__(self, embed_dim, num_heads, normalize_before=False):
super().__init__(
embed_dim=embed_dim,
num_heads=num_heads,
attn_dropout_rate=0.0,
dropout_rate=0.0,
normalize_before=normalize_before,
)
def _reshape_and_transpose(self, x):
assert len(x.shape) == 3
bs, seq_len = x.shape[:2]
x = x.reshape([bs, seq_len, self.num_heads, self.head_dim])
x = x.transpose([0, 2, 1, 3])
return x
def _transpose_and_reshape(self, x):
assert len(x.shape) == 4
x = x.transpose([0, 2, 1, 3])
bs = x.shape[0]
x = x.reshape([bs, -1, self.embed_dim])
return x
def forward(self, x, attn_mask, use_ref=False):
if use_ref:
return self.ref_forward(x, attn_mask)
else:
return super().forward(x, attn_mask)
def ref_forward(self, x, attn_mask):
residual = x
if self.normalize_before:
assert len(self.pre_ln_scale.shape) == 1
out = F.layer_norm(
x,
self.pre_ln_scale.shape,
weight=self.pre_ln_scale,
bias=self.pre_ln_bias,
epsilon=self._epsilon,
)
else:
out = x
qkv_weight = self.qkv_weight.reshape(
[3 * self.embed_dim, self.embed_dim]
)
qkv_bias = self.qkv_bias.reshape([3 * self.embed_dim])
out = paddle.matmul(out, qkv_weight, transpose_y=True) + qkv_bias
# [BS, seq_len, head_dim]
# [BS, seq_len, head_dim * 3]
q, k, v = paddle.split(out, 3, axis=-1)
q = self._reshape_and_transpose(q)
k = self._reshape_and_transpose(k)
v = self._reshape_and_transpose(v)
q *= self.head_dim**-0.5
out = paddle.matmul(q, k, transpose_y=True)
if attn_mask is not None:
out += attn_mask
out = F.softmax(out)
out = paddle.matmul(out, v)
out = self._transpose_and_reshape(out)
out = F.linear(out, weight=self.linear_weight, bias=self.linear_bias)
add_residual = True
if add_residual:
out = residual + out
if not self.normalize_before:
assert len(self.ln_scale.shape) == 1
out = F.layer_norm(
out,
self.ln_scale.shape,
weight=self.ln_scale,
bias=self.ln_bias,
epsilon=self._epsilon,
)
return out
class TestFusedAttention(unittest.TestCase):
def setUp(self):
self.batch_size = 8
self.num_heads = 16
self.max_seq_len = 128
self.hidden_size = 256
self.dtype = "float32"
self.normalize_before = False
self.seed = 10
self.use_mask = False
self.set_configs()
def set_configs(self):
pass
def generate_inputs(self):
np.random.seed(self.seed)
hidden_state = np.random.random(
size=[self.batch_size, self.max_seq_len, self.hidden_size]
).astype(self.dtype)
hidden_state = paddle.to_tensor(hidden_state)
hidden_state.stop_gradient = False
if self.use_mask:
seq_lens = np.random.randint(
low=int(self.max_seq_len / 3),
high=self.max_seq_len,
size=[self.batch_size],
)
mask = np.zeros(
shape=[self.batch_size, self.max_seq_len], dtype=self.dtype
)
for i in range(self.batch_size):
mask[i][0 : seq_lens[i]] = 1
mask = mask.reshape([self.batch_size, 1, 1, self.max_seq_len])
broadcast_shape = [
self.batch_size,
self.num_heads,
self.max_seq_len,
self.max_seq_len,
]
mask = np.broadcast_to(mask, broadcast_shape)
mask = (1 - mask) * -1e9
return hidden_state, paddle.to_tensor(mask.astype(self.dtype))
else:
return hidden_state, None
def run_fwd_bwd(self, use_ref=False):
x, mask = self.generate_inputs()
layer = FusedAttentionTestLayer(
self.hidden_size,
self.num_heads,
normalize_before=self.normalize_before,
)
random_init_model(layer, self.seed + 100)
out = layer(x, mask, use_ref)
loss = out.mean()
loss.backward()
vars_need_gradients = [('out', x)] + list(layer.named_parameters())
numpy_values = [out.numpy()]
for i, (name, var) in enumerate(vars_need_gradients):
tmp = var.grad.numpy()
numpy_values.append(tmp)
return numpy_values
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
values1 = self.run_fwd_bwd(True)
paddle.device.cuda.synchronize()
values2 = self.run_fwd_bwd(False)
paddle.device.cuda.synchronize()
self.assertEqual(len(values1), len(values2))
for i, (v1, v2) in enumerate(zip(values1, values2)):
if not self.normalize_before:
np.testing.assert_allclose(v1, v2, atol=1e-6, rtol=1e-5)
else:
np.testing.assert_equal(v1, v2)
class TestFusedAttentionNormalizeBefore(TestFusedAttention):
def set_configs(self):
self.normalize_before = True
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册