From 7e68bc896bc6e01d22f00f26cfa9ec076e6c293a Mon Sep 17 00:00:00 2001 From: hutuxian Date: Mon, 2 Dec 2019 19:19:18 +0800 Subject: [PATCH] refactor AUC OP and add its CUDA Kernel (#21336) * refactor AUC OP and add its CUDA Kernel * the layout of global auc doesn't change --- paddle/fluid/operators/metrics/auc_op.cc | 11 +- paddle/fluid/operators/metrics/auc_op.cu | 218 ++++++++++++++++++ paddle/fluid/operators/metrics/auc_op.h | 155 +++++++------ python/paddle/fluid/layers/metric_op.py | 16 +- .../fluid/tests/unittests/test_auc_op.py | 17 +- .../unittests/test_auc_single_pred_op.py | 17 +- 6 files changed, 346 insertions(+), 88 deletions(-) create mode 100644 paddle/fluid/operators/metrics/auc_op.cu diff --git a/paddle/fluid/operators/metrics/auc_op.cc b/paddle/fluid/operators/metrics/auc_op.cc index 90734307acc..9a4a30b9cbd 100644 --- a/paddle/fluid/operators/metrics/auc_op.cc +++ b/paddle/fluid/operators/metrics/auc_op.cc @@ -49,9 +49,12 @@ class AucOp : public framework::OperatorWithKernel { ctx->SetOutputDim("AUC", {1}); - slide_steps = slide_steps == 0 ? 1 : slide_steps; - ctx->SetOutputDim("StatPosOut", {slide_steps, num_pred_buckets}); - ctx->SetOutputDim("StatNegOut", {slide_steps, num_pred_buckets}); + // slide_steps = slide_steps == 0 ? 1 : slide_steps; + int need_batch_id = slide_steps ? 1 : 0; + ctx->SetOutputDim("StatPosOut", + {(1 + slide_steps) * num_pred_buckets + need_batch_id}); + ctx->SetOutputDim("StatNegOut", + {(1 + slide_steps) * num_pred_buckets + need_batch_id}); } protected: @@ -59,7 +62,7 @@ class AucOp : public framework::OperatorWithKernel { const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "Predict"), - platform::CPUPlace()); + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/metrics/auc_op.cu b/paddle/fluid/operators/metrics/auc_op.cu new file mode 100644 index 00000000000..04af6c51c73 --- /dev/null +++ b/paddle/fluid/operators/metrics/auc_op.cu @@ -0,0 +1,218 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/metrics/auc_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { +using platform::PADDLE_CUDA_NUM_THREADS; +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +__global__ void ClearObsoleteDataKernel(int64_t *pos, int64_t *neg, + const int bucket_length, + const int slide_steps) { + int cur_step_index = + static_cast(pos[(slide_steps + 1) * bucket_length]) % slide_steps; + int cur_step_begin = cur_step_index * bucket_length; + int sum_step_begin = slide_steps * bucket_length; + CUDA_KERNEL_LOOP(i, bucket_length) { + pos[sum_step_begin + i] -= pos[cur_step_begin + i]; + neg[sum_step_begin + i] -= neg[cur_step_begin + i]; + pos[cur_step_begin + i] = neg[cur_step_begin + i] = 0; + } +} + +__global__ void UpdateSumDataKernel(int64_t *pos, int64_t *neg, + const int bucket_length, + const int slide_steps) { + int cur_step_index = + static_cast(pos[(slide_steps + 1) * bucket_length]) % slide_steps; + int cur_step_begin = cur_step_index * bucket_length; + int sum_step_begin = slide_steps * bucket_length; + CUDA_KERNEL_LOOP(i, bucket_length) { + pos[sum_step_begin + i] += pos[cur_step_begin + i]; + neg[sum_step_begin + i] += neg[cur_step_begin + i]; + } +} + +template +__global__ void AddDataKernel(const int64_t *label_data, const T *pred_data, + const int inference_width, + const int num_thresholds, int64_t *pos, + int64_t *neg, const int numel, + const int slide_steps) { + int cur_step_begin = 0; + if (slide_steps > 0) { + int cur_step_index = + static_cast(pos[(slide_steps + 1) * (1 + num_thresholds)]) % + slide_steps; + cur_step_begin = cur_step_index * (1 + num_thresholds); + } + CUDA_KERNEL_LOOP(i, numel) { + auto predict_data = pred_data[i * inference_width + (inference_width - 1)]; + PADDLE_ENFORCE(predict_data <= 1, "The predict data must less or equal 1."); + PADDLE_ENFORCE(predict_data >= 0, + "The predict data must gather or equal 0."); + uint32_t binIdx = static_cast(predict_data * num_thresholds); + if (label_data[i]) { + paddle::platform::CudaAtomicAdd(pos + cur_step_begin + binIdx, 1); + } else { + paddle::platform::CudaAtomicAdd(neg + cur_step_begin + binIdx, 1); + } + } +} +__global__ void CalcAucKernel(int64_t *stat_pos, int64_t *stat_neg, + int num_thresholds, double *auc, + bool need_add_batch_num) { + *auc = 0.0f; + double totPos = 0.0; + double totNeg = 0.0; + double totPosPrev = 0.0; + double totNegPrev = 0.0; + + int idx = num_thresholds; + + while (idx >= 0) { + totPosPrev = totPos; + totNegPrev = totNeg; + totPos += stat_pos[idx]; + totNeg += stat_neg[idx]; + *auc += (totNeg - totNegPrev) * (totPos + totPosPrev) / 2.0; + --idx; + } + + if (totPos > 0.0 && totNeg > 0.0) { + *auc = *auc / totPos / totNeg; + } + if (need_add_batch_num) { + stat_pos[num_thresholds + 1] += 1; + stat_neg[num_thresholds + 1] += 1; + } +} + +template +class AucCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *predict = ctx.Input("Predict"); + auto *label = ctx.Input("Label"); + + int num_thresholds = ctx.Attr("num_thresholds"); + int slide_steps = ctx.Attr("slide_steps"); + + // Only use output var for now, make sure it's persistable and + // not cleaned up for each batch. + auto *auc_tensor = ctx.Output("AUC"); + auto *stat_pos = ctx.Output("StatPosOut"); + auto *stat_neg = ctx.Output("StatNegOut"); + + auto *origin_stat_pos = stat_pos->mutable_data(ctx.GetPlace()); + auto *origin_stat_neg = stat_neg->mutable_data(ctx.GetPlace()); + auto *auc_value = auc_tensor->mutable_data(ctx.GetPlace()); + + auto *stat_pos_in_tensor = ctx.Input("StatPos"); + auto *pos_in_data = stat_pos_in_tensor->data(); + auto *stat_neg_in_tensor = ctx.Input("StatNeg"); + auto *neg_in_data = stat_neg_in_tensor->data(); + if (stat_pos_in_tensor != stat_pos) { + cudaMemcpy(origin_stat_pos, pos_in_data, + ((1 + slide_steps) * (num_thresholds + 1) + + (slide_steps > 0 ? 1 : 0)) * + sizeof(int64_t), + cudaMemcpyDeviceToDevice); + } + if (stat_neg_in_tensor != stat_neg) { + cudaMemcpy(origin_stat_neg, neg_in_data, + ((1 + slide_steps) * (num_thresholds + 1) + + (slide_steps > 0 ? 1 : 0)) * + sizeof(int64_t), + cudaMemcpyDeviceToDevice); + } + + statAuc(ctx, label, predict, num_thresholds, slide_steps, origin_stat_pos, + origin_stat_neg); + int sum_offset = slide_steps * (num_thresholds + 1); + auto stream = + ctx.template device_context().stream(); + CalcAucKernel<<<1, 1, 0, stream>>>( + origin_stat_pos + sum_offset, origin_stat_neg + sum_offset, + num_thresholds, auc_value, slide_steps > 0); + } + + private: + inline static double trapezoidArea(double X1, double X2, double Y1, + double Y2) { + return (X1 > X2 ? (X1 - X2) : (X2 - X1)) * (Y1 + Y2) / 2.0; + } + + inline static void statAuc(const framework::ExecutionContext &ctx, + const framework::Tensor *label, + const framework::Tensor *predict, + const int num_thresholds, const int slide_steps, + int64_t *origin_stat_pos, + int64_t *origin_stat_neg) { + size_t batch_size = predict->dims()[0]; + size_t inference_width = predict->dims()[1]; + const T *inference_data = predict->data(); + const auto *label_data = label->data(); + const int bucket_length = num_thresholds + 1; + auto stream = + ctx.template device_context().stream(); + if (slide_steps == 0) { + AddDataKernel<<<(batch_size + PADDLE_CUDA_NUM_THREADS - 1) / + PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + label_data, inference_data, inference_width, num_thresholds, + origin_stat_pos, origin_stat_neg, batch_size, slide_steps); + return; + } + // the last number of origin_stat_pos store the index should be used in + // current step + int cur_step_index = + static_cast(origin_stat_pos[(slide_steps + 1) * bucket_length]) % + slide_steps; + int cur_step_begin = cur_step_index * bucket_length; + int sum_step_begin = slide_steps * bucket_length; + + ClearObsoleteDataKernel<<<(bucket_length + PADDLE_CUDA_NUM_THREADS - 1) / + PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + origin_stat_pos, origin_stat_neg, bucket_length, slide_steps); + + AddDataKernel<<<(batch_size + PADDLE_CUDA_NUM_THREADS - 1) / + PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + label_data, inference_data, inference_width, num_thresholds, + origin_stat_pos, origin_stat_neg, batch_size, slide_steps); + UpdateSumDataKernel<<<(bucket_length + PADDLE_CUDA_NUM_THREADS - 1) / + PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + origin_stat_pos, origin_stat_neg, bucket_length, slide_steps); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(auc, + ops::AucCUDAKernel); diff --git a/paddle/fluid/operators/metrics/auc_op.h b/paddle/fluid/operators/metrics/auc_op.h index 6fb4749b35a..2dfcdaa5db0 100644 --- a/paddle/fluid/operators/metrics/auc_op.h +++ b/paddle/fluid/operators/metrics/auc_op.h @@ -30,31 +30,46 @@ class AucKernel : public framework::OpKernel { auto *predict = ctx.Input("Predict"); auto *label = ctx.Input("Label"); - std::string curve = ctx.Attr("curve"); int num_thresholds = ctx.Attr("num_thresholds"); - // buckets contain numbers from 0 to num_thresholds - int num_pred_buckets = num_thresholds + 1; int slide_steps = ctx.Attr("slide_steps"); // Only use output var for now, make sure it's persistable and // not cleaned up for each batch. - auto *auc = ctx.Output("AUC"); + auto *auc_tensor = ctx.Output("AUC"); auto *stat_pos = ctx.Output("StatPosOut"); auto *stat_neg = ctx.Output("StatNegOut"); auto *origin_stat_pos = stat_pos->mutable_data(ctx.GetPlace()); auto *origin_stat_neg = stat_neg->mutable_data(ctx.GetPlace()); - - std::vector stat_pos_data(num_pred_buckets, 0); - std::vector stat_neg_data(num_pred_buckets, 0); - - auto stat_pos_calc = stat_pos_data.data(); - auto stat_neg_calc = stat_neg_data.data(); - - statAuc(label, predict, num_pred_buckets, num_thresholds, slide_steps, - origin_stat_pos, origin_stat_neg, &stat_pos_calc, &stat_neg_calc); - - calcAuc(ctx, stat_pos_calc, stat_neg_calc, num_thresholds, auc); + auto *auc_value = auc_tensor->mutable_data(ctx.GetPlace()); + + // Just for pass UT, since UT's input & output connot be set same var + auto *stat_pos_in_tensor = ctx.Input("StatPos"); + auto *pos_in_data = stat_pos_in_tensor->data(); + auto *stat_neg_in_tensor = ctx.Input("StatNeg"); + auto *neg_in_data = stat_neg_in_tensor->data(); + if (stat_pos_in_tensor != stat_pos) { + memcpy(origin_stat_pos, pos_in_data, + ((1 + slide_steps) * (num_thresholds + 1) + + (slide_steps > 0 ? 1 : 0)) * + sizeof(int64_t)); + } + if (stat_neg_in_tensor != stat_neg) { + memcpy(origin_stat_neg, neg_in_data, + ((1 + slide_steps) * (num_thresholds + 1) + + (slide_steps > 0 ? 1 : 0)) * + sizeof(int64_t)); + } + statAuc(label, predict, num_thresholds, slide_steps, origin_stat_pos, + origin_stat_neg); + + int sum_offset = slide_steps * (num_thresholds + 1); + calcAuc(origin_stat_pos + sum_offset, origin_stat_neg + sum_offset, + num_thresholds, auc_value); + if (slide_steps) { + origin_stat_pos[(slide_steps + 1) * (num_thresholds + 1)] += 1; + origin_stat_neg[(slide_steps + 1) * (num_thresholds + 1)] += 1; + } } private: @@ -65,14 +80,54 @@ class AucKernel : public framework::OpKernel { inline static void statAuc(const framework::Tensor *label, const framework::Tensor *predict, - const int num_pred_buckets, const int num_thresholds, const int slide_steps, - int64_t *origin_stat_pos, int64_t *origin_stat_neg, - int64_t **stat_pos, int64_t **stat_neg) { + int64_t *origin_stat_pos, + int64_t *origin_stat_neg) { size_t batch_size = predict->dims()[0]; size_t inference_width = predict->dims()[1]; const T *inference_data = predict->data(); const auto *label_data = label->data(); + const int bucket_length = num_thresholds + 1; + if (slide_steps == 0) { + for (size_t i = 0; i < batch_size; i++) { + // if predict_data[i] has dim of 2, then predict_data[i][1] is pos prob + // if predict_data[i] has dim of 1, then predict_data[i][0] is pos prob + auto predict_data = + inference_data[i * inference_width + (inference_width - 1)]; + PADDLE_ENFORCE_LE(predict_data, 1, + platform::errors::PreconditionNotMet( + "The predict data must less or equal 1.")); + PADDLE_ENFORCE_GE(predict_data, 0, + platform::errors::PreconditionNotMet( + "The predict data must gather or equal 0.")); + + uint32_t binIdx = static_cast(predict_data * num_thresholds); + if (label_data[i]) { + origin_stat_pos[binIdx] += 1; + } else { + origin_stat_neg[binIdx] += 1; + } + } + return; + } + // the last number of origin_stat_pos store the index should be used in + // current step + int cur_step_index = + static_cast(origin_stat_pos[(slide_steps + 1) * bucket_length]) % + slide_steps; + int cur_step_begin = cur_step_index * bucket_length; + int sum_step_begin = slide_steps * bucket_length; + for (int i = 0; i < bucket_length; ++i) { + origin_stat_pos[sum_step_begin + i] -= + origin_stat_pos[cur_step_begin + i]; + origin_stat_neg[sum_step_begin + i] -= + origin_stat_neg[cur_step_begin + i]; + } + + std::memset(origin_stat_pos + cur_step_begin, 0, + bucket_length * sizeof(int64_t)); + std::memset(origin_stat_neg + cur_step_begin, 0, + bucket_length * sizeof(int64_t)); for (size_t i = 0; i < batch_size; i++) { // if predict_data[i] has dim of 2, then predict_data[i][1] is pos prob @@ -80,67 +135,29 @@ class AucKernel : public framework::OpKernel { auto predict_data = inference_data[i * inference_width + (inference_width - 1)]; PADDLE_ENFORCE_LE(predict_data, 1, - "The predict data must less or equal 1."); + platform::errors::PreconditionNotMet( + "The predict data must less or equal 1.")); PADDLE_ENFORCE_GE(predict_data, 0, - "The predict data must gather or equal 0."); + platform::errors::PreconditionNotMet( + "The predict data must gather or equal 0.")); uint32_t binIdx = static_cast(predict_data * num_thresholds); if (label_data[i]) { - (*stat_pos)[binIdx] += 1.0; + origin_stat_pos[cur_step_begin + binIdx] += 1; } else { - (*stat_neg)[binIdx] += 1.0; + origin_stat_neg[cur_step_begin + binIdx] += 1; } } - - int bucket_length = num_pred_buckets * sizeof(int64_t); - - // will stat auc unlimited. - if (slide_steps == 0) { - for (int slide = 0; slide < num_pred_buckets; ++slide) { - origin_stat_pos[slide] += (*stat_pos)[slide]; - origin_stat_neg[slide] += (*stat_neg)[slide]; - } - - *stat_pos = origin_stat_pos; - *stat_neg = origin_stat_neg; - - } else { - for (int slide = 1; slide < slide_steps; ++slide) { - int dst_idx = (slide - 1) * num_pred_buckets; - int src_inx = slide * num_pred_buckets; - std::memcpy(origin_stat_pos + dst_idx, origin_stat_pos + src_inx, - bucket_length); - std::memcpy(origin_stat_neg + dst_idx, origin_stat_neg + src_inx, - bucket_length); - } - - std::memcpy(origin_stat_pos + (slide_steps - 1) * num_pred_buckets, - *stat_pos, bucket_length); - std::memcpy(origin_stat_neg + (slide_steps - 1) * num_pred_buckets, - *stat_neg, bucket_length); - - std::memset(*stat_pos, 0, bucket_length); - std::memset(*stat_neg, 0, bucket_length); - - for (int slide = 0; slide < num_pred_buckets; ++slide) { - int stat_pos_steps = 0; - int stat_neg_steps = 0; - for (int step = 0; step < slide_steps; ++step) { - stat_pos_steps += origin_stat_pos[slide + step * num_pred_buckets]; - stat_neg_steps += origin_stat_neg[slide + step * num_pred_buckets]; - } - (*stat_pos)[slide] += stat_pos_steps; - (*stat_neg)[slide] += stat_neg_steps; - } + for (int i = 0; i < bucket_length; ++i) { + origin_stat_pos[sum_step_begin + i] += + origin_stat_pos[cur_step_begin + i]; + origin_stat_neg[sum_step_begin + i] += + origin_stat_neg[cur_step_begin + i]; } } - inline static void calcAuc(const framework::ExecutionContext &ctx, - int64_t *stat_pos, int64_t *stat_neg, - int num_thresholds, - framework::Tensor *auc_tensor) { - auto *auc = auc_tensor->mutable_data(ctx.GetPlace()); - + inline static void calcAuc(const int64_t *stat_pos, const int64_t *stat_neg, + int num_thresholds, double *auc) { *auc = 0.0f; double totPos = 0.0; diff --git a/python/paddle/fluid/layers/metric_op.py b/python/paddle/fluid/layers/metric_op.py index 403d92fadf7..3517d3ed824 100644 --- a/python/paddle/fluid/layers/metric_op.py +++ b/python/paddle/fluid/layers/metric_op.py @@ -167,25 +167,31 @@ def auc(input, # make tp, tn, fp, fn persistable, so that can accumulate all batches. # for batch auc + # we create slide_step+1 buckets, the first slide_steps buckets store + # historical batch-level values, and the last bucket stores the sum values of + # previous slide_step buckets. + # The index of bucket that the newest batch will use is determined by batch_id mod slide_steps, + # and batch_id is store in the last posision of following variable batch_stat_pos = helper.create_global_variable( persistable=True, dtype='int64', - shape=[slide_steps, num_thresholds + 1]) + shape=[(1 + slide_steps) * (num_thresholds + 1) + 1]) batch_stat_neg = helper.create_global_variable( persistable=True, dtype='int64', - shape=[slide_steps, num_thresholds + 1]) + shape=[(1 + slide_steps) * (num_thresholds + 1) + 1]) # for global auc + # Needn't maintain the batch id stat_pos = helper.create_global_variable( - persistable=True, dtype='int64', shape=[1, num_thresholds + 1]) + persistable=True, dtype='int64', shape=[num_thresholds + 1]) stat_neg = helper.create_global_variable( - persistable=True, dtype='int64', shape=[1, num_thresholds + 1]) + persistable=True, dtype='int64', shape=[num_thresholds + 1]) for var in [batch_stat_pos, batch_stat_neg, stat_pos, stat_neg]: helper.set_variable_initializer( var, Constant( - value=0.0, force_cpu=True)) + value=0.0, force_cpu=False)) # Batch AUC helper.append_op( diff --git a/python/paddle/fluid/tests/unittests/test_auc_op.py b/python/paddle/fluid/tests/unittests/test_auc_op.py index b75abd424a4..4835c38f5fc 100644 --- a/python/paddle/fluid/tests/unittests/test_auc_op.py +++ b/python/paddle/fluid/tests/unittests/test_auc_op.py @@ -26,9 +26,12 @@ class TestAucOp(OpTest): pred = np.random.random((128, 2)).astype("float32") labels = np.random.randint(0, 2, (128, 1)).astype("int64") num_thresholds = 200 + slide_steps = 1 - stat_pos = np.zeros((num_thresholds + 1, )).astype("int64") - stat_neg = np.zeros((num_thresholds + 1, )).astype("int64") + stat_pos = np.zeros((1 + slide_steps) * (num_thresholds + 1) + 1, + ).astype("int64") + stat_neg = np.zeros((1 + slide_steps) * (num_thresholds + 1) + 1, + ).astype("int64") self.inputs = { 'Predict': pred, @@ -39,7 +42,7 @@ class TestAucOp(OpTest): self.attrs = { 'curve': 'ROC', 'num_thresholds': num_thresholds, - "slide_steps": 1 + "slide_steps": slide_steps } python_auc = metrics.Auc(name="auc", @@ -47,10 +50,14 @@ class TestAucOp(OpTest): num_thresholds=num_thresholds) python_auc.update(pred, labels) + pos = python_auc._stat_pos * 2 + pos.append(1) + neg = python_auc._stat_neg * 2 + neg.append(1) self.outputs = { 'AUC': np.array(python_auc.eval()), - 'StatPosOut': np.array(python_auc._stat_pos), - 'StatNegOut': np.array(python_auc._stat_neg) + 'StatPosOut': np.array(pos), + 'StatNegOut': np.array(neg) } def test_check_output(self): diff --git a/python/paddle/fluid/tests/unittests/test_auc_single_pred_op.py b/python/paddle/fluid/tests/unittests/test_auc_single_pred_op.py index 6d3e93fa57b..574a820deca 100644 --- a/python/paddle/fluid/tests/unittests/test_auc_single_pred_op.py +++ b/python/paddle/fluid/tests/unittests/test_auc_single_pred_op.py @@ -27,9 +27,12 @@ class TestAucSinglePredOp(OpTest): pred0 = pred[:, 0].reshape(128, 1) labels = np.random.randint(0, 2, (128, 1)).astype("int64") num_thresholds = 200 + slide_steps = 1 - stat_pos = np.zeros((num_thresholds + 1, )).astype("int64") - stat_neg = np.zeros((num_thresholds + 1, )).astype("int64") + stat_pos = np.zeros((1 + slide_steps) * (num_thresholds + 1) + 1, + ).astype("int64") + stat_neg = np.zeros((1 + slide_steps) * (num_thresholds + 1) + 1, + ).astype("int64") self.inputs = { 'Predict': pred0, @@ -40,7 +43,7 @@ class TestAucSinglePredOp(OpTest): self.attrs = { 'curve': 'ROC', 'num_thresholds': num_thresholds, - "slide_steps": 1 + "slide_steps": slide_steps } python_auc = metrics.Auc(name="auc", @@ -50,10 +53,14 @@ class TestAucSinglePredOp(OpTest): pred[i][1] = pred[i][0] python_auc.update(pred, labels) + pos = python_auc._stat_pos * 2 + pos.append(1) + neg = python_auc._stat_neg * 2 + neg.append(1) self.outputs = { 'AUC': np.array(python_auc.eval()), - 'StatPosOut': np.array(python_auc._stat_pos), - 'StatNegOut': np.array(python_auc._stat_neg) + 'StatPosOut': np.array(pos), + 'StatNegOut': np.array(neg) } def test_check_output(self): -- GitLab