check_numerics_kernel.cu 23.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/check_numerics_kernel.h"

#include "glog/logging.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
19
#include "paddle/phi/common/amp_type_traits.h"
20
#include "paddle/phi/common/float16.h"
21
#include "paddle/phi/common/memory_utils.h"
22 23
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/check_numerics_utils.h"
24 25
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"

26
namespace phi {
W
WangXi 已提交
27 28 29 30

static std::once_flag init_multi_gpu_op_var_map_flag;

// lazy init
31 32
static std::vector<
    std::unordered_map<std::string, phi::Allocator::AllocationPtr>>&
W
WangXi 已提交
33
multi_op_var2gpu_str() {
34 35
  static std::vector<
      std::unordered_map<std::string, phi::Allocator::AllocationPtr>>
W
WangXi 已提交
36 37 38 39 40 41 42 43 44 45
      _multi_op_var2gpu_str;
  return _multi_op_var2gpu_str;
}

static std::vector<std::mutex>& multi_op_var2gpu_str_mutex() {
  static std::vector<std::mutex> _multi_op_var2gpu_str_mutex;
  return _multi_op_var2gpu_str_mutex;
}

static void InitMultiGPUOpVarMap() {
46
  int dev_count = phi::backends::gpu::GetGPUDeviceCount();
47 48
  PADDLE_ENFORCE_GT(dev_count,
                    0,
49
                    phi::errors::NotFound(
W
WangXi 已提交
50 51 52
                        "cuda device must > 0, now dev_count=%d", dev_count));

  // https://stackoverflow.com/questions/16465633/how-can-i-use-something-like-stdvectorstdmutex
53 54
  std::vector<std::unordered_map<std::string, phi::Allocator::AllocationPtr>>
      tmp_multi(dev_count);
W
WangXi 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
  std::vector<std::mutex> tmp_multi_mutex(dev_count);

  multi_op_var2gpu_str().swap(tmp_multi);
  multi_op_var2gpu_str_mutex().swap(tmp_multi_mutex);
}

template <typename T>
__device__ __forceinline__ void PrintNanInfKernel(const T* value,
                                                  const size_t numel,
                                                  int print_num,
                                                  char* debug_info) {
  const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;

  __shared__ unsigned int nan_count, inf_count, num_count;
  if (threadIdx.x == 0) nan_count = inf_count = num_count = 0;
  __syncthreads;

  for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
    unsigned int count = 0;
    if (isnan(value[i])) {
      count = atomicAdd(&nan_count, 1);
    } else if (isinf(value[i])) {
      count = atomicAdd(&inf_count, 1);
    } else {
      count = atomicAdd(&num_count, 1);
    }
    // for cuda, print in every block
    if (count < print_num) {
83 84 85 86
      printf("numel:%lu idx:%lu value:%f\n",
             static_cast<uint64_t>(numel),
             static_cast<uint64_t>(i),
             static_cast<float>(value[i]));
W
WangXi 已提交
87 88 89 90
    }
  }
  __syncthreads;

91
#ifdef __HIPCC__
92
  if (true && hipThreadIdx_x == 0) {
93 94 95 96 97
    printf("In block %d, there has %u,%u,%u nan,inf,num\n",
           hipBlockIdx_x,
           nan_count,
           inf_count,
           num_count);
98
#else
W
WangXi 已提交
99
  if (true && threadIdx.x == 0) {
100 101 102 103 104
    printf("In block %d, there has %u,%u,%u nan,inf,num\n",
           blockIdx.x,
           nan_count,
           inf_count,
           num_count);
105
#endif
W
WangXi 已提交
106 107 108 109 110 111
    PADDLE_ENFORCE(false, "===ERROR: in %s find nan or inf===", debug_info);
  }
}

// Resnet 2gpus speed test, no check 270 images/s, this check 229 images/s
template <typename T>
112 113 114 115
__global__ void CheckNanInfKernel(const T* value,
                                  const size_t numel,
                                  int print_num,
                                  char* debug_info) {
W
WangXi 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
  /// step 1, judge wheater has nan or inf
  __shared__ volatile int has_nan_inf;
  if (threadIdx.x == 0) has_nan_inf = false;
  __syncthreads();

  const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
  T sum = static_cast<T>(0.0);
  // Todo(wangxi). simd speed up
  for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
    sum += (value[i] - value[i]);
  }

