group_norm_op.cu 14.8 KB
Newer Older
D
Dun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#ifdef __NVCC__
16
#include "cub/cub.cuh"
17 18 19
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
20
namespace cub = hipcub;
21 22
#endif

D
Dun 已提交
23
#include "paddle/fluid/operators/group_norm_op.h"
24 25
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
D
Dun 已提交
26 27 28 29

namespace paddle {
namespace operators {

30
using DataLayout = framework::DataLayout;
31 32
enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 };

P
peizhilin 已提交
33 34 35
#define CHECK_CASE(i, flags, kernel_name, ...)                              \
  if (i == flags) {                                                         \
    kernel_name<T, i><<<grid, threads, 0, dev_ctx.stream()>>>(__VA_ARGS__); \
36 37 38 39 40 41
  }

// 0 for no scale, no bias
// 1 for has scale, no bias
// 2 for no scale, has bias
// 3 for has scale, has bias
P
peizhilin 已提交
42 43 44 45 46
#define UNROLL_ALL_CASES(flags, kernel_name, ...) \
  CHECK_CASE(0, flags, kernel_name, __VA_ARGS__)  \
  CHECK_CASE(1, flags, kernel_name, __VA_ARGS__)  \
  CHECK_CASE(2, flags, kernel_name, __VA_ARGS__)  \
  CHECK_CASE(3, flags, kernel_name, __VA_ARGS__)
47 48 49 50 51 52 53 54 55

template <typename T>
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
  typedef cub::WarpReduce<T> WarpReduce;
  typename WarpReduce::TempStorage temp_storage;
  value = WarpReduce(temp_storage).Sum(value);
  if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value);
}

D
Dun 已提交
56
template <typename T>
57
__global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C, int W,
D
Dun 已提交
58
                                              int imsize, int groups,
59 60
                                              int group_size, T* mean, T* var,
                                              const DataLayout data_layout) {
D
Dun 已提交
61 62 63
  int gid = blockIdx.y;
  int cid = blockIdx.x;
  int bid = blockIdx.z;
64
  int H = imsize / W;
D
Dun 已提交
65 66 67 68 69
  int number = min(group_size, static_cast<int>(C - gid * group_size));
  int ccid = gid * group_size + cid;
  if (ccid >= C) return;
  T x_mean = 0, x_var = 0;
  for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
70 71 72 73 74 75 76 77
    T val;
    if (data_layout == DataLayout::kNCHW) {
      val = x[(bid * C + ccid) * imsize + imid];
    } else {
      int hid = imid / W;
      int wid = imid % W;
      val = x[(bid * H + hid) * W * C + wid * C + ccid];
    }
D
Dun 已提交
78 79 80 81 82
    x_mean += val;
    x_var += val * val;
  }
  x_mean /= number * imsize;
  x_var /= number * imsize;
83 84
  CudaAtomicAddWithWarp(&mean[bid * groups + gid], x_mean);
  CudaAtomicAddWithWarp(&var[bid * groups + gid], x_var);
D
Dun 已提交
85 86
}

87
template <typename T, int flags>
D
Dun 已提交
88 89
__global__ void GroupNormForward(const T* x, const T* mean, const T* var,
                                 const T* scale, const T* bias, int N, int C,
90 91 92
                                 int W, int imsize, int groups, int group_size,
                                 T epsilon, T* y, T* real_var,
                                 const DataLayout data_layout) {
D
Dun 已提交
93 94 95
  int gid = blockIdx.y;
  int cid = blockIdx.x;
  int bid = blockIdx.z;
96
  int H = imsize / W;
D
Dun 已提交
97 98 99 100 101 102 103 104
  int ccid = gid * group_size + cid;
  if (ccid >= C) return;
  T x_mean = mean[bid * groups + gid];
  T x_var = var[bid * groups + gid];
  x_var = x_var - x_mean * x_mean;
  T var_inv = 1.0 / sqrt(x_var + epsilon);
  if (cid == 0 && threadIdx.x == 0) real_var[bid * groups + gid] = x_var;
  for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
105 106 107 108 109 110 111 112 113
    T val;
    int hid, wid;
    if (data_layout == DataLayout::kNCHW) {
      val = x[(bid * C + ccid) * imsize + imid];
    } else {
      hid = imid / W;
      wid = imid % W;
      val = x[(bid * H + hid) * W * C + wid * C + ccid];
    }
D
Dun 已提交
114
    val = (val - x_mean) * var_inv;
115 116
    if (flags & kHasScale) val *= scale[gid * group_size + cid];
    if (flags & kHasBias) val += bias[gid * group_size + cid];
117 118 119 120 121
    if (data_layout == DataLayout::kNCHW) {
      y[(bid * C + ccid) * imsize + imid] = val;
    } else {
      y[(bid * H + hid) * W * C + wid * C + ccid] = val;
    }
D
Dun 已提交
122 123 124 125 126 127 128 129
  }
}

