layer_norm_op.cu 19.7 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

S
sneaxiy 已提交
15
#include <cub/cub.cuh>
Y
Yi Wang 已提交
16
#include "paddle/fluid/operators/layer_norm_op.h"
C
chengduoZH 已提交
17

S
sneaxiy 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
namespace paddle {
namespace operators {

inline static int GetDesiredBlockDim(int block_dim) {
  const int kMaxBlockDim = 512;
  return block_dim >= kMaxBlockDim
             ? kMaxBlockDim
             : (1 << (static_cast<int>(std::log2f(block_dim))));
}

#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...)  \
  case (1 << (log2_block_dim)): {                       \
    constexpr auto kBlockDim = (1 << (log2_block_dim)); \
    __VA_ARGS__;                                        \
  } break

#define FIXED_BLOCK_DIM_CASE(...)              \
  FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(2, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(1, ##__VA_ARGS__)

45 46 47
static __device__ __forceinline__ float real_sqrt(float x) { return sqrtf(x); }
static __device__ __forceinline__ double real_sqrt(double x) { return sqrt(x); }

48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
template <typename T>
struct PairForLayerNorm {
  __device__ __forceinline__ PairForLayerNorm() {}
  __device__ __forceinline__ PairForLayerNorm(const T &first, const T &second)
      : first_(first), second_(second) {}

  T first_;
  T second_;
};

template <typename T>
struct PairForLayerNormAddFunctor {
  __device__ __forceinline__ PairForLayerNorm<T> operator()(
      const PairForLayerNorm<T> &p1, const PairForLayerNorm<T> &p2) {
    return PairForLayerNorm<T>(p1.first_ + p2.first_, p1.second_ + p2.second_);
  }
};

S
sneaxiy 已提交
66 67 68 69
template <typename T, int BlockDim>
__global__ void LayerNormForward(const T *x, const T *scale, const T *bias,
                                 T *y, T *mean, T *var, float epsilon,
                                 int feature_size) {
70
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
S
sneaxiy 已提交
71 72 73 74 75
  __shared__ typename BlockReduce::TempStorage temp_storage;

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

76
  // Step 1: Reduce to calculate mean and var
S
sneaxiy 已提交
77 78 79
  T mean_val = static_cast<T>(0);
  T var_val = static_cast<T>(0);
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
80 81
    T tmp = x[i];
    mean_val += tmp;
S
sneaxiy 已提交
82 83
    var_val += (tmp * tmp);
  }
84 85 86 87 88 89 90 91
  auto pair = BlockReduce(temp_storage)
                  .Reduce(PairForLayerNorm<T>(mean_val, var_val),
                          PairForLayerNormAddFunctor<T>());
  if (threadIdx.x == 0) {
    auto tmp = pair.first_ / feature_size;
    mean[blockIdx.x] = tmp;
    var[blockIdx.x] = pair.second_ / feature_size - tmp * tmp;
  }
S
sneaxiy 已提交
92
  __syncthreads();
93
  mean_val = mean[blockIdx.x];
94
  var_val = static_cast<T>(real_sqrt(var[blockIdx.x] + epsilon));
S
sneaxiy 已提交
95

96
  // Step 2: Calculate y
S
sneaxiy 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
  if (scale != nullptr) {
    if (bias != nullptr) {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
        y[i] = scale[j] * (x[i] - mean_val) / var_val + bias[j];
      }
    } else {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
        y[i] = scale[j] * (x[i] - mean_val) / var_val;
      }
    }
  } else {  // scale == nullptr
    if (bias != nullptr) {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
        y[i] = (x[i] - mean_val) / var_val + bias[j];
      }
    } else {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
        y[i] = (x[i] - mean_val) / var_val;
      }
    }
  }
}

// Make sure that d_scale != nullptr && d_bias != nullptr
// Since d_scale != nullptr, scale would not be nullptr
template <typename T, int BlockDim, bool HasDx>
__global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
                                             T *d_scale, T *d_bias, T *d_x,
                                             const T *mean, const T *var,
                                             const T *scale, float epsilon,
                                             int batch_size, int feature_size) {
132
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
S
sneaxiy 已提交
133 134 135 136 137
  __shared__ typename BlockReduce::TempStorage temp_storage;

  int beg_idx = threadIdx.x * feature_size + blockIdx.x;
  int end_idx = batch_size * feature_size + blockIdx.x;
  int stride = BlockDim * feature_size;
138

S
sneaxiy 已提交
139 140 141 142
  T d_scale_partial = 0, d_bias_partial = 0;

  for (int i = beg_idx; i < end_idx; i += stride) {
    int row_idx = i / feature_size;
143
    auto var_val = static_cast<T>(real_sqrt(var[row_idx] + epsilon));
S
sneaxiy 已提交
144 145
    d_scale_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val;
    d_bias_partial += d_y[i];
146 147 148
    if (HasDx) {
      d_x[i] = d_y[i] * scale[blockIdx.x] / var_val;
    }
S
sneaxiy 已提交
149 150
  }

151 152 153
  auto pair = BlockReduce(temp_storage)
                  .Reduce(PairForLayerNorm<T>(d_scale_partial, d_bias_partial),
                          PairForLayerNormAddFunctor<T>());
S
sneaxiy 已提交
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177

  if (threadIdx.x == 0) {
    d_scale[blockIdx.x] = pair.first_;
    d_bias[blockIdx.x] = pair.second_;
  }
}