  if (isnan(sum) || isinf(sum)) has_nan_inf = true;
  __syncthreads();

  /// Note. different blocks may behave differently
  if (!has_nan_inf) return;

  PrintNanInfKernel(value, numel, print_num, debug_info);
}

137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
template <typename T, int ReduceType>
__device__ T BlockReduce(T value) {
  __shared__ T shared_mem[1024];

  shared_mem[threadIdx.x] = value;
  __syncthreads();

  for (int stride = blockDim.x >> 1; stride > 0; stride = stride >> 1) {
    if (threadIdx.x < stride) {
      T value0 = shared_mem[threadIdx.x];
      T value1 = shared_mem[threadIdx.x + stride];
      T reduce_value;
      if (ReduceType == 0) {
        // max
        reduce_value = value0 > value1 ? value0 : value1;
      } else if (ReduceType == 1) {
        // min
        reduce_value = value0 < value1 ? value0 : value1;
      } else if (ReduceType == 2) {
        // sum
        reduce_value = value0 + value1;
      }
      shared_mem[threadIdx.x] = reduce_value;
    }

    if (stride > 16) {
      __syncthreads();
    }
  }

  __syncthreads();
  return shared_mem[0];
}

__device__ void BlockReduceNumNanInfAndWrite(const int64_t num_nan,
                                             const int64_t num_inf,
173
                                             const int64_t num_zero,
174 175
                                             int64_t offset,
                                             int64_t* num_nan_ptr,
176 177
                                             int64_t* num_inf_ptr,
                                             int64_t* num_zero_ptr) {
178 179
  int64_t block_num_nan = BlockReduce<int64_t, 2>(num_nan);
  int64_t block_num_inf = BlockReduce<int64_t, 2>(num_inf);
180
  int64_t block_num_zero = BlockReduce<int64_t, 2>(num_zero);
181 182 183 184

  if (threadIdx.x == 0) {
    num_nan_ptr[offset] = block_num_nan;
    num_inf_ptr[offset] = block_num_inf;
185
    num_zero_ptr[offset] = block_num_zero;
186 187 188
  }
}

189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
template <
    typename T,
    std::enable_if_t<std::is_same<T, phi::dtype::complex<float>>::value ||
                         std::is_same<T, phi::dtype::complex<double>>::value,
                     bool> = true>
__device__ void BlockReduceMaxMinAndWrite(const T max_value,
                                          const T min_value,
                                          const T mean_value,
                                          int64_t offset,
                                          T* max_ptr,
                                          T* min_ptr,
                                          T* mean_ptr) {
  // TODO(Xreki): support complex
}

template <
    typename T,
    std::enable_if_t<!std::is_same<T, phi::dtype::complex<float>>::value &&
                         !std::is_same<T, phi::dtype::complex<double>>::value,
                     bool> = true>
__device__ void BlockReduceMaxMinAndWrite(const T max_value,
                                          const T min_value,
                                          const T mean_value,
                                          int64_t offset,
                                          T* max_ptr,
                                          T* min_ptr,
                                          T* mean_ptr) {
  if (max_ptr && min_ptr && mean_ptr) {
    __syncthreads();

219 220 221
    T block_max_value = phi::funcs::BlockReduceMax<T>(max_value, FINAL_MASK);
    T block_min_value = phi::funcs::BlockReduceMin<T>(min_value, FINAL_MASK);
    T block_mean_value = phi::funcs::BlockReduceSum<T>(mean_value, FINAL_MASK);
222 223 224 225 226 227 228 229 230 231 232 233

    if (threadIdx.x == 0) {
      max_ptr[offset] = block_max_value;
      min_ptr[offset] = block_min_value;
      mean_ptr[offset] = block_mean_value;
    }
  }
}

