未验证 提交 36dd295e 编写于 作者: Z zhangkaihuo 提交者: GitHub

[cherry-pick-2.2.1]fix fused_transformer_encoder_layer bug (#37229)

修复了fused_transformer_encoder_layer fine-tune过程发现的一些问题:

    fused_attention_op添加attn_mask=None的支持:PR
    pre_layer_norm处理问题:PR
    参数处理,计算错误的问题:PR
    add_bias计算错误问题:PR
    添加pure fp16的支持:PR
上级 79b9f47e
...@@ -266,6 +266,13 @@ NameVarBaseMap CastPureFp16Inputs(const std::string& op_type, ...@@ -266,6 +266,13 @@ NameVarBaseMap CastPureFp16Inputs(const std::string& op_type,
pair.first != "X") { pair.first != "X") {
continue; continue;
} }
if ((op_type == "fused_attention" || op_type == "fused_feedforward")) {
if (pair.first == "LnScale" || pair.first == "LnBias" ||
pair.first == "Ln2Scale" || pair.first == "Ln2Bias" ||
pair.first == "Ln1Scale" || pair.first == "Ln1Bias") {
continue;
}
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to " << GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type); << framework::DataTypeToString(dst_type);
......
...@@ -87,7 +87,8 @@ __global__ void BroadcastKernelBinary( ...@@ -87,7 +87,8 @@ __global__ void BroadcastKernelBinary(
kernel_primitives::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>( kernel_primitives::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(
result, arg0, arg1, func); result, arg0, arg1, func);
// store // store
kernel_primitives::WriteData<OutT, VecSize, 1, 1>(out + fix, result, num); kernel_primitives::WriteData<OutT, VecSize, 1, 1, true>(out + fix, result,
num);
} }
// bias add forward impl for "[m, n] + [n] = [m, n]" // bias add forward impl for "[m, n] + [n] = [m, n]"
...@@ -267,25 +268,24 @@ __global__ void BiasAddBw1DReduceKernel(const ReduceParamType<T>* temp_sum, ...@@ -267,25 +268,24 @@ __global__ void BiasAddBw1DReduceKernel(const ReduceParamType<T>* temp_sum,
} }
template <typename T> template <typename T>
void Launch2DColumnReduce(gpuStream_t stream, const int max_threads, void Launch2DColumnReduce(const platform::CUDADeviceContext& dev_ctx,
const int reduce_num, const int left_num, const int max_threads, const int reduce_num,
const T* d_out, T* d_bias) { const int left_num, const T* d_out, T* d_bias) {
dim3 block; dim3 block;
dim3 grid; dim3 grid;
bool should_reduce_again = false; bool should_reduce_again = false;
int blocking_size = 1; int blocking_size = 1;
SetConfigForColumnReduce(max_threads, reduce_num, left_num, &blocking_size, SetConfigForColumnReduce(max_threads, reduce_num, left_num, &blocking_size,
&should_reduce_again, &block, &grid); &should_reduce_again, &block, &grid);
const auto& stream = dev_ctx.stream();
if (!should_reduce_again) { if (!should_reduce_again) {
BiasAddBwSinglePassKernel<T><<<grid, block, 0, stream>>>(d_out, reduce_num, BiasAddBwSinglePassKernel<T><<<grid, block, 0, stream>>>(d_out, reduce_num,
left_num, d_bias); left_num, d_bias);
} else { } else {
framework::Tensor tmp_sum; framework::Tensor tmp_sum;
tmp_sum.mutable_data<ReduceParamType<T>>( tmp_sum.Resize({grid.y, left_num});
framework::make_ddim({static_cast<int64_t>( tmp_sum.mutable_data<ReduceParamType<T>>(dev_ctx.GetPlace());
left_num * grid.y * sizeof(ReduceParamType<T>))}),
paddle::platform::CUDAPlace());
BiasAddBw2DReduceKernel<T><<<grid, block, 0, stream>>>( BiasAddBw2DReduceKernel<T><<<grid, block, 0, stream>>>(
d_out, reduce_num, left_num, blocking_size, d_out, reduce_num, left_num, blocking_size,
...@@ -311,8 +311,8 @@ void LaunchBiasAddBwKernel(const platform::CUDADeviceContext& dev_ctx, int m, ...@@ -311,8 +311,8 @@ void LaunchBiasAddBwKernel(const platform::CUDADeviceContext& dev_ctx, int m,
Launch1DColumnReduce(dev_ctx.stream(), max_threads, reduce_num, left_num, Launch1DColumnReduce(dev_ctx.stream(), max_threads, reduce_num, left_num,
d_out, d_bias); d_out, d_bias);
} else { } else {
Launch2DColumnReduce(dev_ctx.stream(), max_threads, reduce_num, left_num, Launch2DColumnReduce(dev_ctx, max_threads, reduce_num, left_num, d_out,
d_out, d_bias); d_bias);
} }
} }
......
...@@ -11,10 +11,13 @@ limitations under the License. */ ...@@ -11,10 +11,13 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/fused/attn_bias_add.cu.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -36,8 +39,10 @@ class AttnMatMul { ...@@ -36,8 +39,10 @@ class AttnMatMul {
~AttnMatMul() {} ~AttnMatMul() {}
void ComputeForward(const T* weight_data, const T* input_data, void ComputeForward(const framework::Tensor* weight,
const T* bias_data, T* output_data, T* bias_out_data) { const framework::Tensor* input,
const framework::Tensor* bias, framework::Tensor* output,
framework::Tensor* bias_out) {
// Note: for blas.GEMM API in Paddle, it treats all inputs as row-major. // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
// here: (transa, transb): nt, input * weight. // here: (transa, transb): nt, input * weight.
CBLAS_TRANSPOSE transA = CblasNoTrans; CBLAS_TRANSPOSE transA = CblasNoTrans;
...@@ -54,16 +59,25 @@ class AttnMatMul { ...@@ -54,16 +59,25 @@ class AttnMatMul {
// here: (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out) // here: (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_); auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha, blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha,
input_data, weight_data, beta, output_data); input->data<T>(), weight->data<T>(), beta, output->data<T>());
if (compute_bias_) { if (compute_bias_) {
// compute output + bias // compute output + bias
LaunchBiasAddFwKernel(dev_ctx_, bsz_seq_, output_size_, output_data, std::vector<const Tensor*> ins;
bias_data, bias_out_data); std::vector<Tensor*> outs;
ins.emplace_back(output);
ins.emplace_back(bias);
outs.emplace_back(bias_out);
int elewise_add_axis = -1;
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor<T>());
} }
} }
void ComputeBackward(const T* input, const T* weight, const T* d_output, void ComputeBackward(const framework::Tensor* input,
T* d_input, T* d_weight, T* d_bias) { const framework::Tensor* weight,
const framework::Tensor* d_output,
framework::Tensor* d_input, framework::Tensor* d_weight,
framework::Tensor* d_bias) {
T alpha = static_cast<T>(1.0); T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0); T beta = static_cast<T>(0.0);
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_); auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
...@@ -81,11 +95,11 @@ class AttnMatMul { ...@@ -81,11 +95,11 @@ class AttnMatMul {
T* dB_input_1_ptr = nullptr; T* dB_input_1_ptr = nullptr;
T* dB_input_2_ptr = nullptr; T* dB_input_2_ptr = nullptr;
T* dB_output_ptr = d_weight; T* dB_output_ptr = d_weight->data<T>();
T* dA_input_1_ptr = nullptr; T* dA_input_1_ptr = nullptr;
T* dA_input_2_ptr = nullptr; T* dA_input_2_ptr = nullptr;
T* dA_output_ptr = d_input; T* dA_output_ptr = d_input->data<T>();
if (!transA_) { if (!transA_) {
// fw: gemm-nt // fw: gemm-nt
...@@ -104,10 +118,10 @@ class AttnMatMul { ...@@ -104,10 +118,10 @@ class AttnMatMul {
dA_n = input_size_; dA_n = input_size_;
dA_k = output_size_; dA_k = output_size_;
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, d_output, blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha,
input, beta, dB_output_ptr); d_output->data<T>(), input->data<T>(), beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output, blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha,
weight, beta, dA_output_ptr); d_output->data<T>(), weight->data<T>(), beta, dA_output_ptr);
} else { // fw: gemm-nn } else { // fw: gemm-nn
// bw: gemm-tn, dB = A^t * dC // bw: gemm-tn, dB = A^t * dC
dB_transA = CblasTrans; dB_transA = CblasTrans;
...@@ -123,10 +137,10 @@ class AttnMatMul { ...@@ -123,10 +137,10 @@ class AttnMatMul {
dA_n = input_size_; dA_n = input_size_;
dA_k = output_size_; dA_k = output_size_;
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, input, blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha,
d_output, beta, dB_output_ptr); input->data<T>(), d_output->data<T>(), beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output, blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha,
weight, beta, dA_output_ptr); d_output->data<T>(), weight->data<T>(), beta, dA_output_ptr);
} }
} else if (transB_) { } else if (transB_) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
...@@ -138,7 +152,27 @@ class AttnMatMul { ...@@ -138,7 +152,27 @@ class AttnMatMul {
"parameters.")); "parameters."));
} }
if (compute_bias_) { if (compute_bias_) {
LaunchBiasAddBwKernel(dev_ctx_, bsz_seq_, output_size_, d_output, d_bias); // reduce: {0, 1, 2, 3, 4} -> {2, 3, 4} or {0, 1, 2} -> {2}
const auto input_dims = d_output->dims();
const auto output_dims = d_bias->dims();
bool support_case_1 =
(input_dims.size() == 5 && output_dims.size() == 3 &&
(input_dims[2] == output_dims[0]) &&
(input_dims[3] == output_dims[1]) &&
(input_dims[4] == output_dims[2]));
bool support_case_2 =
(input_dims.size() == 3 && output_dims.size() == 1 &&
(input_dims[2] == output_dims[0]));
if (support_case_1 || support_case_2) {
gpuStream_t stream = dev_ctx_.stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*d_output, d_bias, {0, 1},
stream);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support reduce when the input dims are [0,1,2,3,4] and "
"output is [2,3,4]"
"or input is [0,1,2] and output is [2]."));
}
} }
} }
......
...@@ -69,7 +69,7 @@ class FMHARef { ...@@ -69,7 +69,7 @@ class FMHARef {
~FMHARef() {} ~FMHARef() {}
void ComputeForward(const Tensor& qkv_input_tensor, void ComputeForward(const Tensor& qkv_input_tensor,
const Tensor& src_mask_tensor, const Tensor* src_mask_tensor,
Tensor* transpose_2_out_tensor, Tensor* qk_out_tensor, Tensor* transpose_2_out_tensor, Tensor* qk_out_tensor,
Tensor* src_mask_out_tensor, Tensor* softmax_out_tensor, Tensor* src_mask_out_tensor, Tensor* softmax_out_tensor,
Tensor* dropout_mask_out_tensor, Tensor* dropout_mask_out_tensor,
...@@ -111,17 +111,17 @@ class FMHARef { ...@@ -111,17 +111,17 @@ class FMHARef {
blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, q_ptr, blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, q_ptr,
k_ptr, beta, qk_out_data, gemm_batch_size, stride_a, k_ptr, beta, qk_out_data, gemm_batch_size, stride_a,
stride_b); stride_b);
std::vector<const Tensor*> ins;
std::vector<Tensor*> outs;
ins.emplace_back(qk_out_tensor);
ins.emplace_back(&src_mask_tensor);
outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1;
int softmax_axis = -1; int softmax_axis = -1;
if (&src_mask_tensor != nullptr) { if (src_mask_tensor != nullptr) {
std::vector<const Tensor*> ins;
std::vector<Tensor*> outs;
ins.emplace_back(qk_out_tensor);
ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1;
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor<T>()); dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor<T>());
SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *src_mask_out_tensor, SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *src_mask_out_tensor,
softmax_axis, softmax_out_tensor); softmax_axis, softmax_out_tensor);
} else { } else {
...@@ -165,7 +165,7 @@ class FMHARef { ...@@ -165,7 +165,7 @@ class FMHARef {
} }
void ComputeBackward( void ComputeBackward(
const Tensor& transpose_2_out_tensor, const Tensor& src_mask_tensor, const Tensor& transpose_2_out_tensor, const Tensor* src_mask_tensor,
const Tensor& softmax_out_tensor, const Tensor& dropout_mask_out_tensor, const Tensor& softmax_out_tensor, const Tensor& dropout_mask_out_tensor,
const Tensor& dropout_out_tensor, const Tensor& qk_out_tensor, const Tensor& dropout_out_tensor, const Tensor& qk_out_tensor,
const Tensor& src_mask_out_tensor, const Tensor& fmha_out_grad_tensor, const Tensor& src_mask_out_tensor, const Tensor& fmha_out_grad_tensor,
...@@ -249,7 +249,7 @@ class FMHARef { ...@@ -249,7 +249,7 @@ class FMHARef {
softmax_out_grad_tensor); softmax_out_grad_tensor);
} }
if (&src_mask_tensor != nullptr) { if (src_mask_tensor != nullptr) {
SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, softmax_out_tensor, SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, softmax_out_tensor,
*softmax_out_grad_tensor, softmax_axis, *softmax_out_grad_tensor, softmax_axis,
src_mask_out_grad_tensor); src_mask_out_grad_tensor);
......
...@@ -27,8 +27,6 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -27,8 +27,6 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionOp"); "FusedAttentionOp");
...@@ -44,6 +42,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -44,6 +42,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp"); "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut", OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut",
"FusedAttentionOp"); "FusedAttentionOp");
} else {
OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output",
"BiasDropoutResidualOut", "FusedAttentionOp");
} }
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head] // qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
...@@ -57,8 +62,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -57,8 +62,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp"); "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("QKTVOut"), "Output", "QKTVOut", OP_INOUT_CHECK(ctx->HasOutput("QKTVOut"), "Output", "QKTVOut",
"FusedAttentionOp"); "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut",
"FusedAttentionOp"); if (ctx->HasInput("SrcMask")) {
OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut",
"FusedAttentionOp");
}
OP_INOUT_CHECK(ctx->HasOutput("SoftmaxOut"), "Output", "SoftmaxOut", OP_INOUT_CHECK(ctx->HasOutput("SoftmaxOut"), "Output", "SoftmaxOut",
"FusedAttentionOp"); "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutMaskOut"), "Output", OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutMaskOut"), "Output",
...@@ -69,12 +77,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -69,12 +77,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp"); "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"), "Output", "OutLinearOut", OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"), "Output", "OutLinearOut",
"FusedAttentionOp"); "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output",
"BiasDropoutResidualOut", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), "Output", "DropoutMaskOut", OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), "Output", "DropoutMaskOut",
"FusedAttentionOp"); "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp");
...@@ -108,6 +111,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -108,6 +111,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnOut", ctx->GetInputDim("X")); ctx->SetOutputDim("LnOut", ctx->GetInputDim("X"));
} else {
ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X"));
} }
// [batch_size, seq_len, 3, num_head, head_size] // [batch_size, seq_len, 3, num_head, head_size]
ctx->SetOutputDim("QKVOut", ctx->SetOutputDim("QKVOut",
...@@ -119,7 +126,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -119,7 +126,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
{y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); {y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
// [batch, num_head, seq_len, seq_len] // [batch, num_head, seq_len, seq_len]
ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
if (ctx->HasInput("SrcMask")) {
ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
}
// the same as QKOut's shape. // the same as QKOut's shape.
ctx->SetOutputDim("AttnDropoutOut", ctx->SetOutputDim("AttnDropoutOut",
{x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
...@@ -134,12 +144,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -134,12 +144,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]}); ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]});
ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X")); ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X"));
ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]});
if (ctx->Attrs().Get<bool>("dropout_is_test") == false) { if (ctx->Attrs().Get<bool>("dropout_is_test") == false) {
ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X")); ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X"));
} }
ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X"));
ctx->SetOutputDim("Y", ctx->GetInputDim("X")); ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
} }
...@@ -310,25 +318,28 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -310,25 +318,28 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
}); });
AddComment(R"DOC( AddComment(R"DOC(
Add fused attention op whose logic is as follows: Add fused attention op whose logic is as follows:
// @input: [batch_size, seq_len, 3, num_head, head_dim] // @input: [batch_size, seq_len, 3, num_head, head_dim]
// @final_out: [batch_size, seq_len, num_heads, head_dim] // @final_out: [batch_size, seq_len, num_heads, head_dim]
if (pre_layernorm) if (pre_layernorm)
out = layer_norm(input); out = layer_norm(input);
out = compute_qkv(out) + bias; out = compute_qkv(out) + bias;
// fmha module // fmha module
{ {
out = transpose(out, perm=[2, 0, 3, 1, 4]); out = transpose(out, perm=[2, 0, 3, 1, 4]);
out = q * k^t; out = q * k^t;
out = attn_mark + out; out = attn_mask + out;
out = softmax(out); out = softmax(out);
out = dropout(out); out = dropout(out);
out = out * v; out = out * v;
out = transpose(out, perm=[0, 2, 1, 3]); out = transpose(out, perm=[0, 2, 1, 3]);
} }
out = out_linear(out); out = out_linear(out);
final_out = layer_norm(residual + dropout(bias + out)); if (pre_layernorm)
final_out = residual + dropout(bias + out);
else
final_out = layer_norm(residual + dropout(bias + out));
)DOC"); )DOC");
} }
}; };
...@@ -343,20 +354,20 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ...@@ -343,20 +354,20 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"GradOp is only callable when attn_dropout_is_test is false")); "GradOp is only callable when attn_dropout_is_test is false"));
OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean", if (ctx->Attrs().Get<bool>("pre_layer_norm") == false) {
"FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean",
OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance", "FusedAttentionGrad");
"FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance",
if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) { "FusedAttentionGrad");
ctx->SetOutputDim(framework::GradVarName("Ln2Scale"), if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) {
ctx->GetInputDim("Ln2Scale")); ctx->SetOutputDim(framework::GradVarName("Ln2Scale"),
} ctx->GetInputDim("Ln2Scale"));
if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) { }
ctx->SetOutputDim(framework::GradVarName("Ln2Bias"), if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) {
ctx->GetInputDim("Ln2Bias")); ctx->SetOutputDim(framework::GradVarName("Ln2Bias"),
} ctx->GetInputDim("Ln2Bias"));
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad"); }
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) { } else {
OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean", OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean",
"FusedAttentionGrad"); "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance", OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance",
...@@ -364,12 +375,12 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ...@@ -364,12 +375,12 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut", OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut",
"FusedAttentionGrad"); "FusedAttentionGrad");
} }
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW",
"FusedAttentionGrad"); "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionGrad"); "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
"FusedAttentionGrad"); "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
...@@ -400,6 +411,9 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ...@@ -400,6 +411,9 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) { if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim(framework::GradVarName("LnOut"), ctx->SetOutputDim(framework::GradVarName("LnOut"),
ctx->GetInputDim("LnOut")); ctx->GetInputDim("LnOut"));
} else {
ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"),
ctx->GetInputDim("BiasDropoutResidualOut"));
} }
ctx->SetOutputDim(framework::GradVarName("FMHAOut"), ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
ctx->GetInputDim("FMHAOut")); ctx->GetInputDim("FMHAOut"));
...@@ -413,16 +427,17 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ...@@ -413,16 +427,17 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("SoftmaxOut")); ctx->GetInputDim("SoftmaxOut"));
ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"), ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"),
ctx->GetInputDim("AttnDropoutOut")); ctx->GetInputDim("AttnDropoutOut"));
ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"),
ctx->GetInputDim("SrcMaskOut")); if (ctx->HasOutput(framework::GradVarName("SrcMaskOut"))) {
ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"),
ctx->GetInputDim("SrcMaskOut"));
}
ctx->SetOutputDim(framework::GradVarName("QKVOut"), ctx->SetOutputDim(framework::GradVarName("QKVOut"),
ctx->GetInputDim("QKVOut")); ctx->GetInputDim("QKVOut"));
ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"), ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"),
ctx->GetInputDim("QKVBiasOut")); ctx->GetInputDim("QKVBiasOut"));
ctx->SetOutputDim(framework::GradVarName("OutLinearOut"), ctx->SetOutputDim(framework::GradVarName("OutLinearOut"),
ctx->GetInputDim("OutLinearOut")); ctx->GetInputDim("OutLinearOut"));
ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"),
ctx->GetInputDim("BiasDropoutResidualOut"));
} }
protected: protected:
...@@ -448,7 +463,14 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -448,7 +463,14 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("QKVW", this->Input("QKVW")); op->SetInput("QKVW", this->Input("QKVW"));
op->SetInput("QKVBias", this->Input("QKVBias")); op->SetInput("QKVBias", this->Input("QKVBias"));
op->SetInput("SrcMask", this->Input("SrcMask"));
if (this->HasInput("SrcMask")) {
op->SetInput("SrcMask", this->Input("SrcMask"));
op->SetInput("SrcMaskOut", this->Output("SrcMaskOut"));
op->SetOutput(framework::GradVarName("SrcMaskOut"),
this->OutputGrad("SrcMaskOut"));
}
op->SetInput("OutLinearW", this->Input("OutLinearW")); op->SetInput("OutLinearW", this->Input("OutLinearW"));
op->SetInput("OutLinearBias", this->Input("OutLinearBias")); op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
...@@ -466,17 +488,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -466,17 +488,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("LnBias"), op->SetOutput(framework::GradVarName("LnBias"),
this->InputGrad("LnBias")); this->InputGrad("LnBias"));
} }
} } else {
if (this->HasInput("Ln2Scale")) {
if (this->HasInput("Ln2Scale")) { op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetInput("Ln2Scale", this->Input("Ln2Scale")); op->SetOutput(framework::GradVarName("Ln2Scale"),
op->SetOutput(framework::GradVarName("Ln2Scale"), this->InputGrad("Ln2Scale"));
this->InputGrad("Ln2Scale")); }
} if (this->HasInput("Ln2Bias")) {
if (this->HasInput("Ln2Bias")) { op->SetInput("Ln2Bias", this->Input("Ln2Bias"));
op->SetInput("Ln2Bias", this->Input("Ln2Bias")); op->SetOutput(framework::GradVarName("Ln2Bias"),
op->SetOutput(framework::GradVarName("Ln2Bias"), this->InputGrad("Ln2Bias"));
this->InputGrad("Ln2Bias")); }
} }
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
...@@ -499,6 +521,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -499,6 +521,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
if (this->HasOutput("LnVariance")) { if (this->HasOutput("LnVariance")) {
op->SetInput("LnVariance", this->Output("LnVariance")); op->SetInput("LnVariance", this->Output("LnVariance"));
} }
} else {
op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetInput("BiasDropoutResidualOut",
this->Output("BiasDropoutResidualOut"));
} }
op->SetInput("QKVOut", this->Output("QKVOut")); op->SetInput("QKVOut", this->Output("QKVOut"));
op->SetInput("QKVBiasOut", this->Output("QKVBiasOut")); op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
...@@ -508,15 +535,10 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -508,15 +535,10 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("SoftmaxOut", this->Output("SoftmaxOut")); op->SetInput("SoftmaxOut", this->Output("SoftmaxOut"));
op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut")); op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut"));
op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut")); op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut"));
op->SetInput("SrcMaskOut", this->Output("SrcMaskOut"));
op->SetInput("FMHAOut", this->Output("FMHAOut")); op->SetInput("FMHAOut", this->Output("FMHAOut"));
op->SetInput("OutLinearOut", this->Output("OutLinearOut")); op->SetInput("OutLinearOut", this->Output("OutLinearOut"));
op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut")); op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut"));
op->SetInput("BiasDropoutResidualOut",
this->Output("BiasDropoutResidualOut"));
op->SetInput("QKVOut", this->Output("QKVOut")); op->SetInput("QKVOut", this->Output("QKVOut"));
// backward outputs: dinput // backward outputs: dinput
...@@ -525,7 +547,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -525,7 +547,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("LnOut"), op->SetOutput(framework::GradVarName("LnOut"),
this->OutputGrad("LnOut")); this->OutputGrad("LnOut"));
} }
} else {
op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
this->OutputGrad("BiasDropoutResidualOut"));
} }
op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut")); op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut"));
op->SetOutput(framework::GradVarName("QKVBiasOut"), op->SetOutput(framework::GradVarName("QKVBiasOut"),
this->OutputGrad("QKVBiasOut")); this->OutputGrad("QKVBiasOut"));
...@@ -538,12 +564,9 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -538,12 +564,9 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this->OutputGrad("SoftmaxOut")); this->OutputGrad("SoftmaxOut"));
op->SetOutput(framework::GradVarName("AttnDropoutOut"), op->SetOutput(framework::GradVarName("AttnDropoutOut"),
this->OutputGrad("AttnDropoutOut")); this->OutputGrad("AttnDropoutOut"));
op->SetOutput(framework::GradVarName("SrcMaskOut"),
this->OutputGrad("SrcMaskOut"));
op->SetOutput(framework::GradVarName("FMHAOut"), op->SetOutput(framework::GradVarName("FMHAOut"),
this->OutputGrad("FMHAOut")); this->OutputGrad("FMHAOut"));
op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
this->OutputGrad("BiasDropoutResidualOut"));
op->SetOutput(framework::GradVarName("OutLinearOut"), op->SetOutput(framework::GradVarName("OutLinearOut"),
this->OutputGrad("OutLinearOut")); this->OutputGrad("OutLinearOut"));
} }
......
...@@ -95,15 +95,6 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -95,15 +95,6 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
const auto qkv_w_dims = qkv_weight->dims(); const auto qkv_w_dims = qkv_weight->dims();
auto *x_data = input_x->data<T>(); auto *x_data = input_x->data<T>();
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *ln_mean_data =
pre_layer_norm ? ln_mean->mutable_data<U>(ctx.GetPlace()) : nullptr;
auto *ln_var_data =
pre_layer_norm ? ln_var->mutable_data<U>(ctx.GetPlace()) : nullptr;
auto *ln_out_data =
pre_layer_norm ? ln_out->mutable_data<T>(ctx.GetPlace()) : nullptr;
auto *qkv_weight_data = qkv_weight->data<T>(); auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>(); auto *qkv_bias_data = qkv_bias->data<T>();
auto *qkv_out_data = qkv_out->mutable_data<T>(ctx.GetPlace()); auto *qkv_out_data = qkv_out->mutable_data<T>(ctx.GetPlace());
...@@ -114,7 +105,9 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -114,7 +105,9 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
transpose_out_2->mutable_data<T>(ctx.GetPlace()); transpose_out_2->mutable_data<T>(ctx.GetPlace());
auto *qk_out_data = qk_out->mutable_data<T>(ctx.GetPlace()); auto *qk_out_data = qk_out->mutable_data<T>(ctx.GetPlace());
auto *qktv_out_data = qktv_out->mutable_data<T>(ctx.GetPlace()); auto *qktv_out_data = qktv_out->mutable_data<T>(ctx.GetPlace());
auto *src_mask_out_data = src_mask_out->mutable_data<T>(ctx.GetPlace()); auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr
: src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *softmax_out_data = softmax_out->mutable_data<T>(ctx.GetPlace()); auto *softmax_out_data = softmax_out->mutable_data<T>(ctx.GetPlace());
auto *attn_dropout_mask_out_data = auto *attn_dropout_mask_out_data =
attn_dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace()); attn_dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
...@@ -128,16 +121,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -128,16 +121,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto *out_linear_out_data = out_linear_out->mutable_data<T>(ctx.GetPlace()); auto *out_linear_out_data = out_linear_out->mutable_data<T>(ctx.GetPlace());
// get data ptr for bias+dropout+residual+layernorm // get data ptr for bias+dropout+residual+layernorm
auto *ln_scale_2_data =
(ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>());
auto *ln_bias_2_data =
(ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>());
auto *dropout_mask_out_data = auto *dropout_mask_out_data =
dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace()); dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace());
auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace());
auto *final_out_data = out->mutable_data<T>(ctx.GetPlace()); auto *final_out_data = out->mutable_data<T>(ctx.GetPlace());
int batch_size = input_x_dims[0]; int batch_size = input_x_dims[0];
...@@ -176,29 +161,52 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -176,29 +161,52 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
ln_epsilon); ln_epsilon);
if (pre_layer_norm) { if (pre_layer_norm) {
auto *ln_scale_data =
(ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *ln_mean_data = ln_mean->mutable_data<U>(ctx.GetPlace());
auto *ln_var_data = ln_var->mutable_data<U>(ctx.GetPlace());
auto *ln_out_data = ln_out->mutable_data<T>(ctx.GetPlace());
layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data, layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data,
ln_out_data, ln_mean_data, ln_var_data); ln_out_data, ln_mean_data, ln_var_data);
qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data, qkv_compute.ComputeForward(qkv_weight, ln_out, qkv_bias, qkv_out,
qkv_out_data, qkv_bias_out_data); qkv_bias_out);
} else { } else {
qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data, qkv_compute.ComputeForward(qkv_weight, input_x, qkv_bias, qkv_out,
qkv_out_data, qkv_bias_out_data); qkv_bias_out);
} }
fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out_2, fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out, qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out, attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out); qktv_out, fmha_out);
// fmha_out: [batch_size, seq_len, num_head, head_dim] // fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim] // weight: [embed_dim, embed_dim]
// out_linear_out: [batch_size, seq_len, embed_dim] // out_linear_out: [batch_size, seq_len, embed_dim]
out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data, out_linear_compute.ComputeForward(out_linear_weight, fmha_out, nullptr,
nullptr, out_linear_out_data, nullptr); out_linear_out, nullptr);
// output = layernorm(residual + dropout(input + bias)) if (pre_layer_norm) {
fused_dropout_layernorm_helper.LayernormResidualDropoutBias( // output = (residual + dropout(input + bias))
ctx.cuda_device_context(), out_linear_out_data, x_data, fused_dropout_layernorm_helper.ResidualDropoutBias(
out_linear_bias_data, ln_scale_2_data, ln_bias_2_data, ctx.cuda_device_context(), out_linear_out_data, x_data,
bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data, out_linear_bias_data, final_out_data, dropout_mask_out_data);
ln_mean_2_data, ln_var_2_data); } else {
auto *ln_scale_2_data =
(ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>());
auto *ln_bias_2_data =
(ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>());
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace());
auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace());
// output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data,
out_linear_bias_data, ln_scale_2_data, ln_bias_2_data,
bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data,
ln_mean_2_data, ln_var_2_data);
}
} }
}; };
...@@ -265,12 +273,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -265,12 +273,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *qk_out_data = qk_out->data<T>(); auto *qk_out_data = qk_out->data<T>();
auto *qktv_out_data = qktv_out->data<T>(); auto *qktv_out_data = qktv_out->data<T>();
auto *softmax_out_data = softmax_out->data<T>(); auto *softmax_out_data = softmax_out->data<T>();
auto *src_mask_out_data = src_mask_out->data<T>(); auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
auto *out_linear_out_data = out_linear_out->data<T>(); auto *out_linear_out_data = out_linear_out->data<T>();
auto *ln_2_mean_data = ln_2_mean->data<U>();
auto *ln_2_var_data = ln_2_var->data<U>();
auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>(); auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>();
auto *bias_dropout_residual_out_data = bias_dropout_residual_out->data<T>();
// output's grad // output's grad
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X")); auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
...@@ -302,12 +308,12 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -302,12 +308,12 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_softmax_out_data = d_softmax_out->mutable_data<T>(ctx.GetPlace()); auto *d_softmax_out_data = d_softmax_out->mutable_data<T>(ctx.GetPlace());
auto *d_attn_dropout_out_data = auto *d_attn_dropout_out_data =
d_attn_dropout_out->mutable_data<T>(ctx.GetPlace()); d_attn_dropout_out->mutable_data<T>(ctx.GetPlace());
auto *d_src_mask_out_data = d_src_mask_out->mutable_data<T>(ctx.GetPlace()); auto *d_src_mask_out_data =
(src_mask == nullptr) ? nullptr
: d_src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *d_fmha_out_data = d_fmha_out->mutable_data<T>(ctx.GetPlace()); auto *d_fmha_out_data = d_fmha_out->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_out_data = auto *d_out_linear_out_data =
d_out_linear_out->mutable_data<T>(ctx.GetPlace()); d_out_linear_out->mutable_data<T>(ctx.GetPlace());
auto *d_bias_dropout_residual_out_data =
d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
// parameter grad // parameter grad
auto *d_qkv_weight = ctx.Output<Tensor>(framework::GradVarName("QKVW")); auto *d_qkv_weight = ctx.Output<Tensor>(framework::GradVarName("QKVW"));
...@@ -325,12 +331,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -325,12 +331,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_out_linear_weight->mutable_data<T>(ctx.GetPlace()); d_out_linear_weight->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_bias_data = auto *d_out_linear_bias_data =
d_out_linear_bias->mutable_data<T>(ctx.GetPlace()); d_out_linear_bias->mutable_data<T>(ctx.GetPlace());
auto *d_ln_2_scale_data =
(d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data<U>(
ctx.GetPlace()));
auto *d_ln_2_bias_data =
(d_ln_2_bias == nullptr ? nullptr
: d_ln_2_bias->mutable_data<U>(ctx.GetPlace()));
const auto input_x_dims = input_x->dims(); const auto input_x_dims = input_x->dims();
const auto qkv_w_dims = qkv_weight->dims(); const auto qkv_w_dims = qkv_weight->dims();
...@@ -376,17 +376,37 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -376,17 +376,37 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2,
ln2epsilon); ln2epsilon);
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( if (pre_layer_norm) {
ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data, fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data, ctx.cuda_device_context(), d_y_data, dropout_mask_out_data,
d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data, d_out_linear_out_data, d_residual_data, d_out_linear_bias_data);
d_out_linear_out_data, d_out_linear_bias_data, d_residual_data); } else {
auto *ln_2_mean_data = ln_2_mean->data<U>();
auto *ln_2_var_data = ln_2_var->data<U>();
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->data<T>();
auto *d_ln_2_scale_data =
(d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data<U>(
ctx.GetPlace()));
auto *d_ln_2_bias_data =
(d_ln_2_bias == nullptr ? nullptr : d_ln_2_bias->mutable_data<U>(
ctx.GetPlace()));
auto *d_bias_dropout_residual_out_data =
d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data,
dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data,
d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data,
d_out_linear_out_data, d_out_linear_bias_data, d_residual_data);
}
out_linear_compute.ComputeBackward(fmha_out, out_linear_weight,
d_out_linear_out, d_fmha_out,
d_out_linear_weight, nullptr);
out_linear_compute.ComputeBackward(fmha_out_data, out_linear_weight_data,
d_out_linear_out_data, d_fmha_out_data,
d_out_linear_weight_data, nullptr);
fmha_ref_compute.ComputeBackward( fmha_ref_compute.ComputeBackward(
*transpose_out_2, *src_mask, *softmax_out, *attn_dropout_mask_out, *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
d_transpose_out_2, nullptr, d_qkv_bias_out); d_transpose_out_2, nullptr, d_qkv_bias_out);
...@@ -413,15 +433,14 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -413,15 +433,14 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
(d_ln_bias == nullptr ? nullptr (d_ln_bias == nullptr ? nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace())); : d_ln_bias->mutable_data<U>(ctx.GetPlace()));
qkv_compute.ComputeBackward(ln_out_data, qkv_weight_data, qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out, d_ln_out,
d_qkv_bias_out_data, d_ln_out_data, d_qkv_weight, d_qkv_bias);
d_qkv_weight_data, d_qkv_bias_data);
layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data, layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data,
ln_mean_data, ln_var_data, d_x_data, ln_mean_data, ln_var_data, d_x_data,
d_ln_scale_data, d_ln_bias_data); d_ln_scale_data, d_ln_bias_data);
} else { } else {
qkv_compute.ComputeBackward(x_data, qkv_weight_data, d_qkv_bias_out_data, qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x,
d_x_data, d_qkv_weight_data, d_qkv_bias_data); d_qkv_weight, d_qkv_bias);
} }
// gradient accumulation // gradient accumulation
std::vector<const Tensor *> ins; std::vector<const Tensor *> ins;
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/matmul_v2_op.h" #include "paddle/fluid/operators/matmul_v2_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/fluid/operators/layer_norm_kernel.cu.h"
...@@ -261,7 +262,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -261,7 +262,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
framework::Tensor d_linear2_out, d_dropout2_out, d_residual; framework::Tensor d_linear2_out, d_dropout2_out, d_residual;
d_linear2_out.mutable_data<T>({bsz_seq, d_model}, place); d_linear2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_dropout2_out.mutable_data<T>({bsz_seq, d_model}, place); d_dropout2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_residual.mutable_data<T>({bsz_seq, d_model}, place); d_residual.mutable_data<T>(d_x->dims(), place);
if (pre_layer_norm) { if (pre_layer_norm) {
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad( fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
...@@ -301,6 +302,14 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -301,6 +302,14 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
} else { } else {
MatMulGrad(ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight); MatMulGrad(ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight);
} }
std::vector<const Tensor*> ins(2);
std::vector<Tensor*> outs(1);
ins[0] = &d_residual;
ins[1] = d_x;
outs[0] = d_x;
int elewise_add_axis = -1;
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
ctx, ins, &outs, elewise_add_axis, AddFunctor<T>());
} }
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
......
...@@ -94,6 +94,7 @@ if(NOT WITH_GPU) ...@@ -94,6 +94,7 @@ if(NOT WITH_GPU)
LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op) LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api)
LIST(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer)
endif() endif()
if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
......
...@@ -26,6 +26,9 @@ from paddle import tensor ...@@ -26,6 +26,9 @@ from paddle import tensor
from paddle.fluid import layers from paddle.fluid import layers
import unittest import unittest
from op_test import OpTest from op_test import OpTest
from paddle.fluid.framework import default_main_program
default_main_program().random_seed = 42
class TestFusedAttentionOp(OpTest): class TestFusedAttentionOp(OpTest):
...@@ -66,6 +69,7 @@ class TestFusedAttentionOp(OpTest): ...@@ -66,6 +69,7 @@ class TestFusedAttentionOp(OpTest):
self.x_type = np.float32 self.x_type = np.float32
self.attn_mask_type = np.float64 self.attn_mask_type = np.float64
self.pre_layer_norm = False self.pre_layer_norm = False
self.has_attn_mask = True
self.training = True self.training = True
self.batch_size = 8 self.batch_size = 8
...@@ -84,16 +88,20 @@ class TestFusedAttentionOp(OpTest): ...@@ -84,16 +88,20 @@ class TestFusedAttentionOp(OpTest):
def generate_input_data(self): def generate_input_data(self):
self.query = np.random.rand(self.batch_size, self.query_length, self.query = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type) self.embed_dim).astype(self.x_type)
self.attn_mask = np.ones( if self.has_attn_mask:
(self.batch_size, self.num_heads, self.query_length, self.attn_mask = np.ones(
self.key_length), (self.batch_size, self.num_heads, self.query_length,
dtype=self.attn_mask_type) self.key_length),
if self.attn_mask_type == np.int64: dtype=self.attn_mask_type)
self.attn_mask = np.tril(self.attn_mask) if self.attn_mask_type == np.int64:
elif self.attn_mask_type == np.float64: self.attn_mask = np.tril(self.attn_mask)
self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 elif self.attn_mask_type == np.float64:
self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
else:
raise ValueError(
"'attn_mask_type' should be 'int64' or 'float64'.")
else: else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") self.attn_mask = None
self.key, self.value = self.query, self.query self.key, self.value = self.query, self.query
self.dout = np.random.random((self.batch_size, self.query_length, self.dout = np.random.random((self.batch_size, self.query_length,
...@@ -102,7 +110,10 @@ class TestFusedAttentionOp(OpTest): ...@@ -102,7 +110,10 @@ class TestFusedAttentionOp(OpTest):
def GetBaselineOut(self): def GetBaselineOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0)) paddle.disable_static(place=paddle.CUDAPlace(0))
tensor_query = paddle.to_tensor(self.query, stop_gradient=False) tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
attn_mask = None
residual = tensor_query residual = tensor_query
ln1_out = tensor_query ln1_out = tensor_query
...@@ -147,8 +158,8 @@ class TestFusedAttentionOp(OpTest): ...@@ -147,8 +158,8 @@ class TestFusedAttentionOp(OpTest):
residual_out = residual + self.dropout(out) residual_out = residual + self.dropout(out)
if not self.pre_layer_norm: if not self.pre_layer_norm:
final_out = self.norm1(residual_out) final_out = self.norm1(residual_out)
if self.pre_layer_norm: else:
final_out = self.norm2(residual_out) final_out = residual_out
paddle.autograd.backward( paddle.autograd.backward(
[final_out], [paddle.to_tensor(self.dout)], retain_graph=True) [final_out], [paddle.to_tensor(self.dout)], retain_graph=True)
return final_out, tensor_query.grad return final_out, tensor_query.grad
...@@ -187,7 +198,10 @@ class TestFusedAttentionOp(OpTest): ...@@ -187,7 +198,10 @@ class TestFusedAttentionOp(OpTest):
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
x = paddle.to_tensor(self.query, stop_gradient=False) x = paddle.to_tensor(self.query, stop_gradient=False)
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
attn_mask = None
qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False) qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False)
qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
epsilon = 1e-05 epsilon = 1e-05
...@@ -208,9 +222,9 @@ class TestFusedAttentionOp(OpTest): ...@@ -208,9 +222,9 @@ class TestFusedAttentionOp(OpTest):
final_out_ref, x_grad_ref = self.GetBaselineOut() final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut() final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose( np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5) final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4)
np.testing.assert_allclose( np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-5) x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4)
class TestFusedAttentionOpPreLn(TestFusedAttentionOp): class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
...@@ -218,6 +232,7 @@ class TestFusedAttentionOpPreLn(TestFusedAttentionOp): ...@@ -218,6 +232,7 @@ class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
self.x_type = np.float32 self.x_type = np.float32
self.attn_mask_type = np.float64 self.attn_mask_type = np.float64
self.pre_layer_norm = True self.pre_layer_norm = True
self.has_attn_mask = True
self.training = True self.training = True
self.batch_size = 8 self.batch_size = 8
...@@ -237,9 +252,39 @@ class TestFusedAttentionOpPreLn(TestFusedAttentionOp): ...@@ -237,9 +252,39 @@ class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
final_out_ref, x_grad_ref = self.GetBaselineOut() final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut() final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose( np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1) final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4)
np.testing.assert_allclose( np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1) x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4)
class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = False
self.training = True
self.batch_size = 8
self.query_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
def test_fused_attention_op(self):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4)
class TestFusedAttentionOpFp16(TestFusedAttentionOp): class TestFusedAttentionOpFp16(TestFusedAttentionOp):
...@@ -247,6 +292,7 @@ class TestFusedAttentionOpFp16(TestFusedAttentionOp): ...@@ -247,6 +292,7 @@ class TestFusedAttentionOpFp16(TestFusedAttentionOp):
self.x_type = np.float16 self.x_type = np.float16
self.attn_mask_type = np.float64 self.attn_mask_type = np.float64
self.pre_layer_norm = False self.pre_layer_norm = False
self.has_attn_mask = True
self.training = True self.training = True
self.batch_size = 8 self.batch_size = 8
......
...@@ -89,27 +89,32 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, ...@@ -89,27 +89,32 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
qkv_weight = qkv_weight.reshape(qkv_weight.shape[0], qkv_weight.shape[1] * qkv_weight = qkv_weight.reshape(qkv_weight.shape[0], qkv_weight.shape[1] *
qkv_weight.shape[2] * qkv_weight.shape[3]) qkv_weight.shape[2] * qkv_weight.shape[3])
qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] *
qkv_bias.shape[2])
if (pre_layer_norm): if (pre_layer_norm):
ln_out = ln_out.reshape(batch_size * seq_len, embed_dim) ln_out = ln_out.reshape(batch_size * seq_len, embed_dim)
qkv = fc(ln_out, qkv_weight) qkv = fc(ln_out, qkv_weight)
qkv_bias_out = qkv + qkv_bias
ln_out = ln_out.reshape(batch_size, seq_len, embed_dim) ln_out = ln_out.reshape(batch_size, seq_len, embed_dim)
else: else:
query = query.reshape(batch_size * seq_len, embed_dim) query = query.reshape(batch_size * seq_len, embed_dim)
qkv = fc(query, qkv_weight) qkv = fc(query, qkv_weight)
qkv_bias_out = qkv + qkv_bias
query = query.reshape(batch_size, seq_len, embed_dim) query = query.reshape(batch_size, seq_len, embed_dim)
qkv = qkv.reshape(batch_size, seq_len, 3, num_head, head_dim) qkv_bias_out = qkv_bias_out.reshape(batch_size, seq_len, 3, num_head,
head_dim)
# q*k^t # q*k^t
qkv = qkv.transpose( qkv_bias_out = qkv_bias_out.transpose(
(2, 0, 1, 3, 4)) # 3, batch_size, seq_len, num_head, head_dim (2, 0, 1, 3, 4)) # 3, batch_size, seq_len, num_head, head_dim
qkv = qkv.transpose( qkv_bias_out = qkv_bias_out.transpose(
(0, 1, 3, 2, 4)) # 3, batch_size, num_head, seq_len, head_dim (0, 1, 3, 2, 4)) # 3, batch_size, num_head, seq_len, head_dim
q = qkv[0:1, ::] q = qkv_bias_out[0:1, ::]
q = q.reshape(batch_size, num_head, seq_len, head_dim) q = q.reshape(batch_size, num_head, seq_len, head_dim)
k = qkv[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim] k = qkv_bias_out[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim]
k = k.reshape(batch_size, num_head, seq_len, head_dim) k = k.reshape(batch_size, num_head, seq_len, head_dim)
v = qkv[2::] v = qkv_bias_out[2::]
v = v.reshape(batch_size, num_head, seq_len, head_dim) v = v.reshape(batch_size, num_head, seq_len, head_dim)
k = k.transpose([0, 1, 3, 2]) #[batch_size, num_head, head_dim, seq_len] k = k.transpose([0, 1, 3, 2]) #[batch_size, num_head, head_dim, seq_len]
...@@ -138,9 +143,11 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, ...@@ -138,9 +143,11 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
out_linear_bias_out = out_linear_out + out_linear_bias out_linear_bias_out = out_linear_out + out_linear_bias
out_linear_bias_dropout_out = out_linear_bias_out out_linear_bias_dropout_out = out_linear_bias_out
out_linear_bias_dropout_residual_out = query + out_linear_bias_dropout_out out_linear_bias_dropout_residual_out = query + out_linear_bias_dropout_out
out_linear_bias_dropout_residual_ln_out = layer_norm( if not pre_layer_norm:
out_linear_bias_dropout_residual_out, True, True, ln_2_scale, ln_2_bias) out_linear_bias_dropout_residual_out = layer_norm(
return out_linear_bias_dropout_residual_ln_out out_linear_bias_dropout_residual_out, True, True, ln_2_scale,
ln_2_bias)
return out_linear_bias_dropout_residual_out
class TestFusedAttentionAPI(unittest.TestCase): class TestFusedAttentionAPI(unittest.TestCase):
...@@ -152,6 +159,7 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -152,6 +159,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.x_type = np.float32 self.x_type = np.float32
self.attn_mask_type = np.float64 self.attn_mask_type = np.float64
self.pre_layer_norm = True self.pre_layer_norm = True
self.has_attn_mask = True
self.training = True self.training = True
self.need_weight = False self.need_weight = False
...@@ -172,27 +180,37 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -172,27 +180,37 @@ class TestFusedAttentionAPI(unittest.TestCase):
def generate_input_data(self): def generate_input_data(self):
self.query = np.random.rand(self.batch_size, self.query_length, self.query = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type) self.embed_dim).astype(self.x_type)
self.attn_mask = np.ones( if self.has_attn_mask:
(self.batch_size, self.num_heads, self.query_length, self.attn_mask = np.ones(
self.key_length), (self.batch_size, self.num_heads, self.query_length,
dtype=self.attn_mask_type) self.key_length),
if self.attn_mask_type == np.int64: dtype=self.attn_mask_type)
self.attn_mask = np.tril(self.attn_mask) if self.attn_mask_type == np.int64:
elif self.attn_mask_type == np.float64: self.attn_mask = np.tril(self.attn_mask)
self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 elif self.attn_mask_type == np.float64:
self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
else:
raise ValueError(
"'attn_mask_type' should be 'int64' or 'float64'.")
else: else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") self.attn_mask = None
self.key, self.value = self.query, self.query self.key, self.value = self.query, self.query
def run_imperative(self): def run_imperative(self):
if self.has_attn_mask:
attn_mask_tensor = paddle.to_tensor(self.attn_mask)
else:
attn_mask_tensor = None
fused_attn = FusedMultiHeadAttention( fused_attn = FusedMultiHeadAttention(
self.embed_dim, self.num_heads, self.dropout_prob, self.embed_dim, self.num_heads, self.dropout_prob,
self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm, self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
self.need_weight, self.weight_attr, self.bias_attr) self.need_weight, self.weight_attr, self.bias_attr)
qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype('float32')
fused_attn.qkv_bias.set_value(paddle.to_tensor(qkv_bias))
out = fused_attn( out = fused_attn(
paddle.to_tensor(self.query), paddle.to_tensor(self.query),
paddle.to_tensor(self.query), paddle.to_tensor(self.query),
paddle.to_tensor(self.query), paddle.to_tensor(self.attn_mask)) paddle.to_tensor(self.query), attn_mask_tensor)
ref_out = compute_reference(self.pre_layer_norm, self.query, ref_out = compute_reference(self.pre_layer_norm, self.query,
self.attn_mask, self.attn_mask,
fused_attn.pre_ln_scale.numpy(), fused_attn.pre_ln_scale.numpy(),
...@@ -203,7 +221,7 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -203,7 +221,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
fused_attn.qkv_bias.numpy(), fused_attn.qkv_bias.numpy(),
fused_attn.linear_weight.numpy(), fused_attn.linear_weight.numpy(),
fused_attn.linear_bias.numpy()) fused_attn.linear_bias.numpy())
self.assertTrue(np.allclose(ref_out, out, rtol=1e-5, atol=1e-5)) np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-5)
def run_static(self): def run_static(self):
fused_attn = FusedMultiHeadAttention( fused_attn = FusedMultiHeadAttention(
...@@ -215,29 +233,42 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -215,29 +233,42 @@ class TestFusedAttentionAPI(unittest.TestCase):
name='X', name='X',
shape=[self.batch_size, self.query_length, self.embed_dim], shape=[self.batch_size, self.query_length, self.embed_dim],
dtype=self.x_type) dtype=self.x_type)
attn_mask = paddle.static.data( if self.has_attn_mask:
name='SrcMask', attn_mask = paddle.static.data(
shape=[ name='SrcMask',
self.batch_size, self.num_heads, self.query_length, shape=[
self.key_length self.batch_size, self.num_heads, self.query_length,
], self.key_length
dtype=self.attn_mask_type) ],
final_out = fused_attn(x, x, x, attn_mask) dtype=self.attn_mask_type)
final_out = fused_attn(x, x, x, attn_mask)
else:
final_out = fused_attn(x, x, x)
place = paddle.CUDAPlace(0) place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( if self.has_attn_mask:
paddle.static.default_main_program(), out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
feed={"X": self.query, paddle.static.default_main_program(),
"SrcMask": self.attn_mask}, feed={"X": self.query,
fetch_list=[ "SrcMask": self.attn_mask},
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, fetch_list=[
fused_attn.linear_weight, fused_attn.linear_bias, final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.ln_scale, fused_attn.ln_bias fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
]) fused_attn.ln_scale, fused_attn.ln_bias
])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query, },
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias
def test_static_api(self): def test_static_api(self):
...@@ -249,14 +280,36 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -249,14 +280,36 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.attn_mask, ln_scale, ln_bias, self.attn_mask, ln_scale, ln_bias,
ln_2_scale, ln_2_bias, qkv_weight, qkv_bias, ln_2_scale, ln_2_bias, qkv_weight, qkv_bias,
linear_weight, linear_bias) linear_weight, linear_bias)
self.assertTrue( np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-5)
np.allclose(
np.array(ref_out), np.array(out), rtol=1e-5, atol=1e-5))
def test_dynamic_api(self): def test_dynamic_api(self):
paddle.disable_static(place=paddle.CUDAPlace(0)) paddle.disable_static(place=paddle.CUDAPlace(0))
self.run_imperative() self.run_imperative()
class TestFusedAttentionAPINoneAttnMask(TestFusedAttentionAPI):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = False
self.training = True
self.need_weight = False
self.batch_size = 1
self.query_length = 2
self.head_dim = 2
self.num_heads = 2
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -23,6 +23,7 @@ from paddle.nn.layer.norm import LayerNorm ...@@ -23,6 +23,7 @@ from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.common import Linear, Dropout from paddle.nn.layer.common import Linear, Dropout
import unittest import unittest
from op_test import OpTest from op_test import OpTest
from paddle.fluid.framework import default_main_program
class TestFusedFFNOp(OpTest): class TestFusedFFNOp(OpTest):
...@@ -91,7 +92,7 @@ class TestFusedFFNOp(OpTest): ...@@ -91,7 +92,7 @@ class TestFusedFFNOp(OpTest):
def Base(self): def Base(self):
paddle.disable_static() paddle.disable_static()
tensor_src = paddle.to_tensor(self.src, stop_gradient=False) tensor_src = paddle.to_tensor(self.src, stop_gradient=False)
residual = paddle.to_tensor(self.src) residual = tensor_src
if self.pre_layer_norm: if self.pre_layer_norm:
ln1_out = self.norm1(tensor_src) ln1_out = self.norm1(tensor_src)
linear2_out = self.linear2( linear2_out = self.linear2(
...@@ -140,6 +141,7 @@ class TestFusedFFNOp(OpTest): ...@@ -140,6 +141,7 @@ class TestFusedFFNOp(OpTest):
return out, x.grad return out, x.grad
def test_out_and_grad(self): def test_out_and_grad(self):
default_main_program().random_seed = 42
base_out, base_grad = self.Base() base_out, base_grad = self.Base()
fused_out, fused_grad = self.FusedFFN() fused_out, fused_grad = self.FusedFFN()
np.testing.assert_allclose( np.testing.assert_allclose(
...@@ -192,6 +194,7 @@ class TestFusedFFNOpNormalizeBefore(TestFusedFFNOp): ...@@ -192,6 +194,7 @@ class TestFusedFFNOpNormalizeBefore(TestFusedFFNOp):
class APITestStaticFusedFFN(unittest.TestCase): class APITestStaticFusedFFN(unittest.TestCase):
def test_static(self): def test_static(self):
paddle.enable_static() paddle.enable_static()
default_main_program().random_seed = 42
dtype = "float32" dtype = "float32"
layer_norm_dtype = "float32" layer_norm_dtype = "float32"
batch_size = 1 batch_size = 1
...@@ -324,6 +327,18 @@ class TestFusedFFNOpError(unittest.TestCase): ...@@ -324,6 +327,18 @@ class TestFusedFFNOpError(unittest.TestCase):
self.assertRaises(ValueError, test_dropout_rate_value) self.assertRaises(ValueError, test_dropout_rate_value)
def test_dropout_mode():
x = paddle.static.data(
name='x3', shape=[1, 10, 10], dtype="float32")
linear1_weight = paddle.static.data(
name='linear1_weight3', shape=[10, 10], dtype="float32")
linear2_weight = paddle.static.data(
name='linear2_weight3', shape=[10, 10], dtype="float32")
incubate_f.fused_feedforward(
x, linear1_weight, linear2_weight, mode='test')
self.assertRaises(ValueError, test_dropout_mode)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle.incubate.nn import FusedTransformerEncoderLayer
from paddle.nn import TransformerEncoderLayer
from paddle.fluid.framework import default_main_program
import unittest
class TestFusedTransformerEncoderLayer(unittest.TestCase):
def setActivation(self):
self.activation = 'gelu'
def setPreLayerNorm(self):
self.pre_layer_norm = False
def setAttnMask(self):
self.has_attn_mask = True
def setUp(self):
self.batch_size = np.random.randint(1, 8)
self.query_length = np.random.randint(1, 128)
self.nhead = 16
self.head_dim = 4
self.num_heads = self.nhead
self.d_model = self.head_dim * self.num_heads
self.embed_dim = self.d_model
self.dim_feedforward = np.random.randint(1, 32)
self.dropout_rate = 0
self.attn_dropout_rate = None
self.act_dropout_rate = None
self.attn_mask_type = np.float64
self.key_length = self.query_length
self.dtype = 'float32'
self.setActivation()
self.setPreLayerNorm()
self.setAttnMask()
def fused_weight(self, weight, num_head):
a = paddle.transpose(weight, perm=[1, 0])
return paddle.reshape(
a, shape=[1, num_head, int(a.shape[0] / num_head), a.shape[1]])
def fused_qkv(self, q, k, v, num_head):
fq = self.fused_weight(q, num_head)
fk = self.fused_weight(k, num_head)
fv = self.fused_weight(v, num_head)
return paddle.concat(x=[fq, fk, fv], axis=0)
def test_out(self):
default_main_program().random_seed = 42
base_encoder = TransformerEncoderLayer(
self.d_model, self.nhead, self.dim_feedforward, self.dropout_rate,
self.activation, self.attn_dropout_rate, self.act_dropout_rate,
self.pre_layer_norm)
src = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.dtype)
if self.has_attn_mask:
attn_mask = np.ones(
(self.batch_size, self.num_heads, self.query_length,
self.key_length),
dtype=self.attn_mask_type)
attn_mask_tensor = paddle.to_tensor(attn_mask)
else:
attn_mask = None
attn_mask_tensor = None
dout = np.random.random(src.shape).astype(self.dtype)
base_out = base_encoder(
paddle.to_tensor(
src, stop_gradient=False), attn_mask_tensor)
paddle.autograd.backward([base_out], [paddle.to_tensor(dout)], True)
fused_encoder = FusedTransformerEncoderLayer(
self.d_model, self.nhead, self.dim_feedforward, self.dropout_rate,
self.activation, self.attn_dropout_rate, self.act_dropout_rate,
self.pre_layer_norm)
fused_encoder.ffn._linear1_weight.set_value(base_encoder.linear1.weight)
fused_encoder.ffn._linear1_bias.set_value(base_encoder.linear1.bias)
fused_encoder.ffn._linear2_weight.set_value(base_encoder.linear2.weight)
fused_encoder.ffn._linear2_bias.set_value(base_encoder.linear2.bias)
if self.pre_layer_norm:
fused_encoder.ffn._ln1_scale.set_value(base_encoder.norm2.weight)
fused_encoder.ffn._ln1_bias.set_value(base_encoder.norm2.bias)
else:
fused_encoder.ffn._ln2_scale.set_value(base_encoder.norm2.weight)
fused_encoder.ffn._ln2_bias.set_value(base_encoder.norm2.bias)
fused_encoder.fused_attn.linear_weight.set_value(
base_encoder.self_attn.out_proj.weight)
fused_encoder.fused_attn.linear_bias.set_value(
base_encoder.self_attn.out_proj.bias)
if self.pre_layer_norm:
fused_encoder.fused_attn.pre_ln_scale.set_value(
base_encoder.norm1.weight)
fused_encoder.fused_attn.pre_ln_bias.set_value(
base_encoder.norm1.bias)
else:
fused_encoder.fused_attn.ln_scale.set_value(
base_encoder.norm1.weight)
fused_encoder.fused_attn.ln_bias.set_value(base_encoder.norm1.bias)
q = base_encoder.self_attn.q_proj.weight
q_bias = base_encoder.self_attn.q_proj.bias
k = base_encoder.self_attn.k_proj.weight
k_bias = base_encoder.self_attn.k_proj.bias
v = base_encoder.self_attn.v_proj.weight
v_bias = base_encoder.self_attn.v_proj.bias
qkv_weight = self.fused_qkv(q, k, v, self.num_heads)
fused_encoder.fused_attn.qkv_weight.set_value(qkv_weight)
tmp = paddle.concat(x=[q_bias, k_bias, v_bias], axis=0)
qkv_bias = paddle.reshape(
tmp,
shape=[3, self.num_heads, int(tmp.shape[0] / 3 / self.num_heads)])
fused_encoder.fused_attn.qkv_bias.set_value(qkv_bias)
fused_out = fused_encoder(
paddle.to_tensor(
src, stop_gradient=False), attn_mask_tensor)
paddle.autograd.backward([fused_out], [paddle.to_tensor(dout)], True)
correct_ffn_str = 'd_model={}, dim_feedforward={}, dropout_rate={}, epsilon={}, activation={}, act_dropout_rate={}, normalize_before={}, dtype={}'.format(
self.d_model, self.dim_feedforward, self.dropout_rate,
fused_encoder.ffn._epsilon, self.activation, self.dropout_rate,
self.pre_layer_norm, self.dtype)
self.assertTrue(fused_encoder.ffn.extra_repr(), correct_ffn_str)
correct_attn_str = 'embed_dim={}, num_heads={}, dropout_rate={}, attn_dropout_rate={}, epsilon={}, kdim={}, vdim={}, normalize_before={}, need_weights={}, dtype={}'.format(
self.embed_dim, self.num_heads, self.dropout_rate,
self.dropout_rate, fused_encoder.fused_attn._epsilon, None, None,
self.pre_layer_norm, False, self.dtype)
self.assertTrue(fused_encoder.fused_attn.extra_repr(), correct_attn_str)
np.testing.assert_allclose(
fused_out.numpy(), base_out.numpy(), rtol=1e-3, atol=1e-4)
self.assertTrue(
np.allclose(
fused_out.grad.numpy(),
base_out.grad.numpy(),
rtol=1e-3,
atol=1e-4))
class TestFusedTransformerEncoderLayerAct(TestFusedTransformerEncoderLayer):
def setActivation(self):
self.activation = 'relu'
class TestFusedTransformerEncoderLayerPreLayerNorm(
TestFusedTransformerEncoderLayer):
def setPreLayerNorm(self):
self.pre_layer_norm = True
class TestFusedTransformerEncoderLayerAttnMaskIsNone(
TestFusedTransformerEncoderLayer):
def setAttnMask(self):
self.has_attn_mask = False
class TestFusedTransformerEncoderLayerPreLnTrueAttnMaskIsNone(
TestFusedTransformerEncoderLayer):
def setPreLayerNorm(self):
self.pre_layer_norm = True
def setAttnMask(self):
self.has_attn_mask = False
if __name__ == "__main__":
unittest.main()
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode, default_main_program
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid import core, dygraph_utils from paddle.fluid import core, dygraph_utils
from paddle import _C_ops from paddle import _C_ops
...@@ -43,6 +43,8 @@ def fused_feedforward(x, ...@@ -43,6 +43,8 @@ def fused_feedforward(x,
ln1_epsilon=1e-5, ln1_epsilon=1e-5,
ln2_epsilon=1e-5, ln2_epsilon=1e-5,
pre_layer_norm=False, pre_layer_norm=False,
training=True,
mode='upscale_in_train',
name=None): name=None):
""" """
This is a fusion operator to compute feed forward layer in transformer model architecture. This is a fusion operator to compute feed forward layer in transformer model architecture.
...@@ -74,6 +76,18 @@ def fused_feedforward(x, ...@@ -74,6 +76,18 @@ def fused_feedforward(x,
ln1_epsilon (float, optional): Small float of first layer_norm added to denominator to avoid dividing by zero. Default is 1e-5. ln1_epsilon (float, optional): Small float of first layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.
ln2_epsilon (float, optional): Small float of second layer_norm added to denominator to avoid dividing by zero. Default is 1e-5. ln2_epsilon (float, optional): Small float of second layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.
pre_layer_norm (bool, optional): add layer_norm in the pre-processing stage or post-processing state. pre_layer_norm (bool, optional): add layer_norm in the pre-processing stage or post-processing state.
training (bool, optional): A flag indicating whether it is in train phrase or not. Default True.
mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
1. upscale_in_train(default), upscale the output at training time
- train: out = input * mask / ( 1.0 - p )
- inference: out = input
2. downscale_in_infer, downscale the output at inference
- train: out = input * mask
- inference: out = input * (1.0 - p)
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
...@@ -98,13 +112,27 @@ def fused_feedforward(x, ...@@ -98,13 +112,27 @@ def fused_feedforward(x,
_verify_dropout_rate(dropout1_rate) _verify_dropout_rate(dropout1_rate)
_verify_dropout_rate(dropout2_rate) _verify_dropout_rate(dropout2_rate)
seed = None
if mode not in ('downscale_in_infer', 'upscale_in_train'):
raise ValueError(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'")
mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer
if in_dygraph_mode(): if in_dygraph_mode():
if default_main_program().random_seed != 0:
seed = default_main_program().random_seed
out, _, _, _, _, _, _, _, _, _, _ = _C_ops.fused_feedforward( out, _, _, _, _, _, _, _, _, _, _ = _C_ops.fused_feedforward(
x, None, None, linear1_weight, linear1_bias, linear2_weight, x, None, None, linear1_weight, linear1_bias, linear2_weight,
linear2_bias, ln1_scale, ln1_bias, ln2_scale, ln2_bias, linear2_bias, ln1_scale, ln1_bias, ln2_scale, ln2_bias,
'pre_layer_norm', pre_layer_norm, 'ln1_epsilon', ln1_epsilon, 'pre_layer_norm', pre_layer_norm, 'ln1_epsilon', ln1_epsilon,
'ln2_epsilon', ln2_epsilon, 'act_method', activation, 'ln2_epsilon', ln2_epsilon, 'act_method', activation,
'dropout1_rate', dropout1_rate, 'dropout2_rate', dropout2_rate) 'dropout1_rate', dropout1_rate, 'dropout2_rate', dropout2_rate,
"dropout1_is_test", not training, "dropout2_is_test", not training,
"dropout1_fix_seed", seed is not None, "dropout2_fix_seed",
seed is not None, "dropout1_seed", seed
if seed is not None else 0, "dropout2_seed", seed
if seed is not None else 0, 'dropout1_implementation', mode,
'dropout2_implementation', mode)
return out return out
helper = LayerHelper("fused_feedforward") helper = LayerHelper("fused_feedforward")
...@@ -136,6 +164,9 @@ def fused_feedforward(x, ...@@ -136,6 +164,9 @@ def fused_feedforward(x,
dropout2_out = helper.create_variable_for_type_inference( dropout2_out = helper.create_variable_for_type_inference(
x.dtype, stop_gradient=True) x.dtype, stop_gradient=True)
if (seed is None or seed == 0) and helper.main_program.random_seed != 0:
seed = helper.main_program.random_seed
helper.append_op( helper.append_op(
type='fused_feedforward', type='fused_feedforward',
inputs={ inputs={
...@@ -169,6 +200,14 @@ def fused_feedforward(x, ...@@ -169,6 +200,14 @@ def fused_feedforward(x,
'pre_layer_norm': pre_layer_norm, 'pre_layer_norm': pre_layer_norm,
'ln1_epsilon': ln1_epsilon, 'ln1_epsilon': ln1_epsilon,
'ln2_epsilon': ln2_epsilon, 'ln2_epsilon': ln2_epsilon,
'dropout1_is_test': not training,
'dropout2_is_test': not training,
'dropout1_fix_seed': seed is not None,
'dropout2_fix_seed': seed is not None,
'dropout1_seed': seed if seed is not None else 0,
'dropout2_seed': seed if seed is not None else 0,
'dropout1_implementation': mode,
'dropout2_implementation': mode
}) })
return out return out
...@@ -188,6 +227,8 @@ def fused_multi_head_attention(x, ...@@ -188,6 +227,8 @@ def fused_multi_head_attention(x,
dropout_rate=0.5, dropout_rate=0.5,
attn_dropout_rate=0.5, attn_dropout_rate=0.5,
ln_epsilon=1e-05, ln_epsilon=1e-05,
training=True,
mode='upscale_in_train',
name=None): name=None):
""" """
Attention mapps queries and a set of key-value pairs to outputs, and Attention mapps queries and a set of key-value pairs to outputs, and
...@@ -214,7 +255,10 @@ def fused_multi_head_attention(x, ...@@ -214,7 +255,10 @@ def fused_multi_head_attention(x,
out = out * v out = out * v
out = transpose(out, perm=[0, 2, 1, 3]) out = transpose(out, perm=[0, 2, 1, 3])
out = out_linear(out) out = out_linear(out)
out = layer_norm(x + dropout(linear_bias + out)) if pre_layer_norm:
out = x + dropout(linear_bias + out)
else:
out = layer_norm(x + dropout(linear_bias + out))
Parameters: Parameters:
x (Tensor): The input tensor of fused_multi_head_attention. The shape is x (Tensor): The input tensor of fused_multi_head_attention. The shape is
...@@ -247,6 +291,19 @@ def fused_multi_head_attention(x, ...@@ -247,6 +291,19 @@ def fused_multi_head_attention(x,
0 for no dropout. Default 0.5. 0 for no dropout. Default 0.5.
ln_epsilon (float, optional): Small float value added to denominator of layer_norm ln_epsilon (float, optional): Small float value added to denominator of layer_norm
to avoid dividing by zero. Default is 1e-5. to avoid dividing by zero. Default is 1e-5.
training (bool, optional): A flag indicating whether it is in train phrase or not. Default True.
mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
1. upscale_in_train(default), upscale the output at training time
- train: out = input * mask / ( 1.0 - p )
- inference: out = input
2. downscale_in_infer, downscale the output at inference
- train: out = input * mask
- inference: out = input * (1.0 - p)
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
Tensor: The output Tensor, the data type and shape is same as `x`. Tensor: The output Tensor, the data type and shape is same as `x`.
...@@ -280,7 +337,16 @@ def fused_multi_head_attention(x, ...@@ -280,7 +337,16 @@ def fused_multi_head_attention(x,
# [2, 4, 128] # [2, 4, 128]
print(output.shape) print(output.shape)
""" """
seed = None
if mode not in ('downscale_in_infer', 'upscale_in_train'):
raise ValueError(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'")
mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer
if in_dygraph_mode(): if in_dygraph_mode():
if default_main_program().random_seed != 0:
seed = default_main_program().random_seed
# pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out,
# qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out, # qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out,
# linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out # linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out
...@@ -295,7 +361,12 @@ def fused_multi_head_attention(x, ...@@ -295,7 +361,12 @@ def fused_multi_head_attention(x,
linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm', linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm',
pre_layer_norm, 'epsilon', pre_ln_epsilon, 'dropout_rate', pre_layer_norm, 'epsilon', pre_ln_epsilon, 'dropout_rate',
dropout_rate, 'attn_dropout_rate', attn_dropout_rate, 'ln_epsilon', dropout_rate, 'attn_dropout_rate', attn_dropout_rate, 'ln_epsilon',
ln_epsilon) ln_epsilon, 'attn_dropout_is_test', not training, 'dropout_is_test',
not training, 'attn_dropout_fix_seed', seed is not None,
'dropout_fix_seed', seed is not None, 'attn_dropout_seed', seed
if seed is not None else 0, 'dropout_seed', seed
if seed is not None else 0, 'attn_dropout_implementation', mode,
'dropout_implementation', mode)
return final_out return final_out
else: else:
helper = LayerHelper('fused_multi_head_attention', **locals()) helper = LayerHelper('fused_multi_head_attention', **locals())
...@@ -323,13 +394,24 @@ def fused_multi_head_attention(x, ...@@ -323,13 +394,24 @@ def fused_multi_head_attention(x,
if ln_bias: if ln_bias:
inputs['Ln2Bias'] = [ln_bias] inputs['Ln2Bias'] = [ln_bias]
if (seed is None or seed == 0) and helper.main_program.random_seed != 0:
seed = helper.main_program.random_seed
# set attrs # set attrs
attrs = { attrs = {
'pre_layer_norm': pre_layer_norm, 'pre_layer_norm': pre_layer_norm,
'epsilon': pre_ln_epsilon, 'epsilon': pre_ln_epsilon,
'ln_epsilon': ln_epsilon, 'ln_epsilon': ln_epsilon,
'dropout_rate': dropout_rate, 'dropout_rate': dropout_rate,
'attn_dropout_rate': attn_dropout_rate 'attn_dropout_rate': attn_dropout_rate,
'attn_dropout_is_test': not training,
'dropout_is_test': not training,
'attn_dropout_fix_seed': seed is not None,
'dropout_fix_seed': seed is not None,
'attn_dropout_seed': seed if seed is not None else 0,
'dropout_seed': seed if seed is not None else 0,
'attn_dropout_implementation': mode,
'dropout_implementation': mode,
} }
# set outputs # set outputs
......
...@@ -43,7 +43,7 @@ class FusedMultiHeadAttention(Layer): ...@@ -43,7 +43,7 @@ class FusedMultiHeadAttention(Layer):
`embed_dim`. Default None. `embed_dim`. Default None.
vdim (int, optional): The feature size in value. If None, assumed equal to vdim (int, optional): The feature size in value. If None, assumed equal to
`embed_dim`. Default None. `embed_dim`. Default None.
normalize_before (bool, optional): Indicate whether it is pre_layer_norm normalize_before (bool, optional): Indicate whether it is pre_layer_norm
(True) or post_layer_norm architecture (False). Default False. (True) or post_layer_norm architecture (False). Default False.
need_weights (bool, optional): Indicate whether to return the attention need_weights (bool, optional): Indicate whether to return the attention
weights. Now, only False is supported. Default False. weights. Now, only False is supported. Default False.
...@@ -54,6 +54,8 @@ class FusedMultiHeadAttention(Layer): ...@@ -54,6 +54,8 @@ class FusedMultiHeadAttention(Layer):
Default: None, which means the default bias parameter property is used. Default: None, which means the default bias parameter property is used.
If it is set to False, this layer will not have trainable bias parameter. If it is set to False, this layer will not have trainable bias parameter.
See usage for details in :code:`ParamAttr`. See usage for details in :code:`ParamAttr`.
epsilon (float, optional): The small value added to the variance to prevent
division by zero. Default: 1e-05.
Examples: Examples:
...@@ -80,6 +82,7 @@ class FusedMultiHeadAttention(Layer): ...@@ -80,6 +82,7 @@ class FusedMultiHeadAttention(Layer):
need_weights=False, need_weights=False,
weight_attr=None, weight_attr=None,
bias_attr=None, bias_attr=None,
epsilon=1e-5,
name=None): name=None):
super(FusedMultiHeadAttention, self).__init__() super(FusedMultiHeadAttention, self).__init__()
...@@ -88,13 +91,18 @@ class FusedMultiHeadAttention(Layer): ...@@ -88,13 +91,18 @@ class FusedMultiHeadAttention(Layer):
assert num_heads > 0, ("Expected nhead to be greater than 0, " assert num_heads > 0, ("Expected nhead to be greater than 0, "
"but recieved {}".format(num_heads)) "but recieved {}".format(num_heads))
attn_dropout_rate = dropout_rate if attn_dropout_rate is None else attn_dropout_rate
self.normalize_before = normalize_before self.normalize_before = normalize_before
self._dtype = self._helper.get_default_dtype() self._dtype = self._helper.get_default_dtype()
self._weight_attr = weight_attr self._weight_attr = weight_attr
self._bias_attr = bias_attr self._bias_attr = bias_attr
self._epsilon = epsilon
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.kdim = kdim
self.vdim = vdim
self.need_weights = need_weights
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
assert need_weights == False, "Only support need_weight is False now." assert need_weights == False, "Only support need_weight is False now."
...@@ -186,15 +194,24 @@ class FusedMultiHeadAttention(Layer): ...@@ -186,15 +194,24 @@ class FusedMultiHeadAttention(Layer):
pre_ln_bias=self.pre_ln_bias, pre_ln_bias=self.pre_ln_bias,
ln_scale=self.ln_scale, ln_scale=self.ln_scale,
ln_bias=self.ln_bias, ln_bias=self.ln_bias,
pre_ln_epsilon=1e-05, pre_ln_epsilon=self._epsilon,
qkv_bias=self.qkv_bias, qkv_bias=self.qkv_bias,
linear_bias=self.linear_bias, linear_bias=self.linear_bias,
attn_mask=attn_mask, attn_mask=attn_mask,
dropout_rate=self.dropout_rate, dropout_rate=self.dropout_rate,
attn_dropout_rate=self.attn_dropout_rate, attn_dropout_rate=self.attn_dropout_rate,
ln_epsilon=1e-05) ln_epsilon=self._epsilon,
training=self.training,
name=self.name)
return out return out
def extra_repr(self):
name_str = ', name={}'.format(self.name) if self.name else ''
return 'embed_dim={}, num_heads={}, dropout_rate={}, attn_dropout_rate={}, epsilon={}, kdim={}, vdim={}, normalize_before={}, need_weights={}, dtype={}{}'.format(
self.embed_dim, self.num_heads, self.dropout_rate,
self.attn_dropout_rate, self._epsilon, self.kdim, self.vdim,
self.normalize_before, self.need_weights, self._dtype, name_str)
class FusedFeedForward(Layer): class FusedFeedForward(Layer):
""" """
...@@ -203,6 +220,8 @@ class FusedFeedForward(Layer): ...@@ -203,6 +220,8 @@ class FusedFeedForward(Layer):
dim_feedforward (int): The hidden layer size. dim_feedforward (int): The hidden layer size.
dropout_rate (float, optional): The dropout probability used in pre-process dropout_rate (float, optional): The dropout probability used in pre-process
and post-precess. Default 0.1 and post-precess. Default 0.1
epsilon (float, optional): he small value added to the variance to prevent
division by zero. Default: 1e-05.
activation (str, optional): The activation function. Default relu. activation (str, optional): The activation function. Default relu.
act_dropout_rate (float, optional): The dropout probability after activition. act_dropout_rate (float, optional): The dropout probability after activition.
If None, use the value of `dropout_rate`. Default None If None, use the value of `dropout_rate`. Default None
...@@ -235,11 +254,13 @@ class FusedFeedForward(Layer): ...@@ -235,11 +254,13 @@ class FusedFeedForward(Layer):
d_model, d_model,
dim_feedforward, dim_feedforward,
dropout_rate=0.1, dropout_rate=0.1,
epsilon=1e-05,
activation="relu", activation="relu",
act_dropout_rate=None, act_dropout_rate=None,
normalize_before=False, normalize_before=False,
weight_attr=None, weight_attr=None,
bias_attr=None): bias_attr=None,
name=None):
super(FusedFeedForward, self).__init__() super(FusedFeedForward, self).__init__()
assert d_model > 0, ( assert d_model > 0, (
...@@ -256,6 +277,7 @@ class FusedFeedForward(Layer): ...@@ -256,6 +277,7 @@ class FusedFeedForward(Layer):
self._act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate self._act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate
self._act_method = activation self._act_method = activation
self._normalize_before = normalize_before self._normalize_before = normalize_before
self._epsilon = epsilon
self._linear1_weight = self.create_parameter( self._linear1_weight = self.create_parameter(
shape=[d_model, dim_feedforward], shape=[d_model, dim_feedforward],
...@@ -292,15 +314,36 @@ class FusedFeedForward(Layer): ...@@ -292,15 +314,36 @@ class FusedFeedForward(Layer):
default_initializer=Constant(1.0)) default_initializer=Constant(1.0))
self._ln2_bias = self.create_parameter( self._ln2_bias = self.create_parameter(
shape=[d_model], attr=None, is_bias=True) shape=[d_model], attr=None, is_bias=True)
self.name = name
def forward(self, src, cache=None): def forward(self, src, cache=None):
out = incubate_f.fused_feedforward( out = incubate_f.fused_feedforward(
src, self._linear1_weight, self._linear2_weight, self._linear1_bias, src,
self._linear2_bias, self._ln1_scale, self._ln1_bias, self._linear1_weight,
self._ln2_scale, self._ln2_bias, self._dropout_rate, self._linear2_weight,
self._act_dropout_rate, self._act_method, self._normalize_before) self._linear1_bias,
self._linear2_bias,
self._ln1_scale,
self._ln1_bias,
self._ln2_scale,
self._ln2_bias,
dropout1_rate=self._act_dropout_rate,
dropout2_rate=self._dropout_rate,
activation=self._act_method,
ln1_epsilon=self._epsilon,
ln2_epsilon=self._epsilon,
pre_layer_norm=self._normalize_before,
training=self.training,
name=self.name)
return out return out
def extra_repr(self):
name_str = ', name={}'.format(self.name) if self.name else ''
return 'd_model={}, dim_feedforward={}, dropout_rate={}, epsilon={}, activation={}, act_dropout_rate={}, normalize_before={}, dtype={}{}'.format(
self._d_model, self._dim_feedforward, self._dropout_rate,
self._epsilon, self._act_method, self._act_dropout_rate,
self._normalize_before, self._dtype, name_str)
class FusedTransformerEncoderLayer(Layer): class FusedTransformerEncoderLayer(Layer):
""" """
...@@ -393,7 +436,9 @@ class FusedTransformerEncoderLayer(Layer): ...@@ -393,7 +436,9 @@ class FusedTransformerEncoderLayer(Layer):
self.fused_attn = FusedMultiHeadAttention( self.fused_attn = FusedMultiHeadAttention(
d_model, d_model,
nhead, nhead,
dropout_rate=attn_dropout_rate, dropout_rate=dropout_rate,
attn_dropout_rate=attn_dropout_rate,
normalize_before=self.normalize_before,
weight_attr=weight_attrs[0], weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0]) bias_attr=bias_attrs[0])
...@@ -401,6 +446,7 @@ class FusedTransformerEncoderLayer(Layer): ...@@ -401,6 +446,7 @@ class FusedTransformerEncoderLayer(Layer):
d_model, d_model,
dim_feedforward, dim_feedforward,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
activation=activation,
act_dropout_rate=act_dropout_rate, act_dropout_rate=act_dropout_rate,
normalize_before=self.normalize_before, normalize_before=self.normalize_before,
weight_attr=weight_attrs[1], weight_attr=weight_attrs[1],
......
...@@ -235,8 +235,8 @@ def interpolate(x, ...@@ -235,8 +235,8 @@ def interpolate(x,
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import numpy as np import numpy as np
import paddle.nn.functional as F import paddle.nn.functional as F
# given out size # given out size
...@@ -244,7 +244,7 @@ def interpolate(x, ...@@ -244,7 +244,7 @@ def interpolate(x,
x = paddle.to_tensor(input_data) x = paddle.to_tensor(input_data)
output_1 = F.interpolate(x=x, size=[12,12]) output_1 = F.interpolate(x=x, size=[12,12])
print(output_1.shape) print(output_1.shape)
# [2L, 3L, 12L, 12L] # [2L, 3L, 12L, 12L]
# given scale # given scale
output_2 = F.interpolate(x=x, scale_factor=[2,1]) output_2 = F.interpolate(x=x, scale_factor=[2,1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册