template <typename T>
class GroupNormKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
130 131 132
    const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
    const DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
D
Dun 已提交
133 134 135 136 137 138 139 140 141 142 143
    const float epsilon = ctx.Attr<float>("epsilon");
    auto* scale = ctx.Input<Tensor>("Scale");
    auto* bias = ctx.Input<Tensor>("Bias");
    auto* x = ctx.Input<Tensor>("X");

    auto* y = ctx.Output<Tensor>("Y");
    auto* mean = ctx.Output<Tensor>("Mean");
    auto* var = ctx.Output<Tensor>("Variance");
    const auto groups = ctx.Attr<int>("groups");

    const auto x_dims = x->dims();
144 145 146
    const int C =
        (data_layout == DataLayout::kNCHW ? x_dims[1]
                                          : x_dims[x_dims.size() - 1]);
147 148
    const int group_size = C / groups;

149 150 151
    const int W =
        (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1]
                                          : x_dims[x_dims.size() - 2]);
D
Dun 已提交
152 153 154 155

    y->mutable_data<T>(ctx.GetPlace());
    mean->mutable_data<T>(ctx.GetPlace());
    var->mutable_data<T>(ctx.GetPlace());
156
    phi::funcs::SetConstant<platform::CUDADeviceContext, T> set_zero;
D
Dun 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    Tensor temp_var;
    temp_var.mutable_data<T>(var->dims(), ctx.GetPlace());

    set_zero(dev_ctx, mean, static_cast<T>(0));
    set_zero(dev_ctx, &temp_var, static_cast<T>(0));

    auto* x_data = x->data<T>();
    auto* y_data = y->data<T>();
    auto* mean_data = mean->data<T>();
    auto* var_data = var->data<T>();
    auto* temp_var_data = temp_var.data<T>();

    const T* scale_data = nullptr;
    if (scale) scale_data = scale->data<T>();
    const T* bias_data = nullptr;
    if (bias) bias_data = bias->data<T>();

175 176 177 178 179 180 181 182 183 184
    int imsize = 1;
    if (data_layout == DataLayout::kNCHW) {
      for (int i = 2; i < x_dims.size(); ++i) {
        imsize *= x_dims[i];
      }
    } else {
      for (int i = 1; i < x_dims.size() - 1; ++i) {
        imsize *= x_dims[i];
      }
    }
R
ronnywang 已提交
185 186 187
#ifdef __HIPCC__
    int block_size = std::max(std::min(256, imsize), 64);
#else
188
    int block_size = std::min(1024, imsize);
R
ronnywang 已提交
189
#endif
D
Dun 已提交
190 191 192
    dim3 grid(group_size, groups, x_dims[0]);
    dim3 threads(block_size, 1, 1);
    GroupNormForwardGetMeanAndVar<T><<<grid, threads, 0, dev_ctx.stream()>>>(
193 194
        x_data, x_dims[0], C, W, imsize, groups, group_size, mean_data,
        temp_var_data, data_layout);
195 196 197
    int flags =
        (scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
    UNROLL_ALL_CASES(flags, GroupNormForward, x_data, mean_data, temp_var_data,
198 199
                     scale_data, bias_data, x_dims[0], C, W, imsize, groups,
                     group_size, epsilon, y_data, var_data, data_layout);
D
Dun 已提交
200 201 202
  }
};