// Make sure that there is only one true expression: d_scale != nullptr
// or d_bias != nullptr
// Notice: scale may be nullptr
template <typename T, int BlockDim, bool HasDx, bool HasDScale>
__global__ void LayerNormBackwardGradientScaleOrBias(
    const T *x, const T *d_y, T *d_scale, T *d_bias, T *d_x, const T *mean,
    const T *var, const T *scale, float epsilon, int batch_size,
    int feature_size) {
  using BlockReduce = cub::BlockReduce<T, BlockDim>;
  __shared__ typename BlockReduce::TempStorage temp_storage;
  int beg_idx = threadIdx.x * feature_size + blockIdx.x;
  int end_idx = batch_size * feature_size + blockIdx.x;
  int stride = BlockDim * feature_size;
  T d_scale_or_d_bias_partial = 0;

  for (int i = beg_idx; i < end_idx; i += stride) {
    int row_idx = i / feature_size;
178
    auto var_val = static_cast<T>(real_sqrt(var[row_idx] + epsilon));
S
sneaxiy 已提交
179 180 181 182 183 184 185
    if (HasDScale) {
      d_scale_or_d_bias_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val;
    } else {  // d_bias != nullptr
      d_scale_or_d_bias_partial += d_y[i];
    }

    if (HasDx) {
186
      if (scale != nullptr) {
S
sneaxiy 已提交
187
        d_x[i] = d_y[i] * scale[blockIdx.x] / var_val;
188
      } else {
S
sneaxiy 已提交
189
        d_x[i] = d_y[i] / var_val;
190
      }
S
sneaxiy 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
    }
  }

  d_scale_or_d_bias_partial =
      BlockReduce(temp_storage).Reduce(d_scale_or_d_bias_partial, cub::Sum());

  if (threadIdx.x == 0) {
    if (HasDScale) {
      d_scale[blockIdx.x] = d_scale_or_d_bias_partial;
    } else {
      d_bias[blockIdx.x] = d_scale_or_d_bias_partial;
    }
  }
}

206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
template <typename T, int BlockDim>
__global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x,
                                                          const T *mean,
                                                          const T *var,
                                                          float epsilon,
                                                          int feature_size) {
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
  __shared__ typename BlockReduce::TempStorage temp_storage;
  __shared__ T d_x_reduce_tmp[2];

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

  T block_mean = mean[blockIdx.x];
  T block_var = var[blockIdx.x];
  T d_x_mean_partial = 0, d_x_var_partial = 0;
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
    d_x_mean_partial += d_x[i];
    d_x_var_partial += d_x[i] * (x[i] - block_mean);
  }

  auto pair =
      BlockReduce(temp_storage)
          .Reduce(PairForLayerNorm<T>(d_x_mean_partial, d_x_var_partial),
                  PairForLayerNormAddFunctor<T>());

  if (threadIdx.x == 0) {
    d_x_reduce_tmp[0] = pair.first_ / feature_size;
    d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon));
  }
  __syncthreads();

  d_x_mean_partial = d_x_reduce_tmp[0];
  d_x_var_partial = d_x_reduce_tmp[1];
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
    d_x[i] -= d_x_mean_partial;
    d_x[i] -= (x[i] - block_mean) * d_x_var_partial;
  }
}

S
sneaxiy 已提交
246
// Here, we only calculate d_x
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
template <typename T, int BlockDim>
__global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
                                                T *d_x, const T *mean,
                                                const T *var, const T *scale,
                                                float epsilon,
                                                int feature_size) {
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
  __shared__ typename BlockReduce::TempStorage temp_storage;
  __shared__ T d_x_reduce_tmp[2];

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

  T block_mean = mean[blockIdx.x], block_var = var[blockIdx.x];
  T d_x_mean_partial = 0, d_x_var_partial = 0;
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
    auto var_val = static_cast<T>(real_sqrt(block_var + epsilon));
S
sneaxiy 已提交
264
    if (scale != nullptr) {
265 266
      int col_idx = i % feature_size;
      d_x[i] = d_y[i] * scale[col_idx] / var_val;
S
sneaxiy 已提交
267
    } else {
268
      d_x[i] = d_y[i] / var_val;
S
sneaxiy 已提交
269
    }
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
    d_x_mean_partial += d_x[i];
    d_x_var_partial += d_x[i] * (x[i] - block_mean);
  }

  auto pair =
      BlockReduce(temp_storage)
          .Reduce(PairForLayerNorm<T>(d_x_mean_partial, d_x_var_partial),
                  PairForLayerNormAddFunctor<T>());

  if (threadIdx.x == 0) {
    d_x_reduce_tmp[0] = pair.first_ / feature_size;
    d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon));
  }
  __syncthreads();

  d_x_mean_partial = d_x_reduce_tmp[0];
  d_x_var_partial = d_x_reduce_tmp[1];
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
    d_x[i] -= d_x_mean_partial;
    d_x[i] -= (x[i] - block_mean) * d_x_var_partial;
