未验证 提交 25b4ba7f 编写于 作者: S Sonder 提交者: GitHub

Move fused feedforward (#53166)

* trans fused_feedward Compute function to phi

* add register info

* remove maxfunctor

* move fused feedward to phi

* remove sig file

* remove fliud include

* add include

* add include

* add sig file

* add output register info

* fix sig file

* Update fused_feedforward_sig.cc

* fix grad kernel

* update output register info

* fix

* open fused_feedforward static build

* add optional and fix code style

* fix output info for fused attention

* add optional param

* merge
上级 18e9dcdc
......@@ -813,6 +813,8 @@ PD_REGISTER_KERNEL(fused_attention,
phi::dtype::float16,
double,
float) {
kernel->OutputAt(9).SetDataType(phi::DataType::UINT8);
kernel->OutputAt(14).SetDataType(phi::DataType::UINT8);
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
......
......@@ -13,634 +13,668 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/operators/matmul_v2_op.h"
#include "paddle/fluid/operators/fused/fused_attention_utils.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h"
#include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void MatMul(const phi::GPUContext& dev_ctx,
const phi::DenseTensor& a,
const phi::DenseTensor& b,
phi::DenseTensor* c) {
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
auto a_2d = phi::FoldInitDims(a);
auto b_2d = phi::FoldInitDims(b);
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a_2d.dims(), 0, false);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b_2d.dims(), 0, false);
T alpha = static_cast<T>(1.0);
blas.MatMul(a, mat_dim_a, b, mat_dim_b, alpha, c, T(0));
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
static void AllReduce(phi::DenseTensor& tensor, // NOLINT
const int ring_id,
const phi::GPUContext& ctx) {
if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(ring_id)) {
paddle::distributed::ProcessGroup* pg = map->get(ring_id);
auto pg_nccl = static_cast<distributed::ProcessGroupNCCL*>(pg);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true);
task->Wait();
template <typename T, typename Context>
void FFN(const phi::GPUContext& dev_ctx,
const phi::DenseTensor& x,
const phi::DenseTensor& linear1_weight,
const phi::DenseTensor* linear1_bias,
const phi::DenseTensor& linear2_weight,
const phi::DenseTensor* linear2_bias,
const phi::DenseTensor* ln1_scale,
const phi::DenseTensor* ln1_bias,
const phi::DenseTensor* ln2_scale,
const phi::DenseTensor* ln2_bias,
phi::DenseTensor* out,
phi::DenseTensor* dropout1_mask,
phi::DenseTensor* dropout2_mask,
phi::DenseTensor* ln1_mean,
phi::DenseTensor* ln1_variance,
phi::DenseTensor* ln2_mean,
phi::DenseTensor* ln2_variance,
phi::DenseTensor* linear1_out,
phi::DenseTensor* ln1_out,
phi::DenseTensor* dropout1_out,
phi::DenseTensor* dropout2_out,
const int bsz_seq,
const int d_model,
const int dim_feedforward,
const std::string& act_method,
const bool pre_layer_norm,
const float epsilon1,
const float epsilon2,
const bool add_residual,
const int ring_id,
const phi::fusion::DropoutParam& dropout_param1,
const phi::fusion::DropoutParam& dropout_param2) {
phi::fusion::FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
bsz_seq, d_model, epsilon1);
phi::fusion::FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
dev_ctx, bsz_seq, dim_feedforward, dropout_param1);
phi::fusion::FusedDropoutLayerNormHelper<T, uint8_t>
fused_dropout_layernorm_helper(
dev_ctx, bsz_seq, d_model, dropout_param2, epsilon2);
using U = phi::funcs::LayerNormParamType<T>;
const phi::DenseTensor* in = &x;
const U* ln1_scale_ptr =
ln1_scale == nullptr ? nullptr : ln1_scale->data<U>();
const U* ln1_bias_ptr = ln1_bias == nullptr ? nullptr : ln1_bias->data<U>();
const U* ln2_scale_ptr =
ln2_scale == nullptr ? nullptr : ln2_scale->data<U>();
const U* ln2_bias_ptr = ln2_bias == nullptr ? nullptr : ln2_bias->data<U>();
const T* linear1_bias_ptr =
linear1_bias == nullptr ? nullptr : linear1_bias->data<T>();
const T* linear2_bias_ptr =
linear2_bias == nullptr ? nullptr : linear2_bias->data<T>();
if (pre_layer_norm) {
pre_layernorm_helper.LayerNorm(dev_ctx,
x.data<T>(),
ln1_scale_ptr,
ln1_bias_ptr,
ln1_out->data<T>(),
ln1_mean->data<U>(),
ln1_variance->data<U>());
in = ln1_out;
}
MatMul<T, Context>(dev_ctx, *in, linear1_weight, linear1_out);
fused_act_dropout_helper.DropoutActBias(dev_ctx,
linear1_out->data<T>(),
linear1_bias_ptr,
act_method,
dropout1_out->data<T>(),
dropout1_mask->data<uint8_t>());
phi::DenseTensor linear2_out;
linear2_out.Resize({bsz_seq, d_model});
dev_ctx.template Alloc<T>(&linear2_out, linear2_out.numel() * sizeof(T));
MatMul<T, Context>(dev_ctx, *dropout1_out, linear2_weight, &linear2_out);
// tensor model parallel
phi::fusion::AllReduce<T>(linear2_out, ring_id, dev_ctx);
const T* residual_ptr = add_residual ? x.data<T>() : nullptr;
if (!pre_layer_norm) {
// TODO(Xreki): support post layer_norm case when add_residual is false.
PADDLE_ENFORCE_EQ(add_residual,
true,
phi::errors::InvalidArgument(
"Attribute add_residual is expected to be true "
"when pre_layer_norm is false."));
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
dev_ctx,
linear2_out.data<T>(),
residual_ptr,
linear2_bias_ptr,
ln2_scale_ptr,
ln2_bias_ptr,
dropout2_out->data<T>(),
dropout2_mask->data<uint8_t>(),
out->data<T>(),
ln2_mean->data<U>(),
ln2_variance->data<U>());
} else {
auto dtype = platform::ToNCCLDataType(
framework::TransToProtoVarType(tensor.dtype()));
int64_t numel = tensor.numel();
const void* sendbuff = tensor.data<T>();
auto place = ctx.GetPlace();
void* recvbuff = ctx.Alloc<T>(&tensor, tensor.numel() * sizeof(T));
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream));
fused_dropout_layernorm_helper.ResidualDropoutBias(
dev_ctx,
linear2_out.data<T>(),
residual_ptr,
linear2_bias_ptr,
out->data<T>(),
dropout2_mask->data<uint8_t>());
}
#else
PADDLE_THROW(platform::errors::Unimplemented(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."));
#endif
}
template <typename T, typename DeviceContext>
class FusedFeedForwardKernel : public framework::OpKernel<T> {
public:
void MatMul(const phi::GPUContext& ctx,
const phi::DenseTensor& a,
const phi::DenseTensor& b,
phi::DenseTensor* c) const {
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto a_2d = FoldInitDims(a);
auto b_2d = FoldInitDims(b);
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a_2d.dims(), 0, false);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b_2d.dims(), 0, false);
T alpha = static_cast<T>(1.0);
blas.MatMul(a, mat_dim_a, b, mat_dim_b, alpha, c, T(0));
template <typename T, typename Context>
void FusedFeedForwardKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& dropout1_seed,
const paddle::optional<DenseTensor>& dropout2_seed,
const DenseTensor& linear1_weight,
const paddle::optional<DenseTensor>& linear1_bias,
const DenseTensor& linear2_weight,
const paddle::optional<DenseTensor>& linear2_bias,
const paddle::optional<DenseTensor>& ln1_scale,
const paddle::optional<DenseTensor>& ln1_bias,
const paddle::optional<DenseTensor>& ln2_scale,
const paddle::optional<DenseTensor>& ln2_bias,
bool pre_layer_norm,
float ln1_epsilon,
float ln2_epsilon,
const std::string& act_method,
float dropout1_prob,
float dropout2_prob,
const std::string& dropout1_implementation,
const std::string& dropout2_implementation,
bool is_test,
bool dropout1_fix_seed,
bool dropout2_fix_seed,
int dropout1_seed_val,
int dropout2_seed_val,
bool add_residual,
int ring_id,
DenseTensor* out,
DenseTensor* dropout1_mask,
DenseTensor* dropout2_mask,
DenseTensor* ln1_mean,
DenseTensor* ln1_variance,
DenseTensor* ln2_mean,
DenseTensor* ln2_variance,
DenseTensor* linear1_out,
DenseTensor* ln1_out,
DenseTensor* dropout1_out,
DenseTensor* dropout2_out) {
auto* x_ptr = &x;
auto* linear1_weight_ptr = &linear1_weight;
auto* linear1_bias_ptr = linear1_bias.get_ptr();
auto* linear2_weight_ptr = &linear2_weight;
auto* linear2_bias_ptr = linear2_bias.get_ptr();
auto* ln1_scale_ptr = pre_layer_norm ? ln1_scale.get_ptr() : nullptr;
auto* ln1_bias_ptr = pre_layer_norm ? ln1_bias.get_ptr() : nullptr;
auto* ln2_scale_ptr = !pre_layer_norm ? ln2_scale.get_ptr() : nullptr;
auto* ln2_bias_ptr = !pre_layer_norm ? ln2_bias.get_ptr() : nullptr;
if (!pre_layer_norm) {
ln1_mean = nullptr;
ln1_variance = nullptr;
ln1_out = nullptr;
} else {
ln2_mean = nullptr;
ln2_variance = nullptr;
}
void FFN(const phi::GPUContext& ctx,
const phi::DenseTensor& x,
const phi::DenseTensor& linear1_weight,
const phi::DenseTensor* linear1_bias,
const phi::DenseTensor& linear2_weight,
const phi::DenseTensor* linear2_bias,
const phi::DenseTensor* ln1_scale,
const phi::DenseTensor* ln1_bias,
const phi::DenseTensor* ln2_scale,
const phi::DenseTensor* ln2_bias,
phi::DenseTensor* out,
phi::DenseTensor* dropout1_mask,
phi::DenseTensor* dropout2_mask,
phi::DenseTensor* ln1_mean,
phi::DenseTensor* ln1_variance,
phi::DenseTensor* ln2_mean,
phi::DenseTensor* ln2_variance,
phi::DenseTensor* linear1_out,
phi::DenseTensor* ln1_out,
phi::DenseTensor* dropout1_out,
phi::DenseTensor* dropout2_out,
const int bsz_seq,
const int d_model,
const int dim_feedforward,
const std::string& act_method,
const bool pre_layer_norm,
const float epsilon1,
const float epsilon2,
const bool add_residual,
const int ring_id,
const DropoutParam& dropout_param1,
const DropoutParam& dropout_param2) const {
FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
bsz_seq, d_model, epsilon1);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
ctx, bsz_seq, dim_feedforward, dropout_param1);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx, bsz_seq, d_model, dropout_param2, epsilon2);
using U = phi::funcs::LayerNormParamType<T>;
const phi::DenseTensor* in = &x;
const U* ln1_scale_ptr =
ln1_scale == nullptr ? nullptr : ln1_scale->data<U>();
const U* ln1_bias_ptr = ln1_bias == nullptr ? nullptr : ln1_bias->data<U>();
const U* ln2_scale_ptr =
ln2_scale == nullptr ? nullptr : ln2_scale->data<U>();
const U* ln2_bias_ptr = ln2_bias == nullptr ? nullptr : ln2_bias->data<U>();
const T* linear1_bias_ptr =
linear1_bias == nullptr ? nullptr : linear1_bias->data<T>();
const T* linear2_bias_ptr =
linear2_bias == nullptr ? nullptr : linear2_bias->data<T>();
if (pre_layer_norm) {
pre_layernorm_helper.LayerNorm(ctx,
x.data<T>(),
ln1_scale_ptr,
ln1_bias_ptr,
ln1_out->data<T>(),
ln1_mean->data<U>(),
ln1_variance->data<U>());
in = ln1_out;
}
MatMul(ctx, *in, linear1_weight, linear1_out);
fused_act_dropout_helper.DropoutActBias(ctx,
linear1_out->data<T>(),
linear1_bias_ptr,
act_method,
dropout1_out->data<T>(),
dropout1_mask->data<uint8_t>());
phi::DenseTensor linear2_out;
linear2_out.Resize({bsz_seq, d_model});
ctx.Alloc<T>(&linear2_out, linear2_out.numel() * sizeof(T));
MatMul(ctx, *dropout1_out, linear2_weight, &linear2_out);
bool is_upscale_in_train1 = dropout1_implementation == "upscale_in_train";
bool is_upscale_in_train2 = dropout2_implementation == "upscale_in_train";
auto* dropout1_seed_ptr = dropout1_seed.get_ptr();
auto* dropout2_seed_ptr = dropout2_seed.get_ptr();
phi::fusion::DropoutParam dropout_param1(dropout1_fix_seed,
0,
is_test,
is_upscale_in_train1,
dropout1_prob,
dropout1_seed_ptr,
dropout1_seed_val);
phi::fusion::DropoutParam dropout_param2(dropout2_fix_seed,
0,
is_test,
is_upscale_in_train2,
dropout2_prob,
dropout2_seed_ptr,
dropout2_seed_val);
using U = phi::funcs::LayerNormParamType<T>;
dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
dev_ctx.template Alloc<uint8_t>(dropout1_mask,
dropout1_mask->numel() * sizeof(uint8_t));
dev_ctx.template Alloc<uint8_t>(dropout2_mask,
dropout2_mask->numel() * sizeof(uint8_t));
if (pre_layer_norm) {
dev_ctx.template Alloc<U>(ln1_mean, ln1_mean->numel() * sizeof(U));
dev_ctx.template Alloc<U>(ln1_variance, ln1_variance->numel() * sizeof(U));
dev_ctx.template Alloc<T>(ln1_out, ln1_out->numel() * sizeof(T));
} else {
dev_ctx.template Alloc<U>(ln2_mean, ln2_mean->numel() * sizeof(U));
dev_ctx.template Alloc<U>(ln2_variance, ln2_variance->numel() * sizeof(U));
}
// tensor model parallel
AllReduce<T>(linear2_out, ring_id, ctx);
const T* residual_ptr = add_residual ? x.data<T>() : nullptr;
if (!pre_layer_norm) {
// TODO(Xreki): support post layer_norm case when add_residual is false.
PADDLE_ENFORCE_EQ(add_residual,
true,
platform::errors::InvalidArgument(
"Attribute add_residual is expected to be true "
"when pre_layer_norm is false."));
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx,
linear2_out.data<T>(),
residual_ptr,
linear2_bias_ptr,
ln2_scale_ptr,
ln2_bias_ptr,
dropout2_out->data<T>(),
dropout2_mask->data<uint8_t>(),
out->data<T>(),
ln2_mean->data<U>(),
ln2_variance->data<U>());
} else {
fused_dropout_layernorm_helper.ResidualDropoutBias(
ctx,
linear2_out.data<T>(),
residual_ptr,
linear2_bias_ptr,
out->data<T>(),
dropout2_mask->data<uint8_t>());
}
dev_ctx.template Alloc<T>(linear1_out, linear1_out->numel() * sizeof(T));
dev_ctx.template Alloc<T>(dropout1_out, dropout1_out->numel() * sizeof(T));
dev_ctx.template Alloc<T>(dropout2_out, dropout2_out->numel() * sizeof(T));
auto x_dim = x_ptr->dims();
auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(
phi::RowMatrixFromVector(x_dim), 0, false);
auto dim = linear1_weight_ptr->dims();
int d_model = dim[0];
int dim_feedforward = dim[dim.size() - 1];
int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;
phi::fusion::FFN<T, Context>(dev_ctx,
x,
linear1_weight,
linear1_bias_ptr,
linear2_weight,
linear2_bias_ptr,
ln1_scale_ptr,
ln1_bias_ptr,
ln2_scale_ptr,
ln2_bias_ptr,
out,
dropout1_mask,
dropout2_mask,
ln1_mean,
ln1_variance,
ln2_mean,
ln2_variance,
linear1_out,
ln1_out,
dropout1_out,
dropout2_out,
bsz_seq,
d_model,
dim_feedforward,
act_method,
pre_layer_norm,
ln1_epsilon,
ln2_epsilon,
add_residual,
ring_id,
dropout_param1,
dropout_param2);
}
template <typename T, typename Context>
void MatMulGrad(const phi::GPUContext& dev_ctx,
const phi::DenseTensor& d_out,
const phi::DenseTensor& a,
const phi::DenseTensor& b,
phi::DenseTensor* d_a,
phi::DenseTensor* d_b) {
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
auto a_2d = phi::FoldInitDims(a);
auto b_2d = phi::FoldInitDims(b);
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a_2d.dims(), 0, true);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b_2d.dims(), 0, true);
auto mat_dim_dout =
phi::funcs::CreateMatrixDescriptor(d_out.dims(), 0, false);
T alpha = static_cast<T>(1.0);
blas.MatMul(d_out, mat_dim_dout, b, mat_dim_b, alpha, d_a, T(0));
blas.MatMul(a, mat_dim_a, d_out, mat_dim_dout, alpha, d_b, T(0));
}
template <typename T, typename Context>
void FFNGrad(const phi::GPUContext& dev_ctx,
const phi::DenseTensor& d_out,
const phi::DenseTensor& x,
const phi::DenseTensor& dropout1_mask,
const phi::DenseTensor& dropout2_mask,
const phi::DenseTensor& linear1_out,
const phi::DenseTensor* ln1_out,
const phi::DenseTensor& dropout1_out,
const phi::DenseTensor* dropout2_out,
const phi::DenseTensor& linear1_weight,
const phi::DenseTensor* linear1_bias,
const phi::DenseTensor& linear2_weight,
const phi::DenseTensor* ln1_gamma,
const phi::DenseTensor* ln1_beta,
const phi::DenseTensor* ln1_mean,
const phi::DenseTensor* ln1_variance,
const phi::DenseTensor* ln2_gamma,
const phi::DenseTensor* ln2_beta,
const phi::DenseTensor* ln2_mean,
const phi::DenseTensor* ln2_variance,
phi::DenseTensor* d_x,
phi::DenseTensor* d_linear1_weight,
phi::DenseTensor* d_linear1_bias,
phi::DenseTensor* d_linear2_weight,
phi::DenseTensor* d_linear2_bias,
phi::DenseTensor* d_ln1_gamma,
phi::DenseTensor* d_ln1_beta,
phi::DenseTensor* d_ln2_gamma,
phi::DenseTensor* d_ln2_beta,
const int bsz_seq,
const int d_model,
const int dim_feedforward,
const phi::fusion::DropoutParam& dropout_param1,
const phi::fusion::DropoutParam& dropout_param2,
const std::string& act_method,
const bool pre_layer_norm,
const float epsilon1,
const float epsilon2,
const bool add_residual,
const int ring_id) {
phi::fusion::FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
bsz_seq, d_model, epsilon1);
phi::fusion::FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
dev_ctx, bsz_seq, dim_feedforward, dropout_param1);
phi::fusion::FusedDropoutLayerNormHelper<T, uint8_t>
fused_dropout_layernorm_helper(
dev_ctx, bsz_seq, d_model, dropout_param2, epsilon2);
using U = phi::funcs::LayerNormParamType<T>;
const U* ln1_gamma_ptr =
ln1_gamma == nullptr ? nullptr : ln1_gamma->data<U>();
const U* ln1_beta_ptr = ln1_beta == nullptr ? nullptr : ln1_beta->data<U>();
const U* ln2_gamma_ptr =
ln2_gamma == nullptr ? nullptr : ln2_gamma->data<U>();
const U* ln2_beta_ptr = ln2_beta == nullptr ? nullptr : ln2_beta->data<U>();
const T* linear1_bias_ptr =
linear1_bias == nullptr ? nullptr : linear1_bias->data<T>();
T* d_linear1_bias_ptr =
d_linear1_bias == nullptr ? nullptr : d_linear1_bias->data<T>();
T* d_linear2_bias_ptr =
d_linear2_bias == nullptr ? nullptr : d_linear2_bias->data<T>();
U* d_ln1_gamma_ptr =
d_ln1_gamma == nullptr ? nullptr : d_ln1_gamma->data<U>();
U* d_ln1_beta_ptr = d_ln1_beta == nullptr ? nullptr : d_ln1_beta->data<U>();
U* d_ln2_gamma_ptr =
d_ln2_gamma == nullptr ? nullptr : d_ln2_gamma->data<U>();
U* d_ln2_beta_ptr = d_ln2_beta == nullptr ? nullptr : d_ln2_beta->data<U>();
phi::DenseTensor d_linear2_out, d_dropout2_out, d_residual;
d_linear2_out.Resize({bsz_seq, d_model});
dev_ctx.template Alloc<T>(&d_linear2_out, d_linear2_out.numel() * sizeof(T));
d_dropout2_out.Resize({bsz_seq, d_model});
dev_ctx.template Alloc<T>(&d_dropout2_out,
d_dropout2_out.numel() * sizeof(T));
T* d_residual_ptr = nullptr;
if (add_residual) {
d_residual.Resize(d_x->dims());
d_residual_ptr =
dev_ctx.template Alloc<T>(&d_residual, d_residual.numel() * sizeof(T));
}
if (pre_layer_norm) {
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
dev_ctx,
d_out.data<T>(),
dropout2_mask.data<uint8_t>(),
d_linear2_out.data<T>(),
d_residual_ptr,
d_linear2_bias_ptr);
} else {
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
dev_ctx,
d_out.data<T>(),
dropout2_out->data<T>(),
dropout2_mask.data<uint8_t>(),
ln2_gamma_ptr,
ln2_mean->data<U>(),
ln2_variance->data<U>(),
d_dropout2_out.data<T>(),
d_ln2_gamma_ptr,
d_ln2_beta_ptr,
d_linear2_out.data<T>(),
d_linear2_bias_ptr,
d_residual_ptr);
}
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<phi::DenseTensor>("X");
auto* linear1_weight = context.Input<phi::DenseTensor>("Linear1Weight");
auto* linear1_bias = context.Input<phi::DenseTensor>("Linear1Bias");
auto* linear2_weight = context.Input<phi::DenseTensor>("Linear2Weight");
auto* linear2_bias = context.Input<phi::DenseTensor>("Linear2Bias");
const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm");
auto& dev_ctx = context.template device_context<phi::GPUContext>();
auto* ln1_scale =
pre_layer_norm ? context.Input<phi::DenseTensor>("Ln1Scale") : nullptr;
auto* ln1_bias =
pre_layer_norm ? context.Input<phi::DenseTensor>("Ln1Bias") : nullptr;
auto* ln2_scale =
!pre_layer_norm ? context.Input<phi::DenseTensor>("Ln2Scale") : nullptr;
auto* ln2_bias =
!pre_layer_norm ? context.Input<phi::DenseTensor>("Ln2Bias") : nullptr;
auto* ln1_mean =
pre_layer_norm ? context.Output<phi::DenseTensor>("Ln1Mean") : nullptr;
auto* ln1_variance = pre_layer_norm
? context.Output<phi::DenseTensor>("Ln1Variance")
: nullptr;
auto* ln2_mean =
!pre_layer_norm ? context.Output<phi::DenseTensor>("Ln2Mean") : nullptr;
auto* ln2_variance = !pre_layer_norm
? context.Output<phi::DenseTensor>("Ln2Variance")
: nullptr;
auto* out = context.Output<phi::DenseTensor>("Out");
auto* dropout1_mask = context.Output<phi::DenseTensor>("Dropout1Mask");
auto* dropout2_mask = context.Output<phi::DenseTensor>("Dropout2Mask");
auto* linear1_out = context.Output<phi::DenseTensor>("Linear1Out");
auto* ln1_out =
pre_layer_norm ? context.Output<phi::DenseTensor>("Ln1Out") : nullptr;
auto* dropout1_out = context.Output<phi::DenseTensor>("Dropout1Out");
auto* dropout2_out = context.Output<phi::DenseTensor>("Dropout2Out");
const std::string act_method = context.Attr<std::string>("act_method");
const float epsilon1 = context.Attr<float>("ln1_epsilon");
const float epsilon2 = context.Attr<float>("ln2_epsilon");
const int ring_id = context.Attr<int>("ring_id");
const bool add_residual = context.Attr<bool>("add_residual");
DropoutParam dropout_param1(context, 1);
DropoutParam dropout_param2(context, 2);
using U = phi::funcs::LayerNormParamType<T>;
dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
dev_ctx.Alloc<uint8_t>(dropout1_mask,
dropout1_mask->numel() * sizeof(uint8_t));
dev_ctx.Alloc<uint8_t>(dropout2_mask,
dropout2_mask->numel() * sizeof(uint8_t));
if (pre_layer_norm) {
dev_ctx.Alloc<U>(ln1_mean, ln1_mean->numel() * sizeof(U));
dev_ctx.Alloc<U>(ln1_variance, ln1_variance->numel() * sizeof(U));
dev_ctx.Alloc<T>(ln1_out, ln1_out->numel() * sizeof(T));
} else {
dev_ctx.Alloc<U>(ln2_mean, ln2_mean->numel() * sizeof(U));
dev_ctx.Alloc<U>(ln2_variance, ln2_variance->numel() * sizeof(U));
}
dev_ctx.Alloc<T>(linear1_out, linear1_out->numel() * sizeof(T));
dev_ctx.Alloc<T>(dropout1_out, dropout1_out->numel() * sizeof(T));
dev_ctx.Alloc<T>(dropout2_out, dropout2_out->numel() * sizeof(T));
auto x_dim = x->dims();
auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(
RowMatrixFromVector(x_dim), 0, false);
auto dim = linear1_weight->dims();
int d_model = dim[0];
int dim_feedforward = dim[dim.size() - 1];
int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;
FFN(context.cuda_device_context(),
*x,
*linear1_weight,
linear1_bias,
*linear2_weight,
linear2_bias,
ln1_scale,
ln1_bias,
ln2_scale,
ln2_bias,
out,
dropout1_mask,
dropout2_mask,
ln1_mean,
ln1_variance,
ln2_mean,
ln2_variance,
linear1_out,
ln1_out,
dropout1_out,
dropout2_out,
bsz_seq,
d_model,
dim_feedforward,
act_method,
pre_layer_norm,
epsilon1,
epsilon2,
add_residual,
ring_id,
dropout_param1,
dropout_param2);
phi::DenseTensor d_dropout1_out;
d_dropout1_out.Resize({bsz_seq, dim_feedforward});
dev_ctx.template Alloc<T>(&d_dropout1_out,
d_dropout1_out.numel() * sizeof(T));
MatMulGrad<T, Context>(dev_ctx,
d_linear2_out,
dropout1_out,
linear2_weight,
&d_dropout1_out,
d_linear2_weight);
phi::DenseTensor d_linear1_out;
d_linear1_out.Resize({bsz_seq, dim_feedforward});
dev_ctx.template Alloc<T>(&d_linear1_out, d_linear1_out.numel() * sizeof(T));
fused_act_dropout_helper.DropoutActBiasGrad(dev_ctx,
d_dropout1_out.data<T>(),
linear1_out.data<T>(),
linear1_bias_ptr,
dropout1_mask.data<uint8_t>(),
d_linear1_out.data<T>(),
d_linear1_bias_ptr,
act_method);
if (pre_layer_norm) {
phi::DenseTensor d_ln1_out;
d_ln1_out.Resize({bsz_seq, d_model});
dev_ctx.template Alloc<T>(&d_ln1_out, d_ln1_out.numel() * sizeof(T));
MatMulGrad<T, Context>(dev_ctx,
d_linear1_out,
*ln1_out,
linear1_weight,
&d_ln1_out,
d_linear1_weight);
// tensor model parallel
phi::fusion::AllReduce<T>(d_ln1_out, ring_id, dev_ctx);
pre_layernorm_helper.LayerNormGrad(dev_ctx,
d_ln1_out.data<T>(),
x.data<T>(),
ln1_gamma_ptr,
ln1_mean->data<U>(),
ln1_variance->data<U>(),
d_x->data<T>(),
d_ln1_gamma_ptr,
d_ln1_beta_ptr);
} else {
MatMulGrad<T, Context>(
dev_ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight);
// tensor model parallel
phi::fusion::AllReduce<T>(*d_x, ring_id, dev_ctx);
}
};
template <typename T, typename DeviceContext>
class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
public:
void MatMulGrad(const phi::GPUContext& ctx,
const phi::DenseTensor& d_out,
const phi::DenseTensor& a,
const phi::DenseTensor& b,
phi::DenseTensor* d_a,
phi::DenseTensor* d_b) const {
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto a_2d = FoldInitDims(a);
auto b_2d = FoldInitDims(b);
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a_2d.dims(), 0, true);
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b_2d.dims(), 0, true);
auto mat_dim_dout =
phi::funcs::CreateMatrixDescriptor(d_out.dims(), 0, false);
T alpha = static_cast<T>(1.0);
blas.MatMul(d_out, mat_dim_dout, b, mat_dim_b, alpha, d_a, T(0));
blas.MatMul(a, mat_dim_a, d_out, mat_dim_dout, alpha, d_b, T(0));
if (add_residual) {
// gradient accumulation
std::vector<const phi::DenseTensor*> ins = {&d_residual, d_x};
std::vector<phi::DenseTensor*> outs = {d_x};
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, phi::funcs::AddFunctor<T>());
}
}
void FFNGrad(const phi::GPUContext& ctx,
const phi::DenseTensor& d_out,
const phi::DenseTensor& x,
const phi::DenseTensor& dropout1_mask,
const phi::DenseTensor& dropout2_mask,
const phi::DenseTensor& linear1_out,
const phi::DenseTensor* ln1_out,
const phi::DenseTensor& dropout1_out,
const phi::DenseTensor* dropout2_out,
const phi::DenseTensor& linear1_weight,
const phi::DenseTensor* linear1_bias,
const phi::DenseTensor& linear2_weight,
const phi::DenseTensor* ln1_gamma,
const phi::DenseTensor* ln1_beta,
const phi::DenseTensor* ln1_mean,
const phi::DenseTensor* ln1_variance,
const phi::DenseTensor* ln2_gamma,
const phi::DenseTensor* ln2_beta,
const phi::DenseTensor* ln2_mean,
const phi::DenseTensor* ln2_variance,
phi::DenseTensor* d_x,
phi::DenseTensor* d_linear1_weight,
phi::DenseTensor* d_linear1_bias,
phi::DenseTensor* d_linear2_weight,
phi::DenseTensor* d_linear2_bias,
phi::DenseTensor* d_ln1_gamma,
phi::DenseTensor* d_ln1_beta,
phi::DenseTensor* d_ln2_gamma,
phi::DenseTensor* d_ln2_beta,
const int bsz_seq,
const int d_model,
const int dim_feedforward,
const DropoutParam& dropout_param1,
const DropoutParam& dropout_param2,
const std::string& act_method,
const bool pre_layer_norm,
const float epsilon1,
const float epsilon2,
const bool add_residual,
const int ring_id) const {
FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
bsz_seq, d_model, epsilon1);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
ctx, bsz_seq, dim_feedforward, dropout_param1);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx, bsz_seq, d_model, dropout_param2, epsilon2);
using U = phi::funcs::LayerNormParamType<T>;
const U* ln1_gamma_ptr =
ln1_gamma == nullptr ? nullptr : ln1_gamma->data<U>();
const U* ln1_beta_ptr = ln1_beta == nullptr ? nullptr : ln1_beta->data<U>();
const U* ln2_gamma_ptr =
ln2_gamma == nullptr ? nullptr : ln2_gamma->data<U>();
const U* ln2_beta_ptr = ln2_beta == nullptr ? nullptr : ln2_beta->data<U>();
const T* linear1_bias_ptr =
linear1_bias == nullptr ? nullptr : linear1_bias->data<T>();
T* d_linear1_bias_ptr =
d_linear1_bias == nullptr ? nullptr : d_linear1_bias->data<T>();
T* d_linear2_bias_ptr =
d_linear2_bias == nullptr ? nullptr : d_linear2_bias->data<T>();
U* d_ln1_gamma_ptr =
d_ln1_gamma == nullptr ? nullptr : d_ln1_gamma->data<U>();
U* d_ln1_beta_ptr = d_ln1_beta == nullptr ? nullptr : d_ln1_beta->data<U>();
U* d_ln2_gamma_ptr =
d_ln2_gamma == nullptr ? nullptr : d_ln2_gamma->data<U>();
U* d_ln2_beta_ptr = d_ln2_beta == nullptr ? nullptr : d_ln2_beta->data<U>();
phi::DenseTensor d_linear2_out, d_dropout2_out, d_residual;
d_linear2_out.Resize({bsz_seq, d_model});
ctx.Alloc<T>(&d_linear2_out, d_linear2_out.numel() * sizeof(T));
d_dropout2_out.Resize({bsz_seq, d_model});
ctx.Alloc<T>(&d_dropout2_out, d_dropout2_out.numel() * sizeof(T));
T* d_residual_ptr = nullptr;
if (add_residual) {
d_residual.Resize(d_x->dims());
d_residual_ptr =
ctx.Alloc<T>(&d_residual, d_residual.numel() * sizeof(T));
}
if (pre_layer_norm) {
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
ctx,
d_out.data<T>(),
dropout2_mask.data<uint8_t>(),
d_linear2_out.data<T>(),
d_residual_ptr,
d_linear2_bias_ptr);
} else {
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx,
d_out.data<T>(),
dropout2_out->data<T>(),
dropout2_mask.data<uint8_t>(),
ln2_gamma_ptr,
ln2_mean->data<U>(),
ln2_variance->data<U>(),
d_dropout2_out.data<T>(),
d_ln2_gamma_ptr,
d_ln2_beta_ptr,
d_linear2_out.data<T>(),
d_linear2_bias_ptr,
d_residual_ptr);
}
phi::DenseTensor d_dropout1_out;
d_dropout1_out.Resize({bsz_seq, dim_feedforward});
ctx.Alloc<T>(&d_dropout1_out, d_dropout1_out.numel() * sizeof(T));
MatMulGrad(ctx,
d_linear2_out,
dropout1_out,
linear2_weight,
&d_dropout1_out,
d_linear2_weight);
phi::DenseTensor d_linear1_out;
d_linear1_out.Resize({bsz_seq, dim_feedforward});
ctx.Alloc<T>(&d_linear1_out, d_linear1_out.numel() * sizeof(T));
fused_act_dropout_helper.DropoutActBiasGrad(ctx,
d_dropout1_out.data<T>(),
linear1_out.data<T>(),
linear1_bias_ptr,
dropout1_mask.data<uint8_t>(),
d_linear1_out.data<T>(),
d_linear1_bias_ptr,
act_method);
if (pre_layer_norm) {
phi::DenseTensor d_ln1_out;
d_ln1_out.Resize({bsz_seq, d_model});
ctx.Alloc<T>(&d_ln1_out, d_ln1_out.numel() * sizeof(T));
MatMulGrad(ctx,
d_linear1_out,
*ln1_out,
linear1_weight,
&d_ln1_out,
d_linear1_weight);
// tensor model parallel
AllReduce<T>(d_ln1_out, ring_id, ctx);
pre_layernorm_helper.LayerNormGrad(ctx,
d_ln1_out.data<T>(),
x.data<T>(),
ln1_gamma_ptr,
ln1_mean->data<U>(),
ln1_variance->data<U>(),
d_x->data<T>(),
d_ln1_gamma_ptr,
d_ln1_beta_ptr);
} else {
MatMulGrad(ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight);
// tensor model parallel
AllReduce<T>(*d_x, ring_id, ctx);
}
if (add_residual) {
// gradient accumulation
std::vector<const phi::DenseTensor*> ins = {&d_residual, d_x};
std::vector<phi::DenseTensor*> outs = {d_x};
phi::funcs::ElementwiseKernel<T>(
ctx, ins, &outs, phi::funcs::AddFunctor<T>());
}
template <typename T, typename Context>
void FusedFeedForwardGradKernel(
const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& linear1_weight,
const paddle::optional<DenseTensor>& linear1_bias,
const DenseTensor& linear2_weight,
const DenseTensor& dropout1_mask,
const DenseTensor& dropout2_mask,
const DenseTensor& linear1_out,
const DenseTensor& dropout1_out,
const paddle::optional<DenseTensor>& dropout2_out,
const paddle::optional<DenseTensor>& ln1_scale,
const paddle::optional<DenseTensor>& ln1_bias,
const paddle::optional<DenseTensor>& ln1_out,
const paddle::optional<DenseTensor>& ln1_mean,
const paddle::optional<DenseTensor>& ln1_variance,
const paddle::optional<DenseTensor>& ln2_scale,
const paddle::optional<DenseTensor>& ln2_bias,
const paddle::optional<DenseTensor>& ln2_mean,
const paddle::optional<DenseTensor>& ln2_variance,
const paddle::optional<DenseTensor>& linear2_bias,
bool pre_layer_norm,
float ln1_epsilon,
float ln2_epsilon,
const std::string& act_method,
float dropout1_prob,
float dropout2_prob,
const std::string& dropout1_implementation,
const std::string& dropout2_implementation,
bool is_test,
bool dropout1_fix_seed,
bool dropout2_fix_seed,
int dropout1_seed_val,
int dropout2_seed_val,
bool add_residual,
int ring_id,
DenseTensor* x_grad,
DenseTensor* ln1_scale_grad,
DenseTensor* ln1_bias_grad,
DenseTensor* ln2_scale_grad,
DenseTensor* ln2_bias_grad,
DenseTensor* linear1_weight_grad,
DenseTensor* linear1_bias_grad,
DenseTensor* linear2_weight_grad,
DenseTensor* linear2_bias_grad) {
using U = phi::funcs::LayerNormParamType<T>;
auto* ln1_out_ptr = pre_layer_norm ? ln1_out.get_ptr() : nullptr;
auto* dropout2_out_ptr = dropout2_out.get_ptr();
auto* linear1_bias_ptr = linear1_bias.get_ptr();
auto* ln1_mean_ptr = pre_layer_norm ? ln1_mean.get_ptr() : nullptr;
auto* ln1_variance_ptr = pre_layer_norm ? ln1_variance.get_ptr() : nullptr;
auto* ln1_scale_ptr = pre_layer_norm ? ln1_scale.get_ptr() : nullptr;
auto* ln1_bias_ptr = pre_layer_norm ? ln1_bias.get_ptr() : nullptr;
auto* ln2_mean_ptr = !pre_layer_norm ? ln2_mean.get_ptr() : nullptr;
auto* ln2_variance_ptr = !pre_layer_norm ? ln2_variance.get_ptr() : nullptr;
auto* ln2_scale_ptr = !pre_layer_norm ? ln2_scale.get_ptr() : nullptr;
auto* ln2_bias_ptr = !pre_layer_norm ? ln2_bias.get_ptr() : nullptr;
auto* d_x = x_grad;
auto* d_ln1_scale = pre_layer_norm ? ln1_scale_grad : nullptr;
auto* d_ln1_bias = pre_layer_norm ? ln1_bias_grad : nullptr;
auto* d_ln2_scale = pre_layer_norm ? nullptr : ln2_scale_grad;
auto* d_ln2_bias = pre_layer_norm ? nullptr : ln2_bias_grad;
auto* d_linear1_weight = linear1_weight_grad;
auto* d_linear1_bias = linear1_bias_grad;
auto* d_linear2_weight = linear2_weight_grad;
auto* d_linear2_bias = linear2_bias_grad;
bool is_upscale_in_train1 = dropout1_implementation == "upscale_in_train";
bool is_upscale_in_train2 = dropout2_implementation == "upscale_in_train";
phi::fusion::DropoutParam dropout_param1(dropout1_fix_seed,
0,
is_test,
is_upscale_in_train1,
dropout1_prob,
nullptr,
dropout1_seed_val);
phi::fusion::DropoutParam dropout_param2(dropout2_fix_seed,
0,
is_test,
is_upscale_in_train2,
dropout2_prob,
nullptr,
dropout2_seed_val);
dev_ctx.template Alloc<T>(d_x, d_x->numel() * sizeof(T));
if (d_ln1_scale) {
dev_ctx.template Alloc<U>(d_ln1_scale, d_ln1_scale->numel() * sizeof(U));
}
if (d_ln1_bias) {
dev_ctx.template Alloc<U>(d_ln1_bias, d_ln1_bias->numel() * sizeof(U));
}
if (d_ln2_scale) {
dev_ctx.template Alloc<U>(d_ln2_scale, d_ln2_scale->numel() * sizeof(U));
}
if (d_ln2_bias) {
dev_ctx.template Alloc<U>(d_ln2_bias, d_ln2_bias->numel() * sizeof(U));
}
if (d_linear1_bias) {
dev_ctx.template Alloc<T>(d_linear1_bias,
d_linear1_bias->numel() * sizeof(T));
}
if (d_linear2_bias) {
dev_ctx.template Alloc<T>(d_linear2_bias,
d_linear2_bias->numel() * sizeof(T));
}
dev_ctx.template Alloc<T>(d_linear1_weight,
d_linear1_weight->numel() * sizeof(T));
dev_ctx.template Alloc<T>(d_linear2_weight,
d_linear2_weight->numel() * sizeof(T));
auto x_dim = x.dims();
auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(
phi::RowMatrixFromVector(x_dim), 0, false);
auto linear1_weight_dim = linear1_weight.dims();
int d_model = linear1_weight_dim[0];
int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1];
int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;
FFNGrad<T, Context>(dev_ctx,
out_grad,
x,
dropout1_mask,
dropout2_mask,
linear1_out,
ln1_out_ptr,
dropout1_out,
dropout2_out_ptr,
linear1_weight,
linear1_bias_ptr,
linear2_weight,
ln1_scale_ptr,
ln1_bias_ptr,
ln1_mean_ptr,
ln1_variance_ptr,
ln2_scale_ptr,
ln2_bias_ptr,
ln2_mean_ptr,
ln2_variance_ptr,
d_x,
d_linear1_weight,
d_linear1_bias,
d_linear2_weight,
d_linear2_bias,
d_ln1_scale,
d_ln1_bias,
d_ln2_scale,
d_ln2_bias,
bsz_seq,
d_model,
dim_feedforward,
dropout_param1,
dropout_param2,
act_method,
pre_layer_norm,
ln1_epsilon,
ln2_epsilon,
add_residual,
ring_id);
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_feedforward,
GPU,
ALL_LAYOUT,
phi::fusion::FusedFeedForwardKernel,
float,
double,
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
kernel->OutputAt(2).SetDataType(phi::DataType::UINT8);
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32);
}
}
void Compute(const framework::ExecutionContext& context) const override {
using U = phi::funcs::LayerNormParamType<T>;
auto& dev_ctx = context.template device_context<phi::GPUContext>();
auto d_out =
*context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto x = *context.Input<phi::DenseTensor>("X");
const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm");
auto dropout1_mask = *context.Input<phi::DenseTensor>("Dropout1Mask");
auto dropout2_mask = *context.Input<phi::DenseTensor>("Dropout2Mask");
auto linear1_out = *context.Input<phi::DenseTensor>("Linear1Out");
auto* ln1_out =
pre_layer_norm ? context.Input<phi::DenseTensor>("Ln1Out") : nullptr;
auto dropout1_out = *context.Input<phi::DenseTensor>("Dropout1Out");
auto* dropout2_out = context.Input<phi::DenseTensor>("Dropout2Out");
auto linear1_weight = *context.Input<phi::DenseTensor>("Linear1Weight");
auto* linear1_bias = context.Input<phi::DenseTensor>("Linear1Bias");
auto linear2_weight = *context.Input<phi::DenseTensor>("Linear2Weight");
auto* ln1_mean =
pre_layer_norm ? context.Input<phi::DenseTensor>("Ln1Mean") : nullptr;
auto* ln1_variance = pre_layer_norm
? context.Input<phi::DenseTensor>("Ln1Variance")
: nullptr;
auto* ln1_scale =
pre_layer_norm ? context.Input<phi::DenseTensor>("Ln1Scale") : nullptr;
auto* ln1_bias =
pre_layer_norm ? context.Input<phi::DenseTensor>("Ln1Bias") : nullptr;
auto* ln2_mean =
!pre_layer_norm ? context.Input<phi::DenseTensor>("Ln2Mean") : nullptr;
auto* ln2_variance = !pre_layer_norm
? context.Input<phi::DenseTensor>("Ln2Variance")
: nullptr;
auto* ln2_scale =
!pre_layer_norm ? context.Input<phi::DenseTensor>("Ln2Scale") : nullptr;
auto* ln2_bias =
!pre_layer_norm ? context.Input<phi::DenseTensor>("Ln2Bias") : nullptr;
auto* d_x = context.Output<phi::DenseTensor>(framework::GradVarName("X"));
auto* d_ln1_scale = pre_layer_norm ? context.Output<phi::DenseTensor>(
framework::GradVarName("Ln1Scale"))
: nullptr;
auto* d_ln1_bias = pre_layer_norm ? context.Output<phi::DenseTensor>(
framework::GradVarName("Ln1Bias"))
: nullptr;
auto* d_ln2_scale = pre_layer_norm
? nullptr
: context.Output<phi::DenseTensor>(
framework::GradVarName("Ln2Scale"));
auto* d_ln2_bias = pre_layer_norm ? nullptr
: context.Output<phi::DenseTensor>(
framework::GradVarName("Ln2Bias"));
auto* d_linear1_weight = context.Output<phi::DenseTensor>(
framework::GradVarName("Linear1Weight"));
auto* d_linear1_bias =
context.Output<phi::DenseTensor>(framework::GradVarName("Linear1Bias"));
auto* d_linear2_weight = context.Output<phi::DenseTensor>(
framework::GradVarName("Linear2Weight"));
auto* d_linear2_bias =
context.Output<phi::DenseTensor>(framework::GradVarName("Linear2Bias"));
const float epsilon1 = context.Attr<float>("ln1_epsilon");
const float epsilon2 = context.Attr<float>("ln2_epsilon");
const bool add_residual = context.Attr<bool>("add_residual");
const int ring_id = context.Attr<int>("ring_id");
const std::string act_method = context.Attr<std::string>("act_method");
DropoutParam dropout_param1(context, 1);
DropoutParam dropout_param2(context, 2);
dev_ctx.Alloc<T>(d_x, d_x->numel() * sizeof(T));
if (d_ln1_scale) {
dev_ctx.Alloc<U>(d_ln1_scale, d_ln1_scale->numel() * sizeof(U));
}
if (d_ln1_bias) {
dev_ctx.Alloc<U>(d_ln1_bias, d_ln1_bias->numel() * sizeof(U));
}
if (d_ln2_scale) {
dev_ctx.Alloc<U>(d_ln2_scale, d_ln2_scale->numel() * sizeof(U));
}
if (d_ln2_bias) {
dev_ctx.Alloc<U>(d_ln2_bias, d_ln2_bias->numel() * sizeof(U));
}
if (d_linear1_bias) {
dev_ctx.Alloc<T>(d_linear1_bias, d_linear1_bias->numel() * sizeof(T));
}
if (d_linear2_bias) {
dev_ctx.Alloc<T>(d_linear2_bias, d_linear2_bias->numel() * sizeof(T));
}
dev_ctx.Alloc<T>(d_linear1_weight, d_linear1_weight->numel() * sizeof(T));
dev_ctx.Alloc<T>(d_linear2_weight, d_linear2_weight->numel() * sizeof(T));
auto x_dim = x.dims();
auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(
RowMatrixFromVector(x_dim), 0, false);
auto linear1_weight_dim = linear1_weight.dims();
int d_model = linear1_weight_dim[0];
int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1];
int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;
FFNGrad(context.cuda_device_context(),
d_out,
x,
dropout1_mask,
dropout2_mask,
linear1_out,
ln1_out,
dropout1_out,
dropout2_out,
linear1_weight,
linear1_bias,
linear2_weight,
ln1_scale,
ln1_bias,
ln1_mean,
ln1_variance,
ln2_scale,
ln2_bias,
ln2_mean,
ln2_variance,
d_x,
d_linear1_weight,
d_linear1_bias,
d_linear2_weight,
d_linear2_bias,
d_ln1_scale,
d_ln1_bias,
d_ln2_scale,
d_ln2_bias,
bsz_seq,
d_model,
dim_feedforward,
dropout_param1,
dropout_param2,
act_method,
pre_layer_norm,
epsilon1,
epsilon2,
add_residual,
ring_id);
PD_REGISTER_KERNEL(fused_feedforward_grad,
GPU,
ALL_LAYOUT,
phi::fusion::FusedFeedForwardGradKernel,
float,
double,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(fused_feedforward,
GPU,
ALL_LAYOUT,
ops::FusedFeedForwardKernel,
float,
double,
plat::float16) {}
PD_REGISTER_STRUCT_KERNEL(fused_feedforward_grad,
GPU,
ALL_LAYOUT,
ops::FusedFeedForwardGradKernel,
float,
double,
plat::float16) {}
}
......@@ -25,11 +25,6 @@ struct MulGradFunctor {
inline HOSTDEVICE T Dy(T x, T y) { return x; }
};
template <typename T>
struct MaxFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a < b ? b : a; }
};
template <typename T>
struct AddGradFunctor {
inline HOSTDEVICE T Dx(T x, T y) { return static_cast<T>(1.); }
......
......@@ -24,13 +24,13 @@ void FusedFeedForwardGradKernel(
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& linear1_weight,
const DenseTensor& linear1_bias,
const paddle::optional<DenseTensor>& linear1_bias,
const DenseTensor& linear2_weight,
const DenseTensor& dropout1_mask,
const DenseTensor& dropout2_mask,
const DenseTensor& linear1_out,
const DenseTensor& dropout1_out,
const DenseTensor& dropout2_out,
const paddle::optional<DenseTensor>& dropout2_out,
const paddle::optional<DenseTensor>& ln1_scale,
const paddle::optional<DenseTensor>& ln1_bias,
const paddle::optional<DenseTensor>& ln1_out,
......
......@@ -30,6 +30,8 @@
#include "paddle/phi/kernels/fusion/gpu/fused_residual_dropout_bias.h"
#include "paddle/phi/kernels/layer_norm_kernel.h"
PHI_DECLARE_bool(use_fast_math);
namespace phi {
namespace fusion {
......@@ -292,21 +294,22 @@ class FusedDropoutHelper {
T* d_bias,
const std::string& act_method) {
if (act_method == "gelu") {
phi::funcs::GeluGradFunctor<T> gelu_grad;
phi::fusion::
LaunchDropoutActBiasGrad<T, MaskType, phi::funcs::GeluGradFunctor<T>>(
gelu_grad,
dout,
mask,
src,
bias,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
rows_,
cols_,
d_src,
d_bias,
ctx);
phi::fusion::GeluGradFunctor<T> gelu_grad;
phi::fusion::LaunchDropoutActBiasGrad<T,
MaskType,
phi::fusion::GeluGradFunctor<T>>(
gelu_grad,
dout,
mask,
src,
bias,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
rows_,
cols_,
d_src,
d_bias,
ctx);
} else if (act_method == "relu") {
phi::funcs::ReluGradFunctor<T> relu_grad;
phi::fusion::
......
......@@ -366,13 +366,13 @@ void FusedFeedForwardGradKernel(
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& linear1_weight,
const DenseTensor& linear1_bias,
const paddle::optional<DenseTensor>& linear1_bias,
const DenseTensor& linear2_weight,
const DenseTensor& dropout1_mask,
const DenseTensor& dropout2_mask,
const DenseTensor& linear1_out,
const DenseTensor& dropout1_out,
const DenseTensor& dropout2_out,
const paddle::optional<DenseTensor>& dropout2_out,
const paddle::optional<DenseTensor>& ln1_scale,
const paddle::optional<DenseTensor>& ln1_bias,
const paddle::optional<DenseTensor>& ln1_out,
......@@ -417,7 +417,7 @@ void FusedFeedForwardGradKernel(
auto* ln1_out_ptr = pre_layer_norm ? ln1_out.get_ptr() : nullptr;
auto* dropout1_out_ptr = &dropout1_out;
auto* dropout2_out_ptr = &dropout2_out;
auto* dropout2_out_ptr = dropout2_out.get_ptr();
auto* linear1_weight_ptr = &linear1_weight;
auto* linear2_weight_ptr = &linear2_weight;
......
......@@ -1162,6 +1162,8 @@ set(STATIC_BUILD_TESTS
test_fetch_lod_tensor_array
test_fused_attention_op
test_fused_attention_op_api
test_fused_feedforward_op
test_fused_feedforward_pass
test_imperative_optimizer
test_lamb_op
test_layer_norm_op
......@@ -1191,6 +1193,8 @@ set(STATIC_BUILD_TESTS
if(NOT WITH_GPU)
list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_attention_op)
list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_attention_op_api)
list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_feedforward_op)
list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_feedforward_op_pass)
endif()
foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册