From 87233cbc4f0d4c114cc3dd03df82741836a5efaf Mon Sep 17 00:00:00 2001 From: pkpk Date: Tue, 24 Sep 2019 18:38:06 +0800 Subject: [PATCH] fix metricbase (#1316) --- doc/fluid/api_cn/metrics_cn/MetricBase_cn.rst | 51 ++++++++++++------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/doc/fluid/api_cn/metrics_cn/MetricBase_cn.rst b/doc/fluid/api_cn/metrics_cn/MetricBase_cn.rst index 717405e42..d24aa0470 100644 --- a/doc/fluid/api_cn/metrics_cn/MetricBase_cn.rst +++ b/doc/fluid/api_cn/metrics_cn/MetricBase_cn.rst @@ -5,43 +5,60 @@ MetricBase .. py:class:: paddle.fluid.metrics.MetricBase(name) -所有Metrics的基类。MetricBase为模型估计方法定义一组接口。Metrics累积连续的两个minibatch之间的度量状态,对每个minibatch用最新接口将当前minibatch值添加到全局状态。用eval函数来计算last reset()或者scratch on()中累积的度量值。如果需要定制一个新的metric,请继承自MetricBase和自定义实现类。 +在评估神经网络效果的时候,由于我们常常需要把测试数据切分成mini-batch,并逐次将每个mini-batch送入神经网络进行预测和评估,因此我们每次只能获得当前batch下的评估结果,而并不能一次性获得整个测试集的评估结果。paddle.fluid.metrics正是为了解决这些问题而设计的,大部分paddle.fluid.metrics下的类都具有如下功能: -参数: - - **name** (str) - metric实例名。例如准确率(accuracy)。如果想区分一个模型里不同的metrics,则需要实例名。 +1. 接受模型对一个batch的预测结果(numpy.array)和这个batch的原始标签(numpy.array)作为输入,并进行特定的计算(如计算准确率,召回率等)。 -.. py:method:: reset() +2. 将当前batch评估结果和历史评估结果累计起来,以获取目前处理过的所有batch的整体评估结果。 - reset()清除度量(metric)的状态(state)。默认情况下,状态(state)包含没有 ``_`` 前缀的metric。reset将这些状态设置为初始状态。如果不想使用隐式命名规则,请自定义reset接口。 +MetricBase是所有paddle.fluid.metrics下定义的所有python类的基类,它定义了一组接口,并需要所有继承他的类实现具体的计算逻辑,包括: -.. py:method:: get_config() +1. update(preds, labels):给定当前计算当前batch的预测结果(preds)和标签(labels),计算这个batch的评估结果。 + +2. eval():合并当前累积的每个batch的评估结果,并返回整体评估结果。 + +3. reset():清空累积的每个batch的评估结果。 + +.. py:method:: __init__(name) + +构造函数,参数name表示当前创建的评估器的名字。 + +参数: + - **name** (str) - 当前创建的评估器的名字,用于区分不同的评估器,例如准确率(accuracy)或者其他自定义名字(如,my_evaluator)。 + +返回:一个python对象,表示一个具体的评估器。 -获取度量(metric)状态和当前状态。状态(state)包含没有 ``_`` 前缀的成员。 - -返回:metric对应到state的字典 +返回类型:python对象 + +.. py:method:: reset() -返回类型:字典(dict) +空累积的每个batch的评估结果。 +返回:无 .. py:method:: update(preds,labels) -更新每个minibatch的度量状态(metric states),用户可通过Python或者C++操作符计算minibatch度量值(metric)。 +给定当前计算当前batch的预测结果(preds)和标签(labels),计算这个batch的评估结果,并将这个评估结果在评估器内部记录下来,注意update函数并不会返回评估结果。 参数: - - **preds** (numpy.array) - 当前minibatch的预测 - - **labels** (numpy.array) - 当前minibatch的标签,如果标签为one-hot或者soft-label,应该自定义相应的更新规则。 + - **preds** (numpy.array) - 当前minibatch的预测结果。 + - **labels** (numpy.array) - 当前minibatch的标签。 + +返回:无 .. py:method:: eval() -基于累积状态(accumulated states)评估当前度量(current metric)。 +合并当前累积的每个batch的评估结果,并返回整体评估结果。 -返回:metrics(Python中) +返回:当前累积batch的整体评估结果。 返回类型:float|list(float)|numpy.array +.. py:method:: get_config() +获取当前评估器的状态,特指评估器内部没有 ``_`` 前缀的所有成员变量。 +返回:一个python字典,包含了当前评估器内部的状态。 - - +返回类型:python字典(dict) -- GitLab