From 2dd0a46a1931ae8dea2cf8cbf7f86f9f0951591b Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 25 Oct 2021 18:48:31 +0800 Subject: [PATCH] add op: fused_feedforward(backward) (#35611) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 这个PR是fused_feedforward反向的代码 相关kernel实现:fused_dropout_act_bias, fused_residual_dropout_bias, fused_layernorm_residual_dropout_bias fused_feedforward是一个融合算子,该算子对transformer模型的feed forward层的算子进行融合和封装,使得前端只呈现一个接口,通过融合减少部分访存和kernel launch的时间,以此提升性能。 --- .../operators/fused/fused_feedforward_op.cc | 147 +++++++++++- .../operators/fused/fused_feedforward_op.cu | 211 ++++++++++++++++++ .../unittests/test_fused_feedforward_op.py | 43 ++-- 3 files changed, 386 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cc b/paddle/fluid/operators/fused/fused_feedforward_op.cc index 0b23b30b171..4e03c7369d1 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cc +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cc @@ -206,9 +206,154 @@ class FusedFeedForwardOpMaker : public framework::OpProtoAndCheckerMaker { } }; +class FusedFeedForwardOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->Attrs().Get("dropout1_is_test"), false, + platform::errors::InvalidArgument( + "GradOp is only callable when is_test is false")); + PADDLE_ENFORCE_EQ(ctx->Attrs().Get("dropout2_is_test"), false, + platform::errors::InvalidArgument( + "GradOp is only callable when is_test is false")); + OP_INOUT_CHECK(ctx->HasInput("Dropout1Mask"), "Input", "Dropout1Mask", + "FusedFeedForwardGrad"); + OP_INOUT_CHECK(ctx->HasInput("Dropout2Mask"), "Input", "Dropout1Mask", + "FusedFeedForwardGrad"); + OP_INOUT_CHECK(ctx->HasInput("Linear1Out"), "Input", "Linear1Out", + "FusedFeedForwardGrad"); + OP_INOUT_CHECK(ctx->HasInput("Ln1Out"), "Input", "Ln1Out", + "FusedFeedForwardGrad"); + OP_INOUT_CHECK(ctx->HasInput("Dropout1Out"), "Input", "Dropout1Out", + "FusedFeedForwardGrad"); + OP_INOUT_CHECK(ctx->HasInput("Dropout2Out"), "Input", "Dropout2Out", + "FusedFeedForwardGrad"); + OP_INOUT_CHECK(ctx->HasInput("Linear1Weight"), "Input", "Linear1Weight", + "FusedFeedForwardGrad"); + OP_INOUT_CHECK(ctx->HasInput("Linear2Weight"), "Input", "Linear2Weight", + "FusedFeedForwardGrad"); + OP_INOUT_CHECK(ctx->HasInput("Ln1Mean"), "Input", "Ln1Mean", + "FusedFeedForwardGrad"); + OP_INOUT_CHECK(ctx->HasInput("Ln1Variance"), "Input", "Ln1Variance", + "FusedFeedForwardGrad"); + OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean", + "FusedFeedForwardGrad"); + OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance", + "FusedFeedForwardGrad"); + + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "FusedFeedForwardGrad"); + + auto d_out_dim = ctx->GetInputDim(framework::GradVarName("Out")); + ctx->SetOutputDim(framework::GradVarName("X"), d_out_dim); + if (ctx->HasOutput(framework::GradVarName("Ln1Scale"))) { + ctx->SetOutputDim(framework::GradVarName("Ln1Scale"), + ctx->GetInputDim("Ln1Scale")); + } + if (ctx->HasOutput(framework::GradVarName("Ln1Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Ln1Bias"), + ctx->GetInputDim("Ln1Bias")); + } + if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) { + ctx->SetOutputDim(framework::GradVarName("Ln2Scale"), + ctx->GetInputDim("Ln2Scale")); + } + if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Ln2Bias"), + ctx->GetInputDim("Ln2Bias")); + } + ctx->SetOutputDim(framework::GradVarName("Linear1Weight"), + ctx->GetInputDim("Linear1Weight")); + if (ctx->HasOutput(framework::GradVarName("Linear1Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Linear1Bias"), + ctx->GetInputDim("Linear1Bias")); + } + ctx->SetOutputDim(framework::GradVarName("Linear2Weight"), + ctx->GetInputDim("Linear2Weight")); + if (ctx->HasOutput(framework::GradVarName("Linear2Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Linear2Bias"), + ctx->GetInputDim("Linear2Bias")); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input = ctx.Input("X"); + auto input_data_type = input->type(); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + +template +class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("fused_feedforward_grad"); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetInput("X", this->Input("X")); + op->SetInput("Linear1Weight", this->Input("Linear1Weight")); + op->SetInput("Linear1Bias", this->Input("Linear1Bias")); + op->SetInput("Linear2Weight", this->Input("Linear2Weight")); + op->SetInput("Ln1Scale", this->Input("Ln1Scale")); + op->SetInput("Ln1Bias", this->Input("Ln1Bias")); + op->SetInput("Ln2Scale", this->Input("Ln2Scale")); + op->SetInput("Ln2Bias", this->Input("Ln2Bias")); + op->SetInput("Dropout1Mask", this->Output("Dropout1Mask")); + op->SetInput("Dropout2Mask", this->Output("Dropout2Mask")); + op->SetInput("Linear1Out", this->Output("Linear1Out")); + op->SetInput("Ln1Out", this->Output("Ln1Out")); + op->SetInput("Ln1Mean", this->Output("Ln1Mean")); + op->SetInput("Ln1Variance", this->Output("Ln1Variance")); + op->SetInput("Ln2Mean", this->Output("Ln2Mean")); + op->SetInput("Ln2Variance", this->Output("Ln2Variance")); + op->SetInput("Dropout1Out", this->Output("Dropout1Out")); + op->SetInput("Dropout2Out", this->Output("Dropout2Out")); + + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("Ln1Scale"), + this->InputGrad("Ln1Scale")); + op->SetOutput(framework::GradVarName("Ln1Bias"), + this->InputGrad("Ln1Bias")); + op->SetOutput(framework::GradVarName("Ln2Scale"), + this->InputGrad("Ln2Scale")); + op->SetOutput(framework::GradVarName("Ln2Bias"), + this->InputGrad("Ln2Bias")); + op->SetOutput(framework::GradVarName("Linear1Weight"), + this->InputGrad("Linear1Weight")); + op->SetOutput(framework::GradVarName("Linear1Bias"), + this->InputGrad("Linear1Bias")); + op->SetOutput(framework::GradVarName("Linear2Weight"), + this->InputGrad("Linear2Weight")); + if (this->HasInput("Linear2Bias")) { + op->SetInput("Linear2Bias", this->Input("Linear2Bias")); + op->SetOutput(framework::GradVarName("Linear2Bias"), + this->InputGrad("Linear2Bias")); + } + + op->SetAttrMap(this->Attrs()); + } +}; + +template +class FusedFeedForwardOpDoubleGradMaker + : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override {} +}; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(fused_feedforward, ops::FusedFeedForwardOp, - ops::FusedFeedForwardOpMaker); + ops::FusedFeedForwardOpMaker, + ops::FusedFeedForwardOpGradMaker, + ops::FusedFeedForwardOpGradMaker); +REGISTER_OPERATOR(fused_feedforward_grad, ops::FusedFeedForwardOpGrad); diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cu b/paddle/fluid/operators/fused/fused_feedforward_op.cu index 03f94372517..61a8a9a82f2 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cu +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cu @@ -171,6 +171,210 @@ class FusedFeedForwardKernel : public framework::OpKernel { } }; +template +class FusedFeedForwardGradKernel : public framework::OpKernel { + public: + void MatMulGrad(const platform::CUDADeviceContext& ctx, + const framework::Tensor& d_out, const framework::Tensor& a, + const framework::Tensor& b, framework::Tensor* d_a, + framework::Tensor* d_b) const { + auto blas = math::GetBlas(ctx); + auto a_2d = FoldInitDims(a); + auto b_2d = FoldInitDims(b); + auto mat_dim_a = math::CreateMatrixDescriptor(a_2d.dims(), 0, true); + auto mat_dim_b = math::CreateMatrixDescriptor(b_2d.dims(), 0, true); + auto mat_dim_dout = math::CreateMatrixDescriptor(d_out.dims(), 0, false); + T alpha = static_cast(1.0); + blas.MatMul(d_out, mat_dim_dout, b, mat_dim_b, alpha, d_a, T(0)); + blas.MatMul(a, mat_dim_a, d_out, mat_dim_dout, alpha, d_b, T(0)); + } + + void FFNGrad( + const framework::Tensor& d_out, const framework::Tensor& x, + const framework::Tensor& dropout1_mask, + const framework::Tensor& dropout2_mask, + const framework::Tensor& linear1_out, const framework::Tensor& ln1_out, + const framework::Tensor& dropout1_out, + const framework::Tensor& dropout2_out, + const framework::Tensor& linear1_weight, + const framework::Tensor* linear1_bias, + const framework::Tensor& linear2_weight, + const framework::Tensor* ln1_gamma, const framework::Tensor* ln1_beta, + const framework::Tensor& ln1_mean, const framework::Tensor& ln1_variance, + const framework::Tensor* ln2_gamma, const framework::Tensor* ln2_beta, + const framework::Tensor& ln2_mean, const framework::Tensor& ln2_variance, + framework::Tensor* d_x, framework::Tensor* d_linear1_weight, + framework::Tensor* d_linear1_bias, framework::Tensor* d_linear2_weight, + framework::Tensor* d_linear2_bias, framework::Tensor* d_ln1_gamma, + framework::Tensor* d_ln1_beta, framework::Tensor* d_ln2_gamma, + framework::Tensor* d_ln2_beta, const int bsz_seq, const int d_model, + const int dim_feedforward, const DropoutParam& dropout_param1, + const DropoutParam& dropout_param2, const std::string& act_method, + const bool pre_layer_norm, const float epsilon1, const float epsilon2, + const platform::CUDADeviceContext& ctx) const { + FusedDropoutLayerNormHelper pre_layernorm_helper( + bsz_seq, d_model, epsilon1); + FusedDropoutHelper fused_act_dropout_helper( + ctx, bsz_seq, dim_feedforward, dropout_param1); + FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( + ctx, bsz_seq, d_model, dropout_param2, epsilon2); + + auto place = ctx.GetPlace(); + using U = LayerNormParamType; + const U* ln1_gamma_ptr = + ln1_gamma == nullptr ? nullptr : ln1_gamma->data(); + const U* ln1_beta_ptr = ln1_beta == nullptr ? nullptr : ln1_beta->data(); + const U* ln2_gamma_ptr = + ln2_gamma == nullptr ? nullptr : ln2_gamma->data(); + const U* ln2_beta_ptr = ln2_beta == nullptr ? nullptr : ln2_beta->data(); + const T* linear1_bias_ptr = + linear1_bias == nullptr ? nullptr : linear1_bias->data(); + T* d_linear1_bias_ptr = + d_linear1_bias == nullptr ? nullptr : d_linear1_bias->data(); + T* d_linear2_bias_ptr = + d_linear2_bias == nullptr ? nullptr : d_linear2_bias->data(); + U* d_ln1_gamma_ptr = + d_ln1_gamma == nullptr ? nullptr : d_ln1_gamma->data(); + U* d_ln1_beta_ptr = d_ln1_beta == nullptr ? nullptr : d_ln1_beta->data(); + U* d_ln2_gamma_ptr = + d_ln2_gamma == nullptr ? nullptr : d_ln2_gamma->data(); + U* d_ln2_beta_ptr = d_ln2_beta == nullptr ? nullptr : d_ln2_beta->data(); + + framework::Tensor d_linear2_out, d_dropout2_out, d_residual; + d_linear2_out.mutable_data({bsz_seq, d_model}, place); + d_dropout2_out.mutable_data({bsz_seq, d_model}, place); + d_residual.mutable_data({bsz_seq, d_model}, place); + + if (pre_layer_norm) { + fused_dropout_layernorm_helper.ResidualDropoutBiasGrad( + ctx, d_out.data(), dropout2_mask.data(), + d_linear2_out.data(), d_residual.data(), d_linear2_bias_ptr); + } else { + fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( + ctx, d_out.data(), dropout2_out.data(), + dropout2_mask.data(), ln2_gamma_ptr, ln2_mean.data(), + ln2_variance.data(), d_dropout2_out.data(), d_ln2_gamma_ptr, + d_ln2_beta_ptr, d_linear2_out.data(), d_linear2_bias_ptr, + d_residual.data()); + } + + framework::Tensor d_dropout1_out; + d_dropout1_out.mutable_data({bsz_seq, dim_feedforward}, place); + MatMulGrad(ctx, d_linear2_out, dropout1_out, linear2_weight, + &d_dropout1_out, d_linear2_weight); + + framework::Tensor d_linear1_out; + d_linear1_out.mutable_data({bsz_seq, dim_feedforward}, place); + fused_act_dropout_helper.DropoutActBiasGrad( + ctx, d_dropout1_out.data(), linear1_out.data(), linear1_bias_ptr, + dropout1_mask.data(), d_linear1_out.data(), + d_linear1_bias_ptr, act_method); + + if (pre_layer_norm) { + framework::Tensor d_ln1_out; + d_ln1_out.mutable_data({bsz_seq, d_model}, place); + MatMulGrad(ctx, d_linear1_out, ln1_out, linear1_weight, &d_ln1_out, + d_linear1_weight); + + pre_layernorm_helper.LayerNormGrad(ctx, d_ln1_out.data(), x.data(), + ln1_gamma_ptr, ln1_mean.data(), + ln1_variance.data(), d_x->data(), + d_ln1_gamma_ptr, d_ln1_beta_ptr); + } else { + MatMulGrad(ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight); + } + } + + void Compute(const framework::ExecutionContext& context) const override { + using U = LayerNormParamType; + auto d_out = + *context.Input(framework::GradVarName("Out")); + auto x = *context.Input("X"); + auto dropout1_mask = *context.Input("Dropout1Mask"); + auto dropout2_mask = *context.Input("Dropout2Mask"); + auto linear1_out = *context.Input("Linear1Out"); + auto ln1_out = *context.Input("Ln1Out"); + auto dropout1_out = *context.Input("Dropout1Out"); + auto dropout2_out = *context.Input("Dropout2Out"); + auto linear1_weight = *context.Input("Linear1Weight"); + auto* linear1_bias = context.Input("Linear1Bias"); + auto linear2_weight = *context.Input("Linear2Weight"); + auto ln1_mean = *context.Input("Ln1Mean"); + auto ln1_variance = *context.Input("Ln1Variance"); + auto* ln1_scale = context.Input("Ln1Scale"); + auto* ln1_bias = context.Input("Ln1Bias"); + auto ln2_mean = *context.Input("Ln2Mean"); + auto ln2_variance = *context.Input("Ln2Variance"); + auto* ln2_scale = context.Input("Ln2Scale"); + auto* ln2_bias = context.Input("Ln2Bias"); + + auto* d_x = context.Output(framework::GradVarName("X")); + auto* d_ln1_scale = + context.Output(framework::GradVarName("Ln1Scale")); + auto* d_ln1_bias = + context.Output(framework::GradVarName("Ln1Bias")); + auto* d_ln2_scale = + context.Output(framework::GradVarName("Ln2Scale")); + auto* d_ln2_bias = + context.Output(framework::GradVarName("Ln2Bias")); + auto* d_linear1_weight = context.Output( + framework::GradVarName("Linear1Weight")); + auto* d_linear1_bias = context.Output( + framework::GradVarName("Linear1Bias")); + auto* d_linear2_weight = context.Output( + framework::GradVarName("Linear2Weight")); + auto* d_linear2_bias = context.Output( + framework::GradVarName("Linear2Bias")); + + const float epsilon1 = context.Attr("ln1_epsilon"); + const float epsilon2 = context.Attr("ln2_epsilon"); + const bool pre_layer_norm = context.Attr("pre_layer_norm"); + const std::string act_method = context.Attr("act_method"); + DropoutParam dropout_param1(context, 1); + DropoutParam dropout_param2(context, 2); + + auto place = context.GetPlace(); + d_x->mutable_data(place); + if (d_ln1_scale) { + d_ln1_scale->mutable_data(place); + } + if (d_ln1_bias) { + d_ln1_bias->mutable_data(place); + } + if (d_ln2_scale) { + d_ln2_scale->mutable_data(place); + } + if (d_ln2_bias) { + d_ln2_bias->mutable_data(place); + } + if (d_linear1_bias) { + d_linear1_bias->mutable_data(place); + } + if (d_linear2_bias) { + d_linear2_bias->mutable_data(place); + } + d_linear1_weight->mutable_data(place); + d_linear2_weight->mutable_data(place); + + auto x_dim = x.dims(); + auto mat_dim_x = + math::CreateMatrixDescriptor(RowMatrixFromVector(x_dim), 0, false); + + auto linear1_weight_dim = linear1_weight.dims(); + int d_model = linear1_weight_dim[0]; + int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1]; + int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_; + + FFNGrad(d_out, x, dropout1_mask, dropout2_mask, linear1_out, ln1_out, + dropout1_out, dropout2_out, linear1_weight, linear1_bias, + linear2_weight, ln1_scale, ln1_bias, ln1_mean, ln1_variance, + ln2_scale, ln2_bias, ln2_mean, ln2_variance, d_x, d_linear1_weight, + d_linear1_bias, d_linear2_weight, d_linear2_bias, d_ln1_scale, + d_ln1_bias, d_ln2_scale, d_ln2_bias, bsz_seq, d_model, + dim_feedforward, dropout_param1, dropout_param2, act_method, + pre_layer_norm, epsilon1, epsilon2, context.cuda_device_context()); + } +}; } // namespace operators } // namespace paddle @@ -181,3 +385,10 @@ REGISTER_OP_CUDA_KERNEL( ops::FusedFeedForwardKernel, ops::FusedFeedForwardKernel); +REGISTER_OP_CUDA_KERNEL( + fused_feedforward_grad, + ops::FusedFeedForwardGradKernel, + ops::FusedFeedForwardGradKernel, + ops::FusedFeedForwardGradKernel); diff --git a/python/paddle/fluid/tests/unittests/test_fused_feedforward_op.py b/python/paddle/fluid/tests/unittests/test_fused_feedforward_op.py index a0b341bf6cf..d926512b592 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_feedforward_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_feedforward_op.py @@ -30,10 +30,10 @@ class TestFusedFFNOp(OpTest): self.layer_norm_dtype = "float32" def getShape(self): - self.batch_size = np.random.randint(1, 64) - self.query_length = np.random.randint(32, 256) - self.d_model = np.random.randint(32, 1024) - self.dim_feedforward = np.random.randint(32, 1024) + self.batch_size = np.random.randint(1, 32) + self.query_length = np.random.randint(32, 128) + self.d_model = np.random.randint(32, 512) + self.dim_feedforward = np.random.randint(32, 512) def getDiff(self): self.rtol = 1e-3 @@ -48,6 +48,8 @@ class TestFusedFFNOp(OpTest): def setUp(self): paddle.disable_static() self.__class__.op_type = "fused_feedforward" + #check grad in test_out_and_grad() + self.__class__.no_need_check_grad = True self.getDtype() self.getShape() self.getDiff() @@ -82,6 +84,8 @@ class TestFusedFFNOp(OpTest): self.src = np.random.random((self.batch_size, self.query_length, self.d_model)).astype(self.dtype) + self.dout = np.random.random((self.batch_size, self.query_length, + self.d_model)).astype(self.dtype) def Base(self): paddle.disable_static() @@ -92,12 +96,17 @@ class TestFusedFFNOp(OpTest): linear2_out = self.linear2( self.dropout(self.activation(self.linear1(ln1_out)))) dropout2_out = residual + self.dropout2(linear2_out) + paddle.autograd.backward([dropout2_out], + [paddle.to_tensor(self.dout)], True) + return dropout2_out, tensor_src.grad else: linear2_out = self.linear2( self.dropout(self.activation(self.linear1(tensor_src)))) dropout2_out = residual + self.dropout2(linear2_out) dropout2_out = self.norm2(dropout2_out) - return dropout2_out + paddle.autograd.backward([dropout2_out], + [paddle.to_tensor(self.dout)], True) + return dropout2_out, tensor_src.grad def FusedFFN(self): paddle.disable_static() @@ -126,13 +135,19 @@ class TestFusedFFNOp(OpTest): 0.0, activation=self.act_method, pre_layer_norm=self.pre_layer_norm) - return out + paddle.autograd.backward([out], [paddle.to_tensor(self.dout)]) + return out, x.grad - def test_fused_ffn(self): - base_out = self.Base() - fused_out = self.FusedFFN() + def test_out_and_grad(self): + base_out, base_grad = self.Base() + fused_out, fused_grad = self.FusedFFN() np.testing.assert_allclose( base_out.numpy(), fused_out.numpy(), rtol=self.rtol, atol=self.atol) + np.testing.assert_allclose( + base_grad.numpy(), + fused_grad.numpy(), + rtol=self.rtol, + atol=self.atol) class TestFusedFFNOpFp16(TestFusedFFNOp): @@ -145,10 +160,10 @@ class TestFusedFFNOpFp16(TestFusedFFNOp): self.atol = 1e-2 def getShape(self): - self.batch_size = 8 - self.query_length = 128 - self.d_model = 512 - self.dim_feedforward = 512 + self.batch_size = 4 + self.query_length = 32 + self.d_model = 128 + self.dim_feedforward = 256 class TestFusedFFNOpFp64(TestFusedFFNOp): @@ -263,7 +278,7 @@ class APITestStaticFusedFFN(unittest.TestCase): real_res.append(fetch) self.assertTrue( np.allclose( - real_res[0], real_res[1], atol=1e-5), + real_res[0], real_res[1], atol=1e-3), "two value is check diff") -- GitLab