nanmedian_kernel.cu 9.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/nanmedian_kernel.h"

17
#include "paddle/phi/backends/gpu/gpu_context.h"
18
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
W
Wang Xin 已提交
19
#include "paddle/phi/backends/gpu/gpu_primitives.h"
20
#include "paddle/phi/common/memory_utils.h"
21 22
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
23
#include "paddle/phi/kernels/funcs/nanmedian_utils.h"
24 25 26 27
#include "paddle/phi/kernels/top_k_kernel.h"

namespace phi {

W
Wang Xin 已提交
28
using phi::PADDLE_CUDA_NUM_THREADS;
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

inline int GET_BLOCKS(const int N) {
  return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
}

template <typename T>
__global__ void KernelNanCounts(const T* input,
                                const int numel,
                                const int64_t pre_dim,
                                const int64_t stride,
                                T min_val,
                                int64_t* nan_total,
                                int64_t* nan_counts) {
  extern __shared__ int64_t buf[];
  for (int i = threadIdx.x; i < pre_dim; i += blockDim.x) {
    buf[i] = 0;
    nan_counts[i] = 0;
  }

  if (threadIdx.x == 0) {
    nan_total[0] = 0;
    nan_total[1] = 0;
  }

  __syncthreads();

  CUDA_KERNEL_LOOP(index, numel) {
    const T x = input[index];
    if (isnan(static_cast<float>(x))) {
      auto bin = static_cast<int64_t>(index / stride);
W
Wang Xin 已提交
59
      phi::CudaAtomicAdd(&buf[bin], 1);
60 61 62 63 64
    }
  }
  __syncthreads();

  for (int i = threadIdx.x; i < pre_dim; i += blockDim.x) {
W
Wang Xin 已提交
65 66 67
    phi::CudaAtomicAdd(&nan_counts[i], buf[i]);
    phi::CudaAtomicAdd(&nan_total[0], buf[i]);
    phi::CudaAtomicMax(&nan_total[1], stride - buf[i]);
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
  }
}

template <typename T>
__global__ void CalcMedianKernel(const T* sort_out_ptr,
                                 const int64_t* sort_indices_ptr,
                                 int64_t* median_val,
                                 T* output,
                                 T div_factor,
                                 const bool is_odd,
                                 const int64_t pre_dim,
                                 const int64_t stride) {
  CUDA_KERNEL_LOOP(index, pre_dim) {
    int64_t pos = static_cast<int64_t>((index + 1) * stride) - 1;
    if (is_odd) {
      median_val[index * 2] = sort_indices_ptr[pos];
      median_val[index * 2 + 1] = sort_indices_ptr[pos];
      output[index] = sort_out_ptr[pos];
    } else {
      median_val[index * 2] =
          pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos];
      median_val[index * 2 + 1] = sort_indices_ptr[pos];
      T median_val_left = pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos];
      T median_val_right = sort_out_ptr[pos];
      output[index] = (median_val_left + median_val_right) / div_factor;
    }
  }
}

template <typename T>
__global__ void CalcNanmedianKernel(const T* sort_out_ptr,
                                    const int64_t* sort_indices_ptr,
                                    int64_t* nan_counts,
                                    int64_t* median_val,
                                    T* output,
                                    const bool is_odd,
                                    const int64_t pre_dim,
                                    const int64_t max_valid_num,
                                    const int64_t stride,
                                    const T div_factor,
                                    const T nan_val) {
  CUDA_KERNEL_LOOP(index, pre_dim) {
    int64_t pos = static_cast<int64_t>(index * max_valid_num);
    int64_t nan_cnt = nan_counts[index];
    if (nan_cnt == stride) {
      median_val[index * 2] = -1;
      median_val[index * 2 + 1] = -1;
      output[index] = nan_val;
    } else {
      int64_t nan_k =
          nan_cnt > 0 ? static_cast<int64_t>(stride - nan_cnt) : max_valid_num;
      int64_t row_pos = static_cast<int64_t>(nan_k >> 1);
      pos += row_pos;

      if (nan_k & 1) {
        median_val[index * 2] = sort_indices_ptr[pos];
        median_val[index * 2 + 1] = sort_indices_ptr[pos];
        output[index] = sort_out_ptr[pos];
      } else {
        median_val[index * 2] =
            pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos];
        median_val[index * 2 + 1] = sort_indices_ptr[pos];
        T median_val_left = pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos];
        T median_val_right = sort_out_ptr[pos];
        output[index] = (median_val_left + median_val_right) / div_factor;
      }
    }
  }
}

