nan_inf_utils_detail.cu 21.0 KB
Newer Older
W
WangXi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
#include "paddle/fluid/framework/details/nan_inf_utils_detail.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h"

W
WangXi 已提交
18 19 20 21
#include <algorithm>
#include <unordered_map>
#include <utility>
#include <vector>
22

23
#include "paddle/fluid/framework/convert_utils.h"
24
#include "paddle/fluid/framework/scope.h"
25
#include "paddle/phi/common/amp_type_traits.h"
26
#include "paddle/phi/common/memory_utils.h"
27 28
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"

29
DECLARE_int32(check_nan_inf_level);
W
WangXi 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50

namespace paddle {
namespace framework {
namespace details {

static std::once_flag init_multi_gpu_op_var_map_flag;

// lazy init
static std::vector<std::unordered_map<std::string, memory::AllocationPtr>>&
multi_op_var2gpu_str() {
  static std::vector<std::unordered_map<std::string, memory::AllocationPtr>>
      _multi_op_var2gpu_str;
  return _multi_op_var2gpu_str;
}

static std::vector<std::mutex>& multi_op_var2gpu_str_mutex() {
  static std::vector<std::mutex> _multi_op_var2gpu_str_mutex;
  return _multi_op_var2gpu_str_mutex;
}

static void InitMultiGPUOpVarMap() {
51
  int dev_count = platform::GetGPUDeviceCount();
52 53
  PADDLE_ENFORCE_GT(dev_count,
                    0,
W
WangXi 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
                    platform::errors::NotFound(
                        "cuda device must > 0, now dev_count=%d", dev_count));

  // https://stackoverflow.com/questions/16465633/how-can-i-use-something-like-stdvectorstdmutex
  std::vector<std::unordered_map<std::string, memory::AllocationPtr>> tmp_multi(
      dev_count);
  std::vector<std::mutex> tmp_multi_mutex(dev_count);

  multi_op_var2gpu_str().swap(tmp_multi);
  multi_op_var2gpu_str_mutex().swap(tmp_multi_mutex);
}

template <typename T>
__device__ __forceinline__ void PrintNanInfKernel(const T* value,
                                                  const size_t numel,
                                                  int print_num,
                                                  char* debug_info) {
  const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;

  __shared__ unsigned int nan_count, inf_count, num_count;
  if (threadIdx.x == 0) nan_count = inf_count = num_count = 0;
  __syncthreads;

  for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
    unsigned int count = 0;
    if (isnan(value[i])) {
      count = atomicAdd(&nan_count, 1);
    } else if (isinf(value[i])) {
      count = atomicAdd(&inf_count, 1);
    } else {
      count = atomicAdd(&num_count, 1);
    }
    // for cuda, print in every block
    if (count < print_num) {
88 89 90 91
      printf("numel:%lu idx:%lu value:%f\n",
             static_cast<uint64_t>(numel),
             static_cast<uint64_t>(i),
             static_cast<float>(value[i]));
W
WangXi 已提交
92 93 94 95
    }
  }
  __syncthreads;

96
#ifdef __HIPCC__
97
  if (true && hipThreadIdx_x == 0) {
98 99 100 101 102
    printf("In block %d, there has %u,%u,%u nan,inf,num\n",
           hipBlockIdx_x,
           nan_count,
           inf_count,
           num_count);
103
#else
W
WangXi 已提交
104
  if (true && threadIdx.x == 0) {
105 106 107 108 109
    printf("In block %d, there has %u,%u,%u nan,inf,num\n",
           blockIdx.x,
           nan_count,
           inf_count,
           num_count);
110
#endif
W
WangXi 已提交
111 112 113 114 115 116
    PADDLE_ENFORCE(false, "===ERROR: in %s find nan or inf===", debug_info);
  }
}

// Resnet 2gpus speed test, no check 270 images/s, this check 229 images/s
template <typename T>
117 118 119 120
__global__ void CheckNanInfKernel(const T* value,
                                  const size_t numel,
                                  int print_num,
                                  char* debug_info) {
W
WangXi 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
  /// step 1, judge wheater has nan or inf
  __shared__ volatile int has_nan_inf;
  if (threadIdx.x == 0) has_nan_inf = false;
  __syncthreads();

  const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
  T sum = static_cast<T>(0.0);
  // Todo(wangxi). simd speed up
  for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
    sum += (value[i] - value[i]);
  }

  if (isnan(sum) || isinf(sum)) has_nan_inf = true;
  __syncthreads();

  /// Note. different blocks may behave differently
  if (!has_nan_inf) return;

