// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/instance_norm_kernel.h" #include "paddle/fluid/operators/norm_utils.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/gpu/instance_norm_utils.h" namespace phi { template void InstanceNormKernel(const Context &dev_ctx, const DenseTensor &x, paddle::optional scale, paddle::optional bias, float epsilon_f, DenseTensor *y, DenseTensor *saved_mean, DenseTensor *saved_variance) { double epsilon = static_cast(epsilon_f); auto &x_dims = x.dims(); PADDLE_ENFORCE_GE(x_dims.size(), 2, phi::errors::InvalidArgument( "The `shape` in InstanceNormOp is invalid: " "the size of X's dimensions must greater than " "or equal to 2. But received: " "the size of X's dimensions is [%d]", x_dims.size())); PADDLE_ENFORCE_LE(x_dims.size(), 5, phi::errors::InvalidArgument( "The `shape` in InstanceNormOp is invalid: " "the size of X's dimensions must smaller than" "or equal to 5. But received: " "the size of X's dimensions is [%d]", x_dims.size())); int N, C, H, W, D; paddle::operators::ExtractNCWHD( x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D); int NxC = N * C; DenseTensor x_tmp; x_tmp.ShareDataWith(x).Resize({1, NxC, H, W, D}); dev_ctx.template Alloc(y); #ifdef PADDLE_WITH_HIP miopenTensorDescriptor_t data_desc_; miopenTensorDescriptor_t in_param_desc_; PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::miopenCreateTensorDescriptor(&data_desc_)); PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::miopenCreateTensorDescriptor(&in_param_desc_)); #else cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t in_param_desc_; PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::cudnnCreateTensorDescriptor(&in_param_desc_)); #endif if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { LOG(ERROR) << "Provided epsilon is smaller than " << "CUDNN_BN_MIN_EPSILON. Setting it to " << "CUDNN_BN_MIN_EPSILON instead."; } epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); VLOG(3) << "Setting descriptors."; std::vector dims; std::vector strides; dims = {1, NxC, H, W, D}; strides = {NxC * H * W * D, H * W * D, W * D, D, 1}; #ifdef PADDLE_WITH_HIP PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::miopenSetTensorDescriptor( data_desc_, CudnnDataType::type, x_dims.size() > 3 ? x_dims.size() : 4, const_cast(dims.data()), const_cast(strides.data()))); PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::miopenDeriveBNTensorDescriptor( in_param_desc_, data_desc_, miopenBNSpatial)); #else PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::cudnnSetTensorNdDescriptor( data_desc_, CudnnDataType::type, x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::cudnnDeriveBNTensorDescriptor( in_param_desc_, data_desc_, CUDNN_BATCHNORM_SPATIAL)); #endif const auto scale_ptr = scale.get_ptr(); const auto bias_ptr = bias.get_ptr(); DenseTensor scale_tmp; scale_tmp.Resize({NxC}); dev_ctx.template Alloc(&scale_tmp); DenseTensor bias_tmp; bias_tmp.Resize({NxC}); dev_ctx.template Alloc(&bias_tmp); const int n = x.numel(); const int block = 512; int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); const int grid = std::min((NxC + block - 1) / block, max_blocks); phi::funcs::SetConstant set_constant; if (scale_ptr) { repeat_param<<>>( scale_ptr->data(), scale_tmp.data(), N, C); } else { set_constant(dev_ctx, &scale_tmp, static_cast(1)); } if (bias_ptr) { repeat_param<<>>( bias_ptr->data(), bias_tmp.data(), N, C); } else { set_constant(dev_ctx, &bias_tmp, static_cast(0)); } auto handle = dev_ctx.cudnn_handle(); phi::funcs::SetConstant> functor; dev_ctx.template Alloc>(saved_mean); dev_ctx.template Alloc>(saved_variance); functor(dev_ctx, saved_mean, static_cast>(0)); functor(dev_ctx, saved_variance, static_cast>(0)); #ifdef PADDLE_WITH_HIP PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::miopenBatchNormalizationForwardTraining( handle, miopenBNSpatial, const_cast( static_cast(CudnnDataType::kOne())), const_cast( static_cast(CudnnDataType::kZero())), data_desc_, static_cast(x_tmp.template data()), data_desc_, static_cast(y->template data()), in_param_desc_, const_cast(static_cast( scale_tmp.template data>())), const_cast(static_cast( bias_tmp.template data>())), 0, nullptr, nullptr, epsilon, static_cast( saved_mean->template data>()), static_cast( saved_variance->template data>()))); PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::miopenDestroyTensorDescriptor(data_desc_)); PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::miopenDestroyTensorDescriptor(in_param_desc_)); #else PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::cudnnBatchNormalizationForwardTraining( handle, CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x_tmp.template data(), data_desc_, y->template data(), in_param_desc_, scale_tmp.template data>(), bias_tmp.template data>(), 0, nullptr, nullptr, epsilon, saved_mean->template data>(), saved_variance->template data>())); PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::cudnnDestroyTensorDescriptor(in_param_desc_)); #endif } } // namespace phi #ifdef PADDLE_WITH_HIP // MIOPEN do not support double PD_REGISTER_KERNEL( instance_norm, GPU, ALL_LAYOUT, phi::InstanceNormKernel, float) {} #else PD_REGISTER_KERNEL( instance_norm, GPU, ALL_LAYOUT, phi::InstanceNormKernel, float, double) {} #endif