提交 d1e6d552 编写于 作者: T typhoonzero

update

上级 4d988ed2
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class AccuracyOp : public framework::OperatorWithKernel { class AucOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -76,5 +76,5 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -76,5 +76,5 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(auc, ops::AccuracyOp, ops::AccuracyOpMaker); REGISTER_OP_WITHOUT_GRADIENT(auc, ops::AucOp, ops::AucOpMaker);
REGISTER_OP_CPU_KERNEL(auc, ops::AucKernel<paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(auc, ops::AucKernel<paddle::platform::CPUPlace, float>);
...@@ -23,7 +23,7 @@ namespace operators { ...@@ -23,7 +23,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename Place, typename T>
class AccuracyKernel : public framework::OpKernel { class AucKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* inference = ctx.Input<Tensor>("Inference"); auto* inference = ctx.Input<Tensor>("Inference");
...@@ -45,7 +45,7 @@ class AccuracyKernel : public framework::OpKernel { ...@@ -45,7 +45,7 @@ class AccuracyKernel : public framework::OpKernel {
thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon; thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon;
const int* inference_data = inference->data<int>(); const int* inference_data = inference->data<int>();
const T* inference_prob_data = inference->data<T>(); const T* inference_prob_data = inference_prob->data<T>();
const T* label_data = label->data<T>(); const T* label_data = label->data<T>();
size_t num_samples = inference->dims()[0]; size_t num_samples = inference->dims()[0];
...@@ -54,17 +54,17 @@ class AccuracyKernel : public framework::OpKernel { ...@@ -54,17 +54,17 @@ class AccuracyKernel : public framework::OpKernel {
// create local tensor for storing the curve: TP, FN, TN, FP // create local tensor for storing the curve: TP, FN, TN, FP
// TODO(typhoonzero): put these tensors in Scope // TODO(typhoonzero): put these tensors in Scope
// TODO(typhoonzero): use op to caculate these values. // TODO(typhoonzero): use op to caculate these values.
Tensor true_positive, false_positeve, true_negative, false_negative; Tensor true_positive, false_positive, true_negative, false_negative;
true_positive.Resize({num_thresholds}); true_positive.Resize({num_thresholds});
false_negative.Resize({num_thresholds}); false_negative.Resize({num_thresholds});
true_negative.Resize({num_thresholds}); true_negative.Resize({num_thresholds});
false_positive.Resize({num_thresholds}); false_positive.Resize({num_thresholds});
int* tp_data = true_positive.mutable_data<int>(); int* tp_data = true_positive.mutable_data<int>(ctx.GetPlace());
int* fn_data = false_negative.mutable_data<int>(); int* fn_data = false_negative.mutable_data<int>(ctx.GetPlace());
int* tn_data = true_negative.mutable_data<int>(); int* tn_data = true_negative.mutable_data<int>(ctx.GetPlace());
int* fp_data = false_positive.mutable_data<int>(); int* fp_data = false_positive.mutable_data<int>(ctx.GetPlace());
for (auto thresh = thresholds_list.begin(); thresh != thresholds_list.end(); for (auto thresh = thresholds_list.begin(); thresh != thresholds_list.end();
thresh++) { thresh++) {
...@@ -101,13 +101,13 @@ class AccuracyKernel : public framework::OpKernel { ...@@ -101,13 +101,13 @@ class AccuracyKernel : public framework::OpKernel {
tp_rate.Resize({num_thresholds}); tp_rate.Resize({num_thresholds});
fp_rate.Resize({num_thresholds}); fp_rate.Resize({num_thresholds});
rec_rate.Resize({num_thresholds}); rec_rate.Resize({num_thresholds});
float* tp_rate_data = tp_rate.mutable_data<float>(); float* tp_rate_data = tp_rate.mutable_data<float>(ctx.GetPlace());
float* fp_rate_data = fp_rate.mutable_data<float>(); float* fp_rate_data = fp_rate.mutable_data<float>(ctx.GetPlace());
float* rec_rate_data = rec_rate.mutable_data<float>(); float* rec_rate_data = rec_rate.mutable_data<float>(ctx.GetPlace());
for (int i = 0; i < num_thresholds; i++) { for (int i = 0; i < num_thresholds; i++) {
tp_rate_data[i] = ((float)tp_data[i + epsilon) / (tp_data[i] + fn_data[i] + epsilon); tp_rate_data[i] =
fp_rate_data[i] = ((float)tp_data[i] + epsilon) / (tp_data[i] + fn_data[i] + epsilon);
(float)fp_data[i] / (fp_data[i] + tn_data[i] + epsilon); fp_rate_data[i] = (float)fp_data[i] / (fp_data[i] + tn_data[i] + epsilon);
rec_rate_data[i] = rec_rate_data[i] =
((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon); ((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon);
} }
...@@ -118,7 +118,7 @@ class AccuracyKernel : public framework::OpKernel { ...@@ -118,7 +118,7 @@ class AccuracyKernel : public framework::OpKernel {
auto y = (tp_rate_data[i] + tp_rate_data[i - 1]) / 2.0f; auto y = (tp_rate_data[i] + tp_rate_data[i - 1]) / 2.0f;
*auc_data = *auc_data + dx * y; *auc_data = *auc_data + dx * y;
} }
} else if (curve = "PR") { } else if (curve == "PR") {
for (int i = 1; i < num_thresholds; i++) { for (int i = 1; i < num_thresholds; i++) {
auto dx = tp_rate_data[i] - tp_rate_data[i - 1]; auto dx = tp_rate_data[i] - tp_rate_data[i - 1];
auto y = (rec_rate_data[i] + rec_rate_data[i - 1]) / 2.0f; auto y = (rec_rate_data[i] + rec_rate_data[i - 1]) / 2.0f;
......
...@@ -50,6 +50,7 @@ USE_OP(cos_sim); ...@@ -50,6 +50,7 @@ USE_OP(cos_sim);
USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(gather);
USE_CPU_ONLY_OP(scatter); USE_CPU_ONLY_OP(scatter);
USE_OP(top_k); USE_OP(top_k);
USE_CPU_ONLY_OP(auc);
USE_OP(squared_l2_distance); USE_OP(squared_l2_distance);
namespace paddle { namespace paddle {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册