instance_norm_op.cu 27.4 KB
Newer Older
L
lvmengsi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>
#include <cfloat>
#include <string>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/instance_norm_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;

template <typename T>
static __global__ void repeat_param(const T *input, T *output,
                                    const int repeat_num, const int C) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < repeat_num * C;
       i += blockDim.x * gridDim.x) {
    int index = i % C;
    output[i] = input[index];
  }
}

template <typename T, int BlockDim, bool AVG>
static __global__ void add_param(const T *input, T *output,
                                 const int repeat_num, const int C) {
  typedef cub::BlockReduce<T, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage ou_storage;
  for (int i = blockIdx.x; i < C; i += gridDim.x) {
    T ou = static_cast<T>(0);
    for (int j = threadIdx.x; j < repeat_num; j += blockDim.x) {
      const int index = j * C + i;
      ou += static_cast<T>(input[index]);
    }
    ou = BlockReduce(ou_storage).Reduce(ou, cub::Sum());
    if (threadIdx.x == 0) {
      output[i] = ou;
    }
    __syncthreads();

    if (AVG) {
      output[i] /= repeat_num;
    }
  }
}

template <typename T>
class InstanceNormKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
C
ceci3 已提交
73 74 75
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(ctx.GetPlace()), true,
        platform::errors::PreconditionNotMet("It must be CUDAPlace."));
L
lvmengsi 已提交
76 77 78 79
    double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));

    auto *x = ctx.Input<Tensor>("X");
    auto &x_dims = x->dims();
C
ceci3 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93
    PADDLE_ENFORCE_GE(x_dims.size(), 2,
                      platform::errors::InvalidArgument(
                          "The `shape` in InstanceNormOp is invalid: "
                          "the size of X's dimensions must greater than "
                          "or equal to 2. But received: "
                          "the size of X's dimensions is [%d]",
                          x_dims.size()));
    PADDLE_ENFORCE_LE(x_dims.size(), 5,
                      platform::errors::InvalidArgument(
                          "The `shape` in InstanceNormOp is invalid: "
                          "the size of X's dimensions must smaller than"
                          "or equal to 5. But received: "
                          "the size of X's dimensions is [%d]",
                          x_dims.size()));
L
lvmengsi 已提交
94 95 96 97 98 99 100 101 102 103 104 105
    int N, C, H, W, D;
    ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D);
    int NxC = N * C;
    Tensor x_tmp;
    x_tmp.ShareDataWith(*x).Resize({1, NxC, H, W, D});

    auto *y = ctx.Output<Tensor>("Y");
    y->mutable_data<T>(ctx.GetPlace());

    cudnnTensorDescriptor_t data_desc_;
    cudnnTensorDescriptor_t in_param_desc_;

106 107 108
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
    PADDLE_ENFORCE_CUDA_SUCCESS(
L
lvmengsi 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
        platform::dynload::cudnnCreateTensorDescriptor(&in_param_desc_));

    if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
      LOG(ERROR) << "Provided epsilon is smaller than "
                 << "CUDNN_BN_MIN_EPSILON. Setting it to "
                 << "CUDNN_BN_MIN_EPSILON instead.";
    }
    epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);

    VLOG(3) << "Setting descriptors.";
    std::vector<int> dims;
    std::vector<int> strides;
    dims = {1, NxC, H, W, D};
    strides = {NxC * H * W * D, H * W * D, W * D, D, 1};

    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();

126
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
L
lvmengsi 已提交
127 128
        data_desc_, CudnnDataType<T>::type,
        x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
129 130 131
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnDeriveBNTensorDescriptor(
            in_param_desc_, data_desc_, CUDNN_BATCHNORM_SPATIAL));