template <typename T, typename MT>
__global__ void FindNanInfAndBlockMaxMin(const T* value_ptr,
                                         const int64_t numel,
234 235
                                         int64_t* block_num_nan_ptr,
                                         int64_t* block_num_inf_ptr,
236
                                         int64_t* block_num_zero_ptr,
237 238 239 240 241
                                         MT* tensor_block_max_ptr,
                                         MT* tensor_block_min_ptr,
                                         MT* tensor_block_mean_ptr) {
  int64_t i = threadIdx.x + blockIdx.x * blockDim.x;

242 243
  int64_t num_nan = 0;
  int64_t num_inf = 0;
244
  int64_t num_zero = 0;
245

246 247 248 249 250 251 252 253 254 255 256
  MT max_value = static_cast<MT>(i < numel ? value_ptr[i] : value_ptr[0]);
  MT min_value = static_cast<MT>(i < numel ? value_ptr[i] : value_ptr[0]);
  MT mean_value = static_cast<MT>(0);
  for (; i < numel; i += blockDim.x * gridDim.x) {
    MT value = static_cast<MT>(value_ptr[i]);

    max_value = value > max_value ? value : max_value;
    min_value = value < min_value ? value : min_value;
    mean_value += value / static_cast<MT>(numel);

    if (isnan(value)) {
257 258 259
      num_nan += 1;
    } else if (isinf(value)) {
      num_inf += 1;
260
    }
261 262 263
    if (value == static_cast<MT>(0)) {
      num_zero += 1;
    }
264
  }
265

266 267 268 269 270 271 272
  BlockReduceNumNanInfAndWrite(num_nan,
                               num_inf,
                               num_zero,
                               blockIdx.x,
                               block_num_nan_ptr,
                               block_num_inf_ptr,
                               block_num_zero_ptr);
273 274 275 276 277 278 279 280 281 282

  BlockReduceMaxMinAndWrite<MT>(max_value,
                                min_value,
                                mean_value,
                                blockIdx.x,
                                tensor_block_max_ptr,
                                tensor_block_min_ptr,
                                tensor_block_mean_ptr);
}

283
template <typename T, typename MT>
284 285
__global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr,
                                         const int64_t* block_num_inf_ptr,
286
                                         const int64_t* block_num_zero_ptr,
287 288 289
                                         const MT* tensor_block_max_ptr,
                                         const MT* tensor_block_min_ptr,
                                         const MT* tensor_block_mean_ptr,
290 291 292
                                         const char* debug_info,
                                         int64_t numel,
                                         int64_t numel_max_min,
293
                                         int check_nan_inf_level,
294 295
                                         int64_t* stats_ptr,
                                         float* values_ptr) {
296
  if (blockIdx.x == 0 && threadIdx.x == 0) {
297 298
    int64_t num_nan = 0;
    int64_t num_inf = 0;
299
    int64_t num_zero = 0;
300 301 302 303 304

    // numel_max_min <= 128
    for (int64_t i = 0; i < numel_max_min; ++i) {
      num_nan += block_num_nan_ptr[i];
      num_inf += block_num_inf_ptr[i];
305
      num_zero += block_num_zero_ptr[i];
306
    }
307

308 309 310
    MT max_value = static_cast<MT>(0);
    MT min_value = static_cast<MT>(0);
    MT mean_value = static_cast<MT>(0);
311 312 313 314 315 316 317
    if (tensor_block_max_ptr && tensor_block_min_ptr && tensor_block_mean_ptr) {
      max_value = tensor_block_max_ptr[0];
      min_value = tensor_block_min_ptr[0];
      mean_value = tensor_block_mean_ptr[0];

      // numel_max_min <= 128
      for (int64_t i = 1; i < numel_max_min; ++i) {
318 319 320
        MT tmp_max_value = tensor_block_max_ptr[i];
        MT tmp_min_value = tensor_block_min_ptr[i];
        MT tmp_mean_value = tensor_block_mean_ptr[i];
321 322 323 324 325

        max_value = tmp_max_value > max_value ? tmp_max_value : max_value;
        min_value = tmp_min_value < min_value ? tmp_min_value : min_value;
        mean_value += tmp_mean_value;
      }
326 327 328 329 330 331 332 333
      phi::funcs::SaveStatsAndValues<MT>(num_nan,
                                         num_inf,
                                         num_zero,
                                         max_value,
                                         min_value,
                                         mean_value,
                                         stats_ptr,
                                         values_ptr);
334
    }
335 336 337 338 339 340 341 342 343 344

    phi::funcs::PrintForDifferentLevel<T, MT>(debug_info,
                                              numel,
                                              num_nan,
                                              num_inf,
                                              num_zero,
                                              max_value,
                                              min_value,
                                              mean_value,
                                              check_nan_inf_level);
345 346 347
  }
}