  PrintNanInfKernel(value, numel, print_num, debug_info);
}

142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
template <typename T, int ReduceType>
__device__ T BlockReduce(T value) {
  __shared__ T shared_mem[1024];

  shared_mem[threadIdx.x] = value;
  __syncthreads();

  for (int stride = blockDim.x >> 1; stride > 0; stride = stride >> 1) {
    if (threadIdx.x < stride) {
      T value0 = shared_mem[threadIdx.x];
      T value1 = shared_mem[threadIdx.x + stride];
      T reduce_value;
      if (ReduceType == 0) {
        // max
        reduce_value = value0 > value1 ? value0 : value1;
      } else if (ReduceType == 1) {
        // min
        reduce_value = value0 < value1 ? value0 : value1;
      } else if (ReduceType == 2) {
        // sum
        reduce_value = value0 + value1;
      }
      shared_mem[threadIdx.x] = reduce_value;
    }

    if (stride > 16) {
      __syncthreads();
    }
  }

  __syncthreads();
  return shared_mem[0];
}

__device__ void BlockReduceNumNanInfAndWrite(const int64_t num_nan,
                                             const int64_t num_inf,
178
                                             const int64_t num_zero,
179 180
                                             int64_t offset,
                                             int64_t* num_nan_ptr,
181 182
                                             int64_t* num_inf_ptr,
                                             int64_t* num_zero_ptr) {
183 184
  int64_t block_num_nan = BlockReduce<int64_t, 2>(num_nan);
  int64_t block_num_inf = BlockReduce<int64_t, 2>(num_inf);
185
  int64_t block_num_zero = BlockReduce<int64_t, 2>(num_zero);
186 187 188 189

  if (threadIdx.x == 0) {
    num_nan_ptr[offset] = block_num_nan;
    num_inf_ptr[offset] = block_num_inf;
190
    num_zero_ptr[offset] = block_num_zero;
191 192 193
  }
}

194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
template <
    typename T,
    std::enable_if_t<std::is_same<T, phi::dtype::complex<float>>::value ||
                         std::is_same<T, phi::dtype::complex<double>>::value,
                     bool> = true>
__device__ void BlockReduceMaxMinAndWrite(const T max_value,
                                          const T min_value,
                                          const T mean_value,
                                          int64_t offset,
                                          T* max_ptr,
                                          T* min_ptr,
                                          T* mean_ptr) {
  // TODO(Xreki): support complex
}

template <
    typename T,
    std::enable_if_t<!std::is_same<T, phi::dtype::complex<float>>::value &&
                         !std::is_same<T, phi::dtype::complex<double>>::value,
                     bool> = true>
__device__ void BlockReduceMaxMinAndWrite(const T max_value,
                                          const T min_value,
                                          const T mean_value,
                                          int64_t offset,
                                          T* max_ptr,
                                          T* min_ptr,
                                          T* mean_ptr) {
  if (max_ptr && min_ptr && mean_ptr) {
    __syncthreads();

224 225 226
    T block_max_value = phi::funcs::BlockReduceMax<T>(max_value, FINAL_MASK);
    T block_min_value = phi::funcs::BlockReduceMin<T>(min_value, FINAL_MASK);
    T block_mean_value = phi::funcs::BlockReduceSum<T>(mean_value, FINAL_MASK);
227 228 229 230 231 232 233 234 235 236 237 238

    if (threadIdx.x == 0) {
      max_ptr[offset] = block_max_value;
      min_ptr[offset] = block_min_value;
      mean_ptr[offset] = block_mean_value;
    }
  }
}

template <typename T, typename MT>
__global__ void FindNanInfAndBlockMaxMin(const T* value_ptr,
                                         const int64_t numel,
239 240
                                         int64_t* block_num_nan_ptr,
                                         int64_t* block_num_inf_ptr,
241
                                         int64_t* block_num_zero_ptr,
242 243 244 245 246
                                         MT* tensor_block_max_ptr,
                                         MT* tensor_block_min_ptr,
                                         MT* tensor_block_mean_ptr) {
  int64_t i = threadIdx.x + blockIdx.x * blockDim.x;

247 248
  int64_t num_nan = 0;
  int64_t num_inf = 0;
249
  int64_t num_zero = 0;
250

251 252 253 254 255 256 257 258 259 260 261
  MT max_value = static_cast<MT>(i < numel ? value_ptr[i] : value_ptr[0]);
  MT min_value = static_cast<MT>(i < numel ? value_ptr[i] : value_ptr[0]);
  MT mean_value = static_cast<MT>(0);
  for (; i < numel; i += blockDim.x * gridDim.x) {
    MT value = static_cast<MT>(value_ptr[i]);

    max_value = value > max_value ? value : max_value;
    min_value = value < min_value ? value : min_value;
    mean_value += value / static_cast<MT>(numel);

    if (isnan(value)) {
262 263 264
      num_nan += 1;
    } else if (isinf(value)) {
      num_inf += 1;
265
    }
266 267 268
    if (value == static_cast<MT>(0)) {
      num_zero += 1;
    }
269
  }
270

271 272 273 274 275 276 277
  BlockReduceNumNanInfAndWrite(num_nan,
                               num_inf,
                               num_zero,
                               blockIdx.x,
                               block_num_nan_ptr,
                               block_num_inf_ptr,
                               block_num_zero_ptr);
278 279 280 281 282 283 284 285 286 287

  BlockReduceMaxMinAndWrite<MT>(max_value,
                                min_value,
                                mean_value,
                                blockIdx.x,
                                tensor_block_max_ptr,
                                tensor_block_min_ptr,
                                tensor_block_mean_ptr);
}

288
template <typename T, typename MT>
289 290
__global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr,
                                         const int64_t* block_num_inf_ptr,
291
                                         const int64_t* block_num_zero_ptr,
292 293 294
                                         const MT* tensor_block_max_ptr,
                                         const MT* tensor_block_min_ptr,
                                         const MT* tensor_block_mean_ptr,
295 296 297
                                         const char* debug_info,
                                         int64_t numel,
                                         int64_t numel_max_min,
298 299
                                         int check_nan_inf_level,
                                         int64_t* nan_inf_zero) {
300
  if (blockIdx.x == 0 && threadIdx.x == 0) {
301 302
    int64_t num_nan = 0;
    int64_t num_inf = 0;
303
    int64_t num_zero = 0;
304 305 306 307 308

    // numel_max_min <= 128
    for (int64_t i = 0; i < numel_max_min; ++i) {
      num_nan += block_num_nan_ptr[i];
      num_inf += block_num_inf_ptr[i];
309
      num_zero += block_num_zero_ptr[i];
310
    }
311

312 313 314
    MT max_value = static_cast<MT>(0);
    MT min_value = static_cast<MT>(0);
    MT mean_value = static_cast<MT>(0);
315 316 317 318 319 320 321
    if (tensor_block_max_ptr && tensor_block_min_ptr && tensor_block_mean_ptr) {
      max_value = tensor_block_max_ptr[0];
      min_value = tensor_block_min_ptr[0];
      mean_value = tensor_block_mean_ptr[0];

      // numel_max_min <= 128
      for (int64_t i = 1; i < numel_max_min; ++i) {
322 323 324
        MT tmp_max_value = tensor_block_max_ptr[i];
        MT tmp_min_value = tensor_block_min_ptr[i];
        MT tmp_mean_value = tensor_block_mean_ptr[i];
325 326 327 328 329

        max_value = tmp_max_value > max_value ? tmp_max_value : max_value;
        min_value = tmp_min_value < min_value ? tmp_min_value : min_value;
        mean_value += tmp_mean_value;
      }
330 331 332 333 334
      if (check_nan_inf_level == 0) {
        nan_inf_zero[0] = num_nan;
        nan_inf_zero[1] = num_inf;
        nan_inf_zero[2] = num_zero;
      }
335
    }
336 337 338 339
    PrintForDifferentLevel<T, MT>(debug_info,
                                  numel,
                                  num_nan,
                                  num_inf,
340
                                  num_zero,
341 342 343 344
                                  max_value,
                                  min_value,
                                  mean_value,
                                  check_nan_inf_level);
345 346 347
  }
}

W
WangXi 已提交
348
template <typename T>
349
inline std::string GetHintString(const std::string& op_type,
350
                                 const std::string& var_name,
351 352 353
                                 const phi::Place& place,
                                 int dev_id = -1) {
  std::string op_var = GetCpuHintString<T>(op_type, var_name, place, dev_id);
W
WangXi 已提交
354
  PADDLE_ENFORCE_EQ(
355 356
      (dev_id >= 0 && dev_id < multi_op_var2gpu_str_mutex().size()),
      true,
W
WangXi 已提交
357 358
      platform::errors::OutOfRange("GPU dev_id must >=0 and < dev_count=%d",
                                   multi_op_var2gpu_str_mutex().size()));
359 360
  return op_var;
}
W
WangXi 已提交
361

362 363 364 365 366
template <typename T>
static char* GetGpuHintStringPtr(const phi::GPUContext& ctx,
                                 const std::string& op_type,
                                 const std::string& var_name,
                                 int dev_id) {
367
  std::string op_var =
368
      GetHintString<T>(op_type, var_name, ctx.GetPlace(), dev_id);
369
  char* gpu_str_ptr = nullptr;
W
WangXi 已提交
370 371 372 373 374 375 376

  {
    auto& op_var2gpu_str_mutex = multi_op_var2gpu_str_mutex().at(dev_id);
    auto& op_var2gpu_str = multi_op_var2gpu_str().at(dev_id);

    std::lock_guard<std::mutex> guard(op_var2gpu_str_mutex);
    if (op_var2gpu_str.find(op_var) == op_var2gpu_str.end()) {  // insert
377
      auto gpu_str_tensor = paddle::memory::Alloc(
378
          ctx.GetPlace(),
379
          op_var.length() + 1,
380
          phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
W
WangXi 已提交
381 382 383 384 385
      gpu_str_ptr = reinterpret_cast<char*>(gpu_str_tensor->ptr());

      op_var2gpu_str.emplace(op_var, std::move(gpu_str_tensor));

      auto iter = op_var2gpu_str.find(op_var);
386 387
      PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(),
                        true,
W
WangXi 已提交
388 389 390 391 392
                        platform::errors::PreconditionNotMet(
                            "op_var=%s should successed insert into "
                            "op_var2gpu_str, but now failed",
                            op_var));

393
#ifdef __HIPCC__
394 395 396 397
      PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(gpu_str_ptr,
                                                iter->first.c_str(),
                                                op_var.length() + 1,
                                                hipMemcpyHostToDevice,
398
                                                ctx.stream()));
399
#else
400 401 402 403
      PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(gpu_str_ptr,
                                                 iter->first.c_str(),
                                                 op_var.length() + 1,
                                                 cudaMemcpyHostToDevice,
404
                                                 ctx.stream()));
405
#endif
W
WangXi 已提交
406 407
    } else {  // get
      auto iter = op_var2gpu_str.find(op_var);
408 409
      PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(),
                        true,
W
WangXi 已提交
410 411 412 413 414 415 416
                        platform::errors::PreconditionNotMet(
                            "op_var=%s should be in the op_var2gpu_str, but "
                            "now can't find it",
                            op_var));
      gpu_str_ptr = reinterpret_cast<char*>(iter->second->ptr());
    }
  }
