nanmedian_kernel.cc 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/nanmedian_kernel.h"
16

17 18
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
19
#include "paddle/phi/kernels/impl/nanmedian_kernel_impl.h"
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
#include "paddle/phi/kernels/top_k_kernel.h"

namespace phi {

template <typename T, typename Context>
void CalcMedianFunc(const Context& dev_ctx,
                    const DenseTensor& x,
                    const std::vector<int64_t>& nan_counts,
                    bool ignore_nan,
                    int64_t sort_k,
                    int64_t stride,
                    int64_t pre_dim,
                    T* o_ptr,
                    int64_t* m_ptr) {
  bool should_ignore_nan = ignore_nan;
  DenseTensor sort_out;
  DenseTensor sort_indices;
  auto sort_dim = x.dims();
  int64_t rank = sort_dim.size();
  sort_dim[rank - 1] = sort_k;
  sort_out.Resize(sort_dim);
  sort_indices.Resize(sort_dim);

  dev_ctx.template Alloc<T>(&sort_out);
  T* sort_out_ptr = sort_out.data<T>();
  dev_ctx.template Alloc<int64_t>(&sort_indices);
  int64_t* sort_indices_ptr = sort_indices.data<int64_t>();

  TopkKernel<T, Context>(
      dev_ctx, x, Scalar(sort_k), -1, false, true, &sort_out, &sort_indices);

  T div_factor = static_cast<T>(2.0);
  int64_t offset = 0;
  int64_t i = 0;
  bool is_ori_odd = stride & 1;
  if (should_ignore_nan) {
    for (i = 0; i < pre_dim; i++) {
      offset = i * sort_k;
      if (nan_counts[i] == stride) {
        m_ptr[i * 2] = -1;
        m_ptr[i * 2 + 1] = -1;
        o_ptr[i] = sort_out_ptr[offset];
      } else {
        int64_t nan_k = nan_counts[i] > 0
                            ? static_cast<int64_t>(stride - nan_counts[i])
                            : sort_k;
        int64_t row_pos = static_cast<int64_t>(nan_k >> 1);
        int64_t pos = offset + row_pos;
        if (nan_k & 1) {
          m_ptr[2 * i] = sort_indices_ptr[pos];
          m_ptr[2 * i + 1] = sort_indices_ptr[pos];
          o_ptr[i] = sort_out_ptr[pos];
        } else {
          m_ptr[2 * i] =
              row_pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos];
          m_ptr[2 * i + 1] = sort_indices_ptr[pos];
          T m_val_left =
              row_pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos];
          T m_val_right = sort_out_ptr[pos];
          o_ptr[i] = (m_val_left + m_val_right) / div_factor;
        }
      }
    }
  } else {
    if (is_ori_odd) {
      for (i = 0; i < pre_dim; i++) {
        offset = i * sort_k;
        int64_t pos = offset + sort_k - 1;
        o_ptr[i] = sort_out_ptr[pos];
        m_ptr[2 * i] = sort_indices_ptr[pos];
        m_ptr[2 * i + 1] = sort_indices_ptr[pos];
      }
    } else {
      for (i = 0; i < pre_dim; i++) {
        offset = i * sort_k;
        int64_t pos = offset + sort_k - 1;
        m_ptr[2 * i] =
            sort_k > 1 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos];
        m_ptr[2 * i + 1] = sort_indices_ptr[pos];
        T m_val_left = sort_k > 1 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos];
        T m_val_right = sort_out_ptr[pos];
        o_ptr[i] = (m_val_left + m_val_right) / div_factor;
      }
    }
  }
}

template <typename T, typename Context>
void ProcessMedianKernel(const Context& dev_ctx,
                         const DenseTensor& x,
                         T* o_ptr,
                         int64_t* m_ptr,
                         bool ignore_nan) {
  bool should_ignore_nan = ignore_nan;
  const T* x_ptr = x.data<T>();

  int64_t numel = x.numel();
  auto x_dim = x.dims();
  int64_t x_rank = x_dim.size();
  int64_t stride = x_dim[x_rank - 1];
  int64_t pre_dim = numel / stride;
  int64_t i = 0;

  int64_t max_valid_num = 0;
  std::vector<int64_t> nan_counts;
  if (should_ignore_nan) {
    int64_t total_nan_num = 0;
    std::vector<T> col_vec;
    col_vec.reserve(stride);
    col_vec.resize(stride);
    nan_counts.clear();
    nan_counts.reserve(pre_dim);
    nan_counts.resize(pre_dim);
    for (int64_t i = 0; i < pre_dim; i++) {
      col_vec.clear();
      col_vec.insert(
          col_vec.begin(), x_ptr + i * stride, x_ptr + (i + 1) * stride);
      nan_counts[i] =
          std::count_if(col_vec.begin(), col_vec.end(), [&](const T& val) {
            return std::isnan(static_cast<float>(val));
          });
      total_nan_num += nan_counts[i];
      if (stride - nan_counts[i] > max_valid_num)
        max_valid_num = stride - nan_counts[i];
    }
    // all elems are nan
    if (total_nan_num == numel) {
      for (i = 0; i < pre_dim; i++) {
        o_ptr[i] = x_ptr[0];
        m_ptr[2 * i] = -1;
        m_ptr[2 * i + 1] = -1;
      }
      return;
    }
    should_ignore_nan = total_nan_num > 0;
  }

  int64_t sort_k = should_ignore_nan ? max_valid_num : ((stride >> 1) + 1);
  CalcMedianFunc<T, Context>(dev_ctx,
                             x,
                             nan_counts,
                             should_ignore_nan,
                             sort_k,
                             stride,
                             pre_dim,
                             o_ptr,
                             m_ptr);
}

template <typename T, typename Context>
void BaseMedianKernel(const Context& dev_ctx,
                      const DenseTensor& input,
                      const IntArray& axes,
                      DenseTensor* out,
                      DenseTensor* median_index,
                      bool ignore_nan) {
  DenseTensor x;
  auto rank = input.dims().size();
  if ((axes.size() == 0) || rank <= 1) {
    x = input;
    x.Resize({input.numel()});
  } else {
    PreprocessMedianKernel<T, Context>(dev_ctx, input, axes, &x);
  }

  T* o_ptr = dev_ctx.template Alloc<T>(out);
  int64_t* m_ptr = dev_ctx.template Alloc<int64_t>(median_index);
  ProcessMedianKernel<T, Context>(dev_ctx, x, o_ptr, m_ptr, ignore_nan);
  out->Resize(out->dims());
}

template <typename T, typename Context>
void NanmedianKernel(const Context& dev_ctx,
                     const DenseTensor& x,
                     const IntArray& axes,
                     bool keepdim,
                     DenseTensor* out,
                     DenseTensor* median_index) {
  BaseMedianKernel<T, Context>(dev_ctx, x, axes, out, median_index, true);
}

}  // namespace phi

PD_REGISTER_KERNEL(nanmedian,
                   CPU,
                   ALL_LAYOUT,
                   phi::NanmedianKernel,
                   float,
                   double,
                   int,
210 211 212
                   int64_t) {
  kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}