diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt old mode 100644 new mode 100755 index a86d26bcd58a7bbeb4922acc49715487b5f554f6..e23891d899de69dc5defd5e269614d76d09779e3 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -22,6 +22,7 @@ register_operators(EXCLUDES fused_transformer_op fused_feedforward_op fused_multi_transformer_op + fused_bias_dropout_residual_layer_norm_op resnet_unit_op fused_gemm_epilogue_op fused_gate_attention_op) @@ -81,6 +82,7 @@ if (WITH_GPU OR WITH_ROCM) # fused_attention_op op_library(fused_attention_op) op_library(fused_multi_transformer_op) + op_library(fused_bias_dropout_residual_layer_norm_op) endif() # resnet_unit needs cudnn 8.0 above if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000)) diff --git a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6187544456b373fd1e3407b167b5fbdd2153a870 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc @@ -0,0 +1,240 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class FusedBiasDropoutResidualLnOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", + "FusedBiasDropoutResidualLnOp"); + OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean", + "FusedBiasDropoutResidualLnOp"); + OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance", + "FusedBiasDropoutResidualLnOp"); + OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output", + "BiasDropoutResidualOut", "FusedBiasDropoutResidualLnOp"); + OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), "Output", "DropoutMaskOut", + "FusedBiasDropoutResidualLnOp"); + OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", + "FusedBiasDropoutResidualLnOp"); + auto x_dim = ctx->GetInputDim("X"); + int left = 1; + for (int i = 0; i < x_dim.size() - 1; i++) { + left *= x_dim[i]; + } + ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X")); + if (ctx->Attrs().Get("dropout_is_test") == false) { + ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X")); + } + ctx->SetOutputDim("LnMean", {left}); + ctx->SetOutputDim("LnVariance", {left}); + ctx->SetOutputDim("Y", ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input = ctx.Input("X"); + auto input_data_type = framework::TransToProtoVarType(input->dtype()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + +class FusedBiasDropoutResidualLnOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor."); + AddInput("Residual", "The residual tensor."); + AddInput("Bias", "The linear bias tensor.").AsDispensable(); + AddInput("LnScale", + "(optional) Scale is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDispensable(); + AddInput("LnBias", + "(optional) Bias is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDispensable(); + AddOutput("BiasDropoutResidualOut", "Output of bias + dropout + residual.") + .AsIntermediate(); + AddOutput("DropoutMaskOut", "The random sampled dropout mask.") + .AsIntermediate(); + AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate(); + AddOutput("LnVariance", "Variance of the current mini batch.") + .AsIntermediate(); + AddOutput("Y", "Result."); + AddAttr("dropout_rate", "Probability of setting units to zero.") + .SetDefault(.5f) + .AddCustomChecker([](const float &drop_p) { + PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, true, + platform::errors::InvalidArgument( + "'dropout_rate' must be between 0.0 and 1.0.")); + }); + AddAttr("dropout_is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr("dropout_fix_seed", + "A flag indicating whether to use a fixed seed to generate " + "random mask. NOTE: DO NOT set this flag to true in " + "training. Setting this flag to true is only useful in " + "unittest or for debug that always the same output units " + "will be dropped.") + .SetDefault(true); + AddAttr("dropout_seed", "Dropout random seed.").SetDefault(0); + AddAttr( + "dropout_implementation", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "The meaning is the same as 'attn_dropout_implementation'.") + .SetDefault("downgrade_in_infer") + .AddCustomChecker([](const std::string &type) { + PADDLE_ENFORCE_EQ( + type == "downgrade_in_infer" || type == "upscale_in_train", true, + platform::errors::InvalidArgument( + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train")); + }); + AddAttr("ln_epsilon", + "Constant for numerical stability [default 1e-5].") + .SetDefault(1e-5) + .AddCustomChecker([](const float &ln_epsilon) { + PADDLE_ENFORCE_EQ(ln_epsilon >= 0.0f && ln_epsilon <= 0.001f, true, + platform::errors::InvalidArgument( + "'epsilon' of the LayerNorm should be between " + "0.0 and 0.001, But received [%s].", + ln_epsilon)); + }); + + AddComment(R"DOC( + Add fused bias_dropout_residual_layer_norm op whose logic is as follows: + // @input: [batch_size, seq_len, embed_dim] + // @final_out: [batch_size, seq_len, embed_dim] + y = layer_norm(residual + dropout(bias + x)); + )DOC"); + } +}; + +class FusedBiasDropoutResidualLnGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->Attrs().Get("dropout_is_test"), false, + platform::errors::InvalidArgument( + "GradOp is only callable when dropout_is_test is false")); + OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean", + "FusedBiasDropoutResidualLnGrad"); + OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance", + "FusedBiasDropoutResidualLnGrad"); + OP_INOUT_CHECK(ctx->HasInput("BiasDropoutResidualOut"), "Input", + "BiasDropoutResidualOut", "FusedBiasDropoutResidualLnGrad"); + if (ctx->HasOutput(framework::GradVarName("LnScale"))) { + ctx->SetOutputDim(framework::GradVarName("LnScale"), + ctx->GetInputDim("LnScale")); + } + if (ctx->HasOutput(framework::GradVarName("LnBias"))) { + ctx->SetOutputDim(framework::GradVarName("LnBias"), + ctx->GetInputDim("LnBias")); + } + if (ctx->HasOutput(framework::GradVarName("Residual"))) { + ctx->SetOutputDim(framework::GradVarName("Residual"), + ctx->GetInputDim("Residual")); + } + if (ctx->HasOutput(framework::GradVarName("Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Bias"), + ctx->GetInputDim("Bias")); + } + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"), + ctx->GetInputDim("BiasDropoutResidualOut")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input = ctx.Input("X"); + auto input_data_type = framework::TransToProtoVarType(input->dtype()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + +template +class FusedBiasDropoutResidualLnGradOpMaker + : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("fused_bias_dropout_residual_layer_norm_grad"); + op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); + op->SetInput("X", this->Input("X")); + op->SetInput("Residual", this->Input("Residual")); + if (this->HasInput("Bias")) { + op->SetInput("Bias", this->Input("Bias")); + op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); + } + if (this->HasInput("LnScale")) { + op->SetInput("LnScale", this->Input("LnScale")); + op->SetOutput(framework::GradVarName("LnScale"), + this->InputGrad("LnScale")); + } + if (this->HasInput("LnBias")) { + op->SetInput("LnBias", this->Input("LnBias")); + op->SetOutput(framework::GradVarName("LnBias"), + this->InputGrad("LnBias")); + } + if (this->HasOutput("LnMean")) { + op->SetInput("LnMean", this->Output("LnMean")); + } + if (this->HasOutput("LnVariance")) { + op->SetInput("LnVariance", this->Output("LnVariance")); + } + if (this->HasOutput("BiasDropoutResidualOut")) { + op->SetInput("BiasDropoutResidualOut", + this->Output("BiasDropoutResidualOut")); + } + op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("Residual"), + this->InputGrad("Residual")); + op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"), + this->OutputGrad("BiasDropoutResidualOut")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + fused_bias_dropout_residual_layer_norm, ops::FusedBiasDropoutResidualLnOp, + ops::FusedBiasDropoutResidualLnOpMaker, + ops::FusedBiasDropoutResidualLnGradOpMaker, + ops::FusedBiasDropoutResidualLnGradOpMaker); +REGISTER_OPERATOR(fused_bias_dropout_residual_layer_norm_grad, + ops::FusedBiasDropoutResidualLnGradOp); diff --git a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cu b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..71a2c9728cc6b0455b963448c7b840b16990ce58 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cu @@ -0,0 +1,148 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/fused/fused_dropout_helper.h" +#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class FusedBiasDropoutResidualLnOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; + auto *input_x = ctx.Input("X"); + auto *bias = ctx.Input("Bias"); + auto *residual = ctx.Input("Residual"); + const float ln_epsilon = ctx.Attr("ln_epsilon"); + auto *ln_scale = ctx.Input("LnScale"); + auto *ln_bias = ctx.Input("LnBias"); + auto *dropout_mask_out = ctx.Output("DropoutMaskOut"); + auto *bias_dropout_residual_out = + ctx.Output("BiasDropoutResidualOut"); + auto *ln_mean = ctx.Output("LnMean"); + auto *ln_var = ctx.Output("LnVariance"); + auto *y = ctx.Output("Y"); + auto *x_data = input_x->data(); + auto *bias_data = (bias == nullptr) ? nullptr : bias->data(); + auto *residual_data = (residual == nullptr) ? nullptr : residual->data(); + auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); + auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data()); + auto *bias_dropout_residual_out_data = + bias_dropout_residual_out->mutable_data(ctx.GetPlace()); + auto *ln_mean_data = ln_mean->mutable_data(ctx.GetPlace()); + auto *ln_var_data = ln_var->mutable_data(ctx.GetPlace()); + auto *dropout_mask_out_data = + dropout_mask_out->mutable_data(ctx.GetPlace()); + auto *y_data = y->mutable_data(ctx.GetPlace()); + + const auto input_x_dims = input_x->dims(); + int bsz_seq = 1; + for (int i = 0; i < input_x_dims.size() - 1; i++) { + bsz_seq *= input_x_dims[i]; + } + int dim_embed = input_x_dims[input_x_dims.size() - 1]; + DropoutParam dropout_param(ctx, 0); + FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( + ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param, + ln_epsilon); + // output = layernorm(residual + dropout(input + bias)) + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + ctx.cuda_device_context(), x_data, residual_data, bias_data, + ln_scale_data, ln_bias_data, bias_dropout_residual_out_data, + dropout_mask_out_data, y_data, ln_mean_data, ln_var_data); + } +}; + +template +class FusedBiasDropoutResidualLnGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; + const float ln_epsilon = ctx.Attr("ln_epsilon"); + + auto *d_y = ctx.Input(framework::GradVarName("Y")); + auto *ln_scale = ctx.Input("LnScale"); + auto *dropout_mask_out = ctx.Input("DropoutMaskOut"); + auto *bias_dropout_residual_out = + ctx.Input("BiasDropoutResidualOut"); + auto *ln_mean = ctx.Input("LnMean"); + auto *ln_var = ctx.Input("LnVariance"); + auto *d_y_data = d_y->data(); + auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); + auto *dropout_mask_out_data = dropout_mask_out->data(); + auto *bias_dropout_residual_out_data = bias_dropout_residual_out->data(); + auto *ln_mean_data = ln_mean->data(); + auto *ln_var_data = ln_var->data(); + + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_residual = ctx.Output(framework::GradVarName("Residual")); + auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + auto *d_bias_dropout_residual_out = + ctx.Output(framework::GradVarName("BiasDropoutResidualOut")); + auto *d_ln_scale = ctx.Output(framework::GradVarName("LnScale")); + auto *d_ln_bias = ctx.Output(framework::GradVarName("LnBias")); + auto *d_x_data = d_x->mutable_data(ctx.GetPlace()); + auto *d_residual_data = d_residual->mutable_data(ctx.GetPlace()); + auto *d_bias_dropout_residual_out_data = + d_bias_dropout_residual_out->mutable_data(ctx.GetPlace()); + auto *d_bias_data = + (d_bias == nullptr ? nullptr : d_bias->mutable_data(ctx.GetPlace())); + auto *d_ln_scale_data = + (d_ln_scale == nullptr ? nullptr + : d_ln_scale->mutable_data(ctx.GetPlace())); + auto *d_ln_bias_data = + (d_ln_bias == nullptr ? nullptr + : d_ln_bias->mutable_data(ctx.GetPlace())); + + const auto input_x_dims = d_y->dims(); + int bsz_seq = 1; + for (int i = 0; i < input_x_dims.size() - 1; i++) { + bsz_seq *= input_x_dims[i]; + } + int dim_embed = input_x_dims[input_x_dims.size() - 1]; + DropoutParam dropout_param(ctx, 0); + FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( + ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param, + ln_epsilon); + fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( + ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data, + dropout_mask_out_data, ln_scale_data, ln_mean_data, ln_var_data, + d_bias_dropout_residual_out_data, d_ln_scale_data, d_ln_bias_data, + d_x_data, d_bias_data, d_residual_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(fused_bias_dropout_residual_layer_norm, + ops::FusedBiasDropoutResidualLnOpKernel, + ops::FusedBiasDropoutResidualLnOpKernel, + ops::FusedBiasDropoutResidualLnOpKernel); +REGISTER_OP_CUDA_KERNEL( + fused_bias_dropout_residual_layer_norm_grad, + ops::FusedBiasDropoutResidualLnGradKernel, + ops::FusedBiasDropoutResidualLnGradKernel, + ops::FusedBiasDropoutResidualLnGradKernel); diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index df32f65a794f3e3e5919e6bdb5a41757dc25b659..bc84863d7d60782d5b3648eb53f5c9df16295999 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -40,6 +40,8 @@ std::map> op_ins_map = { {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "CacheKV", "TimeStep", "SrcMask", "OutLinearW", "OutLinearBias", "FFNLnScale", "FFNLnBias", "FFN1Weight", "FFN1Bias", "FFN2Weight", "FFN2Bias"}}, + {"fused_bias_dropout_residual_layer_norm", + {"X", "Residual", "Bias", "LnScale", "LnBias"}}, {"instance_norm", {"X", "Scale", "Bias"}}, {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}}, {"label_smooth", {"X", "PriorDist"}}, @@ -152,6 +154,8 @@ std::map> op_outs_map = { "DropoutMaskOut", "Ln2Mean", "Ln2Variance", "BiasDropoutResidualOut", "CacheKVOut", "Y"}}, + {"fused_bias_dropout_residual_layer_norm", + {"BiasDropoutResidualOut", "DropoutMaskOut", "LnMean", "LnVariance", "Y"}}, {"fused_gate_attention", {"QueryTransposeOut", "KeyTransposeOut", "ValueTransposeOut", "QKVTransposeOut", "SoftmaxOut", "FMHAOut", "GateOut", "Out"}}, diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index e0cd0c4bf4d41f4608ddb3af41c626402463f787..34237d47a56599eba1b460fa6908b2a82e1187ee 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -131,6 +131,8 @@ if(NOT WITH_GPU) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api) LIST(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_op) LIST(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer) + LIST(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op) + LIST(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api) endif() LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op) diff --git a/python/paddle/fluid/tests/unittests/test_fused_bias_dropout_residual_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_fused_bias_dropout_residual_layer_norm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..d47450837a455f4e9e5e9f63592c24bad97ac88c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_bias_dropout_residual_layer_norm_op.py @@ -0,0 +1,151 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.fluid.core as core +import paddle.nn.functional as F +import paddle.incubate.nn.functional as incubate_f +from paddle.nn.layer.norm import LayerNorm +from paddle.nn.layer.common import Linear, Dropout +from paddle.nn.layer.transformer import _convert_attention_mask +from paddle import tensor +from paddle.fluid import layers +import unittest +from op_test import OpTest +from paddle.fluid.framework import default_main_program + +default_main_program().random_seed = 42 + + +class TestFusedBiasDropoutResidualLayerNormOp(OpTest): + def setUp(self): + self.config() + self.generate_input_data() + paddle.set_default_dtype(self.x_type) + self.__class__.op_type = "fused_bias_dropout_residual_layer_norm" + # use autograd to check grad in this unittest. + self.__class__.no_need_check_grad = True + paddle.set_default_dtype(np.float32) + self.norm1 = LayerNorm(self.embed_dim) + paddle.set_default_dtype(self.x_type) + self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train") + + def config(self): + self.x_type = np.float32 + self.atol = 1e-4 + self.training = True + self.batch_size = 8 + self.query_length = 128 + self.embed_dim = 1024 + self.dropout_prob = 0.0 + self.weight_attr = None + self.bias_attr = None + + def generate_input_data(self): + self.x = np.random.rand(self.batch_size, self.query_length, + self.embed_dim).astype(self.x_type) + self.residual = np.random.rand(self.batch_size, self.query_length, + self.embed_dim).astype(self.x_type) + self.linear_bias = np.random.rand(self.embed_dim).astype(self.x_type) + self.dout = np.random.random((self.batch_size, self.query_length, + self.embed_dim)).astype(self.x_type) + + if self.bias_attr is False: + self.tensor_linear_bias = None + else: + self.tensor_linear_bias = paddle.to_tensor( + self.linear_bias, stop_gradient=False) + + self.tensor_x = paddle.to_tensor(self.x, stop_gradient=False) + self.tensor_residual = paddle.to_tensor( + self.residual, stop_gradient=False) + + def GetBaselineOut(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + + if self.tensor_linear_bias is not None: + out = self.tensor_x + self.tensor_linear_bias + else: + out = self.tensor_x + + residual_out = self.tensor_residual + self.dropout(out) + final_out = self.norm1(residual_out) + + paddle.autograd.backward( + [final_out], [paddle.to_tensor(self.dout)], retain_graph=True) + + if self.tensor_linear_bias is not None: + tensor_linear_bias_grad = self.tensor_linear_bias.grad + else: + tensor_linear_bias_grad = None + return final_out, self.tensor_x.grad, self.tensor_residual.grad, tensor_linear_bias_grad + + def GetFusedBiasDropoutResidualLayerNormOut(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + + ln_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False) + ln_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False) + epsilon = 1e-05 + + final_out = incubate_f.fused_bias_dropout_residual_layer_norm( + self.tensor_x, self.tensor_residual, self.tensor_linear_bias, + ln_scale, ln_bias, self.dropout_prob, epsilon) + + paddle.autograd.backward( + [final_out], [paddle.to_tensor(self.dout)], retain_graph=True) + if self.tensor_linear_bias is not None: + tensor_linear_bias_grad = self.tensor_linear_bias.grad + else: + tensor_linear_bias_grad = None + return final_out, self.tensor_x.grad, self.tensor_residual.grad, tensor_linear_bias_grad + + def test_fused_op(self): + out_ref, x_grad_ref, residual_grad_ref, linear_bias_grad_ref = self.GetBaselineOut( + ) + out, x_grad, residual_grad, linear_bias_grad = self.GetFusedBiasDropoutResidualLayerNormOut( + ) + np.testing.assert_allclose( + out_ref, out.numpy(), rtol=1e-5, atol=self.atol) + np.testing.assert_allclose( + x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=self.atol) + np.testing.assert_allclose( + residual_grad_ref, residual_grad.numpy(), rtol=1e-5, atol=self.atol) + if linear_bias_grad_ref is not None: + np.testing.assert_allclose( + linear_bias_grad_ref, + linear_bias_grad.numpy(), + rtol=1e-5, + atol=self.atol) + + +class TestFusedBiasDropoutResidualLayerNormOpBiasIsNone( + TestFusedBiasDropoutResidualLayerNormOp): + def config(self): + super().config() + self.bias_attr = False + + +class TestFusedBiasDropoutResidualLayerNormOpFp16( + TestFusedBiasDropoutResidualLayerNormOp): + def config(self): + super().config() + self.x_type = np.float16 + self.atol = 1e-1 + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fused_bias_dropout_residual_layer_norm_op_api.py b/python/paddle/fluid/tests/unittests/test_fused_bias_dropout_residual_layer_norm_op_api.py new file mode 100644 index 0000000000000000000000000000000000000000..19fc3972e58d4809f771fc2a94004e4aefbfdf5b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_bias_dropout_residual_layer_norm_op_api.py @@ -0,0 +1,175 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.fluid.core as core +import paddle.nn.functional as F +from paddle.incubate.nn.layer.fused_transformer import FusedBiasDropoutResidualLayerNorm +from paddle import tensor +from paddle.fluid import layers +from paddle.static import Program, program_guard +import unittest + + +def layer_norm(x, has_scale, has_bias, weight, bias, epsilon=1e-05): + batch_size, src_len, d_model = x.shape + x = x.reshape((batch_size * src_len, d_model)) + mu = np.mean(x, axis=1, keepdims=True) + sigma_squar = np.sum(np.square(x - mu), axis=1) / d_model + x1_up = (x - mu) + x1_down_1 = sigma_squar + epsilon + x1_down = np.sqrt(x1_down_1) + x1_down = x1_down.reshape((x1_down.shape[0], 1)) + x1 = x1_up / x1_down + x_scaled = x1 + if (has_scale): + x_scaled = weight * x1 + x_scaled_bias = x_scaled + if (has_bias): + x_scaled_bias = x_scaled + bias + x_scaled_bias = x_scaled_bias.reshape((batch_size, src_len, d_model)) + return x_scaled_bias + + +def compute_reference(x, residual, ln_scale, ln_bias, linear_bias): + batch_size = x.shape[0] + seq_len = x.shape[1] + embed_dim = x.shape[2] + + has_bias = True + if ln_bias is None: + has_bias = False + # bias add, dropout, residual add, layer_norm. + if linear_bias is not None: + linear_bias_out = x + linear_bias + else: + linear_bias_out = x + linear_bias_dropout_out = linear_bias_out + linear_bias_dropout_residual_out = residual + linear_bias_dropout_out + linear_bias_dropout_residual_ln_out = layer_norm( + linear_bias_dropout_residual_out, True, has_bias, ln_scale, ln_bias) + return linear_bias_dropout_residual_ln_out + + +class TestFusedBiasDropoutResidualLayerNormAPI(unittest.TestCase): + def setUp(self): + self.setXType() + self.setBiasAttr() + self.config() + self.generate_input_data() + + def setBiasAttr(self): + self.bias_attr = None + + def setXType(self): + self.x_type = np.float32 + self.atol = 1e-4 + + def config(self): + self.training = True + self.batch_size = 1 + self.query_length = 2 + self.embed_dim = 4 + self.dropout_prob = 0.0 + self.weight_attr = None + + def generate_input_data(self): + self.x = np.random.rand(self.batch_size, self.query_length, + self.embed_dim).astype(self.x_type) + self.residual = np.random.rand(self.batch_size, self.query_length, + self.embed_dim).astype(self.x_type) + + def run_imperative(self): + fused_bias_dropout_residual_ln = FusedBiasDropoutResidualLayerNorm( + self.embed_dim, self.dropout_prob, self.weight_attr, self.bias_attr) + + linear_bias = None + if self.bias_attr is not False: + linear_bias = np.random.random(fused_bias_dropout_residual_ln. + linear_bias.shape).astype('float32') + fused_bias_dropout_residual_ln.linear_bias.set_value( + paddle.to_tensor(linear_bias)) + out = fused_bias_dropout_residual_ln( + paddle.to_tensor(self.x), paddle.to_tensor(self.residual)) + + ln_bias = None + if self.bias_attr is not False: + ln_bias = fused_bias_dropout_residual_ln.ln_bias.numpy() + ln_scale = fused_bias_dropout_residual_ln.ln_scale.numpy(), + ref_out = compute_reference(self.x, self.residual, ln_scale, ln_bias, + linear_bias) + np.testing.assert_allclose( + ref_out, out.numpy(), rtol=1e-5, atol=self.atol) + + def run_static(self): + fused_op = FusedBiasDropoutResidualLayerNorm( + self.embed_dim, self.dropout_prob, self.weight_attr, self.bias_attr) + + x = paddle.static.data( + name='X', + shape=[self.batch_size, self.query_length, self.embed_dim], + dtype=self.x_type) + residual = paddle.static.data( + name='Residual', + shape=[self.batch_size, self.query_length, self.embed_dim], + dtype=self.x_type) + final_out = fused_op(x, residual) + + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + + linear_bias = None + ln_bias = None + if self.bias_attr is False: + out, ln_scale = exe.run( + paddle.static.default_main_program(), + feed={"X": self.x, + "Residual": self.residual}, + fetch_list=[final_out, fused_op.ln_scale]) + else: + out, linear_bias, ln_scale, ln_bias = exe.run( + paddle.static.default_main_program(), + feed={"X": self.x, + "Residual": self.residual}, + fetch_list=[ + final_out, fused_op.linear_bias, fused_op.ln_scale, + fused_op.ln_bias + ]) + return out, linear_bias, ln_scale, ln_bias + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(Program()): + out, linear_bias, ln_scale, ln_bias = self.run_static() + ref_out = compute_reference(self.x, self.residual, ln_scale, ln_bias, + linear_bias) + np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=self.atol) + + def test_dynamic_api(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + self.run_imperative() + + +class TestFusedBiasDropoutResidualLayerNormAPIBiasIsNone( + TestFusedBiasDropoutResidualLayerNormAPI): + def setBiasAttr(self): + self.bias_attr = False + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/nn/__init__.py b/python/paddle/incubate/nn/__init__.py index 43fcabf97317e33ed08e2d37f79b78cfc53963f0..3c806aa646ebe3157ab06819c12829806f502aa0 100644 --- a/python/paddle/incubate/nn/__init__.py +++ b/python/paddle/incubate/nn/__init__.py @@ -16,10 +16,12 @@ from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401 from .layer.fused_transformer import FusedFeedForward # noqa: F401 from .layer.fused_transformer import FusedTransformerEncoderLayer # noqa: F401 from .layer.fused_transformer import FusedMultiTransformer # noqa: F401 +from .layer.fused_transformer import FusedBiasDropoutResidualLayerNorm # noqa: F401 __all__ = [ #noqa 'FusedMultiHeadAttention', 'FusedFeedForward', 'FusedTransformerEncoderLayer', 'FusedMultiTransformer', + 'FusedBiasDropoutResidualLayerNorm', ] diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index 4da090487785bc78971e4219f8b418eeadc6c191..02e44548ce5d87a4be505dc6a2981405ee3cc938 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -15,9 +15,11 @@ from .fused_transformer import fused_multi_head_attention from .fused_transformer import fused_feedforward from .fused_transformer import fused_multi_transformer +from .fused_transformer import fused_bias_dropout_residual_layer_norm __all__ = [ 'fused_multi_head_attention', 'fused_feedforward', 'fused_multi_transformer', + 'fused_bias_dropout_residual_layer_norm', ] diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 3e263f1c6d3aef62396d8c8c39da229dee6458d3..ee85642d4166401a770ad760185198b52b7b04c6 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -212,6 +212,151 @@ def fused_feedforward(x, return out +def fused_bias_dropout_residual_layer_norm(x, + residual, + bias=None, + ln_scale=None, + ln_bias=None, + dropout_rate=0.5, + ln_epsilon=1e-5, + training=True, + mode='upscale_in_train', + name=None): + r""" + The fused_bias_dropout_residual_layer_norm operator. The pseudo code is as follows: + + .. code-block:: python + y = layer_norm(residual + dropout(bias + x)) + + Parameters: + x (Tensor): The input tensor. The shape is `[*, embed\_dim]`. + residual (Tensor): The residual tensor. The shape is same as x. + bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None. + ln_scale (Tensor, optional): The weight tensor of layernorm. The shape is `[embed_dim]`. Default None. + ln_bias (Tensor, optional): The bias tensor of layernorm. The shape is `[embed_dim]`. Default None. + dropout_rate (float, optional): The dropout probability used on attention + weights to drop some attention targets for the dropout after attention. + 0 for no dropout. Default 0.5. + ln_epsilon (float, optional): Small float value added to denominator of layer_norm + to avoid dividing by zero. Default is 1e-5. + training (bool, optional): A flag indicating whether it is in train phrase or not. Default True. + mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer'] + + 1. upscale_in_train(default), upscale the output at training time + + - train: out = input * mask / ( 1.0 - p ) + - inference: out = input + + 2. downscale_in_infer, downscale the output at inference + + - train: out = input * mask + - inference: out = input * (1.0 - p) + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: The output Tensor, the data type and shape is same as `x`. + + Examples: + + .. code-block:: python + + # required: gpu + import paddle + import paddle.incubate.nn.functional as F + + # input: [batch_size, seq_len, embed_dim] + x = paddle.rand(shape=(2, 4, 128), dtype="float32") + # residual: [batch_size, seq_len, embed_dim] + residual = paddle.rand(shape=(2, 4, 128), dtype="float32") + # linear bias: [embed_dim] + bias = paddle.rand(shape=[128], dtype="float32") + # output: [batch_size, seq_len, embed_dim] + output = F.fused_bias_dropout_residual_layer_norm( + x, residual, bias) + # [2, 4, 128] + print(output.shape) + """ + seed = None + if mode not in ('downscale_in_infer', 'upscale_in_train'): + raise ValueError( + "mode argument should be 'downscale_in_infer' or 'upscale_in_train'") + mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer + + if ln_scale is not None: + assert len(ln_scale. + shape) == 1, "The dims of the shape of ln_scale should be 1." + assert x.shape[len(x.shape) - 1] == ln_scale.shape[ + 0], "The dim of ln_scale must equal to the last dim of x." + if ln_bias is not None: + assert len( + ln_bias.shape) == 1, "The dims of the shape of ln_bias should be 1." + assert x.shape[len(x.shape) - 1] == ln_bias.shape[ + 0], "The dim of ln_bias must equal to the last dim of x." + + if _non_static_mode(): + if default_main_program().random_seed != 0: + seed = default_main_program().random_seed + _, _, _, _, final_out = _C_ops.fused_bias_dropout_residual_layer_norm( + x, residual, bias, ln_scale, ln_bias, 'dropout_rate', dropout_rate, + 'ln_epsilon', ln_epsilon, 'dropout_is_test', not training, + 'dropout_fix_seed', seed is not None, 'dropout_seed', seed + if seed is not None else 0, 'dropout_implementation', mode) + return final_out + else: + helper = LayerHelper('fused_bias_dropout_residual_layer_norm', + **locals()) + dtype = x.dtype + # check dtypes + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'fused_bias_dropout_residual_layer_norm') + check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], + 'fused_bias_dropout_residual_layer_norm') + # set inputs + inputs = dict() + inputs['X'] = [x] + inputs['Residual'] = [residual] + if bias is not None: + inputs['Bias'] = [bias] + if ln_scale: + inputs['LnScale'] = [ln_scale] + if ln_bias: + inputs['LnBias'] = [ln_bias] + if (seed is None or seed == 0) and helper.main_program.random_seed != 0: + seed = helper.main_program.random_seed + # set attrs + attrs = { + 'ln_epsilon': ln_epsilon, + 'dropout_rate': dropout_rate, + 'dropout_is_test': not training, + 'dropout_fix_seed': seed is not None, + 'dropout_seed': seed if seed is not None else 0, + 'dropout_implementation': mode, + } + # set outputs + dropout_mask_out = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) + ln_mean_out = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + ln_variance_out = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + bias_dropout_residual_out = helper.create_variable_for_type_inference( + dtype=dtype) + final_out = helper.create_variable_for_type_inference(dtype=dtype) + + helper.append_op( + type='fused_bias_dropout_residual_layer_norm', + inputs=inputs, + outputs={ + "BiasDropoutResidualOut": bias_dropout_residual_out, + "DropoutMaskOut": dropout_mask_out, + "LnMean": ln_mean_out, + "LnVariance": ln_variance_out, + 'Y': final_out, + }, + attrs=attrs) + return final_out + + def fused_multi_head_attention(x, qkv_weight, linear_weight, diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index 072c7d9fccade2fd6caffa27a67a4d1ba5160f5d..a64b7e506021cef6c0c03c036c76412b88bf11d8 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -36,6 +36,103 @@ def _set_var_distributed(var): main_block._find_var_recursive(var.name).is_distributed = True +class FusedBiasDropoutResidualLayerNorm(Layer): + """ + Applies fused_bias_dropout_residual_layer_norm operation. + + Parameters: + embed_dim (int): The expected feature size in the input and output. + dropout_rate (float, optional): The dropout probability used on attention + weights to drop some attention targets for the dropout after attention. + 0 for no dropout. Default 0.5. + bias_attr (ParamAttr|bool, optional): To specify the bias parameter property. + Default: None, which means the default bias parameter property is used. + If it is set to False, this layer will not have trainable bias parameter. + See usage for details in :code:`ParamAttr`. + epsilon (float, optional): The small value added to the variance to prevent + division by zero. Default: 1e-05. + + Examples: + + .. code-block:: python + + # required: gpu + import paddle + # input: [batch_size, seq_len, embed_dim] + x = paddle.rand((2, 4, 128)) + # residual: [batch_size, seq_len, embed_dim] + residual = paddle.rand((2, 4, 128)) + fused_bias_dropout_residual_ln = paddle.incubate.nn.FusedBiasDropoutResidualLayerNorm(128) + output = fused_bias_dropout_residual_ln(x, residual) # [2, 4, 128] + """ + + def __init__(self, + embed_dim, + dropout_rate=0.5, + weight_attr=None, + bias_attr=None, + epsilon=1e-5, + name=None): + super(FusedBiasDropoutResidualLayerNorm, self).__init__() + assert embed_dim > 0, ("Expected embed_dim to be greater than 0, " + "but recieved {}".format(embed_dim)) + self._dtype = self._helper.get_default_dtype() + self._bias_attr = bias_attr + self._weight_attr = weight_attr + self.embed_dim = embed_dim + self.linear_bias = self.create_parameter( + shape=[embed_dim], + attr=self._bias_attr, + dtype=self._dtype, + is_bias=True) + self.ln_scale = self.create_parameter( + attr=self._weight_attr, + shape=[embed_dim], + default_initializer=Constant(value=1.0)) + self.ln_bias = self.create_parameter( + attr=self._bias_attr, shape=[embed_dim], is_bias=True) + self.dropout_rate = dropout_rate + self._epsilon = epsilon + + self.name = name + + def forward(self, x, residual): + """ + Applies fused_bias_dropout_residual_layer_norm operation. + + Parameters: + x (Tensor): The input tensor. It is a tensor with shape + `[batch_size, seq_len, embed_dim]`. The data type should be + float32 or float64. + residual (Tensor, optional): The residual tensor. It is a tensor + with shape `[batch_size, value_length, vdim]`. The data type + should be float32 or float64. + + Returns: + Tensor|tuple: It is a tensor that has the same shape and data type \ + as `x`. + """ + + out = incubate_f.fused_bias_dropout_residual_layer_norm( + x=x, + residual=residual, + bias=self.linear_bias, + ln_scale=self.ln_scale, + ln_bias=self.ln_bias, + dropout_rate=self.dropout_rate, + ln_epsilon=self._epsilon, + training=self.training, + mode='upscale_in_train', + name=self.name) + return out + + def extra_repr(self): + name_str = ', name={}'.format(self.name) if self.name else '' + return 'embed_dim={}, seq_len={}, dropout_rate={}, epsilon={}, dtype={}{}'.format( + self.embed_dim, self.seq_len, self.dropout_rate, self._epsilon, + self._dtype, name_str) + + class FusedMultiHeadAttention(Layer): """ Attention mapps queries and a set of key-value pairs to outputs, and diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index 7702e8be9c95887ca7308fa0f1b65d49ea5dd968..7c43ef1a6d2e3231c46c04330e1700a8dbba02db 100755 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -2047,6 +2047,8 @@ TETRAD_PARALLEL_JOB = [ 'test_lambda', 'test_prod_op', 'test_fused_attention_op_api', + 'test_fused_bias_dropout_residual_layer_norm_op', + 'test_fused_bias_dropout_residual_layer_norm_op_api', 'test_complex_grad_accumulated', 'test_deg2rad', 'test_lgamma_op',