未验证 提交 4bbbed9a 编写于 作者: W Wilber 提交者: GitHub

Fix fused cuda op's mutable data [2] (#45562)

上级 26d161ef
...@@ -326,7 +326,8 @@ void Launch2DColumnReduce(const phi::GPUContext& dev_ctx, ...@@ -326,7 +326,8 @@ void Launch2DColumnReduce(const phi::GPUContext& dev_ctx,
} else { } else {
framework::Tensor tmp_sum; framework::Tensor tmp_sum;
tmp_sum.Resize({grid.y, left_num}); tmp_sum.Resize({grid.y, left_num});
tmp_sum.mutable_data<ReduceParamType<T>>(dev_ctx.GetPlace()); dev_ctx.template Alloc<ReduceParamType<T>>(
&tmp_sum, tmp_sum.numel() * sizeof(ReduceParamType<T>));
BiasAddBw2DReduceKernel<T><<<grid, block, 0, stream>>>( BiasAddBw2DReduceKernel<T><<<grid, block, 0, stream>>>(
d_out, d_out,
......
...@@ -49,7 +49,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -49,7 +49,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
auto* bias = ctx.Input<Tensor>("Bias"); auto* bias = ctx.Input<Tensor>("Bias");
auto* residual = ctx.Input<Tensor>("ResidualData"); auto* residual = ctx.Input<Tensor>("ResidualData");
auto* output = ctx.Output<Tensor>("Output"); auto* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>(ctx.GetPlace()); dev_ctx.template Alloc<T>(output, output->numel() * sizeof(T));
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
......
...@@ -84,7 +84,6 @@ class CudnnBNStatsFinalize { ...@@ -84,7 +84,6 @@ class CudnnBNStatsFinalize {
float momentum, float momentum,
int64_t ele_count, int64_t ele_count,
bool is_train) { bool is_train) {
auto place = ctx.GetPlace();
if (is_train) { if (is_train) {
TrainInit(ctx); TrainInit(ctx);
} else { } else {
...@@ -98,12 +97,18 @@ class CudnnBNStatsFinalize { ...@@ -98,12 +97,18 @@ class CudnnBNStatsFinalize {
const_cast<float *>(sum_of_squares.data<float>()); const_cast<float *>(sum_of_squares.data<float>());
float *scale_ptr = const_cast<float *>(scale.data<float>()); float *scale_ptr = const_cast<float *>(scale.data<float>());
float *bias_ptr = const_cast<float *>(bias.data<float>()); float *bias_ptr = const_cast<float *>(bias.data<float>());
float *saved_mean_ptr = saved_mean->mutable_data<float>(place); float *saved_mean_ptr = ctx.template Alloc<float>(
float *saved_invstd_ptr = saved_invstd->mutable_data<float>(place); saved_mean, saved_mean->numel() * sizeof(float));
float *running_mean_ptr = running_mean->mutable_data<float>(place); float *saved_invstd_ptr = ctx.template Alloc<float>(
float *running_var_ptr = running_var->mutable_data<float>(place); saved_invstd, saved_invstd->numel() * sizeof(float));
T *equiv_scale_ptr = equiv_scale->mutable_data<T>(place); float *running_mean_ptr = ctx.template Alloc<float>(
T *equiv_bias_ptr = equiv_bias->mutable_data<T>(place); running_mean, running_mean->numel() * sizeof(float));
float *running_var_ptr = ctx.template Alloc<float>(
running_var, running_var->numel() * sizeof(float));
T *equiv_scale_ptr =
ctx.template Alloc<T>(equiv_scale, equiv_scale->numel() * sizeof(T));
T *equiv_bias_ptr =
ctx.template Alloc<T>(equiv_bias, equiv_bias->numel() * sizeof(T));
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr); op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_BIAS, bias_ptr); op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_BIAS, bias_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_RUNNING_MEAN, running_mean_ptr); op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_RUNNING_MEAN, running_mean_ptr);
......
...@@ -193,7 +193,6 @@ class CudnnNormConvolution { ...@@ -193,7 +193,6 @@ class CudnnNormConvolution {
Tensor *sum, Tensor *sum,
Tensor *sum_of_squares) { Tensor *sum_of_squares) {
auto cudnn_handle = ctx.cudnn_handle(); auto cudnn_handle = ctx.cudnn_handle();
auto place = ctx.GetPlace();
CudnnFusionOp *fwd_op = GetForwardOp(ctx); CudnnFusionOp *fwd_op = GetForwardOp(ctx);
size_t workspace_size = RoundUp( size_t workspace_size = RoundUp(
...@@ -210,9 +209,11 @@ class CudnnNormConvolution { ...@@ -210,9 +209,11 @@ class CudnnNormConvolution {
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size); CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size);
// output ptr // output ptr
T *output_ptr = output->mutable_data<T>(place); T *output_ptr = ctx.template Alloc<T>(output, output->numel() * sizeof(T));
float *sum_ptr = sum->mutable_data<float>(place); float *sum_ptr =
float *sum_of_squares_ptr = sum_of_squares->mutable_data<float>(place); ctx.template Alloc<float>(sum, sum->numel() * sizeof(float));
float *sum_of_squares_ptr = ctx.template Alloc<float>(
sum_of_squares, sum_of_squares->numel() * sizeof(float));
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, output_ptr); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, output_ptr);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr);
...@@ -311,17 +312,18 @@ class CudnnNormConvolutionGrad { ...@@ -311,17 +312,18 @@ class CudnnNormConvolutionGrad {
Tensor *input_grad, Tensor *input_grad,
Tensor *filter_grad, Tensor *filter_grad,
bool use_addto = false) { bool use_addto = false) {
auto place = ctx.GetPlace();
T *input_ptr = const_cast<T *>(input.data<T>()); T *input_ptr = const_cast<T *>(input.data<T>());
T *filter_ptr = const_cast<T *>(filter.data<T>()); T *filter_ptr = const_cast<T *>(filter.data<T>());
T *output_grad_ptr = const_cast<T *>(output_grad.data<T>()); T *output_grad_ptr = const_cast<T *>(output_grad.data<T>());
if (filter_grad) { if (filter_grad) {
T *filter_grad_ptr = filter_grad->mutable_data<T>(place); T *filter_grad_ptr =
ctx.template Alloc<T>(filter_grad, filter_grad->numel() * sizeof(T));
BackwardFilter(ctx, output_grad_ptr, input_ptr, filter_grad_ptr); BackwardFilter(ctx, output_grad_ptr, input_ptr, filter_grad_ptr);
} }
if (input_grad) { if (input_grad) {
T *input_grad_ptr = input_grad->mutable_data<T>(place); T *input_grad_ptr =
ctx.template Alloc<T>(input_grad, input_grad->numel() * sizeof(T));
BackwardData(ctx, output_grad_ptr, filter_ptr, input_grad_ptr, use_addto); BackwardData(ctx, output_grad_ptr, filter_ptr, input_grad_ptr, use_addto);
} }
} }
......
...@@ -127,7 +127,6 @@ class CudnnScaleBiasAddRelu { ...@@ -127,7 +127,6 @@ class CudnnScaleBiasAddRelu {
Tensor *bitmask) { Tensor *bitmask) {
ForwardInit(ctx); ForwardInit(ctx);
auto handle = ctx.cudnn_handle(); auto handle = ctx.cudnn_handle();
auto place = ctx.GetPlace();
auto workspace_handle = ctx.cudnn_workspace_handle(); auto workspace_handle = ctx.cudnn_workspace_handle();
fwd_workspace_byte_ = fwd_op_.GetWorkspaceSizeInBytes(handle); fwd_workspace_byte_ = fwd_op_.GetWorkspaceSizeInBytes(handle);
// Set variant_param // Set variant_param
...@@ -156,8 +155,9 @@ class CudnnScaleBiasAddRelu { ...@@ -156,8 +155,9 @@ class CudnnScaleBiasAddRelu {
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &fwd_workspace_byte_); CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &fwd_workspace_byte_);
// output ptr // output ptr
T *out_ptr = out->mutable_data<T>(place); T *out_ptr = ctx.template Alloc<T>(out, out->numel() * sizeof(T));
int32_t *bitmask_ptr = bitmask->mutable_data<int32_t>(place); int32_t *bitmask_ptr = ctx.template Alloc<int32_t>(
bitmask, bitmask->numel() * sizeof(int32_t));
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, out_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, out_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ACTIVATION_BITMASK, bitmask_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ACTIVATION_BITMASK, bitmask_ptr);
...@@ -186,7 +186,6 @@ class CudnnScaleBiasAddRelu { ...@@ -186,7 +186,6 @@ class CudnnScaleBiasAddRelu {
double eps) { double eps) {
BackwardInit(ctx); BackwardInit(ctx);
auto handle = ctx.cudnn_handle(); auto handle = ctx.cudnn_handle();
auto place = ctx.GetPlace();
auto workspace_handle = ctx.cudnn_workspace_handle(); auto workspace_handle = ctx.cudnn_workspace_handle();
bwd_workspace_byte_ = bwd_op_.GetWorkspaceSizeInBytes(handle); bwd_workspace_byte_ = bwd_op_.GetWorkspaceSizeInBytes(handle);
// Set variant_param // Set variant_param
...@@ -199,10 +198,15 @@ class CudnnScaleBiasAddRelu { ...@@ -199,10 +198,15 @@ class CudnnScaleBiasAddRelu {
float *saved_invstd_ptr = const_cast<float *>(saved_invstd.data<float>()); float *saved_invstd_ptr = const_cast<float *>(saved_invstd.data<float>());
int32_t *bitmask_ptr = int32_t *bitmask_ptr =
bitmask ? const_cast<int32_t *>(bitmask->data<int32_t>()) : nullptr; bitmask ? const_cast<int32_t *>(bitmask->data<int32_t>()) : nullptr;
T *dx_ptr = dx->mutable_data<T>(place); T *dx_ptr = ctx.template Alloc<T>(dx, dx->numel() * sizeof(T));
T *dz_ptr = dz ? dz->mutable_data<T>(place) : nullptr; T *dz_ptr =
float *dscale_ptr = dscale ? dscale->mutable_data<float>(place) : nullptr; dz ? ctx.template Alloc<T>(dz, dz->numel() * sizeof(T)) : nullptr;
float *dbias_ptr = dbias ? dbias->mutable_data<float>(place) : nullptr; float *dscale_ptr = dscale ? ctx.template Alloc<float>(
dscale, dscale->numel() * sizeof(float))
: nullptr;
float *dbias_ptr =
dbias ? ctx.template Alloc<float>(dbias, dbias->numel() * sizeof(float))
: nullptr;
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr); bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DYDATA, dy_ptr); bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DYDATA, dy_ptr);
......
...@@ -64,7 +64,7 @@ static void AllReduce(framework::Tensor &tensor, // NOLINT ...@@ -64,7 +64,7 @@ static void AllReduce(framework::Tensor &tensor, // NOLINT
int64_t numel = tensor.numel(); int64_t numel = tensor.numel();
const void *sendbuff = tensor.data<T>(); const void *sendbuff = tensor.data<T>();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
void *recvbuff = tensor.mutable_data<T>(place); void *recvbuff = ctx.template Alloc<T>(&tensor, tensor.numel() * sizeof(T));
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = ctx.stream(); auto stream = ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
...@@ -83,7 +83,7 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -83,7 +83,7 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>; using U = LayerNormParamType<T>;
auto *input_x = ctx.Input<Tensor>("X"); auto *input_x = ctx.Input<Tensor>("X");
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm"); const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
auto *ln_scale = ctx.Input<Tensor>("LnScale"); auto *ln_scale = ctx.Input<Tensor>("LnScale");
...@@ -145,40 +145,53 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -145,40 +145,53 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto *x_data = input_x->data<T>(); auto *x_data = input_x->data<T>();
auto *qkv_weight_data = qkv_weight->data<T>(); auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>(); auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
auto *qkv_out_data = qkv_out->mutable_data<T>(ctx.GetPlace()); auto *qkv_out_data =
dev_ctx.template Alloc<T>(qkv_out, qkv_out->numel() * sizeof(T));
auto *qkv_bias_out_data = auto *qkv_bias_out_data =
(qkv_bias == nullptr) ? nullptr (qkv_bias == nullptr)
: qkv_bias_out->mutable_data<T>(ctx.GetPlace()); ? nullptr
: dev_ctx.template Alloc<T>(qkv_bias_out,
qkv_bias_out->numel() * sizeof(T));
// get data ptr for FMHA. // get data ptr for FMHA.
auto *transpose_out_2_data = auto *transpose_out_2_data = dev_ctx.template Alloc<T>(
transpose_out_2->mutable_data<T>(ctx.GetPlace()); transpose_out_2, transpose_out_2->numel() * sizeof(T));
auto *cache_kv_out_data = auto *cache_kv_out_data =
(cache_kv_out == nullptr) (cache_kv_out == nullptr)
? nullptr ? nullptr
: cache_kv_out->mutable_data<T>(ctx.GetPlace()); : dev_ctx.template Alloc<T>(cache_kv_out,
auto *qk_out_data = qk_out->mutable_data<T>(ctx.GetPlace()); cache_kv_out->numel() * sizeof(T));
auto *qktv_out_data = qktv_out->mutable_data<T>(ctx.GetPlace()); auto *qk_out_data =
dev_ctx.template Alloc<T>(qk_out, qk_out->numel() * sizeof(T));
auto *qktv_out_data =
dev_ctx.template Alloc<T>(qktv_out, qktv_out->numel() * sizeof(T));
auto *src_mask_out_data = auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr (src_mask == nullptr)
: src_mask_out->mutable_data<T>(ctx.GetPlace()); ? nullptr
auto *softmax_out_data = softmax_out->mutable_data<T>(ctx.GetPlace()); : dev_ctx.template Alloc<T>(src_mask_out,
auto *attn_dropout_mask_out_data = src_mask_out->numel() * sizeof(T));
attn_dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace()); auto *softmax_out_data = dev_ctx.template Alloc<T>(
auto *attn_dropout_out_data = softmax_out, softmax_out->numel() * sizeof(T));
attn_dropout_out->mutable_data<T>(ctx.GetPlace()); auto *attn_dropout_mask_out_data = dev_ctx.template Alloc<uint8_t>(
auto *fmha_out_data = fmha_out->mutable_data<T>(ctx.GetPlace()); attn_dropout_mask_out,
attn_dropout_mask_out->numel() * sizeof(uint8_t));
auto *attn_dropout_out_data = dev_ctx.template Alloc<T>(
attn_dropout_out, attn_dropout_out->numel() * sizeof(T));
auto *fmha_out_data =
dev_ctx.template Alloc<T>(fmha_out, fmha_out->numel() * sizeof(T));
// get data ptr for out_linear. // get data ptr for out_linear.
auto *out_linear_weight_data = out_linear_weight->data<T>(); auto *out_linear_weight_data = out_linear_weight->data<T>();
auto *out_linear_bias_data = auto *out_linear_bias_data =
(out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>(); (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
auto *out_linear_out_data = out_linear_out->mutable_data<T>(ctx.GetPlace()); auto *out_linear_out_data = dev_ctx.template Alloc<T>(
out_linear_out, out_linear_out->numel() * sizeof(T));
// get data ptr for bias+dropout+residual+layernorm // get data ptr for bias+dropout+residual+layernorm
auto *dropout_mask_out_data = auto *dropout_mask_out_data = dev_ctx.template Alloc<uint8_t>(
dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace()); dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t));
auto *final_out_data = out->mutable_data<T>(ctx.GetPlace()); auto *final_out_data =
dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
int batch_size = input_x_dims[0]; int batch_size = input_x_dims[0];
int max_seq_len = input_x_dims[1]; int max_seq_len = input_x_dims[1];
...@@ -248,9 +261,12 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -248,9 +261,12 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto *ln_scale_data = auto *ln_scale_data =
(ln_scale == nullptr ? nullptr : ln_scale->data<U>()); (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>()); auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *ln_mean_data = ln_mean->mutable_data<U>(ctx.GetPlace()); auto *ln_mean_data =
auto *ln_var_data = ln_var->mutable_data<U>(ctx.GetPlace()); dev_ctx.template Alloc<U>(ln_mean, ln_mean->numel() * sizeof(U));
auto *ln_out_data = ln_out->mutable_data<T>(ctx.GetPlace()); auto *ln_var_data =
dev_ctx.template Alloc<U>(ln_var, ln_var->numel() * sizeof(U));
auto *ln_out_data =
dev_ctx.template Alloc<T>(ln_out, ln_out->numel() * sizeof(T));
layer_norm_compute.ComputeForward(x_data, layer_norm_compute.ComputeForward(x_data,
ln_scale_data, ln_scale_data,
...@@ -321,10 +337,13 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -321,10 +337,13 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
const U *ln_scale_2_ptr = ln_scale_2 ? ln_scale_2->data<U>() : nullptr; const U *ln_scale_2_ptr = ln_scale_2 ? ln_scale_2->data<U>() : nullptr;
const U *ln_bias_2_ptr = ln_bias_2 ? ln_bias_2->data<U>() : nullptr; const U *ln_bias_2_ptr = ln_bias_2 ? ln_bias_2->data<U>() : nullptr;
T *bias_dropout_residual_out_ptr = T *bias_dropout_residual_out_ptr = dev_ctx.template Alloc<T>(
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace()); bias_dropout_residual_out,
U *ln_mean_2_ptr = ln_mean_2->mutable_data<U>(ctx.GetPlace()); bias_dropout_residual_out->numel() * sizeof(T));
U *ln_var_2_ptr = ln_var_2->mutable_data<U>(ctx.GetPlace()); U *ln_mean_2_ptr =
dev_ctx.template Alloc<U>(ln_mean_2, ln_mean_2->numel() * sizeof(U));
U *ln_var_2_ptr =
dev_ctx.template Alloc<U>(ln_var_2, ln_var_2->numel() * sizeof(U));
// output = layernorm(residual + dropout(input + bias)) // output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias( fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx.cuda_device_context(), ctx.cuda_device_context(),
...@@ -352,6 +371,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -352,6 +371,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
const float ln2epsilon = ctx.Attr<float>("ln_epsilon"); const float ln2epsilon = ctx.Attr<float>("ln_epsilon");
float attn_dropout_prob = ctx.Attr<float>("attn_dropout_rate"); float attn_dropout_prob = ctx.Attr<float>("attn_dropout_rate");
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
bool is_test_1 = ctx.Attr<bool>("is_test"); bool is_test_1 = ctx.Attr<bool>("is_test");
auto &dropout_implementation_1 = auto &dropout_implementation_1 =
ctx.Attr<std::string>("attn_dropout_implementation"); ctx.Attr<std::string>("attn_dropout_implementation");
...@@ -432,29 +452,37 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -432,29 +452,37 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
ctx.Output<Tensor>(framework::GradVarName("OutLinearOut")); ctx.Output<Tensor>(framework::GradVarName("OutLinearOut"));
auto *d_bias_dropout_residual_out = auto *d_bias_dropout_residual_out =
ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut")); ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut"));
auto *d_x_data = d_x->mutable_data<T>(ctx.GetPlace()); auto *d_x_data = dev_ctx.template Alloc<T>(d_x, d_x->numel() * sizeof(T));
// when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the // when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the
// space can be reused. // space can be reused.
auto *d_qkv_out_data = (d_qkv_bias_out != nullptr) auto *d_qkv_out_data = (d_qkv_bias_out != nullptr)
? nullptr ? nullptr
: d_qkv_out->mutable_data<T>(ctx.GetPlace()); : dev_ctx.template Alloc<T>(
d_qkv_out, d_qkv_out->numel() * sizeof(T));
auto *d_qkv_bias_out_data = auto *d_qkv_bias_out_data =
(d_qkv_bias_out == nullptr) (d_qkv_bias_out == nullptr)
? nullptr ? nullptr
: d_qkv_bias_out->mutable_data<T>(ctx.GetPlace()); : dev_ctx.template Alloc<T>(d_qkv_bias_out,
auto *d_qktv_out_data = d_qktv_out->mutable_data<T>(ctx.GetPlace()); d_qkv_bias_out->numel() * sizeof(T));
auto *d_transpose_out_2_data = auto *d_qktv_out_data =
d_transpose_out_2->mutable_data<T>(ctx.GetPlace()); dev_ctx.template Alloc<T>(d_qktv_out, d_qktv_out->numel() * sizeof(T));
auto *d_qk_out_data = d_qk_out->mutable_data<T>(ctx.GetPlace()); auto *d_transpose_out_2_data = dev_ctx.template Alloc<T>(
auto *d_softmax_out_data = d_softmax_out->mutable_data<T>(ctx.GetPlace()); d_transpose_out_2, d_transpose_out_2->numel() * sizeof(T));
auto *d_attn_dropout_out_data = auto *d_qk_out_data =
d_attn_dropout_out->mutable_data<T>(ctx.GetPlace()); dev_ctx.template Alloc<T>(d_qk_out, d_qk_out->numel() * sizeof(T));
auto *d_softmax_out_data = dev_ctx.template Alloc<T>(
d_softmax_out, d_softmax_out->numel() * sizeof(T));
auto *d_attn_dropout_out_data = dev_ctx.template Alloc<T>(
d_attn_dropout_out, d_attn_dropout_out->numel() * sizeof(T));
auto *d_src_mask_out_data = auto *d_src_mask_out_data =
(src_mask == nullptr) ? nullptr (src_mask == nullptr)
: d_src_mask_out->mutable_data<T>(ctx.GetPlace()); ? nullptr
auto *d_fmha_out_data = d_fmha_out->mutable_data<T>(ctx.GetPlace()); : dev_ctx.template Alloc<T>(d_src_mask_out,
auto *d_out_linear_out_data = d_src_mask_out->numel() * sizeof(T));
d_out_linear_out->mutable_data<T>(ctx.GetPlace()); auto *d_fmha_out_data =
dev_ctx.template Alloc<T>(d_fmha_out, d_fmha_out->numel() * sizeof(T));
auto *d_out_linear_out_data = dev_ctx.template Alloc<T>(
d_out_linear_out, d_out_linear_out->numel() * sizeof(T));
// parameter grad // parameter grad
auto *d_qkv_weight = ctx.Output<Tensor>(framework::GradVarName("QKVW")); auto *d_qkv_weight = ctx.Output<Tensor>(framework::GradVarName("QKVW"));
...@@ -466,16 +494,20 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -466,16 +494,20 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_ln_2_scale = ctx.Output<Tensor>(framework::GradVarName("Ln2Scale")); auto *d_ln_2_scale = ctx.Output<Tensor>(framework::GradVarName("Ln2Scale"));
auto *d_ln_2_bias = ctx.Output<Tensor>(framework::GradVarName("Ln2Bias")); auto *d_ln_2_bias = ctx.Output<Tensor>(framework::GradVarName("Ln2Bias"));
auto *d_qkv_weight_data = d_qkv_weight->mutable_data<T>(ctx.GetPlace()); auto *d_qkv_weight_data = dev_ctx.template Alloc<T>(
auto *d_qkv_bias_data = (d_qkv_bias == nullptr) d_qkv_weight, d_qkv_weight->numel() * sizeof(T));
auto *d_qkv_bias_data =
(d_qkv_bias == nullptr)
? nullptr ? nullptr
: d_qkv_bias->mutable_data<T>(ctx.GetPlace()); : dev_ctx.template Alloc<T>(d_qkv_bias,
auto *d_out_linear_weight_data = d_qkv_bias->numel() * sizeof(T));
d_out_linear_weight->mutable_data<T>(ctx.GetPlace()); auto *d_out_linear_weight_data = dev_ctx.template Alloc<T>(
d_out_linear_weight, d_out_linear_weight->numel() * sizeof(T));
auto *d_out_linear_bias_data = auto *d_out_linear_bias_data =
(d_out_linear_bias == nullptr) (d_out_linear_bias == nullptr)
? nullptr ? nullptr
: d_out_linear_bias->mutable_data<T>(ctx.GetPlace()); : dev_ctx.template Alloc<T>(d_out_linear_bias,
d_out_linear_bias->numel() * sizeof(T));
const auto input_x_dims = input_x->dims(); const auto input_x_dims = input_x->dims();
const auto qkv_w_dims = qkv_weight->dims(); const auto qkv_w_dims = qkv_weight->dims();
...@@ -496,7 +528,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -496,7 +528,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
T *d_residual_data = nullptr; T *d_residual_data = nullptr;
if (add_residual) { if (add_residual) {
d_residual.Resize(input_x_dims); d_residual.Resize(input_x_dims);
d_residual_data = d_residual.mutable_data<T>(ctx.GetPlace()); d_residual_data = dev_ctx.template Alloc<T>(
&d_residual, d_residual.numel() * sizeof(T));
} }
bool transA = false; bool transA = false;
...@@ -560,13 +593,16 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -560,13 +593,16 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_ln_2_scale_data = auto *d_ln_2_scale_data =
(d_ln_2_scale == nullptr (d_ln_2_scale == nullptr
? nullptr ? nullptr
: d_ln_2_scale->mutable_data<U>(ctx.GetPlace())); : dev_ctx.template Alloc<U>(d_ln_2_scale,
d_ln_2_scale->numel() * sizeof(U)));
auto *d_ln_2_bias_data = auto *d_ln_2_bias_data =
(d_ln_2_bias == nullptr (d_ln_2_bias == nullptr
? nullptr ? nullptr
: d_ln_2_bias->mutable_data<U>(ctx.GetPlace())); : dev_ctx.template Alloc<U>(d_ln_2_bias,
auto *d_bias_dropout_residual_out_data = d_ln_2_bias->numel() * sizeof(U)));
d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace()); auto *d_bias_dropout_residual_out_data = dev_ctx.template Alloc<T>(
d_bias_dropout_residual_out,
d_bias_dropout_residual_out->numel() * sizeof(T));
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx.cuda_device_context(), ctx.cuda_device_context(),
...@@ -638,13 +674,18 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -638,13 +674,18 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_ln_out = ctx.Output<Tensor>(framework::GradVarName("LnOut")); auto *d_ln_out = ctx.Output<Tensor>(framework::GradVarName("LnOut"));
auto *d_ln_scale = ctx.Output<Tensor>(framework::GradVarName("LnScale")); auto *d_ln_scale = ctx.Output<Tensor>(framework::GradVarName("LnScale"));
auto *d_ln_bias = ctx.Output<Tensor>(framework::GradVarName("LnBias")); auto *d_ln_bias = ctx.Output<Tensor>(framework::GradVarName("LnBias"));
auto *d_ln_out_data = d_ln_out->mutable_data<T>(ctx.GetPlace()); auto *d_ln_out_data =
dev_ctx.template Alloc<T>(d_ln_out, d_ln_out->numel() * sizeof(T));
auto *d_ln_scale_data = auto *d_ln_scale_data =
(d_ln_scale == nullptr ? nullptr (d_ln_scale == nullptr
: d_ln_scale->mutable_data<U>(ctx.GetPlace())); ? nullptr
: dev_ctx.template Alloc<U>(d_ln_scale,
d_ln_scale->numel() * sizeof(U)));
auto *d_ln_bias_data = auto *d_ln_bias_data =
(d_ln_bias == nullptr ? nullptr (d_ln_bias == nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace())); ? nullptr
: dev_ctx.template Alloc<U>(d_ln_bias,
d_ln_bias->numel() * sizeof(U)));
if (qkv_bias != nullptr) { if (qkv_bias != nullptr) {
qkv_compute.ComputeBackward(ln_out, qkv_compute.ComputeBackward(ln_out,
qkv_weight, qkv_weight,
......
...@@ -31,6 +31,7 @@ template <typename T> ...@@ -31,6 +31,7 @@ template <typename T>
class FusedBiasDropoutResidualLnOpKernel : public framework::OpKernel<T> { class FusedBiasDropoutResidualLnOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
using U = LayerNormParamType<T>; using U = LayerNormParamType<T>;
auto *input_x = ctx.Input<Tensor>("X"); auto *input_x = ctx.Input<Tensor>("X");
auto *bias = ctx.Input<Tensor>("Bias"); auto *bias = ctx.Input<Tensor>("Bias");
...@@ -50,12 +51,14 @@ class FusedBiasDropoutResidualLnOpKernel : public framework::OpKernel<T> { ...@@ -50,12 +51,14 @@ class FusedBiasDropoutResidualLnOpKernel : public framework::OpKernel<T> {
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>()); auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>()); auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *bias_dropout_residual_out_data = auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace()); dev_ctx.Alloc<T>(bias_dropout_residual_out,
auto *ln_mean_data = ln_mean->mutable_data<U>(ctx.GetPlace()); bias_dropout_residual_out->numel() * sizeof(T));
auto *ln_var_data = ln_var->mutable_data<U>(ctx.GetPlace()); auto *ln_mean_data =
auto *dropout_mask_out_data = dev_ctx.Alloc<U>(ln_mean, ln_mean->numel() * sizeof(U));
dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace()); auto *ln_var_data = dev_ctx.Alloc<U>(ln_var, ln_var->numel() * sizeof(U));
auto *y_data = y->mutable_data<T>(ctx.GetPlace()); auto *dropout_mask_out_data = dev_ctx.Alloc<uint8_t>(
dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t));
auto *y_data = dev_ctx.Alloc<T>(y, y->numel() * sizeof(T));
const auto input_x_dims = input_x->dims(); const auto input_x_dims = input_x->dims();
int bsz_seq = 1; int bsz_seq = 1;
...@@ -92,7 +95,7 @@ class FusedBiasDropoutResidualLnGradKernel : public framework::OpKernel<T> { ...@@ -92,7 +95,7 @@ class FusedBiasDropoutResidualLnGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>; using U = LayerNormParamType<T>;
const float ln_epsilon = ctx.Attr<float>("ln_epsilon"); const float ln_epsilon = ctx.Attr<float>("ln_epsilon");
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y")); auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto *ln_scale = ctx.Input<Tensor>("LnScale"); auto *ln_scale = ctx.Input<Tensor>("LnScale");
auto *dropout_mask_out = ctx.Input<Tensor>("DropoutMaskOut"); auto *dropout_mask_out = ctx.Input<Tensor>("DropoutMaskOut");
...@@ -114,18 +117,24 @@ class FusedBiasDropoutResidualLnGradKernel : public framework::OpKernel<T> { ...@@ -114,18 +117,24 @@ class FusedBiasDropoutResidualLnGradKernel : public framework::OpKernel<T> {
ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut")); ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut"));
auto *d_ln_scale = ctx.Output<Tensor>(framework::GradVarName("LnScale")); auto *d_ln_scale = ctx.Output<Tensor>(framework::GradVarName("LnScale"));
auto *d_ln_bias = ctx.Output<Tensor>(framework::GradVarName("LnBias")); auto *d_ln_bias = ctx.Output<Tensor>(framework::GradVarName("LnBias"));
auto *d_x_data = d_x->mutable_data<T>(ctx.GetPlace()); auto *d_x_data = dev_ctx.Alloc<T>(d_x, d_x->numel() * sizeof(T));
auto *d_residual_data = d_residual->mutable_data<T>(ctx.GetPlace()); auto *d_residual_data =
dev_ctx.Alloc<T>(d_residual, d_residual->numel() * sizeof(T));
auto *d_bias_dropout_residual_out_data = auto *d_bias_dropout_residual_out_data =
d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace()); dev_ctx.Alloc<T>(d_bias_dropout_residual_out,
d_bias_dropout_residual_out->numel() * sizeof(T));
auto *d_bias_data = auto *d_bias_data =
(d_bias == nullptr ? nullptr : d_bias->mutable_data<T>(ctx.GetPlace())); (d_bias == nullptr
? nullptr
: dev_ctx.Alloc<T>(d_bias, d_bias->numel() * sizeof(T)));
auto *d_ln_scale_data = auto *d_ln_scale_data =
(d_ln_scale == nullptr ? nullptr (d_ln_scale == nullptr
: d_ln_scale->mutable_data<U>(ctx.GetPlace())); ? nullptr
: dev_ctx.Alloc<U>(d_ln_scale, d_ln_scale->numel() * sizeof(U)));
auto *d_ln_bias_data = auto *d_ln_bias_data =
(d_ln_bias == nullptr ? nullptr (d_ln_bias == nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace())); ? nullptr
: dev_ctx.Alloc<U>(d_ln_bias, d_ln_bias->numel() * sizeof(U)));
const auto input_x_dims = d_y->dims(); const auto input_x_dims = d_y->dims();
int bsz_seq = 1; int bsz_seq = 1;
......
...@@ -45,6 +45,7 @@ class FusedBatchNormActKernel<phi::GPUContext, T> ...@@ -45,6 +45,7 @@ class FusedBatchNormActKernel<phi::GPUContext, T>
platform::is_gpu_place(ctx.GetPlace()), platform::is_gpu_place(ctx.GetPlace()),
true, true,
platform::errors::PreconditionNotMet("It must use CUDAPlace.")); platform::errors::PreconditionNotMet("It must use CUDAPlace."));
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon")); double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
float momentum = ctx.Attr<float>("momentum"); float momentum = ctx.Attr<float>("momentum");
std::string act_type = ctx.Attr<std::string>("act_type"); std::string act_type = ctx.Attr<std::string>("act_type");
...@@ -73,22 +74,26 @@ class FusedBatchNormActKernel<phi::GPUContext, T> ...@@ -73,22 +74,26 @@ class FusedBatchNormActKernel<phi::GPUContext, T>
// initialize them. // initialize them.
auto *mean_out = ctx.Output<Tensor>("MeanOut"); auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut"); auto *variance_out = ctx.Output<Tensor>("VarianceOut");
mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); dev_ctx.Alloc<BatchNormParamType<T>>(
variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); mean_out, mean_out->numel() * sizeof(BatchNormParamType<T>));
dev_ctx.Alloc<BatchNormParamType<T>>(
variance_out, variance_out->numel() * sizeof(BatchNormParamType<T>));
auto *saved_mean = ctx.Output<Tensor>("SavedMean"); auto *saved_mean = ctx.Output<Tensor>("SavedMean");
auto *saved_variance = ctx.Output<Tensor>("SavedVariance"); auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); dev_ctx.Alloc<BatchNormParamType<T>>(
saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); saved_mean, saved_mean->numel() * sizeof(BatchNormParamType<T>));
dev_ctx.Alloc<BatchNormParamType<T>>(
saved_variance,
saved_variance->numel() * sizeof(BatchNormParamType<T>));
auto *y = ctx.Output<Tensor>("Y"); auto *y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace()); dev_ctx.Alloc<T>(y, y->numel() * sizeof(T));
int N, C, H, W, D; int N, C, H, W, D;
const DataLayout data_layout = DataLayout::kNHWC; const DataLayout data_layout = DataLayout::kNHWC;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
if ((N * H * W * D) == 1) { if ((N * H * W * D) == 1) {
// Only 1 element in normalization dimension, // Only 1 element in normalization dimension,
// skip the batch norm calculation, let y = act(x). // skip the batch norm calculation, let y = act(x).
...@@ -172,10 +177,17 @@ class FusedBatchNormActKernel<phi::GPUContext, T> ...@@ -172,10 +177,17 @@ class FusedBatchNormActKernel<phi::GPUContext, T>
/*xDesc=*/data_desc_, /*xDesc=*/data_desc_,
/*sizeInBytes=*/&reserve_space_size)); /*sizeInBytes=*/&reserve_space_size));
reserve_space_ptr = reserve_space->mutable_data( reserve_space->Resize({static_cast<int64_t>(
ctx.GetPlace(), x->dtype(), reserve_space_size); (reserve_space_size + experimental::SizeOf(x->dtype()) - 1) /
workspace_ptr = workspace_tensor.mutable_data( experimental::SizeOf(x->dtype()))});
ctx.GetPlace(), x->dtype(), workspace_size); reserve_space_ptr =
dev_ctx.Alloc<T>(reserve_space, reserve_space->numel() * sizeof(T));
workspace_tensor.Resize({static_cast<int64_t>(
(workspace_size + experimental::SizeOf(x->dtype()) - 1) /
experimental::SizeOf(x->dtype()))});
workspace_ptr = dev_ctx.Alloc<T>(&workspace_tensor,
workspace_tensor.numel() * sizeof(T));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnBatchNormalizationForwardTrainingEx( platform::dynload::cudnnBatchNormalizationForwardTrainingEx(
handle, handle,
...@@ -193,15 +205,18 @@ class FusedBatchNormActKernel<phi::GPUContext, T> ...@@ -193,15 +205,18 @@ class FusedBatchNormActKernel<phi::GPUContext, T>
scale->template data<BatchNormParamType<T>>(), scale->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>(), bias->template data<BatchNormParamType<T>>(),
this_factor, this_factor,
mean_out->template mutable_data<BatchNormParamType<T>>( dev_ctx.template Alloc<BatchNormParamType<T>>(
ctx.GetPlace()), mean_out, mean_out->numel() * sizeof(BatchNormParamType<T>)),
variance_out->template mutable_data<BatchNormParamType<T>>( dev_ctx.template Alloc<BatchNormParamType<T>>(
ctx.GetPlace()), variance_out,
variance_out->numel() * sizeof(BatchNormParamType<T>)),
epsilon, epsilon,
saved_mean->template mutable_data<BatchNormParamType<T>>( dev_ctx.template Alloc<BatchNormParamType<T>>(
ctx.GetPlace()), saved_mean,
saved_variance->template mutable_data<BatchNormParamType<T>>( saved_mean->numel() * sizeof(BatchNormParamType<T>)),
ctx.GetPlace()), dev_ctx.template Alloc<BatchNormParamType<T>>(
saved_variance,
saved_variance->numel() * sizeof(BatchNormParamType<T>)),
activation_desc_, activation_desc_,
workspace_ptr, workspace_ptr,
workspace_size, workspace_size,
...@@ -227,7 +242,7 @@ class FusedBatchNormActGradKernel<phi::GPUContext, T> ...@@ -227,7 +242,7 @@ class FusedBatchNormActGradKernel<phi::GPUContext, T>
platform::errors::PreconditionNotMet("It must use CUDAPlace.")); platform::errors::PreconditionNotMet("It must use CUDAPlace."));
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon")); double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
std::string act_type = ctx.Attr<std::string>("act_type"); std::string act_type = ctx.Attr<std::string>("act_type");
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const auto *y = ctx.Input<Tensor>("Y"); const auto *y = ctx.Input<Tensor>("Y");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y")); const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
...@@ -250,14 +265,16 @@ class FusedBatchNormActGradKernel<phi::GPUContext, T> ...@@ -250,14 +265,16 @@ class FusedBatchNormActGradKernel<phi::GPUContext, T>
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale")); auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace()); dev_ctx.Alloc<T>(d_x, d_x->numel() * sizeof(T));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
d_scale && d_bias, d_scale && d_bias,
true, true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Both the scale grad and the bias grad must not be null.")); "Both the scale grad and the bias grad must not be null."));
d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); dev_ctx.Alloc<BatchNormParamType<T>>(
d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); d_scale, d_scale->numel() * sizeof(BatchNormParamType<T>));
dev_ctx.Alloc<BatchNormParamType<T>>(
d_bias, d_bias->numel() * sizeof(BatchNormParamType<T>));
PADDLE_ENFORCE_EQ(scale->dims().size(), PADDLE_ENFORCE_EQ(scale->dims().size(),
1UL, 1UL,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
...@@ -268,7 +285,6 @@ class FusedBatchNormActGradKernel<phi::GPUContext, T> ...@@ -268,7 +285,6 @@ class FusedBatchNormActGradKernel<phi::GPUContext, T>
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The size of scale is equal to the channel of Input(X).")); "The size of scale is equal to the channel of Input(X)."));
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
if ((N * H * W * D) == 1) { if ((N * H * W * D) == 1) {
if (act_type == "relu") { if (act_type == "relu") {
auto x_v = framework::EigenVector<T>::Flatten(*x); auto x_v = framework::EigenVector<T>::Flatten(*x);
...@@ -344,8 +360,11 @@ class FusedBatchNormActGradKernel<phi::GPUContext, T> ...@@ -344,8 +360,11 @@ class FusedBatchNormActGradKernel<phi::GPUContext, T>
/*activationDesc=*/activation_desc_, /*activationDesc=*/activation_desc_,
/*sizeInBytes=*/&workspace_size)); /*sizeInBytes=*/&workspace_size));
workspace_ptr = workspace_tensor.mutable_data( workspace_tensor.Resize({static_cast<int64_t>(
ctx.GetPlace(), x->type(), workspace_size); (workspace_size + experimental::SizeOf(x->dtype()) - 1) /
experimental::SizeOf(x->dtype()))});
workspace_ptr = dev_ctx.Alloc<T>(&workspace_tensor,
workspace_tensor.numel() * sizeof(T));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnBatchNormalizationBackwardEx( platform::dynload::cudnnBatchNormalizationBackwardEx(
...@@ -365,16 +384,17 @@ class FusedBatchNormActGradKernel<phi::GPUContext, T> ...@@ -365,16 +384,17 @@ class FusedBatchNormActGradKernel<phi::GPUContext, T>
/*dzDesc=*/nullptr, /*dzDesc=*/nullptr,
/*dzData=*/nullptr, /*dzData=*/nullptr,
/*dxDesc=*/data_desc_, /*dxDesc=*/data_desc_,
/*dxData=*/d_x->template mutable_data<T>(ctx.GetPlace()), /*dxData=*/
dev_ctx.template Alloc<T>(d_x, d_x->numel() * sizeof(T)),
/*dBnScaleBiasDesc=*/bn_param_desc_, /*dBnScaleBiasDesc=*/bn_param_desc_,
/*bnScaleData=*/scale->template data<BatchNormParamType<T>>(), /*bnScaleData=*/scale->template data<BatchNormParamType<T>>(),
/*bnBiasData=*/bias->template data<BatchNormParamType<T>>(), /*bnBiasData=*/bias->template data<BatchNormParamType<T>>(),
/*dBnScaleData=*/ /*dBnScaleData=*/
d_scale->template mutable_data<BatchNormParamType<T>>( dev_ctx.template Alloc<BatchNormParamType<T>>(
ctx.GetPlace()), d_scale, d_scale->numel() * sizeof(BatchNormParamType<T>)),
/*dBnBiasData=*/ /*dBnBiasData=*/
d_bias->template mutable_data<BatchNormParamType<T>>( dev_ctx.template Alloc<BatchNormParamType<T>>(
ctx.GetPlace()), d_bias, d_bias->numel() * sizeof(BatchNormParamType<T>)),
/*epsilon=*/epsilon, /*epsilon=*/epsilon,
/*savedMean=*/saved_mean_data, /*savedMean=*/saved_mean_data,
/*savedInvVariance=*/saved_var_data, /*savedInvVariance=*/saved_var_data,
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/operators/norm_utils.h" #include "paddle/fluid/operators/norm_utils.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
DECLARE_bool(cudnn_batchnorm_spatial_persistent); DECLARE_bool(cudnn_batchnorm_spatial_persistent);
...@@ -44,6 +45,7 @@ class FusedBatchNormAddActKernel<phi::GPUContext, T> ...@@ -44,6 +45,7 @@ class FusedBatchNormAddActKernel<phi::GPUContext, T>
platform::is_gpu_place(ctx.GetPlace()), platform::is_gpu_place(ctx.GetPlace()),
true, true,
platform::errors::PreconditionNotMet("It must use CUDAPlace.")); platform::errors::PreconditionNotMet("It must use CUDAPlace."));
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon")); double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
float momentum = ctx.Attr<float>("momentum"); float momentum = ctx.Attr<float>("momentum");
std::string act_type = ctx.Attr<std::string>("act_type"); std::string act_type = ctx.Attr<std::string>("act_type");
...@@ -66,23 +68,26 @@ class FusedBatchNormAddActKernel<phi::GPUContext, T> ...@@ -66,23 +68,26 @@ class FusedBatchNormAddActKernel<phi::GPUContext, T>
auto *mean_out = ctx.Output<Tensor>("MeanOut"); auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut"); auto *variance_out = ctx.Output<Tensor>("VarianceOut");
mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); dev_ctx.Alloc<BatchNormParamType<T>>(
variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); mean_out, mean_out->numel() * sizeof(BatchNormParamType<T>));
dev_ctx.Alloc<BatchNormParamType<T>>(
variance_out, variance_out->numel() * sizeof(BatchNormParamType<T>));
auto *saved_mean = ctx.Output<Tensor>("SavedMean"); auto *saved_mean = ctx.Output<Tensor>("SavedMean");
auto *saved_variance = ctx.Output<Tensor>("SavedVariance"); auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); dev_ctx.Alloc<BatchNormParamType<T>>(
saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); saved_mean, saved_mean->numel() * sizeof(BatchNormParamType<T>));
dev_ctx.Alloc<BatchNormParamType<T>>(
saved_variance,
saved_variance->numel() * sizeof(BatchNormParamType<T>));
auto *y = ctx.Output<Tensor>("Y"); auto *y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace()); dev_ctx.Alloc<T>(y, y->numel() * sizeof(T));
int N, C, H, W, D; int N, C, H, W, D;
const DataLayout data_layout = DataLayout::kNHWC; const DataLayout data_layout = DataLayout::kNHWC;
ExtractNCWHD(in_dims, data_layout, &N, &C, &H, &W, &D); ExtractNCWHD(in_dims, data_layout, &N, &C, &H, &W, &D);
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t data_desc_;
...@@ -149,10 +154,17 @@ class FusedBatchNormAddActKernel<phi::GPUContext, T> ...@@ -149,10 +154,17 @@ class FusedBatchNormAddActKernel<phi::GPUContext, T>
/*xDesc=*/data_desc_, /*xDesc=*/data_desc_,
/*sizeInBytes=*/&reserve_space_size)); /*sizeInBytes=*/&reserve_space_size));
reserve_space_ptr = reserve_space->mutable_data( reserve_space->Resize({static_cast<int64_t>(
ctx.GetPlace(), x->dtype(), reserve_space_size); (reserve_space_size + experimental::SizeOf(x->dtype()) - 1) /
workspace_ptr = workspace_tensor.mutable_data( experimental::SizeOf(x->dtype()))});
ctx.GetPlace(), x->dtype(), workspace_size); reserve_space_ptr =
dev_ctx.Alloc<T>(reserve_space, reserve_space->numel() * sizeof(T));
workspace_tensor.Resize({static_cast<int64_t>(
(workspace_size + experimental::SizeOf(x->dtype()) - 1) /
experimental::SizeOf(x->dtype()))});
workspace_ptr = dev_ctx.Alloc<T>(&workspace_tensor,
workspace_tensor.numel() * sizeof(T));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnBatchNormalizationForwardTrainingEx( platform::dynload::cudnnBatchNormalizationForwardTrainingEx(
handle, handle,
...@@ -170,15 +182,18 @@ class FusedBatchNormAddActKernel<phi::GPUContext, T> ...@@ -170,15 +182,18 @@ class FusedBatchNormAddActKernel<phi::GPUContext, T>
scale->template data<BatchNormParamType<T>>(), scale->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>(), bias->template data<BatchNormParamType<T>>(),
this_factor, this_factor,
mean_out->template mutable_data<BatchNormParamType<T>>( dev_ctx.template Alloc<BatchNormParamType<T>>(
ctx.GetPlace()), mean_out, mean_out->numel() * sizeof(BatchNormParamType<T>)),
variance_out->template mutable_data<BatchNormParamType<T>>( dev_ctx.template Alloc<BatchNormParamType<T>>(
ctx.GetPlace()), variance_out,
variance_out->numel() * sizeof(BatchNormParamType<T>)),
epsilon, epsilon,
saved_mean->template mutable_data<BatchNormParamType<T>>( dev_ctx.template Alloc<BatchNormParamType<T>>(
ctx.GetPlace()), saved_mean,
saved_variance->template mutable_data<BatchNormParamType<T>>( saved_mean->numel() * sizeof(BatchNormParamType<T>)),
ctx.GetPlace()), dev_ctx.template Alloc<BatchNormParamType<T>>(
saved_variance,
saved_variance->numel() * sizeof(BatchNormParamType<T>)),
activation_desc_, activation_desc_,
workspace_ptr, workspace_ptr,
workspace_size, workspace_size,
...@@ -212,6 +227,7 @@ class FusedBatchNormAddActGradKernel<phi::GPUContext, T> ...@@ -212,6 +227,7 @@ class FusedBatchNormAddActGradKernel<phi::GPUContext, T>
const auto *bias = ctx.Input<Tensor>("Bias"); const auto *bias = ctx.Input<Tensor>("Bias");
const auto *reserve_space = ctx.Input<Tensor>("ReserveSpace"); const auto *reserve_space = ctx.Input<Tensor>("ReserveSpace");
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
const auto &in_dims = x->dims(); const auto &in_dims = x->dims();
int N, C, H, W, D; int N, C, H, W, D;
...@@ -243,8 +259,6 @@ class FusedBatchNormAddActGradKernel<phi::GPUContext, T> ...@@ -243,8 +259,6 @@ class FusedBatchNormAddActGradKernel<phi::GPUContext, T>
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The size of scale is equal to the channel of Input(X).")); "The size of scale is equal to the channel of Input(X)."));
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
std::vector<int> dims = {N, C, H, W, D}; std::vector<int> dims = {N, C, H, W, D};
std::vector<int> strides = {H * W * C * D, 1, W * D * C, D * C, C}; std::vector<int> strides = {H * W * C * D, 1, W * D * C, D * C, C};
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
......
...@@ -57,7 +57,7 @@ static void AllReduce(framework::Tensor& tensor, // NOLINT ...@@ -57,7 +57,7 @@ static void AllReduce(framework::Tensor& tensor, // NOLINT
int64_t numel = tensor.numel(); int64_t numel = tensor.numel();
const void* sendbuff = tensor.data<T>(); const void* sendbuff = tensor.data<T>();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
void* recvbuff = tensor.mutable_data<T>(place); void* recvbuff = ctx.Alloc<T>(&tensor, tensor.numel() * sizeof(T));
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = ctx.stream(); auto stream = ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
...@@ -125,7 +125,6 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -125,7 +125,6 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper( FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx, bsz_seq, d_model, dropout_param2, epsilon2); ctx, bsz_seq, d_model, dropout_param2, epsilon2);
auto place = ctx.GetPlace();
using U = LayerNormParamType<T>; using U = LayerNormParamType<T>;
const framework::Tensor* in = &x; const framework::Tensor* in = &x;
...@@ -158,7 +157,8 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -158,7 +157,8 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
dropout1_out->data<T>(), dropout1_out->data<T>(),
dropout1_mask->data<uint8_t>()); dropout1_mask->data<uint8_t>());
framework::Tensor linear2_out; framework::Tensor linear2_out;
linear2_out.mutable_data<T>({bsz_seq, d_model}, place); linear2_out.Resize({bsz_seq, d_model});
ctx.Alloc<T>(&linear2_out, linear2_out.numel() * sizeof(T));
MatMul(ctx, *dropout1_out, linear2_weight, &linear2_out); MatMul(ctx, *dropout1_out, linear2_weight, &linear2_out);
// tensor model parallel // tensor model parallel
...@@ -203,6 +203,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -203,6 +203,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
auto* linear2_weight = context.Input<framework::Tensor>("Linear2Weight"); auto* linear2_weight = context.Input<framework::Tensor>("Linear2Weight");
auto* linear2_bias = context.Input<framework::Tensor>("Linear2Bias"); auto* linear2_bias = context.Input<framework::Tensor>("Linear2Bias");
const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm"); const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm");
auto& dev_ctx = context.template device_context<phi::GPUContext>();
auto* ln1_scale = auto* ln1_scale =
pre_layer_norm ? context.Input<framework::Tensor>("Ln1Scale") : nullptr; pre_layer_norm ? context.Input<framework::Tensor>("Ln1Scale") : nullptr;
...@@ -245,22 +246,23 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -245,22 +246,23 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
DropoutParam dropout_param2(context, 2); DropoutParam dropout_param2(context, 2);
using U = LayerNormParamType<T>; using U = LayerNormParamType<T>;
auto place = context.GetPlace(); dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
out->mutable_data<T>(place); dev_ctx.Alloc<uint8_t>(dropout1_mask,
dropout1_mask->mutable_data<uint8_t>(place); dropout1_mask->numel() * sizeof(uint8_t));
dropout2_mask->mutable_data<uint8_t>(place); dev_ctx.Alloc<uint8_t>(dropout2_mask,
dropout2_mask->numel() * sizeof(uint8_t));
if (pre_layer_norm) { if (pre_layer_norm) {
ln1_mean->mutable_data<U>(place); dev_ctx.Alloc<U>(ln1_mean, ln1_mean->numel() * sizeof(U));
ln1_variance->mutable_data<U>(place); dev_ctx.Alloc<U>(ln1_variance, ln1_variance->numel() * sizeof(U));
ln1_out->mutable_data<T>(place); dev_ctx.Alloc<T>(ln1_out, ln1_out->numel() * sizeof(T));
} else { } else {
ln2_mean->mutable_data<U>(place); dev_ctx.Alloc<U>(ln2_mean, ln2_mean->numel() * sizeof(U));
ln2_variance->mutable_data<U>(place); dev_ctx.Alloc<U>(ln2_variance, ln2_variance->numel() * sizeof(U));
} }
linear1_out->mutable_data<T>(place); dev_ctx.Alloc<T>(linear1_out, linear1_out->numel() * sizeof(T));
dropout1_out->mutable_data<T>(place); dev_ctx.Alloc<T>(dropout1_out, dropout1_out->numel() * sizeof(T));
dropout2_out->mutable_data<T>(place); dev_ctx.Alloc<T>(dropout2_out, dropout2_out->numel() * sizeof(T));
auto x_dim = x->dims(); auto x_dim = x->dims();
auto mat_dim_x = phi::funcs::CreateMatrixDescriptor( auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(
...@@ -374,7 +376,6 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -374,7 +376,6 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper( FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx, bsz_seq, d_model, dropout_param2, epsilon2); ctx, bsz_seq, d_model, dropout_param2, epsilon2);
auto place = ctx.GetPlace();
using U = LayerNormParamType<T>; using U = LayerNormParamType<T>;
const U* ln1_gamma_ptr = const U* ln1_gamma_ptr =
ln1_gamma == nullptr ? nullptr : ln1_gamma->data<U>(); ln1_gamma == nullptr ? nullptr : ln1_gamma->data<U>();
...@@ -396,12 +397,16 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -396,12 +397,16 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
U* d_ln2_beta_ptr = d_ln2_beta == nullptr ? nullptr : d_ln2_beta->data<U>(); U* d_ln2_beta_ptr = d_ln2_beta == nullptr ? nullptr : d_ln2_beta->data<U>();
framework::Tensor d_linear2_out, d_dropout2_out, d_residual; framework::Tensor d_linear2_out, d_dropout2_out, d_residual;
d_linear2_out.mutable_data<T>({bsz_seq, d_model}, place); d_linear2_out.Resize({bsz_seq, d_model});
d_dropout2_out.mutable_data<T>({bsz_seq, d_model}, place); ctx.Alloc<T>(&d_linear2_out, d_linear2_out.numel() * sizeof(T));
d_dropout2_out.Resize({bsz_seq, d_model});
ctx.Alloc<T>(&d_dropout2_out, d_dropout2_out.numel() * sizeof(T));
T* d_residual_ptr = nullptr; T* d_residual_ptr = nullptr;
if (add_residual) { if (add_residual) {
d_residual_ptr = d_residual.mutable_data<T>(d_x->dims(), place); d_residual.Resize(d_x->dims());
d_residual_ptr =
ctx.Alloc<T>(&d_residual, d_residual.numel() * sizeof(T));
} }
if (pre_layer_norm) { if (pre_layer_norm) {
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad( fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
...@@ -429,7 +434,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -429,7 +434,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
} }
framework::Tensor d_dropout1_out; framework::Tensor d_dropout1_out;
d_dropout1_out.mutable_data<T>({bsz_seq, dim_feedforward}, place); d_dropout1_out.Resize({bsz_seq, dim_feedforward});
ctx.Alloc<T>(&d_dropout1_out, d_dropout1_out.numel() * sizeof(T));
MatMulGrad(ctx, MatMulGrad(ctx,
d_linear2_out, d_linear2_out,
dropout1_out, dropout1_out,
...@@ -438,7 +444,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -438,7 +444,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
d_linear2_weight); d_linear2_weight);
framework::Tensor d_linear1_out; framework::Tensor d_linear1_out;
d_linear1_out.mutable_data<T>({bsz_seq, dim_feedforward}, place); d_linear1_out.Resize({bsz_seq, dim_feedforward});
ctx.Alloc<T>(&d_linear1_out, d_linear1_out.numel() * sizeof(T));
fused_act_dropout_helper.DropoutActBiasGrad(ctx, fused_act_dropout_helper.DropoutActBiasGrad(ctx,
d_dropout1_out.data<T>(), d_dropout1_out.data<T>(),
linear1_out.data<T>(), linear1_out.data<T>(),
...@@ -450,7 +457,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -450,7 +457,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
if (pre_layer_norm) { if (pre_layer_norm) {
framework::Tensor d_ln1_out; framework::Tensor d_ln1_out;
d_ln1_out.mutable_data<T>({bsz_seq, d_model}, place); d_ln1_out.Resize({bsz_seq, d_model});
ctx.Alloc<T>(&d_ln1_out, d_ln1_out.numel() * sizeof(T));
MatMulGrad(ctx, MatMulGrad(ctx,
d_linear1_out, d_linear1_out,
*ln1_out, *ln1_out,
...@@ -485,6 +493,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -485,6 +493,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
using U = LayerNormParamType<T>; using U = LayerNormParamType<T>;
auto& dev_ctx = context.template device_context<phi::GPUContext>();
auto d_out = auto d_out =
*context.Input<framework::Tensor>(framework::GradVarName("Out")); *context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto x = *context.Input<framework::Tensor>("X"); auto x = *context.Input<framework::Tensor>("X");
...@@ -550,28 +559,27 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -550,28 +559,27 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
DropoutParam dropout_param1(context, 1); DropoutParam dropout_param1(context, 1);
DropoutParam dropout_param2(context, 2); DropoutParam dropout_param2(context, 2);
auto place = context.GetPlace(); dev_ctx.Alloc<T>(d_x, d_x->numel() * sizeof(T));
d_x->mutable_data<T>(place);
if (d_ln1_scale) { if (d_ln1_scale) {
d_ln1_scale->mutable_data<U>(place); dev_ctx.Alloc<U>(d_ln1_scale, d_ln1_scale->numel() * sizeof(U));
} }
if (d_ln1_bias) { if (d_ln1_bias) {
d_ln1_bias->mutable_data<U>(place); dev_ctx.Alloc<U>(d_ln1_bias, d_ln1_bias->numel() * sizeof(U));
} }
if (d_ln2_scale) { if (d_ln2_scale) {
d_ln2_scale->mutable_data<U>(place); dev_ctx.Alloc<U>(d_ln2_scale, d_ln2_scale->numel() * sizeof(U));
} }
if (d_ln2_bias) { if (d_ln2_bias) {
d_ln2_bias->mutable_data<U>(place); dev_ctx.Alloc<U>(d_ln2_bias, d_ln2_bias->numel() * sizeof(U));
} }
if (d_linear1_bias) { if (d_linear1_bias) {
d_linear1_bias->mutable_data<T>(place); dev_ctx.Alloc<T>(d_linear1_bias, d_linear1_bias->numel() * sizeof(T));
} }
if (d_linear2_bias) { if (d_linear2_bias) {
d_linear2_bias->mutable_data<T>(place); dev_ctx.Alloc<T>(d_linear2_bias, d_linear2_bias->numel() * sizeof(T));
} }
d_linear1_weight->mutable_data<T>(place); dev_ctx.Alloc<T>(d_linear1_weight, d_linear1_weight->numel() * sizeof(T));
d_linear2_weight->mutable_data<T>(place); dev_ctx.Alloc<T>(d_linear2_weight, d_linear2_weight->numel() * sizeof(T));
auto x_dim = x.dims(); auto x_dim = x.dims();
auto mat_dim_x = phi::funcs::CreateMatrixDescriptor( auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(
......
...@@ -47,7 +47,7 @@ template <typename T> ...@@ -47,7 +47,7 @@ template <typename T>
void AllocWithDebugInfo(const phi::GPUContext& dev_ctx, void AllocWithDebugInfo(const phi::GPUContext& dev_ctx,
const std::string& info, const std::string& info,
Tensor* t) { Tensor* t) {
t->mutable_data<T>(dev_ctx.GetPlace()); dev_ctx.Alloc<T>(t, t->numel() * sizeof(T));
VLOG(4) << info << ": " << MemoryDebugString(*t); VLOG(4) << info << ": " << MemoryDebugString(*t);
} }
...@@ -505,9 +505,12 @@ class FMHAGateRef { ...@@ -505,9 +505,12 @@ class FMHAGateRef {
k_transpose_out_grad.Resize(config->kv_transpose_out_dims); k_transpose_out_grad.Resize(config->kv_transpose_out_dims);
v_transpose_out_grad.Resize(config->kv_transpose_out_dims); v_transpose_out_grad.Resize(config->kv_transpose_out_dims);
q_grad_ptr = q_transpose_out_grad.mutable_data<T>(dev_ctx_.GetPlace()); q_grad_ptr = dev_ctx_.Alloc<T>(&q_transpose_out_grad,
k_grad_ptr = k_transpose_out_grad.mutable_data<T>(dev_ctx_.GetPlace()); q_transpose_out_grad.numel() * sizeof(T));
v_grad_ptr = v_transpose_out_grad.mutable_data<T>(dev_ctx_.GetPlace()); k_grad_ptr = dev_ctx_.Alloc<T>(&k_transpose_out_grad,
k_transpose_out_grad.numel() * sizeof(T));
v_grad_ptr = dev_ctx_.Alloc<T>(&v_transpose_out_grad,
v_transpose_out_grad.numel() * sizeof(T));
} }
Tensor softmax_out_grad; Tensor softmax_out_grad;
......
...@@ -90,7 +90,8 @@ void ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx, ...@@ -90,7 +90,8 @@ void ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx,
auto *qkv_weight = ctx.Input<Tensor>("QKVWeight"); auto *qkv_weight = ctx.Input<Tensor>("QKVWeight");
auto *qkv_weight_grad = auto *qkv_weight_grad =
ctx.Output<Tensor>(framework::GradVarName("QKVWeight")); ctx.Output<Tensor>(framework::GradVarName("QKVWeight"));
qkv_weight_grad->mutable_data<T>(ctx.GetPlace()); auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
dev_ctx.Alloc<T>(qkv_weight_grad, qkv_weight_grad->numel() * sizeof(T));
// Gradient of GEMM(query, qkv_weight) // Gradient of GEMM(query, qkv_weight)
int m = config.batch_size * config.seq_len_m * config.seq_len_r; int m = config.batch_size * config.seq_len_m * config.seq_len_r;
...@@ -160,7 +161,8 @@ void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx, ...@@ -160,7 +161,8 @@ void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx,
const auto *key_weight = ctx.Input<Tensor>("KeyWeight"); const auto *key_weight = ctx.Input<Tensor>("KeyWeight");
auto *key_weight_grad = auto *key_weight_grad =
ctx.Output<Tensor>(framework::GradVarName("KeyWeight")); ctx.Output<Tensor>(framework::GradVarName("KeyWeight"));
key_weight_grad->mutable_data<T>(ctx.GetPlace()); auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
dev_ctx.Alloc<T>(key_weight_grad, key_weight_grad->numel() * sizeof(T));
int kv_m = config.batch_size * config.seq_len_m * config.m_size; int kv_m = config.batch_size * config.seq_len_m * config.m_size;
int kv_n = config.num_heads * config.head_dim; int kv_n = config.num_heads * config.head_dim;
...@@ -174,7 +176,7 @@ void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx, ...@@ -174,7 +176,7 @@ void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx,
auto *value_weight = ctx.Input<Tensor>("ValueWeight"); auto *value_weight = ctx.Input<Tensor>("ValueWeight");
auto *value_weight_grad = auto *value_weight_grad =
ctx.Output<Tensor>(framework::GradVarName("ValueWeight")); ctx.Output<Tensor>(framework::GradVarName("ValueWeight"));
value_weight_grad->mutable_data<T>(ctx.GetPlace()); dev_ctx.Alloc<T>(value_weight_grad, value_weight_grad->numel() * sizeof(T));
kv_compute.ComputeBackward(key, kv_compute.ComputeBackward(key,
value_weight, value_weight,
...@@ -188,7 +190,7 @@ void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx, ...@@ -188,7 +190,7 @@ void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx,
const auto *query_weight = ctx.Input<Tensor>("QueryWeight"); const auto *query_weight = ctx.Input<Tensor>("QueryWeight");
auto *query_weight_grad = auto *query_weight_grad =
ctx.Output<Tensor>(framework::GradVarName("QueryWeight")); ctx.Output<Tensor>(framework::GradVarName("QueryWeight"));
query_weight_grad->mutable_data<T>(ctx.GetPlace()); dev_ctx.Alloc<T>(query_weight_grad, query_weight_grad->numel() * sizeof(T));
int q_m = config.batch_size * config.seq_len_m * config.seq_len_r; int q_m = config.batch_size * config.seq_len_m * config.seq_len_r;
int q_n = config.num_heads * config.head_dim; int q_n = config.num_heads * config.head_dim;
...@@ -242,11 +244,11 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, ...@@ -242,11 +244,11 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx,
Tensor *fmha_out_grad) { Tensor *fmha_out_grad) {
const auto *gate_weight = ctx.Input<Tensor>("GateWeight"); const auto *gate_weight = ctx.Input<Tensor>("GateWeight");
const auto *gate_bias = ctx.Input<Tensor>("GateBias"); const auto *gate_bias = ctx.Input<Tensor>("GateBias");
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
// Re-compute gate_bias_out // Re-compute gate_bias_out
Tensor gate_bias_out; Tensor gate_bias_out;
gate_bias_out.Resize(config.gate_out_dims); gate_bias_out.Resize(config.gate_out_dims);
gate_bias_out.mutable_data<T>(ctx.GetPlace()); dev_ctx.Alloc<T>(&gate_bias_out, gate_bias_out.numel() * sizeof(T));
int m = config.batch_size * config.seq_len_m * config.seq_len_r; int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.num_heads * config.head_dim; int n = config.num_heads * config.head_dim;
...@@ -267,8 +269,8 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, ...@@ -267,8 +269,8 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx,
auto *gate_weight_grad = auto *gate_weight_grad =
ctx.Output<Tensor>(framework::GradVarName("GateWeight")); ctx.Output<Tensor>(framework::GradVarName("GateWeight"));
auto *gate_bias_grad = ctx.Output<Tensor>(framework::GradVarName("GateBias")); auto *gate_bias_grad = ctx.Output<Tensor>(framework::GradVarName("GateBias"));
gate_weight_grad->mutable_data<T>(ctx.GetPlace()); dev_ctx.Alloc<T>(gate_weight_grad, gate_weight_grad->numel() * sizeof(T));
gate_bias_grad->mutable_data<T>(ctx.GetPlace()); dev_ctx.Alloc<T>(gate_bias_grad, gate_bias_grad->numel() * sizeof(T));
gate_attn_compute.ComputeBackward(query, gate_attn_compute.ComputeBackward(query,
gate_weight, gate_weight,
...@@ -301,6 +303,7 @@ void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx, ...@@ -301,6 +303,7 @@ void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx,
const GateAttentionGradConfig<T> &config, const GateAttentionGradConfig<T> &config,
const Tensor *input, const Tensor *input,
Tensor *input_grad) { Tensor *input_grad) {
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
const auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); const auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
const auto *out_linear_weight = ctx.Input<Tensor>("OutLinearWeight"); const auto *out_linear_weight = ctx.Input<Tensor>("OutLinearWeight");
...@@ -309,8 +312,10 @@ void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx, ...@@ -309,8 +312,10 @@ void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx,
auto *out_linear_bias_grad = auto *out_linear_bias_grad =
ctx.Output<Tensor>(framework::GradVarName("OutLinearBias")); ctx.Output<Tensor>(framework::GradVarName("OutLinearBias"));
out_linear_weight_grad->mutable_data<T>(ctx.GetPlace()); dev_ctx.Alloc<T>(out_linear_weight_grad,
out_linear_bias_grad->mutable_data<T>(ctx.GetPlace()); out_linear_weight_grad->numel() * sizeof(T));
dev_ctx.Alloc<T>(out_linear_bias_grad,
out_linear_bias_grad->numel() * sizeof(T));
int m = config.batch_size * config.seq_len_m * config.seq_len_r; int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.q_dim; int n = config.q_dim;
......
...@@ -46,7 +46,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -46,7 +46,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
<< " , activation = " << activation; << " , activation = " << activation;
bool enable_auxiliary = reserve_space == nullptr ? false : true; bool enable_auxiliary = reserve_space == nullptr ? false : true;
out->mutable_data<T>(ctx.GetPlace()); dev_ctx->Alloc<T>(out, out->numel() * sizeof(T));
auto* out_data = out->data<T>(); auto* out_data = out->data<T>();
auto x_mat_dims = auto x_mat_dims =
...@@ -110,8 +110,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -110,8 +110,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
} else { } else {
reserve_space_size = phi::product(out->dims()) * sizeof(T); reserve_space_size = phi::product(out->dims()) * sizeof(T);
} }
reserve_space->mutable_data( dev_ctx->Alloc(reserve_space, out->type(), reserve_space_size);
ctx.GetPlace(), out->type(), reserve_space_size);
void* aux_data = reinterpret_cast<void*>(reserve_space->data<T>()); void* aux_data = reinterpret_cast<void*>(reserve_space->data<T>());
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
...@@ -493,7 +492,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -493,7 +492,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
workspace_size, workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace()); auto* dx_data = dev_ctx->Alloc<T>(dx, dx->numel() * sizeof(T));
const auto* y_data = y->data<T>(); const auto* y_data = y->data<T>();
const auto* dout_data = dout->data<T>(); const auto* dout_data = dout->data<T>();
const auto* a_data = kXGradAIsDZ ? dout_data : y_data; const auto* a_data = kXGradAIsDZ ? dout_data : y_data;
...@@ -601,7 +600,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -601,7 +600,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
sizeof(epiloque_func_for_dy))); sizeof(epiloque_func_for_dy)));
if (dbias) { if (dbias) {
auto* dbias_data = dbias->mutable_data<T>(ctx.GetPlace()); auto* dbias_data = dev_ctx->Alloc<T>(dbias, dbias->numel() * sizeof(T));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, dy_operation_desc,
...@@ -614,7 +613,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -614,7 +613,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
dev_ctx.GetPlace(), dev_ctx.GetPlace(),
workspace_size, workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace()); auto* dy_data = dev_ctx->Alloc<T>(dy, dy->numel() * sizeof(T));
const auto* dout_data = dout->data<T>(); const auto* dout_data = dout->data<T>();
const auto* x_data = x->data<T>(); const auto* x_data = x->data<T>();
const auto* a_data = kYGradAIsDZ ? dout_data : x_data; const auto* a_data = kYGradAIsDZ ? dout_data : x_data;
......
...@@ -70,7 +70,7 @@ static void AllReduce(framework::Tensor &tensor, // NOLINT ...@@ -70,7 +70,7 @@ static void AllReduce(framework::Tensor &tensor, // NOLINT
int64_t numel = tensor.numel(); int64_t numel = tensor.numel();
const void *sendbuff = tensor.data<T>(); const void *sendbuff = tensor.data<T>();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
void *recvbuff = tensor.mutable_data<T>(place); void *recvbuff = ctx.Alloc<T>(&tensor, tensor.numel() * sizeof(T));
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = ctx.stream(); auto stream = ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
...@@ -1161,7 +1161,6 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -1161,7 +1161,6 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>; using U = LayerNormParamType<T>;
auto place = ctx.GetPlace();
auto &dev_ctx = ctx.cuda_device_context(); auto &dev_ctx = ctx.cuda_device_context();
auto *time_step = ctx.Input<Tensor>("TimeStep"); auto *time_step = ctx.Input<Tensor>("TimeStep");
...@@ -1181,8 +1180,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -1181,8 +1180,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto ln_compute = AttnLayerNorm<T>(dev_ctx, epsilon, bsz_seq, dim_embed); auto ln_compute = AttnLayerNorm<T>(dev_ctx, epsilon, bsz_seq, dim_embed);
Tensor ln_mean, ln_var; Tensor ln_mean, ln_var;
auto *ln_mean_data = ln_mean.mutable_data<U>({bsz_seq}, place); ln_mean.Resize({{bsz_seq}});
auto *ln_var_data = ln_var.mutable_data<U>({bsz_seq}, place); auto *ln_mean_data =
dev_ctx.Alloc<U>(&ln_mean, ln_mean.numel() * sizeof(U));
ln_var.Resize({{bsz_seq}});
auto *ln_var_data = dev_ctx.Alloc<U>(&ln_var, ln_var.numel() * sizeof(U));
// 2. qkv // 2. qkv
// x: qkv's input [batch_size, seq_len, dim_embed] // x: qkv's input [batch_size, seq_len, dim_embed]
...@@ -1207,8 +1209,9 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -1207,8 +1209,9 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
input_size, input_size,
compute_bias); compute_bias);
Tensor qkv_out; Tensor qkv_out;
qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}});
auto *qkv_out_data = auto *qkv_out_data =
qkv_out.mutable_data<T>({bsz, seq_len, 3, num_head, dim_head}, place); dev_ctx.Alloc<T>(&qkv_out, qkv_out.numel() * sizeof(T));
// 3. fmha // 3. fmha
AttnDropoutParam attn_param( AttnDropoutParam attn_param(
...@@ -1243,26 +1246,32 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -1243,26 +1246,32 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
} }
Tensor transpose_out_2, qk_out; Tensor transpose_out_2, qk_out;
auto *transpose_out_2_data = transpose_out_2.mutable_data<T>( transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}});
{3, bsz, num_head, seq_len, dim_head}, place); auto *transpose_out_2_data =
auto *qk_out_data = dev_ctx.Alloc<T>(&transpose_out_2, transpose_out_2.numel() * sizeof(T));
qk_out.mutable_data<T>({bsz, num_head, seq_len, out_seq_len}, place); qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *qk_out_data = dev_ctx.Alloc<T>(&qk_out, qk_out.numel() * sizeof(T));
Tensor softmax_out; Tensor softmax_out;
Tensor attn_dropout_mask_out, attn_dropout_out; Tensor attn_dropout_mask_out, attn_dropout_out;
Tensor qktv_out, fmha_out; Tensor qktv_out, fmha_out;
auto *softmax_out_data = softmax_out.mutable_data<T>( softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
{bsz, num_head, seq_len, out_seq_len}, place); auto *softmax_out_data =
dev_ctx.Alloc<T>(&softmax_out, softmax_out.numel() * sizeof(T));
auto *attn_dropout_mask_out_data = attn_dropout_mask_out.mutable_data<T>(
{bsz, num_head, seq_len, out_seq_len}, place); attn_dropout_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *attn_dropout_data_data = attn_dropout_out.mutable_data<T>( auto *attn_dropout_mask_out_data = dev_ctx.Alloc<T>(
{bsz, num_head, seq_len, out_seq_len}, place); &attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T));
attn_dropout_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *attn_dropout_data_data = dev_ctx.Alloc<T>(
&attn_dropout_out, attn_dropout_out.numel() * sizeof(T));
qktv_out.Resize({{bsz, num_head, seq_len, dim_head}});
auto *qktv_out_data = auto *qktv_out_data =
qktv_out.mutable_data<T>({bsz, num_head, seq_len, dim_head}, place); dev_ctx.Alloc<T>(&qktv_out, qktv_out.numel() * sizeof(T));
fmha_out.Resize({{bsz, seq_len, num_head, dim_head}});
auto *fmha_out_data = auto *fmha_out_data =
fmha_out.mutable_data<T>({bsz, seq_len, num_head, dim_head}, place); dev_ctx.Alloc<T>(&fmha_out, fmha_out.numel() * sizeof(T));
// 4. out_linear // 4. out_linear
auto out_linear_weights = ctx.MultiInput<Tensor>("OutLinearW"); auto out_linear_weights = ctx.MultiInput<Tensor>("OutLinearW");
...@@ -1281,12 +1290,14 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -1281,12 +1290,14 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
Tensor bias_dropout_residual_out, dropout_mask_out; Tensor bias_dropout_residual_out, dropout_mask_out;
T *bias_dropout_residual_out_data = nullptr; T *bias_dropout_residual_out_data = nullptr;
if (pre_layer_norm) { if (pre_layer_norm) {
bias_dropout_residual_out.Resize({{bsz, seq_len, dim_embed}});
bias_dropout_residual_out_data = bias_dropout_residual_out_data =
bias_dropout_residual_out.mutable_data<T>({bsz, seq_len, dim_embed}, dev_ctx.Alloc<T>(&bias_dropout_residual_out,
place); bias_dropout_residual_out.numel() * sizeof(T));
} }
auto *dropout_mask_out_data = dropout_mask_out.mutable_data<uint8_t>( dropout_mask_out.Resize({{bsz, seq_len, dim_embed}});
{bsz, seq_len, dim_embed}, place); auto *dropout_mask_out_data = dev_ctx.Alloc<uint8_t>(
&dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t));
// 6. ffn matmul1 // 6. ffn matmul1
auto ffn1_weights = ctx.MultiInput<Tensor>("FFN1Weight"); auto ffn1_weights = ctx.MultiInput<Tensor>("FFN1Weight");
...@@ -1297,17 +1308,21 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -1297,17 +1308,21 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto ffn1_linear_compute = AttnMatMul<T>( auto ffn1_linear_compute = AttnMatMul<T>(
dev_ctx, false, false, bsz_seq, dim_ffn, dim_embed, false); dev_ctx, false, false, bsz_seq, dim_ffn, dim_embed, false);
Tensor ffn1_out; Tensor ffn1_out;
auto *ffn1_out_data = ffn1_out.mutable_data<T>({bsz_seq, dim_ffn}, place); ffn1_out.Resize({{bsz_seq, dim_ffn}});
auto *ffn1_out_data =
dev_ctx.Alloc<T>(&ffn1_out, ffn1_out.numel() * sizeof(T));
// 7. ffn act + bias // 7. ffn act + bias
DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0); DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper( FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param); dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param);
Tensor ffn1_dropout_out, ffn1_dropout_mask; Tensor ffn1_dropout_out, ffn1_dropout_mask;
auto *ffn1_dropout_out_data = ffn1_dropout_out.Resize({{bsz_seq, dim_ffn}});
ffn1_dropout_out.mutable_data<T>({bsz_seq, dim_ffn}, place); auto *ffn1_dropout_out_data = dev_ctx.Alloc<T>(
auto *ffn1_dropout_mask_data = &ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T));
ffn1_dropout_mask.mutable_data<uint8_t>({bsz_seq, dim_ffn}, place); ffn1_dropout_mask.Resize({{bsz_seq, dim_ffn}});
auto *ffn1_dropout_mask_data = dev_ctx.Alloc<uint8_t>(
&ffn1_dropout_mask, ffn1_dropout_mask.numel() * sizeof(uint8_t));
// 8. ffn2 matmul // 8. ffn2 matmul
auto ffn2_weights = ctx.MultiInput<Tensor>("FFN2Weight"); auto ffn2_weights = ctx.MultiInput<Tensor>("FFN2Weight");
...@@ -1322,11 +1337,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -1322,11 +1337,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// calc // calc
auto *out = ctx.Output<Tensor>("Out"); auto *out = ctx.Output<Tensor>("Out");
auto *from_data = out->mutable_data<T>(place); auto *from_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
Tensor *from_tensor = out; Tensor *from_tensor = out;
Tensor tmp_out; Tensor tmp_out;
tmp_out.Resize({{bsz, seq_len, dim_embed}});
auto *tmp_out_data = auto *tmp_out_data =
tmp_out.mutable_data<T>({bsz, seq_len, dim_embed}, place); dev_ctx.Alloc<T>(&tmp_out, tmp_out.numel() * sizeof(T));
auto *x_data = input_x->data<T>(); auto *x_data = input_x->data<T>();
Tensor *buf0 = nullptr; Tensor *buf0 = nullptr;
......
...@@ -426,7 +426,7 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel<T> { ...@@ -426,7 +426,7 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto inputs = ctx.MultiInput<LoDTensor>("X"); auto inputs = ctx.MultiInput<LoDTensor>("X");
auto outputs = ctx.MultiOutput<framework::Tensor>("Out"); auto outputs = ctx.MultiOutput<framework::Tensor>("Out");
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
const auto slot_size = inputs.size(); const auto slot_size = inputs.size();
std::vector<const float *> input_data(slot_size); std::vector<const float *> input_data(slot_size);
std::vector<const size_t *> lods_data(slot_size); std::vector<const size_t *> lods_data(slot_size);
...@@ -478,13 +478,13 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel<T> { ...@@ -478,13 +478,13 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel<T> {
} else { } else {
output->Resize({batch_size, embedding_size - cvm_offset}); output->Resize({batch_size, embedding_size - cvm_offset});
} }
output_data[i] = output_data[i] = reinterpret_cast<T *>(
reinterpret_cast<T *>(output->mutable_data<T>(ctx.GetPlace())); dev_ctx.Alloc<T>(output, output->numel() * sizeof(T)));
mix_lods_v[i] = new paddle::framework::MixVector<size_t>(&lods); mix_lods_v[i] = new paddle::framework::MixVector<size_t>(&lods);
lods_data[i] = mix_lods_v[i]->CUDAData(ctx.GetPlace()); lods_data[i] = mix_lods_v[i]->CUDAData(ctx.GetPlace());
seqpool_output_data[i] = seqpool_outputs[i].Resize({batch_size, embedding_size});
reinterpret_cast<T *>(seqpool_outputs[i].mutable_data<T>( seqpool_output_data[i] = reinterpret_cast<T *>(dev_ctx.Alloc<T>(
{batch_size, embedding_size}, ctx.GetPlace())); &seqpool_outputs[i], seqpool_outputs[i].numel() * sizeof(T)));
} }
FusedSeqpoolCVM(ctx, FusedSeqpoolCVM(ctx,
...@@ -512,7 +512,7 @@ class FusedSeqpoolCVMGradCUDAKernel : public framework::OpKernel<T> { ...@@ -512,7 +512,7 @@ class FusedSeqpoolCVMGradCUDAKernel : public framework::OpKernel<T> {
auto out_grads = ctx.MultiInput<LoDTensor>(framework::GradVarName("Out")); auto out_grads = ctx.MultiInput<LoDTensor>(framework::GradVarName("Out"));
auto in_grads = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X")); auto in_grads = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X"));
auto *cvm = ctx.Input<LoDTensor>("CVM"); auto *cvm = ctx.Input<LoDTensor>("CVM");
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
std::string pooltype = ctx.Attr<std::string>("pooltype"); std::string pooltype = ctx.Attr<std::string>("pooltype");
auto use_cvm = ctx.Attr<bool>("use_cvm"); auto use_cvm = ctx.Attr<bool>("use_cvm");
const int cvm_offset = ctx.Attr<int>("cvm_offset"); const int cvm_offset = ctx.Attr<int>("cvm_offset");
...@@ -559,8 +559,8 @@ class FusedSeqpoolCVMGradCUDAKernel : public framework::OpKernel<T> { ...@@ -559,8 +559,8 @@ class FusedSeqpoolCVMGradCUDAKernel : public framework::OpKernel<T> {
auto *out_grad = out_grads[i]; auto *out_grad = out_grads[i];
out_grads_data[i] = reinterpret_cast<const T *>(out_grad->data<T>()); out_grads_data[i] = reinterpret_cast<const T *>(out_grad->data<T>());
in_grads_data[i] = in_grads_data[i] = reinterpret_cast<T *>(
reinterpret_cast<T *>(in_grad->mutable_data<T>(ctx.GetPlace())); dev_ctx.Alloc<T>(in_grad, in_grad->numel() * sizeof(T)));
mix_lods_v[i] = new paddle::framework::MixVector<size_t>(&lods); mix_lods_v[i] = new paddle::framework::MixVector<size_t>(&lods);
lods_data[i] = mix_lods_v[i]->CUDAData(ctx.GetPlace()); lods_data[i] = mix_lods_v[i]->CUDAData(ctx.GetPlace());
cvm_data[i] = reinterpret_cast<const T *>(cvm->data<T>()); cvm_data[i] = reinterpret_cast<const T *>(cvm->data<T>());
......
...@@ -55,8 +55,10 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> { ...@@ -55,8 +55,10 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> {
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB")); static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = dev_ctx.Alloc<T>(output, output->numel() * sizeof(T));
T* temp_data = temp_outs[0]->mutable_data<T>(input->dims(), ctx.GetPlace()); temp_outs[0]->Resize(input->dims());
T* temp_data =
dev_ctx.Alloc<T>(temp_outs[0], temp_outs[0]->numel() * sizeof(T));
DataLayout layout = DataLayout::kNCHW; DataLayout layout = DataLayout::kNCHW;
std::vector<int> in_dim = phi::vectorize<int>(input->dims()); std::vector<int> in_dim = phi::vectorize<int>(input->dims());
...@@ -254,8 +256,9 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> { ...@@ -254,8 +256,9 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> {
in_datas.push_back(static_cast<const void*>(input_data)); in_datas.push_back(static_cast<const void*>(input_data));
in_datas.push_back( in_datas.push_back(
static_cast<const void*>(output_data + (oc0 + oc1) * h * w)); static_cast<const void*>(output_data + (oc0 + oc1) * h * w));
T* temp2_data = temp_outs[1]->mutable_data<T>(phi::make_ddim(out_dims[2]), temp_outs[1]->Resize(phi::make_ddim(out_dims[2]));
ctx.GetPlace()); T* temp2_data =
dev_ctx.Alloc<T>(temp_outs[1], temp_outs[1]->numel() * sizeof(T));
in_datas.push_back(static_cast<const void*>(temp2_data + oc2 * h * w)); in_datas.push_back(static_cast<const void*>(temp2_data + oc2 * h * w));
std::vector<void*> out_datas; std::vector<void*> out_datas;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册