nan_inf_utils_detail.cu 17.3 KB
Newer Older
W
WangXi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
#include "paddle/fluid/framework/details/nan_inf_utils_detail.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h"

W
WangXi 已提交
18 19 20 21
#include <algorithm>
#include <unordered_map>
#include <utility>
#include <vector>
22

23
#include "paddle/fluid/framework/convert_utils.h"
24
#include "paddle/fluid/framework/scope.h"
25 26 27
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"

28
DECLARE_int32(check_nan_inf_level);
W
WangXi 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49

namespace paddle {
namespace framework {
namespace details {

static std::once_flag init_multi_gpu_op_var_map_flag;

// lazy init
static std::vector<std::unordered_map<std::string, memory::AllocationPtr>>&
multi_op_var2gpu_str() {
  static std::vector<std::unordered_map<std::string, memory::AllocationPtr>>
      _multi_op_var2gpu_str;
  return _multi_op_var2gpu_str;
}

static std::vector<std::mutex>& multi_op_var2gpu_str_mutex() {
  static std::vector<std::mutex> _multi_op_var2gpu_str_mutex;
  return _multi_op_var2gpu_str_mutex;
}

static void InitMultiGPUOpVarMap() {
50
  int dev_count = platform::GetGPUDeviceCount();
51 52
  PADDLE_ENFORCE_GT(dev_count,
                    0,
W
WangXi 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
                    platform::errors::NotFound(
                        "cuda device must > 0, now dev_count=%d", dev_count));

  // https://stackoverflow.com/questions/16465633/how-can-i-use-something-like-stdvectorstdmutex
  std::vector<std::unordered_map<std::string, memory::AllocationPtr>> tmp_multi(
      dev_count);
  std::vector<std::mutex> tmp_multi_mutex(dev_count);

  multi_op_var2gpu_str().swap(tmp_multi);
  multi_op_var2gpu_str_mutex().swap(tmp_multi_mutex);
}

template <typename T>
__device__ __forceinline__ void PrintNanInfKernel(const T* value,
                                                  const size_t numel,
                                                  int print_num,
                                                  char* debug_info) {
  const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;

  __shared__ unsigned int nan_count, inf_count, num_count;
  if (threadIdx.x == 0) nan_count = inf_count = num_count = 0;
  __syncthreads;

  for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
    unsigned int count = 0;
    if (isnan(value[i])) {
      count = atomicAdd(&nan_count, 1);
    } else if (isinf(value[i])) {
      count = atomicAdd(&inf_count, 1);
    } else {
      count = atomicAdd(&num_count, 1);
    }
    // for cuda, print in every block
    if (count < print_num) {
87 88 89 90
      printf("numel:%lu idx:%lu value:%f\n",
             static_cast<uint64_t>(numel),
             static_cast<uint64_t>(i),
             static_cast<float>(value[i]));
W
WangXi 已提交
91 92 93 94
    }
  }
  __syncthreads;

95
#ifdef __HIPCC__
96
  if (true && hipThreadIdx_x == 0) {
97 98 99 100 101
    printf("In block %d, there has %u,%u,%u nan,inf,num\n",
           hipBlockIdx_x,
           nan_count,
           inf_count,
           num_count);
102
#else
W
WangXi 已提交
103
  if (true && threadIdx.x == 0) {
104 105 106 107 108
    printf("In block %d, there has %u,%u,%u nan,inf,num\n",
           blockIdx.x,
           nan_count,
           inf_count,
           num_count);
109
#endif
W
WangXi 已提交
110 111 112 113 114 115
    PADDLE_ENFORCE(false, "===ERROR: in %s find nan or inf===", debug_info);
  }
}

// Resnet 2gpus speed test, no check 270 images/s, this check 229 images/s
template <typename T>
116 117 118 119
__global__ void CheckNanInfKernel(const T* value,
                                  const size_t numel,
                                  int print_num,
                                  char* debug_info) {
W
WangXi 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
  /// step 1, judge wheater has nan or inf
  __shared__ volatile int has_nan_inf;
  if (threadIdx.x == 0) has_nan_inf = false;
  __syncthreads();

  const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
  T sum = static_cast<T>(0.0);
  // Todo(wangxi). simd speed up
  for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
    sum += (value[i] - value[i]);
  }

  if (isnan(sum) || isinf(sum)) has_nan_inf = true;
  __syncthreads();

  /// Note. different blocks may behave differently
  if (!has_nan_inf) return;