S
sneaxiy 已提交
290 291 292 293 294 295 296 297 298
  }
}

template <typename T>
__global__ void LayerNormBackwardWhenBatchSizeIsOne(
    const T *x, const T *d_y, T *d_x, T *d_scale, T *d_bias, const T *mean,
    const T *var, const T *scale, float epsilon, int feature_size) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx < feature_size) {
299
    auto var_val = static_cast<T>(real_sqrt(var[idx] + epsilon));
S
sneaxiy 已提交
300
    if (d_x != nullptr) {
301
      if (d_scale == nullptr) {
S
sneaxiy 已提交
302
        d_x[idx] = d_y[idx] / var_val;
303
      } else {
S
sneaxiy 已提交
304
        d_x[idx] = d_y[idx] * scale[idx] / var_val;
305
      }
S
sneaxiy 已提交
306
    }
307 308

    if (d_scale != nullptr) {
S
sneaxiy 已提交
309
      d_scale[idx] = d_y[idx] * (x[idx] - mean[idx]) / var_val;
310 311
    }

S
sneaxiy 已提交
312 313 314 315 316 317 318 319 320 321
    if (d_bias != nullptr) d_bias[idx] = d_y[idx];
  }
}

template <typename T>
static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
                              const T *mean, const T *var, T *d_x, T *d_scale,
                              T *d_bias, float epsilon, int batch_size,
                              int feature_size, cudaStream_t stream) {
  const int kMaxBlockDim = 512;
322 323 324
  int gradient_flag = ((d_x != nullptr ? 1 : 0) << 2) |
                      ((d_scale != nullptr ? 1 : 0) << 1) |
                      ((d_bias != nullptr ? 1 : 0));
S
sneaxiy 已提交
325 326 327 328 329 330 331
  if (gradient_flag == 0) return;

  if (batch_size == 1) {
    LayerNormBackwardWhenBatchSizeIsOne<
        T><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0,
             stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, epsilon,
                       feature_size);
332 333 334 335 336 337 338 339

    if (d_x != nullptr) {
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(LayerNormBackwardPostProcessToCalculateDX<
                             T, kBlockDim><<<1, kBlockDim, 0, stream>>>(
            x, d_x, mean, var, epsilon, feature_size));
      }
    }
S
sneaxiy 已提交
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
    return;
  }

  auto block_dim = GetDesiredBlockDim(batch_size);
  switch (gradient_flag) {
    case 1:  // d_x == nulptr, d_scale == nullptr, d_bias != nullptr
      switch (block_dim) {
        FIXED_BLOCK_DIM_CASE(LayerNormBackwardGradientScaleOrBias<
                             T, kBlockDim, false,
                             false><<<feature_size, kBlockDim, 0, stream>>>(
            x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size,
            feature_size));
      }
      break;
    case 2:  // d_x == nullptr, d_scale != nullptr, d_bias == nullptr
      switch (block_dim) {
        FIXED_BLOCK_DIM_CASE(LayerNormBackwardGradientScaleOrBias<
                             T, kBlockDim, false,
                             true><<<feature_size, kBlockDim, 0, stream>>>(
            x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size,
            feature_size));
      }
      break;
    case 3:  // d_x == nullptr, d_scale != nulptr, d_bias != nullptr
      switch (block_dim) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardGradientAll<
                T, kBlockDim, false><<<feature_size, kBlockDim, 0, stream>>>(
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size));
      }
      break;
    case 4:  // d_x != nullptr, d_scale == nullptr, d_bias == nullptr
373 374 375 376 377 378
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardGradientOnlyDX<
                T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
                x, d_y, d_x, mean, var, scale, epsilon, feature_size));
      }
S
sneaxiy 已提交
379 380 381 382 383 384 385 386 387
      break;
    case 5:  // d_x != nulptr, d_scale == nullptr, d_bias != nullptr
      switch (block_dim) {
        FIXED_BLOCK_DIM_CASE(LayerNormBackwardGradientScaleOrBias<
                             T, kBlockDim, true,
                             false><<<feature_size, kBlockDim, 0, stream>>>(
            x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size,
            feature_size));
      }
