From 19e866f92ef5e84f6e35934a2adb2269d4685bca Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Fri, 17 Jun 2022 18:10:00 +0800 Subject: [PATCH] Support optional residual add in fused_attention and fused_feedforward. (#43474) * Support optional residual add in fused_attention and fused_feedforward. * Add checkpoint and add the check of add_residual when pre_layer_norm is false. * Add TODO and change the python api to add add_residual argument. --- .../operators/fused/fused_attention_op.cc | 10 ++ .../operators/fused/fused_attention_op.cu | 61 +++---- .../operators/fused/fused_dropout_helper.h | 7 +- .../operators/fused/fused_feedforward_op.cc | 9 + .../operators/fused/fused_feedforward_op.cu | 84 ++++++---- .../fused/fused_residual_dropout_bias.h | 9 +- .../fused/fused_residual_dropout_bias_test.cu | 154 ++++++++++-------- .../nn/functional/fused_transformer.py | 12 +- 8 files changed, 209 insertions(+), 137 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 06ede8e2c7b..32dbe2b180c 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace operators { @@ -378,6 +379,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "0.0 and 0.001, But received [%s].", ln_epsilon)); }); + AddAttr("add_residual", "Whether to add residual.").SetDefault(true); AddAttr( "ring_id", "ring id for tensor model parallel. distributed training and inference") @@ -655,3 +657,11 @@ REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp, ops::FusedAttentionGradOpMaker, ops::FusedAttentionGradOpMaker); REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp); + +REGISTER_OP_VERSION(fused_attention) + .AddCheckpoint( + R"ROC( + Add a new attribute [add_residual] )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "add_residual", "A flag to indicate whether to add residual.", + true)); diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 73fdd29fd62..3a17be50450 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -246,26 +246,32 @@ class FusedAttentionOpKernel : public framework::OpKernel { // tensor model parallel AllReduce(*out_linear_out, ring_id, ctx.cuda_device_context()); + bool add_residual = ctx.Attr("add_residual"); + const T *residual_ptr = add_residual ? x_data : nullptr; if (pre_layer_norm) { // output = (residual + dropout(input + bias)) fused_dropout_layernorm_helper.ResidualDropoutBias( - ctx.cuda_device_context(), out_linear_out_data, x_data, + ctx.cuda_device_context(), out_linear_out_data, residual_ptr, out_linear_bias_data, final_out_data, dropout_mask_out_data); } else { - auto *ln_scale_2_data = - (ln_scale_2 == nullptr ? nullptr : ln_scale_2->data()); - auto *ln_bias_2_data = - (ln_bias_2 == nullptr ? nullptr : ln_bias_2->data()); - auto *bias_dropout_residual_out_data = + // TODO(Xreki): support post layer_norm case when add_residual is false. + PADDLE_ENFORCE_EQ(add_residual, true, + platform::errors::InvalidArgument( + "Attribute add_residual is expected to be true " + "when pre_layer_norm is false.")); + + const U *ln_scale_2_ptr = ln_scale_2 ? ln_scale_2->data() : nullptr; + const U *ln_bias_2_ptr = ln_bias_2 ? ln_bias_2->data() : nullptr; + T *bias_dropout_residual_out_ptr = bias_dropout_residual_out->mutable_data(ctx.GetPlace()); - auto *ln_mean_2_data = ln_mean_2->mutable_data(ctx.GetPlace()); - auto *ln_var_2_data = ln_var_2->mutable_data(ctx.GetPlace()); + U *ln_mean_2_ptr = ln_mean_2->mutable_data(ctx.GetPlace()); + U *ln_var_2_ptr = ln_var_2->mutable_data(ctx.GetPlace()); // output = layernorm(residual + dropout(input + bias)) fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - ctx.cuda_device_context(), out_linear_out_data, x_data, - out_linear_bias_data, ln_scale_2_data, ln_bias_2_data, - bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data, - ln_mean_2_data, ln_var_2_data); + ctx.cuda_device_context(), out_linear_out_data, residual_ptr, + out_linear_bias_data, ln_scale_2_ptr, ln_bias_2_ptr, + bias_dropout_residual_out_ptr, dropout_mask_out_data, final_out_data, + ln_mean_2_ptr, ln_var_2_ptr); } } }; @@ -419,16 +425,17 @@ class FusedAttentionGradKernel : public framework::OpKernel { int output_size = 3 * hidden_size; int input_size = dim_embed; + bool add_residual = ctx.Attr("add_residual"); Tensor d_residual; - d_residual.Resize(input_x_dims); - T *d_residual_data = d_residual.mutable_data(ctx.GetPlace()); + T *d_residual_data = nullptr; + if (add_residual) { + d_residual.Resize(input_x_dims); + d_residual_data = d_residual.mutable_data(ctx.GetPlace()); + } bool transA = false; bool transB = true; - bool compute_qkv_bias = true; - if (qkv_bias == nullptr) { - compute_qkv_bias = false; - } + bool compute_qkv_bias = qkv_bias ? true : false; auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed); auto qkv_compute = @@ -539,16 +546,14 @@ class FusedAttentionGradKernel : public framework::OpKernel { // tensor model parallel AllReduce(*d_x, ring_id, ctx.cuda_device_context()); } - // gradient accumulation - std::vector ins; - std::vector outs; - ins.emplace_back(&d_residual); - ins.emplace_back(d_x); - outs.emplace_back(d_x); - int elewise_add_axis = -1; - phi::funcs::BroadcastKernel( - ctx.cuda_device_context(), ins, &outs, elewise_add_axis, - phi::funcs::AddFunctor()); + + if (add_residual) { + // gradient accumulation + std::vector ins = {&d_residual, d_x}; + std::vector outs = {d_x}; + phi::funcs::ElementwiseKernel(ctx.cuda_device_context(), ins, &outs, + phi::funcs::AddFunctor()); + } } }; diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h index 6dc1c446bd7..ab660362573 100644 --- a/paddle/fluid/operators/fused/fused_dropout_helper.h +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -150,9 +150,10 @@ class FusedDropoutHelper { LaunchResidualDropoutBiasGrad( d_out, mask, dropout_param_.dropout_prob, dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx); - auto cuda_place = ctx.GetPlace(); - memory::Copy(cuda_place, d_residual, cuda_place, d_out, - rows_ * cols_ * sizeof(T), ctx.stream()); + if (d_residual) { + memory::Copy(ctx.GetPlace(), d_residual, ctx.GetPlace(), d_out, + rows_ * cols_ * sizeof(T), ctx.stream()); + } } // out = dropout(activation(src + bias)) diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cc b/paddle/fluid/operators/fused/fused_feedforward_op.cc index d3cc1b91276..138515b21d9 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cc +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cc @@ -194,6 +194,7 @@ class FusedFeedForwardOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(false); AddAttr("dropout1_seed", "Dropout1 random seed.").SetDefault(0); AddAttr("dropout2_seed", "Dropout2 random seed.").SetDefault(0); + AddAttr("add_residual", "Whether to add residual.").SetDefault(true); AddAttr("ring_id", "ring id for tensor model parallel.") .SetDefault(-1); AddComment(R"DOC( @@ -367,3 +368,11 @@ REGISTER_OPERATOR(fused_feedforward, ops::FusedFeedForwardOp, ops::FusedFeedForwardOpGradMaker, ops::FusedFeedForwardOpGradMaker); REGISTER_OPERATOR(fused_feedforward_grad, ops::FusedFeedForwardOpGrad); + +REGISTER_OP_VERSION(fused_feedforward) + .AddCheckpoint( + R"ROC( + Add a new attribute [add_residual] )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "add_residual", "A flag to indicate whether to add residual.", + true)); diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cu b/paddle/fluid/operators/fused/fused_feedforward_op.cu index 675ec29da67..e136501a8a3 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cu +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cu @@ -69,7 +69,8 @@ class FusedFeedForwardKernel : public framework::OpKernel { blas.MatMul(a, mat_dim_a, b, mat_dim_b, alpha, c, T(0)); } - void FFN(const framework::Tensor& x, const framework::Tensor& linear1_weight, + void FFN(const platform::CUDADeviceContext& ctx, const framework::Tensor& x, + const framework::Tensor& linear1_weight, const framework::Tensor* linear1_bias, const framework::Tensor& linear2_weight, const framework::Tensor* linear2_bias, @@ -84,10 +85,9 @@ class FusedFeedForwardKernel : public framework::OpKernel { framework::Tensor* dropout1_out, framework::Tensor* dropout2_out, const int bsz_seq, const int d_model, const int dim_feedforward, const std::string& act_method, const bool pre_layer_norm, - const float epsilon1, const float epsilon2, const int ring_id, - const DropoutParam& dropout_param1, - const DropoutParam& dropout_param2, - const platform::CUDADeviceContext& ctx) const { + const float epsilon1, const float epsilon2, const bool add_residual, + const int ring_id, const DropoutParam& dropout_param1, + const DropoutParam& dropout_param2) const { FusedDropoutLayerNormHelper pre_layernorm_helper( bsz_seq, d_model, epsilon1); FusedDropoutHelper fused_act_dropout_helper( @@ -127,15 +127,22 @@ class FusedFeedForwardKernel : public framework::OpKernel { // tensor model parallel AllReduce(linear2_out, ring_id, ctx); + const T* residual_ptr = add_residual ? x.data() : nullptr; if (!pre_layer_norm) { + // TODO(Xreki): support post layer_norm case when add_residual is false. + PADDLE_ENFORCE_EQ(add_residual, true, + platform::errors::InvalidArgument( + "Attribute add_residual is expected to be true " + "when pre_layer_norm is false.")); + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - ctx, linear2_out.data(), x.data(), linear2_bias_ptr, + ctx, linear2_out.data(), residual_ptr, linear2_bias_ptr, ln2_scale_ptr, ln2_bias_ptr, dropout2_out->data(), dropout2_mask->data(), out->data(), ln2_mean->data(), ln2_variance->data()); } else { fused_dropout_layernorm_helper.ResidualDropoutBias( - ctx, linear2_out.data(), x.data(), linear2_bias_ptr, + ctx, linear2_out.data(), residual_ptr, linear2_bias_ptr, out->data(), dropout2_mask->data()); } } @@ -183,6 +190,7 @@ class FusedFeedForwardKernel : public framework::OpKernel { const float epsilon1 = context.Attr("ln1_epsilon"); const float epsilon2 = context.Attr("ln2_epsilon"); const int ring_id = context.Attr("ring_id"); + const bool add_residual = context.Attr("add_residual"); DropoutParam dropout_param1(context, 1); DropoutParam dropout_param2(context, 2); @@ -214,12 +222,12 @@ class FusedFeedForwardKernel : public framework::OpKernel { int dim_feedforward = dim[dim.size() - 1]; int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_; - FFN(*x, *linear1_weight, linear1_bias, *linear2_weight, linear2_bias, - ln1_scale, ln1_bias, ln2_scale, ln2_bias, out, dropout1_mask, - dropout2_mask, ln1_mean, ln1_variance, ln2_mean, ln2_variance, - linear1_out, ln1_out, dropout1_out, dropout2_out, bsz_seq, d_model, - dim_feedforward, act_method, pre_layer_norm, epsilon1, epsilon2, - ring_id, dropout_param1, dropout_param2, context.cuda_device_context()); + FFN(context.cuda_device_context(), *x, *linear1_weight, linear1_bias, + *linear2_weight, linear2_bias, ln1_scale, ln1_bias, ln2_scale, ln2_bias, + out, dropout1_mask, dropout2_mask, ln1_mean, ln1_variance, ln2_mean, + ln2_variance, linear1_out, ln1_out, dropout1_out, dropout2_out, bsz_seq, + d_model, dim_feedforward, act_method, pre_layer_norm, epsilon1, + epsilon2, add_residual, ring_id, dropout_param1, dropout_param2); } }; @@ -243,8 +251,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { } void FFNGrad( - const framework::Tensor& d_out, const framework::Tensor& x, - const framework::Tensor& dropout1_mask, + const platform::CUDADeviceContext& ctx, const framework::Tensor& d_out, + const framework::Tensor& x, const framework::Tensor& dropout1_mask, const framework::Tensor& dropout2_mask, const framework::Tensor& linear1_out, const framework::Tensor* ln1_out, const framework::Tensor& dropout1_out, @@ -264,7 +272,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { const int dim_feedforward, const DropoutParam& dropout_param1, const DropoutParam& dropout_param2, const std::string& act_method, const bool pre_layer_norm, const float epsilon1, const float epsilon2, - const int ring_id, const platform::CUDADeviceContext& ctx) const { + const bool add_residual, const int ring_id) const { FusedDropoutLayerNormHelper pre_layernorm_helper( bsz_seq, d_model, epsilon1); FusedDropoutHelper fused_act_dropout_helper( @@ -296,19 +304,22 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { framework::Tensor d_linear2_out, d_dropout2_out, d_residual; d_linear2_out.mutable_data({bsz_seq, d_model}, place); d_dropout2_out.mutable_data({bsz_seq, d_model}, place); - d_residual.mutable_data(d_x->dims(), place); + T* d_residual_ptr = nullptr; + if (add_residual) { + d_residual_ptr = d_residual.mutable_data(d_x->dims(), place); + } if (pre_layer_norm) { fused_dropout_layernorm_helper.ResidualDropoutBiasGrad( ctx, d_out.data(), dropout2_mask.data(), - d_linear2_out.data(), d_residual.data(), d_linear2_bias_ptr); + d_linear2_out.data(), d_residual_ptr, d_linear2_bias_ptr); } else { fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( ctx, d_out.data(), dropout2_out.data(), dropout2_mask.data(), ln2_gamma_ptr, ln2_mean->data(), ln2_variance->data(), d_dropout2_out.data(), d_ln2_gamma_ptr, d_ln2_beta_ptr, d_linear2_out.data(), d_linear2_bias_ptr, - d_residual.data()); + d_residual_ptr); } framework::Tensor d_dropout1_out; @@ -339,14 +350,14 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { // tensor model parallel AllReduce(*d_x, ring_id, ctx); } - std::vector ins(2); - std::vector outs(1); - ins[0] = &d_residual; - ins[1] = d_x; - outs[0] = d_x; - int elewise_add_axis = -1; - phi::funcs::BroadcastKernel( - ctx, ins, &outs, elewise_add_axis, phi::funcs::AddFunctor()); + + if (add_residual) { + // gradient accumulation + std::vector ins = {&d_residual, d_x}; + std::vector outs = {d_x}; + phi::funcs::ElementwiseKernel(ctx, ins, &outs, + phi::funcs::AddFunctor()); + } } void Compute(const framework::ExecutionContext& context) const override { @@ -410,6 +421,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { const float epsilon1 = context.Attr("ln1_epsilon"); const float epsilon2 = context.Attr("ln2_epsilon"); + const bool add_residual = context.Attr("add_residual"); const int ring_id = context.Attr("ring_id"); const std::string act_method = context.Attr("act_method"); DropoutParam dropout_param1(context, 1); @@ -447,15 +459,15 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1]; int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_; - FFNGrad(d_out, x, dropout1_mask, dropout2_mask, linear1_out, ln1_out, - dropout1_out, dropout2_out, linear1_weight, linear1_bias, - linear2_weight, ln1_scale, ln1_bias, ln1_mean, ln1_variance, - ln2_scale, ln2_bias, ln2_mean, ln2_variance, d_x, d_linear1_weight, - d_linear1_bias, d_linear2_weight, d_linear2_bias, d_ln1_scale, - d_ln1_bias, d_ln2_scale, d_ln2_bias, bsz_seq, d_model, - dim_feedforward, dropout_param1, dropout_param2, act_method, - pre_layer_norm, epsilon1, epsilon2, ring_id, - context.cuda_device_context()); + FFNGrad(context.cuda_device_context(), d_out, x, dropout1_mask, + dropout2_mask, linear1_out, ln1_out, dropout1_out, dropout2_out, + linear1_weight, linear1_bias, linear2_weight, ln1_scale, ln1_bias, + ln1_mean, ln1_variance, ln2_scale, ln2_bias, ln2_mean, ln2_variance, + d_x, d_linear1_weight, d_linear1_bias, d_linear2_weight, + d_linear2_bias, d_ln1_scale, d_ln1_bias, d_ln2_scale, d_ln2_bias, + bsz_seq, d_model, dim_feedforward, dropout_param1, dropout_param2, + act_method, pre_layer_norm, epsilon1, epsilon2, add_residual, + ring_id); } }; } // namespace operators diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index 0cc31e6fc32..eae1d13dcd0 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -140,9 +140,12 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, // dropout_prob == 1.0f if (std::abs(dropout_prob - 1.0f) < 1e-5) { if (residual == dst) return; - auto cuda_place = ctx.GetPlace(); - memory::Copy(cuda_place, dst, cuda_place, residual, rows * cols * sizeof(T), - ctx.stream()); + if (residual) { + memory::Copy(ctx.GetPlace(), dst, ctx.GetPlace(), residual, + rows * cols * sizeof(T), ctx.stream()); + } else { + SetZero(ctx, dst, rows * cols); + } if (!is_test) { SetZero(ctx, mask_data, rows * cols); } diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu index caceac1228e..63a364cc182 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu @@ -29,8 +29,10 @@ PD_DECLARE_KERNEL(dropout_grad, GPU, ALL_LAYOUT); namespace framework = paddle::framework; namespace platform = paddle::platform; +bool CheckEqual(float value, float ref) { return std::abs(value - ref) < 1e-5; } + /** - * @brief the unittest of fusedresidualdropoutbias + * @brief the unittest of FusedResidualDropoutBias * 1. random input data * 2. add bias, call paddle dropout op, add residual, and get the base result * 3. call FusedResidualDropoutBias function get fused result @@ -38,7 +40,7 @@ namespace platform = paddle::platform; */ template -struct TestFusedResidualDropoutBias { +struct FusedResidualDropoutBiasTester { uint32_t rows; uint32_t cols; uint64_t seed; @@ -46,6 +48,8 @@ struct TestFusedResidualDropoutBias { bool is_upscale_in_train; bool is_test; // default false, Set to true for inference only bool has_bias = true; + bool add_residual = true; + framework::Tensor src, residual, bias, out, mask; framework::Tensor dsrc, dbias; @@ -56,37 +60,33 @@ struct TestFusedResidualDropoutBias { platform::CUDAPlace place; platform::CUDADeviceContext *ctx; - TestFusedResidualDropoutBias() { + FusedResidualDropoutBiasTester() { rows = 32; cols = 32; seed = 0; dropout_prob = 0.0; is_upscale_in_train = false; is_test = false; - has_bias = true; platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto device_ctx = pool.Get(place); ctx = reinterpret_cast(device_ctx); } - TestFusedResidualDropoutBias(int rows_, int cols_, uint64_t seed_ = 0, - float dropout_prob_ = 0.0, - bool is_upscale_in_train_ = false, - bool is_test_ = false) { - rows = rows_; - cols = cols_; - seed = seed_; - dropout_prob = dropout_prob_; - is_upscale_in_train = is_upscale_in_train_; - is_test = is_test_; - has_bias = true; + FusedResidualDropoutBiasTester(int rows, int cols, uint64_t seed = 0, + float dropout_prob = 0.0, + bool is_upscale_in_train = false, + bool is_test = false) + : rows(rows), + cols(cols), + seed(seed), + dropout_prob(dropout_prob), + is_upscale_in_train(is_upscale_in_train), + is_test(is_test) { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto device_ctx = pool.Get(place); ctx = reinterpret_cast(device_ctx); } - ~TestFusedResidualDropoutBias() {} - void SetUp() { const int n = rows * cols; correct_out.resize(n); @@ -95,7 +95,9 @@ struct TestFusedResidualDropoutBias { correct_dbias.resize(cols); src_vec.resize(n); - residual_vec.resize(n); + if (add_residual) { + residual_vec.resize(n); + } bias_vec.resize(cols); std::default_random_engine random(time(NULL)); std::uniform_real_distribution dis(0.0, 1.0); @@ -103,7 +105,9 @@ struct TestFusedResidualDropoutBias { for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { src_vec[i * cols + j] = static_cast(dis(random)); - residual_vec[i * cols + j] = static_cast(dis(random)); + if (add_residual) { + residual_vec[i * cols + j] = static_cast(dis(random)); + } if (i == 0) { bias_vec[j] = dis(random); } @@ -112,47 +116,49 @@ struct TestFusedResidualDropoutBias { framework::TensorFromVector(src_vec, *ctx, &src); src.Resize({rows, cols}); - framework::TensorFromVector(residual_vec, *ctx, &residual); - residual.Resize({rows, cols}); + if (add_residual) { + framework::TensorFromVector(residual_vec, *ctx, &residual); + residual.Resize({rows, cols}); + } if (has_bias) { framework::TensorFromVector(bias_vec, *ctx, &bias); bias.Resize({cols}); } - { - out.mutable_data({rows, cols}, place); - mask.mutable_data({rows, cols}, place); - dsrc.mutable_data({rows, cols}, place); + out.mutable_data({rows, cols}, place); + mask.mutable_data({rows, cols}, place); + dsrc.mutable_data({rows, cols}, place); - if (has_bias) { - dbias.mutable_data({cols}, place); - } + if (has_bias) { + dbias.mutable_data({cols}, place); } } void BaseForward() { - std::vector out1(rows * cols), out2(rows * cols); if (has_bias) { // add bias + std::vector bias_out(rows * cols); for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { - out1[i * cols + j] = src_vec[i * cols + j] + bias_vec[j]; + bias_out[i * cols + j] = src_vec[i * cols + j] + bias_vec[j]; } } // call dropout - Dropout(out1, src.dims(), &out2, &correct_mask, *ctx, seed, + Dropout(bias_out, src.dims(), &correct_out, &correct_mask, *ctx, seed, dropout_prob, is_upscale_in_train, is_test); } else { - Dropout(src_vec, src.dims(), &out2, &correct_mask, *ctx, seed, + Dropout(src_vec, src.dims(), &correct_out, &correct_mask, *ctx, seed, dropout_prob, is_upscale_in_train, is_test); } ctx->Wait(); PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError()); - // add residual - for (int i = 0; i < rows; i++) { - for (int j = 0; j < cols; j++) { - correct_out[i * cols + j] = - residual_vec[i * cols + j] + out2[i * cols + j]; + if (add_residual) { + // add residual + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + int idx = i * cols + j; + correct_out[idx] = residual_vec[idx] + correct_out[idx]; + } } } } @@ -178,13 +184,11 @@ struct TestFusedResidualDropoutBias { 1) * VecSize; - T *bias_ptr = nullptr; - if (has_bias) { - bias_ptr = bias.data(); - } + T *bias_ptr = has_bias ? bias.data() : nullptr; + T *residual_ptr = add_residual ? residual.data() : nullptr; paddle::operators::LaunchResidualDropoutBias( rows, cols, increment, seed, dropout_prob, is_test, is_upscale_in_train, - src.data(), residual.data(), bias_ptr, mask.data(), + src.data(), residual_ptr, bias_ptr, mask.data(), out.data(), *ctx); ctx->Wait(); PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError()); @@ -195,10 +199,7 @@ struct TestFusedResidualDropoutBias { return; } - T *bias_ptr = nullptr; - if (has_bias) { - bias_ptr = dbias.data(); - } + T *bias_ptr = has_bias ? dbias.data() : nullptr; paddle::operators::LaunchResidualDropoutBiasGrad( out.data(), mask.data(), dropout_prob, is_upscale_in_train, rows, cols, dsrc.data(), bias_ptr, *ctx); @@ -214,17 +215,19 @@ struct TestFusedResidualDropoutBias { void CheckOut(const T diff) { const int n = rows * cols; - std::vector _out(n); - std::vector _mask(n); - framework::TensorToVector(out, *ctx, &_out); + std::vector fused_out(n); + std::vector fused_mask(n); + framework::TensorToVector(out, *ctx, &fused_out); if (!is_test) { - framework::TensorToVector(mask, *ctx, &_mask); + framework::TensorToVector(mask, *ctx, &fused_mask); } ctx->Wait(); for (int i = 0; i < n; i++) { - EXPECT_LT(std::abs(_out[i] - correct_out[i]), diff); - if (!is_test) EXPECT_EQ(_mask[i], correct_mask[i]); + EXPECT_LT(std::abs(fused_out[i] - correct_out[i]), diff); + if (!is_test) { + EXPECT_EQ(fused_mask[i], correct_mask[i]); + } } } @@ -255,16 +258,21 @@ struct TestFusedResidualDropoutBias { // test the shape and bias template -static void BaseTest(const bool is_fp16 = false) { +static void BaseTest() { const int rows = 16; - T default_diff = !is_fp16 ? static_cast(1e-5) : static_cast(1e-1); + T max_diff = static_cast(0); + if (std::is_same::value) { + max_diff = static_cast(1e-1); + } else { + max_diff = static_cast(1e-5); + } for (auto cols : {16, 17}) { for (auto has_bias : {true, false}) { - TestFusedResidualDropoutBias test(rows, cols); + FusedResidualDropoutBiasTester test(rows, cols); test.has_bias = has_bias; test.Run(); - test.CheckOut(default_diff); - test.CheckGrad(default_diff); + test.CheckOut(max_diff); + test.CheckGrad(max_diff); } } } @@ -274,15 +282,15 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias) { BaseTest(); } TEST(FusedDropout, GPUFusedResidualDropoutBiasDouble) { BaseTest(); } TEST(FusedDropout, GPUFusedResidualDropoutBiasFp16) { - BaseTest(true); + BaseTest(); } TEST(FusedDropout, GPUFusedResidualDropoutBiasIsUpscaleInTrain) { const int rows = 16; const int cols = 16; for (auto is_upscale_in_train : {true, false}) { - TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, - is_upscale_in_train, false); + FusedResidualDropoutBiasTester test(rows, cols, 0, 1.0, + is_upscale_in_train, false); test.Run(); test.CheckOut(static_cast(1e-5)); test.CheckGrad(static_cast(1e-5)); @@ -292,7 +300,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasIsUpscaleInTrain) { TEST(FusedDropout, GPUFusedResidualDropoutBiasIsTest) { const int rows = 16; const int cols = 16; - TestFusedResidualDropoutBias test(rows, cols, 0, 0.35, true, true); + FusedResidualDropoutBiasTester test(rows, cols, 0, 0.35, true, true); test.Run(); test.CheckOut(static_cast(1e-5)); test.CheckGrad(static_cast(1e-5)); @@ -301,16 +309,32 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasIsTest) { TEST(FusedDropout, GPUFusedResidualDropoutBiasSeed) { const int rows = 16; const int cols = 16; - TestFusedResidualDropoutBias test(rows, cols, 125, 0.0, false, false); + FusedResidualDropoutBiasTester test(rows, cols, 125, 0.0, false, + false); test.Run(); test.CheckOut(static_cast(1e-5)); test.CheckGrad(static_cast(1e-5)); } +TEST(FusedDropout, NoResidual) { + const int rows = 16; + const int cols = 16; + for (float p : {0.0f, 0.5f, 1.0f}) { + FusedResidualDropoutBiasTester test(rows, cols, 0, p, false, false); + test.add_residual = false; + test.Run(); + // For a non 0 or 1 dropout_prob, just test whether it can run successly. + if (CheckEqual(p, 0.0f) || CheckEqual(p, 1.0f)) { + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); + } + } +} + TEST(FusedDropout, GPUFusedResidualDropoutBiasLargeShape) { const int rows = 256; const int cols = 4096; - TestFusedResidualDropoutBias test(rows, cols); + FusedResidualDropoutBiasTester test(rows, cols); test.Run(); test.CheckOut(static_cast(1e-5)); test.CheckGrad(static_cast(1e-3)); @@ -326,8 +350,8 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasLargeShapeFp16) { if (std::getenv("_cols") != nullptr) { cols = atoi(std::getenv("_cols")); } - TestFusedResidualDropoutBias test(rows, cols, 0, 0.0, true, - true); + FusedResidualDropoutBiasTester test(rows, cols, 0, 0.0, + true, true); test.Run(); test.CheckOut(static_cast(1e-1)); test.CheckGrad(static_cast(1e-1)); diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 50746652485..8f490751aa6 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -46,6 +46,7 @@ def fused_feedforward(x, training=True, mode='upscale_in_train', ring_id=-1, + add_residual=True, name=None): r""" This is a fusion operator to compute feed forward layer in transformer model architecture. @@ -90,6 +91,7 @@ def fused_feedforward(x, - train: out = input * mask - inference: out = input * (1.0 - p) ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using tensor parallel. + add_residual (bool, optional): Whether add residual at the end. Default is True. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -134,7 +136,8 @@ def fused_feedforward(x, "dropout2_fix_seed", seed is not None, "dropout1_seed", seed if seed is not None else 0, "dropout2_seed", seed if seed is not None else 0, 'dropout1_implementation', mode, - 'dropout2_implementation', mode, 'ring_id', ring_id) + 'dropout2_implementation', mode, 'add_residual', add_residual, + 'ring_id', ring_id) return out helper = LayerHelper("fused_feedforward") @@ -208,6 +211,7 @@ def fused_feedforward(x, 'dropout2_seed': seed if seed is not None else 0, 'dropout1_implementation': mode, 'dropout2_implementation': mode, + 'add_residual': add_residual, 'ring_id': ring_id, }) return out @@ -378,6 +382,7 @@ def fused_multi_head_attention(x, training=True, mode='upscale_in_train', ring_id=-1, + add_residual=True, name=None): r""" Attention mapps queries and a set of key-value pairs to outputs, and @@ -454,6 +459,7 @@ def fused_multi_head_attention(x, - train: out = input * mask - inference: out = input * (1.0 - p) ring_id (int, optional): For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp + add_residual (bool, optional): Whether add residual at the end. Default is True. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -521,7 +527,8 @@ def fused_multi_head_attention(x, 'dropout_fix_seed', seed is not None, 'attn_dropout_seed', seed if seed is not None else 0, 'dropout_seed', seed if seed is not None else 0, 'attn_dropout_implementation', - mode, 'dropout_implementation', mode, 'ring_id', ring_id) + mode, 'dropout_implementation', mode, 'add_residual', add_residual, + 'ring_id', ring_id) if cache_kv is not None: return final_out, cache_kv_out return final_out @@ -571,6 +578,7 @@ def fused_multi_head_attention(x, 'dropout_seed': seed if seed is not None else 0, 'attn_dropout_implementation': mode, 'dropout_implementation': mode, + 'add_residual': add_residual, 'ring_id': ring_id } -- GitLab