  PrintNanInfKernel(value, numel, print_num, debug_info);
}

141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
template <
    typename T,
    std::enable_if_t<std::is_same<T, phi::dtype::complex<float>>::value ||
                         std::is_same<T, phi::dtype::complex<double>>::value,
                     bool> = true>
__device__ void BlockReduceMaxMinAndWrite(const T max_value,
                                          const T min_value,
                                          const T mean_value,
                                          int64_t offset,
                                          T* max_ptr,
                                          T* min_ptr,
                                          T* mean_ptr) {
  // TODO(Xreki): support complex
}

template <
    typename T,
    std::enable_if_t<!std::is_same<T, phi::dtype::complex<float>>::value &&
                         !std::is_same<T, phi::dtype::complex<double>>::value,
                     bool> = true>
__device__ void BlockReduceMaxMinAndWrite(const T max_value,
                                          const T min_value,
                                          const T mean_value,
                                          int64_t offset,
                                          T* max_ptr,
                                          T* min_ptr,
                                          T* mean_ptr) {
  if (max_ptr && min_ptr && mean_ptr) {
    __syncthreads();

    T block_max_value = phi::funcs::blockReduceMax<T>(max_value, FINAL_MASK);
    T block_min_value = phi::funcs::blockReduceMin<T>(min_value, FINAL_MASK);
    T block_mean_value = phi::funcs::blockReduceSum<T>(mean_value, FINAL_MASK);

    if (threadIdx.x == 0) {
      max_ptr[offset] = block_max_value;
      min_ptr[offset] = block_min_value;
      mean_ptr[offset] = block_mean_value;
    }
  }
}

template <typename T, typename MT>
__global__ void FindNanInfAndBlockMaxMin(const T* value_ptr,
                                         const int64_t numel,
                                         int* found_nan_inf_ptr,
                                         MT* tensor_block_max_ptr,
                                         MT* tensor_block_min_ptr,
                                         MT* tensor_block_mean_ptr) {
  bool has_nan = false;
  bool has_inf = false;

  int64_t i = threadIdx.x + blockIdx.x * blockDim.x;

  MT max_value = static_cast<MT>(i < numel ? value_ptr[i] : value_ptr[0]);
  MT min_value = static_cast<MT>(i < numel ? value_ptr[i] : value_ptr[0]);
  MT mean_value = static_cast<MT>(0);
  for (; i < numel; i += blockDim.x * gridDim.x) {
    MT value = static_cast<MT>(value_ptr[i]);

    max_value = value > max_value ? value : max_value;
    min_value = value < min_value ? value : min_value;
    mean_value += value / static_cast<MT>(numel);

    if (isnan(value)) {
      has_nan = true;
    }
    if (isinf(value)) {
      has_inf = true;
    }

    if (has_nan || has_inf) {
      if (!tensor_block_max_ptr && !tensor_block_min_ptr &&
          !tensor_block_mean_ptr) {
        break;
      }
    }
  }
  if (has_nan) {
    found_nan_inf_ptr[0] = 1;
  }
  if (has_inf) {
    found_nan_inf_ptr[1] = 1;
  }

  BlockReduceMaxMinAndWrite<MT>(max_value,
                                min_value,
                                mean_value,
                                blockIdx.x,
                                tensor_block_max_ptr,
                                tensor_block_min_ptr,
                                tensor_block_mean_ptr);
}

235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
template <typename T,
          typename MT,
          std::enable_if_t<std::is_same<T, float>::value, bool> = true>
__device__ bool NeedPrint(MT max_value, MT min_value, int check_nan_inf_level) {
  if (check_nan_inf_level >= 3) {
    return true;
  } else if (check_nan_inf_level >= 2) {
    MT fp16_max =
        static_cast<MT>(std::numeric_limits<phi::dtype::float16>::max());
    return max_value > fp16_max || min_value < -fp16_max;
  }
  return false;
}

template <typename T,
          typename MT,
          std::enable_if_t<!std::is_same<T, float>::value, bool> = true>
__device__ bool NeedPrint(MT max_value, MT min_value, int check_nan_inf_level) {
  if (check_nan_inf_level >= 3) {
    return true;
  }
  return false;
}

template <typename T, typename MT>
260
__global__ void FindGlobalMaxMinAndPrint(const int* found_nan_inf_ptr,
261 262 263
                                         const MT* tensor_block_max_ptr,
                                         const MT* tensor_block_min_ptr,
                                         const MT* tensor_block_mean_ptr,
264 265 266
                                         const char* debug_info,
                                         int64_t numel,
                                         int64_t numel_max_min,
267
                                         int check_nan_inf_level) {
268 269 270 271
  if (blockIdx.x == 0 && threadIdx.x == 0) {
    int has_nan = found_nan_inf_ptr[0];
    int has_inf = found_nan_inf_ptr[1];

272 273 274
    MT max_value = static_cast<MT>(0);
    MT min_value = static_cast<MT>(0);
    MT mean_value = static_cast<MT>(0);
275 276 277 278 279 280 281
    if (tensor_block_max_ptr && tensor_block_min_ptr && tensor_block_mean_ptr) {
      max_value = tensor_block_max_ptr[0];
      min_value = tensor_block_min_ptr[0];
      mean_value = tensor_block_mean_ptr[0];

      // numel_max_min <= 128
      for (int64_t i = 1; i < numel_max_min; ++i) {
282 283 284
        MT tmp_max_value = tensor_block_max_ptr[i];
        MT tmp_min_value = tensor_block_min_ptr[i];
        MT tmp_mean_value = tensor_block_mean_ptr[i];
285 286 287 288 289 290 291 292

        max_value = tmp_max_value > max_value ? tmp_max_value : max_value;
        min_value = tmp_min_value < min_value ? tmp_min_value : min_value;
        mean_value += tmp_mean_value;
      }
    }

    if (has_nan || has_inf) {
293
      if (check_nan_inf_level == 0) {
294 295 296 297 298 299 300 301 302 303 304
        PADDLE_ENFORCE(false,
                       "===[PRECISION] [ERROR] in %s, numel=%ld, find_nan=%d, "
                       "find_inf=%d, "
                       "max=%e, min=%e, mean=%e===\n",
                       debug_info,
                       numel,
                       has_nan,
                       has_inf,
                       static_cast<float>(max_value),
                       static_cast<float>(min_value),
                       static_cast<float>(mean_value));
305
      } else if (check_nan_inf_level >= 1) {
306 307 308 309 310 311 312 313 314 315 316 317
        printf(
            "===[PRECISION] [ERROR] in %s, numel=%ld, find_nan=%d, "
            "find_inf=%d, "
            "max=%e, min=%e, mean=%e===\n",
            debug_info,
            numel,
            has_nan,
            has_inf,
            static_cast<float>(max_value),
            static_cast<float>(min_value),
            static_cast<float>(mean_value));
      }
318
    } else if (NeedPrint<T, MT>(max_value, min_value, check_nan_inf_level)) {
319 320 321 322 323 324 325 326 327 328
      printf("[PRECISION] in %s, numel=%ld, max=%e, min=%e, mean=%e\n",
             debug_info,
             numel,
             static_cast<float>(max_value),
             static_cast<float>(min_value),
             static_cast<float>(mean_value));
    }
  }
}

W
WangXi 已提交
329 330
template <>
template <typename T>
L
Leo Chen 已提交
331
void TensorCheckerVisitor<phi::GPUContext>::apply(
332 333 334 335 336
    typename std::enable_if<
        std::is_floating_point<T>::value ||
        std::is_same<T, ::paddle::platform::complex<float>>::value ||
        std::is_same<T, ::paddle::platform::complex<double>>::value>::type*)
    const {
L
Leo Chen 已提交
337
  auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(
W
WangXi 已提交
338
      platform::DeviceContextPool::Instance().Get(tensor_.place()));
339
  int dev_id = tensor_.place().device;
W
WangXi 已提交
340
  PADDLE_ENFORCE_EQ(
341 342
      (dev_id >= 0 && dev_id < multi_op_var2gpu_str_mutex().size()),
      true,
W
WangXi 已提交
343 344 345
      platform::errors::OutOfRange("GPU dev_id must >=0 and < dev_count=%d",
                                   multi_op_var2gpu_str_mutex().size()));

346 347 348 349 350 351
  std::string dtype_str = DataTypeToString(DataTypeTrait<T>::DataType());
  if (dtype_str == "::paddle::platform::float16") {
    dtype_str = "float16";
  }
  std::string op_var = "[op=" + op_type_ + "] [tensor=" + var_name_ +
                       "] [dtype=" + dtype_str + "]";
W
WangXi 已提交
352 353 354 355 356 357 358 359
  char* gpu_str_ptr = NULL;

  {
    auto& op_var2gpu_str_mutex = multi_op_var2gpu_str_mutex().at(dev_id);
    auto& op_var2gpu_str = multi_op_var2gpu_str().at(dev_id);

    std::lock_guard<std::mutex> guard(op_var2gpu_str_mutex);
    if (op_var2gpu_str.find(op_var) == op_var2gpu_str.end()) {  // insert
360 361 362 363
      auto gpu_str_tensor = paddle::memory::Alloc(
          dev_ctx->GetPlace(),
          op_var.length() + 1,
          phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx->stream())));
W
WangXi 已提交
364 365 366 367 368
      gpu_str_ptr = reinterpret_cast<char*>(gpu_str_tensor->ptr());

      op_var2gpu_str.emplace(op_var, std::move(gpu_str_tensor));

      auto iter = op_var2gpu_str.find(op_var);
369 370
      PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(),
                        true,
W
WangXi 已提交
371 372 373 374 375
                        platform::errors::PreconditionNotMet(
                            "op_var=%s should successed insert into "
                            "op_var2gpu_str, but now failed",
                            op_var));

376
#ifdef __HIPCC__
377 378 379 380 381
      PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(gpu_str_ptr,
                                                iter->first.c_str(),
                                                op_var.length() + 1,
                                                hipMemcpyHostToDevice,
                                                dev_ctx->stream()));
382
#else
383 384 385 386 387
      PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(gpu_str_ptr,
                                                 iter->first.c_str(),
                                                 op_var.length() + 1,
                                                 cudaMemcpyHostToDevice,
                                                 dev_ctx->stream()));
