diff --git a/paddle/fluid/operators/fused/fused_bn_activation_op.cu b/paddle/fluid/operators/fused/fused_bn_activation_op.cu deleted file mode 100644 index a93938fcfd043dc01dfa34f81890f4c76a012897..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/fused/fused_bn_activation_op.cu +++ /dev/null @@ -1,444 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#include "cub/cub.cuh" -#include "paddle/fluid/framework/data_layout.h" -#include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/operators/fused/fused_bn_activation_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/phi/core/flags.h" -#include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/funcs/norm_utils.h" - -PHI_DECLARE_bool(cudnn_batchnorm_spatial_persistent); - -namespace paddle { -namespace operators { -template -using CudnnDataType = platform::CudnnDataType; -template -using BatchNormParamType = typename CudnnDataType::BatchNormParamType; - -template -class FusedBatchNormActKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { -#if CUDNN_VERSION < 7401 - PADDLE_THROW(phi::errors::Unimplemented( - "The fused_batch_norm_act operator is not supported on GPU " - "when CUDNN version < 7.4.1")); -#endif - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(ctx.GetPlace()), - true, - platform::errors::PreconditionNotMet("It must use CUDAPlace.")); - auto &dev_ctx = ctx.template device_context(); - double epsilon = static_cast(ctx.Attr("epsilon")); - float momentum = ctx.Attr("momentum"); - std::string act_type = ctx.Attr("act_type"); - - if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { - LOG(ERROR) << "Provided epsilon is smaller than " - << "CUDNN_BN_MIN_EPSILON. Setting it to " - << "CUDNN_BN_MIN_EPSILON instead."; - } - epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); - - // Get the size for each dimension. - // NHWC [batch_size, in_height, in_width, in_channels] - const auto *x = ctx.Input("X"); - const auto &x_dims = x->dims(); - PADDLE_ENFORCE_EQ(x_dims.size() >= 2 && x_dims.size() <= 5, - true, - platform::errors::PreconditionNotMet( - "The Input dim size should be between 2 and 5")); - - const auto *scale = ctx.Input("Scale"); - const auto *bias = ctx.Input("Bias"); - - // Run training mode. - // obtain running mean and running inv var, and see if we need to - // initialize them. - auto *mean_out = ctx.Output("MeanOut"); - auto *variance_out = ctx.Output("VarianceOut"); - dev_ctx.Alloc>( - mean_out, mean_out->numel() * sizeof(BatchNormParamType)); - dev_ctx.Alloc>( - variance_out, variance_out->numel() * sizeof(BatchNormParamType)); - - auto *saved_mean = ctx.Output("SavedMean"); - auto *saved_variance = ctx.Output("SavedVariance"); - dev_ctx.Alloc>( - saved_mean, saved_mean->numel() * sizeof(BatchNormParamType)); - dev_ctx.Alloc>( - saved_variance, - saved_variance->numel() * sizeof(BatchNormParamType)); - - auto *y = ctx.Output("Y"); - dev_ctx.Alloc(y, y->numel() * sizeof(T)); - - int N, C, H, W, D; - const DataLayout data_layout = DataLayout::kNHWC; - phi::funcs::ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); - - if ((N * H * W * D) == 1) { - // Only 1 element in normalization dimension, - // skip the batch norm calculation, let y = act(x). - auto x_v = framework::EigenVector::Flatten(*x); - auto y_v = framework::EigenVector::Flatten(*y); - auto &dev = *dev_ctx.eigen_device(); - if (act_type == "relu") { - ReluCUDAFunctor()(dev, x_v, y_v); - } else { - PADDLE_THROW( - platform::errors::Unimplemented("Unsupported activation type")); - } - return; - } - - // ------------------- cudnn descriptors --------------------- - auto handle = dev_ctx.cudnn_handle(); - cudnnTensorDescriptor_t data_desc_; - cudnnTensorDescriptor_t bn_param_desc_; - cudnnBatchNormMode_t mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); - - VLOG(3) << "Setting descriptors."; - std::vector dims = {N, C, H, W, D}; - std::vector strides = {H * W * D * C, 1, W * D * C, D * C, C}; - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - data_desc_, - CudnnDataType::type, - x_dims.size() > 3 ? x_dims.size() : 4, - dims.data(), - strides.data())); - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnDeriveBNTensorDescriptor( - bn_param_desc_, data_desc_, mode_)); - - double this_factor = 1. - momentum; - cudnnBatchNormOps_t bnOps_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; - platform::ScopedActivationDescriptor scope_act_desc; - cudnnActivationDescriptor_t activation_desc_ = - scope_act_desc.descriptor(act_type); - size_t workspace_size = 0; - size_t reserve_space_size = 0; - void *reserve_space_ptr = nullptr; - void *workspace_ptr = nullptr; - phi::DenseTensor workspace_tensor; - // Create reserve space and workspace for batch norm. - // Create tensor for each batchnorm op, it will be used in the - // backward. Thus this tensor shouldn't be temp. - auto *reserve_space = ctx.Output("ReserveSpace"); - PADDLE_ENFORCE_NOT_NULL( - reserve_space, - platform::errors::NotFound( - "The argument ReserveSpace of batch_norm op is not found.")); - - // --------------- cudnn batchnorm workspace --------------- - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload:: - cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( - /*handle=*/handle, - /*mode=*/mode_, - /*bnOps=*/bnOps_, - /*xDesc=*/data_desc_, - /*zDesc=*/nullptr, - /*yDesc=*/data_desc_, - /*bnScaleBiasMeanVarDesc=*/bn_param_desc_, - /*activationDesc=*/activation_desc_, - /*sizeInBytes=*/&workspace_size)); - - // -------------- cudnn batchnorm reserve space -------------- - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetBatchNormalizationTrainingExReserveSpaceSize( - /*handle=*/handle, - /*mode=*/mode_, - /*bnOps=*/bnOps_, - /*activationDesc=*/activation_desc_, - /*xDesc=*/data_desc_, - /*sizeInBytes=*/&reserve_space_size)); - - reserve_space->Resize({static_cast( - (reserve_space_size + phi::SizeOf(x->dtype()) - 1) / - phi::SizeOf(x->dtype()))}); - reserve_space_ptr = - dev_ctx.Alloc(reserve_space, reserve_space->numel() * sizeof(T)); - workspace_tensor.Resize( - {static_cast((workspace_size + phi::SizeOf(x->dtype()) - 1) / - phi::SizeOf(x->dtype()))}); - workspace_ptr = dev_ctx.Alloc(&workspace_tensor, - workspace_tensor.numel() * sizeof(T)); - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnBatchNormalizationForwardTrainingEx( - handle, - mode_, - bnOps_, - CudnnDataType::kOne(), - CudnnDataType::kZero(), - data_desc_, - x->template data(), - nullptr, - nullptr, - data_desc_, - y->template data(), - bn_param_desc_, - scale->template data>(), - bias->template data>(), - this_factor, - dev_ctx.template Alloc>( - mean_out, mean_out->numel() * sizeof(BatchNormParamType)), - dev_ctx.template Alloc>( - variance_out, - variance_out->numel() * sizeof(BatchNormParamType)), - epsilon, - dev_ctx.template Alloc>( - saved_mean, - saved_mean->numel() * sizeof(BatchNormParamType)), - dev_ctx.template Alloc>( - saved_variance, - saved_variance->numel() * sizeof(BatchNormParamType)), - activation_desc_, - workspace_ptr, - workspace_size, - reserve_space_ptr, - reserve_space_size)); - - // clean when exit. - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); - } -}; - -template -class FusedBatchNormActGradKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { -#if CUDNN_VERSION < 7401 - PADDLE_THROW(phi::errors::Unimplemented( - "The fused_batch_norm_act operator is not supported on GPU " - "when CUDNN version < 7.4.1")); -#endif - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(ctx.GetPlace()), - true, - platform::errors::PreconditionNotMet("It must use CUDAPlace.")); - double epsilon = static_cast(ctx.Attr("epsilon")); - std::string act_type = ctx.Attr("act_type"); - auto &dev_ctx = ctx.template device_context(); - const auto *x = ctx.Input("X"); - const auto *y = ctx.Input("Y"); - const auto *d_y = ctx.Input(framework::GradVarName("Y")); - const auto *scale = ctx.Input("Scale"); - const auto *bias = ctx.Input("Bias"); - const auto *reserve_space = ctx.Input("ReserveSpace"); - - const auto &x_dims = x->dims(); - - PADDLE_ENFORCE_EQ(x_dims.size() >= 2 && x_dims.size() <= 5, - true, - platform::errors::PreconditionNotMet( - "The Input dim size should be between 2 and 5")); - int N, C, H, W, D; - const DataLayout data_layout = DataLayout::kNHWC; - phi::funcs::ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); - - // init output - auto *d_x = ctx.Output(framework::GradVarName("X")); - auto *d_scale = - ctx.Output(framework::GradVarName("Scale")); - auto *d_bias = ctx.Output(framework::GradVarName("Bias")); - - dev_ctx.Alloc(d_x, d_x->numel() * sizeof(T)); - PADDLE_ENFORCE_EQ( - d_scale && d_bias, - true, - platform::errors::PreconditionNotMet( - "Both the scale grad and the bias grad must not be null.")); - dev_ctx.Alloc>( - d_scale, d_scale->numel() * sizeof(BatchNormParamType)); - dev_ctx.Alloc>( - d_bias, d_bias->numel() * sizeof(BatchNormParamType)); - PADDLE_ENFORCE_EQ(scale->dims().size(), - 1UL, - platform::errors::PreconditionNotMet( - "The scale only has one dimension.")); - PADDLE_ENFORCE_EQ( - scale->dims()[0], - C, - platform::errors::PreconditionNotMet( - "The size of scale is equal to the channel of Input(X).")); - - if ((N * H * W * D) == 1) { - if (act_type == "relu") { - auto x_v = framework::EigenVector::Flatten(*x); - auto y_v = framework::EigenVector::Flatten(*y); - auto dx_v = framework::EigenVector::Flatten(*d_x); - auto dy_v = framework::EigenVector::Flatten(*d_y); - auto &dev = *dev_ctx.eigen_device(); - ReluGradFunctor()(dev, x_v, y_v, dy_v, dx_v); - } else { - PADDLE_THROW( - platform::errors::Unimplemented("Unsupported activation type")); - } - phi::funcs::SetConstant> functor; - functor(dev_ctx, d_scale, static_cast>(0)); - functor(dev_ctx, d_bias, static_cast>(0)); - return; - } - - std::vector dims = {N, C, H, W, D}; - std::vector strides = {H * W * C * D, 1, W * D * C, D * C, C}; - // ------------------- cudnn descriptors --------------------- - cudnnTensorDescriptor_t data_desc_; - cudnnTensorDescriptor_t bn_param_desc_; - cudnnBatchNormMode_t mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); - if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { - LOG(ERROR) << "Provided epsilon is smaller than " - << "CUDNN_BN_MIN_EPSILON. Setting it to " - << "CUDNN_BN_MIN_EPSILON instead."; - } - epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - data_desc_, - CudnnDataType::type, - x_dims.size() > 3 ? x_dims.size() : 4, - dims.data(), - strides.data())); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnDeriveBNTensorDescriptor( - bn_param_desc_, data_desc_, mode_)); - - const auto *saved_mean = ctx.Input("SavedMean"); - const auto *saved_var = ctx.Input("SavedVariance"); - const auto *saved_mean_data = - saved_mean->template data>(); - const auto *saved_var_data = - saved_var->template data>(); - - size_t workspace_size = 0; - void *workspace_ptr = nullptr; - phi::DenseTensor workspace_tensor; - auto reserve_space_size = reserve_space->memory_size(); - cudnnBatchNormOps_t bnOps_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; - platform::ScopedActivationDescriptor scope_act_desc; - cudnnActivationDescriptor_t activation_desc_ = - scope_act_desc.descriptor(act_type); - // --------------- cudnn batchnorm workspace --------------- - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetBatchNormalizationBackwardExWorkspaceSize( - /*handle=*/dev_ctx.cudnn_handle(), - /*mode=*/mode_, - /*bnOps=*/bnOps_, - /*xDesc=*/data_desc_, - /*yDesc=*/data_desc_, - /*dyDesc=*/data_desc_, - /*dzDesc=*/nullptr, - /*dxDesc=*/data_desc_, - /*bnScaleBiasMeanVarDesc=*/bn_param_desc_, - /*activationDesc=*/activation_desc_, - /*sizeInBytes=*/&workspace_size)); - - workspace_tensor.Resize( - {static_cast((workspace_size + phi::SizeOf(x->dtype()) - 1) / - phi::SizeOf(x->dtype()))}); - workspace_ptr = dev_ctx.Alloc(&workspace_tensor, - workspace_tensor.numel() * sizeof(T)); - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnBatchNormalizationBackwardEx( - /*handle=*/dev_ctx.cudnn_handle(), - /*mode=*/mode_, - /*bnOps=*/bnOps_, - /*alphaDataDiff=*/CudnnDataType::kOne(), - /*betaDataDiff=*/CudnnDataType::kZero(), - /*alphaParamDiff=*/CudnnDataType::kOne(), - /*betaParamDiff=*/CudnnDataType::kZero(), - /*xDesc=*/data_desc_, - /*xData=*/x->template data(), - /*yDesc=*/data_desc_, - /*yData=*/y->template data(), - /*dyDesc=*/data_desc_, - /*dyData=*/d_y->template data(), - /*dzDesc=*/nullptr, - /*dzData=*/nullptr, - /*dxDesc=*/data_desc_, - /*dxData=*/ - dev_ctx.template Alloc(d_x, d_x->numel() * sizeof(T)), - /*dBnScaleBiasDesc=*/bn_param_desc_, - /*bnScaleData=*/scale->template data>(), - /*bnBiasData=*/bias->template data>(), - /*dBnScaleData=*/ - dev_ctx.template Alloc>( - d_scale, d_scale->numel() * sizeof(BatchNormParamType)), - /*dBnBiasData=*/ - dev_ctx.template Alloc>( - d_bias, d_bias->numel() * sizeof(BatchNormParamType)), - /*epsilon=*/epsilon, - /*savedMean=*/saved_mean_data, - /*savedInvVariance=*/saved_var_data, - /*activationDesc=*/activation_desc_, - /*workspace=*/workspace_ptr, - /*workSpaceSizeInBytes=*/workspace_size, - /*reserveSpace=*/const_cast(reserve_space->template data()), - /*reserveSpaceSizeInBytes=*/reserve_space_size)); - - // clean when exit. - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -PD_REGISTER_STRUCT_KERNEL(fused_batch_norm_act, - GPU, - ALL_LAYOUT, - ops::FusedBatchNormActKernel, - float, - double, - plat::float16) {} -PD_REGISTER_STRUCT_KERNEL(fused_batch_norm_act_grad, - GPU, - ALL_LAYOUT, - ops::FusedBatchNormActGradKernel, - float, - double, - plat::float16) {} diff --git a/paddle/phi/kernels/fused_bn_activation_grad_kernel.h b/paddle/phi/kernels/fused_bn_activation_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1f74e44cbb009e383837fbdba432864476f7bdee --- /dev/null +++ b/paddle/phi/kernels/fused_bn_activation_grad_kernel.h @@ -0,0 +1,37 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void FusedBatchNormActGradKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &scale, + const DenseTensor &bias, + const DenseTensor &saved_mean, + const DenseTensor &saved_variance, + const DenseTensor &reserve_space, + const DenseTensor &y, + const DenseTensor &y_grad, + float momentum, + float epsilon, + const std::string &act_type, + DenseTensor *x_grad, + DenseTensor *scale_grad, + DenseTensor *bias_grad); +} // namespace phi diff --git a/paddle/phi/kernels/fused_bn_activation_kernel.h b/paddle/phi/kernels/fused_bn_activation_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..d3f26d8c4e34135af8b8d14103c42c111701b770 --- /dev/null +++ b/paddle/phi/kernels/fused_bn_activation_kernel.h @@ -0,0 +1,37 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void FusedBatchNormActKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &scale, + const DenseTensor &bias, + const DenseTensor &mean, + const DenseTensor &variance, + float momentum, + float epsilon, + const std::string &act_type, + DenseTensor *y, + DenseTensor *mean_out, + DenseTensor *variance_out, + DenseTensor *saved_mean, + DenseTensor *saved_variance, + DenseTensor *reserve_space); +} // namespace phi diff --git a/paddle/phi/kernels/fusion/gpu/fused_bn_activation_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bn_activation_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..a27eb5149308e0a0a546a8c6f17b5318074b37ba --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_bn_activation_grad_kernel.cu @@ -0,0 +1,237 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/flags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/activation_functor.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/norm_utils.h" +#include "paddle/phi/kernels/fused_bn_activation_grad_kernel.h" + +PHI_DECLARE_bool(cudnn_batchnorm_spatial_persistent); + +namespace phi { +namespace fusion { + +template +void FusedBatchNormActGradKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &scale, + const DenseTensor &bias, + const DenseTensor &saved_mean, + const DenseTensor &saved_variance, + const DenseTensor &reserve_space, + const DenseTensor &y, + const DenseTensor &y_grad, + float momentum, + float epsilon, + const std::string &act_type, + DenseTensor *x_grad, + DenseTensor *scale_grad, + DenseTensor *bias_grad) { +// Note(andsonder): Fused bn activation only used in the gpu place. +#if defined(PADDLE_WITH_CUDA) and CUDNN_VERSION >= 7401 + using CudnnDataType = phi::backends::gpu::CudnnDataType; + using BatchNormParamType = typename CudnnDataType::BatchNormParamType; + bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU; + PADDLE_ENFORCE_EQ(is_gpu_place, + true, + phi::errors::PreconditionNotMet("It must use CUDAPlace.")); + double epsilon1 = static_cast(epsilon); + + const auto *d_y = &y_grad; + + const auto &x_dims = x.dims(); + + PADDLE_ENFORCE_EQ(x_dims.size() >= 2 && x_dims.size() <= 5, + true, + phi::errors::PreconditionNotMet( + "The Input dim size should be between 2 and 5")); + int N, C, H, W, D; + const phi::DataLayout data_layout = phi::DataLayout::kNHWC; + phi::funcs::ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); + + // init output + auto *d_x = x_grad; + auto *d_scale = scale_grad; + auto *d_bias = bias_grad; + + dev_ctx.template Alloc(d_x); + PADDLE_ENFORCE_EQ( + d_scale && d_bias, + true, + phi::errors::PreconditionNotMet( + "Both the scale grad and the bias grad must not be null.")); + dev_ctx.template Alloc(d_scale); + dev_ctx.template Alloc(d_bias); + PADDLE_ENFORCE_EQ( + scale.dims().size(), + 1UL, + phi::errors::PreconditionNotMet("The scale only has one dimension.")); + PADDLE_ENFORCE_EQ( + scale.dims()[0], + C, + phi::errors::PreconditionNotMet( + "The size of scale is equal to the channel of Input(X).")); + + if ((N * H * W * D) == 1) { + if (act_type == "relu") { + auto x_v = phi::EigenVector::Flatten(x); + auto y_v = phi::EigenVector::Flatten(y); + auto dx_v = phi::EigenVector::Flatten(*d_x); + auto dy_v = phi::EigenVector::Flatten(*d_y); + auto &dev = *dev_ctx.eigen_device(); + phi::funcs::ReluGradFunctor()(dev, x_v, y_v, dy_v, dx_v); + } else { + PADDLE_THROW(phi::errors::Unimplemented("Unsupported activation type")); + } + phi::funcs::SetConstant functor; + functor(dev_ctx, d_scale, static_cast(0)); + functor(dev_ctx, d_bias, static_cast(0)); + return; + } + + std::vector dims = {N, C, H, W, D}; + std::vector strides = {H * W * C * D, 1, W * D * C, D * C, C}; + // ------------------- cudnn descriptors --------------------- + cudnnTensorDescriptor_t data_desc_; + cudnnTensorDescriptor_t bn_param_desc_; + cudnnBatchNormMode_t mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; + + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnCreateTensorDescriptor(&data_desc_)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); + if (epsilon1 <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { + LOG(ERROR) << "Provided epsilon is smaller than " + << "CUDNN_BN_MIN_EPSILON. Setting it to " + << "CUDNN_BN_MIN_EPSILON instead."; + } + epsilon1 = std::max(epsilon1, CUDNN_BN_MIN_EPSILON); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetTensorNdDescriptor( + data_desc_, + CudnnDataType::type, + x_dims.size() > 3 ? x_dims.size() : 4, + dims.data(), + strides.data())); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnDeriveBNTensorDescriptor( + bn_param_desc_, data_desc_, mode_)); + + const auto *saved_mean_data = saved_mean.template data(); + const auto *saved_var_data = + saved_variance.template data(); + + size_t workspace_size = 0; + void *workspace_ptr = nullptr; + phi::DenseTensor workspace_tensor; + auto reserve_space_size = reserve_space.memory_size(); + cudnnBatchNormOps_t bnOps_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; + phi::backends::gpu::ScopedActivationDescriptor scope_act_desc; + cudnnActivationDescriptor_t activation_desc_ = + scope_act_desc.descriptor(act_type); + // --------------- cudnn batchnorm workspace --------------- + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnGetBatchNormalizationBackwardExWorkspaceSize( + /*handle=*/dev_ctx.cudnn_handle(), + /*mode=*/mode_, + /*bnOps=*/bnOps_, + /*xDesc=*/data_desc_, + /*yDesc=*/data_desc_, + /*dyDesc=*/data_desc_, + /*dzDesc=*/nullptr, + /*dxDesc=*/data_desc_, + /*bnScaleBiasMeanVarDesc=*/bn_param_desc_, + /*activationDesc=*/activation_desc_, + /*sizeInBytes=*/&workspace_size)); + + workspace_tensor.Resize({static_cast( + (workspace_size + phi::SizeOf(x.dtype()) - 1) / phi::SizeOf(x.dtype()))}); + workspace_ptr = dev_ctx.template Alloc(&workspace_tensor); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnBatchNormalizationBackwardEx( + /*handle=*/dev_ctx.cudnn_handle(), + /*mode=*/mode_, + /*bnOps=*/bnOps_, + /*alphaDataDiff=*/CudnnDataType::kOne(), + /*betaDataDiff=*/CudnnDataType::kZero(), + /*alphaParamDiff=*/CudnnDataType::kOne(), + /*betaParamDiff=*/CudnnDataType::kZero(), + /*xDesc=*/data_desc_, + /*xData=*/x.template data(), + /*yDesc=*/data_desc_, + /*yData=*/y.template data(), + /*dyDesc=*/data_desc_, + /*dyData=*/d_y->template data(), + /*dzDesc=*/nullptr, + /*dzData=*/nullptr, + /*dxDesc=*/data_desc_, + /*dxData=*/ + dev_ctx.template Alloc(d_x, d_x->numel() * sizeof(T)), + /*dBnScaleBiasDesc=*/bn_param_desc_, + /*bnScaleData=*/scale.template data(), + /*bnBiasData=*/bias.template data(), + /*dBnScaleData=*/ + dev_ctx.template Alloc(d_scale), + /*dBnBiasData=*/ + dev_ctx.template Alloc(d_bias), + /*epsilon=*/epsilon1, + /*savedMean=*/saved_mean_data, + /*savedInvVariance=*/saved_var_data, + /*activationDesc=*/activation_desc_, + /*workspace=*/workspace_ptr, + /*workSpaceSizeInBytes=*/workspace_size, + /*reserveSpace=*/const_cast(reserve_space.template data()), + /*reserveSpaceSizeInBytes=*/reserve_space_size)); + + // clean when exit. + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnDestroyTensorDescriptor(data_desc_)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "The fused_batch_norm_act operator is not supported on GPU " + "when CUDNN version < 7.4.1")); +#endif +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_batch_norm_act_grad, + GPU, + ALL_LAYOUT, + phi::fusion::FusedBatchNormActGradKernel, + float, + double, + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + } +} diff --git a/paddle/phi/kernels/fusion/gpu/fused_bn_activation_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bn_activation_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..700141f1e03318964290c77305e049a62a2e136c --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_bn_activation_kernel.cu @@ -0,0 +1,237 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/flags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/activation_functor.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/norm_utils.h" + +PHI_DECLARE_bool(cudnn_batchnorm_spatial_persistent); + +namespace phi { +namespace fusion { + +template +void FusedBatchNormActKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &scale, + const DenseTensor &bias, + const DenseTensor &mean, + const DenseTensor &variance, + float momentum, + float epsilon, + const std::string &act_type, + DenseTensor *y, + DenseTensor *mean_out, + DenseTensor *variance_out, + DenseTensor *saved_mean, + DenseTensor *saved_variance, + DenseTensor *reserve_space) { +// Note(andsonder): Fused bn activation only used in the gpu place. +#if defined(PADDLE_WITH_CUDA) and CUDNN_VERSION >= 7401 + using CudnnDataType = phi::backends::gpu::CudnnDataType; + using BatchNormParamType = typename CudnnDataType::BatchNormParamType; + bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU; + PADDLE_ENFORCE_EQ(is_gpu_place, + true, + phi::errors::PreconditionNotMet("It must use CUDAPlace.")); + double epsilon1 = static_cast(epsilon); + + if (epsilon1 <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { + LOG(ERROR) << "Provided epsilon is smaller than " + << "CUDNN_BN_MIN_EPSILON. Setting it to " + << "CUDNN_BN_MIN_EPSILON instead."; + } + epsilon1 = std::max(epsilon1, CUDNN_BN_MIN_EPSILON); + + // Get the size for each dimension. + // NHWC [batch_size, in_height, in_width, in_channels] + const auto &x_dims = x.dims(); + PADDLE_ENFORCE_EQ(x_dims.size() >= 2 && x_dims.size() <= 5, + true, + phi::errors::PreconditionNotMet( + "The Input dim size should be between 2 and 5")); + + // Run training mode. + // obtain running mean and running inv var, and see if we need to + // initialize them. + dev_ctx.template Alloc(mean_out); + dev_ctx.template Alloc(variance_out); + + dev_ctx.template Alloc(saved_mean); + dev_ctx.template Alloc(saved_variance); + + dev_ctx.template Alloc(y); + + int N, C, H, W, D; + const DataLayout data_layout = phi::DataLayout::kNHWC; + phi::funcs::ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); + + if ((N * H * W * D) == 1) { + // Only 1 element in normalization dimension, + // skip the batch norm calculation, let y = act(x). + auto x_v = phi::EigenVector::Flatten(x); + auto y_v = phi::EigenVector::Flatten(*y); + auto &dev = *dev_ctx.eigen_device(); + if (act_type == "relu") { + phi::funcs::ReluCUDAFunctor()(dev, x_v, y_v); + } else { + PADDLE_THROW(phi::errors::Unimplemented("Unsupported activation type")); + } + return; + } + + // ------------------- cudnn descriptors --------------------- + auto handle = dev_ctx.cudnn_handle(); + cudnnTensorDescriptor_t data_desc_; + cudnnTensorDescriptor_t bn_param_desc_; + cudnnBatchNormMode_t mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; + + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnCreateTensorDescriptor(&data_desc_)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); + + VLOG(3) << "Setting descriptors."; + std::vector dims = {N, C, H, W, D}; + std::vector strides = {H * W * D * C, 1, W * D * C, D * C, C}; + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetTensorNdDescriptor( + data_desc_, + CudnnDataType::type, + x_dims.size() > 3 ? x_dims.size() : 4, + dims.data(), + strides.data())); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnDeriveBNTensorDescriptor( + bn_param_desc_, data_desc_, mode_)); + + double this_factor = 1. - momentum; + cudnnBatchNormOps_t bnOps_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; + phi::backends::gpu::ScopedActivationDescriptor scope_act_desc; + cudnnActivationDescriptor_t activation_desc_ = + scope_act_desc.descriptor(act_type); + size_t workspace_size = 0; + size_t reserve_space_size = 0; + void *reserve_space_ptr = nullptr; + void *workspace_ptr = nullptr; + phi::DenseTensor workspace_tensor; + + PADDLE_ENFORCE_NOT_NULL( + reserve_space, + phi::errors::NotFound( + "The argument ReserveSpace of batch_norm op is not found.")); + + // --------------- cudnn batchnorm workspace --------------- + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( + /*handle=*/handle, + /*mode=*/mode_, + /*bnOps=*/bnOps_, + /*xDesc=*/data_desc_, + /*zDesc=*/nullptr, + /*yDesc=*/data_desc_, + /*bnScaleBiasMeanVarDesc=*/bn_param_desc_, + /*activationDesc=*/activation_desc_, + /*sizeInBytes=*/&workspace_size)); + + // -------------- cudnn batchnorm reserve space -------------- + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnGetBatchNormalizationTrainingExReserveSpaceSize( + /*handle=*/handle, + /*mode=*/mode_, + /*bnOps=*/bnOps_, + /*activationDesc=*/activation_desc_, + /*xDesc=*/data_desc_, + /*sizeInBytes=*/&reserve_space_size)); + + reserve_space->Resize( + {static_cast((reserve_space_size + phi::SizeOf(x.dtype()) - 1) / + phi::SizeOf(x.dtype()))}); + reserve_space_ptr = dev_ctx.template Alloc(reserve_space); + workspace_tensor.Resize({static_cast( + (workspace_size + phi::SizeOf(x.dtype()) - 1) / phi::SizeOf(x.dtype()))}); + workspace_ptr = dev_ctx.template Alloc(&workspace_tensor); + + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnBatchNormalizationForwardTrainingEx( + handle, + mode_, + bnOps_, + CudnnDataType::kOne(), + CudnnDataType::kZero(), + data_desc_, + x.template data(), + nullptr, + nullptr, + data_desc_, + y->template data(), + bn_param_desc_, + scale.template data(), + bias.template data(), + this_factor, + dev_ctx.template Alloc(mean_out), + dev_ctx.template Alloc(variance_out), + epsilon1, + dev_ctx.template Alloc(saved_mean), + dev_ctx.template Alloc(saved_variance), + activation_desc_, + workspace_ptr, + workspace_size, + reserve_space_ptr, + reserve_space_size)); + + // clean when exit. + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnDestroyTensorDescriptor(data_desc_)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "The fused_batch_norm_act operator is not supported on GPU " + "when CUDNN version < 7.4.1")); +#endif +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_batch_norm_act, + GPU, + ALL_LAYOUT, + phi::fusion::FusedBatchNormActKernel, + float, + double, + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + } +} diff --git a/paddle/phi/ops/compat/fused_bn_activation_sig.cc b/paddle/phi/ops/compat/fused_bn_activation_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..1d478057debc230f4d6320f6c9f863dc3bbbe934 --- /dev/null +++ b/paddle/phi/ops/compat/fused_bn_activation_sig.cc @@ -0,0 +1,53 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature BatchNormActFuseOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("fused_batch_norm_act", + {"X", "Scale", "Bias", "Mean", "Variance"}, + {"momentum", "epsilon", "act_type"}, + {"Y", + "MeanOut", + "VarianceOut", + "SavedMean", + "SavedVariance", + "ReserveSpace"}); +} + +KernelSignature BatchNormActGradFuseOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("fused_batch_norm_act_grad", + {"X", + "Scale", + "Bias", + "SavedMean", + "SavedVariance", + "ReserveSpace", + "Y", + "Y@GRAD"}, + {"momentum", "epsilon", "act_type"}, + {"X@GRAD", "Scale@GRAD", "Bias@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(fused_batch_norm_act, + phi::BatchNormActFuseOpArgumentMapping); + +PD_REGISTER_ARG_MAPPING_FN(fused_batch_norm_act_grad, + phi::BatchNormActGradFuseOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 8ca0bc01ab31b8fc7f2ac760daf9c665bbb7cd68..0cf206084a21cb5994bf1617254d37589c6880f3 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1163,6 +1163,7 @@ set(STATIC_BUILD_TESTS test_fetch_lod_tensor_array test_fused_attention_op test_fused_attention_op_api + test_fuse_bn_act_pass test_fused_feedforward_op test_fused_feedforward_pass test_imperative_optimizer @@ -1206,6 +1207,13 @@ endforeach() set_tests_properties(test_decoupled_py_reader_static_build PROPERTIES TIMEOUT 120) +set_tests_properties(test_fuse_bn_act_pass_static_build PROPERTIES TIMEOUT 120) +set_tests_properties( + test_fuse_bn_act_pass_static_build + PROPERTIES + ENVIRONMENT + "FLAGS_cudnn_deterministic=1;FLAGS_cudnn_batchnorm_spatial_persistent=1;FLAGS_conv_workspace_size_limit=1000" +) set_tests_properties(test_imperative_optimizer_static_build PROPERTIES TIMEOUT 250) set_tests_properties(test_matmul_op_static_build PROPERTIES TIMEOUT 120)