From 3e2dec5b837397d2e8ecc006e302512c26adba9c Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Tue, 12 Oct 2021 21:46:37 +0800 Subject: [PATCH] Change the input param of fusion op interface from pointer to tensor (#36349) --- .../operators/fused/cudnn_bn_add_relu_test.cc | 64 ++++--------- .../fused/cudnn_bn_stats_finalize.cu.h | 24 +++-- .../operators/fused/cudnn_norm_conv.cu.h | 94 +++++++++++++++---- .../operators/fused/cudnn_norm_conv_test.cc | 61 ++++-------- .../fused/cudnn_scale_bias_add_relu.cu.h | 40 ++++++-- 5 files changed, 161 insertions(+), 122 deletions(-) diff --git a/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc b/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc index 837bca6c2c..709d69214c 100644 --- a/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc +++ b/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc @@ -536,32 +536,20 @@ class CudnnBNAddReluTester { bn_bias->Resize({1, 1, 1, channels_}); // input - float *sum_ptr = sum->data(); - float *sum_of_square_ptr = sum_of_square->data(); - float *bn_scale_ptr = bn_scale->data(); - float *bn_bias_ptr = bn_bias->data(); - mean->Resize({1, 1, 1, channels_}); var->Resize({1, 1, 1, channels_}); // output - float *mean_ptr = mean->data(); - float *var_ptr = var->data(); - float *saved_mean_ptr = - saved_mean->mutable_data({1, 1, 1, channels_}, place); - float *saved_var_ptr = - saved_var->mutable_data({1, 1, 1, channels_}, place); - T *equiv_scale_ptr = - equiv_scale->mutable_data({1, 1, 1, channels_}, place); - T *equiv_bias_ptr = - equiv_bias->mutable_data({1, 1, 1, channels_}, place); + equiv_scale->Resize({1, 1, 1, channels_}); + equiv_bias->Resize({1, 1, 1, channels_}); + saved_mean->Resize({1, 1, 1, channels_}); + saved_var->Resize({1, 1, 1, channels_}); auto param_shape = framework::vectorize(bn_scale->dims()); op::CudnnBNStatsFinalize bn_op(ctx, param_shape); - bn_op.Forward(ctx, sum_ptr, sum_of_square_ptr, bn_scale_ptr, bn_bias_ptr, - saved_mean_ptr, saved_var_ptr, mean_ptr, var_ptr, - equiv_scale_ptr, equiv_bias_ptr, eps_, momentum_, ele_count_, - true); + bn_op.Forward(ctx, *sum, *sum_of_square, *bn_scale, *bn_bias, saved_mean, + saved_var, mean, var, equiv_scale, equiv_bias, eps_, + momentum_, ele_count_, true); } // Get forward results of CudnnBNStatsFinalize + CudnnScaleBiasAddRelu @@ -627,21 +615,13 @@ class CudnnBNAddReluTester { &saved_var_z, &equiv_scale_z, &equiv_bias_z); } - T *x_ptr = x.data(); - T *z_ptr = (fuse_add_ || has_shortcut_) ? z.data() : nullptr; - T *equiv_scale_x_ptr = equiv_scale_x.data(); - T *equiv_bias_x_ptr = equiv_bias_x.data(); - T *equiv_scale_z_ptr = has_shortcut_ ? equiv_scale_z.data() : nullptr; - T *equiv_bias_z_ptr = has_shortcut_ ? equiv_bias_z.data() : nullptr; - T *y_ptr = - y.mutable_data({batch_size_, height_, width_, channels_}, place); + y.Resize(framework::make_ddim({batch_size_, height_, width_, channels_})); int c = channels_; int64_t nhw = ele_count_; int32_t c_int32_elems = ((c + 63) & ~63) / 32; int32_t nhw_int32_elems = (nhw + 31) & ~31; - int32_t *bitmask_ptr = bitmask.mutable_data( - {nhw_int32_elems, c_int32_elems, 1}, place); + bitmask.Resize(framework::make_ddim({nhw_int32_elems, c_int32_elems, 1})); auto data_shape = framework::vectorize(x.dims()); auto param_shape = framework::vectorize(bn_scale_x.dims()); @@ -651,8 +631,8 @@ class CudnnBNAddReluTester { op::CudnnScaleBiasAddRelu sbar_op(ctx, act_type_, fuse_add_, has_shortcut_, data_shape, param_shape, bitmask_shape); - sbar_op.Forward(ctx, x_ptr, equiv_scale_x_ptr, equiv_bias_x_ptr, y_ptr, - bitmask_ptr, z_ptr, equiv_scale_z_ptr, equiv_bias_z_ptr); + sbar_op.Forward(ctx, x, equiv_scale_x, equiv_bias_x, z, equiv_scale_z, + equiv_bias_z, &y, &bitmask); TensorCopySync(mean_x, platform::CPUPlace(), cpu_mean_x); TensorCopySync(var_x, platform::CPUPlace(), cpu_var_x); @@ -697,19 +677,10 @@ class CudnnBNAddReluTester { saved_mean.Resize({1, 1, 1, channels_}); saved_var.Resize({1, 1, 1, channels_}); - T *dy_ptr = dy.data(); - T *x_ptr = x.data(); - float *bn_scale_ptr = bn_scale.data(); - float *bn_bias_ptr = bn_bias.data(); - float *saved_mean_ptr = saved_mean.data(); - float *saved_var_ptr = saved_var.data(); - int32_t *bitmask_ptr = bitmask.data(); - T *dx_ptr = - dx.mutable_data({batch_size_, height_, width_, channels_}, place); - T *dz_ptr = - dz.mutable_data({batch_size_, height_, width_, channels_}, place); - float *dscale_ptr = dscale.mutable_data({1, 1, 1, channels_}, place); - float *dbias_ptr = dbias.mutable_data({1, 1, 1, channels_}, place); + dx.Resize(framework::make_ddim({batch_size_, height_, width_, channels_})); + dz.Resize(framework::make_ddim({batch_size_, height_, width_, channels_})); + dscale.Resize(framework::make_ddim({1, 1, 1, channels_})); + dbias.Resize(framework::make_ddim({1, 1, 1, channels_})); auto data_shape = framework::vectorize(x.dims()); auto param_shape = framework::vectorize(bn_scale.dims()); @@ -718,9 +689,8 @@ class CudnnBNAddReluTester { std::string act_type = "relu"; op::CudnnScaleBiasAddRelu sbar_op(ctx, act_type, true, false, data_shape, param_shape, bitmask_shape); - sbar_op.Backward(ctx, dy_ptr, x_ptr, bn_scale_ptr, bn_bias_ptr, - saved_mean_ptr, saved_var_ptr, bitmask_ptr, dx_ptr, dz_ptr, - dscale_ptr, dbias_ptr, eps_); + sbar_op.Backward(ctx, dy, x, bn_scale, bn_bias, saved_mean, saved_var, + bitmask, &dx, &dz, &dscale, &dbias, eps_); TensorCopySync(dx, platform::CPUPlace(), cpu_dx); TensorCopySync(dz, platform::CPUPlace(), cpu_dz); diff --git a/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h b/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h index 7d4b24cd4f..dc703f9a82 100644 --- a/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h +++ b/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h @@ -68,12 +68,13 @@ class CudnnBNStatsFinalize { } ~CudnnBNStatsFinalize() {} - void Forward(const platform::CUDADeviceContext &ctx, float *sum_ptr, - float *sum_of_squares_ptr, float *scale_ptr, float *bias_ptr, - float *saved_mean_ptr, float *saved_invstd_ptr, - float *running_mean_ptr, float *running_var_ptr, - T *equiv_scale_ptr, T *equiv_bias_ptr, double eps, - float momentum, int64_t ele_count, bool is_train) { + void Forward(const platform::CUDADeviceContext &ctx, const Tensor &sum, + const Tensor &sum_of_squares, const Tensor &scale, + const Tensor &bias, Tensor *saved_mean, Tensor *saved_invstd, + Tensor *running_mean, Tensor *running_var, Tensor *equiv_scale, + Tensor *equiv_bias, double eps, float momentum, + int64_t ele_count, bool is_train) { + auto place = ctx.GetPlace(); if (is_train) { TrainInit(ctx); } else { @@ -82,6 +83,17 @@ class CudnnBNStatsFinalize { auto &op = is_train ? train_op_ : inference_op_; // Set variant_param for both inference_op_ and train_op_ + float *sum_ptr = const_cast(sum.data()); + float *sum_of_squares_ptr = + const_cast(sum_of_squares.data()); + float *scale_ptr = const_cast(scale.data()); + float *bias_ptr = const_cast(bias.data()); + float *saved_mean_ptr = saved_mean->mutable_data(place); + float *saved_invstd_ptr = saved_invstd->mutable_data(place); + float *running_mean_ptr = running_mean->mutable_data(place); + float *running_var_ptr = running_var->mutable_data(place); + T *equiv_scale_ptr = equiv_scale->mutable_data(place); + T *equiv_bias_ptr = equiv_bias->mutable_data(place); op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr); op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_BIAS, bias_ptr); op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_RUNNING_MEAN, running_mean_ptr); diff --git a/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h b/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h index 1a73281cb8..9b9328a5ca 100644 --- a/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h +++ b/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h @@ -38,7 +38,8 @@ struct NormConvolutionArgs { compute_type = platform::CudnnDataType::type; } - void Set(const std::vector &input_shape, + void Set(const platform::CUDADeviceContext &ctx, + const std::vector &input_shape, const std::vector &filter_shape, const std::vector &output_shape, int padding, int stride, int dilation, int group) { @@ -61,12 +62,33 @@ struct NormConvolutionArgs { "The filter_shape is expected to store as nhwc, and " "h = w = 1 or 3. But recieved filter_shape is [%s].", framework::make_ddim(filter_shape))); + PADDLE_ENFORCE_EQ((filter_shape[0] % 32 == 0 && filter_shape[3] % 8 == 0), + true, + platform::errors::InvalidArgument( + "The input channel is expected to be multiple of 8, " + "and the output channel is expected to be multiple " + "of 32. But recieved input channel is %d, output " + "channel is %d.", + filter_shape[3], filter_shape[0])); PADDLE_ENFORCE_EQ( output_shape.size(), 4U, platform::errors::InvalidArgument( "The size of output_shape is expected to 4. But recieved " "filter_shape's size is %d, filter_shape is [%s].", output_shape.size(), framework::make_ddim(output_shape))); + is_support = IsSupport(ctx, filter_shape, stride, dilation, group); + PADDLE_ENFORCE_EQ( + is_support, true, + platform::errors::InvalidArgument( + "Current test is only supported in the platforms with " + "compatiblity greater than or equal to 70 and the kernel size " + "must be equal to 1 or 3. When the kernel size is 1, " + "the stride must be 1 if the compatiblity is equal to 70. " + "Besides, the dilation and group must be equal to 1. But recieved " + "compatiblity is %d, kernel size is %d, stride is %d, " + "dilation is %d, group is %d", + ctx.GetComputeCapability(), filter_shape[1], stride, dilation, + group)); for (size_t i = 0; i < input_shape.size(); ++i) { in_dims.push_back(input_shape[i]); @@ -89,6 +111,25 @@ struct NormConvolutionArgs { conv_desc.set(dtype, paddings, strides, dilations, false, group); } + bool IsSupport(const platform::CUDADeviceContext &ctx, + const std::vector &filter_shape, int stride, int dilation, + int group) { + int kernel_size = filter_shape[1]; + if (dilation != 1 || group != 1) { + return false; + } + if (ctx.GetComputeCapability() == 70) { + if ((kernel_size == 3) || ((kernel_size == 1) && (stride == 1))) { + return true; + } + } else if (ctx.GetComputeCapability() > 70) { + if ((kernel_size == 3) || (kernel_size == 1)) { + return true; + } + } + return false; + } + cudnnDataType_t dtype; cudnnTensorFormat_t format; cudnnDataType_t compute_type; @@ -104,6 +145,8 @@ struct NormConvolutionArgs { platform::TensorDescriptor out_desc; platform::TensorDescriptor out_stats_desc; platform::ConvolutionDescriptor conv_desc; + + bool is_support; }; template @@ -115,15 +158,16 @@ class CudnnNormConvolution { const std::vector &output_shape, const int &padding, const int &stride, const int &dilation, const int &group) { - args_.Set(input_shape, filter_shape, output_shape, padding, stride, + args_.Set(ctx, input_shape, filter_shape, output_shape, padding, stride, dilation, group); } ~CudnnNormConvolution() {} - void Forward(const platform::CUDADeviceContext &ctx, T *input_ptr, - T *filter_ptr, T *output_ptr, float *sum_ptr, - float *sum_of_squares_ptr) { + void Forward(const platform::CUDADeviceContext &ctx, const Tensor &input, + const Tensor &filter, Tensor *output, Tensor *sum, + Tensor *sum_of_squares) { auto cudnn_handle = ctx.cudnn_handle(); + auto place = ctx.GetPlace(); CudnnFusionOp *fwd_op = GetForwardOp(ctx); size_t workspace_size = RoundUp( @@ -132,12 +176,17 @@ class CudnnNormConvolution { // Set variant_param // input ptr + T *input_ptr = const_cast(input.data()); + T *filter_ptr = const_cast(filter.data()); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WDATA, filter_ptr); fwd_op->SetOpVariantParamAttrPtr( CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size); // output ptr + T *output_ptr = output->mutable_data(place); + float *sum_ptr = sum->mutable_data(place); + float *sum_of_squares_ptr = sum_of_squares->mutable_data(place); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, output_ptr); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr); @@ -209,28 +258,34 @@ class CudnnNormConvolutionGrad { const std::vector &output_shape, const int &padding, const int &stride, const int &dilation, const int &group) { - args_.Set(input_shape, filter_shape, output_shape, padding, stride, + args_.Set(ctx, input_shape, filter_shape, output_shape, padding, stride, dilation, group); dgrad_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; } ~CudnnNormConvolutionGrad() {} - void Backward(const platform::CUDADeviceContext &ctx, T *input_ptr, - T *output_grad_ptr, T *filter_ptr, T *input_grad_ptr, - T *filter_grad_ptr, bool use_addto = false) { - if (filter_grad_ptr) { - BackwardFilter(ctx, input_ptr, output_grad_ptr, filter_ptr, - filter_grad_ptr); + void Backward(const platform::CUDADeviceContext &ctx, const Tensor &input, + const Tensor &filter, const Tensor &output_grad, + Tensor *input_grad, Tensor *filter_grad, + bool use_addto = false) { + auto place = ctx.GetPlace(); + T *input_ptr = const_cast(input.data()); + T *filter_ptr = const_cast(filter.data()); + T *output_grad_ptr = const_cast(output_grad.data()); + + if (filter_grad) { + T *filter_grad_ptr = filter_grad->mutable_data(place); + BackwardFilter(ctx, output_grad_ptr, input_ptr, filter_grad_ptr); } - if (input_grad_ptr) { - BackwardData(ctx, input_ptr, output_grad_ptr, filter_ptr, input_grad_ptr, - use_addto); + if (input_grad) { + T *input_grad_ptr = input_grad->mutable_data(place); + BackwardData(ctx, output_grad_ptr, filter_ptr, input_grad_ptr, use_addto); } } private: - void BackwardFilter(const platform::CUDADeviceContext &ctx, T *input_ptr, - T *output_grad_ptr, T *filter_ptr, T *filter_grad_ptr) { + void BackwardFilter(const platform::CUDADeviceContext &ctx, + T *output_grad_ptr, T *input_ptr, T *filter_grad_ptr) { auto cudnn_handle = ctx.cudnn_handle(); CudnnFusionOp *wgrad_op = GetBackwardFilterOp(ctx); @@ -255,9 +310,8 @@ class CudnnNormConvolutionGrad { workspace_size); } - void BackwardData(const platform::CUDADeviceContext &ctx, T *input_ptr, - T *output_grad_ptr, T *filter_ptr, T *input_grad_ptr, - bool use_addto = false) { + void BackwardData(const platform::CUDADeviceContext &ctx, T *output_grad_ptr, + T *filter_ptr, T *input_grad_ptr, bool use_addto = false) { auto cudnn_handle = ctx.cudnn_handle(); size_t workspace_size = GetWorkspaceSizeBwdData(ctx); diff --git a/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc b/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc index 4c14029b99..23983d447e 100644 --- a/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc +++ b/paddle/fluid/operators/fused/cudnn_norm_conv_test.cc @@ -229,15 +229,6 @@ class CudnnNormConvolutionTester { platform::DeviceContextPool::Instance().Get( platform::CUDAPlace(0))); - if (!Support(*ctx)) { - LOG(INFO) - << "Current test is only supported in the platforms with " - << "compatiblity greater than or equal to 70 and the kernel size " - << "must be equal to 1 or 3. Besides, when the kernel size is 1, " - << "the stride must be 1 if the compatiblity is equal to 70."; - return; - } - framework::Tensor cpu_output_base; framework::Tensor cpu_sum_base; framework::Tensor cpu_sum_of_square_base; @@ -325,14 +316,10 @@ class CudnnNormConvolutionTester { TensorCopySync(cpu_input_, place, &input); TensorCopySync(cpu_filter_nhwc_, place, &filter_nhwc); - T *input_ptr = input.data(); - T *filter_ptr = filter_nhwc.data(); - T *output_ptr = output.mutable_data( - {batch_size_, out_height_, out_width_, output_channels_}, place); - float *sum_ptr = - sum.mutable_data({1, 1, 1, output_channels_}, place); - float *sum_of_square_ptr = - sum_of_square.mutable_data({1, 1, 1, output_channels_}, place); + output.Resize(framework::make_ddim( + {batch_size_, out_height_, out_width_, output_channels_})); + sum.Resize(framework::make_ddim({1, 1, 1, output_channels_})); + sum_of_square.Resize(framework::make_ddim({1, 1, 1, output_channels_})); auto input_shape = framework::vectorize(input.dims()); auto filter_shape = framework::vectorize(filter_nhwc.dims()); @@ -340,8 +327,7 @@ class CudnnNormConvolutionTester { op::CudnnNormConvolution conv_op(ctx, input_shape, filter_shape, output_shape, padding_, stride_, dilation_, group_); - conv_op.Forward(ctx, input_ptr, filter_ptr, output_ptr, sum_ptr, - sum_of_square_ptr); + conv_op.Forward(ctx, input, filter_nhwc, &output, &sum, &sum_of_square); TensorCopySync(output, platform::CPUPlace(), cpu_output); TensorCopySync(sum, platform::CPUPlace(), cpu_sum); @@ -362,11 +348,8 @@ class CudnnNormConvolutionTester { TensorCopySync(cpu_filter_nhwc_, place, &filter_nhwc); TensorCopySync(cpu_output_grad_, place, &output_grad); - T *input_ptr = input.data(); - T *filter_ptr = filter_nhwc.data(); - T *output_grad_ptr = output_grad.data(); - T *input_grad_ptr = input_grad.mutable_data(input.dims(), place); - T *filter_grad_ptr = filter_grad.mutable_data(filter_nhwc.dims(), place); + input_grad.Resize(input.dims()); + filter_grad.Resize(filter_nhwc.dims()); auto input_shape = framework::vectorize(input.dims()); auto filter_shape = framework::vectorize(filter_nhwc.dims()); @@ -374,26 +357,13 @@ class CudnnNormConvolutionTester { op::CudnnNormConvolutionGrad conv_grad_op(ctx, input_shape, filter_shape, output_shape, padding_, stride_, dilation_, group_); - conv_grad_op.Backward(ctx, input_ptr, output_grad_ptr, filter_ptr, - input_grad_ptr, filter_grad_ptr); + conv_grad_op.Backward(ctx, input, filter_nhwc, output_grad, &input_grad, + &filter_grad); TensorCopySync(input_grad, platform::CPUPlace(), cpu_input_grad); TensorCopySync(filter_grad, platform::CPUPlace(), cpu_filter_grad); } - bool Support(const platform::CUDADeviceContext &ctx) { - if (ctx.GetComputeCapability() == 70) { - if ((kernel_size_ == 3) || ((kernel_size_ == 1) && (stride_ == 1))) { - return true; - } - } else if (ctx.GetComputeCapability() > 70) { - if ((kernel_size_ == 3) || (kernel_size_ == 1)) { - return true; - } - } - return false; - } - private: int batch_size_; int height_; @@ -477,6 +447,15 @@ TEST(CudnnNormConvFp16, K1S2O4) { CudnnNormConvolutionTester test( batch_size, height, width, input_channels, output_channels, kernel_size, stride); - test.CheckForward(1e-3, true); - test.CheckBackward(1e-3); + platform::CUDADeviceContext *ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0))); + + if (ctx->GetComputeCapability() <= 70) { + ASSERT_THROW(test.CheckForward(1e-3, true), + paddle::platform::EnforceNotMet); + ASSERT_THROW(test.CheckBackward(1e-3), paddle::platform::EnforceNotMet); + } else { + ASSERT_NO_THROW(test.CheckForward(1e-3, true)); + ASSERT_NO_THROW(test.CheckBackward(1e-3)); + } } diff --git a/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h b/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h index 2fdb3635e2..b48c964d26 100644 --- a/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h +++ b/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h @@ -107,25 +107,33 @@ class CudnnScaleBiasAddRelu { ~CudnnScaleBiasAddRelu() {} - void Forward(const platform::CUDADeviceContext &ctx, T *x_ptr, T *x_scale_ptr, - T *x_bias_ptr, T *out_ptr, int32_t *bitmask_ptr, - T *z_ptr = nullptr, T *z_scale_ptr = nullptr, - T *z_bias_ptr = nullptr) { + void Forward(const platform::CUDADeviceContext &ctx, const Tensor &x, + const Tensor &x_scale, const Tensor &x_bias, const Tensor &z, + const Tensor &z_scale, const Tensor &z_bias, Tensor *out, + Tensor *bitmask) { ForwardInit(ctx); auto handle = ctx.cudnn_handle(); + auto place = ctx.GetPlace(); auto workspace_handle = ctx.cudnn_workspace_handle(); fwd_workspace_byte_ = fwd_op_.GetWorkspaceSizeInBytes(handle); // Set variant_param // input ptr + T *x_ptr = const_cast(x.data()); + T *x_scale_ptr = const_cast(x_scale.data()); + T *x_bias_ptr = const_cast(x_bias.data()); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, x_scale_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, x_bias_ptr); if (has_shortcut_) { + T *z_ptr = const_cast(z.data()); + T *z_scale_ptr = const_cast(z_scale.data()); + T *z_bias_ptr = const_cast(z_bias.data()); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQSCALE, z_scale_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQBIAS, z_bias_ptr); } else { if (fused_add_) { + T *z_ptr = const_cast(z.data()); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr); } } @@ -134,6 +142,8 @@ class CudnnScaleBiasAddRelu { CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &fwd_workspace_byte_); // output ptr + T *out_ptr = out->mutable_data(place); + int32_t *bitmask_ptr = bitmask->mutable_data(place); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, out_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ACTIVATION_BITMASK, bitmask_ptr); @@ -147,16 +157,30 @@ class CudnnScaleBiasAddRelu { fwd_workspace_byte_); } - void Backward(const platform::CUDADeviceContext &ctx, T *dy_ptr, T *x_ptr, - float *scale_ptr, float *bias_ptr, float *saved_mean_ptr, - float *saved_invstd_ptr, int32_t *bitmask_ptr, T *dx_ptr, - T *dz_ptr, float *dscale_ptr, float *dbias_ptr, double eps) { + void Backward(const platform::CUDADeviceContext &ctx, const Tensor &dy, + const Tensor &x, const Tensor &scale, const Tensor &bias, + const Tensor &saved_mean, const Tensor &saved_invstd, + const Tensor &bitmask, Tensor *dx, Tensor *dz, Tensor *dscale, + Tensor *dbias, double eps) { BackwardInit(ctx); auto handle = ctx.cudnn_handle(); + auto place = ctx.GetPlace(); auto workspace_handle = ctx.cudnn_workspace_handle(); bwd_workspace_byte_ = bwd_op_.GetWorkspaceSizeInBytes(handle); // Set variant_param // input ptr + T *dy_ptr = const_cast(dy.data()); + T *x_ptr = const_cast(x.data()); + float *scale_ptr = const_cast(scale.data()); + float *bias_ptr = const_cast(bias.data()); + float *saved_mean_ptr = const_cast(saved_mean.data()); + float *saved_invstd_ptr = const_cast(saved_invstd.data()); + int32_t *bitmask_ptr = const_cast(bitmask.data()); + T *dx_ptr = dx->mutable_data(place); + T *dz_ptr = dz ? dz->mutable_data(place) : nullptr; + float *dscale_ptr = dscale ? dscale->mutable_data(place) : nullptr; + float *dbias_ptr = dbias ? dbias->mutable_data(place) : nullptr; + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr); bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DYDATA, dy_ptr); bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr); -- GitLab