From 970613fc152b77a4fa76876c1fb21fc8473affaa Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 1 Nov 2017 23:23:42 +0800 Subject: [PATCH] Refine and follow comments. --- paddle/operators/precision_recall_op.cc | 62 ++++++------ paddle/operators/precision_recall_op.h | 54 +++++------ .../tests/test_precision_recall_op.py | 97 ++++++++++--------- 3 files changed, 115 insertions(+), 98 deletions(-) diff --git a/paddle/operators/precision_recall_op.cc b/paddle/operators/precision_recall_op.cc index a3f4c0749..39da1e0bf 100644 --- a/paddle/operators/precision_recall_op.cc +++ b/paddle/operators/precision_recall_op.cc @@ -22,8 +22,10 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Predictions"), - "Input(Predictions) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("MaxProbs"), + "Input(MaxProbs) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Indices"), + "Input(Indices) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Labels"), "Input(Labels) should not be null."); PADDLE_ENFORCE(ctx->HasOutput("BatchMetrics"), @@ -33,34 +35,36 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("AccumStatesInfo"), "Output(AccumStatesInfo) should not be null."); - auto predictions_dims = ctx->GetInputDim("Predictions"); + int64_t cls_num = + static_cast(ctx->Attrs().Get("class_number")); + auto max_probs_dims = ctx->GetInputDim("MaxProbs"); auto labels_dims = ctx->GetInputDim("Labels"); + PADDLE_ENFORCE_EQ(max_probs_dims[1], 1, + "Each instance contains one max probability, so the " + "shape of Input(MaxProbs) should be [batch_size, 1]."); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Indices"), max_probs_dims, + "The shape of Input(Indices) should be [batch_size, 1]."); + PADDLE_ENFORCE_EQ(max_probs_dims[0], labels_dims[0], + "The 1st dimension of Input(MaxProbs) and " + "Input(Labels) both are batch_size and the shape should " + "be the same."); + PADDLE_ENFORCE_EQ(labels_dims[1], 1, + "The 2nd dimension of Input(Labels) contains instance " + "label and the shape should be equal to 1."); if (ctx->HasInput("Weights")) { auto weights_dims = ctx->GetInputDim("Weights"); PADDLE_ENFORCE_EQ(weights_dims, - framework::make_ddim({predictions_dims[0], 1}), + framework::make_ddim({max_probs_dims[0], 1}), "The shape of Input(Weights) should be " "[batch_size, 1]."); } if (ctx->HasInput("StatesInfo")) { auto states_dims = ctx->GetInputDim("StatesInfo"); - PADDLE_ENFORCE_EQ(states_dims, - framework::make_ddim({predictions_dims[1], 4}), + PADDLE_ENFORCE_EQ(states_dims, framework::make_ddim({cls_num, 4}), "The shape of Input(StatesInfo) should be " "[class_number, 4]."); } - PADDLE_ENFORCE_EQ(predictions_dims[0], labels_dims[0], - "The 1st dimension of Input(Predictions) and " - "Input(Labels) both are batch_size and the shape should " - "be the same."); - PADDLE_ENFORCE_EQ(labels_dims[1], 1, - "The 2nd dimension of Input(Labels) " - "contains instance label and the shape should be equal " - "to 1"); - PADDLE_ENFORCE_GE(predictions_dims[1], 1, - "The shape of Input(Predictions)'s 2nd dimension is " - "equal to class number and should be at least 1."); // Layouts of BatchMetrics and AccumMetrics both are: // [ @@ -72,13 +76,13 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { // Shape of AccumStatesInfo is [class_number, 4] // The layout of each row is: // [ TP, FP, TN, FN ] - ctx->SetOutputDim("AccumStatesInfo", {predictions_dims[1], 4}); + ctx->SetOutputDim("AccumStatesInfo", {cls_num, 4}); } protected: framework::DataType IndicateDataType( const framework::ExecutionContext &ctx) const override { - return framework::ToDataType(ctx.Input("Predictions")->type()); + return framework::ToDataType(ctx.Input("MaxProbs")->type()); } }; @@ -87,11 +91,15 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { PrecisionRecallOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("Predictions", - "(Tensor, default Tensor), a 2-D tensor with shape N x D, " - "where N is the batch size and D is the number of classes. " - "Each row contains probabilities for an instance which computed " - "by the previous operator."); + AddInput("MaxProbs", + "(Tensor, default Tensor), a 2-D tensor with shape N x 1, " + "where N is the batch size. Each row contains the max probability " + "of an instance which computed by the previous top_k (k=1) " + "operator."); + AddInput("Indices", + "(Tensor, default Tensor), a 2-D tensor with shape N x 1, " + "where N is the batch size. Each row contains the corresponding " + "index which computed by the previous top_k (k=1) operator."); AddInput("Labels", "(Tensor, default Tensor), a 2-D tensor with shape N x 1, " "where N is the batch size. Each element is a label and the " @@ -125,9 +133,9 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { "accumulated state variables used to compute metrics. The layout " "for each class is [true positives, false positives, " "true negatives, false negatives]."); - + AddAttr("class_number", "Number of classes to be evaluated."); AddComment(R"DOC( -When given 'Input(Predictions)' and 'Input(Labels)', this operator can be used +When given 'Input(Indices)' and 'Input(Labels)', this operator can be used to compute various metrics including: - macro average precision - macro average recall @@ -141,7 +149,7 @@ false positives and false negatives. Here count of true negatives is not necessary, but counting it may provide potential usage and the cost is trivial, so the operator also provides count of true negatives. -We define state as a 2-D tensor with shape [class number, 4]. Each row of a +We define state as a 2-D tensor with shape [class_number, 4]. Each row of a state contains statistic variables for corresponding class. Layout of each row is: TP(true positives), FP(false positives), TN(true negatives), FN(false negatives). If 'Input(Weights)' provided, TP, FP, TN, FN will be diff --git a/paddle/operators/precision_recall_op.h b/paddle/operators/precision_recall_op.h index 2e49bc3bb..4a871ce67 100644 --- a/paddle/operators/precision_recall_op.h +++ b/paddle/operators/precision_recall_op.h @@ -30,7 +30,7 @@ template class PrecisionRecallKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* in0 = ctx.Input("Predictions"); + auto* in0 = ctx.Input("Indices"); auto* in1 = ctx.Input("Labels"); auto* in2 = ctx.Input("Weights"); auto* in3 = ctx.Input("StatesInfo"); @@ -38,8 +38,9 @@ class PrecisionRecallKernel : public framework::OpKernel { auto* out1 = ctx.Output("AccumMetrics"); auto* out2 = ctx.Output("AccumStatesInfo"); - const T* predictions_data = in0->data(); + const int* ids_data = in0->data(); const int* labels_data = in1->data(); + size_t cls_num = static_cast(ctx.Attr("class_number")); const T* weights_data = in2 ? in2->data() : nullptr; const T* states_data = in3 ? in3->data() : nullptr; double* batch_metrics_data = out0->mutable_data(ctx.GetPlace()); @@ -50,43 +51,42 @@ class PrecisionRecallKernel : public framework::OpKernel { T* accum_states_data = out2->data(); size_t sample_num = in0->dims()[0]; - size_t class_dim = in0->dims()[1]; size_t state_var_num = 4; // TP FP TN FN // get states info for current batch for (size_t i = 0; i < sample_num; ++i) { - size_t max_idx = 0; - T max_val = predictions_data[i * class_dim]; - for (size_t j = 1; j < class_dim; ++j) { - if (max_val < predictions_data[i * class_dim + j]) { - max_idx = j; - max_val = predictions_data[i * class_dim + j]; - } - } + size_t idx = ids_data[i]; + size_t label = labels_data[i]; + + PADDLE_ENFORCE(idx >= 0 && idx < cls_num, + "Class index of each instance should be in " + "[0, class_number)."); + PADDLE_ENFORCE(label >= 0 && label < cls_num, + "Label of each instance should be in [0, class_number)."); T w = weights_data ? weights_data[i] : 1.0; - if (max_idx == labels_data[i]) { - accum_states_data[max_idx * state_var_num + TP] += w; - for (size_t j = 0; j < class_dim; ++j) { + if (idx == label) { + accum_states_data[idx * state_var_num + TP] += w; + for (size_t j = 0; j < cls_num; ++j) { accum_states_data[j * state_var_num + TN] += w; } - accum_states_data[max_idx * state_var_num + TN] -= w; + accum_states_data[idx * state_var_num + TN] -= w; } else { - accum_states_data[labels_data[i] * state_var_num + FN] += w; - accum_states_data[max_idx * state_var_num + FP] += w; - for (size_t j = 0; j < class_dim; ++j) { + accum_states_data[label * state_var_num + FN] += w; + accum_states_data[idx * state_var_num + FP] += w; + for (size_t j = 0; j < cls_num; ++j) { accum_states_data[j * state_var_num + TN] += w; } - accum_states_data[max_idx * state_var_num + TN] -= w; - accum_states_data[labels_data[i] * state_var_num + TN] -= w; + accum_states_data[idx * state_var_num + TN] -= w; + accum_states_data[label * state_var_num + TN] -= w; } } ComputeMetrics(accum_states_data, batch_metrics_data, state_var_num, - class_dim); + cls_num); if (states_data) { - for (size_t i = 0; i < class_dim; ++i) { + for (size_t i = 0; i < cls_num; ++i) { for (size_t j = 0; j < state_var_num; ++j) { size_t idx = i * state_var_num + j; accum_states_data[idx] += states_data[idx]; @@ -95,7 +95,7 @@ class PrecisionRecallKernel : public framework::OpKernel { } ComputeMetrics(accum_states_data, accum_metrics_data, state_var_num, - class_dim); + cls_num); } // expose to be reused @@ -122,14 +122,14 @@ class PrecisionRecallKernel : public framework::OpKernel { protected: void ComputeMetrics(const T* states_data, double* metrics_data, - size_t state_var_num, size_t class_dim) const { + size_t state_var_num, size_t cls_num) const { T total_tp_count = 0; T total_fp_count = 0; T total_fn_count = 0; T macro_avg_precision = 0.0; T macro_avg_recall = 0.0; - for (size_t i = 0; i < class_dim; ++i) { + for (size_t i = 0; i < cls_num; ++i) { T tp_count = states_data[i * state_var_num + TP]; T fp_count = states_data[i * state_var_num + FP]; T fn_count = states_data[i * state_var_num + FN]; @@ -139,8 +139,8 @@ class PrecisionRecallKernel : public framework::OpKernel { macro_avg_precision += CalcPrecision(tp_count, fp_count); macro_avg_recall += CalcRecall(tp_count, fn_count); } - macro_avg_precision /= class_dim; - macro_avg_recall /= class_dim; + macro_avg_precision /= cls_num; + macro_avg_recall /= cls_num; T macro_f1_score = CalcF1Score(macro_avg_precision, macro_avg_recall); T micro_avg_precision = CalcPrecision(total_tp_count, total_fp_count); diff --git a/python/paddle/v2/framework/tests/test_precision_recall_op.py b/python/paddle/v2/framework/tests/test_precision_recall_op.py index 33efd717d..d3dbdb6e2 100644 --- a/python/paddle/v2/framework/tests/test_precision_recall_op.py +++ b/python/paddle/v2/framework/tests/test_precision_recall_op.py @@ -21,45 +21,44 @@ def calc_f1_score(precision, recall): return 0.0 -def get_states(predictions, labels, weights=None): - ins_num = predictions.shape[0] - class_num = predictions.shape[1] +def get_states(idxs, labels, cls_num, weights=None): + ins_num = idxs.shape[0] # TP FP TN FN - states = np.zeros((class_num, 4)).astype('float32') + states = np.zeros((cls_num, 4)).astype('float32') for i in xrange(ins_num): w = weights[i] if weights is not None else 1.0 - max_idx = np.argmax(predictions[i]) - if max_idx == labels[i][0]: - states[max_idx][0] += w - for j in xrange(class_num): + idx = idxs[i][0] + label = labels[i][0] + if idx == label: + states[idx][0] += w + for j in xrange(cls_num): states[j][2] += w - states[max_idx][2] -= w + states[idx][2] -= w else: - states[labels[i][0]][3] += w - states[max_idx][1] += w - for j in xrange(class_num): + states[label][3] += w + states[idx][1] += w + for j in xrange(cls_num): states[j][2] += w - states[labels[i][0]][2] -= w - states[max_idx][2] -= w + states[label][2] -= w + states[idx][2] -= w return states -def compute_metrics(states): - class_num = states.shape[0] +def compute_metrics(states, cls_num): total_tp_count = 0.0 total_fp_count = 0.0 total_fn_count = 0.0 macro_avg_precision = 0.0 macro_avg_recall = 0.0 - for i in xrange(class_num): + for i in xrange(cls_num): total_tp_count += states[i][0] total_fp_count += states[i][1] total_fn_count += states[i][3] macro_avg_precision += calc_precision(states[i][0], states[i][1]) macro_avg_recall += calc_recall(states[i][0], states[i][3]) metrics = [] - macro_avg_precision /= class_num - macro_avg_recall /= class_num + macro_avg_precision /= cls_num + macro_avg_recall /= cls_num metrics.append(macro_avg_precision) metrics.append(macro_avg_recall) metrics.append(calc_f1_score(macro_avg_precision, macro_avg_recall)) @@ -75,15 +74,18 @@ class TestPrecisionRecallOp_0(OpTest): def setUp(self): self.op_type = "precision_recall" ins_num = 64 - class_num = 10 - predictions = np.random.uniform(0, 1.0, - (ins_num, class_num)).astype('float32') - labels = np.random.choice(xrange(class_num), ins_num).reshape( + cls_num = 10 + max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') + idxs = np.random.choice(xrange(cls_num), ins_num).reshape( (ins_num, 1)).astype('int32') - states = get_states(predictions, labels) - metrics = compute_metrics(states) + labels = np.random.choice(xrange(cls_num), ins_num).reshape( + (ins_num, 1)).astype('int32') + states = get_states(idxs, labels, cls_num) + metrics = compute_metrics(states, cls_num) + + self.attrs = {'class_number': cls_num} - self.inputs = {'Predictions': predictions, 'Labels': labels} + self.inputs = {'MaxProbs': max_probs, 'Indices': idxs, 'Labels': labels} self.outputs = { 'BatchMetrics': metrics, @@ -99,18 +101,22 @@ class TestPrecisionRecallOp_1(OpTest): def setUp(self): self.op_type = "precision_recall" ins_num = 64 - class_num = 10 - predictions = np.random.uniform(0, 1.0, - (ins_num, class_num)).astype('float32') + cls_num = 10 + max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') + idxs = np.random.choice(xrange(cls_num), ins_num).reshape( + (ins_num, 1)).astype('int32') weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') - predictions = np.random.random((ins_num, class_num)).astype('float32') - labels = np.random.choice(xrange(class_num), ins_num).reshape( + labels = np.random.choice(xrange(cls_num), ins_num).reshape( (ins_num, 1)).astype('int32') - states = get_states(predictions, labels, weights) - metrics = compute_metrics(states) + states = get_states(idxs, labels, cls_num, weights) + metrics = compute_metrics(states, cls_num) + + self.attrs = {'class_number': cls_num} + self.inputs = { - 'Predictions': predictions, + 'MaxProbs': max_probs, + 'Indices': idxs, 'Labels': labels, 'Weights': weights } @@ -129,22 +135,25 @@ class TestPrecisionRecallOp_2(OpTest): def setUp(self): self.op_type = "precision_recall" ins_num = 64 - class_num = 10 - predictions = np.random.uniform(0, 1.0, - (ins_num, class_num)).astype('float32') + cls_num = 10 + max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') + idxs = np.random.choice(xrange(cls_num), ins_num).reshape( + (ins_num, 1)).astype('int32') weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') - predictions = np.random.random((ins_num, class_num)).astype('float32') - labels = np.random.choice(xrange(class_num), ins_num).reshape( + labels = np.random.choice(xrange(cls_num), ins_num).reshape( (ins_num, 1)).astype('int32') - states = np.random.randint(0, 30, (class_num, 4)).astype('float32') + states = np.random.randint(0, 30, (cls_num, 4)).astype('float32') - accum_states = get_states(predictions, labels, weights) - batch_metrics = compute_metrics(accum_states) + accum_states = get_states(idxs, labels, cls_num, weights) + batch_metrics = compute_metrics(accum_states, cls_num) accum_states += states - accum_metrics = compute_metrics(accum_states) + accum_metrics = compute_metrics(accum_states, cls_num) + + self.attrs = {'class_number': cls_num} self.inputs = { - 'Predictions': predictions, + 'MaxProbs': max_probs, + 'Indices': idxs, 'Labels': labels, 'Weights': weights, 'StatesInfo': states -- GitLab