L
lvmengsi 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148

    const auto *scale = ctx.Input<Tensor>("Scale");
    const auto *bias = ctx.Input<Tensor>("Bias");

    Tensor scale_tmp =
        ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({NxC}, dev_ctx);
    scale_tmp.mutable_data<T>(ctx.GetPlace());
    Tensor bias_tmp =
        ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({NxC}, dev_ctx);
    bias_tmp.mutable_data<T>(ctx.GetPlace());

    const int n = x->numel();
    const int block = 512;
    int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
    const int max_blocks = std::max(max_threads / block, 1);
    const int grid = std::min((NxC + block - 1) / block, max_blocks);

C
ceci3 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161
    math::SetConstant<platform::CUDADeviceContext, T> set_constant;
    if (scale) {
      repeat_param<T><<<grid, block, 0, dev_ctx.stream()>>>(
          scale->data<T>(), scale_tmp.data<T>(), N, C);
    } else {
      set_constant(dev_ctx, &scale_tmp, static_cast<T>(1));
    }
    if (bias) {
      repeat_param<T><<<grid, block, 0, dev_ctx.stream()>>>(
          bias->data<T>(), bias_tmp.data<T>(), N, C);
    } else {
      set_constant(dev_ctx, &bias_tmp, static_cast<T>(0));
    }
L
lvmengsi 已提交
162 163 164 165 166 167 168 169 170 171 172 173 174

    auto handle = dev_ctx.cudnn_handle();

    math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
        functor;

    auto *saved_mean = ctx.Output<Tensor>("SavedMean");
    auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
    saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
    saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
    functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
    functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));

175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnBatchNormalizationForwardTraining(
            handle, CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
            CudnnDataType<T>::kZero(), data_desc_, x_tmp.template data<T>(),
            data_desc_, y->template mutable_data<T>(ctx.GetPlace()),
            in_param_desc_, scale_tmp.template data<BatchNormParamType<T>>(),
            bias_tmp.template data<BatchNormParamType<T>>(), 0, nullptr,
            nullptr, epsilon,
            saved_mean->template mutable_data<BatchNormParamType<T>>(
                ctx.GetPlace()),
            saved_variance->template mutable_data<BatchNormParamType<T>>(
                ctx.GetPlace())));

    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
    PADDLE_ENFORCE_CUDA_SUCCESS(
L
lvmengsi 已提交
191 192 193 194
        platform::dynload::cudnnDestroyTensorDescriptor(in_param_desc_));
  }
};

L
lvmengsi 已提交
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
template <typename T, int BlockDim>
static __global__ void GradComputeDX(const T *dy,
                                     const BatchNormParamType<T> *scale,
                                     const BatchNormParamType<T> *mean,
                                     const T *x,
                                     const BatchNormParamType<T> *variance,
                                     const int C, const int sample_size,
                                     T *dx) {
  int beg_idx = blockIdx.x * sample_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * sample_size;
  int ncid = blockIdx.x;
  int c = ncid % C;

  BatchNormParamType<T> mean_val = mean[ncid];
  BatchNormParamType<T> inv_var_val = variance[ncid];

  typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage dy_storage;
  __shared__ typename BlockReduce::TempStorage dy_x_sub_mean_storage;
  __shared__ BatchNormParamType<T> dy_sum_val;
  __shared__ BatchNormParamType<T> dy_x_sub_mean_sum_val;

  BatchNormParamType<T> dy_sum = static_cast<BatchNormParamType<T>>(0);
  BatchNormParamType<T> dy_x_sub_mean_sum =
      static_cast<BatchNormParamType<T>>(0);

  for (int i = beg_idx; i < end_idx; i += BlockDim) {
    BatchNormParamType<T> dy_i = static_cast<BatchNormParamType<T>>(dy[i]);
    dy_sum += dy_i;
    dy_x_sub_mean_sum +=
        dy_i * (static_cast<BatchNormParamType<T>>(x[i]) - mean_val);
  }
  dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
  dy_x_sub_mean_sum =
      BlockReduce(dy_x_sub_mean_storage).Reduce(dy_x_sub_mean_sum, cub::Sum());

  if (threadIdx.x == 0) {
    dy_sum_val = dy_sum;
    dy_x_sub_mean_sum_val = dy_x_sub_mean_sum;
  }
  __syncthreads();

  for (int i = beg_idx; i < end_idx; i += BlockDim) {
    dx[i] =
        (static_cast<BatchNormParamType<T>>(dy[i]) -
         dy_sum_val / static_cast<BatchNormParamType<T>>(sample_size) -
         (static_cast<BatchNormParamType<T>>(x[i]) - mean_val) *
             dy_x_sub_mean_sum_val * inv_var_val * inv_var_val / sample_size) *
        scale[c] * inv_var_val;
  }
}

