sync_batch_norm_kernel.cu 8.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sync_batch_norm_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/sync_batch_norm_utils.h"

namespace phi {

template <typename T, typename Context>
void SyncBatchNormKernel(const Context &ctx,
                         const DenseTensor &x,
                         const DenseTensor &scale,
                         const DenseTensor &bias,
                         const DenseTensor &mean,
                         const DenseTensor &variance,
                         float momentum,
                         float epsilon_f,
                         const std::string &data_layout_str,
                         bool is_test,
                         bool use_global_stats,
                         bool trainable_statistics,
                         bool fuse_with_relu,
                         DenseTensor *y,
                         DenseTensor *mean_out,
                         DenseTensor *variance_out,
                         DenseTensor *saved_mean,
                         DenseTensor *saved_variance,
                         DenseTensor *reserve_space) {
  PADDLE_ENFORCE_EQ(use_global_stats,
                    false,
                    phi::errors::InvalidArgument(
                        "sync_batch_norm doesn't support "
                        "to set use_global_stats True. Please use batch_norm "
                        "in this case."));

  double epsilon = epsilon_f;
  const bool trainable_stats = trainable_statistics;
51
  const DataLayout layout = phi::StringToDataLayout(data_layout_str);
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
  bool test_mode = is_test && (!trainable_statistics);
  const auto &x_dims = x.dims();
  PADDLE_ENFORCE_GE(x_dims.size(),
                    2,
                    phi::errors::InvalidArgument(
                        "The Input dim size should be larger than 1."));
  PADDLE_ENFORCE_LE(x_dims.size(),
                    5,
                    phi::errors::InvalidArgument(
                        "The Input dim size should be less than 6."));
  int N, C, H, W, D;
  funcs::ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D);
  int x_numel = x.numel();

  const T *x_d = x.template data<T>();
  const auto *s_d = scale.template data<BatchNormParamType<T>>();
  const auto *b_d = bias.template data<BatchNormParamType<T>>();

  T *y_d = ctx.template Alloc<T>(y);

  const BatchNormParamType<T> *mean_data = nullptr;
  const BatchNormParamType<T> *var_data = nullptr;

  auto stream = ctx.stream();
  const int block = 512;
  int max_threads = ctx.GetMaxPhysicalThreadCount();

  paddle::memory::AllocationPtr alloc_ptr{nullptr};

  if (test_mode) {
    mean_data = mean.template data<BatchNormParamType<T>>();
    var_data = variance.template data<BatchNormParamType<T>>();
  } else {
    // x, x^2, 1, here 1 is used to calc device num
    // device num also can be got from platform::DeviceContextPool
    const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType<T>);
88 89 90 91
    alloc_ptr = paddle::memory::Alloc(
        ctx.GetPlace(),
        bytes,
        phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
92 93 94 95

    auto *stats = reinterpret_cast<BatchNormParamType<T> *>(alloc_ptr->ptr());
    const int threads = 256;
    int grid = std::min(C, (max_threads + threads - 1) / threads);
96 97
    if (layout == phi::DataLayout::kNCHW) {
      KeLocalStats<T, threads, phi::DataLayout::kNCHW>
98 99
          <<<grid, threads, 0, stream>>>(x_d, N, H * W * D, C, stats);
    } else {
100
      KeLocalStats<T, threads, phi::DataLayout::kNHWC>
101 102 103 104
          <<<grid, threads, 0, stream>>>(x_d, N, H * W * D, C, stats);
    }

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
105 106
    ncclComm_t comm = static_cast<ncclComm_t>(detail::GetCCLComm(x.place(), 0));
    if (comm == nullptr) {
L
LiYuRio 已提交
107 108 109
      comm = ctx.nccl_comm();
    }

110 111 112 113 114 115 116 117 118 119 120 121
    if (comm) {
      int dtype = paddle::platform::ToNCCLDataType(
          paddle::framework::TransToProtoVarType(mean_out->dtype()));
      // In-place operation
      PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::ncclAllReduce(
          stats,
          stats,
          2 * C + 1,
          static_cast<ncclDataType_t>(dtype),
          ncclSum,
          comm,
          stream));
L
LiYuRio 已提交
122
      VLOG(3) << "Sync result using all reduce";
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
    }
#endif

    auto *est_mean_data = ctx.template Alloc<BatchNormParamType<T>>(mean_out);
    auto *est_var_data =
        ctx.template Alloc<BatchNormParamType<T>>(variance_out);

    auto *sv_mean_data = ctx.template Alloc<BatchNormParamType<T>>(saved_mean);
    auto *sv_inv_var_data =
        ctx.template Alloc<BatchNormParamType<T>>(saved_variance);

    // Note, Input('Mean')/Input('Variance') share variable with
    // Output('MeanOut')/Output('VarianceOut')
    KeSyncAndMovingStats<T>
        <<<(C + block - 1) / block, block, 0, stream>>>(stats,
                                                        stats + C,
                                                        stats + 2 * C,
                                                        C,
                                                        momentum,
                                                        epsilon,
                                                        sv_mean_data,
                                                        sv_inv_var_data,
                                                        est_mean_data,
                                                        est_var_data);

    mean_data = sv_mean_data;
    var_data = stats + C;
  }

  int grid2 = (std::min(x_numel, max_threads) + block - 1) / block;
153 154
  if (layout == phi::DataLayout::kNCHW) {
    KeNormAffine<T, phi::DataLayout::kNCHW>
155 156 157 158 159 160 161 162 163 164 165
        <<<grid2, block, 0, stream>>>(x_d,
                                      s_d,
                                      b_d,
                                      mean_data,
                                      var_data,
                                      epsilon,
                                      C,
                                      H * W * D,
                                      x_numel,
                                      y_d);
  } else {
166
    KeNormAffine<T, phi::DataLayout::kNHWC>
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
        <<<grid2, block, 0, stream>>>(x_d,
                                      s_d,
                                      b_d,
                                      mean_data,
                                      var_data,
                                      epsilon,
                                      C,
                                      H * W * D,
                                      x_numel,
                                      y_d);
  }
}

}  // namespace phi

#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(sync_batch_norm,
                   GPU,
                   ALL_LAYOUT,
                   phi::SyncBatchNormKernel,
                   float,
188 189 190 191 192 193 194 195 196 197 198 199
                   phi::dtype::float16) {
  if (kernel_key.dtype() == phi::DataType::FLOAT16) {
    kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
    kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
    kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
    kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
    kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
    kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
    kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
    kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
  }
}
200 201 202 203 204 205 206
#else
PD_REGISTER_KERNEL(sync_batch_norm,
                   GPU,
                   ALL_LAYOUT,
                   phi::SyncBatchNormKernel,
                   float,
                   double,
207 208 209 210 211 212 213 214 215 216 217 218
                   phi::dtype::float16) {
  if (kernel_key.dtype() == phi::DataType::FLOAT16) {
    kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
    kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
    kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
    kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
    kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
    kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
    kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
    kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
  }
}
219
#endif