未验证 提交 36dd295e 编写于 作者: Z zhangkaihuo 提交者: GitHub

[cherry-pick-2.2.1]fix fused_transformer_encoder_layer bug (#37229)

修复了fused_transformer_encoder_layer fine-tune过程发现的一些问题:

    fused_attention_op添加attn_mask=None的支持:PR
    pre_layer_norm处理问题:PR
    参数处理,计算错误的问题:PR
    add_bias计算错误问题:PR
    添加pure fp16的支持:PR
上级 79b9f47e
......@@ -266,6 +266,13 @@ NameVarBaseMap CastPureFp16Inputs(const std::string& op_type,
pair.first != "X") {
continue;
}
if ((op_type == "fused_attention" || op_type == "fused_feedforward")) {
if (pair.first == "LnScale" || pair.first == "LnBias" ||
pair.first == "Ln2Scale" || pair.first == "Ln2Bias" ||
pair.first == "Ln1Scale" || pair.first == "Ln1Bias") {
continue;
}
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type);
......
......@@ -87,7 +87,8 @@ __global__ void BroadcastKernelBinary(
kernel_primitives::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(
result, arg0, arg1, func);
// store
kernel_primitives::WriteData<OutT, VecSize, 1, 1>(out + fix, result, num);
kernel_primitives::WriteData<OutT, VecSize, 1, 1, true>(out + fix, result,
num);
}
// bias add forward impl for "[m, n] + [n] = [m, n]"
......@@ -267,25 +268,24 @@ __global__ void BiasAddBw1DReduceKernel(const ReduceParamType<T>* temp_sum,
}
template <typename T>
void Launch2DColumnReduce(gpuStream_t stream, const int max_threads,
const int reduce_num, const int left_num,
const T* d_out, T* d_bias) {
void Launch2DColumnReduce(const platform::CUDADeviceContext& dev_ctx,
const int max_threads, const int reduce_num,
const int left_num, const T* d_out, T* d_bias) {
dim3 block;
dim3 grid;
bool should_reduce_again = false;
int blocking_size = 1;
SetConfigForColumnReduce(max_threads, reduce_num, left_num, &blocking_size,
&should_reduce_again, &block, &grid);
const auto& stream = dev_ctx.stream();
if (!should_reduce_again) {
BiasAddBwSinglePassKernel<T><<<grid, block, 0, stream>>>(d_out, reduce_num,
left_num, d_bias);
} else {
framework::Tensor tmp_sum;
tmp_sum.mutable_data<ReduceParamType<T>>(
framework::make_ddim({static_cast<int64_t>(
left_num * grid.y * sizeof(ReduceParamType<T>))}),
paddle::platform::CUDAPlace());
tmp_sum.Resize({grid.y, left_num});
tmp_sum.mutable_data<ReduceParamType<T>>(dev_ctx.GetPlace());
BiasAddBw2DReduceKernel<T><<<grid, block, 0, stream>>>(
d_out, reduce_num, left_num, blocking_size,
......@@ -311,8 +311,8 @@ void LaunchBiasAddBwKernel(const platform::CUDADeviceContext& dev_ctx, int m,
Launch1DColumnReduce(dev_ctx.stream(), max_threads, reduce_num, left_num,
d_out, d_bias);
} else {
Launch2DColumnReduce(dev_ctx.stream(), max_threads, reduce_num, left_num,
d_out, d_bias);
Launch2DColumnReduce(dev_ctx, max_threads, reduce_num, left_num, d_out,
d_bias);
}
}
......
......@@ -11,10 +11,13 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/attn_bias_add.cu.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
namespace paddle {
namespace operators {
......@@ -36,8 +39,10 @@ class AttnMatMul {
~AttnMatMul() {}
void ComputeForward(const T* weight_data, const T* input_data,
const T* bias_data, T* output_data, T* bias_out_data) {
void ComputeForward(const framework::Tensor* weight,
const framework::Tensor* input,
const framework::Tensor* bias, framework::Tensor* output,
framework::Tensor* bias_out) {
// Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
// here: (transa, transb): nt, input * weight.
CBLAS_TRANSPOSE transA = CblasNoTrans;
......@@ -54,16 +59,25 @@ class AttnMatMul {
// here: (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha,
input_data, weight_data, beta, output_data);
input->data<T>(), weight->data<T>(), beta, output->data<T>());
if (compute_bias_) {
// compute output + bias
LaunchBiasAddFwKernel(dev_ctx_, bsz_seq_, output_size_, output_data,
bias_data, bias_out_data);
std::vector<const Tensor*> ins;
std::vector<Tensor*> outs;
ins.emplace_back(output);
ins.emplace_back(bias);
outs.emplace_back(bias_out);
int elewise_add_axis = -1;
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor<T>());
}
}
void ComputeBackward(const T* input, const T* weight, const T* d_output,
T* d_input, T* d_weight, T* d_bias) {
void ComputeBackward(const framework::Tensor* input,
const framework::Tensor* weight,
const framework::Tensor* d_output,
framework::Tensor* d_input, framework::Tensor* d_weight,
framework::Tensor* d_bias) {
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
......@@ -81,11 +95,11 @@ class AttnMatMul {
T* dB_input_1_ptr = nullptr;
T* dB_input_2_ptr = nullptr;
T* dB_output_ptr = d_weight;
T* dB_output_ptr = d_weight->data<T>();
T* dA_input_1_ptr = nullptr;
T* dA_input_2_ptr = nullptr;
T* dA_output_ptr = d_input;
T* dA_output_ptr = d_input->data<T>();
if (!transA_) {
// fw: gemm-nt
......@@ -104,10 +118,10 @@ class AttnMatMul {
dA_n = input_size_;
dA_k = output_size_;
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, d_output,
input, beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
weight, beta, dA_output_ptr);
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha,
d_output->data<T>(), input->data<T>(), beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha,
d_output->data<T>(), weight->data<T>(), beta, dA_output_ptr);
} else { // fw: gemm-nn
// bw: gemm-tn, dB = A^t * dC
dB_transA = CblasTrans;
......@@ -123,10 +137,10 @@ class AttnMatMul {
dA_n = input_size_;
dA_k = output_size_;
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, input,
d_output, beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
weight, beta, dA_output_ptr);
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha,
input->data<T>(), d_output->data<T>(), beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha,
d_output->data<T>(), weight->data<T>(), beta, dA_output_ptr);
}
} else if (transB_) {
PADDLE_THROW(platform::errors::InvalidArgument(
......@@ -138,7 +152,27 @@ class AttnMatMul {
"parameters."));
}
if (compute_bias_) {
LaunchBiasAddBwKernel(dev_ctx_, bsz_seq_, output_size_, d_output, d_bias);
// reduce: {0, 1, 2, 3, 4} -> {2, 3, 4} or {0, 1, 2} -> {2}
const auto input_dims = d_output->dims();
const auto output_dims = d_bias->dims();
bool support_case_1 =
(input_dims.size() == 5 && output_dims.size() == 3 &&
(input_dims[2] == output_dims[0]) &&
(input_dims[3] == output_dims[1]) &&
(input_dims[4] == output_dims[2]));
bool support_case_2 =
(input_dims.size() == 3 && output_dims.size() == 1 &&
(input_dims[2] == output_dims[0]));
if (support_case_1 || support_case_2) {
gpuStream_t stream = dev_ctx_.stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*d_output, d_bias, {0, 1},
stream);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support reduce when the input dims are [0,1,2,3,4] and "
"output is [2,3,4]"
"or input is [0,1,2] and output is [2]."));
}
}
}
......
......@@ -69,7 +69,7 @@ class FMHARef {
~FMHARef() {}
void ComputeForward(const Tensor& qkv_input_tensor,
const Tensor& src_mask_tensor,
const Tensor* src_mask_tensor,
Tensor* transpose_2_out_tensor, Tensor* qk_out_tensor,
Tensor* src_mask_out_tensor, Tensor* softmax_out_tensor,
Tensor* dropout_mask_out_tensor,
......@@ -111,17 +111,17 @@ class FMHARef {
blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, q_ptr,
k_ptr, beta, qk_out_data, gemm_batch_size, stride_a,
stride_b);
int softmax_axis = -1;
if (src_mask_tensor != nullptr) {
std::vector<const Tensor*> ins;
std::vector<Tensor*> outs;
ins.emplace_back(qk_out_tensor);
ins.emplace_back(&src_mask_tensor);
ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1;
int softmax_axis = -1;
if (&src_mask_tensor != nullptr) {
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor<T>());
SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *src_mask_out_tensor,
softmax_axis, softmax_out_tensor);
} else {
......@@ -165,7 +165,7 @@ class FMHARef {
}
void ComputeBackward(
const Tensor& transpose_2_out_tensor, const Tensor& src_mask_tensor,
const Tensor& transpose_2_out_tensor, const Tensor* src_mask_tensor,
const Tensor& softmax_out_tensor, const Tensor& dropout_mask_out_tensor,
const Tensor& dropout_out_tensor, const Tensor& qk_out_tensor,
const Tensor& src_mask_out_tensor, const Tensor& fmha_out_grad_tensor,
......@@ -249,7 +249,7 @@ class FMHARef {
softmax_out_grad_tensor);
}
if (&src_mask_tensor != nullptr) {
if (src_mask_tensor != nullptr) {
SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, softmax_out_tensor,
*softmax_out_grad_tensor, softmax_axis,
src_mask_out_grad_tensor);
......
......@@ -27,8 +27,6 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionOp");
......@@ -44,6 +42,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut",
"FusedAttentionOp");
} else {
OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output",
"BiasDropoutResidualOut", "FusedAttentionOp");
}
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
......@@ -57,8 +62,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("QKTVOut"), "Output", "QKTVOut",
"FusedAttentionOp");
if (ctx->HasInput("SrcMask")) {
OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut",
"FusedAttentionOp");
}
OP_INOUT_CHECK(ctx->HasOutput("SoftmaxOut"), "Output", "SoftmaxOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutMaskOut"), "Output",
......@@ -69,12 +77,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"), "Output", "OutLinearOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output",
"BiasDropoutResidualOut", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), "Output", "DropoutMaskOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp");
......@@ -108,6 +111,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnOut", ctx->GetInputDim("X"));
} else {
ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X"));
}
// [batch_size, seq_len, 3, num_head, head_size]
ctx->SetOutputDim("QKVOut",
......@@ -119,7 +126,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
{y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
// [batch, num_head, seq_len, seq_len]
ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
if (ctx->HasInput("SrcMask")) {
ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
}
// the same as QKOut's shape.
ctx->SetOutputDim("AttnDropoutOut",
{x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
......@@ -134,12 +144,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]});
ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X"));
ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]});
if (ctx->Attrs().Get<bool>("dropout_is_test") == false) {
ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X"));
}
ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X"));
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
}
......@@ -320,7 +328,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
{
out = transpose(out, perm=[2, 0, 3, 1, 4]);
out = q * k^t;
out = attn_mark + out;
out = attn_mask + out;
out = softmax(out);
out = dropout(out);
out = out * v;
......@@ -328,6 +336,9 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
}
out = out_linear(out);
if (pre_layernorm)
final_out = residual + dropout(bias + out);
else
final_out = layer_norm(residual + dropout(bias + out));
)DOC");
}
......@@ -343,6 +354,7 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"GradOp is only callable when attn_dropout_is_test is false"));
if (ctx->Attrs().Get<bool>("pre_layer_norm") == false) {
OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance",
......@@ -355,8 +367,7 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("Ln2Bias"),
ctx->GetInputDim("Ln2Bias"));
}
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
} else {
OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance",
......@@ -364,12 +375,12 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut",
"FusedAttentionGrad");
}
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
......@@ -400,6 +411,9 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim(framework::GradVarName("LnOut"),
ctx->GetInputDim("LnOut"));
} else {
ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"),
ctx->GetInputDim("BiasDropoutResidualOut"));
}
ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
ctx->GetInputDim("FMHAOut"));
......@@ -413,16 +427,17 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("SoftmaxOut"));
ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"),
ctx->GetInputDim("AttnDropoutOut"));
if (ctx->HasOutput(framework::GradVarName("SrcMaskOut"))) {
ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"),
ctx->GetInputDim("SrcMaskOut"));
}
ctx->SetOutputDim(framework::GradVarName("QKVOut"),
ctx->GetInputDim("QKVOut"));
ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"),
ctx->GetInputDim("QKVBiasOut"));
ctx->SetOutputDim(framework::GradVarName("OutLinearOut"),
ctx->GetInputDim("OutLinearOut"));
ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"),
ctx->GetInputDim("BiasDropoutResidualOut"));
}
protected:
......@@ -448,7 +463,14 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("X", this->Input("X"));
op->SetInput("QKVW", this->Input("QKVW"));
op->SetInput("QKVBias", this->Input("QKVBias"));
if (this->HasInput("SrcMask")) {
op->SetInput("SrcMask", this->Input("SrcMask"));
op->SetInput("SrcMaskOut", this->Output("SrcMaskOut"));
op->SetOutput(framework::GradVarName("SrcMaskOut"),
this->OutputGrad("SrcMaskOut"));
}
op->SetInput("OutLinearW", this->Input("OutLinearW"));
op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
......@@ -466,8 +488,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("LnBias"),
this->InputGrad("LnBias"));
}
}
} else {
if (this->HasInput("Ln2Scale")) {
op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetOutput(framework::GradVarName("Ln2Scale"),
......@@ -478,6 +499,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Ln2Bias"),
this->InputGrad("Ln2Bias"));
}
}
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("QKVW"), this->InputGrad("QKVW"));
......@@ -499,6 +521,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
if (this->HasOutput("LnVariance")) {
op->SetInput("LnVariance", this->Output("LnVariance"));
}
} else {
op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetInput("BiasDropoutResidualOut",
this->Output("BiasDropoutResidualOut"));
}
op->SetInput("QKVOut", this->Output("QKVOut"));
op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
......@@ -508,15 +535,10 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("SoftmaxOut", this->Output("SoftmaxOut"));
op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut"));
op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut"));
op->SetInput("SrcMaskOut", this->Output("SrcMaskOut"));
op->SetInput("FMHAOut", this->Output("FMHAOut"));
op->SetInput("OutLinearOut", this->Output("OutLinearOut"));
op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut"));
op->SetInput("BiasDropoutResidualOut",
this->Output("BiasDropoutResidualOut"));
op->SetInput("QKVOut", this->Output("QKVOut"));
// backward outputs: dinput
......@@ -525,7 +547,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("LnOut"),
this->OutputGrad("LnOut"));
}
} else {
op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
this->OutputGrad("BiasDropoutResidualOut"));
}
op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut"));
op->SetOutput(framework::GradVarName("QKVBiasOut"),
this->OutputGrad("QKVBiasOut"));
......@@ -538,12 +564,9 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this->OutputGrad("SoftmaxOut"));
op->SetOutput(framework::GradVarName("AttnDropoutOut"),
this->OutputGrad("AttnDropoutOut"));
op->SetOutput(framework::GradVarName("SrcMaskOut"),
this->OutputGrad("SrcMaskOut"));
op->SetOutput(framework::GradVarName("FMHAOut"),
this->OutputGrad("FMHAOut"));
op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
this->OutputGrad("BiasDropoutResidualOut"));
op->SetOutput(framework::GradVarName("OutLinearOut"),
this->OutputGrad("OutLinearOut"));
}
......
......@@ -95,15 +95,6 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
const auto qkv_w_dims = qkv_weight->dims();
auto *x_data = input_x->data<T>();
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *ln_mean_data =
pre_layer_norm ? ln_mean->mutable_data<U>(ctx.GetPlace()) : nullptr;
auto *ln_var_data =
pre_layer_norm ? ln_var->mutable_data<U>(ctx.GetPlace()) : nullptr;
auto *ln_out_data =
pre_layer_norm ? ln_out->mutable_data<T>(ctx.GetPlace()) : nullptr;
auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>();
auto *qkv_out_data = qkv_out->mutable_data<T>(ctx.GetPlace());
......@@ -114,7 +105,9 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
transpose_out_2->mutable_data<T>(ctx.GetPlace());
auto *qk_out_data = qk_out->mutable_data<T>(ctx.GetPlace());
auto *qktv_out_data = qktv_out->mutable_data<T>(ctx.GetPlace());
auto *src_mask_out_data = src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr
: src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *softmax_out_data = softmax_out->mutable_data<T>(ctx.GetPlace());
auto *attn_dropout_mask_out_data =
attn_dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
......@@ -128,16 +121,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto *out_linear_out_data = out_linear_out->mutable_data<T>(ctx.GetPlace());
// get data ptr for bias+dropout+residual+layernorm
auto *ln_scale_2_data =
(ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>());
auto *ln_bias_2_data =
(ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>());
auto *dropout_mask_out_data =
dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace());
auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace());
auto *final_out_data = out->mutable_data<T>(ctx.GetPlace());
int batch_size = input_x_dims[0];
......@@ -176,23 +161,45 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
ln_epsilon);
if (pre_layer_norm) {
auto *ln_scale_data =
(ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *ln_mean_data = ln_mean->mutable_data<U>(ctx.GetPlace());
auto *ln_var_data = ln_var->mutable_data<U>(ctx.GetPlace());
auto *ln_out_data = ln_out->mutable_data<T>(ctx.GetPlace());
layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data,
ln_out_data, ln_mean_data, ln_var_data);
qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data,
qkv_out_data, qkv_bias_out_data);
qkv_compute.ComputeForward(qkv_weight, ln_out, qkv_bias, qkv_out,
qkv_bias_out);
} else {
qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data,
qkv_out_data, qkv_bias_out_data);
qkv_compute.ComputeForward(qkv_weight, input_x, qkv_bias, qkv_out,
qkv_bias_out);
}
fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out_2,
fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out);
// fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim]
// out_linear_out: [batch_size, seq_len, embed_dim]
out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data,
nullptr, out_linear_out_data, nullptr);
out_linear_compute.ComputeForward(out_linear_weight, fmha_out, nullptr,
out_linear_out, nullptr);
if (pre_layer_norm) {
// output = (residual + dropout(input + bias))
fused_dropout_layernorm_helper.ResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data,
out_linear_bias_data, final_out_data, dropout_mask_out_data);
} else {
auto *ln_scale_2_data =
(ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>());
auto *ln_bias_2_data =
(ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>());
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace());
auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace());
// output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data,
......@@ -200,6 +207,7 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data,
ln_mean_2_data, ln_var_2_data);
}
}
};
template <typename T>
......@@ -265,12 +273,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *qk_out_data = qk_out->data<T>();
auto *qktv_out_data = qktv_out->data<T>();
auto *softmax_out_data = softmax_out->data<T>();
auto *src_mask_out_data = src_mask_out->data<T>();
auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
auto *out_linear_out_data = out_linear_out->data<T>();
auto *ln_2_mean_data = ln_2_mean->data<U>();
auto *ln_2_var_data = ln_2_var->data<U>();
auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>();
auto *bias_dropout_residual_out_data = bias_dropout_residual_out->data<T>();
// output's grad
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
......@@ -302,12 +308,12 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_softmax_out_data = d_softmax_out->mutable_data<T>(ctx.GetPlace());
auto *d_attn_dropout_out_data =
d_attn_dropout_out->mutable_data<T>(ctx.GetPlace());
auto *d_src_mask_out_data = d_src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *d_src_mask_out_data =
(src_mask == nullptr) ? nullptr
: d_src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *d_fmha_out_data = d_fmha_out->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_out_data =
d_out_linear_out->mutable_data<T>(ctx.GetPlace());
auto *d_bias_dropout_residual_out_data =
d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
// parameter grad
auto *d_qkv_weight = ctx.Output<Tensor>(framework::GradVarName("QKVW"));
......@@ -325,12 +331,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_out_linear_weight->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_bias_data =
d_out_linear_bias->mutable_data<T>(ctx.GetPlace());
auto *d_ln_2_scale_data =
(d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data<U>(
ctx.GetPlace()));
auto *d_ln_2_bias_data =
(d_ln_2_bias == nullptr ? nullptr
: d_ln_2_bias->mutable_data<U>(ctx.GetPlace()));
const auto input_x_dims = input_x->dims();
const auto qkv_w_dims = qkv_weight->dims();
......@@ -376,17 +376,37 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2,
ln2epsilon);
if (pre_layer_norm) {
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
ctx.cuda_device_context(), d_y_data, dropout_mask_out_data,
d_out_linear_out_data, d_residual_data, d_out_linear_bias_data);
} else {
auto *ln_2_mean_data = ln_2_mean->data<U>();
auto *ln_2_var_data = ln_2_var->data<U>();
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->data<T>();
auto *d_ln_2_scale_data =
(d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data<U>(
ctx.GetPlace()));
auto *d_ln_2_bias_data =
(d_ln_2_bias == nullptr ? nullptr : d_ln_2_bias->mutable_data<U>(
ctx.GetPlace()));
auto *d_bias_dropout_residual_out_data =
d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data,
dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data,
d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data,
d_out_linear_out_data, d_out_linear_bias_data, d_residual_data);
}
out_linear_compute.ComputeBackward(fmha_out, out_linear_weight,
d_out_linear_out, d_fmha_out,
d_out_linear_weight, nullptr);
out_linear_compute.ComputeBackward(fmha_out_data, out_linear_weight_data,
d_out_linear_out_data, d_fmha_out_data,
d_out_linear_weight_data, nullptr);
fmha_ref_compute.ComputeBackward(
*transpose_out_2, *src_mask, *softmax_out, *attn_dropout_mask_out,
*transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
d_transpose_out_2, nullptr, d_qkv_bias_out);
......@@ -413,15 +433,14 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
(d_ln_bias == nullptr ? nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace()));
qkv_compute.ComputeBackward(ln_out_data, qkv_weight_data,
d_qkv_bias_out_data, d_ln_out_data,
d_qkv_weight_data, d_qkv_bias_data);
qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out, d_ln_out,
d_qkv_weight, d_qkv_bias);
layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data,
ln_mean_data, ln_var_data, d_x_data,
d_ln_scale_data, d_ln_bias_data);
} else {
qkv_compute.ComputeBackward(x_data, qkv_weight_data, d_qkv_bias_out_data,
d_x_data, d_qkv_weight_data, d_qkv_bias_data);
qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x,
d_qkv_weight, d_qkv_bias);
}
// gradient accumulation
std::vector<const Tensor *> ins;
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/matmul_v2_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
......@@ -261,7 +262,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
framework::Tensor d_linear2_out, d_dropout2_out, d_residual;
d_linear2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_dropout2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_residual.mutable_data<T>({bsz_seq, d_model}, place);
d_residual.mutable_data<T>(d_x->dims(), place);
if (pre_layer_norm) {
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
......@@ -301,6 +302,14 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
} else {
MatMulGrad(ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight);
}
std::vector<const Tensor*> ins(2);
std::vector<Tensor*> outs(1);
ins[0] = &d_residual;
ins[1] = d_x;
outs[0] = d_x;
int elewise_add_axis = -1;
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
ctx, ins, &outs, elewise_add_axis, AddFunctor<T>());
}
void Compute(const framework::ExecutionContext& context) const override {
......
......@@ -94,6 +94,7 @@ if(NOT WITH_GPU)
LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api)
LIST(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer)
endif()
if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
......
......@@ -26,6 +26,9 @@ from paddle import tensor
from paddle.fluid import layers
import unittest
from op_test import OpTest
from paddle.fluid.framework import default_main_program
default_main_program().random_seed = 42
class TestFusedAttentionOp(OpTest):
......@@ -66,6 +69,7 @@ class TestFusedAttentionOp(OpTest):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = False
self.has_attn_mask = True
self.training = True
self.batch_size = 8
......@@ -84,6 +88,7 @@ class TestFusedAttentionOp(OpTest):
def generate_input_data(self):
self.query = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type)
if self.has_attn_mask:
self.attn_mask = np.ones(
(self.batch_size, self.num_heads, self.query_length,
self.key_length),
......@@ -93,7 +98,10 @@ class TestFusedAttentionOp(OpTest):
elif self.attn_mask_type == np.float64:
self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.")
raise ValueError(
"'attn_mask_type' should be 'int64' or 'float64'.")
else:
self.attn_mask = None
self.key, self.value = self.query, self.query
self.dout = np.random.random((self.batch_size, self.query_length,
......@@ -102,7 +110,10 @@ class TestFusedAttentionOp(OpTest):
def GetBaselineOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
attn_mask = None
residual = tensor_query
ln1_out = tensor_query
......@@ -147,8 +158,8 @@ class TestFusedAttentionOp(OpTest):
residual_out = residual + self.dropout(out)
if not self.pre_layer_norm:
final_out = self.norm1(residual_out)
if self.pre_layer_norm:
final_out = self.norm2(residual_out)
else:
final_out = residual_out
paddle.autograd.backward(
[final_out], [paddle.to_tensor(self.dout)], retain_graph=True)
return final_out, tensor_query.grad
......@@ -187,7 +198,10 @@ class TestFusedAttentionOp(OpTest):
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
x = paddle.to_tensor(self.query, stop_gradient=False)
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
attn_mask = None
qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False)
qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
epsilon = 1e-05
......@@ -208,9 +222,9 @@ class TestFusedAttentionOp(OpTest):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5)
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-5)
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4)
class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
......@@ -218,6 +232,7 @@ class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = True
self.training = True
self.batch_size = 8
......@@ -237,9 +252,39 @@ class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1)
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1)
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4)
class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = False
self.training = True
self.batch_size = 8
self.query_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
def test_fused_attention_op(self):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4)
class TestFusedAttentionOpFp16(TestFusedAttentionOp):
......@@ -247,6 +292,7 @@ class TestFusedAttentionOpFp16(TestFusedAttentionOp):
self.x_type = np.float16
self.attn_mask_type = np.float64
self.pre_layer_norm = False
self.has_attn_mask = True
self.training = True
self.batch_size = 8
......
......@@ -89,27 +89,32 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
qkv_weight = qkv_weight.reshape(qkv_weight.shape[0], qkv_weight.shape[1] *
qkv_weight.shape[2] * qkv_weight.shape[3])
qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] *
qkv_bias.shape[2])
if (pre_layer_norm):
ln_out = ln_out.reshape(batch_size * seq_len, embed_dim)
qkv = fc(ln_out, qkv_weight)
qkv_bias_out = qkv + qkv_bias
ln_out = ln_out.reshape(batch_size, seq_len, embed_dim)
else:
query = query.reshape(batch_size * seq_len, embed_dim)
qkv = fc(query, qkv_weight)
qkv_bias_out = qkv + qkv_bias
query = query.reshape(batch_size, seq_len, embed_dim)
qkv = qkv.reshape(batch_size, seq_len, 3, num_head, head_dim)
qkv_bias_out = qkv_bias_out.reshape(batch_size, seq_len, 3, num_head,
head_dim)
# q*k^t
qkv = qkv.transpose(
qkv_bias_out = qkv_bias_out.transpose(
(2, 0, 1, 3, 4)) # 3, batch_size, seq_len, num_head, head_dim
qkv = qkv.transpose(
qkv_bias_out = qkv_bias_out.transpose(
(0, 1, 3, 2, 4)) # 3, batch_size, num_head, seq_len, head_dim
q = qkv[0:1, ::]
q = qkv_bias_out[0:1, ::]
q = q.reshape(batch_size, num_head, seq_len, head_dim)
k = qkv[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim]
k = qkv_bias_out[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim]
k = k.reshape(batch_size, num_head, seq_len, head_dim)
v = qkv[2::]
v = qkv_bias_out[2::]
v = v.reshape(batch_size, num_head, seq_len, head_dim)
k = k.transpose([0, 1, 3, 2]) #[batch_size, num_head, head_dim, seq_len]
......@@ -138,9 +143,11 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
out_linear_bias_out = out_linear_out + out_linear_bias
out_linear_bias_dropout_out = out_linear_bias_out
out_linear_bias_dropout_residual_out = query + out_linear_bias_dropout_out
out_linear_bias_dropout_residual_ln_out = layer_norm(
out_linear_bias_dropout_residual_out, True, True, ln_2_scale, ln_2_bias)
return out_linear_bias_dropout_residual_ln_out
if not pre_layer_norm:
out_linear_bias_dropout_residual_out = layer_norm(
out_linear_bias_dropout_residual_out, True, True, ln_2_scale,
ln_2_bias)
return out_linear_bias_dropout_residual_out
class TestFusedAttentionAPI(unittest.TestCase):
......@@ -152,6 +159,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = True
self.training = True
self.need_weight = False
......@@ -172,6 +180,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
def generate_input_data(self):
self.query = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type)
if self.has_attn_mask:
self.attn_mask = np.ones(
(self.batch_size, self.num_heads, self.query_length,
self.key_length),
......@@ -181,18 +190,27 @@ class TestFusedAttentionAPI(unittest.TestCase):
elif self.attn_mask_type == np.float64:
self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.")
raise ValueError(
"'attn_mask_type' should be 'int64' or 'float64'.")
else:
self.attn_mask = None
self.key, self.value = self.query, self.query
def run_imperative(self):
if self.has_attn_mask:
attn_mask_tensor = paddle.to_tensor(self.attn_mask)
else:
attn_mask_tensor = None
fused_attn = FusedMultiHeadAttention(
self.embed_dim, self.num_heads, self.dropout_prob,
self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
self.need_weight, self.weight_attr, self.bias_attr)
qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype('float32')
fused_attn.qkv_bias.set_value(paddle.to_tensor(qkv_bias))
out = fused_attn(
paddle.to_tensor(self.query),
paddle.to_tensor(self.query),
paddle.to_tensor(self.query), paddle.to_tensor(self.attn_mask))
paddle.to_tensor(self.query), attn_mask_tensor)
ref_out = compute_reference(self.pre_layer_norm, self.query,
self.attn_mask,
fused_attn.pre_ln_scale.numpy(),
......@@ -203,7 +221,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
fused_attn.qkv_bias.numpy(),
fused_attn.linear_weight.numpy(),
fused_attn.linear_bias.numpy())
self.assertTrue(np.allclose(ref_out, out, rtol=1e-5, atol=1e-5))
np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-5)
def run_static(self):
fused_attn = FusedMultiHeadAttention(
......@@ -215,6 +233,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
name='X',
shape=[self.batch_size, self.query_length, self.embed_dim],
dtype=self.x_type)
if self.has_attn_mask:
attn_mask = paddle.static.data(
name='SrcMask',
shape=[
......@@ -223,10 +242,13 @@ class TestFusedAttentionAPI(unittest.TestCase):
],
dtype=self.attn_mask_type)
final_out = fused_attn(x, x, x, attn_mask)
else:
final_out = fused_attn(x, x, x)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
if self.has_attn_mask:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query,
......@@ -237,7 +259,16 @@ class TestFusedAttentionAPI(unittest.TestCase):
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query, },
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias
def test_static_api(self):
......@@ -249,14 +280,36 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.attn_mask, ln_scale, ln_bias,
ln_2_scale, ln_2_bias, qkv_weight, qkv_bias,
linear_weight, linear_bias)
self.assertTrue(
np.allclose(
np.array(ref_out), np.array(out), rtol=1e-5, atol=1e-5))
np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-5)
def test_dynamic_api(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
self.run_imperative()
class TestFusedAttentionAPINoneAttnMask(TestFusedAttentionAPI):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = False
self.training = True
self.need_weight = False
self.batch_size = 1
self.query_length = 2
self.head_dim = 2
self.num_heads = 2
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
if __name__ == "__main__":
unittest.main()
......@@ -23,6 +23,7 @@ from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.common import Linear, Dropout
import unittest
from op_test import OpTest
from paddle.fluid.framework import default_main_program
class TestFusedFFNOp(OpTest):
......@@ -91,7 +92,7 @@ class TestFusedFFNOp(OpTest):
def Base(self):
paddle.disable_static()
tensor_src = paddle.to_tensor(self.src, stop_gradient=False)
residual = paddle.to_tensor(self.src)
residual = tensor_src
if self.pre_layer_norm:
ln1_out = self.norm1(tensor_src)
linear2_out = self.linear2(
......@@ -140,6 +141,7 @@ class TestFusedFFNOp(OpTest):
return out, x.grad
def test_out_and_grad(self):
default_main_program().random_seed = 42
base_out, base_grad = self.Base()
fused_out, fused_grad = self.FusedFFN()
np.testing.assert_allclose(
......@@ -192,6 +194,7 @@ class TestFusedFFNOpNormalizeBefore(TestFusedFFNOp):
class APITestStaticFusedFFN(unittest.TestCase):
def test_static(self):
paddle.enable_static()
default_main_program().random_seed = 42
dtype = "float32"
layer_norm_dtype = "float32"
batch_size = 1
......@@ -324,6 +327,18 @@ class TestFusedFFNOpError(unittest.TestCase):
self.assertRaises(ValueError, test_dropout_rate_value)
def test_dropout_mode():
x = paddle.static.data(
name='x3', shape=[1, 10, 10], dtype="float32")
linear1_weight = paddle.static.data(
name='linear1_weight3', shape=[10, 10], dtype="float32")
linear2_weight = paddle.static.data(
name='linear2_weight3', shape=[10, 10], dtype="float32")
incubate_f.fused_feedforward(
x, linear1_weight, linear2_weight, mode='test')
self.assertRaises(ValueError, test_dropout_mode)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle.incubate.nn import FusedTransformerEncoderLayer
from paddle.nn import TransformerEncoderLayer
from paddle.fluid.framework import default_main_program
import unittest
class TestFusedTransformerEncoderLayer(unittest.TestCase):
def setActivation(self):
self.activation = 'gelu'
def setPreLayerNorm(self):
self.pre_layer_norm = False
def setAttnMask(self):
self.has_attn_mask = True
def setUp(self):
self.batch_size = np.random.randint(1, 8)
self.query_length = np.random.randint(1, 128)
self.nhead = 16
self.head_dim = 4
self.num_heads = self.nhead
self.d_model = self.head_dim * self.num_heads
self.embed_dim = self.d_model
self.dim_feedforward = np.random.randint(1, 32)
self.dropout_rate = 0
self.attn_dropout_rate = None
self.act_dropout_rate = None
self.attn_mask_type = np.float64
self.key_length = self.query_length
self.dtype = 'float32'
self.setActivation()
self.setPreLayerNorm()
self.setAttnMask()
def fused_weight(self, weight, num_head):
a = paddle.transpose(weight, perm=[1, 0])
return paddle.reshape(
a, shape=[1, num_head, int(a.shape[0] / num_head), a.shape[1]])
def fused_qkv(self, q, k, v, num_head):
fq = self.fused_weight(q, num_head)
fk = self.fused_weight(k, num_head)
fv = self.fused_weight(v, num_head)
return paddle.concat(x=[fq, fk, fv], axis=0)
def test_out(self):
default_main_program().random_seed = 42
base_encoder = TransformerEncoderLayer(
self.d_model, self.nhead, self.dim_feedforward, self.dropout_rate,
self.activation, self.attn_dropout_rate, self.act_dropout_rate,
self.pre_layer_norm)
src = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.dtype)
if self.has_attn_mask:
attn_mask = np.ones(
(self.batch_size, self.num_heads, self.query_length,
self.key_length),
dtype=self.attn_mask_type)
attn_mask_tensor = paddle.to_tensor(attn_mask)
else:
attn_mask = None
attn_mask_tensor = None
dout = np.random.random(src.shape).astype(self.dtype)
base_out = base_encoder(
paddle.to_tensor(
src, stop_gradient=False), attn_mask_tensor)
paddle.autograd.backward([base_out], [paddle.to_tensor(dout)], True)
fused_encoder = FusedTransformerEncoderLayer(
self.d_model, self.nhead, self.dim_feedforward, self.dropout_rate,
self.activation, self.attn_dropout_rate, self.act_dropout_rate,
self.pre_layer_norm)
fused_encoder.ffn._linear1_weight.set_value(base_encoder.linear1.weight)
fused_encoder.ffn._linear1_bias.set_value(base_encoder.linear1.bias)
fused_encoder.ffn._linear2_weight.set_value(base_encoder.linear2.weight)
fused_encoder.ffn._linear2_bias.set_value(base_encoder.linear2.bias)
if self.pre_layer_norm:
fused_encoder.ffn._ln1_scale.set_value(base_encoder.norm2.weight)
fused_encoder.ffn._ln1_bias.set_value(base_encoder.norm2.bias)
else:
fused_encoder.ffn._ln2_scale.set_value(base_encoder.norm2.weight)
fused_encoder.ffn._ln2_bias.set_value(base_encoder.norm2.bias)
fused_encoder.fused_attn.linear_weight.set_value(
base_encoder.self_attn.out_proj.weight)
fused_encoder.fused_attn.linear_bias.set_value(
base_encoder.self_attn.out_proj.bias)
if self.pre_layer_norm:
fused_encoder.fused_attn.pre_ln_scale.set_value(
base_encoder.norm1.weight)
fused_encoder.fused_attn.pre_ln_bias.set_value(
base_encoder.norm1.bias)
else:
fused_encoder.fused_attn.ln_scale.set_value(
base_encoder.norm1.weight)
fused_encoder.fused_attn.ln_bias.set_value(base_encoder.norm1.bias)
q = base_encoder.self_attn.q_proj.weight
q_bias = base_encoder.self_attn.q_proj.bias
k = base_encoder.self_attn.k_proj.weight
k_bias = base_encoder.self_attn.k_proj.bias
v = base_encoder.self_attn.v_proj.weight
v_bias = base_encoder.self_attn.v_proj.bias
qkv_weight = self.fused_qkv(q, k, v, self.num_heads)
fused_encoder.fused_attn.qkv_weight.set_value(qkv_weight)
tmp = paddle.concat(x=[q_bias, k_bias, v_bias], axis=0)
qkv_bias = paddle.reshape(
tmp,
shape=[3, self.num_heads, int(tmp.shape[0] / 3 / self.num_heads)])
fused_encoder.fused_attn.qkv_bias.set_value(qkv_bias)
fused_out = fused_encoder(
paddle.to_tensor(
src, stop_gradient=False), attn_mask_tensor)
paddle.autograd.backward([fused_out], [paddle.to_tensor(dout)], True)
correct_ffn_str = 'd_model={}, dim_feedforward={}, dropout_rate={}, epsilon={}, activation={}, act_dropout_rate={}, normalize_before={}, dtype={}'.format(
self.d_model, self.dim_feedforward, self.dropout_rate,
fused_encoder.ffn._epsilon, self.activation, self.dropout_rate,
self.pre_layer_norm, self.dtype)
self.assertTrue(fused_encoder.ffn.extra_repr(), correct_ffn_str)
correct_attn_str = 'embed_dim={}, num_heads={}, dropout_rate={}, attn_dropout_rate={}, epsilon={}, kdim={}, vdim={}, normalize_before={}, need_weights={}, dtype={}'.format(
self.embed_dim, self.num_heads, self.dropout_rate,
self.dropout_rate, fused_encoder.fused_attn._epsilon, None, None,
self.pre_layer_norm, False, self.dtype)
self.assertTrue(fused_encoder.fused_attn.extra_repr(), correct_attn_str)
np.testing.assert_allclose(
fused_out.numpy(), base_out.numpy(), rtol=1e-3, atol=1e-4)
self.assertTrue(
np.allclose(
fused_out.grad.numpy(),
base_out.grad.numpy(),
rtol=1e-3,
atol=1e-4))
class TestFusedTransformerEncoderLayerAct(TestFusedTransformerEncoderLayer):
def setActivation(self):
self.activation = 'relu'
class TestFusedTransformerEncoderLayerPreLayerNorm(
TestFusedTransformerEncoderLayer):
def setPreLayerNorm(self):
self.pre_layer_norm = True
class TestFusedTransformerEncoderLayerAttnMaskIsNone(
TestFusedTransformerEncoderLayer):
def setAttnMask(self):
self.has_attn_mask = False
class TestFusedTransformerEncoderLayerPreLnTrueAttnMaskIsNone(
TestFusedTransformerEncoderLayer):
def setPreLayerNorm(self):
self.pre_layer_norm = True
def setAttnMask(self):
self.has_attn_mask = False
if __name__ == "__main__":
unittest.main()
......@@ -13,7 +13,7 @@
# limitations under the License.
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import in_dygraph_mode, default_main_program
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid import core, dygraph_utils
from paddle import _C_ops
......@@ -43,6 +43,8 @@ def fused_feedforward(x,
ln1_epsilon=1e-5,
ln2_epsilon=1e-5,
pre_layer_norm=False,
training=True,
mode='upscale_in_train',
name=None):
"""
This is a fusion operator to compute feed forward layer in transformer model architecture.
......@@ -74,6 +76,18 @@ def fused_feedforward(x,
ln1_epsilon (float, optional): Small float of first layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.
ln2_epsilon (float, optional): Small float of second layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.
pre_layer_norm (bool, optional): add layer_norm in the pre-processing stage or post-processing state.
training (bool, optional): A flag indicating whether it is in train phrase or not. Default True.
mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
1. upscale_in_train(default), upscale the output at training time
- train: out = input * mask / ( 1.0 - p )
- inference: out = input
2. downscale_in_infer, downscale the output at inference
- train: out = input * mask
- inference: out = input * (1.0 - p)
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
......@@ -98,13 +112,27 @@ def fused_feedforward(x,
_verify_dropout_rate(dropout1_rate)
_verify_dropout_rate(dropout2_rate)
seed = None
if mode not in ('downscale_in_infer', 'upscale_in_train'):
raise ValueError(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'")
mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer
if in_dygraph_mode():
if default_main_program().random_seed != 0:
seed = default_main_program().random_seed
out, _, _, _, _, _, _, _, _, _, _ = _C_ops.fused_feedforward(
x, None, None, linear1_weight, linear1_bias, linear2_weight,
linear2_bias, ln1_scale, ln1_bias, ln2_scale, ln2_bias,
'pre_layer_norm', pre_layer_norm, 'ln1_epsilon', ln1_epsilon,
'ln2_epsilon', ln2_epsilon, 'act_method', activation,
'dropout1_rate', dropout1_rate, 'dropout2_rate', dropout2_rate)
'dropout1_rate', dropout1_rate, 'dropout2_rate', dropout2_rate,
"dropout1_is_test", not training, "dropout2_is_test", not training,
"dropout1_fix_seed", seed is not None, "dropout2_fix_seed",
seed is not None, "dropout1_seed", seed
if seed is not None else 0, "dropout2_seed", seed
if seed is not None else 0, 'dropout1_implementation', mode,
'dropout2_implementation', mode)
return out
helper = LayerHelper("fused_feedforward")
......@@ -136,6 +164,9 @@ def fused_feedforward(x,
dropout2_out = helper.create_variable_for_type_inference(
x.dtype, stop_gradient=True)
if (seed is None or seed == 0) and helper.main_program.random_seed != 0:
seed = helper.main_program.random_seed
helper.append_op(
type='fused_feedforward',
inputs={
......@@ -169,6 +200,14 @@ def fused_feedforward(x,
'pre_layer_norm': pre_layer_norm,
'ln1_epsilon': ln1_epsilon,
'ln2_epsilon': ln2_epsilon,
'dropout1_is_test': not training,
'dropout2_is_test': not training,
'dropout1_fix_seed': seed is not None,
'dropout2_fix_seed': seed is not None,
'dropout1_seed': seed if seed is not None else 0,
'dropout2_seed': seed if seed is not None else 0,
'dropout1_implementation': mode,
'dropout2_implementation': mode
})
return out
......@@ -188,6 +227,8 @@ def fused_multi_head_attention(x,
dropout_rate=0.5,
attn_dropout_rate=0.5,
ln_epsilon=1e-05,
training=True,
mode='upscale_in_train',
name=None):
"""
Attention mapps queries and a set of key-value pairs to outputs, and
......@@ -214,6 +255,9 @@ def fused_multi_head_attention(x,
out = out * v
out = transpose(out, perm=[0, 2, 1, 3])
out = out_linear(out)
if pre_layer_norm:
out = x + dropout(linear_bias + out)
else:
out = layer_norm(x + dropout(linear_bias + out))
Parameters:
......@@ -247,6 +291,19 @@ def fused_multi_head_attention(x,
0 for no dropout. Default 0.5.
ln_epsilon (float, optional): Small float value added to denominator of layer_norm
to avoid dividing by zero. Default is 1e-5.
training (bool, optional): A flag indicating whether it is in train phrase or not. Default True.
mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
1. upscale_in_train(default), upscale the output at training time
- train: out = input * mask / ( 1.0 - p )
- inference: out = input
2. downscale_in_infer, downscale the output at inference
- train: out = input * mask
- inference: out = input * (1.0 - p)
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: The output Tensor, the data type and shape is same as `x`.
......@@ -280,7 +337,16 @@ def fused_multi_head_attention(x,
# [2, 4, 128]
print(output.shape)
"""
seed = None
if mode not in ('downscale_in_infer', 'upscale_in_train'):
raise ValueError(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'")
mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer
if in_dygraph_mode():
if default_main_program().random_seed != 0:
seed = default_main_program().random_seed
# pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out,
# qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out,
# linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out
......@@ -295,7 +361,12 @@ def fused_multi_head_attention(x,
linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm',
pre_layer_norm, 'epsilon', pre_ln_epsilon, 'dropout_rate',
dropout_rate, 'attn_dropout_rate', attn_dropout_rate, 'ln_epsilon',
ln_epsilon)
ln_epsilon, 'attn_dropout_is_test', not training, 'dropout_is_test',
not training, 'attn_dropout_fix_seed', seed is not None,
'dropout_fix_seed', seed is not None, 'attn_dropout_seed', seed
if seed is not None else 0, 'dropout_seed', seed
if seed is not None else 0, 'attn_dropout_implementation', mode,
'dropout_implementation', mode)
return final_out
else:
helper = LayerHelper('fused_multi_head_attention', **locals())
......@@ -323,13 +394,24 @@ def fused_multi_head_attention(x,
if ln_bias:
inputs['Ln2Bias'] = [ln_bias]
if (seed is None or seed == 0) and helper.main_program.random_seed != 0:
seed = helper.main_program.random_seed
# set attrs
attrs = {
'pre_layer_norm': pre_layer_norm,
'epsilon': pre_ln_epsilon,
'ln_epsilon': ln_epsilon,
'dropout_rate': dropout_rate,
'attn_dropout_rate': attn_dropout_rate
'attn_dropout_rate': attn_dropout_rate,
'attn_dropout_is_test': not training,
'dropout_is_test': not training,
'attn_dropout_fix_seed': seed is not None,
'dropout_fix_seed': seed is not None,
'attn_dropout_seed': seed if seed is not None else 0,
'dropout_seed': seed if seed is not None else 0,
'attn_dropout_implementation': mode,
'dropout_implementation': mode,
}
# set outputs
......
......@@ -54,6 +54,8 @@ class FusedMultiHeadAttention(Layer):
Default: None, which means the default bias parameter property is used.
If it is set to False, this layer will not have trainable bias parameter.
See usage for details in :code:`ParamAttr`.
epsilon (float, optional): The small value added to the variance to prevent
division by zero. Default: 1e-05.
Examples:
......@@ -80,6 +82,7 @@ class FusedMultiHeadAttention(Layer):
need_weights=False,
weight_attr=None,
bias_attr=None,
epsilon=1e-5,
name=None):
super(FusedMultiHeadAttention, self).__init__()
......@@ -88,13 +91,18 @@ class FusedMultiHeadAttention(Layer):
assert num_heads > 0, ("Expected nhead to be greater than 0, "
"but recieved {}".format(num_heads))
attn_dropout_rate = dropout_rate if attn_dropout_rate is None else attn_dropout_rate
self.normalize_before = normalize_before
self._dtype = self._helper.get_default_dtype()
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self._epsilon = epsilon
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.kdim = kdim
self.vdim = vdim
self.need_weights = need_weights
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
assert need_weights == False, "Only support need_weight is False now."
......@@ -186,15 +194,24 @@ class FusedMultiHeadAttention(Layer):
pre_ln_bias=self.pre_ln_bias,
ln_scale=self.ln_scale,
ln_bias=self.ln_bias,
pre_ln_epsilon=1e-05,
pre_ln_epsilon=self._epsilon,
qkv_bias=self.qkv_bias,
linear_bias=self.linear_bias,
attn_mask=attn_mask,
dropout_rate=self.dropout_rate,
attn_dropout_rate=self.attn_dropout_rate,
ln_epsilon=1e-05)
ln_epsilon=self._epsilon,
training=self.training,
name=self.name)
return out
def extra_repr(self):
name_str = ', name={}'.format(self.name) if self.name else ''
return 'embed_dim={}, num_heads={}, dropout_rate={}, attn_dropout_rate={}, epsilon={}, kdim={}, vdim={}, normalize_before={}, need_weights={}, dtype={}{}'.format(
self.embed_dim, self.num_heads, self.dropout_rate,
self.attn_dropout_rate, self._epsilon, self.kdim, self.vdim,
self.normalize_before, self.need_weights, self._dtype, name_str)
class FusedFeedForward(Layer):
"""
......@@ -203,6 +220,8 @@ class FusedFeedForward(Layer):
dim_feedforward (int): The hidden layer size.
dropout_rate (float, optional): The dropout probability used in pre-process
and post-precess. Default 0.1
epsilon (float, optional): he small value added to the variance to prevent
division by zero. Default: 1e-05.
activation (str, optional): The activation function. Default relu.
act_dropout_rate (float, optional): The dropout probability after activition.
If None, use the value of `dropout_rate`. Default None
......@@ -235,11 +254,13 @@ class FusedFeedForward(Layer):
d_model,
dim_feedforward,
dropout_rate=0.1,
epsilon=1e-05,
activation="relu",
act_dropout_rate=None,
normalize_before=False,
weight_attr=None,
bias_attr=None):
bias_attr=None,
name=None):
super(FusedFeedForward, self).__init__()
assert d_model > 0, (
......@@ -256,6 +277,7 @@ class FusedFeedForward(Layer):
self._act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate
self._act_method = activation
self._normalize_before = normalize_before
self._epsilon = epsilon
self._linear1_weight = self.create_parameter(
shape=[d_model, dim_feedforward],
......@@ -292,15 +314,36 @@ class FusedFeedForward(Layer):
default_initializer=Constant(1.0))
self._ln2_bias = self.create_parameter(
shape=[d_model], attr=None, is_bias=True)
self.name = name
def forward(self, src, cache=None):
out = incubate_f.fused_feedforward(
src, self._linear1_weight, self._linear2_weight, self._linear1_bias,
self._linear2_bias, self._ln1_scale, self._ln1_bias,
self._ln2_scale, self._ln2_bias, self._dropout_rate,
self._act_dropout_rate, self._act_method, self._normalize_before)
src,
self._linear1_weight,
self._linear2_weight,
self._linear1_bias,
self._linear2_bias,
self._ln1_scale,
self._ln1_bias,
self._ln2_scale,
self._ln2_bias,
dropout1_rate=self._act_dropout_rate,
dropout2_rate=self._dropout_rate,
activation=self._act_method,
ln1_epsilon=self._epsilon,
ln2_epsilon=self._epsilon,
pre_layer_norm=self._normalize_before,
training=self.training,
name=self.name)
return out
def extra_repr(self):
name_str = ', name={}'.format(self.name) if self.name else ''
return 'd_model={}, dim_feedforward={}, dropout_rate={}, epsilon={}, activation={}, act_dropout_rate={}, normalize_before={}, dtype={}{}'.format(
self._d_model, self._dim_feedforward, self._dropout_rate,
self._epsilon, self._act_method, self._act_dropout_rate,
self._normalize_before, self._dtype, name_str)
class FusedTransformerEncoderLayer(Layer):
"""
......@@ -393,7 +436,9 @@ class FusedTransformerEncoderLayer(Layer):
self.fused_attn = FusedMultiHeadAttention(
d_model,
nhead,
dropout_rate=attn_dropout_rate,
dropout_rate=dropout_rate,
attn_dropout_rate=attn_dropout_rate,
normalize_before=self.normalize_before,
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0])
......@@ -401,6 +446,7 @@ class FusedTransformerEncoderLayer(Layer):
d_model,
dim_feedforward,
dropout_rate=dropout_rate,
activation=activation,
act_dropout_rate=act_dropout_rate,
normalize_before=self.normalize_before,
weight_attr=weight_attrs[1],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册