提交 65dbbd57 编写于 作者: Y yangyaming

Add and pass unittests.

上级 06c7c8c8
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/precision_recall_op.h"
namespace paddle {
namespace operators {
......@@ -37,13 +39,15 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
if (ctx->HasInput("Weights")) {
auto weights_dims = ctx->GetInputDim("Weights");
PADDLE_ENFORCE_EQ(weights_dims, {predictions_dims[0], 1},
PADDLE_ENFORCE_EQ(weights_dims,
framework::make_ddim({predictions_dims[0], 1}),
"The shape of Input(Weights) should be "
"[batch_size, 1].");
}
if (ctx->HasInput("StatesInfo")) {
auto states_dims = ctx->GetInputDim("StatesInfo");
PADDLE_ENFORCE_EQ(states_dims, {predictions_dims[1], 4},
PADDLE_ENFORCE_EQ(states_dims,
framework::make_ddim({predictions_dims[1], 4}),
"The shape of Input(StatesInfo) should be "
"[class_number, 4].");
}
......@@ -71,6 +75,12 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
// [ TP, FP, TN, FN ]
ctx->SetOutputDim("AccumStatesInfo", {predictions_dims[1], 4});
}
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext &ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Predictions")->type());
}
};
class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -98,6 +108,9 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
"provided, current state will be accumulated to this state and "
"the accumulation state will be as the output state.")
.AsDispensable();
AddOutput("BatchMetrics", "");
AddOutput("AccumMetrics", "");
AddOutput("AccumStatesInfo", "");
AddComment(R"DOC(
)DOC");
......@@ -113,6 +126,4 @@ REGISTER_OP_WITHOUT_GRADIENT(precision_recall, ops::PrecisionRecallOp,
REGISTER_OP_CPU_KERNEL(
precision_recall,
ops::PrecisionRecallKernel<paddle::platform::CPUPlace, float>,
ops::PrecisionRecallKernel<paddle::platform::CPUPlace, int>,
ops::PrecisionRecallKernel<paddle::platform::CPUPlace, double>,
ops::PrecisionRecallKernel<paddle::platform::CPUPlace, int64_t>,
ops::PrecisionRecallKernel<paddle::platform::CPUPlace, double>);
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -37,7 +39,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
auto* out2 = ctx.Output<Tensor>("AccumStatesInfo");
const T* predictions_data = in0->data<T>();
const T* labels_data = in1->data<T>();
const int* labels_data = in1->data<int>();
const T* weights_data = in2 ? in2->data<T>() : nullptr;
const T* states_data = in3 ? in3->data<T>() : nullptr;
T* batch_metrics_data = out0->mutable_data<T>(ctx.GetPlace());
......@@ -45,7 +47,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
out2->mutable_data<T>(ctx.GetPlace());
auto accum_states = EigenMatrix<T>::From(*out2);
accum_states.setZero();
T* accum_states_data = out2->data<T>(ctx.GetPlace());
T* accum_states_data = out2->data<T>();
size_t sample_num = in0->dims()[0];
size_t class_dim = in0->dims()[1];
......@@ -76,7 +78,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
accum_states_data[j * state_var_num + TN] += w;
}
accum_states_data[max_idx * state_var_num + TN] -= w;
accum_states_data[labels_data[j] * state_var_num + TN] -= w;
accum_states_data[labels_data[i] * state_var_num + TN] -= w;
}
}
......@@ -108,7 +110,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
if (tp_count > 0.0 || fn_count > 0.0) {
return tp_count / (tp_count + fn_count);
}
return 1.0
return 1.0;
}
static inline T CalcF1Score(T precision, T recall) {
......@@ -120,7 +122,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
protected:
void ComputeMetrics(const T* states_data, T* metrics_data,
size_t state_var_num, size_t class_dim) {
size_t state_var_num, size_t class_dim) const {
T total_tp_count = 0;
T total_fp_count = 0;
T total_fn_count = 0;
......@@ -143,7 +145,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
T micro_avg_precision = CalcPrecision(total_tp_count, total_fp_count);
T micro_avg_recall = CalcRecall(total_tp_count, total_fn_count);
T micro_f1_score = CalcRecall(micro_avg_precision, micro_avg_recall);
T micro_f1_score = CalcF1Score(micro_avg_precision, micro_avg_recall);
// fill metrics data
metrics_data[0] = macro_avg_precision;
......
import unittest
import numpy as np
from op_test import OpTest
def calc_precision(tp_count, fp_count):
if tp_count > 0.0 or fp_count > 0.0:
return tp_count / (tp_count + fp_count)
return 1.0
def calc_recall(tp_count, fn_count):
if tp_count > 0.0 or fn_count > 0.0:
return tp_count / (tp_count + fn_count)
return 1.0
def calc_f1_score(precision, recall):
if precision > 0.0 or recall > 0.0:
return 2 * precision * recall / (precision + recall)
return 0.0
def get_states(predictions, labels, weights=None):
ins_num = predictions.shape[0]
class_num = predictions.shape[1]
# TP FP TN FN
states = np.zeros((class_num, 4)).astype('float32')
for i in xrange(ins_num):
w = weights[i] if weights is not None else 1.0
max_idx = np.argmax(predictions[i])
if max_idx == labels[i][0]:
states[max_idx][0] += w
for j in xrange(class_num):
states[j][2] += w
states[max_idx][2] -= w
else:
states[labels[i][0]][3] += w
states[max_idx][1] += w
for j in xrange(class_num):
states[j][2] += w
states[labels[i][0]][2] -= w
states[max_idx][2] -= w
return states
def compute_metrics(states):
class_num = states.shape[0]
total_tp_count = 0.0
total_fp_count = 0.0
total_fn_count = 0.0
macro_avg_precision = 0.0
macro_avg_recall = 0.0
for i in xrange(class_num):
total_tp_count += states[i][0]
total_fp_count += states[i][1]
total_fn_count += states[i][3]
macro_avg_precision += calc_precision(states[i][0], states[i][1])
macro_avg_recall += calc_recall(states[i][0], states[i][3])
metrics = []
macro_avg_precision /= class_num
macro_avg_recall /= class_num
metrics.append(macro_avg_precision)
metrics.append(macro_avg_recall)
metrics.append(calc_f1_score(macro_avg_precision, macro_avg_recall))
micro_avg_precision = calc_precision(total_tp_count, total_fp_count)
metrics.append(micro_avg_precision)
micro_avg_recall = calc_recall(total_tp_count, total_fn_count)
metrics.append(micro_avg_recall)
metrics.append(calc_f1_score(micro_avg_precision, micro_avg_recall))
return np.array(metrics).astype('float32')
class TestPrecisionRecallOp_0(OpTest):
def setUp(self):
self.op_type = "precision_recall"
ins_num = 64
class_num = 10
predictions = np.random.uniform(0, 1.0,
(ins_num, class_num)).astype('float32')
labels = np.random.choice(xrange(class_num), ins_num).reshape(
(ins_num, 1)).astype('int32')
states = get_states(predictions, labels)
metrics = compute_metrics(states)
self.inputs = {'Predictions': predictions, 'Labels': labels}
self.outputs = {
'BatchMetrics': metrics,
'AccumMetrics': metrics,
'AccumStatesInfo': states
}
def test_check_output(self):
self.check_output()
class TestPrecisionRecallOp_1(OpTest):
def setUp(self):
self.op_type = "precision_recall"
ins_num = 64
class_num = 10
predictions = np.random.uniform(0, 1.0,
(ins_num, class_num)).astype('float32')
weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
predictions = np.random.random((ins_num, class_num)).astype('float32')
labels = np.random.choice(xrange(class_num), ins_num).reshape(
(ins_num, 1)).astype('int32')
states = get_states(predictions, labels, weights)
metrics = compute_metrics(states)
self.inputs = {
'Predictions': predictions,
'Labels': labels,
'Weights': weights
}
self.outputs = {
'BatchMetrics': metrics,
'AccumMetrics': metrics,
'AccumStatesInfo': states
}
def test_check_output(self):
self.check_output()
class TestPrecisionRecallOp_2(OpTest):
def setUp(self):
self.op_type = "precision_recall"
ins_num = 64
class_num = 10
predictions = np.random.uniform(0, 1.0,
(ins_num, class_num)).astype('float32')
weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
predictions = np.random.random((ins_num, class_num)).astype('float32')
labels = np.random.choice(xrange(class_num), ins_num).reshape(
(ins_num, 1)).astype('int32')
states = np.random.randint(0, 30, (class_num, 4)).astype('float32')
accum_states = get_states(predictions, labels, weights)
batch_metrics = compute_metrics(accum_states)
accum_states += states
accum_metrics = compute_metrics(accum_states)
self.inputs = {
'Predictions': predictions,
'Labels': labels,
'Weights': weights,
'StatesInfo': states
}
self.outputs = {
'BatchMetrics': batch_metrics,
'AccumMetrics': accum_metrics,
'AccumStatesInfo': accum_states
}
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册