W
WangXi 已提交
348
template <typename T>
349
inline std::string GetHintString(const std::string& op_type,
350
                                 const std::string& var_name,
351 352
                                 const phi::Place& place,
                                 int dev_id = -1) {
353 354
  std::string op_var =
      phi::funcs::GetCpuHintString<T>(op_type, var_name, place, dev_id);
W
WangXi 已提交
355
  PADDLE_ENFORCE_EQ(
356 357
      (dev_id >= 0 && dev_id < multi_op_var2gpu_str_mutex().size()),
      true,
358 359
      phi::errors::OutOfRange("GPU dev_id must >=0 and < dev_count=%d",
                              multi_op_var2gpu_str_mutex().size()));
360 361
  return op_var;
}
W
WangXi 已提交
362

363 364 365 366 367
template <typename T>
static char* GetGpuHintStringPtr(const phi::GPUContext& ctx,
                                 const std::string& op_type,
                                 const std::string& var_name,
                                 int dev_id) {
368 369
  std::call_once(init_multi_gpu_op_var_map_flag, InitMultiGPUOpVarMap);

370
  std::string op_var =
371
      GetHintString<T>(op_type, var_name, ctx.GetPlace(), dev_id);
372
  char* gpu_str_ptr = nullptr;
W
WangXi 已提交
373 374 375 376 377 378 379

  {
    auto& op_var2gpu_str_mutex = multi_op_var2gpu_str_mutex().at(dev_id);
    auto& op_var2gpu_str = multi_op_var2gpu_str().at(dev_id);

    std::lock_guard<std::mutex> guard(op_var2gpu_str_mutex);
    if (op_var2gpu_str.find(op_var) == op_var2gpu_str.end()) {  // insert
380
      auto gpu_str_tensor = phi::memory_utils::Alloc(
381
          ctx.GetPlace(),
382
          op_var.length() + 1,
383
          phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
W
WangXi 已提交
384 385 386 387 388
      gpu_str_ptr = reinterpret_cast<char*>(gpu_str_tensor->ptr());

      op_var2gpu_str.emplace(op_var, std::move(gpu_str_tensor));

      auto iter = op_var2gpu_str.find(op_var);
389 390
      PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(),
                        true,
391
                        phi::errors::PreconditionNotMet(
W
WangXi 已提交
392 393 394 395
                            "op_var=%s should successed insert into "
                            "op_var2gpu_str, but now failed",
                            op_var));

396
#ifdef __HIPCC__
397 398 399 400
      PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(gpu_str_ptr,
                                                iter->first.c_str(),
                                                op_var.length() + 1,
                                                hipMemcpyHostToDevice,
401
                                                ctx.stream()));
402
#else
403 404 405 406
      PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(gpu_str_ptr,
                                                 iter->first.c_str(),
                                                 op_var.length() + 1,
                                                 cudaMemcpyHostToDevice,
407
                                                 ctx.stream()));
408
#endif
W
WangXi 已提交
409 410
    } else {  // get
      auto iter = op_var2gpu_str.find(op_var);
411 412
      PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(),
                        true,
413
                        phi::errors::PreconditionNotMet(
W
WangXi 已提交
414 415 416 417 418 419
                            "op_var=%s should be in the op_var2gpu_str, but "
                            "now can't find it",
                            op_var));
      gpu_str_ptr = reinterpret_cast<char*>(iter->second->ptr());
    }
  }