388
#endif
W
WangXi 已提交
389 390
    } else {  // get
      auto iter = op_var2gpu_str.find(op_var);
391 392
      PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(),
                        true,
W
WangXi 已提交
393 394 395 396 397 398 399 400
                        platform::errors::PreconditionNotMet(
                            "op_var=%s should be in the op_var2gpu_str, but "
                            "now can't find it",
                            op_var));
      gpu_str_ptr = reinterpret_cast<char*>(iter->second->ptr());
    }
  }

401 402 403 404
#ifdef __HIPCC__
  // HIP will throw GPU memory access fault if threads > 256
  const size_t threads = 256;
#else
W
WangXi 已提交
405
  const size_t threads = 1024;
406
#endif
407 408 409
  size_t blocks =
      std::min(static_cast<size_t>(128),
               static_cast<size_t>((tensor_.numel() + threads - 1) / threads));
410
#ifdef __HIPCC__
411 412
  int print_num = 3;

413 414 415 416 417 418 419 420 421
  hipLaunchKernelGGL(CheckNanInfKernel,
                     dim3(blocks),
                     dim3(threads),
                     0,
                     dev_ctx->stream(),
                     tensor_.data<T>(),
                     tensor_.numel(),
                     print_num,
                     gpu_str_ptr);
422
#else
423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447
  using MT = typename phi::dtype::MPTypeTrait<T>::Type;

  phi::DenseTensor found_nan_inf;
  found_nan_inf.Resize({2});
  int* found_nan_inf_ptr = found_nan_inf.mutable_data<int>(tensor_.place());
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
      found_nan_inf_ptr, 0, 2 * sizeof(int), dev_ctx->stream()));

  int64_t numel_max_min = blocks;

  phi::DenseTensor tensor_block_max_min;
  tensor_block_max_min.Resize({static_cast<int64_t>(3 * numel_max_min)});
  MT* tensor_block_max_ptr =
      tensor_block_max_min.mutable_data<MT>(tensor_.place());
  MT* tensor_block_min_ptr = tensor_block_max_ptr + numel_max_min;
  MT* tensor_block_mean_ptr = tensor_block_max_ptr + 2 * numel_max_min;

  FindNanInfAndBlockMaxMin<T, MT>
      <<<blocks, threads, 0, dev_ctx->stream()>>>(tensor_.data<T>(),
                                                  tensor_.numel(),
                                                  found_nan_inf_ptr,
                                                  tensor_block_max_ptr,
                                                  tensor_block_min_ptr,
                                                  tensor_block_mean_ptr);