template <typename T, typename Context>
void ProcessMedianKernel(const Context& dev_ctx,
                         const DenseTensor& x,
                         DenseTensor* out,
142
                         DenseTensor* median_index) {
143
  auto stream = dev_ctx.stream();
144 145 146
  const T* x_data = x.data<T>();
  T* out_data = dev_ctx.template Alloc<T>(out);
  int64_t* m_data = dev_ctx.template Alloc<int64_t>(median_index);
147 148 149 150 151 152 153 154 155 156 157

  int64_t numel = x.numel();
  auto x_dim = x.dims();
  int64_t x_rank = x_dim.size();
  int64_t stride = x_dim[x_rank - 1];
  int64_t pre_dim = numel / stride;
  int64_t i = 0;

  DenseTensor nan_counts, nan_stat;
  int64_t* nan_counts_ptr;
  int64_t max_valid_num = 0;
158 159 160

  bool ignore_nan = true;
  if (ignore_nan) {
161 162 163 164 165 166 167 168 169 170
    nan_counts.Resize(phi::make_ddim({pre_dim}));
    dev_ctx.template Alloc<int64_t>(&nan_counts);
    nan_counts_ptr = nan_counts.data<int64_t>();
    nan_stat.Resize(phi::make_ddim({2}));
    int64_t* nan_stat_mem = dev_ctx.template Alloc<int64_t>(&nan_stat);
    int64_t* nan_stat_ptr = nan_stat.data<int64_t>();

    KernelNanCounts<T><<<GET_BLOCKS(numel),
                         PADDLE_CUDA_NUM_THREADS,
                         pre_dim * sizeof(int64_t),
171
                         stream>>>(x_data,
172 173 174 175 176 177 178 179
                                   numel,
                                   pre_dim,
                                   stride,
                                   std::numeric_limits<T>::min(),
                                   nan_stat_ptr,
                                   nan_counts_ptr);

    auto nan_stat_mem_cpu =
180
        phi::memory_utils::Alloc(phi::CPUPlace(), sizeof(int64_t) * 2);
181 182
    int64_t* nan_stat_cpu_ptr =
        reinterpret_cast<int64_t*>(nan_stat_mem_cpu->ptr());
183 184 185 186 187 188
    memory_utils::Copy(phi::CPUPlace(),
                       nan_stat_cpu_ptr,
                       dev_ctx.GetPlace(),
                       nan_stat_mem,
                       sizeof(int64_t) * 2,
                       stream);
189 190 191 192

    // all elements are nan values
    T nan_val = std::numeric_limits<T>::quiet_NaN();
    if (nan_stat_cpu_ptr[0] == numel) {
193 194 195 196 197
      phi::funcs::SetConstant<Context, T> set_nan;
      set_nan(dev_ctx, out, nan_val);

      phi::funcs::SetConstant<Context, int64_t> set_negatvie;
      set_negatvie(dev_ctx, median_index, static_cast<int64_t>(-1));
198 199 200
      return;
    }

201
    ignore_nan = nan_stat_cpu_ptr[0] > 0;
202 203 204
    max_valid_num = nan_stat_cpu_ptr[1];
  }

205
  int64_t sort_k = ignore_nan ? max_valid_num : ((stride >> 1) + 1);
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
  bool is_ori_odd = stride & 1;

  DenseTensor sort_out, sort_indices;
  auto sort_dim = x.dims();
  int64_t rank = sort_dim.size();
  sort_dim[rank - 1] = sort_k;
  sort_out.Resize(sort_dim);
  sort_indices.Resize(sort_dim);

  dev_ctx.template Alloc<T>(&sort_out);
  T* sort_out_ptr = sort_out.data<T>();
  dev_ctx.template Alloc<int64_t>(&sort_indices);
  int64_t* sort_indices_ptr = sort_indices.data<int64_t>();

  TopkKernel<T, Context>(
      dev_ctx, x, Scalar(sort_k), -1, false, true, &sort_out, &sort_indices);

  T div_factor = static_cast<T>(2.0);
  T nan_val = std::numeric_limits<T>::quiet_NaN();
225
  if (ignore_nan) {
226 227 228 229 230
    CalcNanmedianKernel<T>
        <<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
            sort_out_ptr,
            sort_indices_ptr,
            nan_counts_ptr,
231 232
            m_data,
            out_data,
233 234 235 236 237 238
            is_ori_odd,
            pre_dim,
            max_valid_num,
            stride,
            div_factor,
            nan_val);
239
  } else {
240 241 242 243
    CalcMedianKernel<T>
        <<<GET_BLOCKS(pre_dim), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
            sort_out_ptr,
            sort_indices_ptr,
244 245
            m_data,
            out_data,
246 247 248 249
            div_factor,
            is_ori_odd,
            pre_dim,
            sort_k);
250 251 252 253 254 255 256 257 258 259
  }
}

template <typename T, typename Context>
void NanmedianKernel(const Context& dev_ctx,
                     const DenseTensor& x,
                     const IntArray& axes,
                     bool keepdim,
                     DenseTensor* out,
                     DenseTensor* median_index) {
260 261 262 263 264 265 266 267 268 269
  DenseTensor tmp_x;
  auto rank = x.dims().size();
  if ((axes.size() == 0) || rank <= 1) {
    tmp_x = x;
    tmp_x.Resize({x.numel()});
  } else {
    funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);
  }

  ProcessMedianKernel<T, Context>(dev_ctx, tmp_x, out, median_index);
270 271 272 273 274 275 276 277 278 279 280 281
}

}  // namespace phi

PD_REGISTER_KERNEL(nanmedian,
                   GPU,
                   ALL_LAYOUT,
                   phi::NanmedianKernel,
                   float,
                   double,
                   int,
                   int64_t,
282 283 284
                   phi::dtype::float16) {
  kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}