未验证 提交 2dd0a46a 编写于 作者: Z zhangkaihuo 提交者: GitHub

add op: fused_feedforward(backward) (#35611)

这个PR是fused_feedforward反向的代码

相关kernel实现:fused_dropout_act_bias, fused_residual_dropout_bias, fused_layernorm_residual_dropout_bias

fused_feedforward是一个融合算子,该算子对transformer模型的feed forward层的算子进行融合和封装,使得前端只呈现一个接口,通过融合减少部分访存和kernel launch的时间,以此提升性能。
上级 39f19127
......@@ -206,9 +206,154 @@ class FusedFeedForwardOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("dropout1_is_test"), false,
platform::errors::InvalidArgument(
"GradOp is only callable when is_test is false"));
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("dropout2_is_test"), false,
platform::errors::InvalidArgument(
"GradOp is only callable when is_test is false"));
OP_INOUT_CHECK(ctx->HasInput("Dropout1Mask"), "Input", "Dropout1Mask",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Dropout2Mask"), "Input", "Dropout1Mask",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Linear1Out"), "Input", "Linear1Out",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln1Out"), "Input", "Ln1Out",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Dropout1Out"), "Input", "Dropout1Out",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Dropout2Out"), "Input", "Dropout2Out",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Linear1Weight"), "Input", "Linear1Weight",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Linear2Weight"), "Input", "Linear2Weight",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln1Mean"), "Input", "Ln1Mean",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln1Variance"), "Input", "Ln1Variance",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "FusedFeedForwardGrad");
auto d_out_dim = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), d_out_dim);
if (ctx->HasOutput(framework::GradVarName("Ln1Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Ln1Scale"),
ctx->GetInputDim("Ln1Scale"));
}
if (ctx->HasOutput(framework::GradVarName("Ln1Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Ln1Bias"),
ctx->GetInputDim("Ln1Bias"));
}
if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Ln2Scale"),
ctx->GetInputDim("Ln2Scale"));
}
if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Ln2Bias"),
ctx->GetInputDim("Ln2Bias"));
}
ctx->SetOutputDim(framework::GradVarName("Linear1Weight"),
ctx->GetInputDim("Linear1Weight"));
if (ctx->HasOutput(framework::GradVarName("Linear1Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Linear1Bias"),
ctx->GetInputDim("Linear1Bias"));
}
ctx->SetOutputDim(framework::GradVarName("Linear2Weight"),
ctx->GetInputDim("Linear2Weight"));
if (ctx->HasOutput(framework::GradVarName("Linear2Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Linear2Bias"),
ctx->GetInputDim("Linear2Bias"));
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input = ctx.Input<Tensor>("X");
auto input_data_type = input->type();
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename T>
class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("fused_feedforward_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("X", this->Input("X"));
op->SetInput("Linear1Weight", this->Input("Linear1Weight"));
op->SetInput("Linear1Bias", this->Input("Linear1Bias"));
op->SetInput("Linear2Weight", this->Input("Linear2Weight"));
op->SetInput("Ln1Scale", this->Input("Ln1Scale"));
op->SetInput("Ln1Bias", this->Input("Ln1Bias"));
op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetInput("Ln2Bias", this->Input("Ln2Bias"));
op->SetInput("Dropout1Mask", this->Output("Dropout1Mask"));
op->SetInput("Dropout2Mask", this->Output("Dropout2Mask"));
op->SetInput("Linear1Out", this->Output("Linear1Out"));
op->SetInput("Ln1Out", this->Output("Ln1Out"));
op->SetInput("Ln1Mean", this->Output("Ln1Mean"));
op->SetInput("Ln1Variance", this->Output("Ln1Variance"));
op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetInput("Dropout1Out", this->Output("Dropout1Out"));
op->SetInput("Dropout2Out", this->Output("Dropout2Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Ln1Scale"),
this->InputGrad("Ln1Scale"));
op->SetOutput(framework::GradVarName("Ln1Bias"),
this->InputGrad("Ln1Bias"));
op->SetOutput(framework::GradVarName("Ln2Scale"),
this->InputGrad("Ln2Scale"));
op->SetOutput(framework::GradVarName("Ln2Bias"),
this->InputGrad("Ln2Bias"));
op->SetOutput(framework::GradVarName("Linear1Weight"),
this->InputGrad("Linear1Weight"));
op->SetOutput(framework::GradVarName("Linear1Bias"),
this->InputGrad("Linear1Bias"));
op->SetOutput(framework::GradVarName("Linear2Weight"),
this->InputGrad("Linear2Weight"));
if (this->HasInput("Linear2Bias")) {
op->SetInput("Linear2Bias", this->Input("Linear2Bias"));
op->SetOutput(framework::GradVarName("Linear2Bias"),
this->InputGrad("Linear2Bias"));
}
op->SetAttrMap(this->Attrs());
}
};
template <typename T>
class FusedFeedForwardOpDoubleGradMaker
: public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fused_feedforward, ops::FusedFeedForwardOp,
ops::FusedFeedForwardOpMaker);
ops::FusedFeedForwardOpMaker,
ops::FusedFeedForwardOpGradMaker<paddle::framework::OpDesc>,
ops::FusedFeedForwardOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_feedforward_grad, ops::FusedFeedForwardOpGrad);
......@@ -171,6 +171,210 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
public:
void MatMulGrad(const platform::CUDADeviceContext& ctx,
const framework::Tensor& d_out, const framework::Tensor& a,
const framework::Tensor& b, framework::Tensor* d_a,
framework::Tensor* d_b) const {
auto blas = math::GetBlas<DeviceContext, T>(ctx);
auto a_2d = FoldInitDims(a);
auto b_2d = FoldInitDims(b);
auto mat_dim_a = math::CreateMatrixDescriptor(a_2d.dims(), 0, true);
auto mat_dim_b = math::CreateMatrixDescriptor(b_2d.dims(), 0, true);
auto mat_dim_dout = math::CreateMatrixDescriptor(d_out.dims(), 0, false);
T alpha = static_cast<T>(1.0);
blas.MatMul(d_out, mat_dim_dout, b, mat_dim_b, alpha, d_a, T(0));
blas.MatMul(a, mat_dim_a, d_out, mat_dim_dout, alpha, d_b, T(0));
}
void FFNGrad(
const framework::Tensor& d_out, const framework::Tensor& x,
const framework::Tensor& dropout1_mask,
const framework::Tensor& dropout2_mask,
const framework::Tensor& linear1_out, const framework::Tensor& ln1_out,
const framework::Tensor& dropout1_out,
const framework::Tensor& dropout2_out,
const framework::Tensor& linear1_weight,
const framework::Tensor* linear1_bias,
const framework::Tensor& linear2_weight,
const framework::Tensor* ln1_gamma, const framework::Tensor* ln1_beta,
const framework::Tensor& ln1_mean, const framework::Tensor& ln1_variance,
const framework::Tensor* ln2_gamma, const framework::Tensor* ln2_beta,
const framework::Tensor& ln2_mean, const framework::Tensor& ln2_variance,
framework::Tensor* d_x, framework::Tensor* d_linear1_weight,
framework::Tensor* d_linear1_bias, framework::Tensor* d_linear2_weight,
framework::Tensor* d_linear2_bias, framework::Tensor* d_ln1_gamma,
framework::Tensor* d_ln1_beta, framework::Tensor* d_ln2_gamma,
framework::Tensor* d_ln2_beta, const int bsz_seq, const int d_model,
const int dim_feedforward, const DropoutParam& dropout_param1,
const DropoutParam& dropout_param2, const std::string& act_method,
const bool pre_layer_norm, const float epsilon1, const float epsilon2,
const platform::CUDADeviceContext& ctx) const {
FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
bsz_seq, d_model, epsilon1);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
ctx, bsz_seq, dim_feedforward, dropout_param1);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx, bsz_seq, d_model, dropout_param2, epsilon2);
auto place = ctx.GetPlace();
using U = LayerNormParamType<T>;
const U* ln1_gamma_ptr =
ln1_gamma == nullptr ? nullptr : ln1_gamma->data<U>();
const U* ln1_beta_ptr = ln1_beta == nullptr ? nullptr : ln1_beta->data<U>();
const U* ln2_gamma_ptr =
ln2_gamma == nullptr ? nullptr : ln2_gamma->data<U>();
const U* ln2_beta_ptr = ln2_beta == nullptr ? nullptr : ln2_beta->data<U>();
const T* linear1_bias_ptr =
linear1_bias == nullptr ? nullptr : linear1_bias->data<T>();
T* d_linear1_bias_ptr =
d_linear1_bias == nullptr ? nullptr : d_linear1_bias->data<T>();
T* d_linear2_bias_ptr =
d_linear2_bias == nullptr ? nullptr : d_linear2_bias->data<T>();
U* d_ln1_gamma_ptr =
d_ln1_gamma == nullptr ? nullptr : d_ln1_gamma->data<U>();
U* d_ln1_beta_ptr = d_ln1_beta == nullptr ? nullptr : d_ln1_beta->data<U>();
U* d_ln2_gamma_ptr =
d_ln2_gamma == nullptr ? nullptr : d_ln2_gamma->data<U>();
U* d_ln2_beta_ptr = d_ln2_beta == nullptr ? nullptr : d_ln2_beta->data<U>();
framework::Tensor d_linear2_out, d_dropout2_out, d_residual;
d_linear2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_dropout2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_residual.mutable_data<T>({bsz_seq, d_model}, place);
if (pre_layer_norm) {
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
ctx, d_out.data<T>(), dropout2_mask.data<uint8_t>(),
d_linear2_out.data<T>(), d_residual.data<T>(), d_linear2_bias_ptr);
} else {
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx, d_out.data<T>(), dropout2_out.data<T>(),
dropout2_mask.data<uint8_t>(), ln2_gamma_ptr, ln2_mean.data<U>(),
ln2_variance.data<U>(), d_dropout2_out.data<T>(), d_ln2_gamma_ptr,
d_ln2_beta_ptr, d_linear2_out.data<T>(), d_linear2_bias_ptr,
d_residual.data<T>());
}
framework::Tensor d_dropout1_out;
d_dropout1_out.mutable_data<T>({bsz_seq, dim_feedforward}, place);
MatMulGrad(ctx, d_linear2_out, dropout1_out, linear2_weight,
&d_dropout1_out, d_linear2_weight);
framework::Tensor d_linear1_out;
d_linear1_out.mutable_data<T>({bsz_seq, dim_feedforward}, place);
fused_act_dropout_helper.DropoutActBiasGrad(
ctx, d_dropout1_out.data<T>(), linear1_out.data<T>(), linear1_bias_ptr,
dropout1_mask.data<uint8_t>(), d_linear1_out.data<T>(),
d_linear1_bias_ptr, act_method);
if (pre_layer_norm) {
framework::Tensor d_ln1_out;
d_ln1_out.mutable_data<T>({bsz_seq, d_model}, place);
MatMulGrad(ctx, d_linear1_out, ln1_out, linear1_weight, &d_ln1_out,
d_linear1_weight);
pre_layernorm_helper.LayerNormGrad(ctx, d_ln1_out.data<T>(), x.data<T>(),
ln1_gamma_ptr, ln1_mean.data<U>(),
ln1_variance.data<U>(), d_x->data<T>(),
d_ln1_gamma_ptr, d_ln1_beta_ptr);
} else {
MatMulGrad(ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight);
}
}
void Compute(const framework::ExecutionContext& context) const override {
using U = LayerNormParamType<T>;
auto d_out =
*context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto x = *context.Input<framework::Tensor>("X");
auto dropout1_mask = *context.Input<framework::Tensor>("Dropout1Mask");
auto dropout2_mask = *context.Input<framework::Tensor>("Dropout2Mask");
auto linear1_out = *context.Input<framework::Tensor>("Linear1Out");
auto ln1_out = *context.Input<framework::Tensor>("Ln1Out");
auto dropout1_out = *context.Input<framework::Tensor>("Dropout1Out");
auto dropout2_out = *context.Input<framework::Tensor>("Dropout2Out");
auto linear1_weight = *context.Input<framework::Tensor>("Linear1Weight");
auto* linear1_bias = context.Input<framework::Tensor>("Linear1Bias");
auto linear2_weight = *context.Input<framework::Tensor>("Linear2Weight");
auto ln1_mean = *context.Input<framework::Tensor>("Ln1Mean");
auto ln1_variance = *context.Input<framework::Tensor>("Ln1Variance");
auto* ln1_scale = context.Input<framework::Tensor>("Ln1Scale");
auto* ln1_bias = context.Input<framework::Tensor>("Ln1Bias");
auto ln2_mean = *context.Input<framework::Tensor>("Ln2Mean");
auto ln2_variance = *context.Input<framework::Tensor>("Ln2Variance");
auto* ln2_scale = context.Input<framework::Tensor>("Ln2Scale");
auto* ln2_bias = context.Input<framework::Tensor>("Ln2Bias");
auto* d_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* d_ln1_scale =
context.Output<framework::Tensor>(framework::GradVarName("Ln1Scale"));
auto* d_ln1_bias =
context.Output<framework::Tensor>(framework::GradVarName("Ln1Bias"));
auto* d_ln2_scale =
context.Output<framework::Tensor>(framework::GradVarName("Ln2Scale"));
auto* d_ln2_bias =
context.Output<framework::Tensor>(framework::GradVarName("Ln2Bias"));
auto* d_linear1_weight = context.Output<framework::Tensor>(
framework::GradVarName("Linear1Weight"));
auto* d_linear1_bias = context.Output<framework::Tensor>(
framework::GradVarName("Linear1Bias"));
auto* d_linear2_weight = context.Output<framework::Tensor>(
framework::GradVarName("Linear2Weight"));
auto* d_linear2_bias = context.Output<framework::Tensor>(
framework::GradVarName("Linear2Bias"));
const float epsilon1 = context.Attr<float>("ln1_epsilon");
const float epsilon2 = context.Attr<float>("ln2_epsilon");
const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm");
const std::string act_method = context.Attr<std::string>("act_method");
DropoutParam dropout_param1(context, 1);
DropoutParam dropout_param2(context, 2);
auto place = context.GetPlace();
d_x->mutable_data<T>(place);
if (d_ln1_scale) {
d_ln1_scale->mutable_data<U>(place);
}
if (d_ln1_bias) {
d_ln1_bias->mutable_data<U>(place);
}
if (d_ln2_scale) {
d_ln2_scale->mutable_data<U>(place);
}
if (d_ln2_bias) {
d_ln2_bias->mutable_data<U>(place);
}
if (d_linear1_bias) {
d_linear1_bias->mutable_data<T>(place);
}
if (d_linear2_bias) {
d_linear2_bias->mutable_data<T>(place);
}
d_linear1_weight->mutable_data<T>(place);
d_linear2_weight->mutable_data<T>(place);
auto x_dim = x.dims();
auto mat_dim_x =
math::CreateMatrixDescriptor(RowMatrixFromVector(x_dim), 0, false);
auto linear1_weight_dim = linear1_weight.dims();
int d_model = linear1_weight_dim[0];
int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1];
int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;
FFNGrad(d_out, x, dropout1_mask, dropout2_mask, linear1_out, ln1_out,
dropout1_out, dropout2_out, linear1_weight, linear1_bias,
linear2_weight, ln1_scale, ln1_bias, ln1_mean, ln1_variance,
ln2_scale, ln2_bias, ln2_mean, ln2_variance, d_x, d_linear1_weight,
d_linear1_bias, d_linear2_weight, d_linear2_bias, d_ln1_scale,
d_ln1_bias, d_ln2_scale, d_ln2_bias, bsz_seq, d_model,
dim_feedforward, dropout_param1, dropout_param2, act_method,
pre_layer_norm, epsilon1, epsilon2, context.cuda_device_context());
}
};
} // namespace operators
} // namespace paddle
......@@ -181,3 +385,10 @@ REGISTER_OP_CUDA_KERNEL(
ops::FusedFeedForwardKernel<paddle::platform::CUDADeviceContext, double>,
ops::FusedFeedForwardKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
fused_feedforward_grad,
ops::FusedFeedForwardGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::FusedFeedForwardGradKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FusedFeedForwardGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
......@@ -30,10 +30,10 @@ class TestFusedFFNOp(OpTest):
self.layer_norm_dtype = "float32"
def getShape(self):
self.batch_size = np.random.randint(1, 64)
self.query_length = np.random.randint(32, 256)
self.d_model = np.random.randint(32, 1024)
self.dim_feedforward = np.random.randint(32, 1024)
self.batch_size = np.random.randint(1, 32)
self.query_length = np.random.randint(32, 128)
self.d_model = np.random.randint(32, 512)
self.dim_feedforward = np.random.randint(32, 512)
def getDiff(self):
self.rtol = 1e-3
......@@ -48,6 +48,8 @@ class TestFusedFFNOp(OpTest):
def setUp(self):
paddle.disable_static()
self.__class__.op_type = "fused_feedforward"
#check grad in test_out_and_grad()
self.__class__.no_need_check_grad = True
self.getDtype()
self.getShape()
self.getDiff()
......@@ -82,6 +84,8 @@ class TestFusedFFNOp(OpTest):
self.src = np.random.random((self.batch_size, self.query_length,
self.d_model)).astype(self.dtype)
self.dout = np.random.random((self.batch_size, self.query_length,
self.d_model)).astype(self.dtype)
def Base(self):
paddle.disable_static()
......@@ -92,12 +96,17 @@ class TestFusedFFNOp(OpTest):
linear2_out = self.linear2(
self.dropout(self.activation(self.linear1(ln1_out))))
dropout2_out = residual + self.dropout2(linear2_out)
paddle.autograd.backward([dropout2_out],
[paddle.to_tensor(self.dout)], True)
return dropout2_out, tensor_src.grad
else:
linear2_out = self.linear2(
self.dropout(self.activation(self.linear1(tensor_src))))
dropout2_out = residual + self.dropout2(linear2_out)
dropout2_out = self.norm2(dropout2_out)
return dropout2_out
paddle.autograd.backward([dropout2_out],
[paddle.to_tensor(self.dout)], True)
return dropout2_out, tensor_src.grad
def FusedFFN(self):
paddle.disable_static()
......@@ -126,13 +135,19 @@ class TestFusedFFNOp(OpTest):
0.0,
activation=self.act_method,
pre_layer_norm=self.pre_layer_norm)
return out
paddle.autograd.backward([out], [paddle.to_tensor(self.dout)])
return out, x.grad
def test_fused_ffn(self):
base_out = self.Base()
fused_out = self.FusedFFN()
def test_out_and_grad(self):
base_out, base_grad = self.Base()
fused_out, fused_grad = self.FusedFFN()
np.testing.assert_allclose(
base_out.numpy(), fused_out.numpy(), rtol=self.rtol, atol=self.atol)
np.testing.assert_allclose(
base_grad.numpy(),
fused_grad.numpy(),
rtol=self.rtol,
atol=self.atol)
class TestFusedFFNOpFp16(TestFusedFFNOp):
......@@ -145,10 +160,10 @@ class TestFusedFFNOpFp16(TestFusedFFNOp):
self.atol = 1e-2
def getShape(self):
self.batch_size = 8
self.query_length = 128
self.d_model = 512
self.dim_feedforward = 512
self.batch_size = 4
self.query_length = 32
self.d_model = 128
self.dim_feedforward = 256
class TestFusedFFNOpFp64(TestFusedFFNOp):
......@@ -263,7 +278,7 @@ class APITestStaticFusedFFN(unittest.TestCase):
real_res.append(fetch)
self.assertTrue(
np.allclose(
real_res[0], real_res[1], atol=1e-5),
real_res[0], real_res[1], atol=1e-3),
"two value is check diff")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册