388 389 390 391 392 393
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<
                T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
                x, d_x, mean, var, epsilon, feature_size));
      }
S
sneaxiy 已提交
394 395 396 397 398 399 400 401 402
      break;
    case 6:  // d_x != nullptr, d_scale != nullptr, d_bias == nullptr
      switch (block_dim) {
        FIXED_BLOCK_DIM_CASE(LayerNormBackwardGradientScaleOrBias<
                             T, kBlockDim, true,
                             true><<<feature_size, kBlockDim, 0, stream>>>(
            x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size,
            feature_size));
      }
403 404 405 406 407 408
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<
                T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
                x, d_x, mean, var, epsilon, feature_size));
      }
S
sneaxiy 已提交
409 410 411 412 413 414 415 416 417
      break;
    case 7:  // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
      switch (block_dim) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardGradientAll<
                T, kBlockDim, true><<<feature_size, kBlockDim, 0, stream>>>(
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size));
      }
418 419 420 421 422 423
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<
                T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
                x, d_x, mean, var, epsilon, feature_size));
      }
S
sneaxiy 已提交
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520
      break;
    default:
      break;
  }
}

template <typename T>
class LayerNormKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const float epsilon = ctx.Attr<float>("epsilon");
    auto *scale = ctx.Input<Tensor>("Scale");
    auto *bias = ctx.Input<Tensor>("Bias");
    auto *x = ctx.Input<Tensor>("X");

    auto *y = ctx.Output<Tensor>("Y");
    auto *mean = ctx.Output<Tensor>("Mean");
    auto *var = ctx.Output<Tensor>("Variance");
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");

    const auto x_dims = x->dims();
    auto *x_data = x->data<T>();
    auto *y_data = y->mutable_data<T>(ctx.GetPlace());
    auto *mean_data = mean->mutable_data<T>(ctx.GetPlace());
    auto *var_data = var->mutable_data<T>(ctx.GetPlace());
    auto *scale_data = (scale == nullptr ? nullptr : scale->data<T>());
    auto *bias_data = (bias == nullptr ? nullptr : bias->data<T>());

    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
    int batch_size = static_cast<int>(matrix_dim[0]);
    int feature_size = static_cast<int>(matrix_dim[1]);

    auto stream = ctx.cuda_device_context().stream();

    switch (GetDesiredBlockDim(feature_size)) {
      FIXED_BLOCK_DIM_CASE(
          LayerNormForward<T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
              x_data, scale_data, bias_data, y_data, mean_data, var_data,
              epsilon, feature_size));
      default:
        PADDLE_THROW(
            "Product from begin_norm_axis to end must be larger than 1");
        break;
    }
  }
};

template <typename T>
class LayerNormGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const float epsilon = ctx.Attr<float>("epsilon");
    // d_x, d_scale, d_bias may be nullptr
    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

    auto *x = ctx.Input<Tensor>("X");
    auto *mean = ctx.Input<Tensor>("Mean");
    auto *var = ctx.Input<Tensor>("Variance");
    auto *scale = ctx.Input<Tensor>("Scale");
    auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));

    auto *x_data = x->data<T>();
    auto *d_y_data = d_y->data<T>();
    auto *mean_data = mean->data<T>();
    auto *var_data = var->data<T>();
    auto *scale_data = (scale == nullptr ? nullptr : scale->data<T>());
    auto *d_scale_data =
        (d_scale == nullptr ? nullptr
                            : d_scale->mutable_data<T>(ctx.GetPlace()));
    auto *d_bias_data =
        (d_bias == nullptr ? nullptr : d_bias->mutable_data<T>(ctx.GetPlace()));
    auto *d_x_data =
        (d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace()));

    const auto &x_dims = x->dims();
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
    int batch_size = static_cast<int>(matrix_dim[0]);
    int feature_size = static_cast<int>(matrix_dim[1]);

    auto stream = ctx.cuda_device_context().stream();

    LayerNormBackward<T>(x_data, d_y_data, scale_data, mean_data, var_data,
                         d_x_data, d_scale_data, d_bias_data, epsilon,
                         batch_size, feature_size, stream);
  }
};

#undef FIXED_BLOCK_DIM_CASE_BASE
#undef FIXED_BLOCK_DIM_CASE
}  // namespace operators
}  // namespace paddle

C
chengduoZH 已提交
521 522 523
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    layer_norm,
C
chengduoZH 已提交
524 525
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>);
C
chengduoZH 已提交
526 527
REGISTER_OP_CUDA_KERNEL(
    layer_norm_grad,
C
chengduoZH 已提交
528 529
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>);