未验证 提交 19e866f9 编写于 作者: Y Yiqun Liu 提交者: GitHub

Support optional residual add in fused_attention and fused_feedforward. (#43474)

* Support optional residual add in fused_attention and fused_feedforward.

* Add checkpoint and add the check of add_residual when pre_layer_norm is false.

* Add TODO and change the python api to add add_residual argument.
上级 2540b023
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -378,6 +379,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -378,6 +379,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"0.0 and 0.001, But received [%s].", "0.0 and 0.001, But received [%s].",
ln_epsilon)); ln_epsilon));
}); });
AddAttr<bool>("add_residual", "Whether to add residual.").SetDefault(true);
AddAttr<int>( AddAttr<int>(
"ring_id", "ring_id",
"ring id for tensor model parallel. distributed training and inference") "ring id for tensor model parallel. distributed training and inference")
...@@ -655,3 +657,11 @@ REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp, ...@@ -655,3 +657,11 @@ REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp,
ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>, ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>,
ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>); ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp); REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp);
REGISTER_OP_VERSION(fused_attention)
.AddCheckpoint(
R"ROC(
Add a new attribute [add_residual] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"add_residual", "A flag to indicate whether to add residual.",
true));
...@@ -246,26 +246,32 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -246,26 +246,32 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// tensor model parallel // tensor model parallel
AllReduce<T>(*out_linear_out, ring_id, ctx.cuda_device_context()); AllReduce<T>(*out_linear_out, ring_id, ctx.cuda_device_context());
bool add_residual = ctx.Attr<bool>("add_residual");
const T *residual_ptr = add_residual ? x_data : nullptr;
if (pre_layer_norm) { if (pre_layer_norm) {
// output = (residual + dropout(input + bias)) // output = (residual + dropout(input + bias))
fused_dropout_layernorm_helper.ResidualDropoutBias( fused_dropout_layernorm_helper.ResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data, ctx.cuda_device_context(), out_linear_out_data, residual_ptr,
out_linear_bias_data, final_out_data, dropout_mask_out_data); out_linear_bias_data, final_out_data, dropout_mask_out_data);
} else { } else {
auto *ln_scale_2_data = // TODO(Xreki): support post layer_norm case when add_residual is false.
(ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>()); PADDLE_ENFORCE_EQ(add_residual, true,
auto *ln_bias_2_data = platform::errors::InvalidArgument(
(ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>()); "Attribute add_residual is expected to be true "
auto *bias_dropout_residual_out_data = "when pre_layer_norm is false."));
const U *ln_scale_2_ptr = ln_scale_2 ? ln_scale_2->data<U>() : nullptr;
const U *ln_bias_2_ptr = ln_bias_2 ? ln_bias_2->data<U>() : nullptr;
T *bias_dropout_residual_out_ptr =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace()); bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace()); U *ln_mean_2_ptr = ln_mean_2->mutable_data<U>(ctx.GetPlace());
auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace()); U *ln_var_2_ptr = ln_var_2->mutable_data<U>(ctx.GetPlace());
// output = layernorm(residual + dropout(input + bias)) // output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias( fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data, ctx.cuda_device_context(), out_linear_out_data, residual_ptr,
out_linear_bias_data, ln_scale_2_data, ln_bias_2_data, out_linear_bias_data, ln_scale_2_ptr, ln_bias_2_ptr,
bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data, bias_dropout_residual_out_ptr, dropout_mask_out_data, final_out_data,
ln_mean_2_data, ln_var_2_data); ln_mean_2_ptr, ln_var_2_ptr);
} }
} }
}; };
...@@ -419,16 +425,17 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -419,16 +425,17 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
int output_size = 3 * hidden_size; int output_size = 3 * hidden_size;
int input_size = dim_embed; int input_size = dim_embed;
bool add_residual = ctx.Attr<bool>("add_residual");
Tensor d_residual; Tensor d_residual;
T *d_residual_data = nullptr;
if (add_residual) {
d_residual.Resize(input_x_dims); d_residual.Resize(input_x_dims);
T *d_residual_data = d_residual.mutable_data<T>(ctx.GetPlace()); d_residual_data = d_residual.mutable_data<T>(ctx.GetPlace());
}
bool transA = false; bool transA = false;
bool transB = true; bool transB = true;
bool compute_qkv_bias = true; bool compute_qkv_bias = qkv_bias ? true : false;
if (qkv_bias == nullptr) {
compute_qkv_bias = false;
}
auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(), auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
epsilon, bsz_seq, dim_embed); epsilon, bsz_seq, dim_embed);
auto qkv_compute = auto qkv_compute =
...@@ -539,17 +546,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -539,17 +546,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
// tensor model parallel // tensor model parallel
AllReduce<T>(*d_x, ring_id, ctx.cuda_device_context()); AllReduce<T>(*d_x, ring_id, ctx.cuda_device_context());
} }
if (add_residual) {
// gradient accumulation // gradient accumulation
std::vector<const Tensor *> ins; std::vector<const Tensor *> ins = {&d_residual, d_x};
std::vector<Tensor *> outs; std::vector<Tensor *> outs = {d_x};
ins.emplace_back(&d_residual); phi::funcs::ElementwiseKernel<T>(ctx.cuda_device_context(), ins, &outs,
ins.emplace_back(d_x);
outs.emplace_back(d_x);
int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
ctx.cuda_device_context(), ins, &outs, elewise_add_axis,
phi::funcs::AddFunctor<T>()); phi::funcs::AddFunctor<T>());
} }
}
}; };
} // namespace operators } // namespace operators
......
...@@ -150,10 +150,11 @@ class FusedDropoutHelper { ...@@ -150,10 +150,11 @@ class FusedDropoutHelper {
LaunchResidualDropoutBiasGrad<T, uint8_t>( LaunchResidualDropoutBiasGrad<T, uint8_t>(
d_out, mask, dropout_param_.dropout_prob, d_out, mask, dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx); dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx);
auto cuda_place = ctx.GetPlace(); if (d_residual) {
memory::Copy(cuda_place, d_residual, cuda_place, d_out, memory::Copy(ctx.GetPlace(), d_residual, ctx.GetPlace(), d_out,
rows_ * cols_ * sizeof(T), ctx.stream()); rows_ * cols_ * sizeof(T), ctx.stream());
} }
}
// out = dropout(activation(src + bias)) // out = dropout(activation(src + bias))
void DropoutActBias(const platform::CUDADeviceContext& ctx, const T* src, void DropoutActBias(const platform::CUDADeviceContext& ctx, const T* src,
......
...@@ -194,6 +194,7 @@ class FusedFeedForwardOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -194,6 +194,7 @@ class FusedFeedForwardOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(false); .SetDefault(false);
AddAttr<int>("dropout1_seed", "Dropout1 random seed.").SetDefault(0); AddAttr<int>("dropout1_seed", "Dropout1 random seed.").SetDefault(0);
AddAttr<int>("dropout2_seed", "Dropout2 random seed.").SetDefault(0); AddAttr<int>("dropout2_seed", "Dropout2 random seed.").SetDefault(0);
AddAttr<bool>("add_residual", "Whether to add residual.").SetDefault(true);
AddAttr<int>("ring_id", "ring id for tensor model parallel.") AddAttr<int>("ring_id", "ring id for tensor model parallel.")
.SetDefault(-1); .SetDefault(-1);
AddComment(R"DOC( AddComment(R"DOC(
...@@ -367,3 +368,11 @@ REGISTER_OPERATOR(fused_feedforward, ops::FusedFeedForwardOp, ...@@ -367,3 +368,11 @@ REGISTER_OPERATOR(fused_feedforward, ops::FusedFeedForwardOp,
ops::FusedFeedForwardOpGradMaker<paddle::framework::OpDesc>, ops::FusedFeedForwardOpGradMaker<paddle::framework::OpDesc>,
ops::FusedFeedForwardOpGradMaker<paddle::imperative::OpBase>); ops::FusedFeedForwardOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_feedforward_grad, ops::FusedFeedForwardOpGrad); REGISTER_OPERATOR(fused_feedforward_grad, ops::FusedFeedForwardOpGrad);
REGISTER_OP_VERSION(fused_feedforward)
.AddCheckpoint(
R"ROC(
Add a new attribute [add_residual] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"add_residual", "A flag to indicate whether to add residual.",
true));
...@@ -69,7 +69,8 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -69,7 +69,8 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
blas.MatMul(a, mat_dim_a, b, mat_dim_b, alpha, c, T(0)); blas.MatMul(a, mat_dim_a, b, mat_dim_b, alpha, c, T(0));
} }
void FFN(const framework::Tensor& x, const framework::Tensor& linear1_weight, void FFN(const platform::CUDADeviceContext& ctx, const framework::Tensor& x,
const framework::Tensor& linear1_weight,
const framework::Tensor* linear1_bias, const framework::Tensor* linear1_bias,
const framework::Tensor& linear2_weight, const framework::Tensor& linear2_weight,
const framework::Tensor* linear2_bias, const framework::Tensor* linear2_bias,
...@@ -84,10 +85,9 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -84,10 +85,9 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
framework::Tensor* dropout1_out, framework::Tensor* dropout2_out, framework::Tensor* dropout1_out, framework::Tensor* dropout2_out,
const int bsz_seq, const int d_model, const int dim_feedforward, const int bsz_seq, const int d_model, const int dim_feedforward,
const std::string& act_method, const bool pre_layer_norm, const std::string& act_method, const bool pre_layer_norm,
const float epsilon1, const float epsilon2, const int ring_id, const float epsilon1, const float epsilon2, const bool add_residual,
const DropoutParam& dropout_param1, const int ring_id, const DropoutParam& dropout_param1,
const DropoutParam& dropout_param2, const DropoutParam& dropout_param2) const {
const platform::CUDADeviceContext& ctx) const {
FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper( FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
bsz_seq, d_model, epsilon1); bsz_seq, d_model, epsilon1);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper( FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
...@@ -127,15 +127,22 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -127,15 +127,22 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
// tensor model parallel // tensor model parallel
AllReduce<T>(linear2_out, ring_id, ctx); AllReduce<T>(linear2_out, ring_id, ctx);
const T* residual_ptr = add_residual ? x.data<T>() : nullptr;
if (!pre_layer_norm) { if (!pre_layer_norm) {
// TODO(Xreki): support post layer_norm case when add_residual is false.
PADDLE_ENFORCE_EQ(add_residual, true,
platform::errors::InvalidArgument(
"Attribute add_residual is expected to be true "
"when pre_layer_norm is false."));
fused_dropout_layernorm_helper.LayernormResidualDropoutBias( fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx, linear2_out.data<T>(), x.data<T>(), linear2_bias_ptr, ctx, linear2_out.data<T>(), residual_ptr, linear2_bias_ptr,
ln2_scale_ptr, ln2_bias_ptr, dropout2_out->data<T>(), ln2_scale_ptr, ln2_bias_ptr, dropout2_out->data<T>(),
dropout2_mask->data<uint8_t>(), out->data<T>(), ln2_mean->data<U>(), dropout2_mask->data<uint8_t>(), out->data<T>(), ln2_mean->data<U>(),
ln2_variance->data<U>()); ln2_variance->data<U>());
} else { } else {
fused_dropout_layernorm_helper.ResidualDropoutBias( fused_dropout_layernorm_helper.ResidualDropoutBias(
ctx, linear2_out.data<T>(), x.data<T>(), linear2_bias_ptr, ctx, linear2_out.data<T>(), residual_ptr, linear2_bias_ptr,
out->data<T>(), dropout2_mask->data<uint8_t>()); out->data<T>(), dropout2_mask->data<uint8_t>());
} }
} }
...@@ -183,6 +190,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -183,6 +190,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
const float epsilon1 = context.Attr<float>("ln1_epsilon"); const float epsilon1 = context.Attr<float>("ln1_epsilon");
const float epsilon2 = context.Attr<float>("ln2_epsilon"); const float epsilon2 = context.Attr<float>("ln2_epsilon");
const int ring_id = context.Attr<int>("ring_id"); const int ring_id = context.Attr<int>("ring_id");
const bool add_residual = context.Attr<bool>("add_residual");
DropoutParam dropout_param1(context, 1); DropoutParam dropout_param1(context, 1);
DropoutParam dropout_param2(context, 2); DropoutParam dropout_param2(context, 2);
...@@ -214,12 +222,12 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -214,12 +222,12 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
int dim_feedforward = dim[dim.size() - 1]; int dim_feedforward = dim[dim.size() - 1];
int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_; int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;
FFN(*x, *linear1_weight, linear1_bias, *linear2_weight, linear2_bias, FFN(context.cuda_device_context(), *x, *linear1_weight, linear1_bias,
ln1_scale, ln1_bias, ln2_scale, ln2_bias, out, dropout1_mask, *linear2_weight, linear2_bias, ln1_scale, ln1_bias, ln2_scale, ln2_bias,
dropout2_mask, ln1_mean, ln1_variance, ln2_mean, ln2_variance, out, dropout1_mask, dropout2_mask, ln1_mean, ln1_variance, ln2_mean,
linear1_out, ln1_out, dropout1_out, dropout2_out, bsz_seq, d_model, ln2_variance, linear1_out, ln1_out, dropout1_out, dropout2_out, bsz_seq,
dim_feedforward, act_method, pre_layer_norm, epsilon1, epsilon2, d_model, dim_feedforward, act_method, pre_layer_norm, epsilon1,
ring_id, dropout_param1, dropout_param2, context.cuda_device_context()); epsilon2, add_residual, ring_id, dropout_param1, dropout_param2);
} }
}; };
...@@ -243,8 +251,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -243,8 +251,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
} }
void FFNGrad( void FFNGrad(
const framework::Tensor& d_out, const framework::Tensor& x, const platform::CUDADeviceContext& ctx, const framework::Tensor& d_out,
const framework::Tensor& dropout1_mask, const framework::Tensor& x, const framework::Tensor& dropout1_mask,
const framework::Tensor& dropout2_mask, const framework::Tensor& dropout2_mask,
const framework::Tensor& linear1_out, const framework::Tensor* ln1_out, const framework::Tensor& linear1_out, const framework::Tensor* ln1_out,
const framework::Tensor& dropout1_out, const framework::Tensor& dropout1_out,
...@@ -264,7 +272,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -264,7 +272,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const int dim_feedforward, const DropoutParam& dropout_param1, const int dim_feedforward, const DropoutParam& dropout_param1,
const DropoutParam& dropout_param2, const std::string& act_method, const DropoutParam& dropout_param2, const std::string& act_method,
const bool pre_layer_norm, const float epsilon1, const float epsilon2, const bool pre_layer_norm, const float epsilon1, const float epsilon2,
const int ring_id, const platform::CUDADeviceContext& ctx) const { const bool add_residual, const int ring_id) const {
FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper( FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
bsz_seq, d_model, epsilon1); bsz_seq, d_model, epsilon1);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper( FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
...@@ -296,19 +304,22 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -296,19 +304,22 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
framework::Tensor d_linear2_out, d_dropout2_out, d_residual; framework::Tensor d_linear2_out, d_dropout2_out, d_residual;
d_linear2_out.mutable_data<T>({bsz_seq, d_model}, place); d_linear2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_dropout2_out.mutable_data<T>({bsz_seq, d_model}, place); d_dropout2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_residual.mutable_data<T>(d_x->dims(), place);
T* d_residual_ptr = nullptr;
if (add_residual) {
d_residual_ptr = d_residual.mutable_data<T>(d_x->dims(), place);
}
if (pre_layer_norm) { if (pre_layer_norm) {
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad( fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
ctx, d_out.data<T>(), dropout2_mask.data<uint8_t>(), ctx, d_out.data<T>(), dropout2_mask.data<uint8_t>(),
d_linear2_out.data<T>(), d_residual.data<T>(), d_linear2_bias_ptr); d_linear2_out.data<T>(), d_residual_ptr, d_linear2_bias_ptr);
} else { } else {
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx, d_out.data<T>(), dropout2_out.data<T>(), ctx, d_out.data<T>(), dropout2_out.data<T>(),
dropout2_mask.data<uint8_t>(), ln2_gamma_ptr, ln2_mean->data<U>(), dropout2_mask.data<uint8_t>(), ln2_gamma_ptr, ln2_mean->data<U>(),
ln2_variance->data<U>(), d_dropout2_out.data<T>(), d_ln2_gamma_ptr, ln2_variance->data<U>(), d_dropout2_out.data<T>(), d_ln2_gamma_ptr,
d_ln2_beta_ptr, d_linear2_out.data<T>(), d_linear2_bias_ptr, d_ln2_beta_ptr, d_linear2_out.data<T>(), d_linear2_bias_ptr,
d_residual.data<T>()); d_residual_ptr);
} }
framework::Tensor d_dropout1_out; framework::Tensor d_dropout1_out;
...@@ -339,14 +350,14 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -339,14 +350,14 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
// tensor model parallel // tensor model parallel
AllReduce<T>(*d_x, ring_id, ctx); AllReduce<T>(*d_x, ring_id, ctx);
} }
std::vector<const Tensor*> ins(2);
std::vector<Tensor*> outs(1); if (add_residual) {
ins[0] = &d_residual; // gradient accumulation
ins[1] = d_x; std::vector<const Tensor*> ins = {&d_residual, d_x};
outs[0] = d_x; std::vector<Tensor*> outs = {d_x};
int elewise_add_axis = -1; phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs,
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( phi::funcs::AddFunctor<T>());
ctx, ins, &outs, elewise_add_axis, phi::funcs::AddFunctor<T>()); }
} }
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -410,6 +421,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -410,6 +421,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const float epsilon1 = context.Attr<float>("ln1_epsilon"); const float epsilon1 = context.Attr<float>("ln1_epsilon");
const float epsilon2 = context.Attr<float>("ln2_epsilon"); const float epsilon2 = context.Attr<float>("ln2_epsilon");
const bool add_residual = context.Attr<bool>("add_residual");
const int ring_id = context.Attr<int>("ring_id"); const int ring_id = context.Attr<int>("ring_id");
const std::string act_method = context.Attr<std::string>("act_method"); const std::string act_method = context.Attr<std::string>("act_method");
DropoutParam dropout_param1(context, 1); DropoutParam dropout_param1(context, 1);
...@@ -447,15 +459,15 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -447,15 +459,15 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1]; int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1];
int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_; int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;
FFNGrad(d_out, x, dropout1_mask, dropout2_mask, linear1_out, ln1_out, FFNGrad(context.cuda_device_context(), d_out, x, dropout1_mask,
dropout1_out, dropout2_out, linear1_weight, linear1_bias, dropout2_mask, linear1_out, ln1_out, dropout1_out, dropout2_out,
linear2_weight, ln1_scale, ln1_bias, ln1_mean, ln1_variance, linear1_weight, linear1_bias, linear2_weight, ln1_scale, ln1_bias,
ln2_scale, ln2_bias, ln2_mean, ln2_variance, d_x, d_linear1_weight, ln1_mean, ln1_variance, ln2_scale, ln2_bias, ln2_mean, ln2_variance,
d_linear1_bias, d_linear2_weight, d_linear2_bias, d_ln1_scale, d_x, d_linear1_weight, d_linear1_bias, d_linear2_weight,
d_ln1_bias, d_ln2_scale, d_ln2_bias, bsz_seq, d_model, d_linear2_bias, d_ln1_scale, d_ln1_bias, d_ln2_scale, d_ln2_bias,
dim_feedforward, dropout_param1, dropout_param2, act_method, bsz_seq, d_model, dim_feedforward, dropout_param1, dropout_param2,
pre_layer_norm, epsilon1, epsilon2, ring_id, act_method, pre_layer_norm, epsilon1, epsilon2, add_residual,
context.cuda_device_context()); ring_id);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -140,9 +140,12 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, ...@@ -140,9 +140,12 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols,
// dropout_prob == 1.0f // dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) { if (std::abs(dropout_prob - 1.0f) < 1e-5) {
if (residual == dst) return; if (residual == dst) return;
auto cuda_place = ctx.GetPlace(); if (residual) {
memory::Copy(cuda_place, dst, cuda_place, residual, rows * cols * sizeof(T), memory::Copy(ctx.GetPlace(), dst, ctx.GetPlace(), residual,
ctx.stream()); rows * cols * sizeof(T), ctx.stream());
} else {
SetZero<T>(ctx, dst, rows * cols);
}
if (!is_test) { if (!is_test) {
SetZero<MaskType>(ctx, mask_data, rows * cols); SetZero<MaskType>(ctx, mask_data, rows * cols);
} }
......
...@@ -29,8 +29,10 @@ PD_DECLARE_KERNEL(dropout_grad, GPU, ALL_LAYOUT); ...@@ -29,8 +29,10 @@ PD_DECLARE_KERNEL(dropout_grad, GPU, ALL_LAYOUT);
namespace framework = paddle::framework; namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
bool CheckEqual(float value, float ref) { return std::abs(value - ref) < 1e-5; }
/** /**
* @brief the unittest of fusedresidualdropoutbias * @brief the unittest of FusedResidualDropoutBias
* 1. random input data * 1. random input data
* 2. add bias, call paddle dropout op, add residual, and get the base result * 2. add bias, call paddle dropout op, add residual, and get the base result
* 3. call FusedResidualDropoutBias function get fused result * 3. call FusedResidualDropoutBias function get fused result
...@@ -38,7 +40,7 @@ namespace platform = paddle::platform; ...@@ -38,7 +40,7 @@ namespace platform = paddle::platform;
*/ */
template <typename T> template <typename T>
struct TestFusedResidualDropoutBias { struct FusedResidualDropoutBiasTester {
uint32_t rows; uint32_t rows;
uint32_t cols; uint32_t cols;
uint64_t seed; uint64_t seed;
...@@ -46,6 +48,8 @@ struct TestFusedResidualDropoutBias { ...@@ -46,6 +48,8 @@ struct TestFusedResidualDropoutBias {
bool is_upscale_in_train; bool is_upscale_in_train;
bool is_test; // default false, Set to true for inference only bool is_test; // default false, Set to true for inference only
bool has_bias = true; bool has_bias = true;
bool add_residual = true;
framework::Tensor src, residual, bias, out, mask; framework::Tensor src, residual, bias, out, mask;
framework::Tensor dsrc, dbias; framework::Tensor dsrc, dbias;
...@@ -56,37 +60,33 @@ struct TestFusedResidualDropoutBias { ...@@ -56,37 +60,33 @@ struct TestFusedResidualDropoutBias {
platform::CUDAPlace place; platform::CUDAPlace place;
platform::CUDADeviceContext *ctx; platform::CUDADeviceContext *ctx;
TestFusedResidualDropoutBias() { FusedResidualDropoutBiasTester() {
rows = 32; rows = 32;
cols = 32; cols = 32;
seed = 0; seed = 0;
dropout_prob = 0.0; dropout_prob = 0.0;
is_upscale_in_train = false; is_upscale_in_train = false;
is_test = false; is_test = false;
has_bias = true;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto device_ctx = pool.Get(place); auto device_ctx = pool.Get(place);
ctx = reinterpret_cast<platform::CUDADeviceContext *>(device_ctx); ctx = reinterpret_cast<platform::CUDADeviceContext *>(device_ctx);
} }
TestFusedResidualDropoutBias(int rows_, int cols_, uint64_t seed_ = 0, FusedResidualDropoutBiasTester(int rows, int cols, uint64_t seed = 0,
float dropout_prob_ = 0.0, float dropout_prob = 0.0,
bool is_upscale_in_train_ = false, bool is_upscale_in_train = false,
bool is_test_ = false) { bool is_test = false)
rows = rows_; : rows(rows),
cols = cols_; cols(cols),
seed = seed_; seed(seed),
dropout_prob = dropout_prob_; dropout_prob(dropout_prob),
is_upscale_in_train = is_upscale_in_train_; is_upscale_in_train(is_upscale_in_train),
is_test = is_test_; is_test(is_test) {
has_bias = true;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto device_ctx = pool.Get(place); auto device_ctx = pool.Get(place);
ctx = reinterpret_cast<platform::CUDADeviceContext *>(device_ctx); ctx = reinterpret_cast<platform::CUDADeviceContext *>(device_ctx);
} }
~TestFusedResidualDropoutBias() {}
void SetUp() { void SetUp() {
const int n = rows * cols; const int n = rows * cols;
correct_out.resize(n); correct_out.resize(n);
...@@ -95,7 +95,9 @@ struct TestFusedResidualDropoutBias { ...@@ -95,7 +95,9 @@ struct TestFusedResidualDropoutBias {
correct_dbias.resize(cols); correct_dbias.resize(cols);
src_vec.resize(n); src_vec.resize(n);
if (add_residual) {
residual_vec.resize(n); residual_vec.resize(n);
}
bias_vec.resize(cols); bias_vec.resize(cols);
std::default_random_engine random(time(NULL)); std::default_random_engine random(time(NULL));
std::uniform_real_distribution<float> dis(0.0, 1.0); std::uniform_real_distribution<float> dis(0.0, 1.0);
...@@ -103,7 +105,9 @@ struct TestFusedResidualDropoutBias { ...@@ -103,7 +105,9 @@ struct TestFusedResidualDropoutBias {
for (int i = 0; i < rows; i++) { for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) { for (int j = 0; j < cols; j++) {
src_vec[i * cols + j] = static_cast<T>(dis(random)); src_vec[i * cols + j] = static_cast<T>(dis(random));
if (add_residual) {
residual_vec[i * cols + j] = static_cast<T>(dis(random)); residual_vec[i * cols + j] = static_cast<T>(dis(random));
}
if (i == 0) { if (i == 0) {
bias_vec[j] = dis(random); bias_vec[j] = dis(random);
} }
...@@ -112,14 +116,15 @@ struct TestFusedResidualDropoutBias { ...@@ -112,14 +116,15 @@ struct TestFusedResidualDropoutBias {
framework::TensorFromVector<T>(src_vec, *ctx, &src); framework::TensorFromVector<T>(src_vec, *ctx, &src);
src.Resize({rows, cols}); src.Resize({rows, cols});
if (add_residual) {
framework::TensorFromVector<T>(residual_vec, *ctx, &residual); framework::TensorFromVector<T>(residual_vec, *ctx, &residual);
residual.Resize({rows, cols}); residual.Resize({rows, cols});
}
if (has_bias) { if (has_bias) {
framework::TensorFromVector<T>(bias_vec, *ctx, &bias); framework::TensorFromVector<T>(bias_vec, *ctx, &bias);
bias.Resize({cols}); bias.Resize({cols});
} }
{
out.mutable_data<T>({rows, cols}, place); out.mutable_data<T>({rows, cols}, place);
mask.mutable_data<uint8_t>({rows, cols}, place); mask.mutable_data<uint8_t>({rows, cols}, place);
dsrc.mutable_data<T>({rows, cols}, place); dsrc.mutable_data<T>({rows, cols}, place);
...@@ -128,31 +133,32 @@ struct TestFusedResidualDropoutBias { ...@@ -128,31 +133,32 @@ struct TestFusedResidualDropoutBias {
dbias.mutable_data<T>({cols}, place); dbias.mutable_data<T>({cols}, place);
} }
} }
}
void BaseForward() { void BaseForward() {
std::vector<T> out1(rows * cols), out2(rows * cols);
if (has_bias) { if (has_bias) {
// add bias // add bias
std::vector<T> bias_out(rows * cols);
for (int i = 0; i < rows; i++) { for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) { for (int j = 0; j < cols; j++) {
out1[i * cols + j] = src_vec[i * cols + j] + bias_vec[j]; bias_out[i * cols + j] = src_vec[i * cols + j] + bias_vec[j];
} }
} }
// call dropout // call dropout
Dropout<T>(out1, src.dims(), &out2, &correct_mask, *ctx, seed, Dropout<T>(bias_out, src.dims(), &correct_out, &correct_mask, *ctx, seed,
dropout_prob, is_upscale_in_train, is_test); dropout_prob, is_upscale_in_train, is_test);
} else { } else {
Dropout<T>(src_vec, src.dims(), &out2, &correct_mask, *ctx, seed, Dropout<T>(src_vec, src.dims(), &correct_out, &correct_mask, *ctx, seed,
dropout_prob, is_upscale_in_train, is_test); dropout_prob, is_upscale_in_train, is_test);
} }
ctx->Wait(); ctx->Wait();
PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError()); PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
if (add_residual) {
// add residual // add residual
for (int i = 0; i < rows; i++) { for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) { for (int j = 0; j < cols; j++) {
correct_out[i * cols + j] = int idx = i * cols + j;
residual_vec[i * cols + j] + out2[i * cols + j]; correct_out[idx] = residual_vec[idx] + correct_out[idx];
}
} }
} }
} }
...@@ -178,13 +184,11 @@ struct TestFusedResidualDropoutBias { ...@@ -178,13 +184,11 @@ struct TestFusedResidualDropoutBias {
1) * 1) *
VecSize; VecSize;
T *bias_ptr = nullptr; T *bias_ptr = has_bias ? bias.data<T>() : nullptr;
if (has_bias) { T *residual_ptr = add_residual ? residual.data<T>() : nullptr;
bias_ptr = bias.data<T>();
}
paddle::operators::LaunchResidualDropoutBias<T, uint8_t>( paddle::operators::LaunchResidualDropoutBias<T, uint8_t>(
rows, cols, increment, seed, dropout_prob, is_test, is_upscale_in_train, rows, cols, increment, seed, dropout_prob, is_test, is_upscale_in_train,
src.data<T>(), residual.data<T>(), bias_ptr, mask.data<uint8_t>(), src.data<T>(), residual_ptr, bias_ptr, mask.data<uint8_t>(),
out.data<T>(), *ctx); out.data<T>(), *ctx);
ctx->Wait(); ctx->Wait();
PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError()); PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
...@@ -195,10 +199,7 @@ struct TestFusedResidualDropoutBias { ...@@ -195,10 +199,7 @@ struct TestFusedResidualDropoutBias {
return; return;
} }
T *bias_ptr = nullptr; T *bias_ptr = has_bias ? dbias.data<T>() : nullptr;
if (has_bias) {
bias_ptr = dbias.data<T>();
}
paddle::operators::LaunchResidualDropoutBiasGrad<T, uint8_t>( paddle::operators::LaunchResidualDropoutBiasGrad<T, uint8_t>(
out.data<T>(), mask.data<uint8_t>(), dropout_prob, is_upscale_in_train, out.data<T>(), mask.data<uint8_t>(), dropout_prob, is_upscale_in_train,
rows, cols, dsrc.data<T>(), bias_ptr, *ctx); rows, cols, dsrc.data<T>(), bias_ptr, *ctx);
...@@ -214,17 +215,19 @@ struct TestFusedResidualDropoutBias { ...@@ -214,17 +215,19 @@ struct TestFusedResidualDropoutBias {
void CheckOut(const T diff) { void CheckOut(const T diff) {
const int n = rows * cols; const int n = rows * cols;
std::vector<T> _out(n); std::vector<T> fused_out(n);
std::vector<uint8_t> _mask(n); std::vector<uint8_t> fused_mask(n);
framework::TensorToVector(out, *ctx, &_out); framework::TensorToVector(out, *ctx, &fused_out);
if (!is_test) { if (!is_test) {
framework::TensorToVector<uint8_t>(mask, *ctx, &_mask); framework::TensorToVector<uint8_t>(mask, *ctx, &fused_mask);
} }
ctx->Wait(); ctx->Wait();
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
EXPECT_LT(std::abs(_out[i] - correct_out[i]), diff); EXPECT_LT(std::abs(fused_out[i] - correct_out[i]), diff);
if (!is_test) EXPECT_EQ(_mask[i], correct_mask[i]); if (!is_test) {
EXPECT_EQ(fused_mask[i], correct_mask[i]);
}
} }
} }
...@@ -255,16 +258,21 @@ struct TestFusedResidualDropoutBias { ...@@ -255,16 +258,21 @@ struct TestFusedResidualDropoutBias {
// test the shape and bias // test the shape and bias
template <typename T> template <typename T>
static void BaseTest(const bool is_fp16 = false) { static void BaseTest() {
const int rows = 16; const int rows = 16;
T default_diff = !is_fp16 ? static_cast<T>(1e-5) : static_cast<T>(1e-1); T max_diff = static_cast<T>(0);
if (std::is_same<T, paddle::platform::float16>::value) {
max_diff = static_cast<T>(1e-1);
} else {
max_diff = static_cast<T>(1e-5);
}
for (auto cols : {16, 17}) { for (auto cols : {16, 17}) {
for (auto has_bias : {true, false}) { for (auto has_bias : {true, false}) {
TestFusedResidualDropoutBias<T> test(rows, cols); FusedResidualDropoutBiasTester<T> test(rows, cols);
test.has_bias = has_bias; test.has_bias = has_bias;
test.Run(); test.Run();
test.CheckOut(default_diff); test.CheckOut(max_diff);
test.CheckGrad(default_diff); test.CheckGrad(max_diff);
} }
} }
} }
...@@ -274,14 +282,14 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias) { BaseTest<float>(); } ...@@ -274,14 +282,14 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias) { BaseTest<float>(); }
TEST(FusedDropout, GPUFusedResidualDropoutBiasDouble) { BaseTest<double>(); } TEST(FusedDropout, GPUFusedResidualDropoutBiasDouble) { BaseTest<double>(); }
TEST(FusedDropout, GPUFusedResidualDropoutBiasFp16) { TEST(FusedDropout, GPUFusedResidualDropoutBiasFp16) {
BaseTest<platform::float16>(true); BaseTest<platform::float16>();
} }
TEST(FusedDropout, GPUFusedResidualDropoutBiasIsUpscaleInTrain) { TEST(FusedDropout, GPUFusedResidualDropoutBiasIsUpscaleInTrain) {
const int rows = 16; const int rows = 16;
const int cols = 16; const int cols = 16;
for (auto is_upscale_in_train : {true, false}) { for (auto is_upscale_in_train : {true, false}) {
TestFusedResidualDropoutBias<float> test(rows, cols, 0, 1.0, FusedResidualDropoutBiasTester<float> test(rows, cols, 0, 1.0,
is_upscale_in_train, false); is_upscale_in_train, false);
test.Run(); test.Run();
test.CheckOut(static_cast<float>(1e-5)); test.CheckOut(static_cast<float>(1e-5));
...@@ -292,7 +300,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasIsUpscaleInTrain) { ...@@ -292,7 +300,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasIsUpscaleInTrain) {
TEST(FusedDropout, GPUFusedResidualDropoutBiasIsTest) { TEST(FusedDropout, GPUFusedResidualDropoutBiasIsTest) {
const int rows = 16; const int rows = 16;
const int cols = 16; const int cols = 16;
TestFusedResidualDropoutBias<float> test(rows, cols, 0, 0.35, true, true); FusedResidualDropoutBiasTester<float> test(rows, cols, 0, 0.35, true, true);
test.Run(); test.Run();
test.CheckOut(static_cast<float>(1e-5)); test.CheckOut(static_cast<float>(1e-5));
test.CheckGrad(static_cast<float>(1e-5)); test.CheckGrad(static_cast<float>(1e-5));
...@@ -301,16 +309,32 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasIsTest) { ...@@ -301,16 +309,32 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasIsTest) {
TEST(FusedDropout, GPUFusedResidualDropoutBiasSeed) { TEST(FusedDropout, GPUFusedResidualDropoutBiasSeed) {
const int rows = 16; const int rows = 16;
const int cols = 16; const int cols = 16;
TestFusedResidualDropoutBias<float> test(rows, cols, 125, 0.0, false, false); FusedResidualDropoutBiasTester<float> test(rows, cols, 125, 0.0, false,
false);
test.Run();
test.CheckOut(static_cast<float>(1e-5));
test.CheckGrad(static_cast<float>(1e-5));
}
TEST(FusedDropout, NoResidual) {
const int rows = 16;
const int cols = 16;
for (float p : {0.0f, 0.5f, 1.0f}) {
FusedResidualDropoutBiasTester<float> test(rows, cols, 0, p, false, false);
test.add_residual = false;
test.Run(); test.Run();
// For a non 0 or 1 dropout_prob, just test whether it can run successly.
if (CheckEqual(p, 0.0f) || CheckEqual(p, 1.0f)) {
test.CheckOut(static_cast<float>(1e-5)); test.CheckOut(static_cast<float>(1e-5));
test.CheckGrad(static_cast<float>(1e-5)); test.CheckGrad(static_cast<float>(1e-5));
}
}
} }
TEST(FusedDropout, GPUFusedResidualDropoutBiasLargeShape) { TEST(FusedDropout, GPUFusedResidualDropoutBiasLargeShape) {
const int rows = 256; const int rows = 256;
const int cols = 4096; const int cols = 4096;
TestFusedResidualDropoutBias<float> test(rows, cols); FusedResidualDropoutBiasTester<float> test(rows, cols);
test.Run(); test.Run();
test.CheckOut(static_cast<float>(1e-5)); test.CheckOut(static_cast<float>(1e-5));
test.CheckGrad(static_cast<float>(1e-3)); test.CheckGrad(static_cast<float>(1e-3));
...@@ -326,8 +350,8 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasLargeShapeFp16) { ...@@ -326,8 +350,8 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasLargeShapeFp16) {
if (std::getenv("_cols") != nullptr) { if (std::getenv("_cols") != nullptr) {
cols = atoi(std::getenv("_cols")); cols = atoi(std::getenv("_cols"));
} }
TestFusedResidualDropoutBias<platform::float16> test(rows, cols, 0, 0.0, true, FusedResidualDropoutBiasTester<platform::float16> test(rows, cols, 0, 0.0,
true); true, true);
test.Run(); test.Run();
test.CheckOut(static_cast<platform::float16>(1e-1)); test.CheckOut(static_cast<platform::float16>(1e-1));
test.CheckGrad(static_cast<platform::float16>(1e-1)); test.CheckGrad(static_cast<platform::float16>(1e-1));
......
...@@ -46,6 +46,7 @@ def fused_feedforward(x, ...@@ -46,6 +46,7 @@ def fused_feedforward(x,
training=True, training=True,
mode='upscale_in_train', mode='upscale_in_train',
ring_id=-1, ring_id=-1,
add_residual=True,
name=None): name=None):
r""" r"""
This is a fusion operator to compute feed forward layer in transformer model architecture. This is a fusion operator to compute feed forward layer in transformer model architecture.
...@@ -90,6 +91,7 @@ def fused_feedforward(x, ...@@ -90,6 +91,7 @@ def fused_feedforward(x,
- train: out = input * mask - train: out = input * mask
- inference: out = input * (1.0 - p) - inference: out = input * (1.0 - p)
ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using tensor parallel. ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using tensor parallel.
add_residual (bool, optional): Whether add residual at the end. Default is True.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
...@@ -134,7 +136,8 @@ def fused_feedforward(x, ...@@ -134,7 +136,8 @@ def fused_feedforward(x,
"dropout2_fix_seed", seed is not None, "dropout1_seed", "dropout2_fix_seed", seed is not None, "dropout1_seed",
seed if seed is not None else 0, "dropout2_seed", seed if seed is not None else 0, "dropout2_seed",
seed if seed is not None else 0, 'dropout1_implementation', mode, seed if seed is not None else 0, 'dropout1_implementation', mode,
'dropout2_implementation', mode, 'ring_id', ring_id) 'dropout2_implementation', mode, 'add_residual', add_residual,
'ring_id', ring_id)
return out return out
helper = LayerHelper("fused_feedforward") helper = LayerHelper("fused_feedforward")
...@@ -208,6 +211,7 @@ def fused_feedforward(x, ...@@ -208,6 +211,7 @@ def fused_feedforward(x,
'dropout2_seed': seed if seed is not None else 0, 'dropout2_seed': seed if seed is not None else 0,
'dropout1_implementation': mode, 'dropout1_implementation': mode,
'dropout2_implementation': mode, 'dropout2_implementation': mode,
'add_residual': add_residual,
'ring_id': ring_id, 'ring_id': ring_id,
}) })
return out return out
...@@ -378,6 +382,7 @@ def fused_multi_head_attention(x, ...@@ -378,6 +382,7 @@ def fused_multi_head_attention(x,
training=True, training=True,
mode='upscale_in_train', mode='upscale_in_train',
ring_id=-1, ring_id=-1,
add_residual=True,
name=None): name=None):
r""" r"""
Attention mapps queries and a set of key-value pairs to outputs, and Attention mapps queries and a set of key-value pairs to outputs, and
...@@ -454,6 +459,7 @@ def fused_multi_head_attention(x, ...@@ -454,6 +459,7 @@ def fused_multi_head_attention(x,
- train: out = input * mask - train: out = input * mask
- inference: out = input * (1.0 - p) - inference: out = input * (1.0 - p)
ring_id (int, optional): For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp ring_id (int, optional): For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp
add_residual (bool, optional): Whether add residual at the end. Default is True.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
...@@ -521,7 +527,8 @@ def fused_multi_head_attention(x, ...@@ -521,7 +527,8 @@ def fused_multi_head_attention(x,
'dropout_fix_seed', seed is not None, 'attn_dropout_seed', 'dropout_fix_seed', seed is not None, 'attn_dropout_seed',
seed if seed is not None else 0, 'dropout_seed', seed if seed is not None else 0, 'dropout_seed',
seed if seed is not None else 0, 'attn_dropout_implementation', seed if seed is not None else 0, 'attn_dropout_implementation',
mode, 'dropout_implementation', mode, 'ring_id', ring_id) mode, 'dropout_implementation', mode, 'add_residual', add_residual,
'ring_id', ring_id)
if cache_kv is not None: if cache_kv is not None:
return final_out, cache_kv_out return final_out, cache_kv_out
return final_out return final_out
...@@ -571,6 +578,7 @@ def fused_multi_head_attention(x, ...@@ -571,6 +578,7 @@ def fused_multi_head_attention(x,
'dropout_seed': seed if seed is not None else 0, 'dropout_seed': seed if seed is not None else 0,
'attn_dropout_implementation': mode, 'attn_dropout_implementation': mode,
'dropout_implementation': mode, 'dropout_implementation': mode,
'add_residual': add_residual,
'ring_id': ring_id 'ring_id': ring_id
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册