420 421 422
  return gpu_str_ptr;
}

423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486
template <typename T>
static void PrintStack(const phi::GPUContext& ctx,
                       const DenseTensor& stats,
                       const std::string& op_type,
                       const std::string& var_name,
                       int dev_id) {
  auto cpu_stats =
      phi::memory_utils::Alloc(phi::CPUPlace(), sizeof(int64_t) * 3);
  int64_t* cpu_stats_ptr = reinterpret_cast<int64_t*>(cpu_stats->ptr());
  phi::memory_utils::Copy(phi::CPUPlace(),
                          cpu_stats_ptr,
                          stats.place(),
                          stats.data(),
                          3 * sizeof(int64_t),
                          ctx.stream());
  ctx.Wait();
  if (cpu_stats_ptr[0] > 0 || cpu_stats_ptr[1] > 0) {
    const std::string debug_info =
        GetHintString<T>(op_type, var_name, stats.place(), dev_id);
    phi::funcs::PrintAndThrowError(debug_info.c_str(),
                                   cpu_stats_ptr[0],
                                   cpu_stats_ptr[1],
                                   cpu_stats_ptr[2]);
  }
}

template <typename T, typename MT>
static void WriteToOutputDir(const phi::GPUContext& ctx,
                             const DenseTensor& tensor,
                             const DenseTensor& stats,
                             const DenseTensor& values,
                             const std::string& op_type,
                             const std::string& var_name,
                             const std::string& output_dir,
                             const int check_nan_inf_level) {
  // Copy stats and values from GPU to CPU.
  phi::DenseTensor cpu_stats;
  cpu_stats.Resize({static_cast<int64_t>(3)});
  phi::Copy(ctx, stats, phi::CPUPlace(), false, &cpu_stats);

  phi::DenseTensor cpu_values;
  cpu_values.Resize({static_cast<int64_t>(3)});
  phi::Copy(ctx, values, phi::CPUPlace(), false, &cpu_values);
  ctx.Wait();

  int dev_id = tensor.place().device;
  const std::string debug_info =
      GetHintString<T>(op_type, var_name, tensor.place(), dev_id);
  std::string log_name = "gpu." + std::to_string(dev_id);
  int64_t* cpu_stats_ptr = cpu_stats.data<int64_t>();
  float* cpu_values_ptr = cpu_values.data<float>();
  phi::funcs::WriteToFileForDifferentLevel<T, MT>(debug_info.c_str(),
                                                  tensor.numel(),
                                                  cpu_stats_ptr[0],
                                                  cpu_stats_ptr[1],
                                                  cpu_stats_ptr[2],
                                                  cpu_values_ptr[0],
                                                  cpu_values_ptr[1],
                                                  cpu_values_ptr[2],
                                                  check_nan_inf_level,
                                                  log_name,
                                                  output_dir);
}

