提交 970613fc 编写于 作者: Y yangyaming

Refine and follow comments.

上级 d2b10cc0
......@@ -22,8 +22,10 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Predictions"),
"Input(Predictions) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("MaxProbs"),
"Input(MaxProbs) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Indices"),
"Input(Indices) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Labels"),
"Input(Labels) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchMetrics"),
......@@ -33,34 +35,36 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("AccumStatesInfo"),
"Output(AccumStatesInfo) should not be null.");
auto predictions_dims = ctx->GetInputDim("Predictions");
int64_t cls_num =
static_cast<int64_t>(ctx->Attrs().Get<int>("class_number"));
auto max_probs_dims = ctx->GetInputDim("MaxProbs");
auto labels_dims = ctx->GetInputDim("Labels");
PADDLE_ENFORCE_EQ(max_probs_dims[1], 1,
"Each instance contains one max probability, so the "
"shape of Input(MaxProbs) should be [batch_size, 1].");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Indices"), max_probs_dims,
"The shape of Input(Indices) should be [batch_size, 1].");
PADDLE_ENFORCE_EQ(max_probs_dims[0], labels_dims[0],
"The 1st dimension of Input(MaxProbs) and "
"Input(Labels) both are batch_size and the shape should "
"be the same.");
PADDLE_ENFORCE_EQ(labels_dims[1], 1,
"The 2nd dimension of Input(Labels) contains instance "
"label and the shape should be equal to 1.");
if (ctx->HasInput("Weights")) {
auto weights_dims = ctx->GetInputDim("Weights");
PADDLE_ENFORCE_EQ(weights_dims,
framework::make_ddim({predictions_dims[0], 1}),
framework::make_ddim({max_probs_dims[0], 1}),
"The shape of Input(Weights) should be "
"[batch_size, 1].");
}
if (ctx->HasInput("StatesInfo")) {
auto states_dims = ctx->GetInputDim("StatesInfo");
PADDLE_ENFORCE_EQ(states_dims,
framework::make_ddim({predictions_dims[1], 4}),
PADDLE_ENFORCE_EQ(states_dims, framework::make_ddim({cls_num, 4}),
"The shape of Input(StatesInfo) should be "
"[class_number, 4].");
}
PADDLE_ENFORCE_EQ(predictions_dims[0], labels_dims[0],
"The 1st dimension of Input(Predictions) and "
"Input(Labels) both are batch_size and the shape should "
"be the same.");
PADDLE_ENFORCE_EQ(labels_dims[1], 1,
"The 2nd dimension of Input(Labels) "
"contains instance label and the shape should be equal "
"to 1");
PADDLE_ENFORCE_GE(predictions_dims[1], 1,
"The shape of Input(Predictions)'s 2nd dimension is "
"equal to class number and should be at least 1.");
// Layouts of BatchMetrics and AccumMetrics both are:
// [
......@@ -72,13 +76,13 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
// Shape of AccumStatesInfo is [class_number, 4]
// The layout of each row is:
// [ TP, FP, TN, FN ]
ctx->SetOutputDim("AccumStatesInfo", {predictions_dims[1], 4});
ctx->SetOutputDim("AccumStatesInfo", {cls_num, 4});
}
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext &ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Predictions")->type());
return framework::ToDataType(ctx.Input<Tensor>("MaxProbs")->type());
}
};
......@@ -87,11 +91,15 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
PrecisionRecallOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Predictions",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
"where N is the batch size and D is the number of classes. "
"Each row contains probabilities for an instance which computed "
"by the previous operator.");
AddInput("MaxProbs",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x 1, "
"where N is the batch size. Each row contains the max probability "
"of an instance which computed by the previous top_k (k=1) "
"operator.");
AddInput("Indices",
"(Tensor, default Tensor<int>), a 2-D tensor with shape N x 1, "
"where N is the batch size. Each row contains the corresponding "
"index which computed by the previous top_k (k=1) operator.");
AddInput("Labels",
"(Tensor, default Tensor<int>), a 2-D tensor with shape N x 1, "
"where N is the batch size. Each element is a label and the "
......@@ -125,9 +133,9 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
"accumulated state variables used to compute metrics. The layout "
"for each class is [true positives, false positives, "
"true negatives, false negatives].");
AddAttr<int>("class_number", "Number of classes to be evaluated.");
AddComment(R"DOC(
When given 'Input(Predictions)' and 'Input(Labels)', this operator can be used
When given 'Input(Indices)' and 'Input(Labels)', this operator can be used
to compute various metrics including:
- macro average precision
- macro average recall
......@@ -141,7 +149,7 @@ false positives and false negatives. Here count of true negatives is not
necessary, but counting it may provide potential usage and the cost is
trivial, so the operator also provides count of true negatives.
We define state as a 2-D tensor with shape [class number, 4]. Each row of a
We define state as a 2-D tensor with shape [class_number, 4]. Each row of a
state contains statistic variables for corresponding class. Layout of each row
is: TP(true positives), FP(false positives), TN(true negatives),
FN(false negatives). If 'Input(Weights)' provided, TP, FP, TN, FN will be
......
......@@ -30,7 +30,7 @@ template <typename Place, typename T>
class PrecisionRecallKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in0 = ctx.Input<Tensor>("Predictions");
auto* in0 = ctx.Input<Tensor>("Indices");
auto* in1 = ctx.Input<Tensor>("Labels");
auto* in2 = ctx.Input<Tensor>("Weights");
auto* in3 = ctx.Input<Tensor>("StatesInfo");
......@@ -38,8 +38,9 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
auto* out1 = ctx.Output<Tensor>("AccumMetrics");
auto* out2 = ctx.Output<Tensor>("AccumStatesInfo");
const T* predictions_data = in0->data<T>();
const int* ids_data = in0->data<int>();
const int* labels_data = in1->data<int>();
size_t cls_num = static_cast<size_t>(ctx.Attr<int>("class_number"));
const T* weights_data = in2 ? in2->data<T>() : nullptr;
const T* states_data = in3 ? in3->data<T>() : nullptr;
double* batch_metrics_data = out0->mutable_data<double>(ctx.GetPlace());
......@@ -50,43 +51,42 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
T* accum_states_data = out2->data<T>();
size_t sample_num = in0->dims()[0];
size_t class_dim = in0->dims()[1];
size_t state_var_num = 4; // TP FP TN FN
// get states info for current batch
for (size_t i = 0; i < sample_num; ++i) {
size_t max_idx = 0;
T max_val = predictions_data[i * class_dim];
for (size_t j = 1; j < class_dim; ++j) {
if (max_val < predictions_data[i * class_dim + j]) {
max_idx = j;
max_val = predictions_data[i * class_dim + j];
}
}
size_t idx = ids_data[i];
size_t label = labels_data[i];
PADDLE_ENFORCE(idx >= 0 && idx < cls_num,
"Class index of each instance should be in "
"[0, class_number).");
PADDLE_ENFORCE(label >= 0 && label < cls_num,
"Label of each instance should be in [0, class_number).");
T w = weights_data ? weights_data[i] : 1.0;
if (max_idx == labels_data[i]) {
accum_states_data[max_idx * state_var_num + TP] += w;
for (size_t j = 0; j < class_dim; ++j) {
if (idx == label) {
accum_states_data[idx * state_var_num + TP] += w;
for (size_t j = 0; j < cls_num; ++j) {
accum_states_data[j * state_var_num + TN] += w;
}
accum_states_data[max_idx * state_var_num + TN] -= w;
accum_states_data[idx * state_var_num + TN] -= w;
} else {
accum_states_data[labels_data[i] * state_var_num + FN] += w;
accum_states_data[max_idx * state_var_num + FP] += w;
for (size_t j = 0; j < class_dim; ++j) {
accum_states_data[label * state_var_num + FN] += w;
accum_states_data[idx * state_var_num + FP] += w;
for (size_t j = 0; j < cls_num; ++j) {
accum_states_data[j * state_var_num + TN] += w;
}
accum_states_data[max_idx * state_var_num + TN] -= w;
accum_states_data[labels_data[i] * state_var_num + TN] -= w;
accum_states_data[idx * state_var_num + TN] -= w;
accum_states_data[label * state_var_num + TN] -= w;
}
}
ComputeMetrics(accum_states_data, batch_metrics_data, state_var_num,
class_dim);
cls_num);
if (states_data) {
for (size_t i = 0; i < class_dim; ++i) {
for (size_t i = 0; i < cls_num; ++i) {
for (size_t j = 0; j < state_var_num; ++j) {
size_t idx = i * state_var_num + j;
accum_states_data[idx] += states_data[idx];
......@@ -95,7 +95,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
}
ComputeMetrics(accum_states_data, accum_metrics_data, state_var_num,
class_dim);
cls_num);
}
// expose to be reused
......@@ -122,14 +122,14 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
protected:
void ComputeMetrics(const T* states_data, double* metrics_data,
size_t state_var_num, size_t class_dim) const {
size_t state_var_num, size_t cls_num) const {
T total_tp_count = 0;
T total_fp_count = 0;
T total_fn_count = 0;
T macro_avg_precision = 0.0;
T macro_avg_recall = 0.0;
for (size_t i = 0; i < class_dim; ++i) {
for (size_t i = 0; i < cls_num; ++i) {
T tp_count = states_data[i * state_var_num + TP];
T fp_count = states_data[i * state_var_num + FP];
T fn_count = states_data[i * state_var_num + FN];
......@@ -139,8 +139,8 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
macro_avg_precision += CalcPrecision(tp_count, fp_count);
macro_avg_recall += CalcRecall(tp_count, fn_count);
}
macro_avg_precision /= class_dim;
macro_avg_recall /= class_dim;
macro_avg_precision /= cls_num;
macro_avg_recall /= cls_num;
T macro_f1_score = CalcF1Score(macro_avg_precision, macro_avg_recall);
T micro_avg_precision = CalcPrecision(total_tp_count, total_fp_count);
......
......@@ -21,45 +21,44 @@ def calc_f1_score(precision, recall):
return 0.0
def get_states(predictions, labels, weights=None):
ins_num = predictions.shape[0]
class_num = predictions.shape[1]
def get_states(idxs, labels, cls_num, weights=None):
ins_num = idxs.shape[0]
# TP FP TN FN
states = np.zeros((class_num, 4)).astype('float32')
states = np.zeros((cls_num, 4)).astype('float32')
for i in xrange(ins_num):
w = weights[i] if weights is not None else 1.0
max_idx = np.argmax(predictions[i])
if max_idx == labels[i][0]:
states[max_idx][0] += w
for j in xrange(class_num):
idx = idxs[i][0]
label = labels[i][0]
if idx == label:
states[idx][0] += w
for j in xrange(cls_num):
states[j][2] += w
states[max_idx][2] -= w
states[idx][2] -= w
else:
states[labels[i][0]][3] += w
states[max_idx][1] += w
for j in xrange(class_num):
states[label][3] += w
states[idx][1] += w
for j in xrange(cls_num):
states[j][2] += w
states[labels[i][0]][2] -= w
states[max_idx][2] -= w
states[label][2] -= w
states[idx][2] -= w
return states
def compute_metrics(states):
class_num = states.shape[0]
def compute_metrics(states, cls_num):
total_tp_count = 0.0
total_fp_count = 0.0
total_fn_count = 0.0
macro_avg_precision = 0.0
macro_avg_recall = 0.0
for i in xrange(class_num):
for i in xrange(cls_num):
total_tp_count += states[i][0]
total_fp_count += states[i][1]
total_fn_count += states[i][3]
macro_avg_precision += calc_precision(states[i][0], states[i][1])
macro_avg_recall += calc_recall(states[i][0], states[i][3])
metrics = []
macro_avg_precision /= class_num
macro_avg_recall /= class_num
macro_avg_precision /= cls_num
macro_avg_recall /= cls_num
metrics.append(macro_avg_precision)
metrics.append(macro_avg_recall)
metrics.append(calc_f1_score(macro_avg_precision, macro_avg_recall))
......@@ -75,15 +74,18 @@ class TestPrecisionRecallOp_0(OpTest):
def setUp(self):
self.op_type = "precision_recall"
ins_num = 64
class_num = 10
predictions = np.random.uniform(0, 1.0,
(ins_num, class_num)).astype('float32')
labels = np.random.choice(xrange(class_num), ins_num).reshape(
cls_num = 10
max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
idxs = np.random.choice(xrange(cls_num), ins_num).reshape(
(ins_num, 1)).astype('int32')
states = get_states(predictions, labels)
metrics = compute_metrics(states)
labels = np.random.choice(xrange(cls_num), ins_num).reshape(
(ins_num, 1)).astype('int32')
states = get_states(idxs, labels, cls_num)
metrics = compute_metrics(states, cls_num)
self.attrs = {'class_number': cls_num}
self.inputs = {'Predictions': predictions, 'Labels': labels}
self.inputs = {'MaxProbs': max_probs, 'Indices': idxs, 'Labels': labels}
self.outputs = {
'BatchMetrics': metrics,
......@@ -99,18 +101,22 @@ class TestPrecisionRecallOp_1(OpTest):
def setUp(self):
self.op_type = "precision_recall"
ins_num = 64
class_num = 10
predictions = np.random.uniform(0, 1.0,
(ins_num, class_num)).astype('float32')
cls_num = 10
max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
idxs = np.random.choice(xrange(cls_num), ins_num).reshape(
(ins_num, 1)).astype('int32')
weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
predictions = np.random.random((ins_num, class_num)).astype('float32')
labels = np.random.choice(xrange(class_num), ins_num).reshape(
labels = np.random.choice(xrange(cls_num), ins_num).reshape(
(ins_num, 1)).astype('int32')
states = get_states(predictions, labels, weights)
metrics = compute_metrics(states)
states = get_states(idxs, labels, cls_num, weights)
metrics = compute_metrics(states, cls_num)
self.attrs = {'class_number': cls_num}
self.inputs = {
'Predictions': predictions,
'MaxProbs': max_probs,
'Indices': idxs,
'Labels': labels,
'Weights': weights
}
......@@ -129,22 +135,25 @@ class TestPrecisionRecallOp_2(OpTest):
def setUp(self):
self.op_type = "precision_recall"
ins_num = 64
class_num = 10
predictions = np.random.uniform(0, 1.0,
(ins_num, class_num)).astype('float32')
cls_num = 10
max_probs = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
idxs = np.random.choice(xrange(cls_num), ins_num).reshape(
(ins_num, 1)).astype('int32')
weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
predictions = np.random.random((ins_num, class_num)).astype('float32')
labels = np.random.choice(xrange(class_num), ins_num).reshape(
labels = np.random.choice(xrange(cls_num), ins_num).reshape(
(ins_num, 1)).astype('int32')
states = np.random.randint(0, 30, (class_num, 4)).astype('float32')
states = np.random.randint(0, 30, (cls_num, 4)).astype('float32')
accum_states = get_states(predictions, labels, weights)
batch_metrics = compute_metrics(accum_states)
accum_states = get_states(idxs, labels, cls_num, weights)
batch_metrics = compute_metrics(accum_states, cls_num)
accum_states += states
accum_metrics = compute_metrics(accum_states)
accum_metrics = compute_metrics(accum_states, cls_num)
self.attrs = {'class_number': cls_num}
self.inputs = {
'Predictions': predictions,
'MaxProbs': max_probs,
'Indices': idxs,
'Labels': labels,
'Weights': weights,
'StatesInfo': states
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册