diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 599be6912b760ec97586bf821725781f68fa8385..2630c12db2fc9a0ba7f2a718ba89ca738a02d3a3 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -80,5 +80,6 @@ if (WITH_GPU OR WITH_ROCM) endif() if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000)) cc_test(test_cudnn_norm_conv SRCS cudnn_norm_conv_test.cc DEPS conv_op blas im2col vol2col depthwise_conv eigen_function tensor op_registry device_context generator memory) + cc_test(test_cudnn_bn_add_relu SRCS cudnn_bn_add_relu_test.cc DEPS batch_norm_op fused_bn_add_activation_op tensor op_registry device_context generator memory) endif() endif() diff --git a/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc b/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7229754cb8ed82ed6f9da427c044bcb5de388bb9 --- /dev/null +++ b/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc @@ -0,0 +1,380 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h" +#include "paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/float16.h" + +DECLARE_bool(cudnn_batchnorm_spatial_persistent); + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace op = paddle::operators; +using Tensor = paddle::framework::Tensor; + +USE_OP(batch_norm); + +template +void InitRandomTensor(const std::vector &dims, + framework::Tensor *cpu_out) { + T *cpu_out_ptr = cpu_out->mutable_data(framework::make_ddim(dims), + platform::CPUPlace()); + std::default_random_engine random(0); + std::uniform_real_distribution dis(0.0, 1.0); + for (int i = 0; i < cpu_out->numel(); ++i) { + cpu_out_ptr[i] = static_cast(dis(random)); + } +} + +template +void InitConstantTensor(const std::vector &dims, T value, + framework::Tensor *cpu_out) { + T *cpu_out_ptr = cpu_out->mutable_data(framework::make_ddim(dims), + platform::CPUPlace()); + for (int i = 0; i < cpu_out->numel(); ++i) { + cpu_out_ptr[i] = value; + } +} + +template +void CheckOutput(std::string name, const framework::Tensor &cpu_res, + const framework::Tensor &cpu_base, float diff, + bool is_relative_atol = false) { + if (cpu_res.dims().size() == cpu_base.dims().size()) { + EXPECT_EQ(cpu_res.dims(), cpu_base.dims()); + } else { + EXPECT_EQ(cpu_res.numel(), cpu_base.numel()); + } + + const T *cpu_res_ptr = cpu_res.data(); + const T *cpu_base_ptr = cpu_base.data(); + float max_diff = 0; + int index = 0; + for (int i = 0; i < cpu_res.numel(); ++i) { + float cur_diff; + if (is_relative_atol) { + cur_diff = static_cast( + std::abs((cpu_res_ptr[i] - cpu_base_ptr[i]) / cpu_base_ptr[i])); + EXPECT_LT(static_cast(std::abs((cpu_res_ptr[i] - cpu_base_ptr[i]) / + cpu_base_ptr[i])), + diff); + } else { + cur_diff = static_cast(std::abs(cpu_res_ptr[i] - cpu_base_ptr[i])); + EXPECT_LT(static_cast(std::abs(cpu_res_ptr[i] - cpu_base_ptr[i])), + diff); + } + if (cur_diff > max_diff) { + max_diff = cur_diff; + index = i; + } + } + std::string error_type = is_relative_atol ? "relative" : "absolute"; + LOG(INFO) << "[" << name << "], The dims is [" << cpu_res.dims() + << "], maximum " << error_type << " error is " << max_diff << ": " + << cpu_res_ptr[index] << " vs " << cpu_base_ptr[index]; +} + +template +void ComputeSumAndSquareSum(const framework::Tensor &cpu_x, + framework::Tensor *cpu_sum, + framework::Tensor *cpu_sum_of_square) { + // x is in NHWC format. + auto dims = cpu_x.dims(); + int64_t c = dims[3]; + + const T *cpu_x_ptr = cpu_x.data(); + float *cpu_sum_ptr = + cpu_sum->mutable_data({1, 1, 1, c}, platform::CPUPlace()); + float *cpu_sum_square_ptr = cpu_sum_of_square->mutable_data( + {1, 1, 1, c}, platform::CPUPlace()); + + for (int j = 0; j < c; ++j) { + float tmp_sum = 0.0f; + float tmp_sum_of_squares = 0.0f; + for (int i = 0; i < cpu_x.numel() / c; ++i) { + float tmp_x = static_cast(cpu_x_ptr[i * c + j]); + tmp_sum += tmp_x; + tmp_sum_of_squares += tmp_x * tmp_x; + } + cpu_sum_ptr[j] = tmp_sum; + cpu_sum_square_ptr[j] = tmp_sum_of_squares; + } +} + +// get paddle batchnorm op results as baseline +void ComputeBatchNormForward(const platform::CUDADeviceContext &ctx, + const Tensor &cpu_x, const Tensor &cpu_scale, + const Tensor &cpu_bias, Tensor *cpu_mean, + Tensor *cpu_var, Tensor *cpu_saved_mean, + Tensor *cpu_saved_var, Tensor *cpu_y, + Tensor *cpu_reserve_space) { + framework::Scope scope; + auto *x = scope.Var("X")->GetMutable(); + auto *scale = scope.Var("Scale")->GetMutable(); + auto *bias = scope.Var("Bias")->GetMutable(); + auto *mean = scope.Var("Mean")->GetMutable(); + auto *var = scope.Var("Variance")->GetMutable(); + auto *y = scope.Var("Y")->GetMutable(); + auto *saved_mean = scope.Var("SavedMean")->GetMutable(); + auto *saved_var = + scope.Var("SavedVariance")->GetMutable(); + auto *reserve_space = + scope.Var("ReserveSpace")->GetMutable(); + + auto place = ctx.GetPlace(); + TensorCopySync(cpu_x, place, x); + TensorCopySync(cpu_scale, place, scale); + TensorCopySync(cpu_bias, place, bias); + TensorCopySync(*cpu_mean, place, mean); + TensorCopySync(*cpu_var, place, var); + + int64_t channels = x->dims()[3]; + scale->Resize({channels}); + bias->Resize({channels}); + mean->Resize({channels}); + var->Resize({channels}); + + framework::AttributeMap attrs; + std::string data_layout = "NHWC"; + attrs.insert({"data_layout", data_layout}); + + auto op = framework::OpRegistry::CreateOp( + "batch_norm", {{"X", {"X"}}, + {"Scale", {"Scale"}}, + {"Bias", {"Bias"}}, + {"Mean", {"Mean"}}, + {"Variance", {"Variance"}}}, + {{"Y", {"Y"}}, + {"MeanOut", {"Mean"}}, + {"VarianceOut", {"Variance"}}, + {"SavedMean", {"SavedMean"}}, + {"SavedVariance", {"SavedVariance"}}, + {"ReserveSpace", {"ReserveSpace"}}}, + attrs); + op->Run(scope, ctx.GetPlace()); + + TensorCopySync(*y, platform::CPUPlace(), cpu_y); + TensorCopySync(*mean, platform::CPUPlace(), cpu_mean); + TensorCopySync(*var, platform::CPUPlace(), cpu_var); + TensorCopySync(*saved_mean, platform::CPUPlace(), cpu_saved_mean); + TensorCopySync(*saved_var, platform::CPUPlace(), cpu_saved_var); + TensorCopySync(*reserve_space, platform::CPUPlace(), cpu_reserve_space); +} + +template +class CudnnBNAddReluTester { + public: + CudnnBNAddReluTester(int batch_size, int height, int width, int channels) { + batch_size_ = batch_size; + height_ = height; + width_ = width; + channels_ = channels; + ele_count_ = batch_size_ * height_ * width_; + SetUp(); + } + + ~CudnnBNAddReluTester() {} + + void CheckForward(float diff, bool is_relative_atol = false) { + platform::CUDADeviceContext *ctx = + static_cast( + platform::DeviceContextPool::Instance().Get( + platform::CUDAPlace(0))); + + framework::Tensor cpu_mean_base; + framework::Tensor cpu_var_base; + framework::Tensor cpu_saved_mean_base; + framework::Tensor cpu_saved_var_base; + framework::Tensor cpu_y_base; + framework::Tensor cpu_reserve_space_base; + BaselineForward(*ctx, &cpu_mean_base, &cpu_var_base, &cpu_saved_mean_base, + &cpu_saved_var_base, &cpu_y_base, &cpu_reserve_space_base); + + framework::Tensor cpu_mean; + framework::Tensor cpu_var; + framework::Tensor cpu_saved_mean; + framework::Tensor cpu_saved_var; + framework::Tensor cpu_y; + framework::Tensor cpu_bitmask; + FusedForward(*ctx, &cpu_mean, &cpu_var, &cpu_saved_mean, &cpu_saved_var, + &cpu_y, &cpu_bitmask); + + CheckOutput("Mean", cpu_mean, cpu_mean_base, diff, is_relative_atol); + CheckOutput("Variance", cpu_var, cpu_var_base, diff, + is_relative_atol); + CheckOutput("SavedMean", cpu_saved_mean, cpu_saved_mean_base, diff, + is_relative_atol); + CheckOutput("SavedVariance", cpu_saved_var, cpu_saved_var_base, diff, + is_relative_atol); + CheckOutput("Y", cpu_y, cpu_y_base, diff, is_relative_atol); + } + + private: + void SetUp() { + // Initialize input data + InitRandomTensor({batch_size_, height_, width_, channels_}, &cpu_x_); + ComputeSumAndSquareSum(cpu_x_, &cpu_sum_, &cpu_sum_of_square_); + + // scale and bias should be initialized randomly. + InitConstantTensor({channels_}, static_cast(1.0f), + &cpu_bn_scale_); + InitConstantTensor({channels_}, static_cast(0.0f), + &cpu_bn_bias_); + } + + void InitMeanVar(Tensor *cpu_mean, Tensor *cpu_var, Tensor *cpu_saved_mean, + Tensor *cpu_saved_var) { + InitConstantTensor({channels_}, static_cast(0.0f), cpu_mean); + InitConstantTensor({channels_}, static_cast(1.0f), cpu_var); + InitConstantTensor({channels_}, static_cast(0.0f), + cpu_saved_mean); + InitConstantTensor({channels_}, static_cast(0.0f), + cpu_saved_var); + } + + void BaselineForward(const platform::CUDADeviceContext &ctx, Tensor *cpu_mean, + Tensor *cpu_var, Tensor *cpu_saved_mean, + Tensor *cpu_saved_var, Tensor *cpu_y, + Tensor *cpu_reserve_space) { + InitMeanVar(cpu_mean, cpu_var, cpu_saved_mean, cpu_saved_var); + ComputeBatchNormForward(ctx, cpu_x_, cpu_bn_scale_, cpu_bn_bias_, cpu_mean, + cpu_var, cpu_saved_mean, cpu_saved_var, cpu_y, + cpu_reserve_space); + } + + // Get forward results of CudnnBNStatsFinalize + CudnnScaleBiasAddRelu + void FusedForward(const platform::CUDADeviceContext &ctx, Tensor *cpu_mean, + Tensor *cpu_var, Tensor *cpu_saved_mean, + Tensor *cpu_saved_var, Tensor *cpu_y, Tensor *cpu_bitmask) { + framework::Tensor x; + framework::Tensor sum; + framework::Tensor sum_of_square; + framework::Tensor bn_scale; + framework::Tensor bn_bias; + + auto place = ctx.GetPlace(); + TensorCopySync(cpu_x_, place, &x); + TensorCopySync(cpu_sum_, place, &sum); + TensorCopySync(cpu_sum_of_square_, place, &sum_of_square); + TensorCopySync(cpu_bn_scale_, place, &bn_scale); + TensorCopySync(cpu_bn_bias_, place, &bn_bias); + + bn_scale.Resize({1, 1, 1, channels_}); + bn_bias.Resize({1, 1, 1, channels_}); + + T *x_ptr = x.data(); + float *sum_ptr = sum.data(); + float *sum_of_square_ptr = sum_of_square.data(); + float *bn_scale_ptr = bn_scale.data(); + float *bn_bias_ptr = bn_bias.data(); + + framework::Tensor mean; + framework::Tensor var; + framework::Tensor saved_mean; + framework::Tensor saved_var; + framework::Tensor equiv_scale; + framework::Tensor equiv_bias; + framework::Tensor y; + framework::Tensor bitmask; + + InitMeanVar(cpu_mean, cpu_var, cpu_saved_mean, cpu_saved_var); + TensorCopySync(*cpu_mean, place, &mean); + TensorCopySync(*cpu_var, place, &var); + + mean.Resize({1, 1, 1, channels_}); + var.Resize({1, 1, 1, channels_}); + + float *mean_ptr = mean.data(); + float *var_ptr = var.data(); + float *saved_mean_ptr = + saved_mean.mutable_data({1, 1, 1, channels_}, place); + float *saved_var_ptr = + saved_var.mutable_data({1, 1, 1, channels_}, place); + T *equiv_scale_ptr = + equiv_scale.mutable_data({1, 1, 1, channels_}, place); + T *equiv_bias_ptr = equiv_bias.mutable_data({1, 1, 1, channels_}, place); + T *y_ptr = + y.mutable_data({batch_size_, height_, width_, channels_}, place); + + // bitmask + int c = channels_; + int64_t nhw = ele_count_; + int32_t c_int32_elems = ((c + 63) & ~63) / 32; + int32_t nhw_int32_elems = (nhw + 31) & ~31; + int32_t *bitmask_ptr = bitmask.mutable_data( + {nhw_int32_elems, c_int32_elems, 1}, place); + + auto data_shape = framework::vectorize(x.dims()); + auto param_shape = framework::vectorize(bn_scale.dims()); + auto bitmask_shape = framework::vectorize(bitmask.dims()); + + // 1. BN Stats Finalize + op::CudnnBNStatsFinalize bn_op(ctx, param_shape); + bn_op.Forward(ctx, sum_ptr, sum_of_square_ptr, bn_scale_ptr, bn_bias_ptr, + saved_mean_ptr, saved_var_ptr, mean_ptr, var_ptr, + equiv_scale_ptr, equiv_bias_ptr, eps_, momentum_, ele_count_, + true); + + // 2. Scale Bias + Relu (not fused add) + std::string act_type = ""; + op::CudnnScaleBiasAddRelu sbar_op( + ctx, act_type, false, false, data_shape, param_shape, bitmask_shape); + sbar_op.Forward(ctx, x_ptr, equiv_scale_ptr, equiv_bias_ptr, y_ptr, + bitmask_ptr); + + TensorCopySync(mean, platform::CPUPlace(), cpu_mean); + TensorCopySync(var, platform::CPUPlace(), cpu_var); + TensorCopySync(saved_mean, platform::CPUPlace(), cpu_saved_mean); + TensorCopySync(saved_var, platform::CPUPlace(), cpu_saved_var); + TensorCopySync(y, platform::CPUPlace(), cpu_y); + TensorCopySync(bitmask, platform::CPUPlace(), cpu_bitmask); + } + + private: + int batch_size_; + int height_; + int width_; + int channels_; + int ele_count_; + + // Forward input + framework::Tensor cpu_x_; + framework::Tensor cpu_sum_; + framework::Tensor cpu_sum_of_square_; + framework::Tensor cpu_bn_scale_; + framework::Tensor cpu_bn_bias_; + + double eps_ = 1e-5; + float momentum_ = 0.9; +}; + +TEST(CudnnBNAddReluForward, GPUCudnnBNAddReluForwardFp16) { + int batch_size = 4; + int height = 8; + int width = 8; + int channels = 64; + FLAGS_cudnn_batchnorm_spatial_persistent = true; + CudnnBNAddReluTester test(batch_size, height, + width, channels); + test.CheckForward(2e-3); +} diff --git a/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h b/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h new file mode 100644 index 0000000000000000000000000000000000000000..7d4b24cd4fc3de427b02e1230fbcadc4cc01e3ad --- /dev/null +++ b/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h @@ -0,0 +1,181 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h" +#include "paddle/fluid/platform/cudnn_desc.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +namespace dynload = platform::dynload; +template +using BatchNormParamType = + typename platform::CudnnDataType::BatchNormParamType; + +#if CUDNN_VERSION >= 8000 + +template +struct BNStatsFinalizeArgs { + BNStatsFinalizeArgs() { + dtype = platform::CudnnDataType::type; + param_dtype = platform::CudnnDataType>::type; + format = CUDNN_TENSOR_NHWC; + } + + void Set(const std::vector ¶m_shape) { + PADDLE_ENFORCE_EQ( + param_shape.size(), 4U, + platform::errors::InvalidArgument( + "The size of param_shape is expected to 4. But recieved " + "param_shape's size is %d, param_shape is [%s].", + param_shape.size(), framework::make_ddim(param_shape))); + + in_desc.set(param_shape, format, param_dtype); + out_desc.set(param_shape, format, dtype); + } + + cudnnDataType_t dtype; + cudnnDataType_t param_dtype; + cudnnTensorFormat_t format; + + platform::TensorDescriptor in_desc; + platform::TensorDescriptor out_desc; +}; + +template +class CudnnBNStatsFinalize { + public: + CudnnBNStatsFinalize(const platform::CUDADeviceContext &ctx, + const std::vector ¶m_shape) + : train_op_(CUDNN_FUSED_BN_FINALIZE_STATISTICS_TRAINING), + inference_op_(CUDNN_FUSED_BN_FINALIZE_STATISTICS_INFERENCE) { + args_.Set(param_shape); + } + ~CudnnBNStatsFinalize() {} + + void Forward(const platform::CUDADeviceContext &ctx, float *sum_ptr, + float *sum_of_squares_ptr, float *scale_ptr, float *bias_ptr, + float *saved_mean_ptr, float *saved_invstd_ptr, + float *running_mean_ptr, float *running_var_ptr, + T *equiv_scale_ptr, T *equiv_bias_ptr, double eps, + float momentum, int64_t ele_count, bool is_train) { + if (is_train) { + TrainInit(ctx); + } else { + InferenceInit(ctx); + } + auto &op = is_train ? train_op_ : inference_op_; + + // Set variant_param for both inference_op_ and train_op_ + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_BIAS, bias_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_RUNNING_MEAN, running_mean_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_RUNNING_VAR, running_var_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, equiv_scale_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, equiv_bias_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_SCALAR_DOUBLE_BN_EPSILON, &eps); + + // Set extra variant_param only for train_op_: + if (is_train) { + op.SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_MEAN, saved_mean_ptr); + op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_INVSTD, saved_invstd_ptr); + double avg_factor = 1.0 - momentum; + op.SetOpVariantParamAttrPtr(CUDNN_SCALAR_INT64_T_BN_ACCUMULATION_COUNT, + &ele_count); + op.SetOpVariantParamAttrPtr(CUDNN_SCALAR_DOUBLE_BN_EXP_AVG_FACTOR, + &avg_factor); + } + // fused op execute + auto handle = ctx.cudnn_handle(); + op.Execute(handle); + } + + private: + void TrainInit(const platform::CUDADeviceContext &ctx) { + // Set constant_param for train op + train_op_.SetOpConstParamAttr( + {CUDNN_PARAM_YSUM_PLACEHOLDER, CUDNN_PARAM_YSQSUM_PLACEHOLDER, + CUDNN_PARAM_BN_SCALE_PLACEHOLDER, CUDNN_PARAM_BN_BIAS_PLACEHOLDER, + CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER, + CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER, + CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER, + CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER, + CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER, CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + // Set input and output desc for train op + train_op_.SetOpConstParamDesc( + {CUDNN_PARAM_YSTATS_DESC, CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC}, + args_.in_desc.desc()); + train_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_EQSCALEBIAS_DESC, + args_.out_desc.desc()); + + // Get workspace + auto handle = ctx.cudnn_handle(); + train_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE, + CUDNN_BATCHNORM_SPATIAL_PERSISTENT); + // Check workspace size, also creates plan. + size_t workspace_size_bytes = train_op_.GetWorkspaceSizeInBytes(handle); + PADDLE_ENFORCE_EQ(workspace_size_bytes, 0U, + platform::errors::InvalidArgument( + "Unexpected non-zero workspace size for " + "CudnnBNStatsFinalize.")); + train_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, + static_cast(nullptr)); + train_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, + &workspace_size_bytes); + } + + void InferenceInit(const platform::CUDADeviceContext &ctx) { + // Set constant_param for inference op + inference_op_.SetOpConstParamAttr( + {CUDNN_PARAM_BN_SCALE_PLACEHOLDER, CUDNN_PARAM_BN_BIAS_PLACEHOLDER, + CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER, + CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER, + CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER, CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + // Set input and output desc for inference op + inference_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC, + args_.in_desc.desc()); + inference_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_EQSCALEBIAS_DESC, + args_.out_desc.desc()); + + // Get workspace + auto handle = ctx.cudnn_handle(); + inference_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE, + CUDNN_BATCHNORM_SPATIAL_PERSISTENT); + // Check workspace size, also creates plan. + size_t workspace_size_bytes = inference_op_.GetWorkspaceSizeInBytes(handle); + PADDLE_ENFORCE_EQ(workspace_size_bytes, 0U, + platform::errors::InvalidArgument( + "Unexpected non-zero workspace size for " + "CudnnBNStatsFinalize.")); + inference_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, + static_cast(nullptr)); + inference_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, + &workspace_size_bytes); + } + + BNStatsFinalizeArgs args_; + CudnnFusionOp train_op_; + CudnnFusionOp inference_op_; +}; +#endif +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h b/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h new file mode 100644 index 0000000000000000000000000000000000000000..2fdb3635e2e1496e41f0a4642fb34455f490e957 --- /dev/null +++ b/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h @@ -0,0 +1,292 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h" +#include "paddle/fluid/platform/cudnn_desc.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; +template +using CudnnDataType = platform::CudnnDataType; +namespace dynload = platform::dynload; +template +using BatchNormParamType = + typename platform::CudnnDataType::BatchNormParamType; + +#if CUDNN_VERSION >= 8000 + +template +struct ScaleBiasAddReluArgs { + ScaleBiasAddReluArgs() { + dtype = platform::CudnnDataType::type; + param_dtype = platform::CudnnDataType>::type; + format = CUDNN_TENSOR_NHWC; + } + + void Set(const std::string &act_type, const std::vector &data_shape, + const std::vector ¶m_shape, + const std::vector &bitmask_shape) { + PADDLE_ENFORCE_EQ( + data_shape.size(), 4U, + platform::errors::InvalidArgument( + "The size of data_shape is expected to 4. But recieved " + "data_shape's size is %d, data_shape is [%s].", + data_shape.size(), framework::make_ddim(data_shape))); + PADDLE_ENFORCE_EQ( + param_shape.size(), 4U, + platform::errors::InvalidArgument( + "The size of param_shape is expected to 4. But recieved " + "param_shape's size is %d, param_shape is [%s].", + param_shape.size(), framework::make_ddim(param_shape))); + PADDLE_ENFORCE_EQ( + bitmask_shape.size(), 3U, + platform::errors::InvalidArgument( + "The size of bitmask_shape is expected to 3. But recieved " + "bitmask_shape's size is %d, bitmask_shape is [%s].", + bitmask_shape.size(), framework::make_ddim(bitmask_shape))); + + in_desc.set(data_shape, format, dtype); + out_desc.set(data_shape, format, dtype); + equiv_scale_bias_desc.set(param_shape, format, dtype); + scale_bias_mean_var_desc.set(param_shape, format, param_dtype); + bitmask_desc.set(bitmask_shape, format, CUDNN_DATA_INT32); + // set activation desc + cudnnActivationMode_t mode = CUDNN_ACTIVATION_IDENTITY; + if (act_type != "") { + PADDLE_ENFORCE_EQ( + act_type, "relu", + platform::errors::InvalidArgument( + "Only relu activation supported in normalized convolution.")); + mode = CUDNN_ACTIVATION_RELU; + } + double dummy_clip = 0.0; + activation_desc.set(mode, dummy_clip); + } + + cudnnDataType_t dtype; + cudnnDataType_t param_dtype; + cudnnTensorFormat_t format; + + platform::TensorDescriptor in_desc; + platform::TensorDescriptor out_desc; + platform::TensorDescriptor equiv_scale_bias_desc; + platform::TensorDescriptor scale_bias_mean_var_desc; + platform::TensorDescriptor bitmask_desc; + platform::ActivationDescriptor activation_desc; +}; + +template +class CudnnScaleBiasAddRelu { + public: + CudnnScaleBiasAddRelu(const platform::CUDADeviceContext &ctx, + const std::string &act_type, bool fused_add, + bool has_shortcut, const std::vector &data_shape, + const std::vector ¶m_shape, + const std::vector &bitmask_shape) + : fwd_op_(CUDNN_FUSED_SCALE_BIAS_ADD_ACTIVATION_GEN_BITMASK), + bwd_op_(CUDNN_FUSED_DACTIVATION_FORK_DBATCHNORM) { + fused_add_ = fused_add; + has_shortcut_ = has_shortcut; + args_.Set(act_type, data_shape, param_shape, bitmask_shape); + } + + ~CudnnScaleBiasAddRelu() {} + + void Forward(const platform::CUDADeviceContext &ctx, T *x_ptr, T *x_scale_ptr, + T *x_bias_ptr, T *out_ptr, int32_t *bitmask_ptr, + T *z_ptr = nullptr, T *z_scale_ptr = nullptr, + T *z_bias_ptr = nullptr) { + ForwardInit(ctx); + auto handle = ctx.cudnn_handle(); + auto workspace_handle = ctx.cudnn_workspace_handle(); + fwd_workspace_byte_ = fwd_op_.GetWorkspaceSizeInBytes(handle); + // Set variant_param + // input ptr + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, x_scale_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, x_bias_ptr); + if (has_shortcut_) { + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQSCALE, z_scale_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQBIAS, z_bias_ptr); + } else { + if (fused_add_) { + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr); + } + } + + fwd_op_.SetOpVariantParamAttrPtr( + CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &fwd_workspace_byte_); + + // output ptr + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, out_ptr); + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ACTIVATION_BITMASK, bitmask_ptr); + + workspace_handle.RunFunc( + [&](void *workspace_ptr) { + // workspace ptr + fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, workspace_ptr); + // workspace ptr + fwd_op_.Execute(handle); + }, + fwd_workspace_byte_); + } + + void Backward(const platform::CUDADeviceContext &ctx, T *dy_ptr, T *x_ptr, + float *scale_ptr, float *bias_ptr, float *saved_mean_ptr, + float *saved_invstd_ptr, int32_t *bitmask_ptr, T *dx_ptr, + T *dz_ptr, float *dscale_ptr, float *dbias_ptr, double eps) { + BackwardInit(ctx); + auto handle = ctx.cudnn_handle(); + auto workspace_handle = ctx.cudnn_workspace_handle(); + bwd_workspace_byte_ = bwd_op_.GetWorkspaceSizeInBytes(handle); + // Set variant_param + // input ptr + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DYDATA, dy_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_BIAS, bias_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_MEAN, saved_mean_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_INVSTD, + saved_invstd_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ACTIVATION_BITMASK, bitmask_ptr); + + bwd_op_.SetOpVariantParamAttrPtr( + CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &bwd_workspace_byte_); + + // output ptr + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DXDATA, dx_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_DSCALE, dscale_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_DBIAS, dbias_ptr); + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_SCALAR_DOUBLE_BN_EPSILON, + &eps); + if (has_shortcut_ || fused_add_) { + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DZDATA, dz_ptr); + } + + workspace_handle.RunFunc( + [&](void *workspace_ptr) { + // workspace ptr + bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, workspace_ptr); + // workspace ptr + bwd_op_.Execute(handle); + }, + bwd_workspace_byte_); + } + + private: + void ForwardInit(const platform::CUDADeviceContext &ctx) { + // Set constant_param + fwd_op_.SetOpConstParamAttr( + {CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER, + CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER, CUDNN_PARAM_YDATA_PLACEHOLDER, + CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + if (has_shortcut_) { + fwd_op_.SetOpConstParamAttr( + {CUDNN_PARAM_ZDATA_PLACEHOLDER, CUDNN_PARAM_BN_Z_EQSCALE_PLACEHOLDER, + CUDNN_PARAM_BN_Z_EQBIAS_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + } else if (fused_add_) { + fwd_op_.SetOpConstParamAttr(CUDNN_PARAM_ZDATA_PLACEHOLDER, + CUDNN_PTR_16B_ALIGNED); + } + + // input desc + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc()); + if (has_shortcut_ || fused_add_) { + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ZDESC, args_.in_desc.desc()); + } + + // equiv scale/bias desc + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_EQSCALEBIAS_DESC, + args_.equiv_scale_bias_desc.desc()); + if (has_shortcut_) { + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_Z_EQSCALEBIAS_DESC, + args_.equiv_scale_bias_desc.desc()); + } + + // output desc + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_YDESC, args_.out_desc.desc()); + + // bitmask desc + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_BITMASK_DESC, + args_.bitmask_desc.desc()); + + // activation desc + fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_DESC, + args_.activation_desc.desc()); + + // others + fwd_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE, + CUDNN_BATCHNORM_SPATIAL_PERSISTENT); + } + + void BackwardInit(const platform::CUDADeviceContext &ctx) { + // Set constant_param + bwd_op_.SetOpConstParamAttr( + {CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_DYDATA_PLACEHOLDER, + CUDNN_PARAM_DXDATA_PLACEHOLDER, CUDNN_PARAM_BN_SCALE_PLACEHOLDER, + CUDNN_PARAM_BN_BIAS_PLACEHOLDER, CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER, + CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER, + CUDNN_PARAM_BN_DSCALE_PLACEHOLDER, CUDNN_PARAM_BN_DBIAS_PLACEHOLDER, + CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER}, + CUDNN_PTR_16B_ALIGNED); + if (has_shortcut_ || fused_add_) { + bwd_op_.SetOpConstParamAttr(CUDNN_PARAM_DZDATA_PLACEHOLDER, + CUDNN_PTR_16B_ALIGNED); + } + + // input desc + bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc()); + bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DXDESC, args_.in_desc.desc()); + if (has_shortcut_ || fused_add_) { + bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DZDESC, args_.in_desc.desc()); + } + + // scale/bias/mean/var desc for backward + bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC, + args_.scale_bias_mean_var_desc.desc()); + + // output desc + bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DYDESC, args_.out_desc.desc()); + + // bitmask desc + bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_BITMASK_DESC, + args_.bitmask_desc.desc()); + + // activation desc + bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_DESC, + args_.activation_desc.desc()); + + // others + bwd_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE, + CUDNN_BATCHNORM_SPATIAL_PERSISTENT); + } + + bool fused_add_ = false; + bool has_shortcut_ = false; + size_t fwd_workspace_byte_; + size_t bwd_workspace_byte_; + ScaleBiasAddReluArgs args_; + CudnnFusionOp fwd_op_; + CudnnFusionOp bwd_op_; +}; +#endif +} // namespace operators +} // namespace paddle