448 449
  int check_nan_inf_level = FLAGS_check_nan_inf_level;
  FindGlobalMaxMinAndPrint<T, MT>
450 451 452 453 454 455 456
      <<<1, 1, 0, dev_ctx->stream()>>>(found_nan_inf_ptr,
                                       tensor_block_max_ptr,
                                       tensor_block_min_ptr,
                                       tensor_block_mean_ptr,
                                       gpu_str_ptr,
                                       tensor_.numel(),
                                       numel_max_min,
457
                                       check_nan_inf_level);
458
#endif
W
WangXi 已提交
459 460 461
}

template <>
L
Leo Chen 已提交
462 463
void tensor_check<phi::GPUContext>(const std::string& op_type,
                                   const std::string& var_name,
464
                                   const phi::DenseTensor& tensor,
L
Leo Chen 已提交
465
                                   const platform::Place& place) {
W
WangXi 已提交
466 467
  std::call_once(init_multi_gpu_op_var_map_flag, InitMultiGPUOpVarMap);

L
Leo Chen 已提交
468
  TensorCheckerVisitor<phi::GPUContext> vistor(
469
      op_type, var_name, tensor, place);
470
  VisitDataType(framework::TransToProtoVarType(tensor.dtype()), vistor);
W
WangXi 已提交
471 472 473 474 475
}

}  // namespace details
}  // namespace framework
}  // namespace paddle