未验证 提交 dceccd9d 编写于 作者: L Li Min 提交者: GitHub

Add fused_bias_dropout_residual_ln op and layer. (#43062)

* add fused_bias_dropout_residual_ln op and layer.
上级 e1e0deed
......@@ -22,6 +22,7 @@ register_operators(EXCLUDES
fused_transformer_op
fused_feedforward_op
fused_multi_transformer_op
fused_bias_dropout_residual_layer_norm_op
resnet_unit_op
fused_gemm_epilogue_op
fused_gate_attention_op)
......@@ -81,6 +82,7 @@ if (WITH_GPU OR WITH_ROCM)
# fused_attention_op
op_library(fused_attention_op)
op_library(fused_multi_transformer_op)
op_library(fused_bias_dropout_residual_layer_norm_op)
endif()
# resnet_unit needs cudnn 8.0 above
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000))
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class FusedBiasDropoutResidualLnOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"FusedBiasDropoutResidualLnOp");
OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean",
"FusedBiasDropoutResidualLnOp");
OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance",
"FusedBiasDropoutResidualLnOp");
OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output",
"BiasDropoutResidualOut", "FusedBiasDropoutResidualLnOp");
OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), "Output", "DropoutMaskOut",
"FusedBiasDropoutResidualLnOp");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y",
"FusedBiasDropoutResidualLnOp");
auto x_dim = ctx->GetInputDim("X");
int left = 1;
for (int i = 0; i < x_dim.size() - 1; i++) {
left *= x_dim[i];
}
ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X"));
if (ctx->Attrs().Get<bool>("dropout_is_test") == false) {
ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X"));
}
ctx->SetOutputDim("LnMean", {left});
ctx->SetOutputDim("LnVariance", {left});
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input = ctx.Input<Tensor>("X");
auto input_data_type = framework::TransToProtoVarType(input->dtype());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class FusedBiasDropoutResidualLnOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor.");
AddInput("Residual", "The residual tensor.");
AddInput("Bias", "The linear bias tensor.").AsDispensable();
AddInput("LnScale",
"(optional) Scale is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDispensable();
AddInput("LnBias",
"(optional) Bias is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDispensable();
AddOutput("BiasDropoutResidualOut", "Output of bias + dropout + residual.")
.AsIntermediate();
AddOutput("DropoutMaskOut", "The random sampled dropout mask.")
.AsIntermediate();
AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate();
AddOutput("LnVariance", "Variance of the current mini batch.")
.AsIntermediate();
AddOutput("Y", "Result.");
AddAttr<float>("dropout_rate", "Probability of setting units to zero.")
.SetDefault(.5f)
.AddCustomChecker([](const float &drop_p) {
PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, true,
platform::errors::InvalidArgument(
"'dropout_rate' must be between 0.0 and 1.0."));
});
AddAttr<bool>("dropout_is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddAttr<bool>("dropout_fix_seed",
"A flag indicating whether to use a fixed seed to generate "
"random mask. NOTE: DO NOT set this flag to true in "
"training. Setting this flag to true is only useful in "
"unittest or for debug that always the same output units "
"will be dropped.")
.SetDefault(true);
AddAttr<int>("dropout_seed", "Dropout random seed.").SetDefault(0);
AddAttr<std::string>(
"dropout_implementation",
"[\"downgrade_in_infer\"|\"upscale_in_train\"]"
"The meaning is the same as 'attn_dropout_implementation'.")
.SetDefault("downgrade_in_infer")
.AddCustomChecker([](const std::string &type) {
PADDLE_ENFORCE_EQ(
type == "downgrade_in_infer" || type == "upscale_in_train", true,
platform::errors::InvalidArgument(
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train"));
});
AddAttr<float>("ln_epsilon",
"Constant for numerical stability [default 1e-5].")
.SetDefault(1e-5)
.AddCustomChecker([](const float &ln_epsilon) {
PADDLE_ENFORCE_EQ(ln_epsilon >= 0.0f && ln_epsilon <= 0.001f, true,
platform::errors::InvalidArgument(
"'epsilon' of the LayerNorm should be between "
"0.0 and 0.001, But received [%s].",
ln_epsilon));
});
AddComment(R"DOC(
Add fused bias_dropout_residual_layer_norm op whose logic is as follows:
// @input: [batch_size, seq_len, embed_dim]
// @final_out: [batch_size, seq_len, embed_dim]
y = layer_norm(residual + dropout(bias + x));
)DOC");
}
};
class FusedBiasDropoutResidualLnGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->Attrs().Get<bool>("dropout_is_test"), false,
platform::errors::InvalidArgument(
"GradOp is only callable when dropout_is_test is false"));
OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean",
"FusedBiasDropoutResidualLnGrad");
OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance",
"FusedBiasDropoutResidualLnGrad");
OP_INOUT_CHECK(ctx->HasInput("BiasDropoutResidualOut"), "Input",
"BiasDropoutResidualOut", "FusedBiasDropoutResidualLnGrad");
if (ctx->HasOutput(framework::GradVarName("LnScale"))) {
ctx->SetOutputDim(framework::GradVarName("LnScale"),
ctx->GetInputDim("LnScale"));
}
if (ctx->HasOutput(framework::GradVarName("LnBias"))) {
ctx->SetOutputDim(framework::GradVarName("LnBias"),
ctx->GetInputDim("LnBias"));
}
if (ctx->HasOutput(framework::GradVarName("Residual"))) {
ctx->SetOutputDim(framework::GradVarName("Residual"),
ctx->GetInputDim("Residual"));
}
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Bias"));
}
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"),
ctx->GetInputDim("BiasDropoutResidualOut"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input = ctx.Input<Tensor>("X");
auto input_data_type = framework::TransToProtoVarType(input->dtype());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename T>
class FusedBiasDropoutResidualLnGradOpMaker
: public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("fused_bias_dropout_residual_layer_norm_grad");
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetInput("X", this->Input("X"));
op->SetInput("Residual", this->Input("Residual"));
if (this->HasInput("Bias")) {
op->SetInput("Bias", this->Input("Bias"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
}
if (this->HasInput("LnScale")) {
op->SetInput("LnScale", this->Input("LnScale"));
op->SetOutput(framework::GradVarName("LnScale"),
this->InputGrad("LnScale"));
}
if (this->HasInput("LnBias")) {
op->SetInput("LnBias", this->Input("LnBias"));
op->SetOutput(framework::GradVarName("LnBias"),
this->InputGrad("LnBias"));
}
if (this->HasOutput("LnMean")) {
op->SetInput("LnMean", this->Output("LnMean"));
}
if (this->HasOutput("LnVariance")) {
op->SetInput("LnVariance", this->Output("LnVariance"));
}
if (this->HasOutput("BiasDropoutResidualOut")) {
op->SetInput("BiasDropoutResidualOut",
this->Output("BiasDropoutResidualOut"));
}
op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Residual"),
this->InputGrad("Residual"));
op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
this->OutputGrad("BiasDropoutResidualOut"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
fused_bias_dropout_residual_layer_norm, ops::FusedBiasDropoutResidualLnOp,
ops::FusedBiasDropoutResidualLnOpMaker,
ops::FusedBiasDropoutResidualLnGradOpMaker<paddle::framework::OpDesc>,
ops::FusedBiasDropoutResidualLnGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_bias_dropout_residual_layer_norm_grad,
ops::FusedBiasDropoutResidualLnGradOp);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cuda_fp16.h>
#include <cub/cub.cuh>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class FusedBiasDropoutResidualLnOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
auto *input_x = ctx.Input<Tensor>("X");
auto *bias = ctx.Input<Tensor>("Bias");
auto *residual = ctx.Input<Tensor>("Residual");
const float ln_epsilon = ctx.Attr<float>("ln_epsilon");
auto *ln_scale = ctx.Input<Tensor>("LnScale");
auto *ln_bias = ctx.Input<Tensor>("LnBias");
auto *dropout_mask_out = ctx.Output<Tensor>("DropoutMaskOut");
auto *bias_dropout_residual_out =
ctx.Output<Tensor>("BiasDropoutResidualOut");
auto *ln_mean = ctx.Output<Tensor>("LnMean");
auto *ln_var = ctx.Output<Tensor>("LnVariance");
auto *y = ctx.Output<Tensor>("Y");
auto *x_data = input_x->data<T>();
auto *bias_data = (bias == nullptr) ? nullptr : bias->data<T>();
auto *residual_data = (residual == nullptr) ? nullptr : residual->data<T>();
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_data = ln_mean->mutable_data<U>(ctx.GetPlace());
auto *ln_var_data = ln_var->mutable_data<U>(ctx.GetPlace());
auto *dropout_mask_out_data =
dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
const auto input_x_dims = input_x->dims();
int bsz_seq = 1;
for (int i = 0; i < input_x_dims.size() - 1; i++) {
bsz_seq *= input_x_dims[i];
}
int dim_embed = input_x_dims[input_x_dims.size() - 1];
DropoutParam dropout_param(ctx, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param,
ln_epsilon);
// output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx.cuda_device_context(), x_data, residual_data, bias_data,
ln_scale_data, ln_bias_data, bias_dropout_residual_out_data,
dropout_mask_out_data, y_data, ln_mean_data, ln_var_data);
}
};
template <typename T>
class FusedBiasDropoutResidualLnGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
const float ln_epsilon = ctx.Attr<float>("ln_epsilon");
auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto *ln_scale = ctx.Input<Tensor>("LnScale");
auto *dropout_mask_out = ctx.Input<Tensor>("DropoutMaskOut");
auto *bias_dropout_residual_out =
ctx.Input<Tensor>("BiasDropoutResidualOut");
auto *ln_mean = ctx.Input<Tensor>("LnMean");
auto *ln_var = ctx.Input<Tensor>("LnVariance");
auto *d_y_data = d_y->data<T>();
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>();
auto *bias_dropout_residual_out_data = bias_dropout_residual_out->data<T>();
auto *ln_mean_data = ln_mean->data<U>();
auto *ln_var_data = ln_var->data<U>();
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_residual = ctx.Output<Tensor>(framework::GradVarName("Residual"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto *d_bias_dropout_residual_out =
ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut"));
auto *d_ln_scale = ctx.Output<Tensor>(framework::GradVarName("LnScale"));
auto *d_ln_bias = ctx.Output<Tensor>(framework::GradVarName("LnBias"));
auto *d_x_data = d_x->mutable_data<T>(ctx.GetPlace());
auto *d_residual_data = d_residual->mutable_data<T>(ctx.GetPlace());
auto *d_bias_dropout_residual_out_data =
d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *d_bias_data =
(d_bias == nullptr ? nullptr : d_bias->mutable_data<T>(ctx.GetPlace()));
auto *d_ln_scale_data =
(d_ln_scale == nullptr ? nullptr
: d_ln_scale->mutable_data<U>(ctx.GetPlace()));
auto *d_ln_bias_data =
(d_ln_bias == nullptr ? nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace()));
const auto input_x_dims = d_y->dims();
int bsz_seq = 1;
for (int i = 0; i < input_x_dims.size() - 1; i++) {
bsz_seq *= input_x_dims[i];
}
int dim_embed = input_x_dims[input_x_dims.size() - 1];
DropoutParam dropout_param(ctx, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param,
ln_epsilon);
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data,
dropout_mask_out_data, ln_scale_data, ln_mean_data, ln_var_data,
d_bias_dropout_residual_out_data, d_ln_scale_data, d_ln_bias_data,
d_x_data, d_bias_data, d_residual_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(fused_bias_dropout_residual_layer_norm,
ops::FusedBiasDropoutResidualLnOpKernel<float>,
ops::FusedBiasDropoutResidualLnOpKernel<double>,
ops::FusedBiasDropoutResidualLnOpKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(
fused_bias_dropout_residual_layer_norm_grad,
ops::FusedBiasDropoutResidualLnGradKernel<float>,
ops::FusedBiasDropoutResidualLnGradKernel<double>,
ops::FusedBiasDropoutResidualLnGradKernel<plat::float16>);
......@@ -40,6 +40,8 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"X", "LnScale", "LnBias", "QKVW", "QKVBias", "CacheKV", "TimeStep",
"SrcMask", "OutLinearW", "OutLinearBias", "FFNLnScale", "FFNLnBias",
"FFN1Weight", "FFN1Bias", "FFN2Weight", "FFN2Bias"}},
{"fused_bias_dropout_residual_layer_norm",
{"X", "Residual", "Bias", "LnScale", "LnBias"}},
{"instance_norm", {"X", "Scale", "Bias"}},
{"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}},
{"label_smooth", {"X", "PriorDist"}},
......@@ -152,6 +154,8 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
"DropoutMaskOut", "Ln2Mean",
"Ln2Variance", "BiasDropoutResidualOut",
"CacheKVOut", "Y"}},
{"fused_bias_dropout_residual_layer_norm",
{"BiasDropoutResidualOut", "DropoutMaskOut", "LnMean", "LnVariance", "Y"}},
{"fused_gate_attention",
{"QueryTransposeOut", "KeyTransposeOut", "ValueTransposeOut",
"QKVTransposeOut", "SoftmaxOut", "FMHAOut", "GateOut", "Out"}},
......
......@@ -131,6 +131,8 @@ if(NOT WITH_GPU)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api)
LIST(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer)
LIST(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api)
endif()
LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
import paddle.fluid.core as core
import paddle.nn.functional as F
import paddle.incubate.nn.functional as incubate_f
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.common import Linear, Dropout
from paddle.nn.layer.transformer import _convert_attention_mask
from paddle import tensor
from paddle.fluid import layers
import unittest
from op_test import OpTest
from paddle.fluid.framework import default_main_program
default_main_program().random_seed = 42
class TestFusedBiasDropoutResidualLayerNormOp(OpTest):
def setUp(self):
self.config()
self.generate_input_data()
paddle.set_default_dtype(self.x_type)
self.__class__.op_type = "fused_bias_dropout_residual_layer_norm"
# use autograd to check grad in this unittest.
self.__class__.no_need_check_grad = True
paddle.set_default_dtype(np.float32)
self.norm1 = LayerNorm(self.embed_dim)
paddle.set_default_dtype(self.x_type)
self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train")
def config(self):
self.x_type = np.float32
self.atol = 1e-4
self.training = True
self.batch_size = 8
self.query_length = 128
self.embed_dim = 1024
self.dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
def generate_input_data(self):
self.x = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type)
self.residual = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type)
self.linear_bias = np.random.rand(self.embed_dim).astype(self.x_type)
self.dout = np.random.random((self.batch_size, self.query_length,
self.embed_dim)).astype(self.x_type)
if self.bias_attr is False:
self.tensor_linear_bias = None
else:
self.tensor_linear_bias = paddle.to_tensor(
self.linear_bias, stop_gradient=False)
self.tensor_x = paddle.to_tensor(self.x, stop_gradient=False)
self.tensor_residual = paddle.to_tensor(
self.residual, stop_gradient=False)
def GetBaselineOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
if self.tensor_linear_bias is not None:
out = self.tensor_x + self.tensor_linear_bias
else:
out = self.tensor_x
residual_out = self.tensor_residual + self.dropout(out)
final_out = self.norm1(residual_out)
paddle.autograd.backward(
[final_out], [paddle.to_tensor(self.dout)], retain_graph=True)
if self.tensor_linear_bias is not None:
tensor_linear_bias_grad = self.tensor_linear_bias.grad
else:
tensor_linear_bias_grad = None
return final_out, self.tensor_x.grad, self.tensor_residual.grad, tensor_linear_bias_grad
def GetFusedBiasDropoutResidualLayerNormOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
ln_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False)
ln_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False)
epsilon = 1e-05
final_out = incubate_f.fused_bias_dropout_residual_layer_norm(
self.tensor_x, self.tensor_residual, self.tensor_linear_bias,
ln_scale, ln_bias, self.dropout_prob, epsilon)
paddle.autograd.backward(
[final_out], [paddle.to_tensor(self.dout)], retain_graph=True)
if self.tensor_linear_bias is not None:
tensor_linear_bias_grad = self.tensor_linear_bias.grad
else:
tensor_linear_bias_grad = None
return final_out, self.tensor_x.grad, self.tensor_residual.grad, tensor_linear_bias_grad
def test_fused_op(self):
out_ref, x_grad_ref, residual_grad_ref, linear_bias_grad_ref = self.GetBaselineOut(
)
out, x_grad, residual_grad, linear_bias_grad = self.GetFusedBiasDropoutResidualLayerNormOut(
)
np.testing.assert_allclose(
out_ref, out.numpy(), rtol=1e-5, atol=self.atol)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=self.atol)
np.testing.assert_allclose(
residual_grad_ref, residual_grad.numpy(), rtol=1e-5, atol=self.atol)
if linear_bias_grad_ref is not None:
np.testing.assert_allclose(
linear_bias_grad_ref,
linear_bias_grad.numpy(),
rtol=1e-5,
atol=self.atol)
class TestFusedBiasDropoutResidualLayerNormOpBiasIsNone(
TestFusedBiasDropoutResidualLayerNormOp):
def config(self):
super().config()
self.bias_attr = False
class TestFusedBiasDropoutResidualLayerNormOpFp16(
TestFusedBiasDropoutResidualLayerNormOp):
def config(self):
super().config()
self.x_type = np.float16
self.atol = 1e-1
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
import paddle.fluid.core as core
import paddle.nn.functional as F
from paddle.incubate.nn.layer.fused_transformer import FusedBiasDropoutResidualLayerNorm
from paddle import tensor
from paddle.fluid import layers
from paddle.static import Program, program_guard
import unittest
def layer_norm(x, has_scale, has_bias, weight, bias, epsilon=1e-05):
batch_size, src_len, d_model = x.shape
x = x.reshape((batch_size * src_len, d_model))
mu = np.mean(x, axis=1, keepdims=True)
sigma_squar = np.sum(np.square(x - mu), axis=1) / d_model
x1_up = (x - mu)
x1_down_1 = sigma_squar + epsilon
x1_down = np.sqrt(x1_down_1)
x1_down = x1_down.reshape((x1_down.shape[0], 1))
x1 = x1_up / x1_down
x_scaled = x1
if (has_scale):
x_scaled = weight * x1
x_scaled_bias = x_scaled
if (has_bias):
x_scaled_bias = x_scaled + bias
x_scaled_bias = x_scaled_bias.reshape((batch_size, src_len, d_model))
return x_scaled_bias
def compute_reference(x, residual, ln_scale, ln_bias, linear_bias):
batch_size = x.shape[0]
seq_len = x.shape[1]
embed_dim = x.shape[2]
has_bias = True
if ln_bias is None:
has_bias = False
# bias add, dropout, residual add, layer_norm.
if linear_bias is not None:
linear_bias_out = x + linear_bias
else:
linear_bias_out = x
linear_bias_dropout_out = linear_bias_out
linear_bias_dropout_residual_out = residual + linear_bias_dropout_out
linear_bias_dropout_residual_ln_out = layer_norm(
linear_bias_dropout_residual_out, True, has_bias, ln_scale, ln_bias)
return linear_bias_dropout_residual_ln_out
class TestFusedBiasDropoutResidualLayerNormAPI(unittest.TestCase):
def setUp(self):
self.setXType()
self.setBiasAttr()
self.config()
self.generate_input_data()
def setBiasAttr(self):
self.bias_attr = None
def setXType(self):
self.x_type = np.float32
self.atol = 1e-4
def config(self):
self.training = True
self.batch_size = 1
self.query_length = 2
self.embed_dim = 4
self.dropout_prob = 0.0
self.weight_attr = None
def generate_input_data(self):
self.x = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type)
self.residual = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type)
def run_imperative(self):
fused_bias_dropout_residual_ln = FusedBiasDropoutResidualLayerNorm(
self.embed_dim, self.dropout_prob, self.weight_attr, self.bias_attr)
linear_bias = None
if self.bias_attr is not False:
linear_bias = np.random.random(fused_bias_dropout_residual_ln.
linear_bias.shape).astype('float32')
fused_bias_dropout_residual_ln.linear_bias.set_value(
paddle.to_tensor(linear_bias))
out = fused_bias_dropout_residual_ln(
paddle.to_tensor(self.x), paddle.to_tensor(self.residual))
ln_bias = None
if self.bias_attr is not False:
ln_bias = fused_bias_dropout_residual_ln.ln_bias.numpy()
ln_scale = fused_bias_dropout_residual_ln.ln_scale.numpy(),
ref_out = compute_reference(self.x, self.residual, ln_scale, ln_bias,
linear_bias)
np.testing.assert_allclose(
ref_out, out.numpy(), rtol=1e-5, atol=self.atol)
def run_static(self):
fused_op = FusedBiasDropoutResidualLayerNorm(
self.embed_dim, self.dropout_prob, self.weight_attr, self.bias_attr)
x = paddle.static.data(
name='X',
shape=[self.batch_size, self.query_length, self.embed_dim],
dtype=self.x_type)
residual = paddle.static.data(
name='Residual',
shape=[self.batch_size, self.query_length, self.embed_dim],
dtype=self.x_type)
final_out = fused_op(x, residual)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
linear_bias = None
ln_bias = None
if self.bias_attr is False:
out, ln_scale = exe.run(
paddle.static.default_main_program(),
feed={"X": self.x,
"Residual": self.residual},
fetch_list=[final_out, fused_op.ln_scale])
else:
out, linear_bias, ln_scale, ln_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.x,
"Residual": self.residual},
fetch_list=[
final_out, fused_op.linear_bias, fused_op.ln_scale,
fused_op.ln_bias
])
return out, linear_bias, ln_scale, ln_bias
def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(Program()):
out, linear_bias, ln_scale, ln_bias = self.run_static()
ref_out = compute_reference(self.x, self.residual, ln_scale, ln_bias,
linear_bias)
np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=self.atol)
def test_dynamic_api(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
self.run_imperative()
class TestFusedBiasDropoutResidualLayerNormAPIBiasIsNone(
TestFusedBiasDropoutResidualLayerNormAPI):
def setBiasAttr(self):
self.bias_attr = False
if __name__ == "__main__":
unittest.main()
......@@ -16,10 +16,12 @@ from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401
from .layer.fused_transformer import FusedFeedForward # noqa: F401
from .layer.fused_transformer import FusedTransformerEncoderLayer # noqa: F401
from .layer.fused_transformer import FusedMultiTransformer # noqa: F401
from .layer.fused_transformer import FusedBiasDropoutResidualLayerNorm # noqa: F401
__all__ = [ #noqa
'FusedMultiHeadAttention',
'FusedFeedForward',
'FusedTransformerEncoderLayer',
'FusedMultiTransformer',
'FusedBiasDropoutResidualLayerNorm',
]
......@@ -15,9 +15,11 @@
from .fused_transformer import fused_multi_head_attention
from .fused_transformer import fused_feedforward
from .fused_transformer import fused_multi_transformer
from .fused_transformer import fused_bias_dropout_residual_layer_norm
__all__ = [
'fused_multi_head_attention',
'fused_feedforward',
'fused_multi_transformer',
'fused_bias_dropout_residual_layer_norm',
]
......@@ -212,6 +212,151 @@ def fused_feedforward(x,
return out
def fused_bias_dropout_residual_layer_norm(x,
residual,
bias=None,
ln_scale=None,
ln_bias=None,
dropout_rate=0.5,
ln_epsilon=1e-5,
training=True,
mode='upscale_in_train',
name=None):
r"""
The fused_bias_dropout_residual_layer_norm operator. The pseudo code is as follows:
.. code-block:: python
y = layer_norm(residual + dropout(bias + x))
Parameters:
x (Tensor): The input tensor. The shape is `[*, embed\_dim]`.
residual (Tensor): The residual tensor. The shape is same as x.
bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
ln_scale (Tensor, optional): The weight tensor of layernorm. The shape is `[embed_dim]`. Default None.
ln_bias (Tensor, optional): The bias tensor of layernorm. The shape is `[embed_dim]`. Default None.
dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout after attention.
0 for no dropout. Default 0.5.
ln_epsilon (float, optional): Small float value added to denominator of layer_norm
to avoid dividing by zero. Default is 1e-5.
training (bool, optional): A flag indicating whether it is in train phrase or not. Default True.
mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
1. upscale_in_train(default), upscale the output at training time
- train: out = input * mask / ( 1.0 - p )
- inference: out = input
2. downscale_in_infer, downscale the output at inference
- train: out = input * mask
- inference: out = input * (1.0 - p)
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: The output Tensor, the data type and shape is same as `x`.
Examples:
.. code-block:: python
# required: gpu
import paddle
import paddle.incubate.nn.functional as F
# input: [batch_size, seq_len, embed_dim]
x = paddle.rand(shape=(2, 4, 128), dtype="float32")
# residual: [batch_size, seq_len, embed_dim]
residual = paddle.rand(shape=(2, 4, 128), dtype="float32")
# linear bias: [embed_dim]
bias = paddle.rand(shape=[128], dtype="float32")
# output: [batch_size, seq_len, embed_dim]
output = F.fused_bias_dropout_residual_layer_norm(
x, residual, bias)
# [2, 4, 128]
print(output.shape)
"""
seed = None
if mode not in ('downscale_in_infer', 'upscale_in_train'):
raise ValueError(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'")
mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer
if ln_scale is not None:
assert len(ln_scale.
shape) == 1, "The dims of the shape of ln_scale should be 1."
assert x.shape[len(x.shape) - 1] == ln_scale.shape[
0], "The dim of ln_scale must equal to the last dim of x."
if ln_bias is not None:
assert len(
ln_bias.shape) == 1, "The dims of the shape of ln_bias should be 1."
assert x.shape[len(x.shape) - 1] == ln_bias.shape[
0], "The dim of ln_bias must equal to the last dim of x."
if _non_static_mode():
if default_main_program().random_seed != 0:
seed = default_main_program().random_seed
_, _, _, _, final_out = _C_ops.fused_bias_dropout_residual_layer_norm(
x, residual, bias, ln_scale, ln_bias, 'dropout_rate', dropout_rate,
'ln_epsilon', ln_epsilon, 'dropout_is_test', not training,
'dropout_fix_seed', seed is not None, 'dropout_seed', seed
if seed is not None else 0, 'dropout_implementation', mode)
return final_out
else:
helper = LayerHelper('fused_bias_dropout_residual_layer_norm',
**locals())
dtype = x.dtype
# check dtypes
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'fused_bias_dropout_residual_layer_norm')
check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'],
'fused_bias_dropout_residual_layer_norm')
# set inputs
inputs = dict()
inputs['X'] = [x]
inputs['Residual'] = [residual]
if bias is not None:
inputs['Bias'] = [bias]
if ln_scale:
inputs['LnScale'] = [ln_scale]
if ln_bias:
inputs['LnBias'] = [ln_bias]
if (seed is None or seed == 0) and helper.main_program.random_seed != 0:
seed = helper.main_program.random_seed
# set attrs
attrs = {
'ln_epsilon': ln_epsilon,
'dropout_rate': dropout_rate,
'dropout_is_test': not training,
'dropout_fix_seed': seed is not None,
'dropout_seed': seed if seed is not None else 0,
'dropout_implementation': mode,
}
# set outputs
dropout_mask_out = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
ln_mean_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
ln_variance_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
bias_dropout_residual_out = helper.create_variable_for_type_inference(
dtype=dtype)
final_out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(
type='fused_bias_dropout_residual_layer_norm',
inputs=inputs,
outputs={
"BiasDropoutResidualOut": bias_dropout_residual_out,
"DropoutMaskOut": dropout_mask_out,
"LnMean": ln_mean_out,
"LnVariance": ln_variance_out,
'Y': final_out,
},
attrs=attrs)
return final_out
def fused_multi_head_attention(x,
qkv_weight,
linear_weight,
......
......@@ -36,6 +36,103 @@ def _set_var_distributed(var):
main_block._find_var_recursive(var.name).is_distributed = True
class FusedBiasDropoutResidualLayerNorm(Layer):
"""
Applies fused_bias_dropout_residual_layer_norm operation.
Parameters:
embed_dim (int): The expected feature size in the input and output.
dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout after attention.
0 for no dropout. Default 0.5.
bias_attr (ParamAttr|bool, optional): To specify the bias parameter property.
Default: None, which means the default bias parameter property is used.
If it is set to False, this layer will not have trainable bias parameter.
See usage for details in :code:`ParamAttr`.
epsilon (float, optional): The small value added to the variance to prevent
division by zero. Default: 1e-05.
Examples:
.. code-block:: python
# required: gpu
import paddle
# input: [batch_size, seq_len, embed_dim]
x = paddle.rand((2, 4, 128))
# residual: [batch_size, seq_len, embed_dim]
residual = paddle.rand((2, 4, 128))
fused_bias_dropout_residual_ln = paddle.incubate.nn.FusedBiasDropoutResidualLayerNorm(128)
output = fused_bias_dropout_residual_ln(x, residual) # [2, 4, 128]
"""
def __init__(self,
embed_dim,
dropout_rate=0.5,
weight_attr=None,
bias_attr=None,
epsilon=1e-5,
name=None):
super(FusedBiasDropoutResidualLayerNorm, self).__init__()
assert embed_dim > 0, ("Expected embed_dim to be greater than 0, "
"but recieved {}".format(embed_dim))
self._dtype = self._helper.get_default_dtype()
self._bias_attr = bias_attr
self._weight_attr = weight_attr
self.embed_dim = embed_dim
self.linear_bias = self.create_parameter(
shape=[embed_dim],
attr=self._bias_attr,
dtype=self._dtype,
is_bias=True)
self.ln_scale = self.create_parameter(
attr=self._weight_attr,
shape=[embed_dim],
default_initializer=Constant(value=1.0))
self.ln_bias = self.create_parameter(
attr=self._bias_attr, shape=[embed_dim], is_bias=True)
self.dropout_rate = dropout_rate
self._epsilon = epsilon
self.name = name
def forward(self, x, residual):
"""
Applies fused_bias_dropout_residual_layer_norm operation.
Parameters:
x (Tensor): The input tensor. It is a tensor with shape
`[batch_size, seq_len, embed_dim]`. The data type should be
float32 or float64.
residual (Tensor, optional): The residual tensor. It is a tensor
with shape `[batch_size, value_length, vdim]`. The data type
should be float32 or float64.
Returns:
Tensor|tuple: It is a tensor that has the same shape and data type \
as `x`.
"""
out = incubate_f.fused_bias_dropout_residual_layer_norm(
x=x,
residual=residual,
bias=self.linear_bias,
ln_scale=self.ln_scale,
ln_bias=self.ln_bias,
dropout_rate=self.dropout_rate,
ln_epsilon=self._epsilon,
training=self.training,
mode='upscale_in_train',
name=self.name)
return out
def extra_repr(self):
name_str = ', name={}'.format(self.name) if self.name else ''
return 'embed_dim={}, seq_len={}, dropout_rate={}, epsilon={}, dtype={}{}'.format(
self.embed_dim, self.seq_len, self.dropout_rate, self._epsilon,
self._dtype, name_str)
class FusedMultiHeadAttention(Layer):
"""
Attention mapps queries and a set of key-value pairs to outputs, and
......
......@@ -2047,6 +2047,8 @@ TETRAD_PARALLEL_JOB = [
'test_lambda',
'test_prod_op',
'test_fused_attention_op_api',
'test_fused_bias_dropout_residual_layer_norm_op',
'test_fused_bias_dropout_residual_layer_norm_op_api',
'test_complex_grad_accumulated',
'test_deg2rad',
'test_lgamma_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册