/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include "cub/cub.cuh" #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/operators/instance_norm_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cudnn_helper.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using DataLayout = framework::DataLayout; template using CudnnDataType = platform::CudnnDataType; template using BatchNormParamType = typename CudnnDataType::BatchNormParamType; template static __global__ void repeat_param(const T *input, T *output, const int repeat_num, const int C) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < repeat_num * C; i += blockDim.x * gridDim.x) { int index = i % C; output[i] = input[index]; } } template static __global__ void add_param(const T *input, T *output, const int repeat_num, const int C) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage ou_storage; for (int i = blockIdx.x; i < C; i += gridDim.x) { T ou = static_cast(0); for (int j = threadIdx.x; j < repeat_num; j += blockDim.x) { const int index = j * C + i; ou += static_cast(input[index]); } ou = BlockReduce(ou_storage).Reduce(ou, cub::Sum()); if (threadIdx.x == 0) { output[i] = ou; } __syncthreads(); if (AVG) { output[i] /= repeat_num; } } } template class InstanceNormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, "It must be CUDAPlace."); double epsilon = static_cast(ctx.Attr("epsilon")); auto *x = ctx.Input("X"); auto &x_dims = x->dims(); PADDLE_ENFORCE_GE( x_dims.size(), 2, "the dimension of input X must greater than or equal to 2"); PADDLE_ENFORCE_LE( x_dims.size(), 5, "the dimension of input X must smaller than or equal to 5"); int N, C, H, W, D; ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D); int NxC = N * C; Tensor x_tmp; x_tmp.ShareDataWith(*x).Resize({1, NxC, H, W, D}); auto *y = ctx.Output("Y"); y->mutable_data(ctx.GetPlace()); cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t in_param_desc_; CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); CUDNN_ENFORCE( platform::dynload::cudnnCreateTensorDescriptor(&in_param_desc_)); if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { LOG(ERROR) << "Provided epsilon is smaller than " << "CUDNN_BN_MIN_EPSILON. Setting it to " << "CUDNN_BN_MIN_EPSILON instead."; } epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); VLOG(3) << "Setting descriptors."; std::vector dims; std::vector strides; dims = {1, NxC, H, W, D}; strides = {NxC * H * W * D, H * W * D, W * D, D, 1}; auto &dev_ctx = ctx.template device_context(); CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( data_desc_, CudnnDataType::type, x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor( in_param_desc_, data_desc_, CUDNN_BATCHNORM_SPATIAL)); const auto *scale = ctx.Input("Scale"); const auto *bias = ctx.Input("Bias"); Tensor scale_tmp = ctx.AllocateTmpTensor({NxC}, dev_ctx); scale_tmp.mutable_data(ctx.GetPlace()); Tensor bias_tmp = ctx.AllocateTmpTensor({NxC}, dev_ctx); bias_tmp.mutable_data(ctx.GetPlace()); const int n = x->numel(); const int block = 512; int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); const int grid = std::min((NxC + block - 1) / block, max_blocks); repeat_param<<>>( scale->data(), scale_tmp.data(), N, C); repeat_param<<>>( bias->data(), bias_tmp.data(), N, C); auto handle = dev_ctx.cudnn_handle(); math::SetConstant> functor; auto *saved_mean = ctx.Output("SavedMean"); auto *saved_variance = ctx.Output("SavedVariance"); saved_mean->mutable_data>(ctx.GetPlace()); saved_variance->mutable_data>(ctx.GetPlace()); functor(dev_ctx, saved_mean, static_cast>(0)); functor(dev_ctx, saved_variance, static_cast>(0)); CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining( handle, CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x_tmp.template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), in_param_desc_, scale_tmp.template data>(), bias_tmp.template data>(), 0, nullptr, nullptr, epsilon, saved_mean->template mutable_data>( ctx.GetPlace()), saved_variance->template mutable_data>( ctx.GetPlace()))); CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); CUDNN_ENFORCE( platform::dynload::cudnnDestroyTensorDescriptor(in_param_desc_)); } }; template static __global__ void GradComputeDX(const T *dy, const BatchNormParamType *scale, const BatchNormParamType *mean, const T *x, const BatchNormParamType *variance, const int C, const int sample_size, T *dx) { int beg_idx = blockIdx.x * sample_size + threadIdx.x; int end_idx = (blockIdx.x + 1) * sample_size; int ncid = blockIdx.x; int c = ncid % C; BatchNormParamType mean_val = mean[ncid]; BatchNormParamType inv_var_val = variance[ncid]; typedef cub::BlockReduce, BlockDim> BlockReduce; __shared__ typename BlockReduce::TempStorage dy_storage; __shared__ typename BlockReduce::TempStorage dy_x_sub_mean_storage; __shared__ BatchNormParamType dy_sum_val; __shared__ BatchNormParamType dy_x_sub_mean_sum_val; BatchNormParamType dy_sum = static_cast>(0); BatchNormParamType dy_x_sub_mean_sum = static_cast>(0); for (int i = beg_idx; i < end_idx; i += BlockDim) { BatchNormParamType dy_i = static_cast>(dy[i]); dy_sum += dy_i; dy_x_sub_mean_sum += dy_i * (static_cast>(x[i]) - mean_val); } dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum()); dy_x_sub_mean_sum = BlockReduce(dy_x_sub_mean_storage).Reduce(dy_x_sub_mean_sum, cub::Sum()); if (threadIdx.x == 0) { dy_sum_val = dy_sum; dy_x_sub_mean_sum_val = dy_x_sub_mean_sum; } __syncthreads(); for (int i = beg_idx; i < end_idx; i += BlockDim) { dx[i] = (static_cast>(dy[i]) - dy_sum_val / static_cast>(sample_size) - (static_cast>(x[i]) - mean_val) * dy_x_sub_mean_sum_val * inv_var_val * inv_var_val / sample_size) * scale[c] * inv_var_val; } } template class InstanceNormGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, "It must use CUDAPlace."); double epsilon = static_cast(ctx.Attr("epsilon")); const auto *scale = ctx.Input("Scale"); const auto *x = ctx.Input("X"); const auto *d_y = ctx.Input(framework::GradVarName("Y")); const auto &x_dims = x->dims(); int N, C, H, W, D; ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D); int NxC = N * C; Tensor x_tmp, d_y_tmp; x_tmp.ShareDataWith(*x).Resize({1, NxC, H, W, D}); d_y_tmp.ShareDataWith(*d_y).Resize({1, NxC, H, W, D}); auto *d_x = ctx.Output(framework::GradVarName("X")); auto *d_scale = ctx.Output(framework::GradVarName("Scale")); auto *d_bias = ctx.Output(framework::GradVarName("Bias")); d_x->mutable_data(ctx.GetPlace()); if (d_scale && d_bias) { d_scale->mutable_data(ctx.GetPlace()); d_bias->mutable_data(ctx.GetPlace()); } PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL); PADDLE_ENFORCE_EQ(scale->dims()[0], C); auto &dev_ctx = ctx.template device_context(); const int n = x->numel(); const int block = 512; int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); const int grid = std::min(NxC, max_blocks); const int grid1 = (C + block - 1) / block; Tensor scale_tmp = ctx.AllocateTmpTensor({NxC}, dev_ctx); scale_tmp.mutable_data(ctx.GetPlace()); Tensor d_scale_tmp = ctx.AllocateTmpTensor({NxC}, dev_ctx); Tensor d_bias_tmp = ctx.AllocateTmpTensor({NxC}, dev_ctx); repeat_param<<>>( scale->data(), scale_tmp.data(), N, C); std::vector dims; std::vector strides; dims = {1, NxC, H, W, D}; strides = {NxC * H * W * D, H * W * D, W * D, D, 1}; if ((H * W * D) == 1) { framework::TensorCopy(*d_y, ctx.GetPlace(), d_x); math::SetConstant> functor; functor(dev_ctx, d_scale, static_cast>(0)); functor(dev_ctx, d_bias, static_cast>(0)); return; } cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t in_param_desc_; CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); CUDNN_ENFORCE( platform::dynload::cudnnCreateTensorDescriptor(&in_param_desc_)); if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { LOG(ERROR) << "Provided epsilon is smaller than " << "CUDNN_BN_MIN_EPSILON. Setting it to " << "CUDNN_BN_MIN_EPSILON instead."; } epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( data_desc_, CudnnDataType::type, x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor( in_param_desc_, data_desc_, CUDNN_BATCHNORM_SPATIAL)); const auto *saved_mean = ctx.Input("SavedMean"); const auto *saved_var = ctx.Input("SavedVariance"); const auto *saved_mean_data = saved_mean->template data>(); const auto *saved_var_data = saved_var->template data>(); if (d_scale && d_bias) { CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward( dev_ctx.cudnn_handle(), CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(), CudnnDataType::kZero(), CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x_tmp.template data(), data_desc_, d_y_tmp.template data(), data_desc_, d_x->template mutable_data(ctx.GetPlace()), in_param_desc_, scale_tmp.template data>(), d_scale_tmp.template mutable_data>( ctx.GetPlace()), d_bias_tmp.template mutable_data>( ctx.GetPlace()), epsilon, saved_mean_data, saved_var_data)); } else { if (d_x) { GradComputeDX<<>>( d_y->data(), scale->data>(), saved_mean_data, x->data(), saved_var_data, C, H * W * D, d_x->data()); } } if (d_scale && d_bias) { add_param<<>>( d_scale_tmp.data(), d_scale->data(), N, C); add_param<<>>( d_bias_tmp.data(), d_bias->data(), N, C); } CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); CUDNN_ENFORCE( platform::dynload::cudnnDestroyTensorDescriptor(in_param_desc_)); } }; static __device__ __forceinline__ float real_sqrt(float x) { return 1. / sqrtf(x); } static __device__ __forceinline__ double real_sqrt(double x) { return 1. / sqrt(x); } template __global__ void DoubleGradComputeDX(const T *x, const T *mean, const T *variance, const T *ddx, const T *dy, const T *scale, const T *ddscale, int C, int sample_size, const double epsilon, T *dx) { int beg_idx = blockIdx.x * sample_size + threadIdx.x; int end_idx = (blockIdx.x + 1) * sample_size; int ncid = blockIdx.x; int c = ncid % C; T mean_val = mean[ncid]; T var_val = variance[ncid]; typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage dy_storage; __shared__ typename BlockReduce::TempStorage ddx_storage; __shared__ typename BlockReduce::TempStorage dy_mul_ddx_storage; __shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage; __shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage; __shared__ T dy_sum_val; __shared__ T ddx_sum_val; __shared__ T dy_mul_ddx_sum_val; __shared__ T dy_mul_x_sub_mean_sum_val; __shared__ T ddx_mul_x_sub_mean_sum_val; T dy_sum = 0; T ddx_sum = 0; T dy_mul_ddx_sum = 0; T dy_mul_x_sub_mean_sum = 0; T ddx_mul_x_sub_mean_sum = 0; for (int i = beg_idx; i < end_idx; i += BlockDim) { T ddx_i = ddx[i]; T dy_i = dy[i]; T tmp = x[i] - mean_val; dy_sum += dy_i; ddx_sum += ddx_i; dy_mul_ddx_sum += (ddx_i * dy_i); dy_mul_x_sub_mean_sum += (dy_i * tmp); ddx_mul_x_sub_mean_sum += (ddx_i * tmp); } dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum()); ddx_sum = BlockReduce(ddx_storage).Reduce(ddx_sum, cub::Sum()); dy_mul_ddx_sum = BlockReduce(dy_mul_ddx_storage).Reduce(dy_mul_ddx_sum, cub::Sum()); dy_mul_x_sub_mean_sum = BlockReduce(dy_mul_x_sub_mean_storage) .Reduce(dy_mul_x_sub_mean_sum, cub::Sum()); ddx_mul_x_sub_mean_sum = BlockReduce(ddx_mul_x_sub_mean_storage) .Reduce(ddx_mul_x_sub_mean_sum, cub::Sum()); if (threadIdx.x == 0) { dy_sum_val = dy_sum; ddx_sum_val = ddx_sum; dy_mul_ddx_sum_val = dy_mul_ddx_sum; dy_mul_x_sub_mean_sum_val = dy_mul_x_sub_mean_sum; ddx_mul_x_sub_mean_sum_val = ddx_mul_x_sub_mean_sum; } __syncthreads(); if (ddx != nullptr) { for (int i = beg_idx; i < end_idx; i += BlockDim) { dx[i] += ((x[i] - mean_val) * var_val * var_val * var_val / sample_size * (ddx_sum_val * dy_sum_val / sample_size - dy_mul_ddx_sum_val + 3. * dy_mul_x_sub_mean_sum_val * var_val * ddx_mul_x_sub_mean_sum_val * var_val / sample_size) + ddx_mul_x_sub_mean_sum_val * var_val / sample_size * var_val * var_val * (dy_sum_val / sample_size - dy[i]) + dy_mul_x_sub_mean_sum_val * var_val / sample_size * var_val * var_val * (ddx_sum_val / sample_size - ddx[i])) * scale[c]; } } __syncthreads(); if (ddscale != nullptr) { for (int i = beg_idx; i < end_idx; i += BlockDim) { dx[i] += (dy[i] * var_val - dy_sum_val / sample_size * var_val - (x[i] - mean_val) * var_val * dy_mul_x_sub_mean_sum_val * var_val / sample_size) * ddscale[c]; } } } template __global__ void DoubleGradComputeDDY(const T *x, const T *mean, const T *variance, const T *ddscale, const T *ddbias, const T *ddx, const T *scale, int C, int sample_size, const double epsilon, T *ddy) { int beg_idx = blockIdx.x * sample_size + threadIdx.x; int end_idx = (blockIdx.x + 1) * sample_size; int ncid = blockIdx.x; int c = ncid % C; T mean_val = mean[ncid]; T var_val = variance[ncid]; typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage ddx_storage; __shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage; __shared__ T ddx_sum_val; __shared__ T ddx_mul_x_sub_mean_sum_val; T ddx_sum = 0; T ddx_mul_x_sub_mean_sum = 0; for (int i = beg_idx; i < end_idx; i += BlockDim) { T ddx_i = ddx[i]; ddx_sum += ddx_i; ddx_mul_x_sub_mean_sum += (ddx_i * (x[i] - mean_val)); } ddx_sum = BlockReduce(ddx_storage).Reduce(ddx_sum, cub::Sum()); ddx_mul_x_sub_mean_sum = BlockReduce(ddx_mul_x_sub_mean_storage) .Reduce(ddx_mul_x_sub_mean_sum, cub::Sum()); if (threadIdx.x == 0) { ddx_sum_val = ddx_sum; ddx_mul_x_sub_mean_sum_val = ddx_mul_x_sub_mean_sum; } __syncthreads(); if (ddx != nullptr) { for (int i = beg_idx; i < end_idx; i += BlockDim) { ddy[i] += scale[c] * var_val * (ddx[i] - ddx_sum_val / sample_size - (x[i] - mean_val) * var_val * ddx_mul_x_sub_mean_sum_val * var_val / sample_size); } } __syncthreads(); if (ddscale != nullptr) { for (int i = beg_idx; i < end_idx; i += BlockDim) { ddy[i] += (x[i] - mean_val) * var_val * ddscale[c]; } } __syncthreads(); if (ddbias != nullptr) { for (int i = beg_idx; i < end_idx; i += BlockDim) { ddy[i] += ddbias[c]; } } } template __global__ void DoubleGradComputeDScale(const T *x, const T *mean, const T *variance, const T *ddx, const T *dy, int C, int sample_size, const double epsilon, T *dscale) { int beg_idx = blockIdx.x * sample_size + threadIdx.x; int end_idx = (blockIdx.x + 1) * sample_size; int ncid = blockIdx.x; int c = ncid % C; T mean_val = mean[ncid]; T var_val = variance[ncid]; typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage dy_storage; __shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage; __shared__ typename BlockReduce::TempStorage dscale_tmp_storage; __shared__ T dy_sum_val; __shared__ T dy_mul_x_sub_mean_sum_val; T dy_sum = 0; T dy_mul_x_sub_mean_sum = 0; for (int i = beg_idx; i < end_idx; i += BlockDim) { T dy_i = dy[i]; dy_sum += dy_i; dy_mul_x_sub_mean_sum += (dy_i * (x[i] - mean_val)); } dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum()); dy_mul_x_sub_mean_sum = BlockReduce(dy_mul_x_sub_mean_storage) .Reduce(dy_mul_x_sub_mean_sum, cub::Sum()); if (threadIdx.x == 0) { dy_sum_val = dy_sum; dy_mul_x_sub_mean_sum_val = dy_mul_x_sub_mean_sum; } __syncthreads(); if (ddx != nullptr) { T dscale_tmp = 0; for (int i = beg_idx; i < end_idx; i += BlockDim) { dscale_tmp += ddx[i] * var_val * (dy[i] - dy_sum_val / sample_size - dy_mul_x_sub_mean_sum_val * (x[i] - mean_val) * var_val * var_val / sample_size); } dscale_tmp = BlockReduce(dscale_tmp_storage).Reduce(dscale_tmp, cub::Sum()); if (threadIdx.x == 0) { dscale[ncid] += dscale_tmp; } __syncthreads(); } } template class InstanceNormDoubleGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { const auto *X = ctx.Input("X"); const auto *Scale = ctx.Input("Scale"); const auto *dY = ctx.Input("DY"); const auto *Saved_mean = ctx.Input("SavedMean"); const auto *Saved_variance = ctx.Input("SavedVariance"); const auto *running_mean = ctx.Input("Mean"); const auto *running_var = ctx.Input("Variance"); const auto *ddX = ctx.Input("DDX"); const auto *ddScale = ctx.Input("DDScale"); const auto *ddBias = ctx.Input("DDBias"); const double epsilon = static_cast(ctx.Attr("epsilon")); auto *dX = ctx.Output("DX"); auto *dScale = ctx.Output("DScale"); auto *ddY = ctx.Output("DDY"); const T *x_data = X->data(); const T *scale_data = Scale->data(); const T *dy_data = dY->data(); const T *ddx_data = (ddX == nullptr ? nullptr : ddX->data()); const T *ddscale_data = (ddScale == nullptr ? nullptr : ddScale->data()); const T *ddbias_data = (ddScale == nullptr ? nullptr : ddBias->data()); const T *mean_data = Saved_mean->data(); const T *variance_data = Saved_variance->data(); auto &x_dims = X->dims(); int N, C, H, W, D; ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D); int NxC = N * C; const int n = X->numel(); int sample_size = n / N / C; auto &dev_ctx = ctx.template device_context(); const int block = 512; int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); const int grid = NxC; const int grid1 = (C + block - 1) / block; math::SetConstant set_zero; if (dX) { T *dx_data = dX->mutable_data(ctx.GetPlace()); set_zero(dev_ctx, dX, static_cast(0)); DoubleGradComputeDX<<>>( x_data, mean_data, variance_data, ddx_data, dy_data, scale_data, ddscale_data, C, sample_size, epsilon, dx_data); } if (dScale) { Tensor dscale_tmp = ctx.AllocateTmpTensor({NxC}, dev_ctx); set_zero(dev_ctx, &dscale_tmp, static_cast(0)); T *dscale_tmp_data = dscale_tmp.mutable_data(ctx.GetPlace()); T *dscale_data = dScale->mutable_data(ctx.GetPlace()); set_zero(dev_ctx, dScale, static_cast(0)); DoubleGradComputeDScale<<>>( x_data, mean_data, variance_data, ddx_data, dy_data, C, sample_size, epsilon, dscale_tmp_data); add_param<<>>( dscale_tmp.data(), dScale->data(), N, C); } if (ddY) { T *ddy_data = ddY->mutable_data(ctx.GetPlace()); set_zero(dev_ctx, ddY, static_cast(0)); DoubleGradComputeDDY<<>>( x_data, mean_data, variance_data, ddscale_data, ddbias_data, ddx_data, scale_data, C, sample_size, epsilon, ddy_data); } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( instance_norm, ops::InstanceNormKernel, ops::InstanceNormKernel); REGISTER_OP_CUDA_KERNEL( instance_norm_grad, ops::InstanceNormGradKernel, ops::InstanceNormGradKernel); REGISTER_OP_CUDA_KERNEL( instance_norm_grad_grad, ops::InstanceNormDoubleGradKernel, ops::InstanceNormDoubleGradKernel);