L
lvmengsi 已提交
247 248 249 250 251
template <typename T>
class InstanceNormGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
C
ceci3 已提交
252 253 254
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(ctx.GetPlace()), true,
        platform::errors::PreconditionNotMet("It must use CUDAPlace."));
L
lvmengsi 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
    double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
    const auto *scale = ctx.Input<Tensor>("Scale");
    const auto *x = ctx.Input<Tensor>("X");
    const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));

    const auto &x_dims = x->dims();

    int N, C, H, W, D;
    ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D);
    int NxC = N * C;

    Tensor x_tmp, d_y_tmp;
    x_tmp.ShareDataWith(*x).Resize({1, NxC, H, W, D});
    d_y_tmp.ShareDataWith(*d_y).Resize({1, NxC, H, W, D});

    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

    d_x->mutable_data<T>(ctx.GetPlace());
    if (d_scale && d_bias) {
      d_scale->mutable_data<T>(ctx.GetPlace());
      d_bias->mutable_data<T>(ctx.GetPlace());
    }
C
ceci3 已提交
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
    if (scale) {
      PADDLE_ENFORCE_EQ(
          scale->dims().size(), 1UL,
          platform::errors::InvalidArgument(
              "The `shape` in InstanceNormOp is invalid: "
              "the size of scale's dimensions must be equal to 1. But "
              "received: the size of scale's dimensions"
              "is [%d]",
              scale->dims().size()));
      PADDLE_ENFORCE_EQ(scale->dims()[0], C,
                        platform::errors::InvalidArgument(
                            "The `shape` in InstanceNormOp is invalid: "
                            "the first dimension of scale must be equal to "
                            "Channels([%d]). But received: "
                            "the first dimension of scale is [%d],"
                            "the dimensions of scale is [%s], ",
                            C, scale->dims()[0], scale->dims()));
    }
L
lvmengsi 已提交
297 298

    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
C
ceci3 已提交
299
    math::SetConstant<platform::CUDADeviceContext, T> set_constant;
L
lvmengsi 已提交
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314

    const int n = x->numel();
    const int block = 512;
    int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
    const int max_blocks = std::max(max_threads / block, 1);
    const int grid = std::min(NxC, max_blocks);
    const int grid1 = (C + block - 1) / block;

    Tensor scale_tmp =
        ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({NxC}, dev_ctx);
    scale_tmp.mutable_data<T>(ctx.GetPlace());
    Tensor d_scale_tmp =
        ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({NxC}, dev_ctx);
    Tensor d_bias_tmp =
        ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({NxC}, dev_ctx);
C
ceci3 已提交
315 316 317 318 319 320
    if (scale) {
      repeat_param<T><<<grid, block, 0, dev_ctx.stream()>>>(
          scale->data<T>(), scale_tmp.data<T>(), N, C);
    } else {
      set_constant(dev_ctx, &scale_tmp, static_cast<T>(1));
    }
L
lvmengsi 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338

    std::vector<int> dims;
    std::vector<int> strides;
    dims = {1, NxC, H, W, D};
    strides = {NxC * H * W * D, H * W * D, W * D, D, 1};

    if ((H * W * D) == 1) {
      framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
      math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
          functor;
      functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
      functor(dev_ctx, d_bias, static_cast<BatchNormParamType<T>>(0));
      return;
    }

    cudnnTensorDescriptor_t data_desc_;
    cudnnTensorDescriptor_t in_param_desc_;

