提交 d1e6d552 编写于 作者: T typhoonzero

update

上级 4d988ed2
......@@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
class AccuracyOp : public framework::OperatorWithKernel {
class AucOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -76,5 +76,5 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(auc, ops::AccuracyOp, ops::AccuracyOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(auc, ops::AucOp, ops::AucOpMaker);
REGISTER_OP_CPU_KERNEL(auc, ops::AucKernel<paddle::platform::CPUPlace, float>);
......@@ -23,7 +23,7 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
class AccuracyKernel : public framework::OpKernel {
class AucKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* inference = ctx.Input<Tensor>("Inference");
......@@ -45,7 +45,7 @@ class AccuracyKernel : public framework::OpKernel {
thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon;
const int* inference_data = inference->data<int>();
const T* inference_prob_data = inference->data<T>();
const T* inference_prob_data = inference_prob->data<T>();
const T* label_data = label->data<T>();
size_t num_samples = inference->dims()[0];
......@@ -54,17 +54,17 @@ class AccuracyKernel : public framework::OpKernel {
// create local tensor for storing the curve: TP, FN, TN, FP
// TODO(typhoonzero): put these tensors in Scope
// TODO(typhoonzero): use op to caculate these values.
Tensor true_positive, false_positeve, true_negative, false_negative;
Tensor true_positive, false_positive, true_negative, false_negative;
true_positive.Resize({num_thresholds});
false_negative.Resize({num_thresholds});
true_negative.Resize({num_thresholds});
false_positive.Resize({num_thresholds});
int* tp_data = true_positive.mutable_data<int>();
int* fn_data = false_negative.mutable_data<int>();
int* tn_data = true_negative.mutable_data<int>();
int* fp_data = false_positive.mutable_data<int>();
int* tp_data = true_positive.mutable_data<int>(ctx.GetPlace());
int* fn_data = false_negative.mutable_data<int>(ctx.GetPlace());
int* tn_data = true_negative.mutable_data<int>(ctx.GetPlace());
int* fp_data = false_positive.mutable_data<int>(ctx.GetPlace());
for (auto thresh = thresholds_list.begin(); thresh != thresholds_list.end();
thresh++) {
......@@ -101,15 +101,15 @@ class AccuracyKernel : public framework::OpKernel {
tp_rate.Resize({num_thresholds});
fp_rate.Resize({num_thresholds});
rec_rate.Resize({num_thresholds});
float* tp_rate_data = tp_rate.mutable_data<float>();
float* fp_rate_data = fp_rate.mutable_data<float>();
float* rec_rate_data = rec_rate.mutable_data<float>();
float* tp_rate_data = tp_rate.mutable_data<float>(ctx.GetPlace());
float* fp_rate_data = fp_rate.mutable_data<float>(ctx.GetPlace());
float* rec_rate_data = rec_rate.mutable_data<float>(ctx.GetPlace());
for (int i = 0; i < num_thresholds; i++) {
tp_rate_data[i] = ((float)tp_data[i + epsilon) / (tp_data[i] + fn_data[i] + epsilon);
fp_rate_data[i] =
(float)fp_data[i] / (fp_data[i] + tn_data[i] + epsilon);
rec_rate_data[i] =
((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon);
tp_rate_data[i] =
((float)tp_data[i] + epsilon) / (tp_data[i] + fn_data[i] + epsilon);
fp_rate_data[i] = (float)fp_data[i] / (fp_data[i] + tn_data[i] + epsilon);
rec_rate_data[i] =
((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon);
}
if (curve == "ROC") {
......@@ -118,7 +118,7 @@ class AccuracyKernel : public framework::OpKernel {
auto y = (tp_rate_data[i] + tp_rate_data[i - 1]) / 2.0f;
*auc_data = *auc_data + dx * y;
}
} else if (curve = "PR") {
} else if (curve == "PR") {
for (int i = 1; i < num_thresholds; i++) {
auto dx = tp_rate_data[i] - tp_rate_data[i - 1];
auto y = (rec_rate_data[i] + rec_rate_data[i - 1]) / 2.0f;
......
......@@ -50,6 +50,7 @@ USE_OP(cos_sim);
USE_CPU_ONLY_OP(gather);
USE_CPU_ONLY_OP(scatter);
USE_OP(top_k);
USE_CPU_ONLY_OP(auc);
USE_OP(squared_l2_distance);
namespace paddle {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册