487 488 489 490 491
template <typename T, typename Context>
void CheckNumericsKernel(const Context& ctx,
                         const DenseTensor& tensor,
                         const std::string& op_type,
                         const std::string& var_name,
492
                         const int check_nan_inf_level,
493
                         const int stack_height_limit,
494 495 496
                         const std::string& output_dir,
                         DenseTensor* stats,
                         DenseTensor* values) {
497
  int dev_id = tensor.place().device;
498
  VLOG(6) << "op_type=" << op_type << ", var_name=" << var_name
499
          << ", dev_id=gpu:" << dev_id << ", numel=" << tensor.numel()
500 501 502
          << ", stack_height_limit=" << stack_height_limit
          << ", output_dir=" << output_dir;

503 504
  if (tensor.numel() <= 0) return;

505 506
  // Print to the standard output.
  char* gpu_str_ptr = GetGpuHintStringPtr<T>(ctx, op_type, var_name, dev_id);
W
WangXi 已提交
507

508 509 510 511
#ifdef __HIPCC__
  // HIP will throw GPU memory access fault if threads > 256
  const size_t threads = 256;
#else
W
WangXi 已提交
512
  const size_t threads = 1024;
513
#endif
514 515
  size_t blocks =
      std::min(static_cast<size_t>(128),
516
               static_cast<size_t>((tensor.numel() + threads - 1) / threads));
517
#ifdef __HIPCC__
518 519
  int print_num = 3;

520 521 522 523
  hipLaunchKernelGGL(CheckNanInfKernel,
                     dim3(blocks),
                     dim3(threads),
                     0,
524
                     ctx.stream(),
525 526
                     tensor.data<T>(),
                     tensor.numel(),
527 528
                     print_num,
                     gpu_str_ptr);
529
#else
530 531 532 533
  using MT = typename phi::dtype::MPTypeTrait<T>::Type;

  int64_t numel_max_min = blocks;

534 535
  phi::DenseTensor block_num_nan_inf_zero;
  block_num_nan_inf_zero.Resize({static_cast<int64_t>(3 * numel_max_min)});
536
  int64_t* block_num_nan_ptr =
537
      ctx.template Alloc<int64_t>(&block_num_nan_inf_zero);
538
  int64_t* block_num_inf_ptr = block_num_nan_ptr + numel_max_min;
539
  int64_t* block_num_zero_ptr = block_num_inf_ptr + numel_max_min;
540

541 542
  phi::DenseTensor tensor_block_max_min;
  tensor_block_max_min.Resize({static_cast<int64_t>(3 * numel_max_min)});
543
  MT* tensor_block_max_ptr = ctx.template Alloc<MT>(&tensor_block_max_min);
544 545 546 547
  MT* tensor_block_min_ptr = tensor_block_max_ptr + numel_max_min;
  MT* tensor_block_mean_ptr = tensor_block_max_ptr + 2 * numel_max_min;

  FindNanInfAndBlockMaxMin<T, MT>
548 549 550 551 552 553 554 555
      <<<blocks, threads, 0, ctx.stream()>>>(tensor.data<T>(),
                                             tensor.numel(),
                                             block_num_nan_ptr,
                                             block_num_inf_ptr,
                                             block_num_zero_ptr,
                                             tensor_block_max_ptr,
                                             tensor_block_min_ptr,
                                             tensor_block_mean_ptr);
556

557 558 559
  // stats stores the checking result of num_nan, num_inf and num_zero.
  stats->Resize({static_cast<int64_t>(3)});
  int64_t* stats_ptr = ctx.template Alloc<int64_t>(stats);
560

561 562 563
  // values stores the max_value, min_value and mean_value.
  values->Resize({static_cast<int64_t>(3)});
  float* values_ptr = ctx.template Alloc<float>(values);
564

565
  FindGlobalMaxMinAndPrint<T, MT>
566 567 568 569 570 571 572 573 574 575
      <<<1, 1, 0, ctx.stream()>>>(block_num_nan_ptr,
                                  block_num_inf_ptr,
                                  block_num_zero_ptr,
                                  tensor_block_max_ptr,
                                  tensor_block_min_ptr,
                                  tensor_block_mean_ptr,
                                  gpu_str_ptr,
                                  tensor.numel(),
                                  numel_max_min,
                                  check_nan_inf_level,
576 577 578 579 580 581 582 583 584 585 586 587 588 589
                                  stats_ptr,
                                  values_ptr);

  if (output_dir.size() > 0) {
    // Write log to output_dir.
    WriteToOutputDir<T, MT>(ctx,
                            tensor,
                            *stats,
                            *values,
                            op_type,
                            var_name,
                            output_dir,
                            check_nan_inf_level);
  }
590 591

  if (check_nan_inf_level == 0 && stack_height_limit > 0) {
592
    PrintStack<T>(ctx, *stats, op_type, var_name, dev_id);
593
  }
594
#endif
W
WangXi 已提交
595 596
}

597 598 599 600 601 602 603 604 605 606 607 608
}  // namespace phi

PD_REGISTER_KERNEL(check_numerics,
                   GPU,
                   ALL_LAYOUT,
                   phi::CheckNumericsKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   phi::dtype::bfloat16,
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}