339 340 341
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
    PADDLE_ENFORCE_CUDA_SUCCESS(
L
lvmengsi 已提交
342 343 344 345 346 347 348 349
        platform::dynload::cudnnCreateTensorDescriptor(&in_param_desc_));
    if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
      LOG(ERROR) << "Provided epsilon is smaller than "
                 << "CUDNN_BN_MIN_EPSILON. Setting it to "
                 << "CUDNN_BN_MIN_EPSILON instead.";
    }
    epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);

350
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
L
lvmengsi 已提交
351 352
        data_desc_, CudnnDataType<T>::type,
        x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
353 354 355
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnDeriveBNTensorDescriptor(
            in_param_desc_, data_desc_, CUDNN_BATCHNORM_SPATIAL));
L
lvmengsi 已提交
356 357 358

    const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
    const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
L
lvmengsi 已提交
359
    const auto *saved_mean_data =
L
lvmengsi 已提交
360
        saved_mean->template data<BatchNormParamType<T>>();
L
lvmengsi 已提交
361
    const auto *saved_var_data =
L
lvmengsi 已提交
362
        saved_var->template data<BatchNormParamType<T>>();
L
lvmengsi 已提交
363
    if (d_scale && d_bias) {
364 365 366 367 368 369 370 371 372 373 374 375 376
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnBatchNormalizationBackward(
              dev_ctx.cudnn_handle(), CUDNN_BATCHNORM_SPATIAL,
              CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
              CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(), data_desc_,
              x_tmp.template data<T>(), data_desc_, d_y_tmp.template data<T>(),
              data_desc_, d_x->template mutable_data<T>(ctx.GetPlace()),
              in_param_desc_, scale_tmp.template data<BatchNormParamType<T>>(),
              d_scale_tmp.template mutable_data<BatchNormParamType<T>>(
                  ctx.GetPlace()),
              d_bias_tmp.template mutable_data<BatchNormParamType<T>>(
                  ctx.GetPlace()),
              epsilon, saved_mean_data, saved_var_data));
L
lvmengsi 已提交
377 378
    } else {
      if (d_x) {
C
ceci3 已提交
379
        GradComputeDX<T, block><<<NxC, block, 0, dev_ctx.stream()>>>(
C
ceci3 已提交
380
            d_y->data<T>(), scale_tmp.data<BatchNormParamType<T>>(),
L
lvmengsi 已提交
381 382 383 384
            saved_mean_data, x->data<T>(), saved_var_data, C, H * W * D,
            d_x->data<T>());
      }
    }
L
lvmengsi 已提交
385 386 387 388 389 390 391 392

    if (d_scale && d_bias) {
      add_param<T, block, false><<<grid1, block, 0, dev_ctx.stream()>>>(
          d_scale_tmp.data<T>(), d_scale->data<T>(), N, C);
      add_param<T, block, false><<<grid1, block, 0, dev_ctx.stream()>>>(
          d_bias_tmp.data<T>(), d_bias->data<T>(), N, C);
    }

393 394 395
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
    PADDLE_ENFORCE_CUDA_SUCCESS(
L
lvmengsi 已提交
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637
        platform::dynload::cudnnDestroyTensorDescriptor(in_param_desc_));
  }
};

static __device__ __forceinline__ float real_sqrt(float x) {
  return 1. / sqrtf(x);
}
static __device__ __forceinline__ double real_sqrt(double x) {
  return 1. / sqrt(x);
}

