auc_op.cu 9.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/metrics/auc_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

__global__ void ClearObsoleteDataKernel(int64_t *pos, int64_t *neg,
                                        const int bucket_length,
                                        const int slide_steps) {
  int cur_step_index =
      static_cast<int>(pos[(slide_steps + 1) * bucket_length]) % slide_steps;
  int cur_step_begin = cur_step_index * bucket_length;
  int sum_step_begin = slide_steps * bucket_length;
  CUDA_KERNEL_LOOP(i, bucket_length) {
    pos[sum_step_begin + i] -= pos[cur_step_begin + i];
    neg[sum_step_begin + i] -= neg[cur_step_begin + i];
    pos[cur_step_begin + i] = neg[cur_step_begin + i] = 0;
  }
}

__global__ void UpdateSumDataKernel(int64_t *pos, int64_t *neg,
                                    const int bucket_length,
                                    const int slide_steps) {
  int cur_step_index =
      static_cast<int>(pos[(slide_steps + 1) * bucket_length]) % slide_steps;
  int cur_step_begin = cur_step_index * bucket_length;
  int sum_step_begin = slide_steps * bucket_length;
  CUDA_KERNEL_LOOP(i, bucket_length) {
    pos[sum_step_begin + i] += pos[cur_step_begin + i];
    neg[sum_step_begin + i] += neg[cur_step_begin + i];
  }
}

template <typename T>
__global__ void AddDataKernel(const int64_t *label_data, const T *pred_data,
                              const int inference_width,
                              const int num_thresholds, int64_t *pos,
                              int64_t *neg, const int numel,
                              const int slide_steps) {
  int cur_step_begin = 0;
  if (slide_steps > 0) {
    int cur_step_index =
        static_cast<int>(pos[(slide_steps + 1) * (1 + num_thresholds)]) %
        slide_steps;
    cur_step_begin = cur_step_index * (1 + num_thresholds);
  }
  CUDA_KERNEL_LOOP(i, numel) {
    auto predict_data = pred_data[i * inference_width + (inference_width - 1)];
    PADDLE_ENFORCE(predict_data <= 1, "The predict data must less or equal 1.");
    PADDLE_ENFORCE(predict_data >= 0,
                   "The predict data must gather or equal 0.");
    uint32_t binIdx = static_cast<uint32_t>(predict_data * num_thresholds);
    if (label_data[i]) {
      paddle::platform::CudaAtomicAdd(pos + cur_step_begin + binIdx, 1);
    } else {
      paddle::platform::CudaAtomicAdd(neg + cur_step_begin + binIdx, 1);
    }
  }
}
__global__ void CalcAucKernel(int64_t *stat_pos, int64_t *stat_neg,
                              int num_thresholds, double *auc,
                              bool need_add_batch_num) {
  *auc = 0.0f;
  double totPos = 0.0;
  double totNeg = 0.0;
  double totPosPrev = 0.0;
  double totNegPrev = 0.0;

  int idx = num_thresholds;

  while (idx >= 0) {
    totPosPrev = totPos;
    totNegPrev = totNeg;
    totPos += stat_pos[idx];
    totNeg += stat_neg[idx];
    *auc += (totNeg - totNegPrev) * (totPos + totPosPrev) / 2.0;
    --idx;
  }

  if (totPos > 0.0 && totNeg > 0.0) {
    *auc = *auc / totPos / totNeg;
  }
  if (need_add_batch_num) {
    stat_pos[num_thresholds + 1] += 1;
    stat_neg[num_thresholds + 1] += 1;
  }
}

template <typename DeviceContext, typename T>
class AucCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto *predict = ctx.Input<Tensor>("Predict");
    auto *label = ctx.Input<Tensor>("Label");
