diff --git a/core/metrics/__init__.py b/core/metrics/__init__.py index ef015087b23774b06da81053538459291847cdb2..2820518c02ebffd1c0c4e847bb30b14cf0a689f9 100755 --- a/core/metrics/__init__.py +++ b/core/metrics/__init__.py @@ -14,6 +14,7 @@ from .recall_k import RecallK from .pairwise_pn import PosNegRatio -from .binary_class import * +from .precision_recall import PrecisionRecall +from .auc import AUC -__all__ = ['RecallK', 'PosNegRatio'] + binary_class.__all__ +__all__ = ['RecallK', 'PosNegRatio', 'AUC', 'PrecisionRecall'] diff --git a/core/metrics/binary_class/auc.py b/core/metrics/auc.py similarity index 91% rename from core/metrics/binary_class/auc.py rename to core/metrics/auc.py index 129b8bc7eb0854f3a19b2ae0c2a101ccaf7d1d74..672a1ffa84291782963d32bd58875170253e41d1 100755 --- a/core/metrics/binary_class/auc.py +++ b/core/metrics/auc.py @@ -26,34 +26,31 @@ class AUC(Metric): Metric For Fluid Model """ - def __init__(self, **kwargs): + def __init__(self, + input, + label, + curve='ROC', + num_thresholds=2**12 - 1, + topk=1, + slide_steps=1): """ """ - if "input" not in kwargs or "label" not in kwargs: - raise ValueError("AUC expect input and label as inputs.") - predict = kwargs.get("input") - label = kwargs.get("label") - curve = kwargs.get("curve", 'ROC') - num_thresholds = kwargs.get("num_thresholds", 2**12 - 1) - topk = kwargs.get("topk", 1) - slide_steps = kwargs.get("slide_steps", 1) - - if not isinstance(predict, Variable): + if not isinstance(input, Variable): raise ValueError("input must be Variable, but received %s" % - type(predict)) + type(input)) if not isinstance(label, Variable): raise ValueError("label must be Variable, but received %s" % type(label)) auc_out, batch_auc_out, [ batch_stat_pos, batch_stat_neg, stat_pos, stat_neg - ] = fluid.layers.auc(predict, + ] = fluid.layers.auc(input, label, curve=curve, num_thresholds=num_thresholds, topk=topk, slide_steps=slide_steps) - prob = fluid.layers.slice(predict, axes=[1], starts=[1], ends=[2]) + prob = fluid.layers.slice(input, axes=[1], starts=[1], ends=[2]) label_cast = fluid.layers.cast(label, dtype="float32") label_cast.stop_gradient = True sqrerr, abserr, prob, q, pos, total = \ diff --git a/core/metrics/binary_class/__init__.py b/core/metrics/binary_class/__init__.py deleted file mode 100755 index 5a3354f8ed0b36354033b31dac61eb12838963c9..0000000000000000000000000000000000000000 --- a/core/metrics/binary_class/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .auc import AUC -from .precision_recall import PrecisionRecall - -__all__ = ['PrecisionRecall', 'AUC'] diff --git a/core/metrics/pairwise_pn.py b/core/metrics/pairwise_pn.py index 156a86063efbe8380fa1314fa7613aa378f35302..fb10e1fc349d1120255f421cd510c40842eca557 100755 --- a/core/metrics/pairwise_pn.py +++ b/core/metrics/pairwise_pn.py @@ -28,8 +28,11 @@ class PosNegRatio(Metric): Metric For Fluid Model """ - def __init__(self, **kwargs): + def __init__(self, pos_score, neg_score): """ """ + kwargs = locals() + del kwargs['self'] + helper = LayerHelper("PaddleRec_PosNegRatio", **kwargs) if "pos_score" not in kwargs or "neg_score" not in kwargs: raise ValueError( diff --git a/core/metrics/binary_class/precision_recall.py b/core/metrics/precision_recall.py similarity index 91% rename from core/metrics/binary_class/precision_recall.py rename to core/metrics/precision_recall.py index a40b1e191b1cee7df9b4f457a0087ff3f58cce69..f7f25ca808642c4a8543bdd464b4748c421653e8 100755 --- a/core/metrics/binary_class/precision_recall.py +++ b/core/metrics/precision_recall.py @@ -28,19 +28,17 @@ class PrecisionRecall(Metric): Metric For Fluid Model """ - def __init__(self, **kwargs): + def __init__(self, input, label, class_num): """R """ - if "input" not in kwargs or "label" not in kwargs or "class_num" not in kwargs: - raise ValueError( - "PrecisionRecall expect input, label and class_num as inputs.") - predict = kwargs.get("input") - label = kwargs.get("label") - self.num_cls = kwargs.get("class_num") - - if not isinstance(predict, Variable): + kwargs = locals() + del kwargs['self'] + + self.num_cls = class_num + + if not isinstance(input, Variable): raise ValueError("input must be Variable, but received %s" % - type(predict)) + type(input)) if not isinstance(label, Variable): raise ValueError("label must be Variable, but received %s" % type(label)) @@ -48,7 +46,7 @@ class PrecisionRecall(Metric): helper = LayerHelper("PaddleRec_PrecisionRecall", **kwargs) label = fluid.layers.cast(label, dtype="int32") label.stop_gradient = True - max_probs, indices = fluid.layers.nn.topk(predict, k=1) + max_probs, indices = fluid.layers.nn.topk(input, k=1) indices = fluid.layers.cast(indices, dtype="int32") indices.stop_gradient = True diff --git a/core/metrics/recall_k.py b/core/metrics/recall_k.py index 27ade14503fe6d558c7f2345517bed831f57dccf..f727c25e97bf1486886310c30e2304cba568c8b8 100755 --- a/core/metrics/recall_k.py +++ b/core/metrics/recall_k.py @@ -29,23 +29,21 @@ class RecallK(Metric): Metric For Fluid Model """ - def __init__(self, **kwargs): + def __init__(self, input, label, k=20): """ """ - if "input" not in kwargs or "label" not in kwargs: - raise ValueError("RecallK expect input and label as inputs.") - predict = kwargs.get('input') - label = kwargs.get('label') - self.k = kwargs.get("k", 20) + kwargs = locals() + del kwargs['self'] + self.k = k - if not isinstance(predict, Variable): + if not isinstance(input, Variable): raise ValueError("input must be Variable, but received %s" % - type(predict)) + type(input)) if not isinstance(label, Variable): raise ValueError("label must be Variable, but received %s" % type(label)) helper = LayerHelper("PaddleRec_RecallK", **kwargs) - batch_accuracy = accuracy(predict, label, self.k) + batch_accuracy = accuracy(input, label, self.k) global_ins_cnt, _ = helper.create_or_get_global_variable( name="ins_cnt", persistable=True, dtype='float32', shape=[1]) global_pos_cnt, _ = helper.create_or_get_global_variable( diff --git a/doc/metrics.md b/doc/metrics.md new file mode 100644 index 0000000000000000000000000000000000000000..32efa0224023cd020c7b4ffd809d4dd55c808e4e --- /dev/null +++ b/doc/metrics.md @@ -0,0 +1,124 @@ +# 如何给模型增加Metric + +## PaddleRec Metric使用示例 +``` +from paddlerec.core.model import ModelBase +from paddlerec.core.metrics import RecallK + +class Model(ModelBase): + def __init__(self, config): + ModelBase.__init__(self, config) + + def net(self, inputs, is_infer=False): + ... + acc = RecallK(input=logits, label=label, k=20) + self._metrics["Train_P@20"] = acc +``` +## Metric类 +### 成员变量 +> _global_metric_state_vars(dict), +字典类型,用以存储metric计算过程中需要的中间状态变量。一般情况下,这些中间状态需要是Persistable=True的变量,所以会在模型保存的时候也会被保存下来。因此infer阶段需手动将这些中间状态值清零,进而保证预测结果的正确性。 + +### 成员函数 +> clear(self, scope): +从scope中将self._global_metric_state_vars中的状态值全清零。该函数一般用在**infer**阶段开始的时候。用以保证预测指标的正确性。 + +> calc_global_metrics(self, fleet, scope=None): +将self._global_metric_state_vars中的状态值在所有训练节点上做all_reduce操作,进而下一步调用_calculate()函数计算全局指标。若fleet=None,则all_reduce的结果为自己本身,即单机全局指标计算。 + +> get_result(self): 返回训练过程中需要fetch,并定期打印至屏幕的变量。返回类型为dict。 + +## Metrics +### AUC +> AUC(input ,label, curve='ROC', num_thresholds=2**12 - 1, topk=1, slide_steps=1) + +Auc,全称Area Under the Curve(AUC),该层根据前向输出和标签计算AUC,在二分类(binary classification)估计中广泛使用。在二分类(binary classification)中广泛使用。相关定义参考 https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve 。 + +#### 参数 +- **input(Tensor|LoDTensor)**: 数据类型为float32,float64。浮点二维变量。输入为网络的预测值。shape为[batch_size, 2]。 +- **label(Tensor|LoDTensor)**: 数据类型为int64,int32。输入为数据集的标签。shape为[batch_size, 1]。 +- **curve(str)**: 曲线类型,可以为 ROC 或 PR,默认 ROC。 +- **num_thresholds(int)**: 将roc曲线离散化时使用的临界值数。默认200。 +- **topk(int)**: 取topk的输出值用于计算。 +- **slide_steps(int)**: - 当计算batch auc时,不仅用当前步也用于先前步。slide_steps=1,表示用当前步;slide_steps = 3表示用当前步和前两步;slide_steps = 0,则用所有步。 + +#### 返回值 +该指标训练过程中定期的变量有两个: +- **AUC**: 整体AUC值 +- **BATCH_AUC**:当前batch的AUC值 + + +### PrecisionRecall +> PrecisionRecall(input, label, class_num) + +计算precison, recall, f1。 + +#### 参数 +- **input(Tensor|LoDTensor)**: 数据类型为float32,float64。输入为网络的预测值。shape为[batch_size, class_num] +- **label(Tensor|LoDTensor)**: 数据类型为int32。输入为数据集的标签。shape为 [batch_size, 1] +- **class_num(int)**: 类别个数。 + +#### 返回值 +- **[TP FP TN FN]**: 形状为[class_num, 4]的变量,用以表征每种类型的TP,FP,TN和FN值。TP=true positive, FP=false positive, TN=true negative, FN=false negative。若需计算每种类型的precison, recall,f1, 则可根据如下公式进行计算: +precision = TP / (TP + FP); recall = TP = TP / (TP + FN); F1 = 2 * precision * recall / (precision + recall)。 + +- **precision_recall_f1**: 形状为[6],分别代表[macro_avg_precision, macro_avg_recall, macro_avg_f1, micro_avg_precision, micro_avg_recall, micro_avg_f1],这里macro代表先计算每种类型的准确率,召回率,F1,然后求平均。micro代表先计算所有类型的整体TP,TN, FP, FN等中间值,然后在计算准确率,召回率,F1. + + +### RecallK +> RecallK(input, label, k=20) + +TopK的召回准确率,对于任意一条样本来说,若前top_k个分类结果中包含正确分类标签,则视为正样本。 + +#### 参数 +- **input(Tensor|LoDTensor)**: 数据类型为float32,float64。输入为网络的预测值。shape为[batch_size, class_dim] +- **label(Tensor|LoDTensor)**: 数据类型为int64,int32。输入为数据集的标签。shape为 [batch_size, 1] +- **k(int)**: 取每个类别中top_k个预测值用于计算召回准确率。 + +#### 返回值 +- **InsCnt**:样本总数 +- **RecallCnt**: topk可以正确被召回的样本数 +- **Acc(Recall@k)**: RecallCnt/InsCnt,即Topk召回准确率。 + +## PairWise_PN +> PosNegRatio(pos_score, neg_score) + +正逆序指标,一般用在输入是pairwise的模型中。例如输入既包含正样本,也包含负样本,模型需要去学习最大化正负样本打分的差异。 + +#### 参数 +- **pos_score(Tensor|LoDTensor)**: 正样本的打分,数据类型为float32,float64。浮点二维变量,值的范围为[0,1]。 +- **neg_score(Tensor|LoDTensor)**:负样本的打分。数据类型为float32,float64。浮点二维变量,值的范围为[0,1]。 + +#### 返回值 +- **RightCnt**: pos_score > neg_score的样本数 +- **WrongCnt**: pos_score <= neg_score的样本数 +- **PN**: (RightCnt + 1.0) / (WrongCnt + 1.0), 正逆序,+1.0是为了避免除0错误。 + +### Customized_Metric +如果你需要在自定义metric,那么你需要按如下步骤操作: +1. 继承paddlerec.core.Metric,定义你的MyMetric类。 +2. 在MyMetric的构造函数中,自定义Metric组网,声明self._global_metric_state_vars私有变量。 +3. 定义_calculate(global_metrics),全局指标计算。该函数的输入globla_metrics,存储了self._global_metric_state_vars中所有中间状态变量的全局统计值。最终结果以str格式返回。 + +自定义Metric模版如下,你可以参考注释,或paddlerec.core.metrics下已经实现的precision_recall, auc, pairwise_pn, recall_k等指标的计算方式,自定义自己的Metric类。 +``` +from paddlerec.core.Metric import Metric + +class MyMetric(Metric): + def __init__(self): + # 1. 自定义Metric组网 + ** 1. your code ** + + # 2. 设置中间状态字典 + self._global_metric_state_vars = dict() + ** 2. your code ** + + def get_result(self): + # 3. 定义训练过程中需要打印的变量,以字典格式返回 + self. _metrics = dict() + ** 3. your code ** + + def _calculate(self, global_metrics): + # 4. 全局指标计算,global_metrics为字典类型,存储了self._global_metric_state_vars中所有中间状态变量的全局统计值。返回格式为str。 + ** your code ** +``` diff --git a/doc/model_develop.md b/doc/model_develop.md index da9523fac2e20258cd488f61ca07900772f5ce78..f21a92a78cba4e8e81c8b1932db7cb07e6392d9c 100644 --- a/doc/model_develop.md +++ b/doc/model_develop.md @@ -113,6 +113,8 @@ def input_data(self, is_infer=False, **kwargs): 可以参考官方模型的示例学习net的构造方法。 +除可以使用Paddle的Metrics接口外,PaddleRec也统一封装了一些常见的Metrics评价指标,并允许开发者定义自己的Metrics类,相关文件参考[Metrics开发文档](metrics.md)。 + ## 如何运行自定义模型 记录`model.py`,`config.yaml`及数据读取`reader.py`的文件路径,建议置于同一文件夹下,如`/home/custom_model`下,更改`config.yaml`中的配置选项