template <typename T, int BlockDim>
__global__ void DoubleGradComputeDX(const T *x, const T *mean,
                                    const T *variance, const T *ddx,
                                    const T *dy, const T *scale,
                                    const T *ddscale, int C, int sample_size,
                                    const double epsilon, T *dx) {
  int beg_idx = blockIdx.x * sample_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * sample_size;
  int ncid = blockIdx.x;
  int c = ncid % C;

  T mean_val = mean[ncid];
  T var_val = variance[ncid];

  typedef cub::BlockReduce<T, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage dy_storage;
  __shared__ typename BlockReduce::TempStorage ddx_storage;
  __shared__ typename BlockReduce::TempStorage dy_mul_ddx_storage;
  __shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage;
  __shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage;
  __shared__ T dy_sum_val;
  __shared__ T ddx_sum_val;
  __shared__ T dy_mul_ddx_sum_val;
  __shared__ T dy_mul_x_sub_mean_sum_val;
  __shared__ T ddx_mul_x_sub_mean_sum_val;

  T dy_sum = 0;
  T ddx_sum = 0;
  T dy_mul_ddx_sum = 0;
  T dy_mul_x_sub_mean_sum = 0;
  T ddx_mul_x_sub_mean_sum = 0;
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
    T ddx_i = ddx[i];
    T dy_i = dy[i];
    T tmp = x[i] - mean_val;

    dy_sum += dy_i;
    ddx_sum += ddx_i;
    dy_mul_ddx_sum += (ddx_i * dy_i);

    dy_mul_x_sub_mean_sum += (dy_i * tmp);
    ddx_mul_x_sub_mean_sum += (ddx_i * tmp);
  }

  dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
  ddx_sum = BlockReduce(ddx_storage).Reduce(ddx_sum, cub::Sum());
  dy_mul_ddx_sum =
      BlockReduce(dy_mul_ddx_storage).Reduce(dy_mul_ddx_sum, cub::Sum());
  dy_mul_x_sub_mean_sum = BlockReduce(dy_mul_x_sub_mean_storage)
                              .Reduce(dy_mul_x_sub_mean_sum, cub::Sum());
  ddx_mul_x_sub_mean_sum = BlockReduce(ddx_mul_x_sub_mean_storage)
                               .Reduce(ddx_mul_x_sub_mean_sum, cub::Sum());

  if (threadIdx.x == 0) {
    dy_sum_val = dy_sum;
    ddx_sum_val = ddx_sum;
    dy_mul_ddx_sum_val = dy_mul_ddx_sum;
    dy_mul_x_sub_mean_sum_val = dy_mul_x_sub_mean_sum;
    ddx_mul_x_sub_mean_sum_val = ddx_mul_x_sub_mean_sum;
  }
  __syncthreads();

  if (ddx != nullptr) {
    for (int i = beg_idx; i < end_idx; i += BlockDim) {
      dx[i] +=
          ((x[i] - mean_val) * var_val * var_val * var_val / sample_size *
               (ddx_sum_val * dy_sum_val / sample_size - dy_mul_ddx_sum_val +
                3. * dy_mul_x_sub_mean_sum_val * var_val *
                    ddx_mul_x_sub_mean_sum_val * var_val / sample_size) +
           ddx_mul_x_sub_mean_sum_val * var_val / sample_size * var_val *
               var_val * (dy_sum_val / sample_size - dy[i]) +
           dy_mul_x_sub_mean_sum_val * var_val / sample_size * var_val *
               var_val * (ddx_sum_val / sample_size - ddx[i])) *
          scale[c];
    }
  }
  __syncthreads();
  if (ddscale != nullptr) {
    for (int i = beg_idx; i < end_idx; i += BlockDim) {
      dx[i] += (dy[i] * var_val - dy_sum_val / sample_size * var_val -
                (x[i] - mean_val) * var_val * dy_mul_x_sub_mean_sum_val *
                    var_val / sample_size) *
               ddscale[c];
    }
  }
}