417 418 419 420 421 422 423 424 425 426 427 428 429 430
  return gpu_str_ptr;
}

template <>
template <typename T>
void TensorCheckerVisitor<phi::GPUContext>::apply(
    typename std::enable_if<
        std::is_floating_point<T>::value ||
        std::is_same<T, ::paddle::platform::complex<float>>::value ||
        std::is_same<T, ::paddle::platform::complex<double>>::value>::type*)
    const {
  auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(
      platform::DeviceContextPool::Instance().Get(tensor.place()));
  int dev_id = tensor.place().device;
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
  // Write log to file
  auto file_path = GetNanPath();
  if (file_path.size() > 0) {
    phi::DenseTensor cpu_tensor;
    platform::CPUPlace cpu_place;
    cpu_tensor.Resize(tensor.dims());
    // 1. copy from gpu to cpu
    paddle::framework::TensorCopySync(tensor, cpu_place, &cpu_tensor);
    auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(
        platform::DeviceContextPool::Instance().Get(tensor.place()));
    const std::string debug_info =
        GetHintString<T>(op_type, var_name, place, dev_id);
    // 2. write log to file
    CheckNanInfCpuImpl(cpu_tensor.data<T>(), tensor.numel(), debug_info, "gpu");
    return;
  }

  // Write log to window
449 450
  char* gpu_str_ptr =
      GetGpuHintStringPtr<T>(*dev_ctx, op_type, var_name, dev_id);
W
WangXi 已提交
451

452 453 454 455
#ifdef __HIPCC__
  // HIP will throw GPU memory access fault if threads > 256
  const size_t threads = 256;
#else
W
WangXi 已提交
456
  const size_t threads = 1024;
457
#endif
458 459
  size_t blocks =
      std::min(static_cast<size_t>(128),
460
               static_cast<size_t>((tensor.numel() + threads - 1) / threads));
461
#ifdef __HIPCC__
462 463
  int print_num = 3;

464 465 466 467 468
  hipLaunchKernelGGL(CheckNanInfKernel,
                     dim3(blocks),
                     dim3(threads),
                     0,
                     dev_ctx->stream(),
469 470
                     tensor.data<T>(),
                     tensor.numel(),
471 472
                     print_num,
                     gpu_str_ptr);
473
#else
474 475 476 477
  using MT = typename phi::dtype::MPTypeTrait<T>::Type;

  int64_t numel_max_min = blocks;

478 479
  phi::DenseTensor block_num_nan_inf_zero;
  block_num_nan_inf_zero.Resize({static_cast<int64_t>(3 * numel_max_min)});
480
  int64_t* block_num_nan_ptr =
481
      dev_ctx->template Alloc<int64_t>(&block_num_nan_inf_zero);
482
  int64_t* block_num_inf_ptr = block_num_nan_ptr + numel_max_min;
483
  int64_t* block_num_zero_ptr = block_num_inf_ptr + numel_max_min;
484

485 486
  phi::DenseTensor tensor_block_max_min;
  tensor_block_max_min.Resize({static_cast<int64_t>(3 * numel_max_min)});
487
  MT* tensor_block_max_ptr = dev_ctx->template Alloc<MT>(&tensor_block_max_min);
488 489 490 491
  MT* tensor_block_min_ptr = tensor_block_max_ptr + numel_max_min;
  MT* tensor_block_mean_ptr = tensor_block_max_ptr + 2 * numel_max_min;

  FindNanInfAndBlockMaxMin<T, MT>
492 493 494 495
      <<<blocks, threads, 0, dev_ctx->stream()>>>(tensor.data<T>(),
                                                  tensor.numel(),
                                                  block_num_nan_ptr,
                                                  block_num_inf_ptr,
496
                                                  block_num_zero_ptr,
497 498 499 500
                                                  tensor_block_max_ptr,
                                                  tensor_block_min_ptr,
                                                  tensor_block_mean_ptr);

501
  int check_nan_inf_level = FLAGS_check_nan_inf_level;
502 503 504 505
  phi::DenseTensor nan_inf_zero_tensor;
  nan_inf_zero_tensor.Resize({static_cast<int64_t>(3)});
  int64_t* nan_inf_zero =
      dev_ctx->template Alloc<int64_t>(&nan_inf_zero_tensor);
506
  FindGlobalMaxMinAndPrint<T, MT>
507 508
      <<<1, 1, 0, dev_ctx->stream()>>>(block_num_nan_ptr,
                                       block_num_inf_ptr,
509
                                       block_num_zero_ptr,
510 511 512 513
                                       tensor_block_max_ptr,
                                       tensor_block_min_ptr,
                                       tensor_block_mean_ptr,
                                       gpu_str_ptr,
514
                                       tensor.numel(),
515
                                       numel_max_min,
516 517 518
                                       check_nan_inf_level,
                                       nan_inf_zero_tensor.data<int64_t>());

519
  if (check_nan_inf_level == 0 && GetNanInfStackLimit() > 0) {
520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
    auto nan_cpu =
        phi::memory_utils::Alloc(phi::CPUPlace(), sizeof(int64_t) * 3);
    int64_t* nan_cpu_ptr = reinterpret_cast<int64_t*>(nan_cpu->ptr());
    phi::memory_utils::Copy(phi::CPUPlace(),
                            nan_cpu_ptr,
                            place,
                            nan_inf_zero,
                            3 * sizeof(int64_t),
                            dev_ctx->stream());

    dev_ctx->Wait();
    if (nan_cpu_ptr[0] > 0 || nan_cpu_ptr[1] > 0) {
      const std::string debug_info =
          GetHintString<T>(op_type, var_name, place, dev_id);
      PADDLE_THROW(platform::errors::PreconditionNotMet(
          "There are NAN or INF (num_nan=%lld, num_inf=%lld, num_zero=%lld) in "
          "%s.",
          static_cast<long long>(nan_cpu_ptr[0]),  // NOLINT
          static_cast<long long>(nan_cpu_ptr[1]),  // NOLINT
          static_cast<long long>(nan_cpu_ptr[2]),  // NOLINT
          debug_info));
    }
  }
543
#endif
W
WangXi 已提交
544 545 546
}

template <>
L
Leo Chen 已提交
547 548
void tensor_check<phi::GPUContext>(const std::string& op_type,
                                   const std::string& var_name,
549
                                   const phi::DenseTensor& tensor,
L
Leo Chen 已提交
550
                                   const platform::Place& place) {
W
WangXi 已提交
551 552
  std::call_once(init_multi_gpu_op_var_map_flag, InitMultiGPUOpVarMap);

L
Leo Chen 已提交
553
  TensorCheckerVisitor<phi::GPUContext> vistor(
554
      op_type, var_name, tensor, place);
555
  VisitDataType(framework::TransToProtoVarType(tensor.dtype()), vistor);
W
WangXi 已提交
556 557 558 559 560
}

}  // namespace details
}  // namespace framework
}  // namespace paddle