未验证 提交 27decacb 编写于 作者: H hutuxian 提交者: GitHub

fix aucop stat shape (#21846)

* fix stat shape back in global auc scenario
* add UT to cover global auc
上级 0d61653c
...@@ -49,12 +49,15 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -49,12 +49,15 @@ class AucOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("AUC", {1}); ctx->SetOutputDim("AUC", {1});
// slide_steps = slide_steps == 0 ? 1 : slide_steps; if (slide_steps) {
int need_batch_id = slide_steps ? 1 : 0; ctx->SetOutputDim("StatPosOut",
ctx->SetOutputDim("StatPosOut", {(1 + slide_steps) * num_pred_buckets + 1});
{(1 + slide_steps) * num_pred_buckets + need_batch_id}); ctx->SetOutputDim("StatNegOut",
ctx->SetOutputDim("StatNegOut", {(1 + slide_steps) * num_pred_buckets + 1});
{(1 + slide_steps) * num_pred_buckets + need_batch_id}); } else {
ctx->SetOutputDim("StatPosOut", {1, num_pred_buckets});
ctx->SetOutputDim("StatNegOut", {1, num_pred_buckets});
}
} }
protected: protected:
......
...@@ -184,9 +184,9 @@ def auc(input, ...@@ -184,9 +184,9 @@ def auc(input,
# for global auc # for global auc
# Needn't maintain the batch id # Needn't maintain the batch id
stat_pos = helper.create_global_variable( stat_pos = helper.create_global_variable(
persistable=True, dtype='int64', shape=[num_thresholds + 1]) persistable=True, dtype='int64', shape=[1, num_thresholds + 1])
stat_neg = helper.create_global_variable( stat_neg = helper.create_global_variable(
persistable=True, dtype='int64', shape=[num_thresholds + 1]) persistable=True, dtype='int64', shape=[1, num_thresholds + 1])
for var in [batch_stat_pos, batch_stat_neg, stat_pos, stat_neg]: for var in [batch_stat_pos, batch_stat_neg, stat_pos, stat_neg]:
helper.set_variable_initializer( helper.set_variable_initializer(
......
...@@ -64,5 +64,45 @@ class TestAucOp(OpTest): ...@@ -64,5 +64,45 @@ class TestAucOp(OpTest):
self.check_output() self.check_output()
class TestGlobalAucOp(OpTest):
def setUp(self):
self.op_type = "auc"
pred = np.random.random((128, 2)).astype("float32")
labels = np.random.randint(0, 2, (128, 1)).astype("int64")
num_thresholds = 200
slide_steps = 0
stat_pos = np.zeros((1, (num_thresholds + 1))).astype("int64")
stat_neg = np.zeros((1, (num_thresholds + 1))).astype("int64")
self.inputs = {
'Predict': pred,
'Label': labels,
"StatPos": stat_pos,
"StatNeg": stat_neg
}
self.attrs = {
'curve': 'ROC',
'num_thresholds': num_thresholds,
"slide_steps": slide_steps
}
python_auc = metrics.Auc(name="auc",
curve='ROC',
num_thresholds=num_thresholds)
python_auc.update(pred, labels)
pos = python_auc._stat_pos
neg = python_auc._stat_neg
self.outputs = {
'AUC': np.array(python_auc.eval()),
'StatPosOut': np.array(pos),
'StatNegOut': np.array(neg)
}
def test_check_output(self):
self.check_output()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -67,5 +67,48 @@ class TestAucSinglePredOp(OpTest): ...@@ -67,5 +67,48 @@ class TestAucSinglePredOp(OpTest):
self.check_output() self.check_output()
class TestAucGlobalSinglePredOp(OpTest):
def setUp(self):
self.op_type = "auc"
pred = np.random.random((128, 2)).astype("float32")
pred0 = pred[:, 0].reshape(128, 1)
labels = np.random.randint(0, 2, (128, 1)).astype("int64")
num_thresholds = 200
slide_steps = 0
stat_pos = np.zeros((1, (num_thresholds + 1))).astype("int64")
stat_neg = np.zeros((1, (num_thresholds + 1))).astype("int64")
self.inputs = {
'Predict': pred0,
'Label': labels,
"StatPos": stat_pos,
"StatNeg": stat_neg
}
self.attrs = {
'curve': 'ROC',
'num_thresholds': num_thresholds,
"slide_steps": slide_steps
}
python_auc = metrics.Auc(name="auc",
curve='ROC',
num_thresholds=num_thresholds)
for i in range(128):
pred[i][1] = pred[i][0]
python_auc.update(pred, labels)
pos = python_auc._stat_pos
neg = python_auc._stat_neg
self.outputs = {
'AUC': np.array(python_auc.eval()),
'StatPosOut': np.array(pos),
'StatNegOut': np.array(neg)
}
def test_check_output(self):
self.check_output()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册