未验证 提交 0660d5f2 编写于 作者: Z Zhang Ting 提交者: GitHub

[cherry pick] Support optional residual add in fused ops and slice large...

[cherry pick] Support optional residual add in fused ops and slice large tensor for cudnn_softmax (#43719)

 [cherry pick] Support optional residual add in fused ops and slice large tensor for cudnn_softmax

cherry-pick #43635 #43681 #43474
上级 8e6a1945
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -372,18 +373,21 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -372,18 +373,21 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"0.0 and 0.001, But received [%s].", "0.0 and 0.001, But received [%s].",
ln_epsilon)); ln_epsilon));
}); });
AddAttr<bool>("add_residual", "Whether to add residual.").SetDefault(true);
AddAttr<int>( AddAttr<int>(
"ring_id", "ring_id",
"ring id for tensor model parallel. distributed training and inference") "ring id for tensor model parallel. distributed training and inference")
.SetDefault(-1); .SetDefault(-1);
AddComment(R"DOC( AddComment(R"DOC(
Add fused attention op whose logic is as follows: The fused_attention operator is the same as following pseudo codes:
// @input: [batch_size, seq_len, 3, num_head, head_dim]
// @input: [batch_size, seq_len, embed_dim]
// @final_out: [batch_size, seq_len, num_heads, head_dim] // @final_out: [batch_size, seq_len, num_heads, head_dim]
residual = input
if (pre_layernorm) if (pre_layernorm)
out = layer_norm(input); query = layer_norm(input);
out = compute_qkv(out) + bias; out = compute_qkv(query) + qkv_bias;
// fmha module // fmha module
{ {
out = transpose(out, perm=[2, 0, 3, 1, 4]); out = transpose(out, perm=[2, 0, 3, 1, 4]);
...@@ -395,11 +399,14 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -395,11 +399,14 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
out = transpose(out, perm=[0, 2, 1, 3]); out = transpose(out, perm=[0, 2, 1, 3]);
} }
out = out_linear(out); // out linear
if (pre_layernorm) out = linear(out);
final_out = residual + dropout(bias + out); if add_residual:
else out = residual + dropout(out);
final_out = layer_norm(residual + dropout(bias + out)); else:
out = dropout(out);
if (!pre_layernorm)
out = layer_norm(out);
)DOC"); )DOC");
} }
}; };
...@@ -649,3 +656,11 @@ REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp, ...@@ -649,3 +656,11 @@ REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp,
ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>, ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>,
ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>); ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp); REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp);
REGISTER_OP_VERSION(fused_attention)
.AddCheckpoint(
R"ROC(
Add a new attribute [add_residual] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"add_residual", "A flag to indicate whether to add residual.",
true));
...@@ -245,26 +245,32 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -245,26 +245,32 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// tensor model parallel // tensor model parallel
AllReduce<T>(*out_linear_out, ring_id, ctx.cuda_device_context()); AllReduce<T>(*out_linear_out, ring_id, ctx.cuda_device_context());
bool add_residual = ctx.Attr<bool>("add_residual");
const T *residual_ptr = add_residual ? x_data : nullptr;
if (pre_layer_norm) { if (pre_layer_norm) {
// output = (residual + dropout(input + bias)) // output = (residual + dropout(input + bias))
fused_dropout_layernorm_helper.ResidualDropoutBias( fused_dropout_layernorm_helper.ResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data, ctx.cuda_device_context(), out_linear_out_data, residual_ptr,
out_linear_bias_data, final_out_data, dropout_mask_out_data); out_linear_bias_data, final_out_data, dropout_mask_out_data);
} else { } else {
auto *ln_scale_2_data = // TODO(Xreki): support post layer_norm case when add_residual is false.
(ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>()); PADDLE_ENFORCE_EQ(add_residual, true,
auto *ln_bias_2_data = platform::errors::InvalidArgument(
(ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>()); "Attribute add_residual is expected to be true "
auto *bias_dropout_residual_out_data = "when pre_layer_norm is false."));
const U *ln_scale_2_ptr = ln_scale_2 ? ln_scale_2->data<U>() : nullptr;
const U *ln_bias_2_ptr = ln_bias_2 ? ln_bias_2->data<U>() : nullptr;
T *bias_dropout_residual_out_ptr =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace()); bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace()); U *ln_mean_2_ptr = ln_mean_2->mutable_data<U>(ctx.GetPlace());
auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace()); U *ln_var_2_ptr = ln_var_2->mutable_data<U>(ctx.GetPlace());
// output = layernorm(residual + dropout(input + bias)) // output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias( fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data, ctx.cuda_device_context(), out_linear_out_data, residual_ptr,
out_linear_bias_data, ln_scale_2_data, ln_bias_2_data, out_linear_bias_data, ln_scale_2_ptr, ln_bias_2_ptr,
bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data, bias_dropout_residual_out_ptr, dropout_mask_out_data, final_out_data,
ln_mean_2_data, ln_var_2_data); ln_mean_2_ptr, ln_var_2_ptr);
} }
} }
}; };
...@@ -418,16 +424,17 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -418,16 +424,17 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
int output_size = 3 * hidden_size; int output_size = 3 * hidden_size;
int input_size = dim_embed; int input_size = dim_embed;
bool add_residual = ctx.Attr<bool>("add_residual");
Tensor d_residual; Tensor d_residual;
T *d_residual_data = nullptr;
if (add_residual) {
d_residual.Resize(input_x_dims); d_residual.Resize(input_x_dims);
T *d_residual_data = d_residual.mutable_data<T>(ctx.GetPlace()); d_residual_data = d_residual.mutable_data<T>(ctx.GetPlace());
}
bool transA = false; bool transA = false;
bool transB = true; bool transB = true;
bool compute_qkv_bias = true; bool compute_qkv_bias = qkv_bias ? true : false;
if (qkv_bias == nullptr) {
compute_qkv_bias = false;
}
auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(), auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
epsilon, bsz_seq, dim_embed); epsilon, bsz_seq, dim_embed);
auto qkv_compute = auto qkv_compute =
...@@ -536,17 +543,14 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -536,17 +543,14 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
// tensor model parallel // tensor model parallel
AllReduce<T>(*d_x, ring_id, ctx.cuda_device_context()); AllReduce<T>(*d_x, ring_id, ctx.cuda_device_context());
} }
if (add_residual) {
// gradient accumulation // gradient accumulation
std::vector<const Tensor *> ins; std::vector<const Tensor *> ins = {&d_residual, d_x};
std::vector<Tensor *> outs; std::vector<Tensor *> outs = {d_x};
ins.emplace_back(&d_residual); phi::funcs::ElementwiseKernel<T>(ctx.cuda_device_context(), ins, &outs,
ins.emplace_back(d_x); phi::funcs::AddFunctor<T>());
outs.emplace_back(d_x); }
int elewise_add_axis = -1;
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T,
T>(
ctx.cuda_device_context(), ins, &outs, elewise_add_axis,
AddFunctor<T>());
} }
}; };
......
...@@ -150,10 +150,11 @@ class FusedDropoutHelper { ...@@ -150,10 +150,11 @@ class FusedDropoutHelper {
LaunchResidualDropoutBiasGrad<T, uint8_t>( LaunchResidualDropoutBiasGrad<T, uint8_t>(
d_out, mask, dropout_param_.dropout_prob, d_out, mask, dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx); dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx);
auto cuda_place = ctx.GetPlace(); if (d_residual) {
memory::Copy(cuda_place, d_residual, cuda_place, d_out, memory::Copy(ctx.GetPlace(), d_residual, ctx.GetPlace(), d_out,
rows_ * cols_ * sizeof(T), ctx.stream()); rows_ * cols_ * sizeof(T), ctx.stream());
} }
}
// out = dropout(activation(src + bias)) // out = dropout(activation(src + bias))
void DropoutActBias(const platform::CUDADeviceContext& ctx, const T* src, void DropoutActBias(const platform::CUDADeviceContext& ctx, const T* src,
......
...@@ -193,19 +193,28 @@ class FusedFeedForwardOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -193,19 +193,28 @@ class FusedFeedForwardOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(false); .SetDefault(false);
AddAttr<int>("dropout1_seed", "Dropout1 random seed.").SetDefault(0); AddAttr<int>("dropout1_seed", "Dropout1 random seed.").SetDefault(0);
AddAttr<int>("dropout2_seed", "Dropout2 random seed.").SetDefault(0); AddAttr<int>("dropout2_seed", "Dropout2 random seed.").SetDefault(0);
AddAttr<bool>("add_residual", "Whether to add residual.").SetDefault(true);
AddAttr<int>("ring_id", "ring id for tensor model parallel.") AddAttr<int>("ring_id", "ring id for tensor model parallel.")
.SetDefault(-1); .SetDefault(-1);
AddComment(R"DOC( AddComment(R"DOC(
the function of fused_feedforward operator is the same as the following pseudo code: The fused_feedforward operator is the same as the following pseudo codes:
residual = src; residual = src;
ln1_out = src; if (pre_layer_norm)
if(pre_layer_norm){
ln1_out = layer_norm(src); ln1_out = layer_norm(src);
} else
out = linear(dropout(activation(dropout(linear(ln1_out))))); ln1_out = src;
if(!pre_layer_norm) { // linear 1
out = linear(ln1_out);
out = dropout(activation(out));
// linear 2
out = linear(out);
if (add_residual)
out = residual + dropout(out);
else
out = dropout(out);
if (!pre_layer_norm)
out = layer_norm(out); out = layer_norm(out);
}
)DOC"); )DOC");
} }
}; };
...@@ -366,3 +375,11 @@ REGISTER_OPERATOR(fused_feedforward, ops::FusedFeedForwardOp, ...@@ -366,3 +375,11 @@ REGISTER_OPERATOR(fused_feedforward, ops::FusedFeedForwardOp,
ops::FusedFeedForwardOpGradMaker<paddle::framework::OpDesc>, ops::FusedFeedForwardOpGradMaker<paddle::framework::OpDesc>,
ops::FusedFeedForwardOpGradMaker<paddle::imperative::OpBase>); ops::FusedFeedForwardOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_feedforward_grad, ops::FusedFeedForwardOpGrad); REGISTER_OPERATOR(fused_feedforward_grad, ops::FusedFeedForwardOpGrad);
REGISTER_OP_VERSION(fused_feedforward)
.AddCheckpoint(
R"ROC(
Add a new attribute [add_residual] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"add_residual", "A flag to indicate whether to add residual.",
true));
...@@ -69,7 +69,8 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -69,7 +69,8 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
blas.MatMul(a, mat_dim_a, b, mat_dim_b, alpha, c, T(0)); blas.MatMul(a, mat_dim_a, b, mat_dim_b, alpha, c, T(0));
} }
void FFN(const framework::Tensor& x, const framework::Tensor& linear1_weight, void FFN(const platform::CUDADeviceContext& ctx, const framework::Tensor& x,
const framework::Tensor& linear1_weight,
const framework::Tensor* linear1_bias, const framework::Tensor* linear1_bias,
const framework::Tensor& linear2_weight, const framework::Tensor& linear2_weight,
const framework::Tensor* linear2_bias, const framework::Tensor* linear2_bias,
...@@ -84,10 +85,9 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -84,10 +85,9 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
framework::Tensor* dropout1_out, framework::Tensor* dropout2_out, framework::Tensor* dropout1_out, framework::Tensor* dropout2_out,
const int bsz_seq, const int d_model, const int dim_feedforward, const int bsz_seq, const int d_model, const int dim_feedforward,
const std::string& act_method, const bool pre_layer_norm, const std::string& act_method, const bool pre_layer_norm,
const float epsilon1, const float epsilon2, const int ring_id, const float epsilon1, const float epsilon2, const bool add_residual,
const DropoutParam& dropout_param1, const int ring_id, const DropoutParam& dropout_param1,
const DropoutParam& dropout_param2, const DropoutParam& dropout_param2) const {
const platform::CUDADeviceContext& ctx) const {
FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper( FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
bsz_seq, d_model, epsilon1); bsz_seq, d_model, epsilon1);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper( FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
...@@ -127,15 +127,22 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -127,15 +127,22 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
// tensor model parallel // tensor model parallel
AllReduce<T>(linear2_out, ring_id, ctx); AllReduce<T>(linear2_out, ring_id, ctx);
const T* residual_ptr = add_residual ? x.data<T>() : nullptr;
if (!pre_layer_norm) { if (!pre_layer_norm) {
// TODO(Xreki): support post layer_norm case when add_residual is false.
PADDLE_ENFORCE_EQ(add_residual, true,
platform::errors::InvalidArgument(
"Attribute add_residual is expected to be true "
"when pre_layer_norm is false."));
fused_dropout_layernorm_helper.LayernormResidualDropoutBias( fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx, linear2_out.data<T>(), x.data<T>(), linear2_bias_ptr, ctx, linear2_out.data<T>(), residual_ptr, linear2_bias_ptr,
ln2_scale_ptr, ln2_bias_ptr, dropout2_out->data<T>(), ln2_scale_ptr, ln2_bias_ptr, dropout2_out->data<T>(),
dropout2_mask->data<uint8_t>(), out->data<T>(), ln2_mean->data<U>(), dropout2_mask->data<uint8_t>(), out->data<T>(), ln2_mean->data<U>(),
ln2_variance->data<U>()); ln2_variance->data<U>());
} else { } else {
fused_dropout_layernorm_helper.ResidualDropoutBias( fused_dropout_layernorm_helper.ResidualDropoutBias(
ctx, linear2_out.data<T>(), x.data<T>(), linear2_bias_ptr, ctx, linear2_out.data<T>(), residual_ptr, linear2_bias_ptr,
out->data<T>(), dropout2_mask->data<uint8_t>()); out->data<T>(), dropout2_mask->data<uint8_t>());
} }
} }
...@@ -183,6 +190,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -183,6 +190,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
const float epsilon1 = context.Attr<float>("ln1_epsilon"); const float epsilon1 = context.Attr<float>("ln1_epsilon");
const float epsilon2 = context.Attr<float>("ln2_epsilon"); const float epsilon2 = context.Attr<float>("ln2_epsilon");
const int ring_id = context.Attr<int>("ring_id"); const int ring_id = context.Attr<int>("ring_id");
const bool add_residual = context.Attr<bool>("add_residual");
DropoutParam dropout_param1(context, 1); DropoutParam dropout_param1(context, 1);
DropoutParam dropout_param2(context, 2); DropoutParam dropout_param2(context, 2);
...@@ -214,12 +222,12 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -214,12 +222,12 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
int dim_feedforward = dim[dim.size() - 1]; int dim_feedforward = dim[dim.size() - 1];
int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_; int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;
FFN(*x, *linear1_weight, linear1_bias, *linear2_weight, linear2_bias, FFN(context.cuda_device_context(), *x, *linear1_weight, linear1_bias,
ln1_scale, ln1_bias, ln2_scale, ln2_bias, out, dropout1_mask, *linear2_weight, linear2_bias, ln1_scale, ln1_bias, ln2_scale, ln2_bias,
dropout2_mask, ln1_mean, ln1_variance, ln2_mean, ln2_variance, out, dropout1_mask, dropout2_mask, ln1_mean, ln1_variance, ln2_mean,
linear1_out, ln1_out, dropout1_out, dropout2_out, bsz_seq, d_model, ln2_variance, linear1_out, ln1_out, dropout1_out, dropout2_out, bsz_seq,
dim_feedforward, act_method, pre_layer_norm, epsilon1, epsilon2, d_model, dim_feedforward, act_method, pre_layer_norm, epsilon1,
ring_id, dropout_param1, dropout_param2, context.cuda_device_context()); epsilon2, add_residual, ring_id, dropout_param1, dropout_param2);
} }
}; };
...@@ -243,8 +251,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -243,8 +251,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
} }
void FFNGrad( void FFNGrad(
const framework::Tensor& d_out, const framework::Tensor& x, const platform::CUDADeviceContext& ctx, const framework::Tensor& d_out,
const framework::Tensor& dropout1_mask, const framework::Tensor& x, const framework::Tensor& dropout1_mask,
const framework::Tensor& dropout2_mask, const framework::Tensor& dropout2_mask,
const framework::Tensor& linear1_out, const framework::Tensor* ln1_out, const framework::Tensor& linear1_out, const framework::Tensor* ln1_out,
const framework::Tensor& dropout1_out, const framework::Tensor& dropout1_out,
...@@ -264,7 +272,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -264,7 +272,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const int dim_feedforward, const DropoutParam& dropout_param1, const int dim_feedforward, const DropoutParam& dropout_param1,
const DropoutParam& dropout_param2, const std::string& act_method, const DropoutParam& dropout_param2, const std::string& act_method,
const bool pre_layer_norm, const float epsilon1, const float epsilon2, const bool pre_layer_norm, const float epsilon1, const float epsilon2,
const int ring_id, const platform::CUDADeviceContext& ctx) const { const bool add_residual, const int ring_id) const {
FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper( FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
bsz_seq, d_model, epsilon1); bsz_seq, d_model, epsilon1);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper( FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
...@@ -296,19 +304,22 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -296,19 +304,22 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
framework::Tensor d_linear2_out, d_dropout2_out, d_residual; framework::Tensor d_linear2_out, d_dropout2_out, d_residual;
d_linear2_out.mutable_data<T>({bsz_seq, d_model}, place); d_linear2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_dropout2_out.mutable_data<T>({bsz_seq, d_model}, place); d_dropout2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_residual.mutable_data<T>(d_x->dims(), place);
T* d_residual_ptr = nullptr;
if (add_residual) {
d_residual_ptr = d_residual.mutable_data<T>(d_x->dims(), place);
}
if (pre_layer_norm) { if (pre_layer_norm) {
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad( fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
ctx, d_out.data<T>(), dropout2_mask.data<uint8_t>(), ctx, d_out.data<T>(), dropout2_mask.data<uint8_t>(),
d_linear2_out.data<T>(), d_residual.data<T>(), d_linear2_bias_ptr); d_linear2_out.data<T>(), d_residual_ptr, d_linear2_bias_ptr);
} else { } else {
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx, d_out.data<T>(), dropout2_out.data<T>(), ctx, d_out.data<T>(), dropout2_out.data<T>(),
dropout2_mask.data<uint8_t>(), ln2_gamma_ptr, ln2_mean->data<U>(), dropout2_mask.data<uint8_t>(), ln2_gamma_ptr, ln2_mean->data<U>(),
ln2_variance->data<U>(), d_dropout2_out.data<T>(), d_ln2_gamma_ptr, ln2_variance->data<U>(), d_dropout2_out.data<T>(), d_ln2_gamma_ptr,
d_ln2_beta_ptr, d_linear2_out.data<T>(), d_linear2_bias_ptr, d_ln2_beta_ptr, d_linear2_out.data<T>(), d_linear2_bias_ptr,
d_residual.data<T>()); d_residual_ptr);
} }
framework::Tensor d_dropout1_out; framework::Tensor d_dropout1_out;
...@@ -339,15 +350,14 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -339,15 +350,14 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
// tensor model parallel // tensor model parallel
AllReduce<T>(*d_x, ring_id, ctx); AllReduce<T>(*d_x, ring_id, ctx);
} }
std::vector<const Tensor*> ins(2);
std::vector<Tensor*> outs(1); if (add_residual) {
ins[0] = &d_residual; // gradient accumulation
ins[1] = d_x; std::vector<const Tensor*> ins = {&d_residual, d_x};
outs[0] = d_x; std::vector<Tensor*> outs = {d_x};
int elewise_add_axis = -1; phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs,
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, phi::funcs::AddFunctor<T>());
T>( }
ctx, ins, &outs, elewise_add_axis, AddFunctor<T>());
} }
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -412,6 +422,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -412,6 +422,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const float epsilon1 = context.Attr<float>("ln1_epsilon"); const float epsilon1 = context.Attr<float>("ln1_epsilon");
const float epsilon2 = context.Attr<float>("ln2_epsilon"); const float epsilon2 = context.Attr<float>("ln2_epsilon");
const bool add_residual = context.Attr<bool>("add_residual");
const int ring_id = context.Attr<int>("ring_id"); const int ring_id = context.Attr<int>("ring_id");
const std::string act_method = context.Attr<std::string>("act_method"); const std::string act_method = context.Attr<std::string>("act_method");
DropoutParam dropout_param1(context, 1); DropoutParam dropout_param1(context, 1);
...@@ -449,15 +460,15 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -449,15 +460,15 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1]; int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1];
int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_; int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;
FFNGrad(d_out, x, dropout1_mask, dropout2_mask, linear1_out, ln1_out, FFNGrad(context.cuda_device_context(), d_out, x, dropout1_mask,
dropout1_out, dropout2_out, linear1_weight, linear1_bias, dropout2_mask, linear1_out, ln1_out, dropout1_out, dropout2_out,
linear2_weight, ln1_scale, ln1_bias, ln1_mean, ln1_variance, linear1_weight, linear1_bias, linear2_weight, ln1_scale, ln1_bias,
ln2_scale, ln2_bias, ln2_mean, ln2_variance, d_x, d_linear1_weight, ln1_mean, ln1_variance, ln2_scale, ln2_bias, ln2_mean, ln2_variance,
d_linear1_bias, d_linear2_weight, d_linear2_bias, d_ln1_scale, d_x, d_linear1_weight, d_linear1_bias, d_linear2_weight,
d_ln1_bias, d_ln2_scale, d_ln2_bias, bsz_seq, d_model, d_linear2_bias, d_ln1_scale, d_ln1_bias, d_ln2_scale, d_ln2_bias,
dim_feedforward, dropout_param1, dropout_param2, act_method, bsz_seq, d_model, dim_feedforward, dropout_param1, dropout_param2,
pre_layer_norm, epsilon1, epsilon2, ring_id, act_method, pre_layer_norm, epsilon1, epsilon2, add_residual,
context.cuda_device_context()); ring_id);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -140,9 +140,12 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, ...@@ -140,9 +140,12 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols,
// dropout_prob == 1.0f // dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) { if (std::abs(dropout_prob - 1.0f) < 1e-5) {
if (residual == dst) return; if (residual == dst) return;
auto cuda_place = ctx.GetPlace(); if (residual) {
memory::Copy(cuda_place, dst, cuda_place, residual, rows * cols * sizeof(T), memory::Copy(ctx.GetPlace(), dst, ctx.GetPlace(), residual,
ctx.stream()); rows * cols * sizeof(T), ctx.stream());
} else {
SetZero<T>(ctx, dst, rows * cols);
}
if (!is_test) { if (!is_test) {
SetZero<MaskType>(ctx, mask_data, rows * cols); SetZero<MaskType>(ctx, mask_data, rows * cols);
} }
......
...@@ -29,8 +29,10 @@ PD_DECLARE_KERNEL(dropout_grad, GPU, ALL_LAYOUT); ...@@ -29,8 +29,10 @@ PD_DECLARE_KERNEL(dropout_grad, GPU, ALL_LAYOUT);
namespace framework = paddle::framework; namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
bool CheckEqual(float value, float ref) { return std::abs(value - ref) < 1e-5; }
/** /**
* @brief the unittest of fusedresidualdropoutbias * @brief the unittest of FusedResidualDropoutBias
* 1. random input data * 1. random input data
* 2. add bias, call paddle dropout op, add residual, and get the base result * 2. add bias, call paddle dropout op, add residual, and get the base result
* 3. call FusedResidualDropoutBias function get fused result * 3. call FusedResidualDropoutBias function get fused result
...@@ -38,7 +40,7 @@ namespace platform = paddle::platform; ...@@ -38,7 +40,7 @@ namespace platform = paddle::platform;
*/ */
template <typename T> template <typename T>
struct TestFusedResidualDropoutBias { struct FusedResidualDropoutBiasTester {
uint32_t rows; uint32_t rows;
uint32_t cols; uint32_t cols;
uint64_t seed; uint64_t seed;
...@@ -46,6 +48,8 @@ struct TestFusedResidualDropoutBias { ...@@ -46,6 +48,8 @@ struct TestFusedResidualDropoutBias {
bool is_upscale_in_train; bool is_upscale_in_train;
bool is_test; // default false, Set to true for inference only bool is_test; // default false, Set to true for inference only
bool has_bias = true; bool has_bias = true;
bool add_residual = true;
framework::Tensor src, residual, bias, out, mask; framework::Tensor src, residual, bias, out, mask;
framework::Tensor dsrc, dbias; framework::Tensor dsrc, dbias;
...@@ -56,37 +60,33 @@ struct TestFusedResidualDropoutBias { ...@@ -56,37 +60,33 @@ struct TestFusedResidualDropoutBias {
platform::CUDAPlace place; platform::CUDAPlace place;
platform::CUDADeviceContext *ctx; platform::CUDADeviceContext *ctx;
TestFusedResidualDropoutBias() { FusedResidualDropoutBiasTester() {
rows = 32; rows = 32;
cols = 32; cols = 32;
seed = 0; seed = 0;
dropout_prob = 0.0; dropout_prob = 0.0;
is_upscale_in_train = false; is_upscale_in_train = false;
is_test = false; is_test = false;
has_bias = true;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto device_ctx = pool.Get(place); auto device_ctx = pool.Get(place);
ctx = reinterpret_cast<platform::CUDADeviceContext *>(device_ctx); ctx = reinterpret_cast<platform::CUDADeviceContext *>(device_ctx);
} }
TestFusedResidualDropoutBias(int rows_, int cols_, uint64_t seed_ = 0, FusedResidualDropoutBiasTester(int rows, int cols, uint64_t seed = 0,
float dropout_prob_ = 0.0, float dropout_prob = 0.0,
bool is_upscale_in_train_ = false, bool is_upscale_in_train = false,
bool is_test_ = false) { bool is_test = false)
rows = rows_; : rows(rows),
cols = cols_; cols(cols),
seed = seed_; seed(seed),
dropout_prob = dropout_prob_; dropout_prob(dropout_prob),
is_upscale_in_train = is_upscale_in_train_; is_upscale_in_train(is_upscale_in_train),
is_test = is_test_; is_test(is_test) {
has_bias = true;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto device_ctx = pool.Get(place); auto device_ctx = pool.Get(place);
ctx = reinterpret_cast<platform::CUDADeviceContext *>(device_ctx); ctx = reinterpret_cast<platform::CUDADeviceContext *>(device_ctx);
} }
~TestFusedResidualDropoutBias() {}
void SetUp() { void SetUp() {
const int n = rows * cols; const int n = rows * cols;
correct_out.resize(n); correct_out.resize(n);
...@@ -95,7 +95,9 @@ struct TestFusedResidualDropoutBias { ...@@ -95,7 +95,9 @@ struct TestFusedResidualDropoutBias {
correct_dbias.resize(cols); correct_dbias.resize(cols);
src_vec.resize(n); src_vec.resize(n);
if (add_residual) {
residual_vec.resize(n); residual_vec.resize(n);
}
bias_vec.resize(cols); bias_vec.resize(cols);
std::default_random_engine random(time(NULL)); std::default_random_engine random(time(NULL));
std::uniform_real_distribution<float> dis(0.0, 1.0); std::uniform_real_distribution<float> dis(0.0, 1.0);
...@@ -103,7 +105,9 @@ struct TestFusedResidualDropoutBias { ...@@ -103,7 +105,9 @@ struct TestFusedResidualDropoutBias {
for (int i = 0; i < rows; i++) { for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) { for (int j = 0; j < cols; j++) {
src_vec[i * cols + j] = static_cast<T>(dis(random)); src_vec[i * cols + j] = static_cast<T>(dis(random));
if (add_residual) {
residual_vec[i * cols + j] = static_cast<T>(dis(random)); residual_vec[i * cols + j] = static_cast<T>(dis(random));
}
if (i == 0) { if (i == 0) {
bias_vec[j] = dis(random); bias_vec[j] = dis(random);
} }
...@@ -112,14 +116,15 @@ struct TestFusedResidualDropoutBias { ...@@ -112,14 +116,15 @@ struct TestFusedResidualDropoutBias {
framework::TensorFromVector<T>(src_vec, *ctx, &src); framework::TensorFromVector<T>(src_vec, *ctx, &src);
src.Resize({rows, cols}); src.Resize({rows, cols});
if (add_residual) {
framework::TensorFromVector<T>(residual_vec, *ctx, &residual); framework::TensorFromVector<T>(residual_vec, *ctx, &residual);
residual.Resize({rows, cols}); residual.Resize({rows, cols});
}
if (has_bias) { if (has_bias) {
framework::TensorFromVector<T>(bias_vec, *ctx, &bias); framework::TensorFromVector<T>(bias_vec, *ctx, &bias);
bias.Resize({cols}); bias.Resize({cols});
} }
{
out.mutable_data<T>({rows, cols}, place); out.mutable_data<T>({rows, cols}, place);
mask.mutable_data<uint8_t>({rows, cols}, place); mask.mutable_data<uint8_t>({rows, cols}, place);
dsrc.mutable_data<T>({rows, cols}, place); dsrc.mutable_data<T>({rows, cols}, place);
...@@ -128,31 +133,32 @@ struct TestFusedResidualDropoutBias { ...@@ -128,31 +133,32 @@ struct TestFusedResidualDropoutBias {
dbias.mutable_data<T>({cols}, place); dbias.mutable_data<T>({cols}, place);
} }
} }
}
void BaseForward() { void BaseForward() {
std::vector<T> out1(rows * cols), out2(rows * cols);
if (has_bias) { if (has_bias) {
// add bias // add bias
std::vector<T> bias_out(rows * cols);
for (int i = 0; i < rows; i++) { for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) { for (int j = 0; j < cols; j++) {
out1[i * cols + j] = src_vec[i * cols + j] + bias_vec[j]; bias_out[i * cols + j] = src_vec[i * cols + j] + bias_vec[j];
} }
} }
// call dropout // call dropout
Dropout<T>(out1, src.dims(), &out2, &correct_mask, *ctx, seed, Dropout<T>(bias_out, src.dims(), &correct_out, &correct_mask, *ctx, seed,
dropout_prob, is_upscale_in_train, is_test); dropout_prob, is_upscale_in_train, is_test);
} else { } else {
Dropout<T>(src_vec, src.dims(), &out2, &correct_mask, *ctx, seed, Dropout<T>(src_vec, src.dims(), &correct_out, &correct_mask, *ctx, seed,
dropout_prob, is_upscale_in_train, is_test); dropout_prob, is_upscale_in_train, is_test);
} }
ctx->Wait(); ctx->Wait();
PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError()); PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
if (add_residual) {
// add residual // add residual
for (int i = 0; i < rows; i++) { for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) { for (int j = 0; j < cols; j++) {
correct_out[i * cols + j] = int idx = i * cols + j;
residual_vec[i * cols + j] + out2[i * cols + j]; correct_out[idx] = residual_vec[idx] + correct_out[idx];
}
} }
} }
} }
...@@ -178,13 +184,11 @@ struct TestFusedResidualDropoutBias { ...@@ -178,13 +184,11 @@ struct TestFusedResidualDropoutBias {
1) * 1) *
VecSize; VecSize;
T *bias_ptr = nullptr; T *bias_ptr = has_bias ? bias.data<T>() : nullptr;
if (has_bias) { T *residual_ptr = add_residual ? residual.data<T>() : nullptr;
bias_ptr = bias.data<T>();
}
paddle::operators::LaunchResidualDropoutBias<T, uint8_t>( paddle::operators::LaunchResidualDropoutBias<T, uint8_t>(
rows, cols, increment, seed, dropout_prob, is_test, is_upscale_in_train, rows, cols, increment, seed, dropout_prob, is_test, is_upscale_in_train,
src.data<T>(), residual.data<T>(), bias_ptr, mask.data<uint8_t>(), src.data<T>(), residual_ptr, bias_ptr, mask.data<uint8_t>(),
out.data<T>(), *ctx); out.data<T>(), *ctx);
ctx->Wait(); ctx->Wait();
PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError()); PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
...@@ -195,10 +199,7 @@ struct TestFusedResidualDropoutBias { ...@@ -195,10 +199,7 @@ struct TestFusedResidualDropoutBias {
return; return;
} }
T *bias_ptr = nullptr; T *bias_ptr = has_bias ? dbias.data<T>() : nullptr;
if (has_bias) {
bias_ptr = dbias.data<T>();
}
paddle::operators::LaunchResidualDropoutBiasGrad<T, uint8_t>( paddle::operators::LaunchResidualDropoutBiasGrad<T, uint8_t>(
out.data<T>(), mask.data<uint8_t>(), dropout_prob, is_upscale_in_train, out.data<T>(), mask.data<uint8_t>(), dropout_prob, is_upscale_in_train,
rows, cols, dsrc.data<T>(), bias_ptr, *ctx); rows, cols, dsrc.data<T>(), bias_ptr, *ctx);
...@@ -214,17 +215,19 @@ struct TestFusedResidualDropoutBias { ...@@ -214,17 +215,19 @@ struct TestFusedResidualDropoutBias {
void CheckOut(const T diff) { void CheckOut(const T diff) {
const int n = rows * cols; const int n = rows * cols;
std::vector<T> _out(n); std::vector<T> fused_out(n);
std::vector<uint8_t> _mask(n); std::vector<uint8_t> fused_mask(n);
framework::TensorToVector(out, *ctx, &_out); framework::TensorToVector(out, *ctx, &fused_out);
if (!is_test) { if (!is_test) {
framework::TensorToVector<uint8_t>(mask, *ctx, &_mask); framework::TensorToVector<uint8_t>(mask, *ctx, &fused_mask);
} }
ctx->Wait(); ctx->Wait();
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
EXPECT_LT(std::abs(_out[i] - correct_out[i]), diff); EXPECT_LT(std::abs(fused_out[i] - correct_out[i]), diff);
if (!is_test) EXPECT_EQ(_mask[i], correct_mask[i]); if (!is_test) {
EXPECT_EQ(fused_mask[i], correct_mask[i]);
}
} }
} }
...@@ -255,16 +258,21 @@ struct TestFusedResidualDropoutBias { ...@@ -255,16 +258,21 @@ struct TestFusedResidualDropoutBias {
// test the shape and bias // test the shape and bias
template <typename T> template <typename T>
static void BaseTest(const bool is_fp16 = false) { static void BaseTest() {
const int rows = 16; const int rows = 16;
T default_diff = !is_fp16 ? static_cast<T>(1e-5) : static_cast<T>(1e-1); T max_diff = static_cast<T>(0);
if (std::is_same<T, paddle::platform::float16>::value) {
max_diff = static_cast<T>(1e-1);
} else {
max_diff = static_cast<T>(1e-5);
}
for (auto cols : {16, 17}) { for (auto cols : {16, 17}) {
for (auto has_bias : {true, false}) { for (auto has_bias : {true, false}) {
TestFusedResidualDropoutBias<T> test(rows, cols); FusedResidualDropoutBiasTester<T> test(rows, cols);
test.has_bias = has_bias; test.has_bias = has_bias;
test.Run(); test.Run();
test.CheckOut(default_diff); test.CheckOut(max_diff);
test.CheckGrad(default_diff); test.CheckGrad(max_diff);
} }
} }
} }
...@@ -274,14 +282,14 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias) { BaseTest<float>(); } ...@@ -274,14 +282,14 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias) { BaseTest<float>(); }
TEST(FusedDropout, GPUFusedResidualDropoutBiasDouble) { BaseTest<double>(); } TEST(FusedDropout, GPUFusedResidualDropoutBiasDouble) { BaseTest<double>(); }
TEST(FusedDropout, GPUFusedResidualDropoutBiasFp16) { TEST(FusedDropout, GPUFusedResidualDropoutBiasFp16) {
BaseTest<platform::float16>(true); BaseTest<platform::float16>();
} }
TEST(FusedDropout, GPUFusedResidualDropoutBiasIsUpscaleInTrain) { TEST(FusedDropout, GPUFusedResidualDropoutBiasIsUpscaleInTrain) {
const int rows = 16; const int rows = 16;
const int cols = 16; const int cols = 16;
for (auto is_upscale_in_train : {true, false}) { for (auto is_upscale_in_train : {true, false}) {
TestFusedResidualDropoutBias<float> test(rows, cols, 0, 1.0, FusedResidualDropoutBiasTester<float> test(rows, cols, 0, 1.0,
is_upscale_in_train, false); is_upscale_in_train, false);
test.Run(); test.Run();
test.CheckOut(static_cast<float>(1e-5)); test.CheckOut(static_cast<float>(1e-5));
...@@ -292,7 +300,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasIsUpscaleInTrain) { ...@@ -292,7 +300,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasIsUpscaleInTrain) {
TEST(FusedDropout, GPUFusedResidualDropoutBiasIsTest) { TEST(FusedDropout, GPUFusedResidualDropoutBiasIsTest) {
const int rows = 16; const int rows = 16;
const int cols = 16; const int cols = 16;
TestFusedResidualDropoutBias<float> test(rows, cols, 0, 0.35, true, true); FusedResidualDropoutBiasTester<float> test(rows, cols, 0, 0.35, true, true);
test.Run(); test.Run();
test.CheckOut(static_cast<float>(1e-5)); test.CheckOut(static_cast<float>(1e-5));
test.CheckGrad(static_cast<float>(1e-5)); test.CheckGrad(static_cast<float>(1e-5));
...@@ -301,16 +309,32 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasIsTest) { ...@@ -301,16 +309,32 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasIsTest) {
TEST(FusedDropout, GPUFusedResidualDropoutBiasSeed) { TEST(FusedDropout, GPUFusedResidualDropoutBiasSeed) {
const int rows = 16; const int rows = 16;
const int cols = 16; const int cols = 16;
TestFusedResidualDropoutBias<float> test(rows, cols, 125, 0.0, false, false); FusedResidualDropoutBiasTester<float> test(rows, cols, 125, 0.0, false,
false);
test.Run();
test.CheckOut(static_cast<float>(1e-5));
test.CheckGrad(static_cast<float>(1e-5));
}
TEST(FusedDropout, NoResidual) {
const int rows = 16;
const int cols = 16;
for (float p : {0.0f, 0.5f, 1.0f}) {
FusedResidualDropoutBiasTester<float> test(rows, cols, 0, p, false, false);
test.add_residual = false;
test.Run(); test.Run();
// For a non 0 or 1 dropout_prob, just test whether it can run successly.
if (CheckEqual(p, 0.0f) || CheckEqual(p, 1.0f)) {
test.CheckOut(static_cast<float>(1e-5)); test.CheckOut(static_cast<float>(1e-5));
test.CheckGrad(static_cast<float>(1e-5)); test.CheckGrad(static_cast<float>(1e-5));
}
}
} }
TEST(FusedDropout, GPUFusedResidualDropoutBiasLargeShape) { TEST(FusedDropout, GPUFusedResidualDropoutBiasLargeShape) {
const int rows = 256; const int rows = 256;
const int cols = 4096; const int cols = 4096;
TestFusedResidualDropoutBias<float> test(rows, cols); FusedResidualDropoutBiasTester<float> test(rows, cols);
test.Run(); test.Run();
test.CheckOut(static_cast<float>(1e-5)); test.CheckOut(static_cast<float>(1e-5));
test.CheckGrad(static_cast<float>(1e-3)); test.CheckGrad(static_cast<float>(1e-3));
...@@ -326,8 +350,8 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasLargeShapeFp16) { ...@@ -326,8 +350,8 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasLargeShapeFp16) {
if (std::getenv("_cols") != nullptr) { if (std::getenv("_cols") != nullptr) {
cols = atoi(std::getenv("_cols")); cols = atoi(std::getenv("_cols"));
} }
TestFusedResidualDropoutBias<platform::float16> test(rows, cols, 0, 0.0, true, FusedResidualDropoutBiasTester<platform::float16> test(rows, cols, 0, 0.0,
true); true, true);
test.Run(); test.Run();
test.CheckOut(static_cast<platform::float16>(1e-1)); test.CheckOut(static_cast<platform::float16>(1e-1));
test.CheckGrad(static_cast<platform::float16>(1e-1)); test.CheckGrad(static_cast<platform::float16>(1e-1));
......
...@@ -786,15 +786,12 @@ static std::vector<int> GetSoftmaxTensorDims(const phi::DDim& dims, ...@@ -786,15 +786,12 @@ static std::vector<int> GetSoftmaxTensorDims(const phi::DDim& dims,
template <typename T> template <typename T>
void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx, void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& x, const T* x_data,
const int axis, const int axis,
const int rank,
const bool log_mode, const bool log_mode,
DenseTensor* out) { const std::vector<int>& tensor_dims,
auto* out_data = out->data<T>(); T* out_data) {
const int rank = x.dims().size();
std::vector<int> tensor_dims = GetSoftmaxTensorDims(x.dims(), axis);
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW; GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;
...@@ -809,7 +806,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx, ...@@ -809,7 +806,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
handle, handle,
paddle::platform::CudnnDataType<T>::kOne(), paddle::platform::CudnnDataType<T>::kOne(),
desc, desc,
x.data<T>(), x_data,
paddle::platform::CudnnDataType<T>::kZero(), paddle::platform::CudnnDataType<T>::kZero(),
desc, desc,
out_data, out_data,
...@@ -826,7 +823,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx, ...@@ -826,7 +823,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
mode, mode,
paddle::platform::CudnnDataType<T>::kOne(), paddle::platform::CudnnDataType<T>::kOne(),
desc, desc,
x.data<T>(), x_data,
paddle::platform::CudnnDataType<T>::kZero(), paddle::platform::CudnnDataType<T>::kZero(),
desc, desc,
out_data)); out_data));
...@@ -834,17 +831,39 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx, ...@@ -834,17 +831,39 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
} }
template <typename T> template <typename T>
void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx, void LaunchSoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& out, const DenseTensor& x,
const DenseTensor& dout,
const int axis, const int axis,
const bool log_mode, const bool log_mode,
DenseTensor* dx) { DenseTensor* out) {
auto* dx_data = dx->data<T>(); auto* out_data = out->data<T>();
auto* x_data = x.data<T>();
const int rank = x.dims().size();
int rank = out.dims().size(); std::vector<int> tensor_dims = GetSoftmaxTensorDims(x.dims(), axis);
std::vector<int> tensor_dims = GetSoftmaxTensorDims(out.dims(), axis); int64_t remaining = tensor_dims[0];
int dim = tensor_dims[1];
int64_t batch_size = std::numeric_limits<int32_t>::max() / dim;
int offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxForwardCudnnKernel<T>(
dev_ctx, x_data, axis, rank, log_mode, tensor_dims, out_data);
x_data += offset;
out_data += offset;
remaining -= batch_size;
}
}
template <typename T>
void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
const T* out_data,
const T* dout_data,
const int axis,
const int rank,
const bool log_mode,
const std::vector<int>& tensor_dims,
T* dx_data) {
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW; GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;
...@@ -860,9 +879,9 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx, ...@@ -860,9 +879,9 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
handle, handle,
paddle::platform::CudnnDataType<T>::kOne(), paddle::platform::CudnnDataType<T>::kOne(),
desc, desc,
out.data<T>(), out_data,
desc, desc,
dout.data<T>(), dout_data,
paddle::platform::CudnnDataType<T>::kZero(), paddle::platform::CudnnDataType<T>::kZero(),
desc, desc,
dx_data, dx_data,
...@@ -879,9 +898,9 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx, ...@@ -879,9 +898,9 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
mode, mode,
paddle::platform::CudnnDataType<T>::kOne(), paddle::platform::CudnnDataType<T>::kOne(),
desc, desc,
out.data<T>(), out_data,
desc, desc,
dout.data<T>(), dout_data,
paddle::platform::CudnnDataType<T>::kZero(), paddle::platform::CudnnDataType<T>::kZero(),
desc, desc,
dx_data)); dx_data));
...@@ -889,21 +908,42 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx, ...@@ -889,21 +908,42 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
} }
template <typename T> template <typename T>
static bool CanUseCudnnSoftmax(const GPUContext& dev_ctx) { void LaunchSoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
if (dev_ctx.cudnn_handle() != nullptr) { const DenseTensor& out,
if (std::is_same<T, phi::dtype::bfloat16>::value) { const DenseTensor& dout,
#if CUDNN_VERSION < 8100 const int axis,
return false; const bool log_mode,
#endif DenseTensor* dx) {
} auto* dx_data = dx->data<T>();
return true; auto* out_data = out.data<T>();
auto* dout_data = dout.data<T>();
int rank = out.dims().size();
std::vector<int> tensor_dims = GetSoftmaxTensorDims(out.dims(), axis);
int64_t remaining = tensor_dims[0];
int dim = tensor_dims[1];
int64_t batch_size = std::numeric_limits<int32_t>::max() / dim;
int offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxBackwardCudnnKernel<T>(dev_ctx,
out_data,
dout_data,
axis,
rank,
log_mode,
tensor_dims,
dx_data);
out_data += offset;
dout_data += offset;
dx_data += offset;
remaining -= batch_size;
} }
return false;
} }
#if CUDNN_VERSION < 8100 #if CUDNN_VERSION < 8100
template <> template <>
inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>( inline void LaunchSoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx, const GPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const int axis, const int axis,
...@@ -914,7 +954,7 @@ inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>( ...@@ -914,7 +954,7 @@ inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
"8100.")); "8100."));
} }
template <> template <>
inline void SoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>( inline void LaunchSoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx, const GPUContext& dev_ctx,
const DenseTensor& out, const DenseTensor& out,
const DenseTensor& dout, const DenseTensor& dout,
...@@ -927,6 +967,25 @@ inline void SoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>( ...@@ -927,6 +967,25 @@ inline void SoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
} }
#endif #endif
template <typename T>
bool UseCudnnSoftmax(const GPUContext& ctx, int softmax_dim, bool last_dim) {
bool cudnn_available = ctx.cudnn_handle();
if (!ctx.cudnn_handle()) {
if (std::is_same<T, phi::dtype::bfloat16>::value) {
#if CUDNN_VERSION < 8100
cudnn_available = false;
#endif
}
}
constexpr int max_dim = 512;
if (!cudnn_available || !last_dim ||
(softmax_dim <= max_dim && sizeof(T) <= 4)) {
return false;
} else {
return true;
}
}
template <typename T, bool LogMode = false> template <typename T, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -941,10 +1000,8 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, ...@@ -941,10 +1000,8 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
int dim = tensor_dims[1]; int dim = tensor_dims[1];
int D = tensor_dims[2]; int D = tensor_dims[2];
constexpr int max_dim = 512; if (D == 1) {
if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
if (D == 1 &&
(!CanUseCudnnSoftmax<T>(dev_ctx) || (dim <= max_dim && sizeof(T) <= 4))) {
int dim_log2 = static_cast<int>(Log2Ceil(dim)); int dim_log2 = static_cast<int>(Log2Ceil(dim));
int dim_ceil = 1 << dim_log2; int dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32; int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
...@@ -993,11 +1050,12 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, ...@@ -993,11 +1050,12 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
dim, dim,
dim_log2); dim_log2);
} }
} else if (D > 1) { } else {
LaunchSoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
}
} else {
LaunchNormalSoftmaxForward<T, LogMode>( LaunchNormalSoftmaxForward<T, LogMode>(
dev_ctx, out_data, x.data<T>(), N, dim, D); dev_ctx, out_data, x.data<T>(), N, dim, D);
} else {
SoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
} }
} }
...@@ -1016,10 +1074,8 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx, ...@@ -1016,10 +1074,8 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
int dim = tensor_dims[1]; int dim = tensor_dims[1];
int D = tensor_dims[2]; int D = tensor_dims[2];
constexpr int max_dim = 512; if (D == 1) {
if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
if (D == 1 &&
(!CanUseCudnnSoftmax<T>(dev_ctx) || (dim <= max_dim && sizeof(T) <= 4))) {
int dim_log2 = Log2Ceil(dim); int dim_log2 = Log2Ceil(dim);
int dim_ceil = 1 << dim_log2; int dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32; int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
...@@ -1069,11 +1125,13 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx, ...@@ -1069,11 +1125,13 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
dim, dim,
dim_log2); dim_log2);
} }
} else if (D > 1) { } else {
LaunchSoftmaxBackwardCudnnKernel<T>(
dev_ctx, out, dout, axis, LogMode, dx);
}
} else {
LaunchNormalSoftmaxBackward<T, LogMode>( LaunchNormalSoftmaxBackward<T, LogMode>(
dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, dim, D); dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, dim, D);
} else {
SoftmaxBackwardCudnnKernel<T>(dev_ctx, out, dout, axis, LogMode, dx);
} }
} }
......
...@@ -46,6 +46,7 @@ def fused_feedforward(x, ...@@ -46,6 +46,7 @@ def fused_feedforward(x,
training=True, training=True,
mode='upscale_in_train', mode='upscale_in_train',
ring_id=-1, ring_id=-1,
add_residual=True,
name=None): name=None):
r""" r"""
This is a fusion operator to compute feed forward layer in transformer model architecture. This is a fusion operator to compute feed forward layer in transformer model architecture.
...@@ -54,12 +55,19 @@ def fused_feedforward(x, ...@@ -54,12 +55,19 @@ def fused_feedforward(x,
.. code-block:: python .. code-block:: python
residual = src; residual = x
if pre_layer_norm: if pre_layer_norm:
src = layer_norm(src) out = layer_norm1(x)
src = linear(dropout(activation(dropout(linear(src))))) else:
out = x
out = linear2(dropout1(activation(linear1(src))))
if add_residual:
out = residual + dropout2(out)
else:
out = dropout2(out)
if not pre_layer_norm: if not pre_layer_norm:
src = layer_norm(out) out = layer_norm2(out)
Args: Args:
x (Tensor): the input tensor could be 3-D tensor, the input data type could be float16, float32 or float64, the shape is`[batch\_size, sequence\_length, d_model]`. x (Tensor): the input tensor could be 3-D tensor, the input data type could be float16, float32 or float64, the shape is`[batch\_size, sequence\_length, d_model]`.
...@@ -90,6 +98,7 @@ def fused_feedforward(x, ...@@ -90,6 +98,7 @@ def fused_feedforward(x,
- train: out = input * mask - train: out = input * mask
- inference: out = input * (1.0 - p) - inference: out = input * (1.0 - p)
ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using tensor parallel. ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using tensor parallel.
add_residual (bool, optional): Whether add residual at the end. Default is True.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
...@@ -100,15 +109,13 @@ def fused_feedforward(x, ...@@ -100,15 +109,13 @@ def fused_feedforward(x,
# required: gpu # required: gpu
import paddle import paddle
import numpy as np import paddle.incubate.nn.functional as F
x_data = np.random.random((1, 8, 8)).astype("float32")
linear1_weight_data = np.random.random((8, 8)).astype("float32") x = paddle.randn(shape=(1, 8, 8), dtype="float32")
linear2_weight_data = np.random.random((8, 8)).astype("float32") linear1_weight = paddle.randn(shape=(8, 8), dtype="float32")
x = paddle.to_tensor(x_data) linear2_weight = paddle.randn(shape=(8, 8), dtype="float32")
linear1_weight = paddle.to_tensor(linear1_weight_data) out = F.fused_feedforward(x, linear1_weight, linear2_weight)
linear2_weight = paddle.to_tensor(linear2_weight_data) print(out.shape)
out = paddle.incubate.nn.functional.fused_feedforward(x, linear1_weight, linear2_weight)
print(out.numpy().shape)
# (1, 8, 8) # (1, 8, 8)
""" """
_verify_dropout_rate(dropout1_rate) _verify_dropout_rate(dropout1_rate)
...@@ -133,7 +140,8 @@ def fused_feedforward(x, ...@@ -133,7 +140,8 @@ def fused_feedforward(x,
"dropout2_fix_seed", seed is not None, "dropout1_seed", seed "dropout2_fix_seed", seed is not None, "dropout1_seed", seed
if seed is not None else 0, "dropout2_seed", seed if seed is not None else 0, "dropout2_seed", seed
if seed is not None else 0, 'dropout1_implementation', mode, if seed is not None else 0, 'dropout1_implementation', mode,
'dropout2_implementation', mode, 'ring_id', ring_id) 'dropout2_implementation', mode, 'add_residual', add_residual,
'ring_id', ring_id)
return out return out
helper = LayerHelper("fused_feedforward") helper = LayerHelper("fused_feedforward")
...@@ -208,6 +216,7 @@ def fused_feedforward(x, ...@@ -208,6 +216,7 @@ def fused_feedforward(x,
'dropout2_seed': seed if seed is not None else 0, 'dropout2_seed': seed if seed is not None else 0,
'dropout1_implementation': mode, 'dropout1_implementation': mode,
'dropout2_implementation': mode, 'dropout2_implementation': mode,
'add_residual': add_residual,
'ring_id': ring_id, 'ring_id': ring_id,
}) })
return out return out
...@@ -232,6 +241,7 @@ def fused_multi_head_attention(x, ...@@ -232,6 +241,7 @@ def fused_multi_head_attention(x,
training=True, training=True,
mode='upscale_in_train', mode='upscale_in_train',
ring_id=-1, ring_id=-1,
add_residual=True,
name=None): name=None):
r""" r"""
Attention mapps queries and a set of key-value pairs to outputs, and Attention mapps queries and a set of key-value pairs to outputs, and
...@@ -241,27 +251,34 @@ def fused_multi_head_attention(x, ...@@ -241,27 +251,34 @@ def fused_multi_head_attention(x,
.. code-block:: python .. code-block:: python
residual = x
if pre_layer_norm: if pre_layer_norm:
out = layer_norm(x) out = layer_norm(x)
out = linear(out) + qkv) + bias
else: else:
out = linear(x) + bias out = x
# compute q, k, v
out = matmul(out, qkv_weight) + qkv_bias
out = transpose(out, perm=[2, 0, 3, 1, 4]) out = transpose(out, perm=[2, 0, 3, 1, 4])
# extract q, k and v from out. # extract q, k and v from out
q = out[0:1,::] q = out[0:1,::] * (head_dim ** -0.5)
k = out[1:2,::] k = out[1:2,::]
v = out[2:3,::] v = out[2:3,::]
out = q * k^t out = matmul(q, k, transpose_y=True)
out = attn_mask + out out = out + attn_mask
out = softmax(out) out = softmax(out)
out = dropout(out) out = dropout(out)
out = out * v out = matmul(out, v)
# combine heads
out = transpose(out, perm=[0, 2, 1, 3]) out = transpose(out, perm=[0, 2, 1, 3])
out = out_linear(out) # project to output
if pre_layer_norm: out = linear(out)
out = x + dropout(linear_bias + out) if add_residual:
out = residual + dropout(out)
else: else:
out = layer_norm(x + dropout(linear_bias + out)) out = dropout(out)
if not pre_layer_norm:
out = layer_norm(out)
Parameters: Parameters:
x (Tensor): The input tensor of fused_multi_head_attention. The shape is x (Tensor): The input tensor of fused_multi_head_attention. The shape is
...@@ -308,6 +325,7 @@ def fused_multi_head_attention(x, ...@@ -308,6 +325,7 @@ def fused_multi_head_attention(x,
- train: out = input * mask - train: out = input * mask
- inference: out = input * (1.0 - p) - inference: out = input * (1.0 - p)
ring_id (int, optional): For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp ring_id (int, optional): For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp
add_residual (bool, optional): Whether add residual at the end. Default is True.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
...@@ -374,7 +392,8 @@ def fused_multi_head_attention(x, ...@@ -374,7 +392,8 @@ def fused_multi_head_attention(x,
'dropout_fix_seed', seed is not None, 'attn_dropout_seed', seed 'dropout_fix_seed', seed is not None, 'attn_dropout_seed', seed
if seed is not None else 0, 'dropout_seed', seed if seed is not None else 0, 'dropout_seed', seed
if seed is not None else 0, 'attn_dropout_implementation', mode, if seed is not None else 0, 'attn_dropout_implementation', mode,
'dropout_implementation', mode, 'ring_id', ring_id) 'dropout_implementation', mode, 'add_residual', add_residual,
'ring_id', ring_id)
if cache_kv is not None: if cache_kv is not None:
return final_out, cache_kv_out return final_out, cache_kv_out
return final_out return final_out
...@@ -424,6 +443,7 @@ def fused_multi_head_attention(x, ...@@ -424,6 +443,7 @@ def fused_multi_head_attention(x,
'dropout_seed': seed if seed is not None else 0, 'dropout_seed': seed if seed is not None else 0,
'attn_dropout_implementation': mode, 'attn_dropout_implementation': mode,
'dropout_implementation': mode, 'dropout_implementation': mode,
'add_residual': add_residual,
'ring_id': ring_id 'ring_id': ring_id
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册