未验证 提交 7e68bc89 编写于 作者: H hutuxian 提交者: GitHub

refactor AUC OP and add its CUDA Kernel (#21336)

* refactor AUC OP and add its CUDA Kernel
* the layout of global auc doesn't change
上级 1f57ac12
......@@ -49,9 +49,12 @@ class AucOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("AUC", {1});
slide_steps = slide_steps == 0 ? 1 : slide_steps;
ctx->SetOutputDim("StatPosOut", {slide_steps, num_pred_buckets});
ctx->SetOutputDim("StatNegOut", {slide_steps, num_pred_buckets});
// slide_steps = slide_steps == 0 ? 1 : slide_steps;
int need_batch_id = slide_steps ? 1 : 0;
ctx->SetOutputDim("StatPosOut",
{(1 + slide_steps) * num_pred_buckets + need_batch_id});
ctx->SetOutputDim("StatNegOut",
{(1 + slide_steps) * num_pred_buckets + need_batch_id});
}
protected:
......@@ -59,7 +62,7 @@ class AucOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Predict"),
platform::CPUPlace());
ctx.device_context());
}
};
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/metrics/auc_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
__global__ void ClearObsoleteDataKernel(int64_t *pos, int64_t *neg,
const int bucket_length,
const int slide_steps) {
int cur_step_index =
static_cast<int>(pos[(slide_steps + 1) * bucket_length]) % slide_steps;
int cur_step_begin = cur_step_index * bucket_length;
int sum_step_begin = slide_steps * bucket_length;
CUDA_KERNEL_LOOP(i, bucket_length) {
pos[sum_step_begin + i] -= pos[cur_step_begin + i];
neg[sum_step_begin + i] -= neg[cur_step_begin + i];
pos[cur_step_begin + i] = neg[cur_step_begin + i] = 0;
}
}
__global__ void UpdateSumDataKernel(int64_t *pos, int64_t *neg,
const int bucket_length,
const int slide_steps) {
int cur_step_index =
static_cast<int>(pos[(slide_steps + 1) * bucket_length]) % slide_steps;
int cur_step_begin = cur_step_index * bucket_length;
int sum_step_begin = slide_steps * bucket_length;
CUDA_KERNEL_LOOP(i, bucket_length) {
pos[sum_step_begin + i] += pos[cur_step_begin + i];
neg[sum_step_begin + i] += neg[cur_step_begin + i];
}
}
template <typename T>
__global__ void AddDataKernel(const int64_t *label_data, const T *pred_data,
const int inference_width,
const int num_thresholds, int64_t *pos,
int64_t *neg, const int numel,
const int slide_steps) {
int cur_step_begin = 0;
if (slide_steps > 0) {
int cur_step_index =
static_cast<int>(pos[(slide_steps + 1) * (1 + num_thresholds)]) %
slide_steps;
cur_step_begin = cur_step_index * (1 + num_thresholds);
}
CUDA_KERNEL_LOOP(i, numel) {
auto predict_data = pred_data[i * inference_width + (inference_width - 1)];
PADDLE_ENFORCE(predict_data <= 1, "The predict data must less or equal 1.");
PADDLE_ENFORCE(predict_data >= 0,
"The predict data must gather or equal 0.");
uint32_t binIdx = static_cast<uint32_t>(predict_data * num_thresholds);
if (label_data[i]) {
paddle::platform::CudaAtomicAdd(pos + cur_step_begin + binIdx, 1);
} else {
paddle::platform::CudaAtomicAdd(neg + cur_step_begin + binIdx, 1);
}
}
}
__global__ void CalcAucKernel(int64_t *stat_pos, int64_t *stat_neg,
int num_thresholds, double *auc,
bool need_add_batch_num) {
*auc = 0.0f;
double totPos = 0.0;
double totNeg = 0.0;
double totPosPrev = 0.0;
double totNegPrev = 0.0;
int idx = num_thresholds;
while (idx >= 0) {
totPosPrev = totPos;
totNegPrev = totNeg;
totPos += stat_pos[idx];
totNeg += stat_neg[idx];
*auc += (totNeg - totNegPrev) * (totPos + totPosPrev) / 2.0;
--idx;
}
if (totPos > 0.0 && totNeg > 0.0) {
*auc = *auc / totPos / totNeg;
}
if (need_add_batch_num) {
stat_pos[num_thresholds + 1] += 1;
stat_neg[num_thresholds + 1] += 1;
}
}
template <typename DeviceContext, typename T>
class AucCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *predict = ctx.Input<Tensor>("Predict");
auto *label = ctx.Input<Tensor>("Label");
int num_thresholds = ctx.Attr<int>("num_thresholds");
int slide_steps = ctx.Attr<int>("slide_steps");
// Only use output var for now, make sure it's persistable and
// not cleaned up for each batch.
auto *auc_tensor = ctx.Output<Tensor>("AUC");
auto *stat_pos = ctx.Output<Tensor>("StatPosOut");
auto *stat_neg = ctx.Output<Tensor>("StatNegOut");
auto *origin_stat_pos = stat_pos->mutable_data<int64_t>(ctx.GetPlace());
auto *origin_stat_neg = stat_neg->mutable_data<int64_t>(ctx.GetPlace());
auto *auc_value = auc_tensor->mutable_data<double>(ctx.GetPlace());
auto *stat_pos_in_tensor = ctx.Input<Tensor>("StatPos");
auto *pos_in_data = stat_pos_in_tensor->data<int64_t>();
auto *stat_neg_in_tensor = ctx.Input<Tensor>("StatNeg");
auto *neg_in_data = stat_neg_in_tensor->data<int64_t>();
if (stat_pos_in_tensor != stat_pos) {
cudaMemcpy(origin_stat_pos, pos_in_data,
((1 + slide_steps) * (num_thresholds + 1) +
(slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t),
cudaMemcpyDeviceToDevice);
}
if (stat_neg_in_tensor != stat_neg) {
cudaMemcpy(origin_stat_neg, neg_in_data,
((1 + slide_steps) * (num_thresholds + 1) +
(slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t),
cudaMemcpyDeviceToDevice);
}
statAuc(ctx, label, predict, num_thresholds, slide_steps, origin_stat_pos,
origin_stat_neg);
int sum_offset = slide_steps * (num_thresholds + 1);
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();
CalcAucKernel<<<1, 1, 0, stream>>>(
origin_stat_pos + sum_offset, origin_stat_neg + sum_offset,
num_thresholds, auc_value, slide_steps > 0);
}
private:
inline static double trapezoidArea(double X1, double X2, double Y1,
double Y2) {
return (X1 > X2 ? (X1 - X2) : (X2 - X1)) * (Y1 + Y2) / 2.0;
}
inline static void statAuc(const framework::ExecutionContext &ctx,
const framework::Tensor *label,
const framework::Tensor *predict,
const int num_thresholds, const int slide_steps,
int64_t *origin_stat_pos,
int64_t *origin_stat_neg) {
size_t batch_size = predict->dims()[0];
size_t inference_width = predict->dims()[1];
const T *inference_data = predict->data<T>();
const auto *label_data = label->data<int64_t>();
const int bucket_length = num_thresholds + 1;
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();
if (slide_steps == 0) {
AddDataKernel<<<(batch_size + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
label_data, inference_data, inference_width, num_thresholds,
origin_stat_pos, origin_stat_neg, batch_size, slide_steps);
return;
}
// the last number of origin_stat_pos store the index should be used in
// current step
int cur_step_index =
static_cast<int>(origin_stat_pos[(slide_steps + 1) * bucket_length]) %
slide_steps;
int cur_step_begin = cur_step_index * bucket_length;
int sum_step_begin = slide_steps * bucket_length;
ClearObsoleteDataKernel<<<(bucket_length + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
origin_stat_pos, origin_stat_neg, bucket_length, slide_steps);
AddDataKernel<<<(batch_size + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
label_data, inference_data, inference_width, num_thresholds,
origin_stat_pos, origin_stat_neg, batch_size, slide_steps);
UpdateSumDataKernel<<<(bucket_length + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
origin_stat_pos, origin_stat_neg, bucket_length, slide_steps);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(auc,
ops::AucCUDAKernel<paddle::platform::CUDAPlace, float>);
......@@ -30,31 +30,46 @@ class AucKernel : public framework::OpKernel<T> {
auto *predict = ctx.Input<Tensor>("Predict");
auto *label = ctx.Input<Tensor>("Label");
std::string curve = ctx.Attr<std::string>("curve");
int num_thresholds = ctx.Attr<int>("num_thresholds");
// buckets contain numbers from 0 to num_thresholds
int num_pred_buckets = num_thresholds + 1;
int slide_steps = ctx.Attr<int>("slide_steps");
// Only use output var for now, make sure it's persistable and
// not cleaned up for each batch.
auto *auc = ctx.Output<Tensor>("AUC");
auto *auc_tensor = ctx.Output<Tensor>("AUC");
auto *stat_pos = ctx.Output<Tensor>("StatPosOut");
auto *stat_neg = ctx.Output<Tensor>("StatNegOut");
auto *origin_stat_pos = stat_pos->mutable_data<int64_t>(ctx.GetPlace());
auto *origin_stat_neg = stat_neg->mutable_data<int64_t>(ctx.GetPlace());
std::vector<int64_t> stat_pos_data(num_pred_buckets, 0);
std::vector<int64_t> stat_neg_data(num_pred_buckets, 0);
auto stat_pos_calc = stat_pos_data.data();
auto stat_neg_calc = stat_neg_data.data();
statAuc(label, predict, num_pred_buckets, num_thresholds, slide_steps,
origin_stat_pos, origin_stat_neg, &stat_pos_calc, &stat_neg_calc);
calcAuc(ctx, stat_pos_calc, stat_neg_calc, num_thresholds, auc);
auto *auc_value = auc_tensor->mutable_data<double>(ctx.GetPlace());
// Just for pass UT, since UT's input & output connot be set same var
auto *stat_pos_in_tensor = ctx.Input<Tensor>("StatPos");
auto *pos_in_data = stat_pos_in_tensor->data<int64_t>();
auto *stat_neg_in_tensor = ctx.Input<Tensor>("StatNeg");
auto *neg_in_data = stat_neg_in_tensor->data<int64_t>();
if (stat_pos_in_tensor != stat_pos) {
memcpy(origin_stat_pos, pos_in_data,
((1 + slide_steps) * (num_thresholds + 1) +
(slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t));
}
if (stat_neg_in_tensor != stat_neg) {
memcpy(origin_stat_neg, neg_in_data,
((1 + slide_steps) * (num_thresholds + 1) +
(slide_steps > 0 ? 1 : 0)) *
sizeof(int64_t));
}
statAuc(label, predict, num_thresholds, slide_steps, origin_stat_pos,
origin_stat_neg);
int sum_offset = slide_steps * (num_thresholds + 1);
calcAuc(origin_stat_pos + sum_offset, origin_stat_neg + sum_offset,
num_thresholds, auc_value);
if (slide_steps) {
origin_stat_pos[(slide_steps + 1) * (num_thresholds + 1)] += 1;
origin_stat_neg[(slide_steps + 1) * (num_thresholds + 1)] += 1;
}
}
private:
......@@ -65,14 +80,54 @@ class AucKernel : public framework::OpKernel<T> {
inline static void statAuc(const framework::Tensor *label,
const framework::Tensor *predict,
const int num_pred_buckets,
const int num_thresholds, const int slide_steps,
int64_t *origin_stat_pos, int64_t *origin_stat_neg,
int64_t **stat_pos, int64_t **stat_neg) {
int64_t *origin_stat_pos,
int64_t *origin_stat_neg) {
size_t batch_size = predict->dims()[0];
size_t inference_width = predict->dims()[1];
const T *inference_data = predict->data<T>();
const auto *label_data = label->data<int64_t>();
const int bucket_length = num_thresholds + 1;
if (slide_steps == 0) {
for (size_t i = 0; i < batch_size; i++) {
// if predict_data[i] has dim of 2, then predict_data[i][1] is pos prob
// if predict_data[i] has dim of 1, then predict_data[i][0] is pos prob
auto predict_data =
inference_data[i * inference_width + (inference_width - 1)];
PADDLE_ENFORCE_LE(predict_data, 1,
platform::errors::PreconditionNotMet(
"The predict data must less or equal 1."));
PADDLE_ENFORCE_GE(predict_data, 0,
platform::errors::PreconditionNotMet(
"The predict data must gather or equal 0."));
uint32_t binIdx = static_cast<uint32_t>(predict_data * num_thresholds);
if (label_data[i]) {
origin_stat_pos[binIdx] += 1;
} else {
origin_stat_neg[binIdx] += 1;
}
}
return;
}
// the last number of origin_stat_pos store the index should be used in
// current step
int cur_step_index =
static_cast<int>(origin_stat_pos[(slide_steps + 1) * bucket_length]) %
slide_steps;
int cur_step_begin = cur_step_index * bucket_length;
int sum_step_begin = slide_steps * bucket_length;
for (int i = 0; i < bucket_length; ++i) {
origin_stat_pos[sum_step_begin + i] -=
origin_stat_pos[cur_step_begin + i];
origin_stat_neg[sum_step_begin + i] -=
origin_stat_neg[cur_step_begin + i];
}
std::memset(origin_stat_pos + cur_step_begin, 0,
bucket_length * sizeof(int64_t));
std::memset(origin_stat_neg + cur_step_begin, 0,
bucket_length * sizeof(int64_t));
for (size_t i = 0; i < batch_size; i++) {
// if predict_data[i] has dim of 2, then predict_data[i][1] is pos prob
......@@ -80,67 +135,29 @@ class AucKernel : public framework::OpKernel<T> {
auto predict_data =
inference_data[i * inference_width + (inference_width - 1)];
PADDLE_ENFORCE_LE(predict_data, 1,
"The predict data must less or equal 1.");
platform::errors::PreconditionNotMet(
"The predict data must less or equal 1."));
PADDLE_ENFORCE_GE(predict_data, 0,
"The predict data must gather or equal 0.");
platform::errors::PreconditionNotMet(
"The predict data must gather or equal 0."));
uint32_t binIdx = static_cast<uint32_t>(predict_data * num_thresholds);
if (label_data[i]) {
(*stat_pos)[binIdx] += 1.0;
origin_stat_pos[cur_step_begin + binIdx] += 1;
} else {
(*stat_neg)[binIdx] += 1.0;
origin_stat_neg[cur_step_begin + binIdx] += 1;
}
}
int bucket_length = num_pred_buckets * sizeof(int64_t);
// will stat auc unlimited.
if (slide_steps == 0) {
for (int slide = 0; slide < num_pred_buckets; ++slide) {
origin_stat_pos[slide] += (*stat_pos)[slide];
origin_stat_neg[slide] += (*stat_neg)[slide];
}
*stat_pos = origin_stat_pos;
*stat_neg = origin_stat_neg;
} else {
for (int slide = 1; slide < slide_steps; ++slide) {
int dst_idx = (slide - 1) * num_pred_buckets;
int src_inx = slide * num_pred_buckets;
std::memcpy(origin_stat_pos + dst_idx, origin_stat_pos + src_inx,
bucket_length);
std::memcpy(origin_stat_neg + dst_idx, origin_stat_neg + src_inx,
bucket_length);
}
std::memcpy(origin_stat_pos + (slide_steps - 1) * num_pred_buckets,
*stat_pos, bucket_length);
std::memcpy(origin_stat_neg + (slide_steps - 1) * num_pred_buckets,
*stat_neg, bucket_length);
std::memset(*stat_pos, 0, bucket_length);
std::memset(*stat_neg, 0, bucket_length);
for (int slide = 0; slide < num_pred_buckets; ++slide) {
int stat_pos_steps = 0;
int stat_neg_steps = 0;
for (int step = 0; step < slide_steps; ++step) {
stat_pos_steps += origin_stat_pos[slide + step * num_pred_buckets];
stat_neg_steps += origin_stat_neg[slide + step * num_pred_buckets];
}
(*stat_pos)[slide] += stat_pos_steps;
(*stat_neg)[slide] += stat_neg_steps;
}
for (int i = 0; i < bucket_length; ++i) {
origin_stat_pos[sum_step_begin + i] +=
origin_stat_pos[cur_step_begin + i];
origin_stat_neg[sum_step_begin + i] +=
origin_stat_neg[cur_step_begin + i];
}
}
inline static void calcAuc(const framework::ExecutionContext &ctx,
int64_t *stat_pos, int64_t *stat_neg,
int num_thresholds,
framework::Tensor *auc_tensor) {
auto *auc = auc_tensor->mutable_data<double>(ctx.GetPlace());
inline static void calcAuc(const int64_t *stat_pos, const int64_t *stat_neg,
int num_thresholds, double *auc) {
*auc = 0.0f;
double totPos = 0.0;
......
......@@ -167,25 +167,31 @@ def auc(input,
# make tp, tn, fp, fn persistable, so that can accumulate all batches.
# for batch auc
# we create slide_step+1 buckets, the first slide_steps buckets store
# historical batch-level values, and the last bucket stores the sum values of
# previous slide_step buckets.
# The index of bucket that the newest batch will use is determined by batch_id mod slide_steps,
# and batch_id is store in the last posision of following variable
batch_stat_pos = helper.create_global_variable(
persistable=True,
dtype='int64',
shape=[slide_steps, num_thresholds + 1])
shape=[(1 + slide_steps) * (num_thresholds + 1) + 1])
batch_stat_neg = helper.create_global_variable(
persistable=True,
dtype='int64',
shape=[slide_steps, num_thresholds + 1])
shape=[(1 + slide_steps) * (num_thresholds + 1) + 1])
# for global auc
# Needn't maintain the batch id
stat_pos = helper.create_global_variable(
persistable=True, dtype='int64', shape=[1, num_thresholds + 1])
persistable=True, dtype='int64', shape=[num_thresholds + 1])
stat_neg = helper.create_global_variable(
persistable=True, dtype='int64', shape=[1, num_thresholds + 1])
persistable=True, dtype='int64', shape=[num_thresholds + 1])
for var in [batch_stat_pos, batch_stat_neg, stat_pos, stat_neg]:
helper.set_variable_initializer(
var, Constant(
value=0.0, force_cpu=True))
value=0.0, force_cpu=False))
# Batch AUC
helper.append_op(
......
......@@ -26,9 +26,12 @@ class TestAucOp(OpTest):
pred = np.random.random((128, 2)).astype("float32")
labels = np.random.randint(0, 2, (128, 1)).astype("int64")
num_thresholds = 200
slide_steps = 1
stat_pos = np.zeros((num_thresholds + 1, )).astype("int64")
stat_neg = np.zeros((num_thresholds + 1, )).astype("int64")
stat_pos = np.zeros((1 + slide_steps) * (num_thresholds + 1) + 1,
).astype("int64")
stat_neg = np.zeros((1 + slide_steps) * (num_thresholds + 1) + 1,
).astype("int64")
self.inputs = {
'Predict': pred,
......@@ -39,7 +42,7 @@ class TestAucOp(OpTest):
self.attrs = {
'curve': 'ROC',
'num_thresholds': num_thresholds,
"slide_steps": 1
"slide_steps": slide_steps
}
python_auc = metrics.Auc(name="auc",
......@@ -47,10 +50,14 @@ class TestAucOp(OpTest):
num_thresholds=num_thresholds)
python_auc.update(pred, labels)
pos = python_auc._stat_pos * 2
pos.append(1)
neg = python_auc._stat_neg * 2
neg.append(1)
self.outputs = {
'AUC': np.array(python_auc.eval()),
'StatPosOut': np.array(python_auc._stat_pos),
'StatNegOut': np.array(python_auc._stat_neg)
'StatPosOut': np.array(pos),
'StatNegOut': np.array(neg)
}
def test_check_output(self):
......
......@@ -27,9 +27,12 @@ class TestAucSinglePredOp(OpTest):
pred0 = pred[:, 0].reshape(128, 1)
labels = np.random.randint(0, 2, (128, 1)).astype("int64")
num_thresholds = 200
slide_steps = 1
stat_pos = np.zeros((num_thresholds + 1, )).astype("int64")
stat_neg = np.zeros((num_thresholds + 1, )).astype("int64")
stat_pos = np.zeros((1 + slide_steps) * (num_thresholds + 1) + 1,
).astype("int64")
stat_neg = np.zeros((1 + slide_steps) * (num_thresholds + 1) + 1,
).astype("int64")
self.inputs = {
'Predict': pred0,
......@@ -40,7 +43,7 @@ class TestAucSinglePredOp(OpTest):
self.attrs = {
'curve': 'ROC',
'num_thresholds': num_thresholds,
"slide_steps": 1
"slide_steps": slide_steps
}
python_auc = metrics.Auc(name="auc",
......@@ -50,10 +53,14 @@ class TestAucSinglePredOp(OpTest):
pred[i][1] = pred[i][0]
python_auc.update(pred, labels)
pos = python_auc._stat_pos * 2
pos.append(1)
neg = python_auc._stat_neg * 2
neg.append(1)
self.outputs = {
'AUC': np.array(python_auc.eval()),
'StatPosOut': np.array(python_auc._stat_pos),
'StatNegOut': np.array(python_auc._stat_neg)
'StatPosOut': np.array(pos),
'StatNegOut': np.array(neg)
}
def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册