diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index f2ea692ad088085becd56b6ebfdde2af84abe468..5f314b0f925759844e9a4fce94623c1059ecb7fe 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -266,6 +266,13 @@ NameVarBaseMap CastPureFp16Inputs(const std::string& op_type, pair.first != "X") { continue; } + if ((op_type == "fused_attention" || op_type == "fused_feedforward")) { + if (pair.first == "LnScale" || pair.first == "LnBias" || + pair.first == "Ln2Scale" || pair.first == "Ln2Bias" || + pair.first == "Ln1Scale" || pair.first == "Ln1Bias") { + continue; + } + } VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " << GetDtypeStr(*pair.second.cbegin()) << " to " << framework::DataTypeToString(dst_type); diff --git a/paddle/fluid/operators/fused/attn_bias_add.cu.h b/paddle/fluid/operators/fused/attn_bias_add.cu.h index 18ae932c9325a9b197b43815150acf3ca4dd05c2..f7478364cdfc5177780f7576f0667cc07d18e701 100644 --- a/paddle/fluid/operators/fused/attn_bias_add.cu.h +++ b/paddle/fluid/operators/fused/attn_bias_add.cu.h @@ -87,7 +87,8 @@ __global__ void BroadcastKernelBinary( kernel_primitives::ElementwiseBinary( result, arg0, arg1, func); // store - kernel_primitives::WriteData(out + fix, result, num); + kernel_primitives::WriteData(out + fix, result, + num); } // bias add forward impl for "[m, n] + [n] = [m, n]" @@ -267,25 +268,24 @@ __global__ void BiasAddBw1DReduceKernel(const ReduceParamType* temp_sum, } template -void Launch2DColumnReduce(gpuStream_t stream, const int max_threads, - const int reduce_num, const int left_num, - const T* d_out, T* d_bias) { +void Launch2DColumnReduce(const platform::CUDADeviceContext& dev_ctx, + const int max_threads, const int reduce_num, + const int left_num, const T* d_out, T* d_bias) { dim3 block; dim3 grid; bool should_reduce_again = false; int blocking_size = 1; SetConfigForColumnReduce(max_threads, reduce_num, left_num, &blocking_size, &should_reduce_again, &block, &grid); + const auto& stream = dev_ctx.stream(); if (!should_reduce_again) { BiasAddBwSinglePassKernel<<>>(d_out, reduce_num, left_num, d_bias); } else { framework::Tensor tmp_sum; - tmp_sum.mutable_data>( - framework::make_ddim({static_cast( - left_num * grid.y * sizeof(ReduceParamType))}), - paddle::platform::CUDAPlace()); + tmp_sum.Resize({grid.y, left_num}); + tmp_sum.mutable_data>(dev_ctx.GetPlace()); BiasAddBw2DReduceKernel<<>>( d_out, reduce_num, left_num, blocking_size, @@ -311,8 +311,8 @@ void LaunchBiasAddBwKernel(const platform::CUDADeviceContext& dev_ctx, int m, Launch1DColumnReduce(dev_ctx.stream(), max_threads, reduce_num, left_num, d_out, d_bias); } else { - Launch2DColumnReduce(dev_ctx.stream(), max_threads, reduce_num, left_num, - d_out, d_bias); + Launch2DColumnReduce(dev_ctx, max_threads, reduce_num, left_num, d_out, + d_bias); } } diff --git a/paddle/fluid/operators/fused/attn_gemm.h b/paddle/fluid/operators/fused/attn_gemm.h index a2001d0a814922d33ad2d4473b219766b8e04b9a..21875cc52146f65b52c3c4a7a928a75617ee3c50 100644 --- a/paddle/fluid/operators/fused/attn_gemm.h +++ b/paddle/fluid/operators/fused/attn_gemm.h @@ -11,10 +11,13 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/fused/attn_bias_add.cu.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" +#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" +#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" + namespace paddle { namespace operators { @@ -36,8 +39,10 @@ class AttnMatMul { ~AttnMatMul() {} - void ComputeForward(const T* weight_data, const T* input_data, - const T* bias_data, T* output_data, T* bias_out_data) { + void ComputeForward(const framework::Tensor* weight, + const framework::Tensor* input, + const framework::Tensor* bias, framework::Tensor* output, + framework::Tensor* bias_out) { // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major. // here: (transa, transb): nt, input * weight. CBLAS_TRANSPOSE transA = CblasNoTrans; @@ -54,16 +59,25 @@ class AttnMatMul { // here: (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out) auto blas = math::GetBlas(dev_ctx_); blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha, - input_data, weight_data, beta, output_data); + input->data(), weight->data(), beta, output->data()); if (compute_bias_) { // compute output + bias - LaunchBiasAddFwKernel(dev_ctx_, bsz_seq_, output_size_, output_data, - bias_data, bias_out_data); + std::vector ins; + std::vector outs; + ins.emplace_back(output); + ins.emplace_back(bias); + outs.emplace_back(bias_out); + int elewise_add_axis = -1; + LaunchElementwiseCudaKernel( + dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor()); } } - void ComputeBackward(const T* input, const T* weight, const T* d_output, - T* d_input, T* d_weight, T* d_bias) { + void ComputeBackward(const framework::Tensor* input, + const framework::Tensor* weight, + const framework::Tensor* d_output, + framework::Tensor* d_input, framework::Tensor* d_weight, + framework::Tensor* d_bias) { T alpha = static_cast(1.0); T beta = static_cast(0.0); auto blas = math::GetBlas(dev_ctx_); @@ -81,11 +95,11 @@ class AttnMatMul { T* dB_input_1_ptr = nullptr; T* dB_input_2_ptr = nullptr; - T* dB_output_ptr = d_weight; + T* dB_output_ptr = d_weight->data(); T* dA_input_1_ptr = nullptr; T* dA_input_2_ptr = nullptr; - T* dA_output_ptr = d_input; + T* dA_output_ptr = d_input->data(); if (!transA_) { // fw: gemm-nt @@ -104,10 +118,10 @@ class AttnMatMul { dA_n = input_size_; dA_k = output_size_; - blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, d_output, - input, beta, dB_output_ptr); - blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output, - weight, beta, dA_output_ptr); + blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, + d_output->data(), input->data(), beta, dB_output_ptr); + blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, + d_output->data(), weight->data(), beta, dA_output_ptr); } else { // fw: gemm-nn // bw: gemm-tn, dB = A^t * dC dB_transA = CblasTrans; @@ -123,10 +137,10 @@ class AttnMatMul { dA_n = input_size_; dA_k = output_size_; - blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, input, - d_output, beta, dB_output_ptr); - blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output, - weight, beta, dA_output_ptr); + blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, + input->data(), d_output->data(), beta, dB_output_ptr); + blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, + d_output->data(), weight->data(), beta, dA_output_ptr); } } else if (transB_) { PADDLE_THROW(platform::errors::InvalidArgument( @@ -138,7 +152,27 @@ class AttnMatMul { "parameters.")); } if (compute_bias_) { - LaunchBiasAddBwKernel(dev_ctx_, bsz_seq_, output_size_, d_output, d_bias); + // reduce: {0, 1, 2, 3, 4} -> {2, 3, 4} or {0, 1, 2} -> {2} + const auto input_dims = d_output->dims(); + const auto output_dims = d_bias->dims(); + bool support_case_1 = + (input_dims.size() == 5 && output_dims.size() == 3 && + (input_dims[2] == output_dims[0]) && + (input_dims[3] == output_dims[1]) && + (input_dims[4] == output_dims[2])); + bool support_case_2 = + (input_dims.size() == 3 && output_dims.size() == 1 && + (input_dims[2] == output_dims[0])); + if (support_case_1 || support_case_2) { + gpuStream_t stream = dev_ctx_.stream(); + TensorReduceFunctorImpl(*d_output, d_bias, {0, 1}, + stream); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Only support reduce when the input dims are [0,1,2,3,4] and " + "output is [2,3,4]" + "or input is [0,1,2] and output is [2].")); + } } } diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index bef0052a00d6b26050a8cd5764b7b083578e1122..066e7e15e88312740beae3e6f9203d5312b37a16 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -69,7 +69,7 @@ class FMHARef { ~FMHARef() {} void ComputeForward(const Tensor& qkv_input_tensor, - const Tensor& src_mask_tensor, + const Tensor* src_mask_tensor, Tensor* transpose_2_out_tensor, Tensor* qk_out_tensor, Tensor* src_mask_out_tensor, Tensor* softmax_out_tensor, Tensor* dropout_mask_out_tensor, @@ -111,17 +111,17 @@ class FMHARef { blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, q_ptr, k_ptr, beta, qk_out_data, gemm_batch_size, stride_a, stride_b); - - std::vector ins; - std::vector outs; - ins.emplace_back(qk_out_tensor); - ins.emplace_back(&src_mask_tensor); - outs.emplace_back(src_mask_out_tensor); - int elewise_add_axis = -1; int softmax_axis = -1; - if (&src_mask_tensor != nullptr) { + if (src_mask_tensor != nullptr) { + std::vector ins; + std::vector outs; + ins.emplace_back(qk_out_tensor); + ins.emplace_back(src_mask_tensor); + outs.emplace_back(src_mask_out_tensor); + int elewise_add_axis = -1; LaunchElementwiseCudaKernel( dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor()); + SoftmaxForwardCUDAKernelDriver(dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); } else { @@ -165,7 +165,7 @@ class FMHARef { } void ComputeBackward( - const Tensor& transpose_2_out_tensor, const Tensor& src_mask_tensor, + const Tensor& transpose_2_out_tensor, const Tensor* src_mask_tensor, const Tensor& softmax_out_tensor, const Tensor& dropout_mask_out_tensor, const Tensor& dropout_out_tensor, const Tensor& qk_out_tensor, const Tensor& src_mask_out_tensor, const Tensor& fmha_out_grad_tensor, @@ -249,7 +249,7 @@ class FMHARef { softmax_out_grad_tensor); } - if (&src_mask_tensor != nullptr) { + if (src_mask_tensor != nullptr) { SoftmaxBackwardCUDAKernelDriver(dev_ctx_, softmax_out_tensor, *softmax_out_grad_tensor, softmax_axis, src_mask_out_grad_tensor); diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index f7c7129c7732b0a656881d1160ea3cd05054e603..11601a5ce40d5a7d82311e08d95db3d28d478d20 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -27,8 +27,6 @@ class FusedAttentionOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask", - "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", "FusedAttentionOp"); @@ -44,6 +42,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut", "FusedAttentionOp"); + } else { + OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output", + "BiasDropoutResidualOut", "FusedAttentionOp"); } // qkv_out: [batch_size, seq_len, 3, num_head, dim_head] @@ -57,8 +62,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("QKTVOut"), "Output", "QKTVOut", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut", - "FusedAttentionOp"); + + if (ctx->HasInput("SrcMask")) { + OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut", + "FusedAttentionOp"); + } OP_INOUT_CHECK(ctx->HasOutput("SoftmaxOut"), "Output", "SoftmaxOut", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutMaskOut"), "Output", @@ -69,12 +77,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"), "Output", "OutLinearOut", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean", - "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance", - "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output", - "BiasDropoutResidualOut", "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), "Output", "DropoutMaskOut", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp"); @@ -108,6 +111,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("LnOut", ctx->GetInputDim("X")); + } else { + ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X")); } // [batch_size, seq_len, 3, num_head, head_size] ctx->SetOutputDim("QKVOut", @@ -119,7 +126,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { {y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); // [batch, num_head, seq_len, seq_len] ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); - ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + + if (ctx->HasInput("SrcMask")) { + ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + } // the same as QKOut's shape. ctx->SetOutputDim("AttnDropoutOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); @@ -134,12 +144,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]}); ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X")); - ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]}); - ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]}); if (ctx->Attrs().Get("dropout_is_test") == false) { ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X")); } - ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X")); + ctx->SetOutputDim("Y", ctx->GetInputDim("X")); } @@ -310,25 +318,28 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { }); AddComment(R"DOC( - Add fused attention op whose logic is as follows: - // @input: [batch_size, seq_len, 3, num_head, head_dim] - // @final_out: [batch_size, seq_len, num_heads, head_dim] - if (pre_layernorm) - out = layer_norm(input); + Add fused attention op whose logic is as follows: + // @input: [batch_size, seq_len, 3, num_head, head_dim] + // @final_out: [batch_size, seq_len, num_heads, head_dim] + if (pre_layernorm) + out = layer_norm(input); out = compute_qkv(out) + bias; // fmha module - { - out = transpose(out, perm=[2, 0, 3, 1, 4]); - out = q * k^t; - out = attn_mark + out; - out = softmax(out); - out = dropout(out); - out = out * v; - out = transpose(out, perm=[0, 2, 1, 3]); + { + out = transpose(out, perm=[2, 0, 3, 1, 4]); + out = q * k^t; + out = attn_mask + out; + out = softmax(out); + out = dropout(out); + out = out * v; + out = transpose(out, perm=[0, 2, 1, 3]); - } + } out = out_linear(out); - final_out = layer_norm(residual + dropout(bias + out)); + if (pre_layernorm) + final_out = residual + dropout(bias + out); + else + final_out = layer_norm(residual + dropout(bias + out)); )DOC"); } }; @@ -343,20 +354,20 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "GradOp is only callable when attn_dropout_is_test is false")); - OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean", - "FusedAttentionGrad"); - OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance", - "FusedAttentionGrad"); - if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) { - ctx->SetOutputDim(framework::GradVarName("Ln2Scale"), - ctx->GetInputDim("Ln2Scale")); - } - if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) { - ctx->SetOutputDim(framework::GradVarName("Ln2Bias"), - ctx->GetInputDim("Ln2Bias")); - } - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad"); - if (ctx->Attrs().Get("pre_layer_norm") == true) { + if (ctx->Attrs().Get("pre_layer_norm") == false) { + OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean", + "FusedAttentionGrad"); + OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance", + "FusedAttentionGrad"); + if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) { + ctx->SetOutputDim(framework::GradVarName("Ln2Scale"), + ctx->GetInputDim("Ln2Scale")); + } + if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Ln2Bias"), + ctx->GetInputDim("Ln2Bias")); + } + } else { OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean", "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance", @@ -364,12 +375,12 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut", "FusedAttentionGrad"); } + + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", "FusedAttentionGrad"); - OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask", - "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", @@ -400,6 +411,9 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { if (ctx->Attrs().Get("pre_layer_norm") == true) { ctx->SetOutputDim(framework::GradVarName("LnOut"), ctx->GetInputDim("LnOut")); + } else { + ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"), + ctx->GetInputDim("BiasDropoutResidualOut")); } ctx->SetOutputDim(framework::GradVarName("FMHAOut"), ctx->GetInputDim("FMHAOut")); @@ -413,16 +427,17 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ctx->GetInputDim("SoftmaxOut")); ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"), ctx->GetInputDim("AttnDropoutOut")); - ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"), - ctx->GetInputDim("SrcMaskOut")); + + if (ctx->HasOutput(framework::GradVarName("SrcMaskOut"))) { + ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"), + ctx->GetInputDim("SrcMaskOut")); + } ctx->SetOutputDim(framework::GradVarName("QKVOut"), ctx->GetInputDim("QKVOut")); ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"), ctx->GetInputDim("QKVBiasOut")); ctx->SetOutputDim(framework::GradVarName("OutLinearOut"), ctx->GetInputDim("OutLinearOut")); - ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"), - ctx->GetInputDim("BiasDropoutResidualOut")); } protected: @@ -448,7 +463,14 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("X", this->Input("X")); op->SetInput("QKVW", this->Input("QKVW")); op->SetInput("QKVBias", this->Input("QKVBias")); - op->SetInput("SrcMask", this->Input("SrcMask")); + + if (this->HasInput("SrcMask")) { + op->SetInput("SrcMask", this->Input("SrcMask")); + op->SetInput("SrcMaskOut", this->Output("SrcMaskOut")); + op->SetOutput(framework::GradVarName("SrcMaskOut"), + this->OutputGrad("SrcMaskOut")); + } + op->SetInput("OutLinearW", this->Input("OutLinearW")); op->SetInput("OutLinearBias", this->Input("OutLinearBias")); @@ -466,17 +488,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { op->SetOutput(framework::GradVarName("LnBias"), this->InputGrad("LnBias")); } - } - - if (this->HasInput("Ln2Scale")) { - op->SetInput("Ln2Scale", this->Input("Ln2Scale")); - op->SetOutput(framework::GradVarName("Ln2Scale"), - this->InputGrad("Ln2Scale")); - } - if (this->HasInput("Ln2Bias")) { - op->SetInput("Ln2Bias", this->Input("Ln2Bias")); - op->SetOutput(framework::GradVarName("Ln2Bias"), - this->InputGrad("Ln2Bias")); + } else { + if (this->HasInput("Ln2Scale")) { + op->SetInput("Ln2Scale", this->Input("Ln2Scale")); + op->SetOutput(framework::GradVarName("Ln2Scale"), + this->InputGrad("Ln2Scale")); + } + if (this->HasInput("Ln2Bias")) { + op->SetInput("Ln2Bias", this->Input("Ln2Bias")); + op->SetOutput(framework::GradVarName("Ln2Bias"), + this->InputGrad("Ln2Bias")); + } } op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); @@ -499,6 +521,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { if (this->HasOutput("LnVariance")) { op->SetInput("LnVariance", this->Output("LnVariance")); } + } else { + op->SetInput("Ln2Mean", this->Output("Ln2Mean")); + op->SetInput("Ln2Variance", this->Output("Ln2Variance")); + op->SetInput("BiasDropoutResidualOut", + this->Output("BiasDropoutResidualOut")); } op->SetInput("QKVOut", this->Output("QKVOut")); op->SetInput("QKVBiasOut", this->Output("QKVBiasOut")); @@ -508,15 +535,10 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("SoftmaxOut", this->Output("SoftmaxOut")); op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut")); op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut")); - op->SetInput("SrcMaskOut", this->Output("SrcMaskOut")); + op->SetInput("FMHAOut", this->Output("FMHAOut")); op->SetInput("OutLinearOut", this->Output("OutLinearOut")); - - op->SetInput("Ln2Mean", this->Output("Ln2Mean")); - op->SetInput("Ln2Variance", this->Output("Ln2Variance")); op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut")); - op->SetInput("BiasDropoutResidualOut", - this->Output("BiasDropoutResidualOut")); op->SetInput("QKVOut", this->Output("QKVOut")); // backward outputs: dinput @@ -525,7 +547,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { op->SetOutput(framework::GradVarName("LnOut"), this->OutputGrad("LnOut")); } + } else { + op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"), + this->OutputGrad("BiasDropoutResidualOut")); } + op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut")); op->SetOutput(framework::GradVarName("QKVBiasOut"), this->OutputGrad("QKVBiasOut")); @@ -538,12 +564,9 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { this->OutputGrad("SoftmaxOut")); op->SetOutput(framework::GradVarName("AttnDropoutOut"), this->OutputGrad("AttnDropoutOut")); - op->SetOutput(framework::GradVarName("SrcMaskOut"), - this->OutputGrad("SrcMaskOut")); + op->SetOutput(framework::GradVarName("FMHAOut"), this->OutputGrad("FMHAOut")); - op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"), - this->OutputGrad("BiasDropoutResidualOut")); op->SetOutput(framework::GradVarName("OutLinearOut"), this->OutputGrad("OutLinearOut")); } diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 01bc49bcf407937310acc3b81d8f6d0907866467..5bcf12856083698962d0186c8de40ff0363e5dc9 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -95,15 +95,6 @@ class FusedAttentionOpKernel : public framework::OpKernel { const auto qkv_w_dims = qkv_weight->dims(); auto *x_data = input_x->data(); - auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); - auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data()); - auto *ln_mean_data = - pre_layer_norm ? ln_mean->mutable_data(ctx.GetPlace()) : nullptr; - auto *ln_var_data = - pre_layer_norm ? ln_var->mutable_data(ctx.GetPlace()) : nullptr; - auto *ln_out_data = - pre_layer_norm ? ln_out->mutable_data(ctx.GetPlace()) : nullptr; - auto *qkv_weight_data = qkv_weight->data(); auto *qkv_bias_data = qkv_bias->data(); auto *qkv_out_data = qkv_out->mutable_data(ctx.GetPlace()); @@ -114,7 +105,9 @@ class FusedAttentionOpKernel : public framework::OpKernel { transpose_out_2->mutable_data(ctx.GetPlace()); auto *qk_out_data = qk_out->mutable_data(ctx.GetPlace()); auto *qktv_out_data = qktv_out->mutable_data(ctx.GetPlace()); - auto *src_mask_out_data = src_mask_out->mutable_data(ctx.GetPlace()); + auto *src_mask_out_data = + (src_mask == nullptr) ? nullptr + : src_mask_out->mutable_data(ctx.GetPlace()); auto *softmax_out_data = softmax_out->mutable_data(ctx.GetPlace()); auto *attn_dropout_mask_out_data = attn_dropout_mask_out->mutable_data(ctx.GetPlace()); @@ -128,16 +121,8 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *out_linear_out_data = out_linear_out->mutable_data(ctx.GetPlace()); // get data ptr for bias+dropout+residual+layernorm - auto *ln_scale_2_data = - (ln_scale_2 == nullptr ? nullptr : ln_scale_2->data()); - auto *ln_bias_2_data = - (ln_bias_2 == nullptr ? nullptr : ln_bias_2->data()); auto *dropout_mask_out_data = dropout_mask_out->mutable_data(ctx.GetPlace()); - auto *bias_dropout_residual_out_data = - bias_dropout_residual_out->mutable_data(ctx.GetPlace()); - auto *ln_mean_2_data = ln_mean_2->mutable_data(ctx.GetPlace()); - auto *ln_var_2_data = ln_var_2->mutable_data(ctx.GetPlace()); auto *final_out_data = out->mutable_data(ctx.GetPlace()); int batch_size = input_x_dims[0]; @@ -176,29 +161,52 @@ class FusedAttentionOpKernel : public framework::OpKernel { ln_epsilon); if (pre_layer_norm) { + auto *ln_scale_data = + (ln_scale == nullptr ? nullptr : ln_scale->data()); + auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data()); + auto *ln_mean_data = ln_mean->mutable_data(ctx.GetPlace()); + auto *ln_var_data = ln_var->mutable_data(ctx.GetPlace()); + auto *ln_out_data = ln_out->mutable_data(ctx.GetPlace()); + layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data, ln_out_data, ln_mean_data, ln_var_data); - qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data, - qkv_out_data, qkv_bias_out_data); + qkv_compute.ComputeForward(qkv_weight, ln_out, qkv_bias, qkv_out, + qkv_bias_out); } else { - qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data, - qkv_out_data, qkv_bias_out_data); + qkv_compute.ComputeForward(qkv_weight, input_x, qkv_bias, qkv_out, + qkv_bias_out); } - fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out_2, + fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2, qk_out, src_mask_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, qktv_out, fmha_out); + // fmha_out: [batch_size, seq_len, num_head, head_dim] // weight: [embed_dim, embed_dim] // out_linear_out: [batch_size, seq_len, embed_dim] - out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data, - nullptr, out_linear_out_data, nullptr); - // output = layernorm(residual + dropout(input + bias)) - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - ctx.cuda_device_context(), out_linear_out_data, x_data, - out_linear_bias_data, ln_scale_2_data, ln_bias_2_data, - bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data, - ln_mean_2_data, ln_var_2_data); + out_linear_compute.ComputeForward(out_linear_weight, fmha_out, nullptr, + out_linear_out, nullptr); + if (pre_layer_norm) { + // output = (residual + dropout(input + bias)) + fused_dropout_layernorm_helper.ResidualDropoutBias( + ctx.cuda_device_context(), out_linear_out_data, x_data, + out_linear_bias_data, final_out_data, dropout_mask_out_data); + } else { + auto *ln_scale_2_data = + (ln_scale_2 == nullptr ? nullptr : ln_scale_2->data()); + auto *ln_bias_2_data = + (ln_bias_2 == nullptr ? nullptr : ln_bias_2->data()); + auto *bias_dropout_residual_out_data = + bias_dropout_residual_out->mutable_data(ctx.GetPlace()); + auto *ln_mean_2_data = ln_mean_2->mutable_data(ctx.GetPlace()); + auto *ln_var_2_data = ln_var_2->mutable_data(ctx.GetPlace()); + // output = layernorm(residual + dropout(input + bias)) + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + ctx.cuda_device_context(), out_linear_out_data, x_data, + out_linear_bias_data, ln_scale_2_data, ln_bias_2_data, + bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data, + ln_mean_2_data, ln_var_2_data); + } } }; @@ -265,12 +273,10 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *qk_out_data = qk_out->data(); auto *qktv_out_data = qktv_out->data(); auto *softmax_out_data = softmax_out->data(); - auto *src_mask_out_data = src_mask_out->data(); + auto *src_mask_out_data = + (src_mask == nullptr) ? nullptr : src_mask_out->data(); auto *out_linear_out_data = out_linear_out->data(); - auto *ln_2_mean_data = ln_2_mean->data(); - auto *ln_2_var_data = ln_2_var->data(); auto *dropout_mask_out_data = dropout_mask_out->data(); - auto *bias_dropout_residual_out_data = bias_dropout_residual_out->data(); // output's grad auto *d_x = ctx.Output(framework::GradVarName("X")); @@ -302,12 +308,12 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *d_softmax_out_data = d_softmax_out->mutable_data(ctx.GetPlace()); auto *d_attn_dropout_out_data = d_attn_dropout_out->mutable_data(ctx.GetPlace()); - auto *d_src_mask_out_data = d_src_mask_out->mutable_data(ctx.GetPlace()); + auto *d_src_mask_out_data = + (src_mask == nullptr) ? nullptr + : d_src_mask_out->mutable_data(ctx.GetPlace()); auto *d_fmha_out_data = d_fmha_out->mutable_data(ctx.GetPlace()); auto *d_out_linear_out_data = d_out_linear_out->mutable_data(ctx.GetPlace()); - auto *d_bias_dropout_residual_out_data = - d_bias_dropout_residual_out->mutable_data(ctx.GetPlace()); // parameter grad auto *d_qkv_weight = ctx.Output(framework::GradVarName("QKVW")); @@ -325,12 +331,6 @@ class FusedAttentionGradKernel : public framework::OpKernel { d_out_linear_weight->mutable_data(ctx.GetPlace()); auto *d_out_linear_bias_data = d_out_linear_bias->mutable_data(ctx.GetPlace()); - auto *d_ln_2_scale_data = - (d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data( - ctx.GetPlace())); - auto *d_ln_2_bias_data = - (d_ln_2_bias == nullptr ? nullptr - : d_ln_2_bias->mutable_data(ctx.GetPlace())); const auto input_x_dims = input_x->dims(); const auto qkv_w_dims = qkv_weight->dims(); @@ -376,17 +376,37 @@ class FusedAttentionGradKernel : public framework::OpKernel { ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, ln2epsilon); - fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( - ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data, - dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data, - d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data, - d_out_linear_out_data, d_out_linear_bias_data, d_residual_data); + if (pre_layer_norm) { + fused_dropout_layernorm_helper.ResidualDropoutBiasGrad( + ctx.cuda_device_context(), d_y_data, dropout_mask_out_data, + d_out_linear_out_data, d_residual_data, d_out_linear_bias_data); + } else { + auto *ln_2_mean_data = ln_2_mean->data(); + auto *ln_2_var_data = ln_2_var->data(); + auto *bias_dropout_residual_out_data = + bias_dropout_residual_out->data(); + auto *d_ln_2_scale_data = + (d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data( + ctx.GetPlace())); + auto *d_ln_2_bias_data = + (d_ln_2_bias == nullptr ? nullptr : d_ln_2_bias->mutable_data( + ctx.GetPlace())); + auto *d_bias_dropout_residual_out_data = + d_bias_dropout_residual_out->mutable_data(ctx.GetPlace()); + + fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( + ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data, + dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data, + d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data, + d_out_linear_out_data, d_out_linear_bias_data, d_residual_data); + } + + out_linear_compute.ComputeBackward(fmha_out, out_linear_weight, + d_out_linear_out, d_fmha_out, + d_out_linear_weight, nullptr); - out_linear_compute.ComputeBackward(fmha_out_data, out_linear_weight_data, - d_out_linear_out_data, d_fmha_out_data, - d_out_linear_weight_data, nullptr); fmha_ref_compute.ComputeBackward( - *transpose_out_2, *src_mask, *softmax_out, *attn_dropout_mask_out, + *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out, *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, d_transpose_out_2, nullptr, d_qkv_bias_out); @@ -413,15 +433,14 @@ class FusedAttentionGradKernel : public framework::OpKernel { (d_ln_bias == nullptr ? nullptr : d_ln_bias->mutable_data(ctx.GetPlace())); - qkv_compute.ComputeBackward(ln_out_data, qkv_weight_data, - d_qkv_bias_out_data, d_ln_out_data, - d_qkv_weight_data, d_qkv_bias_data); + qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out, d_ln_out, + d_qkv_weight, d_qkv_bias); layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data, ln_mean_data, ln_var_data, d_x_data, d_ln_scale_data, d_ln_bias_data); } else { - qkv_compute.ComputeBackward(x_data, qkv_weight_data, d_qkv_bias_out_data, - d_x_data, d_qkv_weight_data, d_qkv_bias_data); + qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x, + d_qkv_weight, d_qkv_bias); } // gradient accumulation std::vector ins; diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cu b/paddle/fluid/operators/fused/fused_feedforward_op.cu index 3b47e65c4833d6581fef6fa77392ee2bf297e794..a241e3c30272504bcd8492607a30fd2fd81536b2 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cu +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/matmul_v2_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h" #include "paddle/fluid/operators/layer_norm_kernel.cu.h" @@ -261,7 +262,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { framework::Tensor d_linear2_out, d_dropout2_out, d_residual; d_linear2_out.mutable_data({bsz_seq, d_model}, place); d_dropout2_out.mutable_data({bsz_seq, d_model}, place); - d_residual.mutable_data({bsz_seq, d_model}, place); + d_residual.mutable_data(d_x->dims(), place); if (pre_layer_norm) { fused_dropout_layernorm_helper.ResidualDropoutBiasGrad( @@ -301,6 +302,14 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { } else { MatMulGrad(ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight); } + std::vector ins(2); + std::vector outs(1); + ins[0] = &d_residual; + ins[1] = d_x; + outs[0] = d_x; + int elewise_add_axis = -1; + LaunchElementwiseCudaKernel( + ctx, ins, &outs, elewise_add_axis, AddFunctor()); } void Compute(const framework::ExecutionContext& context) const override { diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 59d0df85bb14e9298823fef431d921d208b55e98..18c19266aca91fd7a1c08f92d1dae3355f01225f 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -94,6 +94,7 @@ if(NOT WITH_GPU) LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api) + LIST(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer) endif() if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index c33e1f53dfdb62e0e4561098dbb11f24ea9c931e..b2b5cac2bff965363e087d6d09c38477e0e0847a 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -26,6 +26,9 @@ from paddle import tensor from paddle.fluid import layers import unittest from op_test import OpTest +from paddle.fluid.framework import default_main_program + +default_main_program().random_seed = 42 class TestFusedAttentionOp(OpTest): @@ -66,6 +69,7 @@ class TestFusedAttentionOp(OpTest): self.x_type = np.float32 self.attn_mask_type = np.float64 self.pre_layer_norm = False + self.has_attn_mask = True self.training = True self.batch_size = 8 @@ -84,16 +88,20 @@ class TestFusedAttentionOp(OpTest): def generate_input_data(self): self.query = np.random.rand(self.batch_size, self.query_length, self.embed_dim).astype(self.x_type) - self.attn_mask = np.ones( - (self.batch_size, self.num_heads, self.query_length, - self.key_length), - dtype=self.attn_mask_type) - if self.attn_mask_type == np.int64: - self.attn_mask = np.tril(self.attn_mask) - elif self.attn_mask_type == np.float64: - self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 + if self.has_attn_mask: + self.attn_mask = np.ones( + (self.batch_size, self.num_heads, self.query_length, + self.key_length), + dtype=self.attn_mask_type) + if self.attn_mask_type == np.int64: + self.attn_mask = np.tril(self.attn_mask) + elif self.attn_mask_type == np.float64: + self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 + else: + raise ValueError( + "'attn_mask_type' should be 'int64' or 'float64'.") else: - raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") + self.attn_mask = None self.key, self.value = self.query, self.query self.dout = np.random.random((self.batch_size, self.query_length, @@ -102,7 +110,10 @@ class TestFusedAttentionOp(OpTest): def GetBaselineOut(self): paddle.disable_static(place=paddle.CUDAPlace(0)) tensor_query = paddle.to_tensor(self.query, stop_gradient=False) - attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + if self.has_attn_mask: + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + else: + attn_mask = None residual = tensor_query ln1_out = tensor_query @@ -147,8 +158,8 @@ class TestFusedAttentionOp(OpTest): residual_out = residual + self.dropout(out) if not self.pre_layer_norm: final_out = self.norm1(residual_out) - if self.pre_layer_norm: - final_out = self.norm2(residual_out) + else: + final_out = residual_out paddle.autograd.backward( [final_out], [paddle.to_tensor(self.dout)], retain_graph=True) return final_out, tensor_query.grad @@ -187,7 +198,10 @@ class TestFusedAttentionOp(OpTest): qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) x = paddle.to_tensor(self.query, stop_gradient=False) - attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + if self.has_attn_mask: + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + else: + attn_mask = None qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False) qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) epsilon = 1e-05 @@ -208,9 +222,9 @@ class TestFusedAttentionOp(OpTest): final_out_ref, x_grad_ref = self.GetBaselineOut() final_out, x_grad = self.GetFusedAttentionOut() np.testing.assert_allclose( - final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5) + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4) np.testing.assert_allclose( - x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-5) + x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4) class TestFusedAttentionOpPreLn(TestFusedAttentionOp): @@ -218,6 +232,7 @@ class TestFusedAttentionOpPreLn(TestFusedAttentionOp): self.x_type = np.float32 self.attn_mask_type = np.float64 self.pre_layer_norm = True + self.has_attn_mask = True self.training = True self.batch_size = 8 @@ -237,9 +252,39 @@ class TestFusedAttentionOpPreLn(TestFusedAttentionOp): final_out_ref, x_grad_ref = self.GetBaselineOut() final_out, x_grad = self.GetFusedAttentionOut() np.testing.assert_allclose( - final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1) + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4) np.testing.assert_allclose( - x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1) + x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4) + + +class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp): + def config(self): + self.x_type = np.float32 + self.attn_mask_type = np.float64 + self.pre_layer_norm = True + self.has_attn_mask = False + self.training = True + + self.batch_size = 8 + self.query_length = 128 + self.head_dim = 64 + self.num_heads = 16 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.weight_attr = None + self.bias_attr = None + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length + + def test_fused_attention_op(self): + final_out_ref, x_grad_ref = self.GetBaselineOut() + final_out, x_grad = self.GetFusedAttentionOut() + np.testing.assert_allclose( + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4) + np.testing.assert_allclose( + x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4) class TestFusedAttentionOpFp16(TestFusedAttentionOp): @@ -247,6 +292,7 @@ class TestFusedAttentionOpFp16(TestFusedAttentionOp): self.x_type = np.float16 self.attn_mask_type = np.float64 self.pre_layer_norm = False + self.has_attn_mask = True self.training = True self.batch_size = 8 diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py index 5fa9446763b1fe711490806c60a36754ac4e2cb7..70c2ff5cbc8f23b09a5622795610b5800b4dfe9b 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py @@ -89,27 +89,32 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, qkv_weight = qkv_weight.reshape(qkv_weight.shape[0], qkv_weight.shape[1] * qkv_weight.shape[2] * qkv_weight.shape[3]) + qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] * + qkv_bias.shape[2]) if (pre_layer_norm): ln_out = ln_out.reshape(batch_size * seq_len, embed_dim) qkv = fc(ln_out, qkv_weight) + qkv_bias_out = qkv + qkv_bias ln_out = ln_out.reshape(batch_size, seq_len, embed_dim) else: query = query.reshape(batch_size * seq_len, embed_dim) qkv = fc(query, qkv_weight) + qkv_bias_out = qkv + qkv_bias query = query.reshape(batch_size, seq_len, embed_dim) - qkv = qkv.reshape(batch_size, seq_len, 3, num_head, head_dim) + qkv_bias_out = qkv_bias_out.reshape(batch_size, seq_len, 3, num_head, + head_dim) # q*k^t - qkv = qkv.transpose( + qkv_bias_out = qkv_bias_out.transpose( (2, 0, 1, 3, 4)) # 3, batch_size, seq_len, num_head, head_dim - qkv = qkv.transpose( + qkv_bias_out = qkv_bias_out.transpose( (0, 1, 3, 2, 4)) # 3, batch_size, num_head, seq_len, head_dim - q = qkv[0:1, ::] + q = qkv_bias_out[0:1, ::] q = q.reshape(batch_size, num_head, seq_len, head_dim) - k = qkv[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim] + k = qkv_bias_out[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim] k = k.reshape(batch_size, num_head, seq_len, head_dim) - v = qkv[2::] + v = qkv_bias_out[2::] v = v.reshape(batch_size, num_head, seq_len, head_dim) k = k.transpose([0, 1, 3, 2]) #[batch_size, num_head, head_dim, seq_len] @@ -138,9 +143,11 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, out_linear_bias_out = out_linear_out + out_linear_bias out_linear_bias_dropout_out = out_linear_bias_out out_linear_bias_dropout_residual_out = query + out_linear_bias_dropout_out - out_linear_bias_dropout_residual_ln_out = layer_norm( - out_linear_bias_dropout_residual_out, True, True, ln_2_scale, ln_2_bias) - return out_linear_bias_dropout_residual_ln_out + if not pre_layer_norm: + out_linear_bias_dropout_residual_out = layer_norm( + out_linear_bias_dropout_residual_out, True, True, ln_2_scale, + ln_2_bias) + return out_linear_bias_dropout_residual_out class TestFusedAttentionAPI(unittest.TestCase): @@ -152,6 +159,7 @@ class TestFusedAttentionAPI(unittest.TestCase): self.x_type = np.float32 self.attn_mask_type = np.float64 self.pre_layer_norm = True + self.has_attn_mask = True self.training = True self.need_weight = False @@ -172,27 +180,37 @@ class TestFusedAttentionAPI(unittest.TestCase): def generate_input_data(self): self.query = np.random.rand(self.batch_size, self.query_length, self.embed_dim).astype(self.x_type) - self.attn_mask = np.ones( - (self.batch_size, self.num_heads, self.query_length, - self.key_length), - dtype=self.attn_mask_type) - if self.attn_mask_type == np.int64: - self.attn_mask = np.tril(self.attn_mask) - elif self.attn_mask_type == np.float64: - self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 + if self.has_attn_mask: + self.attn_mask = np.ones( + (self.batch_size, self.num_heads, self.query_length, + self.key_length), + dtype=self.attn_mask_type) + if self.attn_mask_type == np.int64: + self.attn_mask = np.tril(self.attn_mask) + elif self.attn_mask_type == np.float64: + self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 + else: + raise ValueError( + "'attn_mask_type' should be 'int64' or 'float64'.") else: - raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") + self.attn_mask = None self.key, self.value = self.query, self.query def run_imperative(self): + if self.has_attn_mask: + attn_mask_tensor = paddle.to_tensor(self.attn_mask) + else: + attn_mask_tensor = None fused_attn = FusedMultiHeadAttention( self.embed_dim, self.num_heads, self.dropout_prob, self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm, self.need_weight, self.weight_attr, self.bias_attr) + qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype('float32') + fused_attn.qkv_bias.set_value(paddle.to_tensor(qkv_bias)) out = fused_attn( paddle.to_tensor(self.query), paddle.to_tensor(self.query), - paddle.to_tensor(self.query), paddle.to_tensor(self.attn_mask)) + paddle.to_tensor(self.query), attn_mask_tensor) ref_out = compute_reference(self.pre_layer_norm, self.query, self.attn_mask, fused_attn.pre_ln_scale.numpy(), @@ -203,7 +221,7 @@ class TestFusedAttentionAPI(unittest.TestCase): fused_attn.qkv_bias.numpy(), fused_attn.linear_weight.numpy(), fused_attn.linear_bias.numpy()) - self.assertTrue(np.allclose(ref_out, out, rtol=1e-5, atol=1e-5)) + np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-5) def run_static(self): fused_attn = FusedMultiHeadAttention( @@ -215,29 +233,42 @@ class TestFusedAttentionAPI(unittest.TestCase): name='X', shape=[self.batch_size, self.query_length, self.embed_dim], dtype=self.x_type) - attn_mask = paddle.static.data( - name='SrcMask', - shape=[ - self.batch_size, self.num_heads, self.query_length, - self.key_length - ], - dtype=self.attn_mask_type) - final_out = fused_attn(x, x, x, attn_mask) + if self.has_attn_mask: + attn_mask = paddle.static.data( + name='SrcMask', + shape=[ + self.batch_size, self.num_heads, self.query_length, + self.key_length + ], + dtype=self.attn_mask_type) + final_out = fused_attn(x, x, x, attn_mask) + else: + final_out = fused_attn(x, x, x) place = paddle.CUDAPlace(0) exe = paddle.static.Executor(place) exe.run(paddle.static.default_startup_program()) - out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( - paddle.static.default_main_program(), - feed={"X": self.query, - "SrcMask": self.attn_mask}, - fetch_list=[ - final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, - fused_attn.linear_weight, fused_attn.linear_bias, - fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, - fused_attn.ln_scale, fused_attn.ln_bias - ]) - + if self.has_attn_mask: + out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( + paddle.static.default_main_program(), + feed={"X": self.query, + "SrcMask": self.attn_mask}, + fetch_list=[ + final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, + fused_attn.linear_weight, fused_attn.linear_bias, + fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, + fused_attn.ln_scale, fused_attn.ln_bias + ]) + else: + out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( + paddle.static.default_main_program(), + feed={"X": self.query, }, + fetch_list=[ + final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, + fused_attn.linear_weight, fused_attn.linear_bias, + fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, + fused_attn.ln_scale, fused_attn.ln_bias + ]) return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias def test_static_api(self): @@ -249,14 +280,36 @@ class TestFusedAttentionAPI(unittest.TestCase): self.attn_mask, ln_scale, ln_bias, ln_2_scale, ln_2_bias, qkv_weight, qkv_bias, linear_weight, linear_bias) - self.assertTrue( - np.allclose( - np.array(ref_out), np.array(out), rtol=1e-5, atol=1e-5)) + np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-5) def test_dynamic_api(self): paddle.disable_static(place=paddle.CUDAPlace(0)) self.run_imperative() +class TestFusedAttentionAPINoneAttnMask(TestFusedAttentionAPI): + def config(self): + self.x_type = np.float32 + self.attn_mask_type = np.float64 + self.pre_layer_norm = True + self.has_attn_mask = False + self.training = True + self.need_weight = False + + self.batch_size = 1 + self.query_length = 2 + self.head_dim = 2 + self.num_heads = 2 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.weight_attr = None + self.bias_attr = None + + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fused_feedforward_op.py b/python/paddle/fluid/tests/unittests/test_fused_feedforward_op.py index 5ea43d2edf0e668baf7a671ab6bb856eec2f56d6..a533b5d87a5a9be1809f24f8107501f380afdfe7 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_feedforward_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_feedforward_op.py @@ -23,6 +23,7 @@ from paddle.nn.layer.norm import LayerNorm from paddle.nn.layer.common import Linear, Dropout import unittest from op_test import OpTest +from paddle.fluid.framework import default_main_program class TestFusedFFNOp(OpTest): @@ -91,7 +92,7 @@ class TestFusedFFNOp(OpTest): def Base(self): paddle.disable_static() tensor_src = paddle.to_tensor(self.src, stop_gradient=False) - residual = paddle.to_tensor(self.src) + residual = tensor_src if self.pre_layer_norm: ln1_out = self.norm1(tensor_src) linear2_out = self.linear2( @@ -140,6 +141,7 @@ class TestFusedFFNOp(OpTest): return out, x.grad def test_out_and_grad(self): + default_main_program().random_seed = 42 base_out, base_grad = self.Base() fused_out, fused_grad = self.FusedFFN() np.testing.assert_allclose( @@ -192,6 +194,7 @@ class TestFusedFFNOpNormalizeBefore(TestFusedFFNOp): class APITestStaticFusedFFN(unittest.TestCase): def test_static(self): paddle.enable_static() + default_main_program().random_seed = 42 dtype = "float32" layer_norm_dtype = "float32" batch_size = 1 @@ -324,6 +327,18 @@ class TestFusedFFNOpError(unittest.TestCase): self.assertRaises(ValueError, test_dropout_rate_value) + def test_dropout_mode(): + x = paddle.static.data( + name='x3', shape=[1, 10, 10], dtype="float32") + linear1_weight = paddle.static.data( + name='linear1_weight3', shape=[10, 10], dtype="float32") + linear2_weight = paddle.static.data( + name='linear2_weight3', shape=[10, 10], dtype="float32") + incubate_f.fused_feedforward( + x, linear1_weight, linear2_weight, mode='test') + + self.assertRaises(ValueError, test_dropout_mode) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fused_transformer_encoder_layer.py b/python/paddle/fluid/tests/unittests/test_fused_transformer_encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..e0281d6e21e5ad556bf9ff51e4b32117db692e4a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_transformer_encoder_layer.py @@ -0,0 +1,188 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +import paddle +from paddle.incubate.nn import FusedTransformerEncoderLayer +from paddle.nn import TransformerEncoderLayer +from paddle.fluid.framework import default_main_program +import unittest + + +class TestFusedTransformerEncoderLayer(unittest.TestCase): + def setActivation(self): + self.activation = 'gelu' + + def setPreLayerNorm(self): + self.pre_layer_norm = False + + def setAttnMask(self): + self.has_attn_mask = True + + def setUp(self): + self.batch_size = np.random.randint(1, 8) + self.query_length = np.random.randint(1, 128) + self.nhead = 16 + self.head_dim = 4 + self.num_heads = self.nhead + self.d_model = self.head_dim * self.num_heads + self.embed_dim = self.d_model + self.dim_feedforward = np.random.randint(1, 32) + self.dropout_rate = 0 + self.attn_dropout_rate = None + self.act_dropout_rate = None + self.attn_mask_type = np.float64 + self.key_length = self.query_length + self.dtype = 'float32' + self.setActivation() + self.setPreLayerNorm() + self.setAttnMask() + + def fused_weight(self, weight, num_head): + a = paddle.transpose(weight, perm=[1, 0]) + return paddle.reshape( + a, shape=[1, num_head, int(a.shape[0] / num_head), a.shape[1]]) + + def fused_qkv(self, q, k, v, num_head): + fq = self.fused_weight(q, num_head) + fk = self.fused_weight(k, num_head) + fv = self.fused_weight(v, num_head) + return paddle.concat(x=[fq, fk, fv], axis=0) + + def test_out(self): + default_main_program().random_seed = 42 + base_encoder = TransformerEncoderLayer( + self.d_model, self.nhead, self.dim_feedforward, self.dropout_rate, + self.activation, self.attn_dropout_rate, self.act_dropout_rate, + self.pre_layer_norm) + src = np.random.rand(self.batch_size, self.query_length, + self.embed_dim).astype(self.dtype) + + if self.has_attn_mask: + attn_mask = np.ones( + (self.batch_size, self.num_heads, self.query_length, + self.key_length), + dtype=self.attn_mask_type) + attn_mask_tensor = paddle.to_tensor(attn_mask) + else: + attn_mask = None + attn_mask_tensor = None + + dout = np.random.random(src.shape).astype(self.dtype) + + base_out = base_encoder( + paddle.to_tensor( + src, stop_gradient=False), attn_mask_tensor) + paddle.autograd.backward([base_out], [paddle.to_tensor(dout)], True) + + fused_encoder = FusedTransformerEncoderLayer( + self.d_model, self.nhead, self.dim_feedforward, self.dropout_rate, + self.activation, self.attn_dropout_rate, self.act_dropout_rate, + self.pre_layer_norm) + + fused_encoder.ffn._linear1_weight.set_value(base_encoder.linear1.weight) + fused_encoder.ffn._linear1_bias.set_value(base_encoder.linear1.bias) + fused_encoder.ffn._linear2_weight.set_value(base_encoder.linear2.weight) + fused_encoder.ffn._linear2_bias.set_value(base_encoder.linear2.bias) + if self.pre_layer_norm: + fused_encoder.ffn._ln1_scale.set_value(base_encoder.norm2.weight) + fused_encoder.ffn._ln1_bias.set_value(base_encoder.norm2.bias) + else: + fused_encoder.ffn._ln2_scale.set_value(base_encoder.norm2.weight) + fused_encoder.ffn._ln2_bias.set_value(base_encoder.norm2.bias) + + fused_encoder.fused_attn.linear_weight.set_value( + base_encoder.self_attn.out_proj.weight) + fused_encoder.fused_attn.linear_bias.set_value( + base_encoder.self_attn.out_proj.bias) + if self.pre_layer_norm: + fused_encoder.fused_attn.pre_ln_scale.set_value( + base_encoder.norm1.weight) + fused_encoder.fused_attn.pre_ln_bias.set_value( + base_encoder.norm1.bias) + else: + fused_encoder.fused_attn.ln_scale.set_value( + base_encoder.norm1.weight) + fused_encoder.fused_attn.ln_bias.set_value(base_encoder.norm1.bias) + + q = base_encoder.self_attn.q_proj.weight + q_bias = base_encoder.self_attn.q_proj.bias + k = base_encoder.self_attn.k_proj.weight + k_bias = base_encoder.self_attn.k_proj.bias + v = base_encoder.self_attn.v_proj.weight + v_bias = base_encoder.self_attn.v_proj.bias + qkv_weight = self.fused_qkv(q, k, v, self.num_heads) + fused_encoder.fused_attn.qkv_weight.set_value(qkv_weight) + + tmp = paddle.concat(x=[q_bias, k_bias, v_bias], axis=0) + qkv_bias = paddle.reshape( + tmp, + shape=[3, self.num_heads, int(tmp.shape[0] / 3 / self.num_heads)]) + fused_encoder.fused_attn.qkv_bias.set_value(qkv_bias) + + fused_out = fused_encoder( + paddle.to_tensor( + src, stop_gradient=False), attn_mask_tensor) + paddle.autograd.backward([fused_out], [paddle.to_tensor(dout)], True) + + correct_ffn_str = 'd_model={}, dim_feedforward={}, dropout_rate={}, epsilon={}, activation={}, act_dropout_rate={}, normalize_before={}, dtype={}'.format( + self.d_model, self.dim_feedforward, self.dropout_rate, + fused_encoder.ffn._epsilon, self.activation, self.dropout_rate, + self.pre_layer_norm, self.dtype) + self.assertTrue(fused_encoder.ffn.extra_repr(), correct_ffn_str) + + correct_attn_str = 'embed_dim={}, num_heads={}, dropout_rate={}, attn_dropout_rate={}, epsilon={}, kdim={}, vdim={}, normalize_before={}, need_weights={}, dtype={}'.format( + self.embed_dim, self.num_heads, self.dropout_rate, + self.dropout_rate, fused_encoder.fused_attn._epsilon, None, None, + self.pre_layer_norm, False, self.dtype) + self.assertTrue(fused_encoder.fused_attn.extra_repr(), correct_attn_str) + + np.testing.assert_allclose( + fused_out.numpy(), base_out.numpy(), rtol=1e-3, atol=1e-4) + self.assertTrue( + np.allclose( + fused_out.grad.numpy(), + base_out.grad.numpy(), + rtol=1e-3, + atol=1e-4)) + + +class TestFusedTransformerEncoderLayerAct(TestFusedTransformerEncoderLayer): + def setActivation(self): + self.activation = 'relu' + + +class TestFusedTransformerEncoderLayerPreLayerNorm( + TestFusedTransformerEncoderLayer): + def setPreLayerNorm(self): + self.pre_layer_norm = True + + +class TestFusedTransformerEncoderLayerAttnMaskIsNone( + TestFusedTransformerEncoderLayer): + def setAttnMask(self): + self.has_attn_mask = False + + +class TestFusedTransformerEncoderLayerPreLnTrueAttnMaskIsNone( + TestFusedTransformerEncoderLayer): + def setPreLayerNorm(self): + self.pre_layer_norm = True + + def setAttnMask(self): + self.has_attn_mask = False + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 6c447a73c5251a1617353ceace9c6d53550b2727..df9cc68a02d8dc7ad0db20b30fec0414140a73b6 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -13,7 +13,7 @@ # limitations under the License. from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.framework import in_dygraph_mode, default_main_program from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid import core, dygraph_utils from paddle import _C_ops @@ -43,6 +43,8 @@ def fused_feedforward(x, ln1_epsilon=1e-5, ln2_epsilon=1e-5, pre_layer_norm=False, + training=True, + mode='upscale_in_train', name=None): """ This is a fusion operator to compute feed forward layer in transformer model architecture. @@ -74,6 +76,18 @@ def fused_feedforward(x, ln1_epsilon (float, optional): Small float of first layer_norm added to denominator to avoid dividing by zero. Default is 1e-5. ln2_epsilon (float, optional): Small float of second layer_norm added to denominator to avoid dividing by zero. Default is 1e-5. pre_layer_norm (bool, optional): add layer_norm in the pre-processing stage or post-processing state. + training (bool, optional): A flag indicating whether it is in train phrase or not. Default True. + mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer'] + + 1. upscale_in_train(default), upscale the output at training time + + - train: out = input * mask / ( 1.0 - p ) + - inference: out = input + + 2. downscale_in_infer, downscale the output at inference + + - train: out = input * mask + - inference: out = input * (1.0 - p) name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -98,13 +112,27 @@ def fused_feedforward(x, _verify_dropout_rate(dropout1_rate) _verify_dropout_rate(dropout2_rate) + seed = None + if mode not in ('downscale_in_infer', 'upscale_in_train'): + raise ValueError( + "mode argument should be 'downscale_in_infer' or 'upscale_in_train'") + mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer + if in_dygraph_mode(): + if default_main_program().random_seed != 0: + seed = default_main_program().random_seed out, _, _, _, _, _, _, _, _, _, _ = _C_ops.fused_feedforward( x, None, None, linear1_weight, linear1_bias, linear2_weight, linear2_bias, ln1_scale, ln1_bias, ln2_scale, ln2_bias, 'pre_layer_norm', pre_layer_norm, 'ln1_epsilon', ln1_epsilon, 'ln2_epsilon', ln2_epsilon, 'act_method', activation, - 'dropout1_rate', dropout1_rate, 'dropout2_rate', dropout2_rate) + 'dropout1_rate', dropout1_rate, 'dropout2_rate', dropout2_rate, + "dropout1_is_test", not training, "dropout2_is_test", not training, + "dropout1_fix_seed", seed is not None, "dropout2_fix_seed", + seed is not None, "dropout1_seed", seed + if seed is not None else 0, "dropout2_seed", seed + if seed is not None else 0, 'dropout1_implementation', mode, + 'dropout2_implementation', mode) return out helper = LayerHelper("fused_feedforward") @@ -136,6 +164,9 @@ def fused_feedforward(x, dropout2_out = helper.create_variable_for_type_inference( x.dtype, stop_gradient=True) + if (seed is None or seed == 0) and helper.main_program.random_seed != 0: + seed = helper.main_program.random_seed + helper.append_op( type='fused_feedforward', inputs={ @@ -169,6 +200,14 @@ def fused_feedforward(x, 'pre_layer_norm': pre_layer_norm, 'ln1_epsilon': ln1_epsilon, 'ln2_epsilon': ln2_epsilon, + 'dropout1_is_test': not training, + 'dropout2_is_test': not training, + 'dropout1_fix_seed': seed is not None, + 'dropout2_fix_seed': seed is not None, + 'dropout1_seed': seed if seed is not None else 0, + 'dropout2_seed': seed if seed is not None else 0, + 'dropout1_implementation': mode, + 'dropout2_implementation': mode }) return out @@ -188,6 +227,8 @@ def fused_multi_head_attention(x, dropout_rate=0.5, attn_dropout_rate=0.5, ln_epsilon=1e-05, + training=True, + mode='upscale_in_train', name=None): """ Attention mapps queries and a set of key-value pairs to outputs, and @@ -214,7 +255,10 @@ def fused_multi_head_attention(x, out = out * v out = transpose(out, perm=[0, 2, 1, 3]) out = out_linear(out) - out = layer_norm(x + dropout(linear_bias + out)) + if pre_layer_norm: + out = x + dropout(linear_bias + out) + else: + out = layer_norm(x + dropout(linear_bias + out)) Parameters: x (Tensor): The input tensor of fused_multi_head_attention. The shape is @@ -247,6 +291,19 @@ def fused_multi_head_attention(x, 0 for no dropout. Default 0.5. ln_epsilon (float, optional): Small float value added to denominator of layer_norm to avoid dividing by zero. Default is 1e-5. + training (bool, optional): A flag indicating whether it is in train phrase or not. Default True. + mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer'] + + 1. upscale_in_train(default), upscale the output at training time + + - train: out = input * mask / ( 1.0 - p ) + - inference: out = input + + 2. downscale_in_infer, downscale the output at inference + + - train: out = input * mask + - inference: out = input * (1.0 - p) + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: Tensor: The output Tensor, the data type and shape is same as `x`. @@ -280,7 +337,16 @@ def fused_multi_head_attention(x, # [2, 4, 128] print(output.shape) """ + + seed = None + if mode not in ('downscale_in_infer', 'upscale_in_train'): + raise ValueError( + "mode argument should be 'downscale_in_infer' or 'upscale_in_train'") + mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer + if in_dygraph_mode(): + if default_main_program().random_seed != 0: + seed = default_main_program().random_seed # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, # qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out, # linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out @@ -295,7 +361,12 @@ def fused_multi_head_attention(x, linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm', pre_layer_norm, 'epsilon', pre_ln_epsilon, 'dropout_rate', dropout_rate, 'attn_dropout_rate', attn_dropout_rate, 'ln_epsilon', - ln_epsilon) + ln_epsilon, 'attn_dropout_is_test', not training, 'dropout_is_test', + not training, 'attn_dropout_fix_seed', seed is not None, + 'dropout_fix_seed', seed is not None, 'attn_dropout_seed', seed + if seed is not None else 0, 'dropout_seed', seed + if seed is not None else 0, 'attn_dropout_implementation', mode, + 'dropout_implementation', mode) return final_out else: helper = LayerHelper('fused_multi_head_attention', **locals()) @@ -323,13 +394,24 @@ def fused_multi_head_attention(x, if ln_bias: inputs['Ln2Bias'] = [ln_bias] + if (seed is None or seed == 0) and helper.main_program.random_seed != 0: + seed = helper.main_program.random_seed + # set attrs attrs = { 'pre_layer_norm': pre_layer_norm, 'epsilon': pre_ln_epsilon, 'ln_epsilon': ln_epsilon, 'dropout_rate': dropout_rate, - 'attn_dropout_rate': attn_dropout_rate + 'attn_dropout_rate': attn_dropout_rate, + 'attn_dropout_is_test': not training, + 'dropout_is_test': not training, + 'attn_dropout_fix_seed': seed is not None, + 'dropout_fix_seed': seed is not None, + 'attn_dropout_seed': seed if seed is not None else 0, + 'dropout_seed': seed if seed is not None else 0, + 'attn_dropout_implementation': mode, + 'dropout_implementation': mode, } # set outputs diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index a3d8a74844b19b8293789769de38265bee0e0424..d38e8d1193beffeecd35c19fafdf47c10aaf8927 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -43,7 +43,7 @@ class FusedMultiHeadAttention(Layer): `embed_dim`. Default None. vdim (int, optional): The feature size in value. If None, assumed equal to `embed_dim`. Default None. - normalize_before (bool, optional): Indicate whether it is pre_layer_norm + normalize_before (bool, optional): Indicate whether it is pre_layer_norm (True) or post_layer_norm architecture (False). Default False. need_weights (bool, optional): Indicate whether to return the attention weights. Now, only False is supported. Default False. @@ -54,6 +54,8 @@ class FusedMultiHeadAttention(Layer): Default: None, which means the default bias parameter property is used. If it is set to False, this layer will not have trainable bias parameter. See usage for details in :code:`ParamAttr`. + epsilon (float, optional): The small value added to the variance to prevent + division by zero. Default: 1e-05. Examples: @@ -80,6 +82,7 @@ class FusedMultiHeadAttention(Layer): need_weights=False, weight_attr=None, bias_attr=None, + epsilon=1e-5, name=None): super(FusedMultiHeadAttention, self).__init__() @@ -88,13 +91,18 @@ class FusedMultiHeadAttention(Layer): assert num_heads > 0, ("Expected nhead to be greater than 0, " "but recieved {}".format(num_heads)) - attn_dropout_rate = dropout_rate if attn_dropout_rate is None else attn_dropout_rate self.normalize_before = normalize_before self._dtype = self._helper.get_default_dtype() self._weight_attr = weight_attr self._bias_attr = bias_attr + self._epsilon = epsilon + self.embed_dim = embed_dim + self.num_heads = num_heads self.head_dim = embed_dim // num_heads + self.kdim = kdim + self.vdim = vdim + self.need_weights = need_weights assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" assert need_weights == False, "Only support need_weight is False now." @@ -186,15 +194,24 @@ class FusedMultiHeadAttention(Layer): pre_ln_bias=self.pre_ln_bias, ln_scale=self.ln_scale, ln_bias=self.ln_bias, - pre_ln_epsilon=1e-05, + pre_ln_epsilon=self._epsilon, qkv_bias=self.qkv_bias, linear_bias=self.linear_bias, attn_mask=attn_mask, dropout_rate=self.dropout_rate, attn_dropout_rate=self.attn_dropout_rate, - ln_epsilon=1e-05) + ln_epsilon=self._epsilon, + training=self.training, + name=self.name) return out + def extra_repr(self): + name_str = ', name={}'.format(self.name) if self.name else '' + return 'embed_dim={}, num_heads={}, dropout_rate={}, attn_dropout_rate={}, epsilon={}, kdim={}, vdim={}, normalize_before={}, need_weights={}, dtype={}{}'.format( + self.embed_dim, self.num_heads, self.dropout_rate, + self.attn_dropout_rate, self._epsilon, self.kdim, self.vdim, + self.normalize_before, self.need_weights, self._dtype, name_str) + class FusedFeedForward(Layer): """ @@ -203,6 +220,8 @@ class FusedFeedForward(Layer): dim_feedforward (int): The hidden layer size. dropout_rate (float, optional): The dropout probability used in pre-process and post-precess. Default 0.1 + epsilon (float, optional): he small value added to the variance to prevent + division by zero. Default: 1e-05. activation (str, optional): The activation function. Default relu. act_dropout_rate (float, optional): The dropout probability after activition. If None, use the value of `dropout_rate`. Default None @@ -235,11 +254,13 @@ class FusedFeedForward(Layer): d_model, dim_feedforward, dropout_rate=0.1, + epsilon=1e-05, activation="relu", act_dropout_rate=None, normalize_before=False, weight_attr=None, - bias_attr=None): + bias_attr=None, + name=None): super(FusedFeedForward, self).__init__() assert d_model > 0, ( @@ -256,6 +277,7 @@ class FusedFeedForward(Layer): self._act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate self._act_method = activation self._normalize_before = normalize_before + self._epsilon = epsilon self._linear1_weight = self.create_parameter( shape=[d_model, dim_feedforward], @@ -292,15 +314,36 @@ class FusedFeedForward(Layer): default_initializer=Constant(1.0)) self._ln2_bias = self.create_parameter( shape=[d_model], attr=None, is_bias=True) + self.name = name def forward(self, src, cache=None): out = incubate_f.fused_feedforward( - src, self._linear1_weight, self._linear2_weight, self._linear1_bias, - self._linear2_bias, self._ln1_scale, self._ln1_bias, - self._ln2_scale, self._ln2_bias, self._dropout_rate, - self._act_dropout_rate, self._act_method, self._normalize_before) + src, + self._linear1_weight, + self._linear2_weight, + self._linear1_bias, + self._linear2_bias, + self._ln1_scale, + self._ln1_bias, + self._ln2_scale, + self._ln2_bias, + dropout1_rate=self._act_dropout_rate, + dropout2_rate=self._dropout_rate, + activation=self._act_method, + ln1_epsilon=self._epsilon, + ln2_epsilon=self._epsilon, + pre_layer_norm=self._normalize_before, + training=self.training, + name=self.name) return out + def extra_repr(self): + name_str = ', name={}'.format(self.name) if self.name else '' + return 'd_model={}, dim_feedforward={}, dropout_rate={}, epsilon={}, activation={}, act_dropout_rate={}, normalize_before={}, dtype={}{}'.format( + self._d_model, self._dim_feedforward, self._dropout_rate, + self._epsilon, self._act_method, self._act_dropout_rate, + self._normalize_before, self._dtype, name_str) + class FusedTransformerEncoderLayer(Layer): """ @@ -393,7 +436,9 @@ class FusedTransformerEncoderLayer(Layer): self.fused_attn = FusedMultiHeadAttention( d_model, nhead, - dropout_rate=attn_dropout_rate, + dropout_rate=dropout_rate, + attn_dropout_rate=attn_dropout_rate, + normalize_before=self.normalize_before, weight_attr=weight_attrs[0], bias_attr=bias_attrs[0]) @@ -401,6 +446,7 @@ class FusedTransformerEncoderLayer(Layer): d_model, dim_feedforward, dropout_rate=dropout_rate, + activation=activation, act_dropout_rate=act_dropout_rate, normalize_before=self.normalize_before, weight_attr=weight_attrs[1], diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index fdd370d7f81e72462500cf3b47beed8062a14c90..7461528bfd975cae006eabf4b77d47e6b3d08558 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -235,8 +235,8 @@ def interpolate(x, Examples: .. code-block:: python - import paddle - import numpy as np + import paddle + import numpy as np import paddle.nn.functional as F # given out size @@ -244,7 +244,7 @@ def interpolate(x, x = paddle.to_tensor(input_data) output_1 = F.interpolate(x=x, size=[12,12]) print(output_1.shape) - # [2L, 3L, 12L, 12L] + # [2L, 3L, 12L, 12L] # given scale output_2 = F.interpolate(x=x, scale_factor=[2,1])