提交 1e09619d 编写于 作者: D dzhwinter

"highlight code"

上级 52a615fd
......@@ -21,20 +21,20 @@ Fluid中包含了常用分类指标,例如Precision, Recall, Accuracy等,更
.. code-block:: python
labels = fluid.layers.data(name="data", shape=[1], dtype="int32")
data = fluid.layers.data(name="data", shape=[32, 32], dtype="int32")
pred = fluid.layers.fc(input=data, size=1000, act="tanh")
comp = fluid.metrics.CompositeMetric()
acc = fluid.metrics.Precision()
recall = fluid.metrics.Recall()
comp.add_metric(acc)
comp.add_metric(recall)
for pass in range(PASSES):
comp.reset()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
comp.update(preds=preds, labels=labels)
numpy_acc, numpy_recall = comp.eval()
>>> labels = fluid.layers.data(name="data", shape=[1], dtype="int32")
>>> data = fluid.layers.data(name="data", shape=[32, 32], dtype="int32")
>>> pred = fluid.layers.fc(input=data, size=1000, act="tanh")
>>> comp = fluid.metrics.CompositeMetric()
>>> acc = fluid.metrics.Precision()
>>> recall = fluid.metrics.Recall()
>>> comp.add_metric(acc)
>>> comp.add_metric(recall)
>>> for pass in range(PASSES):
>>> comp.reset()
>>> for data in train_reader():
>>> loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
>>> comp.update(preds=preds, labels=labels)
>>> numpy_acc, numpy_recall = comp.eval()
其他任务例如MultiTask Learning,Metric Learning,Learning To Rank各种指标构造方法请参考API文档。
......@@ -46,20 +46,20 @@ Fluid支持自定义指标,灵活支持各类计算任务。下文通过一个
.. code-block:: python
class MyMetric(MetricBase):
def __init__(self, name=None):
super(MyMetric, self).__init__(name)
self.counter = 0 # simple counter
>>> class MyMetric(MetricBase):
>>> def __init__(self, name=None):
>>> super(MyMetric, self).__init__(name)
>>> self.counter = 0 # simple counter
def reset(self):
self.counter = 0
>>> def reset(self):
>>> self.counter = 0
def update(self, preds, labels):
if not _is_numpy_(preds):
raise ValueError("The 'preds' must be a numpy ndarray.")
if not _is_numpy_(labels):
raise ValueError("The 'labels' must be a numpy ndarray.")
self.counter += sum(preds == labels)
>>> def update(self, preds, labels):
>>> if not _is_numpy_(preds):
>>> raise ValueError("The 'preds' must be a numpy ndarray.")
>>> if not _is_numpy_(labels):
>>> raise ValueError("The 'labels' must be a numpy ndarray.")
>>> self.counter += sum(preds == labels)
def eval(self):
return self.counter
>>> def eval(self):
>>> return self.counter
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册