nan_inf_utils_detail.cu 8.0 KB
Newer Older
W
WangXi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/details/nan_inf_utils_detail.h"

#include <algorithm>
#include <unordered_map>
#include <utility>
#include <vector>
22
#include "paddle/fluid/framework/convert_utils.h"
W
WangXi 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43

namespace paddle {
namespace framework {
namespace details {

static std::once_flag init_multi_gpu_op_var_map_flag;

// lazy init
static std::vector<std::unordered_map<std::string, memory::AllocationPtr>>&
multi_op_var2gpu_str() {
  static std::vector<std::unordered_map<std::string, memory::AllocationPtr>>
      _multi_op_var2gpu_str;
  return _multi_op_var2gpu_str;
}

static std::vector<std::mutex>& multi_op_var2gpu_str_mutex() {
  static std::vector<std::mutex> _multi_op_var2gpu_str_mutex;
  return _multi_op_var2gpu_str_mutex;
}

static void InitMultiGPUOpVarMap() {
44
  int dev_count = platform::GetGPUDeviceCount();
W
WangXi 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
  PADDLE_ENFORCE_GT(dev_count, 0,
                    platform::errors::NotFound(
                        "cuda device must > 0, now dev_count=%d", dev_count));

  // https://stackoverflow.com/questions/16465633/how-can-i-use-something-like-stdvectorstdmutex
  std::vector<std::unordered_map<std::string, memory::AllocationPtr>> tmp_multi(
      dev_count);
  std::vector<std::mutex> tmp_multi_mutex(dev_count);

  multi_op_var2gpu_str().swap(tmp_multi);
  multi_op_var2gpu_str_mutex().swap(tmp_multi_mutex);
}

template <typename T>
__device__ __forceinline__ void PrintNanInfKernel(const T* value,
                                                  const size_t numel,
                                                  int print_num,
                                                  char* debug_info) {
  const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;

  __shared__ unsigned int nan_count, inf_count, num_count;
  if (threadIdx.x == 0) nan_count = inf_count = num_count = 0;
  __syncthreads;

  for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
    unsigned int count = 0;
    if (isnan(value[i])) {
      count = atomicAdd(&nan_count, 1);
    } else if (isinf(value[i])) {
      count = atomicAdd(&inf_count, 1);
    } else {
      count = atomicAdd(&num_count, 1);
    }
    // for cuda, print in every block
    if (count < print_num) {
      printf("numel:%lu idx:%lu value:%f\n", static_cast<uint64_t>(numel),
             static_cast<uint64_t>(i), static_cast<float>(value[i]));
    }
  }
  __syncthreads;

86
#ifdef __HIPCC__
87 88 89 90
  if (true && hipThreadIdx_x == 0) {
    printf("In block %d, there has %u,%u,%u nan,inf,num\n", hipBlockIdx_x,
           nan_count, inf_count, num_count);
#else
W
WangXi 已提交
91 92 93
  if (true && threadIdx.x == 0) {
    printf("In block %d, there has %u,%u,%u nan,inf,num\n", blockIdx.x,
           nan_count, inf_count, num_count);
94
#endif
W
WangXi 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
    PADDLE_ENFORCE(false, "===ERROR: in %s find nan or inf===", debug_info);
  }
}

// Resnet 2gpus speed test, no check 270 images/s, this check 229 images/s
template <typename T>
__global__ void CheckNanInfKernel(const T* value, const size_t numel,
                                  int print_num, char* debug_info) {
  /// step 1, judge wheater has nan or inf
  __shared__ volatile int has_nan_inf;
  if (threadIdx.x == 0) has_nan_inf = false;
  __syncthreads();

  const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
  T sum = static_cast<T>(0.0);
  // Todo(wangxi). simd speed up
  for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
    sum += (value[i] - value[i]);
  }

  if (isnan(sum) || isinf(sum)) has_nan_inf = true;
  __syncthreads();

  /// Note. different blocks may behave differently
  if (!has_nan_inf) return;