203
template <typename T, int flags>
204 205 206 207
__global__ void GroupNormBackwardGetMeanAndVar(
    const T* x, const T* scale, const T* bias, const T* d_y, int N, int C,
    int W, int imsize, int groups, int group_size, T epsilon, T* d_mean,
    T* d_var, T* d_scale, T* d_bias, const DataLayout data_layout) {
D
Dun 已提交
208 209 210
  int gid = blockIdx.y;
  int cid = blockIdx.x;
  int bid = blockIdx.z;
211
  int H = imsize / W;
D
Dun 已提交
212 213 214
  int number = min(group_size, static_cast<int>(C - gid * group_size));
  int ccid = gid * group_size + cid;
  if (ccid >= C) return;
215 216 217 218
  T x_scale = (flags & kHasScale) ? scale[ccid] : 1;
  T x_bias = (flags & kHasBias) ? bias[ccid] : 0;
  T x_scale_inv = 0;
  if (x_scale != 0) x_scale_inv = 1.0 / x_scale;
D
Dun 已提交
219 220 221
  T d_mean_data = 0, d_var_data = 0, d_scale_data = 0, d_bias_data = 0;

  for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
222 223 224 225 226 227 228 229 230 231
    T val, dval;
    if (data_layout == DataLayout::kNCHW) {
      val = x[(bid * C + ccid) * imsize + imid] - x_bias;
      dval = d_y[(bid * C + ccid) * imsize + imid];
    } else {
      int hid = imid / W;
      int wid = imid % W;
      val = x[(bid * H + hid) * W * C + wid * C + ccid] - x_bias;
      dval = d_y[(bid * H + hid) * W * C + wid * C + ccid];
    }
D
Dun 已提交
232

233 234 235 236 237 238
    d_var_data += val * dval;
    d_mean_data += dval * x_scale;

    val = val * x_scale_inv;
    d_bias_data += dval;
    d_scale_data += val * dval;
D
Dun 已提交
239
  }
240 241 242 243
  CudaAtomicAddWithWarp(&(d_mean[bid * groups + gid]), d_mean_data);
  CudaAtomicAddWithWarp(&(d_var[bid * groups + gid]), d_var_data);
  if (flags & kHasScale) CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data);
  if (flags & kHasBias) CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data);
D
Dun 已提交
244 245
}

246 247 248
template <typename T, int flags>
__global__ void GroupNormBackward(const T* x, const T* d_y, const T* scale,
                                  const T* bias, const T* var, const T* d_mean,
249 250 251 252
                                  const T* d_var, int N, int C, int W,
                                  int imsize, int groups, int group_size,
                                  T epsilon, T* d_x,
                                  const DataLayout data_layout) {
D
Dun 已提交
253 254 255
  int gid = blockIdx.y;
  int cid = blockIdx.x;
  int bid = blockIdx.z;
256
  int H = imsize / W;
D
Dun 已提交
257 258 259 260 261
  int number = min(group_size, static_cast<int>(C - gid * group_size));
  int ccid = gid * group_size + cid;
  if (ccid >= C) return;
  T x_var = var[bid * groups + gid];
  T d_x_mean = d_mean[bid * groups + gid];
262 263 264 265 266 267 268 269 270
  T d_x_var = d_var[bid * groups + gid];

  T x_var_inv = 1.0 / sqrt(x_var + epsilon);
  T number_inv = 1.0 / (number * imsize);

  T x_scale = (flags & kHasScale) ? scale[ccid] : 1;
  T x_bias = (flags & kHasBias) ? bias[ccid] : 0;
  T x_scale_inv = 0;
  if (x_scale != 0) x_scale_inv = 1.0 / x_scale;
D
Dun 已提交
271 272

  for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
    if (data_layout == DataLayout::kNCHW) {
      T tmp = x[(bid * C + ccid) * imsize + imid];
      T v_y = (tmp - x_bias) * x_scale_inv;
      T dly = d_y[(bid * C + ccid) * imsize + imid];
      d_x[(bid * C + ccid) * imsize + imid] =
          x_var_inv *
          (dly * x_scale - number_inv * d_x_var * v_y - number_inv * d_x_mean);
    } else {
      int hid = imid / W;
      int wid = imid % W;
      T tmp = x[(bid * H + hid) * W * C + wid * C + ccid];
      T v_y = (tmp - x_bias) * x_scale_inv;
      T dly = d_y[(bid * H + hid) * W * C + wid * C + ccid];
      d_x[(bid * H + hid) * W * C + wid * C + ccid] =
          x_var_inv *
          (dly * x_scale - number_inv * d_x_var * v_y - number_inv * d_x_mean);
    }
D
Dun 已提交
290 291 292 293 294 295 296 297
  }
}

template <typename T>
class GroupNormGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
298 299 300
    const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
    const DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
D
Dun 已提交
301
    const float epsilon = ctx.Attr<float>("epsilon");
