未验证 提交 428fb804 编写于 作者: S sneaxiy 提交者: GitHub

Save fused_attention op memory when dropout_rate = 0.0 (#48902)

* save fused_attention memory when dropout_rate = 0.0

* add ut

* fix ut bug

* fix fused_layernorm_residual_dropout_bias_test.cu
上级 889e5834
...@@ -572,15 +572,17 @@ fused_attention_dygraph_function( ...@@ -572,15 +572,17 @@ fused_attention_dygraph_function(
egr::EagerUtils::CheckAndRetainGrad(SoftmaxOut); egr::EagerUtils::CheckAndRetainGrad(SoftmaxOut);
grad_node->SetGradOutMeta(SoftmaxOut, 19); grad_node->SetGradOutMeta(SoftmaxOut, 19);
auto AttnDropoutOut_accumulation_node = if (AttnDropoutOut.initialized()) {
std::make_shared<egr::GradNodeAccumulation>( auto AttnDropoutOut_accumulation_node =
p_autograd_AttnDropoutOut); std::make_shared<egr::GradNodeAccumulation>(
egr::EagerUtils::SetOutRankWithSlot(p_autograd_AttnDropoutOut, 0); p_autograd_AttnDropoutOut);
egr::EagerUtils::SetHistory(p_autograd_AttnDropoutOut, egr::EagerUtils::SetOutRankWithSlot(p_autograd_AttnDropoutOut, 0);
AttnDropoutOut_accumulation_node); egr::EagerUtils::SetHistory(p_autograd_AttnDropoutOut,
AttnDropoutOut_accumulation_node->SetGradInMeta(AttnDropoutOut, 0); AttnDropoutOut_accumulation_node);
egr::EagerUtils::CheckAndRetainGrad(AttnDropoutOut); AttnDropoutOut_accumulation_node->SetGradInMeta(AttnDropoutOut, 0);
grad_node->SetGradOutMeta(AttnDropoutOut, 20); egr::EagerUtils::CheckAndRetainGrad(AttnDropoutOut);
grad_node->SetGradOutMeta(AttnDropoutOut, 20);
}
auto FMHAOut_accumulation_node = auto FMHAOut_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(p_autograd_FMHAOut); std::make_shared<egr::GradNodeAccumulation>(p_autograd_FMHAOut);
......
...@@ -476,7 +476,7 @@ class fused_attentionGradNodeCompat : public egr::GradNodeBase { ...@@ -476,7 +476,7 @@ class fused_attentionGradNodeCompat : public egr::GradNodeBase {
SoftmaxOut_ = egr::TensorWrapper(SoftmaxOut, false); SoftmaxOut_ = egr::TensorWrapper(SoftmaxOut, false);
} }
void SetTensorWrapperSrcMask(const paddle::experimental::Tensor& SrcMask) { void SetTensorWrapperSrcMask(const paddle::experimental::Tensor& SrcMask) {
SrcMask_ = egr::TensorWrapper(SrcMask, false); SrcMask_ = egr::TensorWrapper(SrcMask, true);
} }
void SetTensorWrapperSrcMaskOut( void SetTensorWrapperSrcMaskOut(
const paddle::experimental::Tensor& SrcMaskOut) { const paddle::experimental::Tensor& SrcMaskOut) {
......
...@@ -102,7 +102,6 @@ class FMHARef { ...@@ -102,7 +102,6 @@ class FMHARef {
T* qk_out_data = qk_out_tensor->data<T>(); T* qk_out_data = qk_out_tensor->data<T>();
T* qktv_out_data = qktv_out_tensor->data<T>(); T* qktv_out_data = qktv_out_tensor->data<T>();
T* softmax_out_data = softmax_out_tensor->data<T>(); T* softmax_out_data = softmax_out_tensor->data<T>();
T* dropout_out_data = dropout_out_tensor->data<T>();
T* fmha_out_data = fmha_out_tensor->data<T>(); T* fmha_out_data = fmha_out_tensor->data<T>();
auto out_seq_len = seq_len_; auto out_seq_len = seq_len_;
...@@ -219,6 +218,7 @@ class FMHARef { ...@@ -219,6 +218,7 @@ class FMHARef {
dropout_mask_out_tensor, dropout_mask_out_tensor,
dropout_out_tensor, dropout_out_tensor,
false); false);
T* dropout_out_data = dropout_out_tensor->data<T>();
blas.BatchedGEMM(transA, blas.BatchedGEMM(transA,
transB, transB,
gemm_m, gemm_m,
...@@ -462,8 +462,6 @@ class FMHARef { ...@@ -462,8 +462,6 @@ class FMHARef {
const T* softmax_out_data = softmax_out_tensor.data<T>(); const T* softmax_out_data = softmax_out_tensor.data<T>();
T* softmax_out_grad_data = softmax_out_grad_tensor->data<T>(); T* softmax_out_grad_data = softmax_out_grad_tensor->data<T>();
const T* dropout_out_data = dropout_out_tensor.data<T>();
T* dropout_out_grad_data = dropout_out_grad_tensor->data<T>();
T* qktv_out_grad_data = qktv_out_grad_tensor->data<T>(); T* qktv_out_grad_data = qktv_out_grad_tensor->data<T>();
// transpose bw // transpose bw
...@@ -485,6 +483,7 @@ class FMHARef { ...@@ -485,6 +483,7 @@ class FMHARef {
int64_t stride_b = gemm_k * gemm_n; int64_t stride_b = gemm_k * gemm_n;
// bw: dy = x^t * dout // bw: dy = x^t * dout
if (dropout_param_.dropout_prob_) { if (dropout_param_.dropout_prob_) {
const T* dropout_out_data = dropout_out_tensor.data<T>();
blas.BatchedGEMM(transA, blas.BatchedGEMM(transA,
transB, transB,
gemm_m, gemm_m,
...@@ -522,6 +521,7 @@ class FMHARef { ...@@ -522,6 +521,7 @@ class FMHARef {
stride_a = gemm_m * gemm_k; stride_a = gemm_m * gemm_k;
stride_b = gemm_k * gemm_n; stride_b = gemm_k * gemm_n;
if (dropout_param_.dropout_prob_) { if (dropout_param_.dropout_prob_) {
T* dropout_out_grad_data = dropout_out_grad_tensor->data<T>();
blas.BatchedGEMM(transA, blas.BatchedGEMM(transA,
transB, transB,
gemm_m, gemm_m,
......
...@@ -545,8 +545,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ...@@ -545,8 +545,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("QKOut")); ctx->GetInputDim("QKOut"));
ctx->SetOutputDim(framework::GradVarName("SoftmaxOut"), ctx->SetOutputDim(framework::GradVarName("SoftmaxOut"),
ctx->GetInputDim("SoftmaxOut")); ctx->GetInputDim("SoftmaxOut"));
ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"), if (ctx->HasOutput(framework::GradVarName("AttnDropoutOut"))) {
ctx->GetInputDim("AttnDropoutOut")); ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"),
ctx->GetInputDim("AttnDropoutOut"));
}
if (ctx->HasOutput(framework::GradVarName("SrcMaskOut"))) { if (ctx->HasOutput(framework::GradVarName("SrcMaskOut"))) {
ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"), ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"),
...@@ -707,7 +709,8 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(FusedAttentionGradNoNeedBufferInferer, ...@@ -707,7 +709,8 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(FusedAttentionGradNoNeedBufferInferer,
"QKVOut", "QKVOut",
"QKOut", "QKOut",
"QKTVOut", "QKTVOut",
"OutLinearOut"); "OutLinearOut",
"SrcMask");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -121,6 +121,10 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -121,6 +121,10 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
const float ln_epsilon = ctx.Attr<float>("ln_epsilon"); const float ln_epsilon = ctx.Attr<float>("ln_epsilon");
float attn_dropout_rate = ctx.Attr<float>("attn_dropout_rate"); float attn_dropout_rate = ctx.Attr<float>("attn_dropout_rate");
const bool has_attn_dropout = (attn_dropout_rate != 0.0f);
DropoutParam dropout_param2(ctx, 0);
const bool has_dropout = (dropout_param2.dropout_prob != 0.0f);
bool is_test_1 = ctx.Attr<bool>("is_test"); bool is_test_1 = ctx.Attr<bool>("is_test");
auto &dropout_implementation_1 = auto &dropout_implementation_1 =
ctx.Attr<std::string>("attn_dropout_implementation"); ctx.Attr<std::string>("attn_dropout_implementation");
...@@ -169,11 +173,16 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -169,11 +173,16 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
src_mask_out->numel() * sizeof(T)); src_mask_out->numel() * sizeof(T));
auto *softmax_out_data = dev_ctx.template Alloc<T>( auto *softmax_out_data = dev_ctx.template Alloc<T>(
softmax_out, softmax_out->numel() * sizeof(T)); softmax_out, softmax_out->numel() * sizeof(T));
auto *attn_dropout_mask_out_data = dev_ctx.template Alloc<uint8_t>( auto *attn_dropout_mask_out_data =
attn_dropout_mask_out, has_attn_dropout ? dev_ctx.template Alloc<uint8_t>(
attn_dropout_mask_out->numel() * sizeof(uint8_t)); attn_dropout_mask_out,
auto *attn_dropout_out_data = dev_ctx.template Alloc<T>( attn_dropout_mask_out->numel() * sizeof(uint8_t))
attn_dropout_out, attn_dropout_out->numel() * sizeof(T)); : nullptr;
auto *attn_dropout_out_data =
has_attn_dropout
? dev_ctx.template Alloc<T>(attn_dropout_out,
attn_dropout_out->numel() * sizeof(T))
: nullptr;
auto *fmha_out_data = auto *fmha_out_data =
dev_ctx.template Alloc<T>(fmha_out, fmha_out->numel() * sizeof(T)); dev_ctx.template Alloc<T>(fmha_out, fmha_out->numel() * sizeof(T));
...@@ -185,8 +194,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -185,8 +194,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
out_linear_out, out_linear_out->numel() * sizeof(T)); out_linear_out, out_linear_out->numel() * sizeof(T));
// get data ptr for bias+dropout+residual+layernorm // get data ptr for bias+dropout+residual+layernorm
auto *dropout_mask_out_data = dev_ctx.template Alloc<uint8_t>( auto *dropout_mask_out_data =
dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t)); has_dropout
? dev_ctx.template Alloc<uint8_t>(
dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t))
: nullptr;
auto *final_out_data = auto *final_out_data =
dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T)); dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
...@@ -246,7 +258,6 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -246,7 +258,6 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
input_size, input_size,
output_size, output_size,
false); false);
DropoutParam dropout_param2(ctx, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper( FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx.cuda_device_context(), ctx.cuda_device_context(),
bsz_seq, bsz_seq,
...@@ -367,7 +378,11 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -367,7 +378,11 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const float ln2epsilon = ctx.Attr<float>("ln_epsilon"); const float ln2epsilon = ctx.Attr<float>("ln_epsilon");
float attn_dropout_prob = ctx.Attr<float>("attn_dropout_rate"); const float attn_dropout_prob = ctx.Attr<float>("attn_dropout_rate");
const bool has_attn_dropout = (attn_dropout_prob != 0.0f);
DropoutParam dropout_param2(ctx, 0);
const bool has_dropout = (dropout_param2.dropout_prob != 0.0f);
auto &dev_ctx = ctx.template device_context<phi::GPUContext>(); auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
bool is_test_1 = ctx.Attr<bool>("is_test"); bool is_test_1 = ctx.Attr<bool>("is_test");
auto &dropout_implementation_1 = auto &dropout_implementation_1 =
...@@ -398,7 +413,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -398,7 +413,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *qkv_bias = ctx.Input<phi::DenseTensor>("QKVBias"); auto *qkv_bias = ctx.Input<phi::DenseTensor>("QKVBias");
auto *out_linear_weight = ctx.Input<phi::DenseTensor>("OutLinearW"); auto *out_linear_weight = ctx.Input<phi::DenseTensor>("OutLinearW");
auto *out_linear_bias = ctx.Input<phi::DenseTensor>("OutLinearBias"); auto *out_linear_bias = ctx.Input<phi::DenseTensor>("OutLinearBias");
auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data<T>());
auto *qkv_weight_data = qkv_weight->data<T>(); auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>(); auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
auto *out_linear_weight_data = out_linear_weight->data<T>(); auto *out_linear_weight_data = out_linear_weight->data<T>();
...@@ -424,7 +438,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -424,7 +438,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *softmax_out_data = softmax_out->data<T>(); auto *softmax_out_data = softmax_out->data<T>();
auto *src_mask_out_data = auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr : src_mask_out->data<T>(); (src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>(); auto *dropout_mask_out_data =
has_dropout ? dropout_mask_out->data<uint8_t>() : nullptr;
// output's grad // output's grad
auto *d_x = ctx.Output<phi::DenseTensor>(framework::GradVarName("X")); auto *d_x = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
...@@ -470,8 +485,11 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -470,8 +485,11 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
dev_ctx.template Alloc<T>(d_qk_out, d_qk_out->numel() * sizeof(T)); dev_ctx.template Alloc<T>(d_qk_out, d_qk_out->numel() * sizeof(T));
auto *d_softmax_out_data = dev_ctx.template Alloc<T>( auto *d_softmax_out_data = dev_ctx.template Alloc<T>(
d_softmax_out, d_softmax_out->numel() * sizeof(T)); d_softmax_out, d_softmax_out->numel() * sizeof(T));
auto *d_attn_dropout_out_data = dev_ctx.template Alloc<T>( auto *d_attn_dropout_out_data =
d_attn_dropout_out, d_attn_dropout_out->numel() * sizeof(T)); has_attn_dropout
? dev_ctx.template Alloc<T>(d_attn_dropout_out,
d_attn_dropout_out->numel() * sizeof(T))
: nullptr;
auto *d_src_mask_out_data = auto *d_src_mask_out_data =
(src_mask == nullptr) (src_mask == nullptr)
? nullptr ? nullptr
...@@ -571,7 +589,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -571,7 +589,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
input_size, input_size,
output_size, output_size,
compute_bias); compute_bias);
DropoutParam dropout_param2(ctx, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper( FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx.cuda_device_context(), ctx.cuda_device_context(),
bsz_seq, bsz_seq,
...@@ -631,7 +648,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -631,7 +648,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
if (qkv_bias != nullptr) { if (qkv_bias != nullptr) {
fmha_ref_compute.ComputeBackward(*transpose_out_2, fmha_ref_compute.ComputeBackward(*transpose_out_2,
src_mask, has_attn_dropout ? src_mask : nullptr,
*softmax_out, *softmax_out,
*attn_dropout_mask_out, *attn_dropout_mask_out,
*attn_dropout_out, *attn_dropout_out,
...@@ -648,7 +665,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -648,7 +665,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_qkv_bias_out); d_qkv_bias_out);
} else { } else {
fmha_ref_compute.ComputeBackward(*transpose_out_2, fmha_ref_compute.ComputeBackward(*transpose_out_2,
src_mask, has_attn_dropout ? src_mask : nullptr,
*softmax_out, *softmax_out,
*attn_dropout_mask_out, *attn_dropout_mask_out,
*attn_dropout_out, *attn_dropout_out,
......
...@@ -290,7 +290,7 @@ struct TestFusedLayernormResidualDropoutBias { ...@@ -290,7 +290,7 @@ struct TestFusedLayernormResidualDropoutBias {
framework::TensorToVector(layernorm_out, *ctx, &_layernorm_out); framework::TensorToVector(layernorm_out, *ctx, &_layernorm_out);
framework::TensorToVector(means, *ctx, &_means); framework::TensorToVector(means, *ctx, &_means);
framework::TensorToVector(vars, *ctx, &_vars); framework::TensorToVector(vars, *ctx, &_vars);
if (!is_test) { if (!is_test && dropout_prob != 0.0f) {
framework::TensorToVector(mask, *ctx, &_mask); framework::TensorToVector(mask, *ctx, &_mask);
} }
ctx->Wait(); ctx->Wait();
...@@ -298,7 +298,9 @@ struct TestFusedLayernormResidualDropoutBias { ...@@ -298,7 +298,9 @@ struct TestFusedLayernormResidualDropoutBias {
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
EXPECT_LT(std::abs(_out[i] - correct_out[i]), diff); EXPECT_LT(std::abs(_out[i] - correct_out[i]), diff);
EXPECT_LT(std::abs(_layernorm_out[i] - correct_layernorm_out[i]), diff); EXPECT_LT(std::abs(_layernorm_out[i] - correct_layernorm_out[i]), diff);
if (!is_test) EXPECT_EQ(_mask[i], correct_mask[i]); if (!is_test && dropout_prob != 0.0f) {
EXPECT_EQ(_mask[i], correct_mask[i]);
}
} }
for (int i = 0; i < rows; i++) { for (int i = 0; i < rows; i++) {
EXPECT_LT(std::abs(_means[i] - correct_means[i]), static_cast<U>(diff)); EXPECT_LT(std::abs(_means[i] - correct_means[i]), static_cast<U>(diff));
......
...@@ -30,7 +30,8 @@ template <typename T, ...@@ -30,7 +30,8 @@ template <typename T,
bool Activation, bool Activation,
typename Functor, typename Functor,
typename InType = T, typename InType = T,
typename OutType = T> typename OutType = T,
bool HasDropout = true>
__forceinline__ __device__ void FusedResidualDropoutBiasOneThread( __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
const int row_id, const int row_id,
const int col_id, const int col_id,
...@@ -84,7 +85,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( ...@@ -84,7 +85,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
} }
MaskStoreT mask_vec; MaskStoreT mask_vec;
if (!is_test) { if (!is_test && HasDropout) {
float rand[VecSize]; float rand[VecSize];
RandVec<VecSize>(state, rand); RandVec<VecSize>(state, rand);
#pragma unroll #pragma unroll
...@@ -114,8 +115,12 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( ...@@ -114,8 +115,12 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
if (Activation) { if (Activation) {
tmp = act_func(tmp); tmp = act_func(tmp);
} }
dest_vec[ii] = if (HasDropout) {
tmp * static_cast<T>(mask_vec[ii]) * factor + residual_vec[ii]; dest_vec[ii] =
tmp * static_cast<T>(mask_vec[ii]) * factor + residual_vec[ii];
} else {
dest_vec[ii] = tmp * factor + residual_vec[ii];
}
if (ComputeLayerNorm) { if (ComputeLayerNorm) {
U tmp = static_cast<U>(dest_vec[ii]); U tmp = static_cast<U>(dest_vec[ii]);
*mean_val += tmp; *mean_val += tmp;
...@@ -138,7 +143,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( ...@@ -138,7 +143,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
phi::Store<T, VecSize>(dest_vec, phi::Store<T, VecSize>(dest_vec,
reinterpret_cast<T *>(&dst[row_id * cols + col_id])); reinterpret_cast<T *>(&dst[row_id * cols + col_id]));
} }
if (!is_test) { if (!is_test && HasDropout) {
phi::Store<MaskType, VecSize>(mask_vec, &mask[row_id * cols + col_id]); phi::Store<MaskType, VecSize>(mask_vec, &mask[row_id * cols + col_id]);
} }
} }
...@@ -154,7 +159,8 @@ template <typename T, ...@@ -154,7 +159,8 @@ template <typename T,
typename MaskType, typename MaskType,
int VecSize, int VecSize,
typename InType = T, typename InType = T,
typename OutType = T> typename OutType = T,
bool HasDropout = true>
__global__ void FusedResidualDropoutBias( __global__ void FusedResidualDropoutBias(
const size_t rows, const size_t rows,
const size_t cols, const size_t cols,
...@@ -175,8 +181,15 @@ __global__ void FusedResidualDropoutBias( ...@@ -175,8 +181,15 @@ __global__ void FusedResidualDropoutBias(
int row_id = blockIdx.y; int row_id = blockIdx.y;
int idx = row_id * cols + col_id; int idx = row_id * cols + col_id;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state); if (HasDropout) {
const T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test); curand_init(seed, idx, increment, &state);
}
T factor;
if (HasDropout) {
factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
} else {
factor = static_cast<T>(1);
}
phi::funcs::ReluFunctor<T> relu; phi::funcs::ReluFunctor<T> relu;
for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
for (int i = col_id * VecSize; i < cols; for (int i = col_id * VecSize; i < cols;
...@@ -188,24 +201,25 @@ __global__ void FusedResidualDropoutBias( ...@@ -188,24 +201,25 @@ __global__ void FusedResidualDropoutBias(
false, false,
phi::funcs::ReluFunctor<T>, phi::funcs::ReluFunctor<T>,
InType, InType,
OutType>(r, OutType,
i, HasDropout>(r,
cols, i,
&state, cols,
dropout_prob, &state,
factor, dropout_prob,
src, factor,
residual, src,
bias, residual,
dst, bias,
mask, dst,
is_test, mask,
nullptr, is_test,
nullptr, nullptr,
relu, nullptr,
quant_last_in_scale, relu,
dequant_out_scale_data, quant_last_in_scale,
quant_next_in_scale); dequant_out_scale_data,
quant_next_in_scale);
} }
} }
} }
...@@ -256,43 +270,64 @@ void LaunchResidualDropoutBias(const uint32_t rows, ...@@ -256,43 +270,64 @@ void LaunchResidualDropoutBias(const uint32_t rows,
const int VecSize = MAX_CACHE_BYTES / sizeof(T); const int VecSize = MAX_CACHE_BYTES / sizeof(T);
const int real_vec_size = cols % VecSize == 0 ? VecSize : 1; const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size); auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
if (cols % VecSize == 0) {
FusedResidualDropoutBias<T, uint8_t, VecSize, InType, OutType> #define PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_KERNEL(__has_dropout) \
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( do { \
rows, if (cols % VecSize == 0) { \
cols, FusedResidualDropoutBias<T, \
seed, uint8_t, \
dropout_prob, VecSize, \
is_upscale_in_train, InType, \
src, OutType, \
residual, __has_dropout> \
bias, <<<config.block_per_grid, \
mask_data, config.thread_per_block, \
dst, 0, \
increment, ctx.stream()>>>(rows, \
is_test, cols, \
quant_last_in_scale, seed, \
dequant_out_scale_data, dropout_prob, \
quant_next_in_scale); is_upscale_in_train, \
src, \
residual, \
bias, \
mask_data, \
dst, \
increment, \
is_test, \
quant_last_in_scale, \
dequant_out_scale_data, \
quant_next_in_scale); \
} else { \
FusedResidualDropoutBias<T, uint8_t, 1, InType, OutType, __has_dropout> \
<<<config.block_per_grid, \
config.thread_per_block, \
0, \
ctx.stream()>>>(rows, \
cols, \
seed, \
dropout_prob, \
is_upscale_in_train, \
src, \
residual, \
bias, \
mask_data, \
dst, \
increment, \
is_test, \
quant_last_in_scale, \
dequant_out_scale_data, \
quant_next_in_scale); \
} \
} while (0)
if (dropout_prob != 0.0f) {
PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_KERNEL(true);
} else { } else {
FusedResidualDropoutBias<T, uint8_t, 1, InType, OutType> PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_KERNEL(false);
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
rows,
cols,
seed,
dropout_prob,
is_upscale_in_train,
src,
residual,
bias,
mask_data,
dst,
increment,
is_test,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
} }
#undef PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_KERNEL
} }
/* /*
...@@ -334,7 +369,8 @@ template <typename T, ...@@ -334,7 +369,8 @@ template <typename T,
typename MaskType, typename MaskType,
int BlockSizeX, int BlockSizeX,
int BlockSizeY, int BlockSizeY,
int VecSize> int VecSize,
bool HasDropout>
__global__ void FusedResidualDropoutBiasGrad(const T *dout, __global__ void FusedResidualDropoutBiasGrad(const T *dout,
const MaskType *mask, const MaskType *mask,
const T factor, const T factor,
...@@ -350,6 +386,9 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout, ...@@ -350,6 +386,9 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout,
T tmp_sum[VecSize] = {static_cast<T>(0)}; T tmp_sum[VecSize] = {static_cast<T>(0)};
// calculate the dx and temporary sum // calculate the dx and temporary sum
const bool not_need_dx = (dx == nullptr) || (dx == dout && !HasDropout &&
factor == static_cast<T>(1.0));
if (col_id * VecSize < cols) { if (col_id * VecSize < cols) {
for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) { for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) {
int index = row_id * cols + col_id * VecSize; int index = row_id * cols + col_id * VecSize;
...@@ -357,15 +396,27 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout, ...@@ -357,15 +396,27 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout,
MaskLoadT mask_vec; MaskLoadT mask_vec;
StoreT dx_vec; StoreT dx_vec;
phi::Load<T, VecSize>(&dout[index], &out_vec); phi::Load<T, VecSize>(&dout[index], &out_vec);
phi::Load<MaskType, VecSize>(&mask[index], &mask_vec); if (HasDropout) {
phi::Load<MaskType, VecSize>(&mask[index], &mask_vec);
}
if (not_need_dx) {
#pragma unroll
for (int i = 0; i < VecSize; i++) {
tmp_sum[i] += out_vec[i];
}
} else {
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
dx_vec[i] = out_vec[i] * static_cast<T>(mask_vec[i]) * factor; if (HasDropout) {
tmp_sum[i] += out_vec[i]; dx_vec[i] = out_vec[i] * static_cast<T>(mask_vec[i]) * factor;
} else {
dx_vec[i] = out_vec[i] * factor;
}
tmp_sum[i] += out_vec[i];
}
phi::Store<T, VecSize>(dx_vec, &dx[index]);
} }
phi::Store<T, VecSize>(dx_vec, &dx[index]);
} }
} }
...@@ -395,35 +446,68 @@ void LaunchResidualDropoutBiasGrad(const T *dout, ...@@ -395,35 +446,68 @@ void LaunchResidualDropoutBiasGrad(const T *dout,
const int VecSize = MAX_CACHE_BYTES / sizeof(T); const int VecSize = MAX_CACHE_BYTES / sizeof(T);
int real_vec_size = cols % VecSize == 0 ? VecSize : 1; int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
if (dbias != nullptr) {
const auto threads = 8; #define PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_GRAD_KERNEL(__has_dropout) \
auto blocks = std::max(static_cast<uint32_t>(1), do { \
(cols / real_vec_size + threads - 1) / threads); if (dbias != nullptr) { \
dim3 block_dim(threads, 128, 1); const auto threads = 8; \
dim3 grid_dim(blocks, 1, 1); auto blocks = std::max(static_cast<uint32_t>(1), \
if (cols % VecSize == 0) { (cols / real_vec_size + threads - 1) / threads); \
FusedResidualDropoutBiasGrad<T, MaskType, 8, 128, VecSize> dim3 block_dim(threads, 128, 1); \
<<<grid_dim, block_dim, 0, ctx.stream()>>>( dim3 grid_dim(blocks, 1, 1); \
dout, mask, factor, rows, cols, dx, dbias); if (cols % VecSize == 0) { \
} else { FusedResidualDropoutBiasGrad<T, \
FusedResidualDropoutBiasGrad<T, MaskType, 8, 128, 1> MaskType, \
<<<grid_dim, block_dim, 0, ctx.stream()>>>( 8, \
dout, mask, factor, rows, cols, dx, dbias); 128, \
} VecSize, \
__has_dropout> \
<<<grid_dim, block_dim, 0, ctx.stream()>>>( \
dout, mask, factor, rows, cols, dx, dbias); \
} else { \
FusedResidualDropoutBiasGrad<T, MaskType, 8, 128, 1, __has_dropout> \
<<<grid_dim, block_dim, 0, ctx.stream()>>>( \
dout, mask, factor, rows, cols, dx, dbias); \
} \
} else { \
if (dropout_prob == 0.0f) { \
if (dx == nullptr || dx == dout) { \
return; \
} \
memory::Copy(ctx.GetPlace(), \
dx, \
ctx.GetPlace(), \
dout, \
rows *cols * sizeof(T), \
ctx.stream()); \
} else { \
const uint64_t n = rows * cols; \
platform::GpuLaunchConfig config = \
platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size); \
if (n % VecSize == 0) { \
FusedResidualDropoutGrad<T, MaskType, VecSize> \
<<<config.block_per_grid, \
config.thread_per_block, \
0, \
ctx.stream()>>>(dout, mask, factor, n, dx); \
} else { \
FusedResidualDropoutGrad<T, MaskType, 1> \
<<<config.block_per_grid, \
config.thread_per_block, \
0, \
ctx.stream()>>>(dout, mask, factor, n, dx); \
} \
} \
} \
} while (0)
if (dropout_prob != 0.0f) {
PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_GRAD_KERNEL(true);
} else { } else {
const uint64_t n = rows * cols; PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_GRAD_KERNEL(false);
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size);
if (n % VecSize == 0) {
FusedResidualDropoutGrad<T, MaskType, VecSize>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
dout, mask, factor, n, dx);
} else {
FusedResidualDropoutGrad<T, MaskType, 1>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
dout, mask, factor, n, dx);
}
} }
#undef PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_GRAD_KERNEL
} }
} // namespace operators } // namespace operators
......
...@@ -258,14 +258,14 @@ struct FusedResidualDropoutBiasTester { ...@@ -258,14 +258,14 @@ struct FusedResidualDropoutBiasTester {
std::vector<T> fused_out(n); std::vector<T> fused_out(n);
std::vector<uint8_t> fused_mask(n); std::vector<uint8_t> fused_mask(n);
framework::TensorToVector(out, *ctx, &fused_out); framework::TensorToVector(out, *ctx, &fused_out);
if (!is_test) { if (!is_test && dropout_prob != 0.0f) {
framework::TensorToVector<uint8_t>(mask, *ctx, &fused_mask); framework::TensorToVector<uint8_t>(mask, *ctx, &fused_mask);
} }
ctx->Wait(); ctx->Wait();
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
EXPECT_LT(std::abs(fused_out[i] - correct_out[i]), diff); EXPECT_LT(std::abs(fused_out[i] - correct_out[i]), diff);
if (!is_test) { if (!is_test && dropout_prob != 0.0f) {
EXPECT_EQ(fused_mask[i], correct_mask[i]); EXPECT_EQ(fused_mask[i], correct_mask[i]);
} }
} }
......
...@@ -501,7 +501,8 @@ __inline__ __device__ void cuLoadAddStridedInputs(const int64_t i1_block, ...@@ -501,7 +501,8 @@ __inline__ __device__ void cuLoadAddStridedInputs(const int64_t i1_block,
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
template <bool isFusedDropoutResidualLn, template <bool IsFusedDropoutResidualLn,
bool NeedDDropoutSrcPtr,
typename T, typename T,
typename U, typename U,
typename ScaleT = U, typename ScaleT = U,
...@@ -531,6 +532,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel( ...@@ -531,6 +532,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel(
const MaskType *mask_ptr = nullptr, const MaskType *mask_ptr = nullptr,
T factor = static_cast<T>(0), T factor = static_cast<T>(0),
T *d_dropout_src_ptr = nullptr) { T *d_dropout_src_ptr = nullptr) {
static_assert(
!IsFusedDropoutResidualLn || NeedDDropoutSrcPtr,
"When IsFusedDropoutResidualLn = true, NeedDDropoutSrcPtr must be true.");
using Vec = phi::AlignedVector<T, VecSize>; using Vec = phi::AlignedVector<T, VecSize>;
using Vec_scale = phi::AlignedVector<ScaleT, VecSize>; using Vec_scale = phi::AlignedVector<ScaleT, VecSize>;
using MaskLoadT = phi::AlignedVector<MaskType, VecSize>; using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
...@@ -585,7 +590,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel( ...@@ -585,7 +590,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel(
phi::Load<T, VecSize>(dout_ptr + row * ELTS_PER_ROW + col * VecSize, phi::Load<T, VecSize>(dout_ptr + row * ELTS_PER_ROW + col * VecSize,
&dout[it]); &dout[it]);
phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]); phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]);
if (isFusedDropoutResidualLn) { if (IsFusedDropoutResidualLn) {
phi::Load<MaskType, VecSize>( phi::Load<MaskType, VecSize>(
mask_ptr + row * ELTS_PER_ROW + col * VecSize, &mask_vec[it]); mask_ptr + row * ELTS_PER_ROW + col * VecSize, &mask_vec[it]);
} }
...@@ -671,7 +676,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel( ...@@ -671,7 +676,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel(
U dx_tmp = var_cur_row * (dy_tmp - sum_loss2 * y_tmp - sum_loss1); U dx_tmp = var_cur_row * (dy_tmp - sum_loss2 * y_tmp - sum_loss1);
// Note: reuse x and dout vec register to store dx and d_dropout_src. // Note: reuse x and dout vec register to store dx and d_dropout_src.
x[it][jt] = static_cast<T>(dx_tmp); x[it][jt] = static_cast<T>(dx_tmp);
if (isFusedDropoutResidualLn) { if (IsFusedDropoutResidualLn) {
dout[it][jt] = x[it][jt] * static_cast<T>(mask_vec[it][jt]) * factor; dout[it][jt] = x[it][jt] * static_cast<T>(mask_vec[it][jt]) * factor;
} }
} }
...@@ -683,9 +688,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel( ...@@ -683,9 +688,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel(
for (int it = 0; it < LDGS; it++) { for (int it = 0; it < LDGS; it++) {
phi::Store<T, VecSize>(x[it], phi::Store<T, VecSize>(x[it],
dx_ptr + row * ELTS_PER_ROW + col * VecSize); dx_ptr + row * ELTS_PER_ROW + col * VecSize);
if (isFusedDropoutResidualLn) { if (IsFusedDropoutResidualLn) {
phi::Store<T, VecSize>( phi::Store<T, VecSize>(
dout[it], d_dropout_src_ptr + row * ELTS_PER_ROW + col * VecSize); dout[it], d_dropout_src_ptr + row * ELTS_PER_ROW + col * VecSize);
} else if (NeedDDropoutSrcPtr) {
phi::Store<T, VecSize>(
x[it], d_dropout_src_ptr + row * ELTS_PER_ROW + col * VecSize);
} }
col += THREADS_PER_ROW; col += THREADS_PER_ROW;
} }
...@@ -955,6 +963,7 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx, ...@@ -955,6 +963,7 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx,
} }
#define LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \ #define LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \
fused_ln_bwd_fast_kernel<true, \ fused_ln_bwd_fast_kernel<true, \
true, \
T, \ T, \
U, \ U, \
ScaleT, \ ScaleT, \
...@@ -993,8 +1002,10 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx, ...@@ -993,8 +1002,10 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx,
#undef LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL #undef LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL
} else { } else {
#define LAUNCH_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \ #define LAUNCH_FUSED_LN_BWD_FAST_KERNEL_BASE( \
vec_size, ele_per_row, need_d_dropout_src_ptr) \
fused_ln_bwd_fast_kernel<false, \ fused_ln_bwd_fast_kernel<false, \
need_d_dropout_src_ptr, \
T, \ T, \
U, \ U, \
ScaleT, \ ScaleT, \
...@@ -1013,7 +1024,19 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx, ...@@ -1013,7 +1024,19 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx,
dout_ptr, \ dout_ptr, \
dscale_temp_ptr, \ dscale_temp_ptr, \
dbias_temp_ptr, \ dbias_temp_ptr, \
dx_ptr); dx_ptr, \
nullptr, \
factor, \
d_dropout_src_ptr);
#define LAUNCH_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \
do { \
if (d_dropout_src_ptr != nullptr) { \
LAUNCH_FUSED_LN_BWD_FAST_KERNEL_BASE(vec_size, ele_per_row, true); \
} else { \
LAUNCH_FUSED_LN_BWD_FAST_KERNEL_BASE(vec_size, ele_per_row, false); \
} \
} while (0)
if (cols == 1024) { if (cols == 1024) {
LAUNCH_FUSED_LN_BWD_FAST_KERNEL(VecSize, 1024); LAUNCH_FUSED_LN_BWD_FAST_KERNEL(VecSize, 1024);
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.incubate.nn import FusedMultiHeadAttention
def random_init_model(model, seed):
paddle.seed(seed)
for p in model.parameters():
shape = p.shape
dtype = p.dtype
value = paddle.randn(shape=shape, dtype=dtype)
p.set_value(value.numpy())
class FusedAttentionTestLayer(FusedMultiHeadAttention):
def __init__(self, embed_dim, num_heads, normalize_before=False):
super().__init__(
embed_dim=embed_dim,
num_heads=num_heads,
attn_dropout_rate=0.0,
dropout_rate=0.0,
normalize_before=normalize_before,
)
def _reshape_and_transpose(self, x):
assert len(x.shape) == 3
bs, seq_len = x.shape[:2]
x = x.reshape([bs, seq_len, self.num_heads, self.head_dim])
x = x.transpose([0, 2, 1, 3])
return x
def _transpose_and_reshape(self, x):
assert len(x.shape) == 4
x = x.transpose([0, 2, 1, 3])
bs = x.shape[0]
x = x.reshape([bs, -1, self.embed_dim])
return x
def forward(self, x, attn_mask, use_ref=False):
if use_ref:
return self.ref_forward(x, attn_mask)
else:
return super().forward(x, attn_mask)
def ref_forward(self, x, attn_mask):
residual = x
if self.normalize_before:
assert len(self.pre_ln_scale.shape) == 1
out = F.layer_norm(
x,
self.pre_ln_scale.shape,
weight=self.pre_ln_scale,
bias=self.pre_ln_bias,
epsilon=self._epsilon,
)
else:
out = x
qkv_weight = self.qkv_weight.reshape(
[3 * self.embed_dim, self.embed_dim]
)
qkv_bias = self.qkv_bias.reshape([3 * self.embed_dim])
out = paddle.matmul(out, qkv_weight, transpose_y=True) + qkv_bias
# [BS, seq_len, head_dim]
# [BS, seq_len, head_dim * 3]
q, k, v = paddle.split(out, 3, axis=-1)
q = self._reshape_and_transpose(q)
k = self._reshape_and_transpose(k)
v = self._reshape_and_transpose(v)
q *= self.head_dim**-0.5
out = paddle.matmul(q, k, transpose_y=True)
if attn_mask is not None:
out += attn_mask
out = F.softmax(out)
out = paddle.matmul(out, v)
out = self._transpose_and_reshape(out)
out = F.linear(out, weight=self.linear_weight, bias=self.linear_bias)
add_residual = True
if add_residual:
out = residual + out
if not self.normalize_before:
assert len(self.ln_scale.shape) == 1
out = F.layer_norm(
out,
self.ln_scale.shape,
weight=self.ln_scale,
bias=self.ln_bias,
epsilon=self._epsilon,
)
return out
class TestFusedAttention(unittest.TestCase):
def setUp(self):
self.batch_size = 8
self.num_heads = 16
self.max_seq_len = 128
self.hidden_size = 256
self.dtype = "float32"
self.normalize_before = False
self.seed = 10
self.use_mask = False
self.set_configs()
def set_configs(self):
pass
def generate_inputs(self):
np.random.seed(self.seed)
hidden_state = np.random.random(
size=[self.batch_size, self.max_seq_len, self.hidden_size]
).astype(self.dtype)
hidden_state = paddle.to_tensor(hidden_state)
hidden_state.stop_gradient = False
if self.use_mask:
seq_lens = np.random.randint(
low=int(self.max_seq_len / 3),
high=self.max_seq_len,
size=[self.batch_size],
)
mask = np.zeros(
shape=[self.batch_size, self.max_seq_len], dtype=self.dtype
)
for i in range(self.batch_size):
mask[i][0 : seq_lens[i]] = 1
mask = mask.reshape([self.batch_size, 1, 1, self.max_seq_len])
broadcast_shape = [
self.batch_size,
self.num_heads,
self.max_seq_len,
self.max_seq_len,
]
mask = np.broadcast_to(mask, broadcast_shape)
mask = (1 - mask) * -1e9
return hidden_state, paddle.to_tensor(mask.astype(self.dtype))
else:
return hidden_state, None
def run_fwd_bwd(self, use_ref=False):
x, mask = self.generate_inputs()
layer = FusedAttentionTestLayer(
self.hidden_size,
self.num_heads,
normalize_before=self.normalize_before,
)
random_init_model(layer, self.seed + 100)
out = layer(x, mask, use_ref)
loss = out.mean()
loss.backward()
vars_need_gradients = [('out', x)] + list(layer.named_parameters())
numpy_values = [out.numpy()]
for i, (name, var) in enumerate(vars_need_gradients):
tmp = var.grad.numpy()
numpy_values.append(tmp)
return numpy_values
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
values1 = self.run_fwd_bwd(True)
paddle.device.cuda.synchronize()
values2 = self.run_fwd_bwd(False)
paddle.device.cuda.synchronize()
self.assertEqual(len(values1), len(values2))
for i, (v1, v2) in enumerate(zip(values1, values2)):
if not self.normalize_before:
np.testing.assert_allclose(v1, v2, atol=1e-6, rtol=1e-5)
else:
np.testing.assert_equal(v1, v2)
class TestFusedAttentionNormalizeBefore(TestFusedAttention):
def set_configs(self):
self.normalize_before = True
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册