  PrintNanInfKernel(value, numel, print_num, debug_info);
}

template <>
template <typename T>
void TensorCheckerVisitor<platform::CUDADeviceContext>::apply(
127 128 129 130 131
    typename std::enable_if<
        std::is_floating_point<T>::value ||
        std::is_same<T, ::paddle::platform::complex<float>>::value ||
        std::is_same<T, ::paddle::platform::complex<double>>::value>::type*)
    const {
W
WangXi 已提交
132 133 134 135
  int print_num = 3;

  auto* dev_ctx = reinterpret_cast<platform::CUDADeviceContext*>(
      platform::DeviceContextPool::Instance().Get(tensor_.place()));
136
  int dev_id = tensor_.place().device;
W
WangXi 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
  PADDLE_ENFORCE_EQ(
      (dev_id >= 0 && dev_id < multi_op_var2gpu_str_mutex().size()), true,
      platform::errors::OutOfRange("GPU dev_id must >=0 and < dev_count=%d",
                                   multi_op_var2gpu_str_mutex().size()));

  std::string op_var = "[op=" + op_type_ + "] [tensor=" + var_name_ + "]";
  char* gpu_str_ptr = NULL;

  {
    auto& op_var2gpu_str_mutex = multi_op_var2gpu_str_mutex().at(dev_id);
    auto& op_var2gpu_str = multi_op_var2gpu_str().at(dev_id);

    std::lock_guard<std::mutex> guard(op_var2gpu_str_mutex);
    if (op_var2gpu_str.find(op_var) == op_var2gpu_str.end()) {  // insert
      auto gpu_str_tensor =
          paddle::memory::Alloc(*dev_ctx, op_var.length() + 1);
      gpu_str_ptr = reinterpret_cast<char*>(gpu_str_tensor->ptr());

      op_var2gpu_str.emplace(op_var, std::move(gpu_str_tensor));

      auto iter = op_var2gpu_str.find(op_var);
      PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(), true,
                        platform::errors::PreconditionNotMet(
                            "op_var=%s should successed insert into "
                            "op_var2gpu_str, but now failed",
                            op_var));

164
#ifdef __HIPCC__
165
      PADDLE_ENFORCE_GPU_SUCCESS(
166 167 168
          hipMemcpyAsync(gpu_str_ptr, iter->first.c_str(), op_var.length() + 1,
                         hipMemcpyHostToDevice, dev_ctx->stream()));
#else
169
      PADDLE_ENFORCE_GPU_SUCCESS(
W
WangXi 已提交
170
          cudaMemcpyAsync(gpu_str_ptr, iter->first.c_str(), op_var.length() + 1,
171
                          cudaMemcpyHostToDevice, dev_ctx->stream()));
172
#endif
W
WangXi 已提交
173 174 175 176 177 178 179 180 181 182 183
    } else {  // get
      auto iter = op_var2gpu_str.find(op_var);
      PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(), true,
                        platform::errors::PreconditionNotMet(
                            "op_var=%s should be in the op_var2gpu_str, but "
                            "now can't find it",
                            op_var));
      gpu_str_ptr = reinterpret_cast<char*>(iter->second->ptr());
    }
  }

184 185 186 187
#ifdef __HIPCC__
  // HIP will throw GPU memory access fault if threads > 256
  const size_t threads = 256;
#else
W
WangXi 已提交
188
  const size_t threads = 1024;
189
#endif
190 191 192
  size_t blocks =
      std::min(static_cast<size_t>(128),
               static_cast<size_t>((tensor_.numel() + threads - 1) / threads));
193
#ifdef __HIPCC__
194 195 196 197
  hipLaunchKernelGGL(CheckNanInfKernel, dim3(blocks), dim3(threads), 0,
                     dev_ctx->stream(), tensor_.data<T>(), tensor_.numel(),
                     print_num, gpu_str_ptr);
#else
W
WangXi 已提交
198 199
  CheckNanInfKernel<<<blocks, threads, 0, dev_ctx->stream()>>>(
      tensor_.data<T>(), tensor_.numel(), print_num, gpu_str_ptr);
200
#endif
W
WangXi 已提交
201 202 203 204 205 206 207 208 209 210 211
}

template <>
void tensor_check<platform::CUDADeviceContext>(const std::string& op_type,
                                               const std::string& var_name,
                                               const framework::Tensor& tensor,
                                               const platform::Place& place) {
  std::call_once(init_multi_gpu_op_var_map_flag, InitMultiGPUOpVarMap);

  TensorCheckerVisitor<platform::CUDADeviceContext> vistor(op_type, var_name,
                                                           tensor, place);
212
  VisitDataType(framework::TransToProtoVarType(tensor.dtype()), vistor);
W
WangXi 已提交
213 214 215 216 217
}

}  // namespace details
}  // namespace framework
}  // namespace paddle