提交 970613fc 编写于 作者: Y yangyaming

Refine and follow comments.

上级 d2b10cc0
...@@ -22,8 +22,10 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { ...@@ -22,8 +22,10 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Predictions"), PADDLE_ENFORCE(ctx->HasInput("MaxProbs"),
"Input(Predictions) should not be null."); "Input(MaxProbs) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Indices"),
"Input(Indices) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Labels"), PADDLE_ENFORCE(ctx->HasInput("Labels"),
"Input(Labels) should not be null."); "Input(Labels) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchMetrics"), PADDLE_ENFORCE(ctx->HasOutput("BatchMetrics"),
...@@ -33,34 +35,36 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { ...@@ -33,34 +35,36 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("AccumStatesInfo"), PADDLE_ENFORCE(ctx->HasOutput("AccumStatesInfo"),
"Output(AccumStatesInfo) should not be null."); "Output(AccumStatesInfo) should not be null.");
auto predictions_dims = ctx->GetInputDim("Predictions"); int64_t cls_num =
static_cast<int64_t>(ctx->Attrs().Get<int>("class_number"));
auto max_probs_dims = ctx->GetInputDim("MaxProbs");
auto labels_dims = ctx->GetInputDim("Labels"); auto labels_dims = ctx->GetInputDim("Labels");
PADDLE_ENFORCE_EQ(max_probs_dims[1], 1,
"Each instance contains one max probability, so the "
"shape of Input(MaxProbs) should be [batch_size, 1].");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Indices"), max_probs_dims,
"The shape of Input(Indices) should be [batch_size, 1].");
PADDLE_ENFORCE_EQ(max_probs_dims[0], labels_dims[0],
"The 1st dimension of Input(MaxProbs) and "
"Input(Labels) both are batch_size and the shape should "
"be the same.");
PADDLE_ENFORCE_EQ(labels_dims[1], 1,
"The 2nd dimension of Input(Labels) contains instance "
"label and the shape should be equal to 1.");
if (ctx->HasInput("Weights")) { if (ctx->HasInput("Weights")) {
auto weights_dims = ctx->GetInputDim("Weights"); auto weights_dims = ctx->GetInputDim("Weights");
PADDLE_ENFORCE_EQ(weights_dims, PADDLE_ENFORCE_EQ(weights_dims,
framework::make_ddim({predictions_dims[0], 1}), framework::make_ddim({max_probs_dims[0], 1}),
"The shape of Input(Weights) should be " "The shape of Input(Weights) should be "
"[batch_size, 1]."); "[batch_size, 1].");
} }
if (ctx->HasInput("StatesInfo")) { if (ctx->HasInput("StatesInfo")) {
auto states_dims = ctx->GetInputDim("StatesInfo"); auto states_dims = ctx->GetInputDim("StatesInfo");
PADDLE_ENFORCE_EQ(states_dims, PADDLE_ENFORCE_EQ(states_dims, framework::make_ddim({cls_num, 4}),
framework::make_ddim({predictions_dims[1], 4}),
"The shape of Input(StatesInfo) should be " "The shape of Input(StatesInfo) should be "
"[class_number, 4]."); "[class_number, 4].");
} }
PADDLE_ENFORCE_EQ(predictions_dims[0], labels_dims[0],
"The 1st dimension of Input(Predictions) and "
"Input(Labels) both are batch_size and the shape should "
"be the same.");
PADDLE_ENFORCE_EQ(labels_dims[1], 1,
"The 2nd dimension of Input(Labels) "
"contains instance label and the shape should be equal "
"to 1");
PADDLE_ENFORCE_GE(predictions_dims[1], 1,
"The shape of Input(Predictions)'s 2nd dimension is "
"equal to class number and should be at least 1.");
// Layouts of BatchMetrics and AccumMetrics both are: // Layouts of BatchMetrics and AccumMetrics both are:
// [ // [
...@@ -72,13 +76,13 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { ...@@ -72,13 +76,13 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
// Shape of AccumStatesInfo is [class_number, 4] // Shape of AccumStatesInfo is [class_number, 4]
// The layout of each row is: // The layout of each row is:
// [ TP, FP, TN, FN ] // [ TP, FP, TN, FN ]
ctx->SetOutputDim("AccumStatesInfo", {predictions_dims[1], 4}); ctx->SetOutputDim("AccumStatesInfo", {cls_num, 4});
} }
protected: protected:
framework::DataType IndicateDataType( framework::DataType IndicateDataType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Predictions")->type()); return framework::ToDataType(ctx.Input<Tensor>("MaxProbs")->type());
} }
}; };
...@@ -87,11 +91,15 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -87,11 +91,15 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
PrecisionRecallOpMaker(framework::OpProto *proto, PrecisionRecallOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Predictions", AddInput("MaxProbs",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, " "(Tensor, default Tensor<float>), a 2-D tensor with shape N x 1, "
"where N is the batch size and D is the number of classes. " "where N is the batch size. Each row contains the max probability "
"Each row contains probabilities for an instance which computed " "of an instance which computed by the previous top_k (k=1) "
"by the previous operator."); "operator.");
AddInput("Indices",
"(Tensor, default Tensor<int>), a 2-D tensor with shape N x 1, "
"where N is the batch size. Each row contains the corresponding "
"index which computed by the previous top_k (k=1) operator.");
AddInput("Labels", AddInput("Labels",
"(Tensor, default Tensor<int>), a 2-D tensor with shape N x 1, " "(Tensor, default Tensor<int>), a 2-D tensor with shape N x 1, "
"where N is the batch size. Each element is a label and the " "where N is the batch size. Each element is a label and the "
...@@ -125,9 +133,9 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -125,9 +133,9 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
"accumulated state variables used to compute metrics. The layout " "accumulated state variables used to compute metrics. The layout "
"for each class is [true positives, false positives, " "for each class is [true positives, false positives, "
"true negatives, false negatives]."); "true negatives, false negatives].");
AddAttr<int>("class_number", "Number of classes to be evaluated.");
AddComment(R"DOC( AddComment(R"DOC(
When given 'Input(Predictions)' and 'Input(Labels)', this operator can be used When given 'Input(Indices)' and 'Input(Labels)', this operator can be used
to compute various metrics including: to compute various metrics including:
- macro average precision - macro average precision
- macro average recall - macro average recall
...@@ -141,7 +149,7 @@ false positives and false negatives. Here count of true negatives is not ...@@ -141,7 +149,7 @@ false positives and false negatives. Here count of true negatives is not
necessary, but counting it may provide potential usage and the cost is necessary, but counting it may provide potential usage and the cost is
trivial, so the operator also provides count of true negatives. trivial, so the operator also provides count of true negatives.
We define state as a 2-D tensor with shape [class number, 4]. Each row of a We define state as a 2-D tensor with shape [class_number, 4]. Each row of a
state contains statistic variables for corresponding class. Layout of each row state contains statistic variables for corresponding class. Layout of each row
is: TP(true positives), FP(false positives), TN(true negatives), is: TP(true positives), FP(false positives), TN(true negatives),
FN(false negatives). If 'Input(Weights)' provided, TP, FP, TN, FN will be FN(false negatives). If 'Input(Weights)' provided, TP, FP, TN, FN will be
......
...@@ -30,7 +30,7 @@ template <typename Place, typename T> ...@@ -30,7 +30,7 @@ template <typename Place, typename T>
class PrecisionRecallKernel : public framework::OpKernel<T> { class PrecisionRecallKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in0 = ctx.Input<Tensor>("Predictions"); auto* in0 = ctx.Input<Tensor>("Indices");
auto* in1 = ctx.Input<Tensor>("Labels"); auto* in1 = ctx.Input<Tensor>("Labels");
auto* in2 = ctx.Input<Tensor>("Weights"); auto* in2 = ctx.Input<Tensor>("Weights");
auto* in3 = ctx.Input<Tensor>("StatesInfo"); auto* in3 = ctx.Input<Tensor>("StatesInfo");
...@@ -38,8 +38,9 @@ class PrecisionRecallKernel : public framework::OpKernel<T> { ...@@ -38,8 +38,9 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
auto* out1 = ctx.Output<Tensor>("AccumMetrics"); auto* out1 = ctx.Output<Tensor>("AccumMetrics");
auto* out2 = ctx.Output<Tensor>("AccumStatesInfo"); auto* out2 = ctx.Output<Tensor>("AccumStatesInfo");
const T* predictions_data = in0->data<T>(); const int* ids_data = in0->data<int>();
const int* labels_data = in1->data<int>(); const int* labels_data = in1->data<int>();
size_t cls_num = static_cast<size_t>(ctx.Attr<int>("class_number"));
const T* weights_data = in2 ? in2->data<T>() : nullptr; const T* weights_data = in2 ? in2->data<T>() : nullptr;
const T* states_data = in3 ? in3->data<T>() : nullptr; const T* states_data = in3 ? in3->data<T>() : nullptr;
double* batch_metrics_data = out0->mutable_data<double>(ctx.GetPlace()); double* batch_metrics_data = out0->mutable_data<double>(ctx.GetPlace());
...@@ -50,43 +51,42 @@ class PrecisionRecallKernel : public framework::OpKernel<T> { ...@@ -50,43 +51,42 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
T* accum_states_data = out2->data<T>(); T* accum_states_data = out2->data<T>();
size_t sample_num = in0->dims()[0]; size_t sample_num = in0->dims()[0];
size_t class_dim = in0->dims()[1];
size_t state_var_num = 4; // TP FP TN FN size_t state_var_num = 4; // TP FP TN FN
// get states info for current batch // get states info for current batch
for (size_t i = 0; i < sample_num; ++i) { for (size_t i = 0; i < sample_num; ++i) {
size_t max_idx = 0; size_t idx = ids_data[i];
T max_val = predictions_data[i * class_dim]; size_t label = labels_data[i];
for (size_t j = 1; j < class_dim; ++j) {
if (max_val < predictions_data[i * class_dim + j]) { PADDLE_ENFORCE(idx >= 0 && idx < cls_num,
max_idx = j; "Class index of each instance should be in "
max_val = predictions_data[i * class_dim + j]; "[0, class_number).");
} PADDLE_ENFORCE(label >= 0 && label < cls_num,
} "Label of each instance should be in [0, class_number).");
T w = weights_data ? weights_data[i] : 1.0; T w = weights_data ? weights_data[i] : 1.0;
if (max_idx == labels_data[i]) { if (idx == label) {
accum_states_data[max_idx * state_var_num + TP] += w; accum_states_data[idx * state_var_num + TP] += w;
for (size_t j = 0; j < class_dim; ++j) { for (size_t j = 0; j < cls_num; ++j) {
accum_states_data[j * state_var_num + TN] += w; accum_states_data[j * state_var_num + TN] += w;
} }
accum_states_data[max_idx * state_var_num + TN] -= w; accum_states_data[idx * state_var_num + TN] -= w;
} else { } else {
accum_states_data[labels_data[i] * state_var_num + FN] += w; accum_states_data[label * state_var_num + FN] += w;
accum_states_data[max_idx * state_var_num + FP] += w; accum_states_data[idx * state_var_num + FP] += w;
for (size_t j = 0; j < class_dim; ++j) { for (size_t j = 0; j < cls_num; ++j) {
accum_states_data[j * state_var_num + TN] += w; accum_states_data[j * state_var_num + TN] += w;
} }
accum_states_data[max_idx * state_var_num + TN] -= w; accum_states_data[idx * state_var_num + TN] -= w;
accum_states_data[labels_data[i] * state_var_num + TN] -= w; accum_states_data[label * state_var_num + TN] -= w;
} }
} }
ComputeMetrics(accum_states_data, batch_metrics_data, state_var_num, ComputeMetrics(accum_states_data, batch_metrics_data, state_var_num,
class_dim); cls_num);
if (states_data) { if (states_data) {
for (size_t i = 0; i < class_dim; ++i) { for (size_t i = 0; i < cls_num; ++i) {
for (size_t j = 0; j < state_var_num; ++j) { for (size_t j = 0; j < state_var_num; ++j) {
size_t idx = i * state_var_num + j; size_t idx = i * state_var_num + j;
accum_states_data[idx] += states_data[idx]; accum_states_data[idx] += states_data[idx];
...@@ -95,7 +95,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> { ...@@ -95,7 +95,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
} }
ComputeMetrics(accum_states_data, accum_metrics_data, state_var_num, ComputeMetrics(accum_states_data, accum_metrics_data, state_var_num,
class_dim); cls_num);
} }
// expose to be reused // expose to be reused
...@@ -122,14 +122,14 @@ class PrecisionRecallKernel : public framework::OpKernel<T> { ...@@ -122,14 +122,14 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
protected: protected:
void ComputeMetrics(const T* states_data, double* metrics_data, void ComputeMetrics(const T* states_data, double* metrics_data,
size_t state_var_num, size_t class_dim) const { size_t state_var_num, size_t cls_num) const {
T total_tp_count = 0; T total_tp_count = 0;
T total_fp_count = 0; T total_fp_count = 0;
T total_fn_count = 0; T total_fn_count = 0;
T macro_avg_precision = 0.0; T macro_avg_precision = 0.0;
T macro_avg_recall = 0.0; T macro_avg_recall = 0.0;
for (size_t i = 0; i < class_dim; ++i) { for (size_t i = 0; i < cls_num; ++i) {
T tp_count = states_data[i * state_var_num + TP]; T tp_count = states_data[i * state_var_num + TP];
T fp_count = states_data[i * state_var_num + FP]; T fp_count = states_data[i * state_var_num + FP];
T fn_count = states_data[i * state_var_num + FN]; T fn_count = states_data[i * state_var_num + FN];
...@@ -139,8 +139,8 @@ class PrecisionRecallKernel : public framework::OpKernel<T> { ...@@ -139,8 +139,8 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
macro_avg_precision += CalcPrecision(tp_count, fp_count); macro_avg_precision += CalcPrecision(tp_count, fp_count);
macro_avg_recall += CalcRecall(tp_count, fn_count); macro_avg_recall += CalcRecall(tp_count, fn_count);
} }
macro_avg_precision /= class_dim; macro_avg_precision /= cls_num;
macro_avg_recall /= class_dim; macro_avg_recall /= cls_num;
T macro_f1_score = CalcF1Score(macro_avg_precision, macro_avg_recall); T macro_f1_score = CalcF1Score(macro_avg_precision, macro_avg_recall);
T micro_avg_precision = CalcPrecision(total_tp_count, total_fp_count); T micro_avg_precision = CalcPrecision(total_tp_count, total_fp_count);
......
...@@ -21,45 +21,44 @@ def calc_f1_score(precision, recall): ...@@ -21,45 +21,44 @@ def calc_f1_score(precision, recall):
return 0.0 return 0.0
def get_states(predictions, labels, weights=None): def get_states(idxs, labels, cls_num, weights=None):
ins_num = predictions.shape[0] ins_num = idxs.shape[0]
class_num = predictions.shape[1]
# TP FP TN FN # TP FP TN FN
states = np.zeros((class_num, 4)).astype('float32') states = np.zeros((cls_num, 4)).astype('float32')
for i in xrange(ins_num): for i in xrange(ins_num):
w = weights[i] if weights is not None else 1.0 w = weights[i] if weights is not None else 1.0
max_idx = np.argmax(predictions[i]) idx = idxs[i][0]
if max_idx == labels[i][0]: label = labels[i][0]
states[max_idx][0] += w if idx == label:
for j in xrange(class_num): states[idx][0] += w
for j in xrange(cls_num):
states[j][2] += w states[j][2] += w
states[max_idx][2] -= w states[idx][2] -= w
else: else:
states[labels[i][0]][3] += w states[label][3] += w
states[max_idx][1] += w states[idx][1] += w
for j in xrange(class_num): for j in xrange(cls_num):
states[j][2] += w states[j][2] += w
states[labels[i][0]][2] -= w states[label][2] -= w
states[max_idx][2] -= w states[idx][2] -= w
return states return states
def compute_metrics(states): def compute_metrics(states, cls_num):
class_num = states.shape[0]
total_tp_count = 0.0 total_tp_count = 0.0
total_fp_count = 0.0 total_fp_count = 0.0
total_fn_count = 0.0 total_fn_count = 0.0
macro_avg_precision = 0.0 macro_avg_precision = 0.0
macro_avg_recall = 0.0 macro_avg_recall = 0.0
for i in xrange(class_num): for i in xrange(cls_num):
total_tp_count += states[i][0] total_tp_count += states[i][0]
total_fp_count += states[i][1] total_fp_count += states[i][1]
total_fn_count += states[i][3] total_fn_count += states[i][3]
macro_avg_precision += calc_precision(states[i][0], states[i][1]) macro_avg_precision += calc_precision(states[i][0], states[i][1])
macro_avg_recall += calc_recall(states[i][0], states[i][3]) macro_avg_recall += calc_recall(states[i][0], states[i][3])
metrics = [] metrics = []
macro_avg_precision /= class_num macro_avg_precision /= cls_num
macro_avg_recall /= class_num macro_avg_recall /= cls_num
metrics.append(macro_avg_precision) metrics.append(macro_avg_precision)
metrics.append(macro_avg_recall) metrics.append(macro_avg_recall)
metrics.append(calc_f1_score(macro_avg_precision, macro_avg_recall)) metrics.append(calc_f1_score(macro_avg_precision, macro_avg_recall))
...@@ -75,15 +74,18 @@ class TestPrecisionRecallOp_0(OpTest): ...@@ -75,15 +74,18 @@ class TestPrecisionRecallOp_0(OpTest):
def setUp(self): def setUp(self):
self.op_type = "precision_recall" self.op_type = "precision_recall"
ins_num = 64 ins_num = 64
class_num = 10 cls_num = 10
predictions = np.random.uniform(0, 1.0, max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
(ins_num, class_num)).astype('float32') idxs = np.random.choice(xrange(cls_num), ins_num).reshape(
labels = np.random.choice(xrange(class_num), ins_num).reshape(
(ins_num, 1)).astype('int32') (ins_num, 1)).astype('int32')
states = get_states(predictions, labels) labels = np.random.choice(xrange(cls_num), ins_num).reshape(
metrics = compute_metrics(states) (ins_num, 1)).astype('int32')
states = get_states(idxs, labels, cls_num)
metrics = compute_metrics(states, cls_num)
self.attrs = {'class_number': cls_num}
self.inputs = {'Predictions': predictions, 'Labels': labels} self.inputs = {'MaxProbs': max_probs, 'Indices': idxs, 'Labels': labels}
self.outputs = { self.outputs = {
'BatchMetrics': metrics, 'BatchMetrics': metrics,
...@@ -99,18 +101,22 @@ class TestPrecisionRecallOp_1(OpTest): ...@@ -99,18 +101,22 @@ class TestPrecisionRecallOp_1(OpTest):
def setUp(self): def setUp(self):
self.op_type = "precision_recall" self.op_type = "precision_recall"
ins_num = 64 ins_num = 64
class_num = 10 cls_num = 10
predictions = np.random.uniform(0, 1.0, max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
(ins_num, class_num)).astype('float32') idxs = np.random.choice(xrange(cls_num), ins_num).reshape(
(ins_num, 1)).astype('int32')
weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
predictions = np.random.random((ins_num, class_num)).astype('float32') labels = np.random.choice(xrange(cls_num), ins_num).reshape(
labels = np.random.choice(xrange(class_num), ins_num).reshape(
(ins_num, 1)).astype('int32') (ins_num, 1)).astype('int32')
states = get_states(predictions, labels, weights) states = get_states(idxs, labels, cls_num, weights)
metrics = compute_metrics(states) metrics = compute_metrics(states, cls_num)
self.attrs = {'class_number': cls_num}
self.inputs = { self.inputs = {
'Predictions': predictions, 'MaxProbs': max_probs,
'Indices': idxs,
'Labels': labels, 'Labels': labels,
'Weights': weights 'Weights': weights
} }
...@@ -129,22 +135,25 @@ class TestPrecisionRecallOp_2(OpTest): ...@@ -129,22 +135,25 @@ class TestPrecisionRecallOp_2(OpTest):
def setUp(self): def setUp(self):
self.op_type = "precision_recall" self.op_type = "precision_recall"
ins_num = 64 ins_num = 64
class_num = 10 cls_num = 10
predictions = np.random.uniform(0, 1.0, max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
(ins_num, class_num)).astype('float32') idxs = np.random.choice(xrange(cls_num), ins_num).reshape(
(ins_num, 1)).astype('int32')
weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
predictions = np.random.random((ins_num, class_num)).astype('float32') labels = np.random.choice(xrange(cls_num), ins_num).reshape(
labels = np.random.choice(xrange(class_num), ins_num).reshape(
(ins_num, 1)).astype('int32') (ins_num, 1)).astype('int32')
states = np.random.randint(0, 30, (class_num, 4)).astype('float32') states = np.random.randint(0, 30, (cls_num, 4)).astype('float32')
accum_states = get_states(predictions, labels, weights) accum_states = get_states(idxs, labels, cls_num, weights)
batch_metrics = compute_metrics(accum_states) batch_metrics = compute_metrics(accum_states, cls_num)
accum_states += states accum_states += states
accum_metrics = compute_metrics(accum_states) accum_metrics = compute_metrics(accum_states, cls_num)
self.attrs = {'class_number': cls_num}
self.inputs = { self.inputs = {
'Predictions': predictions, 'MaxProbs': max_probs,
'Indices': idxs,
'Labels': labels, 'Labels': labels,
'Weights': weights, 'Weights': weights,
'StatesInfo': states 'StatesInfo': states
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册