提交 b02bc6ca 编写于 作者: M malin10

update metrics

上级 bee67d95
...@@ -21,6 +21,7 @@ from paddlerec.core.metric import Metric ...@@ -21,6 +21,7 @@ from paddlerec.core.metric import Metric
from paddle.fluid.layers import nn, accuracy from paddle.fluid.layers import nn, accuracy
from paddle.fluid.initializer import Constant from paddle.fluid.initializer import Constant
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.tensor import Variable
class AUC(Metric): class AUC(Metric):
...@@ -30,12 +31,22 @@ class AUC(Metric): ...@@ -30,12 +31,22 @@ class AUC(Metric):
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """ """ """
if "input" not in kwargs or "label" not in kwargs:
raise ValueError("AUC expect input and label as inputs.")
predict = kwargs.get("input") predict = kwargs.get("input")
label = kwargs.get("label") label = kwargs.get("label")
curve = kwargs.get("curve", 'ROC') curve = kwargs.get("curve", 'ROC')
num_thresholds = kwargs.get("num_thresholds", 2**12 - 1) num_thresholds = kwargs.get("num_thresholds", 2**12 - 1)
topk = kwargs.get("topk", 1) topk = kwargs.get("topk", 1)
slide_steps = kwargs.get("slide_steps", 1) slide_steps = kwargs.get("slide_steps", 1)
if not isinstance(predict, Variable):
raise ValueError("input must be Variable, but received %s" %
type(predict))
if not isinstance(label, Variable):
raise ValueError("label must be Variable, but received %s" %
type(label))
auc_out, batch_auc_out, [ auc_out, batch_auc_out, [
batch_stat_pos, batch_stat_neg, stat_pos, stat_neg batch_stat_pos, batch_stat_neg, stat_pos, stat_neg
] = fluid.layers.auc(predict, ] = fluid.layers.auc(predict,
......
...@@ -21,6 +21,7 @@ from paddlerec.core.metric import Metric ...@@ -21,6 +21,7 @@ from paddlerec.core.metric import Metric
from paddle.fluid.layers import nn, accuracy from paddle.fluid.layers import nn, accuracy
from paddle.fluid.initializer import Constant from paddle.fluid.initializer import Constant
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.tensor import Variable
class PrecisionRecall(Metric): class PrecisionRecall(Metric):
...@@ -30,12 +31,23 @@ class PrecisionRecall(Metric): ...@@ -30,12 +31,23 @@ class PrecisionRecall(Metric):
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """ """ """
helper = LayerHelper("PaddleRec_PrecisionRecall", **kwargs) if "input" not in kwargs or "label" not in kwargs or "class_num" not in kwargs:
raise ValueError(
"PrecisionRecall expect input, label and class_num as inputs.")
predict = kwargs.get("input") predict = kwargs.get("input")
origin_label = kwargs.get("label") label = kwargs.get("label")
label = fluid.layers.cast(origin_label, dtype="int32")
label.stop_gradient = True
num_cls = kwargs.get("class_num") num_cls = kwargs.get("class_num")
if not isinstance(predict, Variable):
raise ValueError("input must be Variable, but received %s" %
type(predict))
if not isinstance(label, Variable):
raise ValueError("label must be Variable, but received %s" %
type(label))
helper = LayerHelper("PaddleRec_PrecisionRecall", **kwargs)
label = fluid.layers.cast(label, dtype="int32")
label.stop_gradient = True
max_probs, indices = fluid.layers.nn.topk(predict, k=1) max_probs, indices = fluid.layers.nn.topk(predict, k=1)
indices = fluid.layers.cast(indices, dtype="int32") indices = fluid.layers.cast(indices, dtype="int32")
indices.stop_gradient = True indices.stop_gradient = True
......
...@@ -21,6 +21,7 @@ from paddlerec.core.metric import Metric ...@@ -21,6 +21,7 @@ from paddlerec.core.metric import Metric
from paddle.fluid.layers import nn, accuracy from paddle.fluid.layers import nn, accuracy
from paddle.fluid.initializer import Constant from paddle.fluid.initializer import Constant
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.tensor import Variable
class PosNegRatio(Metric): class PosNegRatio(Metric):
...@@ -31,9 +32,19 @@ class PosNegRatio(Metric): ...@@ -31,9 +32,19 @@ class PosNegRatio(Metric):
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """ """ """
helper = LayerHelper("PaddleRec_PosNegRatio", **kwargs) helper = LayerHelper("PaddleRec_PosNegRatio", **kwargs)
if "pos_score" not in kwargs or "neg_score" not in kwargs:
raise ValueError(
"PosNegRatio expect pos_score and neg_score as inputs.")
pos_score = kwargs.get('pos_score') pos_score = kwargs.get('pos_score')
neg_score = kwargs.get('neg_score') neg_score = kwargs.get('neg_score')
if not isinstance(pos_score, Variable):
raise ValueError("pos_score must be Variable, but received %s" %
type(pos_score))
if not isinstance(neg_score, Variable):
raise ValueError("neg_score must be Variable, but received %s" %
type(neg_score))
wrong = fluid.layers.cast( wrong = fluid.layers.cast(
fluid.layers.less_equal(pos_score, neg_score), dtype='float32') fluid.layers.less_equal(pos_score, neg_score), dtype='float32')
wrong_cnt = fluid.layers.reduce_sum(wrong) wrong_cnt = fluid.layers.reduce_sum(wrong)
......
...@@ -21,6 +21,7 @@ from paddlerec.core.metric import Metric ...@@ -21,6 +21,7 @@ from paddlerec.core.metric import Metric
from paddle.fluid.layers import nn, accuracy from paddle.fluid.layers import nn, accuracy
from paddle.fluid.initializer import Constant from paddle.fluid.initializer import Constant
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.tensor import Variable
class RecallK(Metric): class RecallK(Metric):
...@@ -30,71 +31,58 @@ class RecallK(Metric): ...@@ -30,71 +31,58 @@ class RecallK(Metric):
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """ """ """
if "input" not in kwargs or "label" not in kwargs:
raise ValueError("RecallK expect input and label as inputs.")
predict = kwargs.get('input')
label = kwargs.get('label')
k = kwargs.get("k", 20)
if not isinstance(predict, Variable):
raise ValueError("input must be Variable, but received %s" %
type(predict))
if not isinstance(label, Variable):
raise ValueError("label must be Variable, but received %s" %
type(label))
helper = LayerHelper("PaddleRec_RecallK", **kwargs) helper = LayerHelper("PaddleRec_RecallK", **kwargs)
predict = kwargs.get("input") batch_accuracy = accuracy(predict, label, k)
origin_label = kwargs.get("label") global_ins_cnt, _ = helper.create_or_get_global_variable(
label = fluid.layers.cast(origin_label, dtype="int32") name="ins_cnt", persistable=True, dtype='float32', shape=[1])
label.stop_gradient = True global_pos_cnt, _ = helper.create_or_get_global_variable(
num_cls = kwargs.get("class_num") name="pos_cnt", persistable=True, dtype='float32', shape=[1])
max_probs, indices = fluid.layers.nn.topk(predict, k=1)
indices = fluid.layers.cast(indices, dtype="int32") for var in [global_ins_cnt, global_pos_cnt]:
indices.stop_gradient = True helper.set_variable_initializer(
var, Constant(
states_info, _ = helper.create_or_get_global_variable( value=0.0, force_cpu=True))
name="states_info",
persistable=True, tmp_ones = fluid.layers.fill_constant(
dtype='float32', shape=fluid.layers.shape(label), dtype="float32", value=1.0)
shape=[num_cls, 4]) batch_ins = fluid.layers.reduce_sum(tmp_ones)
states_info.stop_gradient = True batch_pos = batch_ins * batch_accuracy
helper.set_variable_initializer(
states_info, Constant(
value=0.0, force_cpu=True))
batch_metrics, _ = helper.create_or_get_global_variable(
name="batch_metrics",
persistable=False,
dtype='float32',
shape=[6])
accum_metrics, _ = helper.create_or_get_global_variable(
name="global_metrics",
persistable=False,
dtype='float32',
shape=[6])
batch_states = fluid.layers.fill_constant(
shape=[num_cls, 4], value=0.0, dtype="float32")
batch_states.stop_gradient = True
helper.append_op( helper.append_op(
type="precision_recall", type="elementwise_add",
attrs={'class_number': num_cls}, inputs={"X": [global_ins_cnt],
inputs={ "Y": [batch_ins]},
'MaxProbs': [max_probs], outputs={"Out": [global_ins_cnt]})
'Indices': [indices],
'Labels': [label],
'StatesInfo': [states_info]
},
outputs={
'BatchMetrics': [batch_metrics],
'AccumMetrics': [accum_metrics],
'AccumStatesInfo': [batch_states]
})
helper.append_op( helper.append_op(
type="assign", type="elementwise_add",
inputs={'X': [batch_states]}, inputs={"X": [global_pos_cnt],
outputs={'Out': [states_info]}) "Y": [batch_pos]},
outputs={"Out": [global_pos_cnt]})
batch_states.stop_gradient = True self.acc = global_pos_cnt / global_ins_cnt
states_info.stop_gradient = True
self._need_clear_list = [("states_info", "float32")] self._need_clear_list = [("ins_cnt", "float32"),
("pos_cnt", "float32")]
metric_name = "Recall@%d_ACC" % k
self.metrics = dict() self.metrics = dict()
self.metrics["precision_recall_f1"] = accum_metrics self.metrics["ins_cnt"] = global_ins_cnt
self.metrics["accum_states"] = states_info self.metrics["pos_cnt"] = global_pos_cnt
self.metrics[metric_name] = self.acc
# self.metrics["batch_metrics"] = batch_metrics
def get_result(self): def get_result(self):
return self.metrics return self.metrics
...@@ -25,10 +25,14 @@ class TestAUC(unittest.TestCase): ...@@ -25,10 +25,14 @@ class TestAUC(unittest.TestCase):
def setUp(self): def setUp(self):
self.ins_num = 64 self.ins_num = 64
self.batch_nums = 3 self.batch_nums = 3
self.probs = np.random.uniform(0, 1.0,
(self.ins_num, 2)).astype('float32') self.datas = []
self.labels = np.random.choice(range(2), self.ins_num).reshape( for i in range(self.batch_nums):
(self.ins_num, 1)).astype('int64') probs = np.random.uniform(0, 1.0,
(self.ins_num, 2)).astype('float32')
labels = np.random.choice(range(2), self.ins_num).reshape(
(self.ins_num, 1)).astype('int64')
self.datas.append((probs, labels))
self.place = fluid.core.CPUPlace() self.place = fluid.core.CPUPlace()
...@@ -37,7 +41,7 @@ class TestAUC(unittest.TestCase): ...@@ -37,7 +41,7 @@ class TestAUC(unittest.TestCase):
curve='ROC', curve='ROC',
num_thresholds=self.num_thresholds) num_thresholds=self.num_thresholds)
for i in range(self.batch_nums): for i in range(self.batch_nums):
python_auc.update(self.probs, self.labels) python_auc.update(self.datas[i][0], self.datas[i][1])
self.auc = np.array(python_auc.eval()) self.auc = np.array(python_auc.eval())
...@@ -65,15 +69,21 @@ class TestAUC(unittest.TestCase): ...@@ -65,15 +69,21 @@ class TestAUC(unittest.TestCase):
exe = fluid.Executor(self.place) exe = fluid.Executor(self.place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
for i in range(self.batch_nums): for i in range(self.batch_nums):
outs = exe.run(fluid.default_main_program(), outs = exe.run(
feed={'predict': self.probs, fluid.default_main_program(),
'label': self.labels}, feed={'predict': self.datas[i][0],
fetch_list=fetch_vars, 'label': self.datas[i][1]},
return_numpy=True) fetch_list=fetch_vars,
return_numpy=True)
outs = dict(zip(metric_keys, outs)) outs = dict(zip(metric_keys, outs))
self.assertTrue(np.allclose(outs['AUC'], self.auc)) self.assertTrue(np.allclose(outs['AUC'], self.auc))
def test_exception(self):
self.assertRaises(Exception, AUC)
self.assertRaises(
Exception, AUC, input=self.datas[0][0], label=self.datas[0][1]),
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -21,12 +21,12 @@ import paddle ...@@ -21,12 +21,12 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
class TestAUC(unittest.TestCase): class TestPosNegRatio(unittest.TestCase):
def setUp(self): def setUp(self):
self.ins_num = 64 self.ins_num = 64
self.batch_nums = 3 self.batch_nums = 3
self.probs = [] self.datas = []
self.right_cnt = 0.0 self.right_cnt = 0.0
self.wrong_cnt = 0.0 self.wrong_cnt = 0.0
for i in range(self.batch_nums): for i in range(self.batch_nums):
...@@ -40,7 +40,7 @@ class TestAUC(unittest.TestCase): ...@@ -40,7 +40,7 @@ class TestAUC(unittest.TestCase):
'int32') 'int32')
self.right_cnt += float(right_cnt) self.right_cnt += float(right_cnt)
self.wrong_cnt += float(wrong_cnt) self.wrong_cnt += float(wrong_cnt)
self.probs.append((pos_score, neg_score)) self.datas.append((pos_score, neg_score))
self.place = fluid.core.CPUPlace() self.place = fluid.core.CPUPlace()
...@@ -68,8 +68,8 @@ class TestAUC(unittest.TestCase): ...@@ -68,8 +68,8 @@ class TestAUC(unittest.TestCase):
for i in range(self.batch_nums): for i in range(self.batch_nums):
outs = exe.run(fluid.default_main_program(), outs = exe.run(fluid.default_main_program(),
feed={ feed={
'pos_score': self.probs[i][0], 'pos_score': self.datas[i][0],
'neg_score': self.probs[i][1] 'neg_score': self.datas[i][1]
}, },
fetch_list=fetch_vars, fetch_list=fetch_vars,
return_numpy=True) return_numpy=True)
...@@ -82,6 +82,14 @@ class TestAUC(unittest.TestCase): ...@@ -82,6 +82,14 @@ class TestAUC(unittest.TestCase):
np.array((self.right_cnt + 1.0) / (self.wrong_cnt + 1.0 np.array((self.right_cnt + 1.0) / (self.wrong_cnt + 1.0
)))) ))))
def test_exception(self):
self.assertRaises(Exception, PosNegRatio)
self.assertRaises(
Exception,
PosNegRatio,
pos_score=self.datas[0][0],
neg_score=self.datas[0][1]),
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -148,6 +148,15 @@ class TestPrecisionRecall(unittest.TestCase): ...@@ -148,6 +148,15 @@ class TestPrecisionRecall(unittest.TestCase):
self.assertTrue(np.allclose(outs['accum_states'], self.states)) self.assertTrue(np.allclose(outs['accum_states'], self.states))
self.assertTrue(np.allclose(outs['precision_recall_f1'], self.metrics)) self.assertTrue(np.allclose(outs['precision_recall_f1'], self.metrics))
def test_exception(self):
self.assertRaises(Exception, PrecisionRecall)
self.assertRaises(
Exception,
PrecisionRecall,
input=self.datas[0][0],
label=self.datas[0][1],
class_num=self.cls_num)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -85,6 +85,12 @@ class TestRecallK(unittest.TestCase): ...@@ -85,6 +85,12 @@ class TestRecallK(unittest.TestCase):
np.array(self.match_num / (self.ins_num * np.array(self.match_num / (self.ins_num *
self.batch_nums)))) self.batch_nums))))
def test_exception(self):
self.assertRaises(Exception, RecallK)
self.assertRaises(
Exception, RecallK, input=self.datas[0][0],
label=self.datas[0][1]),
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册