template <typename T, int BlockDim>
__global__ void DoubleGradComputeDDY(const T *x, const T *mean,
                                     const T *variance, const T *ddscale,
                                     const T *ddbias, const T *ddx,
                                     const T *scale, int C, int sample_size,
                                     const double epsilon, T *ddy) {
  int beg_idx = blockIdx.x * sample_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * sample_size;
  int ncid = blockIdx.x;
  int c = ncid % C;

  T mean_val = mean[ncid];
  T var_val = variance[ncid];

  typedef cub::BlockReduce<T, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage ddx_storage;
  __shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage;
  __shared__ T ddx_sum_val;
  __shared__ T ddx_mul_x_sub_mean_sum_val;

  T ddx_sum = 0;
  T ddx_mul_x_sub_mean_sum = 0;
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
    T ddx_i = ddx[i];
    ddx_sum += ddx_i;
    ddx_mul_x_sub_mean_sum += (ddx_i * (x[i] - mean_val));
  }
  ddx_sum = BlockReduce(ddx_storage).Reduce(ddx_sum, cub::Sum());
  ddx_mul_x_sub_mean_sum = BlockReduce(ddx_mul_x_sub_mean_storage)
                               .Reduce(ddx_mul_x_sub_mean_sum, cub::Sum());

  if (threadIdx.x == 0) {
    ddx_sum_val = ddx_sum;
    ddx_mul_x_sub_mean_sum_val = ddx_mul_x_sub_mean_sum;
  }
  __syncthreads();

  if (ddx != nullptr) {
    for (int i = beg_idx; i < end_idx; i += BlockDim) {
      ddy[i] += scale[c] * var_val *
                (ddx[i] - ddx_sum_val / sample_size -
                 (x[i] - mean_val) * var_val * ddx_mul_x_sub_mean_sum_val *
                     var_val / sample_size);
    }
  }
  __syncthreads();
  if (ddscale != nullptr) {
    for (int i = beg_idx; i < end_idx; i += BlockDim) {
      ddy[i] += (x[i] - mean_val) * var_val * ddscale[c];
    }
  }
  __syncthreads();
  if (ddbias != nullptr) {
    for (int i = beg_idx; i < end_idx; i += BlockDim) {
      ddy[i] += ddbias[c];
    }
  }
}

template <typename T, int BlockDim>
__global__ void DoubleGradComputeDScale(const T *x, const T *mean,
                                        const T *variance, const T *ddx,
                                        const T *dy, int C, int sample_size,
                                        const double epsilon, T *dscale) {
  int beg_idx = blockIdx.x * sample_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * sample_size;
  int ncid = blockIdx.x;
  int c = ncid % C;

  T mean_val = mean[ncid];
  T var_val = variance[ncid];

  typedef cub::BlockReduce<T, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage dy_storage;
  __shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage;
  __shared__ typename BlockReduce::TempStorage dscale_tmp_storage;
  __shared__ T dy_sum_val;
  __shared__ T dy_mul_x_sub_mean_sum_val;

  T dy_sum = 0;
  T dy_mul_x_sub_mean_sum = 0;
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
    T dy_i = dy[i];
    dy_sum += dy_i;
    dy_mul_x_sub_mean_sum += (dy_i * (x[i] - mean_val));
  }
  dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
  dy_mul_x_sub_mean_sum = BlockReduce(dy_mul_x_sub_mean_storage)
                              .Reduce(dy_mul_x_sub_mean_sum, cub::Sum());

  if (threadIdx.x == 0) {
    dy_sum_val = dy_sum;
    dy_mul_x_sub_mean_sum_val = dy_mul_x_sub_mean_sum;
  }
  __syncthreads();

  if (ddx != nullptr) {
    T dscale_tmp = 0;
    for (int i = beg_idx; i < end_idx; i += BlockDim) {
      dscale_tmp +=
          ddx[i] * var_val * (dy[i] - dy_sum_val / sample_size -
                              dy_mul_x_sub_mean_sum_val * (x[i] - mean_val) *
                                  var_val * var_val / sample_size);
    }
    dscale_tmp = BlockReduce(dscale_tmp_storage).Reduce(dscale_tmp, cub::Sum());

    if (threadIdx.x == 0) {
      dscale[ncid] += dscale_tmp;
    }
    __syncthreads();
  }
}