302
    auto* x = ctx.Input<Tensor>("Y");
D
Dun 已提交
303 304
    auto* var = ctx.Input<Tensor>("Variance");
    auto* scale = ctx.Input<Tensor>("Scale");
305
    auto* bias = ctx.Input<Tensor>("Bias");
D
Dun 已提交
306 307 308 309 310 311 312 313 314
    auto* d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
    const auto groups = ctx.Attr<int>("groups");

    // init output
    auto* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto* d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto* d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

    const auto& x_dims = x->dims();
315 316 317
    const int C =
        (data_layout == DataLayout::kNCHW ? x_dims[1]
                                          : x_dims[x_dims.size() - 1]);
318
    const int group_size = C / groups;
319 320 321
    const int W =
        (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1]
                                          : x_dims[x_dims.size() - 2]);
D
Dun 已提交
322

323
    d_x->mutable_data<T>(ctx.GetPlace());
324
    phi::funcs::SetConstant<platform::CUDADeviceContext, T> set_zero;
D
Dun 已提交
325 326 327 328 329 330 331 332 333 334 335 336 337
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();

    Tensor temp_var;
    temp_var.mutable_data<T>(var->dims(), ctx.GetPlace());
    set_zero(dev_ctx, &temp_var, static_cast<T>(0));
    T* temp_var_data = temp_var.data<T>();

    Tensor temp_mean;
    temp_mean.mutable_data<T>(var->dims(), ctx.GetPlace());
    set_zero(dev_ctx, &temp_mean, static_cast<T>(0));
    T* temp_mean_data = temp_mean.data<T>();

    auto* x_data = x->data<T>();
338 339
    T* d_x_data = nullptr;
    if (d_x) d_x_data = d_x->data<T>();
D
Dun 已提交
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
    auto* y_data = d_y->data<T>();
    auto* var_data = var->data<T>();
    T* d_scale_data = nullptr;
    if (d_scale) {
      d_scale->mutable_data<T>(ctx.GetPlace());
      set_zero(dev_ctx, d_scale, static_cast<T>(0));
      d_scale_data = d_scale->data<T>();
    }
    T* d_bias_data = nullptr;
    if (d_bias) {
      d_bias->mutable_data<T>(ctx.GetPlace());
      set_zero(dev_ctx, d_bias, static_cast<T>(0));
      d_bias_data = d_bias->data<T>();
    }

    const T* scale_data = nullptr;
    if (scale) scale_data = scale->data<T>();
357 358
    const T* bias_data = nullptr;
    if (bias) bias_data = bias->data<T>();
D
Dun 已提交
359

360 361 362 363 364 365 366 367 368 369
    int imsize = 1;
    if (data_layout == DataLayout::kNCHW) {
      for (int i = 2; i < x_dims.size(); ++i) {
        imsize *= x_dims[i];
      }
    } else {
      for (int i = 1; i < x_dims.size() - 1; ++i) {
        imsize *= x_dims[i];
      }
    }
370

R
ronnywang 已提交
371 372 373
#ifdef __HIPCC__
    int block_size = std::max(std::min(256, imsize), 64);
#else
374
    int block_size = std::min(1024, imsize);
R
ronnywang 已提交
375
#endif
D
Dun 已提交
376 377
    dim3 grid(group_size, groups, x_dims[0]);
    dim3 threads(block_size, 1, 1);
378 379 380
    int flags =
        (scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
    UNROLL_ALL_CASES(flags, GroupNormBackwardGetMeanAndVar, x_data, scale_data,
381
                     bias_data, y_data, x_dims[0], C, W, imsize, groups,
382
                     group_size, epsilon, temp_mean_data, temp_var_data,
383
                     d_scale_data, d_bias_data, data_layout);
384 385 386
    if (d_x_data != nullptr) {
      UNROLL_ALL_CASES(flags, GroupNormBackward, x_data, y_data, scale_data,
                       bias_data, var_data, temp_mean_data, temp_var_data,
387 388
                       x_dims[0], C, W, imsize, groups, group_size, epsilon,
                       d_x_data, data_layout);
389
    }
D
Dun 已提交
390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    group_norm,
    ops::GroupNormKernel<paddle::platform::CUDADeviceContext, float>,
    ops::GroupNormKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
    group_norm_grad,
    ops::GroupNormGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::GroupNormGradKernel<paddle::platform::CUDADeviceContext, double>);