W
wangzhen38 已提交
115 116 117 118 119 120
    auto *ins_tag_weight = ctx.Input<Tensor>("InsTagWeight");
    const auto *ins_tag_weight_value = ins_tag_weight->data<float>();
    bool is_fake_data = 0;
    if (ins_tag_weight_value[0] == 0) {
      is_fake_data = 1;
    }
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153

    int num_thresholds = ctx.Attr<int>("num_thresholds");
    int slide_steps = ctx.Attr<int>("slide_steps");

    // Only use output var for now, make sure it's persistable and
    // not cleaned up for each batch.
    auto *auc_tensor = ctx.Output<Tensor>("AUC");
    auto *stat_pos = ctx.Output<Tensor>("StatPosOut");
    auto *stat_neg = ctx.Output<Tensor>("StatNegOut");

    auto *origin_stat_pos = stat_pos->mutable_data<int64_t>(ctx.GetPlace());
    auto *origin_stat_neg = stat_neg->mutable_data<int64_t>(ctx.GetPlace());
    auto *auc_value = auc_tensor->mutable_data<double>(ctx.GetPlace());

    auto *stat_pos_in_tensor = ctx.Input<Tensor>("StatPos");
    auto *pos_in_data = stat_pos_in_tensor->data<int64_t>();
    auto *stat_neg_in_tensor = ctx.Input<Tensor>("StatNeg");
    auto *neg_in_data = stat_neg_in_tensor->data<int64_t>();
    if (stat_pos_in_tensor != stat_pos) {
      cudaMemcpy(origin_stat_pos, pos_in_data,
                 ((1 + slide_steps) * (num_thresholds + 1) +
                  (slide_steps > 0 ? 1 : 0)) *
                     sizeof(int64_t),
                 cudaMemcpyDeviceToDevice);
    }
    if (stat_neg_in_tensor != stat_neg) {
      cudaMemcpy(origin_stat_neg, neg_in_data,
                 ((1 + slide_steps) * (num_thresholds + 1) +
                  (slide_steps > 0 ? 1 : 0)) *
                     sizeof(int64_t),
                 cudaMemcpyDeviceToDevice);
    }

W
wangzhen38 已提交
154 155 156 157
    if (slide_steps == 0 && is_fake_data) {
      return;
    }

158
    statAuc(ctx, label, predict, num_thresholds, slide_steps, origin_stat_pos,
W
wangzhen38 已提交
159
            origin_stat_neg, is_fake_data);
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
    int sum_offset = slide_steps * (num_thresholds + 1);
    auto stream =
        ctx.template device_context<platform::CUDADeviceContext>().stream();
    CalcAucKernel<<<1, 1, 0, stream>>>(
        origin_stat_pos + sum_offset, origin_stat_neg + sum_offset,
        num_thresholds, auc_value, slide_steps > 0);
  }

 private:
  inline static double trapezoidArea(double X1, double X2, double Y1,
                                     double Y2) {
    return (X1 > X2 ? (X1 - X2) : (X2 - X1)) * (Y1 + Y2) / 2.0;
  }

  inline static void statAuc(const framework::ExecutionContext &ctx,
                             const framework::Tensor *label,
                             const framework::Tensor *predict,
                             const int num_thresholds, const int slide_steps,
W
wangzhen38 已提交
178 179
                             int64_t *origin_stat_pos, int64_t *origin_stat_neg,
                             const bool is_fake_data) {
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
    size_t batch_size = predict->dims()[0];
    size_t inference_width = predict->dims()[1];
    const T *inference_data = predict->data<T>();
    const auto *label_data = label->data<int64_t>();
    const int bucket_length = num_thresholds + 1;
    auto stream =
        ctx.template device_context<platform::CUDADeviceContext>().stream();
    if (slide_steps == 0) {
      AddDataKernel<<<(batch_size + PADDLE_CUDA_NUM_THREADS - 1) /
                          PADDLE_CUDA_NUM_THREADS,
                      PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
          label_data, inference_data, inference_width, num_thresholds,
          origin_stat_pos, origin_stat_neg, batch_size, slide_steps);
      return;
    }
    // the last number of origin_stat_pos store the index should be used in
    // current step
    int cur_step_index =
        static_cast<int>(origin_stat_pos[(slide_steps + 1) * bucket_length]) %
        slide_steps;
    int cur_step_begin = cur_step_index * bucket_length;
    int sum_step_begin = slide_steps * bucket_length;

    ClearObsoleteDataKernel<<<(bucket_length + PADDLE_CUDA_NUM_THREADS - 1) /
                                  PADDLE_CUDA_NUM_THREADS,
                              PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
        origin_stat_pos, origin_stat_neg, bucket_length, slide_steps);

    AddDataKernel<<<(batch_size + PADDLE_CUDA_NUM_THREADS - 1) /
                        PADDLE_CUDA_NUM_THREADS,
                    PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
        label_data, inference_data, inference_width, num_thresholds,
        origin_stat_pos, origin_stat_neg, batch_size, slide_steps);
W
wangzhen38 已提交
213 214 215 216 217 218
    if (!is_fake_data) {
      UpdateSumDataKernel<<<(bucket_length + PADDLE_CUDA_NUM_THREADS - 1) /
                                PADDLE_CUDA_NUM_THREADS,
                            PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
          origin_stat_pos, origin_stat_neg, bucket_length, slide_steps);
    }
219 220 221 222 223 224 225 226 227
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(auc,
                        ops::AucCUDAKernel<paddle::platform::CUDAPlace, float>);