template <typename T>
class InstanceNormDoubleGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *X = ctx.Input<Tensor>("X");
    const auto *Scale = ctx.Input<Tensor>("Scale");
    const auto *dY = ctx.Input<Tensor>("DY");
    const auto *Saved_mean = ctx.Input<Tensor>("SavedMean");
    const auto *Saved_variance = ctx.Input<Tensor>("SavedVariance");
    const auto *running_mean = ctx.Input<Tensor>("Mean");
    const auto *running_var = ctx.Input<Tensor>("Variance");
    const auto *ddX = ctx.Input<Tensor>("DDX");
    const auto *ddScale = ctx.Input<Tensor>("DDScale");
    const auto *ddBias = ctx.Input<Tensor>("DDBias");
    const double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));

    auto *dX = ctx.Output<Tensor>("DX");
    auto *dScale = ctx.Output<Tensor>("DScale");
    auto *ddY = ctx.Output<Tensor>("DDY");

    const T *x_data = X->data<T>();
    const T *dy_data = dY->data<T>();
    const T *ddx_data = (ddX == nullptr ? nullptr : ddX->data<T>());

    const T *ddscale_data = (ddScale == nullptr ? nullptr : ddScale->data<T>());
    const T *ddbias_data = (ddScale == nullptr ? nullptr : ddBias->data<T>());

    const T *mean_data = Saved_mean->data<T>();
    const T *variance_data = Saved_variance->data<T>();

C
ceci3 已提交
638 639 640
    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    math::SetConstant<platform::CUDADeviceContext, T> set_zero;

L
lvmengsi 已提交
641 642 643 644 645 646 647
    auto &x_dims = X->dims();
    int N, C, H, W, D;
    ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D);
    int NxC = N * C;
    const int n = X->numel();
    int sample_size = n / N / C;

C
ceci3 已提交
648 649 650 651 652 653 654
    Tensor scale_tmp;
    if (!Scale) {
      scale_tmp.mutable_data<T>({C}, ctx.GetPlace());
      set_zero(dev_ctx, &scale_tmp, static_cast<T>(1));
    }
    const T *scale_data = Scale ? Scale->data<T>() : scale_tmp.data<T>();

L
lvmengsi 已提交
655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709
    const int block = 512;
    int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
    const int max_blocks = std::max(max_threads / block, 1);
    const int grid = NxC;
    const int grid1 = (C + block - 1) / block;

    if (dX) {
      T *dx_data = dX->mutable_data<T>(ctx.GetPlace());
      set_zero(dev_ctx, dX, static_cast<T>(0));
      DoubleGradComputeDX<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
          x_data, mean_data, variance_data, ddx_data, dy_data, scale_data,
          ddscale_data, C, sample_size, epsilon, dx_data);
    }
    if (dScale) {
      Tensor dscale_tmp =
          ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({NxC}, dev_ctx);
      set_zero(dev_ctx, &dscale_tmp, static_cast<T>(0));
      T *dscale_tmp_data = dscale_tmp.mutable_data<T>(ctx.GetPlace());

      T *dscale_data = dScale->mutable_data<T>(ctx.GetPlace());
      set_zero(dev_ctx, dScale, static_cast<T>(0));
      DoubleGradComputeDScale<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
          x_data, mean_data, variance_data, ddx_data, dy_data, C, sample_size,
          epsilon, dscale_tmp_data);
      add_param<T, block, false><<<grid1, block, 0, dev_ctx.stream()>>>(
          dscale_tmp.data<T>(), dScale->data<T>(), N, C);
    }
    if (ddY) {
      T *ddy_data = ddY->mutable_data<T>(ctx.GetPlace());
      set_zero(dev_ctx, ddY, static_cast<T>(0));
      DoubleGradComputeDDY<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
          x_data, mean_data, variance_data, ddscale_data, ddbias_data, ddx_data,
          scale_data, C, sample_size, epsilon, ddy_data);
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
    instance_norm, ops::InstanceNormKernel<plat::CUDADeviceContext, float>,
    ops::InstanceNormKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
    instance_norm_grad,
    ops::InstanceNormGradKernel<plat::CUDADeviceContext, float>,
    ops::InstanceNormGradKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
    instance_norm_grad_grad,
    ops::InstanceNormDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                      float>,
    ops::InstanceNormDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                      double>);